Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(jans-cedarling): improve WASM compatibility #10331

Draft
wants to merge 6 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion jans-cedarling/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ thiserror = "1.0"
sparkv = { path = "sparkv" }
cedarling = { path = "cedarling" }
test_utils = { path = "test_utils" }

chrono = "0.4.38"

[profile.release]
strip = "symbols"
Expand Down
5 changes: 3 additions & 2 deletions jans-cedarling/cedarling/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ base64 = "0.22.1"
url = "2.5.2"
lazy_static = "1.5.0"
jsonwebtoken = "9.3.0"
reqwest = { version = "0.12.8", features = ["blocking", "json"] }
reqwest = { version = "0.12.8", features = ["json"] }
bytes = "1.7.2"
typed-builder = "0.20.0"
semver = { version = "1.0.23", features = ["serde"] }
Expand All @@ -25,8 +25,9 @@ derive_more = { version = "1.0.0", features = [
"display",
"error",
] }
time = { version = "0.3.36", features = ["wasm-bindgen"] }
regex = "1.11.1"
chrono = { workspace = true }
tokio = { version = "1.42.0", features = ["rt", "time"] }

[dev-dependencies]
# is used in testing
Expand Down
116 changes: 81 additions & 35 deletions jans-cedarling/cedarling/src/jwt/http_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,63 @@
* Copyright (c) 2024, Gluu, Inc.
*/

use reqwest::blocking::Client;
use std::{thread::sleep, time::Duration};

/// A wrapper around `reqwest::blocking::Client` providing HTTP request functionality
use reqwest::Client;
use serde::de::DeserializeOwned;
use tokio::{
runtime::{Builder as RtBuilder, Runtime},
time::Duration,
};

/// A wrapper around [`reqwest::Client`] providing HTTP request functionality
/// with retry logic.
///
/// The `HttpClient` struct allows for sending GET requests with a retry mechanism
/// that attempts to fetch the requested resource up to a maximum number of times
/// if an error occurs.
#[derive(Debug)]
pub struct HttpClient {
client: reqwest::blocking::Client,
client: reqwest::Client,
max_retries: u32,
retry_delay: Duration,
rt: Runtime,
}

/// A wrapper around [`reqwest::Response`]
#[derive(Debug)]
pub struct Response<'rt> {
rt: &'rt Runtime,
resp: reqwest::Response,
}

impl Response<'_> {
/// Deserializes the response into <T> from JSON.
pub fn json<T>(self) -> Result<T, HttpClientError>
where
T: DeserializeOwned,
{
let resp_json = self
.rt
.block_on(async { self.resp.json::<T>().await })

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In general it's better to use async then block_on. Here you can do something like :

let resp_json = self
            .resp
            .json::<T>()
            .await
            .map_err(HttpClientError::DeserializeJson)?;
        Ok(resp_json)

Moreover, you can remove the lifetime parameter <'rt> from Response.

.map_err(HttpClientError::DeserializeJson)?;
Ok(resp_json)
}

/// Deserializes the response into a String.
pub fn text(self) -> Result<String, HttpClientError> {
let resp_text = self
.rt
.block_on(async { self.resp.text().await })
.map_err(HttpClientError::DeserializeJson)?;

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe it should be better to use DeserializeText instead of DeserializeJson here

Ok(resp_text)
}
}

impl HttpClient {
pub fn new(max_retries: u32, retry_delay: Duration) -> Result<Self, HttpClientError> {
let rt = RtBuilder::new_current_thread()
.enable_all()
.build()
.expect("Failed to create Tokio runtime");
let client = Client::builder()
.build()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a good practice to add a timeout in during an http request like:

let client = Client::builder()
.timeout(Duration::from_sec(20))
.build()
.map_err(HttpClientError::Initialization)?;

.map_err(HttpClientError::Initialization)?;
Expand All @@ -31,38 +70,43 @@ impl HttpClient {
client,
max_retries,
retry_delay,
rt,
})
}

/// Sends a GET request to the specified URI with retry logic.
///
/// This method will attempt to fetch the resource up to 3 times, with an increasing delay
/// between each attempt.
pub fn get(&self, uri: &str) -> Result<reqwest::blocking::Response, HttpClientError> {
pub fn get(&self, uri: &str) -> Result<Response, HttpClientError> {
// Fetch the JWKS from the jwks_uri
let mut attempts = 0;
let response = loop {
match self.client.get(uri).send() {
// Exit loop on success
Ok(response) => break response,

Err(e) if attempts < self.max_retries => {
attempts += 1;
// TODO: pass this message to the logger
eprintln!(
"Request failed (attempt {} of {}): {}. Retrying...",
attempts, self.max_retries, e
);
sleep(self.retry_delay * attempts);
},
// Exit if max retries exceeded
Err(e) => return Err(HttpClientError::MaxHttpRetriesReached(e)),
let response = self.rt.block_on(async {
loop {
match self.client.get(uri).send().await {
// Exit loop on success
Ok(response) => return Ok(response),

Err(e) if attempts < self.max_retries => {
attempts += 1;
// TODO: pass this message to the logger
eprintln!(
"Request failed (attempt {} of {}): {}. Retrying...",
attempts, self.max_retries, e
);
tokio::time::sleep(self.retry_delay * attempts).await
},
// Exit if max retries exceeded
Err(e) => return Err(HttpClientError::MaxHttpRetriesReached(e)),
}
}
};
})?;

response
let resp = response
.error_for_status()
.map_err(HttpClientError::HttpStatus)
.map_err(HttpClientError::HttpStatus)?;

Ok(Response { rt: &self.rt, resp })
}
}

Expand All @@ -75,21 +119,24 @@ pub enum HttpClientError {
/// Indicates an HTTP error response received from an endpoint.
#[error("Received error HTTP status: {0}")]
HttpStatus(#[source] reqwest::Error),

/// Indicates a failure to reach the endpoint after 3 attempts.
#[error("Could not reach endpoint after trying 3 times: {0}")]
MaxHttpRetriesReached(#[source] reqwest::Error),
/// Indicates a failure to deserialize the http response into JSON.
#[error("Failed to deserialize response into JSON: {0}")]
DeserializeJson(#[source] reqwest::Error),
/// Indicates a failure to deserialize the http response into JSON.
#[error("Failed to deserialize response into a String: {0}")]
DeserializeText(#[source] reqwest::Error),
}

#[cfg(test)]
mod test {
use crate::jwt::http_client::HttpClientError;

use super::HttpClient;
use crate::jwt::http_client::{HttpClient, HttpClientError};
use mockito::Server;
use serde_json::json;
use std::time::Duration;
use serde_json::{json, Value};
use test_utils::assert_eq;
use tokio::time::Duration;

#[test]
fn can_fetch() {
Expand All @@ -108,17 +155,16 @@ mod test {
.expect(1)
.create();

let client =
HttpClient::new(3, Duration::from_millis(1)).expect("Should create HttpClient.");
let client = HttpClient::new(3, Duration::new(0, 10)).expect("Should create HttpClient.");

let response = client
.get(&format!(
"{}/.well-known/openid-configuration",
mock_server.url()
))
.expect("Should get response")
.json::<serde_json::Value>()
.expect("Should deserialize JSON response.");
.json::<Value>()
.expect("Should deserialize response to JSON");

assert_eq!(
response, expected,
Expand Down
45 changes: 25 additions & 20 deletions jans-cedarling/cedarling/src/jwt/jwk_store.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,14 @@
use super::http_client::{HttpClient, HttpClientError};
use super::{KeyId, TrustedIssuerId};
use crate::common::policy_store::TrustedIssuer;
use chrono::prelude::*;
use jsonwebtoken::jwk::Jwk;
use jsonwebtoken::DecodingKey;
use serde::Deserialize;
use serde_json::Value;
use std::collections::{HashMap, HashSet};
use std::fmt::Debug;
use std::sync::Arc;
use time::OffsetDateTime;

#[derive(Deserialize)]
struct OpenIdConfig {
Expand All @@ -38,7 +38,7 @@ pub struct JwkStore {
/// A collection of keys that do not have an associated ID.
keys_without_id: Vec<DecodingKey>,
/// The timestamp indicating when the store was last updated.
last_updated: OffsetDateTime,
last_updated: DateTime<Utc>,
/// From which TrustedIssuer this struct was built (if applicable).
source_iss: Option<TrustedIssuer>,
}
Expand Down Expand Up @@ -143,7 +143,7 @@ impl JwkStore {
issuer: None,
keys,
keys_without_id,
last_updated: OffsetDateTime::now_utc(),
last_updated: Utc::now(),
source_iss: None,
})
}
Expand All @@ -155,15 +155,17 @@ impl JwkStore {
http_client: &HttpClient,
) -> Result<Self, JwkStoreError> {
// fetch openid configuration
let response = http_client.get(&issuer.openid_configuration_endpoint)?;
let response = http_client
.get(&issuer.openid_configuration_endpoint)
.map_err(JwkStoreError::FetchOpenIdConfig)?;
let openid_config = response
.json::<OpenIdConfig>()
.map_err(JwkStoreError::FetchOpenIdConfig)?;
.map_err(JwkStoreError::DeserializeOpenIdConfig)?;

// fetch jwks
let response = http_client.get(&openid_config.jwks_uri)?;

let jwks = response.text().map_err(JwkStoreError::FetchJwks)?;
let jwks = response.text().map_err(JwkStoreError::DeserializeJwks)?;

let mut store = Self::new_from_jwks_str(store_id, &jwks)?;
store.issuer = Some(openid_config.issuer.into());
Expand Down Expand Up @@ -204,9 +206,11 @@ impl JwkStore {
#[derive(thiserror::Error, Debug)]
pub enum JwkStoreError {
#[error("Failed to fetch OpenIdConfig remote server: {0}")]
FetchOpenIdConfig(#[source] reqwest::Error),
#[error("Failed to fetch JWKS from remote server: {0}")]
FetchJwks(#[source] reqwest::Error),
FetchOpenIdConfig(#[source] HttpClientError),
#[error("Failed to deserialize OpenIdConfig to JSON: {0}")]
DeserializeOpenIdConfig(#[source] HttpClientError),
#[error("Failed to fetch JWKS: {0}")]
DeserializeJwks(#[source] HttpClientError),
#[error("Failed to make HTTP Request: {0}")]
Http(#[from] HttpClientError),
#[error("Failed to create Decoding Key from JWK: {0}")]
Expand All @@ -232,11 +236,12 @@ mod test {
common::policy_store::TrustedIssuer,
jwt::{http_client::HttpClient, jwk_store::JwkStore},
};
use tokio::time::Duration;
use chrono::prelude::*;
use jsonwebtoken::{jwk::JwkSet, DecodingKey};
use mockito::Server;
use serde_json::json;
use std::{collections::HashMap, time::Duration};
use time::OffsetDateTime;
use std::collections::HashMap;

#[test]
fn can_load_from_jwkset() {
Expand Down Expand Up @@ -266,7 +271,7 @@ mod test {
.expect("Should create JwkStore");
// We edit the `last_updated` from the result so that the comparison
// wont fail because of the timestamp.
result.last_updated = OffsetDateTime::from_unix_timestamp(0).unwrap();
result.last_updated = DateTime::from_timestamp(0, 0).unwrap();

let expected_jwkset =
serde_json::from_value::<JwkSet>(jwks_json).expect("Should create JwkSet");
Expand All @@ -287,7 +292,7 @@ mod test {
issuer: None,
keys: expected_keys,
keys_without_id: Vec::new(),
last_updated: OffsetDateTime::from_unix_timestamp(0).unwrap(),
last_updated: DateTime::from_timestamp(0, 0).unwrap(),
source_iss: None,
};

Expand Down Expand Up @@ -373,7 +378,7 @@ mod test {
.expect("Should load JwkStore from Trusted Issuer");
// We edit the `last_updated` from the result so that the comparison
// wont fail because of the timestamp.
result.last_updated = OffsetDateTime::from_unix_timestamp(0).unwrap();
result.last_updated = DateTime::from_timestamp(0, 0).unwrap();

let jwkset =
serde_json::from_value::<JwkSet>(jwks_json).expect("Should create JwkSet from Value");
Expand All @@ -393,7 +398,7 @@ mod test {
issuer: Some(mock_server.url().into()),
keys: expected_keys,
keys_without_id: Vec::new(),
last_updated: OffsetDateTime::from_unix_timestamp(0).unwrap(),
last_updated: DateTime::from_timestamp(0, 0).unwrap(),
source_iss: Some(source_iss),
};

Expand Down Expand Up @@ -442,7 +447,7 @@ mod test {
.expect("Should create JwkStore");
// We edit the `last_updated` from the result so that the comparison
// wont fail because of the timestamp.
result.last_updated = OffsetDateTime::from_unix_timestamp(0).unwrap();
result.last_updated = DateTime::from_timestamp(0, 0).unwrap();

let jwkset = serde_json::from_value::<JwkSet>(jwks_json).expect("Should create JwkSet");
let expected_keys = jwkset
Expand All @@ -456,7 +461,7 @@ mod test {
issuer: None,
keys: HashMap::new(),
keys_without_id: expected_keys,
last_updated: OffsetDateTime::from_unix_timestamp(0).unwrap(),
last_updated: DateTime::from_timestamp(0, 0).unwrap(),
source_iss: None,
};

Expand Down Expand Up @@ -489,7 +494,7 @@ mod test {
.expect("Should create JwkStore");
// We edit the `last_updated` from the result so that the comparison
// wont fail because of the timestamp.
result.last_updated = OffsetDateTime::from_unix_timestamp(0).unwrap();
result.last_updated = DateTime::from_timestamp(0, 0).unwrap();

assert_eq!(result.get_keys().len(), 2, "Expected 2 keys");
}
Expand Down Expand Up @@ -536,7 +541,7 @@ mod test {
.expect("Should create JwkStore");
// We edit the `last_updated` from the result so that the comparison
// wont fail because of the timestamp.
result.last_updated = OffsetDateTime::from_unix_timestamp(0).unwrap();
result.last_updated = DateTime::from_timestamp(0, 0).unwrap();

let expected_jwkset = serde_json::from_value::<JwkSet>(json!({"keys": [
{
Expand Down Expand Up @@ -574,7 +579,7 @@ mod test {
issuer: None,
keys: expected_keys,
keys_without_id: Vec::new(),
last_updated: OffsetDateTime::from_unix_timestamp(0).unwrap(),
last_updated: DateTime::from_timestamp(0, 0).unwrap(),
source_iss: None,
};

Expand Down
5 changes: 3 additions & 2 deletions jans-cedarling/cedarling/src/jwt/key_service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ use super::{
use crate::common::policy_store::TrustedIssuer;
use jsonwebtoken::DecodingKey;
use serde_json::{json, Value};
use std::{collections::HashMap, sync::Arc, time::Duration};
use std::{collections::HashMap, sync::Arc};
use tokio::time::Duration;

pub struct DecodingKeyWithIss<'a> {
/// The decoding key used to validate JWT signatures.
Expand Down Expand Up @@ -67,7 +68,7 @@ impl KeyService {
pub fn new_from_trusted_issuers(
trusted_issuers: &HashMap<String, TrustedIssuer>,
) -> Result<Self, KeyServiceError> {
let http_client = HttpClient::new(3, Duration::from_secs(3))?;
let http_client = HttpClient::new(3, Duration::new(3, 0))?;

let mut key_stores = HashMap::new();
for (iss_id, iss) in trusted_issuers.iter() {
Expand Down
Loading