diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index d6ea1edd38..4158a8cd16 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -2,7 +2,7 @@ name: CI env: CARGO_TERM_COLOR: always - MSRV: '1.66' + MSRV: '1.75' on: push: @@ -12,37 +12,46 @@ on: jobs: check: - runs-on: ubuntu-latest + runs-on: ubuntu-24.04 steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: taiki-e/install-action@protoc - uses: dtolnay/rust-toolchain@beta with: components: clippy, rustfmt - uses: Swatinem/rust-cache@v2 + with: + save-if: ${{ github.ref == 'refs/heads/main' }} + prefix-key: "v0-rust-ubuntu-24.04" - name: Check run: cargo clippy --workspace --all-targets --all-features -- -D warnings - name: rustfmt run: cargo fmt --all --check check-docs: - runs-on: ubuntu-latest + runs-on: ubuntu-24.04 steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@stable - uses: Swatinem/rust-cache@v2 + with: + save-if: ${{ github.ref == 'refs/heads/main' }} + prefix-key: "v0-rust-ubuntu-24.04" - name: cargo doc env: RUSTDOCFLAGS: "-D rustdoc::all -A rustdoc::private-doc-tests" run: cargo doc --all-features --no-deps cargo-hack: - runs-on: ubuntu-latest + runs-on: ubuntu-24.04 steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: taiki-e/install-action@protoc - uses: dtolnay/rust-toolchain@stable - uses: Swatinem/rust-cache@v2 + with: + save-if: ${{ github.ref == 'refs/heads/main' }} + prefix-key: "v0-rust-ubuntu-24.04" - name: Install cargo-hack run: | curl -LsSf https://github.com/taiki-e/cargo-hack/releases/latest/download/cargo-hack-x86_64-unknown-linux-gnu.tar.gz | tar xzf - -C ~/.cargo/bin @@ -50,42 +59,54 @@ jobs: run: cargo hack check --each-feature --no-dev-deps --all cargo-public-api-crates: - runs-on: ubuntu-latest + runs-on: ubuntu-24.04 strategy: matrix: crate: [axum, axum-core, axum-extra, axum-macros] steps: - - uses: actions/checkout@v3 - - uses: dtolnay/rust-toolchain@nightly + - uses: actions/checkout@v4 + # Pinned version due to failing `cargo-public-api-crates`. + - uses: dtolnay/rust-toolchain@master + with: + toolchain: nightly-2024-06-06 - uses: Swatinem/rust-cache@v2 + with: + save-if: ${{ github.ref == 'refs/heads/main' }} + prefix-key: "v0-rust-ubuntu-24.04" - name: Install cargo-public-api-crates run: | cargo install --git https://github.com/davidpdrsn/cargo-public-api-crates + - name: Build rustdoc + run: | + cargo rustdoc --all-features --manifest-path ${{ matrix.crate }}/Cargo.toml -- -Z unstable-options --output-format json - name: cargo public-api-crates check - run: cargo public-api-crates --manifest-path ${{ matrix.crate }}/Cargo.toml check + run: cargo public-api-crates --manifest-path ${{ matrix.crate }}/Cargo.toml --skip-build check test-versions: needs: check - runs-on: ubuntu-latest + runs-on: ubuntu-24.04 strategy: matrix: rust: [stable, beta] steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: taiki-e/install-action@protoc - uses: dtolnay/rust-toolchain@master with: toolchain: ${{ matrix.rust }} - uses: Swatinem/rust-cache@v2 + with: + save-if: ${{ github.ref == 'refs/heads/main' }} + prefix-key: "v0-rust-ubuntu-24.04" - name: Run tests run: cargo test --workspace --all-features --all-targets - # some examples doesn't support our MSRV so we only test axum itself on our MSRV + # some examples don't support our MSRV so we only test axum itself on our MSRV test-nightly: needs: check - runs-on: ubuntu-latest + runs-on: ubuntu-24.04 steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - name: Get rust-toolchain version id: rust-toolchain run: echo "version=$(cat axum-macros/rust-toolchain)" >> $GITHUB_OUTPUT @@ -93,23 +114,29 @@ jobs: with: toolchain: ${{ steps.rust-toolchain.outputs.version }} - uses: Swatinem/rust-cache@v2 + with: + save-if: ${{ github.ref == 'refs/heads/main' }} + prefix-key: "v0-rust-ubuntu-24.04" - name: Run nightly tests working-directory: axum-macros run: cargo test - # some examples doesn't support our MSRV (such as async-graphql) + # some examples don't support our MSRV (such as async-graphql) # so we only test axum itself on our MSRV test-msrv: needs: check - runs-on: ubuntu-latest + runs-on: ubuntu-24.04 steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@master with: toolchain: ${{ env.MSRV }} - name: "install Rust nightly" uses: dtolnay/rust-toolchain@nightly - uses: Swatinem/rust-cache@v2 + with: + save-if: ${{ github.ref == 'refs/heads/main' }} + prefix-key: "v0-rust-ubuntu-24.04" - name: Select minimal version run: cargo +nightly update -Z minimal-versions - name: Fix up Cargo.lock @@ -137,17 +164,20 @@ jobs: test-docs: needs: check - runs-on: ubuntu-latest + runs-on: ubuntu-24.04 steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@stable - uses: Swatinem/rust-cache@v2 + with: + save-if: ${{ github.ref == 'refs/heads/main' }} + prefix-key: "v0-rust-ubuntu-24.04" - name: Run doc tests run: cargo test --all-features --doc deny-check: name: cargo-deny check - runs-on: ubuntu-latest + runs-on: ubuntu-24.04 continue-on-error: ${{ matrix.checks == 'advisories' }} strategy: matrix: @@ -155,21 +185,24 @@ jobs: - advisories - bans licenses sources steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: EmbarkStudios/cargo-deny-action@v1 with: command: check ${{ matrix.checks }} - arguments: --all-features --manifest-path axum/Cargo.toml + manifest-path: axum/Cargo.toml armv5te-unknown-linux-musleabi: needs: check - runs-on: ubuntu-latest + runs-on: ubuntu-24.04 steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@stable with: target: armv5te-unknown-linux-musleabi - uses: Swatinem/rust-cache@v2 + with: + save-if: ${{ github.ref == 'refs/heads/main' }} + prefix-key: "v0-rust-ubuntu-24.04" - name: Check env: # Clang has native cross-compilation support @@ -187,13 +220,16 @@ jobs: wasm32-unknown-unknown: needs: check - runs-on: ubuntu-latest + runs-on: ubuntu-24.04 steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@stable with: target: wasm32-unknown-unknown - uses: Swatinem/rust-cache@v2 + with: + save-if: ${{ github.ref == 'refs/heads/main' }} + prefix-key: "v0-rust-ubuntu-24.04" - name: Check run: > cargo @@ -202,11 +238,14 @@ jobs: --target wasm32-unknown-unknown dependencies-are-sorted: - runs-on: ubuntu-latest + runs-on: ubuntu-24.04 steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v4 - uses: dtolnay/rust-toolchain@beta - uses: Swatinem/rust-cache@v2 + with: + save-if: ${{ github.ref == 'refs/heads/main' }} + prefix-key: "v0-rust-ubuntu-24.04" - name: Install cargo-sort run: | cargo install cargo-sort @@ -219,12 +258,12 @@ jobs: typos: name: Spell Check with Typos - runs-on: ubuntu-latest + runs-on: ubuntu-24.04 if: github.event_name == 'push' || !github.event.pull_request.draft steps: - name: Checkout Actions Repository - uses: actions/checkout@v3 + uses: actions/checkout@v4 - name: Check the spelling of the files in our repo - uses: crate-ci/typos@v1.16.2 + uses: crate-ci/typos@v1.20.8 diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 9fb4fc70ba..39ce75691f 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -267,7 +267,7 @@ If a Pull Request appears to be abandoned or stalled, it is polite to first check with the contributor to see if they intend to continue the work before checking if they would mind if you took it over (especially if it just has nits left). When doing so, it is courteous to give the original contributor credit -for the work they started (either by preserving their name and email address in +for the work they started, either by preserving their name and email address in the commit log, or by using an `Author: ` meta-data tag in the commit. [hiding-a-comment]: https://help.github.com/articles/managing-disruptive-comments/#hiding-a-comment diff --git a/Cargo.toml b/Cargo.toml index a68aaab16a..f9c9d027b2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,3 +5,6 @@ default-members = ["axum", "axum-*"] # Example has been deleted, but README.md remains exclude = ["examples/async-graphql"] resolver = "2" + +[workspace.package] +rust-version = "1.75" diff --git a/ECOSYSTEM.md b/ECOSYSTEM.md index e774333810..b97e917dd0 100644 --- a/ECOSYSTEM.md +++ b/ECOSYSTEM.md @@ -26,8 +26,8 @@ If your project isn't listed here and you would like it to be, please feel free - [axum-guard-logic](https://github.com/sjud/axum_guard_logic): Use AND/OR logic to extract types and check their values against `Service` inputs. - [axum-casbin-auth](https://github.com/casbin-rs/axum-casbin-auth): Casbin access control middleware for axum framework - [aide](https://docs.rs/aide): Code-first Open API documentation generator with [axum integration](https://docs.rs/aide/latest/aide/axum/index.html). +- [axum-typed-routing](https://docs.rs/axum-typed-routing/latest/axum_typed_routing/): Statically typed routing macros with OpenAPI generation using aide. - [axum-jsonschema](https://docs.rs/axum-jsonschema/): A `Json` extractor that does JSON schema validation of requests. -- [axum-sessions](https://docs.rs/axum-sessions): Cookie-based sessions for axum via async-session. - [axum-login](https://docs.rs/axum-login): Session-based user authentication for axum. - [axum-csrf-sync-pattern](https://crates.io/crates/axum-csrf-sync-pattern): A middleware implementing CSRF STP for AJAX backends and API endpoints. - [axum-otel-metrics](https://github.com/ttys3/axum-otel-metrics/): A axum OpenTelemetry Metrics middleware with prometheus exporter supported. @@ -45,6 +45,11 @@ If your project isn't listed here and you would like it to be, please feel free - [socketioxide](https://github.com/totodore/socketioxide): An easy to use socket.io server implementation working as a `tower` layer/service. - [axum-serde](https://github.com/gengteng/axum-serde): Provides multiple serde-based extractors / responses, also offers a macro to easily customize serde-based extractors / responses. - [loco.rs](https://github.com/loco-rs/loco): A full stack Web and API productivity framework similar to Rails, based on Axum. +- [axum-test](https://crates.io/crates/axum-test): High level library for writing Cargo tests that run against Axum. +- [axum-messages](https://github.com/maxcountryman/axum-messages): One-time notification messages for Axum. +- [spring-rs](https://github.com/spring-rs/spring-rs): spring-rs is a microservice framework written in rust inspired by java's spring-boot, based on axum +- [zino](https://github.com/zino-rs/zino): Zino is a next-generation framework for composable applications which provides full integrations with axum. +- [axum-rails-cookie](https://github.com/endoze/axum-rails-cookie): Extract rails session cookies in axum based apps. ## Project showcase @@ -58,6 +63,8 @@ If your project isn't listed here and you would like it to be, please feel free - [realworld-axum-sqlx](https://github.com/launchbadge/realworld-axum-sqlx): A Rust implementation of the [Realworld] demo app spec using Axum and [SQLx]. See https://github.com/davidpdrsn/realworld-axum-sqlx for a fork with up to date dependencies. - [Rustapi](https://github.com/ndelvalle/rustapi): RESTful API template using MongoDB +- [axum-postgres-template](https://github.com/koskeller/axum-postgres-template): Production-ready Axum + PostgreSQL application template +- [RUSTfulapi](https://github.com/robatipoor/rustfulapi): Reusable template for building REST Web Services in Rust. Uses Axum HTTP web framework and SeaORM. - [Jotsy](https://github.com/ohsayan/jotsy): Self-hosted notes app powered by Skytable, Axum and Tokio - [Svix](https://www.svix.com) ([repository](https://github.com/svix/svix-webhooks)): Enterprise-ready webhook service - [emojied](https://emojied.net) ([repository](https://github.com/sekunho/emojied)): Shorten URLs to emojis! @@ -78,10 +85,12 @@ If your project isn't listed here and you would like it to be, please feel free - [cobrust](https://github.com/scotow/cobrust): Multiplayer web based snake game. - [meta-cross](https://github.com/scotow/meta-cross): Tweaked version of Tic-Tac-Toe. - [httq](https://github.com/scotow/httq) HTTP to MQTT trivial proxy. +- [Pods-Blitz](https://pods-blitz.org) Self-hosted podcast publisher. Uses the crates axum-login, password-auth, sqlx and handlebars (for HTML templates). - [ReductStore](https://github.com/reductstore/reductstore): A time series database for storing and managing large amounts of blob data - [randoku](https://github.com/stchris/randoku): A tiny web service which generates random numbers and shuffles lists randomly - [sero](https://github.com/clowzed/sero): Host static sites with custom subdomains as surge.sh does. But with full control and cool new features. (axum, sea-orm, postgresql) - [Hatsu](https://github.com/importantimport/hatsu): 🩵 Self-hosted & Fully-automated ActivityPub Bridge for Static Sites. +- [Mini RPS](https://github.com/marcodpt/minirps): Mini reverse proxy server, HTTPS, CORS, static file hosting and template engine (minijinja). [Realworld]: https://github.com/gothinkster/realworld [SQLx]: https://github.com/launchbadge/sqlx @@ -97,6 +106,7 @@ If your project isn't listed here and you would like it to be, please feel free - [Introduction to axum]: YouTube playlist - [Rust Axum Full Course]: YouTube video - [Deploying Axum projects with Shuttle] +- [API Development with Rust](https://rust-api.dev/docs/front-matter/preface/): REST APIs based on Axum [axum-tutorial]: https://github.com/programatik29/axum-tutorial [axum-tutorial-website]: https://programatik29.github.io/axum-tutorial/ diff --git a/axum-core/CHANGELOG.md b/axum-core/CHANGELOG.md index 0ae41366a8..f2932a42e4 100644 --- a/axum-core/CHANGELOG.md +++ b/axum-core/CHANGELOG.md @@ -5,9 +5,44 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -# Unreleased +# 0.5.0 -- None. +## alpha.1 + +- **breaking:** Replace `#[async_trait]` with [return-position `impl Trait` in traits][RPITIT] ([#2308]) +- **change:** Update minimum rust version to 1.75 ([#2943]) + +[RPITIT]: https://blog.rust-lang.org/2023/12/21/async-fn-rpit-in-traits.html +[#2308]: https://github.com/tokio-rs/axum/pull/2308 +[#2943]: https://github.com/tokio-rs/axum/pull/2943 + +# 0.4.5 + +- **fixed:** Compile errors from the internal `__log_rejection` macro under + certain Cargo feature combinations between axum crates ([#2933]) + +[#2933]: https://github.com/tokio-rs/axum/pull/2933 + +# 0.4.4 + +- **added:** Derive `Clone` and `Copy` for `AppendHeaders` ([#2776]) +- **added:** `must_use` attribute on `AppendHeaders` ([#2846]) +- **added:** `must_use` attribute on `ErrorResponse` ([#2846]) +- **added:** `must_use` attribute on `IntoResponse::into_response` ([#2846]) +- **added:** `must_use` attribute on `IntoResponseParts` trait methods ([#2846]) +- **added:** Implement `Copy` for `DefaultBodyLimit` ([#2875]) +- **added**: `DefaultBodyLimit::max` and `DefaultBodyLimit::disable` are now + allowed in const context ([#2875]) + +[#2776]: https://github.com/tokio-rs/axum/pull/2776 +[#2846]: https://github.com/tokio-rs/axum/pull/2846 +[#2875]: https://github.com/tokio-rs/axum/pull/2875 + +# 0.4.3 (13. January, 2024) + +- **added:** Implement `IntoResponseParts` for `()` ([#2471]) + +[#2471]: https://github.com/tokio-rs/axum/pull/2471 # 0.4.2 (29. December, 2023) diff --git a/axum-core/Cargo.toml b/axum-core/Cargo.toml index fcceb3f866..d8207e1399 100644 --- a/axum-core/Cargo.toml +++ b/axum-core/Cargo.toml @@ -2,14 +2,14 @@ categories = ["asynchronous", "network-programming", "web-programming"] description = "Core types and traits for axum" edition = "2021" -rust-version = "1.56" +rust-version = { workspace = true } homepage = "https://github.com/tokio-rs/axum" keywords = ["http", "web", "framework"] license = "MIT" name = "axum-core" readme = "README.md" repository = "https://github.com/tokio-rs/axum" -version = "0.4.2" # remember to also bump the version that axum and axum-extra depend on +version = "0.5.0-alpha.1" # remember to bump the version that axum and axum-extra depend on [features] tracing = ["dep:tracing"] @@ -18,32 +18,29 @@ tracing = ["dep:tracing"] __private_docs = ["dep:tower-http"] [dependencies] -async-trait = "0.1.67" -bytes = "1.0" +bytes = "1.2" futures-util = { version = "0.3", default-features = false, features = ["alloc"] } http = "1.0.0" http-body = "1.0.0" http-body-util = "0.1.0" mime = "0.3.16" pin-project-lite = "0.2.7" -sync_wrapper = "0.1.1" +rustversion = "1.0.9" +sync_wrapper = "1.0.0" tower-layer = "0.3" tower-service = "0.3" # optional dependencies -tower-http = { version = "0.5.0", optional = true, features = ["limit"] } +tower-http = { version = "0.6.0", optional = true, features = ["limit"] } tracing = { version = "0.1.37", default-features = false, optional = true } -[build-dependencies] -rustversion = "1.0.9" - [dev-dependencies] -axum = { path = "../axum", version = "0.7.2" } +axum = { path = "../axum" } axum-extra = { path = "../axum-extra", features = ["typed-header"] } futures-util = { version = "0.3", default-features = false, features = ["alloc"] } hyper = "1.0.0" tokio = { version = "1.25.0", features = ["macros"] } -tower-http = { version = "0.5.0", features = ["limit"] } +tower-http = { version = "0.6.0", features = ["limit"] } [package.metadata.cargo-public-api-crates] allowed = [ @@ -57,6 +54,8 @@ allowed = [ "http_body", ] +[package.metadata.cargo-machete] +ignored = ["tower-http"] # See __private_docs feature + [package.metadata.docs.rs] all-features = true -rustdoc-args = ["--cfg", "docsrs"] diff --git a/axum-core/README.md b/axum-core/README.md index 01ff4e5105..600ec33791 100644 --- a/axum-core/README.md +++ b/axum-core/README.md @@ -14,7 +14,7 @@ This crate uses `#![forbid(unsafe_code)]` to ensure everything is implemented in ## Minimum supported Rust version -axum-core's MSRV is 1.56. +axum-core's MSRV is 1.75. ## Getting Help diff --git a/axum-core/build.rs b/axum-core/build.rs deleted file mode 100644 index b52885c626..0000000000 --- a/axum-core/build.rs +++ /dev/null @@ -1,7 +0,0 @@ -#[rustversion::nightly] -fn main() { - println!("cargo:rustc-cfg=nightly_error_messages"); -} - -#[rustversion::not(nightly)] -fn main() {} diff --git a/axum-core/src/ext_traits/mod.rs b/axum-core/src/ext_traits/mod.rs index 02595fbeac..951a12d70c 100644 --- a/axum-core/src/ext_traits/mod.rs +++ b/axum-core/src/ext_traits/mod.rs @@ -6,13 +6,11 @@ mod tests { use std::convert::Infallible; use crate::extract::{FromRef, FromRequestParts}; - use async_trait::async_trait; use http::request::Parts; #[derive(Debug, Default, Clone, Copy)] pub(crate) struct State(pub(crate) S); - #[async_trait] impl FromRequestParts for State where InnerState: FromRef, @@ -30,9 +28,9 @@ mod tests { } // some extractor that requires the state, such as `SignedCookieJar` + #[allow(dead_code)] pub(crate) struct RequiresState(pub(crate) String); - #[async_trait] impl FromRequestParts for RequiresState where S: Send + Sync, diff --git a/axum-core/src/ext_traits/request.rs b/axum-core/src/ext_traits/request.rs index 5b7aee783a..1123fdd3d6 100644 --- a/axum-core/src/ext_traits/request.rs +++ b/axum-core/src/ext_traits/request.rs @@ -1,6 +1,6 @@ use crate::body::Body; use crate::extract::{DefaultBodyLimitKind, FromRequest, FromRequestParts, Request}; -use futures_util::future::BoxFuture; +use std::future::Future; mod sealed { pub trait Sealed {} @@ -20,7 +20,6 @@ pub trait RequestExt: sealed::Sealed + Sized { /// /// ``` /// use axum::{ - /// async_trait, /// extract::{Request, FromRequest}, /// body::Body, /// http::{header::CONTENT_TYPE, StatusCode}, @@ -30,7 +29,6 @@ pub trait RequestExt: sealed::Sealed + Sized { /// /// struct FormOrJson(T); /// - /// #[async_trait] /// impl FromRequest for FormOrJson /// where /// Json: FromRequest<()>, @@ -67,7 +65,7 @@ pub trait RequestExt: sealed::Sealed + Sized { /// } /// } /// ``` - fn extract(self) -> BoxFuture<'static, Result> + fn extract(self) -> impl Future> + Send where E: FromRequest<(), M> + 'static, M: 'static; @@ -83,7 +81,6 @@ pub trait RequestExt: sealed::Sealed + Sized { /// /// ``` /// use axum::{ - /// async_trait, /// body::Body, /// extract::{Request, FromRef, FromRequest}, /// RequestExt, @@ -93,7 +90,6 @@ pub trait RequestExt: sealed::Sealed + Sized { /// requires_state: RequiresState, /// } /// - /// #[async_trait] /// impl FromRequest for MyExtractor /// where /// String: FromRef, @@ -111,7 +107,6 @@ pub trait RequestExt: sealed::Sealed + Sized { /// // some extractor that consumes the request body and requires state /// struct RequiresState { /* ... */ } /// - /// #[async_trait] /// impl FromRequest for RequiresState /// where /// String: FromRef, @@ -124,7 +119,10 @@ pub trait RequestExt: sealed::Sealed + Sized { /// # } /// } /// ``` - fn extract_with_state(self, state: &S) -> BoxFuture<'_, Result> + fn extract_with_state( + self, + state: &S, + ) -> impl Future> + Send where E: FromRequest + 'static, S: Send + Sync; @@ -137,7 +135,6 @@ pub trait RequestExt: sealed::Sealed + Sized { /// /// ``` /// use axum::{ - /// async_trait, /// extract::{Path, Request, FromRequest}, /// response::{IntoResponse, Response}, /// body::Body, @@ -154,7 +151,6 @@ pub trait RequestExt: sealed::Sealed + Sized { /// payload: T, /// } /// - /// #[async_trait] /// impl FromRequest for MyExtractor /// where /// S: Send + Sync, @@ -179,7 +175,7 @@ pub trait RequestExt: sealed::Sealed + Sized { /// } /// } /// ``` - fn extract_parts(&mut self) -> BoxFuture<'_, Result> + fn extract_parts(&mut self) -> impl Future> + Send where E: FromRequestParts<()> + 'static; @@ -191,7 +187,6 @@ pub trait RequestExt: sealed::Sealed + Sized { /// /// ``` /// use axum::{ - /// async_trait, /// extract::{Request, FromRef, FromRequest, FromRequestParts}, /// http::request::Parts, /// response::{IntoResponse, Response}, @@ -204,7 +199,6 @@ pub trait RequestExt: sealed::Sealed + Sized { /// payload: T, /// } /// - /// #[async_trait] /// impl FromRequest for MyExtractor /// where /// String: FromRef, @@ -234,7 +228,6 @@ pub trait RequestExt: sealed::Sealed + Sized { /// /// struct RequiresState {} /// - /// #[async_trait] /// impl FromRequestParts for RequiresState /// where /// String: FromRef, @@ -250,7 +243,7 @@ pub trait RequestExt: sealed::Sealed + Sized { fn extract_parts_with_state<'a, E, S>( &'a mut self, state: &'a S, - ) -> BoxFuture<'a, Result> + ) -> impl Future> + Send + 'a where E: FromRequestParts + 'static, S: Send + Sync; @@ -267,7 +260,7 @@ pub trait RequestExt: sealed::Sealed + Sized { } impl RequestExt for Request { - fn extract(self) -> BoxFuture<'static, Result> + fn extract(self) -> impl Future> + Send where E: FromRequest<(), M> + 'static, M: 'static, @@ -275,7 +268,10 @@ impl RequestExt for Request { self.extract_with_state(&()) } - fn extract_with_state(self, state: &S) -> BoxFuture<'_, Result> + fn extract_with_state( + self, + state: &S, + ) -> impl Future> + Send where E: FromRequest + 'static, S: Send + Sync, @@ -283,17 +279,17 @@ impl RequestExt for Request { E::from_request(self, state) } - fn extract_parts(&mut self) -> BoxFuture<'_, Result> + fn extract_parts(&mut self) -> impl Future> + Send where E: FromRequestParts<()> + 'static, { self.extract_parts_with_state(&()) } - fn extract_parts_with_state<'a, E, S>( + async fn extract_parts_with_state<'a, E, S>( &'a mut self, state: &'a S, - ) -> BoxFuture<'a, Result> + ) -> Result where E: FromRequestParts + 'static, S: Send + Sync, @@ -306,17 +302,15 @@ impl RequestExt for Request { *req.extensions_mut() = std::mem::take(self.extensions_mut()); let (mut parts, ()) = req.into_parts(); - Box::pin(async move { - let result = E::from_request_parts(&mut parts, state).await; + let result = E::from_request_parts(&mut parts, state).await; - *self.version_mut() = parts.version; - *self.method_mut() = parts.method.clone(); - *self.uri_mut() = parts.uri.clone(); - *self.headers_mut() = std::mem::take(&mut parts.headers); - *self.extensions_mut() = std::mem::take(&mut parts.extensions); + *self.version_mut() = parts.version; + *self.method_mut() = parts.method.clone(); + *self.uri_mut() = parts.uri.clone(); + *self.headers_mut() = std::mem::take(&mut parts.headers); + *self.extensions_mut() = std::mem::take(&mut parts.extensions); - result - }) + result } fn with_limited_body(self) -> Request { @@ -345,7 +339,6 @@ mod tests { ext_traits::tests::{RequiresState, State}, extract::FromRef, }; - use async_trait::async_trait; use http::Method; #[tokio::test] @@ -414,7 +407,6 @@ mod tests { body: String, } - #[async_trait] impl FromRequest for WorksForCustomExtractor where S: Send + Sync, diff --git a/axum-core/src/ext_traits/request_parts.rs b/axum-core/src/ext_traits/request_parts.rs index e7063f4d8b..9e1a3d1c16 100644 --- a/axum-core/src/ext_traits/request_parts.rs +++ b/axum-core/src/ext_traits/request_parts.rs @@ -1,6 +1,6 @@ use crate::extract::FromRequestParts; -use futures_util::future::BoxFuture; use http::request::Parts; +use std::future::Future; mod sealed { pub trait Sealed {} @@ -21,7 +21,6 @@ pub trait RequestPartsExt: sealed::Sealed + Sized { /// response::{Response, IntoResponse}, /// http::request::Parts, /// RequestPartsExt, - /// async_trait, /// }; /// use std::collections::HashMap; /// @@ -30,7 +29,6 @@ pub trait RequestPartsExt: sealed::Sealed + Sized { /// query_params: HashMap, /// } /// - /// #[async_trait] /// impl FromRequestParts for MyExtractor /// where /// S: Send + Sync, @@ -54,7 +52,7 @@ pub trait RequestPartsExt: sealed::Sealed + Sized { /// } /// } /// ``` - fn extract(&mut self) -> BoxFuture<'_, Result> + fn extract(&mut self) -> impl Future> + Send where E: FromRequestParts<()> + 'static; @@ -70,14 +68,12 @@ pub trait RequestPartsExt: sealed::Sealed + Sized { /// response::{Response, IntoResponse}, /// http::request::Parts, /// RequestPartsExt, - /// async_trait, /// }; /// /// struct MyExtractor { /// requires_state: RequiresState, /// } /// - /// #[async_trait] /// impl FromRequestParts for MyExtractor /// where /// String: FromRef, @@ -97,7 +93,6 @@ pub trait RequestPartsExt: sealed::Sealed + Sized { /// struct RequiresState { /* ... */ } /// /// // some extractor that requires a `String` in the state - /// #[async_trait] /// impl FromRequestParts for RequiresState /// where /// String: FromRef, @@ -113,14 +108,14 @@ pub trait RequestPartsExt: sealed::Sealed + Sized { fn extract_with_state<'a, E, S>( &'a mut self, state: &'a S, - ) -> BoxFuture<'a, Result> + ) -> impl Future> + Send + 'a where E: FromRequestParts + 'static, S: Send + Sync; } impl RequestPartsExt for Parts { - fn extract(&mut self) -> BoxFuture<'_, Result> + fn extract(&mut self) -> impl Future> + Send where E: FromRequestParts<()> + 'static, { @@ -130,7 +125,7 @@ impl RequestPartsExt for Parts { fn extract_with_state<'a, E, S>( &'a mut self, state: &'a S, - ) -> BoxFuture<'a, Result> + ) -> impl Future> + Send + 'a where E: FromRequestParts + 'static, S: Send + Sync, @@ -148,7 +143,6 @@ mod tests { ext_traits::tests::{RequiresState, State}, extract::FromRef, }; - use async_trait::async_trait; use http::{Method, Request}; #[tokio::test] @@ -181,7 +175,6 @@ mod tests { from_state: String, } - #[async_trait] impl FromRequestParts for WorksForCustomExtractor where S: Send + Sync, diff --git a/axum-core/src/extract/default_body_limit.rs b/axum-core/src/extract/default_body_limit.rs index 2ec82febc6..a045d1cd3f 100644 --- a/axum-core/src/extract/default_body_limit.rs +++ b/axum-core/src/extract/default_body_limit.rs @@ -72,7 +72,7 @@ use tower_layer::Layer; /// [`RequestBodyLimit`]: tower_http::limit::RequestBodyLimit /// [`RequestExt::with_limited_body`]: crate::RequestExt::with_limited_body /// [`RequestExt::into_limited_body`]: crate::RequestExt::into_limited_body -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Copy)] #[must_use] pub struct DefaultBodyLimit { kind: DefaultBodyLimitKind, @@ -116,7 +116,7 @@ impl DefaultBodyLimit { /// [`Bytes`]: bytes::Bytes /// [`Json`]: https://docs.rs/axum/0.7/axum/struct.Json.html /// [`Form`]: https://docs.rs/axum/0.7/axum/struct.Form.html - pub fn disable() -> Self { + pub const fn disable() -> Self { Self { kind: DefaultBodyLimitKind::Disable, } @@ -149,7 +149,7 @@ impl DefaultBodyLimit { /// [`Bytes::from_request`]: bytes::Bytes /// [`Json`]: https://docs.rs/axum/0.7/axum/struct.Json.html /// [`Form`]: https://docs.rs/axum/0.7/axum/struct.Form.html - pub fn max(limit: usize) -> Self { + pub const fn max(limit: usize) -> Self { Self { kind: DefaultBodyLimitKind::Limit(limit), } diff --git a/axum-core/src/extract/mod.rs b/axum-core/src/extract/mod.rs index c8e2d2196f..1baa893555 100644 --- a/axum-core/src/extract/mod.rs +++ b/axum-core/src/extract/mod.rs @@ -5,9 +5,9 @@ //! [`axum::extract`]: https://docs.rs/axum/0.7/axum/extract/index.html use crate::{body::Body, response::IntoResponse}; -use async_trait::async_trait; use http::request::Parts; use std::convert::Infallible; +use std::future::Future; pub mod rejection; @@ -42,9 +42,8 @@ mod private { /// See [`axum::extract`] for more general docs about extractors. /// /// [`axum::extract`]: https://docs.rs/axum/0.7/axum/extract/index.html -#[async_trait] -#[cfg_attr( - nightly_error_messages, +#[rustversion::attr( + since(1.78), diagnostic::on_unimplemented( note = "Function argument is not a valid axum extractor. \nSee `https://docs.rs/axum/0.7/axum/extract/index.html` for details", ) @@ -55,7 +54,10 @@ pub trait FromRequestParts: Sized { type Rejection: IntoResponse; /// Perform the extraction. - async fn from_request_parts(parts: &mut Parts, state: &S) -> Result; + fn from_request_parts( + parts: &mut Parts, + state: &S, + ) -> impl Future> + Send; } /// Types that can be created from requests. @@ -69,9 +71,8 @@ pub trait FromRequestParts: Sized { /// See [`axum::extract`] for more general docs about extractors. /// /// [`axum::extract`]: https://docs.rs/axum/0.7/axum/extract/index.html -#[async_trait] -#[cfg_attr( - nightly_error_messages, +#[rustversion::attr( + since(1.78), diagnostic::on_unimplemented( note = "Function argument is not a valid axum extractor. \nSee `https://docs.rs/axum/0.7/axum/extract/index.html` for details", ) @@ -82,10 +83,12 @@ pub trait FromRequest: Sized { type Rejection: IntoResponse; /// Perform the extraction. - async fn from_request(req: Request, state: &S) -> Result; + fn from_request( + req: Request, + state: &S, + ) -> impl Future> + Send; } -#[async_trait] impl FromRequest for T where S: Send + Sync, @@ -99,7 +102,6 @@ where } } -#[async_trait] impl FromRequestParts for Option where T: FromRequestParts, @@ -115,7 +117,6 @@ where } } -#[async_trait] impl FromRequest for Option where T: FromRequest, @@ -128,7 +129,6 @@ where } } -#[async_trait] impl FromRequestParts for Result where T: FromRequestParts, @@ -141,7 +141,6 @@ where } } -#[async_trait] impl FromRequest for Result where T: FromRequest, diff --git a/axum-core/src/extract/rejection.rs b/axum-core/src/extract/rejection.rs index 34b8115bd4..c5c3b1db3d 100644 --- a/axum-core/src/extract/rejection.rs +++ b/axum-core/src/extract/rejection.rs @@ -42,7 +42,7 @@ define_rejection! { #[body = "Failed to buffer the request body"] /// Encountered some other error when buffering the body. /// - /// This can _only_ happen when you're using [`tower_http::limit::RequestBodyLimitLayer`] or + /// This can _only_ happen when you're using [`tower_http::limit::RequestBodyLimitLayer`] or /// otherwise wrapping request bodies in [`http_body_util::Limited`]. pub struct LengthLimitError(Error); } diff --git a/axum-core/src/extract/request_parts.rs b/axum-core/src/extract/request_parts.rs index 73f54db793..695f7e1e9e 100644 --- a/axum-core/src/extract/request_parts.rs +++ b/axum-core/src/extract/request_parts.rs @@ -1,12 +1,10 @@ use super::{rejection::*, FromRequest, FromRequestParts, Request}; use crate::{body::Body, RequestExt}; -use async_trait::async_trait; -use bytes::Bytes; +use bytes::{BufMut, Bytes, BytesMut}; use http::{request::Parts, Extensions, HeaderMap, Method, Uri, Version}; use http_body_util::BodyExt; use std::convert::Infallible; -#[async_trait] impl FromRequest for Request where S: Send + Sync, @@ -18,7 +16,6 @@ where } } -#[async_trait] impl FromRequestParts for Method where S: Send + Sync, @@ -30,7 +27,6 @@ where } } -#[async_trait] impl FromRequestParts for Uri where S: Send + Sync, @@ -42,7 +38,6 @@ where } } -#[async_trait] impl FromRequestParts for Version where S: Send + Sync, @@ -59,7 +54,6 @@ where /// Prefer using [`TypedHeader`] to extract only the headers you need. /// /// [`TypedHeader`]: https://docs.rs/axum/0.7/axum/extract/struct.TypedHeader.html -#[async_trait] impl FromRequestParts for HeaderMap where S: Send + Sync, @@ -71,7 +65,36 @@ where } } -#[async_trait] +impl FromRequest for BytesMut +where + S: Send + Sync, +{ + type Rejection = BytesRejection; + + async fn from_request(req: Request, _: &S) -> Result { + let mut body = req.into_limited_body(); + let mut bytes = BytesMut::new(); + body_to_bytes_mut(&mut body, &mut bytes).await?; + Ok(bytes) + } +} + +async fn body_to_bytes_mut(body: &mut Body, bytes: &mut BytesMut) -> Result<(), BytesRejection> { + while let Some(frame) = body + .frame() + .await + .transpose() + .map_err(FailedToBufferBody::from_err)? + { + let Ok(data) = frame.into_data() else { + return Ok(()); + }; + bytes.put(data); + } + + Ok(()) +} + impl FromRequest for Bytes where S: Send + Sync, @@ -90,7 +113,6 @@ where } } -#[async_trait] impl FromRequest for String where S: Send + Sync, @@ -106,15 +128,12 @@ where } })?; - let string = std::str::from_utf8(&bytes) - .map_err(InvalidUtf8::from_err)? - .to_owned(); + let string = String::from_utf8(bytes.into()).map_err(InvalidUtf8::from_err)?; Ok(string) } } -#[async_trait] impl FromRequestParts for Parts where S: Send + Sync, @@ -126,7 +145,6 @@ where } } -#[async_trait] impl FromRequestParts for Extensions where S: Send + Sync, @@ -138,7 +156,6 @@ where } } -#[async_trait] impl FromRequest for Body where S: Send + Sync, diff --git a/axum-core/src/extract/tuple.rs b/axum-core/src/extract/tuple.rs index 021b9616df..cbd91a7fb3 100644 --- a/axum-core/src/extract/tuple.rs +++ b/axum-core/src/extract/tuple.rs @@ -1,10 +1,8 @@ use super::{FromRequest, FromRequestParts, Request}; use crate::response::{IntoResponse, Response}; -use async_trait::async_trait; use http::request::Parts; use std::convert::Infallible; -#[async_trait] impl FromRequestParts for () where S: Send + Sync, @@ -20,7 +18,6 @@ macro_rules! impl_from_request { ( [$($ty:ident),*], $last:ident ) => { - #[async_trait] #[allow(non_snake_case, unused_mut, unused_variables)] impl FromRequestParts for ($($ty,)* $last,) where @@ -46,7 +43,6 @@ macro_rules! impl_from_request { // This impl must not be generic over M, otherwise it would conflict with the blanket // implementation of `FromRequest` for `T: FromRequestParts`. - #[async_trait] #[allow(non_snake_case, unused_mut, unused_variables)] impl FromRequest for ($($ty,)* $last,) where diff --git a/axum-core/src/lib.rs b/axum-core/src/lib.rs index a4dd6cd969..134c566b30 100644 --- a/axum-core/src/lib.rs +++ b/axum-core/src/lib.rs @@ -1,4 +1,3 @@ -#![cfg_attr(nightly_error_messages, feature(diagnostic_namespace))] //! Core types and traits for [`axum`]. //! //! Libraries authors that want to provide [`FromRequest`] or [`IntoResponse`] implementations @@ -22,7 +21,6 @@ clippy::needless_borrow, clippy::match_wildcard_for_single_variants, clippy::if_let_mutex, - clippy::mismatched_target_os, clippy::await_holding_lock, clippy::match_on_vec_items, clippy::imprecise_flops, @@ -52,6 +50,11 @@ #[macro_use] pub(crate) mod macros; +#[doc(hidden)] // macro helpers +pub mod __private { + #[cfg(feature = "tracing")] + pub use tracing; +} mod error; mod ext_traits; diff --git a/axum-core/src/macros.rs b/axum-core/src/macros.rs index 3fa61576be..aa99ba402e 100644 --- a/axum-core/src/macros.rs +++ b/axum-core/src/macros.rs @@ -1,4 +1,5 @@ /// Private API. +#[cfg(feature = "tracing")] #[doc(hidden)] #[macro_export] macro_rules! __log_rejection { @@ -7,20 +8,30 @@ macro_rules! __log_rejection { body_text = $body_text:expr, status = $status:expr, ) => { - #[cfg(feature = "tracing")] { - tracing::event!( + $crate::__private::tracing::event!( target: "axum::rejection", - tracing::Level::TRACE, + $crate::__private::tracing::Level::TRACE, status = $status.as_u16(), body = $body_text, - rejection_type = std::any::type_name::<$ty>(), + rejection_type = ::std::any::type_name::<$ty>(), "rejecting request", ); } }; } +#[cfg(not(feature = "tracing"))] +#[doc(hidden)] +#[macro_export] +macro_rules! __log_rejection { + ( + rejection_type = $ty:ident, + body_text = $body_text:expr, + status = $status:expr, + ) => {}; +} + /// Private API. #[doc(hidden)] #[macro_export] @@ -303,8 +314,6 @@ mod composite_rejection_tests { #[allow(dead_code, unreachable_pub)] mod defs { - use crate::{__composite_rejection, __define_rejection}; - __define_rejection! { #[status = BAD_REQUEST] #[body = "error message 1"] diff --git a/axum-core/src/response/append_headers.rs b/axum-core/src/response/append_headers.rs index e4ac4812f9..aa8f2dbdfb 100644 --- a/axum-core/src/response/append_headers.rs +++ b/axum-core/src/response/append_headers.rs @@ -29,7 +29,7 @@ use std::fmt; /// ) /// } /// ``` -#[derive(Debug)] +#[derive(Debug, Clone, Copy)] #[must_use] pub struct AppendHeaders(pub I); diff --git a/axum-core/src/response/into_response.rs b/axum-core/src/response/into_response.rs index dbc1b57fb6..9bc0c2fec2 100644 --- a/axum-core/src/response/into_response.rs +++ b/axum-core/src/response/into_response.rs @@ -49,7 +49,7 @@ use std::{ /// MyError::SomethingElseWentWrong => "something else went wrong", /// }; /// -/// // its often easiest to implement `IntoResponse` by calling other implementations +/// // it's often easiest to implement `IntoResponse` by calling other implementations /// (StatusCode::INTERNAL_SERVER_ERROR, body).into_response() /// } /// } @@ -113,6 +113,7 @@ use std::{ /// ``` pub trait IntoResponse { /// Create a response. + #[must_use] fn into_response(self) -> Response; } diff --git a/axum-core/src/response/into_response_parts.rs b/axum-core/src/response/into_response_parts.rs index 94ce03092d..2deabb6448 100644 --- a/axum-core/src/response/into_response_parts.rs +++ b/axum-core/src/response/into_response_parts.rs @@ -44,7 +44,7 @@ use std::{convert::Infallible, fmt}; /// } /// } /// -/// // Its also recommended to implement `IntoResponse` so `SetHeader` can be used on its own as +/// // It's also recommended to implement `IntoResponse` so `SetHeader` can be used on its own as /// // the response /// impl<'a> IntoResponse for SetHeader<'a> { /// fn into_response(self) -> Response { @@ -105,21 +105,25 @@ pub struct ResponseParts { impl ResponseParts { /// Gets a reference to the response headers. + #[must_use] pub fn headers(&self) -> &HeaderMap { self.res.headers() } /// Gets a mutable reference to the response headers. + #[must_use] pub fn headers_mut(&mut self) -> &mut HeaderMap { self.res.headers_mut() } /// Gets a reference to the response extensions. + #[must_use] pub fn extensions(&self) -> &Extensions { self.res.extensions() } /// Gets a mutable reference to the response extensions. + #[must_use] pub fn extensions_mut(&mut self) -> &mut Extensions { self.res.extensions_mut() } @@ -260,3 +264,11 @@ impl IntoResponseParts for Extensions { Ok(res) } } + +impl IntoResponseParts for () { + type Error = Infallible; + + fn into_response_parts(self, res: ResponseParts) -> Result { + Ok(res) + } +} diff --git a/axum-core/src/response/mod.rs b/axum-core/src/response/mod.rs index 60a98a063e..0a9a478219 100644 --- a/axum-core/src/response/mod.rs +++ b/axum-core/src/response/mod.rs @@ -121,6 +121,7 @@ where /// /// See [`Result`] for more details. #[derive(Debug)] +#[must_use] pub struct ErrorResponse(Response); impl From for ErrorResponse diff --git a/axum-extra/CHANGELOG.md b/axum-extra/CHANGELOG.md index 2dfa296f07..327e3edbb4 100644 --- a/axum-extra/CHANGELOG.md +++ b/axum-extra/CHANGELOG.md @@ -7,7 +7,59 @@ and this project adheres to [Semantic Versioning]. # Unreleased -- None. +- **fixed:** `Host` extractor includes port number when parsing authority ([#2242]) +- **added:** Add `RouterExt::typed_connect` ([#2961]) +- **added:** Add `json!` for easy construction of JSON responses ([#2962]) + +[#2242]: https://github.com/tokio-rs/axum/pull/2242 +[#2961]: https://github.com/tokio-rs/axum/pull/2961 +[#2962]: https://github.com/tokio-rs/axum/pull/2962 + +# 0.10.0 + +# alpha.1 + +- **breaking:** Update to prost 0.13. Used for the `Protobuf` extractor ([#2829]) +- **change:** Update minimum rust version to 1.75 ([#2943]) + +[#2829]: https://github.com/tokio-rs/axum/pull/2829 +[#2943]: https://github.com/tokio-rs/axum/pull/2943 + +# 0.9.6 + +- **docs:** Add links to features table ([#3030]) + +[#3030]: https://github.com/tokio-rs/axum/pull/3030 + +# 0.9.5 + +- **added:** Add `RouterExt::typed_connect` ([#2961]) +- **added:** Add `json!` for easy construction of JSON responses ([#2962]) + +[#2961]: https://github.com/tokio-rs/axum/pull/2961 +[#2962]: https://github.com/tokio-rs/axum/pull/2962 + +# 0.9.4 + +- **added:** The `response::Attachment` type ([#2789]) + +[#2789]: https://github.com/tokio-rs/axum/pull/2789 + +# 0.9.3 (24. March, 2024) + +- **added:** New `tracing` feature which enables logging rejections from + built-in extractor with the `axum::rejection=trace` target ([#2584]) + +[#2584]: https://github.com/tokio-rs/axum/pull/2584 + +# 0.9.2 (13. January, 2024) + +- **added:** Implement `TypedPath` for `WithRejection` +- **fixed:** Documentation link to `serde::Deserialize` in `JsonDeserializer` extractor ([#2498]) +- **added:** Add `is_missing` function for `TypedHeaderRejection` and `TypedHeaderRejectionReason` ([#2503]) + +[#2498]: https://github.com/tokio-rs/axum/pull/2498 +[#2503]: https://github.com/tokio-rs/axum/pull/2503 # 0.9.1 (29. December, 2023) diff --git a/axum-extra/Cargo.toml b/axum-extra/Cargo.toml index 6d3365a903..c357975ad4 100644 --- a/axum-extra/Cargo.toml +++ b/axum-extra/Cargo.toml @@ -2,24 +2,26 @@ categories = ["asynchronous", "network-programming", "web-programming"] description = "Extra utilities for axum" edition = "2021" -rust-version = "1.66" +rust-version = { workspace = true } homepage = "https://github.com/tokio-rs/axum" keywords = ["http", "web", "framework"] license = "MIT" name = "axum-extra" readme = "README.md" repository = "https://github.com/tokio-rs/axum" -version = "0.9.1" +version = "0.10.0-alpha.1" [features] -default = [] +default = ["tracing", "multipart"] async-read-body = ["dep:tokio-util", "tokio-util?/io", "dep:tokio"] +attachment = ["dep:tracing"] +error_response = ["dep:tracing", "tracing/std"] cookie = ["dep:cookie"] cookie-private = ["cookie", "cookie?/private"] cookie-signed = ["cookie", "cookie?/signed"] cookie-key-expansion = ["cookie", "cookie?/key-expansion"] -erased-json = ["dep:serde_json"] +erased-json = ["dep:serde_json", "dep:typed-json"] form = ["dep:serde_html_form"] json-deserializer = ["dep:serde_json", "dep:serde_path_to_error"] json-lines = [ @@ -30,15 +32,17 @@ json-lines = [ "tokio-stream?/io-util", "dep:tokio", ] -multipart = ["dep:multer"] +multipart = ["dep:multer", "dep:fastrand"] protobuf = ["dep:prost"] +scheme = [] query = ["dep:serde_html_form"] +tracing = ["axum-core/tracing", "axum/tracing"] typed-header = ["dep:headers"] typed-routing = ["dep:axum-macros", "dep:percent-encoding", "dep:serde_html_form", "dep:form_urlencoded"] [dependencies] -axum = { path = "../axum", version = "0.7.2", default-features = false } -axum-core = { path = "../axum-core", version = "0.4.2" } +axum = { path = "../axum", version = "0.8.0-alpha.1", default-features = false, features = ["original-uri"] } +axum-core = { path = "../axum-core", version = "0.5.0-alpha.1" } bytes = "1.1.0" futures-util = { version = "0.3", default-features = false, features = ["alloc"] } http = "1.0.0" @@ -47,38 +51,41 @@ http-body-util = "0.1.0" mime = "0.3" pin-project-lite = "0.2" serde = "1.0" -tower = { version = "0.4", default_features = false, features = ["util"] } +tower = { version = "0.5.1", default-features = false, features = ["util"] } tower-layer = "0.3" tower-service = "0.3" # optional dependencies -axum-macros = { path = "../axum-macros", version = "0.4.0", optional = true } +axum-macros = { path = "../axum-macros", version = "0.5.0-alpha.1", optional = true } cookie = { package = "cookie", version = "0.18.0", features = ["percent-encode"], optional = true } +fastrand = { version = "2.1.0", optional = true } form_urlencoded = { version = "1.1.0", optional = true } headers = { version = "0.4.0", optional = true } multer = { version = "3.0.0", optional = true } percent-encoding = { version = "2.1", optional = true } -prost = { version = "0.12", optional = true } +prost = { version = "0.13", optional = true } serde_html_form = { version = "0.2.0", optional = true } serde_json = { version = "1.0.71", optional = true } serde_path_to_error = { version = "0.1.8", optional = true } tokio = { version = "1.19", optional = true } tokio-stream = { version = "0.1.9", optional = true } tokio-util = { version = "0.7", optional = true } +tracing = { version = "0.1.37", default-features = false, optional = true } +typed-json = { version = "0.1.1", optional = true } [dev-dependencies] -axum = { path = "../axum", version = "0.7.2" } +axum = { path = "../axum", features = ["macros"] } +axum-macros = { path = "../axum-macros", features = ["__private"] } hyper = "1.0.0" -reqwest = { version = "0.11", default-features = false, features = ["json", "stream", "multipart"] } +reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "multipart"] } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0.71" tokio = { version = "1.14", features = ["full"] } -tower = { version = "0.4", features = ["util"] } -tower-http = { version = "0.5.0", features = ["map-response-body", "timeout"] } +tower = { version = "0.5.1", features = ["util"] } +tower-http = { version = "0.6.0", features = ["map-response-body", "timeout"] } [package.metadata.docs.rs] all-features = true -rustdoc-args = ["--cfg", "docsrs"] [package.metadata.cargo-public-api-crates] allowed = [ @@ -93,6 +100,7 @@ allowed = [ "headers_core", "http", "http_body", + "pin_project_lite", "prost", "serde", "tokio", diff --git a/axum-extra/README.md b/axum-extra/README.md index 16b96cc8c9..7d3e904e9c 100644 --- a/axum-extra/README.md +++ b/axum-extra/README.md @@ -14,7 +14,7 @@ This crate uses `#![forbid(unsafe_code)]` to ensure everything is implemented in ## Minimum supported Rust version -axum-extra's MSRV is 1.66. +axum-extra's MSRV is 1.75. ## Getting Help diff --git a/axum-extra/src/either.rs b/axum-extra/src/either.rs index 2742debb85..9fa1f82f3f 100755 --- a/axum-extra/src/either.rs +++ b/axum-extra/src/either.rs @@ -7,7 +7,6 @@ //! use axum::{ //! body::Bytes, //! Router, -//! async_trait, //! routing::get, //! extract::FromRequestParts, //! }; @@ -15,7 +14,6 @@ //! // extractors for checking permissions //! struct AdminPermissions {} //! -//! #[async_trait] //! impl FromRequestParts for AdminPermissions //! where //! S: Send + Sync, @@ -29,7 +27,6 @@ //! //! struct User {} //! -//! #[async_trait] //! impl FromRequestParts for User //! where //! S: Send + Sync, @@ -96,7 +93,6 @@ use std::task::{Context, Poll}; use axum::{ - async_trait, extract::FromRequestParts, response::{IntoResponse, Response}, }; @@ -236,7 +232,6 @@ macro_rules! impl_traits_for_either { [$($ident:ident),* $(,)?], $last:ident $(,)? ) => { - #[async_trait] impl FromRequestParts for $either<$($ident),*, $last> where $($ident: FromRequestParts),*, @@ -247,12 +242,12 @@ macro_rules! impl_traits_for_either { async fn from_request_parts(parts: &mut Parts, state: &S) -> Result { $( - if let Ok(value) = FromRequestParts::from_request_parts(parts, state).await { + if let Ok(value) = <$ident as FromRequestParts>::from_request_parts(parts, state).await { return Ok(Self::$ident(value)); } )* - FromRequestParts::from_request_parts(parts, state).await.map(Self::$last) + <$last as FromRequestParts>::from_request_parts(parts, state).await.map(Self::$last) } } diff --git a/axum-extra/src/extract/cached.rs b/axum-extra/src/extract/cached.rs index 6f7d6227b7..64b4c3056f 100644 --- a/axum-extra/src/extract/cached.rs +++ b/axum-extra/src/extract/cached.rs @@ -1,7 +1,4 @@ -use axum::{ - async_trait, - extract::{Extension, FromRequestParts}, -}; +use axum::extract::{Extension, FromRequestParts}; use http::request::Parts; /// Cache results of other extractors. @@ -19,7 +16,6 @@ use http::request::Parts; /// ```rust /// use axum_extra::extract::Cached; /// use axum::{ -/// async_trait, /// extract::FromRequestParts, /// response::{IntoResponse, Response}, /// http::{StatusCode, request::Parts}, @@ -28,7 +24,6 @@ use http::request::Parts; /// #[derive(Clone)] /// struct Session { /* ... */ } /// -/// #[async_trait] /// impl FromRequestParts for Session /// where /// S: Send + Sync, @@ -43,7 +38,6 @@ use http::request::Parts; /// /// struct CurrentUser { /* ... */ } /// -/// #[async_trait] /// impl FromRequestParts for CurrentUser /// where /// S: Send + Sync, @@ -86,7 +80,6 @@ pub struct Cached(pub T); #[derive(Clone)] struct CachedEntry(T); -#[async_trait] impl FromRequestParts for Cached where S: Send + Sync, @@ -111,8 +104,7 @@ axum_core::__impl_deref!(Cached); #[cfg(test)] mod tests { use super::*; - use axum::{extract::FromRequestParts, http::Request, routing::get, Router}; - use http::request::Parts; + use axum::{http::Request, routing::get, Router}; use std::{ convert::Infallible, sync::atomic::{AtomicU32, Ordering}, @@ -126,7 +118,6 @@ mod tests { #[derive(Clone, Debug, PartialEq, Eq)] struct Extractor(Instant); - #[async_trait] impl FromRequestParts for Extractor where S: Send + Sync, diff --git a/axum-extra/src/extract/cookie/mod.rs b/axum-extra/src/extract/cookie/mod.rs index efd2dcdf86..50fa6031ac 100644 --- a/axum-extra/src/extract/cookie/mod.rs +++ b/axum-extra/src/extract/cookie/mod.rs @@ -3,7 +3,6 @@ //! See [`CookieJar`], [`SignedCookieJar`], and [`PrivateCookieJar`] for more details. use axum::{ - async_trait, extract::FromRequestParts, response::{IntoResponse, IntoResponseParts, Response, ResponseParts}, }; @@ -90,7 +89,6 @@ pub struct CookieJar { jar: cookie::CookieJar, } -#[async_trait] impl FromRequestParts for CookieJar where S: Send + Sync, diff --git a/axum-extra/src/extract/cookie/private.rs b/axum-extra/src/extract/cookie/private.rs index 911b0ef2ec..f852b8c4ba 100644 --- a/axum-extra/src/extract/cookie/private.rs +++ b/axum-extra/src/extract/cookie/private.rs @@ -1,6 +1,5 @@ use super::{cookies_from_request, set_cookies, Cookie, Key}; use axum::{ - async_trait, extract::{FromRef, FromRequestParts}, response::{IntoResponse, IntoResponseParts, Response, ResponseParts}, }; @@ -49,11 +48,11 @@ use std::{convert::Infallible, fmt, marker::PhantomData}; /// // our application state /// #[derive(Clone)] /// struct AppState { -/// // that holds the key used to sign cookies +/// // that holds the key used to encrypt cookies /// key: Key, /// } /// -/// // this impl tells `SignedCookieJar` how to access the key from our state +/// // this impl tells `PrivateCookieJar` how to access the key from our state /// impl FromRef for Key { /// fn from_ref(state: &AppState) -> Self { /// state.key.clone() @@ -122,7 +121,6 @@ impl fmt::Debug for PrivateCookieJar { } } -#[async_trait] impl FromRequestParts for PrivateCookieJar where S: Send + Sync, @@ -291,7 +289,7 @@ struct PrivateCookieJarIter<'a, K> { iter: cookie::Iter<'a>, } -impl<'a, K> Iterator for PrivateCookieJarIter<'a, K> { +impl Iterator for PrivateCookieJarIter<'_, K> { type Item = Cookie<'static>; fn next(&mut self) -> Option { diff --git a/axum-extra/src/extract/cookie/signed.rs b/axum-extra/src/extract/cookie/signed.rs index b65df79f95..92bf917145 100644 --- a/axum-extra/src/extract/cookie/signed.rs +++ b/axum-extra/src/extract/cookie/signed.rs @@ -1,6 +1,5 @@ use super::{cookies_from_request, set_cookies}; use axum::{ - async_trait, extract::{FromRef, FromRequestParts}, response::{IntoResponse, IntoResponseParts, Response, ResponseParts}, }; @@ -139,7 +138,6 @@ impl fmt::Debug for SignedCookieJar { } } -#[async_trait] impl FromRequestParts for SignedCookieJar where S: Send + Sync, @@ -309,7 +307,7 @@ struct SignedCookieJarIter<'a, K> { iter: cookie::Iter<'a>, } -impl<'a, K> Iterator for SignedCookieJarIter<'a, K> { +impl Iterator for SignedCookieJarIter<'_, K> { type Item = Cookie<'static>; fn next(&mut self) -> Option { diff --git a/axum-extra/src/extract/form.rs b/axum-extra/src/extract/form.rs index 8729fb5a88..a7ca9305aa 100644 --- a/axum-extra/src/extract/form.rs +++ b/axum-extra/src/extract/form.rs @@ -1,5 +1,4 @@ use axum::{ - async_trait, extract::{rejection::RawFormRejection, FromRequest, RawForm, Request}, response::{IntoResponse, Response}, Error, RequestExt, @@ -44,7 +43,6 @@ pub struct Form(pub T); axum_core::__impl_deref!(Form); -#[async_trait] impl FromRequest for Form where T: DeserializeOwned, @@ -81,11 +79,16 @@ impl IntoResponse for FormRejection { fn into_response(self) -> Response { match self { Self::RawFormRejection(inner) => inner.into_response(), - Self::FailedToDeserializeForm(inner) => ( - StatusCode::BAD_REQUEST, - format!("Failed to deserialize form: {inner}"), - ) - .into_response(), + Self::FailedToDeserializeForm(inner) => { + let body = format!("Failed to deserialize form: {inner}"); + let status = StatusCode::BAD_REQUEST; + axum_core::__log_rejection!( + rejection_type = Self, + body_text = body, + status = status, + ); + (status, body).into_response() + } } } } @@ -113,7 +116,7 @@ mod tests { use super::*; use crate::test_helpers::*; use axum::{routing::post, Router}; - use http::{header::CONTENT_TYPE, StatusCode}; + use http::header::CONTENT_TYPE; use serde::Deserialize; #[tokio::test] @@ -135,7 +138,6 @@ mod tests { .post("/") .header(CONTENT_TYPE, "application/x-www-form-urlencoded") .body("value=one&value=two") - .send() .await; assert_eq!(res.status(), StatusCode::OK); diff --git a/axum/src/extract/host.rs b/axum-extra/src/extract/host.rs similarity index 74% rename from axum/src/extract/host.rs rename to axum-extra/src/extract/host.rs index d5be6a978d..a6828d3004 100644 --- a/axum/src/extract/host.rs +++ b/axum-extra/src/extract/host.rs @@ -1,29 +1,29 @@ -use super::{ - rejection::{FailedToResolveHost, HostRejection}, - FromRequestParts, -}; -use async_trait::async_trait; +use super::rejection::{FailedToResolveHost, HostRejection}; +use axum::extract::FromRequestParts; use http::{ header::{HeaderMap, FORWARDED}, request::Parts, + uri::Authority, }; const X_FORWARDED_HOST_HEADER_KEY: &str = "X-Forwarded-Host"; -/// Extractor that resolves the hostname of the request. +/// Extractor that resolves the host of the request. /// -/// Hostname is resolved through the following, in order: +/// Host is resolved through the following, in order: /// - `Forwarded` header /// - `X-Forwarded-Host` header /// - `Host` header -/// - request target / URI +/// - Authority of the request URI +/// +/// See for the definition of +/// host. /// /// Note that user agents can set `X-Forwarded-Host` and `Host` headers to arbitrary values so make /// sure to validate them to avoid security issues. #[derive(Debug, Clone)] pub struct Host(pub String); -#[async_trait] impl FromRequestParts for Host where S: Send + Sync, @@ -51,8 +51,8 @@ where return Ok(Host(host.to_owned())); } - if let Some(host) = parts.uri.host() { - return Ok(Host(host.to_owned())); + if let Some(authority) = parts.uri.authority() { + return Ok(Host(parse_authority(authority).to_owned())); } Err(HostRejection::FailedToResolveHost(FailedToResolveHost)) @@ -76,11 +76,19 @@ fn parse_forwarded(headers: &HeaderMap) -> Option<&str> { }) } +fn parse_authority(auth: &Authority) -> &str { + auth.as_str() + .rsplit('@') + .next() + .expect("split always has at least 1 item") +} + #[cfg(test)] mod tests { use super::*; - use crate::{routing::get, test_helpers::TestClient, Router}; - use http::header::HeaderName; + use crate::test_helpers::TestClient; + use axum::{routing::get, Router}; + use http::{header::HeaderName, Request}; fn test_client() -> TestClient { async fn host_as_body(Host(host): Host) -> String { @@ -96,7 +104,6 @@ mod tests { let host = test_client() .get("/") .header(http::header::HOST, original_host) - .send() .await .text() .await; @@ -109,7 +116,6 @@ mod tests { let host = test_client() .get("/") .header(X_FORWARDED_HOST_HEADER_KEY, original_host) - .send() .await .text() .await; @@ -124,7 +130,6 @@ mod tests { .get("/") .header(X_FORWARDED_HOST_HEADER_KEY, x_forwarded_host_header) .header(http::header::HOST, host_header) - .send() .await .text() .await; @@ -133,8 +138,26 @@ mod tests { #[crate::test] async fn uri_host() { - let host = test_client().get("/").send().await.text().await; - assert!(host.contains("127.0.0.1")); + let client = test_client(); + let port = client.server_port(); + let host = client.get("/").await.text().await; + assert_eq!(host, format!("127.0.0.1:{port}")); + } + + #[crate::test] + async fn ip4_uri_host() { + let mut parts = Request::new(()).into_parts().0; + parts.uri = "https://127.0.0.1:1234/image.jpg".parse().unwrap(); + let host = Host::from_request_parts(&mut parts, &()).await.unwrap(); + assert_eq!(host.0, "127.0.0.1:1234"); + } + + #[crate::test] + async fn ip6_uri_host() { + let mut parts = Request::new(()).into_parts().0; + parts.uri = "http://cool:user@[::1]:456/file.txt".parse().unwrap(); + let host = Host::from_request_parts(&mut parts, &()).await.unwrap(); + assert_eq!(host.0, "[::1]:456"); } #[test] diff --git a/axum-extra/src/extract/json_deserializer.rs b/axum-extra/src/extract/json_deserializer.rs index 0a30798755..051ab0f1bd 100644 --- a/axum-extra/src/extract/json_deserializer.rs +++ b/axum-extra/src/extract/json_deserializer.rs @@ -1,4 +1,3 @@ -use axum::async_trait; use axum::extract::{FromRequest, Request}; use axum_core::__composite_rejection as composite_rejection; use axum_core::__define_rejection as define_rejection; @@ -10,7 +9,7 @@ use std::marker::PhantomData; /// JSON Extractor for zero-copy deserialization. /// -/// Deserialize request bodies into some type that implements [`serde::Deserialize<'de>`]. +/// Deserialize request bodies into some type that implements [`serde::Deserialize<'de>`][serde::Deserialize]. /// Parsing JSON is delayed until [`deserialize`](JsonDeserializer::deserialize) is called. /// If the type implements [`serde::de::DeserializeOwned`], the [`Json`](axum::Json) extractor should /// be preferred. @@ -23,8 +22,7 @@ use std::marker::PhantomData; /// Additionally, a `JsonRejection` error will be returned, when calling `deserialize` if: /// /// - The body doesn't contain syntactically valid JSON. -/// - The body contains syntactically valid JSON, but it couldn't be deserialized into the target -/// type. +/// - The body contains syntactically valid JSON, but it couldn't be deserialized into the target type. /// - Attempting to deserialize escaped JSON into a type that must be borrowed (e.g. `&'a str`). /// /// ⚠️ `serde` will implicitly try to borrow for `&str` and `&[u8]` types, but will error if the @@ -85,7 +83,6 @@ pub struct JsonDeserializer { _marker: PhantomData, } -#[async_trait] impl FromRequest for JsonDeserializer where T: Deserialize<'static>, @@ -205,7 +202,7 @@ fn json_content_type(headers: &HeaderMap) -> bool { }; let is_json_content_type = mime.type_() == "application" - && (mime.subtype() == "json" || mime.suffix().map_or(false, |name| name == "json")); + && (mime.subtype() == "json" || mime.suffix().is_some_and(|name| name == "json")); is_json_content_type } @@ -245,7 +242,7 @@ mod tests { let app = Router::new().route("/", post(handler)); let client = TestClient::new(app); - let res = client.post("/").json(&json!({ "foo": "bar" })).send().await; + let res = client.post("/").json(&json!({ "foo": "bar" })).await; let body = res.text().await; assert_eq!(body, "bar"); @@ -277,11 +274,7 @@ mod tests { let client = TestClient::new(app); // The escaped characters prevent serde_json from borrowing. - let res = client - .post("/") - .json(&json!({ "foo": "\"bar\"" })) - .send() - .await; + let res = client.post("/").json(&json!({ "foo": "\"bar\"" })).await; let body = res.text().await; @@ -308,19 +301,11 @@ mod tests { let client = TestClient::new(app); - let res = client - .post("/") - .json(&json!({ "foo": "good" })) - .send() - .await; + let res = client.post("/").json(&json!({ "foo": "good" })).await; let body = res.text().await; assert_eq!(body, "good"); - let res = client - .post("/") - .json(&json!({ "foo": "\"bad\"" })) - .send() - .await; + let res = client.post("/").json(&json!({ "foo": "\"bad\"" })).await; assert_eq!(res.status(), StatusCode::UNPROCESSABLE_ENTITY); let body_text = res.text().await; assert_eq!( @@ -344,7 +329,7 @@ mod tests { let app = Router::new().route("/", post(handler)); let client = TestClient::new(app); - let res = client.post("/").body(r#"{ "foo": "bar" }"#).send().await; + let res = client.post("/").body(r#"{ "foo": "bar" }"#).await; let status = res.status(); @@ -366,7 +351,6 @@ mod tests { .post("/") .header("content-type", content_type) .body("{}") - .send() .await; res.status() == StatusCode::OK @@ -395,7 +379,6 @@ mod tests { .post("/") .body("{") .header("content-type", "application/json") - .send() .await; assert_eq!(res.status(), StatusCode::BAD_REQUEST); @@ -433,7 +416,6 @@ mod tests { .post("/") .body("{\"a\": 1, \"b\": [{\"x\": 2}]}") .header("content-type", "application/json") - .send() .await; assert_eq!(res.status(), StatusCode::UNPROCESSABLE_ENTITY); diff --git a/axum-extra/src/extract/mod.rs b/axum-extra/src/extract/mod.rs index 1f9974de02..7d2a5b2433 100644 --- a/axum-extra/src/extract/mod.rs +++ b/axum-extra/src/extract/mod.rs @@ -1,7 +1,9 @@ //! Additional extractors. mod cached; +mod host; mod optional_path; +pub mod rejection; mod with_rejection; #[cfg(feature = "form")] @@ -19,7 +21,12 @@ mod query; #[cfg(feature = "multipart")] pub mod multipart; -pub use self::{cached::Cached, optional_path::OptionalPath, with_rejection::WithRejection}; +#[cfg(feature = "scheme")] +mod scheme; + +pub use self::{ + cached::Cached, host::Host, optional_path::OptionalPath, with_rejection::WithRejection, +}; #[cfg(feature = "cookie")] pub use self::cookie::CookieJar; @@ -39,6 +46,10 @@ pub use self::query::{OptionalQuery, OptionalQueryRejection, Query, QueryRejecti #[cfg(feature = "multipart")] pub use self::multipart::Multipart; +#[cfg(feature = "scheme")] +#[doc(no_inline)] +pub use self::scheme::{Scheme, SchemeMissing}; + #[cfg(feature = "json-deserializer")] pub use self::json_deserializer::{ JsonDataError, JsonDeserializer, JsonDeserializerRejection, JsonSyntaxError, diff --git a/axum-extra/src/extract/multipart.rs b/axum-extra/src/extract/multipart.rs index 8c78a77722..bcfae64eee 100644 --- a/axum-extra/src/extract/multipart.rs +++ b/axum-extra/src/extract/multipart.rs @@ -3,7 +3,6 @@ //! See [`Multipart`] for more details. use axum::{ - async_trait, body::{Body, Bytes}, extract::FromRequest, response::{IntoResponse, Response}, @@ -75,7 +74,7 @@ use std::{ /// to keep `Field`s around from previous loop iterations. That will minimize the risk of runtime /// errors. /// -/// # Differences between this and `axum::extract::Multipart` +/// # Differences between this and `axum::extract::Multipart` /// /// `axum::extract::Multipart` uses lifetimes to enforce field exclusivity at compile time, however /// that leads to significant usability issues such as `Field` not being `'static`. @@ -90,7 +89,6 @@ pub struct Multipart { inner: multer::Multipart<'static>, } -#[async_trait] impl FromRequest for Multipart where S: Send + Sync, @@ -379,7 +377,13 @@ pub struct InvalidBoundary; impl IntoResponse for InvalidBoundary { fn into_response(self) -> Response { - (self.status(), self.body_text()).into_response() + let body = self.body_text(); + axum_core::__log_rejection!( + rejection_type = Self, + body_text = body, + status = self.status(), + ); + (self.status(), body).into_response() } } @@ -407,7 +411,7 @@ impl std::error::Error for InvalidBoundary {} mod tests { use super::*; use crate::test_helpers::*; - use axum::{extract::DefaultBodyLimit, response::IntoResponse, routing::post, Router}; + use axum::{extract::DefaultBodyLimit, routing::post, Router}; #[tokio::test] async fn content_type_with_encoding() { @@ -437,7 +441,7 @@ mod tests { .unwrap(), ); - client.post("/").multipart(form).send().await; + client.post("/").multipart(form).await; } // No need for this to be a #[test], we just want to make sure it compiles @@ -466,7 +470,7 @@ mod tests { let form = reqwest::multipart::Form::new().part("file", reqwest::multipart::Part::bytes(BYTES)); - let res = client.post("/").multipart(form).send().await; + let res = client.post("/").multipart(form).await; assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE); } } diff --git a/axum-extra/src/extract/optional_path.rs b/axum-extra/src/extract/optional_path.rs index ca1634500e..53443f1952 100644 --- a/axum-extra/src/extract/optional_path.rs +++ b/axum-extra/src/extract/optional_path.rs @@ -1,5 +1,4 @@ use axum::{ - async_trait, extract::{path::ErrorKind, rejection::PathRejection, FromRequestParts, Path}, RequestPartsExt, }; @@ -29,13 +28,12 @@ use serde::de::DeserializeOwned; /// /// let app = Router::new() /// .route("/blog", get(render_blog)) -/// .route("/blog/:page", get(render_blog)); +/// .route("/blog/{page}", get(render_blog)); /// # let app: Router = app; /// ``` #[derive(Debug)] pub struct OptionalPath(pub Option); -#[async_trait] impl FromRequestParts for OptionalPath where T: DeserializeOwned + Send + 'static, @@ -77,26 +75,26 @@ mod tests { let app = Router::new() .route("/", get(handle)) - .route("/:num", get(handle)); + .route("/{num}", get(handle)); let client = TestClient::new(app); - let res = client.get("/").send().await; + let res = client.get("/").await; assert_eq!(res.text().await, "Success: 0"); - let res = client.get("/1").send().await; + let res = client.get("/1").await; assert_eq!(res.text().await, "Success: 1"); - let res = client.get("/0").send().await; + let res = client.get("/0").await; assert_eq!( res.text().await, "Invalid URL: invalid value: integer `0`, expected a nonzero u32" ); - let res = client.get("/NaN").send().await; + let res = client.get("/NaN").await; assert_eq!( res.text().await, - "Invalid URL: Cannot parse `\"NaN\"` to a `u32`" + "Invalid URL: Cannot parse `NaN` to a `u32`" ); } } diff --git a/axum-extra/src/extract/query.rs b/axum-extra/src/extract/query.rs index bdeaf78e8a..695ea9576b 100644 --- a/axum-extra/src/extract/query.rs +++ b/axum-extra/src/extract/query.rs @@ -1,5 +1,4 @@ use axum::{ - async_trait, extract::FromRequestParts, response::{IntoResponse, Response}, Error, @@ -51,11 +50,37 @@ use std::fmt; /// example. /// /// [example]: https://github.com/tokio-rs/axum/blob/main/examples/query-params-with-empty-strings/src/main.rs +/// +/// While `Option` will handle empty parameters (e.g. `param=`), beware when using this with a +/// `Vec`. If your list is optional, use `Vec` in combination with `#[serde(default)]` +/// instead of `Option>`. `Option>` will handle 0, 2, or more arguments, but not one +/// argument. +/// +/// # Example +/// +/// ```rust,no_run +/// use axum::{routing::get, Router}; +/// use axum_extra::extract::Query; +/// use serde::Deserialize; +/// +/// #[derive(Deserialize)] +/// struct Params { +/// #[serde(default)] +/// items: Vec, +/// } +/// +/// // This will parse 0 occurrences of `items` as an empty `Vec`. +/// async fn process_items(Query(params): Query) { +/// // ... +/// } +/// +/// let app = Router::new().route("/process_items", get(process_items)); +/// # let _: Router = app; +/// ``` #[cfg_attr(docsrs, doc(cfg(feature = "query")))] #[derive(Debug, Clone, Copy, Default)] pub struct Query(pub T); -#[async_trait] impl FromRequestParts for Query where T: DeserializeOwned, @@ -87,11 +112,16 @@ pub enum QueryRejection { impl IntoResponse for QueryRejection { fn into_response(self) -> Response { match self { - Self::FailedToDeserializeQueryString(inner) => ( - StatusCode::BAD_REQUEST, - format!("Failed to deserialize query string: {inner}"), - ) - .into_response(), + Self::FailedToDeserializeQueryString(inner) => { + let body = format!("Failed to deserialize query string: {inner}"); + let status = StatusCode::BAD_REQUEST; + axum_core::__log_rejection!( + rejection_type = Self, + body_text = body, + status = status, + ); + (status, body).into_response() + } } } } @@ -155,7 +185,6 @@ impl std::error::Error for QueryRejection { #[derive(Debug, Clone, Copy, Default)] pub struct OptionalQuery(pub Option); -#[async_trait] impl FromRequestParts for OptionalQuery where T: DeserializeOwned, @@ -235,7 +264,7 @@ mod tests { use super::*; use crate::test_helpers::*; use axum::{routing::post, Router}; - use http::{header::CONTENT_TYPE, StatusCode}; + use http::header::CONTENT_TYPE; use serde::Deserialize; #[tokio::test] @@ -257,7 +286,6 @@ mod tests { .post("/?value=one&value=two") .header(CONTENT_TYPE, "application/x-www-form-urlencoded") .body("") - .send() .await; assert_eq!(res.status(), StatusCode::OK); @@ -286,7 +314,6 @@ mod tests { .post("/?value=one&value=two") .header(CONTENT_TYPE, "application/x-www-form-urlencoded") .body("") - .send() .await; assert_eq!(res.status(), StatusCode::OK); @@ -312,7 +339,7 @@ mod tests { let client = TestClient::new(app); - let res = client.post("/").body("").send().await; + let res = client.post("/").body("").await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "None"); @@ -341,7 +368,6 @@ mod tests { .post("/?other=something") .header(CONTENT_TYPE, "application/x-www-form-urlencoded") .body("") - .send() .await; assert_eq!(res.status(), StatusCode::BAD_REQUEST); diff --git a/axum-extra/src/extract/rejection.rs b/axum-extra/src/extract/rejection.rs new file mode 100644 index 0000000000..f48473293b --- /dev/null +++ b/axum-extra/src/extract/rejection.rs @@ -0,0 +1,23 @@ +//! Rejection response types. + +use axum_core::{ + __composite_rejection as composite_rejection, __define_rejection as define_rejection, +}; + +define_rejection! { + #[status = BAD_REQUEST] + #[body = "No host found in request"] + /// Rejection type used if the [`Host`](super::Host) extractor is unable to + /// resolve a host. + pub struct FailedToResolveHost; +} + +composite_rejection! { + /// Rejection used for [`Host`](super::Host). + /// + /// Contains one variant for each way the [`Host`](super::Host) extractor + /// can fail. + pub enum HostRejection { + FailedToResolveHost, + } +} diff --git a/axum-extra/src/extract/scheme.rs b/axum-extra/src/extract/scheme.rs new file mode 100644 index 0000000000..891d5c0bdd --- /dev/null +++ b/axum-extra/src/extract/scheme.rs @@ -0,0 +1,152 @@ +//! Extractor that parses the scheme of a request. +//! See [`Scheme`] for more details. + +use axum::{ + extract::FromRequestParts, + response::{IntoResponse, Response}, +}; +use http::{ + header::{HeaderMap, FORWARDED}, + request::Parts, +}; +const X_FORWARDED_PROTO_HEADER_KEY: &str = "X-Forwarded-Proto"; + +/// Extractor that resolves the scheme / protocol of a request. +/// +/// The scheme is resolved through the following, in order: +/// - `Forwarded` header +/// - `X-Forwarded-Proto` header +/// - Request URI (If the request is an HTTP/2 request! e.g. use `--http2(-prior-knowledge)` with cURL) +/// +/// Note that user agents can set the `X-Forwarded-Proto` header to arbitrary values so make +/// sure to validate them to avoid security issues. +#[derive(Debug, Clone)] +pub struct Scheme(pub String); + +/// Rejection type used if the [`Scheme`] extractor is unable to +/// resolve a scheme. +#[derive(Debug)] +pub struct SchemeMissing; + +impl IntoResponse for SchemeMissing { + fn into_response(self) -> Response { + (http::StatusCode::BAD_REQUEST, "No scheme found in request").into_response() + } +} + +impl FromRequestParts for Scheme +where + S: Send + Sync, +{ + type Rejection = SchemeMissing; + + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { + // Within Forwarded header + if let Some(scheme) = parse_forwarded(&parts.headers) { + return Ok(Scheme(scheme.to_owned())); + } + + // X-Forwarded-Proto + if let Some(scheme) = parts + .headers + .get(X_FORWARDED_PROTO_HEADER_KEY) + .and_then(|scheme| scheme.to_str().ok()) + { + return Ok(Scheme(scheme.to_owned())); + } + + // From parts of an HTTP/2 request + if let Some(scheme) = parts.uri.scheme_str() { + return Ok(Scheme(scheme.to_owned())); + } + + Err(SchemeMissing) + } +} + +fn parse_forwarded(headers: &HeaderMap) -> Option<&str> { + // if there are multiple `Forwarded` `HeaderMap::get` will return the first one + let forwarded_values = headers.get(FORWARDED)?.to_str().ok()?; + + // get the first set of values + let first_value = forwarded_values.split(',').next()?; + + // find the value of the `proto` field + first_value.split(';').find_map(|pair| { + let (key, value) = pair.split_once('=')?; + key.trim() + .eq_ignore_ascii_case("proto") + .then(|| value.trim().trim_matches('"')) + }) +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::test_helpers::TestClient; + use axum::{routing::get, Router}; + use http::header::HeaderName; + + fn test_client() -> TestClient { + async fn scheme_as_body(Scheme(scheme): Scheme) -> String { + scheme + } + + TestClient::new(Router::new().route("/", get(scheme_as_body))) + } + + #[crate::test] + async fn forwarded_scheme_parsing() { + // the basic case + let headers = header_map(&[(FORWARDED, "host=192.0.2.60;proto=http;by=203.0.113.43")]); + let value = parse_forwarded(&headers).unwrap(); + assert_eq!(value, "http"); + + // is case insensitive + let headers = header_map(&[(FORWARDED, "host=192.0.2.60;PROTO=https;by=203.0.113.43")]); + let value = parse_forwarded(&headers).unwrap(); + assert_eq!(value, "https"); + + // multiple values in one header + let headers = header_map(&[(FORWARDED, "proto=ftp, proto=https")]); + let value = parse_forwarded(&headers).unwrap(); + assert_eq!(value, "ftp"); + + // multiple header values + let headers = header_map(&[(FORWARDED, "proto=ftp"), (FORWARDED, "proto=https")]); + let value = parse_forwarded(&headers).unwrap(); + assert_eq!(value, "ftp"); + } + + #[crate::test] + async fn x_forwarded_scheme_header() { + let original_scheme = "https"; + let scheme = test_client() + .get("/") + .header(X_FORWARDED_PROTO_HEADER_KEY, original_scheme) + .await + .text() + .await; + assert_eq!(scheme, original_scheme); + } + + #[crate::test] + async fn precedence_forwarded_over_x_forwarded() { + let scheme = test_client() + .get("/") + .header(X_FORWARDED_PROTO_HEADER_KEY, "https") + .header(FORWARDED, "proto=ftp") + .await + .text() + .await; + assert_eq!(scheme, "ftp"); + } + + fn header_map(values: &[(HeaderName, &str)]) -> HeaderMap { + let mut headers = HeaderMap::new(); + for (key, value) in values { + headers.append(key, value.parse().unwrap()); + } + headers + } +} diff --git a/axum-extra/src/extract/with_rejection.rs b/axum-extra/src/extract/with_rejection.rs index 1227a1ab13..c093f6fa47 100644 --- a/axum-extra/src/extract/with_rejection.rs +++ b/axum-extra/src/extract/with_rejection.rs @@ -1,11 +1,13 @@ -use axum::async_trait; use axum::extract::{FromRequest, FromRequestParts, Request}; use axum::response::IntoResponse; use http::request::Parts; -use std::fmt::Debug; +use std::fmt::{Debug, Display}; use std::marker::PhantomData; use std::ops::{Deref, DerefMut}; +#[cfg(feature = "typed-routing")] +use crate::routing::TypedPath; + /// Extractor for customizing extractor rejections /// /// `WithRejection` wraps another extractor and gives you the result. If the @@ -107,7 +109,6 @@ impl DerefMut for WithRejection { } } -#[async_trait] impl FromRequest for WithRejection where S: Send + Sync, @@ -122,7 +123,6 @@ where } } -#[async_trait] impl FromRequestParts for WithRejection where S: Send + Sync, @@ -137,22 +137,35 @@ where } } +#[cfg(feature = "typed-routing")] +impl TypedPath for WithRejection +where + E: TypedPath, +{ + const PATH: &'static str = E::PATH; +} + +impl Display for WithRejection +where + E: Display, +{ + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "{}", self.0) + } +} + #[cfg(test)] mod tests { + use super::*; use axum::body::Body; - use axum::extract::FromRequestParts; use axum::http::Request; use axum::response::Response; - use http::request::Parts; - - use super::*; #[tokio::test] async fn extractor_rejection_is_transformed() { struct TestExtractor; struct TestRejection; - #[async_trait] impl FromRequestParts for TestExtractor where S: Send + Sync, diff --git a/axum-extra/src/handler/mod.rs b/axum-extra/src/handler/mod.rs index 4017e149a6..571ab67707 100644 --- a/axum-extra/src/handler/mod.rs +++ b/axum-extra/src/handler/mod.rs @@ -47,7 +47,6 @@ pub trait HandlerCallWithExtractors: Sized { /// use axum_extra::handler::HandlerCallWithExtractors; /// use axum::{ /// Router, - /// async_trait, /// routing::get, /// extract::FromRequestParts, /// }; @@ -68,7 +67,6 @@ pub trait HandlerCallWithExtractors: Sized { /// // extractors for checking permissions /// struct AdminPermissions {} /// - /// #[async_trait] /// impl FromRequestParts for AdminPermissions /// where /// S: Send + Sync, @@ -82,7 +80,6 @@ pub trait HandlerCallWithExtractors: Sized { /// /// struct User {} /// - /// #[async_trait] /// impl FromRequestParts for User /// where /// S: Send + Sync, @@ -95,7 +92,7 @@ pub trait HandlerCallWithExtractors: Sized { /// } /// /// let app = Router::new().route( - /// "/users/:id", + /// "/users/{id}", /// get( /// // first try `admin`, if that rejects run `user`, finally falling back /// // to `guest` @@ -168,7 +165,7 @@ pub struct IntoHandler { impl Handler for IntoHandler where - H: HandlerCallWithExtractors + Clone + Send + 'static, + H: HandlerCallWithExtractors + Clone + Send + Sync + 'static, T: FromRequest + Send + 'static, T::Rejection: Send, S: Send + Sync + 'static, diff --git a/axum-extra/src/handler/or.rs b/axum-extra/src/handler/or.rs index 7b78fe3d73..f15ccc70b0 100644 --- a/axum-extra/src/handler/or.rs +++ b/axum-extra/src/handler/or.rs @@ -54,8 +54,8 @@ where impl Handler<(M, Lt, Rt), S> for Or where - L: HandlerCallWithExtractors + Clone + Send + 'static, - R: HandlerCallWithExtractors + Clone + Send + 'static, + L: HandlerCallWithExtractors + Clone + Send + Sync + 'static, + R: HandlerCallWithExtractors + Clone + Send + Sync + 'static, Lt: FromRequestParts + Send + 'static, Rt: FromRequest + Send + 'static, Lt::Rejection: Send, @@ -134,17 +134,17 @@ mod tests { "fallback" } - let app = Router::new().route("/:id", get(one.or(two).or(three))); + let app = Router::new().route("/{id}", get(one.or(two).or(three))); let client = TestClient::new(app); - let res = client.get("/123").send().await; + let res = client.get("/123").await; assert_eq!(res.text().await, "123"); - let res = client.get("/foo?a=bar").send().await; + let res = client.get("/foo?a=bar").await; assert_eq!(res.text().await, "bar"); - let res = client.get("/foo").send().await; + let res = client.get("/foo").await; assert_eq!(res.text().await, "fallback"); } } diff --git a/axum-extra/src/json_lines.rs b/axum-extra/src/json_lines.rs index d72c23b6c6..7c513f96cd 100644 --- a/axum-extra/src/json_lines.rs +++ b/axum-extra/src/json_lines.rs @@ -1,7 +1,6 @@ //! Newline delimited JSON extractor and response. use axum::{ - async_trait, body::Body, extract::{FromRequest, Request}, response::{IntoResponse, Response}, @@ -55,7 +54,7 @@ pin_project! { /// JsonLines::new(stream_of_values()).into_response() /// } /// ``` - // we use `AsExtractor` as the default because you're more likely to name this type if its used + // we use `AsExtractor` as the default because you're more likely to name this type if it's used // as an extractor #[must_use] pub struct JsonLines { @@ -99,7 +98,6 @@ impl JsonLines { } } -#[async_trait] impl FromRequest for JsonLines where T: DeserializeOwned, @@ -184,7 +182,7 @@ mod tests { use futures_util::StreamExt; use http::StatusCode; use serde::Deserialize; - use std::{convert::Infallible, error::Error}; + use std::error::Error; #[derive(Serialize, Deserialize, PartialEq, Eq, Debug)] struct User { @@ -224,7 +222,6 @@ mod tests { ] .join("\n"), ) - .send() .await; assert_eq!(res.status(), StatusCode::OK); } @@ -245,7 +242,7 @@ mod tests { let client = TestClient::new(app); - let res = client.get("/").send().await; + let res = client.get("/").await; let values = res .text() diff --git a/axum-extra/src/lib.rs b/axum-extra/src/lib.rs index eb93b0a312..473c42742c 100644 --- a/axum-extra/src/lib.rs +++ b/axum-extra/src/lib.rs @@ -9,20 +9,22 @@ //! //! Name | Description | Default? //! ---|---|--- -//! `async-read-body` | Enables the `AsyncReadBody` body | No -//! `cookie` | Enables the `CookieJar` extractor | No -//! `cookie-private` | Enables the `PrivateCookieJar` extractor | No -//! `cookie-signed` | Enables the `SignedCookieJar` extractor | No -//! `cookie-key-expansion` | Enables the `Key::derive_from` method | No -//! `erased-json` | Enables the `ErasedJson` response | No -//! `form` | Enables the `Form` extractor | No -//! `json-deserializer` | Enables the `JsonDeserializer` extractor | No -//! `json-lines` | Enables the `JsonLines` extractor and response | No -//! `multipart` | Enables the `Multipart` extractor | No -//! `protobuf` | Enables the `Protobuf` extractor and response | No -//! `query` | Enables the `Query` extractor | No -//! `typed-routing` | Enables the `TypedPath` routing utilities | No -//! `typed-header` | Enables the `TypedHeader` extractor and response | No +//! `async-read-body` | Enables the [`AsyncReadBody`](crate::body::AsyncReadBody) body | No +//! `attachment` | Enables the [`Attachment`](crate::response::Attachment) response | No +//! `cookie` | Enables the [`CookieJar`](crate::extract::CookieJar) extractor | No +//! `cookie-private` | Enables the [`PrivateCookieJar`](crate::extract::PrivateCookieJar) extractor | No +//! `cookie-signed` | Enables the [`SignedCookieJar`](crate::extract::SignedCookieJar) extractor | No +//! `cookie-key-expansion` | Enables the [`Key::derive_from`](crate::extract::cookie::Key::derive_from) method | No +//! `erased-json` | Enables the [`ErasedJson`](crate::response::ErasedJson) response | No +//! `form` | Enables the [`Form`](crate::extract::Form) extractor | No +//! `json-deserializer` | Enables the [`JsonDeserializer`](crate::extract::JsonDeserializer) extractor | No +//! `json-lines` | Enables the [`JsonLines`](crate::extract::JsonLines) extractor and response | No +//! `multipart` | Enables the [`Multipart`](crate::extract::Multipart) extractor | No +//! `protobuf` | Enables the [`Protobuf`](crate::protobuf::Protobuf) extractor and response | No +//! `query` | Enables the [`Query`](crate::extract::Query) extractor | No +//! `tracing` | Log rejections from built-in extractors | Yes +//! `typed-routing` | Enables the [`TypedPath`](crate::routing::TypedPath) routing utilities | No +//! `typed-header` | Enables the [`TypedHeader`] extractor and response | No //! //! [`axum`]: https://crates.io/crates/axum @@ -39,7 +41,6 @@ clippy::needless_borrow, clippy::match_wildcard_for_single_variants, clippy::if_let_mutex, - clippy::mismatched_target_os, clippy::await_holding_lock, clippy::match_on_vec_items, clippy::imprecise_flops, @@ -96,11 +97,10 @@ pub use typed_header::TypedHeader; #[cfg(feature = "protobuf")] pub mod protobuf; +/// _not_ public API #[cfg(feature = "typed-routing")] #[doc(hidden)] pub mod __private { - //! _not_ public API - use percent_encoding::{AsciiSet, CONTROLS}; pub use percent_encoding::utf8_percent_encode; @@ -115,9 +115,8 @@ pub mod __private { use axum_macros::__private_axum_test as test; #[cfg(test)] +#[allow(unused_imports)] pub(crate) mod test_helpers { - #![allow(unused_imports)] - use axum::{extract::Request, response::Response, serve}; mod test_client { diff --git a/axum-extra/src/protobuf.rs b/axum-extra/src/protobuf.rs index 912f4a1d3f..faaca151ce 100644 --- a/axum-extra/src/protobuf.rs +++ b/axum-extra/src/protobuf.rs @@ -1,7 +1,6 @@ //! Protocol Buffer extractor and response. use axum::{ - async_trait, extract::{rejection::BytesRejection, FromRequest, Request}, response::{IntoResponse, IntoResponseFailed, Response}, }; @@ -82,7 +81,7 @@ use prost::Message; /// # unimplemented!() /// } /// -/// let app = Router::new().route("/users/:id", get(get_user)); +/// let app = Router::new().route("/users/{id}", get(get_user)); /// # let _: Router = app; /// ``` #[derive(Debug, Clone, Copy, Default)] @@ -90,7 +89,6 @@ use prost::Message; #[must_use] pub struct Protobuf(pub T); -#[async_trait] impl FromRequest for Protobuf where T: Message + Default, @@ -206,7 +204,6 @@ mod tests { use super::*; use crate::test_helpers::*; use axum::{routing::post, Router}; - use http::StatusCode; #[tokio::test] async fn decode_body() { @@ -226,7 +223,7 @@ mod tests { }; let client = TestClient::new(app); - let res = client.post("/").body(input.encode_to_vec()).send().await; + let res = client.post("/").body(input.encode_to_vec()).await; let body = res.text().await; @@ -254,7 +251,7 @@ mod tests { }; let client = TestClient::new(app); - let res = client.post("/").body(input.encode_to_vec()).send().await; + let res = client.post("/").body(input.encode_to_vec()).await; assert_eq!(res.status(), StatusCode::UNPROCESSABLE_ENTITY); } @@ -289,7 +286,7 @@ mod tests { }; let client = TestClient::new(app); - let res = client.post("/").body(input.encode_to_vec()).send().await; + let res = client.post("/").body(input.encode_to_vec()).await; assert_eq!( res.headers()["content-type"], diff --git a/axum-extra/src/response/attachment.rs b/axum-extra/src/response/attachment.rs new file mode 100644 index 0000000000..2063d30f05 --- /dev/null +++ b/axum-extra/src/response/attachment.rs @@ -0,0 +1,103 @@ +use axum::response::IntoResponse; +use http::{header, HeaderMap, HeaderValue}; +use tracing::error; + +/// A file attachment response. +/// +/// This type will set the `Content-Disposition` header to `attachment`. In response a webbrowser +/// will offer to download the file instead of displaying it directly. +/// +/// Use the `filename` and `content_type` methods to set the filename or content-type of the +/// attachment. If these values are not set they will not be sent. +/// +/// +/// # Example +/// +/// ```rust +/// use axum::{http::StatusCode, routing::get, Router}; +/// use axum_extra::response::Attachment; +/// +/// async fn cargo_toml() -> Result, (StatusCode, String)> { +/// let file_contents = tokio::fs::read_to_string("Cargo.toml") +/// .await +/// .map_err(|err| (StatusCode::NOT_FOUND, format!("File not found: {err}")))?; +/// Ok(Attachment::new(file_contents) +/// .filename("Cargo.toml") +/// .content_type("text/x-toml")) +/// } +/// +/// let app = Router::new().route("/Cargo.toml", get(cargo_toml)); +/// let _: Router = app; +/// ``` +/// +/// # Note +/// +/// If you use axum with hyper, hyper will set the `Content-Length` if it is known. +#[derive(Debug)] +#[must_use] +pub struct Attachment { + inner: T, + filename: Option, + content_type: Option, +} + +impl Attachment { + /// Creates a new [`Attachment`]. + pub fn new(inner: T) -> Self { + Self { + inner, + filename: None, + content_type: None, + } + } + + /// Sets the filename of the [`Attachment`]. + /// + /// This updates the `Content-Disposition` header to add a filename. + pub fn filename>(mut self, value: H) -> Self { + self.filename = if let Ok(filename) = value.try_into() { + Some(filename) + } else { + error!("Attachment filename contains invalid characters"); + None + }; + self + } + + /// Sets the content-type of the [`Attachment`] + pub fn content_type>(mut self, value: H) -> Self { + if let Ok(content_type) = value.try_into() { + self.content_type = Some(content_type); + } else { + error!("Attachment content-type contains invalid characters"); + } + self + } +} + +impl IntoResponse for Attachment +where + T: IntoResponse, +{ + fn into_response(self) -> axum::response::Response { + let mut headers = HeaderMap::new(); + + if let Some(content_type) = self.content_type { + headers.append(header::CONTENT_TYPE, content_type); + } + + let content_disposition = if let Some(filename) = self.filename { + let mut bytes = b"attachment; filename=\"".to_vec(); + bytes.extend_from_slice(filename.as_bytes()); + bytes.push(b'\"'); + + HeaderValue::from_bytes(&bytes).expect("This was a HeaderValue so this can not fail") + } else { + HeaderValue::from_static("attachment") + }; + + headers.append(header::CONTENT_DISPOSITION, content_disposition); + + (headers, self.inner).into_response() + } +} diff --git a/axum-extra/src/response/erased_json.rs b/axum-extra/src/response/erased_json.rs index f820b1227c..781f3aeb1e 100644 --- a/axum-extra/src/response/erased_json.rs +++ b/axum-extra/src/response/erased_json.rs @@ -12,6 +12,15 @@ use serde::Serialize; /// This allows returning a borrowing type from a handler, or returning different response /// types as JSON from different branches inside a handler. /// +/// Like [`axum::Json`], +/// if the [`Serialize`] implementation fails +/// or if a map with non-string keys is used, +/// a 500 response will be issued +/// whose body is the error message in UTF-8. +/// +/// This can be constructed using [`new`](ErasedJson::new) +/// or the [`json!`](crate::json) macro. +/// /// # Example /// /// ```rust @@ -77,3 +86,65 @@ impl IntoResponse for ErasedJson { } } } + +/// Construct an [`ErasedJson`] response from a JSON literal. +/// +/// A `Content-Type: application/json` header is automatically added. +/// Any variable or expression implementing [`Serialize`] +/// can be interpolated as a value in the literal. +/// If the [`Serialize`] implementation fails, +/// or if a map with non-string keys is used, +/// a 500 response will be issued +/// whose body is the error message in UTF-8. +/// +/// Internally, +/// this function uses the [`typed_json::json!`] macro, +/// allowing it to perform far fewer allocations +/// than a dynamic macro like [`serde_json::json!`] would – +/// it's equivalent to if you had just written +/// `derive(Serialize)` on a struct. +/// +/// # Examples +/// +/// ``` +/// use axum::{ +/// Router, +/// extract::Path, +/// response::Response, +/// routing::get, +/// }; +/// use axum_extra::response::ErasedJson; +/// +/// async fn get_user(Path(user_id) : Path) -> ErasedJson { +/// let user_name = find_user_name(user_id).await; +/// axum_extra::json!({ "name": user_name }) +/// } +/// +/// async fn find_user_name(user_id: u64) -> String { +/// // ... +/// # unimplemented!() +/// } +/// +/// let app = Router::new().route("/users/{id}", get(get_user)); +/// # let _: Router = app; +/// ``` +/// +/// Trailing commas are allowed in both arrays and objects. +/// +/// ``` +/// let response = axum_extra::json!(["trailing",]); +/// ``` +#[macro_export] +macro_rules! json { + ($($t:tt)*) => { + $crate::response::ErasedJson::new( + $crate::response::__private_erased_json::typed_json::json!($($t)*) + ) + } +} + +/// Not public API. Re-exported as `crate::response::__private_erased_json`. +#[doc(hidden)] +pub mod private { + pub use typed_json; +} diff --git a/axum-extra/src/response/error_response.rs b/axum-extra/src/response/error_response.rs new file mode 100644 index 0000000000..0706950555 --- /dev/null +++ b/axum-extra/src/response/error_response.rs @@ -0,0 +1,51 @@ +use axum_core::response::{IntoResponse, Response}; +use http::StatusCode; +use std::error::Error; +use tracing::error; + +/// Convenience response to create an error response from a non-[`IntoResponse`] error +/// +/// This provides a method to quickly respond with an error that does not implement +/// the `IntoResponse` trait itself. This type should only be used for debugging purposes or internal +/// facing applications, as it includes the full error chain with descriptions, +/// thus leaking information that could possibly be sensitive. +/// +/// ```rust +/// use axum_extra::response::InternalServerError; +/// use axum_core::response::IntoResponse; +/// # use std::io::{Error, ErrorKind}; +/// # fn try_thing() -> Result<(), Error> { +/// # Err(Error::new(ErrorKind::Other, "error")) +/// # } +/// +/// async fn maybe_error() -> Result> { +/// try_thing().map_err(InternalServerError)?; +/// // do something on success +/// # Ok(String::from("ok")) +/// } +/// ``` +#[derive(Debug)] +pub struct InternalServerError(pub T); + +impl IntoResponse for InternalServerError { + fn into_response(self) -> Response { + error!(error = &self.0 as &dyn Error); + ( + StatusCode::INTERNAL_SERVER_ERROR, + "An error occurred while processing your request.", + ) + .into_response() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use std::io::{Error, ErrorKind}; + + #[test] + fn internal_server_error() { + let response = InternalServerError(Error::new(ErrorKind::Other, "Test")).into_response(); + assert_eq!(response.status(), StatusCode::INTERNAL_SERVER_ERROR); + } +} diff --git a/axum-extra/src/response/mod.rs b/axum-extra/src/response/mod.rs index dda382cf02..bac7d040fe 100644 --- a/axum-extra/src/response/mod.rs +++ b/axum-extra/src/response/mod.rs @@ -3,13 +3,33 @@ #[cfg(feature = "erased-json")] mod erased_json; +#[cfg(feature = "attachment")] +mod attachment; + +#[cfg(feature = "multipart")] +pub mod multiple; + +#[cfg(feature = "error_response")] +mod error_response; + +#[cfg(feature = "error_response")] +pub use error_response::InternalServerError; + #[cfg(feature = "erased-json")] pub use erased_json::ErasedJson; +/// _not_ public API +#[cfg(feature = "erased-json")] +#[doc(hidden)] +pub use erased_json::private as __private_erased_json; + #[cfg(feature = "json-lines")] #[doc(no_inline)] pub use crate::json_lines::JsonLines; +#[cfg(feature = "attachment")] +pub use attachment::Attachment; + macro_rules! mime_response { ( $(#[$m:meta])* @@ -57,14 +77,6 @@ macro_rules! mime_response { }; } -mime_response! { - /// A HTML response. - /// - /// Will automatically get `Content-Type: text/html; charset=utf-8`. - Html, - TEXT_HTML_UTF_8, -} - mime_response! { /// A JavaScript response. /// diff --git a/axum-extra/src/response/multiple.rs b/axum-extra/src/response/multiple.rs new file mode 100644 index 0000000000..250dc02457 --- /dev/null +++ b/axum-extra/src/response/multiple.rs @@ -0,0 +1,295 @@ +//! Generate forms to use in responses. + +use axum::response::{IntoResponse, Response}; +use fastrand; +use http::{header, HeaderMap, StatusCode}; +use mime::Mime; + +/// Create multipart forms to be used in API responses. +/// +/// This struct implements [`IntoResponse`], and so it can be returned from a handler. +#[derive(Debug)] +pub struct MultipartForm { + parts: Vec, +} + +impl MultipartForm { + /// Initialize a new multipart form with the provided vector of parts. + /// + /// # Examples + /// + /// ```rust + /// use axum_extra::response::multiple::{MultipartForm, Part}; + /// + /// let parts: Vec = vec![Part::text("foo".to_string(), "abc"), Part::text("bar".to_string(), "def")]; + /// let form = MultipartForm::with_parts(parts); + /// ``` + pub fn with_parts(parts: Vec) -> Self { + MultipartForm { parts } + } +} + +impl IntoResponse for MultipartForm { + fn into_response(self) -> Response { + // see RFC5758 for details + let boundary = generate_boundary(); + let mut headers = HeaderMap::new(); + let mime_type: Mime = match format!("multipart/form-data; boundary={}", boundary).parse() { + Ok(m) => m, + // Realistically this should never happen unless the boundary generation code + // is modified, and that will be caught by unit tests + Err(_) => { + return ( + StatusCode::INTERNAL_SERVER_ERROR, + "Invalid multipart boundary generated", + ) + .into_response() + } + }; + // The use of unwrap is safe here because mime types are inherently string representable + headers.insert(header::CONTENT_TYPE, mime_type.to_string().parse().unwrap()); + let mut serialized_form: Vec = Vec::new(); + for part in self.parts { + // for each part, the boundary is preceded by two dashes + serialized_form.extend_from_slice(format!("--{}\r\n", boundary).as_bytes()); + serialized_form.extend_from_slice(&part.serialize()); + } + serialized_form.extend_from_slice(format!("--{}--", boundary).as_bytes()); + (headers, serialized_form).into_response() + } +} + +// Valid settings for that header are: "base64", "quoted-printable", "8bit", "7bit", and "binary". +/// A single part of a multipart form as defined by +/// +/// and RFC5758. +#[derive(Debug)] +pub struct Part { + // Every part is expected to contain: + // - a [Content-Disposition](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Disposition + // header, where `Content-Disposition` is set to `form-data`, with a parameter of `name` that is set to + // the name of the field in the form. In the below example, the name of the field is `user`: + // ``` + // Content-Disposition: form-data; name="user" + // ``` + // If the field contains a file, then the `filename` parameter may be set to the name of the file. + // Handling for non-ascii field names is not done here, support for non-ascii characters may be encoded using + // methodology described in RFC 2047. + // - (optionally) a `Content-Type` header, which if not set, defaults to `text/plain`. + // If the field contains a file, then the file should be identified with that file's MIME type (eg: `image/gif`). + // If the `MIME` type is not known or specified, then the MIME type should be set to `application/octet-stream`. + /// The name of the part in question + name: String, + /// If the part should be treated as a file, the filename that should be attached that part + filename: Option, + /// The `Content-Type` header. While not strictly required, it is always set here + mime_type: Mime, + /// The content/body of the part + contents: Vec, +} + +impl Part { + /// Create a new part with `Content-Type` of `text/plain` with the supplied name and contents. + /// + /// This form will not have a defined file name. + /// + /// # Examples + /// + /// ```rust + /// use axum_extra::response::multiple::{MultipartForm, Part}; + /// + /// // create a form with a single part that has a field with a name of "foo", + /// // and a value of "abc" + /// let parts: Vec = vec![Part::text("foo".to_string(), "abc")]; + /// let form = MultipartForm::from_iter(parts); + /// ``` + pub fn text(name: String, contents: &str) -> Self { + Self { + name, + filename: None, + mime_type: mime::TEXT_PLAIN_UTF_8, + contents: contents.as_bytes().to_vec(), + } + } + + /// Create a new part containing a generic file, with a `Content-Type` of `application/octet-stream` + /// using the provided file name, field name, and contents. + /// + /// If the MIME type of the file is known, consider using `Part::raw_part`. + /// + /// # Examples + /// + /// ```rust + /// use axum_extra::response::multiple::{MultipartForm, Part}; + /// + /// // create a form with a single part that has a field with a name of "foo", + /// // with a file name of "foo.txt", and with the specified contents + /// let parts: Vec = vec![Part::file("foo", "foo.txt", vec![0x68, 0x68, 0x20, 0x6d, 0x6f, 0x6d])]; + /// let form = MultipartForm::from_iter(parts); + /// ``` + pub fn file(field_name: &str, file_name: &str, contents: Vec) -> Self { + Self { + name: field_name.to_owned(), + filename: Some(file_name.to_owned()), + // If the `MIME` type is not known or specified, then the MIME type should be set to `application/octet-stream`. + // See RFC2388 section 3 for specifics. + mime_type: mime::APPLICATION_OCTET_STREAM, + contents, + } + } + + /// Create a new part with more fine-grained control over the semantics of that part. + /// + /// The caller is assumed to have set a valid MIME type. + /// + /// This function will return an error if the provided MIME type is not valid. + /// + /// # Examples + /// + /// ```rust + /// use axum_extra::response::multiple::{MultipartForm, Part}; + /// + /// // create a form with a single part that has a field with a name of "part_name", + /// // with a MIME type of "application/json", and the supplied contents. + /// let parts: Vec = vec![Part::raw_part("part_name", "application/json", vec![0x68, 0x68, 0x20, 0x6d, 0x6f, 0x6d], None).expect("MIME type must be valid")]; + /// let form = MultipartForm::from_iter(parts); + /// ``` + pub fn raw_part( + name: &str, + mime_type: &str, + contents: Vec, + filename: Option<&str>, + ) -> Result { + let mime_type = mime_type.parse().map_err(|_| "Invalid MIME type")?; + Ok(Self { + name: name.to_owned(), + filename: filename.map(|f| f.to_owned()), + mime_type, + contents, + }) + } + + /// Serialize this part into a chunk that can be easily inserted into a larger form + pub(super) fn serialize(&self) -> Vec { + // A part is serialized in this general format: + // // the filename is optional + // Content-Disposition: form-data; name="FIELD_NAME"; filename="FILENAME"\r\n + // // the mime type (not strictly required by the spec, but always sent here) + // Content-Type: mime/type\r\n + // // a blank line, then the contents of the file start + // \r\n + // CONTENTS\r\n + + // Format what we can as a string, then handle the rest at a byte level + let mut serialized_part = format!("Content-Disposition: form-data; name=\"{}\"", self.name); + // specify a filename if one was set + if let Some(filename) = &self.filename { + serialized_part += &format!("; filename=\"{}\"", filename); + } + serialized_part += "\r\n"; + // specify the MIME type + serialized_part += &format!("Content-Type: {}\r\n", self.mime_type); + serialized_part += "\r\n"; + let mut part_bytes = serialized_part.as_bytes().to_vec(); + part_bytes.extend_from_slice(&self.contents); + part_bytes.extend_from_slice(b"\r\n"); + + part_bytes + } +} + +impl FromIterator for MultipartForm { + fn from_iter>(iter: T) -> Self { + Self { + parts: iter.into_iter().collect(), + } + } +} + +/// A boundary is defined as a user defined (arbitrary) value that does not occur in any of the data. +/// +/// Because the specification does not clearly define a methodology for generating boundaries, this implementation +/// follow's Reqwest's, and generates a boundary in the format of `XXXXXXXX-XXXXXXXX-XXXXXXXX-XXXXXXXX` where `XXXXXXXX` +/// is a hexadecimal representation of a pseudo randomly generated u64. +fn generate_boundary() -> String { + let a = fastrand::u64(0..u64::MAX); + let b = fastrand::u64(0..u64::MAX); + let c = fastrand::u64(0..u64::MAX); + let d = fastrand::u64(0..u64::MAX); + format!("{a:016x}-{b:016x}-{c:016x}-{d:016x}") +} + +#[cfg(test)] +mod tests { + use super::{generate_boundary, MultipartForm, Part}; + use axum::{body::Body, http}; + use axum::{routing::get, Router}; + use http::{Request, Response}; + use http_body_util::BodyExt; + use mime::Mime; + use tower::ServiceExt; + + #[tokio::test] + async fn process_form() -> Result<(), Box> { + // create a boilerplate handle that returns a form + async fn handle() -> MultipartForm { + let parts: Vec = vec![ + Part::text("part1".to_owned(), "basictext"), + Part::file( + "part2", + "file.txt", + vec![0x68, 0x69, 0x20, 0x6d, 0x6f, 0x6d], + ), + Part::raw_part("part3", "text/plain", b"rawpart".to_vec(), None).unwrap(), + ]; + MultipartForm::from_iter(parts) + } + + // make a request to that handle + let app = Router::new().route("/", get(handle)); + let response: Response<_> = app + .oneshot(Request::builder().uri("/").body(Body::empty())?) + .await?; + // content_type header + let ct_header = response.headers().get("content-type").unwrap().to_str()?; + let boundary = ct_header.split("boundary=").nth(1).unwrap().to_owned(); + let body: &[u8] = &response.into_body().collect().await?.to_bytes(); + assert_eq!( + std::str::from_utf8(body)?, + &format!( + "--{boundary}\r\n\ + Content-Disposition: form-data; name=\"part1\"\r\n\ + Content-Type: text/plain; charset=utf-8\r\n\ + \r\n\ + basictext\r\n\ + --{boundary}\r\n\ + Content-Disposition: form-data; name=\"part2\"; filename=\"file.txt\"\r\n\ + Content-Type: application/octet-stream\r\n\ + \r\n\ + hi mom\r\n\ + --{boundary}\r\n\ + Content-Disposition: form-data; name=\"part3\"\r\n\ + Content-Type: text/plain\r\n\ + \r\n\ + rawpart\r\n\ + --{boundary}--", + boundary = boundary + ) + ); + + Ok(()) + } + + #[test] + fn valid_boundary_generation() { + for _ in 0..256 { + let boundary = generate_boundary(); + let mime_type: Result = + format!("multipart/form-data; boundary={}", boundary).parse(); + assert!( + mime_type.is_ok(), + "The generated boundary was unable to be parsed into a valid mime type." + ); + } + } +} diff --git a/axum-extra/src/routing/mod.rs b/axum-extra/src/routing/mod.rs index 40cb336df8..5732f8a308 100644 --- a/axum-extra/src/routing/mod.rs +++ b/axum-extra/src/routing/mod.rs @@ -1,7 +1,7 @@ //! Additional types for defining routes. use axum::{ - extract::Request, + extract::{OriginalUri, Request}, response::{IntoResponse, Redirect, Response}, routing::{any, MethodRouter}, Router, @@ -131,6 +131,19 @@ pub trait RouterExt: sealed::Sealed { T: SecondElementIs

+ 'static, P: TypedPath; + /// Add a typed `CONNECT` route to the router. + /// + /// The path will be inferred from the first argument to the handler function which must + /// implement [`TypedPath`]. + /// + /// See [`TypedPath`] for more details and examples. + #[cfg(feature = "typed-routing")] + fn typed_connect(self, handler: H) -> Self + where + H: axum::handler::Handler, + T: SecondElementIs

+ 'static, + P: TypedPath; + /// Add another route to the router with an additional "trailing slash redirect" route. /// /// If you add a route _without_ a trailing slash, such as `/foo`, this method will also add a @@ -165,7 +178,7 @@ pub trait RouterExt: sealed::Sealed { /// This works like [`RouterExt::route_with_tsr`] but accepts any [`Service`]. fn route_service_with_tsr(self, path: &str, service: T) -> Self where - T: Service + Clone + Send + 'static, + T: Service + Clone + Send + Sync + 'static, T::Response: IntoResponse, T::Future: Send + 'static, Self: Sized; @@ -255,6 +268,16 @@ where self.route(P::PATH, axum::routing::trace(handler)) } + #[cfg(feature = "typed-routing")] + fn typed_connect(self, handler: H) -> Self + where + H: axum::handler::Handler, + T: SecondElementIs

+ 'static, + P: TypedPath, + { + self.route(P::PATH, axum::routing::connect(handler)) + } + #[track_caller] fn route_with_tsr(mut self, path: &str, method_router: MethodRouter) -> Self where @@ -268,7 +291,7 @@ where #[track_caller] fn route_service_with_tsr(mut self, path: &str, service: T) -> Self where - T: Service + Clone + Send + 'static, + T: Service + Clone + Send + Sync + 'static, T::Response: IntoResponse, T::Future: Send + 'static, Self: Sized, @@ -290,7 +313,7 @@ fn add_tsr_redirect_route(router: Router, path: &str) -> Router where S: Clone + Send + Sync + 'static, { - async fn redirect_handler(uri: Uri) -> Response { + async fn redirect_handler(OriginalUri(uri): OriginalUri) -> Response { let new_uri = map_path(uri, |path| { path.strip_suffix('/') .map(Cow::Borrowed) @@ -342,7 +365,7 @@ mod sealed { mod tests { use super::*; use crate::test_helpers::*; - use axum::{extract::Path, http::StatusCode, routing::get}; + use axum::{extract::Path, routing::get}; #[tokio::test] async fn test_tsr() { @@ -352,17 +375,17 @@ mod tests { let client = TestClient::new(app); - let res = client.get("/foo").send().await; + let res = client.get("/foo").await; assert_eq!(res.status(), StatusCode::OK); - let res = client.get("/foo/").send().await; + let res = client.get("/foo/").await; assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT); assert_eq!(res.headers()["location"], "/foo"); - let res = client.get("/bar/").send().await; + let res = client.get("/bar/").await; assert_eq!(res.status(), StatusCode::OK); - let res = client.get("/bar").send().await; + let res = client.get("/bar").await; assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT); assert_eq!(res.headers()["location"], "/bar/"); } @@ -371,29 +394,29 @@ mod tests { async fn tsr_with_params() { let app = Router::new() .route_with_tsr( - "/a/:a", + "/a/{a}", get(|Path(param): Path| async move { param }), ) .route_with_tsr( - "/b/:b/", + "/b/{b}/", get(|Path(param): Path| async move { param }), ); let client = TestClient::new(app); - let res = client.get("/a/foo").send().await; + let res = client.get("/a/foo").await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "foo"); - let res = client.get("/a/foo/").send().await; + let res = client.get("/a/foo/").await; assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT); assert_eq!(res.headers()["location"], "/a/foo"); - let res = client.get("/b/foo/").send().await; + let res = client.get("/b/foo/").await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "foo"); - let res = client.get("/b/foo").send().await; + let res = client.get("/b/foo").await; assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT); assert_eq!(res.headers()["location"], "/b/foo/"); } @@ -404,11 +427,27 @@ mod tests { let client = TestClient::new(app); - let res = client.get("/foo/?a=a").send().await; + let res = client.get("/foo/?a=a").await; assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT); assert_eq!(res.headers()["location"], "/foo?a=a"); } + #[tokio::test] + async fn tsr_works_in_nested_router() { + let app = Router::new().nest( + "/neko", + Router::new().route_with_tsr("/nyan/", get(|| async {})), + ); + + let client = TestClient::new(app); + let res = client.get("/neko/nyan/").await; + assert_eq!(res.status(), StatusCode::OK); + + let res = client.get("/neko/nyan").await; + assert_eq!(res.status(), StatusCode::PERMANENT_REDIRECT); + assert_eq!(res.headers()["location"], "/neko/nyan/"); + } + #[test] #[should_panic = "Cannot add a trailing slash redirect route for `/`"] fn tsr_at_root() { diff --git a/axum-extra/src/routing/resource.rs b/axum-extra/src/routing/resource.rs index 61929c8ffb..96c15c5533 100644 --- a/axum-extra/src/routing/resource.rs +++ b/axum-extra/src/routing/resource.rs @@ -19,13 +19,13 @@ use axum::{ /// .create(|| async {}) /// // `GET /users/new` /// .new(|| async {}) -/// // `GET /users/:users_id` +/// // `GET /users/{users_id}` /// .show(|Path(user_id): Path| async {}) -/// // `GET /users/:users_id/edit` +/// // `GET /users/{users_id}/edit` /// .edit(|Path(user_id): Path| async {}) -/// // `PUT or PATCH /users/:users_id` +/// // `PUT or PATCH /users/{users_id}` /// .update(|Path(user_id): Path| async {}) -/// // `DELETE /users/:users_id` +/// // `DELETE /users/{users_id}` /// .destroy(|Path(user_id): Path| async {}); /// /// let app = Router::new().merge(users); @@ -82,7 +82,9 @@ where self.route(&path, get(handler)) } - /// Add a handler at `GET /{resource_name}/:{resource_name}_id`. + /// Add a handler at `GET //{_id}`. + /// + /// For example when the resources are posts: `GET /post/{post_id}`. pub fn show(self, handler: H) -> Self where H: Handler, @@ -92,17 +94,21 @@ where self.route(&path, get(handler)) } - /// Add a handler at `GET /{resource_name}/:{resource_name}_id/edit`. + /// Add a handler at `GET //{_id}/edit`. + /// + /// For example when the resources are posts: `GET /post/{post_id}/edit`. pub fn edit(self, handler: H) -> Self where H: Handler, T: 'static, { - let path = format!("/{0}/:{0}_id/edit", self.name); + let path = format!("/{0}/{{{0}_id}}/edit", self.name); self.route(&path, get(handler)) } - /// Add a handler at `PUT or PATCH /resource_name/:{resource_name}_id`. + /// Add a handler at `PUT or PATCH //{_id}`. + /// + /// For example when the resources are posts: `PUT /post/{post_id}`. pub fn update(self, handler: H) -> Self where H: Handler, @@ -115,7 +121,9 @@ where ) } - /// Add a handler at `DELETE /{resource_name}/:{resource_name}_id`. + /// Add a handler at `DELETE //{_id}`. + /// + /// For example when the resources are posts: `DELETE /post/{post_id}`. pub fn destroy(self, handler: H) -> Self where H: Handler, @@ -130,7 +138,7 @@ where } fn show_update_destroy_path(&self) -> String { - format!("/{0}/:{0}_id", self.name) + format!("/{0}/{{{0}_id}}", self.name) } fn route(mut self, path: &str, method_router: MethodRouter) -> Self { @@ -149,7 +157,7 @@ impl From> for Router { mod tests { #[allow(unused_imports)] use super::*; - use axum::{body::Body, extract::Path, http::Method, Router}; + use axum::{body::Body, extract::Path, http::Method}; use http::Request; use http_body_util::BodyExt; use tower::ServiceExt; diff --git a/axum-extra/src/routing/typed.rs b/axum-extra/src/routing/typed.rs index ef30462aab..02c5be672c 100644 --- a/axum-extra/src/routing/typed.rs +++ b/axum-extra/src/routing/typed.rs @@ -19,15 +19,15 @@ use serde::Serialize; /// RouterExt, // for `Router::typed_*` /// }; /// -/// // A type safe route with `/users/:id` as its associated path. +/// // A type safe route with `/users/{id}` as its associated path. /// #[derive(TypedPath, Deserialize)] -/// #[typed_path("/users/:id")] +/// #[typed_path("/users/{id}")] /// struct UsersMember { /// id: u32, /// } /// /// // A regular handler function that takes `UsersMember` as the first argument -/// // and thus creates a typed connection between this handler and the `/users/:id` path. +/// // and thus creates a typed connection between this handler and the `/users/{id}` path. /// // /// // The `TypedPath` must be the first argument to the function. /// async fn users_show( @@ -39,7 +39,7 @@ use serde::Serialize; /// let app = Router::new() /// // Add our typed route to the router. /// // -/// // The path will be inferred to `/users/:id` since `users_show`'s +/// // The path will be inferred to `/users/{id}` since `users_show`'s /// // first argument is `UsersMember` which implements `TypedPath` /// .typed_get(users_show) /// .typed_post(users_create) @@ -75,7 +75,7 @@ use serde::Serialize; /// use axum_extra::routing::TypedPath; /// /// #[derive(TypedPath, Deserialize)] -/// #[typed_path("/users/:id")] +/// #[typed_path("/users/{id}")] /// struct UsersMember { /// id: u32, /// } @@ -85,12 +85,12 @@ use serde::Serialize; /// /// - A `TypedPath` implementation. /// - A [`FromRequest`] implementation compatible with [`RouterExt::typed_get`], -/// [`RouterExt::typed_post`], etc. This implementation uses [`Path`] and thus your struct must -/// also implement [`serde::Deserialize`], unless it's a unit struct. +/// [`RouterExt::typed_post`], etc. This implementation uses [`Path`] and thus your struct must +/// also implement [`serde::Deserialize`], unless it's a unit struct. /// - A [`Display`] implementation that interpolates the captures. This can be used to, among other -/// things, create links to known paths and have them verified statically. Note that the -/// [`Display`] implementation for each field must return something that's compatible with its -/// [`Deserialize`] implementation. +/// things, create links to known paths and have them verified statically. Note that the +/// [`Display`] implementation for each field must return something that's compatible with its +/// [`Deserialize`] implementation. /// /// Additionally the macro will verify the captures in the path matches the fields of the struct. /// For example this fails to compile since the struct doesn't have a `team_id` field: @@ -100,7 +100,7 @@ use serde::Serialize; /// use axum_extra::routing::TypedPath; /// /// #[derive(TypedPath, Deserialize)] -/// #[typed_path("/users/:id/teams/:team_id")] +/// #[typed_path("/users/{id}/teams/{team_id}")] /// struct UsersMember { /// id: u32, /// } @@ -117,7 +117,7 @@ use serde::Serialize; /// struct UsersCollection; /// /// #[derive(TypedPath, Deserialize)] -/// #[typed_path("/users/:id")] +/// #[typed_path("/users/{id}")] /// struct UsersMember(u32); /// ``` /// @@ -130,7 +130,7 @@ use serde::Serialize; /// use axum_extra::routing::TypedPath; /// /// #[derive(TypedPath, Deserialize)] -/// #[typed_path("/users/:id")] +/// #[typed_path("/users/{id}")] /// struct UsersMember { /// id: String, /// } @@ -158,7 +158,7 @@ use serde::Serialize; /// }; /// /// #[derive(TypedPath, Deserialize)] -/// #[typed_path("/users/:id", rejection(UsersMemberRejection))] +/// #[typed_path("/users/{id}", rejection(UsersMemberRejection))] /// struct UsersMember { /// id: String, /// } @@ -215,7 +215,7 @@ use serde::Serialize; /// [`Deserialize`]: serde::Deserialize /// [`PathRejection`]: axum::extract::rejection::PathRejection pub trait TypedPath: std::fmt::Display { - /// The path with optional captures such as `/users/:id`. + /// The path with optional captures such as `/users/{id}`. const PATH: &'static str; /// Convert the path into a `Uri`. @@ -321,7 +321,7 @@ where /// Utility trait used with [`RouterExt`] to ensure the second element of a tuple type is a /// given type. /// -/// If you see it in type errors its most likely because the second argument to your handler doesn't +/// If you see it in type errors it's most likely because the second argument to your handler doesn't /// implement [`TypedPath`]. /// /// You normally shouldn't have to use this trait directly. @@ -386,11 +386,19 @@ impl_second_element_is!(T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, T13, #[cfg(test)] mod tests { use super::*; - use crate::routing::TypedPath; + use crate::{ + extract::WithRejection, + routing::{RouterExt, TypedPath}, + }; + use axum::{ + extract::rejection::PathRejection, + response::{IntoResponse, Response}, + Router, + }; use serde::Deserialize; #[derive(TypedPath, Deserialize)] - #[typed_path("/users/:id")] + #[typed_path("/users/{id}")] struct UsersShow { id: i32, } @@ -434,4 +442,25 @@ mod tests { assert_eq!(uri, "/users/1?&foo=foo&bar=123&baz=true&qux=1337"); } + + #[allow(dead_code)] // just needs to compile + fn supports_with_rejection() { + async fn handler(_: WithRejection) {} + + struct MyRejection {} + + impl IntoResponse for MyRejection { + fn into_response(self) -> Response { + unimplemented!() + } + } + + impl From for MyRejection { + fn from(_: PathRejection) -> Self { + unimplemented!() + } + } + + let _: Router = Router::new().typed_get(handler); + } } diff --git a/axum-extra/src/typed_header.rs b/axum-extra/src/typed_header.rs index aa89e81ddf..ef94c3779c 100644 --- a/axum-extra/src/typed_header.rs +++ b/axum-extra/src/typed_header.rs @@ -1,12 +1,11 @@ //! Extractor and response for typed headers. use axum::{ - async_trait, extract::FromRequestParts, response::{IntoResponse, IntoResponseParts, Response, ResponseParts}, }; use headers::{Header, HeaderMapExt}; -use http::request::Parts; +use http::{request::Parts, StatusCode}; use std::convert::Infallible; /// Extractor and response that works with typed header values from [`headers`]. @@ -30,7 +29,7 @@ use std::convert::Infallible; /// // ... /// } /// -/// let app = Router::new().route("/users/:user_id/team/:team_id", get(users_teams_show)); +/// let app = Router::new().route("/users/{user_id}/team/{team_id}", get(users_teams_show)); /// # let _: Router = app; /// ``` /// @@ -55,7 +54,6 @@ use std::convert::Infallible; #[must_use] pub struct TypedHeader(pub T); -#[async_trait] impl FromRequestParts for TypedHeader where T: Header, @@ -123,6 +121,14 @@ impl TypedHeaderRejection { pub fn reason(&self) -> &TypedHeaderRejectionReason { &self.reason } + + /// Returns `true` if the typed header rejection reason is [`Missing`]. + /// + /// [`Missing`]: TypedHeaderRejectionReason::Missing + #[must_use] + pub fn is_missing(&self) -> bool { + self.reason.is_missing() + } } /// Additional information regarding a [`TypedHeaderRejection`] @@ -136,9 +142,22 @@ pub enum TypedHeaderRejectionReason { Error(headers::Error), } +impl TypedHeaderRejectionReason { + /// Returns `true` if the typed header rejection reason is [`Missing`]. + /// + /// [`Missing`]: TypedHeaderRejectionReason::Missing + #[must_use] + pub fn is_missing(&self) -> bool { + matches!(self, Self::Missing) + } +} + impl IntoResponse for TypedHeaderRejection { fn into_response(self) -> Response { - (http::StatusCode::BAD_REQUEST, self.to_string()).into_response() + let status = StatusCode::BAD_REQUEST; + let body = self.to_string(); + axum_core::__log_rejection!(rejection_type = Self, body_text = body, status = status,); + (status, body).into_response() } } @@ -168,7 +187,7 @@ impl std::error::Error for TypedHeaderRejection { mod tests { use super::*; use crate::test_helpers::*; - use axum::{response::IntoResponse, routing::get, Router}; + use axum::{routing::get, Router}; #[tokio::test] async fn typed_header() { @@ -190,7 +209,6 @@ mod tests { .header("user-agent", "foobar") .header("cookie", "a=1; b=2") .header("cookie", "c=3") - .send() .await; let body = res.text().await; assert_eq!( @@ -198,11 +216,11 @@ mod tests { r#"User-Agent="foobar", Cookie=[("a", "1"), ("b", "2"), ("c", "3")]"# ); - let res = client.get("/").header("user-agent", "foobar").send().await; + let res = client.get("/").header("user-agent", "foobar").await; let body = res.text().await; assert_eq!(body, r#"User-Agent="foobar", Cookie=[]"#); - let res = client.get("/").header("cookie", "a=1").send().await; + let res = client.get("/").header("cookie", "a=1").await; let body = res.text().await; assert_eq!(body, "Header of type `user-agent` was missing"); } diff --git a/axum-macros/CHANGELOG.md b/axum-macros/CHANGELOG.md index 90dfd9bc59..11b7820f89 100644 --- a/axum-macros/CHANGELOG.md +++ b/axum-macros/CHANGELOG.md @@ -5,9 +5,27 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). -# Unreleased +# 0.5.0 -- None. +## alpha.1 + +- **breaking:** Update code generation for axum-core 0.5.0-alpha.1 +- **change:** Update minimum rust version to 1.75 ([#2943]) + +[#2943]: https://github.com/tokio-rs/axum/pull/2943 + +# 0.4.2 + +- **added:** Add `#[debug_middleware]` ([#1993], [#2725]) + +[#1993]: https://github.com/tokio-rs/axum/pull/1993 +[#2725]: https://github.com/tokio-rs/axum/pull/2725 + +# 0.4.1 (13. January, 2024) + +- **fixed:** Improve `debug_handler` on tuple response types ([#2201]) + +[#2201]: https://github.com/tokio-rs/axum/pull/2201 # 0.4.0 (27. November, 2023) diff --git a/axum-macros/Cargo.toml b/axum-macros/Cargo.toml index 72356ee672..1960a465da 100644 --- a/axum-macros/Cargo.toml +++ b/axum-macros/Cargo.toml @@ -2,14 +2,14 @@ categories = ["asynchronous", "network-programming", "web-programming"] description = "Macros for axum" edition = "2021" -rust-version = "1.66" +rust-version = { workspace = true } homepage = "https://github.com/tokio-rs/axum" keywords = ["axum"] license = "MIT" name = "axum-macros" readme = "README.md" repository = "https://github.com/tokio-rs/axum" -version = "0.4.0" # remember to also bump the version that axum and axum-extra depends on +version = "0.5.0-alpha.1" # remember to also bump the version that axum and axum-extra depends on [features] default = [] @@ -19,7 +19,6 @@ __private = ["syn/visit-mut"] proc-macro = true [dependencies] -heck = "0.4" proc-macro2 = "1.0" quote = "1.0" syn = { version = "2.0", features = [ @@ -30,8 +29,8 @@ syn = { version = "2.0", features = [ ] } [dev-dependencies] -axum = { path = "../axum", version = "0.7.2", features = ["macros"] } -axum-extra = { path = "../axum-extra", version = "0.9.0", features = ["typed-routing", "cookie-private", "typed-header"] } +axum = { path = "../axum", features = ["macros"] } +axum-extra = { path = "../axum-extra", features = ["typed-routing", "cookie-private", "typed-header"] } rustversion = "1.0" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" @@ -44,4 +43,3 @@ allowed = [] [package.metadata.docs.rs] all-features = true -rustdoc-args = ["--cfg", "docsrs"] diff --git a/axum-macros/README.md b/axum-macros/README.md index 8fcde01ed2..c3967b19ae 100644 --- a/axum-macros/README.md +++ b/axum-macros/README.md @@ -14,7 +14,7 @@ This crate uses `#![forbid(unsafe_code)]` to ensure everything is implemented in ## Minimum supported Rust version -axum-macros's MSRV is 1.66. +axum-macros's MSRV is 1.75. ## Getting Help diff --git a/axum-macros/rust-toolchain b/axum-macros/rust-toolchain index 2c3dbd2c7b..eca143c73f 100644 --- a/axum-macros/rust-toolchain +++ b/axum-macros/rust-toolchain @@ -1 +1 @@ -nightly-2023-09-23 +nightly-2024-06-22 diff --git a/axum-macros/src/debug_handler.rs b/axum-macros/src/debug_handler.rs index 0fbfc0e9af..456bd643ea 100644 --- a/axum-macros/src/debug_handler.rs +++ b/axum-macros/src/debug_handler.rs @@ -1,21 +1,26 @@ -use std::collections::HashSet; +use std::{collections::HashSet, fmt}; use crate::{ attr_parsing::{parse_assignment_attribute, second}, with_position::{Position, WithPosition}, }; -use proc_macro2::{Span, TokenStream}; +use proc_macro2::{Ident, Span, TokenStream}; use quote::{format_ident, quote, quote_spanned}; -use syn::{parse::Parse, spanned::Spanned, FnArg, ItemFn, Token, Type}; +use syn::{parse::Parse, spanned::Spanned, FnArg, ItemFn, ReturnType, Token, Type}; -pub(crate) fn expand(attr: Attrs, item_fn: ItemFn) -> TokenStream { +pub(crate) fn expand(attr: Attrs, item_fn: ItemFn, kind: FunctionKind) -> TokenStream { let Attrs { state_ty } = attr; let mut state_ty = state_ty.map(second); - let check_extractor_count = check_extractor_count(&item_fn); - let check_path_extractor = check_path_extractor(&item_fn); - let check_output_impls_into_response = check_output_impls_into_response(&item_fn); + let check_extractor_count = check_extractor_count(&item_fn, kind); + let check_path_extractor = check_path_extractor(&item_fn, kind); + let check_output_tuples = check_output_tuples(&item_fn); + let check_output_impls_into_response = if check_output_tuples.is_empty() { + check_output_impls_into_response(&item_fn) + } else { + check_output_tuples + }; // If the function is generic, we can't reliably check its inputs or whether the future it // returns is `Send`. Skip those checks to avoid unhelpful additional compiler errors. @@ -32,8 +37,10 @@ pub(crate) fn expand(attr: Attrs, item_fn: ItemFn) -> TokenStream { err = Some( syn::Error::new( Span::call_site(), - "can't infer state type, please add set it explicitly, as in \ - `#[debug_handler(state = MyStateType)]`", + format!( + "can't infer state type, please add set it explicitly, as in \ + `#[axum_macros::debug_{kind}(state = MyStateType)]`" + ), ) .into_compile_error(), ); @@ -43,16 +50,16 @@ pub(crate) fn expand(attr: Attrs, item_fn: ItemFn) -> TokenStream { err.unwrap_or_else(|| { let state_ty = state_ty.unwrap_or_else(|| syn::parse_quote!(())); - let check_future_send = check_future_send(&item_fn); + let check_future_send = check_future_send(&item_fn, kind); - if let Some(check_input_order) = check_input_order(&item_fn) { + if let Some(check_input_order) = check_input_order(&item_fn, kind) { quote! { #check_input_order #check_future_send } } else { let check_inputs_impls_from_request = - check_inputs_impls_from_request(&item_fn, state_ty); + check_inputs_impls_from_request(&item_fn, state_ty, kind); quote! { #check_inputs_impls_from_request @@ -63,17 +70,45 @@ pub(crate) fn expand(attr: Attrs, item_fn: ItemFn) -> TokenStream { } else { syn::Error::new_spanned( &item_fn.sig.generics, - "`#[axum_macros::debug_handler]` doesn't support generic functions", + format!("`#[axum_macros::debug_{kind}]` doesn't support generic functions"), ) .into_compile_error() }; + let middleware_takes_next_as_last_arg = + matches!(kind, FunctionKind::Middleware).then(|| next_is_last_input(&item_fn)); + quote! { #item_fn #check_extractor_count #check_path_extractor #check_output_impls_into_response #check_inputs_and_future_send + #middleware_takes_next_as_last_arg + } +} + +#[derive(Clone, Copy)] +pub(crate) enum FunctionKind { + Handler, + Middleware, +} + +impl fmt::Display for FunctionKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + FunctionKind::Handler => f.write_str("handler"), + FunctionKind::Middleware => f.write_str("middleware"), + } + } +} + +impl FunctionKind { + fn name_uppercase_plural(&self) -> &'static str { + match self { + FunctionKind::Handler => "Handlers", + FunctionKind::Middleware => "Middleware", + } } } @@ -105,25 +140,36 @@ impl Parse for Attrs { } } -fn check_extractor_count(item_fn: &ItemFn) -> Option { +fn check_extractor_count(item_fn: &ItemFn, kind: FunctionKind) -> Option { let max_extractors = 16; - if item_fn.sig.inputs.len() <= max_extractors { + let inputs = item_fn + .sig + .inputs + .iter() + .filter(|arg| skip_next_arg(arg, kind)) + .count(); + if inputs <= max_extractors { None } else { let error_message = format!( - "Handlers cannot take more than {max_extractors} arguments. \ + "{} cannot take more than {max_extractors} arguments. \ Use `(a, b): (ExtractorA, ExtractorA)` to further nest extractors", + kind.name_uppercase_plural(), ); let error = syn::Error::new_spanned(&item_fn.sig.inputs, error_message).to_compile_error(); Some(error) } } -fn extractor_idents(item_fn: &ItemFn) -> impl Iterator { +fn extractor_idents( + item_fn: &ItemFn, + kind: FunctionKind, +) -> impl Iterator { item_fn .sig .inputs .iter() + .filter(move |arg| skip_next_arg(arg, kind)) .enumerate() .filter_map(|(idx, fn_arg)| match fn_arg { FnArg::Receiver(_) => None, @@ -141,8 +187,8 @@ fn extractor_idents(item_fn: &ItemFn) -> impl Iterator TokenStream { - let path_extractors = extractor_idents(item_fn) +fn check_path_extractor(item_fn: &ItemFn, kind: FunctionKind) -> TokenStream { + let path_extractors = extractor_idents(item_fn, kind) .filter(|(_, _, ident)| *ident == "Path") .collect::>(); @@ -174,121 +220,294 @@ fn is_self_pat_type(typed: &syn::PatType) -> bool { ident == "self" } -fn check_inputs_impls_from_request(item_fn: &ItemFn, state_ty: Type) -> TokenStream { - let takes_self = item_fn.sig.inputs.first().map_or(false, |arg| match arg { +fn check_inputs_impls_from_request( + item_fn: &ItemFn, + state_ty: Type, + kind: FunctionKind, +) -> TokenStream { + let takes_self = item_fn.sig.inputs.first().is_some_and(|arg| match arg { FnArg::Receiver(_) => true, FnArg::Typed(typed) => is_self_pat_type(typed), }); - WithPosition::new(item_fn.sig.inputs.iter()) - .enumerate() - .map(|(idx, arg)| { - let must_impl_from_request_parts = match &arg { - Position::First(_) | Position::Middle(_) => true, - Position::Last(_) | Position::Only(_) => false, - }; + WithPosition::new( + item_fn + .sig + .inputs + .iter() + .filter(|arg| skip_next_arg(arg, kind)), + ) + .enumerate() + .map(|(idx, arg)| { + let must_impl_from_request_parts = match &arg { + Position::First(_) | Position::Middle(_) => true, + Position::Last(_) | Position::Only(_) => false, + }; - let arg = arg.into_inner(); + let arg = arg.into_inner(); - let (span, ty) = match arg { - FnArg::Receiver(receiver) => { - if receiver.reference.is_some() { - return syn::Error::new_spanned( - receiver, - "Handlers must only take owned values", - ) - .into_compile_error(); - } + let (span, ty) = match arg { + FnArg::Receiver(receiver) => { + if receiver.reference.is_some() { + return syn::Error::new_spanned( + receiver, + "Handlers must only take owned values", + ) + .into_compile_error(); + } + + let span = receiver.span(); + (span, syn::parse_quote!(Self)) + } + FnArg::Typed(typed) => { + let ty = &typed.ty; + let span = ty.span(); - let span = receiver.span(); + if is_self_pat_type(typed) { (span, syn::parse_quote!(Self)) + } else { + (span, ty.clone()) } - FnArg::Typed(typed) => { - let ty = &typed.ty; - let span = ty.span(); + } + }; + + let consumes_request = request_consuming_type_name(&ty).is_some(); + + let check_fn = format_ident!( + "__axum_macros_check_{}_{}_from_request_check", + item_fn.sig.ident, + idx, + span = span, + ); + + let call_check_fn = format_ident!( + "__axum_macros_check_{}_{}_from_request_call_check", + item_fn.sig.ident, + idx, + span = span, + ); - if is_self_pat_type(typed) { - (span, syn::parse_quote!(Self)) + let call_check_fn_body = if takes_self { + quote_spanned! {span=> + Self::#check_fn(); + } + } else { + quote_spanned! {span=> + #check_fn(); + } + }; + + let check_fn_generics = if must_impl_from_request_parts || consumes_request { + quote! {} + } else { + quote! { } + }; + + let from_request_bound = if must_impl_from_request_parts { + quote_spanned! {span=> + #ty: ::axum::extract::FromRequestParts<#state_ty> + Send + } + } else if consumes_request { + quote_spanned! {span=> + #ty: ::axum::extract::FromRequest<#state_ty> + Send + } + } else { + quote_spanned! {span=> + #ty: ::axum::extract::FromRequest<#state_ty, M> + Send + } + }; + + quote_spanned! {span=> + #[allow(warnings)] + #[doc(hidden)] + fn #check_fn #check_fn_generics() + where + #from_request_bound, + {} + + // we have to call the function to actually trigger a compile error + // since the function is generic, just defining it is not enough + #[allow(warnings)] + #[doc(hidden)] + fn #call_check_fn() + { + #call_check_fn_body + } + } + }) + .collect::() +} + +fn check_output_tuples(item_fn: &ItemFn) -> TokenStream { + let elems = match &item_fn.sig.output { + ReturnType::Type(_, ty) => match &**ty { + Type::Tuple(tuple) => &tuple.elems, + _ => return quote! {}, + }, + ReturnType::Default => return quote! {}, + }; + + let handler_ident = &item_fn.sig.ident; + + match elems.len() { + 0 => quote! {}, + n if n > 17 => syn::Error::new_spanned( + &item_fn.sig.output, + "Cannot return tuples with more than 17 elements", + ) + .to_compile_error(), + _ => WithPosition::new(elems) + .enumerate() + .map(|(idx, arg)| match arg { + Position::First(ty) => match extract_clean_typename(ty).as_deref() { + Some("StatusCode" | "Response") => quote! {}, + Some("Parts") => check_is_response_parts(ty, handler_ident, idx), + Some(_) | None => { + if let Some(tn) = well_known_last_response_type(ty) { + syn::Error::new_spanned( + ty, + format!( + "`{tn}` must be the last element \ + in a response tuple" + ), + ) + .to_compile_error() + } else { + check_into_response_parts(ty, handler_ident, idx) + } + } + }, + Position::Middle(ty) => { + if let Some(tn) = well_known_last_response_type(ty) { + syn::Error::new_spanned( + ty, + format!("`{tn}` must be the last element in a response tuple"), + ) + .to_compile_error() } else { - (span, ty.clone()) + check_into_response_parts(ty, handler_ident, idx) } } - }; + Position::Last(ty) | Position::Only(ty) => check_into_response(handler_ident, ty), + }) + .collect::(), + } +} - let consumes_request = request_consuming_type_name(&ty).is_some(); +fn check_into_response(handler: &Ident, ty: &Type) -> TokenStream { + let (span, ty) = (ty.span(), ty.clone()); - let check_fn = format_ident!( - "__axum_macros_check_{}_{}_from_request_check", - item_fn.sig.ident, - idx, - span = span, - ); + let check_fn = format_ident!( + "__axum_macros_check_{handler}_into_response_check", + span = span, + ); - let call_check_fn = format_ident!( - "__axum_macros_check_{}_{}_from_request_call_check", - item_fn.sig.ident, - idx, - span = span, - ); + let call_check_fn = format_ident!( + "__axum_macros_check_{handler}_into_response_call_check", + span = span, + ); - let call_check_fn_body = if takes_self { - quote_spanned! {span=> - Self::#check_fn(); - } - } else { - quote_spanned! {span=> - #check_fn(); - } - }; + let call_check_fn_body = quote_spanned! {span=> + #check_fn(); + }; - let check_fn_generics = if must_impl_from_request_parts || consumes_request { - quote! {} - } else { - quote! { } - }; + let from_request_bound = quote_spanned! {span=> + #ty: ::axum::response::IntoResponse + }; + quote_spanned! {span=> + #[allow(warnings)] + #[allow(unreachable_code)] + #[doc(hidden)] + fn #check_fn() + where + #from_request_bound, + {} - let from_request_bound = if must_impl_from_request_parts { - quote_spanned! {span=> - #ty: ::axum::extract::FromRequestParts<#state_ty> + Send - } - } else if consumes_request { - quote_spanned! {span=> - #ty: ::axum::extract::FromRequest<#state_ty> + Send - } - } else { - quote_spanned! {span=> - #ty: ::axum::extract::FromRequest<#state_ty, M> + Send - } - }; + // we have to call the function to actually trigger a compile error + // since the function is generic, just defining it is not enough + #[allow(warnings)] + #[allow(unreachable_code)] + #[doc(hidden)] + fn #call_check_fn() { + #call_check_fn_body + } + } +} - quote_spanned! {span=> - #[allow(warnings)] - #[allow(unreachable_code)] - #[doc(hidden)] - fn #check_fn #check_fn_generics() - where - #from_request_bound, - {} +fn check_is_response_parts(ty: &Type, ident: &Ident, index: usize) -> TokenStream { + let (span, ty) = (ty.span(), ty.clone()); - // we have to call the function to actually trigger a compile error - // since the function is generic, just defining it is not enough - #[allow(warnings)] - #[allow(unreachable_code)] - #[doc(hidden)] - fn #call_check_fn() - { - #call_check_fn_body - } - } - }) - .collect::() + let check_fn = format_ident!( + "__axum_macros_check_{}_is_response_parts_{index}_check", + ident, + span = span, + ); + + quote_spanned! {span=> + #[allow(warnings)] + #[allow(unreachable_code)] + #[doc(hidden)] + fn #check_fn(parts: #ty) -> ::axum::http::response::Parts { + parts + } + } } -fn check_input_order(item_fn: &ItemFn) -> Option { +fn check_into_response_parts(ty: &Type, ident: &Ident, index: usize) -> TokenStream { + let (span, ty) = (ty.span(), ty.clone()); + + let check_fn = format_ident!( + "__axum_macros_check_{}_into_response_parts_{index}_check", + ident, + span = span, + ); + + let call_check_fn = format_ident!( + "__axum_macros_check_{}_into_response_parts_{index}_call_check", + ident, + span = span, + ); + + let call_check_fn_body = quote_spanned! {span=> + #check_fn(); + }; + + let from_request_bound = quote_spanned! {span=> + #ty: ::axum::response::IntoResponseParts + }; + quote_spanned! {span=> + #[allow(warnings)] + #[allow(unreachable_code)] + #[doc(hidden)] + fn #check_fn() + where + #from_request_bound, + {} + + // we have to call the function to actually trigger a compile error + // since the function is generic, just defining it is not enough + #[allow(warnings)] + #[allow(unreachable_code)] + #[doc(hidden)] + fn #call_check_fn() { + #call_check_fn_body + } + } +} + +fn check_input_order(item_fn: &ItemFn, kind: FunctionKind) -> Option { + let number_of_inputs = item_fn + .sig + .inputs + .iter() + .filter(|arg| skip_next_arg(arg, kind)) + .count(); + let types_that_consume_the_request = item_fn .sig .inputs .iter() + .filter(|arg| skip_next_arg(arg, kind)) .enumerate() .filter_map(|(idx, arg)| { let ty = match arg { @@ -308,7 +527,7 @@ fn check_input_order(item_fn: &ItemFn) -> Option { // exactly one type that consumes the request if types_that_consume_the_request.len() == 1 { // and that is not the last - if types_that_consume_the_request[0].0 != item_fn.sig.inputs.len() - 1 { + if types_that_consume_the_request[0].0 != number_of_inputs - 1 { let (_idx, type_name, span) = &types_that_consume_the_request[0]; let error = format!( "`{type_name}` consumes the request body and thus must be \ @@ -334,7 +553,7 @@ fn check_input_order(item_fn: &ItemFn) -> Option { compile_error!(#error); }) } else { - let types = WithPosition::new(types_that_consume_the_request.into_iter()) + let types = WithPosition::new(types_that_consume_the_request) .map(|pos| match pos { Position::First((_, type_name, _)) | Position::Middle((_, type_name, _)) => { format!("`{type_name}`, ") @@ -355,18 +574,18 @@ fn check_input_order(item_fn: &ItemFn) -> Option { } } -fn request_consuming_type_name(ty: &Type) -> Option<&'static str> { +fn extract_clean_typename(ty: &Type) -> Option { let path = match ty { Type::Path(type_path) => &type_path.path, _ => return None, }; + path.segments.last().map(|p| p.ident.to_string()) +} - let ident = match path.segments.last() { - Some(path_segment) => &path_segment.ident, - None => return None, - }; +fn request_consuming_type_name(ty: &Type) -> Option<&'static str> { + let typename = extract_clean_typename(ty)?; - let type_name = match &*ident.to_string() { + let type_name = match &*typename { "Json" => "Json<_>", "RawBody" => "RawBody<_>", "RawForm" => "RawForm", @@ -384,6 +603,22 @@ fn request_consuming_type_name(ty: &Type) -> Option<&'static str> { Some(type_name) } +fn well_known_last_response_type(ty: &Type) -> Option<&'static str> { + let typename = extract_clean_typename(ty)?; + + let type_name = match &*typename { + "Json" => "Json<_>", + "Protobuf" => "Protobuf", + "JsonLines" => "JsonLines<_>", + "Form" => "Form<_>", + "Bytes" => "Bytes", + "String" => "String", + _ => return None, + }; + + Some(type_name) +} + fn check_output_impls_into_response(item_fn: &ItemFn) -> TokenStream { let ty = match &item_fn.sig.output { syn::ReturnType::Default => return quote! {}, @@ -473,13 +708,13 @@ fn check_output_impls_into_response(item_fn: &ItemFn) -> TokenStream { } } -fn check_future_send(item_fn: &ItemFn) -> TokenStream { +fn check_future_send(item_fn: &ItemFn, kind: FunctionKind) -> TokenStream { if item_fn.sig.asyncness.is_none() { match &item_fn.sig.output { syn::ReturnType::Default => { return syn::Error::new_spanned( item_fn.sig.fn_token, - "Handlers must be `async fn`s", + format!("{} must be `async fn`s", kind.name_uppercase_plural()), ) .into_compile_error(); } @@ -583,7 +818,69 @@ fn state_types_from_args(item_fn: &ItemFn) -> HashSet { crate::infer_state_types(types).collect() } +fn next_is_last_input(item_fn: &ItemFn) -> TokenStream { + let next_args = item_fn + .sig + .inputs + .iter() + .enumerate() + .filter(|(_, arg)| !skip_next_arg(arg, FunctionKind::Middleware)) + .collect::>(); + + if next_args.is_empty() { + return quote! { + compile_error!( + "Middleware functions must take `axum::middleware::Next` as the last argument", + ); + }; + } + + if next_args.len() == 1 { + let (idx, arg) = &next_args[0]; + if *idx != item_fn.sig.inputs.len() - 1 { + return quote_spanned! {arg.span()=> + compile_error!("`axum::middleware::Next` must the last argument"); + }; + } + } + + if next_args.len() >= 2 { + return quote! { + compile_error!( + "Middleware functions can only take one argument of type `axum::middleware::Next`", + ); + }; + } + + quote! {} +} + +fn skip_next_arg(arg: &FnArg, kind: FunctionKind) -> bool { + match kind { + FunctionKind::Handler => true, + FunctionKind::Middleware => match arg { + FnArg::Receiver(_) => true, + FnArg::Typed(pat_type) => { + if let Type::Path(type_path) = &*pat_type.ty { + type_path + .path + .segments + .last() + .map_or(true, |path_segment| path_segment.ident != "Next") + } else { + true + } + } + }, + } +} + #[test] -fn ui() { +fn ui_debug_handler() { crate::run_ui_tests("debug_handler"); } + +#[test] +fn ui_debug_middleware() { + crate::run_ui_tests("debug_middleware"); +} diff --git a/axum-macros/src/from_ref.rs b/axum-macros/src/from_ref.rs index 2ab69eb54d..1a27765a4f 100644 --- a/axum-macros/src/from_ref.rs +++ b/axum-macros/src/from_ref.rs @@ -54,7 +54,7 @@ fn expand_field(state: &Ident, idx: usize, field: &Field) -> TokenStream { }; quote_spanned! {span=> - #[allow(clippy::clone_on_copy)] + #[allow(clippy::clone_on_copy, clippy::clone_on_ref_ptr)] impl ::axum::extract::FromRef<#state> for #field_ty { fn from_ref(state: &#state) -> Self { #body diff --git a/axum-macros/src/from_request.rs b/axum-macros/src/from_request.rs index 474dd0cd65..145fdad361 100644 --- a/axum-macros/src/from_request.rs +++ b/axum-macros/src/from_request.rs @@ -180,7 +180,7 @@ pub(crate) fn expand(item: syn::Item, tr: Trait) -> syn::Result { variants, } = item; - let generics_error = format!("`#[derive({tr})] on enums don't support generics"); + let generics_error = format!("`#[derive({tr})]` on enums don't support generics"); if !generics.params.is_empty() { return Err(syn::Error::new_spanned(generics, generics_error)); @@ -290,11 +290,7 @@ fn parse_single_generic_type_on_struct( let field = fields_unnamed.unnamed.first().unwrap(); if let syn::Type::Path(type_path) = &field.ty { - if type_path - .path - .get_ident() - .map_or(true, |field_type_ident| field_type_ident != ty_ident) - { + if type_path.path.get_ident() != Some(ty_ident) { return Err(syn::Error::new_spanned( type_path, format_args!( @@ -373,7 +369,6 @@ fn impl_struct_by_extracting_each_field( Ok(match tr { Trait::FromRequest => quote! { - #[::axum::async_trait] #[automatically_derived] impl<#impl_generics> ::axum::extract::FromRequest<#trait_generics> for #ident where @@ -390,7 +385,6 @@ fn impl_struct_by_extracting_each_field( } }, Trait::FromRequestParts => quote! { - #[::axum::async_trait] #[automatically_derived] impl<#impl_generics> ::axum::extract::FromRequestParts<#trait_generics> for #ident where @@ -435,7 +429,7 @@ fn extract_fields( } } - fn into_inner(via: Option<(attr::kw::via, syn::Path)>, ty_span: Span) -> TokenStream { + fn into_inner(via: &Option<(attr::kw::via, syn::Path)>, ty_span: Span) -> TokenStream { if let Some((_, path)) = via { let span = path.span(); quote_spanned! {span=> @@ -448,6 +442,23 @@ fn extract_fields( } } + fn into_outer( + via: &Option<(attr::kw::via, syn::Path)>, + ty_span: Span, + field_ty: &Type, + ) -> TokenStream { + if let Some((_, path)) = via { + let span = path.span(); + quote_spanned! {span=> + #path<#field_ty> + } + } else { + quote_spanned! {ty_span=> + #field_ty + } + } + } + let mut fields_iter = fields.iter(); let last = match tr { @@ -464,16 +475,17 @@ fn extract_fields( let member = member(field, index); let ty_span = field.ty.span(); - let into_inner = into_inner(via, ty_span); + let into_inner = into_inner(&via, ty_span); if peel_option(&field.ty).is_some() { + let field_ty = into_outer(&via, ty_span, peel_option(&field.ty).unwrap()); let tokens = match tr { Trait::FromRequest => { quote_spanned! {ty_span=> #member: { let (mut parts, body) = req.into_parts(); let value = - ::axum::extract::FromRequestParts::from_request_parts( + <#field_ty as ::axum::extract::FromRequestParts<_>>::from_request_parts( &mut parts, state, ) @@ -488,7 +500,7 @@ fn extract_fields( Trait::FromRequestParts => { quote_spanned! {ty_span=> #member: { - ::axum::extract::FromRequestParts::from_request_parts( + <#field_ty as ::axum::extract::FromRequestParts<_>>::from_request_parts( parts, state, ) @@ -501,13 +513,14 @@ fn extract_fields( }; Ok(tokens) } else if peel_result_ok(&field.ty).is_some() { + let field_ty = into_outer(&via,ty_span, peel_result_ok(&field.ty).unwrap()); let tokens = match tr { Trait::FromRequest => { quote_spanned! {ty_span=> #member: { let (mut parts, body) = req.into_parts(); let value = - ::axum::extract::FromRequestParts::from_request_parts( + <#field_ty as ::axum::extract::FromRequestParts<_>>::from_request_parts( &mut parts, state, ) @@ -521,7 +534,7 @@ fn extract_fields( Trait::FromRequestParts => { quote_spanned! {ty_span=> #member: { - ::axum::extract::FromRequestParts::from_request_parts( + <#field_ty as ::axum::extract::FromRequestParts<_>>::from_request_parts( parts, state, ) @@ -533,6 +546,7 @@ fn extract_fields( }; Ok(tokens) } else { + let field_ty = into_outer(&via,ty_span,&field.ty); let map_err = if let Some(rejection) = rejection { quote! { <#rejection as ::std::convert::From<_>>::from } } else { @@ -545,7 +559,7 @@ fn extract_fields( #member: { let (mut parts, body) = req.into_parts(); let value = - ::axum::extract::FromRequestParts::from_request_parts( + <#field_ty as ::axum::extract::FromRequestParts<_>>::from_request_parts( &mut parts, state, ) @@ -560,7 +574,7 @@ fn extract_fields( Trait::FromRequestParts => { quote_spanned! {ty_span=> #member: { - ::axum::extract::FromRequestParts::from_request_parts( + <#field_ty as ::axum::extract::FromRequestParts<_>>::from_request_parts( parts, state, ) @@ -582,26 +596,29 @@ fn extract_fields( let member = member(field, fields.len() - 1); let ty_span = field.ty.span(); - let into_inner = into_inner(via, ty_span); + let into_inner = into_inner(&via, ty_span); let item = if peel_option(&field.ty).is_some() { + let field_ty = into_outer(&via, ty_span, peel_option(&field.ty).unwrap()); quote_spanned! {ty_span=> #member: { - ::axum::extract::FromRequest::from_request(req, state) + <#field_ty as ::axum::extract::FromRequest<_, _>>::from_request(req, state) .await .ok() .map(#into_inner) }, } } else if peel_result_ok(&field.ty).is_some() { + let field_ty = into_outer(&via, ty_span, peel_result_ok(&field.ty).unwrap()); quote_spanned! {ty_span=> #member: { - ::axum::extract::FromRequest::from_request(req, state) + <#field_ty as ::axum::extract::FromRequest<_, _>>::from_request(req, state) .await .map(#into_inner) }, } } else { + let field_ty = into_outer(&via, ty_span, &field.ty); let map_err = if let Some(rejection) = rejection { quote! { <#rejection as ::std::convert::From<_>>::from } } else { @@ -610,7 +627,7 @@ fn extract_fields( quote_spanned! {ty_span=> #member: { - ::axum::extract::FromRequest::from_request(req, state) + <#field_ty as ::axum::extract::FromRequest<_, _>>::from_request(req, state) .await .map(#into_inner) .map_err(#map_err)? @@ -807,7 +824,6 @@ fn impl_struct_by_extracting_all_at_once( let tokens = match tr { Trait::FromRequest => { quote_spanned! {path_span=> - #[::axum::async_trait] #[automatically_derived] impl<#impl_generics> ::axum::extract::FromRequest<#trait_generics> for #ident #ident_generics where @@ -821,7 +837,7 @@ fn impl_struct_by_extracting_all_at_once( req: ::axum::http::Request<::axum::body::Body>, state: &#state, ) -> ::std::result::Result { - ::axum::extract::FromRequest::from_request(req, state) + <#via_path<#via_type_generics> as ::axum::extract::FromRequest<_, _>>::from_request(req, state) .await .map(|#via_path(value)| #value_to_self) .map_err(#map_err) @@ -831,7 +847,6 @@ fn impl_struct_by_extracting_all_at_once( } Trait::FromRequestParts => { quote_spanned! {path_span=> - #[::axum::async_trait] #[automatically_derived] impl<#impl_generics> ::axum::extract::FromRequestParts<#trait_generics> for #ident #ident_generics where @@ -845,7 +860,7 @@ fn impl_struct_by_extracting_all_at_once( parts: &mut ::axum::http::request::Parts, state: &#state, ) -> ::std::result::Result { - ::axum::extract::FromRequestParts::from_request_parts(parts, state) + <#via_path<#via_type_generics> as ::axum::extract::FromRequestParts<_>>::from_request_parts(parts, state) .await .map(|#via_path(value)| #value_to_self) .map_err(#map_err) @@ -920,7 +935,6 @@ fn impl_enum_by_extracting_all_at_once( let tokens = match tr { Trait::FromRequest => { quote_spanned! {path_span=> - #[::axum::async_trait] #[automatically_derived] impl<#impl_generics> ::axum::extract::FromRequest<#trait_generics> for #ident where @@ -932,7 +946,7 @@ fn impl_enum_by_extracting_all_at_once( req: ::axum::http::Request<::axum::body::Body>, state: &#state, ) -> ::std::result::Result { - ::axum::extract::FromRequest::from_request(req, state) + <#path::<#ident> as ::axum::extract::FromRequest<_, _>>::from_request(req, state) .await .map(|#path(inner)| inner) .map_err(#map_err) @@ -942,7 +956,6 @@ fn impl_enum_by_extracting_all_at_once( } Trait::FromRequestParts => { quote_spanned! {path_span=> - #[::axum::async_trait] #[automatically_derived] impl<#impl_generics> ::axum::extract::FromRequestParts<#trait_generics> for #ident where @@ -954,7 +967,7 @@ fn impl_enum_by_extracting_all_at_once( parts: &mut ::axum::http::request::Parts, state: &#state, ) -> ::std::result::Result { - ::axum::extract::FromRequestParts::from_request_parts(parts, state) + <#path::<#ident> as ::axum::extract::FromRequestParts<_>>::from_request_parts(parts, state) .await .map(|#path(inner)| inner) .map_err(#map_err) @@ -1003,7 +1016,7 @@ fn infer_state_type_from_field_attributes(fields: &Fields) -> impl Iterator { Box::new(fields_named.named.iter().filter_map(|field| { - // TODO(david): its a little wasteful to parse the attributes again here + // TODO(david): it's a little wasteful to parse the attributes again here // ideally we should parse things once and pass the data down let FromRequestFieldAttrs { via } = parse_attrs("from_request", &field.attrs).ok()?; @@ -1013,7 +1026,7 @@ fn infer_state_type_from_field_attributes(fields: &Fields) -> impl Iterator { Box::new(fields_unnamed.unnamed.iter().filter_map(|field| { - // TODO(david): its a little wasteful to parse the attributes again here + // TODO(david): it's a little wasteful to parse the attributes again here // ideally we should parse things once and pass the data down let FromRequestFieldAttrs { via } = parse_attrs("from_request", &field.attrs).ok()?; diff --git a/axum-macros/src/lib.rs b/axum-macros/src/lib.rs index 36f9b50c2a..f5aeaab748 100644 --- a/axum-macros/src/lib.rs +++ b/axum-macros/src/lib.rs @@ -15,7 +15,6 @@ clippy::needless_borrow, clippy::match_wildcard_for_single_variants, clippy::if_let_mutex, - clippy::mismatched_target_os, clippy::await_holding_lock, clippy::match_on_vec_items, clippy::imprecise_flops, @@ -44,6 +43,7 @@ #![cfg_attr(test, allow(clippy::float_cmp))] #![cfg_attr(not(test), warn(clippy::print_stdout, clippy::dbg_macro))] +use debug_handler::FunctionKind; use proc_macro::TokenStream; use quote::{quote, ToTokens}; use syn::{parse::Parse, Type}; @@ -233,6 +233,54 @@ use from_request::Trait::{FromRequest, FromRequestParts}; /// } /// ``` /// +/// ## Concrete state +/// +/// If the extraction can be done only for a concrete state, that type can be specified with +/// `#[from_request(state(YourState))]`: +/// +/// ``` +/// use axum::extract::{FromRequest, FromRequestParts}; +/// +/// #[derive(Clone)] +/// struct CustomState; +/// +/// struct MyInnerType; +/// +/// impl FromRequestParts for MyInnerType { +/// // ... +/// # type Rejection = (); +/// +/// # async fn from_request_parts( +/// # _parts: &mut axum::http::request::Parts, +/// # _state: &CustomState +/// # ) -> Result { +/// # todo!() +/// # } +/// } +/// +/// #[derive(FromRequest)] +/// #[from_request(state(CustomState))] +/// struct MyExtractor { +/// custom: MyInnerType, +/// body: String, +/// } +/// ``` +/// +/// This is not needed for a `State` as the type is inferred in that case. +/// +/// ``` +/// use axum::extract::{FromRequest, FromRequestParts, State}; +/// +/// #[derive(Clone)] +/// struct CustomState; +/// +/// #[derive(FromRequest)] +/// struct MyExtractor { +/// custom: State, +/// body: String, +/// } +/// ``` +/// /// # The whole type at once /// /// By using `#[from_request(via(...))]` on the container you can extract the whole type at once, @@ -349,7 +397,7 @@ use from_request::Trait::{FromRequest, FromRequestParts}; /// /// # Known limitations /// -/// Generics are only supported on tuple structs with exactly on field. Thus this doesn't work +/// Generics are only supported on tuple structs with exactly one field. Thus this doesn't work /// /// ```compile_fail /// #[derive(axum_macros::FromRequest)] @@ -415,7 +463,7 @@ pub fn derive_from_request_parts(item: TokenStream) -> TokenStream { expand_with(item, |item| from_request::expand(item, FromRequestParts)) } -/// Generates better error messages when applied handler functions. +/// Generates better error messages when applied to handler functions. /// /// While using [`axum`], you can get long error messages for simple mistakes. For example: /// @@ -466,17 +514,15 @@ pub fn derive_from_request_parts(item: TokenStream) -> TokenStream { /// /// As the error message says, handler function needs to be async. /// -/// ``` +/// ```no_run /// use axum::{routing::get, Router, debug_handler}; /// /// #[tokio::main] /// async fn main() { -/// # async { /// let app = Router::new().route("/", get(handler)); /// /// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap(); /// axum::serve(listener, app).await.unwrap(); -/// # }; /// } /// /// #[debug_handler] @@ -569,7 +615,65 @@ pub fn debug_handler(_attr: TokenStream, input: TokenStream) -> TokenStream { return input; #[cfg(debug_assertions)] - return expand_attr_with(_attr, input, debug_handler::expand); + return expand_attr_with(_attr, input, |attrs, item_fn| { + debug_handler::expand(attrs, item_fn, FunctionKind::Handler) + }); +} + +/// Generates better error messages when applied to middleware functions. +/// +/// This works similarly to [`#[debug_handler]`](macro@debug_handler) except for middleware using +/// [`axum::middleware::from_fn`]. +/// +/// # Example +/// +/// ```no_run +/// use axum::{ +/// routing::get, +/// extract::Request, +/// response::Response, +/// Router, +/// middleware::{self, Next}, +/// debug_middleware, +/// }; +/// +/// #[tokio::main] +/// async fn main() { +/// let app = Router::new() +/// .route("/", get(|| async {})) +/// .layer(middleware::from_fn(my_middleware)); +/// +/// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap(); +/// axum::serve(listener, app).await.unwrap(); +/// } +/// +/// // if this wasn't a valid middleware function #[debug_middleware] would +/// // improve compile error +/// #[debug_middleware] +/// async fn my_middleware( +/// request: Request, +/// next: Next, +/// ) -> Response { +/// next.run(request).await +/// } +/// ``` +/// +/// # Performance +/// +/// This macro has no effect when compiled with the release profile. (eg. `cargo build --release`) +/// +/// [`axum`]: https://docs.rs/axum/latest +/// [`axum::middleware::from_fn`]: https://docs.rs/axum/0.7/axum/middleware/fn.from_fn.html +/// [`debug_middleware`]: macro@debug_middleware +#[proc_macro_attribute] +pub fn debug_middleware(_attr: TokenStream, input: TokenStream) -> TokenStream { + #[cfg(not(debug_assertions))] + return input; + + #[cfg(debug_assertions)] + return expand_attr_with(_attr, input, |attrs, item_fn| { + debug_handler::expand(attrs, item_fn, FunctionKind::Middleware) + }); } /// Private API: Do no use this! diff --git a/axum-macros/src/typed_path.rs b/axum-macros/src/typed_path.rs index 61db3eb9ae..fa272252be 100644 --- a/axum-macros/src/typed_path.rs +++ b/axum-macros/src/typed_path.rs @@ -133,7 +133,6 @@ fn expand_named_fields( let map_err_rejection = map_err_rejection(&rejection); let from_request_impl = quote! { - #[::axum::async_trait] #[automatically_derived] impl ::axum::extract::FromRequestParts for #ident where @@ -238,7 +237,6 @@ fn expand_unnamed_fields( let map_err_rejection = map_err_rejection(&rejection); let from_request_impl = quote! { - #[::axum::async_trait] #[automatically_derived] impl ::axum::extract::FromRequestParts for #ident where @@ -322,7 +320,6 @@ fn expand_unit_fields( }; let from_request_impl = quote! { - #[::axum::async_trait] #[automatically_derived] impl ::axum::extract::FromRequestParts for #ident where @@ -386,8 +383,12 @@ fn parse_path(path: &LitStr) -> syn::Result> { .split('/') .map(|segment| { if let Some(capture) = segment - .strip_prefix(':') - .or_else(|| segment.strip_prefix('*')) + .strip_prefix('{') + .and_then(|segment| segment.strip_suffix('}')) + .and_then(|segment| { + (!segment.starts_with('{') && !segment.ends_with('}')).then_some(segment) + }) + .map(|capture| capture.strip_prefix('*').unwrap_or(capture)) { Ok(Segment::Capture(capture.to_owned(), path.span())) } else { diff --git a/axum-macros/src/with_position.rs b/axum-macros/src/with_position.rs index 2e0caa5022..e064a3f01e 100644 --- a/axum-macros/src/with_position.rs +++ b/axum-macros/src/with_position.rs @@ -40,10 +40,10 @@ impl WithPosition where I: Iterator, { - pub(crate) fn new(iter: I) -> WithPosition { + pub(crate) fn new(iter: impl IntoIterator) -> WithPosition { WithPosition { handled_first: false, - peekable: iter.fuse().peekable(), + peekable: iter.into_iter().fuse().peekable(), } } } diff --git a/axum-macros/tests/debug_handler/fail/argument_not_extractor.rs b/axum-macros/tests/debug_handler/fail/argument_not_extractor.rs index 2d386c82f5..85a4c1d283 100644 --- a/axum-macros/tests/debug_handler/fail/argument_not_extractor.rs +++ b/axum-macros/tests/debug_handler/fail/argument_not_extractor.rs @@ -1,4 +1,3 @@ -#![feature(diagnostic_namespace)] use axum_macros::debug_handler; #[debug_handler] diff --git a/axum-macros/tests/debug_handler/fail/argument_not_extractor.stderr b/axum-macros/tests/debug_handler/fail/argument_not_extractor.stderr index d946782586..f5687df0e8 100644 --- a/axum-macros/tests/debug_handler/fail/argument_not_extractor.stderr +++ b/axum-macros/tests/debug_handler/fail/argument_not_extractor.stderr @@ -1,24 +1,24 @@ -error[E0277]: the trait bound `bool: FromRequestParts<()>` is not satisfied - --> tests/debug_handler/fail/argument_not_extractor.rs:5:24 +error[E0277]: the trait bound `bool: FromRequest<(), axum_core::extract::private::ViaParts>` is not satisfied + --> tests/debug_handler/fail/argument_not_extractor.rs:4:24 | -5 | async fn handler(_foo: bool) {} - | ^^^^ the trait `FromRequestParts<()>` is not implemented for `bool` +4 | async fn handler(_foo: bool) {} + | ^^^^ the trait `FromRequestParts<()>` is not implemented for `bool`, which is required by `bool: FromRequest<(), _>` | = note: Function argument is not a valid axum extractor. See `https://docs.rs/axum/0.7/axum/extract/index.html` for details = help: the following other types implement trait `FromRequestParts`: - > - as FromRequestParts> - > - > - > - > - > - as FromRequestParts> + `()` implements `FromRequestParts` + `(T1, T2)` implements `FromRequestParts` + `(T1, T2, T3)` implements `FromRequestParts` + `(T1, T2, T3, T4)` implements `FromRequestParts` + `(T1, T2, T3, T4, T5)` implements `FromRequestParts` + `(T1, T2, T3, T4, T5, T6)` implements `FromRequestParts` + `(T1, T2, T3, T4, T5, T6, T7)` implements `FromRequestParts` + `(T1, T2, T3, T4, T5, T6, T7, T8)` implements `FromRequestParts` and $N others = note: required for `bool` to implement `FromRequest<(), axum_core::extract::private::ViaParts>` note: required by a bound in `__axum_macros_check_handler_0_from_request_check` - --> tests/debug_handler/fail/argument_not_extractor.rs:5:24 + --> tests/debug_handler/fail/argument_not_extractor.rs:4:24 | -5 | async fn handler(_foo: bool) {} +4 | async fn handler(_foo: bool) {} | ^^^^ required by this bound in `__axum_macros_check_handler_0_from_request_check` diff --git a/axum-macros/tests/debug_handler/fail/extension_not_clone.rs b/axum-macros/tests/debug_handler/fail/extension_not_clone.rs new file mode 100644 index 0000000000..6bed79e195 --- /dev/null +++ b/axum-macros/tests/debug_handler/fail/extension_not_clone.rs @@ -0,0 +1,9 @@ +use axum::extract::Extension; +use axum_macros::debug_handler; + +struct NonCloneType; + +#[debug_handler] +async fn test_extension_non_clone(_: Extension) {} + +fn main() {} diff --git a/axum-macros/tests/debug_handler/fail/extension_not_clone.stderr b/axum-macros/tests/debug_handler/fail/extension_not_clone.stderr new file mode 100644 index 0000000000..81bec91835 --- /dev/null +++ b/axum-macros/tests/debug_handler/fail/extension_not_clone.stderr @@ -0,0 +1,28 @@ +error[E0277]: the trait bound `NonCloneType: Clone` is not satisfied + --> tests/debug_handler/fail/extension_not_clone.rs:7:38 + | +7 | async fn test_extension_non_clone(_: Extension) {} + | ^^^^^^^^^^^^^^^^^^^^^^^ the trait `Clone` is not implemented for `NonCloneType`, which is required by `Extension: FromRequest<(), _>` + | + = help: the following other types implement trait `FromRequest`: + (T1, T2) + (T1, T2, T3) + (T1, T2, T3, T4) + (T1, T2, T3, T4, T5) + (T1, T2, T3, T4, T5, T6) + (T1, T2, T3, T4, T5, T6, T7) + (T1, T2, T3, T4, T5, T6, T7, T8) + (T1, T2, T3, T4, T5, T6, T7, T8, T9) + and $N others + = note: required for `Extension` to implement `FromRequestParts<()>` + = note: required for `Extension` to implement `FromRequest<(), axum_core::extract::private::ViaParts>` +note: required by a bound in `__axum_macros_check_test_extension_non_clone_0_from_request_check` + --> tests/debug_handler/fail/extension_not_clone.rs:7:38 + | +7 | async fn test_extension_non_clone(_: Extension) {} + | ^^^^^^^^^^^^^^^^^^^^^^^ required by this bound in `__axum_macros_check_test_extension_non_clone_0_from_request_check` +help: consider annotating `NonCloneType` with `#[derive(Clone)]` + | +4 + #[derive(Clone)] +5 | struct NonCloneType; + | diff --git a/axum-macros/tests/debug_handler/fail/extract_self_mut.rs b/axum-macros/tests/debug_handler/fail/extract_self_mut.rs index 21ae99d6b8..eb17c1df52 100644 --- a/axum-macros/tests/debug_handler/fail/extract_self_mut.rs +++ b/axum-macros/tests/debug_handler/fail/extract_self_mut.rs @@ -1,12 +1,8 @@ -use axum::{ - async_trait, - extract::{Request, FromRequest}, -}; +use axum::extract::{FromRequest, Request}; use axum_macros::debug_handler; struct A; -#[async_trait] impl FromRequest for A where S: Send + Sync, diff --git a/axum-macros/tests/debug_handler/fail/extract_self_mut.stderr b/axum-macros/tests/debug_handler/fail/extract_self_mut.stderr index 595786bf4e..0610a22a3b 100644 --- a/axum-macros/tests/debug_handler/fail/extract_self_mut.stderr +++ b/axum-macros/tests/debug_handler/fail/extract_self_mut.stderr @@ -1,5 +1,5 @@ error: Handlers must only take owned values - --> tests/debug_handler/fail/extract_self_mut.rs:23:22 + --> tests/debug_handler/fail/extract_self_mut.rs:19:22 | -23 | async fn handler(&mut self) {} +19 | async fn handler(&mut self) {} | ^^^^^^^^^ diff --git a/axum-macros/tests/debug_handler/fail/extract_self_ref.rs b/axum-macros/tests/debug_handler/fail/extract_self_ref.rs index 8e32811994..d70c5f2318 100644 --- a/axum-macros/tests/debug_handler/fail/extract_self_ref.rs +++ b/axum-macros/tests/debug_handler/fail/extract_self_ref.rs @@ -1,12 +1,8 @@ -use axum::{ - async_trait, - extract::{Request, FromRequest}, -}; +use axum::extract::{FromRequest, Request}; use axum_macros::debug_handler; struct A; -#[async_trait] impl FromRequest for A where S: Send + Sync, diff --git a/axum-macros/tests/debug_handler/fail/extract_self_ref.stderr b/axum-macros/tests/debug_handler/fail/extract_self_ref.stderr index 4c0b4950c7..d475c5092f 100644 --- a/axum-macros/tests/debug_handler/fail/extract_self_ref.stderr +++ b/axum-macros/tests/debug_handler/fail/extract_self_ref.stderr @@ -1,5 +1,5 @@ error: Handlers must only take owned values - --> tests/debug_handler/fail/extract_self_ref.rs:23:22 + --> tests/debug_handler/fail/extract_self_ref.rs:19:22 | -23 | async fn handler(&self) {} +19 | async fn handler(&self) {} | ^^^^^ diff --git a/axum-macros/tests/debug_handler/fail/json_not_deserialize.rs b/axum-macros/tests/debug_handler/fail/json_not_deserialize.rs index d894b0bd96..b6be5f87d4 100644 --- a/axum-macros/tests/debug_handler/fail/json_not_deserialize.rs +++ b/axum-macros/tests/debug_handler/fail/json_not_deserialize.rs @@ -4,6 +4,6 @@ use axum_macros::debug_handler; struct Struct {} #[debug_handler] -async fn handler(foo: Json) {} +async fn handler(_foo: Json) {} fn main() {} diff --git a/axum-macros/tests/debug_handler/fail/json_not_deserialize.stderr b/axum-macros/tests/debug_handler/fail/json_not_deserialize.stderr index ee30bfed0d..afda86b65d 100644 --- a/axum-macros/tests/debug_handler/fail/json_not_deserialize.stderr +++ b/axum-macros/tests/debug_handler/fail/json_not_deserialize.stderr @@ -1,20 +1,51 @@ error[E0277]: the trait bound `for<'de> Struct: serde::de::Deserialize<'de>` is not satisfied - --> tests/debug_handler/fail/json_not_deserialize.rs:7:23 + --> tests/debug_handler/fail/json_not_deserialize.rs:7:24 | -7 | async fn handler(foo: Json) {} - | ^^^^^^^^^^^^ the trait `for<'de> serde::de::Deserialize<'de>` is not implemented for `Struct` +7 | async fn handler(_foo: Json) {} + | ^^^^^^^^^^^^ the trait `for<'de> serde::de::Deserialize<'de>` is not implemented for `Struct`, which is required by `Json: FromRequest<()>` | + = note: for local types consider adding `#[derive(serde::Deserialize)]` to your `Struct` type + = note: for types from other crates check whether the crate offers a `serde` feature flag = help: the following other types implement trait `serde::de::Deserialize<'de>`: - bool - char - isize - i8 - i16 - i32 - i64 - i128 + &'a [u8] + &'a serde_json::raw::RawValue + &'a std::path::Path + &'a str + () + (T,) + (T0, T1) + (T0, T1, T2) and $N others = note: required for `Struct` to implement `serde::de::DeserializeOwned` = note: required for `Json` to implement `FromRequest<()>` = help: see issue #48214 - = help: add `#![feature(trivial_bounds)]` to the crate attributes to enable +help: add `#![feature(trivial_bounds)]` to the crate attributes to enable + | +1 + #![feature(trivial_bounds)] + | + +error[E0277]: the trait bound `for<'de> Struct: serde::de::Deserialize<'de>` is not satisfied + --> tests/debug_handler/fail/json_not_deserialize.rs:7:24 + | +7 | async fn handler(_foo: Json) {} + | ^^^^^^^^^^^^ the trait `for<'de> serde::de::Deserialize<'de>` is not implemented for `Struct`, which is required by `Json: FromRequest<()>` + | + = note: for local types consider adding `#[derive(serde::Deserialize)]` to your `Struct` type + = note: for types from other crates check whether the crate offers a `serde` feature flag + = help: the following other types implement trait `serde::de::Deserialize<'de>`: + &'a [u8] + &'a serde_json::raw::RawValue + &'a std::path::Path + &'a str + () + (T,) + (T0, T1) + (T0, T1, T2) + and $N others + = note: required for `Struct` to implement `serde::de::DeserializeOwned` + = note: required for `Json` to implement `FromRequest<()>` +note: required by a bound in `__axum_macros_check_handler_0_from_request_check` + --> tests/debug_handler/fail/json_not_deserialize.rs:7:24 + | +7 | async fn handler(_foo: Json) {} + | ^^^^^^^^^^^^ required by this bound in `__axum_macros_check_handler_0_from_request_check` diff --git a/axum-macros/tests/debug_handler/fail/multiple_request_consumers.rs b/axum-macros/tests/debug_handler/fail/multiple_request_consumers.rs index 4c86ceae69..77235dfdb6 100644 --- a/axum-macros/tests/debug_handler/fail/multiple_request_consumers.rs +++ b/axum-macros/tests/debug_handler/fail/multiple_request_consumers.rs @@ -1,5 +1,9 @@ +use axum::{ + body::Bytes, + http::{Method, Uri}, + Json, +}; use axum_macros::debug_handler; -use axum::{Json, body::Bytes, http::{Method, Uri}}; #[debug_handler] async fn one(_: Json<()>, _: String, _: Uri) {} diff --git a/axum-macros/tests/debug_handler/fail/multiple_request_consumers.stderr b/axum-macros/tests/debug_handler/fail/multiple_request_consumers.stderr index ba2ff7adff..011ce89934 100644 --- a/axum-macros/tests/debug_handler/fail/multiple_request_consumers.stderr +++ b/axum-macros/tests/debug_handler/fail/multiple_request_consumers.stderr @@ -1,11 +1,11 @@ error: Can't have two extractors that consume the request body. `Json<_>` and `String` both do that. - --> tests/debug_handler/fail/multiple_request_consumers.rs:5:14 + --> tests/debug_handler/fail/multiple_request_consumers.rs:9:14 | -5 | async fn one(_: Json<()>, _: String, _: Uri) {} +9 | async fn one(_: Json<()>, _: String, _: Uri) {} | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ error: Can't have more than one extractor that consume the request body. `Json<_>`, `Bytes`, and `String` all do that. - --> tests/debug_handler/fail/multiple_request_consumers.rs:8:14 - | -8 | async fn two(_: Json<()>, _: Method, _: Bytes, _: Uri, _: String) {} - | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + --> tests/debug_handler/fail/multiple_request_consumers.rs:12:14 + | +12 | async fn two(_: Json<()>, _: Method, _: Bytes, _: Uri, _: String) {} + | ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/axum-macros/tests/debug_handler/fail/not_send.stderr b/axum-macros/tests/debug_handler/fail/not_send.stderr index 82f5bb2b3e..13ab83ddc0 100644 --- a/axum-macros/tests/debug_handler/fail/not_send.stderr +++ b/axum-macros/tests/debug_handler/fail/not_send.stderr @@ -4,7 +4,7 @@ error: future cannot be sent between threads safely 3 | #[debug_handler] | ^^^^^^^^^^^^^^^^ future returned by `handler` is not `Send` | - = help: within `impl Future`, the trait `Send` is not implemented for `Rc<()>` + = help: within `impl Future`, the trait `Send` is not implemented for `Rc<()>`, which is required by `impl Future: Send` note: future is not `Send` as this value is used across an await --> tests/debug_handler/fail/not_send.rs:6:14 | @@ -12,8 +12,6 @@ note: future is not `Send` as this value is used across an await | --- has type `Rc<()>` which is not `Send` 6 | async {}.await; | ^^^^^ await occurs here, with `_rc` maybe used later -7 | } - | - `_rc` is later dropped here note: required by a bound in `check` --> tests/debug_handler/fail/not_send.rs:3:1 | diff --git a/axum-macros/tests/debug_handler/fail/output_tuple_too_many.rs b/axum-macros/tests/debug_handler/fail/output_tuple_too_many.rs new file mode 100644 index 0000000000..ea15e66a37 --- /dev/null +++ b/axum-macros/tests/debug_handler/fail/output_tuple_too_many.rs @@ -0,0 +1,28 @@ +use axum::response::AppendHeaders; + +#[axum::debug_handler] +async fn handler() -> ( + axum::http::StatusCode, + AppendHeaders<[(axum::http::HeaderName, &'static str); 1]>, + AppendHeaders<[(axum::http::HeaderName, &'static str); 1]>, + AppendHeaders<[(axum::http::HeaderName, &'static str); 1]>, + AppendHeaders<[(axum::http::HeaderName, &'static str); 1]>, + AppendHeaders<[(axum::http::HeaderName, &'static str); 1]>, + AppendHeaders<[(axum::http::HeaderName, &'static str); 1]>, + AppendHeaders<[(axum::http::HeaderName, &'static str); 1]>, + AppendHeaders<[(axum::http::HeaderName, &'static str); 1]>, + AppendHeaders<[(axum::http::HeaderName, &'static str); 1]>, + AppendHeaders<[(axum::http::HeaderName, &'static str); 1]>, + AppendHeaders<[(axum::http::HeaderName, &'static str); 1]>, + AppendHeaders<[(axum::http::HeaderName, &'static str); 1]>, + AppendHeaders<[(axum::http::HeaderName, &'static str); 1]>, + AppendHeaders<[(axum::http::HeaderName, &'static str); 1]>, + AppendHeaders<[(axum::http::HeaderName, &'static str); 1]>, + AppendHeaders<[(axum::http::HeaderName, &'static str); 1]>, + AppendHeaders<[(axum::http::HeaderName, &'static str); 1]>, + axum::http::StatusCode, +) { + panic!() +} + +fn main() {} diff --git a/axum-macros/tests/debug_handler/fail/output_tuple_too_many.stderr b/axum-macros/tests/debug_handler/fail/output_tuple_too_many.stderr new file mode 100644 index 0000000000..fb31388a4e --- /dev/null +++ b/axum-macros/tests/debug_handler/fail/output_tuple_too_many.stderr @@ -0,0 +1,12 @@ +error: Cannot return tuples with more than 17 elements + --> tests/debug_handler/fail/output_tuple_too_many.rs:4:20 + | +4 | async fn handler() -> ( + | ____________________^ +5 | | axum::http::StatusCode, +6 | | AppendHeaders<[(axum::http::HeaderName, &'static str); 1]>, +7 | | AppendHeaders<[(axum::http::HeaderName, &'static str); 1]>, +... | +23 | | axum::http::StatusCode, +24 | | ) { + | |_^ diff --git a/axum-macros/tests/debug_handler/fail/returning_request_parts.rs b/axum-macros/tests/debug_handler/fail/returning_request_parts.rs new file mode 100644 index 0000000000..0658dc02bf --- /dev/null +++ b/axum-macros/tests/debug_handler/fail/returning_request_parts.rs @@ -0,0 +1,9 @@ +#[axum::debug_handler] +async fn handler() -> ( + axum::http::request::Parts, // this should be response parts, not request parts + axum::http::StatusCode, +) { + panic!() +} + +fn main() {} diff --git a/axum-macros/tests/debug_handler/fail/returning_request_parts.stderr b/axum-macros/tests/debug_handler/fail/returning_request_parts.stderr new file mode 100644 index 0000000000..440f20fe58 --- /dev/null +++ b/axum-macros/tests/debug_handler/fail/returning_request_parts.stderr @@ -0,0 +1,8 @@ +error[E0308]: mismatched types + --> tests/debug_handler/fail/returning_request_parts.rs:3:5 + | +3 | axum::http::request::Parts, // this should be response parts, not request parts + | ^^^^^^^^^^^^^^^^^^^^^^^^^^ + | | + | expected `axum::http::response::Parts`, found `axum::http::request::Parts` + | expected `axum::http::response::Parts` because of return type diff --git a/axum-macros/tests/debug_handler/fail/single_wrong_return_tuple.rs b/axum-macros/tests/debug_handler/fail/single_wrong_return_tuple.rs new file mode 100644 index 0000000000..452fa21d2e --- /dev/null +++ b/axum-macros/tests/debug_handler/fail/single_wrong_return_tuple.rs @@ -0,0 +1,10 @@ +#![allow(unused_parens)] + +struct NotIntoResponse; + +#[axum::debug_handler] +async fn handler() -> (NotIntoResponse) { + panic!() +} + +fn main() {} diff --git a/axum-macros/tests/debug_handler/fail/single_wrong_return_tuple.stderr b/axum-macros/tests/debug_handler/fail/single_wrong_return_tuple.stderr new file mode 100644 index 0000000000..8909373553 --- /dev/null +++ b/axum-macros/tests/debug_handler/fail/single_wrong_return_tuple.stderr @@ -0,0 +1,21 @@ +error[E0277]: the trait bound `NotIntoResponse: IntoResponse` is not satisfied + --> tests/debug_handler/fail/single_wrong_return_tuple.rs:6:23 + | +6 | async fn handler() -> (NotIntoResponse) { + | ^^^^^^^^^^^^^^^^^ the trait `IntoResponse` is not implemented for `NotIntoResponse` + | + = help: the following other types implement trait `IntoResponse`: + &'static [u8; N] + &'static [u8] + &'static str + () + (R,) + (Response<()>, R) + (Response<()>, T1, R) + (Response<()>, T1, T2, R) + and $N others +note: required by a bound in `__axum_macros_check_handler_into_response::{closure#0}::check` + --> tests/debug_handler/fail/single_wrong_return_tuple.rs:6:23 + | +6 | async fn handler() -> (NotIntoResponse) { + | ^^^^^^^^^^^^^^^^^ required by this bound in `check` diff --git a/axum-macros/tests/debug_handler/fail/too_many_extractors.rs b/axum-macros/tests/debug_handler/fail/too_many_extractors.rs index 441d0f0059..894a4e0d46 100644 --- a/axum-macros/tests/debug_handler/fail/too_many_extractors.rs +++ b/axum-macros/tests/debug_handler/fail/too_many_extractors.rs @@ -1,5 +1,5 @@ -use axum_macros::debug_handler; use axum::http::Uri; +use axum_macros::debug_handler; #[debug_handler] async fn handler( @@ -20,6 +20,7 @@ async fn handler( _e15: Uri, _e16: Uri, _e17: Uri, -) {} +) { +} fn main() {} diff --git a/axum-macros/tests/debug_handler/fail/wrong_order.rs b/axum-macros/tests/debug_handler/fail/wrong_order.rs index 7d22bf5251..8dfd73670f 100644 --- a/axum-macros/tests/debug_handler/fail/wrong_order.rs +++ b/axum-macros/tests/debug_handler/fail/wrong_order.rs @@ -1,5 +1,5 @@ +use axum::{http::Uri, Json}; use axum_macros::debug_handler; -use axum::{Json, http::Uri}; #[debug_handler] async fn one(_: Json<()>, _: Uri) {} diff --git a/axum-macros/tests/debug_handler/fail/wrong_return_tuple.rs b/axum-macros/tests/debug_handler/fail/wrong_return_tuple.rs new file mode 100644 index 0000000000..0b2afa168e --- /dev/null +++ b/axum-macros/tests/debug_handler/fail/wrong_return_tuple.rs @@ -0,0 +1,27 @@ +#![allow(unused_parens)] + +#[axum::debug_handler] +async fn named_type() -> ( + axum::http::StatusCode, + axum::Json<&'static str>, + axum::response::AppendHeaders<[(axum::http::HeaderName, &'static str); 1]>, +) { + panic!() +} + +struct CustomIntoResponse {} +impl axum::response::IntoResponse for CustomIntoResponse { + fn into_response(self) -> axum::response::Response { + todo!() + } +} +#[axum::debug_handler] +async fn custom_type() -> ( + axum::http::StatusCode, + CustomIntoResponse, + axum::response::AppendHeaders<[(axum::http::HeaderName, &'static str); 1]>, +) { + panic!() +} + +fn main() {} diff --git a/axum-macros/tests/debug_handler/fail/wrong_return_tuple.stderr b/axum-macros/tests/debug_handler/fail/wrong_return_tuple.stderr new file mode 100644 index 0000000000..8779d35805 --- /dev/null +++ b/axum-macros/tests/debug_handler/fail/wrong_return_tuple.stderr @@ -0,0 +1,49 @@ +error: `Json<_>` must be the last element in a response tuple + --> tests/debug_handler/fail/wrong_return_tuple.rs:6:5 + | +6 | axum::Json<&'static str>, + | ^^^^^^^^^^^^^^^^^^^^^^^^ + +error[E0277]: the trait bound `CustomIntoResponse: IntoResponseParts` is not satisfied + --> tests/debug_handler/fail/wrong_return_tuple.rs:21:5 + | +21 | CustomIntoResponse, + | ^^^^^^^^^^^^^^^^^^ the trait `IntoResponseParts` is not implemented for `CustomIntoResponse` + | + = help: the following other types implement trait `IntoResponseParts`: + () + (T1, T2) + (T1, T2, T3) + (T1, T2, T3, T4) + (T1, T2, T3, T4, T5) + (T1, T2, T3, T4, T5, T6) + (T1, T2, T3, T4, T5, T6, T7) + (T1, T2, T3, T4, T5, T6, T7, T8) + and $N others + = help: see issue #48214 +help: add `#![feature(trivial_bounds)]` to the crate attributes to enable + | +3 + #![feature(trivial_bounds)] + | + +error[E0277]: the trait bound `CustomIntoResponse: IntoResponseParts` is not satisfied + --> tests/debug_handler/fail/wrong_return_tuple.rs:21:5 + | +21 | CustomIntoResponse, + | ^^^^^^^^^^^^^^^^^^ the trait `IntoResponseParts` is not implemented for `CustomIntoResponse` + | + = help: the following other types implement trait `IntoResponseParts`: + () + (T1, T2) + (T1, T2, T3) + (T1, T2, T3, T4) + (T1, T2, T3, T4, T5) + (T1, T2, T3, T4, T5, T6) + (T1, T2, T3, T4, T5, T6, T7) + (T1, T2, T3, T4, T5, T6, T7, T8) + and $N others +note: required by a bound in `__axum_macros_check_custom_type_into_response_parts_1_check` + --> tests/debug_handler/fail/wrong_return_tuple.rs:21:5 + | +21 | CustomIntoResponse, + | ^^^^^^^^^^^^^^^^^^ required by this bound in `__axum_macros_check_custom_type_into_response_parts_1_check` diff --git a/axum-macros/tests/debug_handler/fail/wrong_return_type.stderr b/axum-macros/tests/debug_handler/fail/wrong_return_type.stderr index cc718aae0c..c305e7e781 100644 --- a/axum-macros/tests/debug_handler/fail/wrong_return_type.stderr +++ b/axum-macros/tests/debug_handler/fail/wrong_return_type.stderr @@ -5,14 +5,14 @@ error[E0277]: the trait bound `bool: IntoResponse` is not satisfied | ^^^^ the trait `IntoResponse` is not implemented for `bool` | = help: the following other types implement trait `IntoResponse`: - Box - Box<[u8]> - axum::body::Bytes - Body - axum::extract::rejection::FailedToBufferBody - axum::extract::rejection::LengthLimitError - axum::extract::rejection::UnknownBodyError - axum::extract::rejection::InvalidUtf8 + &'static [u8; N] + &'static [u8] + &'static str + () + (R,) + (Response<()>, R) + (Response<()>, T1, R) + (Response<()>, T1, T2, R) and $N others note: required by a bound in `__axum_macros_check_handler_into_response::{closure#0}::check` --> tests/debug_handler/fail/wrong_return_type.rs:4:23 diff --git a/axum-macros/tests/debug_handler/pass/impl_into_response.rs b/axum-macros/tests/debug_handler/pass/impl_into_response.rs index 69b884e390..f15f29c4d0 100644 --- a/axum-macros/tests/debug_handler/pass/impl_into_response.rs +++ b/axum-macros/tests/debug_handler/pass/impl_into_response.rs @@ -1,5 +1,5 @@ -use axum_macros::debug_handler; use axum::response::IntoResponse; +use axum_macros::debug_handler; #[debug_handler] async fn handler() -> impl IntoResponse { diff --git a/axum-macros/tests/debug_handler/pass/infer_state.rs b/axum-macros/tests/debug_handler/pass/infer_state.rs index 9f21a8a626..fceeb78acc 100644 --- a/axum-macros/tests/debug_handler/pass/infer_state.rs +++ b/axum-macros/tests/debug_handler/pass/infer_state.rs @@ -1,5 +1,5 @@ -use axum_macros::debug_handler; use axum::extract::State; +use axum_macros::debug_handler; #[debug_handler] async fn handler(_: State) {} @@ -8,22 +8,13 @@ async fn handler(_: State) {} async fn handler_2(_: axum::extract::State) {} #[debug_handler] -async fn handler_3( - _: axum::extract::State, - _: axum::extract::State, -) {} +async fn handler_3(_: axum::extract::State, _: axum::extract::State) {} #[debug_handler] -async fn handler_4( - _: State, - _: State, -) {} +async fn handler_4(_: State, _: State) {} #[debug_handler] -async fn handler_5( - _: axum::extract::State, - _: State, -) {} +async fn handler_5(_: axum::extract::State, _: State) {} #[derive(Clone)] struct AppState; diff --git a/axum-macros/tests/debug_handler/pass/multiple_extractors.rs b/axum-macros/tests/debug_handler/pass/multiple_extractors.rs index 6cc05b5166..e54c43e61a 100644 --- a/axum-macros/tests/debug_handler/pass/multiple_extractors.rs +++ b/axum-macros/tests/debug_handler/pass/multiple_extractors.rs @@ -1,5 +1,5 @@ -use axum_macros::debug_handler; use axum::http::{Method, Uri}; +use axum_macros::debug_handler; #[debug_handler] async fn handler(_one: Method, _two: Uri, _three: String) {} diff --git a/axum-macros/tests/debug_handler/pass/ready.rs b/axum-macros/tests/debug_handler/pass/ready.rs index 4ee73e99c8..d705b8fefb 100644 --- a/axum-macros/tests/debug_handler/pass/ready.rs +++ b/axum-macros/tests/debug_handler/pass/ready.rs @@ -1,5 +1,5 @@ use axum_macros::debug_handler; -use std::future::{Ready, ready}; +use std::future::{ready, Ready}; #[debug_handler] fn handler() -> Ready<()> { diff --git a/axum-macros/tests/debug_handler/pass/result_impl_into_response.rs b/axum-macros/tests/debug_handler/pass/result_impl_into_response.rs index 782fc9301c..f23c9b627c 100644 --- a/axum-macros/tests/debug_handler/pass/result_impl_into_response.rs +++ b/axum-macros/tests/debug_handler/pass/result_impl_into_response.rs @@ -1,4 +1,4 @@ -use axum::{async_trait, extract::FromRequestParts, http::request::Parts, response::IntoResponse}; +use axum::{extract::FromRequestParts, http::request::Parts, response::IntoResponse}; use axum_macros::debug_handler; fn main() {} @@ -115,7 +115,6 @@ impl A { } } -#[async_trait] impl FromRequestParts for A where S: Send + Sync, diff --git a/axum-macros/tests/debug_handler/pass/self_receiver.rs b/axum-macros/tests/debug_handler/pass/self_receiver.rs index 9b72284502..3fbcc4e03b 100644 --- a/axum-macros/tests/debug_handler/pass/self_receiver.rs +++ b/axum-macros/tests/debug_handler/pass/self_receiver.rs @@ -1,12 +1,8 @@ -use axum::{ - async_trait, - extract::{Request, FromRequest}, -}; +use axum::extract::{FromRequest, Request}; use axum_macros::debug_handler; struct A; -#[async_trait] impl FromRequest for A where S: Send + Sync, @@ -18,7 +14,6 @@ where } } -#[async_trait] impl FromRequest for Box where S: Send + Sync, diff --git a/axum-macros/tests/debug_handler/pass/set_state.rs b/axum-macros/tests/debug_handler/pass/set_state.rs index 60a7a3304e..72bba5aede 100644 --- a/axum-macros/tests/debug_handler/pass/set_state.rs +++ b/axum-macros/tests/debug_handler/pass/set_state.rs @@ -1,6 +1,5 @@ +use axum::extract::{FromRef, FromRequest, Request}; use axum_macros::debug_handler; -use axum::extract::{Request, FromRef, FromRequest}; -use axum::async_trait; #[debug_handler(state = AppState)] async fn handler(_: A) {} @@ -10,7 +9,6 @@ struct AppState; struct A; -#[async_trait] impl FromRequest for A where S: Send + Sync, diff --git a/axum-macros/tests/debug_handler/pass/state_and_body.rs b/axum-macros/tests/debug_handler/pass/state_and_body.rs index f348360b3a..629023aa03 100644 --- a/axum-macros/tests/debug_handler/pass/state_and_body.rs +++ b/axum-macros/tests/debug_handler/pass/state_and_body.rs @@ -1,5 +1,5 @@ +use axum::{extract::Request, extract::State}; use axum_macros::debug_handler; -use axum::{extract::State, extract::Request}; #[debug_handler(state = AppState)] async fn handler(_: State, _: Request) {} diff --git a/axum-macros/tests/debug_middleware/fail/doesnt_take_next.rs b/axum-macros/tests/debug_middleware/fail/doesnt_take_next.rs new file mode 100644 index 0000000000..12092e857b --- /dev/null +++ b/axum-macros/tests/debug_middleware/fail/doesnt_take_next.rs @@ -0,0 +1,13 @@ +use axum::{ + debug_middleware, + extract::Request, + response::{IntoResponse, Response}, +}; + +#[debug_middleware] +async fn my_middleware(request: Request) -> Response { + let _ = request; + ().into_response() +} + +fn main() {} diff --git a/axum-macros/tests/debug_middleware/fail/doesnt_take_next.stderr b/axum-macros/tests/debug_middleware/fail/doesnt_take_next.stderr new file mode 100644 index 0000000000..2474a4ebb4 --- /dev/null +++ b/axum-macros/tests/debug_middleware/fail/doesnt_take_next.stderr @@ -0,0 +1,7 @@ +error: Middleware functions must take `axum::middleware::Next` as the last argument + --> tests/debug_middleware/fail/doesnt_take_next.rs:7:1 + | +7 | #[debug_middleware] + | ^^^^^^^^^^^^^^^^^^^ + | + = note: this error originates in the attribute macro `debug_middleware` (in Nightly builds, run with -Z macro-backtrace for more info) diff --git a/axum-macros/tests/debug_middleware/fail/next_not_last.rs b/axum-macros/tests/debug_middleware/fail/next_not_last.rs new file mode 100644 index 0000000000..d2596ffb50 --- /dev/null +++ b/axum-macros/tests/debug_middleware/fail/next_not_last.rs @@ -0,0 +1,8 @@ +use axum::{debug_middleware, extract::Request, middleware::Next, response::Response}; + +#[debug_middleware] +async fn my_middleware(next: Next, request: Request) -> Response { + next.run(request).await +} + +fn main() {} diff --git a/axum-macros/tests/debug_middleware/fail/next_not_last.stderr b/axum-macros/tests/debug_middleware/fail/next_not_last.stderr new file mode 100644 index 0000000000..4a5fea4546 --- /dev/null +++ b/axum-macros/tests/debug_middleware/fail/next_not_last.stderr @@ -0,0 +1,5 @@ +error: `axum::middleware::Next` must the last argument + --> tests/debug_middleware/fail/next_not_last.rs:4:24 + | +4 | async fn my_middleware(next: Next, request: Request) -> Response { + | ^^^^^^^^^^ diff --git a/axum-macros/tests/debug_middleware/fail/takes_next_twice.rs b/axum-macros/tests/debug_middleware/fail/takes_next_twice.rs new file mode 100644 index 0000000000..995a97bda6 --- /dev/null +++ b/axum-macros/tests/debug_middleware/fail/takes_next_twice.rs @@ -0,0 +1,9 @@ +use axum::{debug_middleware, extract::Request, middleware::Next, response::Response}; + +#[debug_middleware] +async fn my_middleware(request: Request, next: Next, next2: Next) -> Response { + let _ = next2; + next.run(request).await +} + +fn main() {} diff --git a/axum-macros/tests/debug_middleware/fail/takes_next_twice.stderr b/axum-macros/tests/debug_middleware/fail/takes_next_twice.stderr new file mode 100644 index 0000000000..596f55817f --- /dev/null +++ b/axum-macros/tests/debug_middleware/fail/takes_next_twice.stderr @@ -0,0 +1,7 @@ +error: Middleware functions can only take one argument of type `axum::middleware::Next` + --> tests/debug_middleware/fail/takes_next_twice.rs:3:1 + | +3 | #[debug_middleware] + | ^^^^^^^^^^^^^^^^^^^ + | + = note: this error originates in the attribute macro `debug_middleware` (in Nightly builds, run with -Z macro-backtrace for more info) diff --git a/axum-macros/tests/debug_middleware/pass/basic.rs b/axum-macros/tests/debug_middleware/pass/basic.rs new file mode 100644 index 0000000000..1d2a412ac4 --- /dev/null +++ b/axum-macros/tests/debug_middleware/pass/basic.rs @@ -0,0 +1,8 @@ +use axum::{debug_middleware, extract::Request, middleware::Next, response::Response}; + +#[debug_middleware] +async fn my_middleware(request: Request, next: Next) -> Response { + next.run(request).await +} + +fn main() {} diff --git a/axum-macros/tests/from_ref/pass/basic.rs b/axum-macros/tests/from_ref/pass/basic.rs index e410e11a05..2b66d4064d 100644 --- a/axum-macros/tests/from_ref/pass/basic.rs +++ b/axum-macros/tests/from_ref/pass/basic.rs @@ -1,4 +1,8 @@ -use axum::{Router, routing::get, extract::{State, FromRef}}; +use axum::{ + extract::{FromRef, State}, + routing::get, + Router, +}; // This will implement `FromRef` for each field in the struct. #[derive(Clone, FromRef)] @@ -14,7 +18,5 @@ fn main() { auth_token: Default::default(), }; - let _: axum::Router = Router::new() - .route("/", get(handler)) - .with_state(state); + let _: axum::Router = Router::new().route("/", get(handler)).with_state(state); } diff --git a/axum-macros/tests/from_request/fail/enum_from_request_ident_in_variant.rs b/axum-macros/tests/from_request/fail/enum_from_request_ident_in_variant.rs index 336850e5a4..69942e4476 100644 --- a/axum-macros/tests/from_request/fail/enum_from_request_ident_in_variant.rs +++ b/axum-macros/tests/from_request/fail/enum_from_request_ident_in_variant.rs @@ -6,7 +6,7 @@ enum Extractor { Foo { #[from_request(via(axum::Extension))] foo: (), - } + }, } fn main() {} diff --git a/axum-macros/tests/from_request/fail/generic_without_via.rs b/axum-macros/tests/from_request/fail/generic_without_via.rs index c6b0668a91..f0d54acfa9 100644 --- a/axum-macros/tests/from_request/fail/generic_without_via.rs +++ b/axum-macros/tests/from_request/fail/generic_without_via.rs @@ -1,4 +1,3 @@ -#![feature(diagnostic_namespace)] use axum::{routing::get, Router}; use axum_macros::FromRequest; diff --git a/axum-macros/tests/from_request/fail/generic_without_via.stderr b/axum-macros/tests/from_request/fail/generic_without_via.stderr index 9620e64373..daabab098d 100644 --- a/axum-macros/tests/from_request/fail/generic_without_via.stderr +++ b/axum-macros/tests/from_request/fail/generic_without_via.stderr @@ -1,21 +1,21 @@ error: #[derive(FromRequest)] only supports generics when used with #[from_request(via)] - --> tests/from_request/fail/generic_without_via.rs:6:18 + --> tests/from_request/fail/generic_without_via.rs:5:18 | -6 | struct Extractor(T); +5 | struct Extractor(T); | ^ error[E0277]: the trait bound `fn(Extractor<()>) -> impl Future {foo}: Handler<_, _>` is not satisfied - --> tests/from_request/fail/generic_without_via.rs:11:44 + --> tests/from_request/fail/generic_without_via.rs:10:44 | -11 | _ = Router::<()>::new().route("/", get(foo)); +10 | _ = Router::<()>::new().route("/", get(foo)); | --- ^^^ the trait `Handler<_, _>` is not implemented for fn item `fn(Extractor<()>) -> impl Future {foo}` | | | required by a bound introduced by this call | = note: Consider using `#[axum::debug_handler]` to improve the error message = help: the following other types implement trait `Handler`: - as Handler> - as Handler<(), S>> + `Layered` implements `Handler` + `MethodRouter` implements `Handler<(), S>` note: required by a bound in `axum::routing::get` --> $WORKSPACE/axum/src/routing/method_routing.rs | diff --git a/axum-macros/tests/from_request/fail/generic_without_via_rejection.rs b/axum-macros/tests/from_request/fail/generic_without_via_rejection.rs index 9b74829348..b1ce072cb4 100644 --- a/axum-macros/tests/from_request/fail/generic_without_via_rejection.rs +++ b/axum-macros/tests/from_request/fail/generic_without_via_rejection.rs @@ -1,4 +1,3 @@ -#![feature(diagnostic_namespace)] use axum::{routing::get, Router}; use axum_macros::FromRequest; diff --git a/axum-macros/tests/from_request/fail/generic_without_via_rejection.stderr b/axum-macros/tests/from_request/fail/generic_without_via_rejection.stderr index ee6739af63..66f90281ca 100644 --- a/axum-macros/tests/from_request/fail/generic_without_via_rejection.stderr +++ b/axum-macros/tests/from_request/fail/generic_without_via_rejection.stderr @@ -1,21 +1,21 @@ error: #[derive(FromRequest)] only supports generics when used with #[from_request(via)] - --> tests/from_request/fail/generic_without_via_rejection.rs:7:18 + --> tests/from_request/fail/generic_without_via_rejection.rs:6:18 | -7 | struct Extractor(T); +6 | struct Extractor(T); | ^ error[E0277]: the trait bound `fn(Extractor<()>) -> impl Future {foo}: Handler<_, _>` is not satisfied - --> tests/from_request/fail/generic_without_via_rejection.rs:12:44 + --> tests/from_request/fail/generic_without_via_rejection.rs:11:44 | -12 | _ = Router::<()>::new().route("/", get(foo)); +11 | _ = Router::<()>::new().route("/", get(foo)); | --- ^^^ the trait `Handler<_, _>` is not implemented for fn item `fn(Extractor<()>) -> impl Future {foo}` | | | required by a bound introduced by this call | = note: Consider using `#[axum::debug_handler]` to improve the error message = help: the following other types implement trait `Handler`: - as Handler> - as Handler<(), S>> + `Layered` implements `Handler` + `MethodRouter` implements `Handler<(), S>` note: required by a bound in `axum::routing::get` --> $WORKSPACE/axum/src/routing/method_routing.rs | diff --git a/axum-macros/tests/from_request/fail/override_rejection_on_enum_without_via.rs b/axum-macros/tests/from_request/fail/override_rejection_on_enum_without_via.rs index 744d31aa70..e855d5f65c 100644 --- a/axum-macros/tests/from_request/fail/override_rejection_on_enum_without_via.rs +++ b/axum-macros/tests/from_request/fail/override_rejection_on_enum_without_via.rs @@ -1,4 +1,3 @@ -#![feature(diagnostic_namespace)] use axum::{ extract::rejection::ExtensionRejection, response::{IntoResponse, Response}, diff --git a/axum-macros/tests/from_request/fail/override_rejection_on_enum_without_via.stderr b/axum-macros/tests/from_request/fail/override_rejection_on_enum_without_via.stderr index 41252bc3e7..e70248f3a6 100644 --- a/axum-macros/tests/from_request/fail/override_rejection_on_enum_without_via.stderr +++ b/axum-macros/tests/from_request/fail/override_rejection_on_enum_without_via.stderr @@ -1,21 +1,21 @@ error: cannot use `rejection` without `via` - --> tests/from_request/fail/override_rejection_on_enum_without_via.rs:19:16 + --> tests/from_request/fail/override_rejection_on_enum_without_via.rs:18:16 | -19 | #[from_request(rejection(MyRejection))] +18 | #[from_request(rejection(MyRejection))] | ^^^^^^^^^ error[E0277]: the trait bound `fn(MyExtractor) -> impl Future {handler}: Handler<_, _>` is not satisfied - --> tests/from_request/fail/override_rejection_on_enum_without_via.rs:11:50 + --> tests/from_request/fail/override_rejection_on_enum_without_via.rs:10:50 | -11 | let _: Router = Router::new().route("/", get(handler).post(handler_result)); +10 | let _: Router = Router::new().route("/", get(handler).post(handler_result)); | --- ^^^^^^^ the trait `Handler<_, _>` is not implemented for fn item `fn(MyExtractor) -> impl Future {handler}` | | | required by a bound introduced by this call | = note: Consider using `#[axum::debug_handler]` to improve the error message = help: the following other types implement trait `Handler`: - as Handler> - as Handler<(), S>> + `Layered` implements `Handler` + `MethodRouter` implements `Handler<(), S>` note: required by a bound in `axum::routing::get` --> $WORKSPACE/axum/src/routing/method_routing.rs | @@ -27,17 +27,17 @@ note: required by a bound in `axum::routing::get` = note: this error originates in the macro `top_level_handler_fn` (in Nightly builds, run with -Z macro-backtrace for more info) error[E0277]: the trait bound `fn(Result) -> impl Future {handler_result}: Handler<_, _>` is not satisfied - --> tests/from_request/fail/override_rejection_on_enum_without_via.rs:11:64 + --> tests/from_request/fail/override_rejection_on_enum_without_via.rs:10:64 | -11 | let _: Router = Router::new().route("/", get(handler).post(handler_result)); +10 | let _: Router = Router::new().route("/", get(handler).post(handler_result)); | ---- ^^^^^^^^^^^^^^ the trait `Handler<_, _>` is not implemented for fn item `fn(Result) -> impl Future {handler_result}` | | | required by a bound introduced by this call | = note: Consider using `#[axum::debug_handler]` to improve the error message = help: the following other types implement trait `Handler`: - as Handler> - as Handler<(), S>> + `Layered` implements `Handler` + `MethodRouter` implements `Handler<(), S>` note: required by a bound in `MethodRouter::::post` --> $WORKSPACE/axum/src/routing/method_routing.rs | diff --git a/axum-macros/tests/from_request/fail/parts_extracting_body.rs b/axum-macros/tests/from_request/fail/parts_extracting_body.rs index 753d92a981..45a93777a9 100644 --- a/axum-macros/tests/from_request/fail/parts_extracting_body.rs +++ b/axum-macros/tests/from_request/fail/parts_extracting_body.rs @@ -1,4 +1,3 @@ -#![feature(diagnostic_namespace)] use axum::{extract::FromRequestParts, response::Response}; #[derive(FromRequestParts)] diff --git a/axum-macros/tests/from_request/fail/parts_extracting_body.stderr b/axum-macros/tests/from_request/fail/parts_extracting_body.stderr index fbd58ea013..d2401803dd 100644 --- a/axum-macros/tests/from_request/fail/parts_extracting_body.stderr +++ b/axum-macros/tests/from_request/fail/parts_extracting_body.stderr @@ -1,18 +1,18 @@ -error[E0277]: the trait bound `String: FromRequestParts` is not satisfied - --> tests/from_request/fail/parts_extracting_body.rs:6:11 +error[E0277]: the trait bound `String: FromRequestParts<_>` is not satisfied + --> tests/from_request/fail/parts_extracting_body.rs:5:11 | -6 | body: String, - | ^^^^^^ the trait `FromRequestParts` is not implemented for `String` +5 | body: String, + | ^^^^^^ the trait `FromRequestParts<_>` is not implemented for `String` | = note: Function argument is not a valid axum extractor. See `https://docs.rs/axum/0.7/axum/extract/index.html` for details = help: the following other types implement trait `FromRequestParts`: - > - > - as FromRequestParts> - > - > - > - > - > + `()` implements `FromRequestParts` + `(T1, T2)` implements `FromRequestParts` + `(T1, T2, T3)` implements `FromRequestParts` + `(T1, T2, T3, T4)` implements `FromRequestParts` + `(T1, T2, T3, T4, T5)` implements `FromRequestParts` + `(T1, T2, T3, T4, T5, T6)` implements `FromRequestParts` + `(T1, T2, T3, T4, T5, T6, T7)` implements `FromRequestParts` + `(T1, T2, T3, T4, T5, T6, T7, T8)` implements `FromRequestParts` and $N others diff --git a/axum-macros/tests/from_request/fail/state_infer_multiple_different_types.rs b/axum-macros/tests/from_request/fail/state_infer_multiple_different_types.rs index 6533d3276a..18c0698f9f 100644 --- a/axum-macros/tests/from_request/fail/state_infer_multiple_different_types.rs +++ b/axum-macros/tests/from_request/fail/state_infer_multiple_different_types.rs @@ -1,5 +1,5 @@ -use axum_macros::FromRequest; use axum::extract::State; +use axum_macros::FromRequest; #[derive(FromRequest)] struct Extractor { diff --git a/axum-macros/tests/from_request/pass/container_parts.rs b/axum-macros/tests/from_request/pass/container_parts.rs index dedc1719a7..c90703d0fc 100644 --- a/axum-macros/tests/from_request/pass/container_parts.rs +++ b/axum-macros/tests/from_request/pass/container_parts.rs @@ -1,5 +1,5 @@ use axum::{ - extract::{FromRequestParts, Extension}, + extract::{Extension, FromRequestParts}, response::Response, }; diff --git a/axum-macros/tests/from_request/pass/named.rs b/axum-macros/tests/from_request/pass/named.rs index f63ae8e9db..d396847cce 100644 --- a/axum-macros/tests/from_request/pass/named.rs +++ b/axum-macros/tests/from_request/pass/named.rs @@ -1,11 +1,8 @@ -use axum::{ - extract::FromRequest, - response::Response, -}; +use axum::{extract::FromRequest, response::Response}; use axum_extra::{ - TypedHeader, - typed_header::TypedHeaderRejection, headers::{self, UserAgent}, + typed_header::TypedHeaderRejection, + TypedHeader, }; #[derive(FromRequest)] diff --git a/axum-macros/tests/from_request/pass/named_parts.rs b/axum-macros/tests/from_request/pass/named_parts.rs index cbb67e61da..1168b3a119 100644 --- a/axum-macros/tests/from_request/pass/named_parts.rs +++ b/axum-macros/tests/from_request/pass/named_parts.rs @@ -1,11 +1,8 @@ -use axum::{ - extract::FromRequestParts, - response::Response, -}; +use axum::{extract::FromRequestParts, response::Response}; use axum_extra::{ - TypedHeader, - typed_header::TypedHeaderRejection, headers::{self, UserAgent}, + typed_header::TypedHeaderRejection, + TypedHeader, }; #[derive(FromRequestParts)] diff --git a/axum-macros/tests/from_request/pass/named_via.rs b/axum-macros/tests/from_request/pass/named_via.rs index 691627b08d..5159c49b88 100644 --- a/axum-macros/tests/from_request/pass/named_via.rs +++ b/axum-macros/tests/from_request/pass/named_via.rs @@ -1,11 +1,11 @@ use axum::{ - response::Response, extract::{Extension, FromRequest}, + response::Response, }; use axum_extra::{ - TypedHeader, - typed_header::TypedHeaderRejection, headers::{self, UserAgent}, + typed_header::TypedHeaderRejection, + TypedHeader, }; #[derive(FromRequest)] diff --git a/axum-macros/tests/from_request/pass/named_via_parts.rs b/axum-macros/tests/from_request/pass/named_via_parts.rs index 0377af7b10..38fe0964c0 100644 --- a/axum-macros/tests/from_request/pass/named_via_parts.rs +++ b/axum-macros/tests/from_request/pass/named_via_parts.rs @@ -1,11 +1,11 @@ use axum::{ - response::Response, extract::{Extension, FromRequestParts}, + response::Response, }; use axum_extra::{ - TypedHeader, - typed_header::TypedHeaderRejection, headers::{self, UserAgent}, + typed_header::TypedHeaderRejection, + TypedHeader, }; #[derive(FromRequestParts)] diff --git a/axum-macros/tests/from_request/pass/override_rejection.rs b/axum-macros/tests/from_request/pass/override_rejection.rs index 779058b9fc..ee9a4540d4 100644 --- a/axum-macros/tests/from_request/pass/override_rejection.rs +++ b/axum-macros/tests/from_request/pass/override_rejection.rs @@ -1,10 +1,8 @@ use axum::{ - async_trait, - extract::{Request, rejection::ExtensionRejection, FromRequest}, + extract::{rejection::ExtensionRejection, FromRequest, Request}, http::StatusCode, response::{IntoResponse, Response}, routing::get, - body::Body, Extension, Router, }; @@ -27,7 +25,6 @@ struct MyExtractor { struct OtherExtractor; -#[async_trait] impl FromRequest for OtherExtractor where S: Send + Sync, diff --git a/axum-macros/tests/from_request/pass/override_rejection_non_generic.rs b/axum-macros/tests/from_request/pass/override_rejection_non_generic.rs index 6c4d87fe01..b88b483e3b 100644 --- a/axum-macros/tests/from_request/pass/override_rejection_non_generic.rs +++ b/axum-macros/tests/from_request/pass/override_rejection_non_generic.rs @@ -5,8 +5,8 @@ use axum::{ Router, }; use axum_macros::FromRequest; -use std::collections::HashMap; use serde::Deserialize; +use std::collections::HashMap; fn main() { let _: Router = Router::new().route("/", get(handler).post(handler_result)); @@ -17,10 +17,7 @@ async fn handler(_: MyJson) {} async fn handler_result(_: Result) {} #[derive(FromRequest, Deserialize)] -#[from_request( - via(axum::extract::Json), - rejection(MyJsonRejection), -)] +#[from_request(via(axum::extract::Json), rejection(MyJsonRejection))] #[serde(transparent)] struct MyJson(HashMap); diff --git a/axum-macros/tests/from_request/pass/override_rejection_non_generic_parts.rs b/axum-macros/tests/from_request/pass/override_rejection_non_generic_parts.rs index 9aca7345dd..bcf213d1e2 100644 --- a/axum-macros/tests/from_request/pass/override_rejection_non_generic_parts.rs +++ b/axum-macros/tests/from_request/pass/override_rejection_non_generic_parts.rs @@ -5,8 +5,8 @@ use axum::{ Router, }; use axum_macros::FromRequestParts; -use std::collections::HashMap; use serde::Deserialize; +use std::collections::HashMap; fn main() { let _: Router = Router::new().route("/", get(handler).post(handler_result)); @@ -17,10 +17,7 @@ async fn handler(_: MyQuery) {} async fn handler_result(_: Result) {} #[derive(FromRequestParts, Deserialize)] -#[from_request( - via(axum::extract::Query), - rejection(MyQueryRejection), -)] +#[from_request(via(axum::extract::Query), rejection(MyQueryRejection))] #[serde(transparent)] struct MyQuery(HashMap); diff --git a/axum-macros/tests/from_request/pass/override_rejection_parts.rs b/axum-macros/tests/from_request/pass/override_rejection_parts.rs index 8ef9cb22db..7cc27de24c 100644 --- a/axum-macros/tests/from_request/pass/override_rejection_parts.rs +++ b/axum-macros/tests/from_request/pass/override_rejection_parts.rs @@ -1,5 +1,4 @@ use axum::{ - async_trait, extract::{rejection::ExtensionRejection, FromRequestParts}, http::{request::Parts, StatusCode}, response::{IntoResponse, Response}, @@ -26,7 +25,6 @@ struct MyExtractor { struct OtherExtractor; -#[async_trait] impl FromRequestParts for OtherExtractor where S: Send + Sync, diff --git a/axum-macros/tests/from_request/pass/override_rejection_with_via_on_struct.rs b/axum-macros/tests/from_request/pass/override_rejection_with_via_on_struct.rs index 2a046fae50..8dab4165e4 100644 --- a/axum-macros/tests/from_request/pass/override_rejection_with_via_on_struct.rs +++ b/axum-macros/tests/from_request/pass/override_rejection_with_via_on_struct.rs @@ -19,10 +19,7 @@ async fn handler(_: MyJson) {} async fn handler_result(_: Result, MyJsonRejection>) {} #[derive(FromRequest)] -#[from_request( - via(axum::Json), - rejection(MyJsonRejection), -)] +#[from_request(via(axum::Json), rejection(MyJsonRejection))] struct MyJson(T); struct MyJsonRejection {} diff --git a/axum-macros/tests/from_request/pass/override_rejection_with_via_on_struct_parts.rs b/axum-macros/tests/from_request/pass/override_rejection_with_via_on_struct_parts.rs index eaeeeacf23..19e7307678 100644 --- a/axum-macros/tests/from_request/pass/override_rejection_with_via_on_struct_parts.rs +++ b/axum-macros/tests/from_request/pass/override_rejection_with_via_on_struct_parts.rs @@ -19,10 +19,7 @@ async fn handler(_: MyQuery) {} async fn handler_result(_: Result, MyQueryRejection>) {} #[derive(FromRequestParts)] -#[from_request( - via(axum::extract::Query), - rejection(MyQueryRejection), -)] +#[from_request(via(axum::extract::Query), rejection(MyQueryRejection))] struct MyQuery(T); struct MyQueryRejection {} diff --git a/axum-macros/tests/from_request/pass/state_cookie.rs b/axum-macros/tests/from_request/pass/state_cookie.rs index 6e2aa1f4ed..ce935d67fa 100644 --- a/axum-macros/tests/from_request/pass/state_cookie.rs +++ b/axum-macros/tests/from_request/pass/state_cookie.rs @@ -1,6 +1,6 @@ -use axum_macros::FromRequest; use axum::extract::FromRef; -use axum_extra::extract::cookie::{PrivateCookieJar, Key}; +use axum_extra::extract::cookie::{Key, PrivateCookieJar}; +use axum_macros::FromRequest; #[derive(FromRequest)] #[from_request(state(AppState))] diff --git a/axum-macros/tests/from_request/pass/state_explicit.rs b/axum-macros/tests/from_request/pass/state_explicit.rs index aed9dad6d5..df32a72568 100644 --- a/axum-macros/tests/from_request/pass/state_explicit.rs +++ b/axum-macros/tests/from_request/pass/state_explicit.rs @@ -1,9 +1,9 @@ -use axum_macros::FromRequest; use axum::{ extract::{FromRef, State}, - Router, routing::get, + Router, }; +use axum_macros::FromRequest; fn main() { let _: axum::Router = Router::new() diff --git a/axum-macros/tests/from_request/pass/state_explicit_parts.rs b/axum-macros/tests/from_request/pass/state_explicit_parts.rs index 94f37cf6b8..2822699379 100644 --- a/axum-macros/tests/from_request/pass/state_explicit_parts.rs +++ b/axum-macros/tests/from_request/pass/state_explicit_parts.rs @@ -1,9 +1,9 @@ -use axum_macros::FromRequestParts; use axum::{ - extract::{FromRef, State, Query}, - Router, + extract::{FromRef, Query, State}, routing::get, + Router, }; +use axum_macros::FromRequestParts; use std::collections::HashMap; fn main() { diff --git a/axum-macros/tests/from_request/pass/state_field_explicit.rs b/axum-macros/tests/from_request/pass/state_field_explicit.rs index b6d003dc00..90a903261e 100644 --- a/axum-macros/tests/from_request/pass/state_field_explicit.rs +++ b/axum-macros/tests/from_request/pass/state_field_explicit.rs @@ -1,5 +1,5 @@ use axum::{ - extract::{State, FromRef}, + extract::{FromRef, State}, routing::get, Router, }; diff --git a/axum-macros/tests/from_request/pass/state_field_infer.rs b/axum-macros/tests/from_request/pass/state_field_infer.rs index a24861a162..5e399c1bf0 100644 --- a/axum-macros/tests/from_request/pass/state_field_infer.rs +++ b/axum-macros/tests/from_request/pass/state_field_infer.rs @@ -1,8 +1,4 @@ -use axum::{ - extract::State, - routing::get, - Router, -}; +use axum::{extract::State, routing::get, Router}; use axum_macros::FromRequest; fn main() { diff --git a/axum-macros/tests/from_request/pass/state_infer.rs b/axum-macros/tests/from_request/pass/state_infer.rs index 07545ab074..39966b7c9b 100644 --- a/axum-macros/tests/from_request/pass/state_infer.rs +++ b/axum-macros/tests/from_request/pass/state_infer.rs @@ -1,5 +1,5 @@ -use axum_macros::FromRequest; use axum::extract::State; +use axum_macros::FromRequest; #[derive(FromRequest)] struct Extractor { diff --git a/axum-macros/tests/from_request/pass/state_infer_multiple.rs b/axum-macros/tests/from_request/pass/state_infer_multiple.rs index cb8de1d59c..d727a7e337 100644 --- a/axum-macros/tests/from_request/pass/state_infer_multiple.rs +++ b/axum-macros/tests/from_request/pass/state_infer_multiple.rs @@ -1,5 +1,5 @@ -use axum_macros::FromRequest; use axum::extract::State; +use axum_macros::FromRequest; #[derive(FromRequest)] struct Extractor { diff --git a/axum-macros/tests/from_request/pass/state_infer_parts.rs b/axum-macros/tests/from_request/pass/state_infer_parts.rs index f3f078c5f3..35ffcc31f0 100644 --- a/axum-macros/tests/from_request/pass/state_infer_parts.rs +++ b/axum-macros/tests/from_request/pass/state_infer_parts.rs @@ -1,5 +1,5 @@ -use axum_macros::FromRequestParts; use axum::extract::State; +use axum_macros::FromRequestParts; #[derive(FromRequestParts)] struct Extractor { diff --git a/axum-macros/tests/from_request/pass/state_via_infer.rs b/axum-macros/tests/from_request/pass/state_via_infer.rs index 40c52d8d4d..685ff1c7fe 100644 --- a/axum-macros/tests/from_request/pass/state_via_infer.rs +++ b/axum-macros/tests/from_request/pass/state_via_infer.rs @@ -1,8 +1,4 @@ -use axum::{ - extract::State, - routing::get, - Router, -}; +use axum::{extract::State, routing::get, Router}; use axum_macros::FromRequest; fn main() { diff --git a/axum-macros/tests/from_request/pass/state_with_rejection.rs b/axum-macros/tests/from_request/pass/state_with_rejection.rs index 9921add02b..8e730c961e 100644 --- a/axum-macros/tests/from_request/pass/state_with_rejection.rs +++ b/axum-macros/tests/from_request/pass/state_with_rejection.rs @@ -1,4 +1,3 @@ -use std::convert::Infallible; use axum::{ extract::State, response::{IntoResponse, Response}, @@ -6,6 +5,7 @@ use axum::{ Router, }; use axum_macros::FromRequest; +use std::convert::Infallible; fn main() { let _: axum::Router = Router::new() diff --git a/axum-macros/tests/from_request/pass/tuple_same_type_twice.rs b/axum-macros/tests/from_request/pass/tuple_same_type_twice.rs index 343563ddb6..56c8fd0f12 100644 --- a/axum-macros/tests/from_request/pass/tuple_same_type_twice.rs +++ b/axum-macros/tests/from_request/pass/tuple_same_type_twice.rs @@ -3,10 +3,7 @@ use axum_macros::FromRequest; use serde::Deserialize; #[derive(FromRequest)] -struct Extractor( - Query, - axum::extract::Json, -); +struct Extractor(Query, axum::extract::Json); #[derive(Deserialize)] struct Payload {} diff --git a/axum-macros/tests/from_request/pass/tuple_same_type_twice_parts.rs b/axum-macros/tests/from_request/pass/tuple_same_type_twice_parts.rs index 44c42dc5ab..a781baf605 100644 --- a/axum-macros/tests/from_request/pass/tuple_same_type_twice_parts.rs +++ b/axum-macros/tests/from_request/pass/tuple_same_type_twice_parts.rs @@ -3,10 +3,7 @@ use axum_macros::FromRequestParts; use serde::Deserialize; #[derive(FromRequestParts)] -struct Extractor( - Query, - axum::extract::Path, -); +struct Extractor(Query, axum::extract::Path); #[derive(Deserialize)] struct Payload {} diff --git a/axum-macros/tests/typed_path/fail/missing_field.rs b/axum-macros/tests/typed_path/fail/missing_field.rs index 2e211769ab..7f1f3be06d 100644 --- a/axum-macros/tests/typed_path/fail/missing_field.rs +++ b/axum-macros/tests/typed_path/fail/missing_field.rs @@ -2,8 +2,7 @@ use axum_macros::TypedPath; use serde::Deserialize; #[derive(TypedPath, Deserialize)] -#[typed_path("/users/:id")] +#[typed_path("/users/{id}")] struct MyPath {} -fn main() { -} +fn main() {} diff --git a/axum-macros/tests/typed_path/fail/missing_field.stderr b/axum-macros/tests/typed_path/fail/missing_field.stderr index faf2d4b681..2a85e74938 100644 --- a/axum-macros/tests/typed_path/fail/missing_field.stderr +++ b/axum-macros/tests/typed_path/fail/missing_field.stderr @@ -1,5 +1,5 @@ error[E0026]: struct `MyPath` does not have a field named `id` --> tests/typed_path/fail/missing_field.rs:5:14 | -5 | #[typed_path("/users/:id")] - | ^^^^^^^^^^^^ struct `MyPath` does not have this field +5 | #[typed_path("/users/{id}")] + | ^^^^^^^^^^^^^ struct `MyPath` does not have this field diff --git a/axum-macros/tests/typed_path/fail/not_deserialize.rs b/axum-macros/tests/typed_path/fail/not_deserialize.rs index b569186651..1d99e8f2aa 100644 --- a/axum-macros/tests/typed_path/fail/not_deserialize.rs +++ b/axum-macros/tests/typed_path/fail/not_deserialize.rs @@ -1,7 +1,7 @@ use axum_macros::TypedPath; #[derive(TypedPath)] -#[typed_path("/users/:id")] +#[typed_path("/users/{id}")] struct MyPath { id: u32, } diff --git a/axum-macros/tests/typed_path/fail/not_deserialize.stderr b/axum-macros/tests/typed_path/fail/not_deserialize.stderr index 3ae15fbef0..ed2c9d7571 100644 --- a/axum-macros/tests/typed_path/fail/not_deserialize.stderr +++ b/axum-macros/tests/typed_path/fail/not_deserialize.stderr @@ -2,17 +2,57 @@ error[E0277]: the trait bound `for<'de> MyPath: serde::de::Deserialize<'de>` is --> tests/typed_path/fail/not_deserialize.rs:3:10 | 3 | #[derive(TypedPath)] - | ^^^^^^^^^ the trait `for<'de> serde::de::Deserialize<'de>` is not implemented for `MyPath` + | ^^^^^^^^^ the trait `for<'de> serde::de::Deserialize<'de>` is not implemented for `MyPath`, which is required by `axum::extract::Path: FromRequestParts` | + = note: for local types consider adding `#[derive(serde::Deserialize)]` to your `MyPath` type + = note: for types from other crates check whether the crate offers a `serde` feature flag = help: the following other types implement trait `serde::de::Deserialize<'de>`: - bool - char - isize - i8 - i16 - i32 - i64 - i128 + &'a [u8] + &'a serde_json::raw::RawValue + &'a std::path::Path + &'a str + () + (T,) + (T0, T1) + (T0, T1, T2) + and $N others + = note: required for `MyPath` to implement `serde::de::DeserializeOwned` + = note: required for `axum::extract::Path` to implement `FromRequestParts` + +error[E0277]: the trait bound `MyPath: serde::de::DeserializeOwned` is not satisfied + --> tests/typed_path/fail/not_deserialize.rs:3:10 + | +3 | #[derive(TypedPath)] + | ^^^^^^^^^ the trait `for<'de> serde::de::Deserialize<'de>` is not implemented for `MyPath`, which is required by `axum::extract::Path: FromRequestParts` + | + = help: the following other types implement trait `serde::de::Deserialize<'de>`: + &'a [u8] + &'a serde_json::raw::RawValue + &'a std::path::Path + &'a str + () + (T,) + (T0, T1) + (T0, T1, T2) + and $N others + = note: required for `MyPath` to implement `serde::de::DeserializeOwned` + = note: required for `axum::extract::Path` to implement `FromRequestParts` + +error[E0277]: the trait bound `MyPath: serde::de::DeserializeOwned` is not satisfied + --> tests/typed_path/fail/not_deserialize.rs:3:10 + | +3 | #[derive(TypedPath)] + | ^^^^^^^^^ the trait `for<'de> serde::de::Deserialize<'de>` is not implemented for `MyPath`, which is required by `axum::extract::Path: FromRequestParts` + | + = help: the following other types implement trait `serde::de::Deserialize<'de>`: + &'a [u8] + &'a serde_json::raw::RawValue + &'a std::path::Path + &'a str + () + (T,) + (T0, T1) + (T0, T1, T2) and $N others = note: required for `MyPath` to implement `serde::de::DeserializeOwned` = note: required for `axum::extract::Path` to implement `FromRequestParts` diff --git a/axum-macros/tests/typed_path/fail/route_not_starting_with_slash_non_empty.rs b/axum-macros/tests/typed_path/fail/route_not_starting_with_slash_non_empty.rs index 33ae38d699..9d45b99964 100644 --- a/axum-macros/tests/typed_path/fail/route_not_starting_with_slash_non_empty.rs +++ b/axum-macros/tests/typed_path/fail/route_not_starting_with_slash_non_empty.rs @@ -1,7 +1,7 @@ use axum_extra::routing::TypedPath; #[derive(TypedPath)] -#[typed_path(":foo")] +#[typed_path("{foo}")] struct MyPath; fn main() {} diff --git a/axum-macros/tests/typed_path/fail/route_not_starting_with_slash_non_empty.stderr b/axum-macros/tests/typed_path/fail/route_not_starting_with_slash_non_empty.stderr index db8e40f024..f1b7b2caf3 100644 --- a/axum-macros/tests/typed_path/fail/route_not_starting_with_slash_non_empty.stderr +++ b/axum-macros/tests/typed_path/fail/route_not_starting_with_slash_non_empty.stderr @@ -1,5 +1,5 @@ error: paths must start with a `/` --> tests/typed_path/fail/route_not_starting_with_slash_non_empty.rs:4:14 | -4 | #[typed_path(":foo")] - | ^^^^^^ +4 | #[typed_path("{foo}")] + | ^^^^^^^ diff --git a/axum-macros/tests/typed_path/fail/unit_with_capture.rs b/axum-macros/tests/typed_path/fail/unit_with_capture.rs index 49979cf725..ddd544f658 100644 --- a/axum-macros/tests/typed_path/fail/unit_with_capture.rs +++ b/axum-macros/tests/typed_path/fail/unit_with_capture.rs @@ -2,7 +2,7 @@ use axum_macros::TypedPath; use serde::Deserialize; #[derive(TypedPath, Deserialize)] -#[typed_path("/users/:id")] +#[typed_path("/users/{id}")] struct MyPath; fn main() {} diff --git a/axum-macros/tests/typed_path/fail/unit_with_capture.stderr b/axum-macros/tests/typed_path/fail/unit_with_capture.stderr index d290308c8e..058ca6f974 100644 --- a/axum-macros/tests/typed_path/fail/unit_with_capture.stderr +++ b/axum-macros/tests/typed_path/fail/unit_with_capture.stderr @@ -1,5 +1,5 @@ error: Typed paths for unit structs cannot contain captures --> tests/typed_path/fail/unit_with_capture.rs:5:14 | -5 | #[typed_path("/users/:id")] - | ^^^^^^^^^^^^ +5 | #[typed_path("/users/{id}")] + | ^^^^^^^^^^^^^ diff --git a/axum-macros/tests/typed_path/pass/customize_rejection.rs b/axum-macros/tests/typed_path/pass/customize_rejection.rs index 01f11fc94c..080bc3f2d3 100644 --- a/axum-macros/tests/typed_path/pass/customize_rejection.rs +++ b/axum-macros/tests/typed_path/pass/customize_rejection.rs @@ -6,7 +6,7 @@ use axum_extra::routing::{RouterExt, TypedPath}; use serde::Deserialize; #[derive(TypedPath, Deserialize)] -#[typed_path("/:foo", rejection(MyRejection))] +#[typed_path("/{foo}", rejection(MyRejection))] struct MyPathNamed { foo: String, } @@ -16,7 +16,7 @@ struct MyPathNamed { struct MyPathUnit; #[derive(TypedPath, Deserialize)] -#[typed_path("/:foo", rejection(MyRejection))] +#[typed_path("/{foo}", rejection(MyRejection))] struct MyPathUnnamed(String); struct MyRejection; diff --git a/axum-macros/tests/typed_path/pass/into_uri.rs b/axum-macros/tests/typed_path/pass/into_uri.rs index 5276627c2f..20b01c1d57 100644 --- a/axum-macros/tests/typed_path/pass/into_uri.rs +++ b/axum-macros/tests/typed_path/pass/into_uri.rs @@ -1,15 +1,15 @@ -use axum_extra::routing::TypedPath; use axum::http::Uri; +use axum_extra::routing::TypedPath; use serde::Deserialize; #[derive(TypedPath, Deserialize)] -#[typed_path("/:id")] +#[typed_path("/{id}")] struct Named { id: u32, } #[derive(TypedPath, Deserialize)] -#[typed_path("/:id")] +#[typed_path("/{id}")] struct Unnamed(u32); #[derive(TypedPath, Deserialize)] diff --git a/axum-macros/tests/typed_path/pass/named_fields_struct.rs b/axum-macros/tests/typed_path/pass/named_fields_struct.rs index 042936fe02..5decd89c89 100644 --- a/axum-macros/tests/typed_path/pass/named_fields_struct.rs +++ b/axum-macros/tests/typed_path/pass/named_fields_struct.rs @@ -2,7 +2,7 @@ use axum_extra::routing::TypedPath; use serde::Deserialize; #[derive(TypedPath, Deserialize)] -#[typed_path("/users/:user_id/teams/:team_id")] +#[typed_path("/users/{user_id}/teams/{team_id}")] struct MyPath { user_id: u32, team_id: u32, @@ -11,7 +11,7 @@ struct MyPath { fn main() { _ = axum::Router::<()>::new().route("/", axum::routing::get(|_: MyPath| async {})); - assert_eq!(MyPath::PATH, "/users/:user_id/teams/:team_id"); + assert_eq!(MyPath::PATH, "/users/{user_id}/teams/{team_id}"); assert_eq!( format!( "{}", diff --git a/axum-macros/tests/typed_path/pass/option_result.rs b/axum-macros/tests/typed_path/pass/option_result.rs index 1bd2359010..81cfb29482 100644 --- a/axum-macros/tests/typed_path/pass/option_result.rs +++ b/axum-macros/tests/typed_path/pass/option_result.rs @@ -1,9 +1,9 @@ -use axum_extra::routing::{TypedPath, RouterExt}; use axum::{extract::rejection::PathRejection, http::StatusCode}; +use axum_extra::routing::{RouterExt, TypedPath}; use serde::Deserialize; #[derive(TypedPath, Deserialize)] -#[typed_path("/users/:id")] +#[typed_path("/users/{id}")] struct UsersShow { id: String, } diff --git a/axum-macros/tests/typed_path/pass/tuple_struct.rs b/axum-macros/tests/typed_path/pass/tuple_struct.rs index 3ee8370402..0c85bae5ec 100644 --- a/axum-macros/tests/typed_path/pass/tuple_struct.rs +++ b/axum-macros/tests/typed_path/pass/tuple_struct.rs @@ -4,12 +4,12 @@ use serde::Deserialize; pub type Result = std::result::Result; #[derive(TypedPath, Deserialize)] -#[typed_path("/users/:user_id/teams/:team_id")] +#[typed_path("/users/{user_id}/teams/{team_id}")] struct MyPath(u32, u32); fn main() { _ = axum::Router::<()>::new().route("/", axum::routing::get(|_: MyPath| async {})); - assert_eq!(MyPath::PATH, "/users/:user_id/teams/:team_id"); + assert_eq!(MyPath::PATH, "/users/{user_id}/teams/{team_id}"); assert_eq!(format!("{}", MyPath(1, 2)), "/users/1/teams/2"); } diff --git a/axum-macros/tests/typed_path/pass/unit_struct.rs b/axum-macros/tests/typed_path/pass/unit_struct.rs index f3bb164075..832e1001d3 100644 --- a/axum-macros/tests/typed_path/pass/unit_struct.rs +++ b/axum-macros/tests/typed_path/pass/unit_struct.rs @@ -5,8 +5,7 @@ use axum_extra::routing::TypedPath; struct MyPath; fn main() { - _ = axum::Router::<()>::new() - .route("/", axum::routing::get(|_: MyPath| async {})); + _ = axum::Router::<()>::new().route("/", axum::routing::get(|_: MyPath| async {})); assert_eq!(MyPath::PATH, "/users"); assert_eq!(format!("{}", MyPath), "/users"); diff --git a/axum-macros/tests/typed_path/pass/url_encoding.rs b/axum-macros/tests/typed_path/pass/url_encoding.rs index db1c3700ab..5e773d698e 100644 --- a/axum-macros/tests/typed_path/pass/url_encoding.rs +++ b/axum-macros/tests/typed_path/pass/url_encoding.rs @@ -2,13 +2,13 @@ use axum_extra::routing::TypedPath; use serde::Deserialize; #[derive(TypedPath, Deserialize)] -#[typed_path("/:param")] +#[typed_path("/{param}")] struct Named { param: String, } #[derive(TypedPath, Deserialize)] -#[typed_path("/:param")] +#[typed_path("/{param}")] struct Unnamed(String); fn main() { @@ -22,11 +22,5 @@ fn main() { "/a%20b" ); - assert_eq!( - format!( - "{}", - Unnamed("a b".to_string()), - ), - "/a%20b" - ); + assert_eq!(format!("{}", Unnamed("a b".to_string()),), "/a%20b"); } diff --git a/axum-macros/tests/typed_path/pass/wildcards.rs b/axum-macros/tests/typed_path/pass/wildcards.rs index 98aa5f5153..51f0c3f540 100644 --- a/axum-macros/tests/typed_path/pass/wildcards.rs +++ b/axum-macros/tests/typed_path/pass/wildcards.rs @@ -2,7 +2,7 @@ use axum_extra::routing::{RouterExt, TypedPath}; use serde::Deserialize; #[derive(TypedPath, Deserialize)] -#[typed_path("/*rest")] +#[typed_path("/{*rest}")] struct MyPath { rest: String, } diff --git a/axum/CHANGELOG.md b/axum/CHANGELOG.md index 6d52cd6c4e..ce98bca697 100644 --- a/axum/CHANGELOG.md +++ b/axum/CHANGELOG.md @@ -7,7 +7,112 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 # Unreleased -- None. +- **fixed:** Skip SSE incompatible chars of `serde_json::RawValue` in `Event::json_data` ([#2992]) +- **fixed:** Don't panic when array type is used for path segment ([#3039]) +- **breaking:** Move `Host` extractor to `axum-extra` ([#2956]) +- **added:** Add `method_not_allowed_fallback` to set a fallback when a path matches but there is no handler for the given HTTP method ([#2903]) +- **added:** Add `NoContent` as a self-described shortcut for `StatusCode::NO_CONTENT` ([#2978]) +- **added:** Add support for WebSockets over HTTP/2. + They can be enabled by changing `get(ws_endpoint)` handlers to `any(ws_endpoint)` ([#2894]) +- **added:** Add `MethodFilter::CONNECT`, `routing::connect[_service]` + and `MethodRouter::connect[_service]` ([#2961]) +- **fixed:** Avoid setting `content-length` before middleware. + This allows middleware to add bodies to requests without needing to manually set `content-length` ([#2897]) +- **breaking:** Remove `WebSocket::close`. + Users should explicitly send close messages themselves. ([#2974]) +- **added:** Extend `FailedToDeserializePathParams::kind` enum with (`ErrorKind::DeserializeError`) + This new variant captures both `key`, `value`, and `message` from named path parameters parse errors, + instead of only deserialization error message in `ErrorKind::Message`. ([#2720]) + +[#2897]: https://github.com/tokio-rs/axum/pull/2897 +[#2903]: https://github.com/tokio-rs/axum/pull/2903 +[#2894]: https://github.com/tokio-rs/axum/pull/2894 +[#2956]: https://github.com/tokio-rs/axum/pull/2956 +[#2961]: https://github.com/tokio-rs/axum/pull/2961 +[#2974]: https://github.com/tokio-rs/axum/pull/2974 +[#2978]: https://github.com/tokio-rs/axum/pull/2978 +[#2992]: https://github.com/tokio-rs/axum/pull/2992 +[#2720]: https://github.com/tokio-rs/axum/pull/2720 +[#3039]: https://github.com/tokio-rs/axum/pull/3039 + +# 0.8.0 + +## alpha.1 + +- **breaking:** Require `Sync` for all handlers and services added to `Router` + and `MethodRouter` ([#2473]) +- **breaking:** The tuple and tuple_struct `Path` extractor deserializers now check that the number of parameters matches the tuple length exactly ([#2931]) +- **breaking:** Upgrade matchit to 0.8, changing the path parameter syntax from `/:single` and `/*many` + to `/{single}` and `/{*many}`; the old syntax produces a panic to avoid silent change in behavior ([#2645]) +- **change:** Update minimum rust version to 1.75 ([#2943]) + +[#2473]: https://github.com/tokio-rs/axum/pull/2473 +[#2645]: https://github.com/tokio-rs/axum/pull/2645 +[#2931]: https://github.com/tokio-rs/axum/pull/2931 +[#2943]: https://github.com/tokio-rs/axum/pull/2943 + +# 0.7.9 + +- **fixed:** Avoid setting content-length before middleware ([#3031]) + +[#3031]:https://github.com/tokio-rs/axum/pull/3031 + +# 0.7.8 + +- **fixed:** Skip SSE incompatible chars of `serde_json::RawValue` in `Event::json_data` ([#2992]) +- **added:** Add `method_not_allowed_fallback` to set a fallback when a path matches but there is no handler for the given HTTP method ([#2903]) +- **added:** Add `MethodFilter::CONNECT`, `routing::connect[_service]` + and `MethodRouter::connect[_service]` ([#2961]) +- **added:** Add `NoContent` as a self-described shortcut for `StatusCode::NO_CONTENT` ([#2978]) + +[#2903]: https://github.com/tokio-rs/axum/pull/2903 +[#2961]: https://github.com/tokio-rs/axum/pull/2961 +[#2978]: https://github.com/tokio-rs/axum/pull/2978 +[#2992]: https://github.com/tokio-rs/axum/pull/2992 + +# 0.7.7 + +- **change**: Remove manual tables of content from the documentation, since + rustdoc now generates tables of content in the sidebar ([#2921]) + +[#2921]: https://github.com/tokio-rs/axum/pull/2921 + +# 0.7.6 + +- **change:** Avoid cloning `Arc` during deserialization of `Path` +- **added:** `axum::serve::Serve::tcp_nodelay` and `axum::serve::WithGracefulShutdown::tcp_nodelay` ([#2653]) +- **added:** `Router::has_routes` function ([#2790]) +- **change:** Update tokio-tungstenite to 0.23 ([#2841]) +- **added:** `Serve::local_addr` and `WithGracefulShutdown::local_addr` functions ([#2881]) + +[#2653]: https://github.com/tokio-rs/axum/pull/2653 +[#2790]: https://github.com/tokio-rs/axum/pull/2790 +[#2841]: https://github.com/tokio-rs/axum/pull/2841 +[#2881]: https://github.com/tokio-rs/axum/pull/2881 + +# 0.7.5 (24. March, 2024) + +- **fixed:** Fixed layers being cloned when calling `axum::serve` directly with + a `Router` or `MethodRouter` ([#2586]) +- **fixed:** `h2` is no longer pulled as a dependency unless the `http2` feature + is enabled ([#2605]) +- **added:** Add `#[debug_middleware]` ([#1993], [#2725]) + +[#1993]: https://github.com/tokio-rs/axum/pull/1993 +[#2725]: https://github.com/tokio-rs/axum/pull/2725 +[#2586]: https://github.com/tokio-rs/axum/pull/2586 +[#2605]: https://github.com/tokio-rs/axum/pull/2605 + +# 0.7.4 (13. January, 2024) + +- **fixed:** Fix performance regression present since axum 0.7.0 ([#2483]) +- **fixed:** Improve `debug_handler` on tuple response types ([#2201]) +- **added:** Add `must_use` attribute to `Serve` and `WithGracefulShutdown` ([#2484]) +- **added:** Re-export `axum_core::body::BodyDataStream` from axum + +[#2201]: https://github.com/tokio-rs/axum/pull/2201 +[#2483]: https://github.com/tokio-rs/axum/pull/2483 +[#2484]: https://github.com/tokio-rs/axum/pull/2484 # 0.7.3 (29. December, 2023) @@ -546,7 +651,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ```rust use axum::{Json, http::HeaderMap}; - // This wont compile on 0.6 because both `Json` and `String` need to consume + // This won't compile on 0.6 because both `Json` and `String` need to consume // the request body. You can use either `Json` or `String`, but not both. async fn handler_1( json: Json, @@ -577,7 +682,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ```rust struct MyExtractor { /* ... */ } - #[async_trait] impl FromRequest for MyExtractor where B: Send, @@ -596,13 +700,11 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 use axum::{ extract::{FromRequest, FromRequestParts}, http::{StatusCode, Request, request::Parts}, - async_trait, }; struct MyExtractor { /* ... */ } // implement `FromRequestParts` if you don't need to consume the request body - #[async_trait] impl FromRequestParts for MyExtractor where S: Send + Sync, @@ -615,7 +717,6 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 } // implement `FromRequest` if you do need to consume the request body - #[async_trait] impl FromRequest for MyExtractor where S: Send + Sync, @@ -1132,7 +1233,7 @@ Yanked, as it didn't compile in release mode. ```rust use axum::{Json, http::HeaderMap}; - // This wont compile on 0.6 because both `Json` and `String` need to consume + // This won't compile on 0.6 because both `Json` and `String` need to consume // the request body. You can use either `Json` or `String`, but not both. async fn handler_1( json: Json, @@ -1163,7 +1264,6 @@ Yanked, as it didn't compile in release mode. ```rust struct MyExtractor { /* ... */ } - #[async_trait] impl FromRequest for MyExtractor where B: Send, @@ -1182,13 +1282,11 @@ Yanked, as it didn't compile in release mode. use axum::{ extract::{FromRequest, FromRequestParts}, http::{StatusCode, Request, request::Parts}, - async_trait, }; struct MyExtractor { /* ... */ } // implement `FromRequestParts` if you don't need to consume the request body - #[async_trait] impl FromRequestParts for MyExtractor where S: Send + Sync, @@ -1201,7 +1299,6 @@ Yanked, as it didn't compile in release mode. } // implement `FromRequest` if you do need to consume the request body - #[async_trait] impl FromRequest for MyExtractor where S: Send + Sync, diff --git a/axum/Cargo.toml b/axum/Cargo.toml index da8f5ef24f..ef113c335b 100644 --- a/axum/Cargo.toml +++ b/axum/Cargo.toml @@ -1,10 +1,10 @@ [package] name = "axum" -version = "0.7.3" +version = "0.8.0-alpha.1" # remember to bump the version that axum-extra and axum-macros depend on categories = ["asynchronous", "network-programming", "web-programming::http-server"] description = "Web framework that focuses on ergonomics and modularity" edition = "2021" -rust-version = "1.66" +rust-version = { workspace = true } homepage = "https://github.com/tokio-rs/axum" keywords = ["http", "web", "framework"] license = "MIT" @@ -24,8 +24,8 @@ default = [ "tracing", ] form = ["dep:serde_urlencoded"] -http1 = ["dep:hyper", "hyper?/http1"] -http2 = ["dep:hyper", "hyper?/http2"] +http1 = ["dep:hyper", "hyper?/http1", "hyper-util?/http1"] +http2 = ["dep:hyper", "hyper?/http2", "hyper-util?/http2"] json = ["dep:serde_json", "dep:serde_path_to_error"] macros = ["dep:axum-macros"] matched-path = [] @@ -37,45 +37,50 @@ tower-log = ["tower/log"] tracing = ["dep:tracing", "axum-core/tracing"] ws = ["dep:hyper", "tokio", "dep:tokio-tungstenite", "dep:sha1", "dep:base64"] -# Required for intra-doc links to resolve correctly -__private_docs = ["tower/full", "dep:tower-http"] +__private_docs = [ + # We re-export some docs from axum-core via #[doc(inline)], + # but they need the same sort of treatment as below to be complete + "axum-core/__private_docs", + # Enables upstream things linked to in docs + "tower/full", "dep:tower-http", +] [dependencies] -async-trait = "0.1.67" -axum-core = { path = "../axum-core", version = "0.4.2" } +axum-core = { path = "../axum-core", version = "0.5.0-alpha.1" } bytes = "1.0" futures-util = { version = "0.3", default-features = false, features = ["alloc"] } http = "1.0.0" http-body = "1.0.0" http-body-util = "0.1.0" itoa = "1.0.5" -matchit = "0.7" +matchit = "=0.8.4" memchr = "2.4.1" mime = "0.3.16" percent-encoding = "2.1" pin-project-lite = "0.2.7" +rustversion = "1.0.9" serde = "1.0" -sync_wrapper = "0.1.1" -tower = { version = "0.4.13", default-features = false, features = ["util"] } +sync_wrapper = "1.0.0" +tower = { version = "0.5.1", default-features = false, features = ["util"] } tower-layer = "0.3.2" tower-service = "0.3" # optional dependencies -axum-macros = { path = "../axum-macros", version = "0.4.0", optional = true } -base64 = { version = "0.21.0", optional = true } +axum-macros = { path = "../axum-macros", version = "0.5.0-alpha.1", optional = true } +base64 = { version = "0.22.1", optional = true } hyper = { version = "1.1.0", optional = true } -hyper-util = { version = "0.1.2", features = ["tokio", "server", "server-auto"], optional = true } +hyper-util = { version = "0.1.3", features = ["tokio", "server", "service"], optional = true } multer = { version = "3.0.0", optional = true } serde_json = { version = "1.0", features = ["raw_value"], optional = true } serde_path_to_error = { version = "0.1.8", optional = true } serde_urlencoded = { version = "0.7", optional = true } sha1 = { version = "0.10", optional = true } tokio = { package = "tokio", version = "1.25.0", features = ["time"], optional = true } -tokio-tungstenite = { version = "0.21", optional = true } +tokio-tungstenite = { version = "0.24.0", optional = true } tracing = { version = "0.1", default-features = false, optional = true } [dependencies.tower-http] -version = "0.5.0" +version = "0.6.0" optional = true features = [ # all tower-http features except (de)?compression-zstd which doesn't @@ -109,32 +114,29 @@ features = [ "validate-request", ] -[build-dependencies] -rustversion = "1.0.9" - [dev-dependencies] anyhow = "1.0" -axum-macros = { path = "../axum-macros", version = "0.4.0", features = ["__private"] } +axum-macros = { path = "../axum-macros", features = ["__private"] } +hyper = { version = "1.1.0", features = ["client"] } quickcheck = "1.0" quickcheck_macros = "1.0" -reqwest = { version = "0.11.14", default-features = false, features = ["json", "stream", "multipart"] } -rustversion = "1.0.9" +reqwest = { version = "0.12", default-features = false, features = ["json", "stream", "multipart"] } serde = { version = "1.0", features = ["derive"] } -serde_json = "1.0" +serde_json = { version = "1.0", features = ["raw_value"] } time = { version = "0.3", features = ["serde-human-readable"] } tokio = { package = "tokio", version = "1.25.0", features = ["macros", "rt", "rt-multi-thread", "net", "test-util"] } tokio-stream = "0.1" +tokio-tungstenite = "0.24.0" tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["json"] } uuid = { version = "1.0", features = ["serde", "v4"] } [package.metadata.docs.rs] all-features = true -rustdoc-args = ["--cfg", "docsrs"] [dev-dependencies.tower] package = "tower" -version = "0.4.10" +version = "0.5.1" features = [ "util", "timeout", @@ -145,7 +147,7 @@ features = [ ] [dev-dependencies.tower-http] -version = "0.5.0" +version = "0.6.0" features = [ # all tower-http features except (de)?compression-zstd which doesn't # build on `--target armv5te-unknown-linux-musleabi` @@ -197,11 +199,11 @@ allowed = [ "futures_core", "futures_sink", "futures_util", + "pin_project_lite", "tower_layer", "tower_service", # >=1.0 - "async_trait", "bytes", "http", "http_body", diff --git a/axum/README.md b/axum/README.md index dc7f1a95dd..344484a8ec 100644 --- a/axum/README.md +++ b/axum/README.md @@ -4,7 +4,7 @@ [![Build status](https://github.com/tokio-rs/axum/actions/workflows/CI.yml/badge.svg?branch=main)](https://github.com/tokio-rs/axum/actions/workflows/CI.yml) [![Crates.io](https://img.shields.io/crates/v/axum)](https://crates.io/crates/axum) -[![Documentation](https://docs.rs/axum/badge.svg)](https://docs.rs/axum) +[![Documentation](https://docs.rs/axum/badge.svg)][docs] More information about this crate can be found in the [crate documentation][docs]. @@ -23,6 +23,13 @@ In particular the last point is what sets `axum` apart from other frameworks. authorization, and more, for free. It also enables you to share middleware with applications written using [`hyper`] or [`tonic`]. +## ⚠ Breaking changes ⚠ + +We are currently working towards axum 0.8 so the `main` branch contains breaking +changes. See the [`0.7.x`] branch for what's released to crates.io. + +[`0.7.x`]: https://github.com/tokio-rs/axum/tree/v0.7.x + ## Usage example ```rust @@ -104,7 +111,7 @@ This crate uses `#![forbid(unsafe_code)]` to ensure everything is implemented in ## Minimum supported Rust version -axum's MSRV is 1.66. +axum's MSRV is 1.75. ## Examples diff --git a/axum/benches/benches.rs b/axum/benches/benches.rs index 5bcdc906f9..bb1c303dd1 100644 --- a/axum/benches/benches.rs +++ b/axum/benches/benches.rs @@ -198,7 +198,7 @@ impl BenchmarkBuilder { eprintln!("Running {:?} benchmark", self.name); - // indent output from `rewrk` so its easier to read when running multiple benchmarks + // indent output from `rewrk` so it's easier to read when running multiple benchmarks let mut child = cmd.spawn().unwrap(); let stdout = child.stdout.take().unwrap(); let stdout = std::io::BufReader::new(stdout); diff --git a/axum/build.rs b/axum/build.rs deleted file mode 100644 index b52885c626..0000000000 --- a/axum/build.rs +++ /dev/null @@ -1,7 +0,0 @@ -#[rustversion::nightly] -fn main() { - println!("cargo:rustc-cfg=nightly_error_messages"); -} - -#[rustversion::not(nightly)] -fn main() {} diff --git a/axum/src/body/mod.rs b/axum/src/body/mod.rs index b52db8231a..d32a89956d 100644 --- a/axum/src/body/mod.rs +++ b/axum/src/body/mod.rs @@ -7,7 +7,7 @@ pub use http_body::Body as HttpBody; pub use bytes::Bytes; #[doc(inline)] -pub use axum_core::body::Body; +pub use axum_core::body::{Body, BodyDataStream}; use http_body_util::{BodyExt, Limited}; diff --git a/axum/src/box_clone_service.rs b/axum/src/box_clone_service.rs new file mode 100644 index 0000000000..25c0b205b8 --- /dev/null +++ b/axum/src/box_clone_service.rs @@ -0,0 +1,80 @@ +use futures_util::future::BoxFuture; +use std::{ + fmt, + task::{Context, Poll}, +}; +use tower::ServiceExt; +use tower_service::Service; + +/// Like `tower::BoxCloneService` but `Sync` +pub(crate) struct BoxCloneService( + Box< + dyn CloneService>> + + Send + + Sync, + >, +); + +impl BoxCloneService { + pub(crate) fn new(inner: S) -> Self + where + S: Service + Clone + Send + Sync + 'static, + S::Future: Send + 'static, + { + let inner = inner.map_future(|f| Box::pin(f) as _); + BoxCloneService(Box::new(inner)) + } +} + +impl Service for BoxCloneService { + type Response = U; + type Error = E; + type Future = BoxFuture<'static, Result>; + + #[inline] + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.0.poll_ready(cx) + } + + #[inline] + fn call(&mut self, request: T) -> Self::Future { + self.0.call(request) + } +} + +impl Clone for BoxCloneService { + fn clone(&self) -> Self { + Self(self.0.clone_box()) + } +} + +trait CloneService: Service { + fn clone_box( + &self, + ) -> Box< + dyn CloneService + + Send + + Sync, + >; +} + +impl CloneService for T +where + T: Service + Send + Sync + Clone + 'static, +{ + fn clone_box( + &self, + ) -> Box< + dyn CloneService + + Send + + Sync, + > { + Box::new(self.clone()) + } +} + +impl fmt::Debug for BoxCloneService { + fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result { + fmt.debug_struct("BoxCloneService").finish() + } +} diff --git a/axum/src/boxed.rs b/axum/src/boxed.rs index 5473b2491b..0d65b2d38f 100644 --- a/axum/src/boxed.rs +++ b/axum/src/boxed.rs @@ -32,7 +32,7 @@ impl BoxedIntoRoute { where S: 'static, E: 'static, - F: FnOnce(Route) -> Route + Clone + Send + 'static, + F: FnOnce(Route) -> Route + Clone + Send + Sync + 'static, E2: 'static, { BoxedIntoRoute(Box::new(Map { @@ -58,11 +58,12 @@ impl fmt::Debug for BoxedIntoRoute { } } -pub(crate) trait ErasedIntoRoute: Send { +pub(crate) trait ErasedIntoRoute: Send + Sync { fn clone_box(&self) -> Box>; fn into_route(self: Box, state: S) -> Route; + #[allow(dead_code)] fn call_with_state(self: Box, request: Request, state: S) -> RouteFuture; } @@ -73,7 +74,7 @@ pub(crate) struct MakeErasedHandler { impl ErasedIntoRoute for MakeErasedHandler where - H: Clone + Send + 'static, + H: Clone + Send + Sync + 'static, S: 'static, { fn clone_box(&self) -> Box> { @@ -101,6 +102,7 @@ where } } +#[allow(dead_code)] pub(crate) struct MakeErasedRouter { pub(crate) router: Router, pub(crate) into_route: fn(Router, S) -> Route, @@ -118,7 +120,7 @@ where (self.into_route)(self.router, state) } - fn call_with_state(mut self: Box, request: Request, state: S) -> RouteFuture { + fn call_with_state(self: Box, request: Request, state: S) -> RouteFuture { self.router.call_with_state(request, state) } } @@ -162,13 +164,13 @@ where } } -pub(crate) trait LayerFn: FnOnce(Route) -> Route + Send { +pub(crate) trait LayerFn: FnOnce(Route) -> Route + Send + Sync { fn clone_box(&self) -> Box>; } impl LayerFn for F where - F: FnOnce(Route) -> Route + Clone + Send + 'static, + F: FnOnce(Route) -> Route + Clone + Send + Sync + 'static, { fn clone_box(&self) -> Box> { Box::new(self.clone()) diff --git a/axum/src/docs/error_handling.md b/axum/src/docs/error_handling.md index ea7d8a74cb..7d7e14ee05 100644 --- a/axum/src/docs/error_handling.md +++ b/axum/src/docs/error_handling.md @@ -1,12 +1,5 @@ Error handling model and utilities -# Table of contents - -- [axum's error handling model](#axums-error-handling-model) -- [Routing to fallible services](#routing-to-fallible-services) -- [Applying fallible middleware](#applying-fallible-middleware) -- [Running extractors for error handling](#running-extractors-for-error-handling) - # axum's error handling model axum is based on [`tower::Service`] which bundles errors through its associated @@ -43,10 +36,10 @@ that can ultimately be converted to `Response`. This allows using `?` operator in handlers. See those examples: * [`anyhow-error-response`][anyhow] for generic boxed errors -* [`error-handling-and-dependency-injection`][ehdi] for application-specific detailed errors +* [`error-handling`][error-handling] for application-specific detailed errors [anyhow]: https://github.com/tokio-rs/axum/blob/main/examples/anyhow-error-response/src/main.rs -[ehdi]: https://github.com/tokio-rs/axum/blob/main/examples/error-handling-and-dependency-injection/src/main.rs +[error-handling]: https://github.com/tokio-rs/axum/blob/main/examples/error-handling/src/main.rs This also applies to extractors. If an extractor doesn't match the request the request will be rejected and a response will be returned without calling your diff --git a/axum/src/docs/extract.md b/axum/src/docs/extract.md index 13d27171d6..244528d6a8 100644 --- a/axum/src/docs/extract.md +++ b/axum/src/docs/extract.md @@ -1,27 +1,10 @@ Types and traits for extracting data from requests. -# Table of contents - -- [Intro](#intro) -- [Common extractors](#common-extractors) -- [Applying multiple extractors](#applying-multiple-extractors) -- [The order of extractors](#the-order-of-extractors) -- [Optional extractors](#optional-extractors) -- [Customizing extractor responses](#customizing-extractor-responses) -- [Accessing inner errors](#accessing-inner-errors) -- [Defining custom extractors](#defining-custom-extractors) -- [Accessing other extractors in `FromRequest` or `FromRequestParts` implementations](#accessing-other-extractors-in-fromrequest-or-fromrequestparts-implementations) -- [Request body limits](#request-body-limits) -- [Request body extractors](#request-body-extractors) -- [Wrapping extractors](#wrapping-extractors) -- [Logging rejections](#logging-rejections) - # Intro A handler function is an async function that takes any number of "extractors" as arguments. An extractor is a type that implements -[`FromRequest`](crate::extract::FromRequest) -or [`FromRequestParts`](crate::extract::FromRequestParts). +[`FromRequest`] or [`FromRequestParts`]. For example, [`Json`] is an extractor that consumes the request body and deserializes it as JSON into some target type: @@ -94,7 +77,7 @@ async fn extension(Extension(state): Extension) {} struct State { /* ... */ } let app = Router::new() - .route("/path/:user_id", post(path)) + .route("/path/{user_id}", post(path)) .route("/query", post(query)) .route("/string", post(string)) .route("/bytes", post(bytes)) @@ -117,7 +100,7 @@ use axum::{ use uuid::Uuid; use serde::Deserialize; -let app = Router::new().route("/users/:id/things", get(get_user_things)); +let app = Router::new().route("/users/{id}/things", get(get_user_things)); #[derive(Deserialize)] struct Pagination { @@ -282,10 +265,15 @@ let app = Router::new().route("/users", post(create_user)); # let _: Router = app; ``` +Another option is to make use of the optional extractors in [axum-extra] that +either returns `None` if there are no query parameters in the request URI, +or returns `Some(T)` if deserialization was successful. +If the deserialization was not successful, the request is rejected. + # Customizing extractor responses If an extractor fails it will return a response with the error and your -handler will not be called. To customize the error response you have a two +handler will not be called. To customize the error response you have two options: 1. Use `Result` as your extractor like shown in ["Optional @@ -421,7 +409,6 @@ request body: ```rust,no_run use axum::{ - async_trait, extract::FromRequestParts, routing::get, Router, @@ -434,7 +421,6 @@ use axum::{ struct ExtractUserAgent(HeaderValue); -#[async_trait] impl FromRequestParts for ExtractUserAgent where S: Send + Sync, @@ -464,7 +450,6 @@ If your extractor needs to consume the request body you must implement [`FromReq ```rust,no_run use axum::{ - async_trait, extract::{Request, FromRequest}, response::{Response, IntoResponse}, body::{Bytes, Body}, @@ -478,7 +463,6 @@ use axum::{ struct ValidatedBody(Bytes); -#[async_trait] impl FromRequest for ValidatedBody where Bytes: FromRequest, @@ -518,7 +502,6 @@ use axum::{ extract::{FromRequest, Request, FromRequestParts}, http::request::Parts, body::Body, - async_trait, }; use std::convert::Infallible; @@ -526,7 +509,6 @@ use std::convert::Infallible; struct MyExtractor; // `MyExtractor` implements both `FromRequest` -#[async_trait] impl FromRequest for MyExtractor where S: Send + Sync, @@ -540,7 +522,6 @@ where } // and `FromRequestParts` -#[async_trait] impl FromRequestParts for MyExtractor where S: Send + Sync, @@ -574,7 +555,6 @@ in your implementation. ```rust use axum::{ - async_trait, extract::{Extension, FromRequestParts}, http::{StatusCode, HeaderMap, request::Parts}, response::{IntoResponse, Response}, @@ -591,7 +571,6 @@ struct AuthenticatedUser { // ... } -#[async_trait] impl FromRequestParts for AuthenticatedUser where S: Send + Sync, @@ -645,7 +624,6 @@ use axum::{ routing::get, extract::{Request, FromRequest, FromRequestParts}, http::{HeaderMap, request::Parts}, - async_trait, }; use std::time::{Instant, Duration}; @@ -656,7 +634,6 @@ struct Timing { } // we must implement both `FromRequestParts` -#[async_trait] impl FromRequestParts for Timing where S: Send + Sync, @@ -676,7 +653,6 @@ where } // and `FromRequest` -#[async_trait] impl FromRequest for Timing where S: Send + Sync, @@ -711,6 +687,7 @@ logs, enable the `tracing` feature for axum (enabled by default) and the `axum::rejection=trace` tracing target, for example with `RUST_LOG=info,axum::rejection=trace cargo run`. +[axum-extra]: https://docs.rs/axum-extra/latest/axum_extra/extract/index.html [`body::Body`]: crate::body::Body [`Bytes`]: crate::body::Bytes [customize-extractor-error]: https://github.com/tokio-rs/axum/blob/main/examples/customize-extractor-error/src/main.rs diff --git a/axum/src/docs/middleware.md b/axum/src/docs/middleware.md index 4ba977b4c8..bef5a8b592 100644 --- a/axum/src/docs/middleware.md +++ b/axum/src/docs/middleware.md @@ -1,22 +1,10 @@ -# Table of contents - -- [Intro](#intro) -- [Applying middleware](#applying-middleware) -- [Commonly used middleware](#commonly-used-middleware) -- [Ordering](#ordering) -- [Writing middleware](#writing-middleware) -- [Routing to services/middleware and backpressure](#routing-to-servicesmiddleware-and-backpressure) -- [Accessing state in middleware](#accessing-state-in-middleware) -- [Passing state from middleware to handlers](#passing-state-from-middleware-to-handlers) -- [Rewriting request URI in middleware](#rewriting-request-uri-in-middleware) - # Intro axum is unique in that it doesn't have its own bespoke middleware system and instead integrates with [`tower`]. This means the ecosystem of [`tower`] and [`tower-http`] middleware all work with axum. -While its not necessary to fully understand tower to write or use middleware +While it's not necessary to fully understand tower to write or use middleware with axum, having at least a basic understanding of tower's concepts is recommended. See [tower's guides][tower-guides] for a general introduction. Reading the documentation for [`tower::ServiceBuilder`] is also recommended. @@ -31,7 +19,7 @@ axum allows you to add middleware just about anywhere ## Applying multiple middleware -Its recommended to use [`tower::ServiceBuilder`] to apply multiple middleware at +It's recommended to use [`tower::ServiceBuilder`] to apply multiple middleware at once, instead of calling `layer` (or `route_layer`) repeatedly: ```rust @@ -128,9 +116,9 @@ That is: It's a little more complicated in practice because any middleware is free to return early and not call the next layer, for example if a request cannot be -authorized, but its a useful mental model to have. +authorized, but it's a useful mental model to have. -As previously mentioned its recommended to add multiple middleware using +As previously mentioned it's recommended to add multiple middleware using `tower::ServiceBuilder`, however this impacts ordering: ```rust @@ -202,7 +190,7 @@ You should use these when ## `tower::Service` and `Pin>` -For maximum control (and a more low level API) you can write you own middleware +For maximum control (and a more low level API) you can write your own middleware by implementing [`tower::Service`]: Use [`tower::Service`] with `Pin>` to write your middleware when: @@ -352,11 +340,11 @@ readiness inside the response future returned by `Service::call`. This works well when your services don't care about backpressure and are always ready anyway. -axum expects that all services used in your app wont care about +axum expects that all services used in your app won't care about backpressure and so it uses the latter strategy. However that means you should avoid routing to a service (or using a middleware) that _does_ care -about backpressure. At the very least you should [load shed] so requests are -dropped quickly and don't keep piling up. +about backpressure. At the very least you should [load shed][tower::load_shed] +so requests are dropped quickly and don't keep piling up. It also means that if `poll_ready` returns an error then that error will be returned in the response future from `call` and _not_ from `poll_ready`. In @@ -388,8 +376,7 @@ let app = ServiceBuilder::new() ``` However when applying middleware around your whole application in this way -you have to take care that errors are still being handled with -appropriately. +you have to take care that errors are still being handled appropriately. Also note that handlers created from async functions don't care about backpressure and are always ready. So if you're not using any Tower diff --git a/axum/src/docs/response.md b/axum/src/docs/response.md index a5761c34ed..c0974fb640 100644 --- a/axum/src/docs/response.md +++ b/axum/src/docs/response.md @@ -1,11 +1,5 @@ Types and traits for generating responses. -# Table of contents - -- [Building responses](#building-responses) -- [Returning different response types](#returning-different-response-types) -- [Regarding `impl IntoResponse`](#regarding-impl-intoresponse) - # Building responses Anything that implements [`IntoResponse`] can be returned from a handler. axum @@ -166,7 +160,7 @@ In general you can return tuples like: This means you cannot accidentally override the status or body as [`IntoResponseParts`] only allows setting headers and extensions. -Use [`Response`](crate::response::Response) for more low level control: +Use [`Response`] for more low level control: ```rust,no_run use axum::{ diff --git a/axum/src/docs/routing/fallback.md b/axum/src/docs/routing/fallback.md index 27fb76a59e..a864b7a45d 100644 --- a/axum/src/docs/routing/fallback.md +++ b/axum/src/docs/routing/fallback.md @@ -23,7 +23,11 @@ async fn fallback(uri: Uri) -> (StatusCode, String) { Fallbacks only apply to routes that aren't matched by anything in the router. If a handler is matched by a request but returns 404 the -fallback is not called. +fallback is not called. Note that this applies to [`MethodRouter`]s too: if the +request hits a valid path but the [`MethodRouter`] does not have an appropriate +method handler installed, the fallback is not called (use +[`MethodRouter::fallback`] for this purpose instead). + # Handling all requests without other routes diff --git a/axum/src/docs/routing/merge.md b/axum/src/docs/routing/merge.md index e8f668712e..ddea660879 100644 --- a/axum/src/docs/routing/merge.md +++ b/axum/src/docs/routing/merge.md @@ -16,7 +16,7 @@ use axum::{ // define some routes separately let user_routes = Router::new() .route("/users", get(users_list)) - .route("/users/:id", get(users_show)); + .route("/users/{id}", get(users_show)); let team_routes = Router::new() .route("/teams", get(teams_list)); @@ -30,7 +30,7 @@ let app = Router::new() // Our app now accepts // - GET /users -// - GET /users/:id +// - GET /users/{id} // - GET /teams # let _: Router = app; ``` diff --git a/axum/src/docs/routing/method_not_allowed_fallback.md b/axum/src/docs/routing/method_not_allowed_fallback.md new file mode 100644 index 0000000000..22905cd941 --- /dev/null +++ b/axum/src/docs/routing/method_not_allowed_fallback.md @@ -0,0 +1,38 @@ +Add a fallback [`Handler`] for the case where a route exists, but the method of the request is not supported. + +Sets a fallback on all previously registered [`MethodRouter`]s, +to be called when no matching method handler is set. + +```rust,no_run +use axum::{response::IntoResponse, routing::get, Router}; + +async fn hello_world() -> impl IntoResponse { + "Hello, world!\n" +} + +async fn default_fallback() -> impl IntoResponse { + "Default fallback\n" +} + +async fn handle_405() -> impl IntoResponse { + "Method not allowed fallback" +} + +#[tokio::main] +async fn main() { + let router = Router::new() + .route("/", get(hello_world)) + .fallback(default_fallback) + .method_not_allowed_fallback(handle_405); + + let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap(); + + axum::serve(listener, router).await.unwrap(); +} +``` + +The fallback only applies if there is a `MethodRouter` registered for a given path, +but the method used in the request is not specified. In the example, a `GET` on +`http://localhost:3000` causes the `hello_world` handler to react, while issuing a +`POST` triggers `handle_405`. Calling an entirely different route, like `http://localhost:3000/hello` +causes `default_fallback` to run. diff --git a/axum/src/docs/routing/nest.md b/axum/src/docs/routing/nest.md index c3f7308fdb..bb5b2ea6cb 100644 --- a/axum/src/docs/routing/nest.md +++ b/axum/src/docs/routing/nest.md @@ -11,7 +11,7 @@ use axum::{ Router, }; -let user_routes = Router::new().route("/:id", get(|| async {})); +let user_routes = Router::new().route("/{id}", get(|| async {})); let team_routes = Router::new().route("/", post(|| async {})); @@ -22,7 +22,7 @@ let api_routes = Router::new() let app = Router::new().nest("/api", api_routes); // Our app now accepts -// - GET /api/users/:id +// - GET /api/users/{id} // - POST /api/teams # let _: Router = app; ``` @@ -54,9 +54,9 @@ async fn users_get(Path(params): Path>) { let id = params.get("id"); } -let users_api = Router::new().route("/users/:id", get(users_get)); +let users_api = Router::new().route("/users/{id}", get(users_get)); -let app = Router::new().nest("/:version/api", users_api); +let app = Router::new().nest("/{version}/api", users_api); # let _: Router = app; ``` @@ -75,13 +75,18 @@ let nested_router = Router::new() })); let app = Router::new() - .route("/foo/*rest", get(|uri: Uri| async { + .route("/foo/{*rest}", get(|uri: Uri| async { // `uri` will contain `/foo` })) .nest("/bar", nested_router); # let _: Router = app; ``` +Additionally, while the wildcard route `/foo/*rest` will not match the +paths `/foo` or `/foo/`, a nested router at `/foo` will match the path `/foo` +(but not `/foo/`), and a nested router at `/foo/` will match the path `/foo/` +(but not `/foo`). + # Fallbacks If a nested router doesn't have its own fallback then it will inherit the @@ -181,7 +186,7 @@ router. # Panics - If the route overlaps with another route. See [`Router::route`] -for more details. + for more details. - If the route contains a wildcard (`*`). - If `path` is empty. diff --git a/axum/src/docs/routing/route.md b/axum/src/docs/routing/route.md index eefbb21beb..01be9152ed 100644 --- a/axum/src/docs/routing/route.md +++ b/axum/src/docs/routing/route.md @@ -5,8 +5,7 @@ can be either static, a capture, or a wildcard. `method_router` is the [`MethodRouter`] that should receive the request if the path matches `path`. `method_router` will commonly be a handler wrapped in a method -router like [`get`](crate::routing::get). See [`handler`](crate::handler) for -more details on handlers. +router like [`get`]. See [`handler`](crate::handler) for more details on handlers. # Static paths @@ -21,15 +20,15 @@ be called. # Captures -Paths can contain segments like `/:key` which matches any single segment and +Paths can contain segments like `/{key}` which matches any single segment and will store the value captured at `key`. The value captured can be zero-length except for in the invalid path `//`. Examples: -- `/:key` -- `/users/:id` -- `/users/:id/tweets` +- `/{key}` +- `/users/{id}` +- `/users/{id}/tweets` Captures can be extracted using [`Path`](crate::extract::Path). See its documentation for more details. @@ -42,22 +41,37 @@ path rather than the actual path. # Wildcards -Paths can end in `/*key` which matches all segments and will store the segments +Paths can end in `/{*key}` which matches all segments and will store the segments captured at `key`. Examples: -- `/*key` -- `/assets/*path` -- `/:id/:repo/*tree` +- `/{*key}` +- `/assets/{*path}` +- `/{id}/{repo}/{*tree}` -Note that `/*key` doesn't match empty segments. Thus: +Note that `/{*key}` doesn't match empty segments. Thus: -- `/*key` doesn't match `/` but does match `/a`, `/a/`, etc. -- `/x/*key` doesn't match `/x` or `/x/` but does match `/x/a`, `/x/a/`, etc. +- `/{*key}` doesn't match `/` but does match `/a`, `/a/`, etc. +- `/x/{*key}` doesn't match `/x` or `/x/` but does match `/x/a`, `/x/a/`, etc. -Wildcard captures can also be extracted using [`Path`](crate::extract::Path). -Note that the leading slash is not included, i.e. for the route `/foo/*rest` and +Wildcard captures can also be extracted using [`Path`](crate::extract::Path): + +```rust +use axum::{ + Router, + routing::get, + extract::Path, +}; + +let app: Router = Router::new().route("/{*key}", get(handler)); + +async fn handler(Path(path): Path) -> String { + path +} +``` + +Note that the leading slash is not included, i.e. for the route `/foo/{*rest}` and the path `/foo/bar/baz` the value of `rest` will be `bar/baz`. # Accepting multiple methods @@ -106,9 +120,9 @@ use axum::{Router, routing::{get, delete}, extract::Path}; let app = Router::new() .route("/", get(root)) .route("/users", get(list_users).post(create_user)) - .route("/users/:id", get(show_user)) - .route("/api/:version/users/:id/action", delete(do_users_action)) - .route("/assets/*path", get(serve_asset)); + .route("/users/{id}", get(show_user)) + .route("/api/{version}/users/{id}/action", delete(do_users_action)) + .route("/assets/{*path}", get(serve_asset)); async fn root() {} @@ -137,7 +151,7 @@ let app = Router::new() # let _: Router = app; ``` -The static route `/foo` and the dynamic route `/:key` are not considered to +The static route `/foo` and the dynamic route `/{key}` are not considered to overlap and `/foo` will take precedence. Also panics if `path` is empty. diff --git a/axum/src/docs/routing/route_layer.md b/axum/src/docs/routing/route_layer.md index bc7b219742..9cce3ea79e 100644 --- a/axum/src/docs/routing/route_layer.md +++ b/axum/src/docs/routing/route_layer.md @@ -11,6 +11,10 @@ the request matches a route. This is useful for middleware that return early (such as authorization) which might otherwise convert a `404 Not Found` into a `401 Unauthorized`. +This function will panic if no routes have been declared yet on the router, +since the new layer will have no effect, and this is typically a bug. +In generic code, you can test if that is the case first, by calling [`Router::has_routes`]. + # Example ```rust diff --git a/axum/src/docs/routing/with_state.md b/axum/src/docs/routing/with_state.md index bece920fe0..197741cf55 100644 --- a/axum/src/docs/routing/with_state.md +++ b/axum/src/docs/routing/with_state.md @@ -1,4 +1,5 @@ -Provide the state for the router. +Provide the state for the router. State passed to this method is global and will be used +for all requests this router receives. That means it is not suitable for holding state derived from a request, such as authorization data extracted in a middleware. Use [`Extension`] instead for such data. ```rust use axum::{Router, routing::get, extract::State}; @@ -20,7 +21,7 @@ axum::serve(listener, routes).await.unwrap(); # Returning routers with states from functions -When returning `Router`s from functions it is generally recommend not set the +When returning `Router`s from functions, it is generally recommended not to set the state directly: ```rust @@ -94,13 +95,6 @@ axum::serve(listener, routes).await.unwrap(); # }; ``` -# State is global within the router - -The state passed to this method will be used for all requests this router -receives. That means it is not suitable for holding state derived from a -request, such as authorization data extracted in a middleware. Use [`Extension`] -instead for such data. - # What `S` in `Router` means `Router` means a router that is _missing_ a state of type `S` to be able to @@ -171,7 +165,7 @@ work: # #[derive(Clone)] # struct AppState {} # -// This wont work because we're returning a `Router` +// This won't work because we're returning a `Router` // i.e. we're saying we're still missing an `AppState` fn routes(state: AppState) -> Router { Router::new() diff --git a/axum/src/docs/routing/without_v07_checks.md b/axum/src/docs/routing/without_v07_checks.md new file mode 100644 index 0000000000..f1b465ea9e --- /dev/null +++ b/axum/src/docs/routing/without_v07_checks.md @@ -0,0 +1,43 @@ +Turn off checks for compatibility with route matching syntax from 0.7. + +This allows usage of paths starting with a colon `:` or an asterisk `*` which are otherwise prohibited. + +# Example + +```rust +use axum::{ + routing::get, + Router, +}; + +let app = Router::<()>::new() + .without_v07_checks() + .route("/:colon", get(|| async {})) + .route("/*asterisk", get(|| async {})); + +// Our app now accepts +// - GET /:colon +// - GET /*asterisk +# let _: Router = app; +``` + +Adding such routes without calling this method first will panic. + +```rust,should_panic +use axum::{ + routing::get, + Router, +}; + +// This panics... +let app = Router::<()>::new() + .route("/:colon", get(|| async {})); +``` + +# Merging + +When two routers are merged, v0.7 checks are disabled for route registrations on the resulting router if both of the two routers had them also disabled. + +# Nesting + +Each router needs to have the checks explicitly disabled. Nesting a router with the checks either enabled or disabled has no effect on the outer router. diff --git a/axum/src/extension.rs b/axum/src/extension.rs index a72568ac6c..9485443232 100644 --- a/axum/src/extension.rs +++ b/axum/src/extension.rs @@ -1,5 +1,4 @@ use crate::{extract::rejection::*, response::IntoResponseParts}; -use async_trait::async_trait; use axum_core::{ extract::FromRequestParts, response::{IntoResponse, Response, ResponseParts}, @@ -70,7 +69,6 @@ use tower_service::Service; #[must_use] pub struct Extension(pub T); -#[async_trait] impl FromRequestParts for Extension where T: Clone + Send + Sync + 'static, @@ -87,8 +85,7 @@ where "Extension of type `{}` was not found. Perhaps you forgot to add it? See `axum::Extension`.", std::any::type_name::() )) - }) - .map(|x| x.clone())?; + }).cloned()?; Ok(Extension(value)) } diff --git a/axum/src/extract/connect_info.rs b/axum/src/extract/connect_info.rs index 3036d375be..3d8f9a0163 100644 --- a/axum/src/extract/connect_info.rs +++ b/axum/src/extract/connect_info.rs @@ -7,7 +7,6 @@ use crate::extension::AddExtension; use super::{Extension, FromRequestParts}; -use async_trait::async_trait; use http::request::Parts; use std::{ convert::Infallible, @@ -139,7 +138,6 @@ opaque_future! { #[derive(Clone, Copy, Debug)] pub struct ConnectInfo(pub T); -#[async_trait] impl FromRequestParts for ConnectInfo where S: Send + Sync, @@ -225,7 +223,6 @@ where mod tests { use super::*; use crate::{routing::get, serve::IncomingStream, test_helpers::TestClient, Router}; - use std::net::SocketAddr; use tokio::net::TcpListener; #[crate::test] @@ -311,7 +308,7 @@ mod tests { let client = TestClient::new(app); - let res = client.get("/").send().await; + let res = client.get("/").await; let body = res.text().await; assert!(body.starts_with("0.0.0.0:1337")); } diff --git a/axum/src/extract/matched_path.rs b/axum/src/extract/matched_path.rs index cdb49b4db3..124154f7ef 100644 --- a/axum/src/extract/matched_path.rs +++ b/axum/src/extract/matched_path.rs @@ -1,6 +1,5 @@ use super::{rejection::*, FromRequestParts}; use crate::routing::{RouteId, NEST_TAIL_PARAM_CAPTURE}; -use async_trait::async_trait; use http::request::Parts; use std::{collections::HashMap, sync::Arc}; @@ -14,10 +13,10 @@ use std::{collections::HashMap, sync::Arc}; /// }; /// /// let app = Router::new().route( -/// "/users/:id", +/// "/users/{id}", /// get(|path: MatchedPath| async move { /// let path = path.as_str(); -/// // `path` will be "/users/:id" +/// // `path` will be "/users/{id}" /// }) /// ); /// # let _: Router = app; @@ -39,7 +38,7 @@ use std::{collections::HashMap, sync::Arc}; /// use tower_http::trace::TraceLayer; /// /// let app = Router::new() -/// .route("/users/:id", get(|| async { /* ... */ })) +/// .route("/users/{id}", get(|| async { /* ... */ })) /// .layer( /// TraceLayer::new_for_http().make_span_with(|req: &Request<_>| { /// let path = if let Some(path) = req.extensions().get::() { @@ -63,7 +62,6 @@ impl MatchedPath { } } -#[async_trait] impl FromRequestParts for MatchedPath where S: Send + Sync, @@ -143,40 +141,40 @@ mod tests { #[crate::test] async fn extracting_on_handler() { let app = Router::new().route( - "/:a", + "/{a}", get(|path: MatchedPath| async move { path.as_str().to_owned() }), ); let client = TestClient::new(app); - let res = client.get("/foo").send().await; - assert_eq!(res.text().await, "/:a"); + let res = client.get("/foo").await; + assert_eq!(res.text().await, "/{a}"); } #[crate::test] async fn extracting_on_handler_in_nested_router() { let app = Router::new().nest( - "/:a", + "/{a}", Router::new().route( - "/:b", + "/{b}", get(|path: MatchedPath| async move { path.as_str().to_owned() }), ), ); let client = TestClient::new(app); - let res = client.get("/foo/bar").send().await; - assert_eq!(res.text().await, "/:a/:b"); + let res = client.get("/foo/bar").await; + assert_eq!(res.text().await, "/{a}/{b}"); } #[crate::test] async fn extracting_on_handler_in_deeply_nested_router() { let app = Router::new().nest( - "/:a", + "/{a}", Router::new().nest( - "/:b", + "/{b}", Router::new().route( - "/:c", + "/{c}", get(|path: MatchedPath| async move { path.as_str().to_owned() }), ), ), @@ -184,8 +182,8 @@ mod tests { let client = TestClient::new(app); - let res = client.get("/foo/bar/baz").send().await; - assert_eq!(res.text().await, "/:a/:b/:c"); + let res = client.get("/foo/bar/baz").await; + assert_eq!(res.text().await, "/{a}/{b}/{c}"); } #[crate::test] @@ -199,12 +197,12 @@ mod tests { } let app = Router::new() - .nest_service("/:a", Router::new().route("/:b", get(|| async move {}))) + .nest_service("/{a}", Router::new().route("/{b}", get(|| async move {}))) .layer(map_request(extract_matched_path)); let client = TestClient::new(app); - let res = client.get("/foo/bar").send().await; + let res = client.get("/foo/bar").await; assert_eq!(res.status(), StatusCode::OK); } @@ -214,17 +212,17 @@ mod tests { matched_path: Option, req: Request, ) -> Request { - assert_eq!(matched_path.unwrap().as_str(), "/:a/:b"); + assert_eq!(matched_path.unwrap().as_str(), "/{a}/{b}"); req } let app = Router::new() - .nest("/:a", Router::new().route("/:b", get(|| async move {}))) + .nest("/{a}", Router::new().route("/{b}", get(|| async move {}))) .layer(map_request(extract_matched_path)); let client = TestClient::new(app); - let res = client.get("/foo/bar").send().await; + let res = client.get("/foo/bar").await; assert_eq!(res.status(), StatusCode::OK); } @@ -236,12 +234,12 @@ mod tests { } let app = Router::new() - .nest_service("/:a", Router::new().route("/:b", get(|| async move {}))) + .nest_service("/{a}", Router::new().route("/{b}", get(|| async move {}))) .layer(map_request(assert_no_matched_path)); let client = TestClient::new(app); - let res = client.get("/foo/bar").send().await; + let res = client.get("/foo/bar").await; assert_eq!(res.status(), StatusCode::OK); } @@ -253,32 +251,32 @@ mod tests { } let app = Router::new() - .nest("/:a", Router::new().route("/:b", get(|| async move {}))) + .nest("/{a}", Router::new().route("/{b}", get(|| async move {}))) .layer(map_request(assert_matched_path)); let client = TestClient::new(app); - let res = client.get("/foo/bar").send().await; + let res = client.get("/foo/bar").await; assert_eq!(res.status(), StatusCode::OK); } #[crate::test] async fn can_extract_nested_matched_path_in_middleware_on_nested_router() { async fn extract_matched_path(matched_path: MatchedPath, req: Request) -> Request { - assert_eq!(matched_path.as_str(), "/:a/:b"); + assert_eq!(matched_path.as_str(), "/{a}/{b}"); req } let app = Router::new().nest( - "/:a", + "/{a}", Router::new() - .route("/:b", get(|| async move {})) + .route("/{b}", get(|| async move {})) .layer(map_request(extract_matched_path)), ); let client = TestClient::new(app); - let res = client.get("/foo/bar").send().await; + let res = client.get("/foo/bar").await; assert_eq!(res.status(), StatusCode::OK); } @@ -286,20 +284,20 @@ mod tests { async fn can_extract_nested_matched_path_in_middleware_on_nested_router_via_extension() { async fn extract_matched_path(req: Request) -> Request { let matched_path = req.extensions().get::().unwrap(); - assert_eq!(matched_path.as_str(), "/:a/:b"); + assert_eq!(matched_path.as_str(), "/{a}/{b}"); req } let app = Router::new().nest( - "/:a", + "/{a}", Router::new() - .route("/:b", get(|| async move {})) + .route("/{b}", get(|| async move {})) .layer(map_request(extract_matched_path)), ); let client = TestClient::new(app); - let res = client.get("/foo/bar").send().await; + let res = client.get("/foo/bar").await; assert_eq!(res.status(), StatusCode::OK); } @@ -309,11 +307,11 @@ mod tests { assert!(path.is_none()); } - let app = Router::new().nest_service("/:a", handler.into_service()); + let app = Router::new().nest_service("/{a}", handler.into_service()); let client = TestClient::new(app); - let res = client.get("/foo/bar").send().await; + let res = client.get("/foo/bar").await; assert_eq!(res.status(), StatusCode::OK); } @@ -323,17 +321,17 @@ mod tests { use tower::ServiceExt; let app = Router::new().route( - "/*path", + "/{*path}", any(|req: Request| { Router::new() - .nest("/", Router::new().route("/foo", get(|| async {}))) + .nest("/foo", Router::new().route("/bar", get(|| async {}))) .oneshot(req) }), ); let client = TestClient::new(app); - let res = client.get("/foo").send().await; + let res = client.get("/foo/bar").await; assert_eq!(res.status(), StatusCode::OK); } @@ -348,7 +346,47 @@ mod tests { let client = TestClient::new(app); - let res = client.get("/foo/bar").send().await; + let res = client.get("/foo/bar").await; assert_eq!(res.status(), StatusCode::OK); } + + #[crate::test] + async fn matching_colon() { + let app = Router::new().without_v07_checks().route( + "/:foo", + get(|path: MatchedPath| async move { path.as_str().to_owned() }), + ); + + let client = TestClient::new(app); + + let res = client.get("/:foo").await; + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.text().await, "/:foo"); + + let res = client.get("/:bar").await; + assert_eq!(res.status(), StatusCode::NOT_FOUND); + + let res = client.get("/foo").await; + assert_eq!(res.status(), StatusCode::NOT_FOUND); + } + + #[crate::test] + async fn matching_asterisk() { + let app = Router::new().without_v07_checks().route( + "/*foo", + get(|path: MatchedPath| async move { path.as_str().to_owned() }), + ); + + let client = TestClient::new(app); + + let res = client.get("/*foo").await; + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.text().await, "/*foo"); + + let res = client.get("/*bar").await; + assert_eq!(res.status(), StatusCode::NOT_FOUND); + + let res = client.get("/foo").await; + assert_eq!(res.status(), StatusCode::NOT_FOUND); + } } diff --git a/axum/src/extract/mod.rs b/axum/src/extract/mod.rs index 719083d11f..6d5e8de857 100644 --- a/axum/src/extract/mod.rs +++ b/axum/src/extract/mod.rs @@ -10,7 +10,6 @@ pub mod rejection; #[cfg(feature = "ws")] pub mod ws; -mod host; pub(crate) mod nested_path; mod raw_form; mod raw_query; @@ -24,9 +23,7 @@ pub use axum_core::extract::{DefaultBodyLimit, FromRef, FromRequest, FromRequest pub use axum_macros::{FromRef, FromRequest, FromRequestParts}; #[doc(inline)] -#[allow(deprecated)] pub use self::{ - host::Host, nested_path::NestedPath, path::{Path, RawPathParams}, raw_form::RawForm, @@ -104,7 +101,7 @@ mod tests { let app = Router::new().route("/", get(|body: String| async { body })); let client = TestClient::new(app); - let res = client.get("/").body("foo").send().await; + let res = client.get("/").body("foo").await; let body = res.text().await; assert_eq!(body, "foo"); diff --git a/axum/src/extract/multipart.rs b/axum/src/extract/multipart.rs index 2c592ecea0..9f278f6d25 100644 --- a/axum/src/extract/multipart.rs +++ b/axum/src/extract/multipart.rs @@ -4,7 +4,6 @@ use super::{FromRequest, Request}; use crate::body::Bytes; -use async_trait::async_trait; use axum_core::{ __composite_rejection as composite_rejection, __define_rejection as define_rejection, response::{IntoResponse, Response}, @@ -65,7 +64,6 @@ pub struct Multipart { inner: multer::Multipart<'static>, } -#[async_trait] impl FromRequest for Multipart where S: Send + Sync, @@ -109,7 +107,7 @@ pub struct Field<'a> { _multipart: &'a mut Multipart, } -impl<'a> Stream for Field<'a> { +impl Stream for Field<'_> { type Item = Result; fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { @@ -119,7 +117,7 @@ impl<'a> Stream for Field<'a> { } } -impl<'a> Field<'a> { +impl Field<'_> { /// The field name found in the /// [`Content-Disposition`](https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Content-Disposition) /// header. @@ -274,12 +272,13 @@ impl std::error::Error for MultipartError { impl IntoResponse for MultipartError { fn into_response(self) -> Response { + let body = self.body_text(); axum_core::__log_rejection!( rejection_type = Self, - body_text = self.body_text(), + body_text = body, status = self.status(), ); - (self.status(), self.body_text()).into_response() + (self.status(), body).into_response() } } @@ -310,7 +309,7 @@ mod tests { use axum_core::extract::DefaultBodyLimit; use super::*; - use crate::{response::IntoResponse, routing::post, test_helpers::*, Router}; + use crate::{routing::post, test_helpers::*, Router}; #[crate::test] async fn content_type_with_encoding() { @@ -345,7 +344,7 @@ mod tests { )])), ); - client.post("/").multipart(form).send().await; + client.post("/").multipart(form).await; } // No need for this to be a #[test], we just want to make sure it compiles @@ -376,7 +375,7 @@ mod tests { let form = reqwest::multipart::Form::new().part("file", reqwest::multipart::Part::bytes(BYTES)); - let res = client.post("/").multipart(form).send().await; + let res = client.post("/").multipart(form).await; assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE); } } diff --git a/axum/src/extract/nested_path.rs b/axum/src/extract/nested_path.rs index f31fe3faba..77d1316781 100644 --- a/axum/src/extract/nested_path.rs +++ b/axum/src/extract/nested_path.rs @@ -4,7 +4,6 @@ use std::{ }; use crate::extract::Request; -use async_trait::async_trait; use axum_core::extract::FromRequestParts; use http::request::Parts; use tower_layer::{layer_fn, Layer}; @@ -47,7 +46,6 @@ impl NestedPath { } } -#[async_trait] impl FromRequestParts for NestedPath where S: Send + Sync, @@ -135,7 +133,7 @@ mod tests { let client = TestClient::new(app); - let res = client.get("/api/users").send().await; + let res = client.get("/api/users").await; assert_eq!(res.status(), StatusCode::OK); } @@ -153,7 +151,7 @@ mod tests { let client = TestClient::new(app); - let res = client.get("/api/users").send().await; + let res = client.get("/api/users").await; assert_eq!(res.status(), StatusCode::OK); } @@ -171,7 +169,7 @@ mod tests { let client = TestClient::new(app); - let res = client.get("/api/v2/users").send().await; + let res = client.get("/api/v2/users").await; assert_eq!(res.status(), StatusCode::OK); } @@ -189,43 +187,7 @@ mod tests { let client = TestClient::new(app); - let res = client.get("/api/v2/users").send().await; - assert_eq!(res.status(), StatusCode::OK); - } - - #[crate::test] - async fn nested_at_root() { - let api = Router::new().route( - "/users", - get(|nested_path: NestedPath| { - assert_eq!(nested_path.as_str(), "/"); - async {} - }), - ); - - let app = Router::new().nest("/", api); - - let client = TestClient::new(app); - - let res = client.get("/users").send().await; - assert_eq!(res.status(), StatusCode::OK); - } - - #[crate::test] - async fn deeply_nested_from_root() { - let api = Router::new().route( - "/users", - get(|nested_path: NestedPath| { - assert_eq!(nested_path.as_str(), "/api"); - async {} - }), - ); - - let app = Router::new().nest("/", Router::new().nest("/api", api)); - - let client = TestClient::new(app); - - let res = client.get("/api/users").send().await; + let res = client.get("/api/v2/users").await; assert_eq!(res.status(), StatusCode::OK); } @@ -240,7 +202,7 @@ mod tests { let client = TestClient::new(app); - let res = client.get("/api/doesnt-exist").send().await; + let res = client.get("/api/doesnt-exist").await; assert_eq!(res.status(), StatusCode::OK); } @@ -259,7 +221,7 @@ mod tests { let client = TestClient::new(app); - let res = client.get("/api/users").send().await; + let res = client.get("/api/users").await; assert_eq!(res.status(), StatusCode::OK); } } diff --git a/axum/src/extract/path/de.rs b/axum/src/extract/path/de.rs index bbc0c85c9b..ca78bb9e23 100644 --- a/axum/src/extract/path/de.rs +++ b/axum/src/extract/path/de.rs @@ -94,7 +94,21 @@ impl<'de> Deserializer<'de> for PathDeserializer<'de> { .got(self.url_params.len()) .expected(1)); } - visitor.visit_borrowed_str(&self.url_params[0].1) + let key = &self.url_params[0].0; + let value = &self.url_params[0].1; + visitor + .visit_borrowed_str(value) + .map_err(|e: PathDeserializationError| { + if let ErrorKind::Message(message) = &e.kind { + PathDeserializationError::new(ErrorKind::DeserializeError { + key: key.to_string(), + value: value.as_str().to_owned(), + message: message.to_owned(), + }) + } else { + e + } + }) } fn deserialize_unit(self, visitor: V) -> Result @@ -140,7 +154,7 @@ impl<'de> Deserializer<'de> for PathDeserializer<'de> { where V: Visitor<'de>, { - if self.url_params.len() < len { + if self.url_params.len() != len { return Err(PathDeserializationError::wrong_number_of_parameters() .got(self.url_params.len()) .expected(len)); @@ -160,7 +174,7 @@ impl<'de> Deserializer<'de> for PathDeserializer<'de> { where V: Visitor<'de>, { - if self.url_params.len() < len { + if self.url_params.len() != len { return Err(PathDeserializationError::wrong_number_of_parameters() .got(self.url_params.len()) .expected(len)); @@ -210,14 +224,14 @@ impl<'de> Deserializer<'de> for PathDeserializer<'de> { } visitor.visit_enum(EnumDeserializer { - value: self.url_params[0].1.clone().into_inner(), + value: &self.url_params[0].1, }) } } struct MapDeserializer<'de> { params: &'de [(Arc, PercentDecodedStr)], - key: Option, + key: Option>, value: Option<&'de PercentDecodedStr>, } @@ -232,11 +246,8 @@ impl<'de> MapAccess<'de> for MapDeserializer<'de> { Some(((key, value), tail)) => { self.value = Some(value); self.params = tail; - self.key = Some(KeyOrIdx::Key(key.clone())); - seed.deserialize(KeyDeserializer { - key: Arc::clone(key), - }) - .map(Some) + self.key = Some(KeyOrIdx::Key(key)); + seed.deserialize(KeyDeserializer { key }).map(Some) } None => Ok(None), } @@ -256,8 +267,8 @@ impl<'de> MapAccess<'de> for MapDeserializer<'de> { } } -struct KeyDeserializer { - key: Arc, +struct KeyDeserializer<'de> { + key: &'de str, } macro_rules! parse_key { @@ -271,7 +282,7 @@ macro_rules! parse_key { }; } -impl<'de> Deserializer<'de> for KeyDeserializer { +impl<'de> Deserializer<'de> for KeyDeserializer<'de> { type Error = PathDeserializationError; parse_key!(deserialize_identifier); @@ -302,7 +313,7 @@ macro_rules! parse_value { if let Some(key) = self.key.take() { let kind = match key { KeyOrIdx::Key(key) => ErrorKind::ParseErrorAtKey { - key: key.to_string(), + key: key.to_owned(), value: self.value.as_str().to_owned(), expected_type: $ty, }, @@ -327,7 +338,7 @@ macro_rules! parse_value { #[derive(Debug)] struct ValueDeserializer<'de> { - key: Option, + key: Option>, value: &'de PercentDecodedStr, } @@ -365,7 +376,19 @@ impl<'de> Deserializer<'de> for ValueDeserializer<'de> { where V: Visitor<'de>, { - visitor.visit_borrowed_str(self.value) + visitor + .visit_borrowed_str(self.value) + .map_err(|e: PathDeserializationError| { + if let (ErrorKind::Message(message), Some(key)) = (&e.kind, self.key.as_ref()) { + PathDeserializationError::new(ErrorKind::DeserializeError { + key: key.key().to_owned(), + value: self.value.as_str().to_owned(), + message: message.to_owned(), + }) + } else { + e + } + }) } fn deserialize_bytes(self, visitor: V) -> Result @@ -416,7 +439,7 @@ impl<'de> Deserializer<'de> for ValueDeserializer<'de> { V: Visitor<'de>, { struct PairDeserializer<'de> { - key: Option, + key: Option>, value: Option<&'de PercentDecodedStr>, } @@ -431,9 +454,11 @@ impl<'de> Deserializer<'de> for ValueDeserializer<'de> { Some(KeyOrIdx::Idx { idx: _, key }) => { return seed.deserialize(KeyDeserializer { key }).map(Some); } - // `KeyOrIdx::Key` is only used when deserializing maps so `deserialize_seq` - // wouldn't be called for that - Some(KeyOrIdx::Key(_)) => unreachable!(), + Some(KeyOrIdx::Key(_)) => { + return Err(PathDeserializationError::custom( + "array types are not supported", + )); + } None => {} }; @@ -507,9 +532,7 @@ impl<'de> Deserializer<'de> for ValueDeserializer<'de> { where V: Visitor<'de>, { - visitor.visit_enum(EnumDeserializer { - value: self.value.clone().into_inner(), - }) + visitor.visit_enum(EnumDeserializer { value: self.value }) } fn deserialize_ignored_any(self, visitor: V) -> Result @@ -520,11 +543,11 @@ impl<'de> Deserializer<'de> for ValueDeserializer<'de> { } } -struct EnumDeserializer { - value: Arc, +struct EnumDeserializer<'de> { + value: &'de str, } -impl<'de> EnumAccess<'de> for EnumDeserializer { +impl<'de> EnumAccess<'de> for EnumDeserializer<'de> { type Error = PathDeserializationError; type Variant = UnitVariant; @@ -598,10 +621,7 @@ impl<'de> SeqAccess<'de> for SeqDeserializer<'de> { let idx = self.idx; self.idx += 1; Ok(Some(seed.deserialize(ValueDeserializer { - key: Some(KeyOrIdx::Idx { - idx, - key: key.clone(), - }), + key: Some(KeyOrIdx::Idx { idx, key }), value, })?)) } @@ -611,9 +631,18 @@ impl<'de> SeqAccess<'de> for SeqDeserializer<'de> { } #[derive(Debug, Clone)] -enum KeyOrIdx { - Key(Arc), - Idx { idx: usize, key: Arc }, +enum KeyOrIdx<'de> { + Key(&'de str), + Idx { idx: usize, key: &'de str }, +} + +impl<'de> KeyOrIdx<'de> { + fn key(&self) -> &'de str { + match &self { + Self::Key(key) => key, + Self::Idx { key, .. } => key, + } + } } #[cfg(test)] @@ -781,20 +810,6 @@ mod tests { ); } - #[test] - fn test_parse_tuple_ignoring_additional_fields() { - let url_params = create_url_params(vec![ - ("a", "abc"), - ("b", "true"), - ("c", "1"), - ("d", "false"), - ]); - assert_eq!( - <(&str, bool, u32)>::deserialize(PathDeserializer::new(&url_params)).unwrap(), - ("abc", true, 1) - ); - } - #[test] fn test_parse_map() { let url_params = create_url_params(vec![("a", "1"), ("b", "true"), ("c", "abc")]); @@ -821,6 +836,18 @@ mod tests { }; } + #[test] + fn test_parse_tuple_too_many_fields() { + test_parse_error!( + vec![("a", "abc"), ("b", "true"), ("c", "1"), ("d", "false"),], + (&str, bool, u32), + ErrorKind::WrongNumberOfParameters { + got: 4, + expected: 3, + } + ); + } + #[test] fn test_wrong_number_of_parameters_error() { test_parse_error!( @@ -936,4 +963,17 @@ mod tests { } ); } + + #[test] + fn test_deserialize_key_value() { + test_parse_error!( + vec![("id", "123123-123-123123")], + uuid::Uuid, + ErrorKind::DeserializeError { + key: "id".to_owned(), + value: "123123-123-123123".to_owned(), + message: "UUID parsing failed: invalid group count: expected 5, found 3".to_owned(), + } + ); + } } diff --git a/axum/src/extract/path/mod.rs b/axum/src/extract/path/mod.rs index 0be9008803..427db8f20d 100644 --- a/axum/src/extract/path/mod.rs +++ b/axum/src/extract/path/mod.rs @@ -8,7 +8,6 @@ use crate::{ routing::url_params::UrlParams, util::PercentDecodedStr, }; -use async_trait::async_trait; use axum_core::response::{IntoResponse, Response}; use http::{request::Parts, StatusCode}; use serde::de::DeserializeOwned; @@ -44,7 +43,7 @@ use std::{fmt, sync::Arc}; /// // ... /// } /// -/// let app = Router::new().route("/users/:user_id/team/:team_id", get(users_teams_show)); +/// let app = Router::new().route("/users/{user_id}/team/{team_id}", get(users_teams_show)); /// # let _: Router = app; /// ``` /// @@ -62,7 +61,7 @@ use std::{fmt, sync::Arc}; /// // ... /// } /// -/// let app = Router::new().route("/users/:user_id", get(user_info)); +/// let app = Router::new().route("/users/{user_id}", get(user_info)); /// # let _: Router = app; /// ``` /// @@ -99,7 +98,7 @@ use std::{fmt, sync::Arc}; /// } /// /// let app = Router::new().route( -/// "/users/:user_id/team/:team_id", +/// "/users/{user_id}/team/{team_id}", /// get(users_teams_show).post(users_teams_create), /// ); /// # let _: Router = app; @@ -128,7 +127,7 @@ use std::{fmt, sync::Arc}; /// } /// /// let app = Router::new() -/// .route("/users/:user_id/team/:team_id", get(params_map).post(params_vec)); +/// .route("/users/{user_id}/team/{team_id}", get(params_map).post(params_vec)); /// # let _: Router = app; /// ``` /// @@ -145,7 +144,6 @@ pub struct Path(pub T); axum_core::__impl_deref!(Path); -#[async_trait] impl FromRequestParts for Path where T: DeserializeOwned + Send, @@ -305,6 +303,16 @@ pub enum ErrorKind { name: &'static str, }, + /// Failed to deserialize the value with a custom deserialization error. + DeserializeError { + /// The key at which the invalid value was located. + key: String, + /// The value that failed to deserialize. + value: String, + /// The deserializaation failure message. + message: String, + }, + /// Catch-all variant for errors that don't fit any other variant. Message(String), } @@ -333,20 +341,25 @@ impl fmt::Display for ErrorKind { expected_type, } => write!( f, - "Cannot parse `{key}` with value `{value:?}` to a `{expected_type}`" + "Cannot parse `{key}` with value `{value}` to a `{expected_type}`" ), ErrorKind::ParseError { value, expected_type, - } => write!(f, "Cannot parse `{value:?}` to a `{expected_type}`"), + } => write!(f, "Cannot parse `{value}` to a `{expected_type}`"), ErrorKind::ParseErrorAtIndex { index, value, expected_type, } => write!( f, - "Cannot parse value at index {index} with value `{value:?}` to a `{expected_type}`" + "Cannot parse value at index {index} with value `{value}` to a `{expected_type}`" ), + ErrorKind::DeserializeError { + key, + value, + message, + } => write!(f, "Cannot parse `{key}` with value `{value}`: {message}"), } } } @@ -371,6 +384,7 @@ impl FailedToDeserializePathParams { pub fn body_text(&self) -> String { match self.0.kind { ErrorKind::Message(_) + | ErrorKind::DeserializeError { .. } | ErrorKind::InvalidUtf8InPathParam { .. } | ErrorKind::ParseError { .. } | ErrorKind::ParseErrorAtIndex { .. } @@ -385,6 +399,7 @@ impl FailedToDeserializePathParams { pub fn status(&self) -> StatusCode { match self.0.kind { ErrorKind::Message(_) + | ErrorKind::DeserializeError { .. } | ErrorKind::InvalidUtf8InPathParam { .. } | ErrorKind::ParseError { .. } | ErrorKind::ParseErrorAtIndex { .. } @@ -398,12 +413,13 @@ impl FailedToDeserializePathParams { impl IntoResponse for FailedToDeserializePathParams { fn into_response(self) -> Response { + let body = self.body_text(); axum_core::__log_rejection!( rejection_type = Self, - body_text = self.body_text(), + body_text = body, status = self.status(), ); - (self.status(), self.body_text()).into_response() + (self.status(), body).into_response() } } @@ -439,13 +455,12 @@ impl std::error::Error for FailedToDeserializePathParams {} /// } /// } /// -/// let app = Router::new().route("/users/:user_id/team/:team_id", get(users_teams_show)); +/// let app = Router::new().route("/users/{user_id}/team/{team_id}", get(users_teams_show)); /// # let _: Router = app; /// ``` #[derive(Debug)] pub struct RawPathParams(Vec<(Arc, PercentDecodedStr)>); -#[async_trait] impl FromRequestParts for RawPathParams where S: Send + Sync, @@ -530,7 +545,13 @@ impl std::error::Error for InvalidUtf8InPathParam {} impl IntoResponse for InvalidUtf8InPathParam { fn into_response(self) -> Response { - (self.status(), self.body_text()).into_response() + let body = self.body_text(); + axum_core::__log_rejection!( + rejection_type = Self, + body_text = body, + status = self.status(), + ); + (self.status(), body).into_response() } } @@ -538,14 +559,13 @@ impl IntoResponse for InvalidUtf8InPathParam { mod tests { use super::*; use crate::{routing::get, test_helpers::*, Router}; - use http::StatusCode; use serde::Deserialize; use std::collections::HashMap; #[crate::test] async fn extracting_url_params() { let app = Router::new().route( - "/users/:id", + "/users/{id}", get(|Path(id): Path| async move { assert_eq!(id, 42); }) @@ -556,33 +576,33 @@ mod tests { let client = TestClient::new(app); - let res = client.get("/users/42").send().await; + let res = client.get("/users/42").await; assert_eq!(res.status(), StatusCode::OK); - let res = client.post("/users/1337").send().await; + let res = client.post("/users/1337").await; assert_eq!(res.status(), StatusCode::OK); } #[crate::test] async fn extracting_url_params_multiple_times() { - let app = Router::new().route("/users/:id", get(|_: Path, _: Path| async {})); + let app = Router::new().route("/users/{id}", get(|_: Path, _: Path| async {})); let client = TestClient::new(app); - let res = client.get("/users/42").send().await; + let res = client.get("/users/42").await; assert_eq!(res.status(), StatusCode::OK); } #[crate::test] async fn percent_decoding() { let app = Router::new().route( - "/:key", + "/{key}", get(|Path(param): Path| async move { param }), ); let client = TestClient::new(app); - let res = client.get("/one%20two").send().await; + let res = client.get("/one%20two").await; assert_eq!(res.text().await, "one two"); } @@ -591,20 +611,20 @@ mod tests { async fn supports_128_bit_numbers() { let app = Router::new() .route( - "/i/:key", + "/i/{key}", get(|Path(param): Path| async move { param.to_string() }), ) .route( - "/u/:key", + "/u/{key}", get(|Path(param): Path| async move { param.to_string() }), ); let client = TestClient::new(app); - let res = client.get("/i/123").send().await; + let res = client.get("/i/123").await; assert_eq!(res.text().await, "123"); - let res = client.get("/u/123").send().await; + let res = client.get("/u/123").await; assert_eq!(res.text().await, "123"); } @@ -612,11 +632,11 @@ mod tests { async fn wildcard() { let app = Router::new() .route( - "/foo/*rest", + "/foo/{*rest}", get(|Path(param): Path| async move { param }), ) .route( - "/bar/*rest", + "/bar/{*rest}", get(|Path(params): Path>| async move { params.get("rest").unwrap().clone() }), @@ -624,80 +644,80 @@ mod tests { let client = TestClient::new(app); - let res = client.get("/foo/bar/baz").send().await; + let res = client.get("/foo/bar/baz").await; assert_eq!(res.text().await, "bar/baz"); - let res = client.get("/bar/baz/qux").send().await; + let res = client.get("/bar/baz/qux").await; assert_eq!(res.text().await, "baz/qux"); } #[crate::test] async fn captures_dont_match_empty_path() { - let app = Router::new().route("/:key", get(|| async {})); + let app = Router::new().route("/{key}", get(|| async {})); let client = TestClient::new(app); - let res = client.get("/").send().await; + let res = client.get("/").await; assert_eq!(res.status(), StatusCode::NOT_FOUND); - let res = client.get("/foo").send().await; + let res = client.get("/foo").await; assert_eq!(res.status(), StatusCode::OK); } #[crate::test] async fn captures_match_empty_inner_segments() { let app = Router::new().route( - "/:key/method", + "/{key}/method", get(|Path(param): Path| async move { param.to_string() }), ); let client = TestClient::new(app); - let res = client.get("/abc/method").send().await; + let res = client.get("/abc/method").await; assert_eq!(res.text().await, "abc"); - let res = client.get("//method").send().await; + let res = client.get("//method").await; assert_eq!(res.text().await, ""); } #[crate::test] async fn captures_match_empty_inner_segments_near_end() { let app = Router::new().route( - "/method/:key/", + "/method/{key}/", get(|Path(param): Path| async move { param.to_string() }), ); let client = TestClient::new(app); - let res = client.get("/method/abc").send().await; + let res = client.get("/method/abc").await; assert_eq!(res.status(), StatusCode::NOT_FOUND); - let res = client.get("/method/abc/").send().await; + let res = client.get("/method/abc/").await; assert_eq!(res.text().await, "abc"); - let res = client.get("/method//").send().await; + let res = client.get("/method//").await; assert_eq!(res.text().await, ""); } #[crate::test] async fn captures_match_empty_trailing_segment() { let app = Router::new().route( - "/method/:key", + "/method/{key}", get(|Path(param): Path| async move { param.to_string() }), ); let client = TestClient::new(app); - let res = client.get("/method/abc/").send().await; + let res = client.get("/method/abc/").await; assert_eq!(res.status(), StatusCode::NOT_FOUND); - let res = client.get("/method/abc").send().await; + let res = client.get("/method/abc").await; assert_eq!(res.text().await, "abc"); - let res = client.get("/method/").send().await; + let res = client.get("/method/").await; assert_eq!(res.text().await, ""); - let res = client.get("/method").send().await; + let res = client.get("/method").await; assert_eq!(res.status(), StatusCode::NOT_FOUND); } @@ -714,25 +734,28 @@ mod tests { } } - let app = Router::new().route("/:key", get(|param: Path| async move { param.0 .0 })); + let app = Router::new().route( + "/{key}", + get(|param: Path| async move { param.0 .0 }), + ); let client = TestClient::new(app); - let res = client.get("/foo").send().await; + let res = client.get("/foo").await; assert_eq!(res.text().await, "foo"); // percent decoding should also work - let res = client.get("/foo%20bar").send().await; + let res = client.get("/foo%20bar").await; assert_eq!(res.text().await, "foo bar"); } #[crate::test] async fn two_path_extractors() { - let app = Router::new().route("/:a/:b", get(|_: Path, _: Path| async {})); + let app = Router::new().route("/{a}/{b}", get(|_: Path, _: Path| async {})); let client = TestClient::new(app); - let res = client.get("/a/b").send().await; + let res = client.get("/a/b").await; assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR); assert_eq!( res.text().await, @@ -741,10 +764,40 @@ mod tests { ); } + #[crate::test] + async fn tuple_param_matches_exactly() { + #[allow(dead_code)] + #[derive(Deserialize)] + struct Tuple(String, String); + + let app = Router::new() + .route( + "/foo/{a}/{b}/{c}", + get(|_: Path<(String, String)>| async {}), + ) + .route("/bar/{a}/{b}/{c}", get(|_: Path| async {})); + + let client = TestClient::new(app); + + let res = client.get("/foo/a/b/c").await; + assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR); + assert_eq!( + res.text().await, + "Wrong number of path arguments for `Path`. Expected 2 but got 3", + ); + + let res = client.get("/bar/a/b/c").await; + assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR); + assert_eq!( + res.text().await, + "Wrong number of path arguments for `Path`. Expected 2 but got 3", + ); + } + #[crate::test] async fn deserialize_into_vec_of_tuples() { let app = Router::new().route( - "/:a/:b", + "/{a}/{b}", get(|Path(params): Path>| async move { assert_eq!( params, @@ -758,7 +811,7 @@ mod tests { let client = TestClient::new(app); - let res = client.get("/foo/bar").send().await; + let res = client.get("/foo/bar").await; assert_eq!(res.status(), StatusCode::OK); } @@ -775,31 +828,31 @@ mod tests { let app = Router::new() .route( - "/single/:a", + "/single/{a}", get(|Path(a): Path| async move { format!("single: {a}") }), ) .route( - "/tuple/:a/:b/:c", + "/tuple/{a}/{b}/{c}", get(|Path((a, b, c)): Path<(Date, Date, Date)>| async move { format!("tuple: {a} {b} {c}") }), ) .route( - "/vec/:a/:b/:c", + "/vec/{a}/{b}/{c}", get(|Path(vec): Path>| async move { let [a, b, c]: [Date; 3] = vec.try_into().unwrap(); format!("vec: {a} {b} {c}") }), ) .route( - "/vec_pairs/:a/:b/:c", + "/vec_pairs/{a}/{b}/{c}", get(|Path(vec): Path>| async move { let [(_, a), (_, b), (_, c)]: [(String, Date); 3] = vec.try_into().unwrap(); format!("vec_pairs: {a} {b} {c}") }), ) .route( - "/map/:a/:b/:c", + "/map/{a}/{b}/{c}", get(|Path(mut map): Path>| async move { let a = map.remove("a").unwrap(); let b = map.remove("b").unwrap(); @@ -808,7 +861,7 @@ mod tests { }), ) .route( - "/struct/:a/:b/:c", + "/struct/{a}/{b}/{c}", get(|Path(params): Path| async move { format!("struct: {} {} {}", params.a, params.b, params.c) }), @@ -816,40 +869,27 @@ mod tests { let client = TestClient::new(app); - let res = client.get("/single/2023-01-01").send().await; + let res = client.get("/single/2023-01-01").await; assert_eq!(res.text().await, "single: 2023-01-01"); - let res = client - .get("/tuple/2023-01-01/2023-01-02/2023-01-03") - .send() - .await; + let res = client.get("/tuple/2023-01-01/2023-01-02/2023-01-03").await; assert_eq!(res.text().await, "tuple: 2023-01-01 2023-01-02 2023-01-03"); - let res = client - .get("/vec/2023-01-01/2023-01-02/2023-01-03") - .send() - .await; + let res = client.get("/vec/2023-01-01/2023-01-02/2023-01-03").await; assert_eq!(res.text().await, "vec: 2023-01-01 2023-01-02 2023-01-03"); let res = client .get("/vec_pairs/2023-01-01/2023-01-02/2023-01-03") - .send() .await; assert_eq!( res.text().await, "vec_pairs: 2023-01-01 2023-01-02 2023-01-03", ); - let res = client - .get("/map/2023-01-01/2023-01-02/2023-01-03") - .send() - .await; + let res = client.get("/map/2023-01-01/2023-01-02/2023-01-03").await; assert_eq!(res.text().await, "map: 2023-01-01 2023-01-02 2023-01-03"); - let res = client - .get("/struct/2023-01-01/2023-01-02/2023-01-03") - .send() - .await; + let res = client.get("/struct/2023-01-01/2023-01-02/2023-01-03").await; assert_eq!(res.text().await, "struct: 2023-01-01 2023-01-02 2023-01-03"); } @@ -858,18 +898,18 @@ mod tests { use serde_json::Value; let app = Router::new() - .route("/one/:a", get(|_: Path<(Value, Value)>| async {})) - .route("/two/:a/:b", get(|_: Path| async {})); + .route("/one/{a}", get(|_: Path<(Value, Value)>| async {})) + .route("/two/{a}/{b}", get(|_: Path| async {})); let client = TestClient::new(app); - let res = client.get("/one/1").send().await; + let res = client.get("/one/1").await; assert!(res .text() .await .starts_with("Wrong number of path arguments for `Path`. Expected 2 but got 1")); - let res = client.get("/two/1/2").send().await; + let res = client.get("/two/1/2").await; assert!(res .text() .await @@ -879,7 +919,7 @@ mod tests { #[crate::test] async fn raw_path_params() { let app = Router::new().route( - "/:a/:b/:c", + "/{a}/{b}/{c}", get(|params: RawPathParams| async move { params .into_iter() @@ -890,8 +930,68 @@ mod tests { ); let client = TestClient::new(app); - let res = client.get("/foo/bar/baz").send().await; + let res = client.get("/foo/bar/baz").await; let body = res.text().await; assert_eq!(body, "a=foo b=bar c=baz"); } + + #[crate::test] + async fn deserialize_error_single_value() { + let app = Router::new().route( + "/resources/{res}", + get(|res: Path| async move { + let _res = res; + }), + ); + + let client = TestClient::new(app); + let res = client.get("/resources/123123-123-123123").await; + let body = res.text().await; + assert_eq!( + body, + r#"Invalid URL: Cannot parse `res` with value `123123-123-123123`: UUID parsing failed: invalid group count: expected 5, found 3"# + ); + } + + #[crate::test] + async fn deserialize_error_multi_value() { + let app = Router::new().route( + "/resources/{res}/sub/{sub}", + get( + |Path((res, sub)): Path<(uuid::Uuid, uuid::Uuid)>| async move { + let _res = res; + let _sub = sub; + }, + ), + ); + + let client = TestClient::new(app); + let res = client.get("/resources/456456-123-456456/sub/123").await; + let body = res.text().await; + assert_eq!( + body, + r#"Invalid URL: Cannot parse `res` with value `456456-123-456456`: UUID parsing failed: invalid group count: expected 5, found 3"# + ); + } + + #[crate::test] + async fn regression_3038() { + #[derive(Deserialize)] + #[allow(dead_code)] + struct MoreChars { + first_two: [char; 2], + second_two: [char; 2], + crate_name: String, + } + + let app = Router::new().route( + "/{first_two}/{second_two}/{crate_name}", + get(|Path(_): Path| async move {}), + ); + + let client = TestClient::new(app); + let res = client.get("/te/st/_thing").await; + let body = res.text().await; + assert_eq!(body, r#"Invalid URL: array types are not supported"#); + } } diff --git a/axum/src/extract/query.rs b/axum/src/extract/query.rs index 37a40771c2..371612b71a 100644 --- a/axum/src/extract/query.rs +++ b/axum/src/extract/query.rs @@ -1,5 +1,4 @@ use super::{rejection::*, FromRequestParts}; -use async_trait::async_trait; use http::{request::Parts, Uri}; use serde::de::DeserializeOwned; @@ -42,11 +41,15 @@ use serde::de::DeserializeOwned; /// example. /// /// [example]: https://github.com/tokio-rs/axum/blob/main/examples/query-params-with-empty-strings/src/main.rs +/// +/// For handling multiple values for the same query parameter, in a `?foo=1&foo=2&foo=3` +/// fashion, use [`axum_extra::extract::Query`] instead. +/// +/// [`axum_extra::extract::Query`]: https://docs.rs/axum-extra/latest/axum_extra/extract/struct.Query.html #[cfg_attr(docsrs, doc(cfg(feature = "query")))] #[derive(Debug, Clone, Copy, Default)] pub struct Query(pub T); -#[async_trait] impl FromRequestParts for Query where T: DeserializeOwned, @@ -162,7 +165,7 @@ mod tests { let app = Router::new().route("/", get(handler)); let client = TestClient::new(app); - let res = client.get("/?n=hi").send().await; + let res = client.get("/?n=hi").await; assert_eq!(res.status(), StatusCode::BAD_REQUEST); } diff --git a/axum/src/extract/raw_form.rs b/axum/src/extract/raw_form.rs index a4e0d6c57c..29cb4c6dd3 100644 --- a/axum/src/extract/raw_form.rs +++ b/axum/src/extract/raw_form.rs @@ -1,4 +1,3 @@ -use async_trait::async_trait; use axum_core::extract::{FromRequest, Request}; use bytes::Bytes; use http::Method; @@ -30,7 +29,6 @@ use super::{ #[derive(Debug)] pub struct RawForm(pub Bytes); -#[async_trait] impl FromRequest for RawForm where S: Send + Sync, diff --git a/axum/src/extract/raw_query.rs b/axum/src/extract/raw_query.rs index d8c56f84a4..c792960a1b 100644 --- a/axum/src/extract/raw_query.rs +++ b/axum/src/extract/raw_query.rs @@ -1,5 +1,4 @@ use super::FromRequestParts; -use async_trait::async_trait; use http::request::Parts; use std::convert::Infallible; @@ -25,7 +24,6 @@ use std::convert::Infallible; #[derive(Debug)] pub struct RawQuery(pub Option); -#[async_trait] impl FromRequestParts for RawQuery where S: Send + Sync, diff --git a/axum/src/extract/rejection.rs b/axum/src/extract/rejection.rs index cba76af054..cd49cc78b1 100644 --- a/axum/src/extract/rejection.rs +++ b/axum/src/extract/rejection.rs @@ -65,14 +65,6 @@ define_rejection! { pub struct InvalidFormContentType; } -define_rejection! { - #[status = BAD_REQUEST] - #[body = "No host found in request"] - /// Rejection type used if the [`Host`](super::Host) extractor is unable to - /// resolve a host. - pub struct FailedToResolveHost; -} - define_rejection! { #[status = BAD_REQUEST] #[body = "Failed to deserialize form"] @@ -178,16 +170,6 @@ composite_rejection! { } } -composite_rejection! { - /// Rejection used for [`Host`](super::Host). - /// - /// Contains one variant for each way the [`Host`](super::Host) extractor - /// can fail. - pub enum HostRejection { - FailedToResolveHost, - } -} - #[cfg(feature = "matched-path")] define_rejection! { #[status = INTERNAL_SERVER_ERROR] diff --git a/axum/src/extract/request_parts.rs b/axum/src/extract/request_parts.rs index 9756665b6c..6d9adc672c 100644 --- a/axum/src/extract/request_parts.rs +++ b/axum/src/extract/request_parts.rs @@ -1,5 +1,4 @@ use super::{Extension, FromRequestParts}; -use async_trait::async_trait; use http::{request::Parts, Uri}; use std::convert::Infallible; @@ -47,7 +46,7 @@ use std::convert::Infallible; /// use tower_http::trace::TraceLayer; /// /// let api_routes = Router::new() -/// .route("/users/:id", get(|| async { /* ... */ })) +/// .route("/users/{id}", get(|| async { /* ... */ })) /// .layer( /// TraceLayer::new_for_http().make_span_with(|req: &Request<_>| { /// let path = if let Some(path) = req.extensions().get::() { @@ -70,7 +69,6 @@ use std::convert::Infallible; pub struct OriginalUri(pub Uri); #[cfg(feature = "original-uri")] -#[async_trait] impl FromRequestParts for OriginalUri where S: Send + Sync, @@ -109,7 +107,7 @@ mod tests { let client = TestClient::new(Router::new().route("/", get(handler)).layer(Extension(Ext))); - let res = client.get("/").header("x-foo", "123").send().await; + let res = client.get("/").header("x-foo", "123").await; assert_eq!(res.status(), StatusCode::OK); } } diff --git a/axum/src/extract/state.rs b/axum/src/extract/state.rs index fb401c00d8..b95deb39bb 100644 --- a/axum/src/extract/state.rs +++ b/axum/src/extract/state.rs @@ -1,4 +1,3 @@ -use async_trait::async_trait; use axum_core::extract::{FromRef, FromRequestParts}; use http::request::Parts; use std::{ @@ -11,7 +10,11 @@ use std::{ /// See ["Accessing state in middleware"][state-from-middleware] for how to /// access state in middleware. /// +/// State is global and used in every request a router with state receives. +/// For accessing data derived from requests, such as authorization data, see [`Extension`]. +/// /// [state-from-middleware]: crate::middleware#accessing-state-in-middleware +/// [`Extension`]: crate::Extension /// /// # With `Router` /// @@ -22,9 +25,6 @@ use std::{ /// // /// // here you can put configuration, database connection pools, or whatever /// // state you need -/// // -/// // see "When states need to implement `Clone`" for more details on why we need -/// // `#[derive(Clone)]` here. /// #[derive(Clone)] /// struct AppState {} /// @@ -219,13 +219,11 @@ use std::{ /// ```rust /// use axum_core::extract::{FromRequestParts, FromRef}; /// use http::request::Parts; -/// use async_trait::async_trait; /// use std::convert::Infallible; /// /// // the extractor your library provides /// struct MyLibraryExtractor; /// -/// #[async_trait] /// impl FromRequestParts for MyLibraryExtractor /// where /// // keep `S` generic but require that it can produce a `MyLibraryState` @@ -250,53 +248,6 @@ use std::{ /// } /// ``` /// -/// # When states need to implement `Clone` -/// -/// Your top level state type must implement `Clone` to be extractable with `State`: -/// -/// ``` -/// use axum::extract::State; -/// -/// // no substates, so to extract to `State` we must implement `Clone` for `AppState` -/// #[derive(Clone)] -/// struct AppState {} -/// -/// async fn handler(State(state): State) { -/// // ... -/// } -/// ``` -/// -/// This works because of [`impl FromRef for S where S: Clone`][`FromRef`]. -/// -/// This is also true if you're extracting substates, unless you _never_ extract the top level -/// state itself: -/// -/// ``` -/// use axum::extract::{State, FromRef}; -/// -/// // we never extract `State`, just `State`. So `AppState` doesn't need to -/// // implement `Clone` -/// struct AppState { -/// inner: InnerState, -/// } -/// -/// #[derive(Clone)] -/// struct InnerState {} -/// -/// impl FromRef for InnerState { -/// fn from_ref(app_state: &AppState) -> InnerState { -/// app_state.inner.clone() -/// } -/// } -/// -/// async fn api_users(State(inner): State) { -/// // ... -/// } -/// ``` -/// -/// In general however we recommend you implement `Clone` for all your state types to avoid -/// potential type errors. -/// /// # Shared mutable state /// /// [As state is global within a `Router`][global] you can't directly get a mutable reference to @@ -344,7 +295,6 @@ use std::{ #[derive(Debug, Default, Clone, Copy)] pub struct State(pub S); -#[async_trait] impl FromRequestParts for State where InnerState: FromRef, diff --git a/axum/src/extract/ws.rs b/axum/src/extract/ws.rs index 7781d6c2aa..6bf3b44628 100644 --- a/axum/src/extract/ws.rs +++ b/axum/src/extract/ws.rs @@ -5,12 +5,12 @@ //! ``` //! use axum::{ //! extract::ws::{WebSocketUpgrade, WebSocket}, -//! routing::get, +//! routing::any, //! response::{IntoResponse, Response}, //! Router, //! }; //! -//! let app = Router::new().route("/ws", get(handler)); +//! let app = Router::new().route("/ws", any(handler)); //! //! async fn handler(ws: WebSocketUpgrade) -> Response { //! ws.on_upgrade(handle_socket) @@ -40,7 +40,7 @@ //! use axum::{ //! extract::{ws::{WebSocketUpgrade, WebSocket}, State}, //! response::Response, -//! routing::get, +//! routing::any, //! Router, //! }; //! @@ -58,7 +58,7 @@ //! } //! //! let app = Router::new() -//! .route("/ws", get(handler)) +//! .route("/ws", any(handler)) //! .with_state(AppState { /* ... */ }); //! # let _: Router = app; //! ``` @@ -93,7 +93,6 @@ use self::rejection::*; use super::FromRequestParts; use crate::{body::Bytes, response::Response, Error}; -use async_trait::async_trait; use axum_core::body::Body; use futures_util::{ sink::{Sink, SinkExt}, @@ -102,7 +101,7 @@ use futures_util::{ use http::{ header::{self, HeaderMap, HeaderName, HeaderValue}, request::Parts, - Method, StatusCode, + Method, StatusCode, Version, }; use hyper_util::rt::TokioIo; use sha1::{Digest, Sha1}; @@ -122,17 +121,20 @@ use tokio_tungstenite::{ /// Extractor for establishing WebSocket connections. /// -/// Note: This extractor requires the request method to be `GET` so it should -/// always be used with [`get`](crate::routing::get). Requests with other methods will be -/// rejected. +/// For HTTP/1.1 requests, this extractor requires the request method to be `GET`; +/// in later versions, `CONNECT` is used instead. +/// To support both, it should be used with [`any`](crate::routing::any). /// /// See the [module docs](self) for an example. +/// +/// [`MethodFilter`]: crate::routing::MethodFilter #[cfg_attr(docsrs, doc(cfg(feature = "ws")))] pub struct WebSocketUpgrade { config: WebSocketConfig, /// The chosen protocol sent in the `Sec-WebSocket-Protocol` header of the response. protocol: Option, - sec_websocket_key: HeaderValue, + /// `None` if HTTP/2+ WebSockets are used. + sec_websocket_key: Option, on_upgrade: hyper::upgrade::OnUpgrade, on_failed_upgrade: F, sec_websocket_protocol: Option, @@ -213,12 +215,12 @@ impl WebSocketUpgrade { /// ``` /// use axum::{ /// extract::ws::{WebSocketUpgrade, WebSocket}, - /// routing::get, + /// routing::any, /// response::{IntoResponse, Response}, /// Router, /// }; /// - /// let app = Router::new().route("/ws", get(handler)); + /// let app = Router::new().route("/ws", any(handler)); /// /// async fn handler(ws: WebSocketUpgrade) -> Response { /// ws.protocols(["graphql-ws", "graphql-transport-ws"]) @@ -330,25 +332,34 @@ impl WebSocketUpgrade { callback(socket).await; }); - #[allow(clippy::declare_interior_mutable_const)] - const UPGRADE: HeaderValue = HeaderValue::from_static("upgrade"); - #[allow(clippy::declare_interior_mutable_const)] - const WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket"); - - let mut builder = Response::builder() - .status(StatusCode::SWITCHING_PROTOCOLS) - .header(header::CONNECTION, UPGRADE) - .header(header::UPGRADE, WEBSOCKET) - .header( - header::SEC_WEBSOCKET_ACCEPT, - sign(self.sec_websocket_key.as_bytes()), - ); - - if let Some(protocol) = self.protocol { - builder = builder.header(header::SEC_WEBSOCKET_PROTOCOL, protocol); - } + if let Some(sec_websocket_key) = &self.sec_websocket_key { + // If `sec_websocket_key` was `Some`, we are using HTTP/1.1. + + #[allow(clippy::declare_interior_mutable_const)] + const UPGRADE: HeaderValue = HeaderValue::from_static("upgrade"); + #[allow(clippy::declare_interior_mutable_const)] + const WEBSOCKET: HeaderValue = HeaderValue::from_static("websocket"); + + let mut builder = Response::builder() + .status(StatusCode::SWITCHING_PROTOCOLS) + .header(header::CONNECTION, UPGRADE) + .header(header::UPGRADE, WEBSOCKET) + .header( + header::SEC_WEBSOCKET_ACCEPT, + sign(sec_websocket_key.as_bytes()), + ); + + if let Some(protocol) = self.protocol { + builder = builder.header(header::SEC_WEBSOCKET_PROTOCOL, protocol); + } - builder.body(Body::empty()).unwrap() + builder.body(Body::empty()).unwrap() + } else { + // Otherwise, we are HTTP/2+. As established in RFC 9113 section 8.5, we just respond + // with a 2XX with an empty body: + // . + Response::new(Body::empty()) + } } } @@ -381,7 +392,6 @@ impl OnFailedUpgrade for DefaultOnFailedUpgrade { fn call(self, _error: Error) {} } -#[async_trait] impl FromRequestParts for WebSocketUpgrade where S: Send + Sync, @@ -389,28 +399,49 @@ where type Rejection = WebSocketUpgradeRejection; async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { - if parts.method != Method::GET { - return Err(MethodNotGet.into()); - } + let sec_websocket_key = if parts.version <= Version::HTTP_11 { + if parts.method != Method::GET { + return Err(MethodNotGet.into()); + } - if !header_contains(&parts.headers, header::CONNECTION, "upgrade") { - return Err(InvalidConnectionHeader.into()); - } + if !header_contains(&parts.headers, header::CONNECTION, "upgrade") { + return Err(InvalidConnectionHeader.into()); + } - if !header_eq(&parts.headers, header::UPGRADE, "websocket") { - return Err(InvalidUpgradeHeader.into()); - } + if !header_eq(&parts.headers, header::UPGRADE, "websocket") { + return Err(InvalidUpgradeHeader.into()); + } + + Some( + parts + .headers + .get(header::SEC_WEBSOCKET_KEY) + .ok_or(WebSocketKeyHeaderMissing)? + .clone(), + ) + } else { + if parts.method != Method::CONNECT { + return Err(MethodNotConnect.into()); + } + + // if this feature flag is disabled, we won’t be receiving an HTTP/2 request to begin + // with. + #[cfg(feature = "http2")] + if parts + .extensions + .get::() + .map_or(true, |p| p.as_str() != "websocket") + { + return Err(InvalidProtocolPseudoheader.into()); + } + + None + }; if !header_eq(&parts.headers, header::SEC_WEBSOCKET_VERSION, "13") { return Err(InvalidWebSocketVersionHeader.into()); } - let sec_websocket_key = parts - .headers - .get(header::SEC_WEBSOCKET_KEY) - .ok_or(WebSocketKeyHeaderMissing)? - .clone(); - let on_upgrade = parts .extensions .remove::() @@ -476,11 +507,6 @@ impl WebSocket { .map_err(Error::new) } - /// Gracefully close this WebSocket. - pub async fn close(mut self) -> Result<(), Error> { - self.inner.close(None).await.map_err(Error::new) - } - /// Return the selected WebSocket subprotocol, if one has been chosen. pub fn protocol(&self) -> Option<&HeaderValue> { self.protocol.as_ref() @@ -584,6 +610,24 @@ pub enum Message { /// [unidirectional heartbeat](https://tools.ietf.org/html/rfc6455#section-5.5.3). Pong(Vec), /// A close message with the optional close frame. + /// + /// You may "uncleanly" close a WebSocket connection at any time + /// by simply dropping the [`WebSocket`]. + /// However, you may also use the graceful closing protocol, in which + /// 1. peer A sends a close frame, and does not send any further messages; + /// 2. peer B responds with a close frame, and does not send any further messages; + /// 3. peer A processes the remaining messages sent by peer B, before finally + /// 4. both peers close the connection. + /// + /// After sending a close frame, + /// you may still read messages, + /// but attempts to send another message will error. + /// After receiving a close frame, + /// axum will automatically respond with a close frame if necessary + /// (you do not have to deal with this yourself). + /// Since no further messages will be received, + /// you may either do nothing + /// or explicitly drop the connection. Close(Option>), } @@ -708,6 +752,13 @@ pub mod rejection { pub struct MethodNotGet; } + define_rejection! { + #[status = METHOD_NOT_ALLOWED] + #[body = "Request method must be `CONNECT`"] + /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade). + pub struct MethodNotConnect; + } + define_rejection! { #[status = BAD_REQUEST] #[body = "Connection header did not include 'upgrade'"] @@ -722,6 +773,13 @@ pub mod rejection { pub struct InvalidUpgradeHeader; } + define_rejection! { + #[status = BAD_REQUEST] + #[body = "`:protocol` pseudo-header did not include 'websocket'"] + /// Rejection type for [`WebSocketUpgrade`](super::WebSocketUpgrade). + pub struct InvalidProtocolPseudoheader; + } + define_rejection! { #[status = BAD_REQUEST] #[body = "`Sec-WebSocket-Version` header did not include '13'"] @@ -757,8 +815,10 @@ pub mod rejection { /// extractor can fail. pub enum WebSocketUpgradeRejection { MethodNotGet, + MethodNotConnect, InvalidConnectionHeader, InvalidUpgradeHeader, + InvalidProtocolPseudoheader, InvalidWebSocketVersionHeader, WebSocketKeyHeaderMissing, ConnectionNotUpgradable, @@ -783,8 +843,9 @@ pub mod close_code { pub const PROTOCOL: u16 = 1002; /// Indicates that an endpoint is terminating the connection because it has received a type of - /// data it cannot accept (e.g., an endpoint that understands only text data MAY send this if - /// it receives a binary message). + /// data that it cannot accept. + /// + /// For example, an endpoint MAY send this if it understands only text data, but receives a binary message. pub const UNSUPPORTED: u16 = 1003; /// Indicates that no status code was included in a closing frame. @@ -794,12 +855,15 @@ pub mod close_code { pub const ABNORMAL: u16 = 1006; /// Indicates that an endpoint is terminating the connection because it has received data - /// within a message that was not consistent with the type of the message (e.g., non-UTF-8 - /// RFC3629 data within a text message). + /// within a message that was not consistent with the type of the message. + /// + /// For example, an endpoint received non-UTF-8 RFC3629 data within a text message. pub const INVALID: u16 = 1007; /// Indicates that an endpoint is terminating the connection because it has received a message - /// that violates its policy. This is a generic status code that can be returned when there is + /// that violates its policy. + /// + /// This is a generic status code that can be returned when there is /// no other more suitable status code (e.g., `UNSUPPORTED` or `SIZE`) or if there is a need to /// hide specific details about the policy. pub const POLICY: u16 = 1008; @@ -808,10 +872,13 @@ pub mod close_code { /// that is too big for it to process. pub const SIZE: u16 = 1009; - /// Indicates that an endpoint (client) is terminating the connection because it has expected - /// the server to negotiate one or more extension, but the server didn't return them in the - /// response message of the WebSocket handshake. The list of extensions that are needed should - /// be given as the reason for closing. Note that this status code is not used by the server, + /// Indicates that an endpoint (client) is terminating the connection because the server + /// did not respond to extension negotiation correctly. + /// + /// Specifically, the client has expected the server to negotiate one or more extension(s), + /// but the server didn't return them in the response message of the WebSocket handshake. + /// The list of extensions that are needed should be given as the reason for closing. + /// Note that this status code is not used by the server, /// because it can fail the WebSocket handshake instead. pub const EXTENSION: u16 = 1010; @@ -830,14 +897,21 @@ pub mod close_code { #[cfg(test)] mod tests { + use std::future::ready; + use super::*; - use crate::{body::Body, routing::get, Router}; + use crate::{routing::any, test_helpers::spawn_service, Router}; use http::{Request, Version}; + use http_body_util::BodyExt as _; + use hyper_util::rt::TokioExecutor; + use tokio::io::{AsyncRead, AsyncWrite}; + use tokio::net::TcpStream; + use tokio_tungstenite::tungstenite; use tower::ServiceExt; #[crate::test] async fn rejects_http_1_0_requests() { - let svc = get(|ws: Result| { + let svc = any(|ws: Result| { let rejection = ws.unwrap_err(); assert!(matches!( rejection, @@ -866,7 +940,7 @@ mod tests { async fn handler(ws: WebSocketUpgrade) -> Response { ws.on_upgrade(|_| async {}) } - let _: Router = Router::new().route("/", get(handler)); + let _: Router = Router::new().route("/", any(handler)); } #[allow(dead_code)] @@ -875,6 +949,96 @@ mod tests { ws.on_failed_upgrade(|_error: Error| println!("oops!")) .on_upgrade(|_| async {}) } - let _: Router = Router::new().route("/", get(handler)); + let _: Router = Router::new().route("/", any(handler)); + } + + #[crate::test] + async fn integration_test() { + let addr = spawn_service(echo_app()); + let (socket, _response) = tokio_tungstenite::connect_async(format!("ws://{addr}/echo")) + .await + .unwrap(); + test_echo_app(socket).await; + } + + #[crate::test] + #[cfg(feature = "http2")] + async fn http2() { + let addr = spawn_service(echo_app()); + let io = TokioIo::new(TcpStream::connect(addr).await.unwrap()); + let (mut send_request, conn) = + hyper::client::conn::http2::Builder::new(TokioExecutor::new()) + .handshake(io) + .await + .unwrap(); + + // Wait a little for the SETTINGS frame to go through… + for _ in 0..10 { + tokio::task::yield_now().await; + } + assert!(conn.is_extended_connect_protocol_enabled()); + tokio::spawn(async { + conn.await.unwrap(); + }); + + let req = Request::builder() + .method(Method::CONNECT) + .extension(hyper::ext::Protocol::from_static("websocket")) + .uri("/echo") + .header("sec-websocket-version", "13") + .header("Host", "server.example.com") + .body(Body::empty()) + .unwrap(); + + let response = send_request.send_request(req).await.unwrap(); + let status = response.status(); + if status != 200 { + let body = response.into_body().collect().await.unwrap().to_bytes(); + let body = std::str::from_utf8(&body).unwrap(); + panic!("response status was {}: {body}", status); + } + let upgraded = hyper::upgrade::on(response).await.unwrap(); + let upgraded = TokioIo::new(upgraded); + let socket = WebSocketStream::from_raw_socket(upgraded, protocol::Role::Client, None).await; + test_echo_app(socket).await; + } + + fn echo_app() -> Router { + async fn handle_socket(mut socket: WebSocket) { + while let Some(Ok(msg)) = socket.recv().await { + match msg { + Message::Text(_) | Message::Binary(_) | Message::Close(_) => { + if socket.send(msg).await.is_err() { + break; + } + } + Message::Ping(_) | Message::Pong(_) => { + // tungstenite will respond to pings automatically + } + } + } + } + + Router::new().route( + "/echo", + any(|ws: WebSocketUpgrade| ready(ws.on_upgrade(handle_socket))), + ) + } + + async fn test_echo_app(mut socket: WebSocketStream) { + let input = tungstenite::Message::Text("foobar".to_owned()); + socket.send(input.clone()).await.unwrap(); + let output = socket.next().await.unwrap().unwrap(); + assert_eq!(input, output); + + socket + .send(tungstenite::Message::Ping("ping".to_owned().into_bytes())) + .await + .unwrap(); + let output = socket.next().await.unwrap().unwrap(); + assert_eq!( + output, + tungstenite::Message::Pong("ping".to_owned().into_bytes()) + ); } } diff --git a/axum/src/form.rs b/axum/src/form.rs index b38363b4db..40b7585b56 100644 --- a/axum/src/form.rs +++ b/axum/src/form.rs @@ -1,6 +1,5 @@ use crate::extract::Request; use crate::extract::{rejection::*, FromRequest, RawForm}; -use async_trait::async_trait; use axum_core::response::{IntoResponse, IntoResponseFailed, Response}; use axum_core::RequestExt; use http::header::CONTENT_TYPE; @@ -72,7 +71,6 @@ use serde::Serialize; #[must_use] pub struct Form(pub T); -#[async_trait] impl FromRequest for Form where T: DeserializeOwned, @@ -257,14 +255,13 @@ mod tests { let client = TestClient::new(app); - let res = client.get("/?a=false").send().await; + let res = client.get("/?a=false").await; assert_eq!(res.status(), StatusCode::BAD_REQUEST); let res = client .post("/") .header(CONTENT_TYPE, APPLICATION_WWW_FORM_URLENCODED.as_ref()) .body("a=false") - .send() .await; assert_eq!(res.status(), StatusCode::UNPROCESSABLE_ENTITY); } diff --git a/axum/src/handler/mod.rs b/axum/src/handler/mod.rs index 41596dd001..4821fe1873 100644 --- a/axum/src/handler/mod.rs +++ b/axum/src/handler/mod.rs @@ -10,8 +10,8 @@ //! // Handler that immediately returns an empty `200 OK` response. //! async fn unit_handler() {} //! -//! // Handler that immediately returns an empty `200 OK` response with a plain -//! // text body. +//! // Handler that immediately returns a `200 OK` response with a plain text +//! // body. //! async fn string_handler() -> String { //! "Hello, World!".to_string() //! } @@ -37,10 +37,10 @@ //! in handlers. See those examples: //! //! * [`anyhow-error-response`][anyhow] for generic boxed errors -//! * [`error-handling-and-dependency-injection`][ehdi] for application-specific detailed errors +//! * [`error-handling`][error-handling] for application-specific detailed errors //! //! [anyhow]: https://github.com/tokio-rs/axum/blob/main/examples/anyhow-error-response/src/main.rs -//! [ehdi]: https://github.com/tokio-rs/axum/blob/main/examples/error-handling-and-dependency-injection/src/main.rs +//! [error-handling]: https://github.com/tokio-rs/axum/blob/main/examples/error-handling/src/main.rs //! #![doc = include_str!("../docs/debugging_handler_type_errors.md")] @@ -125,13 +125,13 @@ pub use self::service::HandlerService; /// ))); /// # let _: Router = app; /// ``` -#[cfg_attr( - nightly_error_messages, +#[rustversion::attr( + since(1.78), diagnostic::on_unimplemented( note = "Consider using `#[axum::debug_handler]` to improve the error message" ) )] -pub trait Handler: Clone + Send + Sized + 'static { +pub trait Handler: Clone + Send + Sync + Sized + 'static { /// The type of future calling this handler returns. type Future: Future + Send + 'static; @@ -192,7 +192,7 @@ pub trait Handler: Clone + Send + Sized + 'static { impl Handler<((),), S> for F where - F: FnOnce() -> Fut + Clone + Send + 'static, + F: FnOnce() -> Fut + Clone + Send + Sync + 'static, Fut: Future + Send, Res: IntoResponse, { @@ -210,7 +210,7 @@ macro_rules! impl_handler { #[allow(non_snake_case, unused_mut)] impl Handler<(M, $($ty,)* $last,), S> for F where - F: FnOnce($($ty,)* $last,) -> Fut + Clone + Send + 'static, + F: FnOnce($($ty,)* $last,) -> Fut + Clone + Send + Sync + 'static, Fut: Future + Send, S: Send + Sync + 'static, Res: IntoResponse, @@ -257,7 +257,7 @@ mod private { impl Handler for T where - T: IntoResponse + Clone + Send + 'static, + T: IntoResponse + Clone + Send + Sync + 'static, { type Future = std::future::Ready; @@ -302,7 +302,7 @@ where impl Handler for Layered where - L: Layer> + Clone + Send + 'static, + L: Layer> + Clone + Send + Sync + 'static, H: Handler, L::Service: Service + Clone + Send + 'static, >::Response: IntoResponse, @@ -403,7 +403,7 @@ mod tests { let client = TestClient::new(handle.into_service()); - let res = client.post("/").body("hi there!").send().await; + let res = client.post("/").body("hi there!").await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "you said: hi there!"); } @@ -424,7 +424,7 @@ mod tests { .with_state("foo"); let client = TestClient::new(svc); - let res = client.get("/").send().await; + let res = client.get("/").await; assert_eq!(res.text().await, "foo"); } } diff --git a/axum/src/json.rs b/axum/src/json.rs index 422a186493..8f66e70469 100644 --- a/axum/src/json.rs +++ b/axum/src/json.rs @@ -1,6 +1,5 @@ use crate::extract::Request; use crate::extract::{rejection::*, FromRequest}; -use async_trait::async_trait; use axum_core::response::{IntoResponse, IntoResponseFailed, Response}; use bytes::{BufMut, Bytes, BytesMut}; use http::{ @@ -17,8 +16,7 @@ use serde::{de::DeserializeOwned, Serialize}; /// /// - The request doesn't have a `Content-Type: application/json` (or similar) header. /// - The body doesn't contain syntactically valid JSON. -/// - The body contains syntactically valid JSON, but it couldn't be deserialized into the target -/// type. +/// - The body contains syntactically valid JSON, but it couldn't be deserialized into the target type. /// - Buffering the request body fails. /// /// ⚠️ Since parsing JSON requires consuming the request body, the `Json` extractor must be @@ -56,6 +54,11 @@ use serde::{de::DeserializeOwned, Serialize}; /// When used as a response, it can serialize any type that implements [`serde::Serialize`] to /// `JSON`, and will automatically set `Content-Type: application/json` header. /// +/// If the [`Serialize`] implementation decides to fail +/// or if a map with non-string keys is used, +/// a 500 response will be issued +/// whose body is the error message in UTF-8. +/// /// # Response example /// /// ``` @@ -84,7 +87,7 @@ use serde::{de::DeserializeOwned, Serialize}; /// # unimplemented!() /// } /// -/// let app = Router::new().route("/users/:id", get(get_user)); +/// let app = Router::new().route("/users/{id}", get(get_user)); /// # let _: Router = app; /// ``` #[derive(Debug, Clone, Copy, Default)] @@ -92,7 +95,6 @@ use serde::{de::DeserializeOwned, Serialize}; #[must_use] pub struct Json(pub T); -#[async_trait] impl FromRequest for Json where T: DeserializeOwned, @@ -130,7 +132,7 @@ fn json_content_type(headers: &HeaderMap) -> bool { }; let is_json_content_type = mime.type_() == "application" - && (mime.subtype() == "json" || mime.suffix().map_or(false, |name| name == "json")); + && (mime.subtype() == "json" || mime.suffix().is_some_and(|name| name == "json")); is_json_content_type } @@ -227,7 +229,7 @@ mod tests { let app = Router::new().route("/", post(|input: Json| async { input.0.foo })); let client = TestClient::new(app); - let res = client.post("/").json(&json!({ "foo": "bar" })).send().await; + let res = client.post("/").json(&json!({ "foo": "bar" })).await; let body = res.text().await; assert_eq!(body, "bar"); @@ -243,7 +245,7 @@ mod tests { let app = Router::new().route("/", post(|input: Json| async { input.0.foo })); let client = TestClient::new(app); - let res = client.post("/").body(r#"{ "foo": "bar" }"#).send().await; + let res = client.post("/").body(r#"{ "foo": "bar" }"#).await; let status = res.status(); @@ -261,7 +263,6 @@ mod tests { .post("/") .header("content-type", content_type) .body("{}") - .send() .await; res.status() == StatusCode::OK @@ -283,7 +284,6 @@ mod tests { .post("/") .body("{") .header("content-type", "application/json") - .send() .await; assert_eq!(res.status(), StatusCode::BAD_REQUEST); @@ -314,7 +314,6 @@ mod tests { .post("/") .body("{\"a\": 1, \"b\": [{\"x\": 2}]}") .header("content-type", "application/json") - .send() .await; assert_eq!(res.status(), StatusCode::UNPROCESSABLE_ENTITY); diff --git a/axum/src/lib.rs b/axum/src/lib.rs index 42d5d41afc..fcc929a6ab 100644 --- a/axum/src/lib.rs +++ b/axum/src/lib.rs @@ -1,23 +1,5 @@ -#![cfg_attr(nightly_error_messages, feature(diagnostic_namespace))] //! axum is a web application framework that focuses on ergonomics and modularity. //! -//! # Table of contents -//! -//! - [High-level features](#high-level-features) -//! - [Compatibility](#compatibility) -//! - [Example](#example) -//! - [Routing](#routing) -//! - [Handlers](#handlers) -//! - [Extractors](#extractors) -//! - [Responses](#responses) -//! - [Error handling](#error-handling) -//! - [Middleware](#middleware) -//! - [Sharing state with handlers](#sharing-state-with-handlers) -//! - [Building integrations for axum](#building-integrations-for-axum) -//! - [Required dependencies](#required-dependencies) -//! - [Examples](#examples) -//! - [Feature flags](#feature-flags) -//! //! # High-level features //! //! - Route requests to handlers with a macro-free API. @@ -269,7 +251,7 @@ //! }), //! ) //! .route( -//! "/users/:id", +//! "/users/{id}", //! get({ //! let shared_state = Arc::clone(&shared_state); //! move |path| get_user(path, shared_state) @@ -294,6 +276,67 @@ //! The downside to this approach is that it's a little more verbose than using //! [`State`] or extensions. //! +//! ## Using [tokio's `task_local` macro](https://docs.rs/tokio/1/tokio/macro.task_local.html): +//! +//! This allows to share state with `IntoResponse` implementations. +//! +//! ```rust,no_run +//! use axum::{ +//! extract::Request, +//! http::{header, StatusCode}, +//! middleware::{self, Next}, +//! response::{IntoResponse, Response}, +//! routing::get, +//! Router, +//! }; +//! use tokio::task_local; +//! +//! #[derive(Clone)] +//! struct CurrentUser { +//! name: String, +//! } +//! task_local! { +//! pub static USER: CurrentUser; +//! } +//! +//! async fn auth(req: Request, next: Next) -> Result { +//! let auth_header = req +//! .headers() +//! .get(header::AUTHORIZATION) +//! .and_then(|header| header.to_str().ok()) +//! .ok_or(StatusCode::UNAUTHORIZED)?; +//! if let Some(current_user) = authorize_current_user(auth_header).await { +//! // State is setup here in the middleware +//! Ok(USER.scope(current_user, next.run(req)).await) +//! } else { +//! Err(StatusCode::UNAUTHORIZED) +//! } +//! } +//! async fn authorize_current_user(auth_token: &str) -> Option { +//! Some(CurrentUser { +//! name: auth_token.to_string(), +//! }) +//! } +//! +//! struct UserResponse; +//! +//! impl IntoResponse for UserResponse { +//! fn into_response(self) -> Response { +//! // State is accessed here in the IntoResponse implementation +//! let current_user = USER.with(|u| u.clone()); +//! (StatusCode::OK, current_user.name).into_response() +//! } +//! } +//! +//! async fn handler() -> UserResponse { +//! UserResponse +//! } +//! +//! let app: Router = Router::new() +//! .route("/", get(handler)) +//! .route_layer(middleware::from_fn(auth)); +//! ``` +//! //! # Building integrations for axum //! //! Libraries authors that want to provide [`FromRequest`], [`FromRequestParts`], or @@ -340,7 +383,7 @@ //! `original-uri` | Enables capturing of every request's original URI and the [`OriginalUri`] extractor | Yes //! `tokio` | Enables `tokio` as a dependency and `axum::serve`, `SSE` and `extract::connect_info` types. | Yes //! `tower-log` | Enables `tower`'s `log` feature | Yes -//! `tracing` | Log rejections from built-in extractors | No +//! `tracing` | Log rejections from built-in extractors | Yes //! `ws` | Enables WebSockets support via [`extract::ws`] | No //! `form` | Enables the `Form` extractor | Yes //! `query` | Enables the `Query` extractor | Yes @@ -389,7 +432,6 @@ clippy::needless_borrow, clippy::match_wildcard_for_single_variants, clippy::if_let_mutex, - clippy::mismatched_target_os, clippy::await_holding_lock, clippy::match_on_vec_items, clippy::imprecise_flops, @@ -421,6 +463,7 @@ #[macro_use] pub(crate) mod macros; +mod box_clone_service; mod boxed; mod extension; #[cfg(feature = "form")] @@ -443,8 +486,6 @@ pub mod serve; #[cfg(test)] mod test_helpers; -#[doc(no_inline)] -pub use async_trait::async_trait; #[doc(no_inline)] pub use http; @@ -464,7 +505,7 @@ pub use self::form::Form; pub use axum_core::{BoxError, Error, RequestExt, RequestPartsExt}; #[cfg(feature = "macros")] -pub use axum_macros::debug_handler; +pub use axum_macros::{debug_handler, debug_middleware}; #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] #[doc(inline)] diff --git a/axum/src/middleware/from_extractor.rs b/axum/src/middleware/from_extractor.rs index e120ffc1fc..cda0d97798 100644 --- a/axum/src/middleware/from_extractor.rs +++ b/axum/src/middleware/from_extractor.rs @@ -26,7 +26,7 @@ use tower_service::Service; /// without repeating it in the function signature. /// /// Note that if the extractor consumes the request body, as `String` or -/// [`Bytes`] does, an empty body will be left in its place. Thus wont be +/// [`Bytes`] does, an empty body will be left in its place. Thus won't be /// accessible to subsequent extractors or handlers. /// /// # Example @@ -39,12 +39,10 @@ use tower_service::Service; /// Router, /// http::{header, StatusCode, request::Parts}, /// }; -/// use async_trait::async_trait; /// /// // An extractor that performs authorization. /// struct RequireAuth; /// -/// #[async_trait] /// impl FromRequestParts for RequireAuth /// where /// S: Send + Sync, @@ -303,7 +301,7 @@ where #[cfg(test)] mod tests { use super::*; - use crate::{async_trait, handler::Handler, routing::get, test_helpers::*, Router}; + use crate::{handler::Handler, routing::get, test_helpers::*, Router}; use axum_core::extract::FromRef; use http::{header, request::Parts, StatusCode}; use tower_http::limit::RequestBodyLimitLayer; @@ -315,7 +313,6 @@ mod tests { struct RequireAuth; - #[async_trait::async_trait] impl FromRequestParts for RequireAuth where S: Send + Sync, @@ -352,13 +349,12 @@ mod tests { let client = TestClient::new(app); - let res = client.get("/").send().await; + let res = client.get("/").await; assert_eq!(res.status(), StatusCode::UNAUTHORIZED); let res = client .get("/") .header(http::header::AUTHORIZATION, "secret") - .send() .await; assert_eq!(res.status(), StatusCode::OK); } @@ -368,7 +364,6 @@ mod tests { fn works_with_request_body_limit() { struct MyExtractor; - #[async_trait] impl FromRequestParts for MyExtractor where S: Send + Sync, diff --git a/axum/src/middleware/from_fn.rs b/axum/src/middleware/from_fn.rs index e4c44c74f5..abee97baf0 100644 --- a/axum/src/middleware/from_fn.rs +++ b/axum/src/middleware/from_fn.rs @@ -1,3 +1,4 @@ +use crate::box_clone_service::BoxCloneService; use crate::response::{IntoResponse, Response}; use axum_core::extract::{FromRequest, FromRequestParts, Request}; use futures_util::future::BoxFuture; @@ -10,7 +11,7 @@ use std::{ pin::Pin, task::{Context, Poll}, }; -use tower::{util::BoxCloneService, ServiceBuilder}; +use tower::ServiceBuilder; use tower_layer::Layer; use tower_service::Service; @@ -19,9 +20,10 @@ use tower_service::Service; /// `from_fn` requires the function given to /// /// 1. Be an `async fn`. -/// 2. Take one or more [extractors] as the first arguments. -/// 3. Take [`Next`](Next) as the final argument. -/// 4. Return something that implements [`IntoResponse`]. +/// 2. Take zero or more [`FromRequestParts`] extractors. +/// 3. Take exactly one [`FromRequest`] extractor as the second to last argument. +/// 4. Take [`Next`](Next) as the last argument. +/// 5. Return something that implements [`IntoResponse`]. /// /// Note that this function doesn't support extracting [`State`]. For that, use [`from_fn_with_state`]. /// @@ -112,6 +114,8 @@ pub fn from_fn(f: F) -> FromFnLayer { /// Create a middleware from an async function with the given state. /// +/// For the requirements for the function supplied see [`from_fn`]. +/// /// See [`State`](crate::extract::State) for more details about accessing state. /// /// # Example @@ -166,7 +170,7 @@ pub fn from_fn_with_state(state: S, f: F) -> FromFnLayer { /// /// [`tower::Layer`] is used to apply middleware to [`Router`](crate::Router)'s. /// -/// Created with [`from_fn`]. See that function for more details. +/// Created with [`from_fn`] or [`from_fn_with_state`]. See those functions for more details. #[must_use] pub struct FromFnLayer { f: F, @@ -220,7 +224,7 @@ where /// A middleware created from an async function. /// -/// Created with [`from_fn`]. See that function for more details. +/// Created with [`from_fn`] or [`from_fn_with_state`]. See those functions for more details. pub struct FromFn { f: F, inner: I, @@ -259,6 +263,7 @@ macro_rules! impl_service { I: Service + Clone + Send + + Sync + 'static, I::Response: IntoResponse, I::Future: Send + 'static, @@ -297,7 +302,7 @@ macro_rules! impl_service { }; let inner = ServiceBuilder::new() - .boxed_clone() + .layer_fn(BoxCloneService::new) .map_response(IntoResponse::into_response) .service(ready_inner); let next = Next { inner }; diff --git a/axum/src/middleware/map_request.rs b/axum/src/middleware/map_request.rs index d36a7cc958..596b6c3c87 100644 --- a/axum/src/middleware/map_request.rs +++ b/axum/src/middleware/map_request.rs @@ -411,7 +411,7 @@ mod tests { .layer(map_request(add_header)); let client = TestClient::new(app); - let res = client.get("/").send().await; + let res = client.get("/").await; assert_eq!(res.text().await, "foo"); } @@ -431,7 +431,7 @@ mod tests { .layer(map_request(add_header)); let client = TestClient::new(app); - let res = client.get("/").send().await; + let res = client.get("/").await; assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR); assert_eq!(res.text().await, "something went wrong"); diff --git a/axum/src/middleware/map_response.rs b/axum/src/middleware/map_response.rs index 06f9825740..2510cdc256 100644 --- a/axum/src/middleware/map_response.rs +++ b/axum/src/middleware/map_response.rs @@ -357,7 +357,7 @@ mod tests { let app = Router::new().layer(map_response(add_header)); let client = TestClient::new(app); - let res = client.get("/").send().await; + let res = client.get("/").await; assert_eq!(res.headers()["x-foo"], "foo"); } diff --git a/axum/src/response/mod.rs b/axum/src/response/mod.rs index 0710c64df3..00dc00ccf5 100644 --- a/axum/src/response/mod.rs +++ b/axum/src/response/mod.rs @@ -1,7 +1,6 @@ #![doc = include_str!("../docs/response.md")] -use axum_core::body::Body; -use http::{header, HeaderValue}; +use http::{header, HeaderValue, StatusCode}; mod redirect; @@ -41,7 +40,7 @@ pub struct Html(pub T); impl IntoResponse for Html where - T: Into, + T: IntoResponse, { fn into_response(self) -> Response { ( @@ -49,7 +48,7 @@ where header::CONTENT_TYPE, HeaderValue::from_static(mime::TEXT_HTML_UTF_8.as_ref()), )], - self.0.into(), + self.0, ) .into_response() } @@ -61,6 +60,31 @@ impl From for Html { } } +/// An empty response with 204 No Content status. +/// +/// Due to historical and implementation reasons, the `IntoResponse` implementation of `()` +/// (unit type) returns an empty response with 200 [`StatusCode::OK`] status. +/// If you specifically want a 204 [`StatusCode::NO_CONTENT`] status, you can use either `StatusCode` type +/// directly, or this shortcut struct for self-documentation. +/// +/// ``` +/// use axum::{extract::Path, response::NoContent}; +/// +/// async fn delete_user(Path(user): Path) -> Result { +/// // ...access database... +/// # drop(user); +/// Ok(NoContent) +/// } +/// ``` +#[derive(Debug, Clone, Copy)] +pub struct NoContent; + +impl IntoResponse for NoContent { + fn into_response(self) -> Response { + StatusCode::NO_CONTENT.into_response() + } +} + #[cfg(test)] mod tests { use crate::extract::Extension; @@ -365,7 +389,6 @@ mod tests { assert_eq!(res.headers()["x-foo"], "foo"); assert_eq!(res.headers()["x-bar"], "bar"); } - #[test] fn into_response_parts_failing_sets_extension() { struct Fail; @@ -411,7 +434,6 @@ mod tests { .get::() .is_some()); } - #[test] fn doenst_override_status_code_when_using_into_response_failed_at_same_level() { assert_eq!( @@ -447,7 +469,6 @@ mod tests { StatusCode::OK, ); } - #[test] fn force_overriding_status_code() { assert_eq!( @@ -482,7 +503,6 @@ mod tests { StatusCode::IM_A_TEAPOT ); } - #[crate::test] async fn status_code_tuple_doesnt_override_error_json() { let app = Router::new() @@ -512,4 +532,11 @@ mod tests { let res = client.get("/two").send().await; assert_eq!(res.status(), StatusCode::IM_A_TEAPOT); } + #[test] + fn no_content() { + assert_eq!( + super::NoContent.into_response().status(), + StatusCode::NO_CONTENT, + ) + } } diff --git a/axum/src/response/sse.rs b/axum/src/response/sse.rs index 45e054109a..b414f05725 100644 --- a/axum/src/response/sse.rs +++ b/axum/src/response/sse.rs @@ -208,12 +208,29 @@ impl Event { where T: serde::Serialize, { + struct IgnoreNewLines<'a>(bytes::buf::Writer<&'a mut BytesMut>); + impl std::io::Write for IgnoreNewLines<'_> { + fn write(&mut self, buf: &[u8]) -> std::io::Result { + let mut last_split = 0; + for delimiter in memchr::memchr2_iter(b'\n', b'\r', buf) { + self.0.write_all(&buf[last_split..delimiter])?; + last_split = delimiter + 1; + } + self.0.write_all(&buf[last_split..])?; + Ok(buf.len()) + } + + fn flush(&mut self) -> std::io::Result<()> { + self.0.flush() + } + } if self.flags.contains(EventFlags::HAS_DATA) { panic!("Called `EventBuilder::json_data` multiple times"); } self.buffer.extend_from_slice(b"data: "); - serde_json::to_writer((&mut self.buffer).writer(), &data).map_err(axum_core::Error::new)?; + serde_json::to_writer(IgnoreNewLines((&mut self.buffer).writer()), &data) + .map_err(axum_core::Error::new)?; self.buffer.put_u8(b'\n'); self.flags.insert(EventFlags::HAS_DATA); @@ -515,6 +532,7 @@ mod tests { use super::*; use crate::{routing::get, test_helpers::*, Router}; use futures_util::stream; + use serde_json::value::RawValue; use std::{collections::HashMap, convert::Infallible}; use tokio_stream::StreamExt as _; @@ -527,6 +545,18 @@ mod tests { assert_eq!(&*leading_space.finalize(), b"data: foobar\n\n"); } + #[test] + fn valid_json_raw_value_chars_stripped() { + let json_string = "{\r\"foo\": \n\r\r \"bar\\n\"\n}"; + let json_raw_value_event = Event::default() + .json_data(serde_json::from_str::<&RawValue>(json_string).unwrap()) + .unwrap(); + assert_eq!( + &*json_raw_value_event.finalize(), + format!("data: {}\n\n", json_string.replace(['\n', '\r'], "")).as_bytes() + ); + } + #[crate::test] async fn basic() { let app = Router::new().route( @@ -548,7 +578,7 @@ mod tests { ); let client = TestClient::new(app); - let mut stream = client.get("/").send().await; + let mut stream = client.get("/").await; assert_eq!(stream.headers()["content-type"], "text/event-stream"); assert_eq!(stream.headers()["cache-control"], "no-cache"); @@ -559,13 +589,13 @@ mod tests { let event_fields = parse_event(&stream.chunk_text().await.unwrap()); assert_eq!(event_fields.get("data").unwrap(), "{\"foo\":\"bar\"}"); - assert!(event_fields.get("comment").is_none()); + assert!(!event_fields.contains_key("comment")); let event_fields = parse_event(&stream.chunk_text().await.unwrap()); assert_eq!(event_fields.get("event").unwrap(), "three"); assert_eq!(event_fields.get("retry").unwrap(), "30000"); assert_eq!(event_fields.get("id").unwrap(), "unique-id"); - assert!(event_fields.get("comment").is_none()); + assert!(!event_fields.contains_key("comment")); assert!(stream.chunk_text().await.is_none()); } @@ -590,7 +620,7 @@ mod tests { ); let client = TestClient::new(app); - let mut stream = client.get("/").send().await; + let mut stream = client.get("/").await; for _ in 0..5 { // first message should be an event @@ -627,7 +657,7 @@ mod tests { ); let client = TestClient::new(app); - let mut stream = client.get("/").send().await; + let mut stream = client.get("/").await; // first message should be an event let event_fields = parse_event(&stream.chunk_text().await.unwrap()); diff --git a/axum/src/routing/method_filter.rs b/axum/src/routing/method_filter.rs index 1cea4235e5..040783ec33 100644 --- a/axum/src/routing/method_filter.rs +++ b/axum/src/routing/method_filter.rs @@ -9,6 +9,24 @@ use std::{ pub struct MethodFilter(u16); impl MethodFilter { + /// Match `CONNECT` requests. + /// + /// This is useful for implementing HTTP/2's [extended CONNECT method], + /// in which the `:protocol` pseudoheader is read + /// (using [`hyper::ext::Protocol`]) + /// and the connection upgraded to a bidirectional byte stream + /// (using [`hyper::upgrade::on`]). + /// + /// As seen in the [HTTP Upgrade Token Registry], + /// common uses include WebSockets and proxying UDP or IP – + /// though note that when using [`WebSocketUpgrade`] + /// it's more useful to use [`any`](crate::routing::any) + /// as HTTP/1.1 WebSockets need to support `GET`. + /// + /// [extended CONNECT]: https://www.rfc-editor.org/rfc/rfc8441.html#section-4 + /// [HTTP Upgrade Token Registry]: https://www.iana.org/assignments/http-upgrade-tokens/http-upgrade-tokens.xhtml + /// [`WebSocketUpgrade`]: crate::extract::WebSocketUpgrade + pub const CONNECT: Self = Self::from_bits(0b0_0000_0001); /// Match `DELETE` requests. pub const DELETE: Self = Self::from_bits(0b0_0000_0010); /// Match `GET` requests. @@ -71,6 +89,7 @@ impl TryFrom for MethodFilter { fn try_from(m: Method) -> Result { match m { + Method::CONNECT => Ok(MethodFilter::CONNECT), Method::DELETE => Ok(MethodFilter::DELETE), Method::GET => Ok(MethodFilter::GET), Method::HEAD => Ok(MethodFilter::HEAD), @@ -90,6 +109,11 @@ mod tests { #[test] fn from_http_method() { + assert_eq!( + MethodFilter::try_from(Method::CONNECT).unwrap(), + MethodFilter::CONNECT + ); + assert_eq!( MethodFilter::try_from(Method::DELETE).unwrap(), MethodFilter::DELETE @@ -130,9 +154,11 @@ mod tests { MethodFilter::TRACE ); - assert!(MethodFilter::try_from(http::Method::CONNECT) - .unwrap_err() - .to_string() - .contains("CONNECT")); + assert!( + MethodFilter::try_from(http::Method::from_bytes(b"CUSTOM").unwrap()) + .unwrap_err() + .to_string() + .contains("CUSTOM") + ); } } diff --git a/axum/src/routing/method_routing.rs b/axum/src/routing/method_routing.rs index 962a440111..cfc47e1f7f 100644 --- a/axum/src/routing/method_routing.rs +++ b/axum/src/routing/method_routing.rs @@ -59,6 +59,19 @@ macro_rules! top_level_service_fn { ); }; + ( + $name:ident, CONNECT + ) => { + top_level_service_fn!( + /// Route `CONNECT` requests to the given service. + /// + /// See [`MethodFilter::CONNECT`] for when you'd want to use this, + /// and [`get_service`] for an example. + $name, + CONNECT + ); + }; + ( $name:ident, $method:ident ) => { @@ -78,7 +91,7 @@ macro_rules! top_level_service_fn { $(#[$m])+ pub fn $name(svc: T) -> MethodRouter where - T: Service + Clone + Send + 'static, + T: Service + Clone + Send + Sync + 'static, T::Response: IntoResponse + 'static, T::Future: Send + 'static, S: Clone, @@ -118,6 +131,19 @@ macro_rules! top_level_handler_fn { ); }; + ( + $name:ident, CONNECT + ) => { + top_level_handler_fn!( + /// Route `CONNECT` requests to the given handler. + /// + /// See [`MethodFilter::CONNECT`] for when you'd want to use this, + /// and [`get`] for an example. + $name, + CONNECT + ); + }; + ( $name:ident, $method:ident ) => { @@ -187,6 +213,19 @@ macro_rules! chained_service_fn { ); }; + ( + $name:ident, CONNECT + ) => { + chained_service_fn!( + /// Chain an additional service that will only accept `CONNECT` requests. + /// + /// See [`MethodFilter::CONNECT`] for when you'd want to use this, + /// and [`MethodRouter::get_service`] for an example. + $name, + CONNECT + ); + }; + ( $name:ident, $method:ident ) => { @@ -210,6 +249,7 @@ macro_rules! chained_service_fn { T: Service + Clone + Send + + Sync + 'static, T::Response: IntoResponse + 'static, T::Future: Send + 'static, @@ -249,6 +289,19 @@ macro_rules! chained_handler_fn { ); }; + ( + $name:ident, CONNECT + ) => { + chained_handler_fn!( + /// Chain an additional handler that will only accept `CONNECT` requests. + /// + /// See [`MethodFilter::CONNECT`] for when you'd want to use this, + /// and [`MethodRouter::get`] for an example. + $name, + CONNECT + ); + }; + ( $name:ident, $method:ident ) => { @@ -278,6 +331,7 @@ macro_rules! chained_handler_fn { }; } +top_level_service_fn!(connect_service, CONNECT); top_level_service_fn!(delete_service, DELETE); top_level_service_fn!(get_service, GET); top_level_service_fn!(head_service, HEAD); @@ -312,7 +366,7 @@ top_level_service_fn!(trace_service, TRACE); /// ``` pub fn on_service(filter: MethodFilter, svc: T) -> MethodRouter where - T: Service + Clone + Send + 'static, + T: Service + Clone + Send + Sync + 'static, T::Response: IntoResponse + 'static, T::Future: Send + 'static, S: Clone, @@ -371,7 +425,7 @@ where /// ``` pub fn any_service(svc: T) -> MethodRouter where - T: Service + Clone + Send + 'static, + T: Service + Clone + Send + Sync + 'static, T::Response: IntoResponse + 'static, T::Future: Send + 'static, S: Clone, @@ -381,6 +435,7 @@ where .skip_allow_header() } +top_level_handler_fn!(connect, CONNECT); top_level_handler_fn!(delete, DELETE); top_level_handler_fn!(get, GET); top_level_handler_fn!(head, HEAD); @@ -497,6 +552,7 @@ pub struct MethodRouter { post: MethodEndpoint, put: MethodEndpoint, trace: MethodEndpoint, + connect: MethodEndpoint, fallback: Fallback, allow_header: AllowHeader, } @@ -538,6 +594,7 @@ impl fmt::Debug for MethodRouter { .field("post", &self.post) .field("put", &self.put) .field("trace", &self.trace) + .field("connect", &self.connect) .field("fallback", &self.fallback) .field("allow_header", &self.allow_header) .finish() @@ -582,6 +639,7 @@ where ) } + chained_handler_fn!(connect, CONNECT); chained_handler_fn!(delete, DELETE); chained_handler_fn!(get, GET); chained_handler_fn!(head, HEAD); @@ -601,6 +659,19 @@ where self.fallback = Fallback::BoxedHandler(BoxedIntoRoute::from_handler(handler)); self } + + /// Add a fallback [`Handler`] if no custom one has been provided. + pub(crate) fn default_fallback(self, handler: H) -> Self + where + H: Handler, + T: 'static, + S: Send + Sync + 'static, + { + match self.fallback { + Fallback::Default(_) => self.fallback(handler), + _ => self, + } + } } impl MethodRouter<(), Infallible> { @@ -689,6 +760,7 @@ where post: MethodEndpoint::None, put: MethodEndpoint::None, trace: MethodEndpoint::None, + connect: MethodEndpoint::None, allow_header: AllowHeader::None, fallback: Fallback::Default(fallback), } @@ -705,6 +777,7 @@ where post: self.post.with_state(&state), put: self.put.with_state(&state), trace: self.trace.with_state(&state), + connect: self.connect.with_state(&state), allow_header: self.allow_header, fallback: self.fallback.with_state(state), } @@ -736,7 +809,7 @@ where #[track_caller] pub fn on_service(self, filter: MethodFilter, svc: T) -> Self where - T: Service + Clone + Send + 'static, + T: Service + Clone + Send + Sync + 'static, T::Response: IntoResponse + 'static, T::Future: Send + 'static, { @@ -853,9 +926,20 @@ where &["DELETE"], ); + set_endpoint( + "CONNECT", + &mut self.options, + &endpoint, + filter, + MethodFilter::CONNECT, + &mut self.allow_header, + &["CONNECT"], + ); + self } + chained_service_fn!(connect_service, CONNECT); chained_service_fn!(delete_service, DELETE); chained_service_fn!(get_service, GET); chained_service_fn!(head_service, HEAD); @@ -868,7 +952,7 @@ where #[doc = include_str!("../docs/method_routing/fallback.md")] pub fn fallback_service(mut self, svc: T) -> Self where - T: Service + Clone + Send + 'static, + T: Service + Clone + Send + Sync + 'static, T::Response: IntoResponse + 'static, T::Future: Send + 'static, { @@ -879,8 +963,8 @@ where #[doc = include_str!("../docs/method_routing/layer.md")] pub fn layer(self, layer: L) -> MethodRouter where - L: Layer> + Clone + Send + 'static, - L::Service: Service + Clone + Send + 'static, + L: Layer> + Clone + Send + Sync + 'static, + L::Service: Service + Clone + Send + Sync + 'static, >::Response: IntoResponse + 'static, >::Error: Into + 'static, >::Future: Send + 'static, @@ -899,6 +983,7 @@ where post: self.post.map(layer_fn.clone()), put: self.put.map(layer_fn.clone()), trace: self.trace.map(layer_fn.clone()), + connect: self.connect.map(layer_fn.clone()), fallback: self.fallback.map(layer_fn), allow_header: self.allow_header, } @@ -908,8 +993,8 @@ where #[track_caller] pub fn route_layer(mut self, layer: L) -> MethodRouter where - L: Layer> + Clone + Send + 'static, - L::Service: Service + Clone + Send + 'static, + L: Layer> + Clone + Send + Sync + 'static, + L::Service: Service + Clone + Send + Sync + 'static, >::Response: IntoResponse + 'static, >::Future: Send + 'static, E: 'static, @@ -923,6 +1008,7 @@ where && self.post.is_none() && self.put.is_none() && self.trace.is_none() + && self.connect.is_none() { panic!( "Adding a route_layer before any routes is a no-op. \ @@ -943,7 +1029,8 @@ where self.patch = self.patch.map(layer_fn.clone()); self.post = self.post.map(layer_fn.clone()); self.put = self.put.map(layer_fn.clone()); - self.trace = self.trace.map(layer_fn); + self.trace = self.trace.map(layer_fn.clone()); + self.connect = self.connect.map(layer_fn); self } @@ -984,6 +1071,7 @@ where self.post = merge_inner(path, "POST", self.post, other.post); self.put = merge_inner(path, "PUT", self.put, other.put); self.trace = merge_inner(path, "TRACE", self.trace, other.trace); + self.connect = merge_inner(path, "CONNECT", self.connect, other.connect); self.fallback = self .fallback @@ -1022,33 +1110,28 @@ where self } - pub(crate) fn call_with_state(&mut self, req: Request, state: S) -> RouteFuture { + pub(crate) fn call_with_state(&self, req: Request, state: S) -> RouteFuture { macro_rules! call { ( $req:expr, - $method:expr, $method_variant:ident, $svc:expr ) => { - if $method == Method::$method_variant { + if *req.method() == Method::$method_variant { match $svc { MethodEndpoint::None => {} MethodEndpoint::Route(route) => { - return RouteFuture::from_future(route.oneshot_inner($req)) - .strip_body($method == Method::HEAD); + return route.clone().oneshot_inner_owned($req); } MethodEndpoint::BoxedHandler(handler) => { - let mut route = handler.clone().into_route(state); - return RouteFuture::from_future(route.oneshot_inner($req)) - .strip_body($method == Method::HEAD); + let route = handler.clone().into_route(state); + return route.oneshot_inner_owned($req); } } } }; } - let method = req.method().clone(); - // written with a pattern match like this to ensure we call all routes let Self { get, @@ -1059,21 +1142,23 @@ where post, put, trace, + connect, fallback, allow_header, } = self; - call!(req, method, HEAD, head); - call!(req, method, HEAD, get); - call!(req, method, GET, get); - call!(req, method, POST, post); - call!(req, method, OPTIONS, options); - call!(req, method, PATCH, patch); - call!(req, method, PUT, put); - call!(req, method, DELETE, delete); - call!(req, method, TRACE, trace); + call!(req, HEAD, head); + call!(req, HEAD, get); + call!(req, GET, get); + call!(req, POST, post); + call!(req, OPTIONS, options); + call!(req, PATCH, patch); + call!(req, PUT, put); + call!(req, DELETE, delete); + call!(req, TRACE, trace); + call!(req, CONNECT, connect); - let future = fallback.call_with_state(req, state); + let future = fallback.clone().call_with_state(req, state); match allow_header { AllowHeader::None => future.allow_header(Bytes::new()), @@ -1114,6 +1199,7 @@ impl Clone for MethodRouter { post: self.post.clone(), put: self.put.clone(), trace: self.trace.clone(), + connect: self.connect.clone(), fallback: self.fallback.clone(), allow_header: self.allow_header.clone(), } @@ -1151,7 +1237,7 @@ where where S: 'static, E: 'static, - F: FnOnce(Route) -> Route + Clone + Send + 'static, + F: FnOnce(Route) -> Route + Clone + Send + Sync + 'static, E2: 'static, { match self { @@ -1219,7 +1305,7 @@ where { type Future = InfallibleRouteFuture; - fn call(mut self, req: Request, state: S) -> Self::Future { + fn call(self, req: Request, state: S) -> Self::Future { InfallibleRouteFuture::new(self.call_with_state(req, state)) } } @@ -1239,7 +1325,7 @@ const _: () = { } fn call(&mut self, _req: IncomingStream<'_>) -> Self::Future { - std::future::ready(Ok(self.clone())) + std::future::ready(Ok(self.clone().with_state(()))) } } }; @@ -1247,12 +1333,11 @@ const _: () = { #[cfg(test)] mod tests { use super::*; - use crate::{body::Body, extract::State, handler::HandlerWithoutStateExt}; - use axum_core::response::IntoResponse; + use crate::{extract::State, handler::HandlerWithoutStateExt}; use http::{header::ALLOW, HeaderMap}; use http_body_util::BodyExt; use std::time::Duration; - use tower::{Service, ServiceExt}; + use tower::ServiceExt; use tower_http::{ services::fs::ServeDir, timeout::TimeoutLayer, validate_request::ValidateRequestHeaderLayer, }; diff --git a/axum/src/routing/mod.rs b/axum/src/routing/mod.rs index 13b5725549..d1e84d6aa9 100644 --- a/axum/src/routing/mod.rs +++ b/axum/src/routing/mod.rs @@ -17,6 +17,7 @@ use std::{ convert::Infallible, fmt, marker::PhantomData, + sync::Arc, task::{Context, Poll}, }; use tower_layer::Layer; @@ -39,9 +40,9 @@ mod tests; pub use self::{into_make_service::IntoMakeService, method_filter::MethodFilter, route::Route}; pub use self::method_routing::{ - any, any_service, delete, delete_service, get, get_service, head, head_service, on, on_service, - options, options_service, patch, patch_service, post, post_service, put, put_service, trace, - trace_service, MethodRouter, + any, any_service, connect, connect_service, delete, delete_service, get, get_service, head, + head_service, on, on_service, options, options_service, patch, patch_service, post, + post_service, put, put_service, trace, trace_service, MethodRouter, }; macro_rules! panic_on_err { @@ -59,23 +60,24 @@ pub(crate) struct RouteId(u32); /// The router type for composing handlers and services. #[must_use] pub struct Router { - path_router: PathRouter, - fallback_router: PathRouter, - default_fallback: bool, - catch_all_fallback: Fallback, + inner: Arc>, } impl Clone for Router { fn clone(&self) -> Self { Self { - path_router: self.path_router.clone(), - fallback_router: self.fallback_router.clone(), - default_fallback: self.default_fallback, - catch_all_fallback: self.catch_all_fallback.clone(), + inner: Arc::clone(&self.inner), } } } +struct RouterInner { + path_router: PathRouter, + fallback_router: PathRouter, + default_fallback: bool, + catch_all_fallback: Fallback, +} + impl Default for Router where S: Clone + Send + Sync + 'static, @@ -88,18 +90,43 @@ where impl fmt::Debug for Router { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("Router") - .field("path_router", &self.path_router) - .field("fallback_router", &self.fallback_router) - .field("default_fallback", &self.default_fallback) - .field("catch_all_fallback", &self.catch_all_fallback) + .field("path_router", &self.inner.path_router) + .field("fallback_router", &self.inner.fallback_router) + .field("default_fallback", &self.inner.default_fallback) + .field("catch_all_fallback", &self.inner.catch_all_fallback) .finish() } } pub(crate) const NEST_TAIL_PARAM: &str = "__private__axum_nest_tail_param"; -pub(crate) const NEST_TAIL_PARAM_CAPTURE: &str = "/*__private__axum_nest_tail_param"; +pub(crate) const NEST_TAIL_PARAM_CAPTURE: &str = "/{*__private__axum_nest_tail_param}"; pub(crate) const FALLBACK_PARAM: &str = "__private__axum_fallback"; -pub(crate) const FALLBACK_PARAM_PATH: &str = "/*__private__axum_fallback"; +pub(crate) const FALLBACK_PARAM_PATH: &str = "/{*__private__axum_fallback}"; + +macro_rules! map_inner { + ( $self_:ident, $inner:pat_param => $expr:expr) => { + #[allow(redundant_semicolons)] + { + let $inner = $self_.into_inner(); + Router { + inner: Arc::new($expr), + } + } + }; +} + +macro_rules! tap_inner { + ( $self_:ident, mut $inner:ident => { $($stmt:stmt)* } ) => { + #[allow(redundant_semicolons)] + { + let mut $inner = $self_.into_inner(); + $($stmt)* + Router { + inner: Arc::new($inner), + } + } + }; +} impl Router where @@ -111,24 +138,46 @@ where /// all requests. pub fn new() -> Self { Self { - path_router: Default::default(), - fallback_router: PathRouter::new_fallback(), - default_fallback: true, - catch_all_fallback: Fallback::Default(Route::new(NotFound)), + inner: Arc::new(RouterInner { + path_router: Default::default(), + fallback_router: PathRouter::new_fallback(), + default_fallback: true, + catch_all_fallback: Fallback::Default(Route::new(NotFound)), + }), } } + fn into_inner(self) -> RouterInner { + match Arc::try_unwrap(self.inner) { + Ok(inner) => inner, + Err(arc) => RouterInner { + path_router: arc.path_router.clone(), + fallback_router: arc.fallback_router.clone(), + default_fallback: arc.default_fallback, + catch_all_fallback: arc.catch_all_fallback.clone(), + }, + } + } + + #[doc = include_str!("../docs/routing/without_v07_checks.md")] + pub fn without_v07_checks(self) -> Self { + tap_inner!(self, mut this => { + this.path_router.without_v07_checks(); + }) + } + #[doc = include_str!("../docs/routing/route.md")] #[track_caller] - pub fn route(mut self, path: &str, method_router: MethodRouter) -> Self { - panic_on_err!(self.path_router.route(path, method_router)); - self + pub fn route(self, path: &str, method_router: MethodRouter) -> Self { + tap_inner!(self, mut this => { + panic_on_err!(this.path_router.route(path, method_router)); + }) } #[doc = include_str!("../docs/routing/route_service.md")] - pub fn route_service(mut self, path: &str, service: T) -> Self + pub fn route_service(self, path: &str, service: T) -> Self where - T: Service + Clone + Send + 'static, + T: Service + Clone + Send + Sync + 'static, T::Response: IntoResponse, T::Future: Send + 'static, { @@ -142,14 +191,20 @@ where Err(service) => service, }; - panic_on_err!(self.path_router.route_service(path, service)); - self + tap_inner!(self, mut this => { + panic_on_err!(this.path_router.route_service(path, service)); + }) } #[doc = include_str!("../docs/routing/nest.md")] + #[doc(alias = "scope")] // Some web frameworks like actix-web use this term #[track_caller] - pub fn nest(mut self, path: &str, router: Router) -> Self { - let Router { + pub fn nest(self, path: &str, router: Router) -> Self { + if path.is_empty() || path == "/" { + panic!("Nesting at the root is no longer supported. Use merge instead."); + } + + let RouterInner { path_router, fallback_router, default_fallback, @@ -157,167 +212,200 @@ where // requests with an empty path. If we were to inherit the catch-all fallback // it would end up matching `/{path}/*` which doesn't match empty paths. catch_all_fallback: _, - } = router; - - panic_on_err!(self.path_router.nest(path, path_router)); + } = router.into_inner(); - if !default_fallback { - panic_on_err!(self.fallback_router.nest(path, fallback_router)); - } + tap_inner!(self, mut this => { + panic_on_err!(this.path_router.nest(path, path_router)); - self + if !default_fallback { + panic_on_err!(this.fallback_router.nest(path, fallback_router)); + } + }) } /// Like [`nest`](Self::nest), but accepts an arbitrary `Service`. #[track_caller] - pub fn nest_service(mut self, path: &str, service: T) -> Self + pub fn nest_service(self, path: &str, service: T) -> Self where - T: Service + Clone + Send + 'static, + T: Service + Clone + Send + Sync + 'static, T::Response: IntoResponse, T::Future: Send + 'static, { - panic_on_err!(self.path_router.nest_service(path, service)); - self + if path.is_empty() || path == "/" { + panic!("Nesting at the root is no longer supported. Use fallback_service instead."); + } + + tap_inner!(self, mut this => { + panic_on_err!(this.path_router.nest_service(path, service)); + }) } #[doc = include_str!("../docs/routing/merge.md")] #[track_caller] - pub fn merge(mut self, other: R) -> Self + pub fn merge(self, other: R) -> Self where R: Into>, { const PANIC_MSG: &str = "Failed to merge fallbacks. This is a bug in axum. Please file an issue"; - let Router { + let other: Router = other.into(); + let RouterInner { path_router, fallback_router: mut other_fallback, default_fallback, catch_all_fallback, - } = other.into(); - - panic_on_err!(self.path_router.merge(path_router)); - - match (self.default_fallback, default_fallback) { - // both have the default fallback - // use the one from other - (true, true) => { - self.fallback_router.merge(other_fallback).expect(PANIC_MSG); - } - // self has default fallback, other has a custom fallback - (true, false) => { - self.fallback_router.merge(other_fallback).expect(PANIC_MSG); - self.default_fallback = false; - } - // self has a custom fallback, other has a default - (false, true) => { - let fallback_router = std::mem::take(&mut self.fallback_router); - other_fallback.merge(fallback_router).expect(PANIC_MSG); - self.fallback_router = other_fallback; - } - // both have a custom fallback, not allowed - (false, false) => { - panic!("Cannot merge two `Router`s that both have a fallback") - } - }; - - self.catch_all_fallback = self - .catch_all_fallback - .merge(catch_all_fallback) - .unwrap_or_else(|| panic!("Cannot merge two `Router`s that both have a fallback")); - - self + } = other.into_inner(); + + map_inner!(self, mut this => { + panic_on_err!(this.path_router.merge(path_router)); + + match (this.default_fallback, default_fallback) { + // both have the default fallback + // use the one from other + (true, true) => { + this.fallback_router.merge(other_fallback).expect(PANIC_MSG); + } + // this has default fallback, other has a custom fallback + (true, false) => { + this.fallback_router.merge(other_fallback).expect(PANIC_MSG); + this.default_fallback = false; + } + // this has a custom fallback, other has a default + (false, true) => { + let fallback_router = std::mem::take(&mut this.fallback_router); + other_fallback.merge(fallback_router).expect(PANIC_MSG); + this.fallback_router = other_fallback; + } + // both have a custom fallback, not allowed + (false, false) => { + panic!("Cannot merge two `Router`s that both have a fallback") + } + }; + + this.catch_all_fallback = this + .catch_all_fallback + .merge(catch_all_fallback) + .unwrap_or_else(|| panic!("Cannot merge two `Router`s that both have a fallback")); + + this + }) } #[doc = include_str!("../docs/routing/layer.md")] pub fn layer(self, layer: L) -> Router where - L: Layer + Clone + Send + 'static, - L::Service: Service + Clone + Send + 'static, + L: Layer + Clone + Send + Sync + 'static, + L::Service: Service + Clone + Send + Sync + 'static, >::Response: IntoResponse + 'static, >::Error: Into + 'static, >::Future: Send + 'static, { - Router { - path_router: self.path_router.layer(layer.clone()), - fallback_router: self.fallback_router.layer(layer.clone()), - default_fallback: self.default_fallback, - catch_all_fallback: self.catch_all_fallback.map(|route| route.layer(layer)), - } + map_inner!(self, this => RouterInner { + path_router: this.path_router.layer(layer.clone()), + fallback_router: this.fallback_router.layer(layer.clone()), + default_fallback: this.default_fallback, + catch_all_fallback: this.catch_all_fallback.map(|route| route.layer(layer)), + }) } #[doc = include_str!("../docs/routing/route_layer.md")] #[track_caller] pub fn route_layer(self, layer: L) -> Self where - L: Layer + Clone + Send + 'static, - L::Service: Service + Clone + Send + 'static, + L: Layer + Clone + Send + Sync + 'static, + L::Service: Service + Clone + Send + Sync + 'static, >::Response: IntoResponse + 'static, >::Error: Into + 'static, >::Future: Send + 'static, { - Router { - path_router: self.path_router.route_layer(layer), - fallback_router: self.fallback_router, - default_fallback: self.default_fallback, - catch_all_fallback: self.catch_all_fallback, - } + map_inner!(self, this => RouterInner { + path_router: this.path_router.route_layer(layer), + fallback_router: this.fallback_router, + default_fallback: this.default_fallback, + catch_all_fallback: this.catch_all_fallback, + }) + } + + /// True if the router currently has at least one route added. + pub fn has_routes(&self) -> bool { + self.inner.path_router.has_routes() } #[track_caller] #[doc = include_str!("../docs/routing/fallback.md")] - pub fn fallback(mut self, handler: H) -> Self + pub fn fallback(self, handler: H) -> Self where H: Handler, T: 'static, { - self.catch_all_fallback = - Fallback::BoxedHandler(BoxedIntoRoute::from_handler(handler.clone())); - self.fallback_endpoint(Endpoint::MethodRouter(any(handler))) + tap_inner!(self, mut this => { + this.catch_all_fallback = + Fallback::BoxedHandler(BoxedIntoRoute::from_handler(handler.clone())); + }) + .fallback_endpoint(Endpoint::MethodRouter(any(handler))) } /// Add a fallback [`Service`] to the router. /// /// See [`Router::fallback`] for more details. - pub fn fallback_service(mut self, service: T) -> Self + pub fn fallback_service(self, service: T) -> Self where - T: Service + Clone + Send + 'static, + T: Service + Clone + Send + Sync + 'static, T::Response: IntoResponse, T::Future: Send + 'static, { let route = Route::new(service); - self.catch_all_fallback = Fallback::Service(route.clone()); - self.fallback_endpoint(Endpoint::Route(route)) + tap_inner!(self, mut this => { + this.catch_all_fallback = Fallback::Service(route.clone()); + }) + .fallback_endpoint(Endpoint::Route(route)) } - fn fallback_endpoint(mut self, endpoint: Endpoint) -> Self { - self.fallback_router.set_fallback(endpoint); - self.default_fallback = false; - self + #[doc = include_str!("../docs/routing/method_not_allowed_fallback.md")] + pub fn method_not_allowed_fallback(self, handler: H) -> Self + where + H: Handler, + T: 'static, + { + tap_inner!(self, mut this => { + this.path_router + .method_not_allowed_fallback(handler.clone()) + }) + } + + fn fallback_endpoint(self, endpoint: Endpoint) -> Self { + tap_inner!(self, mut this => { + this.fallback_router.set_fallback(endpoint); + this.default_fallback = false; + }) } #[doc = include_str!("../docs/routing/with_state.md")] pub fn with_state(self, state: S) -> Router { - Router { - path_router: self.path_router.with_state(state.clone()), - fallback_router: self.fallback_router.with_state(state.clone()), - default_fallback: self.default_fallback, - catch_all_fallback: self.catch_all_fallback.with_state(state), - } + map_inner!(self, this => RouterInner { + path_router: this.path_router.with_state(state.clone()), + fallback_router: this.fallback_router.with_state(state.clone()), + default_fallback: this.default_fallback, + catch_all_fallback: this.catch_all_fallback.with_state(state), + }) } - pub(crate) fn call_with_state(&mut self, req: Request, state: S) -> RouteFuture { - let (req, state) = match self.path_router.call_with_state(req, state) { + pub(crate) fn call_with_state(&self, req: Request, state: S) -> RouteFuture { + let (req, state) = match self.inner.path_router.call_with_state(req, state) { Ok(future) => return future, Err((req, state)) => (req, state), }; - let (req, state) = match self.fallback_router.call_with_state(req, state) { + let (req, state) = match self.inner.fallback_router.call_with_state(req, state) { Ok(future) => return future, Err((req, state)) => (req, state), }; - self.catch_all_fallback.call_with_state(req, state) + self.inner + .catch_all_fallback + .clone() + .call_with_state(req, state) } /// Convert the router into a borrowed [`Service`] with a fixed request body type, to aid type @@ -442,7 +530,9 @@ const _: () = { } fn call(&mut self, _req: IncomingStream<'_>) -> Self::Future { - std::future::ready(Ok(self.clone())) + // call `Router::with_state` such that everything is turned into `Route` eagerly + // rather than doing that per request + std::future::ready(Ok(self.clone().with_state(()))) } } }; @@ -476,7 +566,7 @@ pub struct RouterAsService<'a, B, S = ()> { _marker: PhantomData, } -impl<'a, B> Service> for RouterAsService<'a, B, ()> +impl Service> for RouterAsService<'_, B, ()> where B: HttpBody + Send + 'static, B::Error: Into, @@ -496,7 +586,7 @@ where } } -impl<'a, B, S> fmt::Debug for RouterAsService<'a, B, S> +impl fmt::Debug for RouterAsService<'_, B, S> where S: fmt::Debug, { @@ -580,7 +670,7 @@ where where S: 'static, E: 'static, - F: FnOnce(Route) -> Route + Clone + Send + 'static, + F: FnOnce(Route) -> Route + Clone + Send + Sync + 'static, E2: 'static, { match self { @@ -598,14 +688,12 @@ where } } - fn call_with_state(&mut self, req: Request, state: S) -> RouteFuture { + fn call_with_state(self, req: Request, state: S) -> RouteFuture { match self { - Fallback::Default(route) | Fallback::Service(route) => { - RouteFuture::from_future(route.oneshot_inner(req)) - } + Fallback::Default(route) | Fallback::Service(route) => route.oneshot_inner_owned(req), Fallback::BoxedHandler(handler) => { - let mut route = handler.clone().into_route(state); - RouteFuture::from_future(route.oneshot_inner(req)) + let route = handler.clone().into_route(state); + route.oneshot_inner_owned(req) } } } @@ -643,8 +731,8 @@ where { fn layer(self, layer: L) -> Endpoint where - L: Layer + Clone + Send + 'static, - L::Service: Service + Clone + Send + 'static, + L: Layer + Clone + Send + Sync + 'static, + L::Service: Service + Clone + Send + Sync + 'static, >::Response: IntoResponse + 'static, >::Error: Into + 'static, >::Future: Send + 'static, @@ -682,4 +770,5 @@ impl fmt::Debug for Endpoint { fn traits() { use crate::test_helpers::*; assert_send::>(); + assert_sync::>(); } diff --git a/axum/src/routing/path_router.rs b/axum/src/routing/path_router.rs index b4ef4cb412..68ab4d9e5d 100644 --- a/axum/src/routing/path_router.rs +++ b/axum/src/routing/path_router.rs @@ -1,4 +1,7 @@ -use crate::extract::{nested_path::SetNestedPath, Request}; +use crate::{ + extract::{nested_path::SetNestedPath, Request}, + handler::Handler, +}; use axum_core::response::IntoResponse; use matchit::MatchError; use std::{borrow::Cow, collections::HashMap, convert::Infallible, fmt, sync::Arc}; @@ -14,6 +17,7 @@ pub(super) struct PathRouter { routes: HashMap>, node: Arc, prev_route_id: RouteId, + v7_checks: bool, } impl PathRouter @@ -32,26 +36,56 @@ where } } +fn validate_path(v7_checks: bool, path: &str) -> Result<(), &'static str> { + if path.is_empty() { + return Err("Paths must start with a `/`. Use \"/\" for root routes"); + } else if !path.starts_with('/') { + return Err("Paths must start with a `/`"); + } + + if v7_checks { + validate_v07_paths(path)?; + } + + Ok(()) +} + +fn validate_v07_paths(path: &str) -> Result<(), &'static str> { + path.split('/') + .find_map(|segment| { + if segment.starts_with(':') { + Some(Err( + "Path segments must not start with `:`. For capture groups, use \ + `{capture}`. If you meant to literally match a segment starting with \ + a colon, call `without_v07_checks` on the router.", + )) + } else if segment.starts_with('*') { + Some(Err( + "Path segments must not start with `*`. For wildcard capture, use \ + `{*wildcard}`. If you meant to literally match a segment starting with \ + an asterisk, call `without_v07_checks` on the router.", + )) + } else { + None + } + }) + .unwrap_or(Ok(())) +} + impl PathRouter where S: Clone + Send + Sync + 'static, { + pub(super) fn without_v07_checks(&mut self) { + self.v7_checks = false; + } + pub(super) fn route( &mut self, path: &str, method_router: MethodRouter, ) -> Result<(), Cow<'static, str>> { - fn validate_path(path: &str) -> Result<(), &'static str> { - if path.is_empty() { - return Err("Paths must start with a `/`. Use \"/\" for root routes"); - } else if !path.starts_with('/') { - return Err("Paths must start with a `/`"); - } - - Ok(()) - } - - validate_path(path)?; + validate_path(self.v7_checks, path)?; let endpoint = if let Some((route_id, Endpoint::MethodRouter(prev_method_router))) = self .node @@ -79,13 +113,25 @@ where Ok(()) } + pub(super) fn method_not_allowed_fallback(&mut self, handler: H) + where + H: Handler, + T: 'static, + { + for (_, endpoint) in self.routes.iter_mut() { + if let Endpoint::MethodRouter(rt) = endpoint { + *rt = rt.clone().default_fallback(handler.clone()); + } + } + } + pub(super) fn route_service( &mut self, path: &str, service: T, ) -> Result<(), Cow<'static, str>> where - T: Service + Clone + Send + 'static, + T: Service + Clone + Send + Sync + 'static, T::Response: IntoResponse, T::Future: Send + 'static, { @@ -97,11 +143,7 @@ where path: &str, endpoint: Endpoint, ) -> Result<(), Cow<'static, str>> { - if path.is_empty() { - return Err("Paths must start with a `/`. Use \"/\" for root routes".into()); - } else if !path.starts_with('/') { - return Err("Paths must start with a `/`".into()); - } + validate_path(self.v7_checks, path)?; let id = self.next_route_id(); self.set_node(path, id)?; @@ -111,13 +153,10 @@ where } fn set_node(&mut self, path: &str, id: RouteId) -> Result<(), String> { - let mut node = - Arc::try_unwrap(Arc::clone(&self.node)).unwrap_or_else(|node| (*node).clone()); - if let Err(err) = node.insert(path, id) { - return Err(format!("Invalid route {path:?}: {err}")); - } - self.node = Arc::new(node); - Ok(()) + let node = Arc::make_mut(&mut self.node); + + node.insert(path, id) + .map_err(|err| format!("Invalid route {path:?}: {err}")) } pub(super) fn merge( @@ -128,8 +167,12 @@ where routes, node, prev_route_id: _, + v7_checks, } = other; + // If either of the two did not allow paths starting with `:` or `*`, do not allow them for the merged router either. + self.v7_checks |= v7_checks; + for (id, route) in routes { let path = node .route_id_to_path @@ -165,12 +208,14 @@ where path_to_nest_at: &str, router: PathRouter, ) -> Result<(), Cow<'static, str>> { - let prefix = validate_nest_path(path_to_nest_at); + let prefix = validate_nest_path(self.v7_checks, path_to_nest_at); let PathRouter { routes, node, prev_route_id: _, + // Ignore the configuration of the nested router + v7_checks: _, } = router; for (id, endpoint) in routes { @@ -204,17 +249,17 @@ where svc: T, ) -> Result<(), Cow<'static, str>> where - T: Service + Clone + Send + 'static, + T: Service + Clone + Send + Sync + 'static, T::Response: IntoResponse, T::Future: Send + 'static, { - let path = validate_nest_path(path_to_nest_at); + let path = validate_nest_path(self.v7_checks, path_to_nest_at); let prefix = path; let path = if path.ends_with('/') { - format!("{path}*{NEST_TAIL_PARAM}") + format!("{path}{{*{NEST_TAIL_PARAM}}}") } else { - format!("{path}/*{NEST_TAIL_PARAM}") + format!("{path}/{{*{NEST_TAIL_PARAM}}}") }; let layer = ( @@ -225,7 +270,7 @@ where self.route_endpoint(&path, endpoint.clone())?; - // `/*rest` is not matched by `/` so we need to also register a router at the + // `/{*rest}` is not matched by `/` so we need to also register a router at the // prefix itself. Otherwise if you were to nest at `/foo` then `/foo` itself // wouldn't match, which it should self.route_endpoint(prefix, endpoint.clone())?; @@ -239,8 +284,8 @@ where pub(super) fn layer(self, layer: L) -> PathRouter where - L: Layer + Clone + Send + 'static, - L::Service: Service + Clone + Send + 'static, + L: Layer + Clone + Send + Sync + 'static, + L::Service: Service + Clone + Send + Sync + 'static, >::Response: IntoResponse + 'static, >::Error: Into + 'static, >::Future: Send + 'static, @@ -258,14 +303,15 @@ where routes, node: self.node, prev_route_id: self.prev_route_id, + v7_checks: self.v7_checks, } } #[track_caller] pub(super) fn route_layer(self, layer: L) -> Self where - L: Layer + Clone + Send + 'static, - L::Service: Service + Clone + Send + 'static, + L: Layer + Clone + Send + Sync + 'static, + L::Service: Service + Clone + Send + Sync + 'static, >::Response: IntoResponse + 'static, >::Error: Into + 'static, >::Future: Send + 'static, @@ -290,9 +336,14 @@ where routes, node: self.node, prev_route_id: self.prev_route_id, + v7_checks: self.v7_checks, } } + pub(super) fn has_routes(&self) -> bool { + !self.routes.is_empty() + } + pub(super) fn with_state(self, state: S) -> PathRouter { let routes = self .routes @@ -312,11 +363,12 @@ where routes, node: self.node, prev_route_id: self.prev_route_id, + v7_checks: self.v7_checks, } } pub(super) fn call_with_state( - &mut self, + &self, mut req: Request, state: S, ) -> Result, (Request, S)> { @@ -330,9 +382,9 @@ where } } - let path = req.uri().path().to_owned(); + let (mut parts, body) = req.into_parts(); - match self.node.at(&path) { + match self.node.at(parts.uri.path()) { Ok(match_) => { let id = *match_.value; @@ -341,31 +393,28 @@ where crate::extract::matched_path::set_matched_path_for_request( id, &self.node.route_id_to_path, - req.extensions_mut(), + &mut parts.extensions, ); } - url_params::insert_url_params(req.extensions_mut(), match_.params); + url_params::insert_url_params(&mut parts.extensions, match_.params); let endpoint = self .routes - .get_mut(&id) + .get(&id) .expect("no route for id. This is a bug in axum. Please file an issue"); + let req = Request::from_parts(parts, body); match endpoint { Endpoint::MethodRouter(method_router) => { Ok(method_router.call_with_state(req, state)) } - Endpoint::Route(route) => Ok(route.clone().call(req)), + Endpoint::Route(route) => Ok(route.clone().call_owned(req)), } } // explicitly handle all variants in case matchit adds // new ones we need to handle differently - Err( - MatchError::NotFound - | MatchError::ExtraTrailingSlash - | MatchError::MissingTrailingSlash, - ) => Err((req, state)), + Err(MatchError::NotFound) => Err((Request::from_parts(parts, body), state)), } } @@ -398,6 +447,7 @@ impl Default for PathRouter { routes: Default::default(), node: Default::default(), prev_route_id: RouteId(0), + v7_checks: true, } } } @@ -417,6 +467,7 @@ impl Clone for PathRouter { routes: self.routes.clone(), node: self.node.clone(), prev_route_id: self.prev_route_id, + v7_checks: self.v7_checks, } } } @@ -463,16 +514,20 @@ impl fmt::Debug for Node { } #[track_caller] -fn validate_nest_path(path: &str) -> &str { - if path.is_empty() { - // nesting at `""` and `"/"` should mean the same thing - return "/"; - } +fn validate_nest_path(v7_checks: bool, path: &str) -> &str { + assert!(path.starts_with('/')); + assert!(path.len() > 1); - if path.contains('*') { + if path.split('/').any(|segment| { + segment.starts_with("{*") && segment.ends_with('}') && !segment.ends_with("}}") + }) { panic!("Invalid route: nested routes cannot contain wildcards (*)"); } + if v7_checks { + validate_v07_paths(path).unwrap(); + } + path } diff --git a/axum/src/routing/route.rs b/axum/src/routing/route.rs index fd63d0e8d0..30492c388c 100644 --- a/axum/src/routing/route.rs +++ b/axum/src/routing/route.rs @@ -1,12 +1,13 @@ use crate::{ body::{Body, HttpBody}, + box_clone_service::BoxCloneService, response::Response, }; use axum_core::{extract::Request, response::IntoResponse}; use bytes::Bytes; use http::{ header::{self, CONTENT_LENGTH}, - HeaderMap, HeaderValue, + HeaderMap, HeaderValue, Method, }; use pin_project_lite::pin_project; use std::{ @@ -14,10 +15,10 @@ use std::{ fmt, future::Future, pin::Pin, - task::{Context, Poll}, + task::{ready, Context, Poll}, }; use tower::{ - util::{BoxCloneService, MapErrLayer, MapRequestLayer, MapResponseLayer, Oneshot}, + util::{MapErrLayer, MapResponseLayer, Oneshot}, ServiceExt, }; use tower_layer::Layer; @@ -32,7 +33,7 @@ pub struct Route(BoxCloneService); impl Route { pub(crate) fn new(svc: T) -> Self where - T: Service + Clone + Send + 'static, + T: Service + Clone + Send + Sync + 'static, T::Response: IntoResponse + 'static, T::Future: Send + 'static, { @@ -41,24 +42,33 @@ impl Route { )) } - pub(crate) fn oneshot_inner( - &mut self, - req: Request, - ) -> Oneshot, Request> { - self.0.clone().oneshot(req) + /// Variant of [`Route::call`] that takes ownership of the route to avoid cloning. + pub(crate) fn call_owned(self, req: Request) -> RouteFuture { + let req = req.map(Body::new); + self.oneshot_inner_owned(req).not_top_level() + } + + pub(crate) fn oneshot_inner(&mut self, req: Request) -> RouteFuture { + let method = req.method().clone(); + RouteFuture::new(method, self.0.clone().oneshot(req)) + } + + /// Variant of [`Route::oneshot_inner`] that takes ownership of the route to avoid cloning. + pub(crate) fn oneshot_inner_owned(self, req: Request) -> RouteFuture { + let method = req.method().clone(); + RouteFuture::new(method, self.0.oneshot(req)) } pub(crate) fn layer(self, layer: L) -> Route where L: Layer> + Clone + Send + 'static, - L::Service: Service + Clone + Send + 'static, + L::Service: Service + Clone + Send + Sync + 'static, >::Response: IntoResponse + 'static, >::Error: Into + 'static, >::Future: Send + 'static, NewError: 'static, { let layer = ( - MapRequestLayer::new(|req: Request<_>| req.map(Body::new)), MapErrLayer::new(Into::into), MapResponseLayer::new(IntoResponse::into_response), layer, @@ -69,6 +79,7 @@ impl Route { } impl Clone for Route { + #[track_caller] fn clone(&self) -> Self { Self(self.0.clone()) } @@ -96,8 +107,7 @@ where #[inline] fn call(&mut self, req: Request) -> Self::Future { - let req = req.map(Body::new); - RouteFuture::from_future(self.oneshot_inner(req)) + self.oneshot_inner(req.map(Body::new)).not_top_level() } } @@ -105,46 +115,30 @@ pin_project! { /// Response future for [`Route`]. pub struct RouteFuture { #[pin] - kind: RouteFutureKind, - strip_body: bool, + inner: Oneshot, Request>, + method: Method, allow_header: Option, - } -} - -pin_project! { - #[project = RouteFutureKindProj] - enum RouteFutureKind { - Future { - #[pin] - future: Oneshot< - BoxCloneService, - Request, - >, - }, - Response { - response: Option, - } + top_level: bool, } } impl RouteFuture { - pub(crate) fn from_future( - future: Oneshot, Request>, - ) -> Self { + fn new(method: Method, inner: Oneshot, Request>) -> Self { Self { - kind: RouteFutureKind::Future { future }, - strip_body: false, + inner, + method, allow_header: None, + top_level: true, } } - pub(crate) fn strip_body(mut self, strip_body: bool) -> Self { - self.strip_body = strip_body; + pub(crate) fn allow_header(mut self, allow_header: Bytes) -> Self { + self.allow_header = Some(allow_header); self } - pub(crate) fn allow_header(mut self, allow_header: Bytes) -> Self { - self.allow_header = Some(allow_header); + pub(crate) fn not_top_level(mut self) -> Self { + self.top_level = false; self } } @@ -155,28 +149,30 @@ impl Future for RouteFuture { #[inline] fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { let this = self.project(); - - let mut res = match this.kind.project() { - RouteFutureKindProj::Future { future } => match future.poll(cx) { - Poll::Ready(Ok(res)) => res, - Poll::Ready(Err(err)) => return Poll::Ready(Err(err)), - Poll::Pending => return Poll::Pending, - }, - RouteFutureKindProj::Response { response } => { - response.take().expect("future polled after completion") + let mut res = ready!(this.inner.poll(cx))?; + + if *this.method == Method::CONNECT && res.status().is_success() { + // From https://httpwg.org/specs/rfc9110.html#CONNECT: + // > A server MUST NOT send any Transfer-Encoding or + // > Content-Length header fields in a 2xx (Successful) + // > response to CONNECT. + if res.headers().contains_key(&CONTENT_LENGTH) + || res.headers().contains_key(&header::TRANSFER_ENCODING) + || res.size_hint().lower() != 0 + { + error!("response to CONNECT with nonempty body"); + res = res.map(|_| Body::empty()); } - }; + } else if *this.top_level { + set_allow_header(res.headers_mut(), this.allow_header); - set_allow_header(res.headers_mut(), this.allow_header); + // make sure to set content-length before removing the body + set_content_length(res.size_hint(), res.headers_mut()); - // make sure to set content-length before removing the body - set_content_length(res.size_hint(), res.headers_mut()); - - let res = if *this.strip_body { - res.map(|_| Body::empty()) - } else { - res - }; + if *this.method == Method::HEAD { + *res.body_mut() = Body::empty(); + } + } Poll::Ready(Ok(res)) } diff --git a/axum/src/routing/strip_prefix.rs b/axum/src/routing/strip_prefix.rs index 0b06db4d28..3209da3b12 100644 --- a/axum/src/routing/strip_prefix.rs +++ b/axum/src/routing/strip_prefix.rs @@ -56,7 +56,7 @@ fn strip_prefix(uri: &Uri, prefix: &str) -> Option { // ^^^^ this much is matched and the length is 4. Thus if we chop off the first 4 // characters we get the remainder // - // prefix = /api/:version + // prefix = /api/{version} // path = /api/v0/users // ^^^^^^^ this much is matched and the length is 7. let mut matching_prefix_length = Some(0); @@ -66,7 +66,7 @@ fn strip_prefix(uri: &Uri, prefix: &str) -> Option { match item { Item::Both(path_segment, prefix_segment) => { - if prefix_segment.starts_with(':') || path_segment == prefix_segment { + if is_capture(prefix_segment) || path_segment == prefix_segment { // the prefix segment is either a param, which matches anything, or // it actually matches the path segment *matching_prefix_length.as_mut().unwrap() += path_segment.len(); @@ -104,7 +104,7 @@ fn strip_prefix(uri: &Uri, prefix: &str) -> Option { } // if the prefix matches it will always do so up until a `/`, it cannot match only - // part of a segment. Therefore this will always be at a char boundary and `split_at` wont + // part of a segment. Therefore this will always be at a char boundary and `split_at` won't // panic let after_prefix = uri.path().split_at(matching_prefix_length?).1; @@ -148,6 +148,14 @@ where }) } +fn is_capture(segment: &str) -> bool { + segment.starts_with('{') + && segment.ends_with('}') + && !segment.starts_with("{{") + && !segment.ends_with("}}") + && !segment.starts_with("{*") +} + #[derive(Debug)] enum Item { Both(T, T), @@ -279,74 +287,89 @@ mod tests { expected = Some("/"), ); - test!(param_0, uri = "/", prefix = "/:param", expected = Some("/"),); + test!( + param_0, + uri = "/", + prefix = "/{param}", + expected = Some("/"), + ); test!( param_1, uri = "/a", - prefix = "/:param", + prefix = "/{param}", expected = Some("/"), ); test!( param_2, uri = "/a/b", - prefix = "/:param", + prefix = "/{param}", expected = Some("/b"), ); test!( param_3, uri = "/b/a", - prefix = "/:param", + prefix = "/{param}", expected = Some("/a"), ); test!( param_4, uri = "/a/b", - prefix = "/a/:param", + prefix = "/a/{param}", expected = Some("/"), ); - test!(param_5, uri = "/b/a", prefix = "/a/:param", expected = None,); + test!( + param_5, + uri = "/b/a", + prefix = "/a/{param}", + expected = None, + ); - test!(param_6, uri = "/a/b", prefix = "/:param/a", expected = None,); + test!( + param_6, + uri = "/a/b", + prefix = "/{param}/a", + expected = None, + ); test!( param_7, uri = "/b/a", - prefix = "/:param/a", + prefix = "/{param}/a", expected = Some("/"), ); test!( param_8, uri = "/a/b/c", - prefix = "/a/:param/c", + prefix = "/a/{param}/c", expected = Some("/"), ); test!( param_9, uri = "/c/b/a", - prefix = "/a/:param/c", + prefix = "/a/{param}/c", expected = None, ); test!( param_10, uri = "/a/", - prefix = "/:param", + prefix = "/{param}", expected = Some("/"), ); - test!(param_11, uri = "/a", prefix = "/:param/", expected = None,); + test!(param_11, uri = "/a", prefix = "/{param}/", expected = None,); test!( param_12, uri = "/a/", - prefix = "/:param/", + prefix = "/{param}/", expected = Some("/"), ); diff --git a/axum/src/routing/tests/fallback.rs b/axum/src/routing/tests/fallback.rs index d0800467b8..3c8755bb85 100644 --- a/axum/src/routing/tests/fallback.rs +++ b/axum/src/routing/tests/fallback.rs @@ -9,9 +9,9 @@ async fn basic() { let client = TestClient::new(app); - assert_eq!(client.get("/foo").send().await.status(), StatusCode::OK); + assert_eq!(client.get("/foo").await.status(), StatusCode::OK); - let res = client.get("/does-not-exist").send().await; + let res = client.get("/does-not-exist").await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "fallback"); } @@ -24,9 +24,21 @@ async fn nest() { let client = TestClient::new(app); - assert_eq!(client.get("/foo/bar").send().await.status(), StatusCode::OK); + assert_eq!(client.get("/foo/bar").await.status(), StatusCode::OK); - let res = client.get("/does-not-exist").send().await; + let res = client.get("/does-not-exist").await; + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.text().await, "fallback"); +} + +#[crate::test] +async fn two() { + let app = Router::new() + .route("/first", get(|| async {})) + .route("/second", get(|| async {})) + .fallback(get(|| async { "fallback" })); + let client = TestClient::new(app); + let res = client.get("/does-not-exist").await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "fallback"); } @@ -40,10 +52,10 @@ async fn or() { let client = TestClient::new(app); - assert_eq!(client.get("/one").send().await.status(), StatusCode::OK); - assert_eq!(client.get("/two").send().await.status(), StatusCode::OK); + assert_eq!(client.get("/one").await.status(), StatusCode::OK); + assert_eq!(client.get("/two").await.status(), StatusCode::OK); - let res = client.get("/does-not-exist").send().await; + let res = client.get("/does-not-exist").await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "fallback"); } @@ -56,7 +68,7 @@ async fn fallback_accessing_state() { let client = TestClient::new(app); - let res = client.get("/does-not-exist").send().await; + let res = client.get("/does-not-exist").await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "state"); } @@ -76,7 +88,7 @@ async fn nested_router_inherits_fallback() { let client = TestClient::new(app); - let res = client.get("/foo/bar").send().await; + let res = client.get("/foo/bar").await; assert_eq!(res.status(), StatusCode::NOT_FOUND); assert_eq!(res.text().await, "outer"); } @@ -88,11 +100,11 @@ async fn doesnt_inherit_fallback_if_overridden() { let client = TestClient::new(app); - let res = client.get("/foo/bar").send().await; + let res = client.get("/foo/bar").await; assert_eq!(res.status(), StatusCode::NOT_FOUND); assert_eq!(res.text().await, "inner"); - let res = client.get("/").send().await; + let res = client.get("/").await; assert_eq!(res.status(), StatusCode::NOT_FOUND); assert_eq!(res.text().await, "outer"); } @@ -105,7 +117,7 @@ async fn deeply_nested_inherit_from_top() { let client = TestClient::new(app); - let res = client.get("/foo/bar/baz").send().await; + let res = client.get("/foo/bar/baz").await; assert_eq!(res.status(), StatusCode::NOT_FOUND); assert_eq!(res.text().await, "outer"); } @@ -121,7 +133,7 @@ async fn deeply_nested_inherit_from_middle() { let client = TestClient::new(app); - let res = client.get("/foo/bar/baz").send().await; + let res = client.get("/foo/bar/baz").await; assert_eq!(res.status(), StatusCode::NOT_FOUND); assert_eq!(res.text().await, "outer"); } @@ -137,7 +149,7 @@ async fn with_middleware_on_inner_fallback() { let client = TestClient::new(app); - let res = client.get("/foo/bar").send().await; + let res = client.get("/foo/bar").await; assert_eq!(res.status(), StatusCode::NOT_FOUND); assert_eq!(res.text().await, "outer"); } @@ -158,7 +170,7 @@ async fn also_inherits_default_layered_fallback() { let client = TestClient::new(app); - let res = client.get("/foo/bar").send().await; + let res = client.get("/foo/bar").await; assert_eq!(res.status(), StatusCode::NOT_FOUND); assert_eq!(res.headers()["x-from-fallback"], "1"); assert_eq!(res.text().await, "outer"); @@ -177,7 +189,7 @@ async fn nest_fallback_on_inner() { let client = TestClient::new(app); - let res = client.get("/foo/not-found").send().await; + let res = client.get("/foo/not-found").await; assert_eq!(res.status(), StatusCode::NOT_FOUND); assert_eq!(res.text().await, "inner fallback"); } @@ -188,13 +200,13 @@ async fn doesnt_panic_if_used_with_nested_router() { async fn handler() {} let routes_static = - Router::new().nest_service("/", crate::routing::get_service(handler.into_service())); + Router::new().nest_service("/foo", crate::routing::get_service(handler.into_service())); let routes_all = Router::new().fallback_service(routes_static); let client = TestClient::new(routes_all); - let res = client.get("/foobar").send().await; + let res = client.get("/foo/bar").await; assert_eq!(res.status(), StatusCode::OK); } @@ -208,11 +220,11 @@ async fn issue_2072() { let client = TestClient::new(app); - let res = client.get("/nested/does-not-exist").send().await; + let res = client.get("/nested/does-not-exist").await; assert_eq!(res.status(), StatusCode::NOT_FOUND); assert_eq!(res.text().await, "inner"); - let res = client.get("/does-not-exist").send().await; + let res = client.get("/does-not-exist").await; assert_eq!(res.status(), StatusCode::NOT_FOUND); assert_eq!(res.text().await, ""); } @@ -228,11 +240,11 @@ async fn issue_2072_outer_fallback_before_merge() { let client = TestClient::new(app); - let res = client.get("/nested/does-not-exist").send().await; + let res = client.get("/nested/does-not-exist").await; assert_eq!(res.status(), StatusCode::NOT_FOUND); assert_eq!(res.text().await, "inner"); - let res = client.get("/does-not-exist").send().await; + let res = client.get("/does-not-exist").await; assert_eq!(res.status(), StatusCode::NOT_FOUND); assert_eq!(res.text().await, "outer"); } @@ -248,11 +260,11 @@ async fn issue_2072_outer_fallback_after_merge() { let client = TestClient::new(app); - let res = client.get("/nested/does-not-exist").send().await; + let res = client.get("/nested/does-not-exist").await; assert_eq!(res.status(), StatusCode::NOT_FOUND); assert_eq!(res.text().await, "inner"); - let res = client.get("/does-not-exist").send().await; + let res = client.get("/does-not-exist").await; assert_eq!(res.status(), StatusCode::NOT_FOUND); assert_eq!(res.text().await, "outer"); } @@ -267,11 +279,11 @@ async fn merge_router_with_fallback_into_nested_router_with_fallback() { let client = TestClient::new(app); - let res = client.get("/nested/does-not-exist").send().await; + let res = client.get("/nested/does-not-exist").await; assert_eq!(res.status(), StatusCode::NOT_FOUND); assert_eq!(res.text().await, "inner"); - let res = client.get("/does-not-exist").send().await; + let res = client.get("/does-not-exist").await; assert_eq!(res.status(), StatusCode::NOT_FOUND); assert_eq!(res.text().await, "outer"); } @@ -286,11 +298,11 @@ async fn merging_nested_router_with_fallback_into_router_with_fallback() { let client = TestClient::new(app); - let res = client.get("/nested/does-not-exist").send().await; + let res = client.get("/nested/does-not-exist").await; assert_eq!(res.status(), StatusCode::NOT_FOUND); assert_eq!(res.text().await, "inner"); - let res = client.get("/does-not-exist").send().await; + let res = client.get("/does-not-exist").await; assert_eq!(res.status(), StatusCode::NOT_FOUND); assert_eq!(res.text().await, "outer"); } @@ -301,7 +313,7 @@ async fn merge_empty_into_router_with_fallback() { let client = TestClient::new(app); - let res = client.get("/does-not-exist").send().await; + let res = client.get("/does-not-exist").await; assert_eq!(res.status(), StatusCode::NOT_FOUND); assert_eq!(res.text().await, "outer"); } @@ -312,7 +324,70 @@ async fn merge_router_with_fallback_into_empty() { let client = TestClient::new(app); - let res = client.get("/does-not-exist").send().await; + let res = client.get("/does-not-exist").await; assert_eq!(res.status(), StatusCode::NOT_FOUND); assert_eq!(res.text().await, "outer"); } + +#[crate::test] +async fn mna_fallback_with_existing_fallback() { + let app = Router::new() + .route( + "/", + get(|| async { "test" }).fallback(|| async { "index fallback" }), + ) + .route("/path", get(|| async { "path" })) + .method_not_allowed_fallback(|| async { "method not allowed fallback" }); + + let client = TestClient::new(app); + let index_fallback = client.post("/").await; + let method_not_allowed_fallback = client.post("/path").await; + + assert_eq!(index_fallback.text().await, "index fallback"); + assert_eq!( + method_not_allowed_fallback.text().await, + "method not allowed fallback" + ); +} + +#[crate::test] +async fn mna_fallback_with_state() { + let app = Router::new() + .route("/", get(|| async { "index" })) + .method_not_allowed_fallback(|State(state): State<&'static str>| async move { state }) + .with_state("state"); + + let client = TestClient::new(app); + let res = client.post("/").await; + assert_eq!(res.text().await, "state"); +} + +#[crate::test] +async fn mna_fallback_with_unused_state() { + let app = Router::new() + .route("/", get(|| async { "index" })) + .with_state(()) + .method_not_allowed_fallback(|| async move { "bla" }); + + let client = TestClient::new(app); + let res = client.post("/").await; + assert_eq!(res.text().await, "bla"); +} + +#[crate::test] +async fn state_isnt_cloned_too_much_with_fallback() { + let state = CountingCloneableState::new(); + + let app = Router::new() + .fallback(|_: State| async {}) + .with_state(state.clone()); + + let client = TestClient::new(app); + + // ignore clones made during setup + state.setup_done(); + + client.get("/does-not-exist").await; + + assert_eq!(state.count(), 3); +} diff --git a/axum/src/routing/tests/get_to_head.rs b/axum/src/routing/tests/get_to_head.rs index 811e5390c7..b20e8cd032 100644 --- a/axum/src/routing/tests/get_to_head.rs +++ b/axum/src/routing/tests/get_to_head.rs @@ -4,7 +4,6 @@ use tower::ServiceExt; mod for_handlers { use super::*; - use http::HeaderMap; #[crate::test] async fn get_handles_head() { @@ -39,7 +38,6 @@ mod for_handlers { mod for_services { use super::*; - use crate::routing::get_service; #[crate::test] async fn get_handles_head() { diff --git a/axum/src/routing/tests/handle_error.rs b/axum/src/routing/tests/handle_error.rs index 9b81a20f1d..a2fd2e6828 100644 --- a/axum/src/routing/tests/handle_error.rs +++ b/axum/src/routing/tests/handle_error.rs @@ -1,5 +1,5 @@ use super::*; -use std::future::{pending, ready}; +use std::future::pending; use tower::timeout::TimeoutLayer; async fn unit() {} @@ -12,23 +12,6 @@ fn timeout() -> TimeoutLayer { TimeoutLayer::new(Duration::from_millis(10)) } -#[derive(Clone)] -struct Svc; - -impl Service for Svc { - type Response = Response; - type Error = hyper::Error; - type Future = Ready>; - - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, _req: R) -> Self::Future { - ready(Ok(Response::new(Body::empty()))) - } -} - #[crate::test] async fn handler() { let app = Router::new().route( @@ -41,7 +24,7 @@ async fn handler() { let client = TestClient::new(app); - let res = client.get("/").send().await; + let res = client.get("/").await; assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT); } @@ -58,7 +41,7 @@ async fn handler_multiple_methods_first() { let client = TestClient::new(app); - let res = client.get("/").send().await; + let res = client.get("/").await; assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT); } @@ -76,7 +59,7 @@ async fn handler_multiple_methods_middle() { let client = TestClient::new(app); - let res = client.get("/").send().await; + let res = client.get("/").await; assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT); } @@ -92,7 +75,7 @@ async fn handler_multiple_methods_last() { let client = TestClient::new(app); - let res = client.get("/").send().await; + let res = client.get("/").await; assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT); } @@ -106,6 +89,6 @@ async fn handler_service_ext() { let client = TestClient::new(app); - let res = client.get("/").send().await; + let res = client.get("/").await; assert_eq!(res.status(), StatusCode::INTERNAL_SERVER_ERROR); } diff --git a/axum/src/routing/tests/merge.rs b/axum/src/routing/tests/merge.rs index 44b0ce8df7..b760184f54 100644 --- a/axum/src/routing/tests/merge.rs +++ b/axum/src/routing/tests/merge.rs @@ -1,8 +1,7 @@ use super::*; -use crate::{extract::OriginalUri, response::IntoResponse, Json}; +use crate::extract::OriginalUri; use serde_json::{json, Value}; use tower::limit::ConcurrencyLimitLayer; -use tower_http::timeout::TimeoutLayer; #[crate::test] async fn basic() { @@ -14,16 +13,16 @@ async fn basic() { let client = TestClient::new(app); - let res = client.get("/foo").send().await; + let res = client.get("/foo").await; assert_eq!(res.status(), StatusCode::OK); - let res = client.get("/bar").send().await; + let res = client.get("/bar").await; assert_eq!(res.status(), StatusCode::OK); - let res = client.get("/baz").send().await; + let res = client.get("/baz").await; assert_eq!(res.status(), StatusCode::OK); - let res = client.get("/qux").send().await; + let res = client.get("/qux").await; assert_eq!(res.status(), StatusCode::NOT_FOUND); } @@ -65,7 +64,7 @@ async fn multiple_ors_balanced_differently() { for n in ["one", "two", "three", "four"].iter() { println!("running: {name} / {n}"); - let res = client.get(&format!("/{n}")).send().await; + let res = client.get(&format!("/{n}")).await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, *n); } @@ -80,12 +79,12 @@ async fn nested_or() { let bar_or_baz = bar.merge(baz); let client = TestClient::new(bar_or_baz.clone()); - assert_eq!(client.get("/bar").send().await.text().await, "bar"); - assert_eq!(client.get("/baz").send().await.text().await, "baz"); + assert_eq!(client.get("/bar").await.text().await, "bar"); + assert_eq!(client.get("/baz").await.text().await, "baz"); let client = TestClient::new(Router::new().nest("/foo", bar_or_baz)); - assert_eq!(client.get("/foo/bar").send().await.text().await, "bar"); - assert_eq!(client.get("/foo/baz").send().await.text().await, "baz"); + assert_eq!(client.get("/foo/bar").await.text().await, "bar"); + assert_eq!(client.get("/foo/baz").await.text().await, "baz"); } #[crate::test] @@ -96,13 +95,13 @@ async fn or_with_route_following() { let client = TestClient::new(app); - let res = client.get("/one").send().await; + let res = client.get("/one").await; assert_eq!(res.status(), StatusCode::OK); - let res = client.get("/two").send().await; + let res = client.get("/two").await; assert_eq!(res.status(), StatusCode::OK); - let res = client.get("/three").send().await; + let res = client.get("/three").await; assert_eq!(res.status(), StatusCode::OK); } @@ -116,10 +115,10 @@ async fn layer() { let client = TestClient::new(app); - let res = client.get("/foo").send().await; + let res = client.get("/foo").await; assert_eq!(res.status(), StatusCode::OK); - let res = client.get("/bar").send().await; + let res = client.get("/bar").await; assert_eq!(res.status(), StatusCode::OK); } @@ -133,9 +132,9 @@ async fn layer_and_handle_error() { let client = TestClient::new(app); - let res = client.get("/timeout").send().await; + let res = client.get("/timeout").await; assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT); - let res = client.get("/foo").send().await; + let res = client.get("/foo").await; assert_eq!(res.status(), StatusCode::OK); } @@ -147,7 +146,7 @@ async fn nesting() { let client = TestClient::new(app); - let res = client.get("/bar/baz").send().await; + let res = client.get("/bar/baz").await; assert_eq!(res.status(), StatusCode::OK); } @@ -159,7 +158,7 @@ async fn boxed() { let client = TestClient::new(app); - let res = client.get("/bar").send().await; + let res = client.get("/bar").await; assert_eq!(res.status(), StatusCode::OK); } @@ -177,11 +176,11 @@ async fn many_ors() { let client = TestClient::new(app); for n in 1..=7 { - let res = client.get(&format!("/r{n}")).send().await; + let res = client.get(&format!("/r{n}")).await; assert_eq!(res.status(), StatusCode::OK); } - let res = client.get("/r8").send().await; + let res = client.get("/r8").await; assert_eq!(res.status(), StatusCode::NOT_FOUND); } @@ -205,10 +204,10 @@ async fn services() { let client = TestClient::new(app); - let res = client.get("/foo").send().await; + let res = client.get("/foo").await; assert_eq!(res.status(), StatusCode::OK); - let res = client.get("/bar").send().await; + let res = client.get("/bar").await; assert_eq!(res.status(), StatusCode::OK); } @@ -231,7 +230,7 @@ async fn nesting_and_seeing_the_right_uri() { let client = TestClient::new(one.merge(two)); - let res = client.get("/foo/bar").send().await; + let res = client.get("/foo/bar").await; assert_eq!(res.status(), StatusCode::OK); assert_eq!( res.json::().await, @@ -242,7 +241,7 @@ async fn nesting_and_seeing_the_right_uri() { }) ); - let res = client.get("/foo").send().await; + let res = client.get("/foo").await; assert_eq!(res.status(), StatusCode::OK); assert_eq!( res.json::().await, @@ -264,7 +263,7 @@ async fn nesting_and_seeing_the_right_uri_at_more_levels_of_nesting() { let client = TestClient::new(one.merge(two)); - let res = client.get("/foo/bar/baz").send().await; + let res = client.get("/foo/bar/baz").await; assert_eq!(res.status(), StatusCode::OK); assert_eq!( res.json::().await, @@ -275,7 +274,7 @@ async fn nesting_and_seeing_the_right_uri_at_more_levels_of_nesting() { }) ); - let res = client.get("/foo").send().await; + let res = client.get("/foo").await; assert_eq!(res.status(), StatusCode::OK); assert_eq!( res.json::().await, @@ -298,7 +297,7 @@ async fn nesting_and_seeing_the_right_uri_ors_with_nesting() { let client = TestClient::new(one.merge(two).merge(three)); - let res = client.get("/one/bar/baz").send().await; + let res = client.get("/one/bar/baz").await; assert_eq!(res.status(), StatusCode::OK); assert_eq!( res.json::().await, @@ -309,7 +308,7 @@ async fn nesting_and_seeing_the_right_uri_ors_with_nesting() { }) ); - let res = client.get("/two/qux").send().await; + let res = client.get("/two/qux").await; assert_eq!(res.status(), StatusCode::OK); assert_eq!( res.json::().await, @@ -320,7 +319,7 @@ async fn nesting_and_seeing_the_right_uri_ors_with_nesting() { }) ); - let res = client.get("/three").send().await; + let res = client.get("/three").await; assert_eq!(res.status(), StatusCode::OK); assert_eq!( res.json::().await, @@ -342,7 +341,7 @@ async fn nesting_and_seeing_the_right_uri_ors_with_multi_segment_uris() { let client = TestClient::new(one.merge(two)); - let res = client.get("/one/foo/bar").send().await; + let res = client.get("/one/foo/bar").await; assert_eq!(res.status(), StatusCode::OK); assert_eq!( res.json::().await, @@ -353,7 +352,7 @@ async fn nesting_and_seeing_the_right_uri_ors_with_multi_segment_uris() { }) ); - let res = client.get("/two/foo").send().await; + let res = client.get("/two/foo").await; assert_eq!(res.status(), StatusCode::OK); assert_eq!( res.json::().await, @@ -375,22 +374,18 @@ async fn middleware_that_return_early() { let client = TestClient::new(private.merge(public)); - assert_eq!( - client.get("/").send().await.status(), - StatusCode::UNAUTHORIZED - ); + assert_eq!(client.get("/").await.status(), StatusCode::UNAUTHORIZED); assert_eq!( client .get("/") .header("authorization", "Bearer password") - .send() .await .status(), StatusCode::OK ); assert_eq!( - client.get("/doesnt-exist").send().await.status(), + client.get("/doesnt-exist").await.status(), StatusCode::NOT_FOUND ); - assert_eq!(client.get("/public").send().await.status(), StatusCode::OK); + assert_eq!(client.get("/public").await.status(), StatusCode::OK); } diff --git a/axum/src/routing/tests/mod.rs b/axum/src/routing/tests/mod.rs index 6c00a6de67..1993373e8c 100644 --- a/axum/src/routing/tests/mod.rs +++ b/axum/src/routing/tests/mod.rs @@ -3,6 +3,7 @@ use crate::{ error_handling::HandleErrorLayer, extract::{self, DefaultBodyLimit, FromRef, Path, State}, handler::{Handler, HandlerWithoutStateExt}, + middleware::{self, Next}, response::{IntoResponse, Response}, routing::{ delete, get, get_service, on, on_service, patch, patch_service, @@ -15,6 +16,7 @@ use crate::{ BoxError, Extension, Json, Router, ServiceExt, }; use axum_core::extract::Request; +use counting_cloneable_state::CountingCloneableState; use futures_util::stream::StreamExt; use http::{ header::{ALLOW, CONTENT_LENGTH, HOST}, @@ -25,8 +27,8 @@ use serde::Deserialize; use serde_json::json; use std::{ convert::Infallible, - future::{ready, Ready}, - sync::atomic::{AtomicBool, AtomicUsize, Ordering}, + future::{ready, IntoFuture, Ready}, + sync::atomic::{AtomicUsize, Ordering}, task::{Context, Poll}, time::Duration, }; @@ -63,15 +65,15 @@ async fn hello_world() { let client = TestClient::new(app); - let res = client.get("/").send().await; + let res = client.get("/").await; let body = res.text().await; assert_eq!(body, "Hello, World!"); - let res = client.post("/").send().await; + let res = client.post("/").await; let body = res.text().await; assert_eq!(body, "foo"); - let res = client.post("/users").send().await; + let res = client.post("/users").await; let body = res.text().await; assert_eq!(body, "users#create"); } @@ -83,30 +85,30 @@ async fn routing() { "/users", get(|_: Request| async { "users#index" }).post(|_: Request| async { "users#create" }), ) - .route("/users/:id", get(|_: Request| async { "users#show" })) + .route("/users/{id}", get(|_: Request| async { "users#show" })) .route( - "/users/:id/action", + "/users/{id}/action", get(|_: Request| async { "users#action" }), ); let client = TestClient::new(app); - let res = client.get("/").send().await; + let res = client.get("/").await; assert_eq!(res.status(), StatusCode::NOT_FOUND); - let res = client.get("/users").send().await; + let res = client.get("/users").await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "users#index"); - let res = client.post("/users").send().await; + let res = client.post("/users").await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "users#create"); - let res = client.get("/users/1").send().await; + let res = client.get("/users/1").await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "users#show"); - let res = client.get("/users/1/action").send().await; + let res = client.get("/users/1/action").await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "users#action"); } @@ -123,11 +125,11 @@ async fn router_type_doesnt_change() { let client = TestClient::new(app); - let res = client.get("/").send().await; + let res = client.get("/").await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "hi from GET"); - let res = client.post("/").send().await; + let res = client.post("/").await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "hi from POST"); } @@ -161,19 +163,19 @@ async fn routing_between_services() { let client = TestClient::new(app); - let res = client.get("/one").send().await; + let res = client.get("/one").await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "one get"); - let res = client.post("/one").send().await; + let res = client.post("/one").await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "one post"); - let res = client.put("/one").send().await; + let res = client.put("/one").await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "one put"); - let res = client.get("/two").send().await; + let res = client.get("/two").await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "handler"); } @@ -190,7 +192,7 @@ async fn middleware_on_single_route() { let client = TestClient::new(app); - let res = client.get("/").send().await; + let res = client.get("/").await; let body = res.text().await; assert_eq!(body, "Hello, World!"); @@ -215,18 +217,18 @@ async fn wrong_method_handler() { let client = TestClient::new(app); - let res = client.patch("/").send().await; + let res = client.patch("/").await; assert_eq!(res.status(), StatusCode::METHOD_NOT_ALLOWED); assert_eq!(res.headers()[ALLOW], "GET,HEAD,POST"); - let res = client.patch("/foo").send().await; + let res = client.patch("/foo").await; assert_eq!(res.status(), StatusCode::OK); - let res = client.post("/foo").send().await; + let res = client.post("/foo").await; assert_eq!(res.status(), StatusCode::METHOD_NOT_ALLOWED); assert_eq!(res.headers()[ALLOW], "PATCH"); - let res = client.get("/bar").send().await; + let res = client.get("/bar").await; assert_eq!(res.status(), StatusCode::NOT_FOUND); } @@ -255,18 +257,18 @@ async fn wrong_method_service() { let client = TestClient::new(app); - let res = client.patch("/").send().await; + let res = client.patch("/").await; assert_eq!(res.status(), StatusCode::METHOD_NOT_ALLOWED); assert_eq!(res.headers()[ALLOW], "GET,HEAD,POST"); - let res = client.patch("/foo").send().await; + let res = client.patch("/foo").await; assert_eq!(res.status(), StatusCode::OK); - let res = client.post("/foo").send().await; + let res = client.post("/foo").await; assert_eq!(res.status(), StatusCode::METHOD_NOT_ALLOWED); assert_eq!(res.headers()[ALLOW], "PATCH"); - let res = client.get("/bar").send().await; + let res = client.get("/bar").await; assert_eq!(res.status(), StatusCode::NOT_FOUND); } @@ -280,20 +282,23 @@ async fn multiple_methods_for_one_handler() { let client = TestClient::new(app); - let res = client.get("/").send().await; + let res = client.get("/").await; assert_eq!(res.status(), StatusCode::OK); - let res = client.post("/").send().await; + let res = client.post("/").await; assert_eq!(res.status(), StatusCode::OK); } #[crate::test] async fn wildcard_sees_whole_url() { - let app = Router::new().route("/api/*rest", get(|uri: Uri| async move { uri.to_string() })); + let app = Router::new().route( + "/api/{*rest}", + get(|uri: Uri| async move { uri.to_string() }), + ); let client = TestClient::new(app); - let res = client.get("/api/foo/bar").send().await; + let res = client.get("/api/foo/bar").await; assert_eq!(res.text().await, "/api/foo/bar"); } @@ -306,10 +311,10 @@ async fn middleware_applies_to_routes_above() { let client = TestClient::new(app); - let res = client.get("/one").send().await; + let res = client.get("/one").await; assert_eq!(res.status(), StatusCode::REQUEST_TIMEOUT); - let res = client.get("/two").send().await; + let res = client.get("/two").await; assert_eq!(res.status(), StatusCode::OK); } @@ -319,10 +324,10 @@ async fn not_found_for_extra_trailing_slash() { let client = TestClient::new(app); - let res = client.get("/foo/").send().await; + let res = client.get("/foo/").await; assert_eq!(res.status(), StatusCode::NOT_FOUND); - let res = client.get("/foo").send().await; + let res = client.get("/foo").await; assert_eq!(res.status(), StatusCode::OK); } @@ -332,7 +337,7 @@ async fn not_found_for_missing_trailing_slash() { let client = TestClient::new(app); - let res = client.get("/foo").send().await; + let res = client.get("/foo").await; assert_eq!(res.status(), StatusCode::NOT_FOUND); } @@ -344,11 +349,11 @@ async fn with_and_without_trailing_slash() { let client = TestClient::new(app); - let res = client.get("/foo/").send().await; + let res = client.get("/foo/").await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "with tsr"); - let res = client.get("/foo").send().await; + let res = client.get("/foo").await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "without tsr"); } @@ -357,19 +362,19 @@ async fn with_and_without_trailing_slash() { #[crate::test] async fn wildcard_doesnt_match_just_trailing_slash() { let app = Router::new().route( - "/x/*path", + "/x/{*path}", get(|Path(path): Path| async move { path }), ); let client = TestClient::new(app); - let res = client.get("/x").send().await; + let res = client.get("/x").await; assert_eq!(res.status(), StatusCode::NOT_FOUND); - let res = client.get("/x/").send().await; + let res = client.get("/x/").await; assert_eq!(res.status(), StatusCode::NOT_FOUND); - let res = client.get("/x/foo/bar").send().await; + let res = client.get("/x/foo/bar").await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "foo/bar"); } @@ -377,14 +382,14 @@ async fn wildcard_doesnt_match_just_trailing_slash() { #[crate::test] async fn what_matches_wildcard() { let app = Router::new() - .route("/*key", get(|| async { "root" })) - .route("/x/*key", get(|| async { "x" })) + .route("/{*key}", get(|| async { "root" })) + .route("/x/{*key}", get(|| async { "x" })) .fallback(|| async { "fallback" }); let client = TestClient::new(app); let get = |path| { - let f = client.get(path).send(); + let f = client.get(path); async move { f.await.text().await } }; @@ -406,17 +411,17 @@ async fn what_matches_wildcard() { async fn static_and_dynamic_paths() { let app = Router::new() .route( - "/:key", + "/{key}", get(|Path(key): Path| async move { format!("dynamic: {key}") }), ) .route("/foo", get(|| async { "static" })); let client = TestClient::new(app); - let res = client.get("/bar").send().await; + let res = client.get("/bar").await; assert_eq!(res.text().await, "dynamic: bar"); - let res = client.get("/foo").send().await; + let res = client.get("/foo").await; assert_eq!(res.text().await, "static"); } @@ -460,10 +465,10 @@ async fn middleware_still_run_for_unmatched_requests() { assert_eq!(COUNT.load(Ordering::SeqCst), 0); - client.get("/").send().await; + client.get("/").await; assert_eq!(COUNT.load(Ordering::SeqCst), 1); - client.get("/not-found").send().await; + client.get("/not-found").await; assert_eq!(COUNT.load(Ordering::SeqCst), 2); } @@ -487,20 +492,19 @@ async fn route_layer() { let res = client .get("/foo") .header("authorization", "Bearer password") - .send() .await; assert_eq!(res.status(), StatusCode::OK); - let res = client.get("/foo").send().await; + let res = client.get("/foo").await; assert_eq!(res.status(), StatusCode::UNAUTHORIZED); - let res = client.get("/not-found").send().await; + let res = client.get("/not-found").await; assert_eq!(res.status(), StatusCode::NOT_FOUND); // it would be nice if this would return `405 Method Not Allowed` // but that requires knowing more about which method route we're calling, which we - // don't know currently since its just a generic `Service` - let res = client.post("/foo").send().await; + // don't know currently since it's just a generic `Service` + let res = client.post("/foo").await; assert_eq!(res.status(), StatusCode::UNAUTHORIZED); } @@ -512,11 +516,11 @@ async fn different_methods_added_in_different_routes() { let client = TestClient::new(app); - let res = client.get("/").send().await; + let res = client.get("/").await; let body = res.text().await; assert_eq!(body, "GET"); - let res = client.post("/").send().await; + let res = client.post("/").await; let body = res.text().await; assert_eq!(body, "POST"); } @@ -554,11 +558,11 @@ async fn merging_routers_with_same_paths_but_different_methods() { let client = TestClient::new(one.merge(two)); - let res = client.get("/").send().await; + let res = client.get("/").await; let body = res.text().await; assert_eq!(body, "GET"); - let res = client.post("/").send().await; + let res = client.post("/").await; let body = res.text().await; assert_eq!(body, "POST"); } @@ -571,11 +575,11 @@ async fn head_content_length_through_hyper_server() { let client = TestClient::new(app); - let res = client.head("/").send().await; + let res = client.head("/").await; assert_eq!(res.headers()["content-length"], "3"); assert!(res.text().await.is_empty()); - let res = client.head("/json").send().await; + let res = client.head("/json").await; assert_eq!(res.headers()["content-length"], "9"); assert!(res.text().await.is_empty()); } @@ -586,7 +590,7 @@ async fn head_content_length_through_hyper_server_that_hits_fallback() { let client = TestClient::new(app); - let res = client.head("/").send().await; + let res = client.head("/").await; assert_eq!(res.headers()["content-length"], "3"); } @@ -596,7 +600,7 @@ async fn head_with_middleware_applied() { let app = Router::new() .nest( - "/", + "/foo", Router::new().route("/", get(|| async { "Hello, World!" })), ) .layer(CompressionLayer::new().compress_when(SizeAbove::new(0))); @@ -604,21 +608,13 @@ async fn head_with_middleware_applied() { let client = TestClient::new(app); // send GET request - let res = client - .get("/") - .header("accept-encoding", "gzip") - .send() - .await; + let res = client.get("/foo").header("accept-encoding", "gzip").await; assert_eq!(res.headers()["transfer-encoding"], "chunked"); // cannot have `transfer-encoding: chunked` and `content-length` assert!(!res.headers().contains_key("content-length")); // send HEAD request - let res = client - .head("/") - .header("accept-encoding", "gzip") - .send() - .await; + let res = client.head("/foo").header("accept-encoding", "gzip").await; // no response body so no `transfer-encoding` assert!(!res.headers().contains_key("transfer-encoding")); // no content-length since we cannot know it since the response @@ -652,7 +648,7 @@ async fn body_limited_by_default() { .post(uri) .header("content-type", "application/json") .body(body) - .send(); + .into_future(); let res = tokio::time::timeout(Duration::from_secs(3), res_future) .await .expect("never got response"); @@ -672,7 +668,7 @@ async fn disabling_the_default_limit() { // `DEFAULT_LIMIT` is 2mb so make a body larger than that let body = reqwest::Body::from("a".repeat(3_000_000)); - let res = client.post("/").body(body).send().await; + let res = client.post("/").body(body).await; assert_eq!(res.status(), StatusCode::OK); } @@ -692,10 +688,10 @@ async fn limited_body_with_content_length() { let client = TestClient::new(app); - let res = client.post("/").body("a".repeat(LIMIT)).send().await; + let res = client.post("/").body("a".repeat(LIMIT)).await; assert_eq!(res.status(), StatusCode::OK); - let res = client.post("/").body("a".repeat(LIMIT * 2)).send().await; + let res = client.post("/").body("a".repeat(LIMIT * 2)).await; assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE); } @@ -712,14 +708,12 @@ async fn changing_the_default_limit() { let res = client .post("/") .body(reqwest::Body::from("a".repeat(new_limit))) - .send() .await; assert_eq!(res.status(), StatusCode::OK); let res = client .post("/") .body(reqwest::Body::from("a".repeat(new_limit + 1))) - .send() .await; assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE); } @@ -745,42 +739,36 @@ async fn changing_the_default_limit_differently_on_different_routes() { let res = client .post("/limit1") .body(reqwest::Body::from("a".repeat(limit1))) - .send() .await; assert_eq!(res.status(), StatusCode::OK); let res = client .post("/limit1") .body(reqwest::Body::from("a".repeat(limit2))) - .send() .await; assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE); let res = client .post("/limit2") .body(reqwest::Body::from("a".repeat(limit1))) - .send() .await; assert_eq!(res.status(), StatusCode::OK); let res = client .post("/limit2") .body(reqwest::Body::from("a".repeat(limit2))) - .send() .await; assert_eq!(res.status(), StatusCode::OK); let res = client .post("/limit2") .body(reqwest::Body::from("a".repeat(limit1 + limit2))) - .send() .await; assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE); let res = client .post("/default") .body(reqwest::Body::from("a".repeat(limit1 + limit2))) - .send() .await; assert_eq!(res.status(), StatusCode::OK); @@ -788,7 +776,6 @@ async fn changing_the_default_limit_differently_on_different_routes() { .post("/default") // `DEFAULT_LIMIT` is 2mb so make a body larger than that .body(reqwest::Body::from("a".repeat(3_000_000))) - .send() .await; assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE); } @@ -812,7 +799,6 @@ async fn limited_body_with_streaming_body() { let res = client .post("/") .body(reqwest::Body::wrap_stream(stream)) - .send() .await; assert_eq!(res.status(), StatusCode::OK); @@ -820,7 +806,6 @@ async fn limited_body_with_streaming_body() { let res = client .post("/") .body(reqwest::Body::wrap_stream(stream)) - .send() .await; assert_eq!(res.status(), StatusCode::PAYLOAD_TOO_LARGE); } @@ -857,7 +842,7 @@ async fn extract_state() { let app = Router::new().route("/", get(handler)).with_state(state); let client = TestClient::new(app); - let res = client.get("/").send().await; + let res = client.get("/").await; assert_eq!(res.status(), StatusCode::OK); } @@ -871,7 +856,7 @@ async fn explicitly_set_state() { .with_state("..."); let client = TestClient::new(app); - let res = client.get("/").send().await; + let res = client.get("/").await; assert_eq!(res.text().await, "foo"); } @@ -889,7 +874,7 @@ async fn layer_response_into_response() { let client = TestClient::new(app); - let res = client.get("/").send().await; + let res = client.get("/").await; assert_eq!(res.headers()["x-foo"], "bar"); assert_eq!(res.status(), StatusCode::IM_A_TEAPOT); } @@ -924,54 +909,40 @@ fn test_path_for_nested_route() { #[crate::test] async fn state_isnt_cloned_too_much() { - static SETUP_DONE: AtomicBool = AtomicBool::new(false); - static COUNT: AtomicUsize = AtomicUsize::new(0); + let state = CountingCloneableState::new(); - struct AppState; - - impl Clone for AppState { - fn clone(&self) -> Self { - #[rustversion::since(1.66)] - #[track_caller] - fn count() { - if SETUP_DONE.load(Ordering::SeqCst) { - let bt = std::backtrace::Backtrace::force_capture(); - let bt = bt - .to_string() - .lines() - .filter(|line| line.contains("axum") || line.contains("./src")) - .collect::>() - .join("\n"); - println!("AppState::Clone:\n===============\n{bt}\n"); - COUNT.fetch_add(1, Ordering::SeqCst); - } - } - - #[rustversion::not(since(1.66))] - fn count() { - if SETUP_DONE.load(Ordering::SeqCst) { - COUNT.fetch_add(1, Ordering::SeqCst); - } - } - - count(); - - Self - } + let app = Router::new() + .route("/", get(|_: State| async {})) + .with_state(state.clone()); + + let client = TestClient::new(app); + + // ignore clones made during setup + state.setup_done(); + + client.get("/").await; + + assert_eq!(state.count(), 3); +} + +#[crate::test] +async fn state_isnt_cloned_too_much_in_layer() { + async fn layer(State(_): State, req: Request, next: Next) -> Response { + next.run(req).await } - let app = Router::new() - .route("/", get(|_: State| async {})) - .with_state(AppState); + let state = CountingCloneableState::new(); + + let app = Router::new().layer(middleware::from_fn_with_state(state.clone(), layer)); let client = TestClient::new(app); // ignore clones made during setup - SETUP_DONE.store(true, Ordering::SeqCst); + state.setup_done(); - client.get("/").send().await; + client.get("/").await; - assert_eq!(COUNT.load(Ordering::SeqCst), 5); + assert_eq!(state.count(), 3); } #[crate::test] @@ -985,7 +956,7 @@ async fn logging_rejections() { rejection_type: String, } - let events = capture_tracing::(|| async { + let events = capture_tracing::(|| async { let app = Router::new() .route("/extension", get(|_: Extension| async {})) .route("/string", post(|_: String| async {})); @@ -993,7 +964,7 @@ async fn logging_rejections() { let client = TestClient::new(app); assert_eq!( - client.get("/extension").send().await.status(), + client.get("/extension").await.status(), StatusCode::INTERNAL_SERVER_ERROR ); @@ -1001,12 +972,12 @@ async fn logging_rejections() { client .post("/string") .body(Vec::from([0, 159, 146, 150])) - .send() .await .status(), StatusCode::BAD_REQUEST, ); }) + .with_filter("axum::rejection=trace") .await; assert_eq!( @@ -1083,7 +1054,48 @@ async fn impl_handler_for_into_response() { let client = TestClient::new(app); - let res = client.post("/things").send().await; + let res = client.post("/things").await; assert_eq!(res.status(), StatusCode::CREATED); assert_eq!(res.text().await, "thing created"); } + +#[crate::test] +#[should_panic( + expected = "Path segments must not start with `:`. For capture groups, use `{capture}`. If you meant to literally match a segment starting with a colon, call `without_v07_checks` on the router." +)] +async fn colon_in_route() { + _ = Router::<()>::new().route("/:foo", get(|| async move {})); +} + +#[crate::test] +#[should_panic( + expected = "Path segments must not start with `*`. For wildcard capture, use `{*wildcard}`. If you meant to literally match a segment starting with an asterisk, call `without_v07_checks` on the router." +)] +async fn asterisk_in_route() { + _ = Router::<()>::new().route("/*foo", get(|| async move {})); +} + +#[crate::test] +async fn middleware_adding_body() { + let app = Router::new() + .route("/", get(())) + .layer(MapResponseLayer::new(|mut res: Response| -> Response { + // If there is a content-length header, its value will be zero and Axum will avoid + // overwriting it. But this means our content-length doesn’t match the length of the + // body, which leads to panics in Hyper. Thus we have to ensure that Axum doesn’t add + // on content-length headers until after middleware has been run. + assert!(!res.headers().contains_key("content-length")); + *res.body_mut() = "…".into(); + res + })); + + let client = TestClient::new(app); + let res = client.get("/").await; + + let headers = res.headers(); + let header = headers.get("content-length"); + assert!(header.is_some()); + assert_eq!(header.unwrap().to_str().unwrap(), "3"); + + assert_eq!(res.text().await, "…"); +} diff --git a/axum/src/routing/tests/nest.rs b/axum/src/routing/tests/nest.rs index a9119eb3cf..6e14203662 100644 --- a/axum/src/routing/tests/nest.rs +++ b/axum/src/routing/tests/nest.rs @@ -1,5 +1,4 @@ use super::*; -use crate::extract::Extension; use std::collections::HashMap; use tower_http::services::ServeDir; @@ -11,7 +10,7 @@ async fn nesting_apps() { get(|| async { "users#index" }).post(|| async { "users#create" }), ) .route( - "/users/:id", + "/users/{id}", get( |params: extract::Path>| async move { format!( @@ -23,7 +22,7 @@ async fn nesting_apps() { ), ) .route( - "/games/:id", + "/games/{id}", get( |params: extract::Path>| async move { format!( @@ -37,23 +36,23 @@ async fn nesting_apps() { let app = Router::new() .route("/", get(|| async { "hi" })) - .nest("/:version/api", api_routes); + .nest("/{version}/api", api_routes); let client = TestClient::new(app); - let res = client.get("/").send().await; + let res = client.get("/").await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "hi"); - let res = client.get("/v0/api/users").send().await; + let res = client.get("/v0/api/users").await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "users#index"); - let res = client.get("/v0/api/users/123").send().await; + let res = client.get("/v0/api/users/123").await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "v0: users#show (123)"); - let res = client.get("/v0/api/games/123").send().await; + let res = client.get("/v0/api/games/123").await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "v0: games#show (123)"); } @@ -61,74 +60,49 @@ async fn nesting_apps() { #[crate::test] async fn wrong_method_nest() { let nested_app = Router::new().route("/", get(|| async {})); - let app = Router::new().nest("/", nested_app); + let app = Router::new().nest("/foo", nested_app); let client = TestClient::new(app); - let res = client.get("/").send().await; + let res = client.get("/foo").await; assert_eq!(res.status(), StatusCode::OK); - let res = client.post("/").send().await; + let res = client.post("/foo").await; assert_eq!(res.status(), StatusCode::METHOD_NOT_ALLOWED); assert_eq!(res.headers()[ALLOW], "GET,HEAD"); - let res = client.patch("/foo").send().await; + let res = client.patch("/foo/bar").await; assert_eq!(res.status(), StatusCode::NOT_FOUND); } -#[crate::test] -async fn nesting_router_at_root() { - let nested = Router::new().route("/foo", get(|uri: Uri| async move { uri.to_string() })); - let app = Router::new().nest("/", nested); - - let client = TestClient::new(app); - - let res = client.get("/").send().await; - assert_eq!(res.status(), StatusCode::NOT_FOUND); - - let res = client.get("/foo").send().await; - assert_eq!(res.status(), StatusCode::OK); - assert_eq!(res.text().await, "/foo"); - - let res = client.get("/foo/bar").send().await; - assert_eq!(res.status(), StatusCode::NOT_FOUND); +#[test] +#[should_panic(expected = "Nesting at the root is no longer supported. Use merge instead.")] +fn nest_router_at_root() { + let nested = Router::new().route("/foo", get(|| async {})); + let _: Router = Router::new().nest("/", nested); } -#[crate::test] -async fn nesting_router_at_empty_path() { - let nested = Router::new().route("/foo", get(|uri: Uri| async move { uri.to_string() })); - let app = Router::new().nest("", nested); - - let client = TestClient::new(app); - - let res = client.get("/").send().await; - assert_eq!(res.status(), StatusCode::NOT_FOUND); - - let res = client.get("/foo").send().await; - assert_eq!(res.status(), StatusCode::OK); - assert_eq!(res.text().await, "/foo"); - - let res = client.get("/foo/bar").send().await; - assert_eq!(res.status(), StatusCode::NOT_FOUND); +#[test] +#[should_panic(expected = "Nesting at the root is no longer supported. Use merge instead.")] +fn nest_router_at_empty_path() { + let nested = Router::new().route("/foo", get(|| async {})); + let _: Router = Router::new().nest("", nested); } -#[crate::test] -async fn nesting_handler_at_root() { - let app = Router::new().nest_service("/", get(|uri: Uri| async move { uri.to_string() })); - - let client = TestClient::new(app); - - let res = client.get("/").send().await; - assert_eq!(res.status(), StatusCode::OK); - assert_eq!(res.text().await, "/"); - - let res = client.get("/foo").send().await; - assert_eq!(res.status(), StatusCode::OK); - assert_eq!(res.text().await, "/foo"); +#[test] +#[should_panic( + expected = "Nesting at the root is no longer supported. Use fallback_service instead." +)] +fn nest_service_at_root() { + let _: Router = Router::new().nest_service("/", get(|| async {})); +} - let res = client.get("/foo/bar").send().await; - assert_eq!(res.status(), StatusCode::OK); - assert_eq!(res.text().await, "/foo/bar"); +#[test] +#[should_panic( + expected = "Nesting at the root is no longer supported. Use fallback_service instead." +)] +fn nest_service_at_empty_path() { + let _: Router = Router::new().nest_service("", get(|| async {})); } #[crate::test] @@ -148,11 +122,11 @@ async fn nested_url_extractor() { let client = TestClient::new(app); - let res = client.get("/foo/bar/baz").send().await; + let res = client.get("/foo/bar/baz").await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "/baz"); - let res = client.get("/foo/bar/qux").send().await; + let res = client.get("/foo/bar/qux").await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "/qux"); } @@ -172,7 +146,7 @@ async fn nested_url_original_extractor() { let client = TestClient::new(app); - let res = client.get("/foo/bar/baz").send().await; + let res = client.get("/foo/bar/baz").await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "/foo/bar/baz"); } @@ -195,7 +169,7 @@ async fn nested_service_sees_stripped_uri() { let client = TestClient::new(app); - let res = client.get("/foo/bar/baz").send().await; + let res = client.get("/foo/bar/baz").await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "/baz"); } @@ -206,7 +180,7 @@ async fn nest_static_file_server() { let client = TestClient::new(app); - let res = client.get("/static/README.md").send().await; + let res = client.get("/static/README.md").await; assert_eq!(res.status(), StatusCode::OK); } @@ -223,24 +197,9 @@ async fn nested_multiple_routes() { let client = TestClient::new(app); - assert_eq!(client.get("/").send().await.text().await, "root"); - assert_eq!(client.get("/api/users").send().await.text().await, "users"); - assert_eq!(client.get("/api/teams").send().await.text().await, "teams"); -} - -#[test] -#[should_panic = "Invalid route \"/\": insertion failed due to conflict with previously registered route: /*__private__axum_nest_tail_param"] -fn nested_service_at_root_with_other_routes() { - let _: Router = Router::new() - .nest_service("/", Router::new().route("/users", get(|| async {}))) - .route("/", get(|| async {})); -} - -#[test] -fn nested_at_root_with_other_routes() { - let _: Router = Router::new() - .nest("/", Router::new().route("/users", get(|| async {}))) - .route("/", get(|| async {})); + assert_eq!(client.get("/").await.text().await, "root"); + assert_eq!(client.get("/api/users").await.text().await, "users"); + assert_eq!(client.get("/api/teams").await.text().await, "teams"); } #[crate::test] @@ -257,14 +216,14 @@ async fn multiple_top_level_nests() { let client = TestClient::new(app); - assert_eq!(client.get("/one/route").send().await.text().await, "one"); - assert_eq!(client.get("/two/route").send().await.text().await, "two"); + assert_eq!(client.get("/one/route").await.text().await, "one"); + assert_eq!(client.get("/two/route").await.text().await, "two"); } #[crate::test] #[should_panic(expected = "Invalid route: nested routes cannot contain wildcards (*)")] async fn nest_cannot_contain_wildcards() { - _ = Router::<()>::new().nest("/one/*rest", Router::new()); + _ = Router::<()>::new().nest("/one/{*rest}", Router::new()); } #[crate::test] @@ -308,28 +267,25 @@ async fn outer_middleware_still_see_whole_url() { let client = TestClient::new(app); - assert_eq!(client.get("/").send().await.text().await, "/"); - assert_eq!(client.get("/foo").send().await.text().await, "/foo"); - assert_eq!(client.get("/foo/bar").send().await.text().await, "/foo/bar"); - assert_eq!( - client.get("/not-found").send().await.text().await, - "/not-found" - ); - assert_eq!(client.get("/one/two").send().await.text().await, "/one/two"); + assert_eq!(client.get("/").await.text().await, "/"); + assert_eq!(client.get("/foo").await.text().await, "/foo"); + assert_eq!(client.get("/foo/bar").await.text().await, "/foo/bar"); + assert_eq!(client.get("/not-found").await.text().await, "/not-found"); + assert_eq!(client.get("/one/two").await.text().await, "/one/two"); } #[crate::test] async fn nest_at_capture() { let api_routes = Router::new().route( - "/:b", + "/{b}", get(|Path((a, b)): Path<(String, String)>| async move { format!("a={a} b={b}") }), ); - let app = Router::new().nest("/:a", api_routes); + let app = Router::new().nest("/{a}", api_routes); let client = TestClient::new(app); - let res = client.get("/foo/bar").send().await; + let res = client.get("/foo/bar").await; assert_eq!(res.status(), StatusCode::OK); assert_eq!(res.text().await, "a=foo b=bar"); } @@ -340,13 +296,13 @@ async fn nest_with_and_without_trailing() { let client = TestClient::new(app); - let res = client.get("/foo").send().await; + let res = client.get("/foo").await; assert_eq!(res.status(), StatusCode::OK); - let res = client.get("/foo/").send().await; + let res = client.get("/foo/").await; assert_eq!(res.status(), StatusCode::OK); - let res = client.get("/foo/bar").send().await; + let res = client.get("/foo/bar").await; assert_eq!(res.status(), StatusCode::OK); } @@ -361,28 +317,28 @@ async fn nesting_with_root_inner_router() { // `/service/` does match the `/service` prefix and the remaining path is technically // empty, which is the same as `/` which matches `.route("/", _)` - let res = client.get("/service").send().await; + let res = client.get("/service").await; assert_eq!(res.status(), StatusCode::OK); // `/service/` does match the `/service` prefix and the remaining path is `/` // which matches `.route("/", _)` // // this is perhaps a little surprising but don't think there is much we can do - let res = client.get("/service/").send().await; + let res = client.get("/service/").await; assert_eq!(res.status(), StatusCode::OK); // at least it does work like you'd expect when using `nest` - let res = client.get("/router").send().await; + let res = client.get("/router").await; assert_eq!(res.status(), StatusCode::OK); - let res = client.get("/router/").send().await; + let res = client.get("/router/").await; assert_eq!(res.status(), StatusCode::NOT_FOUND); - let res = client.get("/router-slash").send().await; + let res = client.get("/router-slash").await; assert_eq!(res.status(), StatusCode::NOT_FOUND); - let res = client.get("/router-slash/").send().await; + let res = client.get("/router-slash/").await; assert_eq!(res.status(), StatusCode::OK); } @@ -401,7 +357,7 @@ macro_rules! nested_route_test { let inner = Router::new().route($route_path, get(|| async {})); let app = Router::new().nest($nested_path, inner); let client = TestClient::new(app); - let res = client.get($expected_path).send().await; + let res = client.get($expected_path).await; let status = res.status(); assert_eq!(status, StatusCode::OK, "Router"); } @@ -409,15 +365,25 @@ macro_rules! nested_route_test { } // test cases taken from https://github.com/tokio-rs/axum/issues/714#issuecomment-1058144460 -nested_route_test!(nest_1, nest = "", route = "/", expected = "/"); -nested_route_test!(nest_2, nest = "", route = "/a", expected = "/a"); -nested_route_test!(nest_3, nest = "", route = "/a/", expected = "/a/"); -nested_route_test!(nest_4, nest = "/", route = "/", expected = "/"); -nested_route_test!(nest_5, nest = "/", route = "/a", expected = "/a"); -nested_route_test!(nest_6, nest = "/", route = "/a/", expected = "/a/"); -nested_route_test!(nest_7, nest = "/a", route = "/", expected = "/a"); -nested_route_test!(nest_8, nest = "/a", route = "/a", expected = "/a/a"); -nested_route_test!(nest_9, nest = "/a", route = "/a/", expected = "/a/a/"); -nested_route_test!(nest_11, nest = "/a/", route = "/", expected = "/a/"); -nested_route_test!(nest_12, nest = "/a/", route = "/a", expected = "/a/a"); -nested_route_test!(nest_13, nest = "/a/", route = "/a/", expected = "/a/a/"); +nested_route_test!(nest_1, nest = "/a", route = "/", expected = "/a"); +nested_route_test!(nest_2, nest = "/a", route = "/a", expected = "/a/a"); +nested_route_test!(nest_3, nest = "/a", route = "/a/", expected = "/a/a/"); +nested_route_test!(nest_4, nest = "/a/", route = "/", expected = "/a/"); +nested_route_test!(nest_5, nest = "/a/", route = "/a", expected = "/a/a"); +nested_route_test!(nest_6, nest = "/a/", route = "/a/", expected = "/a/a/"); + +#[crate::test] +#[should_panic( + expected = "Path segments must not start with `:`. For capture groups, use `{capture}`. If you meant to literally match a segment starting with a colon, call `without_v07_checks` on the router." +)] +async fn colon_in_route() { + _ = Router::<()>::new().nest("/:foo", Router::new()); +} + +#[crate::test] +#[should_panic( + expected = "Path segments must not start with `*`. For wildcard capture, use `{*wildcard}`. If you meant to literally match a segment starting with an asterisk, call `without_v07_checks` on the router." +)] +async fn asterisk_in_route() { + _ = Router::<()>::new().nest("/*foo", Router::new()); +} diff --git a/axum/src/serve.rs b/axum/src/serve.rs index 9850af2787..87b103ee90 100644 --- a/axum/src/serve.rs +++ b/axum/src/serve.rs @@ -7,25 +7,21 @@ use std::{ io, marker::PhantomData, net::SocketAddr, - pin::Pin, sync::Arc, - task::{Context, Poll}, time::Duration, }; use axum_core::{body::Body, extract::Request, response::Response}; use futures_util::{pin_mut, FutureExt}; use hyper::body::Incoming; -use hyper_util::{ - rt::{TokioExecutor, TokioIo}, - server::conn::auto::Builder, -}; -use pin_project_lite::pin_project; +use hyper_util::rt::{TokioExecutor, TokioIo}; +#[cfg(any(feature = "http1", feature = "http2"))] +use hyper_util::{server::conn::auto::Builder, service::TowerToHyperService}; use tokio::{ net::{TcpListener, TcpStream}, sync::watch, }; -use tower::util::{Oneshot, ServiceExt}; +use tower::ServiceExt as _; use tower_service::Service; /// Serve the service with the supplied listener. @@ -102,15 +98,18 @@ where Serve { tcp_listener, make_service, + tcp_nodelay: None, _marker: PhantomData, } } /// Future returned by [`serve`]. #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] +#[must_use = "futures must be awaited or polled"] pub struct Serve { tcp_listener: TcpListener, make_service: M, + tcp_nodelay: Option, _marker: PhantomData, } @@ -145,9 +144,40 @@ impl Serve { tcp_listener: self.tcp_listener, make_service: self.make_service, signal, + tcp_nodelay: self.tcp_nodelay, _marker: PhantomData, } } + + /// Instructs the server to set the value of the `TCP_NODELAY` option on every accepted connection. + /// + /// See also [`TcpStream::set_nodelay`]. + /// + /// # Example + /// ``` + /// use axum::{Router, routing::get}; + /// + /// # async { + /// let router = Router::new().route("/", get(|| async { "Hello, World!" })); + /// + /// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap(); + /// axum::serve(listener, router) + /// .tcp_nodelay(true) + /// .await + /// .unwrap(); + /// # }; + /// ``` + pub fn tcp_nodelay(self, nodelay: bool) -> Self { + Self { + tcp_nodelay: Some(nodelay), + ..self + } + } + + /// Returns the local address this server is bound to. + pub fn local_addr(&self) -> io::Result { + self.tcp_listener.local_addr() + } } #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] @@ -159,12 +189,14 @@ where let Self { tcp_listener, make_service, + tcp_nodelay, _marker: _, } = self; f.debug_struct("Serve") .field("tcp_listener", tcp_listener) .field("make_service", make_service) + .field("tcp_nodelay", tcp_nodelay) .finish() } } @@ -181,66 +213,59 @@ where type IntoFuture = private::ServeFuture; fn into_future(self) -> Self::IntoFuture { - private::ServeFuture(Box::pin(async move { - let Self { - tcp_listener, - mut make_service, - _marker: _, - } = self; - - loop { - let (tcp_stream, remote_addr) = match tcp_accept(&tcp_listener).await { - Some(conn) => conn, - None => continue, - }; - let tcp_stream = TokioIo::new(tcp_stream); - - poll_fn(|cx| make_service.poll_ready(cx)) - .await - .unwrap_or_else(|err| match err {}); - - let tower_service = make_service - .call(IncomingStream { - tcp_stream: &tcp_stream, - remote_addr, - }) - .await - .unwrap_or_else(|err| match err {}); - - let hyper_service = TowerToHyperService { - service: tower_service, - }; - - tokio::spawn(async move { - match Builder::new(TokioExecutor::new()) - // upgrades needed for websockets - .serve_connection_with_upgrades(tcp_stream, hyper_service) - .await - { - Ok(()) => {} - Err(_err) => { - // This error only appears when the client doesn't send a request and - // terminate the connection. - // - // If client sends one request then terminate connection whenever, it doesn't - // appear. - } - } - }); - } - })) + self.with_graceful_shutdown(std::future::pending()) + .into_future() } } /// Serve future with graceful shutdown enabled. #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] +#[must_use = "futures must be awaited or polled"] pub struct WithGracefulShutdown { tcp_listener: TcpListener, make_service: M, signal: F, + tcp_nodelay: Option, _marker: PhantomData, } +impl WithGracefulShutdown { + /// Instructs the server to set the value of the `TCP_NODELAY` option on every accepted connection. + /// + /// See also [`TcpStream::set_nodelay`]. + /// + /// # Example + /// ``` + /// use axum::{Router, routing::get}; + /// + /// # async { + /// let router = Router::new().route("/", get(|| async { "Hello, World!" })); + /// + /// let listener = tokio::net::TcpListener::bind("0.0.0.0:3000").await.unwrap(); + /// axum::serve(listener, router) + /// .with_graceful_shutdown(shutdown_signal()) + /// .tcp_nodelay(true) + /// .await + /// .unwrap(); + /// # }; + /// + /// async fn shutdown_signal() { + /// // ... + /// } + /// ``` + pub fn tcp_nodelay(self, nodelay: bool) -> Self { + Self { + tcp_nodelay: Some(nodelay), + ..self + } + } + + /// Returns the local address this server is bound to. + pub fn local_addr(&self) -> io::Result { + self.tcp_listener.local_addr() + } +} + #[cfg(all(feature = "tokio", any(feature = "http1", feature = "http2")))] impl Debug for WithGracefulShutdown where @@ -253,6 +278,7 @@ where tcp_listener, make_service, signal, + tcp_nodelay, _marker: _, } = self; @@ -260,6 +286,7 @@ where .field("tcp_listener", tcp_listener) .field("make_service", make_service) .field("signal", signal) + .field("tcp_nodelay", tcp_nodelay) .finish() } } @@ -281,20 +308,21 @@ where tcp_listener, mut make_service, signal, + tcp_nodelay, _marker: _, } = self; - let (signal_tx, signal_rx) = watch::channel(()); - let signal_tx = Arc::new(signal_tx); - tokio::spawn(async move { - signal.await; - trace!("received graceful shutdown signal. Telling tasks to shutdown"); - drop(signal_rx); - }); + private::ServeFuture(Box::pin(async move { + let (signal_tx, signal_rx) = watch::channel(()); + let signal_tx = Arc::new(signal_tx); + tokio::spawn(async move { + signal.await; + trace!("received graceful shutdown signal. Telling tasks to shutdown"); + drop(signal_rx); + }); - let (close_tx, close_rx) = watch::channel(()); + let (close_tx, close_rx) = watch::channel(()); - private::ServeFuture(Box::pin(async move { loop { let (tcp_stream, remote_addr) = tokio::select! { conn = tcp_accept(&tcp_listener) => { @@ -308,6 +336,13 @@ where break; } }; + + if let Some(nodelay) = tcp_nodelay { + if let Err(err) = tcp_stream.set_nodelay(nodelay) { + trace!("failed to set TCP_NODELAY on incoming connection: {err:#}"); + } + } + let tcp_stream = TokioIo::new(tcp_stream); trace!("connection {remote_addr} accepted"); @@ -322,18 +357,21 @@ where remote_addr, }) .await - .unwrap_or_else(|err| match err {}); + .unwrap_or_else(|err| match err {}) + .map_request(|req: Request| req.map(Body::new)); - let hyper_service = TowerToHyperService { - service: tower_service, - }; + let hyper_service = TowerToHyperService::new(tower_service); let signal_tx = Arc::clone(&signal_tx); let close_rx = close_rx.clone(); tokio::spawn(async move { - let builder = Builder::new(TokioExecutor::new()); + #[allow(unused_mut)] + let mut builder = Builder::new(TokioExecutor::new()); + // CONNECT protocol needed for HTTP/2 websockets + #[cfg(feature = "http2")] + builder.http2().enable_connect_protocol(); let conn = builder.serve_connection_with_upgrades(tcp_stream, hyper_service); pin_mut!(conn); @@ -436,49 +474,6 @@ mod private { } } -#[derive(Debug, Copy, Clone)] -struct TowerToHyperService { - service: S, -} - -impl hyper::service::Service> for TowerToHyperService -where - S: tower_service::Service + Clone, -{ - type Response = S::Response; - type Error = S::Error; - type Future = TowerToHyperServiceFuture; - - fn call(&self, req: Request) -> Self::Future { - let req = req.map(Body::new); - TowerToHyperServiceFuture { - future: self.service.clone().oneshot(req), - } - } -} - -pin_project! { - struct TowerToHyperServiceFuture - where - S: tower_service::Service, - { - #[pin] - future: Oneshot, - } -} - -impl Future for TowerToHyperServiceFuture -where - S: tower_service::Service, -{ - type Output = Result; - - #[inline] - fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - self.project().future.poll(cx) - } -} - /// An incoming stream. /// /// Used with [`serve`] and [`IntoMakeServiceWithConnectInfo`]. @@ -510,6 +505,10 @@ mod tests { routing::get, Router, }; + use std::{ + future::pending, + net::{IpAddr, Ipv4Addr}, + }; #[allow(dead_code, unused_must_use)] async fn if_it_compiles_it_works() { @@ -556,7 +555,62 @@ mod tests { TcpListener::bind(addr).await.unwrap(), handler.into_make_service_with_connect_info::(), ); + + // nodelay + serve( + TcpListener::bind(addr).await.unwrap(), + handler.into_service(), + ) + .tcp_nodelay(true); + + serve( + TcpListener::bind(addr).await.unwrap(), + handler.into_service(), + ) + .with_graceful_shutdown(async { /*...*/ }) + .tcp_nodelay(true); } async fn handler() {} + + #[crate::test] + async fn test_serve_local_addr() { + let router: Router = Router::new(); + let addr = "0.0.0.0:0"; + + let server = serve(TcpListener::bind(addr).await.unwrap(), router.clone()); + let address = server.local_addr().unwrap(); + + assert_eq!(address.ip(), IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0))); + assert_ne!(address.port(), 0); + } + + #[crate::test] + async fn test_with_graceful_shutdown_local_addr() { + let router: Router = Router::new(); + let addr = "0.0.0.0:0"; + + let server = serve(TcpListener::bind(addr).await.unwrap(), router.clone()) + .with_graceful_shutdown(pending()); + let address = server.local_addr().unwrap(); + + assert_eq!(address.ip(), IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0))); + assert_ne!(address.port(), 0); + } + + #[test] + fn into_future_outside_tokio() { + let router: Router = Router::new(); + let addr = "0.0.0.0:0"; + + let rt = tokio::runtime::Builder::new_multi_thread() + .enable_all() + .build() + .unwrap(); + + let listener = rt.block_on(tokio::net::TcpListener::bind(addr)).unwrap(); + + // Call Serve::into_future outside of a tokio context. This used to panic. + _ = serve(listener, router).into_future(); + } } diff --git a/axum/src/test_helpers/counting_cloneable_state.rs b/axum/src/test_helpers/counting_cloneable_state.rs new file mode 100644 index 0000000000..762d5ce972 --- /dev/null +++ b/axum/src/test_helpers/counting_cloneable_state.rs @@ -0,0 +1,52 @@ +use std::sync::{ + atomic::{AtomicBool, AtomicUsize, Ordering}, + Arc, +}; + +pub(crate) struct CountingCloneableState { + state: Arc, +} + +struct InnerState { + setup_done: AtomicBool, + count: AtomicUsize, +} + +impl CountingCloneableState { + pub(crate) fn new() -> Self { + let inner_state = InnerState { + setup_done: AtomicBool::new(false), + count: AtomicUsize::new(0), + }; + CountingCloneableState { + state: Arc::new(inner_state), + } + } + + pub(crate) fn setup_done(&self) { + self.state.setup_done.store(true, Ordering::SeqCst); + } + + pub(crate) fn count(&self) -> usize { + self.state.count.load(Ordering::SeqCst) + } +} + +impl Clone for CountingCloneableState { + fn clone(&self) -> Self { + let state = self.state.clone(); + if state.setup_done.load(Ordering::SeqCst) { + let bt = std::backtrace::Backtrace::force_capture(); + let bt = bt + .to_string() + .lines() + .filter(|line| line.contains("axum") || line.contains("./src")) + .collect::>() + .join("\n"); + println!("AppState::Clone:\n===============\n{bt}\n"); + state.count.fetch_add(1, Ordering::SeqCst); + } + + CountingCloneableState { state } + } +} diff --git a/axum/src/test_helpers/mod.rs b/axum/src/test_helpers/mod.rs index 3bb1535e40..5c29f78da2 100644 --- a/axum/src/test_helpers/mod.rs +++ b/axum/src/test_helpers/mod.rs @@ -7,7 +7,10 @@ pub(crate) use self::test_client::*; pub(crate) mod tracing_helpers; +pub(crate) mod counting_cloneable_state; + pub(crate) fn assert_send() {} pub(crate) fn assert_sync() {} +#[allow(dead_code)] pub(crate) struct NotSendSync(*const ()); diff --git a/axum/src/test_helpers/test_client.rs b/axum/src/test_helpers/test_client.rs index 31e7074059..2dfa95a01f 100644 --- a/axum/src/test_helpers/test_client.rs +++ b/axum/src/test_helpers/test_client.rs @@ -1,14 +1,36 @@ use super::{serve, Request, Response}; use bytes::Bytes; +use futures_util::future::BoxFuture; use http::{ header::{HeaderName, HeaderValue}, StatusCode, }; -use std::{convert::Infallible, net::SocketAddr, str::FromStr}; +use std::{convert::Infallible, future::IntoFuture, net::SocketAddr}; use tokio::net::TcpListener; use tower::make::Shared; use tower_service::Service; +pub(crate) fn spawn_service(svc: S) -> SocketAddr +where + S: Service + Clone + Send + 'static, + S::Future: Send, +{ + let std_listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); + std_listener.set_nonblocking(true).unwrap(); + let listener = TcpListener::from_std(std_listener).unwrap(); + + let addr = listener.local_addr().unwrap(); + println!("Listening on {addr}"); + + tokio::spawn(async move { + serve(listener, Shared::new(svc)) + .await + .expect("server error") + }); + + addr +} + pub(crate) struct TestClient { client: reqwest::Client, addr: SocketAddr, @@ -20,18 +42,7 @@ impl TestClient { S: Service + Clone + Send + 'static, S::Future: Send, { - let std_listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); - std_listener.set_nonblocking(true).unwrap(); - let listener = TcpListener::from_std(std_listener).unwrap(); - - let addr = listener.local_addr().unwrap(); - println!("Listening on {addr}"); - - tokio::spawn(async move { - serve(listener, Shared::new(svc)) - .await - .expect("server error") - }); + let addr = spawn_service(svc); let client = reqwest::Client::builder() .redirect(reqwest::redirect::Policy::none()) @@ -72,6 +83,11 @@ impl TestClient { builder: self.client.patch(format!("http://{}{}", self.addr, url)), } } + + #[allow(dead_code)] + pub(crate) fn server_port(&self) -> u16 { + self.addr.port() + } } pub(crate) struct RequestBuilder { @@ -79,12 +95,6 @@ pub(crate) struct RequestBuilder { } impl RequestBuilder { - pub(crate) async fn send(self) -> TestResponse { - TestResponse { - response: self.builder.send().await.unwrap(), - } - } - pub(crate) fn body(mut self, body: impl Into) -> Self { self.builder = self.builder.body(body); self @@ -105,15 +115,7 @@ impl RequestBuilder { HeaderValue: TryFrom, >::Error: Into, { - // reqwest still uses http 0.2 - let key: HeaderName = key.try_into().map_err(Into::into).unwrap(); - let key = reqwest::header::HeaderName::from_bytes(key.as_ref()).unwrap(); - - let value: HeaderValue = value.try_into().map_err(Into::into).unwrap(); - let value = reqwest::header::HeaderValue::from_bytes(value.as_bytes()).unwrap(); - self.builder = self.builder.header(key, value); - self } @@ -124,6 +126,19 @@ impl RequestBuilder { } } +impl IntoFuture for RequestBuilder { + type Output = TestResponse; + type IntoFuture = BoxFuture<'static, Self::Output>; + + fn into_future(self) -> Self::IntoFuture { + Box::pin(async { + TestResponse { + response: self.builder.send().await.unwrap(), + } + }) + } +} + #[derive(Debug)] pub(crate) struct TestResponse { response: reqwest::Response, @@ -152,14 +167,7 @@ impl TestResponse { } pub(crate) fn headers(&self) -> http::HeaderMap { - // reqwest still uses http 0.2 so have to convert into http 1.0 - let mut headers = http::HeaderMap::new(); - for (key, value) in self.response.headers() { - let key = http::HeaderName::from_str(key.as_str()).unwrap(); - let value = http::HeaderValue::from_bytes(value.as_bytes()).unwrap(); - headers.insert(key, value); - } - headers + self.response.headers().clone() } pub(crate) async fn chunk(&mut self) -> Option { diff --git a/axum/src/test_helpers/tracing_helpers.rs b/axum/src/test_helpers/tracing_helpers.rs index 3d5cf18149..f7769ee9d7 100644 --- a/axum/src/test_helpers/tracing_helpers.rs +++ b/axum/src/test_helpers/tracing_helpers.rs @@ -1,10 +1,13 @@ use std::{ - future::Future, + future::{Future, IntoFuture}, io, + marker::PhantomData, + pin::Pin, sync::{Arc, Mutex}, }; use serde::{de::DeserializeOwned, Deserialize}; +use tracing::instrument::WithSubscriber; use tracing_subscriber::prelude::*; use tracing_subscriber::{filter::Targets, fmt::MakeWriter}; @@ -17,36 +20,69 @@ pub(crate) struct TracingEvent { } /// Run an async closure and capture the tracing output it produces. -pub(crate) async fn capture_tracing(f: F) -> Vec> +pub(crate) fn capture_tracing(f: F) -> CaptureTracing where - F: Fn() -> Fut, - Fut: Future, T: DeserializeOwned, { - let (make_writer, handle) = TestMakeWriter::new(); - - let subscriber = tracing_subscriber::registry().with( - tracing_subscriber::fmt::layer() - .with_writer(make_writer) - .with_target(true) - .without_time() - .with_ansi(false) - .json() - .flatten_event(false) - .with_filter("axum=trace".parse::().unwrap()), - ); - - let guard = tracing::subscriber::set_default(subscriber); - - f().await; - - drop(guard); - - handle - .take() - .lines() - .map(|line| serde_json::from_str(line).unwrap()) - .collect() + CaptureTracing { + f, + filter: None, + _phantom: PhantomData, + } +} + +pub(crate) struct CaptureTracing { + f: F, + filter: Option, + _phantom: PhantomData T>, +} + +impl CaptureTracing { + pub(crate) fn with_filter(mut self, filter_string: &str) -> Self { + self.filter = Some(filter_string.parse().unwrap()); + self + } +} + +impl IntoFuture for CaptureTracing +where + F: Fn() -> Fut + Send + Sync + 'static, + Fut: Future + Send, + T: DeserializeOwned, +{ + type Output = Vec>; + type IntoFuture = Pin + Send>>; + + fn into_future(self) -> Self::IntoFuture { + let Self { f, filter, .. } = self; + Box::pin(async move { + let (make_writer, handle) = TestMakeWriter::new(); + + let filter = filter.unwrap_or_else(|| "axum=trace".parse().unwrap()); + let subscriber = tracing_subscriber::registry().with( + tracing_subscriber::fmt::layer() + .with_writer(make_writer) + .with_target(true) + .without_time() + .with_ansi(false) + .json() + .flatten_event(false) + .with_filter(filter), + ); + + let guard = tracing::subscriber::set_default(subscriber); + + f().with_current_subscriber().await; + + drop(guard); + + handle + .take() + .lines() + .map(|line| serde_json::from_str(line).unwrap()) + .collect() + }) + } } struct TestMakeWriter { @@ -76,7 +112,7 @@ impl<'a> MakeWriter<'a> for TestMakeWriter { struct Writer<'a>(&'a TestMakeWriter); -impl<'a> io::Write for Writer<'a> { +impl io::Write for Writer<'_> { fn write(&mut self, buf: &[u8]) -> io::Result { match &mut *self.0.write.lock().unwrap() { Some(vec) => { diff --git a/axum/src/util.rs b/axum/src/util.rs index f7fc6ae149..7c9b7864e9 100644 --- a/axum/src/util.rs +++ b/axum/src/util.rs @@ -18,10 +18,6 @@ impl PercentDecodedStr { pub(crate) fn as_str(&self) -> &str { &self.0 } - - pub(crate) fn into_inner(self) -> Arc { - self.0 - } } impl Deref for PercentDecodedStr { diff --git a/clippy.toml b/clippy.toml new file mode 100644 index 0000000000..625309989d --- /dev/null +++ b/clippy.toml @@ -0,0 +1,3 @@ +disallowed-types = [ + { path = "tower::util::BoxCloneService", reason = "Use our internal BoxCloneService which is Sync" }, +] diff --git a/deny.toml b/deny.toml index d95bdf294b..c32c8715a8 100644 --- a/deny.toml +++ b/deny.toml @@ -33,6 +33,11 @@ skip-tree = [ { name = "regex-automata" }, # pulled in by hyper { name = "socket2" }, + # hyper-util hasn't upgraded to 0.5 yet, but it's the same service / layer + # crates beneath + { name = "tower" }, + # tower hasn't upgraded to 1.0 yet + { name = "sync_wrapper" }, ] [sources] diff --git a/examples/README.md b/examples/README.md index 10339295b8..565c1796d6 100644 --- a/examples/README.md +++ b/examples/README.md @@ -4,4 +4,4 @@ This folder contains numerous examples showing how to use axum. Each example is setup as its own crate so its dependencies are clear. For a list of what the community built with axum, please see the list -[here](../ECOSYSTEM.md). \ No newline at end of file +[here](../ECOSYSTEM.md). diff --git a/examples/auto-reload/src/main.rs b/examples/auto-reload/src/main.rs index 5813ec1ff1..8d71805ba1 100644 --- a/examples/auto-reload/src/main.rs +++ b/examples/auto-reload/src/main.rs @@ -1,7 +1,7 @@ //! Run with //! //! ```not_rust -//! cargo run -p example-hello-world +//! cargo run -p auto-reload //! ``` use axum::{response::Html, routing::get, Router}; @@ -16,7 +16,10 @@ async fn main() { let mut listenfd = ListenFd::from_env(); let listener = match listenfd.take_tcp_listener(0).unwrap() { // if we are given a tcp listener on listen fd 0, we use that one - Some(listener) => TcpListener::from_std(listener).unwrap(), + Some(listener) => { + listener.set_nonblocking(true).unwrap(); + TcpListener::from_std(listener).unwrap() + } // otherwise fall back to local listening None => TcpListener::bind("127.0.0.1:3000").await.unwrap(), }; diff --git a/examples/chat/Cargo.toml b/examples/chat/Cargo.toml index 90d88246fa..2beb99cf85 100644 --- a/examples/chat/Cargo.toml +++ b/examples/chat/Cargo.toml @@ -8,6 +8,5 @@ publish = false axum = { path = "../../axum", features = ["ws"] } futures = "0.3" tokio = { version = "1", features = ["full"] } -tower = { version = "0.4", features = ["util"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/chat/src/main.rs b/examples/chat/src/main.rs index 02e3bdc060..77baada1b5 100644 --- a/examples/chat/src/main.rs +++ b/examples/chat/src/main.rs @@ -36,7 +36,7 @@ async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_chat=trace".into()), + .unwrap_or_else(|_| format!("{}=trace", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); @@ -130,8 +130,8 @@ async fn websocket(stream: WebSocket, state: Arc) { // If any one of the tasks run to completion, we abort the other. tokio::select! { - _ = (&mut send_task) => recv_task.abort(), - _ = (&mut recv_task) => send_task.abort(), + _ = &mut send_task => recv_task.abort(), + _ = &mut recv_task => send_task.abort(), }; // Send "user left" message (similar to "joined" above). diff --git a/examples/compression/Cargo.toml b/examples/compression/Cargo.toml new file mode 100644 index 0000000000..a65d9d0d1b --- /dev/null +++ b/examples/compression/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "example-compression" +version = "0.1.0" +edition = "2021" +publish = false + +[dependencies] +axum = { path = "../../axum" } +serde_json = "1" +tokio = { version = "1", features = ["macros", "rt-multi-thread"] } +tower = "0.5.1" +tower-http = { version = "0.6.1", features = ["compression-full", "decompression-full"] } +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } + +[dev-dependencies] +assert-json-diff = "2.0" +brotli = "6.0" +flate2 = "1" +http = "1" +zstd = "0.13" diff --git a/examples/compression/README.md b/examples/compression/README.md new file mode 100644 index 0000000000..3f0ed94da7 --- /dev/null +++ b/examples/compression/README.md @@ -0,0 +1,32 @@ +# compression + +This example shows how to: +- automatically decompress request bodies when necessary +- compress response bodies based on the `accept` header. + +## Running + +``` +cargo run -p example-compression +``` + +## Sending compressed requests + +``` +curl -v -g 'http://localhost:3000/' \ + -H "Content-Type: application/json" \ + -H "Content-Encoding: gzip" \ + --compressed \ + --data-binary @data/products.json.gz +``` + +(Notice the `Content-Encoding: gzip` in the request, and `content-encoding: gzip` in the response.) + +## Sending uncompressed requests + +``` +curl -v -g 'http://localhost:3000/' \ + -H "Content-Type: application/json" \ + --compressed \ + --data-binary @data/products.json +``` diff --git a/examples/compression/data/products.json b/examples/compression/data/products.json new file mode 100644 index 0000000000..a234fbdd2a --- /dev/null +++ b/examples/compression/data/products.json @@ -0,0 +1,12 @@ +{ + "products": [ + { + "id": 1, + "name": "Product 1" + }, + { + "id": 2, + "name": "Product 2" + } + ] +} diff --git a/examples/compression/data/products.json.gz b/examples/compression/data/products.json.gz new file mode 100644 index 0000000000..91d398955b Binary files /dev/null and b/examples/compression/data/products.json.gz differ diff --git a/examples/compression/src/main.rs b/examples/compression/src/main.rs new file mode 100644 index 0000000000..b487f34e4f --- /dev/null +++ b/examples/compression/src/main.rs @@ -0,0 +1,39 @@ +use axum::{routing::post, Json, Router}; +use serde_json::Value; +use tower::ServiceBuilder; +use tower_http::{compression::CompressionLayer, decompression::RequestDecompressionLayer}; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +#[cfg(test)] +mod tests; + +#[tokio::main] +async fn main() { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| format!("{}=trace", env!("CARGO_CRATE_NAME")).into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + let app: Router = app(); + + let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") + .await + .unwrap(); + tracing::debug!("listening on {}", listener.local_addr().unwrap()); + axum::serve(listener, app).await.unwrap(); +} + +fn app() -> Router { + Router::new().route("/", post(root)).layer( + ServiceBuilder::new() + .layer(RequestDecompressionLayer::new()) + .layer(CompressionLayer::new()), + ) +} + +async fn root(Json(value): Json) -> Json { + Json(value) +} diff --git a/examples/compression/src/tests.rs b/examples/compression/src/tests.rs new file mode 100644 index 0000000000..c91ccaa649 --- /dev/null +++ b/examples/compression/src/tests.rs @@ -0,0 +1,245 @@ +use assert_json_diff::assert_json_eq; +use axum::{ + body::{Body, Bytes}, + response::Response, +}; +use brotli::enc::BrotliEncoderParams; +use flate2::{read::GzDecoder, write::GzEncoder, Compression}; +use http::{header, StatusCode}; +use serde_json::{json, Value}; +use std::io::{Read, Write}; +use tower::ServiceExt; + +use super::*; + +#[tokio::test] +async fn handle_uncompressed_request_bodies() { + // Given + + let body = json(); + + let compressed_request = http::Request::post("/") + .header(header::CONTENT_TYPE, "application/json") + .body(json_body(&body)) + .unwrap(); + + // When + + let response = app().oneshot(compressed_request).await.unwrap(); + + // Then + + assert_eq!(response.status(), StatusCode::OK); + assert_json_eq!(json_from_response(response).await, json()); +} + +#[tokio::test] +async fn decompress_gzip_request_bodies() { + // Given + + let body = compress_gzip(&json()); + + let compressed_request = http::Request::post("/") + .header(header::CONTENT_TYPE, "application/json") + .header(header::CONTENT_ENCODING, "gzip") + .body(Body::from(body)) + .unwrap(); + + // When + + let response = app().oneshot(compressed_request).await.unwrap(); + + // Then + + assert_eq!(response.status(), StatusCode::OK); + assert_json_eq!(json_from_response(response).await, json()); +} + +#[tokio::test] +async fn decompress_br_request_bodies() { + // Given + + let body = compress_br(&json()); + + let compressed_request = http::Request::post("/") + .header(header::CONTENT_TYPE, "application/json") + .header(header::CONTENT_ENCODING, "br") + .body(Body::from(body)) + .unwrap(); + + // When + + let response = app().oneshot(compressed_request).await.unwrap(); + + // Then + + assert_eq!(response.status(), StatusCode::OK); + assert_json_eq!(json_from_response(response).await, json()); +} + +#[tokio::test] +async fn decompress_zstd_request_bodies() { + // Given + + let body = compress_zstd(&json()); + + let compressed_request = http::Request::post("/") + .header(header::CONTENT_TYPE, "application/json") + .header(header::CONTENT_ENCODING, "zstd") + .body(Body::from(body)) + .unwrap(); + + // When + + let response = app().oneshot(compressed_request).await.unwrap(); + + // Then + + assert_eq!(response.status(), StatusCode::OK); + assert_json_eq!(json_from_response(response).await, json()); +} + +#[tokio::test] +async fn do_not_compress_response_bodies() { + // Given + let request = http::Request::post("/") + .header(header::CONTENT_TYPE, "application/json") + .body(json_body(&json())) + .unwrap(); + + // When + + let response = app().oneshot(request).await.unwrap(); + + // Then + + assert_eq!(response.status(), StatusCode::OK); + assert_json_eq!(json_from_response(response).await, json()); +} + +#[tokio::test] +async fn compress_response_bodies_with_gzip() { + // Given + let request = http::Request::post("/") + .header(header::CONTENT_TYPE, "application/json") + .header(header::ACCEPT_ENCODING, "gzip") + .body(json_body(&json())) + .unwrap(); + + // When + + let response = app().oneshot(request).await.unwrap(); + + // Then + + assert_eq!(response.status(), StatusCode::OK); + let response_body = byte_from_response(response).await; + let mut decoder = GzDecoder::new(response_body.as_ref()); + let mut decompress_body = String::new(); + decoder.read_to_string(&mut decompress_body).unwrap(); + assert_json_eq!( + serde_json::from_str::(&decompress_body).unwrap(), + json() + ); +} + +#[tokio::test] +async fn compress_response_bodies_with_br() { + // Given + let request = http::Request::post("/") + .header(header::CONTENT_TYPE, "application/json") + .header(header::ACCEPT_ENCODING, "br") + .body(json_body(&json())) + .unwrap(); + + // When + + let response = app().oneshot(request).await.unwrap(); + + // Then + + assert_eq!(response.status(), StatusCode::OK); + let response_body = byte_from_response(response).await; + let mut decompress_body = Vec::new(); + brotli::BrotliDecompress(&mut response_body.as_ref(), &mut decompress_body).unwrap(); + assert_json_eq!( + serde_json::from_slice::(&decompress_body).unwrap(), + json() + ); +} + +#[tokio::test] +async fn compress_response_bodies_with_zstd() { + // Given + let request = http::Request::post("/") + .header(header::CONTENT_TYPE, "application/json") + .header(header::ACCEPT_ENCODING, "zstd") + .body(json_body(&json())) + .unwrap(); + + // When + + let response = app().oneshot(request).await.unwrap(); + + // Then + + assert_eq!(response.status(), StatusCode::OK); + let response_body = byte_from_response(response).await; + let decompress_body = zstd::stream::decode_all(std::io::Cursor::new(response_body)).unwrap(); + assert_json_eq!( + serde_json::from_slice::(&decompress_body).unwrap(), + json() + ); +} + +fn json() -> Value { + json!({ + "name": "foo", + "mainProduct": { + "typeId": "product", + "id": "p1" + }, + }) +} + +fn json_body(input: &Value) -> Body { + Body::from(serde_json::to_vec(&input).unwrap()) +} + +async fn json_from_response(response: Response) -> Value { + let body = byte_from_response(response).await; + body_as_json(body) +} + +async fn byte_from_response(response: Response) -> Bytes { + axum::body::to_bytes(response.into_body(), usize::MAX) + .await + .unwrap() +} + +fn body_as_json(body: Bytes) -> Value { + serde_json::from_slice(body.as_ref()).unwrap() +} + +fn compress_gzip(json: &Value) -> Vec { + let request_body = serde_json::to_vec(&json).unwrap(); + + let mut encoder = GzEncoder::new(Vec::new(), Compression::default()); + encoder.write_all(&request_body).unwrap(); + encoder.finish().unwrap() +} + +fn compress_br(json: &Value) -> Vec { + let request_body = serde_json::to_vec(&json).unwrap(); + let mut result = Vec::new(); + + let params = BrotliEncoderParams::default(); + let _ = brotli::enc::BrotliCompress(&mut &request_body[..], &mut result, ¶ms).unwrap(); + + result +} + +fn compress_zstd(json: &Value) -> Vec { + let request_body = serde_json::to_vec(&json).unwrap(); + zstd::stream::encode_all(std::io::Cursor::new(request_body), 4).unwrap() +} diff --git a/examples/consume-body-in-extractor-or-middleware/Cargo.toml b/examples/consume-body-in-extractor-or-middleware/Cargo.toml index 9aeb864d61..6688588582 100644 --- a/examples/consume-body-in-extractor-or-middleware/Cargo.toml +++ b/examples/consume-body-in-extractor-or-middleware/Cargo.toml @@ -7,9 +7,6 @@ publish = false [dependencies] axum = { path = "../../axum" } http-body-util = "0.1.0" -hyper = "1.0.0" tokio = { version = "1.0", features = ["full"] } -tower = "0.4" -tower-http = { version = "0.5.0", features = ["map-request-body", "util"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/consume-body-in-extractor-or-middleware/src/main.rs b/examples/consume-body-in-extractor-or-middleware/src/main.rs index 107edb6f1b..3239d6ac6d 100644 --- a/examples/consume-body-in-extractor-or-middleware/src/main.rs +++ b/examples/consume-body-in-extractor-or-middleware/src/main.rs @@ -5,7 +5,6 @@ //! ``` use axum::{ - async_trait, body::{Body, Bytes}, extract::{FromRequest, Request}, http::StatusCode, @@ -22,7 +21,7 @@ async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_consume_body_in_extractor_or_middleware=debug".into()), + .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); @@ -50,7 +49,7 @@ async fn print_request_body(request: Request, next: Next) -> Result Result { let (parts, body) = request.into_parts(); - // this wont work if the body is an long running stream + // this won't work if the body is an long running stream let bytes = body .collect() .await @@ -74,7 +73,6 @@ async fn handler(BufferRequestBody(body): BufferRequestBody) { struct BufferRequestBody(Bytes); // we must implement `FromRequest` (and not `FromRequestParts`) to consume the body -#[async_trait] impl FromRequest for BufferRequestBody where S: Send + Sync, diff --git a/examples/cors/Cargo.toml b/examples/cors/Cargo.toml index 5d5d2edae5..654538fa22 100644 --- a/examples/cors/Cargo.toml +++ b/examples/cors/Cargo.toml @@ -7,4 +7,4 @@ publish = false [dependencies] axum = { path = "../../axum" } tokio = { version = "1.0", features = ["full"] } -tower-http = { version = "0.5.0", features = ["cors"] } +tower-http = { version = "0.6.1", features = ["cors"] } diff --git a/examples/customize-extractor-error/src/custom_extractor.rs b/examples/customize-extractor-error/src/custom_extractor.rs index 3611fba796..4f75fb440d 100644 --- a/examples/customize-extractor-error/src/custom_extractor.rs +++ b/examples/customize-extractor-error/src/custom_extractor.rs @@ -5,7 +5,6 @@ //! - Boilerplate: Requires creating a new extractor for every custom rejection //! - Complexity: Manually implementing `FromRequest` results on more complex code use axum::{ - async_trait, extract::{rejection::JsonRejection, FromRequest, MatchedPath, Request}, http::StatusCode, response::IntoResponse, @@ -20,7 +19,6 @@ pub async fn handler(Json(value): Json) -> impl IntoResponse { // We define our own `Json` extractor that customizes the error from `axum::Json` pub struct Json(pub T); -#[async_trait] impl FromRequest for Json where axum::Json: FromRequest, diff --git a/examples/customize-extractor-error/src/main.rs b/examples/customize-extractor-error/src/main.rs index e8820326f9..48188352e5 100644 --- a/examples/customize-extractor-error/src/main.rs +++ b/examples/customize-extractor-error/src/main.rs @@ -16,7 +16,7 @@ async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_customize_extractor_error=trace".into()), + .unwrap_or_else(|_| format!("{}=trace", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); diff --git a/examples/customize-path-rejection/Cargo.toml b/examples/customize-path-rejection/Cargo.toml index 8f5b1e4487..c1c4884d43 100644 --- a/examples/customize-path-rejection/Cargo.toml +++ b/examples/customize-path-rejection/Cargo.toml @@ -7,7 +7,6 @@ publish = false [dependencies] axum = { path = "../../axum" } serde = { version = "1.0", features = ["derive"] } -serde_json = "1.0" tokio = { version = "1.0", features = ["full"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/customize-path-rejection/src/main.rs b/examples/customize-path-rejection/src/main.rs index 4231eabf60..e784a969b8 100644 --- a/examples/customize-path-rejection/src/main.rs +++ b/examples/customize-path-rejection/src/main.rs @@ -5,7 +5,6 @@ //! ``` use axum::{ - async_trait, extract::{path::ErrorKind, rejection::PathRejection, FromRequestParts}, http::{request::Parts, StatusCode}, response::IntoResponse, @@ -20,13 +19,13 @@ async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_customize_path_rejection=debug".into()), + .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); // build our application with a route - let app = Router::new().route("/users/:user_id/teams/:team_id", get(handler)); + let app = Router::new().route("/users/{user_id}/teams/{team_id}", get(handler)); // run it let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") @@ -49,7 +48,6 @@ struct Params { // We define our own `Path` extractor that customizes the error from `axum::extract::Path` struct Path(T); -#[async_trait] impl FromRequestParts for Path where // these trait bounds are copied from `impl FromRequest for axum::extract::path::Path` diff --git a/examples/dependency-injection/Cargo.toml b/examples/dependency-injection/Cargo.toml new file mode 100644 index 0000000000..1f2801b0a0 --- /dev/null +++ b/examples/dependency-injection/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "example-dependency-injection" +version = "0.1.0" +edition = "2021" +publish = false + +[dependencies] +axum = { path = "../../axum", features = ["tracing", "macros"] } +serde = { version = "1.0", features = ["derive"] } +tokio = { version = "1.0", features = ["full"] } +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } +uuid = { version = "1.0", features = ["serde", "v4"] } diff --git a/examples/dependency-injection/src/main.rs b/examples/dependency-injection/src/main.rs new file mode 100644 index 0000000000..7a4719e768 --- /dev/null +++ b/examples/dependency-injection/src/main.rs @@ -0,0 +1,169 @@ +//! Run with +//! +//! ```not_rust +//! cargo run -p example-dependency-injection +//! ``` + +use std::{ + collections::HashMap, + sync::{Arc, Mutex}, +}; + +use axum::{ + extract::{Path, State}, + http::StatusCode, + routing::{get, post}, + Json, Router, +}; +use serde::{Deserialize, Serialize}; +use tokio::net::TcpListener; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; +use uuid::Uuid; + +#[tokio::main] +async fn main() { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + let user_repo = InMemoryUserRepo::default(); + + // We generally have two ways to inject dependencies: + // + // 1. Using trait objects (`dyn SomeTrait`) + // - Pros + // - Likely leads to simpler code due to fewer type parameters. + // - Cons + // - Less flexible because we can only use object safe traits + // - Small amount of additional runtime overhead due to dynamic dispatch. + // This is likely to be negligible. + // 2. Using generics (`T where T: SomeTrait`) + // - Pros + // - More flexible since all traits can be used. + // - No runtime overhead. + // - Cons: + // - Additional type parameters and trait bounds can lead to more complex code and + // boilerplate. + // + // Using trait objects is recommended unless you really need generics. + + let using_dyn = Router::new() + .route("/users/{id}", get(get_user_dyn)) + .route("/users", post(create_user_dyn)) + .with_state(AppStateDyn { + user_repo: Arc::new(user_repo.clone()), + }); + + let using_generic = Router::new() + .route("/users/{id}", get(get_user_generic::)) + .route("/users", post(create_user_generic::)) + .with_state(AppStateGeneric { user_repo }); + + let app = Router::new() + .nest("/dyn", using_dyn) + .nest("/generic", using_generic); + + let listener = TcpListener::bind("127.0.0.1:3000").await.unwrap(); + tracing::debug!("listening on {}", listener.local_addr().unwrap()); + axum::serve(listener, app).await.unwrap(); +} + +#[derive(Clone)] +struct AppStateDyn { + user_repo: Arc, +} + +#[derive(Clone)] +struct AppStateGeneric { + user_repo: T, +} + +#[derive(Debug, Serialize, Clone)] +struct User { + id: Uuid, + name: String, +} + +#[derive(Deserialize)] +struct UserParams { + name: String, +} + +async fn create_user_dyn( + State(state): State, + Json(params): Json, +) -> Json { + let user = User { + id: Uuid::new_v4(), + name: params.name, + }; + + state.user_repo.save_user(&user); + + Json(user) +} + +async fn get_user_dyn( + State(state): State, + Path(id): Path, +) -> Result, StatusCode> { + match state.user_repo.get_user(id) { + Some(user) => Ok(Json(user)), + None => Err(StatusCode::NOT_FOUND), + } +} + +async fn create_user_generic( + State(state): State>, + Json(params): Json, +) -> Json +where + T: UserRepo, +{ + let user = User { + id: Uuid::new_v4(), + name: params.name, + }; + + state.user_repo.save_user(&user); + + Json(user) +} + +async fn get_user_generic( + State(state): State>, + Path(id): Path, +) -> Result, StatusCode> +where + T: UserRepo, +{ + match state.user_repo.get_user(id) { + Some(user) => Ok(Json(user)), + None => Err(StatusCode::NOT_FOUND), + } +} + +trait UserRepo: Send + Sync { + fn get_user(&self, id: Uuid) -> Option; + + fn save_user(&self, user: &User); +} + +#[derive(Debug, Clone, Default)] +struct InMemoryUserRepo { + map: Arc>>, +} + +impl UserRepo for InMemoryUserRepo { + fn get_user(&self, id: Uuid) -> Option { + self.map.lock().unwrap().get(&id).cloned() + } + + fn save_user(&self, user: &User) { + self.map.lock().unwrap().insert(user.id, user.clone()); + } +} diff --git a/examples/diesel-async-postgres/Cargo.toml b/examples/diesel-async-postgres/Cargo.toml index d86db1516d..efec344044 100644 --- a/examples/diesel-async-postgres/Cargo.toml +++ b/examples/diesel-async-postgres/Cargo.toml @@ -8,9 +8,8 @@ publish = false axum = { path = "../../axum", features = ["macros"] } bb8 = "0.8" diesel = "2" -diesel-async = { version = "0.3", features = ["postgres", "bb8"] } +diesel-async = { version = "0.5", features = ["postgres", "bb8"] } serde = { version = "1.0", features = ["derive"] } -serde_json = "1" tokio = { version = "1.0", features = ["full"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/diesel-async-postgres/src/main.rs b/examples/diesel-async-postgres/src/main.rs index ee42ac1002..44fbb54643 100644 --- a/examples/diesel-async-postgres/src/main.rs +++ b/examples/diesel-async-postgres/src/main.rs @@ -13,7 +13,6 @@ //! for a real world application using axum and diesel use axum::{ - async_trait, extract::{FromRef, FromRequestParts, State}, http::{request::Parts, StatusCode}, response::Json, @@ -57,7 +56,7 @@ async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_diesel_async_postgres=debug".into()), + .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); @@ -102,7 +101,6 @@ struct DatabaseConnection( bb8::PooledConnection<'static, AsyncDieselConnectionManager>, ); -#[async_trait] impl FromRequestParts for DatabaseConnection where S: Send + Sync, diff --git a/examples/diesel-postgres/Cargo.toml b/examples/diesel-postgres/Cargo.toml index ff42a0db68..a68b9df89f 100644 --- a/examples/diesel-postgres/Cargo.toml +++ b/examples/diesel-postgres/Cargo.toml @@ -6,11 +6,10 @@ publish = false [dependencies] axum = { path = "../../axum", features = ["macros"] } -deadpool-diesel = { version = "0.4.1", features = ["postgres"] } +deadpool-diesel = { version = "0.6.1", features = ["postgres"] } diesel = { version = "2", features = ["postgres"] } diesel_migrations = "2" serde = { version = "1.0", features = ["derive"] } -serde_json = "1" tokio = { version = "1.0", features = ["full"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/diesel-postgres/src/main.rs b/examples/diesel-postgres/src/main.rs index 605660d073..0c5852d20b 100644 --- a/examples/diesel-postgres/src/main.rs +++ b/examples/diesel-postgres/src/main.rs @@ -54,7 +54,7 @@ async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_tokio_postgres=debug".into()), + .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); diff --git a/examples/error-handling/Cargo.toml b/examples/error-handling/Cargo.toml index 26fc3b98ee..7aebc903b8 100644 --- a/examples/error-handling/Cargo.toml +++ b/examples/error-handling/Cargo.toml @@ -8,6 +8,6 @@ publish = false axum = { path = "../../axum", features = ["macros"] } serde = { version = "1.0", features = ["derive"] } tokio = { version = "1.0", features = ["full"] } -tower-http = { version = "0.5", features = ["trace"] } +tower-http = { version = "0.6.1", features = ["trace"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/error-handling/src/main.rs b/examples/error-handling/src/main.rs index 6981f59eee..0ad9f43cfa 100644 --- a/examples/error-handling/src/main.rs +++ b/examples/error-handling/src/main.rs @@ -45,8 +45,9 @@ use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; async fn main() { tracing_subscriber::registry() .with( - tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_error_handling=debug,tower_http=debug".into()), + tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { + format!("{}=debug,tower_http=debug", env!("CARGO_CRATE_NAME")).into() + }), ) .with(tracing_subscriber::fmt::layer()) .init(); diff --git a/examples/form/src/main.rs b/examples/form/src/main.rs index 3f9ed09560..02ea23525b 100644 --- a/examples/form/src/main.rs +++ b/examples/form/src/main.rs @@ -13,7 +13,7 @@ async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_form=debug".into()), + .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); diff --git a/examples/global-404-handler/Cargo.toml b/examples/global-404-handler/Cargo.toml index 9848d9e830..a453cab57b 100644 --- a/examples/global-404-handler/Cargo.toml +++ b/examples/global-404-handler/Cargo.toml @@ -7,6 +7,5 @@ publish = false [dependencies] axum = { path = "../../axum" } tokio = { version = "1.0", features = ["full"] } -tower = { version = "0.4", features = ["util"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/global-404-handler/src/main.rs b/examples/global-404-handler/src/main.rs index 38b029439b..bf1d8a95ac 100644 --- a/examples/global-404-handler/src/main.rs +++ b/examples/global-404-handler/src/main.rs @@ -17,7 +17,7 @@ async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_global_404_handler=debug".into()), + .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); diff --git a/examples/graceful-shutdown/Cargo.toml b/examples/graceful-shutdown/Cargo.toml index 86dfd52763..c7a5727423 100644 --- a/examples/graceful-shutdown/Cargo.toml +++ b/examples/graceful-shutdown/Cargo.toml @@ -6,10 +6,6 @@ publish = false [dependencies] axum = { path = "../../axum", features = ["tracing"] } -hyper = { version = "1.0", features = [] } -hyper-util = { version = "0.1", features = ["tokio", "server-auto", "http1"] } tokio = { version = "1.0", features = ["full"] } -tower = { version = "0.4", features = ["util"] } -tower-http = { version = "0.5", features = ["timeout", "trace"] } -tracing = "0.1" +tower-http = { version = "0.6.1", features = ["timeout", "trace"] } tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/graceful-shutdown/src/main.rs b/examples/graceful-shutdown/src/main.rs index d3388c8359..533cf8f145 100644 --- a/examples/graceful-shutdown/src/main.rs +++ b/examples/graceful-shutdown/src/main.rs @@ -21,7 +21,11 @@ async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { - "example_graceful_shutdown=debug,tower_http=debug,axum=trace".into() + format!( + "{}=debug,tower_http=debug,axum=trace", + env!("CARGO_CRATE_NAME") + ) + .into() }), ) .with(tracing_subscriber::fmt::layer().without_time()) diff --git a/examples/handle-head-request/Cargo.toml b/examples/handle-head-request/Cargo.toml index 83a8a66e25..8497b08957 100644 --- a/examples/handle-head-request/Cargo.toml +++ b/examples/handle-head-request/Cargo.toml @@ -11,4 +11,4 @@ tokio = { version = "1.0", features = ["full"] } [dev-dependencies] http-body-util = "0.1.0" hyper = { version = "1.0.0", features = ["full"] } -tower = { version = "0.4", features = ["util"] } +tower = { version = "0.5.1", features = ["util"] } diff --git a/examples/http-proxy/Cargo.toml b/examples/http-proxy/Cargo.toml index aa6070020a..8dc2f19539 100644 --- a/examples/http-proxy/Cargo.toml +++ b/examples/http-proxy/Cargo.toml @@ -9,6 +9,6 @@ axum = { path = "../../axum" } hyper = { version = "1", features = ["full"] } hyper-util = "0.1.1" tokio = { version = "1.0", features = ["full"] } -tower = { version = "0.4", features = ["make"] } +tower = { version = "0.5.1", features = ["make", "util"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/http-proxy/src/main.rs b/examples/http-proxy/src/main.rs index b60ed03daa..90aa5aa817 100644 --- a/examples/http-proxy/src/main.rs +++ b/examples/http-proxy/src/main.rs @@ -36,8 +36,9 @@ use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; async fn main() { tracing_subscriber::registry() .with( - tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_http_proxy=trace,tower_http=debug".into()), + tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { + format!("{}=trace,tower_http=debug", env!("CARGO_CRATE_NAME")).into() + }), ) .with(tracing_subscriber::fmt::layer()) .init(); diff --git a/examples/jwt/Cargo.toml b/examples/jwt/Cargo.toml index b0c76c25d1..c36511a27d 100644 --- a/examples/jwt/Cargo.toml +++ b/examples/jwt/Cargo.toml @@ -7,8 +7,7 @@ publish = false [dependencies] axum = { path = "../../axum" } axum-extra = { path = "../../axum-extra", features = ["typed-header"] } -jsonwebtoken = "8.0" -once_cell = "1.8" +jsonwebtoken = "9.3" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" tokio = { version = "1.0", features = ["full"] } diff --git a/examples/jwt/src/main.rs b/examples/jwt/src/main.rs index 85211851b2..f7877d745f 100644 --- a/examples/jwt/src/main.rs +++ b/examples/jwt/src/main.rs @@ -7,7 +7,6 @@ //! ``` use axum::{ - async_trait, extract::FromRequestParts, http::{request::Parts, StatusCode}, response::{IntoResponse, Response}, @@ -19,10 +18,10 @@ use axum_extra::{ TypedHeader, }; use jsonwebtoken::{decode, encode, DecodingKey, EncodingKey, Header, Validation}; -use once_cell::sync::Lazy; use serde::{Deserialize, Serialize}; use serde_json::json; use std::fmt::Display; +use std::sync::LazyLock; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; // Quick instructions @@ -51,7 +50,7 @@ use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; // -H 'Authorization: Bearer blahblahblah' \ // http://localhost:3000/protected -static KEYS: Lazy = Lazy::new(|| { +static KEYS: LazyLock = LazyLock::new(|| { let secret = std::env::var("JWT_SECRET").expect("JWT_SECRET must be set"); Keys::new(secret.as_bytes()) }); @@ -61,7 +60,7 @@ async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_jwt=debug".into()), + .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); @@ -122,7 +121,6 @@ impl AuthBody { } } -#[async_trait] impl FromRequestParts for Claims where S: Send + Sync, diff --git a/examples/key-value-store/Cargo.toml b/examples/key-value-store/Cargo.toml index c23b28d268..ccd28c2558 100644 --- a/examples/key-value-store/Cargo.toml +++ b/examples/key-value-store/Cargo.toml @@ -7,14 +7,13 @@ publish = false [dependencies] axum = { path = "../../axum" } tokio = { version = "1.0", features = ["full"] } -tower = { version = "0.4", features = ["util", "timeout", "load-shed", "limit"] } -tower-http = { version = "0.5.0", features = [ +tower = { version = "0.5.1", features = ["util", "timeout", "load-shed", "limit"] } +tower-http = { version = "0.6.1", features = [ "add-extension", "auth", "compression-full", "limit", "trace", ] } -tower-layer = "0.3.2" tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/key-value-store/src/main.rs b/examples/key-value-store/src/main.rs index 1e2a5e748c..c2b3f51cda 100644 --- a/examples/key-value-store/src/main.rs +++ b/examples/key-value-store/src/main.rs @@ -33,8 +33,9 @@ use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; async fn main() { tracing_subscriber::registry() .with( - tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_key_value_store=debug,tower_http=debug".into()), + tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { + format!("{}=debug,tower_http=debug", env!("CARGO_CRATE_NAME")).into() + }), ) .with(tracing_subscriber::fmt::layer()) .init(); @@ -44,7 +45,7 @@ async fn main() { // Build our application by composing routes let app = Router::new() .route( - "/:key", + "/{key}", // Add compression to `kv_get` get(kv_get.layer(CompressionLayer::new())) // But don't compress `kv_set` @@ -124,7 +125,7 @@ fn admin_routes() -> Router { Router::new() .route("/keys", delete(delete_all_keys)) - .route("/key/:key", delete(remove_key)) + .route("/key/{key}", delete(remove_key)) // Require bearer auth for all admin routes .layer(ValidateRequestHeaderLayer::bearer("secret-token")) } diff --git a/examples/listen-multiple-addrs/Cargo.toml b/examples/listen-multiple-addrs/Cargo.toml deleted file mode 100644 index 8940b94332..0000000000 --- a/examples/listen-multiple-addrs/Cargo.toml +++ /dev/null @@ -1,12 +0,0 @@ -[package] -name = "listen-multiple-addrs" -version = "0.1.0" -edition = "2021" -publish = false - -[dependencies] -axum = { path = "../../axum" } -hyper = { version = "1.0.0", features = ["full"] } -hyper-util = { version = "0.1", features = ["tokio", "server-auto", "http1"] } -tokio = { version = "1", features = ["full"] } -tower = { version = "0.4", features = ["util"] } diff --git a/examples/listen-multiple-addrs/src/main.rs b/examples/listen-multiple-addrs/src/main.rs deleted file mode 100644 index dafd4d64fc..0000000000 --- a/examples/listen-multiple-addrs/src/main.rs +++ /dev/null @@ -1,57 +0,0 @@ -//! Showcases how listening on multiple addrs is possible. -//! -//! This may be useful in cases where the platform does not -//! listen on both IPv4 and IPv6 when the IPv6 catch-all listener is used (`::`), -//! [like older versions of Windows.](https://docs.microsoft.com/en-us/windows/win32/winsock/dual-stack-sockets) - -use axum::{extract::Request, routing::get, Router}; -use hyper::body::Incoming; -use hyper_util::{ - rt::{TokioExecutor, TokioIo}, - server, -}; -use std::net::{Ipv4Addr, Ipv6Addr, SocketAddr}; -use tokio::net::TcpListener; -use tower::Service; - -#[tokio::main] -async fn main() { - let app: Router = Router::new().route("/", get(|| async { "Hello, World!" })); - - let localhost_v4 = SocketAddr::new(Ipv4Addr::LOCALHOST.into(), 8080); - let listener_v4 = TcpListener::bind(&localhost_v4).await.unwrap(); - - let localhost_v6 = SocketAddr::new(Ipv6Addr::LOCALHOST.into(), 8080); - let listener_v6 = TcpListener::bind(&localhost_v6).await.unwrap(); - - // See https://github.com/tokio-rs/axum/blob/main/examples/serve-with-hyper/src/main.rs for - // more details about this setup - loop { - // Accept connections from `listener_v4` and `listener_v6` at the same time - let (socket, _remote_addr) = tokio::select! { - result = listener_v4.accept() => { - result.unwrap() - } - result = listener_v6.accept() => { - result.unwrap() - } - }; - - let tower_service = app.clone(); - - tokio::spawn(async move { - let socket = TokioIo::new(socket); - - let hyper_service = hyper::service::service_fn(move |request: Request| { - tower_service.clone().call(request) - }); - - if let Err(err) = server::conn::auto::Builder::new(TokioExecutor::new()) - .serve_connection_with_upgrades(socket, hyper_service) - .await - { - eprintln!("failed to serve connection: {err:#}"); - } - }); - } -} diff --git a/examples/low-level-native-tls/Cargo.toml b/examples/low-level-native-tls/Cargo.toml new file mode 100644 index 0000000000..eee80081c9 --- /dev/null +++ b/examples/low-level-native-tls/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "example-low-level-native-tls" +version = "0.1.0" +edition = "2021" +publish = false + +[dependencies] +axum = { path = "../../axum" } +futures-util = { version = "0.3", default-features = false } +hyper = { version = "1.0.0", features = ["full"] } +hyper-util = { version = "0.1" } +tokio = { version = "1", features = ["full"] } +tokio-native-tls = "0.3.1" +tower = { version = "0.5.1", features = ["make"] } +tower-service = "0.3.2" +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/low-level-native-tls/self_signed_certs/cert.pem b/examples/low-level-native-tls/self_signed_certs/cert.pem new file mode 100644 index 0000000000..656aa88055 --- /dev/null +++ b/examples/low-level-native-tls/self_signed_certs/cert.pem @@ -0,0 +1,22 @@ +-----BEGIN CERTIFICATE----- +MIIDkzCCAnugAwIBAgIUXVYkRCrM/ge03DVymDtXCuybp7gwDQYJKoZIhvcNAQEL +BQAwWTELMAkGA1UEBhMCVVMxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM +GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDESMBAGA1UEAwwJbG9jYWxob3N0MB4X +DTIxMDczMTE0MjIxMloXDTIyMDczMTE0MjIxMlowWTELMAkGA1UEBhMCVVMxEzAR +BgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5 +IEx0ZDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A +MIIBCgKCAQEA02V5ZjmqLB/VQwTarrz/35qsa83L+DbAoa0001+jVmmC+G9Nufi0 +daroFWj/Uicv2fZWETU8JoZKUrX4BK9og5cg5rln/CtBRWCUYIwRgY9R/CdBGPn4 +kp+XkSJaCw74ZIyLy/Zfux6h8ES1m9YRnBza+s7U+ImRBRf4MRPtXQ3/mqJxAZYq +dOnKnvssRyD2qutgVTAxwMUvJWIivRhRYDj7WOpS4CEEeQxP1iH1/T5P7FdtTGdT +bVBABCA8JhL96uFGPpOYHcM/7R5EIA3yZ5FNg931QzoDITjtXGtQ6y9/l/IYkWm6 +J67RWcN0IoTsZhz0WNU4gAeslVtJLofn8QIDAQABo1MwUTAdBgNVHQ4EFgQUzFnK +NfS4LAYuKeWwHbzooER0yZ0wHwYDVR0jBBgwFoAUzFnKNfS4LAYuKeWwHbzooER0 +yZ0wDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAk4O+e9jia59W +ZwetN4GU7OWcYhmOgSizRSs6u7mTfp62LDMt96WKU3THksOnZ44HnqWQxsSfdFVU +XJD12tjvVU8Z4FWzQajcHeemUYiDze8EAh6TnxnUcOrU8IcwiKGxCWRY/908jnWg ++MMscfMCMYTRdeTPqD8fGzAlUCtmyzH6KLE3s4Oo/r5+NR+Uvrwpdvb7xe0MwwO9 +Q/zR4N8ep/HwHVEObcaBofE1ssZLksX7ZgCP9wMgXRWpNAtC5EWxMbxYjBfWFH24 +fDJlBMiGJWg8HHcxK7wQhFh+fuyNzE+xEWPsI9VL1zDftd9x8/QsOagyEOnY8Vxr +AopvZ09uEQ== +-----END CERTIFICATE----- diff --git a/examples/low-level-native-tls/self_signed_certs/key.pem b/examples/low-level-native-tls/self_signed_certs/key.pem new file mode 100644 index 0000000000..3de14eb32f --- /dev/null +++ b/examples/low-level-native-tls/self_signed_certs/key.pem @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDTZXlmOaosH9VD +BNquvP/fmqxrzcv4NsChrTTTX6NWaYL4b025+LR1qugVaP9SJy/Z9lYRNTwmhkpS +tfgEr2iDlyDmuWf8K0FFYJRgjBGBj1H8J0EY+fiSn5eRIloLDvhkjIvL9l+7HqHw +RLWb1hGcHNr6ztT4iZEFF/gxE+1dDf+aonEBlip06cqe+yxHIPaq62BVMDHAxS8l +YiK9GFFgOPtY6lLgIQR5DE/WIfX9Pk/sV21MZ1NtUEAEIDwmEv3q4UY+k5gdwz/t +HkQgDfJnkU2D3fVDOgMhOO1ca1DrL3+X8hiRabonrtFZw3QihOxmHPRY1TiAB6yV +W0kuh+fxAgMBAAECggEADltu8k1qTFLhJgsXWxTFAAe+PBgfCT2WuaRM2So+qqjB +12Of0MieYPt5hbK63HaC3nfHgqWt7yPhulpXfOH45C8IcgMXl93MMg0MJr58leMI ++2ojFrIrerHSFm5R1TxwDEwrVm/mMowzDWFtQCc6zPJ8wNn5RuP48HKfTZ3/2fjw +zEjSwPO2wFMfo1EJNTjlI303lFbdFBs67NaX6puh30M7Tn+gznHKyO5a7F57wkIt +fkgnEy/sgMedQlwX7bRpUoD6f0fZzV8Qz4cHFywtYErczZJh3VGitJoO/VCIDdty +RPXOAqVDd7EpP1UUehZlKVWZ0OZMEfRgKbRCel5abQKBgQDwgwrIQ5+BiZv6a0VT +ETeXB+hRbvBinRykNo/RvLc3j1enRh9/zO/ShadZIXgOAiM1Jnr5Gp8KkNGca6K1 +myhtad7xYPODYzNXXp6T1OPgZxHZLIYzVUj6ypXeV64Te5ZiDaJ1D49czsq+PqsQ +XRcgBJSNpFtDFiXWpjXWfx8PxwKBgQDhAnLY5Sl2eeQo+ud0MvjwftB/mN2qCzJY +5AlQpRI4ThWxJgGPuHTR29zVa5iWNYuA5LWrC1y/wx+t5HKUwq+5kxvs+npYpDJD +ZX/w0Glc6s0Jc/mFySkbw9B2LePedL7lRF5OiAyC6D106Sc9V2jlL4IflmOzt4CD +ZTNbLtC6hwKBgHfIzBXxl/9sCcMuqdg1Ovp9dbcZCaATn7ApfHd5BccmHQGyav27 +k7XF2xMJGEHhzqcqAxUNrSgV+E9vTBomrHvRvrd5Ec7eGTPqbBA0d0nMC5eeFTh7 +wV0miH20LX6Gjt9G6yJiHYSbeV5G1+vOcTYBEft5X/qJjU7aePXbWh0BAoGBAJlV +5tgCCuhvFloK6fHYzqZtdT6O+PfpW20SMXrgkvMF22h2YvgDFrDwqKRUB47NfHzg +3yBpxNH1ccA5/w97QO8w3gX3h6qicpJVOAPusu6cIBACFZfjRv1hyszOZwvw+Soa +Fj5kHkqTY1YpkREPYS9V2dIW1Wjic1SXgZDw7VM/AoGAP/cZ3ZHTSCDTFlItqy5C +rIy2AiY0WJsx+K0qcvtosPOOwtnGjWHb1gdaVdfX/IRkSsX4PAOdnsyidNC5/l/m +y8oa+5WEeGFclWFhr4dnTA766o8HrM2UjIgWWYBF2VKdptGnHxFeJWFUmeQC/xeW +w37pCS7ykL+7gp7V0WShYsw= +-----END PRIVATE KEY----- diff --git a/examples/low-level-native-tls/src/main.rs b/examples/low-level-native-tls/src/main.rs new file mode 100644 index 0000000000..d676238dfa --- /dev/null +++ b/examples/low-level-native-tls/src/main.rs @@ -0,0 +1,101 @@ +//! Run with +//! +//! ```not_rust +//! cargo run -p example-low-level-native-tls +//! ``` + +use axum::{extract::Request, routing::get, Router}; +use futures_util::pin_mut; +use hyper::body::Incoming; +use hyper_util::rt::{TokioExecutor, TokioIo}; +use std::path::PathBuf; +use tokio::net::TcpListener; +use tokio_native_tls::{ + native_tls::{Identity, Protocol, TlsAcceptor as NativeTlsAcceptor}, + TlsAcceptor, +}; +use tower_service::Service; +use tracing::{error, info, warn}; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +#[tokio::main] +async fn main() { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| "example_low_level_rustls=debug".into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + let tls_acceptor = native_tls_acceptor( + PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("self_signed_certs") + .join("key.pem"), + PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("self_signed_certs") + .join("cert.pem"), + ); + + let tls_acceptor = TlsAcceptor::from(tls_acceptor); + let bind = "[::1]:3000"; + let tcp_listener = TcpListener::bind(bind).await.unwrap(); + info!("HTTPS server listening on {bind}. To contact curl -k https://localhost:3000"); + let app = Router::new().route("/", get(handler)); + + pin_mut!(tcp_listener); + loop { + let tower_service = app.clone(); + let tls_acceptor = tls_acceptor.clone(); + + // Wait for new tcp connection + let (cnx, addr) = tcp_listener.accept().await.unwrap(); + + tokio::spawn(async move { + // Wait for tls handshake to happen + let Ok(stream) = tls_acceptor.accept(cnx).await else { + error!("error during tls handshake connection from {}", addr); + return; + }; + + // Hyper has its own `AsyncRead` and `AsyncWrite` traits and doesn't use tokio. + // `TokioIo` converts between them. + let stream = TokioIo::new(stream); + + // Hyper also has its own `Service` trait and doesn't use tower. We can use + // `hyper::service::service_fn` to create a hyper `Service` that calls our app through + // `tower::Service::call`. + let hyper_service = hyper::service::service_fn(move |request: Request| { + // We have to clone `tower_service` because hyper's `Service` uses `&self` whereas + // tower's `Service` requires `&mut self`. + // + // We don't need to call `poll_ready` since `Router` is always ready. + tower_service.clone().call(request) + }); + + let ret = hyper_util::server::conn::auto::Builder::new(TokioExecutor::new()) + .serve_connection_with_upgrades(stream, hyper_service) + .await; + + if let Err(err) = ret { + warn!("error serving connection from {addr}: {err}"); + } + }); + } +} + +async fn handler() -> &'static str { + "Hello, World!" +} + +fn native_tls_acceptor(key_file: PathBuf, cert_file: PathBuf) -> NativeTlsAcceptor { + let key_pem = std::fs::read_to_string(&key_file).unwrap(); + let cert_pem = std::fs::read_to_string(&cert_file).unwrap(); + + let id = Identity::from_pkcs8(cert_pem.as_bytes(), key_pem.as_bytes()).unwrap(); + NativeTlsAcceptor::builder(id) + // let's be modern + .min_protocol_version(Some(Protocol::Tlsv12)) + .build() + .unwrap() +} diff --git a/examples/low-level-openssl/Cargo.toml b/examples/low-level-openssl/Cargo.toml index c5247dec9c..a74a950e56 100644 --- a/examples/low-level-openssl/Cargo.toml +++ b/examples/low-level-openssl/Cargo.toml @@ -12,6 +12,6 @@ hyper-util = { version = "0.1" } openssl = "0.10" tokio = { version = "1", features = ["full"] } tokio-openssl = "0.6" -tower = { version = "0.4", features = ["make"] } +tower = { version = "0.5.1", features = ["make"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/low-level-openssl/src/main.rs b/examples/low-level-openssl/src/main.rs index 1b47375621..7c483010c5 100644 --- a/examples/low-level-openssl/src/main.rs +++ b/examples/low-level-openssl/src/main.rs @@ -15,7 +15,7 @@ async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_low_level_openssl=debug".into()), + .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); @@ -73,7 +73,7 @@ async fn main() { // `TokioIo` converts between them. let stream = TokioIo::new(tls_stream); - // Hyper has also its own `Service` trait and doesn't use tower. We can use + // Hyper also has its own `Service` trait and doesn't use tower. We can use // `hyper::service::service_fn` to create a hyper `Service` that calls our app through // `tower::Service::call`. let hyper_service = hyper::service::service_fn(move |request: Request| { diff --git a/examples/low-level-rustls/Cargo.toml b/examples/low-level-rustls/Cargo.toml index 3975fcb917..eace54a846 100644 --- a/examples/low-level-rustls/Cargo.toml +++ b/examples/low-level-rustls/Cargo.toml @@ -9,10 +9,8 @@ axum = { path = "../../axum" } futures-util = { version = "0.3", default-features = false } hyper = { version = "1.0.0", features = ["full"] } hyper-util = { version = "0.1" } -rustls-pemfile = "1.0.4" tokio = { version = "1", features = ["full"] } -tokio-rustls = "0.24.1" -tower = { version = "0.4", features = ["make"] } +tokio-rustls = "0.26" tower-service = "0.3.2" tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/low-level-rustls/src/main.rs b/examples/low-level-rustls/src/main.rs index 0a28e979fe..d0627ed832 100644 --- a/examples/low-level-rustls/src/main.rs +++ b/examples/low-level-rustls/src/main.rs @@ -8,16 +8,14 @@ use axum::{extract::Request, routing::get, Router}; use futures_util::pin_mut; use hyper::body::Incoming; use hyper_util::rt::{TokioExecutor, TokioIo}; -use rustls_pemfile::{certs, pkcs8_private_keys}; use std::{ - fs::File, - io::BufReader, path::{Path, PathBuf}, sync::Arc, }; use tokio::net::TcpListener; use tokio_rustls::{ - rustls::{Certificate, PrivateKey, ServerConfig}, + rustls::pki_types::{pem::PemObject, CertificateDer, PrivateKeyDer}, + rustls::ServerConfig, TlsAcceptor, }; use tower_service::Service; @@ -29,7 +27,7 @@ async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_low_level_rustls=debug".into()), + .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); @@ -68,7 +66,7 @@ async fn main() { // `TokioIo` converts between them. let stream = TokioIo::new(stream); - // Hyper has also its own `Service` trait and doesn't use tower. We can use + // Hyper also has its own `Service` trait and doesn't use tower. We can use // `hyper::service::service_fn` to create a hyper `Service` that calls our app through // `tower::Service::call`. let hyper_service = hyper::service::service_fn(move |request: Request| { @@ -95,18 +93,14 @@ async fn handler() -> &'static str { } fn rustls_server_config(key: impl AsRef, cert: impl AsRef) -> Arc { - let mut key_reader = BufReader::new(File::open(key).unwrap()); - let mut cert_reader = BufReader::new(File::open(cert).unwrap()); + let key = PrivateKeyDer::from_pem_file(key).unwrap(); - let key = PrivateKey(pkcs8_private_keys(&mut key_reader).unwrap().remove(0)); - let certs = certs(&mut cert_reader) + let certs = CertificateDer::pem_file_iter(cert) .unwrap() - .into_iter() - .map(Certificate) + .map(|cert| cert.unwrap()) .collect(); let mut config = ServerConfig::builder() - .with_safe_defaults() .with_no_client_auth() .with_single_cert(certs, key) .expect("bad certificate/key"); diff --git a/examples/mongodb/Cargo.toml b/examples/mongodb/Cargo.toml new file mode 100644 index 0000000000..c084a36f7d --- /dev/null +++ b/examples/mongodb/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "example-mongodb" +version = "0.1.0" +edition = "2021" +publish = false + +[dependencies] +axum = { path = "../../axum" } +mongodb = "3.1.0" +serde = { version = "1.0", features = ["derive"] } +tokio = { version = "1.0", features = ["full"] } +tower-http = { version = "0.6.1", features = ["add-extension", "trace"] } +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/mongodb/src/main.rs b/examples/mongodb/src/main.rs new file mode 100644 index 0000000000..2cf25f4f31 --- /dev/null +++ b/examples/mongodb/src/main.rs @@ -0,0 +1,132 @@ +//! Run with +//! +//! ```not_rust +//! cargo run -p example-mongodb +//! ``` + +use axum::{ + extract::{Path, State}, + http::StatusCode, + routing::{delete, get, post, put}, + Json, Router, +}; +use mongodb::{ + bson::doc, + results::{DeleteResult, InsertOneResult, UpdateResult}, + Client, Collection, +}; +use serde::{Deserialize, Serialize}; +use tower_http::trace::TraceLayer; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +#[tokio::main] +async fn main() { + // connecting to mongodb + let db_connection_str = std::env::var("DATABASE_URL").unwrap_or_else(|_| { + "mongodb://admin:password@127.0.0.1:27017/?authSource=admin".to_string() + }); + let client = Client::with_uri_str(db_connection_str).await.unwrap(); + + // pinging the database + client + .database("axum-mongo") + .run_command(doc! { "ping": 1 }) + .await + .unwrap(); + println!("Pinged your database. Successfully connected to MongoDB!"); + + // logging middleware + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { + format!("{}=debug,tower_http=debug", env!("CARGO_CRATE_NAME")).into() + }), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + // run it + let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") + .await + .unwrap(); + tracing::debug!("Listening on {}", listener.local_addr().unwrap()); + axum::serve(listener, app(client)).await.unwrap(); +} + +// defining routes and state +fn app(client: Client) -> Router { + let collection: Collection = client.database("axum-mongo").collection("members"); + + Router::new() + .route("/create", post(create_member)) + .route("/read/:id", get(read_member)) + .route("/update", put(update_member)) + .route("/delete/:id", delete(delete_member)) + .layer(TraceLayer::new_for_http()) + .with_state(collection) +} + +// handler to create a new member +async fn create_member( + State(db): State>, + Json(input): Json, +) -> Result, (StatusCode, String)> { + let result = db.insert_one(input).await.map_err(internal_error)?; + + Ok(Json(result)) +} + +// handler to read an existing member +async fn read_member( + State(db): State>, + Path(id): Path, +) -> Result>, (StatusCode, String)> { + let result = db + .find_one(doc! { "_id": id }) + .await + .map_err(internal_error)?; + + Ok(Json(result)) +} + +// handler to update an existing member +async fn update_member( + State(db): State>, + Json(input): Json, +) -> Result, (StatusCode, String)> { + let result = db + .replace_one(doc! { "_id": input.id }, input) + .await + .map_err(internal_error)?; + + Ok(Json(result)) +} + +// handler to delete an existing member +async fn delete_member( + State(db): State>, + Path(id): Path, +) -> Result, (StatusCode, String)> { + let result = db + .delete_one(doc! { "_id": id }) + .await + .map_err(internal_error)?; + + Ok(Json(result)) +} + +fn internal_error(err: E) -> (StatusCode, String) +where + E: std::error::Error, +{ + (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()) +} + +// defining Member type +#[derive(Debug, Deserialize, Serialize)] +struct Member { + #[serde(rename = "_id")] + id: u32, + name: String, + active: bool, +} diff --git a/examples/multipart-form/Cargo.toml b/examples/multipart-form/Cargo.toml index d93b9c08e8..143154e89d 100644 --- a/examples/multipart-form/Cargo.toml +++ b/examples/multipart-form/Cargo.toml @@ -7,6 +7,6 @@ publish = false [dependencies] axum = { path = "../../axum", features = ["multipart"] } tokio = { version = "1.0", features = ["full"] } -tower-http = { version = "0.5.0", features = ["limit", "trace"] } +tower-http = { version = "0.6.1", features = ["limit", "trace"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/multipart-form/src/main.rs b/examples/multipart-form/src/main.rs index ecf5191f2a..30fcfc70a9 100644 --- a/examples/multipart-form/src/main.rs +++ b/examples/multipart-form/src/main.rs @@ -17,8 +17,9 @@ use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; async fn main() { tracing_subscriber::registry() .with( - tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_multipart_form=debug,tower_http=debug".into()), + tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { + format!("{}=debug,tower_http=debug", env!("CARGO_CRATE_NAME")).into() + }), ) .with(tracing_subscriber::fmt::layer()) .init(); diff --git a/examples/oauth/Cargo.toml b/examples/oauth/Cargo.toml index 0186b95bfe..aafa29f4dc 100644 --- a/examples/oauth/Cargo.toml +++ b/examples/oauth/Cargo.toml @@ -12,7 +12,7 @@ axum-extra = { path = "../../axum-extra", features = ["typed-header"] } http = "1.0.0" oauth2 = "4.1" # Use Rustls because it makes it easier to cross-compile on CI -reqwest = { version = "0.11", default-features = false, features = ["rustls-tls", "json"] } +reqwest = { version = "0.12", default-features = false, features = ["rustls-tls", "json"] } serde = { version = "1.0", features = ["derive"] } tokio = { version = "1.0", features = ["full"] } tracing = "0.1" diff --git a/examples/oauth/src/main.rs b/examples/oauth/src/main.rs index 6ceffe8f84..4241db2a7c 100644 --- a/examples/oauth/src/main.rs +++ b/examples/oauth/src/main.rs @@ -8,10 +8,9 @@ //! CLIENT_ID=REPLACE_ME CLIENT_SECRET=REPLACE_ME cargo run -p example-oauth //! ``` -use anyhow::{Context, Result}; +use anyhow::{anyhow, Context, Result}; use async_session::{MemoryStore, Session, SessionStore}; use axum::{ - async_trait, extract::{FromRef, FromRequestParts, Query, State}, http::{header::SET_COOKIE, HeaderMap}, response::{IntoResponse, Redirect, Response}, @@ -29,13 +28,14 @@ use std::env; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; static COOKIE_NAME: &str = "SESSION"; +static CSRF_TOKEN: &str = "csrf_token"; #[tokio::main] async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_oauth=debug".into()), + .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); @@ -142,14 +142,37 @@ async fn index(user: Option) -> impl IntoResponse { } } -async fn discord_auth(State(client): State) -> impl IntoResponse { - let (auth_url, _csrf_token) = client +async fn discord_auth( + State(client): State, + State(store): State, +) -> Result { + let (auth_url, csrf_token) = client .authorize_url(CsrfToken::new_random) .add_scope(Scope::new("identify".to_string())) .url(); - // Redirect to Discord's oauth service - Redirect::to(auth_url.as_ref()) + // Create session to store csrf_token + let mut session = Session::new(); + session + .insert(CSRF_TOKEN, &csrf_token) + .context("failed in inserting CSRF token into session")?; + + // Store the session in MemoryStore and retrieve the session cookie + let cookie = store + .store_session(session) + .await + .context("failed to store CSRF token session")? + .context("unexpected error retrieving CSRF cookie value")?; + + // Attach the session cookie to the response header + let cookie = format!("{COOKIE_NAME}={cookie}; SameSite=Lax; HttpOnly; Secure; Path=/"); + let mut headers = HeaderMap::new(); + headers.insert( + SET_COOKIE, + cookie.parse().context("failed to parse cookie")?, + ); + + Ok((headers, Redirect::to(auth_url.as_ref()))) } // Valid user session required. If there is none, redirect to the auth page @@ -190,11 +213,55 @@ struct AuthRequest { state: String, } +async fn csrf_token_validation_workflow( + auth_request: &AuthRequest, + cookies: &headers::Cookie, + store: &MemoryStore, +) -> Result<(), AppError> { + // Extract the cookie from the request + let cookie = cookies + .get(COOKIE_NAME) + .context("unexpected error getting cookie name")? + .to_string(); + + // Load the session + let session = match store + .load_session(cookie) + .await + .context("failed to load session")? + { + Some(session) => session, + None => return Err(anyhow!("Session not found").into()), + }; + + // Extract the CSRF token from the session + let stored_csrf_token = session + .get::(CSRF_TOKEN) + .context("CSRF token not found in session")? + .to_owned(); + + // Cleanup the CSRF token session + store + .destroy_session(session) + .await + .context("Failed to destroy old session")?; + + // Validate CSRF token is the same as the one in the auth request + if *stored_csrf_token.secret() != auth_request.state { + return Err(anyhow!("CSRF token mismatch").into()); + } + + Ok(()) +} + async fn login_authorized( Query(query): Query, State(store): State, State(oauth_client): State, + TypedHeader(cookies): TypedHeader, ) -> Result { + csrf_token_validation_workflow(&query, &cookies, &store).await?; + // Get an auth token let token = oauth_client .exchange_code(AuthorizationCode::new(query.code.clone())) @@ -229,7 +296,7 @@ async fn login_authorized( .context("unexpected error retrieving cookie value")?; // Build the cookie - let cookie = format!("{COOKIE_NAME}={cookie}; SameSite=Lax; Path=/"); + let cookie = format!("{COOKIE_NAME}={cookie}; SameSite=Lax; HttpOnly; Secure; Path=/"); // Set cookie let mut headers = HeaderMap::new(); @@ -249,7 +316,6 @@ impl IntoResponse for AuthRedirect { } } -#[async_trait] impl FromRequestParts for User where MemoryStore: FromRef, diff --git a/examples/parse-body-based-on-content-type/src/main.rs b/examples/parse-body-based-on-content-type/src/main.rs index bae4ec1d29..1e4fc1ac43 100644 --- a/examples/parse-body-based-on-content-type/src/main.rs +++ b/examples/parse-body-based-on-content-type/src/main.rs @@ -7,7 +7,6 @@ //! ``` use axum::{ - async_trait, extract::{FromRequest, Request}, http::{header::CONTENT_TYPE, StatusCode}, response::{IntoResponse, Response}, @@ -22,7 +21,7 @@ async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { - "example_parse_body_based_on_content_type=debug,tower_http=debug".into() + format!("{}=debug,tower_http=debug", env!("CARGO_CRATE_NAME")).into() }), ) .with(tracing_subscriber::fmt::layer()) @@ -48,7 +47,6 @@ async fn handler(JsonOrForm(payload): JsonOrForm) { struct JsonOrForm(T); -#[async_trait] impl FromRequest for JsonOrForm where S: Send + Sync, diff --git a/examples/print-request-response/Cargo.toml b/examples/print-request-response/Cargo.toml index a314b5b7fe..d6e064bb63 100644 --- a/examples/print-request-response/Cargo.toml +++ b/examples/print-request-response/Cargo.toml @@ -7,8 +7,6 @@ publish = false [dependencies] axum = { path = "../../axum" } http-body-util = "0.1.0" -hyper = { version = "1.0.0", features = ["full"] } tokio = { version = "1.0", features = ["full"] } -tower = { version = "0.4", features = ["util", "filter"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/print-request-response/src/main.rs b/examples/print-request-response/src/main.rs index 5e0d4d1d97..84f14f2d50 100644 --- a/examples/print-request-response/src/main.rs +++ b/examples/print-request-response/src/main.rs @@ -20,8 +20,9 @@ use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; async fn main() { tracing_subscriber::registry() .with( - tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_print_request_response=debug,tower_http=debug".into()), + tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { + format!("{}=debug,tower_http=debug", env!("CARGO_CRATE_NAME")).into() + }), ) .with(tracing_subscriber::fmt::layer()) .init(); diff --git a/examples/prometheus-metrics/Cargo.toml b/examples/prometheus-metrics/Cargo.toml index c29993de74..56ccdd05b0 100644 --- a/examples/prometheus-metrics/Cargo.toml +++ b/examples/prometheus-metrics/Cargo.toml @@ -6,8 +6,8 @@ publish = false [dependencies] axum = { path = "../../axum" } -metrics = "0.18" -metrics-exporter-prometheus = "0.8" +metrics = { version = "0.23", default-features = false } +metrics-exporter-prometheus = { version = "0.15", default-features = false } tokio = { version = "1.0", features = ["full"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/prometheus-metrics/src/main.rs b/examples/prometheus-metrics/src/main.rs index 1944317c34..fe76121ce9 100644 --- a/examples/prometheus-metrics/src/main.rs +++ b/examples/prometheus-metrics/src/main.rs @@ -63,8 +63,9 @@ async fn start_metrics_server() { async fn main() { tracing_subscriber::registry() .with( - tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_todos=debug,tower_http=debug".into()), + tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { + format!("{}=debug,tower_http=debug", env!("CARGO_CRATE_NAME")).into() + }), ) .with(tracing_subscriber::fmt::layer()) .init(); @@ -110,8 +111,8 @@ async fn track_metrics(req: Request, next: Next) -> impl IntoResponse { ("status", status), ]; - metrics::increment_counter!("http_requests_total", &labels); - metrics::histogram!("http_requests_duration_seconds", latency, &labels); + metrics::counter!("http_requests_total", &labels).increment(1); + metrics::histogram!("http_requests_duration_seconds", &labels).record(latency); response } diff --git a/examples/query-params-with-empty-strings/Cargo.toml b/examples/query-params-with-empty-strings/Cargo.toml index 7a52e98d1e..6dde9a3ac6 100644 --- a/examples/query-params-with-empty-strings/Cargo.toml +++ b/examples/query-params-with-empty-strings/Cargo.toml @@ -7,7 +7,6 @@ publish = false [dependencies] axum = { path = "../../axum" } http-body-util = "0.1.0" -hyper = "1.0.0" serde = { version = "1.0", features = ["derive"] } tokio = { version = "1.0", features = ["full"] } -tower = { version = "0.4", features = ["util"] } +tower = { version = "0.5.1", features = ["util"] } diff --git a/examples/readme/Cargo.toml b/examples/readme/Cargo.toml index 4a79c9bb88..17669567da 100644 --- a/examples/readme/Cargo.toml +++ b/examples/readme/Cargo.toml @@ -7,7 +7,6 @@ publish = false [dependencies] axum = { path = "../../axum" } serde = { version = "1.0", features = ["derive"] } -serde_json = "1.0.68" tokio = { version = "1.0", features = ["full"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/request-id/Cargo.toml b/examples/request-id/Cargo.toml new file mode 100644 index 0000000000..22879e0824 --- /dev/null +++ b/examples/request-id/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "example-request-id" +version = "0.1.0" +edition = "2021" +publish = false + +[dependencies] +axum = { path = "../../axum" } +tokio = { version = "1.0", features = ["full"] } +tower = "0.5" +tower-http = { version = "0.5", features = ["request-id", "trace"] } +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/request-id/src/main.rs b/examples/request-id/src/main.rs new file mode 100644 index 0000000000..552d8d4a81 --- /dev/null +++ b/examples/request-id/src/main.rs @@ -0,0 +1,81 @@ +//! Run with +//! +//! ```not_rust +//! cargo run -p example-request-id +//! ``` + +use axum::{ + http::{HeaderName, Request}, + response::Html, + routing::get, + Router, +}; +use tower::ServiceBuilder; +use tower_http::{ + request_id::{MakeRequestUuid, PropagateRequestIdLayer, SetRequestIdLayer}, + trace::TraceLayer, +}; +use tracing::{error, info, info_span}; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +const REQUEST_ID_HEADER: &str = "x-request-id"; + +#[tokio::main] +async fn main() { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { + // axum logs rejections from built-in extractors with the `axum::rejection` + // target, at `TRACE` level. `axum::rejection=trace` enables showing those events + format!( + "{}=debug,tower_http=debug,axum::rejection=trace", + env!("CARGO_CRATE_NAME") + ) + .into() + }), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + let x_request_id = HeaderName::from_static(REQUEST_ID_HEADER); + + let middleware = ServiceBuilder::new() + .layer(SetRequestIdLayer::new( + x_request_id.clone(), + MakeRequestUuid, + )) + .layer( + TraceLayer::new_for_http().make_span_with(|request: &Request<_>| { + // Log the request id as generated. + let request_id = request.headers().get(REQUEST_ID_HEADER); + + match request_id { + Some(request_id) => info_span!( + "http_request", + request_id = ?request_id, + ), + None => { + error!("could not extract request_id"); + info_span!("http_request") + } + } + }), + ) + // send headers from request to response headers + .layer(PropagateRequestIdLayer::new(x_request_id)); + + // build our application with a route + let app = Router::new().route("/", get(handler)).layer(middleware); + + // run it + let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") + .await + .unwrap(); + println!("listening on {}", listener.local_addr().unwrap()); + axum::serve(listener, app).await.unwrap(); +} + +async fn handler() -> Html<&'static str> { + info!("Hello world!"); + Html("

Hello, World!

") +} diff --git a/examples/reqwest-response/Cargo.toml b/examples/reqwest-response/Cargo.toml index 18320e9f66..3ea740e3cb 100644 --- a/examples/reqwest-response/Cargo.toml +++ b/examples/reqwest-response/Cargo.toml @@ -6,9 +6,9 @@ publish = false [dependencies] axum = { path = "../../axum" } -reqwest = { version = "0.11", features = ["stream"] } +reqwest = { version = "0.12", features = ["stream"] } tokio = { version = "1.0", features = ["full"] } tokio-stream = "0.1" -tower-http = { version = "0.5.0", features = ["trace"] } +tower-http = { version = "0.6.1", features = ["trace"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/reqwest-response/src/main.rs b/examples/reqwest-response/src/main.rs index b6dfeb70b5..a1ba33cf26 100644 --- a/examples/reqwest-response/src/main.rs +++ b/examples/reqwest-response/src/main.rs @@ -4,18 +4,16 @@ //! cargo run -p example-reqwest-response //! ``` -use std::{convert::Infallible, time::Duration}; - -use axum::http::{HeaderMap, StatusCode}; use axum::{ body::{Body, Bytes}, extract::State, - http::{HeaderName, HeaderValue}, + http::StatusCode, response::{IntoResponse, Response}, routing::get, Router, }; use reqwest::Client; +use std::{convert::Infallible, time::Duration}; use tokio_stream::StreamExt; use tower_http::trace::TraceLayer; use tracing::Span; @@ -25,8 +23,9 @@ use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; async fn main() { tracing_subscriber::registry() .with( - tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_reqwest_response=debug,tower_http=debug".into()), + tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { + format!("{}=debug,tower_http=debug", env!("CARGO_CRATE_NAME")).into() + }), ) .with(tracing_subscriber::fmt::layer()) .init(); @@ -34,7 +33,7 @@ async fn main() { let client = Client::new(); let app = Router::new() - .route("/", get(proxy_via_reqwest)) + .route("/", get(stream_reqwest_response)) .route("/stream", get(stream_some_data)) // Add some logging so we can see the streams going through .layer(TraceLayer::new_for_http().on_body_chunk( @@ -51,7 +50,7 @@ async fn main() { axum::serve(listener, app).await.unwrap(); } -async fn proxy_via_reqwest(State(client): State) -> Response { +async fn stream_reqwest_response(State(client): State) -> Response { let reqwest_response = match client.get("http://127.0.0.1:3000/stream").send().await { Ok(res) => res, Err(err) => { @@ -60,16 +59,8 @@ async fn proxy_via_reqwest(State(client): State) -> Response { } }; - let response_builder = Response::builder().status(reqwest_response.status().as_u16()); - - // Here the mapping of headers is required due to reqwest and axum differ on the http crate versions - let mut headers = HeaderMap::with_capacity(reqwest_response.headers().len()); - headers.extend(reqwest_response.headers().into_iter().map(|(name, value)| { - let name = HeaderName::from_bytes(name.as_ref()).unwrap(); - let value = HeaderValue::from_bytes(value.as_ref()).unwrap(); - (name, value) - })); - + let mut response_builder = Response::builder().status(reqwest_response.status()); + *response_builder.headers_mut().unwrap() = reqwest_response.headers().clone(); response_builder .body(Body::from_stream(reqwest_response.bytes_stream())) // This unwrap is fine because the body is empty here diff --git a/examples/rest-grpc-multiplex/Cargo.toml b/examples/rest-grpc-multiplex/Cargo.toml index 11a6a3a2b4..69ece3632e 100644 --- a/examples/rest-grpc-multiplex/Cargo.toml +++ b/examples/rest-grpc-multiplex/Cargo.toml @@ -8,13 +8,13 @@ publish = false axum = { path = "../../axum" } futures = "0.3" hyper = { version = "1.0.0", features = ["full"] } -prost = "0.11" -tokio = { version = "1", features = ["full"] } -tonic = { version = "0.9" } -tonic-reflection = "0.9" -tower = { version = "0.4", features = ["full"] } -tracing = "0.1" -tracing-subscriber = { version = "0.3", features = ["env-filter"] } +#prost = "0.11" +#tokio = { version = "1", features = ["full"] } +#tonic = { version = "0.9" } +#tonic-reflection = "0.9" +tower = { version = "0.5.1", features = ["full"] } +#tracing = "0.1" +#tracing-subscriber = { version = "0.3", features = ["env-filter"] } [build-dependencies] tonic-build = { version = "0.9", features = ["prost"] } diff --git a/examples/rest-grpc-multiplex/src/multiplex_service.rs b/examples/rest-grpc-multiplex/src/multiplex_service.rs index 80b612e12e..51550ec5ba 100644 --- a/examples/rest-grpc-multiplex/src/multiplex_service.rs +++ b/examples/rest-grpc-multiplex/src/multiplex_service.rs @@ -38,7 +38,7 @@ where Self { rest: self.rest.clone(), grpc: self.grpc.clone(), - // the cloned services probably wont be ready + // the cloned services probably won't be ready rest_ready: false, grpc_ready: false, } diff --git a/examples/serve-with-hyper/Cargo.toml b/examples/serve-with-hyper/Cargo.toml index 06f4607053..81553eb08b 100644 --- a/examples/serve-with-hyper/Cargo.toml +++ b/examples/serve-with-hyper/Cargo.toml @@ -9,4 +9,4 @@ axum = { path = "../../axum" } hyper = { version = "1.0", features = [] } hyper-util = { version = "0.1", features = ["tokio", "server-auto", "http1"] } tokio = { version = "1.0", features = ["full"] } -tower = { version = "0.4", features = ["util"] } +tower = { version = "0.5.1", features = ["util"] } diff --git a/examples/serve-with-hyper/src/main.rs b/examples/serve-with-hyper/src/main.rs index 0b9c93d295..8bad2acd49 100644 --- a/examples/serve-with-hyper/src/main.rs +++ b/examples/serve-with-hyper/src/main.rs @@ -6,7 +6,7 @@ //! //! This example shows how to run axum using hyper's low level API. //! -//! The [hyper-util] crate exists to provide high level utilities but its still in early stages of +//! The [hyper-util] crate exists to provide high level utilities but it's still in early stages of //! development. //! //! [hyper-util]: https://crates.io/crates/hyper-util @@ -43,7 +43,7 @@ async fn serve_plain() { // We don't need to call `poll_ready` because `Router` is always ready. let tower_service = app.clone(); - // Spawn a task to handle the connection. That way we can multiple connections + // Spawn a task to handle the connection. That way we can handle multiple connections // concurrently. tokio::spawn(async move { // Hyper has its own `AsyncRead` and `AsyncWrite` traits and doesn't use tokio. diff --git a/examples/simple-router-wasm/Cargo.toml b/examples/simple-router-wasm/Cargo.toml index c3041e4a58..d250cc2387 100644 --- a/examples/simple-router-wasm/Cargo.toml +++ b/examples/simple-router-wasm/Cargo.toml @@ -14,3 +14,6 @@ axum-extra = { path = "../../axum-extra", default-features = false } futures-executor = "0.3.21" http = "1.0.0" tower-service = "0.3.1" + +[package.metadata.cargo-machete] +ignored = ["axum-extra"] diff --git a/examples/sqlx-postgres/Cargo.toml b/examples/sqlx-postgres/Cargo.toml index 3bc40302ed..0a0c437630 100644 --- a/examples/sqlx-postgres/Cargo.toml +++ b/examples/sqlx-postgres/Cargo.toml @@ -10,4 +10,4 @@ tokio = { version = "1.0", features = ["full"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } -sqlx = { version = "0.7", features = ["runtime-tokio-rustls", "any", "postgres"] } +sqlx = { version = "0.8", features = ["runtime-tokio-rustls", "any", "postgres"] } diff --git a/examples/sqlx-postgres/src/main.rs b/examples/sqlx-postgres/src/main.rs index 465711157e..904a5a8aad 100644 --- a/examples/sqlx-postgres/src/main.rs +++ b/examples/sqlx-postgres/src/main.rs @@ -14,7 +14,6 @@ //! ``` use axum::{ - async_trait, extract::{FromRef, FromRequestParts, State}, http::{request::Parts, StatusCode}, routing::get, @@ -31,7 +30,7 @@ async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_tokio_postgres=debug".into()), + .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); @@ -75,7 +74,6 @@ async fn using_connection_pool_extractor( // which setup is appropriate depends on your application struct DatabaseConnection(sqlx::pool::PoolConnection); -#[async_trait] impl FromRequestParts for DatabaseConnection where PgPool: FromRef, diff --git a/examples/sse/Cargo.toml b/examples/sse/Cargo.toml index dec104bc36..138820db16 100644 --- a/examples/sse/Cargo.toml +++ b/examples/sse/Cargo.toml @@ -11,6 +11,11 @@ futures = "0.3" headers = "0.4" tokio = { version = "1.0", features = ["full"] } tokio-stream = "0.1" -tower-http = { version = "0.5.0", features = ["fs", "trace"] } +tower-http = { version = "0.6.1", features = ["fs", "trace"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } + +[dev-dependencies] +eventsource-stream = "0.2" +reqwest = { version = "0.12", features = ["stream"] } +reqwest-eventsource = "0.6" diff --git a/examples/sse/src/main.rs b/examples/sse/src/main.rs index 679412eab9..4f616f6b05 100644 --- a/examples/sse/src/main.rs +++ b/examples/sse/src/main.rs @@ -3,13 +3,17 @@ //! ```not_rust //! cargo run -p example-sse //! ``` +//! Test with +//! ```not_rust +//! cargo test -p example-sse +//! ``` use axum::{ response::sse::{Event, Sse}, routing::get, Router, }; -use axum_extra::{headers, TypedHeader}; +use axum_extra::TypedHeader; use futures::stream::{self, Stream}; use std::{convert::Infallible, path::PathBuf, time::Duration}; use tokio_stream::StreamExt as _; @@ -20,21 +24,15 @@ use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; async fn main() { tracing_subscriber::registry() .with( - tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_sse=debug,tower_http=debug".into()), + tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { + format!("{}=debug,tower_http=debug", env!("CARGO_CRATE_NAME")).into() + }), ) .with(tracing_subscriber::fmt::layer()) .init(); - let assets_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("assets"); - - let static_files_service = ServeDir::new(assets_dir).append_index_html_on_directories(true); - - // build our application with a route - let app = Router::new() - .fallback_service(static_files_service) - .route("/sse", get(sse_handler)) - .layer(TraceLayer::new_for_http()); + // build our application + let app = app(); // run it let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") @@ -44,6 +42,16 @@ async fn main() { axum::serve(listener, app).await.unwrap(); } +fn app() -> Router { + let assets_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("assets"); + let static_files_service = ServeDir::new(assets_dir).append_index_html_on_directories(true); + // build our application with a route + Router::new() + .fallback_service(static_files_service) + .route("/sse", get(sse_handler)) + .layer(TraceLayer::new_for_http()) +} + async fn sse_handler( TypedHeader(user_agent): TypedHeader, ) -> Sse>> { @@ -63,3 +71,58 @@ async fn sse_handler( .text("keep-alive-text"), ) } + +#[cfg(test)] +mod tests { + use eventsource_stream::Eventsource; + use tokio::net::TcpListener; + + use super::*; + + #[tokio::test] + async fn integration_test() { + // A helper function that spawns our application in the background + async fn spawn_app(host: impl Into) -> String { + let host = host.into(); + // Bind to localhost at the port 0, which will let the OS assign an available port to us + let listener = TcpListener::bind(format!("{}:0", host)).await.unwrap(); + // Retrieve the port assigned to us by the OS + let port = listener.local_addr().unwrap().port(); + tokio::spawn(async { + axum::serve(listener, app()).await.unwrap(); + }); + // Returns address (e.g. http://127.0.0.1{random_port}) + format!("http://{}:{}", host, port) + } + let listening_url = spawn_app("127.0.0.1").await; + + let mut event_stream = reqwest::Client::new() + .get(format!("{}/sse", listening_url)) + .header("User-Agent", "integration_test") + .send() + .await + .unwrap() + .bytes_stream() + .eventsource() + .take(1); + + let mut event_data: Vec = vec![]; + while let Some(event) = event_stream.next().await { + match event { + Ok(event) => { + // break the loop at the end of SSE stream + if event.data == "[DONE]" { + break; + } + + event_data.push(event.data); + } + Err(_) => { + panic!("Error in event stream"); + } + } + } + + assert!(event_data[0] == "hi!"); + } +} diff --git a/examples/static-file-server/Cargo.toml b/examples/static-file-server/Cargo.toml index 3f41d60816..ce1955432f 100644 --- a/examples/static-file-server/Cargo.toml +++ b/examples/static-file-server/Cargo.toml @@ -6,9 +6,8 @@ publish = false [dependencies] axum = { path = "../../axum" } -axum-extra = { path = "../../axum-extra" } tokio = { version = "1.0", features = ["full"] } -tower = { version = "0.4", features = ["util"] } -tower-http = { version = "0.5.0", features = ["fs", "trace"] } +tower = { version = "0.5.1", features = ["util"] } +tower-http = { version = "0.6.1", features = ["fs", "trace"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/static-file-server/src/main.rs b/examples/static-file-server/src/main.rs index 707d2ee3f3..148af57c04 100644 --- a/examples/static-file-server/src/main.rs +++ b/examples/static-file-server/src/main.rs @@ -19,8 +19,9 @@ use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; async fn main() { tracing_subscriber::registry() .with( - tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_static_file_server=debug,tower_http=debug".into()), + tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { + format!("{}=debug,tower_http=debug", env!("CARGO_CRATE_NAME")).into() + }), ) .with(tracing_subscriber::fmt::layer()) .init(); diff --git a/examples/stream-to-file/src/main.rs b/examples/stream-to-file/src/main.rs index a595d0d834..7c44286d87 100644 --- a/examples/stream-to-file/src/main.rs +++ b/examples/stream-to-file/src/main.rs @@ -25,7 +25,7 @@ async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_stream_to_file=debug".into()), + .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); @@ -37,7 +37,7 @@ async fn main() { let app = Router::new() .route("/", get(show_form).post(accept_form)) - .route("/file/:file_name", post(save_request_body)); + .route("/file/{file_name}", post(save_request_body)); let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") .await diff --git a/examples/templates-minijinja/Cargo.toml b/examples/templates-minijinja/Cargo.toml new file mode 100644 index 0000000000..692ea0ca12 --- /dev/null +++ b/examples/templates-minijinja/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "example-templates-minijinja" +version = "0.1.0" +edition = "2021" +publish = false + +[dependencies] +axum = { path = "../../axum" } +minijinja = "2.3.1" +tokio = { version = "1.0", features = ["full"] } diff --git a/examples/templates-minijinja/src/main.rs b/examples/templates-minijinja/src/main.rs new file mode 100644 index 0000000000..2903b824bf --- /dev/null +++ b/examples/templates-minijinja/src/main.rs @@ -0,0 +1,87 @@ +//! Run with +//! +//! ```not_rust +//! cargo run -p example-templates-minijinja +//! ``` +//! Demo for the MiniJinja templating engine. +//! Exposes three pages all sharing the same layout with a minimal nav menu. + +use axum::extract::State; +use axum::http::StatusCode; +use axum::{response::Html, routing::get, Router}; +use minijinja::{context, Environment}; +use std::sync::Arc; + +struct AppState { + env: Environment<'static>, +} + +#[tokio::main] +async fn main() { + // init template engine and add templates + let mut env = Environment::new(); + env.add_template("layout", include_str!("../templates/layout.jinja")) + .unwrap(); + env.add_template("home", include_str!("../templates/home.jinja")) + .unwrap(); + env.add_template("content", include_str!("../templates/content.jinja")) + .unwrap(); + env.add_template("about", include_str!("../templates/about.jinja")) + .unwrap(); + + // pass env to handlers via state + let app_state = Arc::new(AppState { env }); + + // define routes + let app = Router::new() + .route("/", get(handler_home)) + .route("/content", get(handler_content)) + .route("/about", get(handler_about)) + .with_state(app_state); + + // run it + let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") + .await + .unwrap(); + println!("listening on {}", listener.local_addr().unwrap()); + axum::serve(listener, app).await.unwrap(); +} + +async fn handler_home(State(state): State>) -> Result, StatusCode> { + let template = state.env.get_template("home").unwrap(); + + let rendered = template + .render(context! { + title => "Home", + welcome_text => "Hello World!", + }) + .unwrap(); + + Ok(Html(rendered)) +} + +async fn handler_content(State(state): State>) -> Result, StatusCode> { + let template = state.env.get_template("content").unwrap(); + + let some_example_entries = vec!["Data 1", "Data 2", "Data 3"]; + + let rendered = template + .render(context! { + title => "Content", + entries => some_example_entries, + }) + .unwrap(); + + Ok(Html(rendered)) +} + +async fn handler_about(State(state): State>) -> Result, StatusCode> { + let template = state.env.get_template("about").unwrap(); + + let rendered = template.render(context!{ + title => "About", + about_text => "Simple demonstration layout for an axum project with minijinja as templating engine.", + }).unwrap(); + + Ok(Html(rendered)) +} diff --git a/examples/templates-minijinja/templates/about.jinja b/examples/templates-minijinja/templates/about.jinja new file mode 100644 index 0000000000..ba8c97e3ce --- /dev/null +++ b/examples/templates-minijinja/templates/about.jinja @@ -0,0 +1,6 @@ +{% extends "layout" %} +{% block title %}{{ super() }} | {{ title }} {% endblock %} +{% block body %} +

{{ title }}

+

{{ about_text }}

+{% endblock %} diff --git a/examples/templates-minijinja/templates/content.jinja b/examples/templates-minijinja/templates/content.jinja new file mode 100644 index 0000000000..b3fbfa6c79 --- /dev/null +++ b/examples/templates-minijinja/templates/content.jinja @@ -0,0 +1,10 @@ +{% extends "layout" %} +{% block title %}{{ super() }} | {{ title }} {% endblock %} +{% block body %} +

{{ title }}

+{% for data_entry in entries %} +
    +
  • {{ data_entry }}
  • +
+{% endfor %} +{% endblock %} diff --git a/examples/templates-minijinja/templates/home.jinja b/examples/templates-minijinja/templates/home.jinja new file mode 100644 index 0000000000..2d231db34b --- /dev/null +++ b/examples/templates-minijinja/templates/home.jinja @@ -0,0 +1,6 @@ +{% extends "layout" %} +{% block title %}{{ super() }} | {{ title }} {% endblock %} +{% block body %} +

{{ title }}

+

{{ welcome_text }}

+{% endblock %} diff --git a/examples/templates-minijinja/templates/layout.jinja b/examples/templates-minijinja/templates/layout.jinja new file mode 100644 index 0000000000..9ef56205e3 --- /dev/null +++ b/examples/templates-minijinja/templates/layout.jinja @@ -0,0 +1,14 @@ + + + {% block title %}Website Name{% endblock %} + +
+ {% block body %}{% endblock %} + + diff --git a/examples/templates/Cargo.toml b/examples/templates/Cargo.toml index 2f5aba2791..6cba09469f 100644 --- a/examples/templates/Cargo.toml +++ b/examples/templates/Cargo.toml @@ -5,7 +5,7 @@ edition = "2021" publish = false [dependencies] -askama = "0.11" +askama = "0.12" axum = { path = "../../axum" } tokio = { version = "1.0", features = ["full"] } tracing = "0.1" diff --git a/examples/templates/src/main.rs b/examples/templates/src/main.rs index 3a6c82316c..872471c235 100644 --- a/examples/templates/src/main.rs +++ b/examples/templates/src/main.rs @@ -19,13 +19,13 @@ async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_templates=debug".into()), + .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); // build our application with some routes - let app = Router::new().route("/greet/:name", get(greet)); + let app = Router::new().route("/greet/{name}", get(greet)); // run it let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") diff --git a/examples/testing-websockets/Cargo.toml b/examples/testing-websockets/Cargo.toml index 842624c9e5..31ed2601f0 100644 --- a/examples/testing-websockets/Cargo.toml +++ b/examples/testing-websockets/Cargo.toml @@ -7,6 +7,5 @@ publish = false [dependencies] axum = { path = "../../axum", features = ["ws"] } futures = "0.3" -hyper = { version = "1.0.0", features = ["full"] } tokio = { version = "1.0", features = ["full"] } -tokio-tungstenite = "0.21" +tokio-tungstenite = "0.24" diff --git a/examples/testing/Cargo.toml b/examples/testing/Cargo.toml index 00e8132f73..811e4f6056 100644 --- a/examples/testing/Cargo.toml +++ b/examples/testing/Cargo.toml @@ -7,14 +7,13 @@ publish = false [dependencies] axum = { path = "../../axum" } http-body-util = "0.1.0" -hyper = { version = "1.0.0", features = ["full"] } hyper-util = { version = "0.1", features = ["client", "http1", "client-legacy"] } mime = "0.3" serde_json = "1.0" tokio = { version = "1.0", features = ["full"] } -tower-http = { version = "0.5.0", features = ["trace"] } +tower-http = { version = "0.6.1", features = ["trace"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } [dev-dependencies] -tower = { version = "0.4", features = ["util"] } +tower = { version = "0.5.1", features = ["util"] } diff --git a/examples/testing/src/main.rs b/examples/testing/src/main.rs index 2879b69a74..c6dbcf5c16 100644 --- a/examples/testing/src/main.rs +++ b/examples/testing/src/main.rs @@ -18,8 +18,9 @@ use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; async fn main() { tracing_subscriber::registry() .with( - tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_testing=debug,tower_http=debug".into()), + tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { + format!("{}=debug,tower_http=debug", env!("CARGO_CRATE_NAME")).into() + }), ) .with(tracing_subscriber::fmt::layer()) .init(); @@ -43,7 +44,7 @@ fn app() -> Router { }), ) .route( - "/requires-connect-into", + "/requires-connect-info", get(|ConnectInfo(addr): ConnectInfo| async move { format!("Hi {addr}") }), ) // We can still add middleware @@ -60,7 +61,6 @@ mod tests { }; use http_body_util::BodyExt; // for `collect` use serde_json::{json, Value}; - use std::net::SocketAddr; use tokio::net::TcpListener; use tower::{Service, ServiceExt}; // for `call`, `oneshot`, and `ready` @@ -179,7 +179,7 @@ mod tests { assert_eq!(response.status(), StatusCode::OK); } - // Here we're calling `/requires-connect-into` which requires `ConnectInfo` + // Here we're calling `/requires-connect-info` which requires `ConnectInfo` // // That is normally set with `Router::into_make_service_with_connect_info` but we can't easily // use that during tests. The solution is instead to set the `MockConnectInfo` layer during @@ -191,7 +191,7 @@ mod tests { .into_service(); let request = Request::builder() - .uri("/requires-connect-into") + .uri("/requires-connect-info") .body(Body::empty()) .unwrap(); let response = app.ready().await.unwrap().call(request).await.unwrap(); diff --git a/examples/tls-graceful-shutdown/Cargo.toml b/examples/tls-graceful-shutdown/Cargo.toml index 7b0169ba8f..517e5a38c3 100644 --- a/examples/tls-graceful-shutdown/Cargo.toml +++ b/examples/tls-graceful-shutdown/Cargo.toml @@ -6,8 +6,8 @@ publish = false [dependencies] axum = { path = "../../axum" } -axum-server = { version = "0.3", features = ["tls-rustls"] } -hyper = { version = "0.14", features = ["full"] } +axum-extra = { path = "../../axum-extra" } +axum-server = { version = "0.7", features = ["tls-rustls"] } tokio = { version = "1", features = ["full"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/tls-graceful-shutdown/src/main.rs b/examples/tls-graceful-shutdown/src/main.rs index 13251846de..344ecf9d05 100644 --- a/examples/tls-graceful-shutdown/src/main.rs +++ b/examples/tls-graceful-shutdown/src/main.rs @@ -4,140 +4,138 @@ //! cargo run -p example-tls-graceful-shutdown //! ``` -fn main() { - // This example has not yet been updated to Hyper 1.0 +use axum::{ + handler::HandlerWithoutStateExt, + http::{StatusCode, Uri}, + response::Redirect, + routing::get, + BoxError, Router, +}; +use axum_extra::extract::Host; +use axum_server::tls_rustls::RustlsConfig; +use std::{future::Future, net::SocketAddr, path::PathBuf, time::Duration}; +use tokio::signal; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +#[derive(Clone, Copy)] +struct Ports { + http: u16, + https: u16, } -//use axum::{ -// extract::Host, -// handler::HandlerWithoutStateExt, -// http::{StatusCode, Uri}, -// response::Redirect, -// routing::get, -// BoxError, Router, -//}; -//use axum_server::tls_rustls::RustlsConfig; -//use std::{future::Future, net::SocketAddr, path::PathBuf, time::Duration}; -//use tokio::signal; -//use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; - -//#[derive(Clone, Copy)] -//struct Ports { -// http: u16, -// https: u16, -//} - -//#[tokio::main] -//async fn main() { -// tracing_subscriber::registry() -// .with( -// tracing_subscriber::EnvFilter::try_from_default_env() -// .unwrap_or_else(|_| "example_tls_graceful_shutdown=debug".into()), -// ) -// .with(tracing_subscriber::fmt::layer()) -// .init(); - -// let ports = Ports { -// http: 7878, -// https: 3000, -// }; - -// //Create a handle for our TLS server so the shutdown signal can all shutdown -// let handle = axum_server::Handle::new(); -// //save the future for easy shutting down of redirect server -// let shutdown_future = shutdown_signal(handle.clone()); - -// // optional: spawn a second server to redirect http requests to this server -// tokio::spawn(redirect_http_to_https(ports, shutdown_future)); - -// // configure certificate and private key used by https -// let config = RustlsConfig::from_pem_file( -// PathBuf::from(env!("CARGO_MANIFEST_DIR")) -// .join("self_signed_certs") -// .join("cert.pem"), -// PathBuf::from(env!("CARGO_MANIFEST_DIR")) -// .join("self_signed_certs") -// .join("key.pem"), -// ) -// .await -// .unwrap(); - -// let app = Router::new().route("/", get(handler)); - -// // run https server -// let addr = SocketAddr::from(([127, 0, 0, 1], ports.https)); -// tracing::debug!("listening on {addr}"); -// axum_server::bind_rustls(addr, config) -// .handle(handle) -// .serve(app.into_make_service()) -// .await -// .unwrap(); -//} - -//async fn shutdown_signal(handle: axum_server::Handle) { -// let ctrl_c = async { -// signal::ctrl_c() -// .await -// .expect("failed to install Ctrl+C handler"); -// }; - -// #[cfg(unix)] -// let terminate = async { -// signal::unix::signal(signal::unix::SignalKind::terminate()) -// .expect("failed to install signal handler") -// .recv() -// .await; -// }; - -// #[cfg(not(unix))] -// let terminate = std::future::pending::<()>(); - -// tokio::select! { -// _ = ctrl_c => {}, -// _ = terminate => {}, -// } - -// tracing::info!("Received termination signal shutting down"); -// handle.graceful_shutdown(Some(Duration::from_secs(10))); // 10 secs is how long docker will wait -// // to force shutdown -//} - -//async fn handler() -> &'static str { -// "Hello, World!" -//} - -//async fn redirect_http_to_https(ports: Ports, signal: impl Future) { -// fn make_https(host: String, uri: Uri, ports: Ports) -> Result { -// let mut parts = uri.into_parts(); - -// parts.scheme = Some(axum::http::uri::Scheme::HTTPS); - -// if parts.path_and_query.is_none() { -// parts.path_and_query = Some("/".parse().unwrap()); -// } - -// let https_host = host.replace(&ports.http.to_string(), &ports.https.to_string()); -// parts.authority = Some(https_host.parse()?); - -// Ok(Uri::from_parts(parts)?) -// } - -// let redirect = move |Host(host): Host, uri: Uri| async move { -// match make_https(host, uri, ports) { -// Ok(uri) => Ok(Redirect::permanent(&uri.to_string())), -// Err(error) => { -// tracing::warn!(%error, "failed to convert URI to HTTPS"); -// Err(StatusCode::BAD_REQUEST) -// } -// } -// }; - -// let addr = SocketAddr::from(([127, 0, 0, 1], ports.http)); -// //let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); -// tracing::debug!("listening on {addr}"); -// hyper::Server::bind(&addr) -// .serve(redirect.into_make_service()) -// .with_graceful_shutdown(signal) -// .await -// .unwrap(); -//} +#[tokio::main] +async fn main() { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + let ports = Ports { + http: 7878, + https: 3000, + }; + + //Create a handle for our TLS server so the shutdown signal can all shutdown + let handle = axum_server::Handle::new(); + //save the future for easy shutting down of redirect server + let shutdown_future = shutdown_signal(handle.clone()); + + // optional: spawn a second server to redirect http requests to this server + tokio::spawn(redirect_http_to_https(ports, shutdown_future)); + + // configure certificate and private key used by https + let config = RustlsConfig::from_pem_file( + PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("self_signed_certs") + .join("cert.pem"), + PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("self_signed_certs") + .join("key.pem"), + ) + .await + .unwrap(); + + let app = Router::new().route("/", get(handler)); + + // run https server + let addr = SocketAddr::from(([127, 0, 0, 1], ports.https)); + tracing::debug!("listening on {addr}"); + axum_server::bind_rustls(addr, config) + .handle(handle) + .serve(app.into_make_service()) + .await + .unwrap(); +} + +async fn shutdown_signal(handle: axum_server::Handle) { + let ctrl_c = async { + signal::ctrl_c() + .await + .expect("failed to install Ctrl+C handler"); + }; + + #[cfg(unix)] + let terminate = async { + signal::unix::signal(signal::unix::SignalKind::terminate()) + .expect("failed to install signal handler") + .recv() + .await; + }; + + #[cfg(not(unix))] + let terminate = std::future::pending::<()>(); + + tokio::select! { + _ = ctrl_c => {}, + _ = terminate => {}, + } + + tracing::info!("Received termination signal shutting down"); + handle.graceful_shutdown(Some(Duration::from_secs(10))); // 10 secs is how long docker will wait + // to force shutdown +} + +async fn handler() -> &'static str { + "Hello, World!" +} + +async fn redirect_http_to_https(ports: Ports, signal: F) +where + F: Future + Send + 'static, +{ + fn make_https(host: String, uri: Uri, ports: Ports) -> Result { + let mut parts = uri.into_parts(); + + parts.scheme = Some(axum::http::uri::Scheme::HTTPS); + + if parts.path_and_query.is_none() { + parts.path_and_query = Some("/".parse().unwrap()); + } + + let https_host = host.replace(&ports.http.to_string(), &ports.https.to_string()); + parts.authority = Some(https_host.parse()?); + + Ok(Uri::from_parts(parts)?) + } + + let redirect = move |Host(host): Host, uri: Uri| async move { + match make_https(host, uri, ports) { + Ok(uri) => Ok(Redirect::permanent(&uri.to_string())), + Err(error) => { + tracing::warn!(%error, "failed to convert URI to HTTPS"); + Err(StatusCode::BAD_REQUEST) + } + } + }; + + let addr = SocketAddr::from(([127, 0, 0, 1], ports.http)); + let listener = tokio::net::TcpListener::bind(addr).await.unwrap(); + tracing::debug!("listening on {addr}"); + axum::serve(listener, redirect.into_make_service()) + .with_graceful_shutdown(signal) + .await + .unwrap(); +} diff --git a/examples/tls-rustls/Cargo.toml b/examples/tls-rustls/Cargo.toml index 4c255c2763..c5cd65ad41 100644 --- a/examples/tls-rustls/Cargo.toml +++ b/examples/tls-rustls/Cargo.toml @@ -6,7 +6,8 @@ publish = false [dependencies] axum = { path = "../../axum" } -axum-server = { version = "0.6", features = ["tls-rustls"] } +axum-extra = { path = "../../axum-extra" } +axum-server = { version = "0.7", features = ["tls-rustls"] } tokio = { version = "1", features = ["full"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/tls-rustls/src/main.rs b/examples/tls-rustls/src/main.rs index 3649c75cb1..bd8b07fcfd 100644 --- a/examples/tls-rustls/src/main.rs +++ b/examples/tls-rustls/src/main.rs @@ -7,13 +7,13 @@ #![allow(unused_imports)] use axum::{ - extract::Host, handler::HandlerWithoutStateExt, http::{StatusCode, Uri}, response::Redirect, routing::get, BoxError, Router, }; +use axum_extra::extract::Host; use axum_server::tls_rustls::RustlsConfig; use std::{net::SocketAddr, path::PathBuf}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; @@ -30,7 +30,7 @@ async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_tls_rustls=debug".into()), + .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); diff --git a/examples/todos/Cargo.toml b/examples/todos/Cargo.toml index dbd8b7125a..127fbee47c 100644 --- a/examples/todos/Cargo.toml +++ b/examples/todos/Cargo.toml @@ -8,8 +8,8 @@ publish = false axum = { path = "../../axum" } serde = { version = "1.0", features = ["derive"] } tokio = { version = "1.0", features = ["full"] } -tower = { version = "0.4", features = ["util", "timeout"] } -tower-http = { version = "0.5.0", features = ["add-extension", "trace"] } +tower = { version = "0.5.1", features = ["util", "timeout"] } +tower-http = { version = "0.6.1", features = ["add-extension", "trace"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } uuid = { version = "1.0", features = ["serde", "v4"] } diff --git a/examples/todos/src/main.rs b/examples/todos/src/main.rs index 2fdac41bd5..cef395b715 100644 --- a/examples/todos/src/main.rs +++ b/examples/todos/src/main.rs @@ -4,8 +4,8 @@ //! //! - `GET /todos`: return a JSON list of Todos. //! - `POST /todos`: create a new Todo. -//! - `PATCH /todos/:id`: update a specific Todo. -//! - `DELETE /todos/:id`: delete a specific Todo. +//! - `PATCH /todos/{id}`: update a specific Todo. +//! - `DELETE /todos/{id}`: delete a specific Todo. //! //! Run with //! @@ -36,8 +36,9 @@ use uuid::Uuid; async fn main() { tracing_subscriber::registry() .with( - tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_todos=debug,tower_http=debug".into()), + tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { + format!("{}=debug,tower_http=debug", env!("CARGO_CRATE_NAME")).into() + }), ) .with(tracing_subscriber::fmt::layer()) .init(); @@ -47,7 +48,7 @@ async fn main() { // Compose the routes let app = Router::new() .route("/todos", get(todos_index).post(todos_create)) - .route("/todos/:id", patch(todos_update).delete(todos_delete)) + .route("/todos/{id}", patch(todos_update).delete(todos_delete)) // Add middleware to all routes .layer( ServiceBuilder::new() @@ -81,14 +82,9 @@ pub struct Pagination { pub limit: Option, } -async fn todos_index( - pagination: Option>, - State(db): State, -) -> impl IntoResponse { +async fn todos_index(pagination: Query, State(db): State) -> impl IntoResponse { let todos = db.read().unwrap(); - let Query(pagination) = pagination.unwrap_or_default(); - let todos = todos .values() .skip(pagination.offset.unwrap_or(0)) diff --git a/examples/tokio-postgres/Cargo.toml b/examples/tokio-postgres/Cargo.toml index 74806044ca..e14520d23e 100644 --- a/examples/tokio-postgres/Cargo.toml +++ b/examples/tokio-postgres/Cargo.toml @@ -6,8 +6,8 @@ publish = false [dependencies] axum = { path = "../../axum" } -bb8 = "0.7.1" -bb8-postgres = "0.7.0" +bb8 = "0.8.5" +bb8-postgres = "0.8.1" tokio = { version = "1.0", features = ["full"] } tokio-postgres = "0.7.2" tracing = "0.1" diff --git a/examples/tokio-postgres/src/main.rs b/examples/tokio-postgres/src/main.rs index effc032089..7df9917b92 100644 --- a/examples/tokio-postgres/src/main.rs +++ b/examples/tokio-postgres/src/main.rs @@ -5,7 +5,6 @@ //! ``` use axum::{ - async_trait, extract::{FromRef, FromRequestParts, State}, http::{request::Parts, StatusCode}, routing::get, @@ -21,7 +20,7 @@ async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_tokio_postgres=debug".into()), + .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); @@ -68,7 +67,6 @@ async fn using_connection_pool_extractor( // which setup is appropriate depends on your application struct DatabaseConnection(PooledConnection<'static, PostgresConnectionManager>); -#[async_trait] impl FromRequestParts for DatabaseConnection where ConnectionPool: FromRef, diff --git a/examples/tokio-redis/Cargo.toml b/examples/tokio-redis/Cargo.toml new file mode 100644 index 0000000000..86d8513b90 --- /dev/null +++ b/examples/tokio-redis/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "example-tokio-redis" +version = "0.1.0" +edition = "2021" +publish = false + +[dependencies] +axum = { path = "../../axum" } +bb8 = "0.8.5" +bb8-redis = "0.17.0" +redis = "0.27.2" +tokio = { version = "1.0", features = ["full"] } +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/tokio-redis/src/main.rs b/examples/tokio-redis/src/main.rs new file mode 100644 index 0000000000..105b1de46c --- /dev/null +++ b/examples/tokio-redis/src/main.rs @@ -0,0 +1,104 @@ +//! Run with +//! +//! ```not_rust +//! cargo run -p example-tokio-redis +//! ``` + +use axum::{ + extract::{FromRef, FromRequestParts, State}, + http::{request::Parts, StatusCode}, + routing::get, + Router, +}; +use bb8::{Pool, PooledConnection}; +use bb8_redis::RedisConnectionManager; +use redis::AsyncCommands; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +use bb8_redis::bb8; + +#[tokio::main] +async fn main() { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + tracing::debug!("connecting to redis"); + let manager = RedisConnectionManager::new("redis://localhost").unwrap(); + let pool = bb8::Pool::builder().build(manager).await.unwrap(); + + { + // ping the database before starting + let mut conn = pool.get().await.unwrap(); + conn.set::<&str, &str, ()>("foo", "bar").await.unwrap(); + let result: String = conn.get("foo").await.unwrap(); + assert_eq!(result, "bar"); + } + tracing::debug!("successfully connected to redis and pinged it"); + + // build our application with some routes + let app = Router::new() + .route( + "/", + get(using_connection_pool_extractor).post(using_connection_extractor), + ) + .with_state(pool); + + // run it + let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") + .await + .unwrap(); + tracing::debug!("listening on {}", listener.local_addr().unwrap()); + axum::serve(listener, app).await.unwrap(); +} + +type ConnectionPool = Pool; + +async fn using_connection_pool_extractor( + State(pool): State, +) -> Result { + let mut conn = pool.get().await.map_err(internal_error)?; + let result: String = conn.get("foo").await.map_err(internal_error)?; + Ok(result) +} + +// we can also write a custom extractor that grabs a connection from the pool +// which setup is appropriate depends on your application +struct DatabaseConnection(PooledConnection<'static, RedisConnectionManager>); + +impl FromRequestParts for DatabaseConnection +where + ConnectionPool: FromRef, + S: Send + Sync, +{ + type Rejection = (StatusCode, String); + + async fn from_request_parts(_parts: &mut Parts, state: &S) -> Result { + let pool = ConnectionPool::from_ref(state); + + let conn = pool.get_owned().await.map_err(internal_error)?; + + Ok(Self(conn)) + } +} + +async fn using_connection_extractor( + DatabaseConnection(mut conn): DatabaseConnection, +) -> Result { + let result: String = conn.get("foo").await.map_err(internal_error)?; + + Ok(result) +} + +/// Utility function for mapping any error into a `500 Internal Server Error` +/// response. +fn internal_error(err: E) -> (StatusCode, String) +where + E: std::error::Error, +{ + (StatusCode::INTERNAL_SERVER_ERROR, err.to_string()) +} diff --git a/examples/tracing-aka-logging/Cargo.toml b/examples/tracing-aka-logging/Cargo.toml index 4004cd596b..3d1204723d 100644 --- a/examples/tracing-aka-logging/Cargo.toml +++ b/examples/tracing-aka-logging/Cargo.toml @@ -7,6 +7,6 @@ publish = false [dependencies] axum = { path = "../../axum", features = ["tracing"] } tokio = { version = "1.0", features = ["full"] } -tower-http = { version = "0.5.0", features = ["trace"] } +tower-http = { version = "0.6.1", features = ["trace"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/tracing-aka-logging/src/main.rs b/examples/tracing-aka-logging/src/main.rs index 74a2055a07..30c16f1962 100644 --- a/examples/tracing-aka-logging/src/main.rs +++ b/examples/tracing-aka-logging/src/main.rs @@ -25,7 +25,11 @@ async fn main() { tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { // axum logs rejections from built-in extractors with the `axum::rejection` // target, at `TRACE` level. `axum::rejection=trace` enables showing those events - "example_tracing_aka_logging=debug,tower_http=debug,axum::rejection=trace".into() + format!( + "{}=debug,tower_http=debug,axum::rejection=trace", + env!("CARGO_CRATE_NAME") + ) + .into() }), ) .with(tracing_subscriber::fmt::layer()) diff --git a/examples/unix-domain-socket/Cargo.toml b/examples/unix-domain-socket/Cargo.toml index 7f157c7dcb..94ceb04080 100644 --- a/examples/unix-domain-socket/Cargo.toml +++ b/examples/unix-domain-socket/Cargo.toml @@ -10,6 +10,5 @@ http-body-util = "0.1" hyper = { version = "1.0.0", features = ["full"] } hyper-util = { version = "0.1", features = ["tokio", "server-auto", "http1"] } tokio = { version = "1.0", features = ["full"] } -tower = { version = "0.4", features = ["util"] } -tracing = "0.1" +tower = { version = "0.5.1", features = ["util"] } tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/unix-domain-socket/src/main.rs b/examples/unix-domain-socket/src/main.rs index d11792dd70..697f31a557 100644 --- a/examples/unix-domain-socket/src/main.rs +++ b/examples/unix-domain-socket/src/main.rs @@ -3,7 +3,6 @@ //! ```not_rust //! cargo run -p example-unix-domain-socket //! ``` - #[cfg(unix)] #[tokio::main] async fn main() { diff --git a/examples/validator/Cargo.toml b/examples/validator/Cargo.toml index a1adc075a8..8a7e6928d8 100644 --- a/examples/validator/Cargo.toml +++ b/examples/validator/Cargo.toml @@ -5,12 +5,10 @@ publish = false version = "0.1.0" [dependencies] -async-trait = "0.1.67" axum = { path = "../../axum" } -http-body = "1.0.0" serde = { version = "1.0", features = ["derive"] } thiserror = "1.0.29" tokio = { version = "1.0", features = ["full"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } -validator = { version = "0.14.0", features = ["derive"] } +validator = { version = "0.18.1", features = ["derive"] } diff --git a/examples/validator/src/main.rs b/examples/validator/src/main.rs index 85c4ac1843..00e46173c4 100644 --- a/examples/validator/src/main.rs +++ b/examples/validator/src/main.rs @@ -10,7 +10,6 @@ //! ->

Hello, LT!

//! ``` -use async_trait::async_trait; use axum::{ extract::{rejection::FormRejection, Form, FromRequest, Request}, http::StatusCode, @@ -29,7 +28,7 @@ async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_validator=debug".into()), + .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); @@ -56,7 +55,6 @@ async fn handler(ValidatedForm(input): ValidatedForm) -> Html #[derive(Debug, Clone, Copy, Default)] pub struct ValidatedForm(pub T); -#[async_trait] impl FromRequest for ValidatedForm where T: DeserializeOwned + Validate, diff --git a/examples/versioning/src/main.rs b/examples/versioning/src/main.rs index a1d96e8340..7b3ca5a581 100644 --- a/examples/versioning/src/main.rs +++ b/examples/versioning/src/main.rs @@ -5,7 +5,6 @@ //! ``` use axum::{ - async_trait, extract::{FromRequestParts, Path}, http::{request::Parts, StatusCode}, response::{IntoResponse, Response}, @@ -20,13 +19,13 @@ async fn main() { tracing_subscriber::registry() .with( tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_versioning=debug".into()), + .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), ) .with(tracing_subscriber::fmt::layer()) .init(); // build our application with some routes - let app = Router::new().route("/:version/foo", get(handler)); + let app = Router::new().route("/{version}/foo", get(handler)); // run it let listener = tokio::net::TcpListener::bind("127.0.0.1:3000") @@ -47,7 +46,6 @@ enum Version { V3, } -#[async_trait] impl FromRequestParts for Version where S: Send + Sync, diff --git a/examples/websockets-http2/Cargo.toml b/examples/websockets-http2/Cargo.toml new file mode 100644 index 0000000000..19a8d0d781 --- /dev/null +++ b/examples/websockets-http2/Cargo.toml @@ -0,0 +1,13 @@ +[package] +name = "example-websockets-http2" +version = "0.1.0" +edition = "2021" +publish = false + +[dependencies] +axum = { path = "../../axum", features = ["ws", "http2"] } +axum-server = { version = "0.6", features = ["tls-rustls"] } +tokio = { version = "1", features = ["full"] } +tower-http = { version = "0.5.0", features = ["fs"] } +tracing = "0.1" +tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/websockets-http2/assets/index.html b/examples/websockets-http2/assets/index.html new file mode 100644 index 0000000000..6a373782f4 --- /dev/null +++ b/examples/websockets-http2/assets/index.html @@ -0,0 +1,7 @@ +

Open this page in two windows and try sending some messages!

+
+ + +
+
+ diff --git a/examples/websockets-http2/assets/script.js b/examples/websockets-http2/assets/script.js new file mode 100644 index 0000000000..952c21dadc --- /dev/null +++ b/examples/websockets-http2/assets/script.js @@ -0,0 +1,11 @@ +const socket = new WebSocket('wss://localhost:3000/ws'); + +socket.addEventListener('message', e => { + document.getElementById("messages").append(e.data, document.createElement("br")); +}); + +const form = document.querySelector("form"); +form.addEventListener("submit", () => { + socket.send(form.elements.namedItem("content").value); + form.elements.namedItem("content").value = ""; +}); diff --git a/examples/websockets-http2/self_signed_certs/cert.pem b/examples/websockets-http2/self_signed_certs/cert.pem new file mode 100644 index 0000000000..656aa88055 --- /dev/null +++ b/examples/websockets-http2/self_signed_certs/cert.pem @@ -0,0 +1,22 @@ +-----BEGIN CERTIFICATE----- +MIIDkzCCAnugAwIBAgIUXVYkRCrM/ge03DVymDtXCuybp7gwDQYJKoZIhvcNAQEL +BQAwWTELMAkGA1UEBhMCVVMxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoM +GEludGVybmV0IFdpZGdpdHMgUHR5IEx0ZDESMBAGA1UEAwwJbG9jYWxob3N0MB4X +DTIxMDczMTE0MjIxMloXDTIyMDczMTE0MjIxMlowWTELMAkGA1UEBhMCVVMxEzAR +BgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGEludGVybmV0IFdpZGdpdHMgUHR5 +IEx0ZDESMBAGA1UEAwwJbG9jYWxob3N0MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A +MIIBCgKCAQEA02V5ZjmqLB/VQwTarrz/35qsa83L+DbAoa0001+jVmmC+G9Nufi0 +daroFWj/Uicv2fZWETU8JoZKUrX4BK9og5cg5rln/CtBRWCUYIwRgY9R/CdBGPn4 +kp+XkSJaCw74ZIyLy/Zfux6h8ES1m9YRnBza+s7U+ImRBRf4MRPtXQ3/mqJxAZYq +dOnKnvssRyD2qutgVTAxwMUvJWIivRhRYDj7WOpS4CEEeQxP1iH1/T5P7FdtTGdT +bVBABCA8JhL96uFGPpOYHcM/7R5EIA3yZ5FNg931QzoDITjtXGtQ6y9/l/IYkWm6 +J67RWcN0IoTsZhz0WNU4gAeslVtJLofn8QIDAQABo1MwUTAdBgNVHQ4EFgQUzFnK +NfS4LAYuKeWwHbzooER0yZ0wHwYDVR0jBBgwFoAUzFnKNfS4LAYuKeWwHbzooER0 +yZ0wDwYDVR0TAQH/BAUwAwEB/zANBgkqhkiG9w0BAQsFAAOCAQEAk4O+e9jia59W +ZwetN4GU7OWcYhmOgSizRSs6u7mTfp62LDMt96WKU3THksOnZ44HnqWQxsSfdFVU +XJD12tjvVU8Z4FWzQajcHeemUYiDze8EAh6TnxnUcOrU8IcwiKGxCWRY/908jnWg ++MMscfMCMYTRdeTPqD8fGzAlUCtmyzH6KLE3s4Oo/r5+NR+Uvrwpdvb7xe0MwwO9 +Q/zR4N8ep/HwHVEObcaBofE1ssZLksX7ZgCP9wMgXRWpNAtC5EWxMbxYjBfWFH24 +fDJlBMiGJWg8HHcxK7wQhFh+fuyNzE+xEWPsI9VL1zDftd9x8/QsOagyEOnY8Vxr +AopvZ09uEQ== +-----END CERTIFICATE----- diff --git a/examples/websockets-http2/self_signed_certs/key.pem b/examples/websockets-http2/self_signed_certs/key.pem new file mode 100644 index 0000000000..3de14eb32f --- /dev/null +++ b/examples/websockets-http2/self_signed_certs/key.pem @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQDTZXlmOaosH9VD +BNquvP/fmqxrzcv4NsChrTTTX6NWaYL4b025+LR1qugVaP9SJy/Z9lYRNTwmhkpS +tfgEr2iDlyDmuWf8K0FFYJRgjBGBj1H8J0EY+fiSn5eRIloLDvhkjIvL9l+7HqHw +RLWb1hGcHNr6ztT4iZEFF/gxE+1dDf+aonEBlip06cqe+yxHIPaq62BVMDHAxS8l +YiK9GFFgOPtY6lLgIQR5DE/WIfX9Pk/sV21MZ1NtUEAEIDwmEv3q4UY+k5gdwz/t +HkQgDfJnkU2D3fVDOgMhOO1ca1DrL3+X8hiRabonrtFZw3QihOxmHPRY1TiAB6yV +W0kuh+fxAgMBAAECggEADltu8k1qTFLhJgsXWxTFAAe+PBgfCT2WuaRM2So+qqjB +12Of0MieYPt5hbK63HaC3nfHgqWt7yPhulpXfOH45C8IcgMXl93MMg0MJr58leMI ++2ojFrIrerHSFm5R1TxwDEwrVm/mMowzDWFtQCc6zPJ8wNn5RuP48HKfTZ3/2fjw +zEjSwPO2wFMfo1EJNTjlI303lFbdFBs67NaX6puh30M7Tn+gznHKyO5a7F57wkIt +fkgnEy/sgMedQlwX7bRpUoD6f0fZzV8Qz4cHFywtYErczZJh3VGitJoO/VCIDdty +RPXOAqVDd7EpP1UUehZlKVWZ0OZMEfRgKbRCel5abQKBgQDwgwrIQ5+BiZv6a0VT +ETeXB+hRbvBinRykNo/RvLc3j1enRh9/zO/ShadZIXgOAiM1Jnr5Gp8KkNGca6K1 +myhtad7xYPODYzNXXp6T1OPgZxHZLIYzVUj6ypXeV64Te5ZiDaJ1D49czsq+PqsQ +XRcgBJSNpFtDFiXWpjXWfx8PxwKBgQDhAnLY5Sl2eeQo+ud0MvjwftB/mN2qCzJY +5AlQpRI4ThWxJgGPuHTR29zVa5iWNYuA5LWrC1y/wx+t5HKUwq+5kxvs+npYpDJD +ZX/w0Glc6s0Jc/mFySkbw9B2LePedL7lRF5OiAyC6D106Sc9V2jlL4IflmOzt4CD +ZTNbLtC6hwKBgHfIzBXxl/9sCcMuqdg1Ovp9dbcZCaATn7ApfHd5BccmHQGyav27 +k7XF2xMJGEHhzqcqAxUNrSgV+E9vTBomrHvRvrd5Ec7eGTPqbBA0d0nMC5eeFTh7 +wV0miH20LX6Gjt9G6yJiHYSbeV5G1+vOcTYBEft5X/qJjU7aePXbWh0BAoGBAJlV +5tgCCuhvFloK6fHYzqZtdT6O+PfpW20SMXrgkvMF22h2YvgDFrDwqKRUB47NfHzg +3yBpxNH1ccA5/w97QO8w3gX3h6qicpJVOAPusu6cIBACFZfjRv1hyszOZwvw+Soa +Fj5kHkqTY1YpkREPYS9V2dIW1Wjic1SXgZDw7VM/AoGAP/cZ3ZHTSCDTFlItqy5C +rIy2AiY0WJsx+K0qcvtosPOOwtnGjWHb1gdaVdfX/IRkSsX4PAOdnsyidNC5/l/m +y8oa+5WEeGFclWFhr4dnTA766o8HrM2UjIgWWYBF2VKdptGnHxFeJWFUmeQC/xeW +w37pCS7ykL+7gp7V0WShYsw= +-----END PRIVATE KEY----- diff --git a/examples/websockets-http2/src/main.rs b/examples/websockets-http2/src/main.rs new file mode 100644 index 0000000000..dbc682c4d9 --- /dev/null +++ b/examples/websockets-http2/src/main.rs @@ -0,0 +1,97 @@ +//! Run with +//! +//! ```not_rust +//! cargo run -p example-websockets-http2 +//! ``` + +use axum::{ + extract::{ + ws::{self, WebSocketUpgrade}, + State, + }, + http::Version, + routing::any, + Router, +}; +use axum_server::tls_rustls::RustlsConfig; +use std::{net::SocketAddr, path::PathBuf}; +use tokio::sync::broadcast; +use tower_http::services::ServeDir; +use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; + +#[tokio::main] +async fn main() { + tracing_subscriber::registry() + .with( + tracing_subscriber::EnvFilter::try_from_default_env() + .unwrap_or_else(|_| format!("{}=debug", env!("CARGO_CRATE_NAME")).into()), + ) + .with(tracing_subscriber::fmt::layer()) + .init(); + + let assets_dir = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("assets"); + + // configure certificate and private key used by https + let config = RustlsConfig::from_pem_file( + PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("self_signed_certs") + .join("cert.pem"), + PathBuf::from(env!("CARGO_MANIFEST_DIR")) + .join("self_signed_certs") + .join("key.pem"), + ) + .await + .unwrap(); + + // build our application with some routes and a broadcast channel + let app = Router::new() + .fallback_service(ServeDir::new(assets_dir).append_index_html_on_directories(true)) + .route("/ws", any(ws_handler)) + .with_state(broadcast::channel::(16).0); + + let addr = SocketAddr::from(([127, 0, 0, 1], 3000)); + tracing::debug!("listening on {}", addr); + + let mut server = axum_server::bind_rustls(addr, config); + + // IMPORTANT: This is required to advertise our support for HTTP/2 websockets to the client. + // If you use axum::serve, it is enabled by default. + server.http_builder().http2().enable_connect_protocol(); + + server.serve(app.into_make_service()).await.unwrap(); +} + +async fn ws_handler( + ws: WebSocketUpgrade, + version: Version, + State(sender): State>, +) -> axum::response::Response { + tracing::debug!("accepted a WebSocket using {version:?}"); + let mut receiver = sender.subscribe(); + ws.on_upgrade(|mut ws| async move { + loop { + tokio::select! { + // Since `ws` is a `Stream`, it is by nature cancel-safe. + res = ws.recv() => { + match res { + Some(Ok(ws::Message::Text(s))) => { + let _ = sender.send(s); + } + Some(Ok(_)) => {} + Some(Err(e)) => tracing::debug!("client disconnected abruptly: {e}"), + None => break, + } + } + // Tokio guarantees that `broadcast::Receiver::recv` is cancel-safe. + res = receiver.recv() => { + match res { + Ok(msg) => if let Err(e) = ws.send(ws::Message::Text(msg)).await { + tracing::debug!("client disconnected abruptly: {e}"); + } + Err(_) => continue, + } + } + } + } + }) +} diff --git a/examples/websockets/Cargo.toml b/examples/websockets/Cargo.toml index f62a8b03ee..541d82805a 100644 --- a/examples/websockets/Cargo.toml +++ b/examples/websockets/Cargo.toml @@ -11,9 +11,8 @@ futures = "0.3" futures-util = { version = "0.3", default-features = false, features = ["sink", "std"] } headers = "0.4" tokio = { version = "1.0", features = ["full"] } -tokio-tungstenite = "0.21" -tower = { version = "0.4", features = ["util"] } -tower-http = { version = "0.5.0", features = ["fs", "trace"] } +tokio-tungstenite = "0.24.0" +tower-http = { version = "0.6.1", features = ["fs", "trace"] } tracing = "0.1" tracing-subscriber = { version = "0.3", features = ["env-filter"] } diff --git a/examples/websockets/src/main.rs b/examples/websockets/src/main.rs index 62b00d34d2..7c4a9801af 100644 --- a/examples/websockets/src/main.rs +++ b/examples/websockets/src/main.rs @@ -19,7 +19,7 @@ use axum::{ extract::ws::{Message, WebSocket, WebSocketUpgrade}, response::IntoResponse, - routing::get, + routing::any, Router, }; use axum_extra::TypedHeader; @@ -45,8 +45,9 @@ use futures::{sink::SinkExt, stream::StreamExt}; async fn main() { tracing_subscriber::registry() .with( - tracing_subscriber::EnvFilter::try_from_default_env() - .unwrap_or_else(|_| "example_websockets=debug,tower_http=debug".into()), + tracing_subscriber::EnvFilter::try_from_default_env().unwrap_or_else(|_| { + format!("{}=debug,tower_http=debug", env!("CARGO_CRATE_NAME")).into() + }), ) .with(tracing_subscriber::fmt::layer()) .init(); @@ -56,7 +57,7 @@ async fn main() { // build our application with some routes let app = Router::new() .fallback_service(ServeDir::new(assets_dir).append_index_html_on_directories(true)) - .route("/ws", get(ws_handler)) + .route("/ws", any(ws_handler)) // logging so we can see whats going on .layer( TraceLayer::new_for_http() @@ -76,7 +77,7 @@ async fn main() { .unwrap(); } -/// The handler for the HTTP request (this gets called when the HTTP GET lands at the start +/// The handler for the HTTP request (this gets called when the HTTP request lands at the start /// of websocket negotiation). After this completes, the actual switching from HTTP to /// websocket protocol will occur. /// This is the last point where we can extract TCP/IP metadata such as IP address of the client @@ -99,7 +100,7 @@ async fn ws_handler( /// Actual websocket statemachine (one will be spawned per connection) async fn handle_socket(mut socket: WebSocket, who: SocketAddr) { - //send a ping (unsupported by some browsers) just to kick things off and get a response + // send a ping (unsupported by some browsers) just to kick things off and get a response if socket.send(Message::Ping(vec![1, 2, 3])).await.is_ok() { println!("Pinged {who}..."); } else {