From 042e17e88392d96ec2d6a9f2d31a62ec79c005e0 Mon Sep 17 00:00:00 2001 From: aumetra Date: Sat, 14 Dec 2024 01:29:40 +0100 Subject: [PATCH] return more errors in a standard compliant way --- lib/komainu/src/error.rs | 60 +++++++++++++++++++++++++++ lib/komainu/src/flow/authorization.rs | 50 +++++++++++----------- lib/komainu/src/flow/refresh.rs | 32 +++++++------- lib/komainu/src/lib.rs | 17 ++------ 4 files changed, 105 insertions(+), 54 deletions(-) diff --git a/lib/komainu/src/error.rs b/lib/komainu/src/error.rs index 62dfd9e51..eaf6115fe 100644 --- a/lib/komainu/src/error.rs +++ b/lib/komainu/src/error.rs @@ -1,4 +1,6 @@ use thiserror::Error; +use serde::Serialize; +use strum::AsRefStr; type BoxError = Box; @@ -30,3 +32,61 @@ impl Error { Self::Query(err.into()) } } + +impl From for OAuthError { + #[track_caller] + fn from(value: Error) -> Self { + debug!(error = ?value); + + match value { + Error::Body(..) | Error::MissingParam | Error::Query(..) => Self::InvalidRequest, + Error::Unauthorized => Self::AccessDenied, + } + } +} + +#[derive(AsRefStr, Serialize)] +#[serde(rename_all = "snake_case")] +#[strum(serialize_all = "snake_case")] +pub enum OAuthError { + InvalidRequest, + UnauthorizedClient, + AccessDenied, + UnsupportedResponseType, + InvalidScope, + ServerError, + TemporarilyUnavailable, +} + +#[derive(Serialize)] +pub struct OAuthErrorResponse { + pub error: OAuthError, +} + +macro_rules! fallible { + ($op:expr) => {{ + match { $op } { + Ok(val) => val, + Err(error) => { + debug!(?error); + $crate::error::yield_error!(error); + } + } + }}; +} + +macro_rules! yield_error { + (@ser $error:expr) => {{ + return ::http::Response::builder() + .status(::http::StatusCode::BAD_REQUEST) + .body(sonic_rs::to_vec(&$error).unwrap().into()) + .unwrap(); + }}; + ($error:expr) => {{ + $crate::error::yield_error!(@ser $crate::error::OAuthErrorResponse { + error: $error.into(), + }); + }}; +} + +pub(crate) use {fallible, yield_error}; diff --git a/lib/komainu/src/flow/authorization.rs b/lib/komainu/src/flow/authorization.rs index a35a174c3..777289fbb 100644 --- a/lib/komainu/src/flow/authorization.rs +++ b/lib/komainu/src/flow/authorization.rs @@ -1,6 +1,8 @@ use super::TokenResponse; use crate::{ - error::Result, params::ParamStorage, Authorization, ClientExtractor, Error, OptionExt, + error::{fallible, yield_error, Result}, + params::ParamStorage, + Authorization, ClientExtractor, Error, OptionExt, }; use bytes::Bytes; use headers::HeaderMapExt; @@ -23,12 +25,12 @@ pub async fn perform( req: http::Request, client_extractor: CE, token_issuer: I, -) -> Result> +) -> http::Response where CE: ClientExtractor, I: Issuer, { - let body: ParamStorage<&str, &str> = crate::deserialize_body(&req)?; + let body: ParamStorage<&str, &str> = fallible!(crate::deserialize_body(&req)); let basic_auth = req .headers() @@ -41,54 +43,52 @@ where // As a fallback, try to read from the body. // Not recommended but some clients do this. Done to increase compatibility. - let client_id = body.get("client_id").or_missing_param()?; - let client_secret = body.get("client_secret").or_missing_param()?; + let client_id = fallible!(body.get("client_id").or_missing_param()); + let client_secret = fallible!(body.get("client_secret").or_missing_param()); (*client_id, *client_secret) }; - let grant_type = body.get("grant_type").or_missing_param()?; - let code = body.get("code").or_missing_param()?; - let redirect_uri = body.get("redirect_uri").or_missing_param()?; + let grant_type = fallible!(body.get("grant_type").or_missing_param()); + let code = fallible!(body.get("code").or_missing_param()); + let redirect_uri = fallible!(body.get("redirect_uri").or_missing_param()); if *grant_type != "authorization_code" { error!(?client_id, "grant_type is not authorization_code"); - return Err(Error::Unauthorized); + yield_error!(Error::Unauthorized); } - let client = client_extractor - .extract(client_id, Some(client_secret)) - .await?; + let client = fallible!( + client_extractor + .extract(client_id, Some(client_secret)) + .await + ); if client.redirect_uri != *redirect_uri { error!(?client_id, "redirect uri doesn't match"); - return Err(Error::Unauthorized); + yield_error!(Error::Unauthorized); } - let authorization = token_issuer - .load_authorization(code) - .await? - .or_unauthorized()?; + let maybe_authorization = fallible!(token_issuer.load_authorization(code).await); + let authorization = fallible!(maybe_authorization.or_unauthorized()); // This check is constant time :3 if client != authorization.client { - return Err(Error::Unauthorized); + yield_error!(Error::Unauthorized); } if let Some(ref pkce) = authorization.pkce_payload { - let code_verifier = body.get("code_verifier").or_unauthorized()?; - pkce.verify(code_verifier)?; + let code_verifier = fallible!(body.get("code_verifier").or_unauthorized()); + fallible!(pkce.verify(code_verifier)); } - let token = token_issuer.issue_token(&authorization).await?; + let token = fallible!(token_issuer.issue_token(&authorization).await); let body = sonic_rs::to_vec(&token).unwrap(); debug!("token successfully issued. building response"); - let response = http::Response::builder() + http::Response::builder() .status(http::StatusCode::OK) .body(body.into()) - .unwrap(); - - Ok(response) + .unwrap() } diff --git a/lib/komainu/src/flow/refresh.rs b/lib/komainu/src/flow/refresh.rs index 234fb11b8..fb661d37a 100644 --- a/lib/komainu/src/flow/refresh.rs +++ b/lib/komainu/src/flow/refresh.rs @@ -1,6 +1,6 @@ use super::TokenResponse; use crate::{ - error::{Error, Result}, + error::{fallible, yield_error, Error, Result}, params::ParamStorage, Client, ClientExtractor, OptionExt, }; @@ -21,12 +21,12 @@ pub async fn perform( req: http::Request, client_extractor: CE, token_issuer: I, -) -> Result> +) -> http::Response where CE: ClientExtractor, I: Issuer, { - let body: ParamStorage<&str, &str> = crate::deserialize_body(&req)?; + let body: ParamStorage<&str, &str> = fallible!(crate::deserialize_body(&req)); let basic_auth = req .headers() @@ -39,33 +39,33 @@ where // As a fallback, try to read from the body. // Not recommended but some clients do this. Done to increase compatibility. - let client_id = body.get("client_id").or_missing_param()?; - let client_secret = body.get("client_secret").or_missing_param()?; + let client_id = fallible!(body.get("client_id").or_missing_param()); + let client_secret = fallible!(body.get("client_secret").or_missing_param()); (*client_id, *client_secret) }; - let grant_type = body.get("grant_type").or_missing_param()?; - let refresh_token = body.get("refresh_token").or_missing_param()?; + let grant_type = fallible!(body.get("grant_type").or_missing_param()); + let refresh_token = fallible!(body.get("refresh_token").or_missing_param()); if *grant_type != "refresh_token" { debug!(?client_id, "grant_type is not refresh_token"); - return Err(Error::Unauthorized); + yield_error!(Error::Unauthorized); } - let client = client_extractor - .extract(client_id, Some(client_secret)) - .await?; + let client = fallible!( + client_extractor + .extract(client_id, Some(client_secret)) + .await + ); - let token = token_issuer.issue_token(&client, refresh_token).await?; + let token = fallible!(token_issuer.issue_token(&client, refresh_token).await); let body = sonic_rs::to_vec(&token).unwrap(); debug!("token successfully issued. building response"); - let response = http::Response::builder() + http::Response::builder() .status(http::StatusCode::OK) .body(body.into()) - .unwrap(); - - Ok(response) + .unwrap() } diff --git a/lib/komainu/src/lib.rs b/lib/komainu/src/lib.rs index acd13ca18..afa255860 100644 --- a/lib/komainu/src/lib.rs +++ b/lib/komainu/src/lib.rs @@ -76,20 +76,11 @@ pub trait ClientExtractor { ) -> impl Future>> + Send; } -#[derive(AsRefStr)] -#[strum(serialize_all = "snake_case")] -pub enum OAuthError { - InvalidRequest, - UnauthorizedClient, - AccessDenied, - UnsupportedResponseType, - InvalidScope, - ServerError, - TemporarilyUnavailable, -} - #[inline] -fn deserialize_body<'a, T: serde::Deserialize<'a>>(req: &'a http::Request) -> Result { +fn deserialize_body<'a, T>(req: &'a http::Request) -> Result +where + T: serde::Deserialize<'a>, +{ // Not part of the RFC, but a bunch of implementations allow this. // And because they allow this, clients make use of this. //