From 814d30bd493282d87bec67afdd18c4c3d3aa6dd6 Mon Sep 17 00:00:00 2001 From: steven leadbeater Date: Sat, 16 Nov 2024 19:03:19 +0000 Subject: [PATCH] OAuth open URL in default browser, set success message and throw descriptive errors --- CHANGELOG.md | 3 + Cargo.lock | 37 +++++++++ oauth/Cargo.toml | 1 + oauth/examples/oauth.rs | 2 +- oauth/src/lib.rs | 161 +++++++++++++++++++++++++++++++++++----- src/main.rs | 1 + 6 files changed, 187 insertions(+), 18 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 82bfb094a..a4c08ac93 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ## [Unreleased] - YYYY-MM-DD ### Changed +- [oauth] Open authorization URL in default browser +- [oauth] Allow optionally passing success message to display on browser return page +- [oauth] Throw specific errors on failure states ### Added diff --git a/Cargo.lock b/Cargo.lock index 4938ac8df..828b19c94 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1681,6 +1681,25 @@ version = "2.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ddc24109865250148c2e0f3d25d4f0f479571723792d3802153c60922a4fb708" +[[package]] +name = "is-docker" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "928bae27f42bc99b60d9ac7334e3a21d10ad8f1835a4e12ec3ec0464765ed1b3" +dependencies = [ + "once_cell", +] + +[[package]] +name = "is-wsl" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "173609498df190136aa7dea1a91db051746d339e18476eed5ca40521f02d7aa5" +dependencies = [ + "is-docker", + "once_cell", +] + [[package]] name = "is_terminal_polyfill" version = "1.70.1" @@ -2061,6 +2080,7 @@ dependencies = [ "env_logger", "log", "oauth2", + "open", "thiserror", "url", ] @@ -2471,6 +2491,17 @@ version = "1.20.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1261fe7e33c73b354eab43b1273a57c8f967d0391e80353e51f764ac02cf6775" +[[package]] +name = "open" +version = "5.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3ecd52f0b8d15c40ce4820aa251ed5de032e5d91fab27f7db2f40d42a8bdf69c" +dependencies = [ + "is-wsl", + "libc", + "pathdiff", +] + [[package]] name = "openssl-probe" version = "0.1.5" @@ -2534,6 +2565,12 @@ version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" +[[package]] +name = "pathdiff" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d61c5ce1153ab5b689d0c074c4e7fc613e942dfb7dd9eea5ab202d2ad91fe361" + [[package]] name = "pbkdf2" version = "0.12.2" diff --git a/oauth/Cargo.toml b/oauth/Cargo.toml index 32148b598..decefef78 100644 --- a/oauth/Cargo.toml +++ b/oauth/Cargo.toml @@ -13,6 +13,7 @@ log = "0.4" oauth2 = "4.4" thiserror = "1.0" url = "2.2" +open = "5.3.1" [dev-dependencies] env_logger = { version = "0.11.2", default-features = false, features = ["color", "humantime", "auto-color"] } diff --git a/oauth/examples/oauth.rs b/oauth/examples/oauth.rs index 76ff088e3..836fb7747 100644 --- a/oauth/examples/oauth.rs +++ b/oauth/examples/oauth.rs @@ -25,7 +25,7 @@ fn main() { return; }; - match get_access_token(client_id, redirect_uri, scopes) { + match get_access_token(client_id, redirect_uri, scopes, None) { Ok(token) => println!("Success: {token:#?}"), Err(e) => println!("Failed: {e}"), }; diff --git a/oauth/src/lib.rs b/oauth/src/lib.rs index 591e65594..2b1e41c40 100644 --- a/oauth/src/lib.rs +++ b/oauth/src/lib.rs @@ -34,6 +34,9 @@ pub enum OAuthError { #[error("Auth code param not found in URI {uri}")] AuthCodeNotFound { uri: String }, + #[error("CSRF token param not found in URI {uri}")] + CsrfTokenNotFound { uri: String }, + #[error("Failed to read redirect URI from stdin")] AuthCodeStdinRead, @@ -63,6 +66,12 @@ pub enum OAuthError { #[error("Failed to exchange code for access token ({e})")] ExchangeCode { e: String }, + + #[error("Spotify did not provide a refresh token")] + NoRefreshToken, + + #[error("Spotify did not return the token scopes")] + NoTokenScopes, } #[derive(Debug)] @@ -74,20 +83,38 @@ pub struct OAuthToken { pub scopes: Vec, } -/// Return code query-string parameter from the redirect URI. -fn get_code(redirect_url: &str) -> Result { +/// Return URL from the redirect URI &str. +fn get_url(redirect_url: &str) -> Result { let url = Url::parse(redirect_url).map_err(|e| OAuthError::AuthCodeBadUri { uri: redirect_url.to_string(), e, })?; - let code = url - .query_pairs() - .find(|(key, _)| key == "code") - .map(|(_, code)| AuthorizationCode::new(code.into_owned())) + Ok(url) +} + +/// Return a query-string parameter from the redirect URI. +fn get_query_string_parameter(url: &Url, query_string_parameter_key: &str) -> Option { + url.query_pairs() + .find(|(key, _)| key == query_string_parameter_key) + .map(|(_, query_string_parameter)| query_string_parameter.into_owned()) +} + +/// Return state query-string parameter from the redirect URI (CSRF token). +fn get_state(url: &Url) -> Result { + let state = get_query_string_parameter(url, "state").ok_or(OAuthError::CsrfTokenNotFound { + uri: url.to_string(), + })?; + + Ok(state) +} + +/// Return code query-string parameter from the redirect URI. +fn get_code(url: &Url) -> Result { + let code = get_query_string_parameter(url, "code") + .map(AuthorizationCode::new) .ok_or(OAuthError::AuthCodeNotFound { - uri: redirect_url.to_string(), + uri: url.to_string(), })?; - Ok(code) } @@ -100,11 +127,16 @@ fn get_authcode_stdin() -> Result { .read_line(&mut buffer) .map_err(|_| OAuthError::AuthCodeStdinRead)?; - get_code(buffer.trim()) + let url = get_url(buffer.trim())?; + get_code(&url) } /// Spawn HTTP server at provided socket address to accept OAuth callback and return auth code. -fn get_authcode_listener(socket_address: SocketAddr) -> Result { +fn get_authcode_listener( + socket_address: SocketAddr, + csrf_token: CsrfToken, + success_message: Option, +) -> Result { let listener = TcpListener::bind(socket_address).map_err(|e| OAuthError::AuthCodeListenerBind { addr: socket_address, @@ -128,19 +160,28 @@ fn get_authcode_listener(socket_address: SocketAddr) -> Result, + success_message: Option, ) -> Result { let auth_url = AuthUrl::new("https://accounts.spotify.com/authorize".to_string()) .map_err(|_| OAuthError::InvalidSpotifyUri)?; @@ -195,16 +237,19 @@ pub fn get_access_token( .into_iter() .map(|s| Scope::new(s.into())) .collect(); - let (auth_url, _) = client + let (auth_url, csrf_token) = client .authorize_url(CsrfToken::new_random) .add_scopes(request_scopes) .set_pkce_challenge(pkce_challenge) .url(); println!("Browse to: {}", auth_url); + if let Err(err) = open::that(auth_url.to_string()) { + eprintln!("An error occurred when opening '{}': {}", auth_url, err) + } let code = match get_socket_address(redirect_uri) { - Some(addr) => get_authcode_listener(addr), + Some(addr) => get_authcode_listener(addr, csrf_token, success_message), _ => get_authcode_stdin(), }?; trace!("Exchange {code:?} for access token"); @@ -226,11 +271,17 @@ pub fn get_access_token( let token_scopes: Vec = match token.scopes() { Some(s) => s.iter().map(|s| s.to_string()).collect(), - _ => scopes.into_iter().map(|s| s.to_string()).collect(), + None => { + error!("Spotify did not return the token scopes."); + return Err(OAuthError::NoTokenScopes); + } }; let refresh_token = match token.refresh_token() { Some(t) => t.secret().to_string(), - _ => "".to_string(), // Spotify always provides a refresh token. + None => { + error!("Spotify did not provide a refresh token."); + return Err(OAuthError::NoRefreshToken); + } }; Ok(OAuthToken { access_token: token.access_token().secret().to_string(), @@ -284,4 +335,80 @@ mod test { Some(localhost_v6) ); } + #[test] + fn test_get_url_valid() { + let redirect_url = "https://example.com/callback?code=1234&state=abcd"; + let result = get_url(redirect_url); + assert!(result.is_ok()); + let url = result.unwrap(); + assert_eq!(url.as_str(), redirect_url); + } + + #[test] + fn test_get_url_invalid() { + let redirect_url = "invalid_url"; + let result = get_url(redirect_url); + assert!(result.is_err()); + if let Err(OAuthError::AuthCodeBadUri { uri, .. }) = result { + assert_eq!(uri, redirect_url.to_string()); + } else { + panic!("Expected OAuthError::AuthCodeBadUri"); + } + } + + #[test] + fn test_get_query_string_parameter_found() { + let url = Url::parse("https://example.com/callback?code=1234&state=abcd").unwrap(); + let key = "code"; + let result = get_query_string_parameter(&url, key); + assert_eq!(result, Some("1234".to_string())); + } + + #[test] + fn test_get_query_string_parameter_not_found() { + let url = Url::parse("https://example.com/callback?code=1234&state=abcd").unwrap(); + let key = "missing_key"; + let result = get_query_string_parameter(&url, key); + assert!(result.is_none()); + } + + #[test] + fn test_get_state_valid() { + let url = Url::parse("https://example.com/callback?state=abcd").unwrap(); + let result = get_state(&url); + assert!(result.is_ok()); + assert_eq!(result.unwrap(), "abcd"); + } + + #[test] + fn test_get_state_missing() { + let url = Url::parse("https://example.com/callback").unwrap(); + let result = get_state(&url); + assert!(result.is_err()); + if let Err(OAuthError::CsrfTokenNotFound { uri }) = result { + assert_eq!(uri, url.to_string()); + } else { + panic!("Expected OAuthError::CsrfTokenNotFound"); + } + } + + #[test] + fn test_get_code_valid() { + let url = Url::parse("https://example.com/callback?code=1234").unwrap(); + let result = get_code(&url); + assert!(result.is_ok()); + assert_eq!(result.unwrap().secret(), "1234"); + } + + #[test] + fn test_get_code_missing() { + let url = Url::parse("https://example.com/callback").unwrap(); + let result = get_code(&url); + assert!(result.is_err()); + if let Err(OAuthError::AuthCodeNotFound { uri }) = result { + assert_eq!(uri, url.to_string()); + } else { + panic!("Expected OAuthError::AuthCodeNotFound"); + } + } } diff --git a/src/main.rs b/src/main.rs index 2da9323a1..a1a396d24 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1896,6 +1896,7 @@ async fn main() { &setup.session_config.client_id, &format!("http://127.0.0.1{port_str}/login"), OAUTH_SCOPES.to_vec(), + None, ) { Ok(token) => token.access_token, Err(e) => {