Skip to content

Commit

Permalink
[PLATFORM-1400]: Include jwks_client in bridge.rs (#153)
Browse files Browse the repository at this point in the history
* Replace actual jwks impl to use jwks_client

* Formatting

* Pin to v1.72

* Fix command

* Fix clippy warnings

* Remove chrono-tz
  • Loading branch information
cottinisimone authored Mar 11, 2024
1 parent c584ab1 commit f4e486c
Show file tree
Hide file tree
Showing 11 changed files with 40 additions and 172 deletions.
2 changes: 2 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- run: rustup toolchain install 1.72.0
- uses: Swatinem/rust-cache@v2
- uses: taiki-e/install-action@cargo-make
- run: cargo make fmt-check
Expand All @@ -23,6 +24,7 @@ jobs:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- run: rustup toolchain install 1.72.0
- uses: Swatinem/rust-cache@v2
- uses: taiki-e/install-action@cargo-make
- run: cargo make test
Expand Down
4 changes: 2 additions & 2 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ rust-version = "1.72"
[features]
default = ["tracing_opentelemetry"]

auth0 = ["rand", "redis", "jsonwebtoken", "chrono", "chrono-tz", "aes", "cbc", "dashmap", "tracing"]
auth0 = ["rand", "redis", "jsonwebtoken", "jwks_client_rs", "chrono", "aes", "cbc", "dashmap", "tracing"]
gzip = ["reqwest/gzip"]
redis-tls = ["redis/tls", "redis/tokio-native-tls-comp"]
tracing_opentelemetry = [ "tracing_opentelemetry_0_21" ]
Expand All @@ -29,11 +29,11 @@ async-trait = "0.1"
bytes = "1.2"
cbc = {version = "0.1", features = ["std"], optional = true}
chrono = {version = "0.4", default-features = false, features = ["clock", "std", "serde"], optional = true}
chrono-tz = {version = "0.8", optional = true}
dashmap = {version = "5.1", optional = true}
futures = "0.3"
futures-util = "0.3"
jsonwebtoken = {version = "9.0", optional = true}
jwks_client_rs = {version = "0.5", optional = true}
rand = {version = "0.8", optional = true}
redis = {version = "0.23", features = ["tokio-comp"], optional = true}
reqwest = {version = "0.11", features = ["json", "multipart", "stream"]}
Expand Down
25 changes: 0 additions & 25 deletions src/auth0/cache/inmemory.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use dashmap::DashMap;

use crate::auth0::cache::{self, crypto};
use crate::auth0::errors::Auth0Error;
use crate::auth0::keyset::JsonWebKeySet;
use crate::auth0::token::Token;
use crate::auth0::{cache::Cache, Config};

Expand Down Expand Up @@ -40,20 +39,6 @@ impl Cache for InMemoryCache {
let _ = self.key_value.insert(key, encrypted_value);
Ok(())
}

async fn get_jwks(&self) -> Result<Option<JsonWebKeySet>, Auth0Error> {
self.key_value
.get(&cache::jwks_key(&self.caller, &self.audience))
.map(|value| crypto::decrypt(self.encryption_key.as_str(), value.as_slice()))
.transpose()
}

async fn put_jwks(&self, value_ref: &JsonWebKeySet, _expiration: Option<usize>) -> Result<(), Auth0Error> {
let key: String = cache::jwks_key(&self.caller, &self.audience);
let encrypted_value: Vec<u8> = crypto::encrypt(value_ref, self.encryption_key.as_str())?;
let _ = self.key_value.insert(key, encrypted_value);
Ok(())
}
}

#[cfg(test)]
Expand All @@ -70,22 +55,12 @@ mod tests {
let result: Option<Token> = cache.get_token().await.unwrap();
assert!(result.is_none());

let result: Option<JsonWebKeySet> = cache.get_jwks().await.unwrap();
assert!(result.is_none());

let token_str: &str = "token";
let token: Token = Token::new(token_str.to_string(), Utc::now(), Utc::now());
cache.put_token(&token).await.unwrap();

let result: Option<Token> = cache.get_token().await.unwrap();
assert!(result.is_some());
assert_eq!(result.unwrap().as_str(), token_str);

let string: &str = "{\"keys\": []}";
let jwks: JsonWebKeySet = serde_json::from_str(string).unwrap();
cache.put_jwks(&jwks, None).await.unwrap();

let result: Option<JsonWebKeySet> = cache.get_jwks().await.unwrap();
assert!(result.is_some());
}
}
10 changes: 0 additions & 10 deletions src/auth0/cache/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,21 @@ pub use inmemory::InMemoryCache;
pub use redis_impl::RedisCache;

use crate::auth0::errors::Auth0Error;
use crate::auth0::keyset::JsonWebKeySet;
use crate::auth0::Token;

mod crypto;
mod inmemory;
mod redis_impl;

const TOKEN_PREFIX: &str = "auth0rs_tokens";
const JWKS_PREFIX: &str = "auth0rs_jwks";

#[async_trait::async_trait]
pub trait Cache: Send + Sync + std::fmt::Debug {
async fn get_token(&self) -> Result<Option<Token>, Auth0Error>;

async fn put_token(&self, value_ref: &Token) -> Result<(), Auth0Error>;

async fn get_jwks(&self) -> Result<Option<JsonWebKeySet>, Auth0Error>;

async fn put_jwks(&self, value_ref: &JsonWebKeySet, expiration: Option<usize>) -> Result<(), Auth0Error>;
}

pub(in crate::auth0::cache) fn token_key(caller: &str, audience: &str) -> String {
format!("{}:{}:{}", TOKEN_PREFIX, caller, audience)
}

pub(in crate::auth0::cache) fn jwks_key(caller: &str, audience: &str) -> String {
format!("{}:{}:{}", JWKS_PREFIX, caller, audience)
}
15 changes: 0 additions & 15 deletions src/auth0/cache/redis_impl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use redis::AsyncCommands;
use serde::Deserialize;

use crate::auth0::cache::{self, crypto, Cache};
use crate::auth0::keyset::JsonWebKeySet;
use crate::auth0::token::Token;
use crate::auth0::{Auth0Error, Config};

Expand Down Expand Up @@ -57,20 +56,6 @@ impl Cache for RedisCache {
connection.set_ex(key, encrypted_value, expiration).await?;
Ok(())
}

async fn get_jwks(&self) -> Result<Option<JsonWebKeySet>, Auth0Error> {
let key: &str = &cache::jwks_key(&self.caller, &self.audience);
self.get(key).await
}

async fn put_jwks(&self, value_ref: &JsonWebKeySet, expiration: Option<usize>) -> Result<(), Auth0Error> {
let key: &str = &cache::jwks_key(&self.caller, &self.audience);
let mut connection = self.client.get_async_connection().await?;
let encrypted_value: Vec<u8> = crypto::encrypt(value_ref, self.encryption_key.as_str())?;
let expiration: usize = expiration.unwrap_or(86400);
connection.set_ex(key, encrypted_value, expiration).await?;
Ok(())
}
}

// To run this test (it works):
Expand Down
8 changes: 4 additions & 4 deletions src/auth0/errors.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ pub enum Auth0Error {
JwtFetchError(u16, String, reqwest::Error),
#[error("failed to deserialize jwt from {0}. {1}")]
JwtFetchDeserializationError(String, reqwest::Error),
#[error("failed to fetch jwks from {0}. Status code: {0}; error: {1}")]
JwksFetchError(u16, String, reqwest::Error),
#[error("failed to deserialize jwks from {0}. {1}")]
JwksFetchDeserializationError(String, reqwest::Error),
#[error(transparent)]
JwksClientError(#[from] jwks_client_rs::JwksClientError),
#[error("failed to fetch jwt from {0}. Status code: {0}; error: {1}")]
JwksHttpError(String, reqwest::Error),
#[error("redis error: {0}")]
RedisError(#[from] redis::RedisError),
#[error(transparent)]
Expand Down
70 changes: 0 additions & 70 deletions src/auth0/keyset.rs

This file was deleted.

57 changes: 19 additions & 38 deletions src/auth0/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
//! Stuff used to provide JWT authentication via Auth0
use std::sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard};
use std::time::Duration;

use jwks_client_rs::source::WebSource;
use jwks_client_rs::JwksClient;
use reqwest::Client;
use tokio::task::JoinHandle;
use tokio::time::Interval;
Expand All @@ -11,17 +14,16 @@ pub use errors::Auth0Error;
use util::ResultExt;

use crate::auth0::cache::Cache;
use crate::auth0::keyset::JsonWebKeySet;
use crate::auth0::token::Claims;
pub use crate::auth0::token::Token;

mod cache;
mod config;
mod errors;
mod keyset;
mod token;
mod util;

#[derive(Clone, Debug)]
#[derive(Debug)]
pub struct Auth0 {
token_lock: Arc<RwLock<Token>>,
}
Expand All @@ -34,15 +36,20 @@ impl Auth0 {
Arc::new(cache::RedisCache::new(&config).await?)
};

let jwks: JsonWebKeySet = get_jwks(client_ref, &cache, &config).await?;
let source: WebSource = WebSource::builder()
.with_timeout(Duration::from_secs(5))
.with_connect_timeout(Duration::from_secs(55))
.build(config.jwks_url().to_owned())
.map_err(|err| Auth0Error::JwksHttpError(config.token_url().as_str().to_string(), err))?;

let jwks_client = JwksClient::builder().build(source);
let token: Token = get_token(client_ref, &cache, &config).await?;

let jwks_lock: Arc<RwLock<JsonWebKeySet>> = Arc::new(RwLock::new(jwks));
let token_lock: Arc<RwLock<Token>> = Arc::new(RwLock::new(token));

start(
jwks_lock.clone(),
token_lock.clone(),
jwks_client.clone(),
client_ref.clone(),
cache.clone(),
config,
Expand All @@ -58,8 +65,8 @@ impl Auth0 {
}

async fn start(
jwks_lock: Arc<RwLock<JsonWebKeySet>>,
token_lock: Arc<RwLock<Token>>,
jwks_client: JwksClient<WebSource>,
client: Client,
cache: Arc<dyn Cache>,
config: Config,
Expand All @@ -82,22 +89,13 @@ async fn start(
if token.needs_refresh(&config) {
tracing::info!("Refreshing JWT and JWKS");

let jwks_opt = match JsonWebKeySet::fetch(&client, &config).await {
Ok(jwks) => {
let _ = cache.put_jwks(&jwks, None).await.log_err("Error caching JWKS");
write(&jwks_lock, jwks.clone());
Some(jwks)
}
Err(error) => {
tracing::error!("Failed to fetch JWKS. Reason: {:?}", error);
None
}
};

match Token::fetch(&client, &config).await {
Ok(token) => {
let is_signed: Option<bool> = jwks_opt.map(|j| j.is_signed(&token));
tracing::info!("is signed: {}", is_signed.unwrap_or_default());
let is_signed: bool = jwks_client
.decode::<Claims>(token.as_str(), &[config.audience()])
.await
.is_ok();
tracing::info!("is signed: {}", is_signed);

let _ = cache.put_token(&token).await.log_err("Error caching JWT");
write(&token_lock, token);
Expand All @@ -111,23 +109,6 @@ async fn start(
})
}

// Try to fetch the jwks from cache. If it's found return it; fetch from auth0 and put in cache otherwise
async fn get_jwks(
client_ref: &Client,
cache_ref: &Arc<dyn Cache>,
config_ref: &Config,
) -> Result<JsonWebKeySet, Auth0Error> {
match cache_ref.get_jwks().await? {
Some(jwks) => Ok(jwks),
None => {
let jwks: JsonWebKeySet = JsonWebKeySet::fetch(client_ref, config_ref).await?;
let _ = cache_ref.put_jwks(&jwks, None).await.log_err("JWKS cache set failed");

Ok(jwks)
}
}
}

// Try to fetch the token from cache. If it's found return it; fetch from auth0 and put in cache otherwise
async fn get_token(client_ref: &Client, cache_ref: &Arc<dyn Cache>, config_ref: &Config) -> Result<Token, Auth0Error> {
match cache_ref.get_token().await? {
Expand Down
12 changes: 9 additions & 3 deletions src/auth0/token.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use chrono::{DateTime, Duration, Utc};
use chrono::{DateTime, Utc};
use reqwest::{Client, Response};
use serde::{Deserialize, Serialize};
use std::time::Duration;

use crate::auth0::errors::Auth0Error;
use crate::auth0::Config;
Expand Down Expand Up @@ -54,8 +55,7 @@ impl Token {
// the exact issued_at (iat) and expiration (exp)
// reference: https://www.iana.org/assignments/jwt/jwt.xhtml
let issue_date: DateTime<Utc> = Utc::now();
let expire_date: DateTime<Utc> =
Utc::now() + Duration::try_seconds(response.expires_in as i64).unwrap_or(Duration::max_value());
let expire_date: DateTime<Utc> = Utc::now() + Duration::from_secs(response.expires_in as u64);

Ok(Self {
token: access_token,
Expand Down Expand Up @@ -132,3 +132,9 @@ impl From<&Config> for FetchTokenRequest {
}
}
}

#[derive(Deserialize, Debug)]
pub struct Claims {
#[serde(default)]
pub permissions: Vec<String>,
}
Loading

0 comments on commit f4e486c

Please sign in to comment.