Skip to content

Commit

Permalink
Merge pull request #15 from cerberauth/pkce-support
Browse files Browse the repository at this point in the history
feat: add oauth pkce extension support
  • Loading branch information
emmanuelgautier authored Mar 14, 2024
2 parents 1541566 + d9821d0 commit 691b658
Show file tree
Hide file tree
Showing 10 changed files with 100 additions and 70 deletions.
3 changes: 3 additions & 0 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@
"kind": "bin"
}
},
"env": {
"RUST_BACKTRACE": "full"
},
"args": [],
"cwd": "${workspaceFolder}/baffao-proxy"
},
Expand Down
24 changes: 15 additions & 9 deletions baffao-core/src/handlers/authorize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,19 +14,25 @@ pub fn oauth2_authorize(
query: Option<AuthorizationQuery>,
client: OAuthClient,
CookiesConfig {
csrf: csrf_cookie, ..
oauth_csrf: oauth_csrf_cookie,
oauth_pkce: oauth_pkce_cookie,
..
}: CookiesConfig,
) -> (CookieJar, StatusCode, String) {
let scopes = query
.map(|q| q.scope.unwrap_or_default())
.unwrap_or_default()
.split_whitespace()
.map(|s| s.to_string())
.collect();
let (url, csrf_token) = client.get_authorization_url(scopes);
let scope = query
.and_then(|q| q.scope)
.map(|scope| scope.split(' ').map(String::from).collect());
let (url, csrf_token, pkce_code_verifier) = client.build_authorization_url(scope);

(
jar.add(new_cookie(csrf_cookie, csrf_token.secret().to_string())),
jar.add(new_cookie(
oauth_csrf_cookie,
csrf_token.secret().to_string(),
))
.add(new_cookie(
oauth_pkce_cookie,
pkce_code_verifier.secret().to_string(),
)),
StatusCode::TEMPORARY_REDIRECT,
url.to_string(),
)
Expand Down
51 changes: 37 additions & 14 deletions baffao-core/src/handlers/callback.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,44 +23,67 @@ pub async fn oauth2_callback(
query: AuthorizationCallbackQuery,
client: OAuthClient,
CookiesConfig {
csrf: csrf_cookie,
oauth_csrf: oauth_csrf_cookie,
oauth_pkce: oauth_pkce_cookie,
access_token: access_token_cookie,
refresh_token: refresh_token_cookie,
session: session_cookie,
..
}: CookiesConfig,
) -> Result<(CookieJar, StatusCode, String), Error> {
let pkce_code = jar
.get(csrf_cookie.name.as_str())
.map(|cookie| cookie.value().to_string())
.unwrap_or_default();
let response = match client
.exchange_code(query.code, pkce_code, query.state.clone())
.get(oauth_csrf_cookie.name.as_str())
.map(|cookie| cookie.value().to_string());
if pkce_code.is_none() {
return Err(anyhow::anyhow!("CSRF token not found"));
} else if pkce_code.unwrap() != query.state {
return Err(anyhow::anyhow!("CSRF token mismatch"));
}

let pkce_verifier = jar
.get(oauth_pkce_cookie.name.as_str())
.map(|cookie| cookie.value().to_string());
if pkce_verifier.is_none() {
return Err(anyhow::anyhow!("PKCE verifier not found"));
}

if query.code.is_empty() {
return Err(anyhow::anyhow!("Authorization code not found"));
}

let mut updated_jar = jar
.remove(Cookie::from(oauth_csrf_cookie.name))
.remove(Cookie::from(oauth_pkce_cookie.name));

let token_result = match client
.exchange_code(query.code, pkce_verifier.unwrap())
.await
{
Ok(response) => response,
Err(e) => {
return Err(e);
Err(_) => {
return Ok((
updated_jar,
StatusCode::INTERNAL_SERVER_ERROR,
"/error".to_string(),
));
}
};

let mut updated_jar = jar.clone();
updated_jar = updated_jar.remove(Cookie::from(csrf_cookie.name));
updated_jar = updated_jar.add(new_cookie(
access_token_cookie,
response.access_token().secret().to_string(),
token_result.access_token().secret().to_string(),
));
updated_jar = if response.refresh_token().is_some() {
updated_jar = if token_result.refresh_token().is_some() {
updated_jar.add(new_cookie(
refresh_token_cookie,
response.refresh_token().unwrap().secret().to_string(),
token_result.refresh_token().unwrap().secret().to_string(),
))
} else {
updated_jar.remove(Cookie::from(refresh_token_cookie.name))
};

let now = Utc::now();
let expires_in = response.expires_in().map(|duration| {
let expires_in = token_result.expires_in().map(|duration| {
now.checked_add_signed(Duration::from_std(duration).unwrap())
.unwrap()
});
Expand Down
8 changes: 0 additions & 8 deletions baffao-core/src/identity.rs

This file was deleted.

1 change: 0 additions & 1 deletion baffao-core/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,4 @@ pub mod oauth;
pub mod session;

pub mod cookies;
pub mod identity;
pub mod settings;
50 changes: 28 additions & 22 deletions baffao-core/src/oauth/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use oauth2::{
basic::{BasicClient, BasicTokenType},
reqwest::async_http_client,
AuthType, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, EmptyExtraTokenFields,
RedirectUrl, Scope, StandardTokenResponse, TokenUrl,
PkceCodeChallenge, PkceCodeVerifier, RedirectUrl, Scope, StandardTokenResponse, TokenUrl,
};
use reqwest::Url;

Expand All @@ -25,9 +25,12 @@ impl Clone for OAuthClient {

impl OAuthClient {
pub fn new(config: OAuthConfig) -> Result<Self, Error> {
let redirect_uri = RedirectUrl::new(config.authorization_redirect_uri.clone())?;
let auth_url = AuthUrl::new(config.authorization_url.clone())?;
let token_url = TokenUrl::new(config.token_url.clone())?;
let redirect_uri = RedirectUrl::new(config.authorization_redirect_uri.clone())
.context("Failed to parse redirect uri")?;
let auth_url = AuthUrl::new(config.authorization_url.clone())
.context("Failed to parse authorization url")?;
let token_url =
TokenUrl::new(config.token_url.clone()).context("Failed to parse token url")?;

let client = BasicClient::new(
ClientId::new(config.client_id.clone()),
Expand All @@ -41,33 +44,36 @@ impl OAuthClient {
Ok(Self { config, client })
}

pub fn get_authorization_url(&self, scope: Vec<String>) -> (Url, CsrfToken) {
let mut request = self.client.authorize_url(CsrfToken::new_random);
if !scope.is_empty() {
request = request.add_scope(Scope::new(scope.join(" ")));
}
pub fn build_authorization_url(
&self,
scope: Option<Vec<String>>,
) -> (Url, CsrfToken, PkceCodeVerifier) {
let scopes =
scope.unwrap_or_else(|| self.config.default_scopes.clone().unwrap_or_default());
let (pkce_code_challenge, pkce_code_verifier) = PkceCodeChallenge::new_random_sha256();

let (url, csrf_token) = self
.client
.authorize_url(CsrfToken::new_random)
.add_scopes(scopes.iter().map(|s| Scope::new(s.clone())))
.set_pkce_challenge(pkce_code_challenge)
.url();

return request.url();
(url, csrf_token, pkce_code_verifier)
}

pub async fn exchange_code(
&self,
code: String,
csrf_token: String,
state: String,
pkce_verifier: String,
) -> Result<StandardTokenResponse<EmptyExtraTokenFields, BasicTokenType>, Error> {
if state != csrf_token {
return Err(anyhow::anyhow!("Invalid state"));
}

let code = AuthorizationCode::new(code);
let token = self
let token_result = self
.client
.exchange_code(code)
.exchange_code(AuthorizationCode::new(code))
.set_pkce_verifier(PkceCodeVerifier::new(pkce_verifier))
.request_async(async_http_client)
.await
.context("Failed to exchange code")?;
.await?;

Ok(token)
Ok(token_result)
}
}
3 changes: 1 addition & 2 deletions baffao-core/src/oauth/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,9 @@ use serde::Deserialize;
pub struct OAuthConfig {
pub client_id: String,
pub client_secret: String,
pub metadata_url: Option<String>,
pub authorization_redirect_uri: String,
pub authorization_url: String,
pub token_url: String,
pub userinfo_url: Option<String>,
pub redirect_uri: Option<String>,
pub default_scopes: Option<Vec<String>>,
}
4 changes: 2 additions & 2 deletions baffao-core/src/settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,9 +59,9 @@ impl CookieConfig {

#[derive(Deserialize, Clone)]
pub struct CookiesConfig {
pub csrf: CookieConfig,
pub oauth_csrf: CookieConfig,
pub oauth_pkce: CookieConfig,
pub access_token: CookieConfig,
pub refresh_token: CookieConfig,
pub id_token: CookieConfig,
pub session: CookieConfig,
}
17 changes: 9 additions & 8 deletions baffao-proxy/config/default.toml
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
[server]
[server.cookies]

[server.cookies.csrf]
name = "oauth.csrf_token"
[server.cookies.oauth_csrf]
name = "oauth.csrf"
secure = true
http_only = true
same_site = "Strict"

[server.cookies.access_token]
name = "oauth.access_token"
[server.cookies.oauth_pkce]
name = "oauth.pkce_verifier"
secure = true
http_only = true
same_site = "Strict"

[server.cookies.refresh_token]
name = "oauth.refresh_token"
[server.cookies.access_token]
name = "oauth.access_token"
secure = true
http_only = true
same_site = "Strict"

[server.cookies.id_token]
name = "oauth.id_token"
[server.cookies.refresh_token]
name = "oauth.refresh_token"
secure = true
http_only = true
same_site = "Strict"
Expand All @@ -33,3 +33,4 @@ same_site = "Strict"

[oauth]
redirect_uri = "/"
default_scopes = ["offline_access"]
9 changes: 5 additions & 4 deletions baffao-proxy/config/development.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,22 @@ base_url = "http://localhost:3000"

[server.cookies]

[server.cookies.csrf]
[server.cookies.oauth_csrf]
domain = ""
secure = false
same_site = "Lax"

[server.cookies.access_token]
[server.cookies.oauth_pkce]
domain = ""
secure = false
same_site = "Lax"

[server.cookies.refresh_token]
[server.cookies.access_token]
domain = ""
secure = false
same_site = "Lax"

[server.cookies.id_token]
[server.cookies.refresh_token]
domain = ""
secure = false
same_site = "Lax"
Expand All @@ -34,3 +34,4 @@ same_site = "Lax"

[oauth]
authorization_redirect_uri = "http://localhost:3000/oauth/callback"
# default_scopes = ["openid", "email", "profile", "offline_access"]

0 comments on commit 691b658

Please sign in to comment.