diff --git a/jans-cedarling/Cargo.toml b/jans-cedarling/Cargo.toml index 623620c7158..71754fed0a8 100644 --- a/jans-cedarling/Cargo.toml +++ b/jans-cedarling/Cargo.toml @@ -9,7 +9,7 @@ thiserror = "1.0" sparkv = { path = "sparkv" } cedarling = { path = "cedarling" } test_utils = { path = "test_utils" } - +chrono = "0.4.38" [profile.release] strip = "symbols" diff --git a/jans-cedarling/cedarling/Cargo.toml b/jans-cedarling/cedarling/Cargo.toml index d9b40648d37..0b4e84c6a10 100644 --- a/jans-cedarling/cedarling/Cargo.toml +++ b/jans-cedarling/cedarling/Cargo.toml @@ -15,7 +15,7 @@ base64 = "0.22.1" url = "2.5.2" lazy_static = "1.5.0" jsonwebtoken = "9.3.0" -reqwest = { version = "0.12.8", features = ["blocking", "json"] } +reqwest = { version = "0.12.8", features = ["json"] } bytes = "1.7.2" typed-builder = "0.20.0" semver = { version = "1.0.23", features = ["serde"] } @@ -25,8 +25,9 @@ derive_more = { version = "1.0.0", features = [ "display", "error", ] } -time = { version = "0.3.36", features = ["wasm-bindgen"] } regex = "1.11.1" +chrono = { workspace = true } +tokio = { version = "1.42.0", features = ["rt", "time"] } [dev-dependencies] # is used in testing diff --git a/jans-cedarling/cedarling/src/jwt/http_client.rs b/jans-cedarling/cedarling/src/jwt/http_client.rs index 5164d7433be..4b82d9f0346 100644 --- a/jans-cedarling/cedarling/src/jwt/http_client.rs +++ b/jans-cedarling/cedarling/src/jwt/http_client.rs @@ -5,10 +5,14 @@ * Copyright (c) 2024, Gluu, Inc. */ -use reqwest::blocking::Client; -use std::{thread::sleep, time::Duration}; - -/// A wrapper around `reqwest::blocking::Client` providing HTTP request functionality +use reqwest::Client; +use serde::de::DeserializeOwned; +use tokio::{ + runtime::{Builder as RtBuilder, Runtime}, + time::Duration, +}; + +/// A wrapper around [`reqwest::Client`] providing HTTP request functionality /// with retry logic. /// /// The `HttpClient` struct allows for sending GET requests with a retry mechanism @@ -16,13 +20,48 @@ use std::{thread::sleep, time::Duration}; /// if an error occurs. #[derive(Debug)] pub struct HttpClient { - client: reqwest::blocking::Client, + client: reqwest::Client, max_retries: u32, retry_delay: Duration, + rt: Runtime, +} + +/// A wrapper around [`reqwest::Response`] +#[derive(Debug)] +pub struct Response<'rt> { + rt: &'rt Runtime, + resp: reqwest::Response, +} + +impl Response<'_> { + /// Deserializes the response into from JSON. + pub fn json(self) -> Result + where + T: DeserializeOwned, + { + let resp_json = self + .rt + .block_on(async { self.resp.json::().await }) + .map_err(HttpClientError::DeserializeJson)?; + Ok(resp_json) + } + + /// Deserializes the response into a String. + pub fn text(self) -> Result { + let resp_text = self + .rt + .block_on(async { self.resp.text().await }) + .map_err(HttpClientError::DeserializeJson)?; + Ok(resp_text) + } } impl HttpClient { pub fn new(max_retries: u32, retry_delay: Duration) -> Result { + let rt = RtBuilder::new_current_thread() + .enable_all() + .build() + .expect("Failed to create Tokio runtime"); let client = Client::builder() .build() .map_err(HttpClientError::Initialization)?; @@ -31,6 +70,7 @@ impl HttpClient { client, max_retries, retry_delay, + rt, }) } @@ -38,31 +78,35 @@ impl HttpClient { /// /// This method will attempt to fetch the resource up to 3 times, with an increasing delay /// between each attempt. - pub fn get(&self, uri: &str) -> Result { + pub fn get(&self, uri: &str) -> Result { // Fetch the JWKS from the jwks_uri let mut attempts = 0; - let response = loop { - match self.client.get(uri).send() { - // Exit loop on success - Ok(response) => break response, - - Err(e) if attempts < self.max_retries => { - attempts += 1; - // TODO: pass this message to the logger - eprintln!( - "Request failed (attempt {} of {}): {}. Retrying...", - attempts, self.max_retries, e - ); - sleep(self.retry_delay * attempts); - }, - // Exit if max retries exceeded - Err(e) => return Err(HttpClientError::MaxHttpRetriesReached(e)), + let response = self.rt.block_on(async { + loop { + match self.client.get(uri).send().await { + // Exit loop on success + Ok(response) => return Ok(response), + + Err(e) if attempts < self.max_retries => { + attempts += 1; + // TODO: pass this message to the logger + eprintln!( + "Request failed (attempt {} of {}): {}. Retrying...", + attempts, self.max_retries, e + ); + tokio::time::sleep(self.retry_delay * attempts).await + }, + // Exit if max retries exceeded + Err(e) => return Err(HttpClientError::MaxHttpRetriesReached(e)), + } } - }; + })?; - response + let resp = response .error_for_status() - .map_err(HttpClientError::HttpStatus) + .map_err(HttpClientError::HttpStatus)?; + + Ok(Response { rt: &self.rt, resp }) } } @@ -75,21 +119,24 @@ pub enum HttpClientError { /// Indicates an HTTP error response received from an endpoint. #[error("Received error HTTP status: {0}")] HttpStatus(#[source] reqwest::Error), - /// Indicates a failure to reach the endpoint after 3 attempts. #[error("Could not reach endpoint after trying 3 times: {0}")] MaxHttpRetriesReached(#[source] reqwest::Error), + /// Indicates a failure to deserialize the http response into JSON. + #[error("Failed to deserialize response into JSON: {0}")] + DeserializeJson(#[source] reqwest::Error), + /// Indicates a failure to deserialize the http response into JSON. + #[error("Failed to deserialize response into a String: {0}")] + DeserializeText(#[source] reqwest::Error), } #[cfg(test)] mod test { - use crate::jwt::http_client::HttpClientError; - - use super::HttpClient; + use crate::jwt::http_client::{HttpClient, HttpClientError}; use mockito::Server; - use serde_json::json; - use std::time::Duration; + use serde_json::{json, Value}; use test_utils::assert_eq; + use tokio::time::Duration; #[test] fn can_fetch() { @@ -108,8 +155,7 @@ mod test { .expect(1) .create(); - let client = - HttpClient::new(3, Duration::from_millis(1)).expect("Should create HttpClient."); + let client = HttpClient::new(3, Duration::new(0, 10)).expect("Should create HttpClient."); let response = client .get(&format!( @@ -117,8 +163,8 @@ mod test { mock_server.url() )) .expect("Should get response") - .json::() - .expect("Should deserialize JSON response."); + .json::() + .expect("Should deserialize response to JSON"); assert_eq!( response, expected, diff --git a/jans-cedarling/cedarling/src/jwt/jwk_store.rs b/jans-cedarling/cedarling/src/jwt/jwk_store.rs index 4b7d07d156b..69d09279675 100644 --- a/jans-cedarling/cedarling/src/jwt/jwk_store.rs +++ b/jans-cedarling/cedarling/src/jwt/jwk_store.rs @@ -8,6 +8,7 @@ use super::http_client::{HttpClient, HttpClientError}; use super::{KeyId, TrustedIssuerId}; use crate::common::policy_store::TrustedIssuer; +use chrono::prelude::*; use jsonwebtoken::jwk::Jwk; use jsonwebtoken::DecodingKey; use serde::Deserialize; @@ -15,7 +16,6 @@ use serde_json::Value; use std::collections::{HashMap, HashSet}; use std::fmt::Debug; use std::sync::Arc; -use time::OffsetDateTime; #[derive(Deserialize)] struct OpenIdConfig { @@ -38,7 +38,7 @@ pub struct JwkStore { /// A collection of keys that do not have an associated ID. keys_without_id: Vec, /// The timestamp indicating when the store was last updated. - last_updated: OffsetDateTime, + last_updated: DateTime, /// From which TrustedIssuer this struct was built (if applicable). source_iss: Option, } @@ -143,7 +143,7 @@ impl JwkStore { issuer: None, keys, keys_without_id, - last_updated: OffsetDateTime::now_utc(), + last_updated: Utc::now(), source_iss: None, }) } @@ -155,15 +155,17 @@ impl JwkStore { http_client: &HttpClient, ) -> Result { // fetch openid configuration - let response = http_client.get(&issuer.openid_configuration_endpoint)?; + let response = http_client + .get(&issuer.openid_configuration_endpoint) + .map_err(JwkStoreError::FetchOpenIdConfig)?; let openid_config = response .json::() - .map_err(JwkStoreError::FetchOpenIdConfig)?; + .map_err(JwkStoreError::DeserializeOpenIdConfig)?; // fetch jwks let response = http_client.get(&openid_config.jwks_uri)?; - let jwks = response.text().map_err(JwkStoreError::FetchJwks)?; + let jwks = response.text().map_err(JwkStoreError::DeserializeJwks)?; let mut store = Self::new_from_jwks_str(store_id, &jwks)?; store.issuer = Some(openid_config.issuer.into()); @@ -204,9 +206,11 @@ impl JwkStore { #[derive(thiserror::Error, Debug)] pub enum JwkStoreError { #[error("Failed to fetch OpenIdConfig remote server: {0}")] - FetchOpenIdConfig(#[source] reqwest::Error), - #[error("Failed to fetch JWKS from remote server: {0}")] - FetchJwks(#[source] reqwest::Error), + FetchOpenIdConfig(#[source] HttpClientError), + #[error("Failed to deserialize OpenIdConfig to JSON: {0}")] + DeserializeOpenIdConfig(#[source] HttpClientError), + #[error("Failed to fetch JWKS: {0}")] + DeserializeJwks(#[source] HttpClientError), #[error("Failed to make HTTP Request: {0}")] Http(#[from] HttpClientError), #[error("Failed to create Decoding Key from JWK: {0}")] @@ -232,11 +236,12 @@ mod test { common::policy_store::TrustedIssuer, jwt::{http_client::HttpClient, jwk_store::JwkStore}, }; + use tokio::time::Duration; + use chrono::prelude::*; use jsonwebtoken::{jwk::JwkSet, DecodingKey}; use mockito::Server; use serde_json::json; - use std::{collections::HashMap, time::Duration}; - use time::OffsetDateTime; + use std::collections::HashMap; #[test] fn can_load_from_jwkset() { @@ -266,7 +271,7 @@ mod test { .expect("Should create JwkStore"); // We edit the `last_updated` from the result so that the comparison // wont fail because of the timestamp. - result.last_updated = OffsetDateTime::from_unix_timestamp(0).unwrap(); + result.last_updated = DateTime::from_timestamp(0, 0).unwrap(); let expected_jwkset = serde_json::from_value::(jwks_json).expect("Should create JwkSet"); @@ -287,7 +292,7 @@ mod test { issuer: None, keys: expected_keys, keys_without_id: Vec::new(), - last_updated: OffsetDateTime::from_unix_timestamp(0).unwrap(), + last_updated: DateTime::from_timestamp(0, 0).unwrap(), source_iss: None, }; @@ -373,7 +378,7 @@ mod test { .expect("Should load JwkStore from Trusted Issuer"); // We edit the `last_updated` from the result so that the comparison // wont fail because of the timestamp. - result.last_updated = OffsetDateTime::from_unix_timestamp(0).unwrap(); + result.last_updated = DateTime::from_timestamp(0, 0).unwrap(); let jwkset = serde_json::from_value::(jwks_json).expect("Should create JwkSet from Value"); @@ -393,7 +398,7 @@ mod test { issuer: Some(mock_server.url().into()), keys: expected_keys, keys_without_id: Vec::new(), - last_updated: OffsetDateTime::from_unix_timestamp(0).unwrap(), + last_updated: DateTime::from_timestamp(0, 0).unwrap(), source_iss: Some(source_iss), }; @@ -442,7 +447,7 @@ mod test { .expect("Should create JwkStore"); // We edit the `last_updated` from the result so that the comparison // wont fail because of the timestamp. - result.last_updated = OffsetDateTime::from_unix_timestamp(0).unwrap(); + result.last_updated = DateTime::from_timestamp(0, 0).unwrap(); let jwkset = serde_json::from_value::(jwks_json).expect("Should create JwkSet"); let expected_keys = jwkset @@ -456,7 +461,7 @@ mod test { issuer: None, keys: HashMap::new(), keys_without_id: expected_keys, - last_updated: OffsetDateTime::from_unix_timestamp(0).unwrap(), + last_updated: DateTime::from_timestamp(0, 0).unwrap(), source_iss: None, }; @@ -489,7 +494,7 @@ mod test { .expect("Should create JwkStore"); // We edit the `last_updated` from the result so that the comparison // wont fail because of the timestamp. - result.last_updated = OffsetDateTime::from_unix_timestamp(0).unwrap(); + result.last_updated = DateTime::from_timestamp(0, 0).unwrap(); assert_eq!(result.get_keys().len(), 2, "Expected 2 keys"); } @@ -536,7 +541,7 @@ mod test { .expect("Should create JwkStore"); // We edit the `last_updated` from the result so that the comparison // wont fail because of the timestamp. - result.last_updated = OffsetDateTime::from_unix_timestamp(0).unwrap(); + result.last_updated = DateTime::from_timestamp(0, 0).unwrap(); let expected_jwkset = serde_json::from_value::(json!({"keys": [ { @@ -574,7 +579,7 @@ mod test { issuer: None, keys: expected_keys, keys_without_id: Vec::new(), - last_updated: OffsetDateTime::from_unix_timestamp(0).unwrap(), + last_updated: DateTime::from_timestamp(0, 0).unwrap(), source_iss: None, }; diff --git a/jans-cedarling/cedarling/src/jwt/key_service.rs b/jans-cedarling/cedarling/src/jwt/key_service.rs index dfbcb33a912..ed3bcf2dab1 100644 --- a/jans-cedarling/cedarling/src/jwt/key_service.rs +++ b/jans-cedarling/cedarling/src/jwt/key_service.rs @@ -13,7 +13,8 @@ use super::{ use crate::common::policy_store::TrustedIssuer; use jsonwebtoken::DecodingKey; use serde_json::{json, Value}; -use std::{collections::HashMap, sync::Arc, time::Duration}; +use std::{collections::HashMap, sync::Arc}; +use tokio::time::Duration; pub struct DecodingKeyWithIss<'a> { /// The decoding key used to validate JWT signatures. @@ -67,7 +68,7 @@ impl KeyService { pub fn new_from_trusted_issuers( trusted_issuers: &HashMap, ) -> Result { - let http_client = HttpClient::new(3, Duration::from_secs(3))?; + let http_client = HttpClient::new(3, Duration::new(3, 0))?; let mut key_stores = HashMap::new(); for (iss_id, iss) in trusted_issuers.iter() { diff --git a/jans-cedarling/cedarling/src/log/log_entry.rs b/jans-cedarling/cedarling/src/log/log_entry.rs index 686460ea70d..471b11250e1 100644 --- a/jans-cedarling/cedarling/src/log/log_entry.rs +++ b/jans-cedarling/cedarling/src/log/log_entry.rs @@ -8,11 +8,10 @@ //! # Log entry //! The module contains structs for logging events. +use chrono::prelude::*; use std::collections::HashSet; use std::fmt::Display; -use std::time::{SystemTime, UNIX_EPOCH}; - use uuid7::uuid7; use uuid7::Uuid; @@ -54,10 +53,10 @@ impl LogEntry { application_id: Option, log_kind: LogType, ) -> LogEntry { - let unix_time_sec = SystemTime::now() - .duration_since(UNIX_EPOCH) - .expect("Time went backwards") - .as_secs(); + let unix_time_sec = Utc::now() + .timestamp() + .try_into() + .expect("Failed to convert timestamp: value might be negative"); Self { // We use uuid v7 because it is generated based on the time and sortable. diff --git a/jans-cedarling/cedarling/src/log/memory_logger.rs b/jans-cedarling/cedarling/src/log/memory_logger.rs index 1c1b9e07208..b5f4c0fa521 100644 --- a/jans-cedarling/cedarling/src/log/memory_logger.rs +++ b/jans-cedarling/cedarling/src/log/memory_logger.rs @@ -8,8 +8,9 @@ use super::interface::{LogStorage, LogWriter}; use super::LogEntry; use crate::bootstrap_config::log_config::MemoryLogConfig; +use chrono::Duration; use sparkv::{Config as ConfigSparKV, SparKV}; -use std::{sync::Mutex, time::Duration}; +use std::sync::Mutex; const STORAGE_MUTEX_EXPECT_MESSAGE: &str = "MemoryLogger storage mutex should unlock"; const STORAGE_JSON_PARSE_EXPECT_MESSAGE: &str = @@ -23,7 +24,11 @@ pub(crate) struct MemoryLogger { impl MemoryLogger { pub fn new(config: MemoryLogConfig) -> Self { let sparkv_config = ConfigSparKV { - default_ttl: Duration::from_secs(config.log_ttl), + default_ttl: Duration::new( + config.log_ttl.try_into().expect("u64 that fits in a i64"), + 0, + ) + .expect("a valid duration"), ..Default::default() }; diff --git a/jans-cedarling/cedarling/src/log/stdout_logger.rs b/jans-cedarling/cedarling/src/log/stdout_logger.rs index b474a956145..ab81fe00a33 100644 --- a/jans-cedarling/cedarling/src/log/stdout_logger.rs +++ b/jans-cedarling/cedarling/src/log/stdout_logger.rs @@ -87,12 +87,9 @@ impl Write for TestWriter { #[cfg(test)] mod tests { use super::super::LogType; - + use chrono::prelude::*; use super::*; - use std::{ - io::Write, - time::{SystemTime, UNIX_EPOCH}, - }; + use std::io::Write; use uuid7::uuid7; @@ -101,10 +98,7 @@ mod tests { // Create a log entry let log_entry = LogEntry { id: uuid7(), - time: SystemTime::now() - .duration_since(UNIX_EPOCH) - .expect("Time went backwards") - .as_secs(), + time: Utc::now().timestamp().try_into().unwrap(), log_kind: LogType::Decision, pdp_id: uuid7(), application_id: Some("test_app".to_string().into()), diff --git a/jans-cedarling/cedarling/src/log/test.rs b/jans-cedarling/cedarling/src/log/test.rs index 13efc523a69..91be128e6b3 100644 --- a/jans-cedarling/cedarling/src/log/test.rs +++ b/jans-cedarling/cedarling/src/log/test.rs @@ -2,15 +2,12 @@ //! Contains unit tests for the main code flow with the `LogStrategy`` //! `LogStrategy` wraps all other logger implementations. -use std::{ - io::Write, - time::{SystemTime, UNIX_EPOCH}, -}; - use super::*; use crate::{common::app_types, log::stdout_logger::TestWriter}; +use chrono::prelude::*; use interface::LogWriter; use nop_logger::NopLogger; +use std::io::Write; use stdout_logger::StdOutLogger; use uuid7::uuid7; @@ -67,10 +64,7 @@ fn test_log_memory_logger() { let strategy = LogStrategy::new(&config); let entry = LogEntry { id: uuid7(), - time: SystemTime::now() - .duration_since(UNIX_EPOCH) - .expect("Time went backwards") - .as_secs(), + time: Utc::now().timestamp().try_into().unwrap(), log_kind: LogType::Decision, pdp_id: uuid7(), application_id: Some("test_app".to_string().into()), @@ -146,10 +140,7 @@ fn test_log_stdout_logger() { // Arrange let log_entry = LogEntry { id: uuid7(), - time: SystemTime::now() - .duration_since(UNIX_EPOCH) - .expect("Time went backwards") - .as_secs(), + time: Utc::now().timestamp().try_into().unwrap(), log_kind: LogType::Decision, pdp_id: uuid7(), application_id: Some("test_app".to_string().into()), diff --git a/jans-cedarling/sparkv/Cargo.toml b/jans-cedarling/sparkv/Cargo.toml index a09e6d7f132..ec0f51db975 100644 --- a/jans-cedarling/sparkv/Cargo.toml +++ b/jans-cedarling/sparkv/Cargo.toml @@ -13,3 +13,4 @@ homepage = "https://crates.io/crates/sparkv" [dependencies] thiserror = { workspace = true } +chrono = { workspace = true } diff --git a/jans-cedarling/sparkv/README.md b/jans-cedarling/sparkv/README.md index ab4aa01dc7e..b7278655800 100644 --- a/jans-cedarling/sparkv/README.md +++ b/jans-cedarling/sparkv/README.md @@ -26,7 +26,7 @@ sparkv.set("your-key", "your-value"); // write let value = sparkv.get("your-key").unwrap(); // read // Write with unique TTL -sparkv.set_with_ttl("diff-ttl", "your-value", std::time::Duration::from_secs(60)); +sparkv.set_with_ttl("diff-ttl", "your-value", chrono::Duration::new(60, 0)); ``` See `config.rs` for more configuration options. diff --git a/jans-cedarling/sparkv/src/config.rs b/jans-cedarling/sparkv/src/config.rs index 356ab16fbbb..d1f0c966c5c 100644 --- a/jans-cedarling/sparkv/src/config.rs +++ b/jans-cedarling/sparkv/src/config.rs @@ -5,12 +5,14 @@ * Copyright (c) 2024 U-Zyn Chua */ +use chrono::Duration; + #[derive(Debug, PartialEq, Clone, Copy)] pub struct Config { pub max_items: usize, pub max_item_size: usize, - pub max_ttl: std::time::Duration, - pub default_ttl: std::time::Duration, + pub max_ttl: Duration, + pub default_ttl: Duration, pub auto_clear_expired: bool, } @@ -19,8 +21,8 @@ impl Config { Config { max_items: 10_000, max_item_size: 500_000, - max_ttl: std::time::Duration::from_secs(60 * 60), - default_ttl: std::time::Duration::from_secs(5 * 60), // 5 minutes + max_ttl: Duration::new(60 * 60, 0).expect("a valid duration"), + default_ttl: Duration::new(5 * 60, 0).expect("a valid duration"), // 5 minutes auto_clear_expired: true, } } @@ -41,8 +43,14 @@ mod tests { let config: Config = Config::new(); assert_eq!(config.max_items, 10_000); assert_eq!(config.max_item_size, 500_000); - assert_eq!(config.max_ttl, std::time::Duration::from_secs(60 * 60)); - assert_eq!(config.default_ttl, std::time::Duration::from_secs(5 * 60)); + assert_eq!( + config.max_ttl, + Duration::new(60 * 60, 0).expect("a valid duration") + ); + assert_eq!( + config.default_ttl, + Duration::new(5 * 60, 0).expect("a valid duration") + ); assert!(config.auto_clear_expired); } } diff --git a/jans-cedarling/sparkv/src/expentry.rs b/jans-cedarling/sparkv/src/expentry.rs index 014c7f98226..af4a435daea 100644 --- a/jans-cedarling/sparkv/src/expentry.rs +++ b/jans-cedarling/sparkv/src/expentry.rs @@ -6,16 +6,18 @@ */ use super::kventry::KvEntry; +use chrono::prelude::*; +use chrono::Duration; #[derive(Debug, Clone, PartialEq, Eq)] pub struct ExpEntry { pub key: String, - pub expired_at: std::time::Instant, + pub expired_at: DateTime, } impl ExpEntry { - pub fn new(key: &str, expiration: std::time::Duration) -> Self { - let expired_at: std::time::Instant = std::time::Instant::now() + expiration; + pub fn new(key: &str, expiration: Duration) -> Self { + let expired_at: DateTime = Utc::now() + expiration; Self { key: String::from(key), expired_at, @@ -30,7 +32,7 @@ impl ExpEntry { } pub fn is_expired(&self) -> bool { - self.expired_at < std::time::Instant::now() + self.expired_at < Utc::now() } } @@ -57,10 +59,10 @@ mod tests { #[test] fn test_new() { - let item = ExpEntry::new("key", std::time::Duration::from_secs(10)); + let item = ExpEntry::new("key", Duration::new(10, 0).expect("a valid duration")); assert_eq!(item.key, "key"); - assert!(item.expired_at > std::time::Instant::now() + std::time::Duration::from_secs(9)); - assert!(item.expired_at <= std::time::Instant::now() + std::time::Duration::from_secs(10)); + assert!(item.expired_at > Utc::now() + Duration::new(9, 0).expect("a valid duration")); + assert!(item.expired_at <= Utc::now() + Duration::new(10, 0).expect("a valid duration")); } #[test] @@ -68,7 +70,7 @@ mod tests { let kv_entry = KvEntry::new( "keyFromKV", "value from KV", - std::time::Duration::from_secs(10), + Duration::new(10, 0).expect("a valid duration"), ); let exp_item = ExpEntry::from_kv_entry(&kv_entry); assert_eq!(exp_item.key, "keyFromKV"); @@ -77,17 +79,16 @@ mod tests { #[test] fn test_cmp() { - let item_small = ExpEntry::new("k1", std::time::Duration::from_secs(10)); - let item_big = ExpEntry::new("k2", std::time::Duration::from_secs(8000)); + let item_small = ExpEntry::new("k1", Duration::new(10, 0).expect("a valid duration")); + let item_big = ExpEntry::new("k2", Duration::new(8000, 0).expect("a valid duration")); assert!(item_small > item_big); // reverse order assert!(item_big < item_small); // reverse order } #[test] fn test_is_expired() { - let item = ExpEntry::new("k1", std::time::Duration::from_millis(1)); - assert!(!item.is_expired()); - std::thread::sleep(std::time::Duration::from_millis(2)); + let item = ExpEntry::new("k1", Duration::new(0, 100).expect("a valid duration")); + std::thread::sleep(std::time::Duration::from_nanos(200)); assert!(item.is_expired()); } } diff --git a/jans-cedarling/sparkv/src/kventry.rs b/jans-cedarling/sparkv/src/kventry.rs index ea817bd812e..1e8644aa117 100644 --- a/jans-cedarling/sparkv/src/kventry.rs +++ b/jans-cedarling/sparkv/src/kventry.rs @@ -5,16 +5,19 @@ * Copyright (c) 2024 U-Zyn Chua */ +use chrono::prelude::*; +use chrono::Duration; + #[derive(Debug, Clone, PartialEq, Eq)] pub struct KvEntry { pub key: String, pub value: String, - pub expired_at: std::time::Instant, + pub expired_at: DateTime, } impl KvEntry { - pub fn new(key: &str, value: &str, expiration: std::time::Duration) -> Self { - let expired_at: std::time::Instant = std::time::Instant::now() + expiration; + pub fn new(key: &str, value: &str, expiration: Duration) -> Self { + let expired_at: DateTime = Utc::now() + expiration; Self { key: String::from(key), value: String::from(value), @@ -29,10 +32,14 @@ mod tests { #[test] fn test_new() { - let item = KvEntry::new("key", "value", std::time::Duration::from_secs(10)); + let item = KvEntry::new( + "key", + "value", + Duration::new(10, 0).expect("a valid duration"), + ); assert_eq!(item.key, "key"); assert_eq!(item.value, "value"); - assert!(item.expired_at > std::time::Instant::now() + std::time::Duration::from_secs(9)); - assert!(item.expired_at <= std::time::Instant::now() + std::time::Duration::from_secs(10)); + assert!(item.expired_at > Utc::now() + Duration::new(9, 0).expect("a valid duration")); + assert!(item.expired_at <= Utc::now() + Duration::new(10, 0).expect("a valid duration")); } } diff --git a/jans-cedarling/sparkv/src/lib.rs b/jans-cedarling/sparkv/src/lib.rs index 8c76171f013..0683ec4fb9d 100644 --- a/jans-cedarling/sparkv/src/lib.rs +++ b/jans-cedarling/sparkv/src/lib.rs @@ -15,6 +15,9 @@ pub use error::Error; pub use expentry::ExpEntry; pub use kventry::KvEntry; +use chrono::prelude::*; +use chrono::Duration; + pub struct SparKV { pub config: Config, data: std::collections::BTreeMap, @@ -39,12 +42,7 @@ impl SparKV { self.set_with_ttl(key, value, self.config.default_ttl) } - pub fn set_with_ttl( - &mut self, - key: &str, - value: &str, - ttl: std::time::Duration, - ) -> Result<(), Error> { + pub fn set_with_ttl(&mut self, key: &str, value: &str, ttl: Duration) -> Result<(), Error> { self.clear_expired_if_auto(); self.ensure_capacity_ignore_key(key)?; self.ensure_item_size(value)?; @@ -66,7 +64,7 @@ impl SparKV { // Only returns if it is not yet expired pub fn get_item(&self, key: &str) -> Option<&KvEntry> { let item = self.data.get(key)?; - if item.expired_at > std::time::Instant::now() { + if item.expired_at > Utc::now() { Some(item) } else { None @@ -151,7 +149,7 @@ impl SparKV { Ok(()) } - fn ensure_max_ttl(&self, ttl: std::time::Duration) -> Result<(), Error> { + fn ensure_max_ttl(&self, ttl: Duration) -> Result<(), Error> { if ttl > self.config.max_ttl { return Err(Error::TTLTooLong); } @@ -174,7 +172,10 @@ mod tests { let config: Config = Config::new(); assert_eq!(config.max_items, 10_000); assert_eq!(config.max_item_size, 500_000); - assert_eq!(config.max_ttl, std::time::Duration::from_secs(60 * 60)); + assert_eq!( + config.max_ttl, + Duration::new(60 * 60, 0).expect("a valid duration") + ); } #[test] @@ -213,7 +214,11 @@ mod tests { #[test] fn test_get_item() { let mut sparkv = SparKV::new(); - let item = KvEntry::new("keyARaw", "value99", std::time::Duration::from_secs(1)); + let item = KvEntry::new( + "keyARaw", + "value99", + Duration::new(1, 0).expect("a valid duration"), + ); sparkv.data.insert(item.key.clone(), item); let get_result = sparkv.get_item("keyARaw"); let unwrapped = get_result.unwrap(); @@ -228,10 +233,14 @@ mod tests { #[test] fn test_get_item_return_none_if_expired() { let mut sparkv = SparKV::new(); - _ = sparkv.set_with_ttl("kkk", "value", std::time::Duration::from_millis(50)); + _ = sparkv.set_with_ttl( + "kkk", + "value", + Duration::new(0, 10000).expect("a valid duration"), + ); assert_eq!(sparkv.get("kkk"), Some(String::from("value"))); - std::thread::sleep(std::time::Duration::from_millis(60)); + std::thread::sleep(std::time::Duration::from_nanos(30000)); assert_eq!(sparkv.get("kkk"), None); } @@ -263,8 +272,16 @@ mod tests { fn test_set_with_ttl() { let mut sparkv = SparKV::new(); _ = sparkv.set("longest", "value"); - _ = sparkv.set_with_ttl("longer", "value", std::time::Duration::from_secs(2)); - _ = sparkv.set_with_ttl("shorter", "value", std::time::Duration::from_secs(1)); + _ = sparkv.set_with_ttl( + "longer", + "value", + Duration::new(2, 0).expect("a valid duration"), + ); + _ = sparkv.set_with_ttl( + "shorter", + "value", + Duration::new(1, 0).expect("a valid duration"), + ); assert_eq!(sparkv.get("longer"), Some(String::from("value"))); assert_eq!(sparkv.get("shorter"), Some(String::from("value"))); @@ -281,24 +298,33 @@ mod tests { #[test] fn test_ensure_max_ttl() { let mut config: Config = Config::new(); - config.max_ttl = std::time::Duration::from_secs(3600); - config.default_ttl = std::time::Duration::from_secs(5000); + config.max_ttl = Duration::new(3600, 0).expect("a valid duration"); + config.default_ttl = Duration::new(5000, 0).expect("a valid duration"); let mut sparkv = SparKV::with_config(config); let set_result_long_def = sparkv.set("default is longer than max", "should fail"); assert!(set_result_long_def.is_err()); assert_eq!(set_result_long_def.unwrap_err(), Error::TTLTooLong); - let set_result_ok = - sparkv.set_with_ttl("shorter", "ok", std::time::Duration::from_secs(3599)); + let set_result_ok = sparkv.set_with_ttl( + "shorter", + "ok", + Duration::new(3599, 0).expect("a valid duration"), + ); assert!(set_result_ok.is_ok()); - let set_result_ok_2 = - sparkv.set_with_ttl("exact", "ok", std::time::Duration::from_secs(3600)); + let set_result_ok_2 = sparkv.set_with_ttl( + "exact", + "ok", + Duration::new(3600, 0).expect("a valid duration"), + ); assert!(set_result_ok_2.is_ok()); - let set_result_not_ok = - sparkv.set_with_ttl("not", "not ok", std::time::Duration::from_secs(3601)); + let set_result_not_ok = sparkv.set_with_ttl( + "not", + "not ok", + Duration::new(3601, 0).expect("a valid duration"), + ); assert!(set_result_not_ok.is_err()); assert_eq!(set_result_not_ok.unwrap_err(), Error::TTLTooLong); } @@ -321,10 +347,22 @@ mod tests { let mut config: Config = Config::new(); config.auto_clear_expired = false; let mut sparkv = SparKV::with_config(config); - _ = sparkv.set_with_ttl("not-yet-expired", "v", std::time::Duration::from_secs(90)); - _ = sparkv.set_with_ttl("expiring", "value", std::time::Duration::from_millis(1)); - _ = sparkv.set_with_ttl("not-expired", "value", std::time::Duration::from_secs(60)); - std::thread::sleep(std::time::Duration::from_millis(2)); + _ = sparkv.set_with_ttl( + "not-yet-expired", + "v", + Duration::new(0, 90).expect("a valid duration"), + ); + _ = sparkv.set_with_ttl( + "expiring", + "value", + Duration::new(1, 0).expect("a valid duration"), + ); + _ = sparkv.set_with_ttl( + "not-expired", + "value", + Duration::new(60, 0).expect("a valid duration"), + ); + std::thread::sleep(std::time::Duration::from_nanos(2)); assert_eq!(sparkv.len(), 3); let cleared_count = sparkv.clear_expired(); @@ -339,10 +377,22 @@ mod tests { let mut config: Config = Config::new(); config.auto_clear_expired = false; let mut sparkv = SparKV::with_config(config); - _ = sparkv.set_with_ttl("no-longer", "value", std::time::Duration::from_millis(1)); - _ = sparkv.set_with_ttl("no-longer", "v", std::time::Duration::from_secs(90)); - _ = sparkv.set_with_ttl("not-expired", "value", std::time::Duration::from_secs(60)); - std::thread::sleep(std::time::Duration::from_millis(2)); + _ = sparkv.set_with_ttl( + "no-longer", + "value", + Duration::new(0, 1).expect("a valid duration"), + ); + _ = sparkv.set_with_ttl( + "no-longer", + "v", + Duration::new(90, 0).expect("a valid duration"), + ); + _ = sparkv.set_with_ttl( + "not-expired", + "value", + Duration::new(60, 0).expect("a valid duration"), + ); + std::thread::sleep(std::time::Duration::from_nanos(2)); assert_eq!(sparkv.expiries.len(), 3); // overwriting key does not update expiries assert_eq!(sparkv.len(), 2); @@ -357,15 +407,31 @@ mod tests { let mut config: Config = Config::new(); config.auto_clear_expired = true; // explicitly setting it to true let mut sparkv = SparKV::with_config(config); - _ = sparkv.set_with_ttl("no-longer", "value", std::time::Duration::from_millis(1)); - _ = sparkv.set_with_ttl("no-longer", "v", std::time::Duration::from_secs(90)); - std::thread::sleep(std::time::Duration::from_millis(2)); - _ = sparkv.set_with_ttl("not-expired", "value", std::time::Duration::from_secs(60)); + _ = sparkv.set_with_ttl( + "no-longer", + "value", + Duration::new(1, 0).expect("a valid duration"), + ); + _ = sparkv.set_with_ttl( + "no-longer", + "v", + Duration::new(90, 0).expect("a valid duration"), + ); + std::thread::sleep(std::time::Duration::from_secs(2)); + _ = sparkv.set_with_ttl( + "not-expired", + "value", + Duration::new(60, 0).expect("a valid duration"), + ); assert_eq!(sparkv.expiries.len(), 2); // diff from above, because of auto clear assert_eq!(sparkv.len(), 2); - // auto clear - _ = sparkv.set_with_ttl("new-", "value", std::time::Duration::from_secs(60)); + // auto clear 2 + _ = sparkv.set_with_ttl( + "new-", + "value", + Duration::new(60, 0).expect("a valid duration"), + ); assert_eq!(sparkv.expiries.len(), 3); // should have cleared the expiries assert_eq!(sparkv.len(), 3); // but not actually deleting }