From a6cc3ab75dd0965d61eec8a9fb2aace2f543e25d Mon Sep 17 00:00:00 2001 From: Yi Lin Date: Tue, 10 Sep 2024 13:07:11 +0000 Subject: [PATCH 1/5] Update Azure dependencies and add support for Fabric token authentication --- object_store/Cargo.toml | 5 +- object_store/src/azure/builder.rs | 86 +++++++++++++++++++++++++- object_store/src/azure/credential.rs | 91 +++++++++++++++++++++++++++- 3 files changed, 177 insertions(+), 5 deletions(-) diff --git a/object_store/Cargo.toml b/object_store/Cargo.toml index a878c0c605cf..2c9b4248fe82 100644 --- a/object_store/Cargo.toml +++ b/object_store/Cargo.toml @@ -17,7 +17,7 @@ [package] name = "object_store" -version = "0.11.0" +version = "0.10.2" edition = "2021" license = "MIT/Apache-2.0" readme = "README.md" @@ -55,13 +55,14 @@ ring = { version = "0.17", default-features = false, features = ["std"], optiona rustls-pemfile = { version = "2.0", default-features = false, features = ["std"], optional = true } tokio = { version = "1.25.0", features = ["sync", "macros", "rt", "time", "io-util"] } md-5 = { version = "0.10.6", default-features = false, optional = true } +jsonwebtoken = { version = "9.3.0", default-features = false, optional = true } [target.'cfg(target_family="unix")'.dev-dependencies] nix = { version = "0.29.0", features = ["fs"] } [features] cloud = ["serde", "serde_json", "quick-xml", "hyper", "reqwest", "reqwest/json", "reqwest/stream", "chrono/serde", "base64", "rand", "ring"] -azure = ["cloud"] +azure = ["cloud", "jsonwebtoken"] gcp = ["cloud", "rustls-pemfile"] aws = ["cloud", "md-5"] http = ["cloud"] diff --git a/object_store/src/azure/builder.rs b/object_store/src/azure/builder.rs index 0208073e85c6..9818a3228ba1 100644 --- a/object_store/src/azure/builder.rs +++ b/object_store/src/azure/builder.rs @@ -17,8 +17,8 @@ use crate::azure::client::{AzureClient, AzureConfig}; use crate::azure::credential::{ - AzureAccessKey, AzureCliCredential, ClientSecretOAuthProvider, ImdsManagedIdentityProvider, - WorkloadIdentityOAuthProvider, + AzureAccessKey, AzureCliCredential, ClientSecretOAuthProvider, FabricTokenOAuthProvider, + ImdsManagedIdentityProvider, WorkloadIdentityOAuthProvider, }; use crate::azure::{AzureCredential, AzureCredentialProvider, MicrosoftAzure, STORE}; use crate::client::TokenCredentialProvider; @@ -172,6 +172,14 @@ pub struct MicrosoftAzureBuilder { use_fabric_endpoint: ConfigValue, /// When set to true, skips tagging objects disable_tagging: ConfigValue, + /// Fabric token service url + fabric_token_service_url: Option, + /// Fabric workload host + fabric_workload_host: Option, + /// Fabric session token + fabric_session_token: Option, + /// Fabric cluster identifier + fabric_cluster_identifier: Option, } /// Configuration keys for [`MicrosoftAzureBuilder`] @@ -336,6 +344,34 @@ pub enum AzureConfigKey { /// - `disable_tagging` DisableTagging, + /// Fabric token service url + /// + /// Supported keys: + /// - `azure_fabric_token_service_url` + /// - `fabric_token_service_url` + FabricTokenServiceUrl, + + /// Fabric workload host + /// + /// Supported keys: + /// - `azure_fabric_workload_host` + /// - `fabric_workload_host` + FabricWorkloadHost, + + /// Fabric session token + /// + /// Supported keys: + /// - `azure_fabric_session_token` + /// - `fabric_session_token` + FabricSessionToken, + + /// Fabric cluster identifier + /// + /// Supported keys: + /// - `azure_fabric_cluster_identifier` + /// - `fabric_cluster_identifier` + FabricClusterIdentifier, + /// Client options Client(ClientConfigKey), } @@ -361,6 +397,10 @@ impl AsRef for AzureConfigKey { Self::SkipSignature => "azure_skip_signature", Self::ContainerName => "azure_container_name", Self::DisableTagging => "azure_disable_tagging", + Self::FabricTokenServiceUrl => "azure_fabric_token_service_url", + Self::FabricWorkloadHost => "azure_fabric_workload_host", + Self::FabricSessionToken => "azure_fabric_session_token", + Self::FabricClusterIdentifier => "azure_fabric_cluster_identifier", Self::Client(key) => key.as_ref(), } } @@ -406,6 +446,14 @@ impl FromStr for AzureConfigKey { "azure_skip_signature" | "skip_signature" => Ok(Self::SkipSignature), "azure_container_name" | "container_name" => Ok(Self::ContainerName), "azure_disable_tagging" | "disable_tagging" => Ok(Self::DisableTagging), + "azure_fabric_token_service_url" | "fabric_token_service_url" => { + Ok(Self::FabricTokenServiceUrl) + } + "azure_fabric_workload_host" | "fabric_workload_host" => Ok(Self::FabricWorkloadHost), + "azure_fabric_session_token" | "fabric_session_token" => Ok(Self::FabricSessionToken), + "azure_fabric_cluster_identifier" | "fabric_cluster_identifier" => { + Ok(Self::FabricClusterIdentifier) + } // Backwards compatibility "azure_allow_http" => Ok(Self::Client(ClientConfigKey::AllowHttp)), _ => match s.strip_prefix("azure_").unwrap_or(s).parse() { @@ -525,6 +573,14 @@ impl MicrosoftAzureBuilder { } AzureConfigKey::ContainerName => self.container_name = Some(value.into()), AzureConfigKey::DisableTagging => self.disable_tagging.parse(value), + AzureConfigKey::FabricTokenServiceUrl => { + self.fabric_token_service_url = Some(value.into()) + } + AzureConfigKey::FabricWorkloadHost => self.fabric_workload_host = Some(value.into()), + AzureConfigKey::FabricSessionToken => self.fabric_session_token = Some(value.into()), + AzureConfigKey::FabricClusterIdentifier => { + self.fabric_cluster_identifier = Some(value.into()) + } }; self } @@ -561,6 +617,10 @@ impl MicrosoftAzureBuilder { AzureConfigKey::Client(key) => self.client_options.get_config_value(key), AzureConfigKey::ContainerName => self.container_name.clone(), AzureConfigKey::DisableTagging => Some(self.disable_tagging.to_string()), + AzureConfigKey::FabricTokenServiceUrl => self.fabric_token_service_url.clone(), + AzureConfigKey::FabricWorkloadHost => self.fabric_workload_host.clone(), + AzureConfigKey::FabricSessionToken => self.fabric_session_token.clone(), + AzureConfigKey::FabricClusterIdentifier => self.fabric_cluster_identifier.clone(), } } @@ -895,6 +955,28 @@ impl MicrosoftAzureBuilder { static_creds(AzureCredential::SASToken(split_sas(&sas)?)) } else if self.use_azure_cli.get()? { Arc::new(AzureCliCredential::new()) as _ + } else if let ( + Some(fabric_token_service_url), + Some(fabric_workload_host), + Some(fabric_session_token), + Some(fabric_cluster_identifier), + ) = ( + &self.fabric_token_service_url, + &self.fabric_workload_host, + &self.fabric_session_token, + &self.fabric_cluster_identifier, + ) { + let fabric_credential = FabricTokenOAuthProvider::new( + fabric_token_service_url, + fabric_workload_host, + fabric_session_token, + fabric_cluster_identifier, + ); + Arc::new(TokenCredentialProvider::new( + fabric_credential, + self.client_options.client()?, + self.retry_config.clone(), + )) as _ } else { let msi_credential = ImdsManagedIdentityProvider::new( self.client_id, diff --git a/object_store/src/azure/credential.rs b/object_store/src/azure/credential.rs index 7808c7c4a7c8..31f195410ee3 100644 --- a/object_store/src/azure/credential.rs +++ b/object_store/src/azure/credential.rs @@ -25,13 +25,14 @@ use async_trait::async_trait; use base64::prelude::BASE64_STANDARD; use base64::Engine; use chrono::{DateTime, SecondsFormat, Utc}; +use jsonwebtoken::{decode, Algorithm, DecodingKey, TokenData, Validation}; use reqwest::header::{ HeaderMap, HeaderName, HeaderValue, ACCEPT, AUTHORIZATION, CONTENT_ENCODING, CONTENT_LANGUAGE, CONTENT_LENGTH, CONTENT_TYPE, DATE, IF_MATCH, IF_MODIFIED_SINCE, IF_NONE_MATCH, IF_UNMODIFIED_SINCE, RANGE, }; use reqwest::{Client, Method, Request, RequestBuilder}; -use serde::Deserialize; +use serde::{Deserialize, Serialize}; use snafu::{ResultExt, Snafu}; use std::borrow::Cow; use std::collections::HashMap; @@ -934,6 +935,94 @@ impl AzureCliCredential { } } +/// Encapsulates the logic to perform an OAuth token challenge for Fabric +#[derive(Debug)] +pub struct FabricTokenOAuthProvider { + fabric_token_service_url: String, + fabric_workload_host: String, + fabric_session_token: String, + fabric_cluster_identifier: String, +} + +#[derive(Debug, Serialize, Deserialize)] +struct Claims { + exp: usize, +} + +impl FabricTokenOAuthProvider { + /// Create a new [`FabricTokenOAuthProvider`] for an azure backed store + pub fn new( + fabric_token_service_url: impl Into, + fabric_workload_host: impl Into, + fabric_session_token: impl Into, + fabric_cluster_identifier: impl Into, + ) -> Self { + Self { + fabric_token_service_url: fabric_token_service_url.into(), + fabric_workload_host: fabric_workload_host.into(), + fabric_session_token: fabric_session_token.into(), + fabric_cluster_identifier: fabric_cluster_identifier.into(), + } + } +} + +#[async_trait::async_trait] +impl TokenProvider for FabricTokenOAuthProvider { + type Credential = AzureCredential; + + /// Fetch a token + async fn fetch_token( + &self, + client: &Client, + retry: &RetryConfig, + ) -> crate::Result>> { + let query_items = vec![("resource", AZURE_STORAGE_RESOURCE)]; + let access_token: String = client + .request(Method::GET, &self.fabric_token_service_url) + .header("Content-Type", "application/json;charset=utf-8") + .header("x-ms-partner-token", self.fabric_session_token.as_str()) + .header( + "x-ms-cluster-identifier", + self.fabric_cluster_identifier.as_str(), + ) + .header( + "x-ms-workload-resource-moniker", + self.fabric_cluster_identifier.as_str(), + ) + .header("x-ms-proxy-host", self.fabric_workload_host.as_str()) + .query(&query_items) + .retryable(retry) + .idempotent(true) + .send() + .await + .context(TokenRequestSnafu)? + .text() + .await + .context(TokenResponseBodySnafu)?; + + let mut validation: Validation = Validation::new(Algorithm::HS256); + validation.insecure_disable_signature_validation(); + let token_data: Result, _> = + decode::(&access_token, &DecodingKey::from_secret(&[]), &validation); + let exp = match token_data { + Ok(data) => data.claims.exp as u64, + Err(_) => { + let current_time = SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .unwrap() + .as_secs(); + (current_time + 3600) as u64 + } + }; + print!("exp: {}", exp); + + Ok(TemporaryToken { + token: Arc::new(AzureCredential::BearerToken(access_token)), + expiry: Some(Instant::now() + Duration::from_secs(exp)), + }) + } +} + #[async_trait] impl CredentialProvider for AzureCliCredential { type Credential = AzureCredential; From 5a9fe4352b04cd82a85f903961ca29c0a2f1d373 Mon Sep 17 00:00:00 2001 From: Yi Lin Date: Wed, 11 Sep 2024 13:14:52 +0000 Subject: [PATCH 2/5] Refactor Azure credential provider to support Fabric token authentication --- object_store/src/azure/builder.rs | 46 +++++++------- object_store/src/azure/credential.rs | 90 +++++++++++++++++++--------- 2 files changed, 85 insertions(+), 51 deletions(-) diff --git a/object_store/src/azure/builder.rs b/object_store/src/azure/builder.rs index 9818a3228ba1..35cedeafc049 100644 --- a/object_store/src/azure/builder.rs +++ b/object_store/src/azure/builder.rs @@ -916,6 +916,30 @@ impl MicrosoftAzureBuilder { let credential = if let Some(credential) = self.credentials { credential + } else if let ( + Some(fabric_token_service_url), + Some(fabric_workload_host), + Some(fabric_session_token), + Some(fabric_cluster_identifier), + ) = ( + &self.fabric_token_service_url, + &self.fabric_workload_host, + &self.fabric_session_token, + &self.fabric_cluster_identifier, + ) { + // This case should precede the bearer token case because it is more specific and will utilize the bearer token. + let fabric_credential = FabricTokenOAuthProvider::new( + fabric_token_service_url, + fabric_workload_host, + fabric_session_token, + fabric_cluster_identifier, + self.bearer_token.clone(), + ); + Arc::new(TokenCredentialProvider::new( + fabric_credential, + self.client_options.client()?, + self.retry_config.clone(), + )) as _ } else if let Some(bearer_token) = self.bearer_token { static_creds(AzureCredential::BearerToken(bearer_token)) } else if let Some(access_key) = self.access_key { @@ -955,28 +979,6 @@ impl MicrosoftAzureBuilder { static_creds(AzureCredential::SASToken(split_sas(&sas)?)) } else if self.use_azure_cli.get()? { Arc::new(AzureCliCredential::new()) as _ - } else if let ( - Some(fabric_token_service_url), - Some(fabric_workload_host), - Some(fabric_session_token), - Some(fabric_cluster_identifier), - ) = ( - &self.fabric_token_service_url, - &self.fabric_workload_host, - &self.fabric_session_token, - &self.fabric_cluster_identifier, - ) { - let fabric_credential = FabricTokenOAuthProvider::new( - fabric_token_service_url, - fabric_workload_host, - fabric_session_token, - fabric_cluster_identifier, - ); - Arc::new(TokenCredentialProvider::new( - fabric_credential, - self.client_options.client()?, - self.retry_config.clone(), - )) as _ } else { let msi_credential = ImdsManagedIdentityProvider::new( self.client_id, diff --git a/object_store/src/azure/credential.rs b/object_store/src/azure/credential.rs index 31f195410ee3..cb34c012ef41 100644 --- a/object_store/src/azure/credential.rs +++ b/object_store/src/azure/credential.rs @@ -25,7 +25,7 @@ use async_trait::async_trait; use base64::prelude::BASE64_STANDARD; use base64::Engine; use chrono::{DateTime, SecondsFormat, Utc}; -use jsonwebtoken::{decode, Algorithm, DecodingKey, TokenData, Validation}; +use jsonwebtoken::{decode, get_current_timestamp, Algorithm, DecodingKey, Validation}; use reqwest::header::{ HeaderMap, HeaderName, HeaderValue, ACCEPT, AUTHORIZATION, CONTENT_ENCODING, CONTENT_LANGUAGE, CONTENT_LENGTH, CONTENT_TYPE, DATE, IF_MATCH, IF_MODIFIED_SINCE, IF_NONE_MATCH, @@ -52,10 +52,15 @@ pub(crate) static BLOB_TYPE: HeaderName = HeaderName::from_static("x-ms-blob-typ pub(crate) static DELETE_SNAPSHOTS: HeaderName = HeaderName::from_static("x-ms-delete-snapshots"); pub(crate) static COPY_SOURCE: HeaderName = HeaderName::from_static("x-ms-copy-source"); static CONTENT_MD5: HeaderName = HeaderName::from_static("content-md5"); +static PARTNER_TOKEN: HeaderName = HeaderName::from_static("x-ms-partner-token"); +static CLUSTER_IDENTIFIER: HeaderName = HeaderName::from_static("x-ms-cluster-identifier"); +static WORKLOAD_RESOURCE: HeaderName = HeaderName::from_static("x-ms-workload-resource-moniker"); +static PROXY_HOST: HeaderName = HeaderName::from_static("x-ms-proxy-host"); pub(crate) const RFC1123_FMT: &str = "%a, %d %h %Y %T GMT"; const CONTENT_TYPE_JSON: &str = "application/json"; const MSI_SECRET_ENV_KEY: &str = "IDENTITY_HEADER"; const MSI_API_VERSION: &str = "2019-08-01"; +const TOKEN_MIN_TTL: u64 = 300; /// OIDC scope used when interacting with OAuth2 APIs /// @@ -942,11 +947,13 @@ pub struct FabricTokenOAuthProvider { fabric_workload_host: String, fabric_session_token: String, fabric_cluster_identifier: String, + storage_access_token: Option, + token_expiry: Option, } #[derive(Debug, Serialize, Deserialize)] struct Claims { - exp: usize, + exp: u64, } impl FabricTokenOAuthProvider { @@ -956,14 +963,41 @@ impl FabricTokenOAuthProvider { fabric_workload_host: impl Into, fabric_session_token: impl Into, fabric_cluster_identifier: impl Into, + storage_access_token: Option, ) -> Self { + let (storage_access_token, token_expiry) = if let Some(token) = storage_access_token { + if let Some(expiry) = Self::validate_and_get_expiry(&token) { + if expiry > get_current_timestamp() + TOKEN_MIN_TTL { + (Some(token), Some(expiry)) + } else { + (None, None) + } + } else { + (None, None) + } + } else { + (None, None) + }; + Self { fabric_token_service_url: fabric_token_service_url.into(), fabric_workload_host: fabric_workload_host.into(), fabric_session_token: fabric_session_token.into(), fabric_cluster_identifier: fabric_cluster_identifier.into(), + storage_access_token, + token_expiry, } } + + fn validate_and_get_expiry(token: &str) -> Option { + let mut validation: Validation = Validation::new(Algorithm::HS256); + validation.insecure_disable_signature_validation(); + validation.set_audience(&[AZURE_STORAGE_RESOURCE]); + let key = DecodingKey::from_secret(&[]); + decode::(token, &key, &validation) + .ok() + .map(|data| data.claims.exp) + } } #[async_trait::async_trait] @@ -976,20 +1010,30 @@ impl TokenProvider for FabricTokenOAuthProvider { client: &Client, retry: &RetryConfig, ) -> crate::Result>> { + if let Some(storage_access_token) = &self.storage_access_token { + if let Some(expiry) = self.token_expiry { + let exp_in = expiry - get_current_timestamp(); + if exp_in > TOKEN_MIN_TTL { + return Ok(TemporaryToken { + token: Arc::new(AzureCredential::BearerToken(storage_access_token.clone())), + expiry: Some(Instant::now() + Duration::from_secs(exp_in)), + }); + } else { + println!("access token is expired"); + } + } else { + println!("access token is invalid"); + } + } + + println!("requesting new token"); let query_items = vec![("resource", AZURE_STORAGE_RESOURCE)]; let access_token: String = client .request(Method::GET, &self.fabric_token_service_url) - .header("Content-Type", "application/json;charset=utf-8") - .header("x-ms-partner-token", self.fabric_session_token.as_str()) - .header( - "x-ms-cluster-identifier", - self.fabric_cluster_identifier.as_str(), - ) - .header( - "x-ms-workload-resource-moniker", - self.fabric_cluster_identifier.as_str(), - ) - .header("x-ms-proxy-host", self.fabric_workload_host.as_str()) + .header(&PARTNER_TOKEN, self.fabric_session_token.as_str()) + .header(&CLUSTER_IDENTIFIER, self.fabric_cluster_identifier.as_str()) + .header(&WORKLOAD_RESOURCE, self.fabric_cluster_identifier.as_str()) + .header(&PROXY_HOST, self.fabric_workload_host.as_str()) .query(&query_items) .retryable(retry) .idempotent(true) @@ -1000,25 +1044,13 @@ impl TokenProvider for FabricTokenOAuthProvider { .await .context(TokenResponseBodySnafu)?; - let mut validation: Validation = Validation::new(Algorithm::HS256); - validation.insecure_disable_signature_validation(); - let token_data: Result, _> = - decode::(&access_token, &DecodingKey::from_secret(&[]), &validation); - let exp = match token_data { - Ok(data) => data.claims.exp as u64, - Err(_) => { - let current_time = SystemTime::now() - .duration_since(SystemTime::UNIX_EPOCH) - .unwrap() - .as_secs(); - (current_time + 3600) as u64 - } - }; - print!("exp: {}", exp); + let exp_in = Self::validate_and_get_expiry(&access_token) + .map_or(3600, |expiry| expiry - get_current_timestamp()); + println!("exp_in: {}", exp_in); Ok(TemporaryToken { token: Arc::new(AzureCredential::BearerToken(access_token)), - expiry: Some(Instant::now() + Duration::from_secs(exp)), + expiry: Some(Instant::now() + Duration::from_secs(exp_in)), }) } } From 816cc79de25e2c9da06b439df21a4667f81413f3 Mon Sep 17 00:00:00 2001 From: Yi Lin Date: Wed, 11 Sep 2024 13:17:43 +0000 Subject: [PATCH 3/5] Refactor Azure credential provider to remove unnecessary print statements and improve token handling --- object_store/src/azure/credential.rs | 8 -------- 1 file changed, 8 deletions(-) diff --git a/object_store/src/azure/credential.rs b/object_store/src/azure/credential.rs index cb34c012ef41..8da2b7ef11f4 100644 --- a/object_store/src/azure/credential.rs +++ b/object_store/src/azure/credential.rs @@ -1018,15 +1018,10 @@ impl TokenProvider for FabricTokenOAuthProvider { token: Arc::new(AzureCredential::BearerToken(storage_access_token.clone())), expiry: Some(Instant::now() + Duration::from_secs(exp_in)), }); - } else { - println!("access token is expired"); } - } else { - println!("access token is invalid"); } } - println!("requesting new token"); let query_items = vec![("resource", AZURE_STORAGE_RESOURCE)]; let access_token: String = client .request(Method::GET, &self.fabric_token_service_url) @@ -1043,11 +1038,8 @@ impl TokenProvider for FabricTokenOAuthProvider { .text() .await .context(TokenResponseBodySnafu)?; - let exp_in = Self::validate_and_get_expiry(&access_token) .map_or(3600, |expiry| expiry - get_current_timestamp()); - println!("exp_in: {}", exp_in); - Ok(TemporaryToken { token: Arc::new(AzureCredential::BearerToken(access_token)), expiry: Some(Instant::now() + Duration::from_secs(exp_in)), From a0682bc59bfcfda93953ad638319c0bbde8b52b2 Mon Sep 17 00:00:00 2001 From: Yi Lin Date: Wed, 11 Sep 2024 13:20:21 +0000 Subject: [PATCH 4/5] Bump object_store version to 0.11.0 --- object_store/Cargo.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/object_store/Cargo.toml b/object_store/Cargo.toml index 2c9b4248fe82..4d943cfd2adf 100644 --- a/object_store/Cargo.toml +++ b/object_store/Cargo.toml @@ -17,7 +17,7 @@ [package] name = "object_store" -version = "0.10.2" +version = "0.11.0" edition = "2021" license = "MIT/Apache-2.0" readme = "README.md" From 92144b20927e312bc63e1a53a5225257a7bc24d9 Mon Sep 17 00:00:00 2001 From: Yi Lin Date: Fri, 13 Sep 2024 00:52:20 +0000 Subject: [PATCH 5/5] Refactor Azure credential provider to remove unnecessary print statements and improve token handling --- object_store/Cargo.toml | 3 +- object_store/src/azure/credential.rs | 45 ++++++++++++++-------------- 2 files changed, 23 insertions(+), 25 deletions(-) diff --git a/object_store/Cargo.toml b/object_store/Cargo.toml index 4d943cfd2adf..a878c0c605cf 100644 --- a/object_store/Cargo.toml +++ b/object_store/Cargo.toml @@ -55,14 +55,13 @@ ring = { version = "0.17", default-features = false, features = ["std"], optiona rustls-pemfile = { version = "2.0", default-features = false, features = ["std"], optional = true } tokio = { version = "1.25.0", features = ["sync", "macros", "rt", "time", "io-util"] } md-5 = { version = "0.10.6", default-features = false, optional = true } -jsonwebtoken = { version = "9.3.0", default-features = false, optional = true } [target.'cfg(target_family="unix")'.dev-dependencies] nix = { version = "0.29.0", features = ["fs"] } [features] cloud = ["serde", "serde_json", "quick-xml", "hyper", "reqwest", "reqwest/json", "reqwest/stream", "chrono/serde", "base64", "rand", "ring"] -azure = ["cloud", "jsonwebtoken"] +azure = ["cloud"] gcp = ["cloud", "rustls-pemfile"] aws = ["cloud", "md-5"] http = ["cloud"] diff --git a/object_store/src/azure/credential.rs b/object_store/src/azure/credential.rs index 8da2b7ef11f4..6b5fa19d154b 100644 --- a/object_store/src/azure/credential.rs +++ b/object_store/src/azure/credential.rs @@ -22,17 +22,16 @@ use crate::client::{CredentialProvider, TokenProvider}; use crate::util::hmac_sha256; use crate::RetryConfig; use async_trait::async_trait; -use base64::prelude::BASE64_STANDARD; +use base64::prelude::{BASE64_STANDARD, BASE64_URL_SAFE_NO_PAD}; use base64::Engine; use chrono::{DateTime, SecondsFormat, Utc}; -use jsonwebtoken::{decode, get_current_timestamp, Algorithm, DecodingKey, Validation}; use reqwest::header::{ HeaderMap, HeaderName, HeaderValue, ACCEPT, AUTHORIZATION, CONTENT_ENCODING, CONTENT_LANGUAGE, CONTENT_LENGTH, CONTENT_TYPE, DATE, IF_MATCH, IF_MODIFIED_SINCE, IF_NONE_MATCH, IF_UNMODIFIED_SINCE, RANGE, }; use reqwest::{Client, Method, Request, RequestBuilder}; -use serde::{Deserialize, Serialize}; +use serde::Deserialize; use snafu::{ResultExt, Snafu}; use std::borrow::Cow; use std::collections::HashMap; @@ -951,7 +950,7 @@ pub struct FabricTokenOAuthProvider { token_expiry: Option, } -#[derive(Debug, Serialize, Deserialize)] +#[derive(Debug, Deserialize)] struct Claims { exp: u64, } @@ -965,18 +964,14 @@ impl FabricTokenOAuthProvider { fabric_cluster_identifier: impl Into, storage_access_token: Option, ) -> Self { - let (storage_access_token, token_expiry) = if let Some(token) = storage_access_token { - if let Some(expiry) = Self::validate_and_get_expiry(&token) { - if expiry > get_current_timestamp() + TOKEN_MIN_TTL { + let (storage_access_token, token_expiry) = match storage_access_token { + Some(token) => match Self::validate_and_get_expiry(&token) { + Some(expiry) if expiry > Self::get_current_timestamp() + TOKEN_MIN_TTL => { (Some(token), Some(expiry)) - } else { - (None, None) } - } else { - (None, None) - } - } else { - (None, None) + _ => (None, None), + }, + None => (None, None), }; Self { @@ -990,13 +985,17 @@ impl FabricTokenOAuthProvider { } fn validate_and_get_expiry(token: &str) -> Option { - let mut validation: Validation = Validation::new(Algorithm::HS256); - validation.insecure_disable_signature_validation(); - validation.set_audience(&[AZURE_STORAGE_RESOURCE]); - let key = DecodingKey::from_secret(&[]); - decode::(token, &key, &validation) - .ok() - .map(|data| data.claims.exp) + let payload = token.split('.').nth(1)?; + let decoded_bytes = BASE64_URL_SAFE_NO_PAD.decode(payload).ok()?; + let decoded_str = str::from_utf8(&decoded_bytes).ok()?; + let claims: Claims = serde_json::from_str(decoded_str).ok()?; + Some(claims.exp) + } + + fn get_current_timestamp() -> u64 { + SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .map_or(0, |d| d.as_secs()) } } @@ -1012,7 +1011,7 @@ impl TokenProvider for FabricTokenOAuthProvider { ) -> crate::Result>> { if let Some(storage_access_token) = &self.storage_access_token { if let Some(expiry) = self.token_expiry { - let exp_in = expiry - get_current_timestamp(); + let exp_in = expiry - Self::get_current_timestamp(); if exp_in > TOKEN_MIN_TTL { return Ok(TemporaryToken { token: Arc::new(AzureCredential::BearerToken(storage_access_token.clone())), @@ -1039,7 +1038,7 @@ impl TokenProvider for FabricTokenOAuthProvider { .await .context(TokenResponseBodySnafu)?; let exp_in = Self::validate_and_get_expiry(&access_token) - .map_or(3600, |expiry| expiry - get_current_timestamp()); + .map_or(3600, |expiry| expiry - Self::get_current_timestamp()); Ok(TemporaryToken { token: Arc::new(AzureCredential::BearerToken(access_token)), expiry: Some(Instant::now() + Duration::from_secs(exp_in)),