Skip to content

Commit

Permalink
split into files
Browse files Browse the repository at this point in the history
  • Loading branch information
aumetra committed Dec 15, 2023
1 parent 5da9e0a commit 8dd01dd
Show file tree
Hide file tree
Showing 5 changed files with 292 additions and 267 deletions.
52 changes: 52 additions & 0 deletions lib/csurf/src/future.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
use crate::{CsrfHandle, CSRF_COOKIE_NAME};
use cookie::{Cookie, SameSite};
use http::{header, HeaderValue, Response};
use pin_project_lite::pin_project;
use std::{
future::Future,
pin::Pin,
task::{self, ready, Poll},
};

pin_project! {
pub struct ResponseFuture<F> {
#[pin]
pub(crate) inner: F,
pub(crate) handle: CsrfHandle,
}
}

impl<F, E, ResBody> Future for ResponseFuture<F>
where
F: Future<Output = Result<Response<ResBody>, E>>,
{
type Output = Result<Response<ResBody>, E>;

fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
let this = self.project();

let mut response = ready!(this.inner.poll(cx))?;
let mut cookie = Cookie::build(CSRF_COOKIE_NAME)
.permanent()
.same_site(SameSite::Strict)
.secure(true)
.build();

let guard = this.handle.inner.lock().unwrap();
if let Some(ref set_data) = guard.set_data {
let value = format!("{}.{}", set_data.hash, set_data.message);
cookie.set_value(value);
} else {
cookie.make_removal();
}

let encoded_cookie = cookie.encoded().to_string();
let header_value = HeaderValue::from_str(&encoded_cookie).unwrap();

response
.headers_mut()
.append(header::SET_COOKIE, header_value);

Poll::Ready(Ok(response))
}
}
117 changes: 117 additions & 0 deletions lib/csurf/src/handle.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
use crate::{CsrfData, HashRef, Message, MessageRef, RANDOM_DATA_LEN};
use hex_simd::{AsOut, AsciiCase};
use rand::{distributions::Alphanumeric, Rng};
use std::{
fmt::Display,
sync::{Arc, Mutex},
};
use zeroize::{Zeroize, ZeroizeOnDrop};

pub struct Shared {
pub(crate) read_data: Option<CsrfData>,
pub(crate) set_data: Option<CsrfData>,
}

#[derive(Clone, Zeroize, ZeroizeOnDrop)]
pub struct CsrfHandle {
#[zeroize(skip)]
pub(crate) inner: Arc<Mutex<Shared>>,
pub(crate) key: [u8; blake3::KEY_LEN],
}

fn raw_verify(key: &[u8; blake3::KEY_LEN], hash: &HashRef, message: &MessageRef) -> bool {
let (hash, message) = (hash.as_ref(), message.as_ref());
if hash.len() / 2 != blake3::OUT_LEN {
return false;
}

let mut decoded_hash = [0_u8; blake3::OUT_LEN];
if hex_simd::decode(hash.as_bytes(), decoded_hash.as_mut().as_out()).is_err() {
return false;
}

let expected_hash = blake3::keyed_hash(key, message.as_bytes());

// The `PartialEq` implementation on `Hash` is constant-time
expected_hash == decoded_hash
}

impl CsrfHandle {
/// Keep the current signature and message inside the cookie
#[inline]
pub fn keep_cookie(&self) {
let mut guard = self.inner.lock().unwrap();
guard.set_data = guard.read_data.clone();
}

/// Create a signature and store it inside a cookie
///
/// **Important**: The data passed into this function should reference an *authenticated session*.
/// The use of the user ID (or something similarly static) is *discouraged*, use the session ID.
#[inline]
pub fn sign<SID>(&self, session_id: SID) -> Message
where
SID: AsRef<[u8]> + Display,
{
let random = rand::thread_rng()
.sample_iter(Alphanumeric)
.map(char::from)
.take(RANDOM_DATA_LEN)
.collect::<String>();

let message = format!("{session_id}!{random}");
let hash = blake3::keyed_hash(&self.key, message.as_bytes());
let hash = hex_simd::encode_to_string(hash.as_bytes(), AsciiCase::Lower);

let message: Message = message.into();
self.inner.lock().unwrap().set_data = Some(CsrfData {
hash: hash.into(),
message: message.clone(),
});

message
}

/// Verify the CSRF request
///
/// Simply pass in the message that was submitted by the client.
/// Internally, we will compare this to the
#[inline]
#[must_use]
pub fn verify(&self, message: &MessageRef) -> bool {
let guard = self.inner.lock().unwrap();
let Some(ref read_data) = guard.read_data else {
return false;
};

raw_verify(&self.key, &read_data.hash, &read_data.message)
&& raw_verify(&self.key, &read_data.hash, message)
}
}

#[cfg(feature = "axum")]
mod axum_impl {
use super::CsrfHandle;
use async_trait::async_trait;
use axum_core::extract::FromRequestParts;
use http::request::Parts;
use std::convert::Infallible;

#[async_trait]
impl<S> FromRequestParts<S> for CsrfHandle {
type Rejection = Infallible;

async fn from_request_parts(
parts: &mut Parts,
_state: &S,
) -> Result<Self, Self::Rejection> {
let handle = parts
.extensions
.get::<Self>()
.expect("Service not wrapped by CSRF middleware")
.clone();

Ok(handle)
}
}
}
23 changes: 23 additions & 0 deletions lib/csurf/src/layer.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
use crate::CsrfService;
use tower::Layer;
use zeroize::{Zeroize, ZeroizeOnDrop};

#[derive(Clone, Zeroize, ZeroizeOnDrop)]
pub struct CsrfLayer {
key: [u8; blake3::KEY_LEN],
}

impl CsrfLayer {
#[must_use]
pub fn new(key: [u8; blake3::KEY_LEN]) -> Self {
Self { key }
}
}

impl<S> Layer<S> for CsrfLayer {
type Service = CsrfService<S>;

fn layer(&self, inner: S) -> Self::Service {
CsrfService::new(inner, self.key)
}
}
Loading

0 comments on commit 8dd01dd

Please sign in to comment.