From 6e66806835942472b01087e9d2055e458c70c950 Mon Sep 17 00:00:00 2001 From: Sergio Benitez Date: Wed, 13 Mar 2024 00:48:32 -0700 Subject: [PATCH] wip: http/3 support --- core/lib/Cargo.toml | 14 +- core/lib/src/config/config.rs | 8 +- core/lib/src/config/mod.rs | 425 +----------------- core/lib/src/config/tests.rs | 394 ++++++++++++++++ core/lib/src/data/data_stream.rs | 23 +- core/lib/src/data/mod.rs | 2 + core/lib/src/erased.rs | 27 +- core/lib/src/error.rs | 40 +- core/lib/src/fairing/fairings.rs | 7 + core/lib/src/fairing/mod.rs | 6 +- core/lib/src/lib.rs | 2 +- core/lib/src/listener/bounced.rs | 4 +- core/lib/src/listener/cancellable.rs | 222 +++------ core/lib/src/listener/connection.rs | 19 +- core/lib/src/listener/endpoint.rs | 2 + core/lib/src/listener/listener.rs | 13 +- core/lib/src/listener/mod.rs | 2 + core/lib/src/listener/quic.rs | 237 ++++++++++ core/lib/src/listener/tcp.rs | 4 +- core/lib/src/listener/tls.rs | 36 +- core/lib/src/listener/unix.rs | 4 +- core/lib/src/phase.rs | 7 +- core/lib/src/request/request.rs | 4 +- core/lib/src/rocket.rs | 108 ++++- core/lib/src/router/collider.rs | 4 +- core/lib/src/router/router.rs | 2 +- core/lib/src/server.rs | 163 +++---- core/lib/src/shield/shield.rs | 39 +- .../shutdown.rs => shutdown/config.rs} | 80 +--- .../src/{shutdown.rs => shutdown/handle.rs} | 92 +++- core/lib/src/shutdown/mod.rs | 13 + core/lib/src/shutdown/sig.rs | 58 +++ core/lib/src/{util => shutdown}/tripwire.rs | 46 +- core/lib/src/tls/config.rs | 2 +- core/lib/src/util/mod.rs | 46 +- docs/guide/10-configuration.md | 40 +- examples/tls/Cargo.toml | 2 +- examples/tls/Rocket.toml | 3 + examples/tls/src/main.rs | 2 +- 39 files changed, 1300 insertions(+), 902 deletions(-) create mode 100644 core/lib/src/config/tests.rs create mode 100644 core/lib/src/listener/quic.rs rename core/lib/src/{config/shutdown.rs => shutdown/config.rs} (84%) rename core/lib/src/{shutdown.rs => shutdown/handle.rs} (57%) create mode 100644 core/lib/src/shutdown/mod.rs create mode 100644 core/lib/src/shutdown/sig.rs rename core/lib/src/{util => shutdown}/tripwire.rs (77%) diff --git a/core/lib/Cargo.toml b/core/lib/Cargo.toml index cc38bb1b60..c09d416107 100644 --- a/core/lib/Cargo.toml +++ b/core/lib/Cargo.toml @@ -22,6 +22,7 @@ all-features = true [features] default = ["http2", "tokio-macros"] http2 = ["hyper/http2", "hyper-util/http2"] +http3 = ["s2n-quic", "s2n-quic-h3", "tls"] secrets = ["cookie/private", "cookie/key-expansion"] json = ["serde_json"] msgpack = ["rmp-serde"] @@ -76,8 +77,7 @@ futures = { version = "0.3.30", default-features = false, features = ["std"] } state = "0.6" [dependencies.hyper-util] -git = "https://github.com/SergioBenitez/hyper-util.git" -branch = "fix-readversion" +version = "0.1.3" default-features = false features = ["http1", "server", "tokio"] @@ -99,6 +99,16 @@ version = "0.6.0-dev" path = "../http" features = ["serde"] +[dependencies.s2n-quic] +version = "1.32" +default-features = false +features = ["provider-address-token-default", "provider-tls-rustls"] +optional = true + +[dependencies.s2n-quic-h3] +git = "https://github.com/SergioBenitez/s2n-quic-h3.git" +optional = true + [target.'cfg(unix)'.dependencies] libc = "0.2.149" diff --git a/core/lib/src/config/config.rs b/core/lib/src/config/config.rs index e208944c76..bf2e6ebc1c 100644 --- a/core/lib/src/config/config.rs +++ b/core/lib/src/config/config.rs @@ -5,7 +5,7 @@ use serde::{Deserialize, Serialize}; use yansi::{Paint, Style, Color::Primary}; use crate::log::PaintExt; -use crate::config::{LogLevel, Shutdown, Ident, CliColors}; +use crate::config::{LogLevel, ShutdownConfig, Ident, CliColors}; use crate::request::{self, Request, FromRequest}; use crate::http::uncased::Uncased; use crate::data::Limits; @@ -120,8 +120,8 @@ pub struct Config { #[cfg_attr(nightly, doc(cfg(feature = "secrets")))] #[serde(serialize_with = "SecretKey::serialize_zero")] pub secret_key: SecretKey, - /// Graceful shutdown configuration. **(default: [`Shutdown::default()`])** - pub shutdown: Shutdown, + /// Graceful shutdown configuration. **(default: [`ShutdownConfig::default()`])** + pub shutdown: ShutdownConfig, /// Max level to log. **(default: _debug_ `normal` / _release_ `critical`)** pub log_level: LogLevel, /// Whether to use colors and emoji when logging. **(default: @@ -200,7 +200,7 @@ impl Config { keep_alive: 5, #[cfg(feature = "secrets")] secret_key: SecretKey::zero(), - shutdown: Shutdown::default(), + shutdown: ShutdownConfig::default(), log_level: LogLevel::Normal, cli_colors: CliColors::Auto, __non_exhaustive: (), diff --git a/core/lib/src/config/mod.rs b/core/lib/src/config/mod.rs index 86481af1fe..9f07e9192c 100644 --- a/core/lib/src/config/mod.rs +++ b/core/lib/src/config/mod.rs @@ -113,422 +113,35 @@ #[macro_use] mod ident; mod config; -mod shutdown; mod cli_colors; mod http_header; +#[cfg(test)] +mod tests; -#[cfg(feature = "secrets")] -mod secret_key; - -#[doc(hidden)] -pub use config::{pretty_print_error, bail_with_config_error}; - -pub use config::Config; -pub use crate::log::LogLevel; -pub use shutdown::Shutdown; pub use ident::Ident; +pub use config::Config; pub use cli_colors::CliColors; -#[cfg(feature = "secrets")] -pub use secret_key::SecretKey; - -#[cfg(unix)] -pub use shutdown::Sig; - -#[cfg(test)] -mod tests { - use figment::{Figment, Profile}; - use pretty_assertions::assert_eq; - - use crate::log::LogLevel; - use crate::data::{Limits, ToByteUnit}; - use crate::config::{Config, CliColors}; - - #[test] - fn test_figment_is_default() { - figment::Jail::expect_with(|_| { - let mut default: Config = Config::figment().extract().unwrap(); - default.profile = Config::default().profile; - assert_eq!(default, Config::default()); - Ok(()) - }); - } - - #[test] - fn test_default_round_trip() { - figment::Jail::expect_with(|_| { - let original = Config::figment(); - let roundtrip = Figment::from(Config::from(&original)); - for figment in &[original, roundtrip] { - let config = Config::from(figment); - assert_eq!(config, Config::default()); - } - - Ok(()) - }); - } - - #[test] - fn test_profile_env() { - figment::Jail::expect_with(|jail| { - jail.set_env("ROCKET_PROFILE", "debug"); - let figment = Config::figment(); - assert_eq!(figment.profile(), "debug"); - - jail.set_env("ROCKET_PROFILE", "release"); - let figment = Config::figment(); - assert_eq!(figment.profile(), "release"); - - jail.set_env("ROCKET_PROFILE", "random"); - let figment = Config::figment(); - assert_eq!(figment.profile(), "random"); - - Ok(()) - }); - } - - #[test] - fn test_toml_file() { - figment::Jail::expect_with(|jail| { - jail.create_file("Rocket.toml", r#" - [default] - ident = "Something Cool" - workers = 20 - keep_alive = 10 - log_level = "off" - cli_colors = 0 - "#)?; - - let config = Config::from(Config::figment()); - assert_eq!(config, Config { - workers: 20, - ident: ident!("Something Cool"), - keep_alive: 10, - log_level: LogLevel::Off, - cli_colors: CliColors::Never, - ..Config::default() - }); - - jail.create_file("Rocket.toml", r#" - [global] - ident = "Something Else Cool" - workers = 20 - keep_alive = 10 - log_level = "off" - cli_colors = 0 - "#)?; - - let config = Config::from(Config::figment()); - assert_eq!(config, Config { - workers: 20, - ident: ident!("Something Else Cool"), - keep_alive: 10, - log_level: LogLevel::Off, - cli_colors: CliColors::Never, - ..Config::default() - }); - - jail.set_env("ROCKET_CONFIG", "Other.toml"); - jail.create_file("Other.toml", r#" - [default] - workers = 20 - keep_alive = 10 - log_level = "off" - cli_colors = 0 - "#)?; - - let config = Config::from(Config::figment()); - assert_eq!(config, Config { - workers: 20, - keep_alive: 10, - log_level: LogLevel::Off, - cli_colors: CliColors::Never, - ..Config::default() - }); - - Ok(()) - }); - } - - #[test] - fn test_cli_colors() { - figment::Jail::expect_with(|jail| { - jail.create_file("Rocket.toml", r#" - [default] - cli_colors = "never" - "#)?; - - let config = Config::from(Config::figment()); - assert_eq!(config.cli_colors, CliColors::Never); - - jail.create_file("Rocket.toml", r#" - [default] - cli_colors = "auto" - "#)?; - - let config = Config::from(Config::figment()); - assert_eq!(config.cli_colors, CliColors::Auto); - - jail.create_file("Rocket.toml", r#" - [default] - cli_colors = "always" - "#)?; - - let config = Config::from(Config::figment()); - assert_eq!(config.cli_colors, CliColors::Always); - - jail.create_file("Rocket.toml", r#" - [default] - cli_colors = true - "#)?; - - let config = Config::from(Config::figment()); - assert_eq!(config.cli_colors, CliColors::Auto); - - jail.create_file("Rocket.toml", r#" - [default] - cli_colors = false - "#)?; - - let config = Config::from(Config::figment()); - assert_eq!(config.cli_colors, CliColors::Never); - - jail.create_file("Rocket.toml", r#"[default]"#)?; - let config = Config::from(Config::figment()); - assert_eq!(config.cli_colors, CliColors::Auto); - - jail.create_file("Rocket.toml", r#" - [default] - cli_colors = 1 - "#)?; - - let config = Config::from(Config::figment()); - assert_eq!(config.cli_colors, CliColors::Auto); - - jail.create_file("Rocket.toml", r#" - [default] - cli_colors = 0 - "#)?; - - let config = Config::from(Config::figment()); - assert_eq!(config.cli_colors, CliColors::Never); - - jail.set_env("ROCKET_CLI_COLORS", 1); - let config = Config::from(Config::figment()); - assert_eq!(config.cli_colors, CliColors::Auto); - - jail.set_env("ROCKET_CLI_COLORS", 0); - let config = Config::from(Config::figment()); - assert_eq!(config.cli_colors, CliColors::Never); - - jail.set_env("ROCKET_CLI_COLORS", true); - let config = Config::from(Config::figment()); - assert_eq!(config.cli_colors, CliColors::Auto); - - jail.set_env("ROCKET_CLI_COLORS", false); - let config = Config::from(Config::figment()); - assert_eq!(config.cli_colors, CliColors::Never); - - jail.set_env("ROCKET_CLI_COLORS", "always"); - let config = Config::from(Config::figment()); - assert_eq!(config.cli_colors, CliColors::Always); - - jail.set_env("ROCKET_CLI_COLORS", "NEveR"); - let config = Config::from(Config::figment()); - assert_eq!(config.cli_colors, CliColors::Never); - - jail.set_env("ROCKET_CLI_COLORS", "auTO"); - let config = Config::from(Config::figment()); - assert_eq!(config.cli_colors, CliColors::Auto); - - Ok(()) - }) - } - - #[test] - fn test_profiles_merge() { - figment::Jail::expect_with(|jail| { - jail.create_file("Rocket.toml", r#" - [default.limits] - stream = "50kb" - - [global] - limits = { forms = "2kb" } - - [debug.limits] - file = "100kb" - "#)?; - - jail.set_env("ROCKET_PROFILE", "unknown"); - let config = Config::from(Config::figment()); - assert_eq!(config, Config { - profile: Profile::const_new("unknown"), - limits: Limits::default() - .limit("stream", 50.kilobytes()) - .limit("forms", 2.kilobytes()), - ..Config::default() - }); - - jail.set_env("ROCKET_PROFILE", "debug"); - let config = Config::from(Config::figment()); - assert_eq!(config, Config { - profile: Profile::const_new("debug"), - limits: Limits::default() - .limit("stream", 50.kilobytes()) - .limit("forms", 2.kilobytes()) - .limit("file", 100.kilobytes()), - ..Config::default() - }); - - Ok(()) - }); - } - - #[test] - fn test_env_vars_merge() { - use crate::config::{Ident, Shutdown}; - - figment::Jail::expect_with(|jail| { - jail.set_env("ROCKET_KEEP_ALIVE", 9999); - let config = Config::from(Config::figment()); - assert_eq!(config, Config { - keep_alive: 9999, - ..Config::default() - }); - - jail.set_env("ROCKET_SHUTDOWN", r#"{grace=7}"#); - let first_figment = Config::figment(); - jail.set_env("ROCKET_SHUTDOWN", r#"{mercy=10}"#); - let prev_figment = Config::figment().join(&first_figment); - let config = Config::from(&prev_figment); - assert_eq!(config, Config { - keep_alive: 9999, - shutdown: Shutdown { grace: 7, mercy: 10, ..Default::default() }, - ..Config::default() - }); - - jail.set_env("ROCKET_SHUTDOWN", r#"{mercy=20}"#); - let config = Config::from(Config::figment().join(&prev_figment)); - assert_eq!(config, Config { - keep_alive: 9999, - shutdown: Shutdown { grace: 7, mercy: 20, ..Default::default() }, - ..Config::default() - }); - - jail.set_env("ROCKET_LIMITS", r#"{stream=100kiB}"#); - let config = Config::from(Config::figment().join(&prev_figment)); - assert_eq!(config, Config { - keep_alive: 9999, - shutdown: Shutdown { grace: 7, mercy: 20, ..Default::default() }, - limits: Limits::default().limit("stream", 100.kibibytes()), - ..Config::default() - }); - - jail.set_env("ROCKET_IDENT", false); - let config = Config::from(Config::figment().join(&prev_figment)); - assert_eq!(config, Config { - keep_alive: 9999, - shutdown: Shutdown { grace: 7, mercy: 20, ..Default::default() }, - limits: Limits::default().limit("stream", 100.kibibytes()), - ident: Ident::none(), - ..Config::default() - }); - - Ok(()) - }); - } - - #[test] - fn test_precedence() { - figment::Jail::expect_with(|jail| { - jail.create_file("Rocket.toml", r#" - [global.limits] - forms = "1mib" - stream = "50kb" - file = "100kb" - "#)?; - - let config = Config::from(Config::figment()); - assert_eq!(config, Config { - limits: Limits::default() - .limit("forms", 1.mebibytes()) - .limit("stream", 50.kilobytes()) - .limit("file", 100.kilobytes()), - ..Config::default() - }); - - jail.set_env("ROCKET_LIMITS", r#"{stream=3MiB,capture=2MiB}"#); - let config = Config::from(Config::figment()); - assert_eq!(config, Config { - limits: Limits::default() - .limit("file", 100.kilobytes()) - .limit("forms", 1.mebibytes()) - .limit("stream", 3.mebibytes()) - .limit("capture", 2.mebibytes()), - ..Config::default() - }); - - jail.set_env("ROCKET_PROFILE", "foo"); - let val: Result = Config::figment().extract_inner("profile"); - assert!(val.is_err()); - - Ok(()) - }); - } +pub use crate::log::LogLevel; +pub use crate::shutdown::ShutdownConfig; - #[test] - #[cfg(feature = "secrets")] - #[should_panic] - fn test_err_on_non_debug_and_no_secret_key() { - figment::Jail::expect_with(|jail| { - jail.set_env("ROCKET_PROFILE", "release"); - let rocket = crate::custom(Config::figment()); - let _result = crate::local::blocking::Client::untracked(rocket); - Ok(()) - }); - } +#[cfg(feature = "tls")] +pub use crate::tls::TlsConfig; - #[test] - #[cfg(feature = "secrets")] - #[should_panic] - fn test_err_on_non_debug2_and_no_secret_key() { - figment::Jail::expect_with(|jail| { - jail.set_env("ROCKET_PROFILE", "boop"); - let rocket = crate::custom(Config::figment()); - let _result = crate::local::blocking::Client::tracked(rocket); - Ok(()) - }); - } +#[cfg(feature = "mtls")] +pub use crate::mtls::MtlsConfig; - #[test] - fn test_no_err_on_debug_and_no_secret_key() { - figment::Jail::expect_with(|jail| { - jail.set_env("ROCKET_PROFILE", "debug"); - let figment = Config::figment(); - assert!(crate::local::blocking::Client::untracked(crate::custom(&figment)).is_ok()); - crate::async_main(async { - let rocket = crate::custom(&figment); - assert!(crate::local::asynchronous::Client::tracked(rocket).await.is_ok()); - }); +#[cfg(feature = "secrets")] +mod secret_key; - Ok(()) - }); - } +#[cfg(unix)] +pub use crate::shutdown::Sig; - #[test] - fn test_no_err_on_release_and_custom_secret_key() { - figment::Jail::expect_with(|jail| { - jail.set_env("ROCKET_PROFILE", "release"); - let key = "Bx4Gb+aSIfuoEyMHD4DvNs92+wmzfQK98qc6MiwyPY4="; - let figment = Config::figment().merge(("secret_key", key)); +#[cfg(unix)] +pub use crate::listener::unix::UdsConfig; - assert!(crate::local::blocking::Client::tracked(crate::custom(&figment)).is_ok()); - crate::async_main(async { - let rocket = crate::custom(&figment); - assert!(crate::local::asynchronous::Client::untracked(rocket).await.is_ok()); - }); +#[cfg(feature = "secrets")] +pub use secret_key::SecretKey; - Ok(()) - }); - } -} +#[doc(hidden)] +pub use config::{pretty_print_error, bail_with_config_error}; diff --git a/core/lib/src/config/tests.rs b/core/lib/src/config/tests.rs new file mode 100644 index 0000000000..b5e3429f4c --- /dev/null +++ b/core/lib/src/config/tests.rs @@ -0,0 +1,394 @@ +use figment::{Figment, Profile}; +use pretty_assertions::assert_eq; + +use crate::log::LogLevel; +use crate::data::{Limits, ToByteUnit}; +use crate::config::{Config, CliColors}; + +#[test] +fn test_figment_is_default() { + figment::Jail::expect_with(|_| { + let mut default: Config = Config::figment().extract().unwrap(); + default.profile = Config::default().profile; + assert_eq!(default, Config::default()); + Ok(()) + }); +} + +#[test] +fn test_default_round_trip() { + figment::Jail::expect_with(|_| { + let original = Config::figment(); + let roundtrip = Figment::from(Config::from(&original)); + for figment in &[original, roundtrip] { + let config = Config::from(figment); + assert_eq!(config, Config::default()); + } + + Ok(()) + }); +} + +#[test] +fn test_profile_env() { + figment::Jail::expect_with(|jail| { + jail.set_env("ROCKET_PROFILE", "debug"); + let figment = Config::figment(); + assert_eq!(figment.profile(), "debug"); + + jail.set_env("ROCKET_PROFILE", "release"); + let figment = Config::figment(); + assert_eq!(figment.profile(), "release"); + + jail.set_env("ROCKET_PROFILE", "random"); + let figment = Config::figment(); + assert_eq!(figment.profile(), "random"); + + Ok(()) + }); +} + +#[test] +fn test_toml_file() { + figment::Jail::expect_with(|jail| { + jail.create_file("Rocket.toml", r#" + [default] + ident = "Something Cool" + workers = 20 + keep_alive = 10 + log_level = "off" + cli_colors = 0 + "#)?; + + let config = Config::from(Config::figment()); + assert_eq!(config, Config { + workers: 20, + ident: ident!("Something Cool"), + keep_alive: 10, + log_level: LogLevel::Off, + cli_colors: CliColors::Never, + ..Config::default() + }); + + jail.create_file("Rocket.toml", r#" + [global] + ident = "Something Else Cool" + workers = 20 + keep_alive = 10 + log_level = "off" + cli_colors = 0 + "#)?; + + let config = Config::from(Config::figment()); + assert_eq!(config, Config { + workers: 20, + ident: ident!("Something Else Cool"), + keep_alive: 10, + log_level: LogLevel::Off, + cli_colors: CliColors::Never, + ..Config::default() + }); + + jail.set_env("ROCKET_CONFIG", "Other.toml"); + jail.create_file("Other.toml", r#" + [default] + workers = 20 + keep_alive = 10 + log_level = "off" + cli_colors = 0 + "#)?; + + let config = Config::from(Config::figment()); + assert_eq!(config, Config { + workers: 20, + keep_alive: 10, + log_level: LogLevel::Off, + cli_colors: CliColors::Never, + ..Config::default() + }); + + Ok(()) + }); +} + +#[test] +fn test_cli_colors() { + figment::Jail::expect_with(|jail| { + jail.create_file("Rocket.toml", r#" + [default] + cli_colors = "never" + "#)?; + + let config = Config::from(Config::figment()); + assert_eq!(config.cli_colors, CliColors::Never); + + jail.create_file("Rocket.toml", r#" + [default] + cli_colors = "auto" + "#)?; + + let config = Config::from(Config::figment()); + assert_eq!(config.cli_colors, CliColors::Auto); + + jail.create_file("Rocket.toml", r#" + [default] + cli_colors = "always" + "#)?; + + let config = Config::from(Config::figment()); + assert_eq!(config.cli_colors, CliColors::Always); + + jail.create_file("Rocket.toml", r#" + [default] + cli_colors = true + "#)?; + + let config = Config::from(Config::figment()); + assert_eq!(config.cli_colors, CliColors::Auto); + + jail.create_file("Rocket.toml", r#" + [default] + cli_colors = false + "#)?; + + let config = Config::from(Config::figment()); + assert_eq!(config.cli_colors, CliColors::Never); + + jail.create_file("Rocket.toml", r#"[default]"#)?; + let config = Config::from(Config::figment()); + assert_eq!(config.cli_colors, CliColors::Auto); + + jail.create_file("Rocket.toml", r#" + [default] + cli_colors = 1 + "#)?; + + let config = Config::from(Config::figment()); + assert_eq!(config.cli_colors, CliColors::Auto); + + jail.create_file("Rocket.toml", r#" + [default] + cli_colors = 0 + "#)?; + + let config = Config::from(Config::figment()); + assert_eq!(config.cli_colors, CliColors::Never); + + jail.set_env("ROCKET_CLI_COLORS", 1); + let config = Config::from(Config::figment()); + assert_eq!(config.cli_colors, CliColors::Auto); + + jail.set_env("ROCKET_CLI_COLORS", 0); + let config = Config::from(Config::figment()); + assert_eq!(config.cli_colors, CliColors::Never); + + jail.set_env("ROCKET_CLI_COLORS", true); + let config = Config::from(Config::figment()); + assert_eq!(config.cli_colors, CliColors::Auto); + + jail.set_env("ROCKET_CLI_COLORS", false); + let config = Config::from(Config::figment()); + assert_eq!(config.cli_colors, CliColors::Never); + + jail.set_env("ROCKET_CLI_COLORS", "always"); + let config = Config::from(Config::figment()); + assert_eq!(config.cli_colors, CliColors::Always); + + jail.set_env("ROCKET_CLI_COLORS", "NEveR"); + let config = Config::from(Config::figment()); + assert_eq!(config.cli_colors, CliColors::Never); + + jail.set_env("ROCKET_CLI_COLORS", "auTO"); + let config = Config::from(Config::figment()); + assert_eq!(config.cli_colors, CliColors::Auto); + + Ok(()) + }) +} + +#[test] +fn test_profiles_merge() { + figment::Jail::expect_with(|jail| { + jail.create_file("Rocket.toml", r#" + [default.limits] + stream = "50kb" + + [global] + limits = { forms = "2kb" } + + [debug.limits] + file = "100kb" + "#)?; + + jail.set_env("ROCKET_PROFILE", "unknown"); + let config = Config::from(Config::figment()); + assert_eq!(config, Config { + profile: Profile::const_new("unknown"), + limits: Limits::default() + .limit("stream", 50.kilobytes()) + .limit("forms", 2.kilobytes()), + ..Config::default() + }); + + jail.set_env("ROCKET_PROFILE", "debug"); + let config = Config::from(Config::figment()); + assert_eq!(config, Config { + profile: Profile::const_new("debug"), + limits: Limits::default() + .limit("stream", 50.kilobytes()) + .limit("forms", 2.kilobytes()) + .limit("file", 100.kilobytes()), + ..Config::default() + }); + + Ok(()) + }); +} + +#[test] +fn test_env_vars_merge() { + use crate::config::{Ident, ShutdownConfig}; + + figment::Jail::expect_with(|jail| { + jail.set_env("ROCKET_KEEP_ALIVE", 9999); + let config = Config::from(Config::figment()); + assert_eq!(config, Config { + keep_alive: 9999, + ..Config::default() + }); + + jail.set_env("ROCKET_SHUTDOWN", r#"{grace=7}"#); + let first_figment = Config::figment(); + jail.set_env("ROCKET_SHUTDOWN", r#"{mercy=10}"#); + let prev_figment = Config::figment().join(&first_figment); + let config = Config::from(&prev_figment); + assert_eq!(config, Config { + keep_alive: 9999, + shutdown: ShutdownConfig { grace: 7, mercy: 10, ..Default::default() }, + ..Config::default() + }); + + jail.set_env("ROCKET_SHUTDOWN", r#"{mercy=20}"#); + let config = Config::from(Config::figment().join(&prev_figment)); + assert_eq!(config, Config { + keep_alive: 9999, + shutdown: ShutdownConfig { grace: 7, mercy: 20, ..Default::default() }, + ..Config::default() + }); + + jail.set_env("ROCKET_LIMITS", r#"{stream=100kiB}"#); + let config = Config::from(Config::figment().join(&prev_figment)); + assert_eq!(config, Config { + keep_alive: 9999, + shutdown: ShutdownConfig { grace: 7, mercy: 20, ..Default::default() }, + limits: Limits::default().limit("stream", 100.kibibytes()), + ..Config::default() + }); + + jail.set_env("ROCKET_IDENT", false); + let config = Config::from(Config::figment().join(&prev_figment)); + assert_eq!(config, Config { + keep_alive: 9999, + shutdown: ShutdownConfig { grace: 7, mercy: 20, ..Default::default() }, + limits: Limits::default().limit("stream", 100.kibibytes()), + ident: Ident::none(), + ..Config::default() + }); + + Ok(()) + }); +} + +#[test] +fn test_precedence() { + figment::Jail::expect_with(|jail| { + jail.create_file("Rocket.toml", r#" + [global.limits] + forms = "1mib" + stream = "50kb" + file = "100kb" + "#)?; + + let config = Config::from(Config::figment()); + assert_eq!(config, Config { + limits: Limits::default() + .limit("forms", 1.mebibytes()) + .limit("stream", 50.kilobytes()) + .limit("file", 100.kilobytes()), + ..Config::default() + }); + + jail.set_env("ROCKET_LIMITS", r#"{stream=3MiB,capture=2MiB}"#); + let config = Config::from(Config::figment()); + assert_eq!(config, Config { + limits: Limits::default() + .limit("file", 100.kilobytes()) + .limit("forms", 1.mebibytes()) + .limit("stream", 3.mebibytes()) + .limit("capture", 2.mebibytes()), + ..Config::default() + }); + + jail.set_env("ROCKET_PROFILE", "foo"); + let val: Result = Config::figment().extract_inner("profile"); + assert!(val.is_err()); + + Ok(()) + }); +} + +#[test] +#[cfg(feature = "secrets")] +#[should_panic] +fn test_err_on_non_debug_and_no_secret_key() { + figment::Jail::expect_with(|jail| { + jail.set_env("ROCKET_PROFILE", "release"); + let rocket = crate::custom(Config::figment()); + let _result = crate::local::blocking::Client::untracked(rocket); + Ok(()) + }); +} + +#[test] +#[cfg(feature = "secrets")] +#[should_panic] +fn test_err_on_non_debug2_and_no_secret_key() { + figment::Jail::expect_with(|jail| { + jail.set_env("ROCKET_PROFILE", "boop"); + let rocket = crate::custom(Config::figment()); + let _result = crate::local::blocking::Client::tracked(rocket); + Ok(()) + }); +} + +#[test] +fn test_no_err_on_debug_and_no_secret_key() { + figment::Jail::expect_with(|jail| { + jail.set_env("ROCKET_PROFILE", "debug"); + let figment = Config::figment(); + assert!(crate::local::blocking::Client::untracked(crate::custom(&figment)).is_ok()); + crate::async_main(async { + let rocket = crate::custom(&figment); + assert!(crate::local::asynchronous::Client::tracked(rocket).await.is_ok()); + }); + + Ok(()) + }); +} + +#[test] +fn test_no_err_on_release_and_custom_secret_key() { + figment::Jail::expect_with(|jail| { + jail.set_env("ROCKET_PROFILE", "release"); + let key = "Bx4Gb+aSIfuoEyMHD4DvNs92+wmzfQK98qc6MiwyPY4="; + let figment = Config::figment().merge(("secret_key", key)); + + assert!(crate::local::blocking::Client::tracked(crate::custom(&figment)).is_ok()); + crate::async_main(async { + let rocket = crate::custom(&figment); + assert!(crate::local::asynchronous::Client::untracked(rocket).await.is_ok()); + }); + + Ok(()) + }); +} diff --git a/core/lib/src/data/data_stream.rs b/core/lib/src/data/data_stream.rs index 77d033284a..c0542a1b77 100644 --- a/core/lib/src/data/data_stream.rs +++ b/core/lib/src/data/data_stream.rs @@ -68,7 +68,9 @@ pub type RawReader<'r> = StreamReader, Bytes>; /// Raw underlying data stream. pub enum RawStream<'r> { Empty, - Body(&'r mut HyperBody), + Body(HyperBody), + #[cfg(feature = "http3")] + H3Body(crate::listener::Cancellable), Multipart(multer::Field<'r>), } @@ -343,7 +345,9 @@ impl Stream for RawStream<'_> { .poll_frame(cx) .map_ok(|frame| frame.into_data().unwrap_or_else(|_| Bytes::new())) .map_err(io::Error::other) - } + }, + #[cfg(feature = "http3")] + RawStream::H3Body(stream) => Pin::new(stream).poll_next(cx), RawStream::Multipart(s) => Pin::new(s).poll_next(cx).map_err(io::Error::other), RawStream::Empty => Poll::Ready(None), } @@ -356,6 +360,8 @@ impl Stream for RawStream<'_> { let (lower, upper) = (hint.lower(), hint.upper()); (lower as usize, upper.map(|x| x as usize)) }, + #[cfg(feature = "http3")] + RawStream::H3Body(_) => (0, Some(0)), RawStream::Multipart(mp) => mp.size_hint(), RawStream::Empty => (0, Some(0)), } @@ -367,17 +373,26 @@ impl std::fmt::Display for RawStream<'_> { match self { RawStream::Empty => f.write_str("empty stream"), RawStream::Body(_) => f.write_str("request body"), + #[cfg(feature = "http3")] + RawStream::H3Body(_) => f.write_str("http3 quic stream"), RawStream::Multipart(_) => f.write_str("multipart form field"), } } } -impl<'r> From<&'r mut HyperBody> for RawStream<'r> { - fn from(value: &'r mut HyperBody) -> Self { +impl<'r> From for RawStream<'r> { + fn from(value: HyperBody) -> Self { Self::Body(value) } } +#[cfg(feature = "http3")] +impl<'r> From> for RawStream<'r> { + fn from(value: crate::listener::Cancellable) -> Self { + Self::H3Body(value) + } +} + impl<'r> From> for RawStream<'r> { fn from(value: multer::Field<'r>) -> Self { Self::Multipart(value) diff --git a/core/lib/src/data/mod.rs b/core/lib/src/data/mod.rs index e3eebdd23c..f7c879dc0d 100644 --- a/core/lib/src/data/mod.rs +++ b/core/lib/src/data/mod.rs @@ -18,3 +18,5 @@ pub use self::capped::{N, Capped}; pub use self::io_stream::{IoHandler, IoStream}; pub use ubyte::{ByteUnit, ToByteUnit}; pub use self::transform::{Transform, TransformBuf}; + +pub(crate) use self::data_stream::RawStream; diff --git a/core/lib/src/erased.rs b/core/lib/src/erased.rs index 7b62522c55..966ff2dd56 100644 --- a/core/lib/src/erased.rs +++ b/core/lib/src/erased.rs @@ -6,19 +6,18 @@ use std::task::{Poll, Context}; use futures::future::BoxFuture; use http::request::Parts; -use hyper::body::Incoming; use tokio::io::{AsyncRead, ReadBuf}; -use crate::data::{Data, IoHandler}; +use crate::data::{Data, IoHandler, RawStream}; use crate::{Request, Response, Rocket, Orbit}; // TODO: Magic with trait async fn to get rid of the box pin. // TODO: Write safety proofs. macro_rules! static_assert_covariance { - ($T:tt) => ( + ($($T:tt)*) => ( const _: () = { - fn _assert_covariance<'x: 'y, 'y>(x: &'y $T<'x>) -> &'y $T<'y> { x } + fn _assert_covariance<'x: 'y, 'y>(x: &'y $($T)*<'x>) -> &'y $($T)*<'y> { x } }; ) } @@ -40,7 +39,6 @@ pub struct ErasedResponse { // XXX: SAFETY: This (dependent) field must come first due to drop order! response: Response<'static>, _request: Arc, - _incoming: Box, } impl Drop for ErasedResponse { @@ -79,10 +77,9 @@ impl ErasedRequest { ErasedRequest { _rocket: rocket, _parts: parts, request, } } - pub async fn into_response( + pub async fn into_response( self, - incoming: Incoming, - data_builder: impl for<'r> FnOnce(&'r mut Incoming) -> Data<'r>, + raw_stream: D, preprocess: impl for<'r, 'x> FnOnce( &'r Rocket, &'r mut Request<'x>, @@ -94,14 +91,11 @@ impl ErasedRequest { &'r Request<'r>, Data<'r> ) -> BoxFuture<'r, Response<'r>>, - ) -> ErasedResponse { - let mut incoming = Box::new(incoming); - let mut data: Data<'_> = { - let incoming: &mut Incoming = &mut *incoming; - let incoming: &'static mut Incoming = unsafe { transmute(incoming) }; - data_builder(incoming) - }; - + ) -> ErasedResponse + where T: Send + Sync + 'static, + D: for<'r> Into> + { + let mut data: Data<'_> = Data::from(raw_stream); let mut parent = Arc::new(self); let token: T = { let parent: &mut ErasedRequest = Arc::get_mut(&mut parent).unwrap(); @@ -122,7 +116,6 @@ impl ErasedRequest { ErasedResponse { _request: parent, - _incoming: incoming, response: response, } } diff --git a/core/lib/src/error.rs b/core/lib/src/error.rs index 21753b1f1b..5c39cea9fc 100644 --- a/core/lib/src/error.rs +++ b/core/lib/src/error.rs @@ -7,7 +7,7 @@ use std::error::Error as StdError; use yansi::Paint; use figment::Profile; -use crate::{Rocket, Orbit}; +use crate::{Ignite, Orbit, Rocket}; /// An error that occurs during launch. /// @@ -89,6 +89,11 @@ pub enum ErrorKind { SentinelAborts(Vec), /// The configuration profile is not debug but no secret key is configured. InsecureSecretKey(Profile), + /// Liftoff failed. Contains the Rocket instance that failed to shutdown. + Liftoff( + Result, Arc>>, + Box + ), /// Shutdown failed. Contains the Rocket instance that failed to shutdown. Shutdown(Arc>), } @@ -225,6 +230,11 @@ impl Error { "aborting due to sentinel-triggered abort(s)" } + ErrorKind::Liftoff(_, error) => { + error!("Rocket liftoff failed due to panicking liftoff fairing(s)."); + error_!("{error}"); + "aborting due to failed liftoff" + } ErrorKind::Shutdown(_) => { error!("Rocket failed to shutdown gracefully."); "aborting due to failed shutdown" @@ -246,6 +256,7 @@ impl fmt::Display for ErrorKind { ErrorKind::InsecureSecretKey(_) => "insecure secret key config".fmt(f), ErrorKind::Config(_) => "failed to extract configuration".fmt(f), ErrorKind::SentinelAborts(_) => "sentinel(s) aborted".fmt(f), + ErrorKind::Liftoff(_, _) => "liftoff failed".fmt(f), ErrorKind::Shutdown(_) => "shutdown failed".fmt(f), } } @@ -293,40 +304,45 @@ impl fmt::Display for Empty { impl StdError for Empty { } /// Log an error that occurs during request processing -pub(crate) fn log_server_error(error: &Box) { +#[track_caller] +pub(crate) fn log_server_error(error: &(dyn StdError + 'static)) { struct ServerError<'a>(&'a (dyn StdError + 'static)); impl fmt::Display for ServerError<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let error = &self.0; if let Some(e) = error.downcast_ref::() { - write!(f, "request processing failed: {e}")?; + write!(f, "request failed: {e}")?; } else if let Some(e) = error.downcast_ref::() { - write!(f, "connection I/O error: ")?; + write!(f, "connection error: ")?; match e.kind() { io::ErrorKind::NotConnected => write!(f, "remote disconnected")?, io::ErrorKind::UnexpectedEof => write!(f, "remote sent early eof")?, io::ErrorKind::ConnectionReset - | io::ErrorKind::ConnectionAborted - | io::ErrorKind::BrokenPipe => write!(f, "terminated by remote")?, + | io::ErrorKind::ConnectionAborted => write!(f, "terminated by remote")?, _ => write!(f, "{e}")?, } } else { write!(f, "http server error: {error}")?; } - if let Some(e) = error.source() { - write!(f, " ({})", ServerError(e))?; - } - Ok(()) } } + let mut error: &(dyn StdError + 'static) = &*error; if error.downcast_ref::().is_some() { - warn!("{}", ServerError(&**error)) + warn!("{}", ServerError(error)); + while let Some(source) = error.source() { + error = source; + warn_!("{}", ServerError(error)); + } } else { - error!("{}", ServerError(&**error)) + error!("{}", ServerError(error)); + while let Some(source) = error.source() { + error = source; + error_!("{}", ServerError(error)); + } } } diff --git a/core/lib/src/fairing/fairings.rs b/core/lib/src/fairing/fairings.rs index 12a99c08c4..f79ac639db 100644 --- a/core/lib/src/fairing/fairings.rs +++ b/core/lib/src/fairing/fairings.rs @@ -181,6 +181,13 @@ impl Fairings { } } } + + pub fn find(&self) -> Option<&T> { + self.all_fairings.iter() + .map(|f| &*f as &dyn std::any::Any) + .filter_map(|any| any.downcast_ref::()) + .next() + } } impl std::fmt::Debug for Fairings { diff --git a/core/lib/src/fairing/mod.rs b/core/lib/src/fairing/mod.rs index fec2b33f2b..ad9aaca40f 100644 --- a/core/lib/src/fairing/mod.rs +++ b/core/lib/src/fairing/mod.rs @@ -191,8 +191,8 @@ pub type Result, E = Rocket> = std::result::Result Listener for Bounced { self.listener.connect(accept).await } - fn socket_addr(&self) -> io::Result { - self.listener.socket_addr() + fn endpoint(&self) -> io::Result { + self.listener.endpoint() } } diff --git a/core/lib/src/listener/cancellable.rs b/core/lib/src/listener/cancellable.rs index fbabfb2c6d..52bce62338 100644 --- a/core/lib/src/listener/cancellable.rs +++ b/core/lib/src/listener/cancellable.rs @@ -1,178 +1,80 @@ use std::io; -use std::time::Duration; use std::task::{Poll, Context}; use std::pin::Pin; -use tokio::time::{sleep, Sleep}; +use futures::{Future, Stream}; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; -use futures::{StreamExt, future::{select, Either, Fuse, Future, FutureExt}}; +use futures::future::{poll_fn, FutureExt}; use pin_project_lite::pin_project; -use crate::{config, Shutdown}; -use crate::listener::{Listener, Connection, Certificates, Bounced, Endpoint}; - -// Rocket wraps all connections in a `CancellableIo` struct, an internal -// structure that gracefully closes I/O when it receives a signal. That signal -// is the `shutdown` future. When the future resolves, `CancellableIo` begins to -// terminate in grace, mercy, and finally force close phases. Since all -// connections are wrapped in `CancellableIo`, this eventually ends all I/O. -// -// At that point, unless a user spawned an infinite, stand-alone task that isn't -// monitoring `Shutdown`, all tasks should resolve. This means that all -// instances of the shared `Arc` are dropped and we can return the owned -// instance of `Rocket`. -// -// Unfortunately, the Hyper `server` future resolves as soon as it has finished -// processing requests without respect for ongoing responses. That is, `server` -// resolves even when there are running tasks that are generating a response. -// So, `server` resolving implies little to nothing about the state of -// connections. As a result, we depend on the timing of grace + mercy + some -// buffer to determine when all connections should be closed, thus all tasks -// should be complete, thus all references to `Arc` should be dropped -// and we can get a unique reference. -pin_project! { - pub struct CancellableListener { - pub trigger: F, - #[pin] - pub listener: L, - pub grace: Duration, - pub mercy: Duration, - } -} +use crate::util::FutureExt as _; +use crate::shutdown::Stages; pin_project! { /// I/O that can be cancelled when a future `F` resolves. #[must_use = "futures do nothing unless polled"] - pub struct CancellableIo { + pub struct Cancellable { #[pin] io: Option, - #[pin] - trigger: Fuse, + stages: Stages, state: State, - grace: Duration, - mercy: Duration, } } +#[derive(Debug)] enum State { - /// I/O has not been cancelled. Proceed as normal. + /// I/O has not been cancelled. Proceed as normal until `Shutdown`. Active, - /// I/O has been cancelled. See if we can finish before the timer expires. - Grace(Pin>), - /// Grace period elapsed. Shutdown the connection, waiting for the timer - /// until we force close. - Mercy(Pin>), + /// I/O has been cancelled. Try to finish before `Shutdown`. + Grace, + /// Grace has elapsed. Shutdown connections. After `Shutdown`, force close. + Mercy, } pub trait CancellableExt: Sized { - fn cancellable( - self, - trigger: Shutdown, - config: &config::Shutdown - ) -> CancellableListener { - if let Some(mut stream) = config.signal_stream() { - let trigger = trigger.clone(); - tokio::spawn(async move { - while let Some(sig) = stream.next().await { - if trigger.0.tripped() { - warn!("Received {}. Shutdown already in progress.", sig); - } else { - warn!("Received {}. Requesting shutdown.", sig); - } - - trigger.0.trip(); - } - }); - }; - - CancellableListener { - trigger, - listener: self, - grace: config.grace(), - mercy: config.mercy(), + fn cancellable(self, stages: Stages) -> Cancellable { + Cancellable { + io: Some(self), + state: State::Active, + stages, } } } -impl CancellableExt for L { } +impl CancellableExt for T { } fn time_out() -> io::Error { - io::Error::new(io::ErrorKind::TimedOut, "Shutdown grace timed out") + io::Error::new(io::ErrorKind::TimedOut, "shutdown grace period elapsed") } fn gone() -> io::Error { - io::Error::new(io::ErrorKind::BrokenPipe, "IO driver has terminated") + io::Error::new(io::ErrorKind::BrokenPipe, "I/O driver terminated") } -impl CancellableListener> - where L: Listener + Sync, - F: Future + Unpin + Clone + Send + Sync + 'static -{ - pub async fn accept_next(&self) -> Option<::Accept> { - let next = std::pin::pin!(self.listener.accept_next()); - match select(next, self.trigger.clone()).await { - Either::Left((next, _)) => Some(next), - Either::Right(_) => None, - } +impl Cancellable { + pub fn inner(&self) -> Option<&I> { + self.io.as_ref() } } -impl CancellableListener - where L: Listener + Sync, - F: Future + Clone + Send + Sync + 'static -{ - fn io(&self, conn: C) -> CancellableIo { - CancellableIo { - io: Some(conn), - trigger: self.trigger.clone().fuse(), - state: State::Active, - grace: self.grace, - mercy: self.mercy, - } - } +pub trait AsyncCancel { + fn poll_cancel(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>; } -impl Listener for CancellableListener - where L: Listener + Sync, - F: Future + Clone + Send + Sync + Unpin + 'static -{ - type Accept = L::Accept; - - type Connection = CancellableIo; - - async fn accept(&self) -> io::Result { - let accept = std::pin::pin!(self.listener.accept()); - match select(accept, self.trigger.clone()).await { - Either::Left((result, _)) => result, - Either::Right(_) => Err(gone()), - } - } - - async fn connect(&self, accept: Self::Accept) -> io::Result { - let conn = std::pin::pin!(self.listener.connect(accept)); - match select(conn, self.trigger.clone()).await { - Either::Left((conn, _)) => Ok(self.io(conn?)), - Either::Right(_) => Err(gone()), - } - } - - fn socket_addr(&self) -> io::Result { - self.listener.socket_addr() +impl AsyncCancel for T { + fn poll_cancel(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + ::poll_shutdown(self, cx) } } -impl CancellableIo { - fn inner(&self) -> Option<&I> { - self.io.as_ref() - } - +impl Cancellable { /// Run `do_io` while connection processing should continue. - fn poll_trigger_then( + pub fn poll_with( mut self: Pin<&mut Self>, cx: &mut Context<'_>, do_io: impl FnOnce(Pin<&mut I>, &mut Context<'_>) -> Poll>, ) -> Poll> { - let mut me = self.as_mut().project(); + let me = self.as_mut().project(); let io = match me.io.as_pin_mut() { Some(io) => io, None => return Poll::Ready(Err(gone())), @@ -181,29 +83,29 @@ impl CancellableIo { loop { match me.state { State::Active => { - if me.trigger.as_mut().poll(cx).is_ready() { - *me.state = State::Grace(Box::pin(sleep(*me.grace))); + if me.stages.start.poll_unpin(cx).is_ready() { + *me.state = State::Grace; } else { return do_io(io, cx); } } - State::Grace(timer) => { - if timer.as_mut().poll(cx).is_ready() { - *me.state = State::Mercy(Box::pin(sleep(*me.mercy))); + State::Grace => { + if me.stages.grace.poll_unpin(cx).is_ready() { + *me.state = State::Mercy; } else { return do_io(io, cx); } } - State::Mercy(timer) => { - if timer.as_mut().poll(cx).is_ready() { + State::Mercy => { + if me.stages.mercy.poll_unpin(cx).is_ready() { self.project().io.set(None); return Poll::Ready(Err(time_out())); } else { - let result = futures::ready!(io.poll_shutdown(cx)); + let result = futures::ready!(io.poll_cancel(cx)); self.project().io.set(None); return match result { + Ok(()) => Poll::Ready(Err(gone())), Err(e) => Poll::Ready(Err(e)), - Ok(()) => Poll::Ready(Err(gone())) }; } }, @@ -212,45 +114,45 @@ impl CancellableIo { } } -impl AsyncRead for CancellableIo { +impl AsyncRead for Cancellable { fn poll_read( - mut self: Pin<&mut Self>, + self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut ReadBuf<'_>, ) -> Poll> { - self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_read(cx, buf)) + self.poll_with(cx, |io, cx| io.poll_read(cx, buf)) } } -impl AsyncWrite for CancellableIo { +impl AsyncWrite for Cancellable { fn poll_write( - mut self: Pin<&mut Self>, + self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { - self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_write(cx, buf)) + self.poll_with(cx, |io, cx| io.poll_write(cx, buf)) } fn poll_flush( - mut self: Pin<&mut Self>, + self: Pin<&mut Self>, cx: &mut Context<'_> ) -> Poll> { - self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_flush(cx)) + self.poll_with(cx, |io, cx| io.poll_flush(cx)) } fn poll_shutdown( - mut self: Pin<&mut Self>, + self: Pin<&mut Self>, cx: &mut Context<'_> ) -> Poll> { - self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_shutdown(cx)) + self.poll_with(cx, |io, cx| io.poll_shutdown(cx)) } fn poll_write_vectored( - mut self: Pin<&mut Self>, + self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[io::IoSlice<'_>], ) -> Poll> { - self.as_mut().poll_trigger_then(cx, |io, cx| io.poll_write_vectored(cx, bufs)) + self.poll_with(cx, |io, cx| io.poll_write_vectored(cx, bufs)) } fn is_write_vectored(&self) -> bool { @@ -258,16 +160,16 @@ impl AsyncWrite for CancellableIo { } } -impl Connection for CancellableIo - where F: Unpin + Send + 'static -{ - fn peer_address(&self) -> io::Result { - self.inner() - .ok_or_else(|| gone()) - .and_then(|io| io.peer_address()) - } +impl> + AsyncCancel> Stream for Cancellable { + type Item = I::Item; - fn peer_certificates(&self) -> Option> { - self.inner().and_then(|io| io.peer_certificates()) + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + use futures::ready; + + match ready!(self.poll_with(cx, |io, cx| io.poll_next(cx).map(Ok))) { + Ok(Some(v)) => Poll::Ready(Some(v)), + Ok(None) => Poll::Ready(None), + Err(e) => Poll::Ready(Some(Err(e))), + } } } diff --git a/core/lib/src/listener/connection.rs b/core/lib/src/listener/connection.rs index 68541109e0..49d1778997 100644 --- a/core/lib/src/listener/connection.rs +++ b/core/lib/src/listener/connection.rs @@ -2,7 +2,6 @@ use std::io; use std::borrow::Cow; use tokio_util::either::Either; -use tokio::io::{AsyncRead, AsyncWrite}; use super::Endpoint; @@ -10,8 +9,8 @@ use super::Endpoint; #[derive(Clone)] pub struct Certificates<'r>(Cow<'r, [der::CertificateDer<'r>]>); -pub trait Connection: AsyncRead + AsyncWrite + Send + Unpin { - fn peer_address(&self) -> io::Result; +pub trait Connection: Send + Unpin { + fn endpoint(&self) -> io::Result; /// DER-encoded X.509 certificate chain presented by the client, if any. /// @@ -21,21 +20,21 @@ pub trait Connection: AsyncRead + AsyncWrite + Send + Unpin { /// /// Defaults to an empty vector to indicate that no certificates were /// presented. - fn peer_certificates(&self) -> Option> { None } + fn certificates(&self) -> Option> { None } } impl Connection for Either { - fn peer_address(&self) -> io::Result { + fn endpoint(&self) -> io::Result { match self { - Either::Left(c) => c.peer_address(), - Either::Right(c) => c.peer_address(), + Either::Left(c) => c.endpoint(), + Either::Right(c) => c.endpoint(), } } - fn peer_certificates(&self) -> Option> { + fn certificates(&self) -> Option> { match self { - Either::Left(c) => c.peer_certificates(), - Either::Right(c) => c.peer_certificates(), + Either::Left(c) => c.certificates(), + Either::Right(c) => c.certificates(), } } } diff --git a/core/lib/src/listener/endpoint.rs b/core/lib/src/listener/endpoint.rs index 26640d1d1c..54e22e2c81 100644 --- a/core/lib/src/listener/endpoint.rs +++ b/core/lib/src/listener/endpoint.rs @@ -39,6 +39,7 @@ impl Endpoint { pub fn tcp(&self) -> Option { match self { Endpoint::Tcp(addr) => Some(*addr), + Endpoint::Tls(addr, _) => addr.tcp(), _ => None, } } @@ -46,6 +47,7 @@ impl Endpoint { pub fn unix(&self) -> Option<&Path> { match self { Endpoint::Unix(addr) => Some(addr), + Endpoint::Tls(addr, _) => addr.unix(), _ => None, } } diff --git a/core/lib/src/listener/listener.rs b/core/lib/src/listener/listener.rs index 8bdbc08c2b..a272b699c8 100644 --- a/core/lib/src/listener/listener.rs +++ b/core/lib/src/listener/listener.rs @@ -10,12 +10,13 @@ pub trait Listener: Send + Sync { type Connection: Connection; + #[crate::async_bound(Send)] async fn accept(&self) -> io::Result; #[crate::async_bound(Send)] async fn connect(&self, accept: Self::Accept) -> io::Result; - fn socket_addr(&self) -> io::Result; + fn endpoint(&self) -> io::Result; } impl Listener for &L { @@ -31,8 +32,8 @@ impl Listener for &L { ::connect(self, accept).await } - fn socket_addr(&self) -> io::Result { - ::socket_addr(self) + fn endpoint(&self) -> io::Result { + ::endpoint(self) } } @@ -56,10 +57,10 @@ impl Listener for Either { } } - fn socket_addr(&self) -> io::Result { + fn endpoint(&self) -> io::Result { match self { - Either::Left(l) => l.socket_addr(), - Either::Right(l) => l.socket_addr(), + Either::Left(l) => l.endpoint(), + Either::Right(l) => l.endpoint(), } } } diff --git a/core/lib/src/listener/mod.rs b/core/lib/src/listener/mod.rs index 244c36c604..b656c7fdad 100644 --- a/core/lib/src/listener/mod.rs +++ b/core/lib/src/listener/mod.rs @@ -13,6 +13,8 @@ pub mod unix; #[cfg_attr(nightly, doc(cfg(feature = "tls")))] pub mod tls; pub mod tcp; +#[cfg(feature = "http3")] +pub mod quic; pub use endpoint::*; pub use listener::*; diff --git a/core/lib/src/listener/quic.rs b/core/lib/src/listener/quic.rs new file mode 100644 index 0000000000..8c8ac3d09a --- /dev/null +++ b/core/lib/src/listener/quic.rs @@ -0,0 +1,237 @@ +use std::io; +use std::fmt; +use std::net::SocketAddr; + +use bytes::Bytes; +use futures::Stream; +use s2n_quic as quic; +use s2n_quic_h3 as quic_h3; +use quic_h3::h3 as h3; +use s2n_quic::provider::tls::rustls::{rustls, DEFAULT_CIPHERSUITES}; +use s2n_quic::provider::tls::rustls::Server as H3TlsServer; + +use tokio::sync::Mutex; +use tokio_stream::StreamExt; + +use crate::listener::{Bindable, Listener}; +use crate::tls::TlsConfig; + +use super::{Connection, Endpoint}; + +pub struct QuicBindable { + pub address: SocketAddr, + pub tls: TlsConfig, +} + +pub struct QuicListener { + listener: Mutex, + local_addr: SocketAddr, +} + +impl Bindable for QuicBindable { + type Listener = QuicListener; + + type Error = io::Error; + + async fn bind(self) -> Result { + // FIXME: Remove this as soon as `s2n_quic` is on rustls 0.22. + let cert_chain = crate::tls::util::load_cert_chain(&mut self.tls.certs_reader().unwrap()) + .unwrap() + .into_iter() + .map(|v| v.to_vec()) + .map(rustls::Certificate) + .collect::>(); + + let key = crate::tls::util::load_key(&mut self.tls.key_reader().unwrap()) + .unwrap() + .secret_der() + .to_vec(); + + let mut tls = rustls::server::ServerConfig::builder() + .with_cipher_suites(DEFAULT_CIPHERSUITES) + .with_safe_default_kx_groups() + .with_safe_default_protocol_versions() + .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("bad TLS config: {}", e)))? + .with_client_cert_verifier(rustls::server::NoClientAuth::boxed()) + .with_single_cert(cert_chain, rustls::PrivateKey(key)) + .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("bad TLS config: {}", e)))?; + + tls.alpn_protocols = vec![b"h3".to_vec()]; + tls.ignore_client_order = self.tls.prefer_server_cipher_order; + tls.session_storage = rustls::server::ServerSessionMemoryCache::new(1024); + tls.ticketer = rustls::Ticketer::new() + .map_err(|e| io::Error::new(io::ErrorKind::Other, format!("bad TLS ticketer: {}", e)))?; + + let listener = quic::Server::builder() + .with_tls(H3TlsServer::new(tls)) + .unwrap_or_else(|e| match e { }) + .with_io(self.address)? + .start() + .map_err(io::Error::other)?; + + let local_addr = listener.local_addr()?; + + Ok(QuicListener { listener: Mutex::new(listener), local_addr }) + } +} + +type H3Conn = h3::server::Connection; + +pub struct H3Stream(H3Conn); + +pub struct H3Connection { + pub handle: quic::connection::Handle, + pub parts: http::request::Parts, + pub tx: QuicTx, + pub rx: QuicRx, +} + +pub struct QuicRx(h3::server::RequestStream); + +pub struct QuicTx(h3::server::RequestStream, Bytes>); + +impl Listener for QuicListener { + type Accept = quic::Connection; + + type Connection = H3Stream; + + async fn accept(&self) -> io::Result { + self.listener + .lock().await + .accept().await + .ok_or_else(|| io::Error::new(io::ErrorKind::BrokenPipe, "closed")) + } + + async fn connect(&self, accept: Self::Accept) -> io::Result { + let quic_conn = quic_h3::Connection::new(accept); + let conn = H3Conn::new(quic_conn).await.map_err(io::Error::other)?; + Ok(H3Stream(conn)) + } + + fn endpoint(&self) -> io::Result { + Ok(self.local_addr.into()) + } +} + +impl H3Stream { + pub async fn accept(&mut self) -> io::Result> { + let handle = self.0.inner.conn.handle().clone(); + let ((parts, _), (tx, rx)) = match self.0.accept().await { + Ok(Some((req, stream))) => (req.into_parts(), stream.split()), + Ok(None) => return Ok(None), + Err(e) => { + if matches!(e.try_get_code().map(|c| c.value()), Some(0 | 0x100)) { + return Ok(None) + } + + return Err(io::Error::other(e)); + } + }; + + Ok(Some(H3Connection { handle, parts, tx: QuicTx(tx), rx: QuicRx(rx) })) + } +} + +impl QuicTx { + pub async fn send_response(&mut self, response: http::Response) -> io::Result<()> + where S: Stream> + { + use std::pin::pin; + + let (parts, body) = response.into_parts(); + let response = http::Response::from_parts(parts, ()); + self.0.send_response(response).await.map_err(io::Error::other)?; + + let mut body = pin!(body); + while let Some(bytes) = body.next().await { + let bytes = bytes.map_err(io::Error::other)?; + self.0.send_data(bytes).await.map_err(io::Error::other)?; + } + + self.0.finish().await.map_err(io::Error::other) + } + + pub fn cancel(&mut self) { + use s2n_quic_h3::h3; + + self.0.stop_stream(h3::error::Code::H3_NO_ERROR); + } +} + +impl Connection for H3Stream { + fn endpoint(&self) -> io::Result { + Ok(self.0.inner.conn.handle().local_addr()?.into()) + } +} +impl Connection for H3Connection { + fn endpoint(&self) -> io::Result { + Ok(self.handle.local_addr()?.into()) + } +} + +mod async_traits { + use std::io; + use std::pin::Pin; + use std::task::{ready, Context, Poll}; + + use super::{Bytes, QuicRx}; + use crate::listener::AsyncCancel; + + use futures::Stream; + use s2n_quic_h3::h3; + + impl Stream for QuicRx { + type Item = io::Result; + + fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + use bytes::Buf; + + match ready!(self.0.poll_recv_data(cx)) { + Ok(Some(mut buf)) => Poll::Ready(Some(Ok(buf.copy_to_bytes(buf.remaining())))), + Ok(None) => Poll::Ready(None), + Err(e) => Poll::Ready(Some(Err(io::Error::other(e)))), + } + } + } + + impl AsyncCancel for QuicRx { + fn poll_cancel(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + self.0.stop_sending(h3::error::Code::H3_NO_ERROR); + Poll::Ready(Ok(())) + } + } + + // impl AsyncWrite for QuicTx { + // fn poll_write( + // mut self: Pin<&mut Self>, + // cx: &mut Context<'_>, + // buf: &[u8], + // ) -> Poll> { + // let len = buf.len(); + // let result = ready!(self.0.poll_send_data(cx, Bytes::copy_from_slice(buf))); + // result.map_err(io::Error::other)?; + // Poll::Ready(Ok(len)) + // } + // + // fn poll_flush(self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + // Poll::Ready(Ok(())) + // } + // + // fn poll_shutdown(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll> { + // self.0.stop_stream(h3::error::Code::H3_NO_ERROR); + // Poll::Ready(Ok(())) + // } + // } +} + +impl fmt::Debug for H3Stream { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("H3Stream").finish() + } +} + +impl fmt::Debug for H3Connection { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("H3Connection").finish() + } +} diff --git a/core/lib/src/listener/tcp.rs b/core/lib/src/listener/tcp.rs index c2e3fd9f3f..ad223627bf 100644 --- a/core/lib/src/listener/tcp.rs +++ b/core/lib/src/listener/tcp.rs @@ -31,13 +31,13 @@ impl Listener for TcpListener { Ok(conn) } - fn socket_addr(&self) -> io::Result { + fn endpoint(&self) -> io::Result { self.local_addr().map(Endpoint::Tcp) } } impl Connection for TcpStream { - fn peer_address(&self) -> io::Result { + fn endpoint(&self) -> io::Result { self.peer_addr().map(Endpoint::Tcp) } } diff --git a/core/lib/src/listener/tls.rs b/core/lib/src/listener/tls.rs index ce2b53ffaf..d2ff5f4ef5 100644 --- a/core/lib/src/listener/tls.rs +++ b/core/lib/src/listener/tls.rs @@ -3,6 +3,7 @@ use std::sync::Arc; use serde::Deserialize; use rustls::server::{ServerSessionMemoryCache, ServerConfig, WebPkiClientVerifier}; +use tokio::io::{AsyncRead, AsyncWrite}; use tokio_rustls::TlsAcceptor; use crate::tls::{TlsConfig, Error}; @@ -27,7 +28,7 @@ pub struct TlsBindable { } impl TlsConfig { - pub(crate) fn acceptor(&self) -> Result { + pub(crate) fn server_config(&self) -> Result { let provider = rustls::crypto::CryptoProvider { cipher_suites: self.ciphers().map(|c| c.into()).collect(), ..rustls::crypto::ring::default_provider() @@ -64,52 +65,55 @@ impl TlsConfig { tls_config.alpn_protocols.insert(0, b"h2".to_vec()); } - Ok(TlsAcceptor::from(Arc::new(tls_config))) + Ok(tls_config) } } -impl Bindable for TlsBindable { +impl Bindable for TlsBindable + where I::Listener: Listener::Connection>, + ::Connection: AsyncRead + AsyncWrite +{ type Listener = TlsListener; type Error = Error; async fn bind(self) -> Result { Ok(TlsListener { - acceptor: self.tls.acceptor()?, + acceptor: TlsAcceptor::from(Arc::new(self.tls.server_config()?)), listener: self.inner.bind().await.map_err(|e| Error::Bind(Box::new(e)))?, config: self.tls, }) } } -impl Listener for TlsListener - where L::Connection: Unpin +impl Listener for TlsListener + where L: Listener::Connection>, + L::Connection: AsyncRead + AsyncWrite { - type Accept = L::Accept; + type Accept = L::Connection; type Connection = TlsStream; async fn accept(&self) -> io::Result { - self.listener.accept().await + Ok(self.listener.accept().await?) } - async fn connect(&self, accept: L::Accept) -> io::Result { - let conn = self.listener.connect(accept).await?; + async fn connect(&self, conn: L::Connection) -> io::Result { self.acceptor.accept(conn).await } - fn socket_addr(&self) -> io::Result { - Ok(self.listener.socket_addr()?.with_tls(self.config.clone())) + fn endpoint(&self) -> io::Result { + Ok(self.listener.endpoint()?.with_tls(self.config.clone())) } } -impl Connection for TlsStream { - fn peer_address(&self) -> io::Result { - Ok(self.get_ref().0.peer_address()?.assume_tls()) +impl Connection for TlsStream { + fn endpoint(&self) -> io::Result { + Ok(self.get_ref().0.endpoint()?.assume_tls()) } #[cfg(feature = "mtls")] - fn peer_certificates(&self) -> Option> { + fn certificates(&self) -> Option> { let cert_chain = self.get_ref().1.peer_certificates()?; Some(Certificates::from(cert_chain)) } diff --git a/core/lib/src/listener/unix.rs b/core/lib/src/listener/unix.rs index ea1a367e00..f5b79c76c5 100644 --- a/core/lib/src/listener/unix.rs +++ b/core/lib/src/listener/unix.rs @@ -83,13 +83,13 @@ impl Listener for UdsListener { Ok(accept) } - fn socket_addr(&self) -> io::Result { + fn endpoint(&self) -> io::Result { self.listener.local_addr()?.try_into() } } impl Connection for UnixStream { - fn peer_address(&self) -> io::Result { + fn endpoint(&self) -> io::Result { self.local_addr()?.try_into() } } diff --git a/core/lib/src/phase.rs b/core/lib/src/phase.rs index b38deeeaf6..35a1837d35 100644 --- a/core/lib/src/phase.rs +++ b/core/lib/src/phase.rs @@ -2,7 +2,8 @@ use state::TypeMap; use figment::Figment; use crate::listener::Endpoint; -use crate::{Catcher, Config, Rocket, Route, Shutdown}; +use crate::shutdown::Stages; +use crate::{Catcher, Config, Rocket, Route}; use crate::router::Router; use crate::fairing::Fairings; @@ -99,7 +100,7 @@ phases! { pub(crate) figment: Figment, pub(crate) config: Config, pub(crate) state: TypeMap![Send + Sync], - pub(crate) shutdown: Shutdown, + pub(crate) shutdown: Stages, } /// The final launch [`Phase`]. See [Rocket#orbit](`Rocket#orbit`) for @@ -113,7 +114,7 @@ phases! { pub(crate) figment: Figment, pub(crate) config: Config, pub(crate) state: TypeMap![Send + Sync], - pub(crate) shutdown: Shutdown, + pub(crate) shutdown: Stages, pub(crate) endpoint: Endpoint, } } diff --git a/core/lib/src/request/request.rs b/core/lib/src/request/request.rs index 5ff7e4b7e2..150098b92b 100644 --- a/core/lib/src/request/request.rs +++ b/core/lib/src/request/request.rs @@ -47,8 +47,8 @@ pub(crate) struct ConnectionMeta { impl From<&C> for ConnectionMeta { fn from(conn: &C) -> Self { ConnectionMeta { - peer_address: conn.peer_address().ok().map(Arc::new), - peer_certs: conn.peer_certificates().map(|c| c.into_owned()).map(Arc::new), + peer_address: conn.endpoint().ok().map(Arc::new), + peer_certs: conn.certificates().map(|c| c.into_owned()).map(Arc::new), } } } diff --git a/core/lib/src/rocket.rs b/core/lib/src/rocket.rs index 40570e7b45..6b4ec29f23 100644 --- a/core/lib/src/rocket.rs +++ b/core/lib/src/rocket.rs @@ -1,14 +1,17 @@ use std::fmt; use std::ops::{Deref, DerefMut}; +use std::sync::Arc; +use std::time::Duration; +use tokio::io::{AsyncRead, AsyncWrite}; use yansi::Paint; use either::Either; use figment::{Figment, Provider}; -use crate::{Catcher, Config, Route, Shutdown, sentinel, shield::Shield}; -use crate::listener::{Endpoint, Bindable, DefaultListener}; +use crate::shutdown::{Stages, Shutdown}; +use crate::{sentinel, shield::Shield, Catcher, Config, Route}; +use crate::listener::{Bindable, DefaultListener, Endpoint, Listener}; use crate::router::Router; -use crate::util::TripWire; use crate::fairing::{Fairing, Fairings}; use crate::phase::{Phase, Build, Building, Ignite, Igniting, Orbit, Orbiting}; use crate::phase::{Stateful, StateRef, State}; @@ -575,11 +578,11 @@ impl Rocket { // Ignite the rocket. let rocket: Rocket = Rocket(Igniting { - router, config, - shutdown: Shutdown(TripWire::new()), + shutdown: Stages::new(), figment: self.0.figment, fairings: self.0.fairings, state: self.0.state, + router, config, }); // Query the sentinels, abort if requested. @@ -630,7 +633,7 @@ impl Rocket { /// A completed graceful shutdown resolves the future returned by /// [`Rocket::launch()`]. If [`Shutdown::notify()`] is called _before_ an /// instance is launched, it will be immediately shutdown after liftoff. See - /// [`Shutdown`] and [`config::Shutdown`](crate::config::Shutdown) for + /// [`Shutdown`] and [`ShutdownConfig`](crate::config::ShutdownConfig) for /// details on graceful shutdown. /// /// # Example @@ -657,7 +660,7 @@ impl Rocket { /// } /// ``` pub fn shutdown(&self) -> Shutdown { - self.shutdown.clone() + self.shutdown.start.clone() } pub(crate) fn into_orbit(self, address: Endpoint) -> Rocket { @@ -687,13 +690,90 @@ impl Rocket { }) } - async fn _launch_on(self, bindable: B) -> Result, Error> { + async fn _launch_on(self, bindable: B) -> Result, Error> + where ::Connection: AsyncRead + AsyncWrite + { let listener = bindable.bind().await.map_err(|e| ErrorKind::Bind(Box::new(e)))?; - self.serve(listener).await + let rocket = Arc::new(self.into_orbit(listener.endpoint()?)); + rocket.shutdown.spawn_listener(&rocket.config.shutdown); + + if let Err(e) = tokio::spawn(Rocket::liftoff(rocket.clone())).await { + let rocket = rocket.try_wait_shutdown().await; + return Err(ErrorKind::Liftoff(rocket, Box::new(e)).into()); + } + + #[cfg(not(feature = "http3"))] + rocket.clone().serve(listener).await?; + + #[cfg(feature = "http3")] { + use crate::listener::quic::QuicBindable; + + let endpoint = rocket.endpoint(); + if let (Some(address), Some(tls)) = (endpoint.tcp(), endpoint.tls_config()) { + let quic_bindable = QuicBindable { address, tls: tls.clone() }; + let http3 = tokio::task::spawn(rocket.clone().serve3(quic_bindable.bind().await?)); + let http12 = tokio::task::spawn(rocket.clone().serve(listener)); + let (r1, r2) = tokio::join!(http12, http3); + r1.map_err(|e| ErrorKind::Liftoff(Err(rocket.clone()), Box::new(e)))??; + r2.map_err(|e| ErrorKind::Liftoff(Err(rocket.clone()), Box::new(e)))??; + } else { + warn!("HTTP/3 feature is enabled, but listener is not TCP/TLS."); + warn_!("HTTP/3 server cannot be started."); + rocket.clone().serve(listener).await?; + } + } + + Ok(rocket.try_wait_shutdown().await.map_err(ErrorKind::Shutdown)?) } } impl Rocket { + /// Rocket wraps all connections in a `CancellableIo` struct, an internal + /// structure that gracefully closes I/O when it receives a signal. That + /// signal is the `shutdown` future. When the future resolves, + /// `CancellableIo` begins to terminate in grace, mercy, and finally force + /// close phases. Since all connections are wrapped in `CancellableIo`, this + /// eventually ends all I/O. + /// + /// At that point, unless a user spawned an infinite, stand-alone task that + /// isn't monitoring `Shutdown`, all tasks should resolve. This means that + /// all instances of the shared `Arc` are dropped and we can return + /// the owned instance of `Rocket`. + /// + /// Unfortunately, the Hyper `server` future resolves as soon as it has + /// finished processing requests without respect for ongoing responses. That + /// is, `server` resolves even when there are running tasks that are + /// generating a response. So, `server` resolving implies little to nothing + /// about the state of connections. As a result, we depend on the timing of + /// grace + mercy + some buffer to determine when all connections should be + /// closed, thus all tasks should be complete, thus all references to + /// `Arc` should be dropped and we can get back a unique reference. + async fn try_wait_shutdown(self: Arc) -> Result, Arc> { + info!("Shutting down. Waiting for shutdown fairings and pending I/O..."); + tokio::spawn({ + let rocket = self.clone(); + async move { rocket.fairings.handle_shutdown(&*rocket).await } + }); + + let config = &self.config.shutdown; + let wait = Duration::from_micros(250); + for period in [wait, config.grace(), wait, config.mercy(), wait * 4] { + if Arc::strong_count(&self) == 1 { break } + tokio::time::sleep(period).await; + } + + match Arc::try_unwrap(self) { + Ok(rocket) => { + info!("Graceful shutdown completed successfully."); + Ok(rocket.into_ignite()) + } + Err(rocket) => { + warn!("Shutdown failed: outstanding background I/O."); + Err(rocket) + } + } + } + pub(crate) fn into_ignite(self) -> Rocket { Rocket(Igniting { router: self.0.router, @@ -751,8 +831,8 @@ impl Rocket { /// /// A completed graceful shutdown resolves the future returned by /// [`Rocket::launch()`]. See [`Shutdown`] and - /// [`config::Shutdown`](crate::config::Shutdown) for details on graceful - /// shutdown. + /// [`ShutdownConfig`](crate::config::ShutdownConfig) for details on + /// graceful shutdown. /// /// # Example /// @@ -774,7 +854,7 @@ impl Rocket { /// } /// ``` pub fn shutdown(&self) -> Shutdown { - self.shutdown.clone() + self.shutdown.start.clone() } } @@ -941,7 +1021,9 @@ impl Rocket

{ } } - pub async fn launch_on(self, bindable: B) -> Result, Error> { + pub async fn launch_on(self, bindable: B) -> Result, Error> + where ::Connection: AsyncRead + AsyncWrite + { match self.0.into_state() { State::Build(s) => Rocket::from(s).ignite().await?._launch_on(bindable).await, State::Ignite(s) => Rocket::from(s)._launch_on(bindable).await, diff --git a/core/lib/src/router/collider.rs b/core/lib/src/router/collider.rs index acf4d7c937..d6ada473bf 100644 --- a/core/lib/src/router/collider.rs +++ b/core/lib/src/router/collider.rs @@ -213,8 +213,8 @@ mod tests { use std::str::FromStr; use super::*; - use crate::route::{Route, dummy_handler}; - use crate::http::{Method, Method::*, MediaType}; + use crate::route::dummy_handler; + use crate::http::{Method, Method::*}; fn dummy_route(ranked: bool, method: impl Into>, uri: &'static str) -> Route { let method = method.into().unwrap_or(Get); diff --git a/core/lib/src/router/router.rs b/core/lib/src/router/router.rs index 5617f4fbcd..a9fb4600b0 100644 --- a/core/lib/src/router/router.rs +++ b/core/lib/src/router/router.rs @@ -102,7 +102,7 @@ mod test { use crate::route::dummy_handler; use crate::local::blocking::Client; - use crate::http::{Method, Method::*, uri::Origin}; + use crate::http::{Method::*, uri::Origin}; impl Router { fn has_collisions(&self) -> bool { diff --git a/core/lib/src/server.rs b/core/lib/src/server.rs index 3fbe2ae702..7be2634a7f 100644 --- a/core/lib/src/server.rs +++ b/core/lib/src/server.rs @@ -6,33 +6,35 @@ use std::time::Duration; use hyper::service::service_fn; use hyper_util::rt::{TokioExecutor, TokioIo, TokioTimer}; use hyper_util::server::conn::auto::Builder; -use futures::{Future, TryFutureExt, future::{select, Either::*}}; -use tokio::time::sleep; +use futures::{Future, TryFutureExt, future::Either::*}; +use tokio::io::{AsyncRead, AsyncWrite}; -use crate::{Request, Rocket, Orbit, Data, Ignite}; +use crate::{Orbit, Request, Rocket}; use crate::request::ConnectionMeta; use crate::erased::{ErasedRequest, ErasedResponse, ErasedIoHandler}; use crate::listener::{Listener, CancellableExt, BouncedExt}; -use crate::error::{Error, ErrorKind}; -use crate::data::IoStream; -use crate::util::ReaderStream; +use crate::error::log_server_error; +use crate::data::{IoStream, RawStream}; +use crate::util::{spawn_inspect, FutureExt, ReaderStream}; use crate::http::Status; +type Result = std::result::Result; + impl Rocket { - async fn service( + async fn service Into>>( self: Arc, - mut req: hyper::Request, + parts: http::request::Parts, + stream: T, + upgrade: Option, connection: ConnectionMeta, ) -> Result>, http::Error> { - let upgrade = hyper::upgrade::on(&mut req); - let (parts, incoming) = req.into_parts(); + let _http3_addr = self.endpoint().tls_config().and_then(|_| self.endpoint().tcp()); let request = ErasedRequest::new(self, parts, |rocket, parts| { Request::from_hyp(rocket, parts, connection).unwrap_or_else(|e| e) }); let mut response = request.into_response( - incoming, - |incoming| Data::from(incoming), + stream, |rocket, request, data| Box::pin(rocket.preprocess(request, data)), |token, rocket, request, data| Box::pin(async move { if !request.errors.is_empty() { @@ -46,7 +48,7 @@ impl Rocket { ).await; let io_handler = response.to_io_handler(Rocket::extract_io_handler); - if let Some(handler) = io_handler { + if let (Some(handler), Some(upgrade)) = (io_handler, upgrade) { let upgrade = upgrade.map_ok(IoStream::from).map_err(io::Error::other); tokio::task::spawn(io_handler_task(upgrade, handler)); } @@ -61,6 +63,12 @@ impl Rocket { builder = builder.header("Content-Length", size); } + // On HTTP3, add an Alt-Svc header. + #[cfg(feature = "http3")] + if let Some(addr) = _http3_addr { + builder = builder.header("alt-svc", format!("h3=\":{}\"", addr.port())); + } + let chunk_size = response.inner().body().max_chunk_size(); builder.body(ReaderStream::with_capacity(response, chunk_size)) } @@ -83,9 +91,10 @@ async fn io_handler_task(stream: S, mut handler: ErasedIoHandler) } } -impl Rocket { - pub(crate) async fn serve(self, listener: L) -> Result - where L: Listener + 'static +impl Rocket { + pub(crate) async fn serve(self: Arc, listener: L) -> Result<()> + where L: Listener + 'static, + L::Connection: AsyncRead + AsyncWrite { let mut builder = Builder::new(TokioExecutor::new()); let keep_alive = Duration::from_secs(self.config.keep_alive.into()); @@ -106,75 +115,69 @@ impl Rocket { } } - let listener = listener.bounced().cancellable(self.shutdown(), &self.config.shutdown); - let rocket = Arc::new(self.into_orbit(listener.socket_addr()?)); - let _ = tokio::spawn(Rocket::liftoff(rocket.clone())).await; - - let (server, listener) = (Arc::new(builder), Arc::new(listener)); - while let Some(accept) = listener.accept_next().await { - let (listener, rocket, server) = (listener.clone(), rocket.clone(), server.clone()); - tokio::spawn({ - let result = async move { - let conn = TokioIo::new(listener.connect(accept).await?); - let meta = ConnectionMeta::from(conn.inner()); - let service = service_fn(|req| rocket.clone().service(req, meta.clone())); - let serve = pin!(server.serve_connection_with_upgrades(conn, service)); - match select(serve, rocket.shutdown()).await { - Left((result, _)) => result, - Right((_, mut conn)) => { - conn.as_mut().graceful_shutdown(); - conn.await - } - } - }; - - result.inspect_err(crate::error::log_server_error) + let (listener, server) = (Arc::new(listener.bounced()), Arc::new(builder)); + while let Some(accept) = listener.accept().try_until(self.shutdown()).await? { + let (listener, rocket, server) = (listener.clone(), self.clone(), server.clone()); + spawn_inspect(|e| log_server_error(&**e), async move { + let conn = listener.connect(accept).io_try_until(rocket.shutdown()).await?; + let meta = ConnectionMeta::from(&conn); + let service = service_fn(|mut req| { + let upgrade = hyper::upgrade::on(&mut req); + let (parts, incoming) = req.into_parts(); + rocket.clone().service(parts, incoming, Some(upgrade), meta.clone()) + }); + + let io = TokioIo::new(conn.cancellable(rocket.shutdown.clone())); + let mut server = pin!(server.serve_connection_with_upgrades(io, service)); + match server.as_mut().or(rocket.shutdown()).await { + Left(result) => result, + Right(()) => { + server.as_mut().graceful_shutdown(); + server.await + }, + } }); } - // Rocket wraps all connections in a `CancellableIo` struct, an internal - // structure that gracefully closes I/O when it receives a signal. That - // signal is the `shutdown` future. When the future resolves, - // `CancellableIo` begins to terminate in grace, mercy, and finally - // force close phases. Since all connections are wrapped in - // `CancellableIo`, this eventually ends all I/O. - // - // At that point, unless a user spawned an infinite, stand-alone task - // that isn't monitoring `Shutdown`, all tasks should resolve. This - // means that all instances of the shared `Arc` are dropped and - // we can return the owned instance of `Rocket`. - // - // Unfortunately, the Hyper `server` future resolves as soon as it has - // finished processing requests without respect for ongoing responses. - // That is, `server` resolves even when there are running tasks that are - // generating a response. So, `server` resolving implies little to - // nothing about the state of connections. As a result, we depend on the - // timing of grace + mercy + some buffer to determine when all - // connections should be closed, thus all tasks should be complete, thus - // all references to `Arc` should be dropped and we can get back - // a unique reference. - info!("Shutting down. Waiting for shutdown fairings and pending I/O..."); - tokio::spawn({ - let rocket = rocket.clone(); - async move { rocket.fairings.handle_shutdown(&*rocket).await } - }); + Ok(()) + } +} - let config = &rocket.config.shutdown; - let wait = Duration::from_micros(250); - for period in [wait, config.grace(), wait, config.mercy(), wait * 4] { - if Arc::strong_count(&rocket) == 1 { break } - sleep(period).await; - } +#[cfg(feature = "http3")] +use crate::listener::quic::QuicListener; - match Arc::try_unwrap(rocket) { - Ok(rocket) => { - info!("Graceful shutdown completed successfully."); - Ok(rocket.into_ignite()) - } - Err(rocket) => { - warn!("Shutdown failed: outstanding background I/O."); - Err(Error::new(ErrorKind::Shutdown(rocket))) - } +#[cfg(feature = "http3")] +impl Rocket { + pub(crate) async fn serve3(self: Arc, listener: QuicListener) -> Result<()> { + let rocket = self.clone(); + let listener = Arc::new(listener.bounced()); + while let Some(accept) = listener.accept().try_until(rocket.shutdown()).await? { + let (listener, rocket) = (listener.clone(), rocket.clone()); + spawn_inspect(|e: &io::Error| log_server_error(e), async move { + let mut stream = listener.connect(accept).io_try_until(rocket.shutdown()).await?; + while let Some(mut conn) = stream.accept().io_try_until(rocket.shutdown()).await? { + let rocket = rocket.clone(); + spawn_inspect(|e: &io::Error| log_server_error(e), async move { + let meta = ConnectionMeta::from(&conn); + let rx = conn.rx.cancellable(rocket.shutdown.clone()); + let response = rocket.clone() + .service(conn.parts, rx, None, ConnectionMeta::from(meta)) + .map_err(io::Error::other) + .io_try_until(rocket.shutdown.mercy.clone()) + .await?; + + let grace = rocket.shutdown.grace.clone(); + match conn.tx.send_response(response).or(grace).await { + Left(result) => result, + Right(_) => Ok(conn.tx.cancel()), + } + }); + } + + Ok(()) + }); } + + Ok(()) } } diff --git a/core/lib/src/shield/shield.rs b/core/lib/src/shield/shield.rs index f3a3aeb241..1136e60cf9 100644 --- a/core/lib/src/shield/shield.rs +++ b/core/lib/src/shield/shield.rs @@ -1,7 +1,6 @@ use std::collections::HashMap; use std::sync::atomic::{AtomicBool, Ordering}; -use state::InitCell; use yansi::Paint; use crate::{Rocket, Request, Response, Orbit, Config}; @@ -68,11 +67,18 @@ use crate::shield::*; /// policy. pub struct Shield { /// Enabled policies where the key is the header name. - policies: HashMap<&'static UncasedStr, Box>, + policies: HashMap<&'static UncasedStr, Header<'static>>, /// Whether to enforce HSTS even though the user didn't enable it. force_hsts: AtomicBool, - /// Headers pre-rendered at liftoff from the configured policies. - rendered: InitCell>>, +} + +impl Clone for Shield { + fn clone(&self) -> Self { + Self { + policies: self.policies.clone(), + force_hsts: AtomicBool::from(self.force_hsts.load(Ordering::Acquire)), + } + } } impl Default for Shield { @@ -111,7 +117,6 @@ impl Shield { Shield { policies: HashMap::new(), force_hsts: AtomicBool::new(false), - rendered: InitCell::new(), } } @@ -129,8 +134,7 @@ impl Shield { /// let shield = Shield::new().enable(NoSniff::default()); /// ``` pub fn enable(mut self, policy: P) -> Self { - self.rendered = InitCell::new(); - self.policies.insert(P::NAME.into(), Box::new(policy)); + self.policies.insert(P::NAME.into(), policy.header()); self } @@ -145,7 +149,6 @@ impl Shield { /// let shield = Shield::default().disable::(); /// ``` pub fn disable(mut self) -> Self { - self.rendered = InitCell::new(); self.policies.remove(UncasedStr::new(P::NAME)); self } @@ -172,20 +175,6 @@ impl Shield { pub fn is_enabled(&self) -> bool { self.policies.contains_key(UncasedStr::new(P::NAME)) } - - fn headers(&self) -> &[Header<'static>] { - self.rendered.get_or_init(|| { - let mut headers: Vec<_> = self.policies.values() - .map(|p| p.header()) - .collect(); - - if self.force_hsts.load(Ordering::Acquire) { - headers.push(Policy::header(&Hsts::default())); - } - - headers - }) - } } #[crate::async_trait] @@ -206,10 +195,10 @@ impl Fairing for Shield { self.force_hsts.store(true, Ordering::Release); } - if !self.headers().is_empty() { + if !self.policies.is_empty() { info!("{}{}:", "🛡️ ".emoji(), "Shield".magenta()); - for header in self.headers() { + for header in self.policies.values() { info_!("{}: {}", header.name(), header.value().primary()); } @@ -224,7 +213,7 @@ impl Fairing for Shield { async fn on_response<'r>(&self, _: &'r Request<'_>, response: &mut Response<'r>) { // Set all of the headers in `self.policies` in `response` as long as // the header is not already in the response. - for header in self.headers() { + for header in self.policies.values() { if response.headers().contains(header.name()) { warn!("Shield: response contains a '{}' header.", header.name()); warn_!("Refusing to overwrite existing header."); diff --git a/core/lib/src/config/shutdown.rs b/core/lib/src/shutdown/config.rs similarity index 84% rename from core/lib/src/config/shutdown.rs rename to core/lib/src/shutdown/config.rs index 2353a4fbae..fc4cb9de9e 100644 --- a/core/lib/src/config/shutdown.rs +++ b/core/lib/src/shutdown/config.rs @@ -6,60 +6,7 @@ use std::collections::HashSet; use futures::stream::Stream; use serde::{Deserialize, Serialize}; -/// A Unix signal for triggering graceful shutdown. -/// -/// Each variant corresponds to a Unix process signal which can be used to -/// trigger a graceful shutdown. See [`Shutdown`] for details. -/// -/// ## (De)serialization -/// -/// A `Sig` variant serializes and deserializes as a lowercase string equal to -/// the name of the variant: `"alrm"` for [`Sig::Alrm`], `"chld"` for -/// [`Sig::Chld`], and so on. -#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] -#[serde(rename_all = "lowercase")] -#[cfg_attr(nightly, doc(cfg(unix)))] -pub enum Sig { - /// The `SIGALRM` Unix signal. - Alrm, - /// The `SIGCHLD` Unix signal. - Chld, - /// The `SIGHUP` Unix signal. - Hup, - /// The `SIGINT` Unix signal. - Int, - /// The `SIGIO` Unix signal. - Io, - /// The `SIGPIPE` Unix signal. - Pipe, - /// The `SIGQUIT` Unix signal. - Quit, - /// The `SIGTERM` Unix signal. - Term, - /// The `SIGUSR1` Unix signal. - Usr1, - /// The `SIGUSR2` Unix signal. - Usr2 -} - -impl fmt::Display for Sig { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let s = match self { - Sig::Alrm => "SIGALRM", - Sig::Chld => "SIGCHLD", - Sig::Hup => "SIGHUP", - Sig::Int => "SIGINT", - Sig::Io => "SIGIO", - Sig::Pipe => "SIGPIPE", - Sig::Quit => "SIGQUIT", - Sig::Term => "SIGTERM", - Sig::Usr1 => "SIGUSR1", - Sig::Usr2 => "SIGUSR2", - }; - - s.fmt(f) - } -} +use crate::shutdown::Sig; /// Graceful shutdown configuration. /// @@ -94,11 +41,13 @@ impl fmt::Display for Sig { /// /// Once a shutdown is triggered, Rocket stops accepting new connections and /// waits at most `grace` seconds before initiating connection shutdown. -/// Applications can `await` the [`Shutdown`](crate::Shutdown) future to detect +/// Applications can `await` the [`Shutdown`] future to detect /// a shutdown and cancel any server-initiated I/O, such as from [infinite /// responders](crate::response::stream#graceful-shutdown), to avoid abrupt I/O /// cancellation. /// +/// [`Shutdown`]: crate::Shutdown +/// /// # Mercy Period /// /// After the grace period has elapsed, Rocket initiates connection shutdown, @@ -125,7 +74,8 @@ impl fmt::Display for Sig { /// prevent _buggy_ code, such as an unintended infinite loop or unknown use of /// blocking I/O, from preventing shutdown. /// -/// This behavior can be disabled by setting [`Shutdown::force`] to `false`. +/// This behavior can be disabled by setting [`ShutdownConfig::force`] to +/// `false`. /// /// # Example /// @@ -169,13 +119,13 @@ impl fmt::Display for Sig { /// /// ```rust /// # use rocket::figment::{Figment, providers::{Format, Toml}}; -/// use rocket::config::{Config, Shutdown}; +/// use rocket::config::{Config, ShutdownConfig}; /// /// #[cfg(unix)] /// use rocket::config::Sig; /// /// let config = Config { -/// shutdown: Shutdown { +/// shutdown: ShutdownConfig { /// ctrlc: false, /// #[cfg(unix)] /// signals: { @@ -204,7 +154,7 @@ impl fmt::Display for Sig { /// } /// ``` #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)] -pub struct Shutdown { +pub struct ShutdownConfig { /// Whether `ctrl-c` (`SIGINT`) initiates a server shutdown. /// /// **default: `true`** @@ -245,9 +195,9 @@ pub struct Shutdown { /// _always_ be done using a public constructor or update syntax: /// /// ```rust - /// use rocket::config::Shutdown; + /// use rocket::config::ShutdownConfig; /// - /// let config = Shutdown { + /// let config = ShutdownConfig { /// grace: 5, /// mercy: 10, /// ..Default::default() @@ -258,7 +208,7 @@ pub struct Shutdown { pub __non_exhaustive: (), } -impl fmt::Display for Shutdown { +impl fmt::Display for ShutdownConfig { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "ctrlc = {}, force = {}, ", self.ctrlc, self.force)?; @@ -276,9 +226,9 @@ impl fmt::Display for Shutdown { } } -impl Default for Shutdown { +impl Default for ShutdownConfig { fn default() -> Self { - Shutdown { + ShutdownConfig { ctrlc: true, #[cfg(unix)] signals: { let mut set = HashSet::new(); set.insert(Sig::Term); set }, @@ -290,7 +240,7 @@ impl Default for Shutdown { } } -impl Shutdown { +impl ShutdownConfig { pub(crate) fn grace(&self) -> Duration { Duration::from_secs(self.grace as u64) } diff --git a/core/lib/src/shutdown.rs b/core/lib/src/shutdown/handle.rs similarity index 57% rename from core/lib/src/shutdown.rs rename to core/lib/src/shutdown/handle.rs index 43a667af0a..63aff6d2b4 100644 --- a/core/lib/src/shutdown.rs +++ b/core/lib/src/shutdown/handle.rs @@ -2,22 +2,22 @@ use std::future::Future; use std::task::{Context, Poll}; use std::pin::Pin; -use futures::FutureExt; +use futures::{FutureExt, StreamExt}; +use crate::shutdown::{ShutdownConfig, TripWire}; use crate::request::{FromRequest, Outcome, Request}; -use crate::util::TripWire; /// A request guard and future for graceful shutdown. /// /// A server shutdown is manually requested by calling [`Shutdown::notify()`] -/// or, if enabled, through [automatic triggers] like `Ctrl-C`. Rocket will stop accepting new -/// requests, finish handling any pending requests, wait a grace period before -/// cancelling any outstanding I/O, and return `Ok()` to the caller of -/// [`Rocket::launch()`]. Graceful shutdown is configured via -/// [`config::Shutdown`](crate::config::Shutdown). +/// or, if enabled, through [automatic triggers] like `Ctrl-C`. Rocket will stop +/// accepting new requests, finish handling any pending requests, wait a grace +/// period before cancelling any outstanding I/O, and return `Ok()` to the +/// caller of [`Rocket::launch()`]. Graceful shutdown is configured via +/// [`ShutdownConfig`](crate::config::ShutdownConfig). /// /// [`Rocket::launch()`]: crate::Rocket::launch() -/// [automatic triggers]: crate::config::Shutdown#triggers +/// [automatic triggers]: crate::shutdown::Shutdown#triggers /// /// # Detecting Shutdown /// @@ -65,9 +65,24 @@ use crate::util::TripWire; /// ``` #[derive(Debug, Clone)] #[must_use = "`Shutdown` does nothing unless polled or `notify`ed"] -pub struct Shutdown(pub(crate) TripWire); +pub struct Shutdown { + wire: TripWire, +} + +#[derive(Debug, Clone)] +pub struct Stages { + pub start: Shutdown, + pub grace: Shutdown, + pub mercy: Shutdown, +} impl Shutdown { + fn new() -> Self { + Shutdown { + wire: TripWire::new(), + } + } + /// Notify the application to shut down gracefully. /// /// This function returns immediately; pending requests will continue to run @@ -85,9 +100,24 @@ impl Shutdown { /// "Shutting down..." /// } /// ``` - #[inline] - pub fn notify(self) { - self.0.trip(); + #[inline(always)] + pub fn notify(&self) { + self.wire.trip(); + } + + /// Returns `true` if `Shutdown::notify()` has already been called. + #[must_use] + #[inline(always)] + pub fn notified(&self) -> bool { + self.wire.tripped() + } +} + +impl Future for Shutdown { + type Output = (); + + fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + self.wire.poll_unpin(cx) } } @@ -101,11 +131,41 @@ impl<'r> FromRequest<'r> for Shutdown { } } -impl Future for Shutdown { - type Output = (); +impl Stages { + pub fn new() -> Self { + Stages { + start: Shutdown::new(), + grace: Shutdown::new(), + mercy: Shutdown::new(), + } + } - fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - self.0.poll_unpin(cx) + pub(crate) fn spawn_listener(&self, config: &ShutdownConfig) { + use futures::stream; + use futures::future::{select, Either}; + + let mut signal = match config.signal_stream() { + Some(stream) => Either::Left(stream.chain(stream::pending())), + None => Either::Right(stream::pending()), + }; + + let start = self.start.clone(); + let (grace, grace_duration) = (self.grace.clone(), config.grace()); + let (mercy, mercy_duration) = (self.mercy.clone(), config.mercy()); + tokio::spawn(async move { + if let Either::Left((sig, start)) = select(signal.next(), start).await { + warn!("Received {}. Shutdown started.", sig.unwrap()); + start.notify(); + } + + tokio::time::sleep(grace_duration).await; + warn!("Shutdown grace period elapsed. Shutting down I/O."); + grace.notify(); + + tokio::time::sleep(mercy_duration).await; + warn!("Mercy period elapsed. Terminating I/O."); + mercy.notify(); + }); } } diff --git a/core/lib/src/shutdown/mod.rs b/core/lib/src/shutdown/mod.rs new file mode 100644 index 0000000000..d68fddf37a --- /dev/null +++ b/core/lib/src/shutdown/mod.rs @@ -0,0 +1,13 @@ +//! Shutdown configuration and notification handle. + +mod tripwire; +mod handle; +mod sig; +mod config; + +pub(crate) use tripwire::TripWire; +pub(crate) use handle::Stages; + +pub use config::ShutdownConfig; +pub use handle::Shutdown; +pub use sig::Sig; diff --git a/core/lib/src/shutdown/sig.rs b/core/lib/src/shutdown/sig.rs new file mode 100644 index 0000000000..2f20b7a4d4 --- /dev/null +++ b/core/lib/src/shutdown/sig.rs @@ -0,0 +1,58 @@ +use std::fmt; + +use serde::{Deserialize, Serialize}; + +/// A Unix signal for triggering graceful shutdown. +/// +/// Each variant corresponds to a Unix process signal which can be used to +/// trigger a graceful shutdown. See [`Shutdown`](crate::Shutdown) for details. +/// +/// ## (De)serialization +/// +/// A `Sig` variant serializes and deserializes as a lowercase string equal to +/// the name of the variant: `"alrm"` for [`Sig::Alrm`], `"chld"` for +/// [`Sig::Chld`], and so on. +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[serde(rename_all = "lowercase")] +#[cfg_attr(nightly, doc(cfg(unix)))] +pub enum Sig { + /// The `SIGALRM` Unix signal. + Alrm, + /// The `SIGCHLD` Unix signal. + Chld, + /// The `SIGHUP` Unix signal. + Hup, + /// The `SIGINT` Unix signal. + Int, + /// The `SIGIO` Unix signal. + Io, + /// The `SIGPIPE` Unix signal. + Pipe, + /// The `SIGQUIT` Unix signal. + Quit, + /// The `SIGTERM` Unix signal. + Term, + /// The `SIGUSR1` Unix signal. + Usr1, + /// The `SIGUSR2` Unix signal. + Usr2 +} + +impl fmt::Display for Sig { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let s = match self { + Sig::Alrm => "SIGALRM", + Sig::Chld => "SIGCHLD", + Sig::Hup => "SIGHUP", + Sig::Int => "SIGINT", + Sig::Io => "SIGIO", + Sig::Pipe => "SIGPIPE", + Sig::Quit => "SIGQUIT", + Sig::Term => "SIGTERM", + Sig::Usr1 => "SIGUSR1", + Sig::Usr2 => "SIGUSR2", + }; + + s.fmt(f) + } +} diff --git a/core/lib/src/util/tripwire.rs b/core/lib/src/shutdown/tripwire.rs similarity index 77% rename from core/lib/src/util/tripwire.rs rename to core/lib/src/shutdown/tripwire.rs index c4d649bf4a..47e3b88858 100644 --- a/core/lib/src/util/tripwire.rs +++ b/core/lib/src/shutdown/tripwire.rs @@ -3,6 +3,9 @@ use std::{ops::Deref, pin::Pin, future::Future}; use std::task::{Context, Poll}; use std::sync::{Arc, atomic::{AtomicBool, Ordering}}; +use futures::future::FusedFuture; +use pin_project_lite::pin_project; +use tokio::sync::futures::Notified; use tokio::sync::Notify; #[doc(hidden)] @@ -15,7 +18,7 @@ pub struct State { pub struct TripWire { state: Arc, // `Notified` is `!Unpin`. Even if we could name it, we'd need to pin it. - event: Option + Send + Sync>>>, + event: Option>>>, } impl Deref for TripWire { @@ -35,6 +38,13 @@ impl Clone for TripWire { } } +impl Drop for TripWire { + fn drop(&mut self) { + // SAFETY: Ensure we drop the self-reference before `self`. + self.event = None; + } +} + impl fmt::Debug for TripWire { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("TripWire") @@ -47,35 +57,20 @@ impl Future for TripWire { type Output = (); fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { - if self.tripped.load(Ordering::Acquire) { + if self.tripped() { self.event = None; return Poll::Ready(()); } if self.event.is_none() { - let state = self.state.clone(); - self.event = Some(Box::pin(async move { - let notified = state.notify.notified(); - notified.await - })); + let notified = self.state.notify.notified(); + self.event = Some(Box::pin(unsafe { std::mem::transmute(notified) })); } if let Some(ref mut event) = self.event { - if event.as_mut().poll(cx).is_ready() { - // We need to call `trip()` to avoid a race condition where: - // 1) many trip wires have seen !self.tripped but have not - // polled for `self.event` yet, so are not subscribed - // 2) trip() is called, adding a permit to `event` - // 3) some trip wires poll `event` for the first time - // 4) one of those wins, returns `Ready()` - // 5) the rest return pending - // - // Without this `self.trip()` those will never be awoken. With - // the call to self.trip(), those that made it to poll() in 3) - // will be awoken by `notify_waiters()`. For those the didn't, - // one will be awoken by `notify_one()`, which will in-turn call - // self.trip(), awaking more until there are no more to awake. - self.trip(); + // The order here is important! We need to know: + // !self.tripped() => not notified == notified => self.tripped() + if event.as_mut().poll(cx).is_ready() || self.tripped() { self.event = None; return Poll::Ready(()); } @@ -85,6 +80,12 @@ impl Future for TripWire { } } +impl FusedFuture for TripWire { + fn is_terminated(&self) -> bool { + self.tripped() + } +} + impl TripWire { pub fn new() -> Self { TripWire { @@ -99,7 +100,6 @@ impl TripWire { pub fn trip(&self) { self.tripped.store(true, Ordering::Release); self.notify.notify_waiters(); - self.notify.notify_one(); } #[inline(always)] diff --git a/core/lib/src/tls/config.rs b/core/lib/src/tls/config.rs index 3131e16d5c..533b5793af 100644 --- a/core/lib/src/tls/config.rs +++ b/core/lib/src/tls/config.rs @@ -427,7 +427,7 @@ impl TlsConfig { } pub fn validate(&self) -> Result<(), crate::tls::Error> { - self.acceptor().map(|_| ()) + self.server_config().map(|_| ()) } } diff --git a/core/lib/src/util/mod.rs b/core/lib/src/util/mod.rs index d3055f36ce..ed591453c9 100644 --- a/core/lib/src/util/mod.rs +++ b/core/lib/src/util/mod.rs @@ -1,5 +1,4 @@ mod chain; -mod tripwire; mod reader_stream; mod join; @@ -7,6 +6,49 @@ mod join; pub mod unix; pub use chain::Chain; -pub use tripwire::TripWire; pub use reader_stream::ReaderStream; pub use join::join; + +#[track_caller] +pub fn spawn_inspect(or: F, future: Fut) + where F: FnOnce(&E) + Send + Sync + 'static, + E: Send + Sync + 'static, + Fut: std::future::Future> + Send + 'static, +{ + use futures::TryFutureExt; + tokio::spawn(future.inspect_err(or)); +} + +use std::io; +use std::pin::pin; +use std::future::Future; +use futures::future::{select, Either}; + +pub trait FutureExt: Future + Sized { + async fn or(self, other: B) -> Either { + match futures::future::select(pin!(self), pin!(other)).await { + Either::Left((v, _)) => Either::Left(v), + Either::Right((v, _)) => Either::Right(v), + } + } + + async fn try_until(self, trigger: K) -> Result, E> + where Self: Future> + { + match select(pin!(self), pin!(trigger)).await { + Either::Left((v, _)) => Ok(Some(v?)), + Either::Right((_, _)) => Ok(None), + } + } + + async fn io_try_until(self, trigger: K) -> std::io::Result + where Self: Future> + { + match select(pin!(self), pin!(trigger)).await { + Either::Left((v, _)) => v, + Either::Right((_, _)) => Err(io::Error::other("I/O terminated")), + } + } +} + +impl FutureExt for F { } diff --git a/docs/guide/10-configuration.md b/docs/guide/10-configuration.md index 3865b94bbc..e86d897082 100644 --- a/docs/guide/10-configuration.md +++ b/docs/guide/10-configuration.md @@ -21,24 +21,24 @@ is configured with. This means that no matter which configuration provider Rocket is asked to use, it must be able to read the following configuration values: -| key | kind | description | debug/release default | -|----------------------|-------------------|-------------------------------------------------|-------------------------| -| `address` | `IpAddr` | IP address to serve on. | `127.0.0.1` | -| `port` | `u16` | Port to serve on. | `8000` | -| `workers`* | `usize` | Number of threads to use for executing futures. | cpu core count | -| `max_blocking`* | `usize` | Limit on threads to start for blocking tasks. | `512` | -| `ident` | `string`, `false` | If and how to identify via the `Server` header. | `"Rocket"` | -| `ip_header` | `string`, `false` | IP header to inspect to get [client's real IP]. | `"X-Real-IP"` | -| `proxy_proto_header` | `string`, `false` | Header identifying [client to proxy protocol]. | `None` | -| `keep_alive` | `u32` | Keep-alive timeout seconds; disabled when `0`. | `5` | -| `log_level` | [`LogLevel`] | Max level to log. (off/normal/debug/critical) | `normal`/`critical` | -| `cli_colors` | [`CliColors`] | Whether to use colors and emoji when logging. | `"auto"` | -| `secret_key` | [`SecretKey`] | Secret key for signing and encrypting values. | `None` | -| `tls` | [`TlsConfig`] | TLS configuration, if any. | `None` | -| `limits` | [`Limits`] | Streaming read size limits. | [`Limits::default()`] | -| `limits.$name` | `&str`/`uint` | Read limit for `$name`. | form = "32KiB" | -| `ctrlc` | `bool` | Whether `ctrl-c` initiates a server shutdown. | `true` | -| `shutdown`* | [`Shutdown`] | Graceful shutdown configuration. | [`Shutdown::default()`] | +| key | kind | description | debug/release default | +|----------------------|--------------------|-------------------------------------------------|-------------------------------| +| `address` | `IpAddr` | IP address to serve on. | `127.0.0.1` | +| `port` | `u16` | Port to serve on. | `8000` | +| `workers`* | `usize` | Number of threads to use for executing futures. | cpu core count | +| `max_blocking`* | `usize` | Limit on threads to start for blocking tasks. | `512` | +| `ident` | `string`, `false` | If and how to identify via the `Server` header. | `"Rocket"` | +| `ip_header` | `string`, `false` | IP header to inspect to get [client's real IP]. | `"X-Real-IP"` | +| `proxy_proto_header` | `string`, `false` | Header identifying [client to proxy protocol]. | `None` | +| `keep_alive` | `u32` | Keep-alive timeout seconds; disabled when `0`. | `5` | +| `log_level` | [`LogLevel`] | Max level to log. (off/normal/debug/critical) | `normal`/`critical` | +| `cli_colors` | [`CliColors`] | Whether to use colors and emoji when logging. | `"auto"` | +| `secret_key` | [`SecretKey`] | Secret key for signing and encrypting values. | `None` | +| `tls` | [`TlsConfig`] | TLS configuration, if any. | `None` | +| `limits` | [`Limits`] | Streaming read size limits. | [`Limits::default()`] | +| `limits.$name` | `&str`/`uint` | Read limit for `$name`. | form = "32KiB" | +| `ctrlc` | `bool` | Whether `ctrl-c` initiates a server shutdown. | `true` | +| `shutdown`* | [`ShutdownConfig`] | Graceful shutdown configuration. | [`ShutdownConfig::default()`] | * Note: the `workers`, `max_blocking`, and `shutdown.force` configuration @@ -77,8 +77,8 @@ profile supplant any values with the same name in any profile. [`SecretKey`]: @api/master/rocket/config/struct.SecretKey.html [`CliColors`]: @api/master/rocket/config/enum.CliColors.html [`TlsConfig`]: @api/master/rocket/tls/struct.TlsConfig.html -[`Shutdown`]: @api/master/rocket/config/struct.Shutdown.html -[`Shutdown::default()`]: @api/master/rocket/config/struct.Shutdown.html#fields +[`ShutdownConfig`]: @api/master/rocket/shutdown/struct.ShutdownConfig.html +[`ShutdownConfig::default()`]: @api/master/rocket/shutdown/struct.ShutdownConfig.html#fields ## Default Provider diff --git a/examples/tls/Cargo.toml b/examples/tls/Cargo.toml index b37758ffdd..344ed7b491 100644 --- a/examples/tls/Cargo.toml +++ b/examples/tls/Cargo.toml @@ -6,5 +6,5 @@ edition = "2021" publish = false [dependencies] -rocket = { path = "../../core/lib", features = ["tls", "mtls", "secrets"] } +rocket = { path = "../../core/lib", features = ["tls", "mtls", "secrets", "http3"] } yansi = "1.0.1" diff --git a/examples/tls/Rocket.toml b/examples/tls/Rocket.toml index b7a538f96f..8157e5f6fe 100644 --- a/examples/tls/Rocket.toml +++ b/examples/tls/Rocket.toml @@ -5,6 +5,9 @@ # directly for your browser to show connections as secure. You should NEVER use # these certificate/key pairs. They are here for DEMONSTRATION PURPOSES ONLY. +[default] +port = 443 + [default.tls] certs = "private/rsa_sha256_cert.pem" key = "private/rsa_sha256_key.pem" diff --git a/examples/tls/src/main.rs b/examples/tls/src/main.rs index 4ce4254c24..31c8d09cd6 100644 --- a/examples/tls/src/main.rs +++ b/examples/tls/src/main.rs @@ -22,5 +22,5 @@ fn rocket() -> _ { // Run `./private/gen_certs.sh` to generate a CA and key pairs. rocket::build() .mount("/", routes![hello, mutual]) - .attach(redirector::Redirector::on(3000)) + // .attach(redirector::Redirector::on(3000)) }