Skip to content

Commit

Permalink
Remote verifier reloads keys when encountering unknown keys.
Browse files Browse the repository at this point in the history
  • Loading branch information
bittrance committed May 24, 2024
1 parent edc52da commit c2437ec
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 17 deletions.
1 change: 1 addition & 0 deletions examples/jwks_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ async fn main() -> jwtk::Result<()> {
"http://127.0.0.1:3000/jwks".into(),
None,
Duration::from_secs(300),
None,
);
let c = j.verify::<Map<String, Value>>(&v.token).await?;

Expand Down
69 changes: 52 additions & 17 deletions src/jwk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,16 @@ impl<K: PublicKeyToJwk> PublicKeyToJwk for WithKid<K> {
#[cfg(feature = "remote-jwks")]
struct JWKSCache {
jwks: JwkSetVerifier,
valid_until: std::time::Instant,
last_retrieved: std::time::Instant,
}

impl JWKSCache {
fn fresher_than(&self, age: std::time::Duration) -> bool {
self.last_retrieved
.checked_add(age)
.and_then(|deadline| deadline.checked_duration_since(std::time::Instant::now()))
.is_some()
}
}

/// A JWK Set served from a remote url. Automatically fetched and cached.
Expand All @@ -450,6 +459,7 @@ pub struct RemoteJwksVerifier {
url: String,
client: reqwest::Client,
cache_duration: std::time::Duration,
cooldown: std::time::Duration,
cache: tokio::sync::RwLock<Option<JWKSCache>>,
require_kid: bool,
}
Expand All @@ -460,11 +470,13 @@ impl RemoteJwksVerifier {
url: String,
client: Option<reqwest::Client>,
cache_duration: std::time::Duration,
cooldown: Option<std::time::Duration>,
) -> Self {
Self {
url,
client: client.unwrap_or_default(),
cache_duration,
cooldown: cooldown.unwrap_or(std::time::Duration::from_secs(30)),
cache: tokio::sync::RwLock::new(None),
require_kid: true,
}
Expand All @@ -483,10 +495,7 @@ impl RemoteJwksVerifier {
let cache = self.cache.read().await;
// Cache still valid.
if let Some(c) = &*cache {
if c.valid_until
.checked_duration_since(std::time::Instant::now())
.is_some()
{
if c.fresher_than(self.cache_duration) {
return Ok(tokio::sync::RwLockReadGuard::map(cache, |c| {
&c.as_ref().unwrap().jwks
}));
Expand All @@ -496,15 +505,20 @@ impl RemoteJwksVerifier {

let mut cache = self.cache.write().await;
if let Some(c) = &*cache {
if c.valid_until
.checked_duration_since(std::time::Instant::now())
.is_some()
{
if c.fresher_than(self.cache_duration) {
return Ok(tokio::sync::RwLockReadGuard::map(cache.downgrade(), |c| {
&c.as_ref().unwrap().jwks
}));
}
}
self.reload_jwks(&mut cache).await?;

Ok(tokio::sync::RwLockReadGuard::map(cache.downgrade(), |c| {
&c.as_ref().unwrap().jwks
}))
}

async fn reload_jwks(&self, cache: &mut tokio::sync::RwLockWriteGuard<'_, Option<JWKSCache>>) -> Result<()> {
let response = self
.client
.get(&self.url)
Expand All @@ -513,31 +527,52 @@ impl RemoteJwksVerifier {
.await?;
let jwks: JwkSet = response.json().await?;

*cache = Some(JWKSCache {
cache.replace(JWKSCache {
jwks: {
let mut v = jwks.verifier();
v.require_kid = self.require_kid;
v
},
valid_until: std::time::Instant::now() + self.cache_duration,
last_retrieved: std::time::Instant::now(),
});

Ok(tokio::sync::RwLockReadGuard::map(cache.downgrade(), |c| {
&c.as_ref().unwrap().jwks
}))
Ok(())
}

pub async fn verify<E: DeserializeOwned>(&self, token: &str) -> Result<HeaderAndClaims<E>> {
let v = self.get_verifier().await?;
v.verify(token)
match v.verify(token) {
Ok(v) => Ok(v),
err @ Err(Error::NoKey) => {
let cache = self.cache.read().await;
if cache.as_ref().filter(|c| c.fresher_than(self.cooldown)).is_some() {
return err;
}
let mut cache = self.cache.write().await;
self.reload_jwks(&mut cache).await?;
cache.as_ref().unwrap().jwks.verify(token)
}
Err(e) => Err(e),
}
}

pub async fn verify_only<E: DeserializeOwned>(
&self,
token: &str,
) -> Result<HeaderAndClaims<E>> {
let v = self.get_verifier().await?;
v.verify_only(token)
match v.verify_only(token) {
Ok(v) => Ok(v),
err @ Err(Error::NoKey) => {
let cache = self.cache.read().await;
if !cache.as_ref().filter(|c| c.fresher_than(self.cooldown)).is_some() {
return err;
}
let mut cache = self.cache.write().await;
self.reload_jwks(&mut cache).await?;
cache.as_ref().unwrap().jwks.verify_only(token)
}
Err(e) => Err(e),
}
}
}

Expand Down

0 comments on commit c2437ec

Please sign in to comment.