Skip to content

Commit

Permalink
fix(gemini): handle token expiration
Browse files Browse the repository at this point in the history
  • Loading branch information
GreeFine committed Dec 17, 2023
1 parent 87d744a commit 8a36480
Showing 1 changed file with 18 additions and 9 deletions.
27 changes: 18 additions & 9 deletions src/features/gemini.rs
Original file line number Diff line number Diff line change
@@ -1,25 +1,27 @@
use tokio::sync::OnceCell;

use procedural_macros::command;
use yup_oauth2::{AccessToken, ServiceAccountAuthenticator};
use yup_oauth2::{
authenticator::Authenticator, hyper::client, hyper_rustls, ServiceAccountAuthenticator,
};

use crate::core::commands::{CallBackParams, CallbackReturn};

static TOKEN: OnceCell<AccessToken> = OnceCell::const_new();
type SA = Authenticator<hyper_rustls::HttpsConnector<client::HttpConnector>>;

async fn get_token() -> AccessToken {
static SERVICE_ACCOUNT: OnceCell<SA> = OnceCell::const_new();
const SCOPES: [&str; 1] = ["https://www.googleapis.com/auth/cloud-platform"];

async fn get_token() -> SA {
let key_path = std::env::var("GOOGLE_APPLICATION_CREDENTIALS")
.unwrap_or_else(|_| "blackfoot-dev-bd1f97a0d61e.json".to_string());
let creds = yup_oauth2::read_service_account_key(key_path)
.await
.unwrap();
let sa = ServiceAccountAuthenticator::builder(creds)
ServiceAccountAuthenticator::builder(creds)
.build()
.await
.unwrap();
let scopes = &["https://www.googleapis.com/auth/cloud-platform"];

sa.token(scopes).await.unwrap()
.unwrap()
}

#[derive(Debug, Serialize, Deserialize)]
Expand Down Expand Up @@ -109,9 +111,16 @@ async fn query_gemini(question: &str) -> Result<Vec<GeminiResult>, &str> {
let url = format!("https://{API_ENDPOINT}/v1beta1/projects/{PROJECT_ID}/locations/{LOCATION_ID}/publishers/google/models/{MODEL_ID}:streamGenerateContent");
let body = GeminiBody::new(question);
let client = reqwest::Client::new();
let token = SERVICE_ACCOUNT
.get_or_init(get_token)
.await
.token(&SCOPES)
.await
.expect("service account token to call google services");
let token = token.token().unwrap();
let response = client
.post(url)
.bearer_auth(TOKEN.get_or_init(get_token).await.token().unwrap())
.bearer_auth(token)
.header("Content-Type", "application/json")
.json(&body)
.send()
Expand Down

0 comments on commit 8a36480

Please sign in to comment.