Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

object_score: Support Azure Fabric OAuth Provider #6382

Merged
merged 5 commits into from
Sep 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 86 additions & 2 deletions object_store/src/azure/builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@

use crate::azure::client::{AzureClient, AzureConfig};
use crate::azure::credential::{
AzureAccessKey, AzureCliCredential, ClientSecretOAuthProvider, ImdsManagedIdentityProvider,
WorkloadIdentityOAuthProvider,
AzureAccessKey, AzureCliCredential, ClientSecretOAuthProvider, FabricTokenOAuthProvider,
ImdsManagedIdentityProvider, WorkloadIdentityOAuthProvider,
};
use crate::azure::{AzureCredential, AzureCredentialProvider, MicrosoftAzure, STORE};
use crate::client::TokenCredentialProvider;
Expand Down Expand Up @@ -172,6 +172,14 @@ pub struct MicrosoftAzureBuilder {
use_fabric_endpoint: ConfigValue<bool>,
/// When set to true, skips tagging objects
disable_tagging: ConfigValue<bool>,
/// Fabric token service url
fabric_token_service_url: Option<String>,
/// Fabric workload host
fabric_workload_host: Option<String>,
/// Fabric session token
fabric_session_token: Option<String>,
/// Fabric cluster identifier
fabric_cluster_identifier: Option<String>,
}

/// Configuration keys for [`MicrosoftAzureBuilder`]
Expand Down Expand Up @@ -336,6 +344,34 @@ pub enum AzureConfigKey {
/// - `disable_tagging`
DisableTagging,

/// Fabric token service url
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I double checked and this enum is marked #[non_exhaustive] and thus it is ok to add new variants without breaking the API

///
/// Supported keys:
/// - `azure_fabric_token_service_url`
/// - `fabric_token_service_url`
FabricTokenServiceUrl,

/// Fabric workload host
///
/// Supported keys:
/// - `azure_fabric_workload_host`
/// - `fabric_workload_host`
FabricWorkloadHost,

/// Fabric session token
///
/// Supported keys:
/// - `azure_fabric_session_token`
/// - `fabric_session_token`
FabricSessionToken,

/// Fabric cluster identifier
///
/// Supported keys:
/// - `azure_fabric_cluster_identifier`
/// - `fabric_cluster_identifier`
FabricClusterIdentifier,

/// Client options
Client(ClientConfigKey),
}
Expand All @@ -361,6 +397,10 @@ impl AsRef<str> for AzureConfigKey {
Self::SkipSignature => "azure_skip_signature",
Self::ContainerName => "azure_container_name",
Self::DisableTagging => "azure_disable_tagging",
Self::FabricTokenServiceUrl => "azure_fabric_token_service_url",
Self::FabricWorkloadHost => "azure_fabric_workload_host",
Self::FabricSessionToken => "azure_fabric_session_token",
Self::FabricClusterIdentifier => "azure_fabric_cluster_identifier",
Self::Client(key) => key.as_ref(),
}
}
Expand Down Expand Up @@ -406,6 +446,14 @@ impl FromStr for AzureConfigKey {
"azure_skip_signature" | "skip_signature" => Ok(Self::SkipSignature),
"azure_container_name" | "container_name" => Ok(Self::ContainerName),
"azure_disable_tagging" | "disable_tagging" => Ok(Self::DisableTagging),
"azure_fabric_token_service_url" | "fabric_token_service_url" => {
Ok(Self::FabricTokenServiceUrl)
}
"azure_fabric_workload_host" | "fabric_workload_host" => Ok(Self::FabricWorkloadHost),
"azure_fabric_session_token" | "fabric_session_token" => Ok(Self::FabricSessionToken),
"azure_fabric_cluster_identifier" | "fabric_cluster_identifier" => {
Ok(Self::FabricClusterIdentifier)
}
// Backwards compatibility
"azure_allow_http" => Ok(Self::Client(ClientConfigKey::AllowHttp)),
_ => match s.strip_prefix("azure_").unwrap_or(s).parse() {
Expand Down Expand Up @@ -525,6 +573,14 @@ impl MicrosoftAzureBuilder {
}
AzureConfigKey::ContainerName => self.container_name = Some(value.into()),
AzureConfigKey::DisableTagging => self.disable_tagging.parse(value),
AzureConfigKey::FabricTokenServiceUrl => {
self.fabric_token_service_url = Some(value.into())
}
AzureConfigKey::FabricWorkloadHost => self.fabric_workload_host = Some(value.into()),
AzureConfigKey::FabricSessionToken => self.fabric_session_token = Some(value.into()),
AzureConfigKey::FabricClusterIdentifier => {
self.fabric_cluster_identifier = Some(value.into())
}
};
self
}
Expand Down Expand Up @@ -561,6 +617,10 @@ impl MicrosoftAzureBuilder {
AzureConfigKey::Client(key) => self.client_options.get_config_value(key),
AzureConfigKey::ContainerName => self.container_name.clone(),
AzureConfigKey::DisableTagging => Some(self.disable_tagging.to_string()),
AzureConfigKey::FabricTokenServiceUrl => self.fabric_token_service_url.clone(),
AzureConfigKey::FabricWorkloadHost => self.fabric_workload_host.clone(),
AzureConfigKey::FabricSessionToken => self.fabric_session_token.clone(),
AzureConfigKey::FabricClusterIdentifier => self.fabric_cluster_identifier.clone(),
}
}

Expand Down Expand Up @@ -856,6 +916,30 @@ impl MicrosoftAzureBuilder {

let credential = if let Some(credential) = self.credentials {
credential
} else if let (
Some(fabric_token_service_url),
Some(fabric_workload_host),
Some(fabric_session_token),
Some(fabric_cluster_identifier),
) = (
&self.fabric_token_service_url,
&self.fabric_workload_host,
&self.fabric_session_token,
&self.fabric_cluster_identifier,
) {
// This case should precede the bearer token case because it is more specific and will utilize the bearer token.
let fabric_credential = FabricTokenOAuthProvider::new(
fabric_token_service_url,
fabric_workload_host,
fabric_session_token,
fabric_cluster_identifier,
self.bearer_token.clone(),
);
Arc::new(TokenCredentialProvider::new(
fabric_credential,
self.client_options.client()?,
self.retry_config.clone(),
)) as _
} else if let Some(bearer_token) = self.bearer_token {
static_creds(AzureCredential::BearerToken(bearer_token))
} else if let Some(access_key) = self.access_key {
Expand Down
114 changes: 113 additions & 1 deletion object_store/src/azure/credential.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ use crate::client::{CredentialProvider, TokenProvider};
use crate::util::hmac_sha256;
use crate::RetryConfig;
use async_trait::async_trait;
use base64::prelude::BASE64_STANDARD;
use base64::prelude::{BASE64_STANDARD, BASE64_URL_SAFE_NO_PAD};
use base64::Engine;
use chrono::{DateTime, SecondsFormat, Utc};
use reqwest::header::{
Expand Down Expand Up @@ -51,10 +51,15 @@ pub(crate) static BLOB_TYPE: HeaderName = HeaderName::from_static("x-ms-blob-typ
pub(crate) static DELETE_SNAPSHOTS: HeaderName = HeaderName::from_static("x-ms-delete-snapshots");
pub(crate) static COPY_SOURCE: HeaderName = HeaderName::from_static("x-ms-copy-source");
static CONTENT_MD5: HeaderName = HeaderName::from_static("content-md5");
static PARTNER_TOKEN: HeaderName = HeaderName::from_static("x-ms-partner-token");
static CLUSTER_IDENTIFIER: HeaderName = HeaderName::from_static("x-ms-cluster-identifier");
static WORKLOAD_RESOURCE: HeaderName = HeaderName::from_static("x-ms-workload-resource-moniker");
static PROXY_HOST: HeaderName = HeaderName::from_static("x-ms-proxy-host");
pub(crate) const RFC1123_FMT: &str = "%a, %d %h %Y %T GMT";
const CONTENT_TYPE_JSON: &str = "application/json";
const MSI_SECRET_ENV_KEY: &str = "IDENTITY_HEADER";
const MSI_API_VERSION: &str = "2019-08-01";
const TOKEN_MIN_TTL: u64 = 300;

/// OIDC scope used when interacting with OAuth2 APIs
///
Expand Down Expand Up @@ -934,6 +939,113 @@ impl AzureCliCredential {
}
}

/// Encapsulates the logic to perform an OAuth token challenge for Fabric
#[derive(Debug)]
pub struct FabricTokenOAuthProvider {
fabric_token_service_url: String,
fabric_workload_host: String,
fabric_session_token: String,
fabric_cluster_identifier: String,
storage_access_token: Option<String>,
token_expiry: Option<u64>,
}

#[derive(Debug, Deserialize)]
struct Claims {
exp: u64,
}

impl FabricTokenOAuthProvider {
/// Create a new [`FabricTokenOAuthProvider`] for an azure backed store
pub fn new(
fabric_token_service_url: impl Into<String>,
fabric_workload_host: impl Into<String>,
fabric_session_token: impl Into<String>,
fabric_cluster_identifier: impl Into<String>,
storage_access_token: Option<String>,
) -> Self {
let (storage_access_token, token_expiry) = match storage_access_token {
Some(token) => match Self::validate_and_get_expiry(&token) {
Some(expiry) if expiry > Self::get_current_timestamp() + TOKEN_MIN_TTL => {
(Some(token), Some(expiry))
}
_ => (None, None),
},
None => (None, None),
};

Self {
fabric_token_service_url: fabric_token_service_url.into(),
fabric_workload_host: fabric_workload_host.into(),
fabric_session_token: fabric_session_token.into(),
fabric_cluster_identifier: fabric_cluster_identifier.into(),
storage_access_token,
token_expiry,
}
}

fn validate_and_get_expiry(token: &str) -> Option<u64> {
let payload = token.split('.').nth(1)?;
let decoded_bytes = BASE64_URL_SAFE_NO_PAD.decode(payload).ok()?;
let decoded_str = str::from_utf8(&decoded_bytes).ok()?;
let claims: Claims = serde_json::from_str(decoded_str).ok()?;
Some(claims.exp)
}

fn get_current_timestamp() -> u64 {
SystemTime::now()
.duration_since(SystemTime::UNIX_EPOCH)
.map_or(0, |d| d.as_secs())
}
}

#[async_trait::async_trait]
impl TokenProvider for FabricTokenOAuthProvider {
type Credential = AzureCredential;

/// Fetch a token
async fn fetch_token(
&self,
client: &Client,
retry: &RetryConfig,
) -> crate::Result<TemporaryToken<Arc<AzureCredential>>> {
if let Some(storage_access_token) = &self.storage_access_token {
if let Some(expiry) = self.token_expiry {
let exp_in = expiry - Self::get_current_timestamp();
if exp_in > TOKEN_MIN_TTL {
return Ok(TemporaryToken {
token: Arc::new(AzureCredential::BearerToken(storage_access_token.clone())),
expiry: Some(Instant::now() + Duration::from_secs(exp_in)),
});
}
}
}

let query_items = vec![("resource", AZURE_STORAGE_RESOURCE)];
let access_token: String = client
.request(Method::GET, &self.fabric_token_service_url)
.header(&PARTNER_TOKEN, self.fabric_session_token.as_str())
.header(&CLUSTER_IDENTIFIER, self.fabric_cluster_identifier.as_str())
.header(&WORKLOAD_RESOURCE, self.fabric_cluster_identifier.as_str())
.header(&PROXY_HOST, self.fabric_workload_host.as_str())
.query(&query_items)
.retryable(retry)
.idempotent(true)
.send()
.await
.context(TokenRequestSnafu)?
.text()
.await
.context(TokenResponseBodySnafu)?;
let exp_in = Self::validate_and_get_expiry(&access_token)
.map_or(3600, |expiry| expiry - Self::get_current_timestamp());
Ok(TemporaryToken {
token: Arc::new(AzureCredential::BearerToken(access_token)),
expiry: Some(Instant::now() + Duration::from_secs(exp_in)),
})
}
}

#[async_trait]
impl CredentialProvider for AzureCliCredential {
type Credential = AzureCredential;
Expand Down
Loading