Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add validator support for Rocket.rs #67

Merged
merged 6 commits into from
Oct 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ name = "axum"
path = "examples/axum.rs"
required-features = ["axum"]

[[example]]
name = "rocket"
path = "examples/rocket.rs"
required-features = ["rocket"]

[lib]
doctest = false

Expand All @@ -43,6 +48,7 @@ futures-util = "0.3.28"
actix-rt = { version = "2.10.0", optional = true }
actix-web = { version = "4.9.0", optional = true }
axum = { version = "0.7.5", optional = true }
rocket = { version = "0.5.0", optional = true }
axum-extra = { version = "0.9.3", features = ["cookie"], optional = true }
tower = { version = "0.5.0", optional = true }
async-trait = "0.1.81"
Expand All @@ -67,3 +73,4 @@ native-tls = ["reqwest/native-tls"]
rustls-tls = ["reqwest/rustls-tls"]
actix = ["dep:actix-rt", "dep:actix-web"]
axum = ["dep:axum", "dep:axum-extra", "dep:tower"]
rocket = ["dep:rocket"]
43 changes: 43 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,49 @@ async fn main() -> std::io::Result<()> {
}
```

### Protecting a rocket endpoint with Clerk.dev:

With the `rocket` feature enabled:

```rust
use clerk_rs::{
clerk::Clerk,
validators::{
jwks::MemoryCacheJwksProvider,
rocket::{ClerkGuard, ClerkGuardConfig},
},
ClerkConfiguration,
};
use rocket::{
get, launch, routes,
serde::{Deserialize, Serialize},
};

#[derive(Serialize, Deserialize)]
struct Message {
content: String,
}

#[get("/")]
fn index(jwt: ClerkGuard<MemoryCacheJwksProvider>) -> &'static str {
"Hello world!"
}

#[launch]
fn rocket() -> _ {
let config = ClerkConfiguration::new(None, None, Some("sk_test_F9HM5l3WMTDMdBB0ygcMMAiL37QA6BvXYV1v18Noit".to_string()), None);
let clerk = Clerk::new(config);
let clerk_config = ClerkGuardConfig::new(
MemoryCacheJwksProvider::new(clerk),
None,
true, // validate_session_cookie
);

rocket::build().mount("/", routes![index]).manage(clerk_config)
}

```

## Roadmap

- [ ] Support other http clients along with the default reqwest client (like hyper)
Expand Down
35 changes: 35 additions & 0 deletions examples/rocket.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
use clerk_rs::{
clerk::Clerk,
validators::{
jwks::MemoryCacheJwksProvider,
rocket::{ClerkGuard, ClerkGuardConfig},
},
ClerkConfiguration,
};
use rocket::{
get, launch, routes,
serde::{Deserialize, Serialize},
};

#[derive(Serialize, Deserialize)]
struct Message {
content: String,
}

#[get("/")]
fn index(jwt: ClerkGuard<MemoryCacheJwksProvider>) -> &'static str {
"Hello world!"
}

#[launch]
fn rocket() -> _ {
let config = ClerkConfiguration::new(None, None, Some("sk_test_F9HM5l3WMTDMdBB0ygcMMAiL37QA6BvXYV1v18Noit".to_string()), None);
let clerk = Clerk::new(config);
let clerk_config = ClerkGuardConfig::new(
MemoryCacheJwksProvider::new(clerk),
None,
true, // validate_session_cookie
);

rocket::build().mount("/", routes![index]).manage(clerk_config)
}
12 changes: 6 additions & 6 deletions src/validators/jwks.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,18 @@ impl From<JwksProviderError> for ClerkError {
/// A [`JwksProvider`] implementation that doesn't do any caching.
///
/// The JWKS is fetched from the Clerk API on every request.
pub struct SimpleJwksProvider {
pub struct JwksProviderNoCache {
clerk_client: Clerk,
}

impl SimpleJwksProvider {
impl JwksProviderNoCache {
pub fn new(clerk_client: Clerk) -> Self {
Self { clerk_client }
}
}

#[async_trait]
impl JwksProvider for SimpleJwksProvider {
impl JwksProvider for JwksProviderNoCache {
type Error = JwksProviderError;

async fn get_key(&self, kid: &str) -> Result<JwksKey, JwksProviderError> {
Expand Down Expand Up @@ -273,7 +273,7 @@ pub(crate) mod tests {
};
let clerk = Clerk::new(config);

let jwks = SimpleJwksProvider::new(clerk);
let jwks = JwksProviderNoCache::new(clerk);

let res = jwks.get_key(MOCK_KID).await.expect("should retrieve key");
assert_eq!(res.kid, MOCK_KID);
Expand All @@ -293,7 +293,7 @@ pub(crate) mod tests {
};
let clerk = Clerk::new(config);

let jwks = SimpleJwksProvider::new(clerk);
let jwks = JwksProviderNoCache::new(clerk);

jwks.get_key(MOCK_KID).await.expect("should retrieve key");
jwks.get_key(MOCK_KID).await.expect("should retrieve key");
Expand All @@ -314,7 +314,7 @@ pub(crate) mod tests {
};
let clerk = Clerk::new(config);

let jwks = SimpleJwksProvider::new(clerk);
let jwks = JwksProviderNoCache::new(clerk);

// try to get a key that doesn't exist
let res = jwks.get_key("unknown key").await.expect_err("should fail");
Expand Down
2 changes: 2 additions & 0 deletions src/validators/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,5 @@ pub mod jwks;
pub mod actix;
#[cfg(feature = "axum")]
pub mod axum;
#[cfg(feature = "rocket")]
pub mod rocket;
86 changes: 86 additions & 0 deletions src/validators/rocket.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
use crate::validators::{
authorizer::{ClerkAuthorizer, ClerkError, ClerkRequest},
jwks::JwksProvider,
};
use rocket::{
http::Status,
request::{FromRequest, Outcome},
Request,
};

use super::authorizer::ClerkJwt;

// Implement ClerkRequest for Rocket's Request
impl<'r> ClerkRequest for &'r Request<'_> {
fn get_header(&self, key: &str) -> Option<String> {
self.headers().get_one(key).map(|s| s.to_string())
}

fn get_cookie(&self, key: &str) -> Option<String> {
self.cookies().get(key).map(|cookie| cookie.value().to_string())
}
}

pub struct ClerkGuardConfig<J: JwksProvider> {
pub authorizer: ClerkAuthorizer<J>,
pub routes: Option<Vec<String>>,
}

impl<J: JwksProvider> ClerkGuardConfig<J> {
pub fn new(jwks_provider: J, routes: Option<Vec<String>>, validate_session_cookie: bool) -> Self {
let authorizer = ClerkAuthorizer::new(jwks_provider, validate_session_cookie);
Self { authorizer, routes }
}
}

pub struct ClerkGuard<J: JwksProvider + Send + Sync> {
pub jwt: Option<ClerkJwt>,
_marker: std::marker::PhantomData<J>,
}

// Implement request guard for ClerkGuard
#[rocket::async_trait]
impl<'r, J: JwksProvider + Send + Sync + 'static> FromRequest<'r> for ClerkGuard<J> {
type Error = ClerkError;

async fn from_request(request: &'r Request<'_>) -> Outcome<Self, Self::Error> {
// Retrieve the ClerkAuthorizer from managed state
let config = request
.rocket()
.state::<ClerkGuardConfig<J>>()
.expect("ClerkGuardConfig not found in managed state");

match &config.routes {
Some(route_matches) => {
// If the user only wants to apply authentication to a select amount of routes, we handle that logic here
let path = request.uri().path();
// Check if the path was NOT contained inside of the routes specified by the user...
let path_not_in_specified_routes = route_matches.iter().find(|&route| route == &path.to_string()).is_none();

if path_not_in_specified_routes {
// Since the path was not inside of the listed routes we want to trigger an early exit
return Outcome::Success(ClerkGuard {
jwt: None,
_marker: std::marker::PhantomData,
});
}
}
// Since we did find a matching route we can simply do nothing here and start the actual auth logic...
None => {}
}

match config.authorizer.authorize(&request).await {
Ok(jwt) => {
request.local_cache(|| jwt.clone());
return Outcome::Success(ClerkGuard {
jwt: Some(jwt),
_marker: std::marker::PhantomData,
});
}
Err(error) => match error {
ClerkError::Unauthorized(msg) => Outcome::Error((Status::Unauthorized, ClerkError::Unauthorized(msg))),
ClerkError::InternalServerError(msg) => Outcome::Error((Status::InternalServerError, ClerkError::InternalServerError(msg))),
},
}
}
}
Loading