diff --git a/examples/simple.rs b/examples/simple.rs index d3d840b..89e339a 100644 --- a/examples/simple.rs +++ b/examples/simple.rs @@ -8,8 +8,11 @@ async fn main() -> Result<(), Box> { .unwrap(); let token_provider = gcp_auth::provider().await?; - let _token = token_provider + let token = token_provider .token(&["https://www.googleapis.com/auth/cloud-platform"]) .await?; + + tracing::info!(email = token_provider.email().await?, token = ?token); + Ok(()) } diff --git a/src/config_default_credentials.rs b/src/config_default_credentials.rs index 10cb513..0c5958f 100644 --- a/src/config_default_credentials.rs +++ b/src/config_default_credentials.rs @@ -49,31 +49,28 @@ impl ConfigDefaultCredentials { }) } - #[instrument(level = Level::DEBUG, skip(cred, client))] + #[instrument(level = Level::DEBUG, skip(cred, client), fields(provider = "ConfigDefaultCredentials"))] async fn fetch_token( cred: &AuthorizedUserRefreshToken, client: &HttpClient, ) -> Result, Error> { client - .token( - &|| { - Request::builder() - .method(Method::POST) - .uri(DEFAULT_TOKEN_GCP_URI) - .header(CONTENT_TYPE, "application/json") - .body(Full::from(Bytes::from( - serde_json::to_vec(&RefreshRequest { - client_id: &cred.client_id, - client_secret: &cred.client_secret, - grant_type: "refresh_token", - refresh_token: &cred.refresh_token, - }) - .unwrap(), - ))) - .unwrap() - }, - "ConfigDefaultCredentials", - ) + .token(&|| { + Request::builder() + .method(Method::POST) + .uri(DEFAULT_TOKEN_GCP_URI) + .header(CONTENT_TYPE, "application/json") + .body(Full::from(Bytes::from( + serde_json::to_vec(&RefreshRequest { + client_id: &cred.client_id, + client_secret: &cred.client_secret, + grant_type: "refresh_token", + refresh_token: &cred.refresh_token, + }) + .unwrap(), + ))) + .unwrap() + }) .await } } @@ -92,6 +89,12 @@ impl TokenProvider for ConfigDefaultCredentials { Ok(token) } + async fn email(&self) -> Result { + let token = self.token(&[]).await?; + let info = self.client.token_info(&token).await?; + Ok(info.email) + } + async fn project_id(&self) -> Result, Error> { self.credentials .quota_project_id diff --git a/src/custom_service_account.rs b/src/custom_service_account.rs index a138f6f..28aa8bd 100644 --- a/src/custom_service_account.rs +++ b/src/custom_service_account.rs @@ -70,7 +70,7 @@ impl CustomServiceAccount { }) } - #[instrument(level = Level::DEBUG, skip(self))] + #[instrument(level = Level::DEBUG, skip(self), fields(provider = "CustomServiceAccount"))] async fn fetch_token(&self, scopes: &[&str]) -> Result, Error> { let jwt = Claims::new(&self.credentials, scopes, self.subject.as_deref()).to_jwt(&self.signer)?; @@ -83,15 +83,12 @@ impl CustomServiceAccount { let token = self .client - .token( - &|| { - Request::post(&self.credentials.token_uri) - .header(CONTENT_TYPE, "application/x-www-form-urlencoded") - .body(Full::from(body.clone())) - .unwrap() - }, - "CustomServiceAccount", - ) + .token(&|| { + Request::post(&self.credentials.token_uri) + .header(CONTENT_TYPE, "application/x-www-form-urlencoded") + .body(Full::from(body.clone())) + .unwrap() + }) .await?; Ok(token) @@ -135,6 +132,10 @@ impl TokenProvider for CustomServiceAccount { return Ok(token); } + async fn email(&self) -> Result { + Ok(self.credentials.client_email.clone()) + } + async fn project_id(&self) -> Result, Error> { match &self.credentials.project_id { Some(pid) => Ok(pid.clone()), diff --git a/src/gcloud_authorized_user.rs b/src/gcloud_authorized_user.rs index ffc6e59..fba2e68 100644 --- a/src/gcloud_authorized_user.rs +++ b/src/gcloud_authorized_user.rs @@ -28,7 +28,7 @@ impl GCloudAuthorizedUser { }) } - #[instrument(level = tracing::Level::DEBUG)] + #[instrument(level = tracing::Level::DEBUG, fields(provider = "GCloudAuthorizedUser"))] fn fetch_token() -> Result, Error> { Ok(Arc::new(Token::from_string( run(&["auth", "print-access-token", "--quiet"])?, @@ -51,6 +51,10 @@ impl TokenProvider for GCloudAuthorizedUser { Ok(token) } + async fn email(&self) -> Result { + run(&["auth", "print-identity-token"]) + } + async fn project_id(&self) -> Result, Error> { self.project_id .clone() diff --git a/src/lib.rs b/src/lib.rs index d170475..8639e08 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -167,6 +167,8 @@ pub trait TokenProvider: Send + Sync { /// the current token (for the given scopes) has expired. async fn token(&self, scopes: &[&str]) -> Result, Error>; + async fn email(&self) -> Result; + /// Get the project ID for the authentication context async fn project_id(&self) -> Result, Error>; } diff --git a/src/metadata_service_account.rs b/src/metadata_service_account.rs index 3f244ac..cd52e06 100644 --- a/src/metadata_service_account.rs +++ b/src/metadata_service_account.rs @@ -34,7 +34,7 @@ impl MetadataServiceAccount { debug!("getting project ID from GCP instance metadata server"); let req = metadata_request(DEFAULT_PROJECT_ID_GCP_URI); - let body = client.request(req, "MetadataServiceAccount").await?; + let body = client.request(req).await?; let project_id = match str::from_utf8(&body) { Ok(s) if !s.is_empty() => Arc::from(s), Ok(_) => { @@ -56,13 +56,10 @@ impl MetadataServiceAccount { }) } - #[instrument(level = Level::DEBUG, skip(client))] + #[instrument(level = Level::DEBUG, skip(client), fields(provider = "MetadataServiceAccount"))] async fn fetch_token(client: &HttpClient) -> Result, Error> { client - .token( - &|| metadata_request(DEFAULT_TOKEN_GCP_URI), - "MetadataServiceAccount", - ) + .token(&|| metadata_request(DEFAULT_TOKEN_GCP_URI)) .await } } @@ -81,6 +78,15 @@ impl TokenProvider for MetadataServiceAccount { Ok(token) } + async fn email(&self) -> Result { + let email = self + .client + .request(metadata_request(DEFAULT_SERVICE_ACCOUNT_EMAIL_URI)) + .await?; + + String::from_utf8(email.to_vec()).map_err(|_| Error::Str("invalid UTF-8 email")) + } + async fn project_id(&self) -> Result, Error> { Ok(self.project_id.clone()) } @@ -100,3 +106,5 @@ const DEFAULT_PROJECT_ID_GCP_URI: &str = "http://metadata.google.internal/computeMetadata/v1/project/project-id"; const DEFAULT_TOKEN_GCP_URI: &str = "http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/token"; +const DEFAULT_SERVICE_ACCOUNT_EMAIL_URI: &str = + "http://metadata.google.internal/computeMetadata/v1/instance/service-accounts/default/email"; diff --git a/src/types.rs b/src/types.rs index d2d2c33..b81cb80 100644 --- a/src/types.rs +++ b/src/types.rs @@ -7,6 +7,7 @@ use std::{env, fmt}; use bytes::Buf; use chrono::{DateTime, Utc}; +use http::Method; use http_body_util::{BodyExt, Full}; use hyper::body::Bytes; use hyper::Request; @@ -48,20 +49,16 @@ impl HttpClient { pub(crate) async fn token( &self, request: &impl Fn() -> Request>, - provider: &'static str, ) -> Result, Error> { let mut retries = 0; let body = loop { - let err = match self.request(request(), provider).await { + let err = match self.request(request()).await { // Early return when the request succeeds Ok(body) => break body, Err(err) => err, }; - warn!( - ?err, - provider, retries, "failed to refresh token, trying again..." - ); + warn!(?err, retries, "failed to refresh token, trying again..."); retries += 1; if retries >= RETRY_COUNT { @@ -73,12 +70,27 @@ impl HttpClient { .map_err(|err| Error::Json("failed to deserialize token from response", err)) } - pub(crate) async fn request( - &self, - req: Request>, - provider: &'static str, - ) -> Result { - debug!(url = ?req.uri(), provider, "requesting token"); + pub(crate) async fn token_info(&self, token: &Token) -> Result { + let req = Request::builder() + .method(Method::GET) + .uri(format!( + "https://oauth2.googleapis.com/tokeninfo?access_token={}", + token.as_str() + )) + .body(Full::from(Bytes::new())) + .map_err(|err| Error::Other("failed to build HTTP request", Box::new(err)))?; + + let body = self + .request(req) + .await + .map_err(|err| Error::Other("failed to fetch token info", Box::new(err)))?; + + serde_json::from_slice(&body) + .map_err(|err| Error::Json("failed to deserialize token info from response", err)) + } + + pub(crate) async fn request(&self, req: Request>) -> Result { + debug!(url = ?req.uri(), "requesting token"); let (parts, body) = self .inner .request(req) @@ -304,6 +316,11 @@ impl fmt::Debug for AuthorizedUserRefreshToken { } } +#[derive(Deserialize)] +pub(crate) struct TokenInfo { + pub(crate) email: String, +} + /// How many times to attempt to fetch a token from the set credentials token endpoint. const RETRY_COUNT: u8 = 5;