From a32edf59884d89664ed4837380ac6a0811437700 Mon Sep 17 00:00:00 2001 From: Aumetra Weisman Date: Sat, 4 Jan 2025 14:42:58 +0000 Subject: [PATCH] Less allocations in flash impl (#627) * restructure cursiv cookie handling * own cookie signing using blake3 * fix routes --- Cargo.lock | 6 +- Cargo.toml | 6 +- kitsune/src/http/router.rs | 64 +++++++++++--------- lib/cursiv/src/service.rs | 45 ++++++++------ lib/flashy/Cargo.toml | 5 ++ lib/flashy/src/lib.rs | 121 ++++++++++++++++++++++++++++++------- lib/flashy/tests/basic.rs | 3 +- 7 files changed, 172 insertions(+), 78 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 56c00d80e..7a81f224c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1196,7 +1196,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4ddef33a339a91ea89fb53151bd0a4689cfce27055c291dfa69945475d22c747" dependencies = [ "base64 0.22.1", - "hkdf", "hmac", "percent-encoding", "rand 0.8.5", @@ -2105,14 +2104,19 @@ name = "flashy" version = "0.0.1-pre.6" dependencies = [ "axum-core 0.5.0", + "blake3", "cookie", "futures-test", + "hex-simd", "http", "pin-project-lite", + "rand 0.8.5", "serde", "sonic-rs", + "subtle", "tower 0.5.2", "triomphe", + "zeroize", ] [[package]] diff --git a/Cargo.toml b/Cargo.toml index 0a56f87b3..5d8d584cd 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -113,11 +113,7 @@ color-eyre = "0.6.3" colored_json = "5.0.0" const_format = "0.2.34" const-oid = { version = "0.9.6", features = ["db"] } -cookie = { version = "0.18.1", features = [ - "key-expansion", - "percent-encode", - "signed", -] } +cookie = { version = "0.18.1", features = ["percent-encode"] } derive_builder = "0.20.2" derive_more = { version = "1.0.0", features = ["from"] } diesel = { version = "2.2.6", default-features = false, features = [ diff --git a/kitsune/src/http/router.rs b/kitsune/src/http/router.rs index 64063a111..4732aed76 100644 --- a/kitsune/src/http/router.rs +++ b/kitsune/src/http/router.rs @@ -16,14 +16,15 @@ use tower_http_digest::VerifyDigestLayer; use tower_stop_using_brave::StopUsingBraveLayer; use tower_x_clacks_overhead::XClacksOverheadLayer; +#[allow(clippy::too_many_lines)] pub fn create(state: Zustand, server_config: &server::Configuration) -> eyre::Result { let router = Router::new() .route( - "/confirm-account/:confirmation_token", + "/confirm-account/{confirmation_token}", routing::get(handler::confirm_account::get), ) - .route("/emojis/:id", routing::get(handler::custom_emojis::get)) - .route("/media/:id", routing::get(handler::media::get)) + .route("/emojis/{id}", routing::get(handler::custom_emojis::get)) + .route("/media/{id}", routing::get(handler::media::get)) .route( "/nodeinfo/2.1", routing::get(handler::nodeinfo::two_one::get), @@ -42,27 +43,30 @@ pub fn create(state: Zustand, server_config: &server::Configuration) -> eyre::Re .nest( "/posts", Router::new() - .route("/:id", routing::get(handler::posts::get)) - .route("/:id/activity", routing::get(handler::posts::activity::get)), + .route("/{id}", routing::get(handler::posts::get)) + .route( + "/{id}/activity", + routing::get(handler::posts::activity::get), + ), ) .nest( "/users", Router::new() - .route("/:user_id", routing::get(handler::users::get)) + .route("/{user_id}", routing::get(handler::users::get)) .route( - "/:user_id/followers", + "/{user_id}/followers", routing::get(handler::users::followers::get), ) .route( - "/:user_id/following", + "/{user_id}/following", routing::get(handler::users::following::get), ) .route( - "/:user_id/inbox", + "/{user_id}/inbox", routing::post(handler::users::inbox::post).layer(VerifyDigestLayer::default()), ) .route( - "/:user_id/outbox", + "/{user_id}/outbox", routing::get(handler::users::outbox::get), ), ) @@ -78,7 +82,7 @@ pub fn create(state: Zustand, server_config: &server::Configuration) -> eyre::Re routing::get(handler::well_known::webfinger::get), ), ) - .route("/public/*path", routing::get(handler::public::get)); + .route("/public/{*path}", routing::get(handler::public::get)); #[cfg(feature = "oidc")] let router = router.route("/oidc/callback", routing::get(handler::oidc::callback::get)); @@ -110,19 +114,19 @@ pub fn create(state: Zustand, server_config: &server::Configuration) -> eyre::Re "/accounts", Router::new() .route( - "/:id", + "/{id}", routing::get(handler::mastodon::api::v1::accounts::get), ) .route( - "/:id/follow", + "/{id}/follow", routing::post(handler::mastodon::api::v1::accounts::follow::post), ) .route( - "/:id/statuses", + "/{id}/statuses", routing::get(handler::mastodon::api::v1::accounts::statuses::get), ) .route( - "/:id/unfollow", + "/{id}/unfollow", routing::post(handler::mastodon::api::v1::accounts::unfollow::post), ) .route( @@ -164,13 +168,13 @@ pub fn create(state: Zustand, server_config: &server::Configuration) -> eyre::Re routing::get(handler::mastodon::api::v1::follow_requests::get), ) .route( - "/:id/authorize", + "/{id}/authorize", routing::post( handler::mastodon::api::v1::follow_requests::accept::post, ), ) .route( - "/:id/reject", + "/{id}/reject", routing::post( handler::mastodon::api::v1::follow_requests::reject::post, ), @@ -192,7 +196,7 @@ pub fn create(state: Zustand, server_config: &server::Configuration) -> eyre::Re ), ) .route( - "/:id", + "/{id}", routing::get(handler::mastodon::api::v1::media::get) .put(handler::mastodon::api::v1::media::put), ), @@ -205,11 +209,11 @@ pub fn create(state: Zustand, server_config: &server::Configuration) -> eyre::Re routing::get(handler::mastodon::api::v1::notifications::get), ) .route( - "/:id", + "/{id}", routing::get(handler::mastodon::api::v1::notifications::get_by_id), ) .route( - "/:id/dismiss", + "/{id}/dismiss", routing::post( handler::mastodon::api::v1::notifications::dismiss::post, ), @@ -229,49 +233,49 @@ pub fn create(state: Zustand, server_config: &server::Configuration) -> eyre::Re routing::post(handler::mastodon::api::v1::statuses::post), ) .route( - "/:id", + "/{id}", routing::delete(handler::mastodon::api::v1::statuses::delete) .get(handler::mastodon::api::v1::statuses::get) .put(handler::mastodon::api::v1::statuses::put), ) .route( - "/:id/context", + "/{id}/context", routing::get(handler::mastodon::api::v1::statuses::context::get), ) .route( - "/:id/favourite", + "/{id}/favourite", routing::post( handler::mastodon::api::v1::statuses::favourite::post, ), ) .route( - "/:id/favourited_by", + "/{id}/favourited_by", routing::get( handler::mastodon::api::v1::statuses::favourited_by::get, ), ) .route( - "/:id/reblog", + "/{id}/reblog", routing::post(handler::mastodon::api::v1::statuses::reblog::post), ) .route( - "/:id/reblogged_by", + "/{id}/reblogged_by", routing::get( handler::mastodon::api::v1::statuses::reblogged_by::get, ), ) .route( - "/:id/source", + "/{id}/source", routing::get(handler::mastodon::api::v1::statuses::source::get), ) .route( - "/:id/unfavourite", + "/{id}/unfavourite", routing::post( handler::mastodon::api::v1::statuses::unfavourite::post, ), ) .route( - "/:id/unreblog", + "/{id}/unreblog", routing::post(handler::mastodon::api::v1::statuses::unreblog::post), ), ) @@ -303,7 +307,7 @@ pub fn create(state: Zustand, server_config: &server::Configuration) -> eyre::Re ), ) .route( - "/:id", + "/{id}", routing::get(handler::mastodon::api::v1::media::get) .put(handler::mastodon::api::v1::media::put), ), diff --git a/lib/cursiv/src/service.rs b/lib/cursiv/src/service.rs index d18fb07e5..add85cb82 100644 --- a/lib/cursiv/src/service.rs +++ b/lib/cursiv/src/service.rs @@ -35,25 +35,34 @@ where } fn call(&mut self, mut req: Request) -> Self::Future { - let csrf_cookie = req - .headers() - .get_all(header::COOKIE) - .into_iter() - .filter_map(|value| value.to_str().ok()) // Filter out all the values that aren't valid UTF-8 - .flat_map(Cookie::split_parse_encoded) // Parse all the cookie headers and flatten the resulting iterator into a contiguous one - .flatten() // Call `.flatten()` to turn `Result` -> `Cookie`, ignoring all the errors - .find(|cookie| cookie.name() == CSRF_COOKIE_NAME); // Find the cookie with the name of our CSRF cookie + let read_data = { + let mut csrf_data = None; + 'outer: for header in req.headers().get_all(header::COOKIE) { + let Ok(value_str) = header.to_str() else { + continue; + }; - let read_data = if let Some(csrf_cookie) = csrf_cookie { - csrf_cookie - .value_trimmed() - .split_once('.') - .map(|(hash, message)| CsrfData { - hash: hash.into(), - message: message.into(), - }) - } else { - None + for cookie in Cookie::split_parse_encoded(value_str) { + let Ok(cookie) = cookie else { + continue; + }; + + if cookie.name() == CSRF_COOKIE_NAME { + let Some((hash, message)) = cookie.value_trimmed().split_once('.') else { + continue; + }; + + csrf_data = Some(CsrfData { + hash: hash.into(), + message: message.into(), + }); + + break 'outer; + } + } + } + + csrf_data }; let handle = CsrfHandle { diff --git a/lib/flashy/Cargo.toml b/lib/flashy/Cargo.toml index 621afb51a..d6d203351 100644 --- a/lib/flashy/Cargo.toml +++ b/lib/flashy/Cargo.toml @@ -7,14 +7,19 @@ license = "MIT OR Apache-2.0" [dependencies] axum-core = { workspace = true, optional = true } +blake3.workspace = true cookie.workspace = true futures-test.workspace = true +hex-simd.workspace = true http.workspace = true pin-project-lite.workspace = true +rand.workspace = true serde.workspace = true sonic-rs.workspace = true +subtle.workspace = true tower.workspace = true triomphe.workspace = true +zeroize.workspace = true [features] axum = ["dep:axum-core"] diff --git a/lib/flashy/src/lib.rs b/lib/flashy/src/lib.rs index 450094dac..161542b89 100644 --- a/lib/flashy/src/lib.rs +++ b/lib/flashy/src/lib.rs @@ -1,23 +1,59 @@ -use cookie::{Cookie, CookieJar, Expiration, SameSite}; +use cookie::{Cookie, Expiration, SameSite}; +use hex_simd::Out; use http::HeaderValue; use pin_project_lite::pin_project; use serde::{Deserialize, Serialize}; use std::{ future::Future, + ops::Deref, pin::Pin, - slice, + slice, str, sync::Mutex, task::{self, ready, Poll}, }; +use subtle::ConstantTimeEq; use tower::{Layer, Service}; use triomphe::Arc; - -pub use cookie::Key; +use zeroize::{Zeroize, ZeroizeOnDrop}; const COOKIE_NAME: &str = "FLASH_MESSAGES"; type Flash = (Level, String); +#[derive(Clone, Zeroize, ZeroizeOnDrop)] +pub struct Key([u8; blake3::KEY_LEN]); + +impl Key { + #[inline] + #[must_use] + pub fn new(inner: [u8; blake3::KEY_LEN]) -> Self { + Self(inner) + } + + #[inline] + #[must_use] + pub fn derive_from(data: &[u8]) -> Self { + const CONTEXT: &str = "FLASHY-SIGN_COOKIE-BLAKE3-V1"; + + Self::new(blake3::derive_key(CONTEXT, data)) + } + + #[inline] + #[must_use] + pub fn generate() -> Self { + Self::new(rand::random()) + } +} + +impl Deref for Key { + type Target = [u8; blake3::KEY_LEN]; + + #[inline] + fn deref(&self) -> &Self::Target { + &self.0 + } +} + #[derive(Clone, Copy, Debug, Deserialize, Eq, Ord, PartialEq, PartialOrd, Serialize)] pub enum Level { Debug, @@ -142,7 +178,15 @@ where let encoded_messages = { let guard = handle.0.lock().unwrap(); - sonic_rs::to_string(&guard.flashes).expect("failed to encode messages") + let serialized = + sonic_rs::to_string(&guard.flashes).expect("failed to encode messages"); + + let signed = sign_data(&key, &serialized); + #[allow(unsafe_code)] + // SAFETY: the future returns correctly encoded UTF-8 + let signed = unsafe { str::from_utf8_unchecked(&signed) }; + + format!("{signed}.{serialized}") }; let mut cookie = Cookie::new(COOKIE_NAME, encoded_messages); @@ -150,16 +194,10 @@ where cookie.set_secure(true); cookie.set_expires(Expiration::Session); - let mut jar = CookieJar::new(); - let mut signed_jar = jar.signed_mut(&key); - signed_jar.add(cookie); - - for cookie in jar.iter() { - let encoded = cookie.encoded().to_string(); - let value = HeaderValue::from_bytes(encoded.as_ref()).unwrap(); + let encoded = cookie.encoded().to_string(); + let value = HeaderValue::from_bytes(encoded.as_ref()).unwrap(); - resp.headers_mut().insert(http::header::SET_COOKIE, value); - } + resp.headers_mut().append(http::header::SET_COOKIE, value); return Poll::Ready(Ok(resp)); } @@ -170,6 +208,34 @@ where } } +#[inline] +fn sign_data(key: &Key, value: &str) -> [u8; blake3::OUT_LEN * 2] { + let hash = blake3::keyed_hash(key, value.as_bytes()); + + let mut out = [0; blake3::OUT_LEN * 2]; + let enc_slice = hex_simd::encode( + hash.as_bytes(), + Out::from_slice(&mut out), + hex_simd::AsciiCase::Lower, + ); + assert_eq!(enc_slice.len(), out.len()); + + out +} + +#[inline] +fn verify_data(key: &Key, mac: &str, value: &str) -> bool { + let mut out = [0; blake3::KEY_LEN]; + let Ok(decoded_mac) = hex_simd::decode(mac.as_ref(), Out::from_slice(&mut out)) else { + return false; + }; + + blake3::keyed_hash(key, value.as_bytes()) + .as_bytes() + .ct_eq(decoded_mac) + .into() +} + impl Service> for FlashService where S: Service, Response = http::Response>, @@ -185,8 +251,8 @@ where #[inline] fn call(&mut self, mut req: http::Request) -> Self::Future { - let mut jar = CookieJar::new(); - for header in req.headers().get_all(http::header::COOKIE) { + let mut flash_cookie = None; + 'outer: for header in req.headers().get_all(http::header::COOKIE) { let Ok(cookie_str) = header.to_str() else { continue; }; @@ -197,15 +263,26 @@ where continue; }; - jar.add_original(cookie); + if cookie.name() == COOKIE_NAME { + flash_cookie = Some(cookie); + break 'outer; + } } } - let signed_jar = jar.signed(&self.key); - let flashes = signed_jar - .get(COOKIE_NAME) - .and_then(|cookie| sonic_rs::from_str(cookie.value()).ok()) - .unwrap_or_default(); + let flashes = if let Some(flash_cookie) = flash_cookie { + if let Some((mac, value)) = flash_cookie.value().split_once('.') { + if verify_data(&self.key, mac, value) { + sonic_rs::from_str(value).unwrap() + } else { + Vec::new() + } + } else { + Vec::new() + } + } else { + Vec::new() + }; let read_flashes = IncomingFlashes(Arc::new(flashes)); let handle = FlashHandle(Arc::new(Mutex::new(HandleInner { diff --git a/lib/flashy/tests/basic.rs b/lib/flashy/tests/basic.rs index 2916b5ee9..e68d02dc6 100644 --- a/lib/flashy/tests/basic.rs +++ b/lib/flashy/tests/basic.rs @@ -1,5 +1,4 @@ -use cookie::Key; -use flashy::{FlashHandle, FlashLayer, IncomingFlashes, Level}; +use flashy::{FlashHandle, FlashLayer, IncomingFlashes, Key, Level}; use http::header::{COOKIE, SET_COOKIE}; use std::convert::Infallible; use tower::{Layer, ServiceExt};