Skip to content

Commit

Permalink
return more errors in a standard compliant way
Browse files Browse the repository at this point in the history
  • Loading branch information
aumetra committed Dec 14, 2024
1 parent 8fa1a9c commit 042e17e
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 54 deletions.
60 changes: 60 additions & 0 deletions lib/komainu/src/error.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use thiserror::Error;
use serde::Serialize;
use strum::AsRefStr;

type BoxError = Box<dyn std::error::Error + Send + Sync>;

Expand Down Expand Up @@ -30,3 +32,61 @@ impl Error {
Self::Query(err.into())
}
}

impl From<Error> for OAuthError {
#[track_caller]
fn from(value: Error) -> Self {
debug!(error = ?value);

match value {
Error::Body(..) | Error::MissingParam | Error::Query(..) => Self::InvalidRequest,
Error::Unauthorized => Self::AccessDenied,
}
}
}

#[derive(AsRefStr, Serialize)]
#[serde(rename_all = "snake_case")]
#[strum(serialize_all = "snake_case")]
pub enum OAuthError {
InvalidRequest,
UnauthorizedClient,
AccessDenied,
UnsupportedResponseType,
InvalidScope,
ServerError,
TemporarilyUnavailable,
}

#[derive(Serialize)]
pub struct OAuthErrorResponse {
pub error: OAuthError,
}

macro_rules! fallible {
($op:expr) => {{
match { $op } {
Ok(val) => val,
Err(error) => {
debug!(?error);
$crate::error::yield_error!(error);
}
}
}};
}

macro_rules! yield_error {
(@ser $error:expr) => {{
return ::http::Response::builder()
.status(::http::StatusCode::BAD_REQUEST)
.body(sonic_rs::to_vec(&$error).unwrap().into())
.unwrap();
}};
($error:expr) => {{
$crate::error::yield_error!(@ser $crate::error::OAuthErrorResponse {
error: $error.into(),
});
}};
}

pub(crate) use {fallible, yield_error};
50 changes: 25 additions & 25 deletions lib/komainu/src/flow/authorization.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use super::TokenResponse;
use crate::{
error::Result, params::ParamStorage, Authorization, ClientExtractor, Error, OptionExt,
error::{fallible, yield_error, Result},
params::ParamStorage,
Authorization, ClientExtractor, Error, OptionExt,
};
use bytes::Bytes;
use headers::HeaderMapExt;
Expand All @@ -23,12 +25,12 @@ pub async fn perform<CE, I>(
req: http::Request<Bytes>,
client_extractor: CE,
token_issuer: I,
) -> Result<http::Response<Bytes>>
) -> http::Response<Bytes>
where
CE: ClientExtractor,
I: Issuer,
{
let body: ParamStorage<&str, &str> = crate::deserialize_body(&req)?;
let body: ParamStorage<&str, &str> = fallible!(crate::deserialize_body(&req));

let basic_auth = req
.headers()
Expand All @@ -41,54 +43,52 @@ where

// As a fallback, try to read from the body.
// Not recommended but some clients do this. Done to increase compatibility.
let client_id = body.get("client_id").or_missing_param()?;
let client_secret = body.get("client_secret").or_missing_param()?;
let client_id = fallible!(body.get("client_id").or_missing_param());
let client_secret = fallible!(body.get("client_secret").or_missing_param());

(*client_id, *client_secret)
};

let grant_type = body.get("grant_type").or_missing_param()?;
let code = body.get("code").or_missing_param()?;
let redirect_uri = body.get("redirect_uri").or_missing_param()?;
let grant_type = fallible!(body.get("grant_type").or_missing_param());
let code = fallible!(body.get("code").or_missing_param());
let redirect_uri = fallible!(body.get("redirect_uri").or_missing_param());

if *grant_type != "authorization_code" {
error!(?client_id, "grant_type is not authorization_code");
return Err(Error::Unauthorized);
yield_error!(Error::Unauthorized);
}

let client = client_extractor
.extract(client_id, Some(client_secret))
.await?;
let client = fallible!(
client_extractor
.extract(client_id, Some(client_secret))
.await
);

if client.redirect_uri != *redirect_uri {
error!(?client_id, "redirect uri doesn't match");
return Err(Error::Unauthorized);
yield_error!(Error::Unauthorized);
}

let authorization = token_issuer
.load_authorization(code)
.await?
.or_unauthorized()?;
let maybe_authorization = fallible!(token_issuer.load_authorization(code).await);
let authorization = fallible!(maybe_authorization.or_unauthorized());

// This check is constant time :3
if client != authorization.client {
return Err(Error::Unauthorized);
yield_error!(Error::Unauthorized);
}

if let Some(ref pkce) = authorization.pkce_payload {
let code_verifier = body.get("code_verifier").or_unauthorized()?;
pkce.verify(code_verifier)?;
let code_verifier = fallible!(body.get("code_verifier").or_unauthorized());
fallible!(pkce.verify(code_verifier));
}

let token = token_issuer.issue_token(&authorization).await?;
let token = fallible!(token_issuer.issue_token(&authorization).await);
let body = sonic_rs::to_vec(&token).unwrap();

debug!("token successfully issued. building response");

let response = http::Response::builder()
http::Response::builder()
.status(http::StatusCode::OK)
.body(body.into())
.unwrap();

Ok(response)
.unwrap()
}
32 changes: 16 additions & 16 deletions lib/komainu/src/flow/refresh.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::TokenResponse;
use crate::{
error::{Error, Result},
error::{fallible, yield_error, Error, Result},
params::ParamStorage,
Client, ClientExtractor, OptionExt,
};
Expand All @@ -21,12 +21,12 @@ pub async fn perform<CE, I>(
req: http::Request<Bytes>,
client_extractor: CE,
token_issuer: I,
) -> Result<http::Response<Bytes>>
) -> http::Response<Bytes>
where
CE: ClientExtractor,
I: Issuer,
{
let body: ParamStorage<&str, &str> = crate::deserialize_body(&req)?;
let body: ParamStorage<&str, &str> = fallible!(crate::deserialize_body(&req));

let basic_auth = req
.headers()
Expand All @@ -39,33 +39,33 @@ where

// As a fallback, try to read from the body.
// Not recommended but some clients do this. Done to increase compatibility.
let client_id = body.get("client_id").or_missing_param()?;
let client_secret = body.get("client_secret").or_missing_param()?;
let client_id = fallible!(body.get("client_id").or_missing_param());
let client_secret = fallible!(body.get("client_secret").or_missing_param());

(*client_id, *client_secret)
};

let grant_type = body.get("grant_type").or_missing_param()?;
let refresh_token = body.get("refresh_token").or_missing_param()?;
let grant_type = fallible!(body.get("grant_type").or_missing_param());
let refresh_token = fallible!(body.get("refresh_token").or_missing_param());

if *grant_type != "refresh_token" {
debug!(?client_id, "grant_type is not refresh_token");
return Err(Error::Unauthorized);
yield_error!(Error::Unauthorized);
}

let client = client_extractor
.extract(client_id, Some(client_secret))
.await?;
let client = fallible!(
client_extractor
.extract(client_id, Some(client_secret))
.await
);

let token = token_issuer.issue_token(&client, refresh_token).await?;
let token = fallible!(token_issuer.issue_token(&client, refresh_token).await);
let body = sonic_rs::to_vec(&token).unwrap();

debug!("token successfully issued. building response");

let response = http::Response::builder()
http::Response::builder()
.status(http::StatusCode::OK)
.body(body.into())
.unwrap();

Ok(response)
.unwrap()
}
17 changes: 4 additions & 13 deletions lib/komainu/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,20 +76,11 @@ pub trait ClientExtractor {
) -> impl Future<Output = Result<Client<'_>>> + Send;
}

#[derive(AsRefStr)]
#[strum(serialize_all = "snake_case")]
pub enum OAuthError {
InvalidRequest,
UnauthorizedClient,
AccessDenied,
UnsupportedResponseType,
InvalidScope,
ServerError,
TemporarilyUnavailable,
}

#[inline]
fn deserialize_body<'a, T: serde::Deserialize<'a>>(req: &'a http::Request<Bytes>) -> Result<T> {
fn deserialize_body<'a, T>(req: &'a http::Request<Bytes>) -> Result<T>
where
T: serde::Deserialize<'a>,
{
// Not part of the RFC, but a bunch of implementations allow this.
// And because they allow this, clients make use of this.
//
Expand Down

0 comments on commit 042e17e

Please sign in to comment.