From dccbd348e25a93db606fcc632c021e982b3c6bd8 Mon Sep 17 00:00:00 2001 From: Nex <60712924+NexRX@users.noreply.github.com> Date: Thu, 10 Aug 2023 02:13:37 +0100 Subject: [PATCH] Added ability to return an error for failed `SecuritySchema` checker. (#625) * feat: added ability for `securityscheme` checker to return `option` or `result` --- poem-openapi-derive/src/security_scheme.rs | 18 ++-- poem-openapi/src/auth/mod.rs | 32 ++++++- poem-openapi/src/docs/security_scheme.md | 33 ++++--- poem-openapi/src/lib.rs | 2 +- poem-openapi/src/types/mod.rs | 1 + poem-openapi/tests/security_scheme.rs | 105 ++++++++++++++++++++- 6 files changed, 164 insertions(+), 27 deletions(-) diff --git a/poem-openapi-derive/src/security_scheme.rs b/poem-openapi-derive/src/security_scheme.rs index c3787793b0..5619850fb5 100644 --- a/poem-openapi-derive/src/security_scheme.rs +++ b/poem-openapi-derive/src/security_scheme.rs @@ -440,11 +440,16 @@ pub(crate) fn generate(args: DeriveInput) -> GeneratorResult { let register_security_scheme = args.generate_register_security_scheme(&crate_name, &oai_typename)?; let from_request = args.generate_from_request(&crate_name); - let checker = args.checker.as_ref().map(|path| { - quote! { - let output = ::std::option::Option::ok_or(#path(&req, output).await, #crate_name::error::AuthorizationError)?; - } - }); + let path = args.checker.as_ref(); + + let output = match path { + Some(_) => quote! { + let output = #crate_name::__private::CheckerReturn::from(#path(&req, #from_request?).await).into_result()?; + }, + None => quote! { + let output = #from_request?; + }, + }; let expanded = quote! { #[#crate_name::__private::poem::async_trait] @@ -468,8 +473,7 @@ pub(crate) fn generate(args: DeriveInput) -> GeneratorResult { _param_opts: #crate_name::ExtractParamOptions, ) -> #crate_name::__private::poem::Result { let query = req.extensions().get::<#crate_name::__private::UrlQuery>().unwrap(); - let output = #from_request?; - #checker + #output ::std::result::Result::Ok(Self(output)) } } diff --git a/poem-openapi/src/auth/mod.rs b/poem-openapi/src/auth/mod.rs index 8814123d7b..4723342a2b 100644 --- a/poem-openapi/src/auth/mod.rs +++ b/poem-openapi/src/auth/mod.rs @@ -7,7 +7,7 @@ mod bearer; use poem::{Request, Result}; pub use self::{api_key::ApiKey, basic::Basic, bearer::Bearer}; -use crate::{base::UrlQuery, registry::MetaParamIn}; +use crate::{base::UrlQuery, error::AuthorizationError, registry::MetaParamIn}; /// Represents a basic authorization extractor. pub trait BasicAuthorization: Sized { @@ -31,3 +31,33 @@ pub trait ApiKeyAuthorization: Sized { in_type: MetaParamIn, ) -> Result; } + +/// Facilitates the conversion of `Option` into `Results`, for `SecuritySchema` checker. +#[doc(hidden)] +pub enum CheckerReturn { + Result(Result), + Option(Option), +} + +impl CheckerReturn { + pub fn into_result(self) -> Result { + match self { + Self::Result(result) => result, + Self::Option(option) => option.ok_or(AuthorizationError.into()), + } + } +} + +impl From> for CheckerReturn { + #[inline] + fn from(result: Result) -> Self { + Self::Result(result) + } +} + +impl From> for CheckerReturn { + #[inline] + fn from(option: Option) -> Self { + Self::Option(option) + } +} diff --git a/poem-openapi/src/docs/security_scheme.md b/poem-openapi/src/docs/security_scheme.md index 215afd6f9b..971477ccf4 100644 --- a/poem-openapi/src/docs/security_scheme.md +++ b/poem-openapi/src/docs/security_scheme.md @@ -2,21 +2,21 @@ Define a OpenAPI Security Scheme. # Macro parameters -| Attribute | Description | Type | Optional | -|--------------------|-----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|------------|----------| -| rename | Rename the security scheme. | string | Y | -| ty | The type of the security scheme. (api_key, basic, bearer, oauth2, openid_connect) | string | N | -| key_in | `api_key` The location of the API key. Valid values are "query", "header" or "cookie". (query, header, cookie) | string | Y | -| key_name | `api_key` The name of the header, query or cookie parameter to be used.. | string | Y | -| bearer_format | `bearer` A hint to the client to identify how the bearer token is formatted. Bearer tokens are usually generated by an authorization server, so this information is primarily for documentation purposes. | string | Y | -| flows | `oauth2` An object containing configuration information for the flow types supported. | OAuthFlows | Y | -| openid_connect_url | OpenId Connect URL to discover OAuth2 configuration values. | string | Y | -| checker | Specify a function to check the original authentication information and convert it to the return type of this function. This function must return `Option`, and return `None` if check fails. | string | Y | +| Attribute | Description | Type | Optional | +| ------------------ | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ | ---------- | -------- | +| rename | Rename the security scheme. | string | Y | +| ty | The type of the security scheme. (api_key, basic, bearer, oauth2, openid_connect) | string | N | +| key_in | `api_key` The location of the API key. Valid values are "query", "header" or "cookie". (query, header, cookie) | string | Y | +| key_name | `api_key` The name of the header, query or cookie parameter to be used.. | string | Y | +| bearer_format | `bearer` A hint to the client to identify how the bearer token is formatted. Bearer tokens are usually generated by an authorization server, so this information is primarily for documentation purposes. | string | Y | +| flows | `oauth2` An object containing configuration information for the flow types supported. | OAuthFlows | Y | +| openid_connect_url | OpenId Connect URL to discover OAuth2 configuration values. | string | Y | +| checker | Specify a function to check the original authentication information and convert it to the return type of this function. This function must return `Option` or `poem::Result`, with `None` meaning a General Authorization error and anĀ `Err` reflecting the error supplied. | string | Y | # OAuthFlows | Attribute | description | Type | Optional | -|--------------------|----------------------------------------------------------|-----------|----------| +| ------------------ | -------------------------------------------------------- | --------- | -------- | | implicit | Configuration for the OAuth Implicit flow | OAuthFlow | Y | | password | Configuration for the OAuth Resource Owner Password flow | OAuthFlow | Y | | client_credentials | Configuration for the OAuth Client Credentials flow | OAuthFlow | Y | @@ -24,10 +24,9 @@ Define a OpenAPI Security Scheme. # OAuthFlow -| Attribute | description | Type | Optional | -|-------------------|----------------------------------------------------------------------------------------------|-------------|----------| -| authorization_url | `implicit` `authorization_code` The authorization URL to be used for this flow. | string | Y | +| Attribute | description | Type | Optional | +| ----------------- | -------------------------------------------------------------------------------------------------- | ----------- | -------- | +| authorization_url | `implicit` `authorization_code` The authorization URL to be used for this flow. | string | Y | | token_url | `password` `client_credentials` `authorization_code` The token URL to be used for this flow. | string | Y | -| refresh_url | The URL to be used for obtaining refresh tokens. | string | Y | -| scopes | The available scopes for the OAuth2 security scheme. | OAuthScopes | Y | - +| refresh_url | The URL to be used for obtaining refresh tokens. | string | Y | +| scopes | The available scopes for the OAuth2 security scheme. | OAuthScopes | Y | diff --git a/poem-openapi/src/lib.rs b/poem-openapi/src/lib.rs index 254f1e248b..a3a1538c93 100644 --- a/poem-openapi/src/lib.rs +++ b/poem-openapi/src/lib.rs @@ -183,5 +183,5 @@ pub mod __private { pub use serde; pub use serde_json; - pub use crate::base::UrlQuery; + pub use crate::{auth::CheckerReturn, base::UrlQuery}; } diff --git a/poem-openapi/src/types/mod.rs b/poem-openapi/src/types/mod.rs index e1f72e652d..d240366ba2 100644 --- a/poem-openapi/src/types/mod.rs +++ b/poem-openapi/src/types/mod.rs @@ -463,6 +463,7 @@ mod tests { #[test] #[allow(clippy::assertions_on_constants)] + #[allow(unused_allocation)] fn box_type() { assert!(Box::::IS_REQUIRED); assert_eq!(Box::::name(), "integer(int32)"); diff --git a/poem-openapi/tests/security_scheme.rs b/poem-openapi/tests/security_scheme.rs index dd69d318de..30f0cf3d7d 100644 --- a/poem-openapi/tests/security_scheme.rs +++ b/poem-openapi/tests/security_scheme.rs @@ -1,7 +1,9 @@ use poem::{ - http::header, + error::ResponseError, + http::{header, StatusCode}, test::TestClient, web::{cookie::Cookie, headers}, + Request, }; use poem_openapi::{ auth::{ApiKey, Basic, Bearer}, @@ -435,3 +437,104 @@ async fn oauth2_auth() { } ); } + +#[tokio::test] +async fn checker_result() { + #[derive(SecurityScheme)] + #[oai(rename = "Checker Option", ty = "basic", checker = "extract_string")] + struct MySecurityScheme(Basic); + + #[derive(Debug, thiserror::Error)] + #[error("Your account is disabled")] + struct AccountDisabledError; + + impl ResponseError for AccountDisabledError { + fn status(&self) -> StatusCode { + StatusCode::FORBIDDEN + } + } + + async fn extract_string(_req: &Request, basic: Basic) -> poem::Result { + if basic.username != "Disabled" { + Ok(basic) + } else { + Err(AccountDisabledError)? + } + } + + let mut registry = Registry::new(); + MySecurityScheme::register(&mut registry); + + struct MyApi; + + #[OpenApi] + impl MyApi { + #[oai(path = "/test", method = "get")] + async fn test(&self, auth: MySecurityScheme) -> PlainText { + PlainText(format!("Authed: {}", auth.0.username)) + } + } + + let service = OpenApiService::new(MyApi, "test", "1.0"); + let client = TestClient::new(service); + let resp = client + .get("/test") + .typed_header(headers::Authorization::basic("Enabled", "password")) + .send() + .await; + resp.assert_status_is_ok(); + resp.assert_text("Authed: Enabled".to_string()).await; + + let resp = client + .get("/test") + .typed_header(headers::Authorization::basic("Disabled", "password")) + .send() + .await; + resp.assert_status(StatusCode::FORBIDDEN); + resp.assert_text("Your account is disabled").await; +} + +#[tokio::test] +async fn checker_option() { + #[derive(SecurityScheme)] + #[oai(rename = "Checker Option", ty = "basic", checker = "extract_string")] + struct MySecurityScheme(Basic); + + async fn extract_string(_req: &Request, basic: Basic) -> Option { + if basic.username != "Disabled" { + Some(basic) + } else { + None + } + } + + let mut registry = Registry::new(); + MySecurityScheme::register(&mut registry); + + struct MyApi; + + #[OpenApi] + impl MyApi { + #[oai(path = "/test", method = "get")] + async fn test(&self, auth: MySecurityScheme) -> PlainText { + PlainText(format!("Authed: {}", auth.0.username)) + } + } + + let service = OpenApiService::new(MyApi, "test", "1.0"); + let client = TestClient::new(service); + let resp = client + .get("/test") + .typed_header(headers::Authorization::basic("Enabled", "password")) + .send() + .await; + resp.assert_status_is_ok(); + resp.assert_text("Authed: Enabled".to_string()).await; + + let resp = client + .get("/test") + .typed_header(headers::Authorization::basic("Disabled", "password")) + .send() + .await; + resp.assert_status(StatusCode::UNAUTHORIZED); +}