Skip to content

Commit

Permalink
wip: advanced csrf token support
Browse files Browse the repository at this point in the history
  • Loading branch information
SergioBenitez committed May 4, 2024
1 parent 370287c commit 90daaad
Show file tree
Hide file tree
Showing 16 changed files with 840 additions and 14 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,6 @@ members = [
"contrib/sync_db_pools/lib/",
"contrib/dyn_templates/",
"contrib/ws/",
"contrib/csrf/",
"docs/tests",
]
30 changes: 30 additions & 0 deletions contrib/csrf/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
[package]
name = "rocket_csrf"
version = "0.1.0"
authors = ["Sergio Benitez <[email protected]>"]
description = "CSRF protection for Rocket."
documentation = "https://api.rocket.rs/master/rocket_csrf/"
homepage = "https://rocket.rs"
repository = "https://github.com/rwf2/Rocket/tree/master/contrib/csrf"
readme = "README.md"
keywords = ["rocket", "web", "framework", "csrf", "security"]
license = "MIT OR Apache-2.0"
edition = "2021"
rust-version = "1.75"

[dependencies]
rand = { version = "0.8.5", features = ["min_const_gen"] }
arc-swap = "1.7"
blake3 = { version = "1.5.1", features = ["serde"] }
base64 = "0.22"
zerocopy = { version = "=0.8.0-alpha.7", features = ["derive"] }
multer = { version = "3.0.0", features = ["tokio-io"] }

[dependencies.rocket]
version = "0.6.0-dev"
path = "../../core/lib"
default-features = false
features = ["secrets"]

[package.metadata.docs.rs]
all-features = true
12 changes: 12 additions & 0 deletions contrib/csrf/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# `csrf` [![ci.svg]][ci] [![crates.io]][crate] [![docs.svg]][crate docs]

[crates.io]: https://img.shields.io/crates/v/rocket_csrf.svg
[crate]: https://crates.io/crates/rocket_csrf
[docs.svg]: https://img.shields.io/badge/web-master-red.svg?style=flat&label=docs&colorB=d33847
[crate docs]: https://api.rocket.rs/master/rocket_csrf
[ci.svg]: https://github.com/rwf2/Rocket/workflocsrf/CI/badge.svg
[ci]: https://github.com/rwf2/Rocket/actions

CSRF protection for Rocket.

See the [crate docs] for full details.
47 changes: 47 additions & 0 deletions contrib/csrf/src/config.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
use std::time::Duration;

use rocket::serde::{Deserialize, Serialize};

#[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(crate = "rocket::serde")]
pub struct Config {
pub enable: bool,
pub rotate: Rotate,
}

#[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(crate = "rocket::serde")]
pub struct Rotate {
pub period: u8,
pub window: u8,
}

impl Default for Config {
fn default() -> Self {
Self { enable: true, rotate: Rotate::default() }
}
}

impl Default for Rotate {
fn default() -> Self {
Self {
period: 24,
window: 6,
}
}
}

impl Rotate {
pub const fn period(&self) -> Duration {
Duration::from_secs(self.period as u64 * 3600)
}

pub const fn window(&self) -> Duration {
Duration::from_secs(self.window as u64 * 3600)
}

pub const fn epoch(&self) -> Duration {
let wait = self.period.saturating_sub(self.window);
Duration::from_secs(wait as u64 * 3600)
}
}
129 changes: 129 additions & 0 deletions contrib/csrf/src/fairing.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
use rocket::form::Form;
use rocket::fairing::{AdHoc, Fairing, Info, Kind};
use rocket::figment::providers::Serialized;
use rocket::futures::Race;
use rocket::{Data, Orbit, Request, Rocket};
use rocket::tokio::{spawn, time::sleep};
use rocket::yansi::Paint;

use crate::{Config, Session, Token, Tokenizer};

struct TokenizerFairing {
config: Config,
tokenizer: Tokenizer,
}

impl TokenizerFairing {
const FORM_FIELD: &'static str = "_authenticity_token";

const HEADER: &'static str = "X-CSRF-Token";

fn new(config: Config) -> Option<Self> {
Some(Self { config, tokenizer: Tokenizer::new() })
}
}

impl Tokenizer {
pub fn fairing() -> impl Fairing {
AdHoc::try_on_ignite("CSRF Protection Configuration", |rocket| async {
let config = rocket.figment()
.clone()
.join(Serialized::default("csrf", Config::default()))
.extract_inner::<Config>("csrf");

match config {
Ok(config) if config.enable => match TokenizerFairing::new(config) {
Some(fairing) => Ok(rocket.attach(fairing)),
None => {
error!("{}CSRF protection failed to initialize.", "🔐 ".mask());
Err(rocket)
}
},
Ok(_) => Ok(rocket),
Err(e) => {
let kind = rocket::error::ErrorKind::Config(e);
rocket::Error::from(kind).pretty_print();
Err(rocket)
},
}
})
}
}

#[rocket::async_trait]
impl Fairing for TokenizerFairing {
fn info(&self) -> Info {
Info {
name: "Tokenizer",
kind: Kind::Singleton | Kind::Liftoff | Kind::Request | Kind::Response
}
}

async fn on_liftoff(&self, rocket: &Rocket<Orbit>) {
let rotate = self.config.rotate;
info!("{}{}", "🔐 ".mask(), "CSRF Protection:".magenta());
info_!("status: {}", "enabled".green());
info_!("rotation: {}/{}", rotate.period, rotate.window);

let tokenizer = self.tokenizer.clone();
spawn(rocket.shutdown().race(async move {
loop {
sleep(rotate.epoch()).await;
tokenizer.rotate();
info!("{}{}", "🔐 ".mask(), "CSRF Protection: keys sliding.");

sleep(rotate.window()).await;
tokenizer.rotate();
info!("{}{}", "🔐 ".mask(), "CSRF Protection: keys rotated.");
}
}));
}

async fn on_request(&self, req: &mut Request<'_>, data: &mut Data<'_>) {
let session = Session::fetch(req);
let gen_token = self.tokenizer.form_token(session.id());
dbg!(&session, &gen_token, gen_token.to_string());

if !req.method().supports_payload() {
return;
}

let token = match req.content_type() {
Some(mime) if mime.is_form() => {
std::str::from_utf8(data.peek(192).await).ok()
.into_iter()
.flat_map(Form::values)
.find(|field| field.name == Self::FORM_FIELD)
.and_then(|field| field.value.parse::<Token>().ok())
},
// TODO: Fix _method resolution for form data in Rocket proper.
Some(mime) if mime.is_form_data() => {
let token = async {
let data = data.peek(512).await;
let boundary = mime.param("boundary")?;
let mut form = multer::Multipart::with_reader(data, boundary);
while let Ok(Some(field)) = form.next_field().await {
if field.name() == Some(Self::FORM_FIELD) {
return field.text().await.ok()?.parse().ok();
}
}

None
};

token.await
},
_ => req.headers().get_one(Self::HEADER).and_then(|s| s.parse().ok()),
};

// FIXME: Check token context matches the expectation too.
if !dbg!(token.as_ref()).map_or(false, |token| self.tokenizer.validate(token, &session)) {
match token {
Some(_) => error_!("{}{}", "🔐 ".mask(), "CSRF Protection: invalid token."),
None => error_!("{}{}", "🔐 ".mask(), "CSRF Protection: missing token."),
}

req.set_uri(uri!("/__rocket/csrf/denied"));
}
}
}
58 changes: 58 additions & 0 deletions contrib/csrf/src/key.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
use rand::distributions::{Distribution, Standard};

// TODO: Make this thread-safe (ArcSwap internally). Use it to generate a
// rotating session id. Probably remove generation from this?

#[derive(Debug, Clone)]
pub struct Rotatable<T> {
generation: u32,
current: T,
previous: Option<T>,
}

impl<T> Rotatable<T> {
#[inline(always)]
pub fn new(value: T) -> Self {
Self {
generation: 0,
current: value,
previous: None
}
}

#[inline(always)]
pub fn rotate(&mut self, new: T) {
let old = std::mem::replace(&mut self.current, new);
self.previous.replace(old);
self.generation = self.generation.wrapping_add(1);
}

pub fn generation(&self) -> u32 {
self.generation
}

pub fn iter(&self) -> impl Iterator<Item = &T> {
std::iter::once(&self.current).chain(self.previous.as_ref())
}
}

impl<T> Rotatable<T>
where Standard: Distribution<T>
{
#[inline(always)]
pub fn generate() -> Self {
Self::new(rand::random())
}

#[inline(always)]
pub fn generate_and_rotate(&mut self) -> Result<(), ()> {
self.rotate(Self::generate().current);
Ok(())
}
}

impl<T> AsRef<T> for Rotatable<T> {
fn as_ref(&self) -> &T {
&self.current
}
}
Loading

0 comments on commit 90daaad

Please sign in to comment.