diff --git a/Cargo.toml b/Cargo.toml index f1757d5..081696c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,6 +28,11 @@ name = "axum" path = "examples/axum.rs" required-features = ["axum"] +[[example]] +name = "rocket" +path = "examples/rocket.rs" +required-features = ["rocket"] + [lib] doctest = false @@ -43,6 +48,7 @@ futures-util = "0.3.28" actix-rt = { version = "2.10.0", optional = true } actix-web = { version = "4.9.0", optional = true } axum = { version = "0.7.5", optional = true } +rocket = { version = "0.5.0", optional = true } axum-extra = { version = "0.9.3", features = ["cookie"], optional = true } tower = { version = "0.5.0", optional = true } async-trait = "0.1.81" @@ -67,3 +73,4 @@ native-tls = ["reqwest/native-tls"] rustls-tls = ["reqwest/rustls-tls"] actix = ["dep:actix-rt", "dep:actix-web"] axum = ["dep:axum", "dep:axum-extra", "dep:tower"] +rocket = ["dep:rocket"] diff --git a/README.md b/README.md index c355bd0..bea27ee 100644 --- a/README.md +++ b/README.md @@ -111,6 +111,49 @@ async fn main() -> std::io::Result<()> { } ``` +### Protecting a rocket endpoint with Clerk.dev: + +With the `rocket` feature enabled: + +```rust +use clerk_rs::{ + clerk::Clerk, + validators::{ + jwks::MemoryCacheJwksProvider, + rocket::{ClerkGuard, ClerkGuardConfig}, + }, + ClerkConfiguration, +}; +use rocket::{ + get, launch, routes, + serde::{Deserialize, Serialize}, +}; + +#[derive(Serialize, Deserialize)] +struct Message { + content: String, +} + +#[get("/")] +fn index(jwt: ClerkGuard) -> &'static str { + "Hello world!" +} + +#[launch] +fn rocket() -> _ { + let config = ClerkConfiguration::new(None, None, Some("sk_test_F9HM5l3WMTDMdBB0ygcMMAiL37QA6BvXYV1v18Noit".to_string()), None); + let clerk = Clerk::new(config); + let clerk_config = ClerkGuardConfig::new( + MemoryCacheJwksProvider::new(clerk), + None, + true, // validate_session_cookie + ); + + rocket::build().mount("/", routes![index]).manage(clerk_config) +} + +``` + ## Roadmap - [ ] Support other http clients along with the default reqwest client (like hyper) diff --git a/examples/rocket.rs b/examples/rocket.rs new file mode 100644 index 0000000..a2a8eff --- /dev/null +++ b/examples/rocket.rs @@ -0,0 +1,35 @@ +use clerk_rs::{ + clerk::Clerk, + validators::{ + jwks::MemoryCacheJwksProvider, + rocket::{ClerkGuard, ClerkGuardConfig}, + }, + ClerkConfiguration, +}; +use rocket::{ + get, launch, routes, + serde::{Deserialize, Serialize}, +}; + +#[derive(Serialize, Deserialize)] +struct Message { + content: String, +} + +#[get("/")] +fn index(jwt: ClerkGuard) -> &'static str { + "Hello world!" +} + +#[launch] +fn rocket() -> _ { + let config = ClerkConfiguration::new(None, None, Some("sk_test_F9HM5l3WMTDMdBB0ygcMMAiL37QA6BvXYV1v18Noit".to_string()), None); + let clerk = Clerk::new(config); + let clerk_config = ClerkGuardConfig::new( + MemoryCacheJwksProvider::new(clerk), + None, + true, // validate_session_cookie + ); + + rocket::build().mount("/", routes![index]).manage(clerk_config) +} diff --git a/src/validators/jwks.rs b/src/validators/jwks.rs index 901e045..491751c 100644 --- a/src/validators/jwks.rs +++ b/src/validators/jwks.rs @@ -41,18 +41,18 @@ impl From for ClerkError { /// A [`JwksProvider`] implementation that doesn't do any caching. /// /// The JWKS is fetched from the Clerk API on every request. -pub struct SimpleJwksProvider { +pub struct JwksProviderNoCache { clerk_client: Clerk, } -impl SimpleJwksProvider { +impl JwksProviderNoCache { pub fn new(clerk_client: Clerk) -> Self { Self { clerk_client } } } #[async_trait] -impl JwksProvider for SimpleJwksProvider { +impl JwksProvider for JwksProviderNoCache { type Error = JwksProviderError; async fn get_key(&self, kid: &str) -> Result { @@ -273,7 +273,7 @@ pub(crate) mod tests { }; let clerk = Clerk::new(config); - let jwks = SimpleJwksProvider::new(clerk); + let jwks = JwksProviderNoCache::new(clerk); let res = jwks.get_key(MOCK_KID).await.expect("should retrieve key"); assert_eq!(res.kid, MOCK_KID); @@ -293,7 +293,7 @@ pub(crate) mod tests { }; let clerk = Clerk::new(config); - let jwks = SimpleJwksProvider::new(clerk); + let jwks = JwksProviderNoCache::new(clerk); jwks.get_key(MOCK_KID).await.expect("should retrieve key"); jwks.get_key(MOCK_KID).await.expect("should retrieve key"); @@ -314,7 +314,7 @@ pub(crate) mod tests { }; let clerk = Clerk::new(config); - let jwks = SimpleJwksProvider::new(clerk); + let jwks = JwksProviderNoCache::new(clerk); // try to get a key that doesn't exist let res = jwks.get_key("unknown key").await.expect_err("should fail"); diff --git a/src/validators/mod.rs b/src/validators/mod.rs index e51edce..96f42f0 100644 --- a/src/validators/mod.rs +++ b/src/validators/mod.rs @@ -7,3 +7,5 @@ pub mod jwks; pub mod actix; #[cfg(feature = "axum")] pub mod axum; +#[cfg(feature = "rocket")] +pub mod rocket; diff --git a/src/validators/rocket.rs b/src/validators/rocket.rs new file mode 100644 index 0000000..d081a1e --- /dev/null +++ b/src/validators/rocket.rs @@ -0,0 +1,86 @@ +use crate::validators::{ + authorizer::{ClerkAuthorizer, ClerkError, ClerkRequest}, + jwks::JwksProvider, +}; +use rocket::{ + http::Status, + request::{FromRequest, Outcome}, + Request, +}; + +use super::authorizer::ClerkJwt; + +// Implement ClerkRequest for Rocket's Request +impl<'r> ClerkRequest for &'r Request<'_> { + fn get_header(&self, key: &str) -> Option { + self.headers().get_one(key).map(|s| s.to_string()) + } + + fn get_cookie(&self, key: &str) -> Option { + self.cookies().get(key).map(|cookie| cookie.value().to_string()) + } +} + +pub struct ClerkGuardConfig { + pub authorizer: ClerkAuthorizer, + pub routes: Option>, +} + +impl ClerkGuardConfig { + pub fn new(jwks_provider: J, routes: Option>, validate_session_cookie: bool) -> Self { + let authorizer = ClerkAuthorizer::new(jwks_provider, validate_session_cookie); + Self { authorizer, routes } + } +} + +pub struct ClerkGuard { + pub jwt: Option, + _marker: std::marker::PhantomData, +} + +// Implement request guard for ClerkGuard +#[rocket::async_trait] +impl<'r, J: JwksProvider + Send + Sync + 'static> FromRequest<'r> for ClerkGuard { + type Error = ClerkError; + + async fn from_request(request: &'r Request<'_>) -> Outcome { + // Retrieve the ClerkAuthorizer from managed state + let config = request + .rocket() + .state::>() + .expect("ClerkGuardConfig not found in managed state"); + + match &config.routes { + Some(route_matches) => { + // If the user only wants to apply authentication to a select amount of routes, we handle that logic here + let path = request.uri().path(); + // Check if the path was NOT contained inside of the routes specified by the user... + let path_not_in_specified_routes = route_matches.iter().find(|&route| route == &path.to_string()).is_none(); + + if path_not_in_specified_routes { + // Since the path was not inside of the listed routes we want to trigger an early exit + return Outcome::Success(ClerkGuard { + jwt: None, + _marker: std::marker::PhantomData, + }); + } + } + // Since we did find a matching route we can simply do nothing here and start the actual auth logic... + None => {} + } + + match config.authorizer.authorize(&request).await { + Ok(jwt) => { + request.local_cache(|| jwt.clone()); + return Outcome::Success(ClerkGuard { + jwt: Some(jwt), + _marker: std::marker::PhantomData, + }); + } + Err(error) => match error { + ClerkError::Unauthorized(msg) => Outcome::Error((Status::Unauthorized, ClerkError::Unauthorized(msg))), + ClerkError::InternalServerError(msg) => Outcome::Error((Status::InternalServerError, ClerkError::InternalServerError(msg))), + }, + } + } +}