From 9744b663c3e89a490b733c4926c84164cfbc0357 Mon Sep 17 00:00:00 2001 From: boxbeam Date: Wed, 29 May 2024 02:21:19 -0400 Subject: [PATCH] test(webserver): create mock oauth client and write unit test for oauth login flow (#2238) * test(webserver): create mock oauth client and write unit test for oauth login flow * [autofix.ci] apply automated fixes --------- Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> --- ee/tabby-webserver/src/oauth/mod.rs | 30 ++++++ ee/tabby-webserver/src/service/auth.rs | 132 +++++++++++++++++++------ 2 files changed, 134 insertions(+), 28 deletions(-) diff --git a/ee/tabby-webserver/src/oauth/mod.rs b/ee/tabby-webserver/src/oauth/mod.rs index 7b2fcc818f41..bde967a4547d 100644 --- a/ee/tabby-webserver/src/oauth/mod.rs +++ b/ee/tabby-webserver/src/oauth/mod.rs @@ -30,3 +30,33 @@ pub fn new_oauth_client( OAuthProvider::Github => Arc::new(GithubClient::new(auth)), } } + +#[cfg(test)] +pub mod test_client { + use super::*; + + pub struct TestOAuthClient { + pub access_token_response: fn() -> Result, + pub user_email: String, + pub user_name: String, + } + + #[async_trait] + impl OAuthClient for TestOAuthClient { + async fn exchange_code_for_token(&self, _code: String) -> Result { + (self.access_token_response)() + } + + async fn fetch_user_email(&self, _access_token: &str) -> Result { + Ok(self.user_email.clone()) + } + + async fn fetch_user_full_name(&self, _access_token: &str) -> Result { + Ok(self.user_name.clone()) + } + + async fn get_authorization_url(&self) -> Result { + Ok("https://example.com".into()) + } + } +} diff --git a/ee/tabby-webserver/src/service/auth.rs b/ee/tabby-webserver/src/service/auth.rs index ccaba48b1265..a5b124820dea 100644 --- a/ee/tabby-webserver/src/service/auth.rs +++ b/ee/tabby-webserver/src/service/auth.rs @@ -29,7 +29,7 @@ use super::graphql_pagination_to_filter; use crate::{ bail, jwt::{generate_jwt, validate_jwt}, - oauth, + oauth::{self, OAuthClient}, }; #[derive(Clone)] @@ -427,27 +427,21 @@ impl AuthenticationService for AuthenticationServiceImpl { provider: OAuthProvider, ) -> std::result::Result { let client = oauth::new_oauth_client(provider, Arc::new(self.clone())); - let access_token = client.exchange_code_for_token(code).await?; - let email = client.fetch_user_email(&access_token).await?; - let name = client.fetch_user_full_name(&access_token).await?; let license = self .license .read() .await .context("Failed to read license info")?; - let user_id = - get_or_create_oauth_user(&license, &self.db, &self.setting, &self.mail, &email, &name) - .await?; - let refresh_token = self.db.create_refresh_token(user_id).await?; - - let access_token = generate_jwt(user_id.as_id()).map_err(|_| OAuthError::Unknown)?; - - let resp = OAuthResponse { - access_token, - refresh_token, - }; - Ok(resp) + oauth_login( + client, + code, + &self.db, + &*self.setting, + &license, + &*self.mail, + ) + .await } async fn read_oauth_credential( @@ -521,11 +515,35 @@ impl AuthenticationService for AuthenticationServiceImpl { } } +async fn oauth_login( + client: Arc, + code: String, + db: &DbConn, + setting: &dyn SettingService, + license: &LicenseInfo, + mail: &dyn EmailService, +) -> Result { + let access_token = client.exchange_code_for_token(code).await?; + let email = client.fetch_user_email(&access_token).await?; + let name = client.fetch_user_full_name(&access_token).await?; + let user_id = get_or_create_oauth_user(license, db, setting, mail, &email, &name).await?; + + let refresh_token = db.create_refresh_token(user_id).await?; + + let access_token = generate_jwt(user_id.as_id()).map_err(|_| OAuthError::Unknown)?; + + let resp = OAuthResponse { + access_token, + refresh_token, + }; + Ok(resp) +} + async fn get_or_create_oauth_user( license: &LicenseInfo, db: &DbConn, - setting: &Arc, - mail: &Arc, + setting: &dyn SettingService, + mail: &dyn EmailService, email: &str, name: &str, ) -> Result { @@ -713,11 +731,14 @@ mod tests { use serial_test::serial; use tabby_schema::{ juniper::relay::{self, Connection}, - license::{LicenseInfo, LicenseStatus}, + license::{LicenseInfo, LicenseStatus, LicenseType}, }; use super::*; - use crate::service::email::{new_email_service, testutils::TestEmailServer}; + use crate::{ + oauth::test_client::TestOAuthClient, + service::email::{new_email_service, testutils::TestEmailServer}, + }; #[test] fn test_password_hash() { @@ -955,8 +976,8 @@ mod tests { let res = get_or_create_oauth_user( &license, &service.db, - &setting, - &service.mail, + &*setting, + &*service.mail, "test@example.com", "", ) @@ -972,8 +993,8 @@ mod tests { let res = get_or_create_oauth_user( &license, &service.db, - &setting, - &service.mail, + &*setting, + &*service.mail, "example@example.com", "Example User", ) @@ -990,8 +1011,8 @@ mod tests { let res = get_or_create_oauth_user( &license, &service.db, - &setting, - &service.mail, + &*setting, + &*service.mail, "example@gmail.com", "", ) @@ -1007,8 +1028,8 @@ mod tests { let res = get_or_create_oauth_user( &license, &service.db, - &setting, - &service.mail, + &*setting, + &*service.mail, "example@gmail.com", "User 3 by Invitation", ) @@ -1474,4 +1495,59 @@ mod tests { assert_eq!(cred.client_id, "id"); assert_eq!(cred.client_secret, "secret"); } + + #[tokio::test] + async fn test_oauth_login() { + let service = test_authentication_service().await; + let license = LicenseInfo { + r#type: LicenseType::Enterprise, + status: LicenseStatus::Ok, + seats: 1000, + seats_used: 0, + issued_at: None, + expires_at: None, + }; + + let client = Arc::new(TestOAuthClient { + access_token_response: || Ok("faketoken".into()), + user_email: "user@example.com".into(), + user_name: "user".into(), + }); + + service + .create_invitation("user@example.com".into()) + .await + .unwrap(); + + let response = oauth_login( + client, + "fakecode".into(), + &service.db, + &*service.setting, + &license, + &*service.mail, + ) + .await + .unwrap(); + + assert!(!response.access_token.is_empty()); + + let client = Arc::new(TestOAuthClient { + access_token_response: || Err(anyhow!("bad auth")), + user_email: "user@example.com".into(), + user_name: "user".into(), + }); + + let response = oauth_login( + client, + "fakecode".into(), + &service.db, + &*service.setting, + &license, + &*service.mail, + ) + .await; + + assert!(response.is_err()); + } }