diff --git a/examples/github_async.rs b/examples/github_async.rs index 4add907..937e05b 100644 --- a/examples/github_async.rs +++ b/examples/github_async.rs @@ -14,7 +14,6 @@ //! use oauth2::basic::BasicClient; -// Alternatively, this can be `oauth2::curl::http_client` or a custom client. use oauth2::reqwest::async_http_client; use oauth2::{ AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, RedirectUrl, Scope, diff --git a/examples/google.rs b/examples/google.rs index fce2d0c..1f74b66 100644 --- a/examples/google.rs +++ b/examples/google.rs @@ -13,7 +13,7 @@ //! ...and follow the instructions. //! -use oauth2::{basic::BasicClient, revocation::StandardRevocableToken, TokenResponse}; +use oauth2::{basic::BasicClient, StandardRevocableToken, TokenResponse}; // Alternatively, this can be oauth2::curl::http_client or a custom. use oauth2::reqwest::http_client; use oauth2::{ diff --git a/examples/google_devicecode.rs b/examples/google_devicecode.rs index 0a38a92..6f4489f 100644 --- a/examples/google_devicecode.rs +++ b/examples/google_devicecode.rs @@ -15,9 +15,11 @@ use oauth2::basic::BasicClient; // Alternatively, this can be oauth2::curl::http_client or a custom. -use oauth2::devicecode::{DeviceAuthorizationResponse, ExtraDeviceAuthorizationFields}; use oauth2::reqwest::http_client; -use oauth2::{AuthType, AuthUrl, ClientId, ClientSecret, DeviceAuthorizationUrl, Scope, TokenUrl}; +use oauth2::{ + AuthType, AuthUrl, ClientId, ClientSecret, DeviceAuthorizationResponse, DeviceAuthorizationUrl, + ExtraDeviceAuthorizationFields, Scope, TokenUrl, +}; use serde::{Deserialize, Serialize}; use std::collections::HashMap; diff --git a/examples/microsoft_devicecode_common_user.rs b/examples/microsoft_devicecode_common_user.rs index fa3ddfd..a80800c 100644 --- a/examples/microsoft_devicecode_common_user.rs +++ b/examples/microsoft_devicecode_common_user.rs @@ -1,7 +1,8 @@ use oauth2::basic::BasicClient; -use oauth2::devicecode::StandardDeviceAuthorizationResponse; use oauth2::reqwest::async_http_client; -use oauth2::{AuthUrl, ClientId, DeviceAuthorizationUrl, Scope, TokenUrl}; +use oauth2::{ + AuthUrl, ClientId, DeviceAuthorizationUrl, Scope, StandardDeviceAuthorizationResponse, TokenUrl, +}; use std::error::Error; diff --git a/examples/microsoft_devicecode_tenant_user.rs b/examples/microsoft_devicecode_tenant_user.rs index 56a0ef0..29a2102 100644 --- a/examples/microsoft_devicecode_tenant_user.rs +++ b/examples/microsoft_devicecode_tenant_user.rs @@ -1,6 +1,6 @@ use oauth2::basic::BasicClient; -use oauth2::devicecode::StandardDeviceAuthorizationResponse; use oauth2::reqwest::async_http_client; +use oauth2::StandardDeviceAuthorizationResponse; use oauth2::{AuthUrl, ClientId, DeviceAuthorizationUrl, Scope, TokenUrl}; use std::error::Error; diff --git a/examples/wunderlist.rs b/examples/wunderlist.rs index e133e91..a93be24 100644 --- a/examples/wunderlist.rs +++ b/examples/wunderlist.rs @@ -14,16 +14,13 @@ //! ...and follow the instructions. //! -use oauth2::TokenType; -use oauth2::{ - basic::{ - BasicErrorResponse, BasicRevocationErrorResponse, BasicTokenIntrospectionResponse, - BasicTokenType, - }, - revocation::StandardRevocableToken, +use oauth2::basic::{ + BasicErrorResponse, BasicRevocationErrorResponse, BasicTokenIntrospectionResponse, + BasicTokenType, }; -// Alternatively, this can be `oauth2::curl::http_client` or a custom client. use oauth2::helpers; +use oauth2::{StandardRevocableToken, TokenType}; +// Alternatively, this can be `oauth2::curl::http_client` or a custom client. use oauth2::reqwest::http_client; use oauth2::{ AccessToken, AuthUrl, AuthorizationCode, Client, ClientId, ClientSecret, CsrfToken, diff --git a/src/client.rs b/src/client.rs new file mode 100644 index 0000000..5925500 --- /dev/null +++ b/src/client.rs @@ -0,0 +1,830 @@ +use crate::{ + AccessToken, AuthType, AuthUrl, AuthorizationCode, AuthorizationRequest, + ClientCredentialsTokenRequest, ClientId, ClientSecret, CodeTokenRequest, ConfigurationError, + CsrfToken, DeviceAccessTokenRequest, DeviceAuthorizationRequest, DeviceAuthorizationResponse, + DeviceAuthorizationUrl, ErrorResponse, ExtraDeviceAuthorizationFields, IntrospectionRequest, + IntrospectionUrl, PasswordTokenRequest, RedirectUrl, RefreshToken, RefreshTokenRequest, + ResourceOwnerPassword, ResourceOwnerUsername, RevocableToken, RevocationRequest, RevocationUrl, + TokenIntrospectionResponse, TokenResponse, TokenType, TokenUrl, +}; + +use chrono::Utc; + +use std::borrow::Cow; +use std::marker::PhantomData; +use std::sync::Arc; + +/// Stores the configuration for an OAuth2 client. +/// +/// This type implements the +/// [Builder Pattern](https://doc.rust-lang.org/1.0.0/style/ownership/builders.html) together with +/// [typestates](https://cliffle.com/blog/rust-typestate/#what-are-typestates) to encode whether +/// certain fields have been set that are prerequisites to certain authentication flows. For +/// example, the authorization endpoint must be set via [`Client::set_auth_url`] before +/// [`Client::authorize_url`] can be called. Each endpoint has a corresponding const generic +/// parameter (e.g., `HAS_AUTH_URL`) used to statically enforce these dependencies. These generics +/// are set automatically by the corresponding setter functions, and in most cases user code should +/// not need to deal with them directly. +/// +/// # Error Types +/// +/// To enable compile time verification that only the correct and complete set of errors for the `Client` function being +/// invoked are exposed to the caller, the `Client` type is specialized on multiple implementations of the +/// [`ErrorResponse`] trait. The exact [`ErrorResponse`] implementation returned varies by the RFC that the invoked +/// `Client` function implements: +/// +/// - Generic type `TE` (aka Token Error) for errors defined by [RFC 6749 OAuth 2.0 Authorization Framework](https://tools.ietf.org/html/rfc6749). +/// - Generic type `TRE` (aka Token Revocation Error) for errors defined by [RFC 7009 OAuth 2.0 Token Revocation](https://tools.ietf.org/html/rfc7009). +/// +/// For example when revoking a token, error code `unsupported_token_type` (from RFC 7009) may be returned: +/// ```rust +/// # use thiserror::Error; +/// # use http::status::StatusCode; +/// # use http::header::{HeaderValue, CONTENT_TYPE}; +/// # use oauth2::{*, basic::*}; +/// # let client = BasicClient::new(ClientId::new("aaa".to_string())) +/// # .set_client_secret(ClientSecret::new("bbb".to_string())) +/// # .set_auth_url(AuthUrl::new("https://example.com/auth".to_string()).unwrap()) +/// # .set_token_url(TokenUrl::new("https://example.com/token".to_string()).unwrap()) +/// # .set_revocation_uri(RevocationUrl::new("https://revocation/url".to_string()).unwrap()); +/// # +/// # #[derive(Debug, Error)] +/// # enum FakeError { +/// # #[error("error")] +/// # Err, +/// # } +/// # +/// # let http_client = |_| -> Result { +/// # Ok(HttpResponse { +/// # status_code: StatusCode::BAD_REQUEST, +/// # headers: vec![( +/// # CONTENT_TYPE, +/// # HeaderValue::from_str("application/json").unwrap(), +/// # )] +/// # .into_iter() +/// # .collect(), +/// # body: "{\"error\": \"unsupported_token_type\", \"error_description\": \"stuff happened\", \ +/// # \"error_uri\": \"https://errors\"}" +/// # .to_string() +/// # .into_bytes(), +/// # }) +/// # }; +/// # +/// let res = client +/// .revoke_token(AccessToken::new("some token".to_string()).into()) +/// .unwrap() +/// .request(http_client); +/// +/// assert!(matches!(res, Err( +/// RequestTokenError::ServerResponse(err)) if matches!(err.error(), +/// RevocationErrorResponseType::UnsupportedTokenType))); +/// ``` +/// +/// # Examples +/// +/// See the [crate] root documentation for usage examples. +#[derive(Clone, Debug)] +pub struct Client< + TE, + TR, + TT, + TIR, + RT, + TRE, + const HAS_AUTH_URL: bool = false, + const HAS_DEVICE_AUTH_URL: bool = false, + const HAS_INTROSPECTION_URL: bool = false, + const HAS_REVOCATION_URL: bool = false, + const HAS_TOKEN_URL: bool = false, +> where + TE: ErrorResponse, + TR: TokenResponse, + TT: TokenType, + TIR: TokenIntrospectionResponse, + RT: RevocableToken, + TRE: ErrorResponse, +{ + pub(crate) client_id: ClientId, + pub(crate) client_secret: Option, + pub(crate) auth_url: Option, + pub(crate) auth_type: AuthType, + pub(crate) token_url: Option, + pub(crate) redirect_url: Option, + pub(crate) introspection_url: Option, + pub(crate) revocation_url: Option, + pub(crate) device_authorization_url: Option, + pub(crate) phantom: PhantomData<(TE, TR, TT, TIR, RT, TRE)>, +} +impl Client +where + TE: ErrorResponse + 'static, + TR: TokenResponse, + TT: TokenType, + TIR: TokenIntrospectionResponse, + RT: RevocableToken, + TRE: ErrorResponse + 'static, +{ + /// Initializes an OAuth2 client with the specified client ID. + pub fn new(client_id: ClientId) -> Self { + Self { + client_id, + client_secret: None, + auth_url: None, + auth_type: AuthType::BasicAuth, + token_url: None, + redirect_url: None, + introspection_url: None, + revocation_url: None, + device_authorization_url: None, + phantom: PhantomData, + } + } +} +impl< + TE, + TR, + TT, + TIR, + RT, + TRE, + const HAS_AUTH_URL: bool, + const HAS_DEVICE_AUTH_URL: bool, + const HAS_INTROSPECTION_URL: bool, + const HAS_REVOCATION_URL: bool, + const HAS_TOKEN_URL: bool, + > + Client< + TE, + TR, + TT, + TIR, + RT, + TRE, + HAS_AUTH_URL, + HAS_DEVICE_AUTH_URL, + HAS_INTROSPECTION_URL, + HAS_REVOCATION_URL, + HAS_TOKEN_URL, + > +where + TE: ErrorResponse + 'static, + TR: TokenResponse, + TT: TokenType, + TIR: TokenIntrospectionResponse, + RT: RevocableToken, + TRE: ErrorResponse + 'static, +{ + /// Configures the type of client authentication used for communicating with the authorization + /// server. + /// + /// The default is to use HTTP Basic authentication, as recommended in + /// [Section 2.3.1 of RFC 6749](https://tools.ietf.org/html/rfc6749#section-2.3.1). Note that + /// if a client secret is omitted (i.e., `client_secret` is set to `None` when calling + /// [`Client::new`]), [`AuthType::RequestBody`] is used regardless of the `auth_type` passed to + /// this function. + pub fn set_auth_type(mut self, auth_type: AuthType) -> Self { + self.auth_type = auth_type; + + self + } + + /// Sets the authorization endpoint. + /// + /// The client uses the authorization endpoint to obtain authorization from the resource owner + /// via user-agent redirection. This URL is used in all standard OAuth2 flows except the + /// [Resource Owner Password Credentials Grant](https://tools.ietf.org/html/rfc6749#section-4.3) + /// and the [Client Credentials Grant](https://tools.ietf.org/html/rfc6749#section-4.4). + pub fn set_auth_url( + self, + auth_url: AuthUrl, + ) -> Client< + TE, + TR, + TT, + TIR, + RT, + TRE, + true, + HAS_DEVICE_AUTH_URL, + HAS_INTROSPECTION_URL, + HAS_REVOCATION_URL, + HAS_TOKEN_URL, + > { + Client { + client_id: self.client_id, + client_secret: self.client_secret, + auth_url: Some(auth_url), + auth_type: self.auth_type, + token_url: self.token_url, + redirect_url: self.redirect_url, + introspection_url: self.introspection_url, + revocation_url: self.revocation_url, + device_authorization_url: self.device_authorization_url, + phantom: self.phantom, + } + } + + /// Sets the client secret. + /// + /// A client secret is generally used for confidential (i.e., server-side) OAuth2 clients and + /// omitted from public (browser or native app) OAuth2 clients (see + /// [RFC 8252](https://tools.ietf.org/html/rfc8252)). + pub fn set_client_secret(mut self, client_secret: ClientSecret) -> Self { + self.client_secret = Some(client_secret); + + self + } + + /// Sets the device authorization URL used by the device authorization endpoint. + /// Used for Device Code Flow, as per [RFC 8628](https://tools.ietf.org/html/rfc8628). + pub fn set_device_authorization_url( + self, + device_authorization_url: DeviceAuthorizationUrl, + ) -> Client< + TE, + TR, + TT, + TIR, + RT, + TRE, + HAS_AUTH_URL, + true, + HAS_INTROSPECTION_URL, + HAS_REVOCATION_URL, + HAS_TOKEN_URL, + > { + Client { + client_id: self.client_id, + client_secret: self.client_secret, + auth_url: self.auth_url, + auth_type: self.auth_type, + token_url: self.token_url, + redirect_url: self.redirect_url, + introspection_url: self.introspection_url, + revocation_url: self.revocation_url, + device_authorization_url: Some(device_authorization_url), + phantom: self.phantom, + } + } + + /// Sets the introspection URL for contacting the ([RFC 7662](https://tools.ietf.org/html/rfc7662)) + /// introspection endpoint. + pub fn set_introspection_uri( + self, + introspection_url: IntrospectionUrl, + ) -> Client< + TE, + TR, + TT, + TIR, + RT, + TRE, + HAS_TOKEN_URL, + HAS_DEVICE_AUTH_URL, + true, + HAS_REVOCATION_URL, + HAS_TOKEN_URL, + > { + Client { + client_id: self.client_id, + client_secret: self.client_secret, + auth_url: self.auth_url, + auth_type: self.auth_type, + token_url: self.token_url, + redirect_url: self.redirect_url, + introspection_url: Some(introspection_url), + revocation_url: self.revocation_url, + device_authorization_url: self.device_authorization_url, + phantom: self.phantom, + } + } + + /// Sets the redirect URL used by the authorization endpoint. + pub fn set_redirect_uri(mut self, redirect_url: RedirectUrl) -> Self { + self.redirect_url = Some(redirect_url); + + self + } + + /// Sets the revocation URL for contacting the revocation endpoint ([RFC 7009](https://tools.ietf.org/html/rfc7009)). + /// + /// See: [`revoke_token()`](Self::revoke_token()) + pub fn set_revocation_uri( + self, + revocation_url: RevocationUrl, + ) -> Client< + TE, + TR, + TT, + TIR, + RT, + TRE, + HAS_TOKEN_URL, + HAS_DEVICE_AUTH_URL, + HAS_INTROSPECTION_URL, + true, + HAS_TOKEN_URL, + > { + Client { + client_id: self.client_id, + client_secret: self.client_secret, + auth_url: self.auth_url, + auth_type: self.auth_type, + token_url: self.token_url, + redirect_url: self.redirect_url, + introspection_url: self.introspection_url, + revocation_url: Some(revocation_url), + device_authorization_url: self.device_authorization_url, + phantom: self.phantom, + } + } + + /// Sets the token endpoint. + /// + /// The client uses the token endpoint to exchange an authorization code for an access token, + /// typically with client authentication. This URL is used in + /// all standard OAuth2 flows except the + /// [Implicit Grant](https://tools.ietf.org/html/rfc6749#section-4.2). + pub fn set_token_url( + self, + token_url: TokenUrl, + ) -> Client< + TE, + TR, + TT, + TIR, + RT, + TRE, + HAS_AUTH_URL, + HAS_DEVICE_AUTH_URL, + HAS_INTROSPECTION_URL, + HAS_REVOCATION_URL, + true, + > { + Client { + client_id: self.client_id, + client_secret: self.client_secret, + auth_url: self.auth_url, + auth_type: self.auth_type, + token_url: Some(token_url), + redirect_url: self.redirect_url, + introspection_url: self.introspection_url, + revocation_url: self.revocation_url, + device_authorization_url: self.device_authorization_url, + phantom: self.phantom, + } + } + + /// Returns the Client ID. + pub fn client_id(&self) -> &ClientId { + &self.client_id + } + + /// Returns the type of client authentication used for communicating with the authorization + /// server. + pub fn auth_type(&self) -> &AuthType { + &self.auth_type + } + + /// Returns the redirect URL used by the authorization endpoint. + pub fn redirect_url(&self) -> Option<&RedirectUrl> { + self.redirect_url.as_ref() + } +} + +// Methods requiring an authorization endpoint. +impl< + TE, + TR, + TT, + TIR, + RT, + TRE, + const HAS_DEVICE_AUTH_URL: bool, + const HAS_INTROSPECTION_URL: bool, + const HAS_REVOCATION_URL: bool, + const HAS_TOKEN_URL: bool, + > + Client< + TE, + TR, + TT, + TIR, + RT, + TRE, + true, + HAS_DEVICE_AUTH_URL, + HAS_INTROSPECTION_URL, + HAS_REVOCATION_URL, + HAS_TOKEN_URL, + > +where + TE: ErrorResponse + 'static, + TR: TokenResponse, + TT: TokenType, + TIR: TokenIntrospectionResponse, + RT: RevocableToken, + TRE: ErrorResponse + 'static, +{ + /// Returns the authorization endpoint. + pub fn auth_url(&self) -> &AuthUrl { + // This is enforced statically via the HAS_AUTH_URL const generic. + self.auth_url.as_ref().expect("should have auth_url") + } + + /// Generates an authorization URL for a new authorization request. + /// + /// # Arguments + /// + /// * `state_fn` - A function that returns an opaque value used by the client to maintain state + /// between the request and callback. The authorization server includes this value when + /// redirecting the user-agent back to the client. + /// + /// # Security Warning + /// + /// Callers should use a fresh, unpredictable `state` for each authorization request and verify + /// that this value matches the `state` parameter passed by the authorization server to the + /// redirect URI. Doing so mitigates + /// [Cross-Site Request Forgery](https://tools.ietf.org/html/rfc6749#section-10.12) + /// attacks. To disable CSRF protections (NOT recommended), use `insecure::authorize_url` + /// instead. + pub fn authorize_url(&self, state_fn: S) -> AuthorizationRequest + where + S: FnOnce() -> CsrfToken, + { + AuthorizationRequest { + // This is enforced statically via the HAS_AUTH_URL const generic. + auth_url: self.auth_url(), + client_id: &self.client_id, + extra_params: Vec::new(), + pkce_challenge: None, + redirect_url: self.redirect_url.as_ref().map(Cow::Borrowed), + response_type: "code".into(), + scopes: Vec::new(), + state: state_fn(), + } + } +} + +// Methods requiring a token endpoint. +impl< + TE, + TR, + TT, + TIR, + RT, + TRE, + const HAS_AUTH_URL: bool, + const HAS_DEVICE_AUTH_URL: bool, + const HAS_INTROSPECTION_URL: bool, + const HAS_REVOCATION_URL: bool, + > + Client< + TE, + TR, + TT, + TIR, + RT, + TRE, + HAS_AUTH_URL, + HAS_DEVICE_AUTH_URL, + HAS_INTROSPECTION_URL, + HAS_REVOCATION_URL, + true, + > +where + TE: ErrorResponse + 'static, + TR: TokenResponse, + TT: TokenType, + TIR: TokenIntrospectionResponse, + RT: RevocableToken, + TRE: ErrorResponse + 'static, +{ + /// Requests an access token for the *client credentials* grant type. + /// + /// See . + pub fn exchange_client_credentials(&self) -> ClientCredentialsTokenRequest { + ClientCredentialsTokenRequest { + auth_type: &self.auth_type, + client_id: &self.client_id, + client_secret: self.client_secret.as_ref(), + extra_params: Vec::new(), + scopes: Vec::new(), + // This is enforced statically via the HAS_TOKEN_URL const generic. + token_url: self.token_url.as_ref().expect("should have token_url"), + _phantom: PhantomData, + } + } + + /// Exchanges a code produced by a successful authorization process with an access token. + /// + /// Acquires ownership of the `code` because authorization codes may only be used once to + /// retrieve an access token from the authorization server. + /// + /// See . + pub fn exchange_code(&self, code: AuthorizationCode) -> CodeTokenRequest { + CodeTokenRequest { + auth_type: &self.auth_type, + client_id: &self.client_id, + client_secret: self.client_secret.as_ref(), + code, + extra_params: Vec::new(), + pkce_verifier: None, + // This is enforced statically via the HAS_TOKEN_URL const generic. + token_url: self.token_url.as_ref().expect("should have token_url"), + redirect_url: self.redirect_url.as_ref().map(Cow::Borrowed), + _phantom: PhantomData, + } + } + + /// Perform a device access token request as per + /// . + pub fn exchange_device_access_token<'a, 'b, 'c, EF>( + &'a self, + auth_response: &'b DeviceAuthorizationResponse, + ) -> DeviceAccessTokenRequest<'b, 'c, TR, TT, EF> + where + 'a: 'b, + EF: ExtraDeviceAuthorizationFields, + { + DeviceAccessTokenRequest { + auth_type: &self.auth_type, + client_id: &self.client_id, + client_secret: self.client_secret.as_ref(), + extra_params: Vec::new(), + // This is enforced statically via the HAS_TOKEN_URL const generic. + token_url: self.token_url.as_ref().expect("should have token_url"), + dev_auth_resp: auth_response, + time_fn: Arc::new(Utc::now), + max_backoff_interval: None, + _phantom: PhantomData, + } + } + + /// Requests an access token for the *password* grant type. + /// + /// See . + pub fn exchange_password<'a, 'b>( + &'a self, + username: &'b ResourceOwnerUsername, + password: &'b ResourceOwnerPassword, + ) -> PasswordTokenRequest<'b, TE, TR, TT> + where + 'a: 'b, + { + PasswordTokenRequest::<'b> { + auth_type: &self.auth_type, + client_id: &self.client_id, + client_secret: self.client_secret.as_ref(), + username, + password, + extra_params: Vec::new(), + scopes: Vec::new(), + // This is enforced statically via the HAS_TOKEN_URL const generic. + token_url: self.token_url.as_ref().expect("should have token_url"), + _phantom: PhantomData, + } + } + + /// Exchanges a refresh token for an access token + /// + /// See . + pub fn exchange_refresh_token<'a, 'b>( + &'a self, + refresh_token: &'b RefreshToken, + ) -> RefreshTokenRequest<'b, TE, TR, TT> + where + 'a: 'b, + { + RefreshTokenRequest { + auth_type: &self.auth_type, + client_id: &self.client_id, + client_secret: self.client_secret.as_ref(), + extra_params: Vec::new(), + refresh_token, + scopes: Vec::new(), + // This is enforced statically via the HAS_TOKEN_URL const generic. + token_url: self.token_url.as_ref().expect("should have token_url"), + _phantom: PhantomData, + } + } + + /// Returns the token endpoint. + pub fn token_url(&self) -> &TokenUrl { + // This is enforced statically via the HAS_TOKEN_URL const generic. + self.token_url.as_ref().expect("should have token_url") + } +} + +// Methods requiring a device authorization endpoint. +impl< + TE, + TR, + TT, + TIR, + RT, + TRE, + const HAS_AUTH_URL: bool, + const HAS_INTROSPECTION_URL: bool, + const HAS_REVOCATION_URL: bool, + const HAS_TOKEN_URL: bool, + > + Client< + TE, + TR, + TT, + TIR, + RT, + TRE, + HAS_AUTH_URL, + true, + HAS_INTROSPECTION_URL, + HAS_REVOCATION_URL, + HAS_TOKEN_URL, + > +where + TE: ErrorResponse + 'static, + TR: TokenResponse, + TT: TokenType, + TIR: TokenIntrospectionResponse, + RT: RevocableToken, + TRE: ErrorResponse + 'static, +{ + /// Perform a device authorization request as per + /// . + pub fn exchange_device_code(&self) -> DeviceAuthorizationRequest { + DeviceAuthorizationRequest { + auth_type: &self.auth_type, + client_id: &self.client_id, + client_secret: self.client_secret.as_ref(), + extra_params: Vec::new(), + scopes: Vec::new(), + // This is enforced statically via the HAS_DEVICE_AUTH_URL const generic. + device_authorization_url: self + .device_authorization_url + .as_ref() + .expect("should have device_authorization_url"), + _phantom: PhantomData, + } + } + + /// Returns the device authorization URL used by the device authorization endpoint. + pub fn device_authorization_url(&self) -> &DeviceAuthorizationUrl { + // This is enforced statically via the HAS_DEVICE_AUTH_URL const generic. + self.device_authorization_url + .as_ref() + .expect("should have device_authorization_url") + } +} + +// Methods requiring an introspection endpoint. +impl< + TE, + TR, + TT, + TIR, + RT, + TRE, + const HAS_AUTH_URL: bool, + const HAS_DEVICE_AUTH_URL: bool, + const HAS_REVOCATION_URL: bool, + const HAS_TOKEN_URL: bool, + > + Client< + TE, + TR, + TT, + TIR, + RT, + TRE, + HAS_AUTH_URL, + HAS_DEVICE_AUTH_URL, + true, + HAS_REVOCATION_URL, + HAS_TOKEN_URL, + > +where + TE: ErrorResponse + 'static, + TR: TokenResponse, + TT: TokenType, + TIR: TokenIntrospectionResponse, + RT: RevocableToken, + TRE: ErrorResponse + 'static, +{ + /// Query the authorization server [`RFC 7662 compatible`](https://tools.ietf.org/html/rfc7662) introspection + /// endpoint to determine the set of metadata for a previously received token. + /// + /// Requires [`set_introspection_uri()`](Self::set_introspection_uri) to have been previously + /// called to set the introspection endpoint URL. + pub fn introspect<'a>( + &'a self, + token: &'a AccessToken, + ) -> Result, ConfigurationError> { + Ok(IntrospectionRequest { + auth_type: &self.auth_type, + client_id: &self.client_id, + client_secret: self.client_secret.as_ref(), + extra_params: Vec::new(), + // This is enforced statically via the HAS_INTROSPECTION_URL const generic. + introspection_url: self + .introspection_url + .as_ref() + .expect("should have introspection_url"), + token, + token_type_hint: None, + _phantom: PhantomData, + }) + } + + /// Returns the introspection URL for contacting the ([RFC 7662](https://tools.ietf.org/html/rfc7662)) + /// introspection endpoint. + pub fn introspection_url(&self) -> &IntrospectionUrl { + // This is enforced statically via the HAS_INTROSPECTION_URL const generic. + self.introspection_url + .as_ref() + .expect("should have introspection_url") + } +} + +// Methods requiring a revocation endpoint. +impl< + TE, + TR, + TT, + TIR, + RT, + TRE, + const HAS_AUTH_URL: bool, + const HAS_DEVICE_AUTH_URL: bool, + const HAS_INTROSPECTION_URL: bool, + const HAS_TOKEN_URL: bool, + > + Client< + TE, + TR, + TT, + TIR, + RT, + TRE, + HAS_AUTH_URL, + HAS_DEVICE_AUTH_URL, + HAS_INTROSPECTION_URL, + true, + HAS_TOKEN_URL, + > +where + TE: ErrorResponse + 'static, + TR: TokenResponse, + TT: TokenType, + TIR: TokenIntrospectionResponse, + RT: RevocableToken, + TRE: ErrorResponse + 'static, +{ + /// Attempts to revoke the given previously received token using an + /// [RFC 7009 OAuth 2.0 Token Revocation](https://tools.ietf.org/html/rfc7009) compatible + /// endpoint. + /// + /// Requires [`set_revocation_uri()`](Self::set_revocation_uri) to have been previously + /// called to set the revocation endpoint URL. + pub fn revoke_token( + &self, + token: RT, + ) -> Result, ConfigurationError> { + // https://tools.ietf.org/html/rfc7009#section-2 states: + // "The client requests the revocation of a particular token by making an + // HTTP POST request to the token revocation endpoint URL. This URL + // MUST conform to the rules given in [RFC6749], Section 3.1. Clients + // MUST verify that the URL is an HTTPS URL." + + // This is enforced statically via the HAS_REVOCATION_URL const generic. + let revocation_url = self + .revocation_url + .as_ref() + .expect("should have revocation_url"); + + if revocation_url.url().scheme() != "https" { + return Err(ConfigurationError::InsecureUrl("revocation")); + } + + Ok(RevocationRequest { + auth_type: &self.auth_type, + client_id: &self.client_id, + client_secret: self.client_secret.as_ref(), + extra_params: Vec::new(), + revocation_url, + token, + _phantom: PhantomData, + }) + } + + /// Returns the revocation URL for contacting the revocation endpoint + /// ([RFC 7009](https://tools.ietf.org/html/rfc7009)). + /// + /// See: [`revoke_token()`](Self::revoke_token()) + pub fn revocation_url(&self) -> &RevocationUrl { + // This is enforced statically via the HAS_REVOCATION_URL const generic. + self.revocation_url + .as_ref() + .expect("should have revocation_url") + } +} diff --git a/src/code.rs b/src/code.rs new file mode 100644 index 0000000..abd241c --- /dev/null +++ b/src/code.rs @@ -0,0 +1,347 @@ +use crate::{AuthUrl, ClientId, CsrfToken, PkceCodeChallenge, RedirectUrl, ResponseType, Scope}; + +use url::Url; + +use std::borrow::Cow; + +/// A request to the authorization endpoint +#[derive(Debug)] +pub struct AuthorizationRequest<'a> { + pub(crate) auth_url: &'a AuthUrl, + pub(crate) client_id: &'a ClientId, + pub(crate) extra_params: Vec<(Cow<'a, str>, Cow<'a, str>)>, + pub(crate) pkce_challenge: Option, + pub(crate) redirect_url: Option>, + pub(crate) response_type: Cow<'a, str>, + pub(crate) scopes: Vec>, + pub(crate) state: CsrfToken, +} +impl<'a> AuthorizationRequest<'a> { + /// Appends a new scope to the authorization URL. + pub fn add_scope(mut self, scope: Scope) -> Self { + self.scopes.push(Cow::Owned(scope)); + self + } + + /// Appends a collection of scopes to the token request. + pub fn add_scopes(mut self, scopes: I) -> Self + where + I: IntoIterator, + { + self.scopes.extend(scopes.into_iter().map(Cow::Owned)); + self + } + + /// Appends an extra param to the authorization URL. + /// + /// This method allows extensions to be used without direct support from + /// this crate. If `name` conflicts with a parameter managed by this crate, the + /// behavior is undefined. In particular, do not set parameters defined by + /// [RFC 6749](https://tools.ietf.org/html/rfc6749) or + /// [RFC 7636](https://tools.ietf.org/html/rfc7636). + /// + /// # Security Warning + /// + /// Callers should follow the security recommendations for any OAuth2 extensions used with + /// this function, which are beyond the scope of + /// [RFC 6749](https://tools.ietf.org/html/rfc6749). + pub fn add_extra_param(mut self, name: N, value: V) -> Self + where + N: Into>, + V: Into>, + { + self.extra_params.push((name.into(), value.into())); + self + } + + /// Enables the [Implicit Grant](https://tools.ietf.org/html/rfc6749#section-4.2) flow. + pub fn use_implicit_flow(mut self) -> Self { + self.response_type = "token".into(); + self + } + + /// Enables custom flows other than the `code` and `token` (implicit flow) grant. + pub fn set_response_type(mut self, response_type: &ResponseType) -> Self { + self.response_type = (**response_type).to_owned().into(); + self + } + + /// Enables the use of [Proof Key for Code Exchange](https://tools.ietf.org/html/rfc7636) + /// (PKCE). + /// + /// PKCE is *highly recommended* for all public clients (i.e., those for which there + /// is no client secret or for which the client secret is distributed with the client, + /// such as in a native, mobile app, or browser app). + pub fn set_pkce_challenge(mut self, pkce_code_challenge: PkceCodeChallenge) -> Self { + self.pkce_challenge = Some(pkce_code_challenge); + self + } + + /// Overrides the `redirect_url` to the one specified. + pub fn set_redirect_uri(mut self, redirect_url: Cow<'a, RedirectUrl>) -> Self { + self.redirect_url = Some(redirect_url); + self + } + + /// Returns the full authorization URL and CSRF state for this authorization + /// request. + pub fn url(self) -> (Url, CsrfToken) { + let scopes = self + .scopes + .iter() + .map(|s| s.to_string()) + .collect::>() + .join(" "); + + let url = { + let mut pairs: Vec<(&str, &str)> = vec![ + ("response_type", self.response_type.as_ref()), + ("client_id", self.client_id), + ("state", self.state.secret()), + ]; + + if let Some(ref pkce_challenge) = self.pkce_challenge { + pairs.push(("code_challenge", pkce_challenge.as_str())); + pairs.push(("code_challenge_method", pkce_challenge.method().as_str())); + } + + if let Some(ref redirect_url) = self.redirect_url { + pairs.push(("redirect_uri", redirect_url.as_str())); + } + + if !scopes.is_empty() { + pairs.push(("scope", &scopes)); + } + + let mut url: Url = self.auth_url.url().to_owned(); + + url.query_pairs_mut() + .extend_pairs(pairs.iter().map(|&(k, v)| (k, v))); + + url.query_pairs_mut() + .extend_pairs(self.extra_params.iter().cloned()); + url + }; + + (url, self.state) + } +} + +#[cfg(test)] +mod tests { + use crate::basic::BasicClient; + use crate::tests::new_client; + use crate::{ + AuthUrl, ClientId, ClientSecret, CsrfToken, PkceCodeChallenge, PkceCodeVerifier, + RedirectUrl, ResponseType, Scope, TokenUrl, + }; + + use url::form_urlencoded::byte_serialize; + use url::Url; + + use std::borrow::Cow; + + #[test] + fn test_authorize_url() { + let client = new_client(); + let (url, _) = client + .authorize_url(|| CsrfToken::new("csrf_token".to_string())) + .url(); + + assert_eq!( + Url::parse( + "https://example.com/auth?response_type=code&client_id=aaa&state=csrf_token" + ) + .unwrap(), + url + ); + } + + #[test] + fn test_authorize_random() { + let client = new_client(); + let (url, csrf_state) = client.authorize_url(CsrfToken::new_random).url(); + + assert_eq!( + Url::parse(&format!( + "https://example.com/auth?response_type=code&client_id=aaa&state={}", + byte_serialize(csrf_state.secret().clone().into_bytes().as_slice()) + .collect::>() + .join("") + )) + .unwrap(), + url + ); + } + + #[test] + fn test_authorize_url_pkce() { + // Example from https://tools.ietf.org/html/rfc7636#appendix-B + let client = new_client(); + + let (url, _) = client + .authorize_url(|| CsrfToken::new("csrf_token".to_string())) + .set_pkce_challenge(PkceCodeChallenge::from_code_verifier_sha256( + &PkceCodeVerifier::new("dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk".to_string()), + )) + .url(); + assert_eq!( + Url::parse(concat!( + "https://example.com/auth", + "?response_type=code&client_id=aaa", + "&state=csrf_token", + "&code_challenge=E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM", + "&code_challenge_method=S256", + )) + .unwrap(), + url + ); + } + + #[test] + fn test_authorize_url_implicit() { + let client = new_client(); + + let (url, _) = client + .authorize_url(|| CsrfToken::new("csrf_token".to_string())) + .use_implicit_flow() + .url(); + + assert_eq!( + Url::parse( + "https://example.com/auth?response_type=token&client_id=aaa&state=csrf_token" + ) + .unwrap(), + url + ); + } + + #[test] + fn test_authorize_url_with_param() { + let client = BasicClient::new(ClientId::new("aaa".to_string())) + .set_client_secret(ClientSecret::new("bbb".to_string())) + .set_auth_url(AuthUrl::new("https://example.com/auth?foo=bar".to_string()).unwrap()) + .set_token_url(TokenUrl::new("https://example.com/token".to_string()).unwrap()); + + let (url, _) = client + .authorize_url(|| CsrfToken::new("csrf_token".to_string())) + .url(); + + assert_eq!( + Url::parse( + "https://example.com/auth?foo=bar&response_type=code&client_id=aaa&state=csrf_token" + ) + .unwrap(), + url + ); + } + + #[test] + fn test_authorize_url_with_scopes() { + let scopes = vec![ + Scope::new("read".to_string()), + Scope::new("write".to_string()), + ]; + let (url, _) = new_client() + .authorize_url(|| CsrfToken::new("csrf_token".to_string())) + .add_scopes(scopes) + .url(); + + assert_eq!( + Url::parse( + "https://example.com/auth\ + ?response_type=code\ + &client_id=aaa\ + &state=csrf_token\ + &scope=read+write" + ) + .unwrap(), + url + ); + } + + #[test] + fn test_authorize_url_with_one_scope() { + let (url, _) = new_client() + .authorize_url(|| CsrfToken::new("csrf_token".to_string())) + .add_scope(Scope::new("read".to_string())) + .url(); + + assert_eq!( + Url::parse( + "https://example.com/auth\ + ?response_type=code\ + &client_id=aaa\ + &state=csrf_token\ + &scope=read" + ) + .unwrap(), + url + ); + } + + #[test] + fn test_authorize_url_with_extension_response_type() { + let client = new_client(); + + let (url, _) = client + .authorize_url(|| CsrfToken::new("csrf_token".to_string())) + .set_response_type(&ResponseType::new("code token".to_string())) + .add_extra_param("foo", "bar") + .url(); + + assert_eq!( + Url::parse( + "https://example.com/auth?response_type=code+token&client_id=aaa&state=csrf_token\ + &foo=bar" + ) + .unwrap(), + url + ); + } + + #[test] + fn test_authorize_url_with_redirect_url() { + let client = new_client() + .set_redirect_uri(RedirectUrl::new("https://localhost/redirect".to_string()).unwrap()); + + let (url, _) = client + .authorize_url(|| CsrfToken::new("csrf_token".to_string())) + .url(); + + assert_eq!( + Url::parse( + "https://example.com/auth?response_type=code\ + &client_id=aaa\ + &state=csrf_token\ + &redirect_uri=https%3A%2F%2Flocalhost%2Fredirect" + ) + .unwrap(), + url + ); + } + + #[test] + fn test_authorize_url_with_redirect_url_override() { + let client = new_client() + .set_redirect_uri(RedirectUrl::new("https://localhost/redirect".to_string()).unwrap()); + + let (url, _) = client + .authorize_url(|| CsrfToken::new("csrf_token".to_string())) + .set_redirect_uri(Cow::Owned( + RedirectUrl::new("https://localhost/alternative".to_string()).unwrap(), + )) + .url(); + + assert_eq!( + Url::parse( + "https://example.com/auth?response_type=code\ + &client_id=aaa\ + &state=csrf_token\ + &redirect_uri=https%3A%2F%2Flocalhost%2Falternative" + ) + .unwrap(), + url + ); + } +} diff --git a/src/devicecode.rs b/src/devicecode.rs index 385112e..7b373cc 100644 --- a/src/devicecode.rs +++ b/src/devicecode.rs @@ -1,19 +1,376 @@ use crate::basic::BasicErrorResponseType; +use crate::endpoint::{endpoint_request, endpoint_response}; use crate::types::VerificationUriComplete; use crate::{ - DeviceCode, EndUserVerificationUrl, ErrorResponse, ErrorResponseType, RequestTokenError, - StandardErrorResponse, TokenResponse, TokenType, UserCode, + AuthType, ClientId, ClientSecret, DeviceAuthorizationUrl, DeviceCode, EndUserVerificationUrl, + ErrorResponse, ErrorResponseType, HttpRequest, HttpResponse, RequestTokenError, Scope, + StandardErrorResponse, TokenResponse, TokenType, TokenUrl, UserCode, }; +use chrono::{DateTime, Utc}; use serde::de::DeserializeOwned; use serde::{Deserialize, Serialize}; +use std::borrow::Cow; use std::error::Error; use std::fmt::Error as FormatterError; use std::fmt::{Debug, Display, Formatter}; +use std::future::Future; use std::marker::PhantomData; +use std::sync::Arc; use std::time::Duration; +/// The request for a set of verification codes from the authorization server. +/// +/// See . +#[derive(Debug)] +pub struct DeviceAuthorizationRequest<'a, TE> +where + TE: ErrorResponse, +{ + pub(crate) auth_type: &'a AuthType, + pub(crate) client_id: &'a ClientId, + pub(crate) client_secret: Option<&'a ClientSecret>, + pub(crate) extra_params: Vec<(Cow<'a, str>, Cow<'a, str>)>, + pub(crate) scopes: Vec>, + pub(crate) device_authorization_url: &'a DeviceAuthorizationUrl, + pub(crate) _phantom: PhantomData, +} + +impl<'a, TE> DeviceAuthorizationRequest<'a, TE> +where + TE: ErrorResponse + 'static, +{ + /// Appends an extra param to the token request. + /// + /// This method allows extensions to be used without direct support from + /// this crate. If `name` conflicts with a parameter managed by this crate, the + /// behavior is undefined. In particular, do not set parameters defined by + /// [RFC 6749](https://tools.ietf.org/html/rfc6749) or + /// [RFC 7636](https://tools.ietf.org/html/rfc7636). + /// + /// # Security Warning + /// + /// Callers should follow the security recommendations for any OAuth2 extensions used with + /// this function, which are beyond the scope of + /// [RFC 6749](https://tools.ietf.org/html/rfc6749). + pub fn add_extra_param(mut self, name: N, value: V) -> Self + where + N: Into>, + V: Into>, + { + self.extra_params.push((name.into(), value.into())); + self + } + + /// Appends a new scope to the token request. + pub fn add_scope(mut self, scope: Scope) -> Self { + self.scopes.push(Cow::Owned(scope)); + self + } + + /// Appends a collection of scopes to the token request. + pub fn add_scopes(mut self, scopes: I) -> Self + where + I: IntoIterator, + { + self.scopes.extend(scopes.into_iter().map(Cow::Owned)); + self + } + + fn prepare_request(self) -> HttpRequest { + endpoint_request( + self.auth_type, + self.client_id, + self.client_secret, + &self.extra_params, + None, + Some(&self.scopes), + self.device_authorization_url.url(), + vec![], + ) + } + + /// Synchronously sends the request to the authorization server and awaits a response. + pub fn request( + self, + http_client: F, + ) -> Result, RequestTokenError> + where + F: FnOnce(HttpRequest) -> Result, + RE: Error + 'static, + EF: ExtraDeviceAuthorizationFields, + { + endpoint_response(http_client(self.prepare_request())?) + } + + /// Asynchronously sends the request to the authorization server and returns a Future. + pub async fn request_async( + self, + http_client: C, + ) -> Result, RequestTokenError> + where + C: FnOnce(HttpRequest) -> F, + F: Future>, + RE: Error + 'static, + EF: ExtraDeviceAuthorizationFields, + { + let http_response = http_client(self.prepare_request()).await?; + endpoint_response(http_response) + } +} + +/// The request for a device access token from the authorization server. +/// +/// See . +#[derive(Clone)] +pub struct DeviceAccessTokenRequest<'a, 'b, TR, TT, EF> +where + TR: TokenResponse, + TT: TokenType, + EF: ExtraDeviceAuthorizationFields, +{ + pub(crate) auth_type: &'a AuthType, + pub(crate) client_id: &'a ClientId, + pub(crate) client_secret: Option<&'a ClientSecret>, + pub(crate) extra_params: Vec<(Cow<'a, str>, Cow<'a, str>)>, + pub(crate) token_url: &'a TokenUrl, + pub(crate) dev_auth_resp: &'a DeviceAuthorizationResponse, + pub(crate) time_fn: Arc DateTime + 'b + Send + Sync>, + pub(crate) max_backoff_interval: Option, + pub(crate) _phantom: PhantomData<(TR, TT, EF)>, +} + +impl<'a, 'b, TR, TT, EF> DeviceAccessTokenRequest<'a, 'b, TR, TT, EF> +where + TR: TokenResponse, + TT: TokenType, + EF: ExtraDeviceAuthorizationFields, +{ + /// Appends an extra param to the token request. + /// + /// This method allows extensions to be used without direct support from + /// this crate. If `name` conflicts with a parameter managed by this crate, the + /// behavior is undefined. In particular, do not set parameters defined by + /// [RFC 6749](https://tools.ietf.org/html/rfc6749) or + /// [RFC 7636](https://tools.ietf.org/html/rfc7636). + /// + /// # Security Warning + /// + /// Callers should follow the security recommendations for any OAuth2 extensions used with + /// this function, which are beyond the scope of + /// [RFC 6749](https://tools.ietf.org/html/rfc6749). + pub fn add_extra_param(mut self, name: N, value: V) -> Self + where + N: Into>, + V: Into>, + { + self.extra_params.push((name.into(), value.into())); + self + } + + /// Specifies a function for returning the current time. + /// + /// This function is used while polling the authorization server. + pub fn set_time_fn(mut self, time_fn: T) -> Self + where + T: Fn() -> DateTime + 'b + Send + Sync, + { + self.time_fn = Arc::new(time_fn); + self + } + + /// Sets the upper limit of the sleep interval to use for polling the token endpoint when the + /// HTTP client returns an error (e.g., in case of connection timeout). + pub fn set_max_backoff_interval(mut self, interval: Duration) -> Self { + self.max_backoff_interval = Some(interval); + self + } + + /// Synchronously polls the authorization server for a response, waiting + /// using a user defined sleep function. + pub fn request( + self, + http_client: F, + sleep_fn: S, + timeout: Option, + ) -> Result> + where + F: Fn(HttpRequest) -> Result, + S: Fn(Duration), + RE: Error + 'static, + { + // Get the request timeout and starting interval + let timeout_dt = self.compute_timeout(timeout)?; + let mut interval = self.dev_auth_resp.interval(); + + // Loop while requesting a token. + loop { + let now = (*self.time_fn)(); + if now > timeout_dt { + break Err(RequestTokenError::ServerResponse( + DeviceCodeErrorResponse::new( + DeviceCodeErrorResponseType::ExpiredToken, + Some(String::from("This device code has expired.")), + None, + ), + )); + } + + match self.process_response(http_client(self.prepare_request()), interval) { + DeviceAccessTokenPollResult::ContinueWithNewPollInterval(new_interval) => { + interval = new_interval + } + DeviceAccessTokenPollResult::Done(res, _) => break res, + } + + // Sleep here using the provided sleep function. + sleep_fn(interval); + } + } + + /// Asynchronously sends the request to the authorization server and awaits a response. + pub async fn request_async( + self, + http_client: C, + sleep_fn: S, + timeout: Option, + ) -> Result> + where + C: Fn(HttpRequest) -> F, + F: Future>, + S: Fn(Duration) -> SF, + SF: Future, + RE: Error + 'static, + { + // Get the request timeout and starting interval + let timeout_dt = self.compute_timeout(timeout)?; + let mut interval = self.dev_auth_resp.interval(); + + // Loop while requesting a token. + loop { + let now = (*self.time_fn)(); + if now > timeout_dt { + break Err(RequestTokenError::ServerResponse( + DeviceCodeErrorResponse::new( + DeviceCodeErrorResponseType::ExpiredToken, + Some(String::from("This device code has expired.")), + None, + ), + )); + } + + match self.process_response(http_client(self.prepare_request()).await, interval) { + DeviceAccessTokenPollResult::ContinueWithNewPollInterval(new_interval) => { + interval = new_interval + } + DeviceAccessTokenPollResult::Done(res, _) => break res, + } + + // Sleep here using the provided sleep function. + sleep_fn(interval).await; + } + } + + fn prepare_request(&self) -> HttpRequest { + endpoint_request( + self.auth_type, + self.client_id, + self.client_secret, + &self.extra_params, + None, + None, + self.token_url.url(), + vec![ + ("grant_type", "urn:ietf:params:oauth:grant-type:device_code"), + ("device_code", self.dev_auth_resp.device_code().secret()), + ], + ) + } + + fn process_response( + &self, + res: Result, + current_interval: Duration, + ) -> DeviceAccessTokenPollResult + where + RE: Error + 'static, + { + let http_response = match res { + Ok(inner) => inner, + Err(_) => { + // RFC 8628 requires a backoff in cases of connection timeout, but we can't + // distinguish between connection timeouts and other HTTP client request errors + // here. Set a maximum backoff so that the client doesn't effectively backoff + // infinitely when there are network issues unrelated to server load. + const DEFAULT_MAX_BACKOFF_INTERVAL: Duration = Duration::from_secs(10); + let new_interval = std::cmp::min( + current_interval.checked_mul(2).unwrap_or(current_interval), + self.max_backoff_interval + .unwrap_or(DEFAULT_MAX_BACKOFF_INTERVAL), + ); + return DeviceAccessTokenPollResult::ContinueWithNewPollInterval(new_interval); + } + }; + + // Explicitly process the response with a DeviceCodeErrorResponse + let res = endpoint_response::(http_response); + match res { + // On a ServerResponse error, the error needs inspecting as a DeviceCodeErrorResponse + // to work out whether a retry needs to happen. + Err(RequestTokenError::ServerResponse(dcer)) => { + match dcer.error() { + // On AuthorizationPending, a retry needs to happen with the same poll interval. + DeviceCodeErrorResponseType::AuthorizationPending => { + DeviceAccessTokenPollResult::ContinueWithNewPollInterval(current_interval) + } + // On SlowDown, a retry needs to happen with a larger poll interval. + DeviceCodeErrorResponseType::SlowDown => { + DeviceAccessTokenPollResult::ContinueWithNewPollInterval( + current_interval + Duration::from_secs(5), + ) + } + + // On any other error, just return the error. + _ => DeviceAccessTokenPollResult::Done( + Err(RequestTokenError::ServerResponse(dcer)), + PhantomData, + ), + } + } + + // On any other success or failure, return the failure. + res => DeviceAccessTokenPollResult::Done(res, PhantomData), + } + } + + fn compute_timeout( + &self, + timeout: Option, + ) -> Result, RequestTokenError> + where + RE: Error + 'static, + { + // Calculate the request timeout - if the user specified a timeout, + // use that, otherwise use the value given by the device authorization + // response. + let timeout_dur = timeout.unwrap_or_else(|| self.dev_auth_resp.expires_in()); + let chrono_timeout = chrono::Duration::from_std(timeout_dur).map_err(|e| { + RequestTokenError::Other(format!( + "Failed to convert `{:?}` to `chrono::Duration`: {}", + timeout_dur, e + )) + })?; + + // Calculate the DateTime at which the request times out. + let timeout_dt = (*self.time_fn)() + .checked_add_signed(chrono_timeout) + .ok_or_else(|| RequestTokenError::Other("Failed to calculate timeout".to_string()))?; + + Ok(timeout_dt) + } +} + /// The minimum amount of time in seconds that the client SHOULD wait /// between polling requests to the token endpoint. If no value is /// provided, clients MUST use 5 as the default. @@ -217,3 +574,468 @@ where ContinueWithNewPollInterval(Duration), Done(Result>, PhantomData), } + +#[cfg(test)] +mod tests { + use crate::basic::BasicTokenType; + use crate::tests::{mock_http_client, mock_http_client_success_fail, new_client}; + use crate::{ + DeviceAuthorizationUrl, DeviceCodeErrorResponse, DeviceCodeErrorResponseType, HttpResponse, + RequestTokenError, Scope, StandardDeviceAuthorizationResponse, TokenResponse, + }; + + use chrono::{DateTime, Utc}; + use http::header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE}; + use http::{HeaderValue, StatusCode}; + + use std::time::Duration; + + fn new_device_auth_details(expires_in: u32) -> StandardDeviceAuthorizationResponse { + let body = format!( + "{{\ + \"device_code\": \"12345\", \ + \"verification_uri\": \"https://verify/here\", \ + \"user_code\": \"abcde\", \ + \"verification_uri_complete\": \"https://verify/here?abcde\", \ + \"expires_in\": {}, \ + \"interval\": 1 \ + }}", + expires_in + ); + + let device_auth_url = + DeviceAuthorizationUrl::new("https://deviceauth/here".to_string()).unwrap(); + + let client = new_client().set_device_authorization_url(device_auth_url.clone()); + client + .exchange_device_code() + .add_extra_param("foo", "bar") + .add_scope(Scope::new("openid".to_string())) + .request(mock_http_client( + vec![ + (ACCEPT, "application/json"), + (CONTENT_TYPE, "application/x-www-form-urlencoded"), + (AUTHORIZATION, "Basic YWFhOmJiYg=="), + ], + "scope=openid&foo=bar", + Some(device_auth_url.url().to_owned()), + HttpResponse { + status_code: StatusCode::OK, + headers: vec![( + CONTENT_TYPE, + HeaderValue::from_str("application/json").unwrap(), + )] + .into_iter() + .collect(), + body: body.into_bytes(), + }, + )) + .unwrap() + } + + #[test] + fn test_device_token_pending_then_success() { + let details = new_device_auth_details(20); + assert_eq!("12345", details.device_code().secret()); + assert_eq!("https://verify/here", details.verification_uri().as_str()); + assert_eq!("abcde", details.user_code().secret().as_str()); + assert_eq!( + "https://verify/here?abcde", + details + .verification_uri_complete() + .unwrap() + .secret() + .as_str() + ); + assert_eq!(Duration::from_secs(20), details.expires_in()); + assert_eq!(Duration::from_secs(1), details.interval()); + + let token = new_client() + .exchange_device_access_token(&details) + .set_time_fn(mock_time_fn()) + .request(mock_http_client_success_fail( + None, + vec![ + (ACCEPT, "application/json"), + (CONTENT_TYPE, "application/x-www-form-urlencoded"), + (AUTHORIZATION, "Basic YWFhOmJiYg=="), + ], + "grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Adevice_code&device_code=12345", + HttpResponse { + status_code: StatusCode::from_u16(400).unwrap(), + headers: vec![( + CONTENT_TYPE, + HeaderValue::from_str("application/json").unwrap(), + )] + .into_iter() + .collect(), + body: "{\ + \"error\": \"authorization_pending\", \ + \"error_description\": \"Still waiting for user\"\ + }" + .to_string() + .into_bytes(), + }, + 5, + HttpResponse { + status_code: StatusCode::OK, + headers: vec![( + CONTENT_TYPE, + HeaderValue::from_str("application/json").unwrap(), + )] + .into_iter() + .collect(), + body: "{\ + \"access_token\": \"12/34\", \ + \"token_type\": \"bearer\", \ + \"scope\": \"openid\"\ + }" + .to_string() + .into_bytes(), + }, + ), + mock_sleep_fn, + None) + .unwrap(); + + assert_eq!("12/34", token.access_token().secret()); + assert_eq!(BasicTokenType::Bearer, *token.token_type()); + assert_eq!( + Some(&vec![Scope::new("openid".to_string()),]), + token.scopes() + ); + assert_eq!(None, token.expires_in()); + assert!(token.refresh_token().is_none()); + } + + #[test] + fn test_device_token_slowdown_then_success() { + let details = new_device_auth_details(3600); + assert_eq!("12345", details.device_code().secret()); + assert_eq!("https://verify/here", details.verification_uri().as_str()); + assert_eq!("abcde", details.user_code().secret().as_str()); + assert_eq!( + "https://verify/here?abcde", + details + .verification_uri_complete() + .unwrap() + .secret() + .as_str() + ); + assert_eq!(Duration::from_secs(3600), details.expires_in()); + assert_eq!(Duration::from_secs(1), details.interval()); + + let token = new_client() + .exchange_device_access_token(&details) + .set_time_fn(mock_time_fn()) + .request(mock_http_client_success_fail( + None, + vec![ + (ACCEPT, "application/json"), + (CONTENT_TYPE, "application/x-www-form-urlencoded"), + (AUTHORIZATION, "Basic YWFhOmJiYg=="), + ], + "grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Adevice_code&device_code=12345", + HttpResponse { + status_code: StatusCode::from_u16(400).unwrap(), + headers: vec![( + CONTENT_TYPE, + HeaderValue::from_str("application/json").unwrap(), + )] + .into_iter() + .collect(), + body: "{\ + \"error\": \"slow_down\", \ + \"error_description\": \"Woah there partner\"\ + }" + .to_string() + .into_bytes(), + }, + 5, + HttpResponse { + status_code: StatusCode::OK, + headers: vec![( + CONTENT_TYPE, + HeaderValue::from_str("application/json").unwrap(), + )] + .into_iter() + .collect(), + body: "{\ + \"access_token\": \"12/34\", \ + \"token_type\": \"bearer\", \ + \"scope\": \"openid\"\ + }" + .to_string() + .into_bytes(), + }, + ), + mock_sleep_fn, + None) + .unwrap(); + + assert_eq!("12/34", token.access_token().secret()); + assert_eq!(BasicTokenType::Bearer, *token.token_type()); + assert_eq!( + Some(&vec![Scope::new("openid".to_string()),]), + token.scopes() + ); + assert_eq!(None, token.expires_in()); + assert!(token.refresh_token().is_none()); + } + + struct IncreasingTime { + times: std::ops::RangeFrom, + } + + impl IncreasingTime { + fn new() -> Self { + Self { times: (0..) } + } + fn next(&mut self) -> DateTime { + let next_value = self.times.next().unwrap(); + let naive = chrono::NaiveDateTime::from_timestamp(next_value, 0); + DateTime::::from_utc(naive, chrono::Utc) + } + } + + /// Creates a time function that increments by one second each time. + fn mock_time_fn() -> impl Fn() -> DateTime + Send + Sync { + let timer = std::sync::Mutex::new(IncreasingTime::new()); + move || timer.lock().unwrap().next() + } + + /// Mock sleep function that doesn't actually sleep. + fn mock_sleep_fn(_: Duration) {} + + #[test] + fn test_exchange_device_code_and_token() { + let details = new_device_auth_details(3600); + assert_eq!("12345", details.device_code().secret()); + assert_eq!("https://verify/here", details.verification_uri().as_str()); + assert_eq!("abcde", details.user_code().secret().as_str()); + assert_eq!( + "https://verify/here?abcde", + details + .verification_uri_complete() + .unwrap() + .secret() + .as_str() + ); + assert_eq!(Duration::from_secs(3600), details.expires_in()); + assert_eq!(Duration::from_secs(1), details.interval()); + + let token = new_client() + .exchange_device_access_token(&details) + .set_time_fn(mock_time_fn()) + .request(mock_http_client( + vec![ + (ACCEPT, "application/json"), + (CONTENT_TYPE, "application/x-www-form-urlencoded"), + (AUTHORIZATION, "Basic YWFhOmJiYg=="), + ], + "grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Adevice_code&device_code=12345", + None, + HttpResponse { + status_code: StatusCode::OK, + headers: vec![( + CONTENT_TYPE, + HeaderValue::from_str("application/json").unwrap(), + )] + .into_iter() + .collect(), + body: "{\ + \"access_token\": \"12/34\", \ + \"token_type\": \"bearer\", \ + \"scope\": \"openid\"\ + }" + .to_string() + .into_bytes(), + }, + ), + mock_sleep_fn, + None) + .unwrap(); + + assert_eq!("12/34", token.access_token().secret()); + assert_eq!(BasicTokenType::Bearer, *token.token_type()); + assert_eq!( + Some(&vec![Scope::new("openid".to_string()),]), + token.scopes() + ); + assert_eq!(None, token.expires_in()); + assert!(token.refresh_token().is_none()); + } + + #[test] + fn test_device_token_authorization_timeout() { + let details = new_device_auth_details(2); + assert_eq!("12345", details.device_code().secret()); + assert_eq!("https://verify/here", details.verification_uri().as_str()); + assert_eq!("abcde", details.user_code().secret().as_str()); + assert_eq!( + "https://verify/here?abcde", + details + .verification_uri_complete() + .unwrap() + .secret() + .as_str() + ); + assert_eq!(Duration::from_secs(2), details.expires_in()); + assert_eq!(Duration::from_secs(1), details.interval()); + + let token = new_client() + .exchange_device_access_token(&details) + .set_time_fn(mock_time_fn()) + .request(mock_http_client( + vec![ + (ACCEPT, "application/json"), + (CONTENT_TYPE, "application/x-www-form-urlencoded"), + (AUTHORIZATION, "Basic YWFhOmJiYg=="), + ], + "grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Adevice_code&device_code=12345", + None, + HttpResponse { + status_code: StatusCode::from_u16(400).unwrap(), + headers: vec![( + CONTENT_TYPE, + HeaderValue::from_str("application/json").unwrap(), + )] + .into_iter() + .collect(), + body: "{\ + \"error\": \"authorization_pending\", \ + \"error_description\": \"Still waiting for user\"\ + }" + .to_string() + .into_bytes(), + }, + ), + mock_sleep_fn, + None) + .err() + .unwrap(); + match token { + RequestTokenError::ServerResponse(msg) => assert_eq!( + msg, + DeviceCodeErrorResponse::new( + DeviceCodeErrorResponseType::ExpiredToken, + Some(String::from("This device code has expired.")), + None, + ) + ), + _ => unreachable!("Error should be an expiry"), + } + } + + #[test] + fn test_device_token_access_denied() { + let details = new_device_auth_details(2); + assert_eq!("12345", details.device_code().secret()); + assert_eq!("https://verify/here", details.verification_uri().as_str()); + assert_eq!("abcde", details.user_code().secret().as_str()); + assert_eq!( + "https://verify/here?abcde", + details + .verification_uri_complete() + .unwrap() + .secret() + .as_str() + ); + assert_eq!(Duration::from_secs(2), details.expires_in()); + assert_eq!(Duration::from_secs(1), details.interval()); + + let token = new_client() + .exchange_device_access_token(&details) + .set_time_fn(mock_time_fn()) + .request(mock_http_client( + vec![ + (ACCEPT, "application/json"), + (CONTENT_TYPE, "application/x-www-form-urlencoded"), + (AUTHORIZATION, "Basic YWFhOmJiYg=="), + ], + "grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Adevice_code&device_code=12345", + None, + HttpResponse { + status_code: StatusCode::from_u16(400).unwrap(), + headers: vec![( + CONTENT_TYPE, + HeaderValue::from_str("application/json").unwrap(), + )] + .into_iter() + .collect(), + body: "{\ + \"error\": \"access_denied\", \ + \"error_description\": \"Access Denied\"\ + }" + .to_string() + .into_bytes(), + }, + ), + mock_sleep_fn, + None) + .err() + .unwrap(); + match token { + RequestTokenError::ServerResponse(msg) => { + assert_eq!(msg.error(), &DeviceCodeErrorResponseType::AccessDenied) + } + _ => unreachable!("Error should be Access Denied"), + } + } + + #[test] + fn test_device_token_expired() { + let details = new_device_auth_details(2); + assert_eq!("12345", details.device_code().secret()); + assert_eq!("https://verify/here", details.verification_uri().as_str()); + assert_eq!("abcde", details.user_code().secret().as_str()); + assert_eq!( + "https://verify/here?abcde", + details + .verification_uri_complete() + .unwrap() + .secret() + .as_str() + ); + assert_eq!(Duration::from_secs(2), details.expires_in()); + assert_eq!(Duration::from_secs(1), details.interval()); + + let token = new_client() + .exchange_device_access_token(&details) + .set_time_fn(mock_time_fn()) + .request(mock_http_client( + vec![ + (ACCEPT, "application/json"), + (CONTENT_TYPE, "application/x-www-form-urlencoded"), + (AUTHORIZATION, "Basic YWFhOmJiYg=="), + ], + "grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Adevice_code&device_code=12345", + None, + HttpResponse { + status_code: StatusCode::from_u16(400).unwrap(), + headers: vec![( + CONTENT_TYPE, + HeaderValue::from_str("application/json").unwrap(), + )] + .into_iter() + .collect(), + body: "{\ + \"error\": \"expired_token\", \ + \"error_description\": \"Token has expired\"\ + }" + .to_string() + .into_bytes(), + }, + ), + mock_sleep_fn, + None) + .err() + .unwrap(); + match token { + RequestTokenError::ServerResponse(msg) => { + assert_eq!(msg.error(), &DeviceCodeErrorResponseType::ExpiredToken) + } + _ => unreachable!("Error should be ExpiredToken"), + } + } +} diff --git a/src/endpoint.rs b/src/endpoint.rs new file mode 100644 index 0000000..3956973 --- /dev/null +++ b/src/endpoint.rs @@ -0,0 +1,220 @@ +use crate::{ + AuthType, ClientId, ClientSecret, ErrorResponse, RedirectUrl, RequestTokenError, Scope, + CONTENT_TYPE_FORMENCODED, CONTENT_TYPE_JSON, +}; + +use http::header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE}; +use http::{HeaderMap, HeaderValue, StatusCode}; +use serde::de::DeserializeOwned; +use url::{form_urlencoded, Url}; + +use std::borrow::Cow; +use std::error::Error; + +/// An HTTP request. +#[derive(Clone, Debug)] +pub struct HttpRequest { + // These are all owned values so that the request can safely be passed between + // threads. + /// URL to which the HTTP request is being made. + pub url: Url, + /// HTTP request method for this request. + pub method: http::method::Method, + /// HTTP request headers to send. + pub headers: HeaderMap, + /// HTTP request body (typically for POST requests only). + pub body: Vec, +} + +/// An HTTP response. +#[derive(Clone, Debug)] +pub struct HttpResponse { + /// HTTP status code returned by the server. + pub status_code: StatusCode, + /// HTTP response headers returned by the server. + pub headers: HeaderMap, + /// HTTP response body returned by the server. + pub body: Vec, +} + +#[allow(clippy::too_many_arguments)] +pub(crate) fn endpoint_request<'a>( + auth_type: &'a AuthType, + client_id: &'a ClientId, + client_secret: Option<&'a ClientSecret>, + extra_params: &'a [(Cow<'a, str>, Cow<'a, str>)], + redirect_url: Option>, + scopes: Option<&'a Vec>>, + url: &'a Url, + params: Vec<(&'a str, &'a str)>, +) -> HttpRequest { + let mut headers = HeaderMap::new(); + headers.append(ACCEPT, HeaderValue::from_static(CONTENT_TYPE_JSON)); + headers.append( + CONTENT_TYPE, + HeaderValue::from_static(CONTENT_TYPE_FORMENCODED), + ); + + let scopes_opt = scopes.and_then(|scopes| { + if !scopes.is_empty() { + Some( + scopes + .iter() + .map(|s| s.to_string()) + .collect::>() + .join(" "), + ) + } else { + None + } + }); + + let mut params: Vec<(&str, &str)> = params; + if let Some(ref scopes) = scopes_opt { + params.push(("scope", scopes)); + } + + // FIXME: add support for auth extensions? e.g., client_secret_jwt and private_key_jwt + match (auth_type, client_secret) { + // Basic auth only makes sense when a client secret is provided. Otherwise, always pass the + // client ID in the request body. + (AuthType::BasicAuth, Some(secret)) => { + // Section 2.3.1 of RFC 6749 requires separately url-encoding the id and secret + // before using them as HTTP Basic auth username and password. Note that this is + // not standard for ordinary Basic auth, so curl won't do it for us. + let urlencoded_id: String = + form_urlencoded::byte_serialize(client_id.as_bytes()).collect(); + let urlencoded_secret: String = + form_urlencoded::byte_serialize(secret.secret().as_bytes()).collect(); + let b64_credential = + base64::encode(format!("{}:{}", &urlencoded_id, urlencoded_secret)); + headers.append( + AUTHORIZATION, + HeaderValue::from_str(&format!("Basic {}", &b64_credential)).unwrap(), + ); + } + (AuthType::RequestBody, _) | (AuthType::BasicAuth, None) => { + params.push(("client_id", client_id)); + if let Some(client_secret) = client_secret { + params.push(("client_secret", client_secret.secret())); + } + } + } + + if let Some(ref redirect_url) = redirect_url { + params.push(("redirect_uri", redirect_url.as_str())); + } + + params.extend_from_slice( + extra_params + .iter() + .map(|(k, v)| (k.as_ref(), v.as_ref())) + .collect::>() + .as_slice(), + ); + + let body = form_urlencoded::Serializer::new(String::new()) + .extend_pairs(params) + .finish() + .into_bytes(); + + HttpRequest { + url: url.to_owned(), + method: http::method::Method::POST, + headers, + body, + } +} + +pub(crate) fn endpoint_response( + http_response: HttpResponse, +) -> Result> +where + RE: Error + 'static, + TE: ErrorResponse, + DO: DeserializeOwned, +{ + check_response_status(&http_response)?; + + check_response_body(&http_response)?; + + let response_body = http_response.body.as_slice(); + serde_path_to_error::deserialize(&mut serde_json::Deserializer::from_slice(response_body)) + .map_err(|e| RequestTokenError::Parse(e, response_body.to_vec())) +} + +pub(crate) fn endpoint_response_status_only( + http_response: HttpResponse, +) -> Result<(), RequestTokenError> +where + RE: Error + 'static, + TE: ErrorResponse, +{ + check_response_status(&http_response) +} + +fn check_response_status( + http_response: &HttpResponse, +) -> Result<(), RequestTokenError> +where + RE: Error + 'static, + TE: ErrorResponse, +{ + if http_response.status_code != StatusCode::OK { + let reason = http_response.body.as_slice(); + if reason.is_empty() { + Err(RequestTokenError::Other( + "Server returned empty error response".to_string(), + )) + } else { + let error = match serde_path_to_error::deserialize::<_, TE>( + &mut serde_json::Deserializer::from_slice(reason), + ) { + Ok(error) => RequestTokenError::ServerResponse(error), + Err(error) => RequestTokenError::Parse(error, reason.to_vec()), + }; + Err(error) + } + } else { + Ok(()) + } +} + +fn check_response_body( + http_response: &HttpResponse, +) -> Result<(), RequestTokenError> +where + RE: Error + 'static, + TE: ErrorResponse, +{ + // Validate that the response Content-Type is JSON. + http_response + .headers + .get(CONTENT_TYPE) + .map_or(Ok(()), |content_type| + // Section 3.1.1.1 of RFC 7231 indicates that media types are case-insensitive and + // may be followed by optional whitespace and/or a parameter (e.g., charset). + // See https://tools.ietf.org/html/rfc7231#section-3.1.1.1. + if content_type.to_str().ok().filter(|ct| ct.to_lowercase().starts_with(CONTENT_TYPE_JSON)).is_none() { + Err( + RequestTokenError::Other( + format!( + "Unexpected response Content-Type: {:?}, should be `{}`", + content_type, + CONTENT_TYPE_JSON + ) + ) + ) + } else { + Ok(()) + } + )?; + + if http_response.body.is_empty() { + return Err(RequestTokenError::Other( + "Server returned empty response body".to_string(), + )); + } + + Ok(()) +} diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..bebc96f --- /dev/null +++ b/src/error.rs @@ -0,0 +1,167 @@ +use serde::de::DeserializeOwned; +use serde::{Deserialize, Serialize}; + +use std::error::Error; +use std::fmt::{Debug, Display, Formatter}; + +/// Server Error Response +/// +/// See [Section 5.2](https://datatracker.ietf.org/doc/html/rfc6749#section-5.2) of RFC 6749. +/// This trait exists separately from the `StandardErrorResponse` struct +/// to support customization by clients, such as supporting interoperability with +/// non-standards-complaint OAuth2 providers. +/// +/// The [`Display`] trait implementation for types implementing [`ErrorResponse`] should be a +/// human-readable string suitable for printing (e.g., within a [`RequestTokenError`]). +pub trait ErrorResponse: Debug + Display + DeserializeOwned + Serialize {} + +/// Error types enum. +/// +/// NOTE: The serialization must return the `snake_case` representation of +/// this error type. This value must match the error type from the relevant OAuth 2.0 standards +/// (RFC 6749 or an extension). +pub trait ErrorResponseType: Debug + DeserializeOwned + Serialize {} + +/// Error response returned by server after requesting an access token. +/// +/// The fields in this structure are defined in +/// [Section 5.2 of RFC 6749](https://tools.ietf.org/html/rfc6749#section-5.2). This +/// trait is parameterized by a `ErrorResponseType` to support error types specific to future OAuth2 +/// authentication schemes and extensions. +#[derive(Clone, Debug, Deserialize, PartialEq, Eq, Serialize)] +pub struct StandardErrorResponse { + #[serde(bound = "T: ErrorResponseType")] + pub(crate) error: T, + #[serde(default)] + #[serde(skip_serializing_if = "Option::is_none")] + pub(crate) error_description: Option, + #[serde(default)] + #[serde(skip_serializing_if = "Option::is_none")] + pub(crate) error_uri: Option, +} + +impl StandardErrorResponse { + /// Instantiate a new `ErrorResponse`. + /// + /// # Arguments + /// + /// * `error` - REQUIRED. A single ASCII error code deserialized to the generic parameter. + /// `ErrorResponseType`. + /// * `error_description` - OPTIONAL. Human-readable ASCII text providing additional + /// information, used to assist the client developer in understanding the error that + /// occurred. Values for this parameter MUST NOT include characters outside the set + /// `%x20-21 / %x23-5B / %x5D-7E`. + /// * `error_uri` - OPTIONAL. A URI identifying a human-readable web page with information + /// about the error used to provide the client developer with additional information about + /// the error. Values for the "error_uri" parameter MUST conform to the URI-reference + /// syntax and thus MUST NOT include characters outside the set `%x21 / %x23-5B / %x5D-7E`. + pub fn new(error: T, error_description: Option, error_uri: Option) -> Self { + Self { + error, + error_description, + error_uri, + } + } + + /// REQUIRED. A single ASCII error code deserialized to the generic parameter + /// `ErrorResponseType`. + pub fn error(&self) -> &T { + &self.error + } + /// OPTIONAL. Human-readable ASCII text providing additional information, used to assist + /// the client developer in understanding the error that occurred. Values for this + /// parameter MUST NOT include characters outside the set `%x20-21 / %x23-5B / %x5D-7E`. + pub fn error_description(&self) -> Option<&String> { + self.error_description.as_ref() + } + /// OPTIONAL. URI identifying a human-readable web page with information about the error, + /// used to provide the client developer with additional information about the error. + /// Values for the "error_uri" parameter MUST conform to the URI-reference syntax and + /// thus MUST NOT include characters outside the set `%x21 / %x23-5B / %x5D-7E`. + pub fn error_uri(&self) -> Option<&String> { + self.error_uri.as_ref() + } +} + +impl ErrorResponse for StandardErrorResponse where T: ErrorResponseType + Display + 'static {} + +impl Display for StandardErrorResponse +where + TE: ErrorResponseType + Display, +{ + fn fmt(&self, f: &mut Formatter) -> Result<(), std::fmt::Error> { + let mut formatted = self.error().to_string(); + + if let Some(error_description) = self.error_description() { + formatted.push_str(": "); + formatted.push_str(error_description); + } + + if let Some(error_uri) = self.error_uri() { + formatted.push_str(" (see "); + formatted.push_str(error_uri); + formatted.push(')'); + } + + write!(f, "{}", formatted) + } +} + +/// Error encountered while requesting access token. +#[derive(Debug, thiserror::Error)] +pub enum RequestTokenError +where + RE: Error + 'static, + T: ErrorResponse + 'static, +{ + /// Error response returned by authorization server. Contains the parsed `ErrorResponse` + /// returned by the server. + #[error("Server returned error response: {0}")] + ServerResponse(T), + /// An error occurred while sending the request or receiving the response (e.g., network + /// connectivity failed). + #[error("Request failed")] + Request(#[from] RE), + /// Failed to parse server response. Parse errors may occur while parsing either successful + /// or error responses. + #[error("Failed to parse server response")] + Parse( + #[source] serde_path_to_error::Error, + Vec, + ), + /// Some other type of error occurred (e.g., an unexpected server response). + #[error("Other error: {}", _0)] + Other(String), +} + +#[cfg(test)] +mod tests { + use crate::basic::{BasicErrorResponse, BasicErrorResponseType}; + + #[test] + fn test_error_response_serializer() { + assert_eq!( + "{\"error\":\"unauthorized_client\"}", + serde_json::to_string(&BasicErrorResponse::new( + BasicErrorResponseType::UnauthorizedClient, + None, + None, + )) + .unwrap(), + ); + + assert_eq!( + "{\ + \"error\":\"invalid_client\",\ + \"error_description\":\"Invalid client_id\",\ + \"error_uri\":\"https://example.com/errors/invalid_client\"\ + }", + serde_json::to_string(&BasicErrorResponse::new( + BasicErrorResponseType::InvalidClient, + Some("Invalid client_id".to_string()), + Some("https://example.com/errors/invalid_client".to_string()), + )) + .unwrap(), + ); + } +} diff --git a/src/helpers.rs b/src/helpers.rs index 41875ed..038e6c1 100644 --- a/src/helpers.rs +++ b/src/helpers.rs @@ -356,3 +356,38 @@ pub fn variant_name(t: &T) -> &'static str { t.serialize(VariantName).unwrap() } + +#[cfg(test)] +mod tests { + use serde::Deserialize; + + #[derive(Deserialize, Debug, Clone)] + pub struct ObjectWithOptionalStringOrVecString { + #[serde(deserialize_with = "crate::helpers::deserialize_optional_string_or_vec_string")] + pub strings: Option>, + } + + #[test] + fn test_deserialize_optional_string_or_vec_string_none() { + let list_of_strings: ObjectWithOptionalStringOrVecString = + serde_json::from_str(r#"{ "strings": null }"#).unwrap(); + assert_eq!(None, list_of_strings.strings); + } + + #[test] + fn test_deserialize_optional_string_or_vec_string_single_value() { + let list_of_strings: ObjectWithOptionalStringOrVecString = + serde_json::from_str(r#"{ "strings": "v1" }"#).unwrap(); + assert_eq!(Some(vec!["v1".to_string()]), list_of_strings.strings); + } + + #[test] + fn test_deserialize_optional_string_or_vec_string_vec() { + let list_of_strings: ObjectWithOptionalStringOrVecString = + serde_json::from_str(r#"{ "strings": ["v1", "v2"] }"#).unwrap(); + assert_eq!( + Some(vec!["v1".to_string(), "v2".to_string()]), + list_of_strings.strings + ); + } +} diff --git a/src/introspection.rs b/src/introspection.rs new file mode 100644 index 0000000..3441394 --- /dev/null +++ b/src/introspection.rs @@ -0,0 +1,544 @@ +use crate::endpoint::{endpoint_request, endpoint_response}; +use crate::{ + AccessToken, AuthType, ClientId, ClientSecret, ErrorResponse, ExtraTokenFields, HttpRequest, + HttpResponse, IntrospectionUrl, RequestTokenError, Scope, TokenType, +}; + +use chrono::serde::ts_seconds_option; +use chrono::{DateTime, Utc}; +use serde::de::DeserializeOwned; +use serde::{Deserialize, Serialize}; + +use std::borrow::Cow; +use std::error::Error; +use std::fmt::Debug; +use std::future::Future; +use std::marker::PhantomData; + +/// A request to introspect an access token. +/// +/// See . +#[derive(Debug)] +pub struct IntrospectionRequest<'a, TE, TIR, TT> +where + TE: ErrorResponse, + TIR: TokenIntrospectionResponse, + TT: TokenType, +{ + pub(crate) token: &'a AccessToken, + pub(crate) token_type_hint: Option>, + pub(crate) auth_type: &'a AuthType, + pub(crate) client_id: &'a ClientId, + pub(crate) client_secret: Option<&'a ClientSecret>, + pub(crate) extra_params: Vec<(Cow<'a, str>, Cow<'a, str>)>, + pub(crate) introspection_url: &'a IntrospectionUrl, + pub(crate) _phantom: PhantomData<(TE, TIR, TT)>, +} + +impl<'a, TE, TIR, TT> IntrospectionRequest<'a, TE, TIR, TT> +where + TE: ErrorResponse + 'static, + TIR: TokenIntrospectionResponse, + TT: TokenType, +{ + /// Sets the optional token_type_hint parameter. + /// + /// See . + /// + /// OPTIONAL. A hint about the type of the token submitted for + /// introspection. The protected resource MAY pass this parameter to + /// help the authorization server optimize the token lookup. If the + /// server is unable to locate the token using the given hint, it MUST + /// extend its search across all of its supported token types. An + /// authorization server MAY ignore this parameter, particularly if it + /// is able to detect the token type automatically. Values for this + /// field are defined in the "OAuth Token Type Hints" registry defined + /// in OAuth Token Revocation [RFC7009](https://tools.ietf.org/html/rfc7009). + pub fn set_token_type_hint(mut self, value: V) -> Self + where + V: Into>, + { + self.token_type_hint = Some(value.into()); + + self + } + + /// Appends an extra param to the token introspection request. + /// + /// This method allows extensions to be used without direct support from + /// this crate. If `name` conflicts with a parameter managed by this crate, the + /// behavior is undefined. In particular, do not set parameters defined by + /// [RFC 6749](https://tools.ietf.org/html/rfc6749) or + /// [RFC 7662](https://tools.ietf.org/html/rfc7662). + /// + /// # Security Warning + /// + /// Callers should follow the security recommendations for any OAuth2 extensions used with + /// this function, which are beyond the scope of + /// [RFC 6749](https://tools.ietf.org/html/rfc6749). + pub fn add_extra_param(mut self, name: N, value: V) -> Self + where + N: Into>, + V: Into>, + { + self.extra_params.push((name.into(), value.into())); + self + } + + fn prepare_request(self) -> HttpRequest { + let mut params: Vec<(&str, &str)> = vec![("token", self.token.secret())]; + if let Some(ref token_type_hint) = self.token_type_hint { + params.push(("token_type_hint", token_type_hint)); + } + + endpoint_request( + self.auth_type, + self.client_id, + self.client_secret, + &self.extra_params, + None, + None, + self.introspection_url.url(), + params, + ) + } + + /// Synchronously sends the request to the authorization server and awaits a response. + pub fn request(self, http_client: F) -> Result> + where + F: FnOnce(HttpRequest) -> Result, + RE: Error + 'static, + { + endpoint_response(http_client(self.prepare_request())?) + } + + /// Asynchronously sends the request to the authorization server and returns a Future. + pub async fn request_async( + self, + http_client: C, + ) -> Result> + where + C: FnOnce(HttpRequest) -> F, + F: Future>, + RE: Error + 'static, + { + let http_response = http_client(self.prepare_request()).await?; + endpoint_response(http_response) + } +} + +/// Common methods shared by all OAuth2 token introspection implementations. +/// +/// The methods in this trait are defined in +/// [Section 2.2 of RFC 7662](https://tools.ietf.org/html/rfc7662#section-2.2). This trait exists +/// separately from the `StandardTokenIntrospectionResponse` struct to support customization by +/// clients, such as supporting interoperability with non-standards-complaint OAuth2 providers. +pub trait TokenIntrospectionResponse: Debug + DeserializeOwned + Serialize +where + TT: TokenType, +{ + /// REQUIRED. Boolean indicator of whether or not the presented token + /// is currently active. The specifics of a token's "active" state + /// will vary depending on the implementation of the authorization + /// server and the information it keeps about its tokens, but a "true" + /// value return for the "active" property will generally indicate + /// that a given token has been issued by this authorization server, + /// has not been revoked by the resource owner, and is within its + /// given time window of validity (e.g., after its issuance time and + /// before its expiration time). + fn active(&self) -> bool; + /// OPTIONAL. A JSON string containing a space-separated list of + /// scopes associated with this token, in the format described in + /// [Section 3.3 of RFC 7662](https://tools.ietf.org/html/rfc7662#section-3.3). + /// If included in the response, + /// this space-delimited field is parsed into a `Vec` of individual scopes. If omitted from + /// the response, this field is `None`. + fn scopes(&self) -> Option<&Vec>; + /// OPTIONAL. Client identifier for the OAuth 2.0 client that + /// requested this token. + fn client_id(&self) -> Option<&ClientId>; + /// OPTIONAL. Human-readable identifier for the resource owner who + /// authorized this token. + fn username(&self) -> Option<&str>; + /// OPTIONAL. Type of the token as defined in + /// [Section 5.1 of RFC 7662](https://tools.ietf.org/html/rfc7662#section-5.1). + /// Value is case insensitive and deserialized to the generic `TokenType` parameter. + fn token_type(&self) -> Option<&TT>; + /// OPTIONAL. Integer timestamp, measured in the number of seconds + /// since January 1 1970 UTC, indicating when this token will expire, + /// as defined in JWT [RFC7519](https://tools.ietf.org/html/rfc7519). + fn exp(&self) -> Option>; + /// OPTIONAL. Integer timestamp, measured in the number of seconds + /// since January 1 1970 UTC, indicating when this token was + /// originally issued, as defined in JWT [RFC7519](https://tools.ietf.org/html/rfc7519). + fn iat(&self) -> Option>; + /// OPTIONAL. Integer timestamp, measured in the number of seconds + /// since January 1 1970 UTC, indicating when this token is not to be + /// used before, as defined in JWT [RFC7519](https://tools.ietf.org/html/rfc7519). + fn nbf(&self) -> Option>; + /// OPTIONAL. Subject of the token, as defined in JWT [RFC7519](https://tools.ietf.org/html/rfc7519). + /// Usually a machine-readable identifier of the resource owner who + /// authorized this token. + fn sub(&self) -> Option<&str>; + /// OPTIONAL. Service-specific string identifier or list of string + /// identifiers representing the intended audience for this token, as + /// defined in JWT [RFC7519](https://tools.ietf.org/html/rfc7519). + fn aud(&self) -> Option<&Vec>; + /// OPTIONAL. String representing the issuer of this token, as + /// defined in JWT [RFC7519](https://tools.ietf.org/html/rfc7519). + fn iss(&self) -> Option<&str>; + /// OPTIONAL. String identifier for the token, as defined in JWT + /// [RFC7519](https://tools.ietf.org/html/rfc7519). + fn jti(&self) -> Option<&str>; +} + +/// Standard OAuth2 token introspection response. +/// +/// This struct includes the fields defined in +/// [Section 2.2 of RFC 7662](https://tools.ietf.org/html/rfc7662#section-2.2), as well as +/// extensions defined by the `EF` type parameter. +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct StandardTokenIntrospectionResponse +where + EF: ExtraTokenFields, + TT: TokenType + 'static, +{ + active: bool, + #[serde(rename = "scope")] + #[serde(deserialize_with = "crate::helpers::deserialize_space_delimited_vec")] + #[serde(serialize_with = "crate::helpers::serialize_space_delimited_vec")] + #[serde(skip_serializing_if = "Option::is_none")] + #[serde(default)] + scopes: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + client_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + username: Option, + #[serde( + bound = "TT: TokenType", + skip_serializing_if = "Option::is_none", + deserialize_with = "crate::helpers::deserialize_untagged_enum_case_insensitive", + default = "none_field" + )] + token_type: Option, + #[serde(skip_serializing_if = "Option::is_none")] + #[serde(with = "ts_seconds_option")] + #[serde(default)] + exp: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + #[serde(with = "ts_seconds_option")] + #[serde(default)] + iat: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + #[serde(with = "ts_seconds_option")] + #[serde(default)] + nbf: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + sub: Option, + #[serde(skip_serializing_if = "Option::is_none")] + #[serde(default)] + #[serde(deserialize_with = "crate::helpers::deserialize_optional_string_or_vec_string")] + aud: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + iss: Option, + #[serde(skip_serializing_if = "Option::is_none")] + jti: Option, + + #[serde(bound = "EF: ExtraTokenFields")] + #[serde(flatten)] + extra_fields: EF, +} + +fn none_field() -> Option { + None +} + +impl StandardTokenIntrospectionResponse +where + EF: ExtraTokenFields, + TT: TokenType, +{ + /// Instantiate a new OAuth2 token introspection response. + pub fn new(active: bool, extra_fields: EF) -> Self { + Self { + active, + + scopes: None, + client_id: None, + username: None, + token_type: None, + exp: None, + iat: None, + nbf: None, + sub: None, + aud: None, + iss: None, + jti: None, + extra_fields, + } + } + + /// Sets the `set_active` field. + pub fn set_active(&mut self, active: bool) { + self.active = active; + } + /// Sets the `set_scopes` field. + pub fn set_scopes(&mut self, scopes: Option>) { + self.scopes = scopes; + } + /// Sets the `set_client_id` field. + pub fn set_client_id(&mut self, client_id: Option) { + self.client_id = client_id; + } + /// Sets the `set_username` field. + pub fn set_username(&mut self, username: Option) { + self.username = username; + } + /// Sets the `set_token_type` field. + pub fn set_token_type(&mut self, token_type: Option) { + self.token_type = token_type; + } + /// Sets the `set_exp` field. + pub fn set_exp(&mut self, exp: Option>) { + self.exp = exp; + } + /// Sets the `set_iat` field. + pub fn set_iat(&mut self, iat: Option>) { + self.iat = iat; + } + /// Sets the `set_nbf` field. + pub fn set_nbf(&mut self, nbf: Option>) { + self.nbf = nbf; + } + /// Sets the `set_sub` field. + pub fn set_sub(&mut self, sub: Option) { + self.sub = sub; + } + /// Sets the `set_aud` field. + pub fn set_aud(&mut self, aud: Option>) { + self.aud = aud; + } + /// Sets the `set_iss` field. + pub fn set_iss(&mut self, iss: Option) { + self.iss = iss; + } + /// Sets the `set_jti` field. + pub fn set_jti(&mut self, jti: Option) { + self.jti = jti; + } + /// Extra fields defined by the client application. + pub fn extra_fields(&self) -> &EF { + &self.extra_fields + } + /// Sets the `set_extra_fields` field. + pub fn set_extra_fields(&mut self, extra_fields: EF) { + self.extra_fields = extra_fields; + } +} +impl TokenIntrospectionResponse for StandardTokenIntrospectionResponse +where + EF: ExtraTokenFields, + TT: TokenType, +{ + fn active(&self) -> bool { + self.active + } + + fn scopes(&self) -> Option<&Vec> { + self.scopes.as_ref() + } + + fn client_id(&self) -> Option<&ClientId> { + self.client_id.as_ref() + } + + fn username(&self) -> Option<&str> { + self.username.as_deref() + } + + fn token_type(&self) -> Option<&TT> { + self.token_type.as_ref() + } + + fn exp(&self) -> Option> { + self.exp + } + + fn iat(&self) -> Option> { + self.iat + } + + fn nbf(&self) -> Option> { + self.nbf + } + + fn sub(&self) -> Option<&str> { + self.sub.as_deref() + } + + fn aud(&self) -> Option<&Vec> { + self.aud.as_ref() + } + + fn iss(&self) -> Option<&str> { + self.iss.as_deref() + } + + fn jti(&self) -> Option<&str> { + self.jti.as_deref() + } +} + +#[cfg(test)] +mod tests { + use crate::basic::BasicTokenType; + use crate::tests::{mock_http_client, new_client}; + use crate::{ + AccessToken, AuthType, ClientId, HttpResponse, IntrospectionUrl, RedirectUrl, Scope, + }; + + use chrono::{TimeZone, Utc}; + use http::header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE}; + use http::{HeaderValue, StatusCode}; + + #[test] + fn test_token_introspection_successful_with_basic_auth_minimal_response() { + let client = new_client() + .set_auth_type(AuthType::BasicAuth) + .set_redirect_uri(RedirectUrl::new("https://redirect/here".to_string()).unwrap()) + .set_introspection_uri( + IntrospectionUrl::new("https://introspection/url".to_string()).unwrap(), + ); + + let introspection_response = client + .introspect(&AccessToken::new("access_token_123".to_string())) + .unwrap() + .request(mock_http_client( + vec![ + (ACCEPT, "application/json"), + (CONTENT_TYPE, "application/x-www-form-urlencoded"), + (AUTHORIZATION, "Basic YWFhOmJiYg=="), + ], + "token=access_token_123", + Some("https://introspection/url".parse().unwrap()), + HttpResponse { + status_code: StatusCode::OK, + headers: vec![( + CONTENT_TYPE, + HeaderValue::from_str("application/json").unwrap(), + )] + .into_iter() + .collect(), + body: "{\ + \"active\": true\ + }" + .to_string() + .into_bytes(), + }, + )) + .unwrap(); + + assert!(introspection_response.active); + assert_eq!(None, introspection_response.scopes); + assert_eq!(None, introspection_response.client_id); + assert_eq!(None, introspection_response.username); + assert_eq!(None, introspection_response.token_type); + assert_eq!(None, introspection_response.exp); + assert_eq!(None, introspection_response.iat); + assert_eq!(None, introspection_response.nbf); + assert_eq!(None, introspection_response.sub); + assert_eq!(None, introspection_response.aud); + assert_eq!(None, introspection_response.iss); + assert_eq!(None, introspection_response.jti); + } + + #[test] + fn test_token_introspection_successful_with_basic_auth_full_response() { + let client = new_client() + .set_auth_type(AuthType::BasicAuth) + .set_redirect_uri(RedirectUrl::new("https://redirect/here".to_string()).unwrap()) + .set_introspection_uri( + IntrospectionUrl::new("https://introspection/url".to_string()).unwrap(), + ); + + let introspection_response = client + .introspect(&AccessToken::new("access_token_123".to_string())) + .unwrap() + .set_token_type_hint("access_token") + .request(mock_http_client( + vec![ + (ACCEPT, "application/json"), + (CONTENT_TYPE, "application/x-www-form-urlencoded"), + (AUTHORIZATION, "Basic YWFhOmJiYg=="), + ], + "token=access_token_123&token_type_hint=access_token", + Some("https://introspection/url".parse().unwrap()), + HttpResponse { + status_code: StatusCode::OK, + headers: vec![( + CONTENT_TYPE, + HeaderValue::from_str("application/json").unwrap(), + )] + .into_iter() + .collect(), + body: r#"{ + "active": true, + "scope": "email profile", + "client_id": "aaa", + "username": "demo", + "token_type": "bearer", + "exp": 1604073517, + "iat": 1604073217, + "nbf": 1604073317, + "sub": "demo", + "aud": "demo", + "iss": "http://127.0.0.1:8080/auth/realms/test-realm", + "jti": "be1b7da2-fc18-47b3-bdf1-7a4f50bcf53f" + }"# + .to_string() + .into_bytes(), + }, + )) + .unwrap(); + + assert!(introspection_response.active); + assert_eq!( + Some(vec![ + Scope::new("email".to_string()), + Scope::new("profile".to_string()) + ]), + introspection_response.scopes + ); + assert_eq!( + Some(ClientId::new("aaa".to_string())), + introspection_response.client_id + ); + assert_eq!(Some("demo".to_string()), introspection_response.username); + assert_eq!( + Some(BasicTokenType::Bearer), + introspection_response.token_type + ); + assert_eq!( + Some(Utc.timestamp(1604073517, 0)), + introspection_response.exp + ); + assert_eq!( + Some(Utc.timestamp(1604073217, 0)), + introspection_response.iat + ); + assert_eq!( + Some(Utc.timestamp(1604073317, 0)), + introspection_response.nbf + ); + assert_eq!(Some("demo".to_string()), introspection_response.sub); + assert_eq!(Some(vec!["demo".to_string()]), introspection_response.aud); + assert_eq!( + Some("http://127.0.0.1:8080/auth/realms/test-realm".to_string()), + introspection_response.iss + ); + assert_eq!( + Some("be1b7da2-fc18-47b3-bdf1-7a4f50bcf53f".to_string()), + introspection_response.jti + ); + } +} diff --git a/src/lib.rs b/src/lib.rs index c766560..902f722 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -76,7 +76,8 @@ //! OAuth flows require comparing secrets received from the provider servers. To do so securely //! while avoiding [timing side-channels](https://en.wikipedia.org/wiki/Timing_attack), the //! comparison must be done in constant time, either using a constant-time crate such as -//! [`constant_time_eq`] (which could break if a future compiler version decides to be overly smart +//! [`constant_time_eq`](https://crates.io/crates/constant_time_eq) (which could break if a future +//! compiler version decides to be overly smart //! about its optimizations), or by first computing a cryptographically-secure hash (e.g., SHA-256) //! of both values and then comparing the hashes using `==`. //! @@ -413,29 +414,15 @@ //! //! - [`actix-web-oauth2`](https://github.com/pka/actix-web-oauth2) (version 2.x of this crate) //! -use crate::devicecode::DeviceAccessTokenPollResult; - -use chrono::serde::ts_seconds_option; -use chrono::{DateTime, Utc}; -use http::header::{HeaderMap, HeaderValue, ACCEPT, AUTHORIZATION, CONTENT_TYPE}; -use http::status::StatusCode; -use serde::de::DeserializeOwned; -use serde::{Deserialize, Serialize}; -use url::{form_urlencoded, Url}; - -use std::borrow::Cow; -use std::error::Error; -use std::fmt::Error as FormatterError; -use std::fmt::{Debug, Display, Formatter}; -use std::future::Future; -use std::marker::PhantomData; -use std::sync::Arc; -use std::time::Duration; /// Basic OAuth2 implementation with no extensions /// ([RFC 6749](https://tools.ietf.org/html/rfc6749)). pub mod basic; +mod client; + +mod code; + /// HTTP client backed by the [curl](https://crates.io/crates/curl) crate. /// Requires "curl" feature. #[cfg(all(feature = "curl", not(target_arch = "wasm32")))] @@ -446,23 +433,31 @@ compile_error!("wasm32 is not supported with the `curl` feature. Use the `reqwes /// Device Code Flow OAuth2 implementation /// ([RFC 8628](https://tools.ietf.org/html/rfc8628)). -pub mod devicecode; +mod devicecode; -/// OAuth 2.0 Token Revocation implementation -/// ([RFC 7009](https://tools.ietf.org/html/rfc7009)). -pub mod revocation; +mod endpoint; + +mod error; /// Helper methods used by OAuth2 implementations/extensions. pub mod helpers; +mod introspection; + /// HTTP client backed by the [reqwest](https://crates.io/crates/reqwest) crate. /// Requires "reqwest" feature. #[cfg(feature = "reqwest")] pub mod reqwest; +/// OAuth 2.0 Token Revocation implementation +/// ([RFC 7009](https://tools.ietf.org/html/rfc7009)). +mod revocation; + #[cfg(test)] mod tests; +mod token; + mod types; /// HTTP client backed by the [ureq](https://crates.io/crates/ureq) crate. @@ -470,17 +465,28 @@ mod types; #[cfg(feature = "ureq")] pub mod ureq; -/// Public re-exports of types used for HTTP client interfaces. -pub use http; -pub use url; - -pub use devicecode::{ - DeviceAuthorizationResponse, DeviceCodeErrorResponse, DeviceCodeErrorResponseType, - EmptyExtraDeviceAuthorizationFields, ExtraDeviceAuthorizationFields, - StandardDeviceAuthorizationResponse, +pub use crate::client::Client; +pub use crate::code::AuthorizationRequest; +pub use crate::devicecode::{ + DeviceAccessTokenRequest, DeviceAuthorizationRequest, DeviceAuthorizationResponse, + DeviceCodeErrorResponse, DeviceCodeErrorResponseType, EmptyExtraDeviceAuthorizationFields, + ExtraDeviceAuthorizationFields, StandardDeviceAuthorizationResponse, }; - -pub use types::{ +pub use crate::endpoint::{HttpRequest, HttpResponse}; +pub use crate::error::{ + ErrorResponse, ErrorResponseType, RequestTokenError, StandardErrorResponse, +}; +pub use crate::introspection::{ + IntrospectionRequest, StandardTokenIntrospectionResponse, TokenIntrospectionResponse, +}; +pub use crate::revocation::{ + RevocableToken, RevocationErrorResponseType, RevocationRequest, StandardRevocableToken, +}; +pub use crate::token::{ + ClientCredentialsTokenRequest, CodeTokenRequest, EmptyExtraTokenFields, ExtraTokenFields, + PasswordTokenRequest, RefreshTokenRequest, StandardTokenResponse, TokenResponse, TokenType, +}; +pub use crate::types::{ AccessToken, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, DeviceAuthorizationUrl, DeviceCode, EndUserVerificationUrl, IntrospectionUrl, PkceCodeChallenge, PkceCodeChallengeMethod, PkceCodeVerifier, RedirectUrl, RefreshToken, @@ -488,7 +494,9 @@ pub use types::{ UserCode, VerificationUriComplete, }; -pub use revocation::{RevocableToken, RevocationErrorResponseType, StandardRevocableToken}; +/// Public re-exports of types used for HTTP client interfaces. +pub use http; +pub use url; const CONTENT_TYPE_JSON: &str = "application/json"; const CONTENT_TYPE_FORMENCODED: &str = "application/x-www-form-urlencoded"; @@ -515,2691 +523,3 @@ pub enum AuthType { /// The client_id and client_secret will be included using the basic auth authentication scheme. BasicAuth, } - -/// Stores the configuration for an OAuth2 client. -/// -/// This type implements the -/// [Builder Pattern](https://doc.rust-lang.org/1.0.0/style/ownership/builders.html) together with -/// [typestates](https://cliffle.com/blog/rust-typestate/#what-are-typestates) to encode whether -/// certain fields have been set that are prerequisites to certain authentication flows. For -/// example, the authorization endpoint must be set via [`Client::set_auth_url`] before -/// [`Client::authorize_url`] can be called. Each endpoint has a corresponding const generic -/// parameter (e.g., `HAS_AUTH_URL`) used to statically enforce these dependencies. These generics -/// are set automatically by the corresponding setter functions, and in most cases user code should -/// not need to deal with them directly. -/// -/// # Error Types -/// -/// To enable compile time verification that only the correct and complete set of errors for the `Client` function being -/// invoked are exposed to the caller, the `Client` type is specialized on multiple implementations of the -/// [`ErrorResponse`] trait. The exact [`ErrorResponse`] implementation returned varies by the RFC that the invoked -/// `Client` function implements: -/// -/// - Generic type `TE` (aka Token Error) for errors defined by [RFC 6749 OAuth 2.0 Authorization Framework](https://tools.ietf.org/html/rfc6749). -/// - Generic type `TRE` (aka Token Revocation Error) for errors defined by [RFC 7009 OAuth 2.0 Token Revocation](https://tools.ietf.org/html/rfc7009). -/// -/// For example when revoking a token, error code `unsupported_token_type` (from RFC 7009) may be returned: -/// ```rust -/// # use thiserror::Error; -/// # use http::status::StatusCode; -/// # use http::header::{HeaderValue, CONTENT_TYPE}; -/// # use oauth2::{*, basic::*}; -/// # let client = BasicClient::new(ClientId::new("aaa".to_string())) -/// # .set_client_secret(ClientSecret::new("bbb".to_string())) -/// # .set_auth_url(AuthUrl::new("https://example.com/auth".to_string()).unwrap()) -/// # .set_token_url(TokenUrl::new("https://example.com/token".to_string()).unwrap()) -/// # .set_revocation_uri(RevocationUrl::new("https://revocation/url".to_string()).unwrap()); -/// # -/// # #[derive(Debug, Error)] -/// # enum FakeError { -/// # #[error("error")] -/// # Err, -/// # } -/// # -/// # let http_client = |_| -> Result { -/// # Ok(HttpResponse { -/// # status_code: StatusCode::BAD_REQUEST, -/// # headers: vec![( -/// # CONTENT_TYPE, -/// # HeaderValue::from_str("application/json").unwrap(), -/// # )] -/// # .into_iter() -/// # .collect(), -/// # body: "{\"error\": \"unsupported_token_type\", \"error_description\": \"stuff happened\", \ -/// # \"error_uri\": \"https://errors\"}" -/// # .to_string() -/// # .into_bytes(), -/// # }) -/// # }; -/// # -/// let res = client -/// .revoke_token(AccessToken::new("some token".to_string()).into()) -/// .unwrap() -/// .request(http_client); -/// -/// assert!(matches!(res, Err( -/// RequestTokenError::ServerResponse(err)) if matches!(err.error(), -/// RevocationErrorResponseType::UnsupportedTokenType))); -/// ``` -/// -/// # Examples -/// -/// See the [crate] root documentation for usage examples. -#[derive(Clone, Debug)] -pub struct Client< - TE, - TR, - TT, - TIR, - RT, - TRE, - const HAS_AUTH_URL: bool = false, - const HAS_DEVICE_AUTH_URL: bool = false, - const HAS_INTROSPECTION_URL: bool = false, - const HAS_REVOCATION_URL: bool = false, - const HAS_TOKEN_URL: bool = false, -> where - TE: ErrorResponse, - TR: TokenResponse, - TT: TokenType, - TIR: TokenIntrospectionResponse, - RT: RevocableToken, - TRE: ErrorResponse, -{ - client_id: ClientId, - client_secret: Option, - auth_url: Option, - auth_type: AuthType, - token_url: Option, - redirect_url: Option, - introspection_url: Option, - revocation_url: Option, - device_authorization_url: Option, - phantom: PhantomData<(TE, TR, TT, TIR, RT, TRE)>, -} -impl Client -where - TE: ErrorResponse + 'static, - TR: TokenResponse, - TT: TokenType, - TIR: TokenIntrospectionResponse, - RT: RevocableToken, - TRE: ErrorResponse + 'static, -{ - /// Initializes an OAuth2 client with the specified client ID. - pub fn new(client_id: ClientId) -> Self { - Self { - client_id, - client_secret: None, - auth_url: None, - auth_type: AuthType::BasicAuth, - token_url: None, - redirect_url: None, - introspection_url: None, - revocation_url: None, - device_authorization_url: None, - phantom: PhantomData, - } - } -} -impl< - TE, - TR, - TT, - TIR, - RT, - TRE, - const HAS_AUTH_URL: bool, - const HAS_DEVICE_AUTH_URL: bool, - const HAS_INTROSPECTION_URL: bool, - const HAS_REVOCATION_URL: bool, - const HAS_TOKEN_URL: bool, - > - Client< - TE, - TR, - TT, - TIR, - RT, - TRE, - HAS_AUTH_URL, - HAS_DEVICE_AUTH_URL, - HAS_INTROSPECTION_URL, - HAS_REVOCATION_URL, - HAS_TOKEN_URL, - > -where - TE: ErrorResponse + 'static, - TR: TokenResponse, - TT: TokenType, - TIR: TokenIntrospectionResponse, - RT: RevocableToken, - TRE: ErrorResponse + 'static, -{ - /// Configures the type of client authentication used for communicating with the authorization - /// server. - /// - /// The default is to use HTTP Basic authentication, as recommended in - /// [Section 2.3.1 of RFC 6749](https://tools.ietf.org/html/rfc6749#section-2.3.1). Note that - /// if a client secret is omitted (i.e., `client_secret` is set to `None` when calling - /// [`Client::new`]), [`AuthType::RequestBody`] is used regardless of the `auth_type` passed to - /// this function. - pub fn set_auth_type(mut self, auth_type: AuthType) -> Self { - self.auth_type = auth_type; - - self - } - - /// Sets the authorization endpoint. - /// - /// The client uses the authorization endpoint to obtain authorization from the resource owner - /// via user-agent redirection. This URL is used in all standard OAuth2 flows except the - /// [Resource Owner Password Credentials Grant](https://tools.ietf.org/html/rfc6749#section-4.3) - /// and the [Client Credentials Grant](https://tools.ietf.org/html/rfc6749#section-4.4). - pub fn set_auth_url( - self, - auth_url: AuthUrl, - ) -> Client< - TE, - TR, - TT, - TIR, - RT, - TRE, - true, - HAS_DEVICE_AUTH_URL, - HAS_INTROSPECTION_URL, - HAS_REVOCATION_URL, - HAS_TOKEN_URL, - > { - Client { - client_id: self.client_id, - client_secret: self.client_secret, - auth_url: Some(auth_url), - auth_type: self.auth_type, - token_url: self.token_url, - redirect_url: self.redirect_url, - introspection_url: self.introspection_url, - revocation_url: self.revocation_url, - device_authorization_url: self.device_authorization_url, - phantom: self.phantom, - } - } - - /// Sets the client secret. - /// - /// A client secret is generally used for confidential (i.e., server-side) OAuth2 clients and - /// omitted from public (browser or native app) OAuth2 clients (see - /// [RFC 8252](https://tools.ietf.org/html/rfc8252)). - pub fn set_client_secret(mut self, client_secret: ClientSecret) -> Self { - self.client_secret = Some(client_secret); - - self - } - - /// Sets the the device authorization URL used by the device authorization endpoint. - /// Used for Device Code Flow, as per [RFC 8628](https://tools.ietf.org/html/rfc8628). - pub fn set_device_authorization_url( - self, - device_authorization_url: DeviceAuthorizationUrl, - ) -> Client< - TE, - TR, - TT, - TIR, - RT, - TRE, - HAS_AUTH_URL, - true, - HAS_INTROSPECTION_URL, - HAS_REVOCATION_URL, - HAS_TOKEN_URL, - > { - Client { - client_id: self.client_id, - client_secret: self.client_secret, - auth_url: self.auth_url, - auth_type: self.auth_type, - token_url: self.token_url, - redirect_url: self.redirect_url, - introspection_url: self.introspection_url, - revocation_url: self.revocation_url, - device_authorization_url: Some(device_authorization_url), - phantom: self.phantom, - } - } - - /// Sets the introspection URL for contacting the ([RFC 7662](https://tools.ietf.org/html/rfc7662)) - /// introspection endpoint. - pub fn set_introspection_uri( - self, - introspection_url: IntrospectionUrl, - ) -> Client< - TE, - TR, - TT, - TIR, - RT, - TRE, - HAS_TOKEN_URL, - HAS_DEVICE_AUTH_URL, - true, - HAS_REVOCATION_URL, - HAS_TOKEN_URL, - > { - Client { - client_id: self.client_id, - client_secret: self.client_secret, - auth_url: self.auth_url, - auth_type: self.auth_type, - token_url: self.token_url, - redirect_url: self.redirect_url, - introspection_url: Some(introspection_url), - revocation_url: self.revocation_url, - device_authorization_url: self.device_authorization_url, - phantom: self.phantom, - } - } - - /// Sets the redirect URL used by the authorization endpoint. - pub fn set_redirect_uri(mut self, redirect_url: RedirectUrl) -> Self { - self.redirect_url = Some(redirect_url); - - self - } - - /// Sets the revocation URL for contacting the revocation endpoint ([RFC 7009](https://tools.ietf.org/html/rfc7009)). - /// - /// See: [`revoke_token()`](Self::revoke_token()) - pub fn set_revocation_uri( - self, - revocation_url: RevocationUrl, - ) -> Client< - TE, - TR, - TT, - TIR, - RT, - TRE, - HAS_TOKEN_URL, - HAS_DEVICE_AUTH_URL, - HAS_INTROSPECTION_URL, - true, - HAS_TOKEN_URL, - > { - Client { - client_id: self.client_id, - client_secret: self.client_secret, - auth_url: self.auth_url, - auth_type: self.auth_type, - token_url: self.token_url, - redirect_url: self.redirect_url, - introspection_url: self.introspection_url, - revocation_url: Some(revocation_url), - device_authorization_url: self.device_authorization_url, - phantom: self.phantom, - } - } - - /// Sets the token endpoint. - /// - /// The client uses the token endpoint to exchange an authorization code for an access token, - /// typically with client authentication. This URL is used in - /// all standard OAuth2 flows except the - /// [Implicit Grant](https://tools.ietf.org/html/rfc6749#section-4.2). - pub fn set_token_url( - self, - token_url: TokenUrl, - ) -> Client< - TE, - TR, - TT, - TIR, - RT, - TRE, - HAS_AUTH_URL, - HAS_DEVICE_AUTH_URL, - HAS_INTROSPECTION_URL, - HAS_REVOCATION_URL, - true, - > { - Client { - client_id: self.client_id, - client_secret: self.client_secret, - auth_url: self.auth_url, - auth_type: self.auth_type, - token_url: Some(token_url), - redirect_url: self.redirect_url, - introspection_url: self.introspection_url, - revocation_url: self.revocation_url, - device_authorization_url: self.device_authorization_url, - phantom: self.phantom, - } - } - - /// Returns the Client ID. - pub fn client_id(&self) -> &ClientId { - &self.client_id - } - - /// Returns the type of client authentication used for communicating with the authorization - /// server. - pub fn auth_type(&self) -> &AuthType { - &self.auth_type - } - - /// Returns the redirect URL used by the authorization endpoint. - pub fn redirect_url(&self) -> Option<&RedirectUrl> { - self.redirect_url.as_ref() - } -} - -// Methods requiring an authorization endpoint. -impl< - TE, - TR, - TT, - TIR, - RT, - TRE, - const HAS_DEVICE_AUTH_URL: bool, - const HAS_INTROSPECTION_URL: bool, - const HAS_REVOCATION_URL: bool, - const HAS_TOKEN_URL: bool, - > - Client< - TE, - TR, - TT, - TIR, - RT, - TRE, - true, - HAS_DEVICE_AUTH_URL, - HAS_INTROSPECTION_URL, - HAS_REVOCATION_URL, - HAS_TOKEN_URL, - > -where - TE: ErrorResponse + 'static, - TR: TokenResponse, - TT: TokenType, - TIR: TokenIntrospectionResponse, - RT: RevocableToken, - TRE: ErrorResponse + 'static, -{ - /// Returns the authorization endpoint. - pub fn auth_url(&self) -> &AuthUrl { - // This is enforced statically via the HAS_AUTH_URL const generic. - self.auth_url.as_ref().expect("should have auth_url") - } - - /// Generates an authorization URL for a new authorization request. - /// - /// # Arguments - /// - /// * `state_fn` - A function that returns an opaque value used by the client to maintain state - /// between the request and callback. The authorization server includes this value when - /// redirecting the user-agent back to the client. - /// - /// # Security Warning - /// - /// Callers should use a fresh, unpredictable `state` for each authorization request and verify - /// that this value matches the `state` parameter passed by the authorization server to the - /// redirect URI. Doing so mitigates - /// [Cross-Site Request Forgery](https://tools.ietf.org/html/rfc6749#section-10.12) - /// attacks. To disable CSRF protections (NOT recommended), use `insecure::authorize_url` - /// instead. - pub fn authorize_url(&self, state_fn: S) -> AuthorizationRequest - where - S: FnOnce() -> CsrfToken, - { - AuthorizationRequest { - // This is enforced statically via the HAS_AUTH_URL const generic. - auth_url: self.auth_url(), - client_id: &self.client_id, - extra_params: Vec::new(), - pkce_challenge: None, - redirect_url: self.redirect_url.as_ref().map(Cow::Borrowed), - response_type: "code".into(), - scopes: Vec::new(), - state: state_fn(), - } - } -} - -// Methods requiring a token endpoint. -impl< - TE, - TR, - TT, - TIR, - RT, - TRE, - const HAS_AUTH_URL: bool, - const HAS_DEVICE_AUTH_URL: bool, - const HAS_INTROSPECTION_URL: bool, - const HAS_REVOCATION_URL: bool, - > - Client< - TE, - TR, - TT, - TIR, - RT, - TRE, - HAS_AUTH_URL, - HAS_DEVICE_AUTH_URL, - HAS_INTROSPECTION_URL, - HAS_REVOCATION_URL, - true, - > -where - TE: ErrorResponse + 'static, - TR: TokenResponse, - TT: TokenType, - TIR: TokenIntrospectionResponse, - RT: RevocableToken, - TRE: ErrorResponse + 'static, -{ - /// Requests an access token for the *client credentials* grant type. - /// - /// See . - pub fn exchange_client_credentials(&self) -> ClientCredentialsTokenRequest { - ClientCredentialsTokenRequest { - auth_type: &self.auth_type, - client_id: &self.client_id, - client_secret: self.client_secret.as_ref(), - extra_params: Vec::new(), - scopes: Vec::new(), - // This is enforced statically via the HAS_TOKEN_URL const generic. - token_url: self.token_url.as_ref().expect("should have token_url"), - _phantom: PhantomData, - } - } - - /// Exchanges a code produced by a successful authorization process with an access token. - /// - /// Acquires ownership of the `code` because authorization codes may only be used once to - /// retrieve an access token from the authorization server. - /// - /// See . - pub fn exchange_code(&self, code: AuthorizationCode) -> CodeTokenRequest { - CodeTokenRequest { - auth_type: &self.auth_type, - client_id: &self.client_id, - client_secret: self.client_secret.as_ref(), - code, - extra_params: Vec::new(), - pkce_verifier: None, - // This is enforced statically via the HAS_TOKEN_URL const generic. - token_url: self.token_url.as_ref().expect("should have token_url"), - redirect_url: self.redirect_url.as_ref().map(Cow::Borrowed), - _phantom: PhantomData, - } - } - - /// Perform a device access token request as per - /// . - pub fn exchange_device_access_token<'a, 'b, 'c, EF>( - &'a self, - auth_response: &'b DeviceAuthorizationResponse, - ) -> DeviceAccessTokenRequest<'b, 'c, TR, TT, EF> - where - 'a: 'b, - EF: ExtraDeviceAuthorizationFields, - { - DeviceAccessTokenRequest { - auth_type: &self.auth_type, - client_id: &self.client_id, - client_secret: self.client_secret.as_ref(), - extra_params: Vec::new(), - // This is enforced statically via the HAS_TOKEN_URL const generic. - token_url: self.token_url.as_ref().expect("should have token_url"), - dev_auth_resp: auth_response, - time_fn: Arc::new(Utc::now), - max_backoff_interval: None, - _phantom: PhantomData, - } - } - - /// Requests an access token for the *password* grant type. - /// - /// See . - pub fn exchange_password<'a, 'b>( - &'a self, - username: &'b ResourceOwnerUsername, - password: &'b ResourceOwnerPassword, - ) -> PasswordTokenRequest<'b, TE, TR, TT> - where - 'a: 'b, - { - PasswordTokenRequest::<'b> { - auth_type: &self.auth_type, - client_id: &self.client_id, - client_secret: self.client_secret.as_ref(), - username, - password, - extra_params: Vec::new(), - scopes: Vec::new(), - // This is enforced statically via the HAS_TOKEN_URL const generic. - token_url: self.token_url.as_ref().expect("should have token_url"), - _phantom: PhantomData, - } - } - - /// Exchanges a refresh token for an access token - /// - /// See . - pub fn exchange_refresh_token<'a, 'b>( - &'a self, - refresh_token: &'b RefreshToken, - ) -> RefreshTokenRequest<'b, TE, TR, TT> - where - 'a: 'b, - { - RefreshTokenRequest { - auth_type: &self.auth_type, - client_id: &self.client_id, - client_secret: self.client_secret.as_ref(), - extra_params: Vec::new(), - refresh_token, - scopes: Vec::new(), - // This is enforced statically via the HAS_TOKEN_URL const generic. - token_url: self.token_url.as_ref().expect("should have token_url"), - _phantom: PhantomData, - } - } - - /// Returns the token endpoint. - pub fn token_url(&self) -> &TokenUrl { - // This is enforced statically via the HAS_TOKEN_URL const generic. - self.token_url.as_ref().expect("should have token_url") - } -} - -// Methods requiring a device authorization endpoint. -impl< - TE, - TR, - TT, - TIR, - RT, - TRE, - const HAS_AUTH_URL: bool, - const HAS_INTROSPECTION_URL: bool, - const HAS_REVOCATION_URL: bool, - const HAS_TOKEN_URL: bool, - > - Client< - TE, - TR, - TT, - TIR, - RT, - TRE, - HAS_AUTH_URL, - true, - HAS_INTROSPECTION_URL, - HAS_REVOCATION_URL, - HAS_TOKEN_URL, - > -where - TE: ErrorResponse + 'static, - TR: TokenResponse, - TT: TokenType, - TIR: TokenIntrospectionResponse, - RT: RevocableToken, - TRE: ErrorResponse + 'static, -{ - /// Perform a device authorization request as per - /// . - pub fn exchange_device_code(&self) -> DeviceAuthorizationRequest { - DeviceAuthorizationRequest { - auth_type: &self.auth_type, - client_id: &self.client_id, - client_secret: self.client_secret.as_ref(), - extra_params: Vec::new(), - scopes: Vec::new(), - // This is enforced statically via the HAS_DEVICE_AUTH_URL const generic. - device_authorization_url: self - .device_authorization_url - .as_ref() - .expect("should have device_authorization_url"), - _phantom: PhantomData, - } - } - - /// Returns the device authorization URL used by the device authorization endpoint. - pub fn device_authorization_url(&self) -> &DeviceAuthorizationUrl { - // This is enforced statically via the HAS_DEVICE_AUTH_URL const generic. - self.device_authorization_url - .as_ref() - .expect("should have device_authorization_url") - } -} - -// Methods requiring an introspection endpoint. -impl< - TE, - TR, - TT, - TIR, - RT, - TRE, - const HAS_AUTH_URL: bool, - const HAS_DEVICE_AUTH_URL: bool, - const HAS_REVOCATION_URL: bool, - const HAS_TOKEN_URL: bool, - > - Client< - TE, - TR, - TT, - TIR, - RT, - TRE, - HAS_AUTH_URL, - HAS_DEVICE_AUTH_URL, - true, - HAS_REVOCATION_URL, - HAS_TOKEN_URL, - > -where - TE: ErrorResponse + 'static, - TR: TokenResponse, - TT: TokenType, - TIR: TokenIntrospectionResponse, - RT: RevocableToken, - TRE: ErrorResponse + 'static, -{ - /// Query the authorization server [`RFC 7662 compatible`](https://tools.ietf.org/html/rfc7662) introspection - /// endpoint to determine the set of metadata for a previously received token. - /// - /// Requires [`set_introspection_uri()`](Self::set_introspection_uri) to have been previously - /// called to set the introspection endpoint URL. - pub fn introspect<'a>( - &'a self, - token: &'a AccessToken, - ) -> Result, ConfigurationError> { - Ok(IntrospectionRequest { - auth_type: &self.auth_type, - client_id: &self.client_id, - client_secret: self.client_secret.as_ref(), - extra_params: Vec::new(), - // This is enforced statically via the HAS_INTROSPECTION_URL const generic. - introspection_url: self - .introspection_url - .as_ref() - .expect("should have introspection_url"), - token, - token_type_hint: None, - _phantom: PhantomData, - }) - } - - /// Returns the introspection URL for contacting the ([RFC 7662](https://tools.ietf.org/html/rfc7662)) - /// introspection endpoint. - pub fn introspection_url(&self) -> &IntrospectionUrl { - // This is enforced statically via the HAS_INTROSPECTION_URL const generic. - self.introspection_url - .as_ref() - .expect("should have introspection_url") - } -} - -// Methods requiring a revocation endpoint. -impl< - TE, - TR, - TT, - TIR, - RT, - TRE, - const HAS_AUTH_URL: bool, - const HAS_DEVICE_AUTH_URL: bool, - const HAS_INTROSPECTION_URL: bool, - const HAS_TOKEN_URL: bool, - > - Client< - TE, - TR, - TT, - TIR, - RT, - TRE, - HAS_AUTH_URL, - HAS_DEVICE_AUTH_URL, - HAS_INTROSPECTION_URL, - true, - HAS_TOKEN_URL, - > -where - TE: ErrorResponse + 'static, - TR: TokenResponse, - TT: TokenType, - TIR: TokenIntrospectionResponse, - RT: RevocableToken, - TRE: ErrorResponse + 'static, -{ - /// Attempts to revoke the given previously received token using an - /// [RFC 7009 OAuth 2.0 Token Revocation](https://tools.ietf.org/html/rfc7009) compatible - /// endpoint. - /// - /// Requires [`set_revocation_uri()`](Self::set_revocation_uri) to have been previously - /// called to set the revocation endpoint URL. - pub fn revoke_token( - &self, - token: RT, - ) -> Result, ConfigurationError> { - // https://tools.ietf.org/html/rfc7009#section-2 states: - // "The client requests the revocation of a particular token by making an - // HTTP POST request to the token revocation endpoint URL. This URL - // MUST conform to the rules given in [RFC6749], Section 3.1. Clients - // MUST verify that the URL is an HTTPS URL." - - // This is enforced statically via the HAS_REVOCATION_URL const generic. - let revocation_url = self - .revocation_url - .as_ref() - .expect("should have revocation_url"); - - if revocation_url.url().scheme() != "https" { - return Err(ConfigurationError::InsecureUrl("revocation")); - } - - Ok(RevocationRequest { - auth_type: &self.auth_type, - client_id: &self.client_id, - client_secret: self.client_secret.as_ref(), - extra_params: Vec::new(), - revocation_url, - token, - _phantom: PhantomData, - }) - } - - /// Returns the revocation URL for contacting the revocation endpoint - /// ([RFC 7009](https://tools.ietf.org/html/rfc7009)). - /// - /// See: [`revoke_token()`](Self::revoke_token()) - pub fn revocation_url(&self) -> &RevocationUrl { - // This is enforced statically via the HAS_REVOCATION_URL const generic. - self.revocation_url - .as_ref() - .expect("should have revocation_url") - } -} - -/// A request to the authorization endpoint -#[derive(Debug)] -pub struct AuthorizationRequest<'a> { - auth_url: &'a AuthUrl, - client_id: &'a ClientId, - extra_params: Vec<(Cow<'a, str>, Cow<'a, str>)>, - pkce_challenge: Option, - redirect_url: Option>, - response_type: Cow<'a, str>, - scopes: Vec>, - state: CsrfToken, -} -impl<'a> AuthorizationRequest<'a> { - /// Appends a new scope to the authorization URL. - pub fn add_scope(mut self, scope: Scope) -> Self { - self.scopes.push(Cow::Owned(scope)); - self - } - - /// Appends a collection of scopes to the token request. - pub fn add_scopes(mut self, scopes: I) -> Self - where - I: IntoIterator, - { - self.scopes.extend(scopes.into_iter().map(Cow::Owned)); - self - } - - /// Appends an extra param to the authorization URL. - /// - /// This method allows extensions to be used without direct support from - /// this crate. If `name` conflicts with a parameter managed by this crate, the - /// behavior is undefined. In particular, do not set parameters defined by - /// [RFC 6749](https://tools.ietf.org/html/rfc6749) or - /// [RFC 7636](https://tools.ietf.org/html/rfc7636). - /// - /// # Security Warning - /// - /// Callers should follow the security recommendations for any OAuth2 extensions used with - /// this function, which are beyond the scope of - /// [RFC 6749](https://tools.ietf.org/html/rfc6749). - pub fn add_extra_param(mut self, name: N, value: V) -> Self - where - N: Into>, - V: Into>, - { - self.extra_params.push((name.into(), value.into())); - self - } - - /// Enables the [Implicit Grant](https://tools.ietf.org/html/rfc6749#section-4.2) flow. - pub fn use_implicit_flow(mut self) -> Self { - self.response_type = "token".into(); - self - } - - /// Enables custom flows other than the `code` and `token` (implicit flow) grant. - pub fn set_response_type(mut self, response_type: &ResponseType) -> Self { - self.response_type = (**response_type).to_owned().into(); - self - } - - /// Enables the use of [Proof Key for Code Exchange](https://tools.ietf.org/html/rfc7636) - /// (PKCE). - /// - /// PKCE is *highly recommended* for all public clients (i.e., those for which there - /// is no client secret or for which the client secret is distributed with the client, - /// such as in a native, mobile app, or browser app). - pub fn set_pkce_challenge(mut self, pkce_code_challenge: PkceCodeChallenge) -> Self { - self.pkce_challenge = Some(pkce_code_challenge); - self - } - - /// Overrides the `redirect_url` to the one specified. - pub fn set_redirect_uri(mut self, redirect_url: Cow<'a, RedirectUrl>) -> Self { - self.redirect_url = Some(redirect_url); - self - } - - /// Returns the full authorization URL and CSRF state for this authorization - /// request. - pub fn url(self) -> (Url, CsrfToken) { - let scopes = self - .scopes - .iter() - .map(|s| s.to_string()) - .collect::>() - .join(" "); - - let url = { - let mut pairs: Vec<(&str, &str)> = vec![ - ("response_type", self.response_type.as_ref()), - ("client_id", self.client_id), - ("state", self.state.secret()), - ]; - - if let Some(ref pkce_challenge) = self.pkce_challenge { - pairs.push(("code_challenge", pkce_challenge.as_str())); - pairs.push(("code_challenge_method", pkce_challenge.method().as_str())); - } - - if let Some(ref redirect_url) = self.redirect_url { - pairs.push(("redirect_uri", redirect_url.as_str())); - } - - if !scopes.is_empty() { - pairs.push(("scope", &scopes)); - } - - let mut url: Url = self.auth_url.url().to_owned(); - - url.query_pairs_mut() - .extend_pairs(pairs.iter().map(|&(k, v)| (k, v))); - - url.query_pairs_mut() - .extend_pairs(self.extra_params.iter().cloned()); - url - }; - - (url, self.state) - } -} - -/// An HTTP request. -#[derive(Clone, Debug)] -pub struct HttpRequest { - // These are all owned values so that the request can safely be passed between - // threads. - /// URL to which the HTTP request is being made. - pub url: Url, - /// HTTP request method for this request. - pub method: http::method::Method, - /// HTTP request headers to send. - pub headers: HeaderMap, - /// HTTP request body (typically for POST requests only). - pub body: Vec, -} - -/// An HTTP response. -#[derive(Clone, Debug)] -pub struct HttpResponse { - /// HTTP status code returned by the server. - pub status_code: http::status::StatusCode, - /// HTTP response headers returned by the server. - pub headers: HeaderMap, - /// HTTP response body returned by the server. - pub body: Vec, -} - -/// A request to exchange an authorization code for an access token. -/// -/// See . -#[derive(Debug)] -pub struct CodeTokenRequest<'a, TE, TR, TT> -where - TE: ErrorResponse, - TR: TokenResponse, - TT: TokenType, -{ - auth_type: &'a AuthType, - client_id: &'a ClientId, - client_secret: Option<&'a ClientSecret>, - code: AuthorizationCode, - extra_params: Vec<(Cow<'a, str>, Cow<'a, str>)>, - pkce_verifier: Option, - token_url: &'a TokenUrl, - redirect_url: Option>, - _phantom: PhantomData<(TE, TR, TT)>, -} -impl<'a, TE, TR, TT> CodeTokenRequest<'a, TE, TR, TT> -where - TE: ErrorResponse + 'static, - TR: TokenResponse, - TT: TokenType, -{ - /// Appends an extra param to the token request. - /// - /// This method allows extensions to be used without direct support from - /// this crate. If `name` conflicts with a parameter managed by this crate, the - /// behavior is undefined. In particular, do not set parameters defined by - /// [RFC 6749](https://tools.ietf.org/html/rfc6749) or - /// [RFC 7636](https://tools.ietf.org/html/rfc7636). - /// - /// # Security Warning - /// - /// Callers should follow the security recommendations for any OAuth2 extensions used with - /// this function, which are beyond the scope of - /// [RFC 6749](https://tools.ietf.org/html/rfc6749). - pub fn add_extra_param(mut self, name: N, value: V) -> Self - where - N: Into>, - V: Into>, - { - self.extra_params.push((name.into(), value.into())); - self - } - - /// Completes the [Proof Key for Code Exchange](https://tools.ietf.org/html/rfc7636) - /// (PKCE) protocol flow. - /// - /// This method must be called if [`AuthorizationRequest::set_pkce_challenge`] was used during - /// the authorization request. - pub fn set_pkce_verifier(mut self, pkce_verifier: PkceCodeVerifier) -> Self { - self.pkce_verifier = Some(pkce_verifier); - self - } - - /// Overrides the `redirect_url` to the one specified. - pub fn set_redirect_uri(mut self, redirect_url: Cow<'a, RedirectUrl>) -> Self { - self.redirect_url = Some(redirect_url); - self - } - - fn prepare_request(self) -> HttpRequest { - let mut params = vec![ - ("grant_type", "authorization_code"), - ("code", self.code.secret()), - ]; - if let Some(ref pkce_verifier) = self.pkce_verifier { - params.push(("code_verifier", pkce_verifier.secret())); - } - - endpoint_request( - self.auth_type, - self.client_id, - self.client_secret, - &self.extra_params, - self.redirect_url, - None, - self.token_url.url(), - params, - ) - } - - /// Synchronously sends the request to the authorization server and awaits a response. - pub fn request(self, http_client: F) -> Result> - where - F: FnOnce(HttpRequest) -> Result, - RE: Error + 'static, - { - endpoint_response(http_client(self.prepare_request())?) - } - - /// Asynchronously sends the request to the authorization server and returns a Future. - pub async fn request_async( - self, - http_client: C, - ) -> Result> - where - C: FnOnce(HttpRequest) -> F, - F: Future>, - RE: Error + 'static, - { - let http_response = http_client(self.prepare_request()).await?; - endpoint_response(http_response) - } -} - -/// A request to exchange a refresh token for an access token. -/// -/// See . -#[derive(Debug)] -pub struct RefreshTokenRequest<'a, TE, TR, TT> -where - TE: ErrorResponse, - TR: TokenResponse, - TT: TokenType, -{ - auth_type: &'a AuthType, - client_id: &'a ClientId, - client_secret: Option<&'a ClientSecret>, - extra_params: Vec<(Cow<'a, str>, Cow<'a, str>)>, - refresh_token: &'a RefreshToken, - scopes: Vec>, - token_url: &'a TokenUrl, - _phantom: PhantomData<(TE, TR, TT)>, -} -impl<'a, TE, TR, TT> RefreshTokenRequest<'a, TE, TR, TT> -where - TE: ErrorResponse + 'static, - TR: TokenResponse, - TT: TokenType, -{ - /// Appends an extra param to the token request. - /// - /// This method allows extensions to be used without direct support from - /// this crate. If `name` conflicts with a parameter managed by this crate, the - /// behavior is undefined. In particular, do not set parameters defined by - /// [RFC 6749](https://tools.ietf.org/html/rfc6749) or - /// [RFC 7636](https://tools.ietf.org/html/rfc7636). - /// - /// # Security Warning - /// - /// Callers should follow the security recommendations for any OAuth2 extensions used with - /// this function, which are beyond the scope of - /// [RFC 6749](https://tools.ietf.org/html/rfc6749). - pub fn add_extra_param(mut self, name: N, value: V) -> Self - where - N: Into>, - V: Into>, - { - self.extra_params.push((name.into(), value.into())); - self - } - - /// Appends a new scope to the token request. - pub fn add_scope(mut self, scope: Scope) -> Self { - self.scopes.push(Cow::Owned(scope)); - self - } - - /// Appends a collection of scopes to the token request. - pub fn add_scopes(mut self, scopes: I) -> Self - where - I: IntoIterator, - { - self.scopes.extend(scopes.into_iter().map(Cow::Owned)); - self - } - - /// Synchronously sends the request to the authorization server and awaits a response. - pub fn request(self, http_client: F) -> Result> - where - F: FnOnce(HttpRequest) -> Result, - RE: Error + 'static, - { - endpoint_response(http_client(self.prepare_request()?)?) - } - /// Asynchronously sends the request to the authorization server and awaits a response. - pub async fn request_async( - self, - http_client: C, - ) -> Result> - where - C: FnOnce(HttpRequest) -> F, - F: Future>, - RE: Error + 'static, - { - let http_request = self.prepare_request()?; - let http_response = http_client(http_request).await?; - endpoint_response(http_response) - } - - fn prepare_request(&self) -> Result> - where - RE: Error + 'static, - { - Ok(endpoint_request( - self.auth_type, - self.client_id, - self.client_secret, - &self.extra_params, - None, - Some(&self.scopes), - self.token_url.url(), - vec![ - ("grant_type", "refresh_token"), - ("refresh_token", self.refresh_token.secret()), - ], - )) - } -} - -/// A request to exchange resource owner credentials for an access token. -/// -/// See . -#[derive(Debug)] -pub struct PasswordTokenRequest<'a, TE, TR, TT> -where - TE: ErrorResponse, - TR: TokenResponse, - TT: TokenType, -{ - auth_type: &'a AuthType, - client_id: &'a ClientId, - client_secret: Option<&'a ClientSecret>, - extra_params: Vec<(Cow<'a, str>, Cow<'a, str>)>, - username: &'a ResourceOwnerUsername, - password: &'a ResourceOwnerPassword, - scopes: Vec>, - token_url: &'a TokenUrl, - _phantom: PhantomData<(TE, TR, TT)>, -} -impl<'a, TE, TR, TT> PasswordTokenRequest<'a, TE, TR, TT> -where - TE: ErrorResponse + 'static, - TR: TokenResponse, - TT: TokenType, -{ - /// Appends an extra param to the token request. - /// - /// This method allows extensions to be used without direct support from - /// this crate. If `name` conflicts with a parameter managed by this crate, the - /// behavior is undefined. In particular, do not set parameters defined by - /// [RFC 6749](https://tools.ietf.org/html/rfc6749) or - /// [RFC 7636](https://tools.ietf.org/html/rfc7636). - /// - /// # Security Warning - /// - /// Callers should follow the security recommendations for any OAuth2 extensions used with - /// this function, which are beyond the scope of - /// [RFC 6749](https://tools.ietf.org/html/rfc6749). - pub fn add_extra_param(mut self, name: N, value: V) -> Self - where - N: Into>, - V: Into>, - { - self.extra_params.push((name.into(), value.into())); - self - } - - /// Appends a new scope to the token request. - pub fn add_scope(mut self, scope: Scope) -> Self { - self.scopes.push(Cow::Owned(scope)); - self - } - - /// Appends a collection of scopes to the token request. - pub fn add_scopes(mut self, scopes: I) -> Self - where - I: IntoIterator, - { - self.scopes.extend(scopes.into_iter().map(Cow::Owned)); - self - } - - /// Synchronously sends the request to the authorization server and awaits a response. - pub fn request(self, http_client: F) -> Result> - where - F: FnOnce(HttpRequest) -> Result, - RE: Error + 'static, - { - endpoint_response(http_client(self.prepare_request()?)?) - } - - /// Asynchronously sends the request to the authorization server and awaits a response. - pub async fn request_async( - self, - http_client: C, - ) -> Result> - where - C: FnOnce(HttpRequest) -> F, - F: Future>, - RE: Error + 'static, - { - let http_request = self.prepare_request()?; - let http_response = http_client(http_request).await?; - endpoint_response(http_response) - } - - fn prepare_request(&self) -> Result> - where - RE: Error + 'static, - { - Ok(endpoint_request( - self.auth_type, - self.client_id, - self.client_secret, - &self.extra_params, - None, - Some(&self.scopes), - self.token_url.url(), - vec![ - ("grant_type", "password"), - ("username", self.username), - ("password", self.password.secret()), - ], - )) - } -} - -/// A request to exchange client credentials for an access token. -/// -/// See . -#[derive(Debug)] -pub struct ClientCredentialsTokenRequest<'a, TE, TR, TT> -where - TE: ErrorResponse, - TR: TokenResponse, - TT: TokenType, -{ - auth_type: &'a AuthType, - client_id: &'a ClientId, - client_secret: Option<&'a ClientSecret>, - extra_params: Vec<(Cow<'a, str>, Cow<'a, str>)>, - scopes: Vec>, - token_url: &'a TokenUrl, - _phantom: PhantomData<(TE, TR, TT)>, -} -impl<'a, TE, TR, TT> ClientCredentialsTokenRequest<'a, TE, TR, TT> -where - TE: ErrorResponse + 'static, - TR: TokenResponse, - TT: TokenType, -{ - /// Appends an extra param to the token request. - /// - /// This method allows extensions to be used without direct support from - /// this crate. If `name` conflicts with a parameter managed by this crate, the - /// behavior is undefined. In particular, do not set parameters defined by - /// [RFC 6749](https://tools.ietf.org/html/rfc6749) or - /// [RFC 7636](https://tools.ietf.org/html/rfc7636). - /// - /// # Security Warning - /// - /// Callers should follow the security recommendations for any OAuth2 extensions used with - /// this function, which are beyond the scope of - /// [RFC 6749](https://tools.ietf.org/html/rfc6749). - pub fn add_extra_param(mut self, name: N, value: V) -> Self - where - N: Into>, - V: Into>, - { - self.extra_params.push((name.into(), value.into())); - self - } - - /// Appends a new scope to the token request. - pub fn add_scope(mut self, scope: Scope) -> Self { - self.scopes.push(Cow::Owned(scope)); - self - } - - /// Appends a collection of scopes to the token request. - pub fn add_scopes(mut self, scopes: I) -> Self - where - I: IntoIterator, - { - self.scopes.extend(scopes.into_iter().map(Cow::Owned)); - self - } - - /// Synchronously sends the request to the authorization server and awaits a response. - pub fn request(self, http_client: F) -> Result> - where - F: FnOnce(HttpRequest) -> Result, - RE: Error + 'static, - { - endpoint_response(http_client(self.prepare_request()?)?) - } - - /// Asynchronously sends the request to the authorization server and awaits a response. - pub async fn request_async( - self, - http_client: C, - ) -> Result> - where - C: FnOnce(HttpRequest) -> F, - F: Future>, - RE: Error + 'static, - { - let http_request = self.prepare_request()?; - let http_response = http_client(http_request).await?; - endpoint_response(http_response) - } - - fn prepare_request(&self) -> Result> - where - RE: Error + 'static, - { - Ok(endpoint_request( - self.auth_type, - self.client_id, - self.client_secret, - &self.extra_params, - None, - Some(&self.scopes), - self.token_url.url(), - vec![("grant_type", "client_credentials")], - )) - } -} - -/// A request to introspect an access token. -/// -/// See . -#[derive(Debug)] -pub struct IntrospectionRequest<'a, TE, TIR, TT> -where - TE: ErrorResponse, - TIR: TokenIntrospectionResponse, - TT: TokenType, -{ - token: &'a AccessToken, - token_type_hint: Option>, - - auth_type: &'a AuthType, - client_id: &'a ClientId, - client_secret: Option<&'a ClientSecret>, - extra_params: Vec<(Cow<'a, str>, Cow<'a, str>)>, - introspection_url: &'a IntrospectionUrl, - - _phantom: PhantomData<(TE, TIR, TT)>, -} - -impl<'a, TE, TIR, TT> IntrospectionRequest<'a, TE, TIR, TT> -where - TE: ErrorResponse + 'static, - TIR: TokenIntrospectionResponse, - TT: TokenType, -{ - /// Sets the optional token_type_hint parameter. - /// - /// See . - /// - /// OPTIONAL. A hint about the type of the token submitted for - /// introspection. The protected resource MAY pass this parameter to - /// help the authorization server optimize the token lookup. If the - /// server is unable to locate the token using the given hint, it MUST - /// extend its search across all of its supported token types. An - /// authorization server MAY ignore this parameter, particularly if it - /// is able to detect the token type automatically. Values for this - /// field are defined in the "OAuth Token Type Hints" registry defined - /// in OAuth Token Revocation [RFC7009](https://tools.ietf.org/html/rfc7009). - pub fn set_token_type_hint(mut self, value: V) -> Self - where - V: Into>, - { - self.token_type_hint = Some(value.into()); - - self - } - - /// Appends an extra param to the token introspection request. - /// - /// This method allows extensions to be used without direct support from - /// this crate. If `name` conflicts with a parameter managed by this crate, the - /// behavior is undefined. In particular, do not set parameters defined by - /// [RFC 6749](https://tools.ietf.org/html/rfc6749) or - /// [RFC 7662](https://tools.ietf.org/html/rfc7662). - /// - /// # Security Warning - /// - /// Callers should follow the security recommendations for any OAuth2 extensions used with - /// this function, which are beyond the scope of - /// [RFC 6749](https://tools.ietf.org/html/rfc6749). - pub fn add_extra_param(mut self, name: N, value: V) -> Self - where - N: Into>, - V: Into>, - { - self.extra_params.push((name.into(), value.into())); - self - } - - fn prepare_request(self) -> HttpRequest { - let mut params: Vec<(&str, &str)> = vec![("token", self.token.secret())]; - if let Some(ref token_type_hint) = self.token_type_hint { - params.push(("token_type_hint", token_type_hint)); - } - - endpoint_request( - self.auth_type, - self.client_id, - self.client_secret, - &self.extra_params, - None, - None, - self.introspection_url.url(), - params, - ) - } - - /// Synchronously sends the request to the authorization server and awaits a response. - pub fn request(self, http_client: F) -> Result> - where - F: FnOnce(HttpRequest) -> Result, - RE: Error + 'static, - { - endpoint_response(http_client(self.prepare_request())?) - } - - /// Asynchronously sends the request to the authorization server and returns a Future. - pub async fn request_async( - self, - http_client: C, - ) -> Result> - where - C: FnOnce(HttpRequest) -> F, - F: Future>, - RE: Error + 'static, - { - let http_response = http_client(self.prepare_request()).await?; - endpoint_response(http_response) - } -} - -/// A request to revoke a token via an [`RFC 7009`](https://tools.ietf.org/html/rfc7009#section-2.1) compatible -/// endpoint. -#[derive(Debug)] -pub struct RevocationRequest<'a, RT, TE> -where - RT: RevocableToken, - TE: ErrorResponse, -{ - token: RT, - - auth_type: &'a AuthType, - client_id: &'a ClientId, - client_secret: Option<&'a ClientSecret>, - extra_params: Vec<(Cow<'a, str>, Cow<'a, str>)>, - revocation_url: &'a RevocationUrl, - - _phantom: PhantomData<(RT, TE)>, -} - -impl<'a, RT, TE> RevocationRequest<'a, RT, TE> -where - RT: RevocableToken, - TE: ErrorResponse + 'static, -{ - /// Appends an extra param to the token revocation request. - /// - /// This method allows extensions to be used without direct support from - /// this crate. If `name` conflicts with a parameter managed by this crate, the - /// behavior is undefined. In particular, do not set parameters defined by - /// [RFC 6749](https://tools.ietf.org/html/rfc6749) or - /// [RFC 7662](https://tools.ietf.org/html/rfc7662). - /// - /// # Security Warning - /// - /// Callers should follow the security recommendations for any OAuth2 extensions used with - /// this function, which are beyond the scope of - /// [RFC 6749](https://tools.ietf.org/html/rfc6749). - pub fn add_extra_param(mut self, name: N, value: V) -> Self - where - N: Into>, - V: Into>, - { - self.extra_params.push((name.into(), value.into())); - self - } - - fn prepare_request(self) -> HttpRequest { - let mut params: Vec<(&str, &str)> = vec![("token", self.token.secret())]; - if let Some(type_hint) = self.token.type_hint() { - params.push(("token_type_hint", type_hint)); - } - - endpoint_request( - self.auth_type, - self.client_id, - self.client_secret, - &self.extra_params, - None, - None, - self.revocation_url.url(), - params, - ) - } - - /// Synchronously sends the request to the authorization server and awaits a response. - /// - /// A successful response indicates that the server either revoked the token or the token was not known to the - /// server. - /// - /// Error [`UnsupportedTokenType`](crate::revocation::RevocationErrorResponseType::UnsupportedTokenType) will be returned if the - /// type of token type given is not supported by the server. - pub fn request(self, http_client: F) -> Result<(), RequestTokenError> - where - F: FnOnce(HttpRequest) -> Result, - RE: Error + 'static, - { - // From https://tools.ietf.org/html/rfc7009#section-2.2: - // "The content of the response body is ignored by the client as all - // necessary information is conveyed in the response code." - endpoint_response_status_only(http_client(self.prepare_request())?) - } - - /// Asynchronously sends the request to the authorization server and returns a Future. - pub async fn request_async( - self, - http_client: C, - ) -> Result<(), RequestTokenError> - where - C: FnOnce(HttpRequest) -> F, - F: Future>, - RE: Error + 'static, - { - let http_response = http_client(self.prepare_request()).await?; - endpoint_response_status_only(http_response) - } -} - -#[allow(clippy::too_many_arguments)] -fn endpoint_request<'a>( - auth_type: &'a AuthType, - client_id: &'a ClientId, - client_secret: Option<&'a ClientSecret>, - extra_params: &'a [(Cow<'a, str>, Cow<'a, str>)], - redirect_url: Option>, - scopes: Option<&'a Vec>>, - url: &'a Url, - params: Vec<(&'a str, &'a str)>, -) -> HttpRequest { - let mut headers = HeaderMap::new(); - headers.append(ACCEPT, HeaderValue::from_static(CONTENT_TYPE_JSON)); - headers.append( - CONTENT_TYPE, - HeaderValue::from_static(CONTENT_TYPE_FORMENCODED), - ); - - let scopes_opt = scopes.and_then(|scopes| { - if !scopes.is_empty() { - Some( - scopes - .iter() - .map(|s| s.to_string()) - .collect::>() - .join(" "), - ) - } else { - None - } - }); - - let mut params: Vec<(&str, &str)> = params; - if let Some(ref scopes) = scopes_opt { - params.push(("scope", scopes)); - } - - // FIXME: add support for auth extensions? e.g., client_secret_jwt and private_key_jwt - match (auth_type, client_secret) { - // Basic auth only makes sense when a client secret is provided. Otherwise, always pass the - // client ID in the request body. - (AuthType::BasicAuth, Some(secret)) => { - // Section 2.3.1 of RFC 6749 requires separately url-encoding the id and secret - // before using them as HTTP Basic auth username and password. Note that this is - // not standard for ordinary Basic auth, so curl won't do it for us. - let urlencoded_id: String = - form_urlencoded::byte_serialize(client_id.as_bytes()).collect(); - let urlencoded_secret: String = - form_urlencoded::byte_serialize(secret.secret().as_bytes()).collect(); - let b64_credential = - base64::encode(format!("{}:{}", &urlencoded_id, urlencoded_secret)); - headers.append( - AUTHORIZATION, - HeaderValue::from_str(&format!("Basic {}", &b64_credential)).unwrap(), - ); - } - (AuthType::RequestBody, _) | (AuthType::BasicAuth, None) => { - params.push(("client_id", client_id)); - if let Some(client_secret) = client_secret { - params.push(("client_secret", client_secret.secret())); - } - } - } - - if let Some(ref redirect_url) = redirect_url { - params.push(("redirect_uri", redirect_url.as_str())); - } - - params.extend_from_slice( - extra_params - .iter() - .map(|(k, v)| (k.as_ref(), v.as_ref())) - .collect::>() - .as_slice(), - ); - - let body = url::form_urlencoded::Serializer::new(String::new()) - .extend_pairs(params) - .finish() - .into_bytes(); - - HttpRequest { - url: url.to_owned(), - method: http::method::Method::POST, - headers, - body, - } -} - -fn endpoint_response( - http_response: HttpResponse, -) -> Result> -where - RE: Error + 'static, - TE: ErrorResponse, - DO: DeserializeOwned, -{ - check_response_status(&http_response)?; - - check_response_body(&http_response)?; - - let response_body = http_response.body.as_slice(); - serde_path_to_error::deserialize(&mut serde_json::Deserializer::from_slice(response_body)) - .map_err(|e| RequestTokenError::Parse(e, response_body.to_vec())) -} - -fn endpoint_response_status_only( - http_response: HttpResponse, -) -> Result<(), RequestTokenError> -where - RE: Error + 'static, - TE: ErrorResponse, -{ - check_response_status(&http_response) -} - -fn check_response_status( - http_response: &HttpResponse, -) -> Result<(), RequestTokenError> -where - RE: Error + 'static, - TE: ErrorResponse, -{ - if http_response.status_code != StatusCode::OK { - let reason = http_response.body.as_slice(); - if reason.is_empty() { - return Err(RequestTokenError::Other( - "Server returned empty error response".to_string(), - )); - } else { - let error = match serde_path_to_error::deserialize::<_, TE>( - &mut serde_json::Deserializer::from_slice(reason), - ) { - Ok(error) => RequestTokenError::ServerResponse(error), - Err(error) => RequestTokenError::Parse(error, reason.to_vec()), - }; - return Err(error); - } - } - - Ok(()) -} - -fn check_response_body( - http_response: &HttpResponse, -) -> Result<(), RequestTokenError> -where - RE: Error + 'static, - TE: ErrorResponse, -{ - // Validate that the response Content-Type is JSON. - http_response - .headers - .get(CONTENT_TYPE) - .map_or(Ok(()), |content_type| - // Section 3.1.1.1 of RFC 7231 indicates that media types are case insensitive and - // may be followed by optional whitespace and/or a parameter (e.g., charset). - // See https://tools.ietf.org/html/rfc7231#section-3.1.1.1. - if content_type.to_str().ok().filter(|ct| ct.to_lowercase().starts_with(CONTENT_TYPE_JSON)).is_none() { - Err( - RequestTokenError::Other( - format!( - "Unexpected response Content-Type: {:?}, should be `{}`", - content_type, - CONTENT_TYPE_JSON - ) - ) - ) - } else { - Ok(()) - } - )?; - - if http_response.body.is_empty() { - return Err(RequestTokenError::Other( - "Server returned empty response body".to_string(), - )); - } - - Ok(()) -} - -/// The request for a set of verification codes from the authorization server. -/// -/// See . -#[derive(Debug)] -pub struct DeviceAuthorizationRequest<'a, TE> -where - TE: ErrorResponse, -{ - auth_type: &'a AuthType, - client_id: &'a ClientId, - client_secret: Option<&'a ClientSecret>, - extra_params: Vec<(Cow<'a, str>, Cow<'a, str>)>, - scopes: Vec>, - device_authorization_url: &'a DeviceAuthorizationUrl, - _phantom: PhantomData, -} - -impl<'a, TE> DeviceAuthorizationRequest<'a, TE> -where - TE: ErrorResponse + 'static, -{ - /// Appends an extra param to the token request. - /// - /// This method allows extensions to be used without direct support from - /// this crate. If `name` conflicts with a parameter managed by this crate, the - /// behavior is undefined. In particular, do not set parameters defined by - /// [RFC 6749](https://tools.ietf.org/html/rfc6749) or - /// [RFC 7636](https://tools.ietf.org/html/rfc7636). - /// - /// # Security Warning - /// - /// Callers should follow the security recommendations for any OAuth2 extensions used with - /// this function, which are beyond the scope of - /// [RFC 6749](https://tools.ietf.org/html/rfc6749). - pub fn add_extra_param(mut self, name: N, value: V) -> Self - where - N: Into>, - V: Into>, - { - self.extra_params.push((name.into(), value.into())); - self - } - - /// Appends a new scope to the token request. - pub fn add_scope(mut self, scope: Scope) -> Self { - self.scopes.push(Cow::Owned(scope)); - self - } - - /// Appends a collection of scopes to the token request. - pub fn add_scopes(mut self, scopes: I) -> Self - where - I: IntoIterator, - { - self.scopes.extend(scopes.into_iter().map(Cow::Owned)); - self - } - - fn prepare_request(self) -> HttpRequest { - endpoint_request( - self.auth_type, - self.client_id, - self.client_secret, - &self.extra_params, - None, - Some(&self.scopes), - self.device_authorization_url.url(), - vec![], - ) - } - - /// Synchronously sends the request to the authorization server and awaits a response. - pub fn request( - self, - http_client: F, - ) -> Result, RequestTokenError> - where - F: FnOnce(HttpRequest) -> Result, - RE: Error + 'static, - EF: ExtraDeviceAuthorizationFields, - { - endpoint_response(http_client(self.prepare_request())?) - } - - /// Asynchronously sends the request to the authorization server and returns a Future. - pub async fn request_async( - self, - http_client: C, - ) -> Result, RequestTokenError> - where - C: FnOnce(HttpRequest) -> F, - F: Future>, - RE: Error + 'static, - EF: ExtraDeviceAuthorizationFields, - { - let http_response = http_client(self.prepare_request()).await?; - endpoint_response(http_response) - } -} - -/// The request for an device access token from the authorization server. -/// -/// See . -#[derive(Clone)] -pub struct DeviceAccessTokenRequest<'a, 'b, TR, TT, EF> -where - TR: TokenResponse, - TT: TokenType, - EF: ExtraDeviceAuthorizationFields, -{ - auth_type: &'a AuthType, - client_id: &'a ClientId, - client_secret: Option<&'a ClientSecret>, - extra_params: Vec<(Cow<'a, str>, Cow<'a, str>)>, - token_url: &'a TokenUrl, - dev_auth_resp: &'a DeviceAuthorizationResponse, - time_fn: Arc DateTime + 'b + Send + Sync>, - max_backoff_interval: Option, - _phantom: PhantomData<(TR, TT, EF)>, -} - -impl<'a, 'b, TR, TT, EF> DeviceAccessTokenRequest<'a, 'b, TR, TT, EF> -where - TR: TokenResponse, - TT: TokenType, - EF: ExtraDeviceAuthorizationFields, -{ - /// Appends an extra param to the token request. - /// - /// This method allows extensions to be used without direct support from - /// this crate. If `name` conflicts with a parameter managed by this crate, the - /// behavior is undefined. In particular, do not set parameters defined by - /// [RFC 6749](https://tools.ietf.org/html/rfc6749) or - /// [RFC 7636](https://tools.ietf.org/html/rfc7636). - /// - /// # Security Warning - /// - /// Callers should follow the security recommendations for any OAuth2 extensions used with - /// this function, which are beyond the scope of - /// [RFC 6749](https://tools.ietf.org/html/rfc6749). - pub fn add_extra_param(mut self, name: N, value: V) -> Self - where - N: Into>, - V: Into>, - { - self.extra_params.push((name.into(), value.into())); - self - } - - /// Specifies a function for returning the current time. - /// - /// This function is used while polling the authorization server. - pub fn set_time_fn(mut self, time_fn: T) -> Self - where - T: Fn() -> DateTime + 'b + Send + Sync, - { - self.time_fn = Arc::new(time_fn); - self - } - - /// Sets the upper limit of the sleep interval to use for polling the token endpoint when the - /// HTTP client returns an error (e.g., in case of connection timeout). - pub fn set_max_backoff_interval(mut self, interval: Duration) -> Self { - self.max_backoff_interval = Some(interval); - self - } - - /// Synchronously polls the authorization server for a response, waiting - /// using a user defined sleep function. - pub fn request( - self, - http_client: F, - sleep_fn: S, - timeout: Option, - ) -> Result> - where - F: Fn(HttpRequest) -> Result, - S: Fn(Duration), - RE: Error + 'static, - { - // Get the request timeout and starting interval - let timeout_dt = self.compute_timeout(timeout)?; - let mut interval = self.dev_auth_resp.interval(); - - // Loop while requesting a token. - loop { - let now = (*self.time_fn)(); - if now > timeout_dt { - break Err(RequestTokenError::ServerResponse( - DeviceCodeErrorResponse::new( - DeviceCodeErrorResponseType::ExpiredToken, - Some(String::from("This device code has expired.")), - None, - ), - )); - } - - match self.process_response(http_client(self.prepare_request()), interval) { - DeviceAccessTokenPollResult::ContinueWithNewPollInterval(new_interval) => { - interval = new_interval - } - DeviceAccessTokenPollResult::Done(res, _) => break res, - } - - // Sleep here using the provided sleep function. - sleep_fn(interval); - } - } - - /// Asynchronously sends the request to the authorization server and awaits a response. - pub async fn request_async( - self, - http_client: C, - sleep_fn: S, - timeout: Option, - ) -> Result> - where - C: Fn(HttpRequest) -> F, - F: Future>, - S: Fn(Duration) -> SF, - SF: Future, - RE: Error + 'static, - { - // Get the request timeout and starting interval - let timeout_dt = self.compute_timeout(timeout)?; - let mut interval = self.dev_auth_resp.interval(); - - // Loop while requesting a token. - loop { - let now = (*self.time_fn)(); - if now > timeout_dt { - break Err(RequestTokenError::ServerResponse( - DeviceCodeErrorResponse::new( - DeviceCodeErrorResponseType::ExpiredToken, - Some(String::from("This device code has expired.")), - None, - ), - )); - } - - match self.process_response(http_client(self.prepare_request()).await, interval) { - DeviceAccessTokenPollResult::ContinueWithNewPollInterval(new_interval) => { - interval = new_interval - } - DeviceAccessTokenPollResult::Done(res, _) => break res, - } - - // Sleep here using the provided sleep function. - sleep_fn(interval).await; - } - } - - fn prepare_request(&self) -> HttpRequest { - endpoint_request( - self.auth_type, - self.client_id, - self.client_secret, - &self.extra_params, - None, - None, - self.token_url.url(), - vec![ - ("grant_type", "urn:ietf:params:oauth:grant-type:device_code"), - ("device_code", self.dev_auth_resp.device_code().secret()), - ], - ) - } - - fn process_response( - &self, - res: Result, - current_interval: Duration, - ) -> DeviceAccessTokenPollResult - where - RE: Error + 'static, - { - let http_response = match res { - Ok(inner) => inner, - Err(_) => { - // RFC 8628 requires a backoff in cases of connection timeout, but we can't - // distinguish between connection timeouts and other HTTP client request errors - // here. Set a maximum backoff so that the client doesn't effectively backoff - // infinitely when there are network issues unrelated to server load. - const DEFAULT_MAX_BACKOFF_INTERVAL: Duration = Duration::from_secs(10); - let new_interval = std::cmp::min( - current_interval.checked_mul(2).unwrap_or(current_interval), - self.max_backoff_interval - .unwrap_or(DEFAULT_MAX_BACKOFF_INTERVAL), - ); - return DeviceAccessTokenPollResult::ContinueWithNewPollInterval(new_interval); - } - }; - - // Explicitly process the response with a DeviceCodeErrorResponse - let res = endpoint_response::(http_response); - match res { - // On a ServerResponse error, the error needs inspecting as a DeviceCodeErrorResponse - // to work out whether a retry needs to happen. - Err(RequestTokenError::ServerResponse(dcer)) => { - match dcer.error() { - // On AuthorizationPending, a retry needs to happen with the same poll interval. - DeviceCodeErrorResponseType::AuthorizationPending => { - DeviceAccessTokenPollResult::ContinueWithNewPollInterval(current_interval) - } - // On SlowDown, a retry needs to happen with a larger poll interval. - DeviceCodeErrorResponseType::SlowDown => { - DeviceAccessTokenPollResult::ContinueWithNewPollInterval( - current_interval + Duration::from_secs(5), - ) - } - - // On any other error, just return the error. - _ => DeviceAccessTokenPollResult::Done( - Err(RequestTokenError::ServerResponse(dcer)), - PhantomData, - ), - } - } - - // On any other success or failure, return the failure. - res => DeviceAccessTokenPollResult::Done(res, PhantomData), - } - } - - fn compute_timeout( - &self, - timeout: Option, - ) -> Result, RequestTokenError> - where - RE: Error + 'static, - { - // Calculate the request timeout - if the user specified a timeout, - // use that, otherwise use the value given by the device authorization - // response. - let timeout_dur = timeout.unwrap_or_else(|| self.dev_auth_resp.expires_in()); - let chrono_timeout = chrono::Duration::from_std(timeout_dur).map_err(|e| { - RequestTokenError::Other(format!( - "Failed to convert `{:?}` to `chrono::Duration`: {}", - timeout_dur, e - )) - })?; - - // Calculate the DateTime at which the request times out. - let timeout_dt = (*self.time_fn)() - .checked_add_signed(chrono_timeout) - .ok_or_else(|| RequestTokenError::Other("Failed to calculate timeout".to_string()))?; - - Ok(timeout_dt) - } -} - -/// Trait for OAuth2 access tokens. -pub trait TokenType: Clone + DeserializeOwned + Debug + PartialEq + Serialize {} - -/// Trait for adding extra fields to the `TokenResponse`. -pub trait ExtraTokenFields: DeserializeOwned + Debug + Serialize {} - -/// Empty (default) extra token fields. -#[derive(Clone, Debug, Deserialize, PartialEq, Eq, Serialize)] -pub struct EmptyExtraTokenFields {} -impl ExtraTokenFields for EmptyExtraTokenFields {} - -/// Common methods shared by all OAuth2 token implementations. -/// -/// The methods in this trait are defined in -/// [Section 5.1 of RFC 6749](https://tools.ietf.org/html/rfc6749#section-5.1). This trait exists -/// separately from the `StandardTokenResponse` struct to support customization by clients, -/// such as supporting interoperability with non-standards-complaint OAuth2 providers. -pub trait TokenResponse: Debug + DeserializeOwned + Serialize -where - TT: TokenType, -{ - /// REQUIRED. The access token issued by the authorization server. - fn access_token(&self) -> &AccessToken; - /// REQUIRED. The type of the token issued as described in - /// [Section 7.1](https://tools.ietf.org/html/rfc6749#section-7.1). - /// Value is case insensitive and deserialized to the generic `TokenType` parameter. - fn token_type(&self) -> &TT; - /// RECOMMENDED. The lifetime in seconds of the access token. For example, the value 3600 - /// denotes that the access token will expire in one hour from the time the response was - /// generated. If omitted, the authorization server SHOULD provide the expiration time via - /// other means or document the default value. - fn expires_in(&self) -> Option; - /// OPTIONAL. The refresh token, which can be used to obtain new access tokens using the same - /// authorization grant as described in - /// [Section 6](https://tools.ietf.org/html/rfc6749#section-6). - fn refresh_token(&self) -> Option<&RefreshToken>; - /// OPTIONAL, if identical to the scope requested by the client; otherwise, REQUIRED. The - /// scope of the access token as described by - /// [Section 3.3](https://tools.ietf.org/html/rfc6749#section-3.3). If included in the response, - /// this space-delimited field is parsed into a `Vec` of individual scopes. If omitted from - /// the response, this field is `None`. - fn scopes(&self) -> Option<&Vec>; -} - -/// Standard OAuth2 token response. -/// -/// This struct includes the fields defined in -/// [Section 5.1 of RFC 6749](https://tools.ietf.org/html/rfc6749#section-5.1), as well as -/// extensions defined by the `EF` type parameter. -#[derive(Clone, Debug, Deserialize, Serialize)] -pub struct StandardTokenResponse -where - EF: ExtraTokenFields, - TT: TokenType, -{ - access_token: AccessToken, - #[serde(bound = "TT: TokenType")] - #[serde(deserialize_with = "helpers::deserialize_untagged_enum_case_insensitive")] - token_type: TT, - #[serde(skip_serializing_if = "Option::is_none")] - expires_in: Option, - #[serde(skip_serializing_if = "Option::is_none")] - refresh_token: Option, - #[serde(rename = "scope")] - #[serde(deserialize_with = "helpers::deserialize_space_delimited_vec")] - #[serde(serialize_with = "helpers::serialize_space_delimited_vec")] - #[serde(skip_serializing_if = "Option::is_none")] - #[serde(default)] - scopes: Option>, - - #[serde(bound = "EF: ExtraTokenFields")] - #[serde(flatten)] - extra_fields: EF, -} -impl StandardTokenResponse -where - EF: ExtraTokenFields, - TT: TokenType, -{ - /// Instantiate a new OAuth2 token response. - pub fn new(access_token: AccessToken, token_type: TT, extra_fields: EF) -> Self { - Self { - access_token, - token_type, - expires_in: None, - refresh_token: None, - scopes: None, - extra_fields, - } - } - - /// Set the `access_token` field. - pub fn set_access_token(&mut self, access_token: AccessToken) { - self.access_token = access_token; - } - - /// Set the `token_type` field. - pub fn set_token_type(&mut self, token_type: TT) { - self.token_type = token_type; - } - - /// Set the `expires_in` field. - pub fn set_expires_in(&mut self, expires_in: Option<&Duration>) { - self.expires_in = expires_in.map(Duration::as_secs); - } - - /// Set the `refresh_token` field. - pub fn set_refresh_token(&mut self, refresh_token: Option) { - self.refresh_token = refresh_token; - } - - /// Set the `scopes` field. - pub fn set_scopes(&mut self, scopes: Option>) { - self.scopes = scopes; - } - - /// Extra fields defined by the client application. - pub fn extra_fields(&self) -> &EF { - &self.extra_fields - } - - /// Set the extra fields defined by the client application. - pub fn set_extra_fields(&mut self, extra_fields: EF) { - self.extra_fields = extra_fields; - } -} -impl TokenResponse for StandardTokenResponse -where - EF: ExtraTokenFields, - TT: TokenType, -{ - /// REQUIRED. The access token issued by the authorization server. - fn access_token(&self) -> &AccessToken { - &self.access_token - } - /// REQUIRED. The type of the token issued as described in - /// [Section 7.1](https://tools.ietf.org/html/rfc6749#section-7.1). - /// Value is case insensitive and deserialized to the generic `TokenType` parameter. - fn token_type(&self) -> &TT { - &self.token_type - } - /// RECOMMENDED. The lifetime in seconds of the access token. For example, the value 3600 - /// denotes that the access token will expire in one hour from the time the response was - /// generated. If omitted, the authorization server SHOULD provide the expiration time via - /// other means or document the default value. - fn expires_in(&self) -> Option { - self.expires_in.map(Duration::from_secs) - } - /// OPTIONAL. The refresh token, which can be used to obtain new access tokens using the same - /// authorization grant as described in - /// [Section 6](https://tools.ietf.org/html/rfc6749#section-6). - fn refresh_token(&self) -> Option<&RefreshToken> { - self.refresh_token.as_ref() - } - /// OPTIONAL, if identical to the scope requested by the client; otherwise, REQUIRED. The - /// scope of the access token as described by - /// [Section 3.3](https://tools.ietf.org/html/rfc6749#section-3.3). If included in the response, - /// this space-delimited field is parsed into a `Vec` of individual scopes. If omitted from - /// the response, this field is `None`. - fn scopes(&self) -> Option<&Vec> { - self.scopes.as_ref() - } -} - -/// Common methods shared by all OAuth2 token introspection implementations. -/// -/// The methods in this trait are defined in -/// [Section 2.2 of RFC 7662](https://tools.ietf.org/html/rfc7662#section-2.2). This trait exists -/// separately from the `StandardTokenIntrospectionResponse` struct to support customization by -/// clients, such as supporting interoperability with non-standards-complaint OAuth2 providers. -pub trait TokenIntrospectionResponse: Debug + DeserializeOwned + Serialize -where - TT: TokenType, -{ - /// REQUIRED. Boolean indicator of whether or not the presented token - /// is currently active. The specifics of a token's "active" state - /// will vary depending on the implementation of the authorization - /// server and the information it keeps about its tokens, but a "true" - /// value return for the "active" property will generally indicate - /// that a given token has been issued by this authorization server, - /// has not been revoked by the resource owner, and is within its - /// given time window of validity (e.g., after its issuance time and - /// before its expiration time). - fn active(&self) -> bool; - /// OPTIONAL. A JSON string containing a space-separated list of - /// scopes associated with this token, in the format described in - /// [Section 3.3 of RFC 7662](https://tools.ietf.org/html/rfc7662#section-3.3). - /// If included in the response, - /// this space-delimited field is parsed into a `Vec` of individual scopes. If omitted from - /// the response, this field is `None`. - fn scopes(&self) -> Option<&Vec>; - /// OPTIONAL. Client identifier for the OAuth 2.0 client that - /// requested this token. - fn client_id(&self) -> Option<&ClientId>; - /// OPTIONAL. Human-readable identifier for the resource owner who - /// authorized this token. - fn username(&self) -> Option<&str>; - /// OPTIONAL. Type of the token as defined in - /// [Section 5.1 of RFC 7662](https://tools.ietf.org/html/rfc7662#section-5.1). - /// Value is case insensitive and deserialized to the generic `TokenType` parameter. - fn token_type(&self) -> Option<&TT>; - /// OPTIONAL. Integer timestamp, measured in the number of seconds - /// since January 1 1970 UTC, indicating when this token will expire, - /// as defined in JWT [RFC7519](https://tools.ietf.org/html/rfc7519). - fn exp(&self) -> Option>; - /// OPTIONAL. Integer timestamp, measured in the number of seconds - /// since January 1 1970 UTC, indicating when this token was - /// originally issued, as defined in JWT [RFC7519](https://tools.ietf.org/html/rfc7519). - fn iat(&self) -> Option>; - /// OPTIONAL. Integer timestamp, measured in the number of seconds - /// since January 1 1970 UTC, indicating when this token is not to be - /// used before, as defined in JWT [RFC7519](https://tools.ietf.org/html/rfc7519). - fn nbf(&self) -> Option>; - /// OPTIONAL. Subject of the token, as defined in JWT [RFC7519](https://tools.ietf.org/html/rfc7519). - /// Usually a machine-readable identifier of the resource owner who - /// authorized this token. - fn sub(&self) -> Option<&str>; - /// OPTIONAL. Service-specific string identifier or list of string - /// identifiers representing the intended audience for this token, as - /// defined in JWT [RFC7519](https://tools.ietf.org/html/rfc7519). - fn aud(&self) -> Option<&Vec>; - /// OPTIONAL. String representing the issuer of this token, as - /// defined in JWT [RFC7519](https://tools.ietf.org/html/rfc7519). - fn iss(&self) -> Option<&str>; - /// OPTIONAL. String identifier for the token, as defined in JWT - /// [RFC7519](https://tools.ietf.org/html/rfc7519). - fn jti(&self) -> Option<&str>; -} - -/// Standard OAuth2 token introspection response. -/// -/// This struct includes the fields defined in -/// [Section 2.2 of RFC 7662](https://tools.ietf.org/html/rfc7662#section-2.2), as well as -/// extensions defined by the `EF` type parameter. -#[derive(Clone, Debug, Deserialize, Serialize)] -pub struct StandardTokenIntrospectionResponse -where - EF: ExtraTokenFields, - TT: TokenType + 'static, -{ - active: bool, - #[serde(rename = "scope")] - #[serde(deserialize_with = "helpers::deserialize_space_delimited_vec")] - #[serde(serialize_with = "helpers::serialize_space_delimited_vec")] - #[serde(skip_serializing_if = "Option::is_none")] - #[serde(default)] - scopes: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - client_id: Option, - #[serde(skip_serializing_if = "Option::is_none")] - username: Option, - #[serde( - bound = "TT: TokenType", - skip_serializing_if = "Option::is_none", - deserialize_with = "helpers::deserialize_untagged_enum_case_insensitive", - default = "none_field" - )] - token_type: Option, - #[serde(skip_serializing_if = "Option::is_none")] - #[serde(with = "ts_seconds_option")] - #[serde(default)] - exp: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - #[serde(with = "ts_seconds_option")] - #[serde(default)] - iat: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - #[serde(with = "ts_seconds_option")] - #[serde(default)] - nbf: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - sub: Option, - #[serde(skip_serializing_if = "Option::is_none")] - #[serde(default)] - #[serde(deserialize_with = "helpers::deserialize_optional_string_or_vec_string")] - aud: Option>, - #[serde(skip_serializing_if = "Option::is_none")] - iss: Option, - #[serde(skip_serializing_if = "Option::is_none")] - jti: Option, - - #[serde(bound = "EF: ExtraTokenFields")] - #[serde(flatten)] - extra_fields: EF, -} - -fn none_field() -> Option { - None -} - -impl StandardTokenIntrospectionResponse -where - EF: ExtraTokenFields, - TT: TokenType, -{ - /// Instantiate a new OAuth2 token introspection response. - pub fn new(active: bool, extra_fields: EF) -> Self { - Self { - active, - - scopes: None, - client_id: None, - username: None, - token_type: None, - exp: None, - iat: None, - nbf: None, - sub: None, - aud: None, - iss: None, - jti: None, - extra_fields, - } - } - - /// Sets the `set_active` field. - pub fn set_active(&mut self, active: bool) { - self.active = active; - } - /// Sets the `set_scopes` field. - pub fn set_scopes(&mut self, scopes: Option>) { - self.scopes = scopes; - } - /// Sets the `set_client_id` field. - pub fn set_client_id(&mut self, client_id: Option) { - self.client_id = client_id; - } - /// Sets the `set_username` field. - pub fn set_username(&mut self, username: Option) { - self.username = username; - } - /// Sets the `set_token_type` field. - pub fn set_token_type(&mut self, token_type: Option) { - self.token_type = token_type; - } - /// Sets the `set_exp` field. - pub fn set_exp(&mut self, exp: Option>) { - self.exp = exp; - } - /// Sets the `set_iat` field. - pub fn set_iat(&mut self, iat: Option>) { - self.iat = iat; - } - /// Sets the `set_nbf` field. - pub fn set_nbf(&mut self, nbf: Option>) { - self.nbf = nbf; - } - /// Sets the `set_sub` field. - pub fn set_sub(&mut self, sub: Option) { - self.sub = sub; - } - /// Sets the `set_aud` field. - pub fn set_aud(&mut self, aud: Option>) { - self.aud = aud; - } - /// Sets the `set_iss` field. - pub fn set_iss(&mut self, iss: Option) { - self.iss = iss; - } - /// Sets the `set_jti` field. - pub fn set_jti(&mut self, jti: Option) { - self.jti = jti; - } - /// Extra fields defined by the client application. - pub fn extra_fields(&self) -> &EF { - &self.extra_fields - } - /// Sets the `set_extra_fields` field. - pub fn set_extra_fields(&mut self, extra_fields: EF) { - self.extra_fields = extra_fields; - } -} -impl TokenIntrospectionResponse for StandardTokenIntrospectionResponse -where - EF: ExtraTokenFields, - TT: TokenType, -{ - fn active(&self) -> bool { - self.active - } - - fn scopes(&self) -> Option<&Vec> { - self.scopes.as_ref() - } - - fn client_id(&self) -> Option<&ClientId> { - self.client_id.as_ref() - } - - fn username(&self) -> Option<&str> { - self.username.as_deref() - } - - fn token_type(&self) -> Option<&TT> { - self.token_type.as_ref() - } - - fn exp(&self) -> Option> { - self.exp - } - - fn iat(&self) -> Option> { - self.iat - } - - fn nbf(&self) -> Option> { - self.nbf - } - - fn sub(&self) -> Option<&str> { - self.sub.as_deref() - } - - fn aud(&self) -> Option<&Vec> { - self.aud.as_ref() - } - - fn iss(&self) -> Option<&str> { - self.iss.as_deref() - } - - fn jti(&self) -> Option<&str> { - self.jti.as_deref() - } -} - -/// Server Error Response -/// -/// See [Section 5.2](https://datatracker.ietf.org/doc/html/rfc6749#section-5.2) of RFC 6749. -/// This trait exists separately from the `StandardErrorResponse` struct -/// to support customization by clients, such as supporting interoperability with -/// non-standards-complaint OAuth2 providers. -/// -/// The [`Display`] trait implementation for types implementing [`ErrorResponse`] should be a -/// human-readable string suitable for printing (e.g., within a [`RequestTokenError`]). -pub trait ErrorResponse: Debug + Display + DeserializeOwned + Serialize {} - -/// Error types enum. -/// -/// NOTE: The serialization must return the `snake_case` representation of -/// this error type. This value must match the error type from the relevant OAuth 2.0 standards -/// (RFC 6749 or an extension). -pub trait ErrorResponseType: Debug + DeserializeOwned + Serialize {} - -/// Error response returned by server after requesting an access token. -/// -/// The fields in this structure are defined in -/// [Section 5.2 of RFC 6749](https://tools.ietf.org/html/rfc6749#section-5.2). This -/// trait is parameterized by a `ErrorResponseType` to support error types specific to future OAuth2 -/// authentication schemes and extensions. -#[derive(Clone, Debug, Deserialize, PartialEq, Eq, Serialize)] -pub struct StandardErrorResponse { - #[serde(bound = "T: ErrorResponseType")] - error: T, - #[serde(default)] - #[serde(skip_serializing_if = "Option::is_none")] - error_description: Option, - #[serde(default)] - #[serde(skip_serializing_if = "Option::is_none")] - error_uri: Option, -} - -impl StandardErrorResponse { - /// Instantiate a new `ErrorResponse`. - /// - /// # Arguments - /// - /// * `error` - REQUIRED. A single ASCII error code deserialized to the generic parameter. - /// `ErrorResponseType`. - /// * `error_description` - OPTIONAL. Human-readable ASCII text providing additional - /// information, used to assist the client developer in understanding the error that - /// occurred. Values for this parameter MUST NOT include characters outside the set - /// `%x20-21 / %x23-5B / %x5D-7E`. - /// * `error_uri` - OPTIONAL. A URI identifying a human-readable web page with information - /// about the error used to provide the client developer with additional information about - /// the error. Values for the "error_uri" parameter MUST conform to the URI-reference - /// syntax and thus MUST NOT include characters outside the set `%x21 / %x23-5B / %x5D-7E`. - pub fn new(error: T, error_description: Option, error_uri: Option) -> Self { - Self { - error, - error_description, - error_uri, - } - } - - /// REQUIRED. A single ASCII error code deserialized to the generic parameter - /// `ErrorResponseType`. - pub fn error(&self) -> &T { - &self.error - } - /// OPTIONAL. Human-readable ASCII text providing additional information, used to assist - /// the client developer in understanding the error that occurred. Values for this - /// parameter MUST NOT include characters outside the set `%x20-21 / %x23-5B / %x5D-7E`. - pub fn error_description(&self) -> Option<&String> { - self.error_description.as_ref() - } - /// OPTIONAL. URI identifying a human-readable web page with information about the error, - /// used to provide the client developer with additional information about the error. - /// Values for the "error_uri" parameter MUST conform to the URI-reference syntax and - /// thus MUST NOT include characters outside the set `%x21 / %x23-5B / %x5D-7E`. - pub fn error_uri(&self) -> Option<&String> { - self.error_uri.as_ref() - } -} - -impl ErrorResponse for StandardErrorResponse where T: ErrorResponseType + Display + 'static {} - -impl Display for StandardErrorResponse -where - TE: ErrorResponseType + Display, -{ - fn fmt(&self, f: &mut Formatter) -> Result<(), FormatterError> { - let mut formatted = self.error().to_string(); - - if let Some(error_description) = self.error_description() { - formatted.push_str(": "); - formatted.push_str(error_description); - } - - if let Some(error_uri) = self.error_uri() { - formatted.push_str(" (see "); - formatted.push_str(error_uri); - formatted.push(')'); - } - - write!(f, "{}", formatted) - } -} - -/// Error encountered while requesting access token. -#[derive(Debug, thiserror::Error)] -pub enum RequestTokenError -where - RE: Error + 'static, - T: ErrorResponse + 'static, -{ - /// Error response returned by authorization server. Contains the parsed `ErrorResponse` - /// returned by the server. - #[error("Server returned error response: {0}")] - ServerResponse(T), - /// An error occurred while sending the request or receiving the response (e.g., network - /// connectivity failed). - #[error("Request failed")] - Request(#[from] RE), - /// Failed to parse server response. Parse errors may occur while parsing either successful - /// or error responses. - #[error("Failed to parse server response")] - Parse( - #[source] serde_path_to_error::Error, - Vec, - ), - /// Some other type of error occurred (e.g., an unexpected server response). - #[error("Other error: {}", _0)] - Other(String), -} diff --git a/src/revocation.rs b/src/revocation.rs index 7bf276f..1e57b0a 100644 --- a/src/revocation.rs +++ b/src/revocation.rs @@ -1,10 +1,18 @@ -use crate::{basic::BasicErrorResponseType, ErrorResponseType}; -use crate::{AccessToken, RefreshToken}; +use crate::basic::BasicErrorResponseType; +use crate::endpoint::{endpoint_request, endpoint_response_status_only}; +use crate::{ + AccessToken, AuthType, ClientId, ClientSecret, ErrorResponse, ErrorResponseType, HttpRequest, + HttpResponse, RefreshToken, RequestTokenError, RevocationUrl, +}; use serde::{Deserialize, Serialize}; +use std::borrow::Cow; +use std::error::Error; use std::fmt::Error as FormatterError; use std::fmt::{Debug, Display, Formatter}; +use std::future::Future; +use std::marker::PhantomData; /// A revocable token. /// @@ -102,6 +110,101 @@ impl From<&RefreshToken> for StandardRevocableToken { } } +/// A request to revoke a token via an [`RFC 7009`](https://tools.ietf.org/html/rfc7009#section-2.1) compatible +/// endpoint. +#[derive(Debug)] +pub struct RevocationRequest<'a, RT, TE> +where + RT: RevocableToken, + TE: ErrorResponse, +{ + pub(crate) token: RT, + pub(crate) auth_type: &'a AuthType, + pub(crate) client_id: &'a ClientId, + pub(crate) client_secret: Option<&'a ClientSecret>, + pub(crate) extra_params: Vec<(Cow<'a, str>, Cow<'a, str>)>, + pub(crate) revocation_url: &'a RevocationUrl, + pub(crate) _phantom: PhantomData<(RT, TE)>, +} + +impl<'a, RT, TE> RevocationRequest<'a, RT, TE> +where + RT: RevocableToken, + TE: ErrorResponse + 'static, +{ + /// Appends an extra param to the token revocation request. + /// + /// This method allows extensions to be used without direct support from + /// this crate. If `name` conflicts with a parameter managed by this crate, the + /// behavior is undefined. In particular, do not set parameters defined by + /// [RFC 6749](https://tools.ietf.org/html/rfc6749) or + /// [RFC 7662](https://tools.ietf.org/html/rfc7662). + /// + /// # Security Warning + /// + /// Callers should follow the security recommendations for any OAuth2 extensions used with + /// this function, which are beyond the scope of + /// [RFC 6749](https://tools.ietf.org/html/rfc6749). + pub fn add_extra_param(mut self, name: N, value: V) -> Self + where + N: Into>, + V: Into>, + { + self.extra_params.push((name.into(), value.into())); + self + } + + fn prepare_request(self) -> HttpRequest { + let mut params: Vec<(&str, &str)> = vec![("token", self.token.secret())]; + if let Some(type_hint) = self.token.type_hint() { + params.push(("token_type_hint", type_hint)); + } + + endpoint_request( + self.auth_type, + self.client_id, + self.client_secret, + &self.extra_params, + None, + None, + self.revocation_url.url(), + params, + ) + } + + /// Synchronously sends the request to the authorization server and awaits a response. + /// + /// A successful response indicates that the server either revoked the token or the token was not known to the + /// server. + /// + /// Error [`UnsupportedTokenType`](RevocationErrorResponseType::UnsupportedTokenType) will be returned if the + /// type of token type given is not supported by the server. + pub fn request(self, http_client: F) -> Result<(), RequestTokenError> + where + F: FnOnce(HttpRequest) -> Result, + RE: Error + 'static, + { + // From https://tools.ietf.org/html/rfc7009#section-2.2: + // "The content of the response body is ignored by the client as all + // necessary information is conveyed in the response code." + endpoint_response_status_only(http_client(self.prepare_request())?) + } + + /// Asynchronously sends the request to the authorization server and returns a Future. + pub async fn request_async( + self, + http_client: C, + ) -> Result<(), RequestTokenError> + where + C: FnOnce(HttpRequest) -> F, + F: Future>, + RE: Error + 'static, + { + let http_response = http_client(self.prepare_request()).await?; + endpoint_response_status_only(http_response) + } +} + /// OAuth 2.0 Token Revocation error response types. /// /// These error types are defined in @@ -162,3 +265,223 @@ impl Display for RevocationErrorResponseType { write!(f, "{}", self.as_ref()) } } + +#[cfg(test)] +mod tests { + use crate::basic::BasicRevocationErrorResponse; + use crate::tests::colorful_extension::{ColorfulClient, ColorfulRevocableToken}; + use crate::tests::{mock_http_client, new_client}; + use crate::{ + AccessToken, AuthUrl, ClientId, ClientSecret, HttpResponse, RefreshToken, + RequestTokenError, RevocationErrorResponseType, RevocationUrl, TokenUrl, + }; + + use http::header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE}; + use http::{HeaderValue, StatusCode}; + + #[test] + fn test_token_revocation_with_non_https_url() { + let client = new_client(); + + let result = client + .set_revocation_uri(RevocationUrl::new("http://revocation/url".to_string()).unwrap()) + .revoke_token(AccessToken::new("access_token_123".to_string()).into()) + .unwrap_err(); + + assert_eq!( + format!("{}", result), + "Scheme for revocation endpoint URL must be HTTPS" + ); + } + + #[test] + fn test_token_revocation_with_unsupported_token_type() { + let client = new_client() + .set_revocation_uri(RevocationUrl::new("https://revocation/url".to_string()).unwrap()); + + let revocation_response = client + .revoke_token(AccessToken::new("access_token_123".to_string()).into()).unwrap() + .request(mock_http_client( + vec![ + (ACCEPT, "application/json"), + (CONTENT_TYPE, "application/x-www-form-urlencoded"), + (AUTHORIZATION, "Basic YWFhOmJiYg=="), + ], + "token=access_token_123&token_type_hint=access_token", + Some("https://revocation/url".parse().unwrap()), + HttpResponse { + status_code: StatusCode::BAD_REQUEST, + headers: vec![( + CONTENT_TYPE, + HeaderValue::from_str("application/json").unwrap(), + )] + .into_iter() + .collect(), + body: "{\"error\": \"unsupported_token_type\", \"error_description\": \"stuff happened\", \ + \"error_uri\": \"https://errors\"}" + .to_string() + .into_bytes(), + }, + )); + + assert!(matches!( + revocation_response, + Err(RequestTokenError::ServerResponse( + BasicRevocationErrorResponse { + error: RevocationErrorResponseType::UnsupportedTokenType, + .. + } + )) + )); + } + + #[test] + fn test_token_revocation_with_access_token_and_empty_json_response() { + let client = new_client() + .set_revocation_uri(RevocationUrl::new("https://revocation/url".to_string()).unwrap()); + + client + .revoke_token(AccessToken::new("access_token_123".to_string()).into()) + .unwrap() + .request(mock_http_client( + vec![ + (ACCEPT, "application/json"), + (CONTENT_TYPE, "application/x-www-form-urlencoded"), + (AUTHORIZATION, "Basic YWFhOmJiYg=="), + ], + "token=access_token_123&token_type_hint=access_token", + Some("https://revocation/url".parse().unwrap()), + HttpResponse { + status_code: StatusCode::OK, + headers: vec![( + CONTENT_TYPE, + HeaderValue::from_str("application/json").unwrap(), + )] + .into_iter() + .collect(), + body: b"{}".to_vec(), + }, + )) + .unwrap(); + } + + #[test] + fn test_token_revocation_with_access_token_and_empty_response() { + let client = new_client() + .set_revocation_uri(RevocationUrl::new("https://revocation/url".to_string()).unwrap()); + + client + .revoke_token(AccessToken::new("access_token_123".to_string()).into()) + .unwrap() + .request(mock_http_client( + vec![ + (ACCEPT, "application/json"), + (CONTENT_TYPE, "application/x-www-form-urlencoded"), + (AUTHORIZATION, "Basic YWFhOmJiYg=="), + ], + "token=access_token_123&token_type_hint=access_token", + Some("https://revocation/url".parse().unwrap()), + HttpResponse { + status_code: StatusCode::OK, + headers: vec![].into_iter().collect(), + body: vec![], + }, + )) + .unwrap(); + } + + #[test] + fn test_token_revocation_with_access_token_and_non_json_response() { + let client = new_client() + .set_revocation_uri(RevocationUrl::new("https://revocation/url".to_string()).unwrap()); + + client + .revoke_token(AccessToken::new("access_token_123".to_string()).into()) + .unwrap() + .request(mock_http_client( + vec![ + (ACCEPT, "application/json"), + (CONTENT_TYPE, "application/x-www-form-urlencoded"), + (AUTHORIZATION, "Basic YWFhOmJiYg=="), + ], + "token=access_token_123&token_type_hint=access_token", + Some("https://revocation/url".parse().unwrap()), + HttpResponse { + status_code: StatusCode::OK, + headers: vec![( + CONTENT_TYPE, + HeaderValue::from_str("application/octet-stream").unwrap(), + )] + .into_iter() + .collect(), + body: vec![1, 2, 3], + }, + )) + .unwrap(); + } + + #[test] + fn test_token_revocation_with_refresh_token() { + let client = new_client() + .set_revocation_uri(RevocationUrl::new("https://revocation/url".to_string()).unwrap()); + + client + .revoke_token(RefreshToken::new("refresh_token_123".to_string()).into()) + .unwrap() + .request(mock_http_client( + vec![ + (ACCEPT, "application/json"), + (CONTENT_TYPE, "application/x-www-form-urlencoded"), + (AUTHORIZATION, "Basic YWFhOmJiYg=="), + ], + "token=refresh_token_123&token_type_hint=refresh_token", + Some("https://revocation/url".parse().unwrap()), + HttpResponse { + status_code: StatusCode::OK, + headers: vec![( + CONTENT_TYPE, + HeaderValue::from_str("application/json").unwrap(), + )] + .into_iter() + .collect(), + body: b"{}".to_vec(), + }, + )) + .unwrap(); + } + + #[test] + fn test_extension_token_revocation_successful() { + let client = ColorfulClient::new(ClientId::new("aaa".to_string())) + .set_client_secret(ClientSecret::new("bbb".to_string())) + .set_auth_url(AuthUrl::new("https://example.com/auth".to_string()).unwrap()) + .set_token_url(TokenUrl::new("https://example.com/token".to_string()).unwrap()) + .set_revocation_uri(RevocationUrl::new("https://revocation/url".to_string()).unwrap()); + + client + .revoke_token(ColorfulRevocableToken::Red( + "colorful_token_123".to_string(), + )) + .unwrap() + .request(mock_http_client( + vec![ + (ACCEPT, "application/json"), + (CONTENT_TYPE, "application/x-www-form-urlencoded"), + (AUTHORIZATION, "Basic YWFhOmJiYg=="), + ], + "token=colorful_token_123&token_type_hint=red_token", + Some("https://revocation/url".parse().unwrap()), + HttpResponse { + status_code: StatusCode::OK, + headers: vec![( + CONTENT_TYPE, + HeaderValue::from_str("application/json").unwrap(), + )] + .into_iter() + .collect(), + body: b"{}".to_vec(), + }, + )) + .unwrap(); + } +} diff --git a/src/tests.rs b/src/tests.rs index e94182f..a826ff1 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -1,46 +1,32 @@ use crate::basic::{ - BasicClient, BasicErrorResponse, BasicErrorResponseType, BasicRevocationErrorResponse, - BasicTokenResponse, BasicTokenType, + BasicClient, BasicErrorResponseType, BasicRevocationErrorResponse, BasicTokenType, }; -use crate::revocation::StandardRevocableToken; -use crate::tests::colorful_extension::{ - ColorfulClient, ColorfulErrorResponseType, ColorfulFields, ColorfulRevocableToken, - ColorfulTokenResponse, ColorfulTokenType, -}; -use crate::tests::custom_errors::CustomErrorClient; use crate::{ AccessToken, AuthType, AuthUrl, AuthorizationCode, AuthorizationRequest, Client, ClientCredentialsTokenRequest, ClientId, ClientSecret, CodeTokenRequest, CsrfToken, DeviceAccessTokenRequest, DeviceAuthorizationRequest, DeviceAuthorizationUrl, DeviceCode, DeviceCodeErrorResponse, DeviceCodeErrorResponseType, EmptyExtraDeviceAuthorizationFields, - EmptyExtraTokenFields, EndUserVerificationUrl, ExtraTokenFields, HttpRequest, HttpResponse, - IntrospectionUrl, PasswordTokenRequest, PkceCodeChallenge, PkceCodeChallengeMethod, - PkceCodeVerifier, RedirectUrl, RefreshToken, RefreshTokenRequest, RequestTokenError, - ResourceOwnerPassword, ResourceOwnerUsername, ResponseType, RevocationErrorResponseType, - RevocationUrl, Scope, StandardDeviceAuthorizationResponse, StandardErrorResponse, - StandardTokenIntrospectionResponse, StandardTokenResponse, TokenResponse, TokenType, TokenUrl, + EmptyExtraTokenFields, EndUserVerificationUrl, HttpRequest, HttpResponse, PasswordTokenRequest, + PkceCodeChallenge, PkceCodeChallengeMethod, PkceCodeVerifier, RedirectUrl, RefreshToken, + RefreshTokenRequest, RequestTokenError, ResourceOwnerPassword, ResourceOwnerUsername, + ResponseType, Scope, StandardDeviceAuthorizationResponse, StandardErrorResponse, + StandardRevocableToken, StandardTokenIntrospectionResponse, StandardTokenResponse, TokenUrl, UserCode, }; -use chrono::{DateTime, TimeZone, Utc}; -use http::header::{HeaderMap, HeaderName, HeaderValue, ACCEPT, AUTHORIZATION, CONTENT_TYPE}; -use http::status::StatusCode; -use serde::Deserialize; +use http::header::HeaderName; +use http::HeaderValue; use thiserror::Error; -use url::form_urlencoded::byte_serialize; use url::Url; -use std::borrow::Cow; -use std::time::Duration; - -fn new_client() -> BasicClient { +pub(crate) fn new_client() -> BasicClient { BasicClient::new(ClientId::new("aaa".to_string())) .set_auth_url(AuthUrl::new("https://example.com/auth".to_string()).unwrap()) .set_token_url(TokenUrl::new("https://example.com/token".to_string()).unwrap()) .set_client_secret(ClientSecret::new("bbb".to_string())) } -fn mock_http_client( +pub(crate) fn mock_http_client( request_headers: Vec<(HeaderName, &'static str)>, request_body: &'static str, request_url: Option, @@ -66,1159 +52,13 @@ fn mock_http_client( } } -#[test] -#[should_panic] -fn test_code_verifier_too_short() { - PkceCodeChallenge::new_random_sha256_len(31); -} - -#[test] -#[should_panic] -fn test_code_verifier_too_long() { - PkceCodeChallenge::new_random_sha256_len(97); -} - -#[test] -fn test_code_verifier_min() { - let code = PkceCodeChallenge::new_random_sha256_len(32); - assert_eq!(code.1.secret().len(), 43); -} - -#[test] -fn test_code_verifier_max() { - let code = PkceCodeChallenge::new_random_sha256_len(96); - assert_eq!(code.1.secret().len(), 128); -} - -#[test] -fn test_code_verifier_challenge() { - // Example from https://tools.ietf.org/html/rfc7636#appendix-B - let code_verifier = - PkceCodeVerifier::new("dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk".to_string()); - assert_eq!( - PkceCodeChallenge::from_code_verifier_sha256(&code_verifier).as_str(), - "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM", - ); -} - -#[test] -fn test_authorize_url() { - let client = new_client(); - let (url, _) = client - .authorize_url(|| CsrfToken::new("csrf_token".to_string())) - .url(); - - assert_eq!( - Url::parse("https://example.com/auth?response_type=code&client_id=aaa&state=csrf_token") - .unwrap(), - url - ); -} - -#[test] -fn test_authorize_random() { - let client = new_client(); - let (url, csrf_state) = client.authorize_url(CsrfToken::new_random).url(); - - assert_eq!( - Url::parse(&format!( - "https://example.com/auth?response_type=code&client_id=aaa&state={}", - byte_serialize(csrf_state.secret().clone().into_bytes().as_slice()) - .collect::>() - .join("") - )) - .unwrap(), - url - ); -} - -#[test] -fn test_authorize_url_pkce() { - // Example from https://tools.ietf.org/html/rfc7636#appendix-B - let client = new_client(); - - let (url, _) = client - .authorize_url(|| CsrfToken::new("csrf_token".to_string())) - .set_pkce_challenge(PkceCodeChallenge::from_code_verifier_sha256( - &PkceCodeVerifier::new("dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk".to_string()), - )) - .url(); - assert_eq!( - Url::parse(concat!( - "https://example.com/auth", - "?response_type=code&client_id=aaa", - "&state=csrf_token", - "&code_challenge=E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM", - "&code_challenge_method=S256", - )) - .unwrap(), - url - ); -} - -#[test] -fn test_authorize_url_implicit() { - let client = new_client(); - - let (url, _) = client - .authorize_url(|| CsrfToken::new("csrf_token".to_string())) - .use_implicit_flow() - .url(); - - assert_eq!( - Url::parse("https://example.com/auth?response_type=token&client_id=aaa&state=csrf_token") - .unwrap(), - url - ); -} - -#[test] -fn test_authorize_url_with_param() { - let client = BasicClient::new(ClientId::new("aaa".to_string())) - .set_client_secret(ClientSecret::new("bbb".to_string())) - .set_auth_url(AuthUrl::new("https://example.com/auth?foo=bar".to_string()).unwrap()) - .set_token_url(TokenUrl::new("https://example.com/token".to_string()).unwrap()); - - let (url, _) = client - .authorize_url(|| CsrfToken::new("csrf_token".to_string())) - .url(); - - assert_eq!( - Url::parse( - "https://example.com/auth?foo=bar&response_type=code&client_id=aaa&state=csrf_token" - ) - .unwrap(), - url - ); -} - -#[test] -fn test_authorize_url_with_scopes() { - let scopes = vec![ - Scope::new("read".to_string()), - Scope::new("write".to_string()), - ]; - let (url, _) = new_client() - .authorize_url(|| CsrfToken::new("csrf_token".to_string())) - .add_scopes(scopes) - .url(); - - assert_eq!( - Url::parse( - "https://example.com/auth\ - ?response_type=code\ - &client_id=aaa\ - &state=csrf_token\ - &scope=read+write" - ) - .unwrap(), - url - ); -} - -#[test] -fn test_authorize_url_with_one_scope() { - let (url, _) = new_client() - .authorize_url(|| CsrfToken::new("csrf_token".to_string())) - .add_scope(Scope::new("read".to_string())) - .url(); - - assert_eq!( - Url::parse( - "https://example.com/auth\ - ?response_type=code\ - &client_id=aaa\ - &state=csrf_token\ - &scope=read" - ) - .unwrap(), - url - ); -} - -#[test] -fn test_authorize_url_with_extension_response_type() { - let client = new_client(); - - let (url, _) = client - .authorize_url(|| CsrfToken::new("csrf_token".to_string())) - .set_response_type(&ResponseType::new("code token".to_string())) - .add_extra_param("foo", "bar") - .url(); - - assert_eq!( - Url::parse( - "https://example.com/auth?response_type=code+token&client_id=aaa&state=csrf_token\ - &foo=bar" - ) - .unwrap(), - url - ); -} - -#[test] -fn test_authorize_url_with_redirect_url() { - let client = new_client() - .set_redirect_uri(RedirectUrl::new("https://localhost/redirect".to_string()).unwrap()); - - let (url, _) = client - .authorize_url(|| CsrfToken::new("csrf_token".to_string())) - .url(); - - assert_eq!( - Url::parse( - "https://example.com/auth?response_type=code\ - &client_id=aaa\ - &state=csrf_token\ - &redirect_uri=https%3A%2F%2Flocalhost%2Fredirect" - ) - .unwrap(), - url - ); -} - -#[test] -fn test_authorize_url_with_redirect_url_override() { - let client = new_client() - .set_redirect_uri(RedirectUrl::new("https://localhost/redirect".to_string()).unwrap()); - - let (url, _) = client - .authorize_url(|| CsrfToken::new("csrf_token".to_string())) - .set_redirect_uri(Cow::Owned( - RedirectUrl::new("https://localhost/alternative".to_string()).unwrap(), - )) - .url(); - - assert_eq!( - Url::parse( - "https://example.com/auth?response_type=code\ - &client_id=aaa\ - &state=csrf_token\ - &redirect_uri=https%3A%2F%2Flocalhost%2Falternative" - ) - .unwrap(), - url - ); -} - #[derive(Debug, Error)] -enum FakeError { +pub(crate) enum FakeError { #[error("error")] Err, } -// Because the secret types don't implement PartialEq, we can't directly use == to compare tokens. -fn assert_token_eq(a: &StandardTokenResponse, b: &StandardTokenResponse) -where - EF: ExtraTokenFields + PartialEq, - TT: TokenType, -{ - assert_eq!(a.access_token().secret(), b.access_token().secret()); - assert_eq!(a.token_type(), b.token_type()); - assert_eq!(a.expires_in(), b.expires_in()); - assert_eq!( - a.refresh_token().map(RefreshToken::secret), - b.refresh_token().map(RefreshToken::secret) - ); - assert_eq!(a.scopes(), b.scopes()); - assert_eq!(a.extra_fields(), b.extra_fields()); -} - -#[test] -fn test_exchange_code_successful_with_minimal_json_response() { - let client = BasicClient::new(ClientId::new("aaa".to_string())) - .set_client_secret(ClientSecret::new("bbb".to_string())) - .set_auth_url(AuthUrl::new("https://example.com/auth".to_string()).unwrap()) - .set_token_url(TokenUrl::new("https://example.com/token".to_string()).unwrap()); - - let token = client - .exchange_code(AuthorizationCode::new("ccc".to_string())) - .request(mock_http_client( - vec![ - (ACCEPT, "application/json"), - (CONTENT_TYPE, "application/x-www-form-urlencoded"), - (AUTHORIZATION, "Basic YWFhOmJiYg=="), - ], - "grant_type=authorization_code&code=ccc", - None, - HttpResponse { - status_code: StatusCode::OK, - headers: HeaderMap::new(), - body: "{\"access_token\": \"12/34\", \"token_type\": \"BEARER\"}" - .to_string() - .into_bytes(), - }, - )) - .unwrap(); - - assert_eq!("12/34", token.access_token().secret()); - assert_eq!(BasicTokenType::Bearer, *token.token_type()); - assert_eq!(None, token.expires_in()); - assert!(token.refresh_token().is_none()); - - // Ensure that serialization produces an equivalent JSON value. - let serialized_json = serde_json::to_string(&token).unwrap(); - assert_eq!( - "{\"access_token\":\"12/34\",\"token_type\":\"bearer\"}".to_string(), - serialized_json - ); - - let deserialized_token = serde_json::from_str::(&serialized_json).unwrap(); - assert_token_eq(&token, &deserialized_token); -} - -#[test] -fn test_exchange_code_successful_with_complete_json_response() { - let client = new_client().set_auth_type(AuthType::RequestBody); - let token = client - .exchange_code(AuthorizationCode::new("ccc".to_string())) - .request(mock_http_client( - vec![ - (ACCEPT, "application/json"), - (CONTENT_TYPE, "application/x-www-form-urlencoded"), - ], - "grant_type=authorization_code&code=ccc&client_id=aaa&client_secret=bbb", - None, - HttpResponse { - status_code: StatusCode::OK, - headers: vec![( - CONTENT_TYPE, - HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: "{\ - \"access_token\": \"12/34\", \ - \"token_type\": \"bearer\", \ - \"scope\": \"read write\", \ - \"expires_in\": 3600, \ - \"refresh_token\": \"foobar\"\ - }" - .to_string() - .into_bytes(), - }, - )) - .unwrap(); - - assert_eq!("12/34", token.access_token().secret()); - assert_eq!(BasicTokenType::Bearer, *token.token_type()); - assert_eq!( - Some(&vec![ - Scope::new("read".to_string()), - Scope::new("write".to_string()), - ]), - token.scopes() - ); - assert_eq!(3600, token.expires_in().unwrap().as_secs()); - assert_eq!("foobar", token.refresh_token().unwrap().secret()); - - // Ensure that serialization produces an equivalent JSON value. - let serialized_json = serde_json::to_string(&token).unwrap(); - assert_eq!( - "{\"access_token\":\"12/34\",\"token_type\":\"bearer\",\"expires_in\":3600,\ - \"refresh_token\":\"foobar\",\"scope\":\"read write\"}" - .to_string(), - serialized_json - ); - - let deserialized_token = serde_json::from_str::(&serialized_json).unwrap(); - assert_token_eq(&token, &deserialized_token); -} - -#[test] -fn test_exchange_client_credentials_with_basic_auth() { - let client = BasicClient::new(ClientId::new("aaa/;&".to_string())) - .set_client_secret(ClientSecret::new("bbb/;&".to_string())) - .set_auth_url(AuthUrl::new("https://example.com/auth".to_string()).unwrap()) - .set_token_url(TokenUrl::new("https://example.com/token".to_string()).unwrap()) - .set_auth_type(AuthType::BasicAuth); - - let token = client - .exchange_client_credentials() - .request(mock_http_client( - vec![ - (ACCEPT, "application/json"), - (CONTENT_TYPE, "application/x-www-form-urlencoded"), - (AUTHORIZATION, "Basic YWFhJTJGJTNCJTI2OmJiYiUyRiUzQiUyNg=="), - ], - "grant_type=client_credentials", - None, - HttpResponse { - status_code: StatusCode::OK, - headers: HeaderMap::new(), - body: "{\ - \"access_token\": \"12/34\", \ - \"token_type\": \"bearer\", \ - \"scope\": \"read write\"\ - }" - .to_string() - .into_bytes(), - }, - )) - .unwrap(); - - assert_eq!("12/34", token.access_token().secret()); - assert_eq!(BasicTokenType::Bearer, *token.token_type()); - assert_eq!( - Some(&vec![ - Scope::new("read".to_string()), - Scope::new("write".to_string()), - ]), - token.scopes() - ); - assert_eq!(None, token.expires_in()); - assert!(token.refresh_token().is_none()); -} - -#[test] -fn test_exchange_client_credentials_with_basic_auth_but_no_client_secret() { - let client = BasicClient::new(ClientId::new("aaa/;&".to_string())) - .set_auth_url(AuthUrl::new("https://example.com/auth".to_string()).unwrap()) - .set_token_url(TokenUrl::new("https://example.com/token".to_string()).unwrap()) - .set_auth_type(AuthType::BasicAuth); - - let token = client - .exchange_client_credentials() - .request(mock_http_client( - vec![ - (ACCEPT, "application/json"), - (CONTENT_TYPE, "application/x-www-form-urlencoded"), - ], - "grant_type=client_credentials&client_id=aaa%2F%3B%26", - None, - HttpResponse { - status_code: StatusCode::OK, - headers: HeaderMap::new(), - body: "{\ - \"access_token\": \"12/34\", \ - \"token_type\": \"bearer\", \ - \"scope\": \"read write\"\ - }" - .to_string() - .into_bytes(), - }, - )) - .unwrap(); - - assert_eq!("12/34", token.access_token().secret()); - assert_eq!(BasicTokenType::Bearer, *token.token_type()); - assert_eq!( - Some(&vec![ - Scope::new("read".to_string()), - Scope::new("write".to_string()), - ]), - token.scopes() - ); - assert_eq!(None, token.expires_in()); - assert!(token.refresh_token().is_none()); -} - -#[test] -fn test_exchange_client_credentials_with_body_auth_and_scope() { - let client = new_client().set_auth_type(AuthType::RequestBody); - let token = client - .exchange_client_credentials() - .add_scope(Scope::new("read".to_string())) - .add_scope(Scope::new("write".to_string())) - .request(mock_http_client( - vec![ - (ACCEPT, "application/json"), - (CONTENT_TYPE, "application/x-www-form-urlencoded"), - ], - "grant_type=client_credentials&scope=read+write&client_id=aaa&client_secret=bbb", - None, - HttpResponse { - status_code: StatusCode::OK, - headers: vec![( - CONTENT_TYPE, - HeaderValue::from_str("APPLICATION/jSoN").unwrap(), - )] - .into_iter() - .collect(), - body: "{\ - \"access_token\": \"12/34\", \ - \"token_type\": \"bearer\", \ - \"scope\": \"read write\"\ - }" - .to_string() - .into_bytes(), - }, - )) - .unwrap(); - - assert_eq!("12/34", token.access_token().secret()); - assert_eq!(BasicTokenType::Bearer, *token.token_type()); - assert_eq!( - Some(&vec![ - Scope::new("read".to_string()), - Scope::new("write".to_string()), - ]), - token.scopes() - ); - assert_eq!(None, token.expires_in()); - assert!(token.refresh_token().is_none()); -} - -#[test] -fn test_exchange_refresh_token_with_basic_auth() { - let client = new_client().set_auth_type(AuthType::BasicAuth); - let token = client - .exchange_refresh_token(&RefreshToken::new("ccc".to_string())) - .request(mock_http_client( - vec![ - (ACCEPT, "application/json"), - (CONTENT_TYPE, "application/x-www-form-urlencoded"), - (AUTHORIZATION, "Basic YWFhOmJiYg=="), - ], - "grant_type=refresh_token&refresh_token=ccc", - None, - HttpResponse { - status_code: StatusCode::OK, - headers: HeaderMap::new(), - body: "{\"access_token\": \"12/34\", \ - \"token_type\": \"bearer\", \ - \"scope\": \"read write\"\ - }" - .to_string() - .into_bytes(), - }, - )) - .unwrap(); - - assert_eq!("12/34", token.access_token().secret()); - assert_eq!(BasicTokenType::Bearer, *token.token_type()); - assert_eq!( - Some(&vec![ - Scope::new("read".to_string()), - Scope::new("write".to_string()), - ]), - token.scopes() - ); - assert_eq!(None, token.expires_in()); - assert!(token.refresh_token().is_none()); -} - -#[test] -fn test_exchange_refresh_token_with_json_response() { - let client = new_client(); - let token = client - .exchange_refresh_token(&RefreshToken::new("ccc".to_string())) - .request(mock_http_client( - vec![ - (ACCEPT, "application/json"), - (CONTENT_TYPE, "application/x-www-form-urlencoded"), - (AUTHORIZATION, "Basic YWFhOmJiYg=="), - ], - "grant_type=refresh_token&refresh_token=ccc", - None, - HttpResponse { - status_code: StatusCode::OK, - headers: HeaderMap::new(), - body: "{\ - \"access_token\": \"12/34\", \ - \"token_type\": \"bearer\", \ - \"scope\": \"read write\"\ - }" - .to_string() - .into_bytes(), - }, - )) - .unwrap(); - - assert_eq!("12/34", token.access_token().secret()); - assert_eq!(BasicTokenType::Bearer, *token.token_type()); - assert_eq!( - Some(&vec![ - Scope::new("read".to_string()), - Scope::new("write".to_string()), - ]), - token.scopes() - ); - assert_eq!(None, token.expires_in()); - assert!(token.refresh_token().is_none()); -} - -#[test] -fn test_exchange_password_with_json_response() { - let client = new_client(); - let token = client - .exchange_password( - &ResourceOwnerUsername::new("user".to_string()), - &ResourceOwnerPassword::new("pass".to_string()), - ) - .add_scope(Scope::new("read".to_string())) - .add_scope(Scope::new("write".to_string())) - .request(mock_http_client( - vec![ - (ACCEPT, "application/json"), - (CONTENT_TYPE, "application/x-www-form-urlencoded"), - (AUTHORIZATION, "Basic YWFhOmJiYg=="), - ], - "grant_type=password&username=user&password=pass&scope=read+write", - None, - HttpResponse { - status_code: StatusCode::OK, - headers: vec![( - CONTENT_TYPE, - HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: "{\ - \"access_token\": \"12/34\", \ - \"token_type\": \"bearer\", \ - \"scope\": \"read write\"\ - }" - .to_string() - .into_bytes(), - }, - )) - .unwrap(); - - assert_eq!("12/34", token.access_token().secret()); - assert_eq!(BasicTokenType::Bearer, *token.token_type()); - assert_eq!( - Some(&vec![ - Scope::new("read".to_string()), - Scope::new("write".to_string()), - ]), - token.scopes() - ); - assert_eq!(None, token.expires_in()); - assert!(token.refresh_token().is_none()); -} - -#[test] -fn test_exchange_code_successful_with_redirect_url() { - let client = new_client() - .set_auth_type(AuthType::RequestBody) - .set_redirect_uri(RedirectUrl::new("https://redirect/here".to_string()).unwrap()); - - let token = client - .exchange_code(AuthorizationCode::new("ccc".to_string())) - .request(mock_http_client( - vec![ - (ACCEPT, "application/json"), - (CONTENT_TYPE, "application/x-www-form-urlencoded"), - ], - "grant_type=authorization_code&code=ccc&client_id=aaa&client_secret=bbb&\ - redirect_uri=https%3A%2F%2Fredirect%2Fhere", - None, - HttpResponse { - status_code: StatusCode::OK, - headers: vec![( - CONTENT_TYPE, - HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: "{\ - \"access_token\": \"12/34\", \ - \"token_type\": \"bearer\", \ - \"scope\": \"read write\"\ - }" - .to_string() - .into_bytes(), - }, - )) - .unwrap(); - - assert_eq!("12/34", token.access_token().secret()); - assert_eq!(BasicTokenType::Bearer, *token.token_type()); - assert_eq!( - Some(&vec![ - Scope::new("read".to_string()), - Scope::new("write".to_string()), - ]), - token.scopes() - ); - assert_eq!(None, token.expires_in()); - assert!(token.refresh_token().is_none()); -} - -#[test] -fn test_exchange_code_successful_with_redirect_url_override() { - let client = new_client() - .set_auth_type(AuthType::RequestBody) - .set_redirect_uri(RedirectUrl::new("https://redirect/here".to_string()).unwrap()); - - let token = client - .exchange_code(AuthorizationCode::new("ccc".to_string())) - .set_redirect_uri(Cow::Owned( - RedirectUrl::new("https://redirect/alternative".to_string()).unwrap(), - )) - .request(mock_http_client( - vec![ - (ACCEPT, "application/json"), - (CONTENT_TYPE, "application/x-www-form-urlencoded"), - ], - "grant_type=authorization_code&code=ccc&client_id=aaa&client_secret=bbb&\ - redirect_uri=https%3A%2F%2Fredirect%2Falternative", - None, - HttpResponse { - status_code: StatusCode::OK, - headers: vec![( - CONTENT_TYPE, - HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: "{\ - \"access_token\": \"12/34\", \ - \"token_type\": \"bearer\", \ - \"scope\": \"read write\"\ - }" - .to_string() - .into_bytes(), - }, - )) - .unwrap(); - - assert_eq!("12/34", token.access_token().secret()); - assert_eq!(BasicTokenType::Bearer, *token.token_type()); - assert_eq!( - Some(&vec![ - Scope::new("read".to_string()), - Scope::new("write".to_string()), - ]), - token.scopes() - ); - assert_eq!(None, token.expires_in()); - assert!(token.refresh_token().is_none()); -} - -#[test] -fn test_exchange_code_successful_with_basic_auth() { - let client = new_client() - .set_auth_type(AuthType::BasicAuth) - .set_redirect_uri(RedirectUrl::new("https://redirect/here".to_string()).unwrap()); - - let token = client - .exchange_code(AuthorizationCode::new("ccc".to_string())) - .request(mock_http_client( - vec![ - (ACCEPT, "application/json"), - (CONTENT_TYPE, "application/x-www-form-urlencoded"), - (AUTHORIZATION, "Basic YWFhOmJiYg=="), - ], - "grant_type=authorization_code&code=ccc&redirect_uri=https%3A%2F%2Fredirect%2Fhere", - None, - HttpResponse { - status_code: StatusCode::OK, - headers: vec![( - CONTENT_TYPE, - HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: "{\ - \"access_token\": \"12/34\", \ - \"token_type\": \"bearer\", \ - \"scope\": \"read write\"\ - }" - .to_string() - .into_bytes(), - }, - )) - .unwrap(); - - assert_eq!("12/34", token.access_token().secret()); - assert_eq!(BasicTokenType::Bearer, *token.token_type()); - assert_eq!( - Some(&vec![ - Scope::new("read".to_string()), - Scope::new("write".to_string()), - ]), - token.scopes() - ); - assert_eq!(None, token.expires_in()); - assert!(token.refresh_token().is_none()); -} - -#[test] -fn test_exchange_code_successful_with_pkce_and_extension() { - let client = new_client() - .set_auth_type(AuthType::BasicAuth) - .set_redirect_uri(RedirectUrl::new("https://redirect/here".to_string()).unwrap()); - - let token = client - .exchange_code(AuthorizationCode::new("ccc".to_string())) - .set_pkce_verifier(PkceCodeVerifier::new( - "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk".to_string(), - )) - .add_extra_param("foo", "bar") - .request(mock_http_client( - vec![ - (ACCEPT, "application/json"), - (CONTENT_TYPE, "application/x-www-form-urlencoded"), - (AUTHORIZATION, "Basic YWFhOmJiYg=="), - ], - "grant_type=authorization_code\ - &code=ccc\ - &code_verifier=dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk\ - &redirect_uri=https%3A%2F%2Fredirect%2Fhere\ - &foo=bar", - None, - HttpResponse { - status_code: StatusCode::OK, - headers: vec![( - CONTENT_TYPE, - HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: "{\ - \"access_token\": \"12/34\", \ - \"token_type\": \"bearer\", \ - \"scope\": \"read write\"\ - }" - .to_string() - .into_bytes(), - }, - )) - .unwrap(); - - assert_eq!("12/34", token.access_token().secret()); - assert_eq!(BasicTokenType::Bearer, *token.token_type()); - assert_eq!( - Some(&vec![ - Scope::new("read".to_string()), - Scope::new("write".to_string()), - ]), - token.scopes() - ); - assert_eq!(None, token.expires_in()); - assert!(token.refresh_token().is_none()); -} - -#[test] -fn test_exchange_refresh_token_successful_with_extension() { - let client = new_client() - .set_auth_type(AuthType::BasicAuth) - .set_redirect_uri(RedirectUrl::new("https://redirect/here".to_string()).unwrap()); - - let token = client - .exchange_refresh_token(&RefreshToken::new("ccc".to_string())) - .add_extra_param("foo", "bar") - .request(mock_http_client( - vec![ - (ACCEPT, "application/json"), - (CONTENT_TYPE, "application/x-www-form-urlencoded"), - (AUTHORIZATION, "Basic YWFhOmJiYg=="), - ], - "grant_type=refresh_token&refresh_token=ccc&foo=bar", - None, - HttpResponse { - status_code: StatusCode::OK, - headers: vec![( - CONTENT_TYPE, - HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: "{\ - \"access_token\": \"12/34\", \ - \"token_type\": \"bearer\", \ - \"scope\": \"read write\"\ - }" - .to_string() - .into_bytes(), - }, - )) - .unwrap(); - - assert_eq!("12/34", token.access_token().secret()); - assert_eq!(BasicTokenType::Bearer, *token.token_type()); - assert_eq!( - Some(&vec![ - Scope::new("read".to_string()), - Scope::new("write".to_string()), - ]), - token.scopes() - ); - assert_eq!(None, token.expires_in()); - assert!(token.refresh_token().is_none()); -} - -#[test] -fn test_exchange_code_with_simple_json_error() { - let client = new_client(); - let token = client - .exchange_code(AuthorizationCode::new("ccc".to_string())) - .request(mock_http_client( - vec![ - (ACCEPT, "application/json"), - (CONTENT_TYPE, "application/x-www-form-urlencoded"), - (AUTHORIZATION, "Basic YWFhOmJiYg=="), - ], - "grant_type=authorization_code&code=ccc", - None, - HttpResponse { - status_code: StatusCode::BAD_REQUEST, - headers: vec![( - CONTENT_TYPE, - HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: "{\ - \"error\": \"invalid_request\", \ - \"error_description\": \"stuff happened\"\ - }" - .to_string() - .into_bytes(), - }, - )); - - assert!(token.is_err()); - - let token_err = token.err().unwrap(); - match token_err { - RequestTokenError::ServerResponse(ref error_response) => { - assert_eq!( - BasicErrorResponseType::InvalidRequest, - *error_response.error() - ); - assert_eq!( - Some(&"stuff happened".to_string()), - error_response.error_description() - ); - assert_eq!(None, error_response.error_uri()); - - // Test Debug trait for ErrorResponse - assert_eq!( - "StandardErrorResponse { error: invalid_request, \ - error_description: Some(\"stuff happened\"), error_uri: None }", - format!("{:?}", error_response) - ); - // Test Display trait for ErrorResponse - assert_eq!( - "invalid_request: stuff happened", - format!("{}", error_response) - ); - - // Test Debug trait for BasicErrorResponseType - assert_eq!("invalid_request", format!("{:?}", error_response.error())); - // Test Display trait for BasicErrorResponseType - assert_eq!("invalid_request", format!("{}", error_response.error())); - - // Ensure that serialization produces an equivalent JSON value. - let serialized_json = serde_json::to_string(&error_response).unwrap(); - assert_eq!( - "{\"error\":\"invalid_request\",\"error_description\":\"stuff happened\"}" - .to_string(), - serialized_json - ); - - let deserialized_error = - serde_json::from_str::(&serialized_json).unwrap(); - assert_eq!(error_response, &deserialized_error); - } - other => panic!("Unexpected error: {:?}", other), - } - - // Test Debug trait for RequestTokenError - assert_eq!( - "ServerResponse(StandardErrorResponse { error: invalid_request, \ - error_description: Some(\"stuff happened\"), error_uri: None })", - format!("{:?}", token_err) - ); - // Test Display trait for RequestTokenError - assert_eq!( - "Server returned error response: invalid_request: stuff happened", - token_err.to_string() - ); -} - -#[test] -fn test_exchange_code_with_json_parse_error() { - let client = new_client(); - let token = client - .exchange_code(AuthorizationCode::new("ccc".to_string())) - .request(mock_http_client( - vec![ - (ACCEPT, "application/json"), - (CONTENT_TYPE, "application/x-www-form-urlencoded"), - (AUTHORIZATION, "Basic YWFhOmJiYg=="), - ], - "grant_type=authorization_code&code=ccc", - None, - HttpResponse { - status_code: StatusCode::OK, - headers: vec![( - CONTENT_TYPE, - HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: "broken json".to_string().into_bytes(), - }, - )); - - assert!(token.is_err()); - - match token.err().unwrap() { - RequestTokenError::Parse(json_err, _) => { - assert_eq!(".", json_err.path().to_string()); - assert_eq!(1, json_err.inner().line()); - assert_eq!(1, json_err.inner().column()); - assert_eq!( - serde_json::error::Category::Syntax, - json_err.inner().classify() - ); - } - other => panic!("Unexpected error: {:?}", other), - } -} - -#[test] -fn test_exchange_code_with_unexpected_content_type() { - let client = new_client(); - let token = client - .exchange_code(AuthorizationCode::new("ccc".to_string())) - .request(mock_http_client( - vec![ - (ACCEPT, "application/json"), - (CONTENT_TYPE, "application/x-www-form-urlencoded"), - (AUTHORIZATION, "Basic YWFhOmJiYg=="), - ], - "grant_type=authorization_code&code=ccc", - None, - HttpResponse { - status_code: StatusCode::OK, - headers: vec![(CONTENT_TYPE, HeaderValue::from_str("text/plain").unwrap())] - .into_iter() - .collect(), - body: "broken json".to_string().into_bytes(), - }, - )); - - assert!(token.is_err()); - - match token.err().unwrap() { - RequestTokenError::Other(error_str) => { - assert_eq!( - "Unexpected response Content-Type: \"text/plain\", should be `application/json`", - error_str - ); - } - other => panic!("Unexpected error: {:?}", other), - } -} - -#[test] -fn test_exchange_code_with_invalid_token_type() { - let client = BasicClient::new(ClientId::new("aaa".to_string())) - .set_auth_url(AuthUrl::new("https://example.com/auth".to_string()).unwrap()) - .set_token_url(TokenUrl::new("https://example.com/token".to_string()).unwrap()); - - let token = client - .exchange_code(AuthorizationCode::new("ccc".to_string())) - .request(mock_http_client( - vec![ - (ACCEPT, "application/json"), - (CONTENT_TYPE, "application/x-www-form-urlencoded"), - ], - "grant_type=authorization_code&code=ccc&client_id=aaa", - None, - HttpResponse { - status_code: StatusCode::OK, - headers: vec![( - CONTENT_TYPE, - HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: "{\"access_token\": \"12/34\", \"token_type\": 123}" - .to_string() - .into_bytes(), - }, - )); - - assert!(token.is_err()); - match token.err().unwrap() { - RequestTokenError::Parse(json_err, _) => { - assert_eq!("token_type", json_err.path().to_string()); - assert_eq!(1, json_err.inner().line()); - assert_eq!(43, json_err.inner().column()); - assert_eq!( - serde_json::error::Category::Data, - json_err.inner().classify() - ); - } - other => panic!("Unexpected error: {:?}", other), - } -} - -#[test] -fn test_exchange_code_with_400_status_code() { - let body = r#"{"error":"invalid_request","error_description":"Expired code."}"#; - let client = new_client(); - let token_err = client - .exchange_code(AuthorizationCode::new("ccc".to_string())) - .request(mock_http_client( - vec![ - (ACCEPT, "application/json"), - (CONTENT_TYPE, "application/x-www-form-urlencoded"), - (AUTHORIZATION, "Basic YWFhOmJiYg=="), - ], - "grant_type=authorization_code&code=ccc", - None, - HttpResponse { - status_code: StatusCode::BAD_REQUEST, - headers: vec![( - CONTENT_TYPE, - HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: body.to_string().into_bytes(), - }, - )) - .err() - .unwrap(); - - match token_err { - RequestTokenError::ServerResponse(ref error_response) => { - assert_eq!( - BasicErrorResponseType::InvalidRequest, - *error_response.error() - ); - assert_eq!( - Some(&"Expired code.".to_string()), - error_response.error_description() - ); - assert_eq!(None, error_response.error_uri()); - } - other => panic!("Unexpected error: {:?}", other), - } - - assert_eq!( - "Server returned error response: invalid_request: Expired code.", - token_err.to_string(), - ); -} - -#[test] -fn test_exchange_code_fails_gracefully_on_transport_error() { - let client = BasicClient::new(ClientId::new("aaa".to_string())) - .set_client_secret(ClientSecret::new("bbb".to_string())) - .set_auth_url(AuthUrl::new("https://auth".to_string()).unwrap()) - .set_token_url(TokenUrl::new("https://token".to_string()).unwrap()); - - let token = client - .exchange_code(AuthorizationCode::new("ccc".to_string())) - .request(|_| Err(FakeError::Err)); - - assert!(token.is_err()); - - match token.err().unwrap() { - RequestTokenError::Request(FakeError::Err) => (), - other => panic!("Unexpected error: {:?}", other), - } -} - -mod colorful_extension { +pub(crate) mod colorful_extension { extern crate serde_json; use crate::{ @@ -1331,1027 +171,7 @@ mod colorful_extension { } } -#[test] -fn test_extension_successful_with_minimal_json_response() { - let client = ColorfulClient::new(ClientId::new("aaa".to_string())) - .set_client_secret(ClientSecret::new("bbb".to_string())) - .set_auth_url(AuthUrl::new("https://example.com/auth".to_string()).unwrap()) - .set_token_url(TokenUrl::new("https://example.com/token".to_string()).unwrap()); - - let token = client - .exchange_code(AuthorizationCode::new("ccc".to_string())) - .request(mock_http_client( - vec![ - (ACCEPT, "application/json"), - (CONTENT_TYPE, "application/x-www-form-urlencoded"), - (AUTHORIZATION, "Basic YWFhOmJiYg=="), - ], - "grant_type=authorization_code&code=ccc", - None, - HttpResponse { - status_code: StatusCode::OK, - headers: vec![( - CONTENT_TYPE, - HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: "{\"access_token\": \"12/34\", \"token_type\": \"green\", \"height\": 10}" - .to_string() - .into_bytes(), - }, - )) - .unwrap(); - - assert_eq!("12/34", token.access_token().secret()); - assert_eq!(ColorfulTokenType::Green, *token.token_type()); - assert_eq!(None, token.expires_in()); - assert!(token.refresh_token().is_none()); - assert_eq!(None, token.extra_fields().shape()); - assert_eq!(10, token.extra_fields().height()); - - // Ensure that serialization produces an equivalent JSON value. - let serialized_json = serde_json::to_string(&token).unwrap(); - assert_eq!( - "{\"access_token\":\"12/34\",\"token_type\":\"green\",\"height\":10}".to_string(), - serialized_json - ); - - let deserialized_token = - serde_json::from_str::(&serialized_json).unwrap(); - assert_token_eq(&token, &deserialized_token); -} - -#[test] -fn test_extension_successful_with_complete_json_response() { - let client = ColorfulClient::new(ClientId::new("aaa".to_string())) - .set_client_secret(ClientSecret::new("bbb".to_string())) - .set_auth_url(AuthUrl::new("https://example.com/auth".to_string()).unwrap()) - .set_token_url(TokenUrl::new("https://example.com/token".to_string()).unwrap()) - .set_auth_type(AuthType::RequestBody); - - let token = client - .exchange_code(AuthorizationCode::new("ccc".to_string())) - .request(mock_http_client( - vec![ - (ACCEPT, "application/json"), - (CONTENT_TYPE, "application/x-www-form-urlencoded"), - ], - "grant_type=authorization_code&code=ccc&client_id=aaa&client_secret=bbb", - None, - HttpResponse { - status_code: StatusCode::OK, - headers: vec![( - CONTENT_TYPE, - HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: "{\ - \"access_token\": \"12/34\", \ - \"token_type\": \"red\", \ - \"scope\": \"read write\", \ - \"expires_in\": 3600, \ - \"refresh_token\": \"foobar\", \ - \"shape\": \"round\", \ - \"height\": 12\ - }" - .to_string() - .into_bytes(), - }, - )) - .unwrap(); - - assert_eq!("12/34", token.access_token().secret()); - assert_eq!(ColorfulTokenType::Red, *token.token_type()); - assert_eq!( - Some(&vec![ - Scope::new("read".to_string()), - Scope::new("write".to_string()), - ]), - token.scopes() - ); - assert_eq!(3600, token.expires_in().unwrap().as_secs()); - assert_eq!("foobar", token.refresh_token().unwrap().secret()); - assert_eq!(Some(&"round".to_string()), token.extra_fields().shape()); - assert_eq!(12, token.extra_fields().height()); - - // Ensure that serialization produces an equivalent JSON value. - let serialized_json = serde_json::to_string(&token).unwrap(); - assert_eq!( - "{\"access_token\":\"12/34\",\"token_type\":\"red\",\"expires_in\":3600,\ - \"refresh_token\":\"foobar\",\"scope\":\"read write\",\"shape\":\"round\",\"height\":12}" - .to_string(), - serialized_json - ); - - let deserialized_token = - serde_json::from_str::(&serialized_json).unwrap(); - assert_token_eq(&token, &deserialized_token); -} - -#[test] -fn test_extension_with_simple_json_error() { - let client = ColorfulClient::new(ClientId::new("aaa".to_string())) - .set_client_secret(ClientSecret::new("bbb".to_string())) - .set_auth_url(AuthUrl::new("https://example.com/auth".to_string()).unwrap()) - .set_token_url(TokenUrl::new("https://example.com/token".to_string()).unwrap()); - - let token = client - .exchange_code(AuthorizationCode::new("ccc".to_string())) - .request(mock_http_client( - vec![ - (ACCEPT, "application/json"), - (CONTENT_TYPE, "application/x-www-form-urlencoded"), - (AUTHORIZATION, "Basic YWFhOmJiYg=="), - ], - "grant_type=authorization_code&code=ccc", - None, - HttpResponse { - status_code: StatusCode::BAD_REQUEST, - headers: vec![( - CONTENT_TYPE, - HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: "{\"error\": \"too_light\", \"error_description\": \"stuff happened\", \ - \"error_uri\": \"https://errors\"}" - .to_string() - .into_bytes(), - }, - )); - - assert!(token.is_err()); - - let token_err = token.err().unwrap(); - match token_err { - RequestTokenError::ServerResponse(ref error_response) => { - assert_eq!(ColorfulErrorResponseType::TooLight, *error_response.error()); - assert_eq!( - Some(&"stuff happened".to_string()), - error_response.error_description() - ); - assert_eq!( - Some(&"https://errors".to_string()), - error_response.error_uri() - ); - - // Ensure that serialization produces an equivalent JSON value. - let serialized_json = serde_json::to_string(&error_response).unwrap(); - assert_eq!( - "{\"error\":\"too_light\",\"error_description\":\"stuff happened\",\ - \"error_uri\":\"https://errors\"}" - .to_string(), - serialized_json - ); - - let deserialized_error = serde_json::from_str::< - StandardErrorResponse, - >(&serialized_json) - .unwrap(); - assert_eq!(error_response, &deserialized_error); - } - other => panic!("Unexpected error: {:?}", other), - } - - // Test Debug trait for RequestTokenError - assert_eq!( - "ServerResponse(StandardErrorResponse { error: too_light, \ - error_description: Some(\"stuff happened\"), error_uri: Some(\"https://errors\") })", - format!("{:?}", token_err) - ); - // Test Display trait for RequestTokenError - assert_eq!( - "Server returned error response: too_light: stuff happened (see https://errors)", - token_err.to_string() - ); -} - -mod custom_errors { - use crate::tests::colorful_extension::{ - ColorfulFields, ColorfulRevocableToken, ColorfulTokenType, - }; - use crate::{Client, ErrorResponse, StandardTokenIntrospectionResponse, StandardTokenResponse}; - - use serde::{Deserialize, Serialize}; - - use std::fmt::Error as FormatterError; - use std::fmt::{Display, Formatter}; - - extern crate serde_json; - - #[derive(Serialize, Deserialize, Debug)] - pub struct CustomErrorResponse { - pub custom_error: String, - } - - impl Display for CustomErrorResponse { - fn fmt(&self, f: &mut Formatter) -> Result<(), FormatterError> { - write!(f, "Custom Error from server") - } - } - - impl ErrorResponse for CustomErrorResponse {} - - pub type CustomErrorClient< - const HAS_AUTH_URL: bool, - const HAS_DEVICE_AUTH_URL: bool, - const HAS_INTROSPECTION_URL: bool, - const HAS_REVOCATION_URL: bool, - const HAS_TOKEN_URL: bool, - > = Client< - CustomErrorResponse, - StandardTokenResponse, - ColorfulTokenType, - StandardTokenIntrospectionResponse, - ColorfulRevocableToken, - CustomErrorResponse, - HAS_AUTH_URL, - HAS_DEVICE_AUTH_URL, - HAS_INTROSPECTION_URL, - HAS_REVOCATION_URL, - HAS_TOKEN_URL, - >; -} - -#[test] -fn test_extension_with_custom_json_error() { - let client = CustomErrorClient::new(ClientId::new("aaa".to_string())) - .set_client_secret(ClientSecret::new("bbb".to_string())) - .set_auth_url(AuthUrl::new("https://example.com/auth".to_string()).unwrap()) - .set_token_url(TokenUrl::new("https://example.com/token".to_string()).unwrap()); - - let token = client - .exchange_code(AuthorizationCode::new("ccc".to_string())) - .request(mock_http_client( - vec![ - (ACCEPT, "application/json"), - (CONTENT_TYPE, "application/x-www-form-urlencoded"), - (AUTHORIZATION, "Basic YWFhOmJiYg=="), - ], - "grant_type=authorization_code&code=ccc", - None, - HttpResponse { - status_code: StatusCode::BAD_REQUEST, - headers: vec![( - CONTENT_TYPE, - HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: "{\"custom_error\": \"non-compliant oauth implementation ;-)\"}" - .to_string() - .into_bytes(), - }, - )); - - assert!(token.is_err()); - - match token.err().unwrap() { - RequestTokenError::ServerResponse(e) => { - assert_eq!("non-compliant oauth implementation ;-)", e.custom_error) - } - e => panic!("failed to correctly parse custom server error, got {:?}", e), - }; -} - -#[test] -fn test_extension_serializer() { - let mut token_response = ColorfulTokenResponse::new( - AccessToken::new("mysecret".to_string()), - ColorfulTokenType::Red, - ColorfulFields { - shape: Some("circle".to_string()), - height: 10, - }, - ); - token_response.set_expires_in(Some(&Duration::from_secs(3600))); - token_response.set_refresh_token(Some(RefreshToken::new("myothersecret".to_string()))); - let serialized = serde_json::to_string(&token_response).unwrap(); - assert_eq!( - "{\ - \"access_token\":\"mysecret\",\ - \"token_type\":\"red\",\ - \"expires_in\":3600,\ - \"refresh_token\":\"myothersecret\",\ - \"shape\":\"circle\",\ - \"height\":10\ - }", - serialized, - ); -} - -#[test] -fn test_error_response_serializer() { - assert_eq!( - "{\"error\":\"unauthorized_client\"}", - serde_json::to_string(&BasicErrorResponse::new( - BasicErrorResponseType::UnauthorizedClient, - None, - None, - )) - .unwrap(), - ); - - assert_eq!( - "{\ - \"error\":\"invalid_client\",\ - \"error_description\":\"Invalid client_id\",\ - \"error_uri\":\"https://example.com/errors/invalid_client\"\ - }", - serde_json::to_string(&BasicErrorResponse::new( - BasicErrorResponseType::InvalidClient, - Some("Invalid client_id".to_string()), - Some("https://example.com/errors/invalid_client".to_string()), - )) - .unwrap(), - ); -} - -#[derive(Deserialize, Debug, Clone)] -pub struct ObjectWithOptionalStringOrVecString { - #[serde(deserialize_with = "crate::helpers::deserialize_optional_string_or_vec_string")] - pub strings: Option>, -} - -#[test] -fn test_deserialize_optional_string_or_vec_string_none() { - let list_of_strings: ObjectWithOptionalStringOrVecString = - serde_json::from_str(r#"{ "strings": null }"#).unwrap(); - assert_eq!(None, list_of_strings.strings); -} - -#[test] -fn test_deserialize_optional_string_or_vec_string_single_value() { - let list_of_strings: ObjectWithOptionalStringOrVecString = - serde_json::from_str(r#"{ "strings": "v1" }"#).unwrap(); - assert_eq!(Some(vec!["v1".to_string()]), list_of_strings.strings); -} - -#[test] -fn test_deserialize_optional_string_or_vec_string_vec() { - let list_of_strings: ObjectWithOptionalStringOrVecString = - serde_json::from_str(r#"{ "strings": ["v1", "v2"] }"#).unwrap(); - assert_eq!( - Some(vec!["v1".to_string(), "v2".to_string()]), - list_of_strings.strings - ); -} - -#[test] -fn test_token_introspection_successful_with_basic_auth_minimal_response() { - let client = new_client() - .set_auth_type(AuthType::BasicAuth) - .set_redirect_uri(RedirectUrl::new("https://redirect/here".to_string()).unwrap()) - .set_introspection_uri( - IntrospectionUrl::new("https://introspection/url".to_string()).unwrap(), - ); - - let introspection_response = client - .introspect(&AccessToken::new("access_token_123".to_string())) - .unwrap() - .request(mock_http_client( - vec![ - (ACCEPT, "application/json"), - (CONTENT_TYPE, "application/x-www-form-urlencoded"), - (AUTHORIZATION, "Basic YWFhOmJiYg=="), - ], - "token=access_token_123", - Some("https://introspection/url".parse().unwrap()), - HttpResponse { - status_code: StatusCode::OK, - headers: vec![( - CONTENT_TYPE, - HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: "{\ - \"active\": true\ - }" - .to_string() - .into_bytes(), - }, - )) - .unwrap(); - - assert!(introspection_response.active); - assert_eq!(None, introspection_response.scopes); - assert_eq!(None, introspection_response.client_id); - assert_eq!(None, introspection_response.username); - assert_eq!(None, introspection_response.token_type); - assert_eq!(None, introspection_response.exp); - assert_eq!(None, introspection_response.iat); - assert_eq!(None, introspection_response.nbf); - assert_eq!(None, introspection_response.sub); - assert_eq!(None, introspection_response.aud); - assert_eq!(None, introspection_response.iss); - assert_eq!(None, introspection_response.jti); -} - -#[test] -fn test_token_introspection_successful_with_basic_auth_full_response() { - let client = new_client() - .set_auth_type(AuthType::BasicAuth) - .set_redirect_uri(RedirectUrl::new("https://redirect/here".to_string()).unwrap()) - .set_introspection_uri( - IntrospectionUrl::new("https://introspection/url".to_string()).unwrap(), - ); - - let introspection_response = client - .introspect(&AccessToken::new("access_token_123".to_string())) - .unwrap() - .set_token_type_hint("access_token") - .request(mock_http_client( - vec![ - (ACCEPT, "application/json"), - (CONTENT_TYPE, "application/x-www-form-urlencoded"), - (AUTHORIZATION, "Basic YWFhOmJiYg=="), - ], - "token=access_token_123&token_type_hint=access_token", - Some("https://introspection/url".parse().unwrap()), - HttpResponse { - status_code: StatusCode::OK, - headers: vec![( - CONTENT_TYPE, - HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: r#"{ - "active": true, - "scope": "email profile", - "client_id": "aaa", - "username": "demo", - "token_type": "bearer", - "exp": 1604073517, - "iat": 1604073217, - "nbf": 1604073317, - "sub": "demo", - "aud": "demo", - "iss": "http://127.0.0.1:8080/auth/realms/test-realm", - "jti": "be1b7da2-fc18-47b3-bdf1-7a4f50bcf53f" - }"# - .to_string() - .into_bytes(), - }, - )) - .unwrap(); - - assert!(introspection_response.active); - assert_eq!( - Some(vec![ - Scope::new("email".to_string()), - Scope::new("profile".to_string()) - ]), - introspection_response.scopes - ); - assert_eq!( - Some(ClientId::new("aaa".to_string())), - introspection_response.client_id - ); - assert_eq!(Some("demo".to_string()), introspection_response.username); - assert_eq!( - Some(BasicTokenType::Bearer), - introspection_response.token_type - ); - assert_eq!( - Some(Utc.timestamp(1604073517, 0)), - introspection_response.exp - ); - assert_eq!( - Some(Utc.timestamp(1604073217, 0)), - introspection_response.iat - ); - assert_eq!( - Some(Utc.timestamp(1604073317, 0)), - introspection_response.nbf - ); - assert_eq!(Some("demo".to_string()), introspection_response.sub); - assert_eq!(Some(vec!["demo".to_string()]), introspection_response.aud); - assert_eq!( - Some("http://127.0.0.1:8080/auth/realms/test-realm".to_string()), - introspection_response.iss - ); - assert_eq!( - Some("be1b7da2-fc18-47b3-bdf1-7a4f50bcf53f".to_string()), - introspection_response.jti - ); -} - -#[test] -fn test_token_revocation_with_non_https_url() { - let client = new_client(); - - let result = client - .set_revocation_uri(RevocationUrl::new("http://revocation/url".to_string()).unwrap()) - .revoke_token(AccessToken::new("access_token_123".to_string()).into()) - .unwrap_err(); - - assert_eq!( - format!("{}", result), - "Scheme for revocation endpoint URL must be HTTPS" - ); -} - -#[test] -fn test_token_revocation_with_unsupported_token_type() { - let client = new_client() - .set_revocation_uri(RevocationUrl::new("https://revocation/url".to_string()).unwrap()); - - let revocation_response = client - .revoke_token(AccessToken::new("access_token_123".to_string()).into()).unwrap() - .request(mock_http_client( - vec![ - (ACCEPT, "application/json"), - (CONTENT_TYPE, "application/x-www-form-urlencoded"), - (AUTHORIZATION, "Basic YWFhOmJiYg=="), - ], - "token=access_token_123&token_type_hint=access_token", - Some("https://revocation/url".parse().unwrap()), - HttpResponse { - status_code: StatusCode::BAD_REQUEST, - headers: vec![( - CONTENT_TYPE, - HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: "{\"error\": \"unsupported_token_type\", \"error_description\": \"stuff happened\", \ - \"error_uri\": \"https://errors\"}" - .to_string() - .into_bytes(), - }, - )); - - assert!(matches!( - revocation_response, - Err(RequestTokenError::ServerResponse( - BasicRevocationErrorResponse { - error: RevocationErrorResponseType::UnsupportedTokenType, - .. - } - )) - )); -} - -#[test] -fn test_token_revocation_with_access_token_and_empty_json_response() { - let client = new_client() - .set_revocation_uri(RevocationUrl::new("https://revocation/url".to_string()).unwrap()); - - client - .revoke_token(AccessToken::new("access_token_123".to_string()).into()) - .unwrap() - .request(mock_http_client( - vec![ - (ACCEPT, "application/json"), - (CONTENT_TYPE, "application/x-www-form-urlencoded"), - (AUTHORIZATION, "Basic YWFhOmJiYg=="), - ], - "token=access_token_123&token_type_hint=access_token", - Some("https://revocation/url".parse().unwrap()), - HttpResponse { - status_code: StatusCode::OK, - headers: vec![( - CONTENT_TYPE, - HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: b"{}".to_vec(), - }, - )) - .unwrap(); -} - -#[test] -fn test_token_revocation_with_access_token_and_empty_response() { - let client = new_client() - .set_revocation_uri(RevocationUrl::new("https://revocation/url".to_string()).unwrap()); - - client - .revoke_token(AccessToken::new("access_token_123".to_string()).into()) - .unwrap() - .request(mock_http_client( - vec![ - (ACCEPT, "application/json"), - (CONTENT_TYPE, "application/x-www-form-urlencoded"), - (AUTHORIZATION, "Basic YWFhOmJiYg=="), - ], - "token=access_token_123&token_type_hint=access_token", - Some("https://revocation/url".parse().unwrap()), - HttpResponse { - status_code: StatusCode::OK, - headers: vec![].into_iter().collect(), - body: vec![], - }, - )) - .unwrap(); -} - -#[test] -fn test_token_revocation_with_access_token_and_non_json_response() { - let client = new_client() - .set_revocation_uri(RevocationUrl::new("https://revocation/url".to_string()).unwrap()); - - client - .revoke_token(AccessToken::new("access_token_123".to_string()).into()) - .unwrap() - .request(mock_http_client( - vec![ - (ACCEPT, "application/json"), - (CONTENT_TYPE, "application/x-www-form-urlencoded"), - (AUTHORIZATION, "Basic YWFhOmJiYg=="), - ], - "token=access_token_123&token_type_hint=access_token", - Some("https://revocation/url".parse().unwrap()), - HttpResponse { - status_code: StatusCode::OK, - headers: vec![( - CONTENT_TYPE, - HeaderValue::from_str("application/octet-stream").unwrap(), - )] - .into_iter() - .collect(), - body: vec![1, 2, 3], - }, - )) - .unwrap(); -} - -#[test] -fn test_token_revocation_with_refresh_token() { - let client = new_client() - .set_revocation_uri(RevocationUrl::new("https://revocation/url".to_string()).unwrap()); - - client - .revoke_token(RefreshToken::new("refresh_token_123".to_string()).into()) - .unwrap() - .request(mock_http_client( - vec![ - (ACCEPT, "application/json"), - (CONTENT_TYPE, "application/x-www-form-urlencoded"), - (AUTHORIZATION, "Basic YWFhOmJiYg=="), - ], - "token=refresh_token_123&token_type_hint=refresh_token", - Some("https://revocation/url".parse().unwrap()), - HttpResponse { - status_code: StatusCode::OK, - headers: vec![( - CONTENT_TYPE, - HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: b"{}".to_vec(), - }, - )) - .unwrap(); -} - -#[test] -fn test_extension_token_revocation_successful() { - let client = ColorfulClient::new(ClientId::new("aaa".to_string())) - .set_client_secret(ClientSecret::new("bbb".to_string())) - .set_auth_url(AuthUrl::new("https://example.com/auth".to_string()).unwrap()) - .set_token_url(TokenUrl::new("https://example.com/token".to_string()).unwrap()) - .set_revocation_uri(RevocationUrl::new("https://revocation/url".to_string()).unwrap()); - - client - .revoke_token(ColorfulRevocableToken::Red( - "colorful_token_123".to_string(), - )) - .unwrap() - .request(mock_http_client( - vec![ - (ACCEPT, "application/json"), - (CONTENT_TYPE, "application/x-www-form-urlencoded"), - (AUTHORIZATION, "Basic YWFhOmJiYg=="), - ], - "token=colorful_token_123&token_type_hint=red_token", - Some("https://revocation/url".parse().unwrap()), - HttpResponse { - status_code: StatusCode::OK, - headers: vec![( - CONTENT_TYPE, - HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: b"{}".to_vec(), - }, - )) - .unwrap(); -} - -#[test] -fn test_secret_redaction() { - let secret = ClientSecret::new("top_secret".to_string()); - assert_eq!("ClientSecret([redacted])", format!("{:?}", secret)); -} - -fn new_device_auth_details(expires_in: u32) -> StandardDeviceAuthorizationResponse { - let body = format!( - "{{\ - \"device_code\": \"12345\", \ - \"verification_uri\": \"https://verify/here\", \ - \"user_code\": \"abcde\", \ - \"verification_uri_complete\": \"https://verify/here?abcde\", \ - \"expires_in\": {}, \ - \"interval\": 1 \ - }}", - expires_in - ); - - let device_auth_url = - DeviceAuthorizationUrl::new("https://deviceauth/here".to_string()).unwrap(); - - let client = new_client().set_device_authorization_url(device_auth_url.clone()); - client - .exchange_device_code() - .add_extra_param("foo", "bar") - .add_scope(Scope::new("openid".to_string())) - .request(mock_http_client( - vec![ - (ACCEPT, "application/json"), - (CONTENT_TYPE, "application/x-www-form-urlencoded"), - (AUTHORIZATION, "Basic YWFhOmJiYg=="), - ], - "scope=openid&foo=bar", - Some(device_auth_url.url().to_owned()), - HttpResponse { - status_code: StatusCode::OK, - headers: vec![( - CONTENT_TYPE, - HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: body.into_bytes(), - }, - )) - .unwrap() -} - -struct IncreasingTime { - times: std::ops::RangeFrom, -} - -impl IncreasingTime { - fn new() -> Self { - Self { times: (0..) } - } - fn next(&mut self) -> DateTime { - let next_value = self.times.next().unwrap(); - let naive = chrono::NaiveDateTime::from_timestamp(next_value, 0); - DateTime::::from_utc(naive, chrono::Utc) - } -} - -/// Creates a time function that increments by one second each time. -fn mock_time_fn() -> impl Fn() -> DateTime + Send + Sync { - let timer = std::sync::Mutex::new(IncreasingTime::new()); - move || timer.lock().unwrap().next() -} - -/// Mock sleep function that doesn't actually sleep. -fn mock_sleep_fn(_: Duration) {} - -#[test] -fn test_exchange_device_code_and_token() { - let details = new_device_auth_details(3600); - assert_eq!("12345", details.device_code().secret()); - assert_eq!("https://verify/here", details.verification_uri().as_str()); - assert_eq!("abcde", details.user_code().secret().as_str()); - assert_eq!( - "https://verify/here?abcde", - details - .verification_uri_complete() - .unwrap() - .secret() - .as_str() - ); - assert_eq!(Duration::from_secs(3600), details.expires_in()); - assert_eq!(Duration::from_secs(1), details.interval()); - - let token = new_client() - .exchange_device_access_token(&details) - .set_time_fn(mock_time_fn()) - .request(mock_http_client( - vec![ - (ACCEPT, "application/json"), - (CONTENT_TYPE, "application/x-www-form-urlencoded"), - (AUTHORIZATION, "Basic YWFhOmJiYg=="), - ], - "grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Adevice_code&device_code=12345", - None, - HttpResponse { - status_code: StatusCode::OK, - headers: vec![( - CONTENT_TYPE, - HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: "{\ - \"access_token\": \"12/34\", \ - \"token_type\": \"bearer\", \ - \"scope\": \"openid\"\ - }" - .to_string() - .into_bytes(), - }, - ), - mock_sleep_fn, - None) - .unwrap(); - - assert_eq!("12/34", token.access_token().secret()); - assert_eq!(BasicTokenType::Bearer, *token.token_type()); - assert_eq!( - Some(&vec![Scope::new("openid".to_string()),]), - token.scopes() - ); - assert_eq!(None, token.expires_in()); - assert!(token.refresh_token().is_none()); -} - -#[test] -fn test_device_token_authorization_timeout() { - let details = new_device_auth_details(2); - assert_eq!("12345", details.device_code().secret()); - assert_eq!("https://verify/here", details.verification_uri().as_str()); - assert_eq!("abcde", details.user_code().secret().as_str()); - assert_eq!( - "https://verify/here?abcde", - details - .verification_uri_complete() - .unwrap() - .secret() - .as_str() - ); - assert_eq!(Duration::from_secs(2), details.expires_in()); - assert_eq!(Duration::from_secs(1), details.interval()); - - let token = new_client() - .exchange_device_access_token(&details) - .set_time_fn(mock_time_fn()) - .request(mock_http_client( - vec![ - (ACCEPT, "application/json"), - (CONTENT_TYPE, "application/x-www-form-urlencoded"), - (AUTHORIZATION, "Basic YWFhOmJiYg=="), - ], - "grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Adevice_code&device_code=12345", - None, - HttpResponse { - status_code: StatusCode::from_u16(400).unwrap(), - headers: vec![( - CONTENT_TYPE, - HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: "{\ - \"error\": \"authorization_pending\", \ - \"error_description\": \"Still waiting for user\"\ - }" - .to_string() - .into_bytes(), - }, - ), - mock_sleep_fn, - None) - .err() - .unwrap(); - match token { - RequestTokenError::ServerResponse(msg) => assert_eq!( - msg, - DeviceCodeErrorResponse::new( - DeviceCodeErrorResponseType::ExpiredToken, - Some(String::from("This device code has expired.")), - None, - ) - ), - _ => unreachable!("Error should be an expiry"), - } -} - -#[test] -fn test_device_token_access_denied() { - let details = new_device_auth_details(2); - assert_eq!("12345", details.device_code().secret()); - assert_eq!("https://verify/here", details.verification_uri().as_str()); - assert_eq!("abcde", details.user_code().secret().as_str()); - assert_eq!( - "https://verify/here?abcde", - details - .verification_uri_complete() - .unwrap() - .secret() - .as_str() - ); - assert_eq!(Duration::from_secs(2), details.expires_in()); - assert_eq!(Duration::from_secs(1), details.interval()); - - let token = new_client() - .exchange_device_access_token(&details) - .set_time_fn(mock_time_fn()) - .request(mock_http_client( - vec![ - (ACCEPT, "application/json"), - (CONTENT_TYPE, "application/x-www-form-urlencoded"), - (AUTHORIZATION, "Basic YWFhOmJiYg=="), - ], - "grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Adevice_code&device_code=12345", - None, - HttpResponse { - status_code: StatusCode::from_u16(400).unwrap(), - headers: vec![( - CONTENT_TYPE, - HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: "{\ - \"error\": \"access_denied\", \ - \"error_description\": \"Access Denied\"\ - }" - .to_string() - .into_bytes(), - }, - ), - mock_sleep_fn, - None) - .err() - .unwrap(); - match token { - RequestTokenError::ServerResponse(msg) => { - assert_eq!(msg.error(), &DeviceCodeErrorResponseType::AccessDenied) - } - _ => unreachable!("Error should be Access Denied"), - } -} - -#[test] -fn test_device_token_expired() { - let details = new_device_auth_details(2); - assert_eq!("12345", details.device_code().secret()); - assert_eq!("https://verify/here", details.verification_uri().as_str()); - assert_eq!("abcde", details.user_code().secret().as_str()); - assert_eq!( - "https://verify/here?abcde", - details - .verification_uri_complete() - .unwrap() - .secret() - .as_str() - ); - assert_eq!(Duration::from_secs(2), details.expires_in()); - assert_eq!(Duration::from_secs(1), details.interval()); - - let token = new_client() - .exchange_device_access_token(&details) - .set_time_fn(mock_time_fn()) - .request(mock_http_client( - vec![ - (ACCEPT, "application/json"), - (CONTENT_TYPE, "application/x-www-form-urlencoded"), - (AUTHORIZATION, "Basic YWFhOmJiYg=="), - ], - "grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Adevice_code&device_code=12345", - None, - HttpResponse { - status_code: StatusCode::from_u16(400).unwrap(), - headers: vec![( - CONTENT_TYPE, - HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: "{\ - \"error\": \"expired_token\", \ - \"error_description\": \"Token has expired\"\ - }" - .to_string() - .into_bytes(), - }, - ), - mock_sleep_fn, - None) - .err() - .unwrap(); - match token { - RequestTokenError::ServerResponse(msg) => { - assert_eq!(msg.error(), &DeviceCodeErrorResponseType::ExpiredToken) - } - _ => unreachable!("Error should be ExpiredToken"), - } -} - -fn mock_http_client_success_fail( +pub(crate) fn mock_http_client_success_fail( request_url: Option, request_headers: Vec<(HeaderName, &'static str)>, request_body: &'static str, @@ -2392,156 +212,6 @@ fn mock_http_client_success_fail( } } -#[test] -fn test_device_token_pending_then_success() { - let details = new_device_auth_details(20); - assert_eq!("12345", details.device_code().secret()); - assert_eq!("https://verify/here", details.verification_uri().as_str()); - assert_eq!("abcde", details.user_code().secret().as_str()); - assert_eq!( - "https://verify/here?abcde", - details - .verification_uri_complete() - .unwrap() - .secret() - .as_str() - ); - assert_eq!(Duration::from_secs(20), details.expires_in()); - assert_eq!(Duration::from_secs(1), details.interval()); - - let token = new_client() - .exchange_device_access_token(&details) - .set_time_fn(mock_time_fn()) - .request(mock_http_client_success_fail( - None, - vec![ - (ACCEPT, "application/json"), - (CONTENT_TYPE, "application/x-www-form-urlencoded"), - (AUTHORIZATION, "Basic YWFhOmJiYg=="), - ], - "grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Adevice_code&device_code=12345", - HttpResponse { - status_code: StatusCode::from_u16(400).unwrap(), - headers: vec![( - CONTENT_TYPE, - HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: "{\ - \"error\": \"authorization_pending\", \ - \"error_description\": \"Still waiting for user\"\ - }" - .to_string() - .into_bytes(), - }, - 5, - HttpResponse { - status_code: StatusCode::OK, - headers: vec![( - CONTENT_TYPE, - HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: "{\ - \"access_token\": \"12/34\", \ - \"token_type\": \"bearer\", \ - \"scope\": \"openid\"\ - }" - .to_string() - .into_bytes(), - }, - ), - mock_sleep_fn, - None) - .unwrap(); - - assert_eq!("12/34", token.access_token().secret()); - assert_eq!(BasicTokenType::Bearer, *token.token_type()); - assert_eq!( - Some(&vec![Scope::new("openid".to_string()),]), - token.scopes() - ); - assert_eq!(None, token.expires_in()); - assert!(token.refresh_token().is_none()); -} - -#[test] -fn test_device_token_slowdown_then_success() { - let details = new_device_auth_details(3600); - assert_eq!("12345", details.device_code().secret()); - assert_eq!("https://verify/here", details.verification_uri().as_str()); - assert_eq!("abcde", details.user_code().secret().as_str()); - assert_eq!( - "https://verify/here?abcde", - details - .verification_uri_complete() - .unwrap() - .secret() - .as_str() - ); - assert_eq!(Duration::from_secs(3600), details.expires_in()); - assert_eq!(Duration::from_secs(1), details.interval()); - - let token = new_client() - .exchange_device_access_token(&details) - .set_time_fn(mock_time_fn()) - .request(mock_http_client_success_fail( - None, - vec![ - (ACCEPT, "application/json"), - (CONTENT_TYPE, "application/x-www-form-urlencoded"), - (AUTHORIZATION, "Basic YWFhOmJiYg=="), - ], - "grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Adevice_code&device_code=12345", - HttpResponse { - status_code: StatusCode::from_u16(400).unwrap(), - headers: vec![( - CONTENT_TYPE, - HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: "{\ - \"error\": \"slow_down\", \ - \"error_description\": \"Woah there partner\"\ - }" - .to_string() - .into_bytes(), - }, - 5, - HttpResponse { - status_code: StatusCode::OK, - headers: vec![( - CONTENT_TYPE, - HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: "{\ - \"access_token\": \"12/34\", \ - \"token_type\": \"bearer\", \ - \"scope\": \"openid\"\ - }" - .to_string() - .into_bytes(), - }, - ), - mock_sleep_fn, - None) - .unwrap(); - - assert_eq!("12/34", token.access_token().secret()); - assert_eq!(BasicTokenType::Bearer, *token.token_type()); - assert_eq!( - Some(&vec![Scope::new("openid".to_string()),]), - token.scopes() - ); - assert_eq!(None, token.expires_in()); - assert!(token.refresh_token().is_none()); -} - #[test] fn test_send_sync_impl() { fn is_sync_and_send() {} diff --git a/src/token/mod.rs b/src/token/mod.rs new file mode 100644 index 0000000..5bc93ec --- /dev/null +++ b/src/token/mod.rs @@ -0,0 +1,607 @@ +use crate::endpoint::{endpoint_request, endpoint_response}; +use crate::{ + AccessToken, AuthType, AuthorizationCode, ClientId, ClientSecret, ErrorResponse, HttpRequest, + HttpResponse, PkceCodeVerifier, RedirectUrl, RefreshToken, RequestTokenError, + ResourceOwnerPassword, ResourceOwnerUsername, Scope, TokenUrl, +}; + +use serde::de::DeserializeOwned; +use serde::{Deserialize, Serialize}; + +use std::borrow::Cow; +use std::error::Error; +use std::fmt::Debug; +use std::future::Future; +use std::marker::PhantomData; +use std::time::Duration; + +#[cfg(test)] +mod tests; + +/// A request to exchange an authorization code for an access token. +/// +/// See . +#[derive(Debug)] +pub struct CodeTokenRequest<'a, TE, TR, TT> +where + TE: ErrorResponse, + TR: TokenResponse, + TT: TokenType, +{ + pub(crate) auth_type: &'a AuthType, + pub(crate) client_id: &'a ClientId, + pub(crate) client_secret: Option<&'a ClientSecret>, + pub(crate) code: AuthorizationCode, + pub(crate) extra_params: Vec<(Cow<'a, str>, Cow<'a, str>)>, + pub(crate) pkce_verifier: Option, + pub(crate) token_url: &'a TokenUrl, + pub(crate) redirect_url: Option>, + pub(crate) _phantom: PhantomData<(TE, TR, TT)>, +} +impl<'a, TE, TR, TT> CodeTokenRequest<'a, TE, TR, TT> +where + TE: ErrorResponse + 'static, + TR: TokenResponse, + TT: TokenType, +{ + /// Appends an extra param to the token request. + /// + /// This method allows extensions to be used without direct support from + /// this crate. If `name` conflicts with a parameter managed by this crate, the + /// behavior is undefined. In particular, do not set parameters defined by + /// [RFC 6749](https://tools.ietf.org/html/rfc6749) or + /// [RFC 7636](https://tools.ietf.org/html/rfc7636). + /// + /// # Security Warning + /// + /// Callers should follow the security recommendations for any OAuth2 extensions used with + /// this function, which are beyond the scope of + /// [RFC 6749](https://tools.ietf.org/html/rfc6749). + pub fn add_extra_param(mut self, name: N, value: V) -> Self + where + N: Into>, + V: Into>, + { + self.extra_params.push((name.into(), value.into())); + self + } + + /// Completes the [Proof Key for Code Exchange](https://tools.ietf.org/html/rfc7636) + /// (PKCE) protocol flow. + /// + /// This method must be called if [`crate::AuthorizationRequest::set_pkce_challenge`] was used + /// during the authorization request. + pub fn set_pkce_verifier(mut self, pkce_verifier: PkceCodeVerifier) -> Self { + self.pkce_verifier = Some(pkce_verifier); + self + } + + /// Overrides the `redirect_url` to the one specified. + pub fn set_redirect_uri(mut self, redirect_url: Cow<'a, RedirectUrl>) -> Self { + self.redirect_url = Some(redirect_url); + self + } + + fn prepare_request(self) -> HttpRequest { + let mut params = vec![ + ("grant_type", "authorization_code"), + ("code", self.code.secret()), + ]; + if let Some(ref pkce_verifier) = self.pkce_verifier { + params.push(("code_verifier", pkce_verifier.secret())); + } + + endpoint_request( + self.auth_type, + self.client_id, + self.client_secret, + &self.extra_params, + self.redirect_url, + None, + self.token_url.url(), + params, + ) + } + + /// Synchronously sends the request to the authorization server and awaits a response. + pub fn request(self, http_client: F) -> Result> + where + F: FnOnce(HttpRequest) -> Result, + RE: Error + 'static, + { + endpoint_response(http_client(self.prepare_request())?) + } + + /// Asynchronously sends the request to the authorization server and returns a Future. + pub async fn request_async( + self, + http_client: C, + ) -> Result> + where + C: FnOnce(HttpRequest) -> F, + F: Future>, + RE: Error + 'static, + { + let http_response = http_client(self.prepare_request()).await?; + endpoint_response(http_response) + } +} + +/// A request to exchange a refresh token for an access token. +/// +/// See . +#[derive(Debug)] +pub struct RefreshTokenRequest<'a, TE, TR, TT> +where + TE: ErrorResponse, + TR: TokenResponse, + TT: TokenType, +{ + pub(crate) auth_type: &'a AuthType, + pub(crate) client_id: &'a ClientId, + pub(crate) client_secret: Option<&'a ClientSecret>, + pub(crate) extra_params: Vec<(Cow<'a, str>, Cow<'a, str>)>, + pub(crate) refresh_token: &'a RefreshToken, + pub(crate) scopes: Vec>, + pub(crate) token_url: &'a TokenUrl, + pub(crate) _phantom: PhantomData<(TE, TR, TT)>, +} +impl<'a, TE, TR, TT> RefreshTokenRequest<'a, TE, TR, TT> +where + TE: ErrorResponse + 'static, + TR: TokenResponse, + TT: TokenType, +{ + /// Appends an extra param to the token request. + /// + /// This method allows extensions to be used without direct support from + /// this crate. If `name` conflicts with a parameter managed by this crate, the + /// behavior is undefined. In particular, do not set parameters defined by + /// [RFC 6749](https://tools.ietf.org/html/rfc6749) or + /// [RFC 7636](https://tools.ietf.org/html/rfc7636). + /// + /// # Security Warning + /// + /// Callers should follow the security recommendations for any OAuth2 extensions used with + /// this function, which are beyond the scope of + /// [RFC 6749](https://tools.ietf.org/html/rfc6749). + pub fn add_extra_param(mut self, name: N, value: V) -> Self + where + N: Into>, + V: Into>, + { + self.extra_params.push((name.into(), value.into())); + self + } + + /// Appends a new scope to the token request. + pub fn add_scope(mut self, scope: Scope) -> Self { + self.scopes.push(Cow::Owned(scope)); + self + } + + /// Appends a collection of scopes to the token request. + pub fn add_scopes(mut self, scopes: I) -> Self + where + I: IntoIterator, + { + self.scopes.extend(scopes.into_iter().map(Cow::Owned)); + self + } + + /// Synchronously sends the request to the authorization server and awaits a response. + pub fn request(self, http_client: F) -> Result> + where + F: FnOnce(HttpRequest) -> Result, + RE: Error + 'static, + { + endpoint_response(http_client(self.prepare_request()?)?) + } + /// Asynchronously sends the request to the authorization server and awaits a response. + pub async fn request_async( + self, + http_client: C, + ) -> Result> + where + C: FnOnce(HttpRequest) -> F, + F: Future>, + RE: Error + 'static, + { + let http_request = self.prepare_request()?; + let http_response = http_client(http_request).await?; + endpoint_response(http_response) + } + + fn prepare_request(&self) -> Result> + where + RE: Error + 'static, + { + Ok(endpoint_request( + self.auth_type, + self.client_id, + self.client_secret, + &self.extra_params, + None, + Some(&self.scopes), + self.token_url.url(), + vec![ + ("grant_type", "refresh_token"), + ("refresh_token", self.refresh_token.secret()), + ], + )) + } +} + +/// A request to exchange resource owner credentials for an access token. +/// +/// See . +#[derive(Debug)] +pub struct PasswordTokenRequest<'a, TE, TR, TT> +where + TE: ErrorResponse, + TR: TokenResponse, + TT: TokenType, +{ + pub(crate) auth_type: &'a AuthType, + pub(crate) client_id: &'a ClientId, + pub(crate) client_secret: Option<&'a ClientSecret>, + pub(crate) extra_params: Vec<(Cow<'a, str>, Cow<'a, str>)>, + pub(crate) username: &'a ResourceOwnerUsername, + pub(crate) password: &'a ResourceOwnerPassword, + pub(crate) scopes: Vec>, + pub(crate) token_url: &'a TokenUrl, + pub(crate) _phantom: PhantomData<(TE, TR, TT)>, +} +impl<'a, TE, TR, TT> PasswordTokenRequest<'a, TE, TR, TT> +where + TE: ErrorResponse + 'static, + TR: TokenResponse, + TT: TokenType, +{ + /// Appends an extra param to the token request. + /// + /// This method allows extensions to be used without direct support from + /// this crate. If `name` conflicts with a parameter managed by this crate, the + /// behavior is undefined. In particular, do not set parameters defined by + /// [RFC 6749](https://tools.ietf.org/html/rfc6749) or + /// [RFC 7636](https://tools.ietf.org/html/rfc7636). + /// + /// # Security Warning + /// + /// Callers should follow the security recommendations for any OAuth2 extensions used with + /// this function, which are beyond the scope of + /// [RFC 6749](https://tools.ietf.org/html/rfc6749). + pub fn add_extra_param(mut self, name: N, value: V) -> Self + where + N: Into>, + V: Into>, + { + self.extra_params.push((name.into(), value.into())); + self + } + + /// Appends a new scope to the token request. + pub fn add_scope(mut self, scope: Scope) -> Self { + self.scopes.push(Cow::Owned(scope)); + self + } + + /// Appends a collection of scopes to the token request. + pub fn add_scopes(mut self, scopes: I) -> Self + where + I: IntoIterator, + { + self.scopes.extend(scopes.into_iter().map(Cow::Owned)); + self + } + + /// Synchronously sends the request to the authorization server and awaits a response. + pub fn request(self, http_client: F) -> Result> + where + F: FnOnce(HttpRequest) -> Result, + RE: Error + 'static, + { + endpoint_response(http_client(self.prepare_request()?)?) + } + + /// Asynchronously sends the request to the authorization server and awaits a response. + pub async fn request_async( + self, + http_client: C, + ) -> Result> + where + C: FnOnce(HttpRequest) -> F, + F: Future>, + RE: Error + 'static, + { + let http_request = self.prepare_request()?; + let http_response = http_client(http_request).await?; + endpoint_response(http_response) + } + + fn prepare_request(&self) -> Result> + where + RE: Error + 'static, + { + Ok(endpoint_request( + self.auth_type, + self.client_id, + self.client_secret, + &self.extra_params, + None, + Some(&self.scopes), + self.token_url.url(), + vec![ + ("grant_type", "password"), + ("username", self.username), + ("password", self.password.secret()), + ], + )) + } +} + +/// A request to exchange client credentials for an access token. +/// +/// See . +#[derive(Debug)] +pub struct ClientCredentialsTokenRequest<'a, TE, TR, TT> +where + TE: ErrorResponse, + TR: TokenResponse, + TT: TokenType, +{ + pub(crate) auth_type: &'a AuthType, + pub(crate) client_id: &'a ClientId, + pub(crate) client_secret: Option<&'a ClientSecret>, + pub(crate) extra_params: Vec<(Cow<'a, str>, Cow<'a, str>)>, + pub(crate) scopes: Vec>, + pub(crate) token_url: &'a TokenUrl, + pub(crate) _phantom: PhantomData<(TE, TR, TT)>, +} +impl<'a, TE, TR, TT> ClientCredentialsTokenRequest<'a, TE, TR, TT> +where + TE: ErrorResponse + 'static, + TR: TokenResponse, + TT: TokenType, +{ + /// Appends an extra param to the token request. + /// + /// This method allows extensions to be used without direct support from + /// this crate. If `name` conflicts with a parameter managed by this crate, the + /// behavior is undefined. In particular, do not set parameters defined by + /// [RFC 6749](https://tools.ietf.org/html/rfc6749) or + /// [RFC 7636](https://tools.ietf.org/html/rfc7636). + /// + /// # Security Warning + /// + /// Callers should follow the security recommendations for any OAuth2 extensions used with + /// this function, which are beyond the scope of + /// [RFC 6749](https://tools.ietf.org/html/rfc6749). + pub fn add_extra_param(mut self, name: N, value: V) -> Self + where + N: Into>, + V: Into>, + { + self.extra_params.push((name.into(), value.into())); + self + } + + /// Appends a new scope to the token request. + pub fn add_scope(mut self, scope: Scope) -> Self { + self.scopes.push(Cow::Owned(scope)); + self + } + + /// Appends a collection of scopes to the token request. + pub fn add_scopes(mut self, scopes: I) -> Self + where + I: IntoIterator, + { + self.scopes.extend(scopes.into_iter().map(Cow::Owned)); + self + } + + /// Synchronously sends the request to the authorization server and awaits a response. + pub fn request(self, http_client: F) -> Result> + where + F: FnOnce(HttpRequest) -> Result, + RE: Error + 'static, + { + endpoint_response(http_client(self.prepare_request()?)?) + } + + /// Asynchronously sends the request to the authorization server and awaits a response. + pub async fn request_async( + self, + http_client: C, + ) -> Result> + where + C: FnOnce(HttpRequest) -> F, + F: Future>, + RE: Error + 'static, + { + let http_request = self.prepare_request()?; + let http_response = http_client(http_request).await?; + endpoint_response(http_response) + } + + fn prepare_request(&self) -> Result> + where + RE: Error + 'static, + { + Ok(endpoint_request( + self.auth_type, + self.client_id, + self.client_secret, + &self.extra_params, + None, + Some(&self.scopes), + self.token_url.url(), + vec![("grant_type", "client_credentials")], + )) + } +} + +/// Trait for OAuth2 access tokens. +pub trait TokenType: Clone + DeserializeOwned + Debug + PartialEq + Serialize {} + +/// Trait for adding extra fields to the `TokenResponse`. +pub trait ExtraTokenFields: DeserializeOwned + Debug + Serialize {} + +/// Empty (default) extra token fields. +#[derive(Clone, Debug, Deserialize, PartialEq, Eq, Serialize)] +pub struct EmptyExtraTokenFields {} +impl ExtraTokenFields for EmptyExtraTokenFields {} + +/// Common methods shared by all OAuth2 token implementations. +/// +/// The methods in this trait are defined in +/// [Section 5.1 of RFC 6749](https://tools.ietf.org/html/rfc6749#section-5.1). This trait exists +/// separately from the `StandardTokenResponse` struct to support customization by clients, +/// such as supporting interoperability with non-standards-complaint OAuth2 providers. +pub trait TokenResponse: Debug + DeserializeOwned + Serialize +where + TT: TokenType, +{ + /// REQUIRED. The access token issued by the authorization server. + fn access_token(&self) -> &AccessToken; + /// REQUIRED. The type of the token issued as described in + /// [Section 7.1](https://tools.ietf.org/html/rfc6749#section-7.1). + /// Value is case insensitive and deserialized to the generic `TokenType` parameter. + fn token_type(&self) -> &TT; + /// RECOMMENDED. The lifetime in seconds of the access token. For example, the value 3600 + /// denotes that the access token will expire in one hour from the time the response was + /// generated. If omitted, the authorization server SHOULD provide the expiration time via + /// other means or document the default value. + fn expires_in(&self) -> Option; + /// OPTIONAL. The refresh token, which can be used to obtain new access tokens using the same + /// authorization grant as described in + /// [Section 6](https://tools.ietf.org/html/rfc6749#section-6). + fn refresh_token(&self) -> Option<&RefreshToken>; + /// OPTIONAL, if identical to the scope requested by the client; otherwise, REQUIRED. The + /// scope of the access token as described by + /// [Section 3.3](https://tools.ietf.org/html/rfc6749#section-3.3). If included in the response, + /// this space-delimited field is parsed into a `Vec` of individual scopes. If omitted from + /// the response, this field is `None`. + fn scopes(&self) -> Option<&Vec>; +} + +/// Standard OAuth2 token response. +/// +/// This struct includes the fields defined in +/// [Section 5.1 of RFC 6749](https://tools.ietf.org/html/rfc6749#section-5.1), as well as +/// extensions defined by the `EF` type parameter. +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct StandardTokenResponse +where + EF: ExtraTokenFields, + TT: TokenType, +{ + access_token: AccessToken, + #[serde(bound = "TT: TokenType")] + #[serde(deserialize_with = "crate::helpers::deserialize_untagged_enum_case_insensitive")] + token_type: TT, + #[serde(skip_serializing_if = "Option::is_none")] + expires_in: Option, + #[serde(skip_serializing_if = "Option::is_none")] + refresh_token: Option, + #[serde(rename = "scope")] + #[serde(deserialize_with = "crate::helpers::deserialize_space_delimited_vec")] + #[serde(serialize_with = "crate::helpers::serialize_space_delimited_vec")] + #[serde(skip_serializing_if = "Option::is_none")] + #[serde(default)] + scopes: Option>, + + #[serde(bound = "EF: ExtraTokenFields")] + #[serde(flatten)] + extra_fields: EF, +} +impl StandardTokenResponse +where + EF: ExtraTokenFields, + TT: TokenType, +{ + /// Instantiate a new OAuth2 token response. + pub fn new(access_token: AccessToken, token_type: TT, extra_fields: EF) -> Self { + Self { + access_token, + token_type, + expires_in: None, + refresh_token: None, + scopes: None, + extra_fields, + } + } + + /// Set the `access_token` field. + pub fn set_access_token(&mut self, access_token: AccessToken) { + self.access_token = access_token; + } + + /// Set the `token_type` field. + pub fn set_token_type(&mut self, token_type: TT) { + self.token_type = token_type; + } + + /// Set the `expires_in` field. + pub fn set_expires_in(&mut self, expires_in: Option<&Duration>) { + self.expires_in = expires_in.map(Duration::as_secs); + } + + /// Set the `refresh_token` field. + pub fn set_refresh_token(&mut self, refresh_token: Option) { + self.refresh_token = refresh_token; + } + + /// Set the `scopes` field. + pub fn set_scopes(&mut self, scopes: Option>) { + self.scopes = scopes; + } + + /// Extra fields defined by the client application. + pub fn extra_fields(&self) -> &EF { + &self.extra_fields + } + + /// Set the extra fields defined by the client application. + pub fn set_extra_fields(&mut self, extra_fields: EF) { + self.extra_fields = extra_fields; + } +} +impl TokenResponse for StandardTokenResponse +where + EF: ExtraTokenFields, + TT: TokenType, +{ + /// REQUIRED. The access token issued by the authorization server. + fn access_token(&self) -> &AccessToken { + &self.access_token + } + /// REQUIRED. The type of the token issued as described in + /// [Section 7.1](https://tools.ietf.org/html/rfc6749#section-7.1). + /// Value is case insensitive and deserialized to the generic `TokenType` parameter. + fn token_type(&self) -> &TT { + &self.token_type + } + /// RECOMMENDED. The lifetime in seconds of the access token. For example, the value 3600 + /// denotes that the access token will expire in one hour from the time the response was + /// generated. If omitted, the authorization server SHOULD provide the expiration time via + /// other means or document the default value. + fn expires_in(&self) -> Option { + self.expires_in.map(Duration::from_secs) + } + /// OPTIONAL. The refresh token, which can be used to obtain new access tokens using the same + /// authorization grant as described in + /// [Section 6](https://tools.ietf.org/html/rfc6749#section-6). + fn refresh_token(&self) -> Option<&RefreshToken> { + self.refresh_token.as_ref() + } + /// OPTIONAL, if identical to the scope requested by the client; otherwise, REQUIRED. The + /// scope of the access token as described by + /// [Section 3.3](https://tools.ietf.org/html/rfc6749#section-3.3). If included in the response, + /// this space-delimited field is parsed into a `Vec` of individual scopes. If omitted from + /// the response, this field is `None`. + fn scopes(&self) -> Option<&Vec> { + self.scopes.as_ref() + } +} diff --git a/src/token/tests.rs b/src/token/tests.rs new file mode 100644 index 0000000..dfd6103 --- /dev/null +++ b/src/token/tests.rs @@ -0,0 +1,1243 @@ +use crate::basic::{ + BasicClient, BasicErrorResponse, BasicErrorResponseType, BasicTokenResponse, BasicTokenType, +}; +use crate::tests::colorful_extension::{ + ColorfulClient, ColorfulErrorResponseType, ColorfulFields, ColorfulTokenResponse, + ColorfulTokenType, +}; +use crate::tests::{mock_http_client, new_client, FakeError}; +use crate::token::tests::custom_errors::CustomErrorClient; +use crate::{ + AccessToken, AuthType, AuthUrl, AuthorizationCode, ClientId, ClientSecret, ExtraTokenFields, + HttpResponse, PkceCodeVerifier, RedirectUrl, RefreshToken, RequestTokenError, + ResourceOwnerPassword, ResourceOwnerUsername, Scope, StandardErrorResponse, + StandardTokenResponse, TokenResponse, TokenType, TokenUrl, +}; + +use http::header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE}; +use http::{HeaderMap, HeaderValue, StatusCode}; + +use std::borrow::Cow; +use std::time::Duration; + +// Because the secret types don't implement PartialEq, we can't directly use == to compare tokens. +fn assert_token_eq(a: &StandardTokenResponse, b: &StandardTokenResponse) +where + EF: ExtraTokenFields + PartialEq, + TT: TokenType, +{ + assert_eq!(a.access_token().secret(), b.access_token().secret()); + assert_eq!(a.token_type(), b.token_type()); + assert_eq!(a.expires_in(), b.expires_in()); + assert_eq!( + a.refresh_token().map(RefreshToken::secret), + b.refresh_token().map(RefreshToken::secret) + ); + assert_eq!(a.scopes(), b.scopes()); + assert_eq!(a.extra_fields(), b.extra_fields()); +} + +#[test] +fn test_exchange_code_successful_with_minimal_json_response() { + let client = BasicClient::new(ClientId::new("aaa".to_string())) + .set_client_secret(ClientSecret::new("bbb".to_string())) + .set_auth_url(AuthUrl::new("https://example.com/auth".to_string()).unwrap()) + .set_token_url(TokenUrl::new("https://example.com/token".to_string()).unwrap()); + + let token = client + .exchange_code(AuthorizationCode::new("ccc".to_string())) + .request(mock_http_client( + vec![ + (ACCEPT, "application/json"), + (CONTENT_TYPE, "application/x-www-form-urlencoded"), + (AUTHORIZATION, "Basic YWFhOmJiYg=="), + ], + "grant_type=authorization_code&code=ccc", + None, + HttpResponse { + status_code: StatusCode::OK, + headers: HeaderMap::new(), + body: "{\"access_token\": \"12/34\", \"token_type\": \"BEARER\"}" + .to_string() + .into_bytes(), + }, + )) + .unwrap(); + + assert_eq!("12/34", token.access_token().secret()); + assert_eq!(BasicTokenType::Bearer, *token.token_type()); + assert_eq!(None, token.expires_in()); + assert!(token.refresh_token().is_none()); + + // Ensure that serialization produces an equivalent JSON value. + let serialized_json = serde_json::to_string(&token).unwrap(); + assert_eq!( + "{\"access_token\":\"12/34\",\"token_type\":\"bearer\"}".to_string(), + serialized_json + ); + + let deserialized_token = serde_json::from_str::(&serialized_json).unwrap(); + assert_token_eq(&token, &deserialized_token); +} + +#[test] +fn test_exchange_code_successful_with_complete_json_response() { + let client = new_client().set_auth_type(AuthType::RequestBody); + let token = client + .exchange_code(AuthorizationCode::new("ccc".to_string())) + .request(mock_http_client( + vec![ + (ACCEPT, "application/json"), + (CONTENT_TYPE, "application/x-www-form-urlencoded"), + ], + "grant_type=authorization_code&code=ccc&client_id=aaa&client_secret=bbb", + None, + HttpResponse { + status_code: StatusCode::OK, + headers: vec![( + CONTENT_TYPE, + HeaderValue::from_str("application/json").unwrap(), + )] + .into_iter() + .collect(), + body: "{\ + \"access_token\": \"12/34\", \ + \"token_type\": \"bearer\", \ + \"scope\": \"read write\", \ + \"expires_in\": 3600, \ + \"refresh_token\": \"foobar\"\ + }" + .to_string() + .into_bytes(), + }, + )) + .unwrap(); + + assert_eq!("12/34", token.access_token().secret()); + assert_eq!(BasicTokenType::Bearer, *token.token_type()); + assert_eq!( + Some(&vec![ + Scope::new("read".to_string()), + Scope::new("write".to_string()), + ]), + token.scopes() + ); + assert_eq!(3600, token.expires_in().unwrap().as_secs()); + assert_eq!("foobar", token.refresh_token().unwrap().secret()); + + // Ensure that serialization produces an equivalent JSON value. + let serialized_json = serde_json::to_string(&token).unwrap(); + assert_eq!( + "{\"access_token\":\"12/34\",\"token_type\":\"bearer\",\"expires_in\":3600,\ + \"refresh_token\":\"foobar\",\"scope\":\"read write\"}" + .to_string(), + serialized_json + ); + + let deserialized_token = serde_json::from_str::(&serialized_json).unwrap(); + assert_token_eq(&token, &deserialized_token); +} + +#[test] +fn test_exchange_client_credentials_with_basic_auth() { + let client = BasicClient::new(ClientId::new("aaa/;&".to_string())) + .set_client_secret(ClientSecret::new("bbb/;&".to_string())) + .set_auth_url(AuthUrl::new("https://example.com/auth".to_string()).unwrap()) + .set_token_url(TokenUrl::new("https://example.com/token".to_string()).unwrap()) + .set_auth_type(AuthType::BasicAuth); + + let token = client + .exchange_client_credentials() + .request(mock_http_client( + vec![ + (ACCEPT, "application/json"), + (CONTENT_TYPE, "application/x-www-form-urlencoded"), + (AUTHORIZATION, "Basic YWFhJTJGJTNCJTI2OmJiYiUyRiUzQiUyNg=="), + ], + "grant_type=client_credentials", + None, + HttpResponse { + status_code: StatusCode::OK, + headers: HeaderMap::new(), + body: "{\ + \"access_token\": \"12/34\", \ + \"token_type\": \"bearer\", \ + \"scope\": \"read write\"\ + }" + .to_string() + .into_bytes(), + }, + )) + .unwrap(); + + assert_eq!("12/34", token.access_token().secret()); + assert_eq!(BasicTokenType::Bearer, *token.token_type()); + assert_eq!( + Some(&vec![ + Scope::new("read".to_string()), + Scope::new("write".to_string()), + ]), + token.scopes() + ); + assert_eq!(None, token.expires_in()); + assert!(token.refresh_token().is_none()); +} + +#[test] +fn test_exchange_client_credentials_with_basic_auth_but_no_client_secret() { + let client = BasicClient::new(ClientId::new("aaa/;&".to_string())) + .set_auth_url(AuthUrl::new("https://example.com/auth".to_string()).unwrap()) + .set_token_url(TokenUrl::new("https://example.com/token".to_string()).unwrap()) + .set_auth_type(AuthType::BasicAuth); + + let token = client + .exchange_client_credentials() + .request(mock_http_client( + vec![ + (ACCEPT, "application/json"), + (CONTENT_TYPE, "application/x-www-form-urlencoded"), + ], + "grant_type=client_credentials&client_id=aaa%2F%3B%26", + None, + HttpResponse { + status_code: StatusCode::OK, + headers: HeaderMap::new(), + body: "{\ + \"access_token\": \"12/34\", \ + \"token_type\": \"bearer\", \ + \"scope\": \"read write\"\ + }" + .to_string() + .into_bytes(), + }, + )) + .unwrap(); + + assert_eq!("12/34", token.access_token().secret()); + assert_eq!(BasicTokenType::Bearer, *token.token_type()); + assert_eq!( + Some(&vec![ + Scope::new("read".to_string()), + Scope::new("write".to_string()), + ]), + token.scopes() + ); + assert_eq!(None, token.expires_in()); + assert!(token.refresh_token().is_none()); +} + +#[test] +fn test_exchange_client_credentials_with_body_auth_and_scope() { + let client = new_client().set_auth_type(AuthType::RequestBody); + let token = client + .exchange_client_credentials() + .add_scope(Scope::new("read".to_string())) + .add_scope(Scope::new("write".to_string())) + .request(mock_http_client( + vec![ + (ACCEPT, "application/json"), + (CONTENT_TYPE, "application/x-www-form-urlencoded"), + ], + "grant_type=client_credentials&scope=read+write&client_id=aaa&client_secret=bbb", + None, + HttpResponse { + status_code: StatusCode::OK, + headers: vec![( + CONTENT_TYPE, + HeaderValue::from_str("APPLICATION/jSoN").unwrap(), + )] + .into_iter() + .collect(), + body: "{\ + \"access_token\": \"12/34\", \ + \"token_type\": \"bearer\", \ + \"scope\": \"read write\"\ + }" + .to_string() + .into_bytes(), + }, + )) + .unwrap(); + + assert_eq!("12/34", token.access_token().secret()); + assert_eq!(BasicTokenType::Bearer, *token.token_type()); + assert_eq!( + Some(&vec![ + Scope::new("read".to_string()), + Scope::new("write".to_string()), + ]), + token.scopes() + ); + assert_eq!(None, token.expires_in()); + assert!(token.refresh_token().is_none()); +} + +#[test] +fn test_exchange_refresh_token_with_basic_auth() { + let client = new_client().set_auth_type(AuthType::BasicAuth); + let token = client + .exchange_refresh_token(&RefreshToken::new("ccc".to_string())) + .request(mock_http_client( + vec![ + (ACCEPT, "application/json"), + (CONTENT_TYPE, "application/x-www-form-urlencoded"), + (AUTHORIZATION, "Basic YWFhOmJiYg=="), + ], + "grant_type=refresh_token&refresh_token=ccc", + None, + HttpResponse { + status_code: StatusCode::OK, + headers: HeaderMap::new(), + body: "{\"access_token\": \"12/34\", \ + \"token_type\": \"bearer\", \ + \"scope\": \"read write\"\ + }" + .to_string() + .into_bytes(), + }, + )) + .unwrap(); + + assert_eq!("12/34", token.access_token().secret()); + assert_eq!(BasicTokenType::Bearer, *token.token_type()); + assert_eq!( + Some(&vec![ + Scope::new("read".to_string()), + Scope::new("write".to_string()), + ]), + token.scopes() + ); + assert_eq!(None, token.expires_in()); + assert!(token.refresh_token().is_none()); +} + +#[test] +fn test_exchange_refresh_token_with_json_response() { + let client = new_client(); + let token = client + .exchange_refresh_token(&RefreshToken::new("ccc".to_string())) + .request(mock_http_client( + vec![ + (ACCEPT, "application/json"), + (CONTENT_TYPE, "application/x-www-form-urlencoded"), + (AUTHORIZATION, "Basic YWFhOmJiYg=="), + ], + "grant_type=refresh_token&refresh_token=ccc", + None, + HttpResponse { + status_code: StatusCode::OK, + headers: HeaderMap::new(), + body: "{\ + \"access_token\": \"12/34\", \ + \"token_type\": \"bearer\", \ + \"scope\": \"read write\"\ + }" + .to_string() + .into_bytes(), + }, + )) + .unwrap(); + + assert_eq!("12/34", token.access_token().secret()); + assert_eq!(BasicTokenType::Bearer, *token.token_type()); + assert_eq!( + Some(&vec![ + Scope::new("read".to_string()), + Scope::new("write".to_string()), + ]), + token.scopes() + ); + assert_eq!(None, token.expires_in()); + assert!(token.refresh_token().is_none()); +} + +#[test] +fn test_exchange_password_with_json_response() { + let client = new_client(); + let token = client + .exchange_password( + &ResourceOwnerUsername::new("user".to_string()), + &ResourceOwnerPassword::new("pass".to_string()), + ) + .add_scope(Scope::new("read".to_string())) + .add_scope(Scope::new("write".to_string())) + .request(mock_http_client( + vec![ + (ACCEPT, "application/json"), + (CONTENT_TYPE, "application/x-www-form-urlencoded"), + (AUTHORIZATION, "Basic YWFhOmJiYg=="), + ], + "grant_type=password&username=user&password=pass&scope=read+write", + None, + HttpResponse { + status_code: StatusCode::OK, + headers: vec![( + CONTENT_TYPE, + HeaderValue::from_str("application/json").unwrap(), + )] + .into_iter() + .collect(), + body: "{\ + \"access_token\": \"12/34\", \ + \"token_type\": \"bearer\", \ + \"scope\": \"read write\"\ + }" + .to_string() + .into_bytes(), + }, + )) + .unwrap(); + + assert_eq!("12/34", token.access_token().secret()); + assert_eq!(BasicTokenType::Bearer, *token.token_type()); + assert_eq!( + Some(&vec![ + Scope::new("read".to_string()), + Scope::new("write".to_string()), + ]), + token.scopes() + ); + assert_eq!(None, token.expires_in()); + assert!(token.refresh_token().is_none()); +} + +#[test] +fn test_exchange_code_successful_with_redirect_url() { + let client = new_client() + .set_auth_type(AuthType::RequestBody) + .set_redirect_uri(RedirectUrl::new("https://redirect/here".to_string()).unwrap()); + + let token = client + .exchange_code(AuthorizationCode::new("ccc".to_string())) + .request(mock_http_client( + vec![ + (ACCEPT, "application/json"), + (CONTENT_TYPE, "application/x-www-form-urlencoded"), + ], + "grant_type=authorization_code&code=ccc&client_id=aaa&client_secret=bbb&\ + redirect_uri=https%3A%2F%2Fredirect%2Fhere", + None, + HttpResponse { + status_code: StatusCode::OK, + headers: vec![( + CONTENT_TYPE, + HeaderValue::from_str("application/json").unwrap(), + )] + .into_iter() + .collect(), + body: "{\ + \"access_token\": \"12/34\", \ + \"token_type\": \"bearer\", \ + \"scope\": \"read write\"\ + }" + .to_string() + .into_bytes(), + }, + )) + .unwrap(); + + assert_eq!("12/34", token.access_token().secret()); + assert_eq!(BasicTokenType::Bearer, *token.token_type()); + assert_eq!( + Some(&vec![ + Scope::new("read".to_string()), + Scope::new("write".to_string()), + ]), + token.scopes() + ); + assert_eq!(None, token.expires_in()); + assert!(token.refresh_token().is_none()); +} + +#[test] +fn test_exchange_code_successful_with_redirect_url_override() { + let client = new_client() + .set_auth_type(AuthType::RequestBody) + .set_redirect_uri(RedirectUrl::new("https://redirect/here".to_string()).unwrap()); + + let token = client + .exchange_code(AuthorizationCode::new("ccc".to_string())) + .set_redirect_uri(Cow::Owned( + RedirectUrl::new("https://redirect/alternative".to_string()).unwrap(), + )) + .request(mock_http_client( + vec![ + (ACCEPT, "application/json"), + (CONTENT_TYPE, "application/x-www-form-urlencoded"), + ], + "grant_type=authorization_code&code=ccc&client_id=aaa&client_secret=bbb&\ + redirect_uri=https%3A%2F%2Fredirect%2Falternative", + None, + HttpResponse { + status_code: StatusCode::OK, + headers: vec![( + CONTENT_TYPE, + HeaderValue::from_str("application/json").unwrap(), + )] + .into_iter() + .collect(), + body: "{\ + \"access_token\": \"12/34\", \ + \"token_type\": \"bearer\", \ + \"scope\": \"read write\"\ + }" + .to_string() + .into_bytes(), + }, + )) + .unwrap(); + + assert_eq!("12/34", token.access_token().secret()); + assert_eq!(BasicTokenType::Bearer, *token.token_type()); + assert_eq!( + Some(&vec![ + Scope::new("read".to_string()), + Scope::new("write".to_string()), + ]), + token.scopes() + ); + assert_eq!(None, token.expires_in()); + assert!(token.refresh_token().is_none()); +} + +#[test] +fn test_exchange_code_successful_with_basic_auth() { + let client = new_client() + .set_auth_type(AuthType::BasicAuth) + .set_redirect_uri(RedirectUrl::new("https://redirect/here".to_string()).unwrap()); + + let token = client + .exchange_code(AuthorizationCode::new("ccc".to_string())) + .request(mock_http_client( + vec![ + (ACCEPT, "application/json"), + (CONTENT_TYPE, "application/x-www-form-urlencoded"), + (AUTHORIZATION, "Basic YWFhOmJiYg=="), + ], + "grant_type=authorization_code&code=ccc&redirect_uri=https%3A%2F%2Fredirect%2Fhere", + None, + HttpResponse { + status_code: StatusCode::OK, + headers: vec![( + CONTENT_TYPE, + HeaderValue::from_str("application/json").unwrap(), + )] + .into_iter() + .collect(), + body: "{\ + \"access_token\": \"12/34\", \ + \"token_type\": \"bearer\", \ + \"scope\": \"read write\"\ + }" + .to_string() + .into_bytes(), + }, + )) + .unwrap(); + + assert_eq!("12/34", token.access_token().secret()); + assert_eq!(BasicTokenType::Bearer, *token.token_type()); + assert_eq!( + Some(&vec![ + Scope::new("read".to_string()), + Scope::new("write".to_string()), + ]), + token.scopes() + ); + assert_eq!(None, token.expires_in()); + assert!(token.refresh_token().is_none()); +} + +#[test] +fn test_exchange_code_successful_with_pkce_and_extension() { + let client = new_client() + .set_auth_type(AuthType::BasicAuth) + .set_redirect_uri(RedirectUrl::new("https://redirect/here".to_string()).unwrap()); + + let token = client + .exchange_code(AuthorizationCode::new("ccc".to_string())) + .set_pkce_verifier(PkceCodeVerifier::new( + "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk".to_string(), + )) + .add_extra_param("foo", "bar") + .request(mock_http_client( + vec![ + (ACCEPT, "application/json"), + (CONTENT_TYPE, "application/x-www-form-urlencoded"), + (AUTHORIZATION, "Basic YWFhOmJiYg=="), + ], + "grant_type=authorization_code\ + &code=ccc\ + &code_verifier=dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk\ + &redirect_uri=https%3A%2F%2Fredirect%2Fhere\ + &foo=bar", + None, + HttpResponse { + status_code: StatusCode::OK, + headers: vec![( + CONTENT_TYPE, + HeaderValue::from_str("application/json").unwrap(), + )] + .into_iter() + .collect(), + body: "{\ + \"access_token\": \"12/34\", \ + \"token_type\": \"bearer\", \ + \"scope\": \"read write\"\ + }" + .to_string() + .into_bytes(), + }, + )) + .unwrap(); + + assert_eq!("12/34", token.access_token().secret()); + assert_eq!(BasicTokenType::Bearer, *token.token_type()); + assert_eq!( + Some(&vec![ + Scope::new("read".to_string()), + Scope::new("write".to_string()), + ]), + token.scopes() + ); + assert_eq!(None, token.expires_in()); + assert!(token.refresh_token().is_none()); +} + +#[test] +fn test_exchange_refresh_token_successful_with_extension() { + let client = new_client() + .set_auth_type(AuthType::BasicAuth) + .set_redirect_uri(RedirectUrl::new("https://redirect/here".to_string()).unwrap()); + + let token = client + .exchange_refresh_token(&RefreshToken::new("ccc".to_string())) + .add_extra_param("foo", "bar") + .request(mock_http_client( + vec![ + (ACCEPT, "application/json"), + (CONTENT_TYPE, "application/x-www-form-urlencoded"), + (AUTHORIZATION, "Basic YWFhOmJiYg=="), + ], + "grant_type=refresh_token&refresh_token=ccc&foo=bar", + None, + HttpResponse { + status_code: StatusCode::OK, + headers: vec![( + CONTENT_TYPE, + HeaderValue::from_str("application/json").unwrap(), + )] + .into_iter() + .collect(), + body: "{\ + \"access_token\": \"12/34\", \ + \"token_type\": \"bearer\", \ + \"scope\": \"read write\"\ + }" + .to_string() + .into_bytes(), + }, + )) + .unwrap(); + + assert_eq!("12/34", token.access_token().secret()); + assert_eq!(BasicTokenType::Bearer, *token.token_type()); + assert_eq!( + Some(&vec![ + Scope::new("read".to_string()), + Scope::new("write".to_string()), + ]), + token.scopes() + ); + assert_eq!(None, token.expires_in()); + assert!(token.refresh_token().is_none()); +} + +#[test] +fn test_exchange_code_with_simple_json_error() { + let client = new_client(); + let token = client + .exchange_code(AuthorizationCode::new("ccc".to_string())) + .request(mock_http_client( + vec![ + (ACCEPT, "application/json"), + (CONTENT_TYPE, "application/x-www-form-urlencoded"), + (AUTHORIZATION, "Basic YWFhOmJiYg=="), + ], + "grant_type=authorization_code&code=ccc", + None, + HttpResponse { + status_code: StatusCode::BAD_REQUEST, + headers: vec![( + CONTENT_TYPE, + HeaderValue::from_str("application/json").unwrap(), + )] + .into_iter() + .collect(), + body: "{\ + \"error\": \"invalid_request\", \ + \"error_description\": \"stuff happened\"\ + }" + .to_string() + .into_bytes(), + }, + )); + + assert!(token.is_err()); + + let token_err = token.err().unwrap(); + match token_err { + RequestTokenError::ServerResponse(ref error_response) => { + assert_eq!( + BasicErrorResponseType::InvalidRequest, + *error_response.error() + ); + assert_eq!( + Some(&"stuff happened".to_string()), + error_response.error_description() + ); + assert_eq!(None, error_response.error_uri()); + + // Test Debug trait for ErrorResponse + assert_eq!( + "StandardErrorResponse { error: invalid_request, \ + error_description: Some(\"stuff happened\"), error_uri: None }", + format!("{:?}", error_response) + ); + // Test Display trait for ErrorResponse + assert_eq!( + "invalid_request: stuff happened", + format!("{}", error_response) + ); + + // Test Debug trait for BasicErrorResponseType + assert_eq!("invalid_request", format!("{:?}", error_response.error())); + // Test Display trait for BasicErrorResponseType + assert_eq!("invalid_request", format!("{}", error_response.error())); + + // Ensure that serialization produces an equivalent JSON value. + let serialized_json = serde_json::to_string(&error_response).unwrap(); + assert_eq!( + "{\"error\":\"invalid_request\",\"error_description\":\"stuff happened\"}" + .to_string(), + serialized_json + ); + + let deserialized_error = + serde_json::from_str::(&serialized_json).unwrap(); + assert_eq!(error_response, &deserialized_error); + } + other => panic!("Unexpected error: {:?}", other), + } + + // Test Debug trait for RequestTokenError + assert_eq!( + "ServerResponse(StandardErrorResponse { error: invalid_request, \ + error_description: Some(\"stuff happened\"), error_uri: None })", + format!("{:?}", token_err) + ); + // Test Display trait for RequestTokenError + assert_eq!( + "Server returned error response: invalid_request: stuff happened", + token_err.to_string() + ); +} + +#[test] +fn test_exchange_code_with_json_parse_error() { + let client = new_client(); + let token = client + .exchange_code(AuthorizationCode::new("ccc".to_string())) + .request(mock_http_client( + vec![ + (ACCEPT, "application/json"), + (CONTENT_TYPE, "application/x-www-form-urlencoded"), + (AUTHORIZATION, "Basic YWFhOmJiYg=="), + ], + "grant_type=authorization_code&code=ccc", + None, + HttpResponse { + status_code: StatusCode::OK, + headers: vec![( + CONTENT_TYPE, + HeaderValue::from_str("application/json").unwrap(), + )] + .into_iter() + .collect(), + body: "broken json".to_string().into_bytes(), + }, + )); + + assert!(token.is_err()); + + match token.err().unwrap() { + RequestTokenError::Parse(json_err, _) => { + assert_eq!(".", json_err.path().to_string()); + assert_eq!(1, json_err.inner().line()); + assert_eq!(1, json_err.inner().column()); + assert_eq!( + serde_json::error::Category::Syntax, + json_err.inner().classify() + ); + } + other => panic!("Unexpected error: {:?}", other), + } +} + +#[test] +fn test_exchange_code_with_unexpected_content_type() { + let client = new_client(); + let token = client + .exchange_code(AuthorizationCode::new("ccc".to_string())) + .request(mock_http_client( + vec![ + (ACCEPT, "application/json"), + (CONTENT_TYPE, "application/x-www-form-urlencoded"), + (AUTHORIZATION, "Basic YWFhOmJiYg=="), + ], + "grant_type=authorization_code&code=ccc", + None, + HttpResponse { + status_code: StatusCode::OK, + headers: vec![(CONTENT_TYPE, HeaderValue::from_str("text/plain").unwrap())] + .into_iter() + .collect(), + body: "broken json".to_string().into_bytes(), + }, + )); + + assert!(token.is_err()); + + match token.err().unwrap() { + RequestTokenError::Other(error_str) => { + assert_eq!( + "Unexpected response Content-Type: \"text/plain\", should be `application/json`", + error_str + ); + } + other => panic!("Unexpected error: {:?}", other), + } +} + +#[test] +fn test_exchange_code_with_invalid_token_type() { + let client = BasicClient::new(ClientId::new("aaa".to_string())) + .set_auth_url(AuthUrl::new("https://example.com/auth".to_string()).unwrap()) + .set_token_url(TokenUrl::new("https://example.com/token".to_string()).unwrap()); + + let token = client + .exchange_code(AuthorizationCode::new("ccc".to_string())) + .request(mock_http_client( + vec![ + (ACCEPT, "application/json"), + (CONTENT_TYPE, "application/x-www-form-urlencoded"), + ], + "grant_type=authorization_code&code=ccc&client_id=aaa", + None, + HttpResponse { + status_code: StatusCode::OK, + headers: vec![( + CONTENT_TYPE, + HeaderValue::from_str("application/json").unwrap(), + )] + .into_iter() + .collect(), + body: "{\"access_token\": \"12/34\", \"token_type\": 123}" + .to_string() + .into_bytes(), + }, + )); + + assert!(token.is_err()); + match token.err().unwrap() { + RequestTokenError::Parse(json_err, _) => { + assert_eq!("token_type", json_err.path().to_string()); + assert_eq!(1, json_err.inner().line()); + assert_eq!(43, json_err.inner().column()); + assert_eq!( + serde_json::error::Category::Data, + json_err.inner().classify() + ); + } + other => panic!("Unexpected error: {:?}", other), + } +} + +#[test] +fn test_exchange_code_with_400_status_code() { + let body = r#"{"error":"invalid_request","error_description":"Expired code."}"#; + let client = new_client(); + let token_err = client + .exchange_code(AuthorizationCode::new("ccc".to_string())) + .request(mock_http_client( + vec![ + (ACCEPT, "application/json"), + (CONTENT_TYPE, "application/x-www-form-urlencoded"), + (AUTHORIZATION, "Basic YWFhOmJiYg=="), + ], + "grant_type=authorization_code&code=ccc", + None, + HttpResponse { + status_code: StatusCode::BAD_REQUEST, + headers: vec![( + CONTENT_TYPE, + HeaderValue::from_str("application/json").unwrap(), + )] + .into_iter() + .collect(), + body: body.to_string().into_bytes(), + }, + )) + .err() + .unwrap(); + + match token_err { + RequestTokenError::ServerResponse(ref error_response) => { + assert_eq!( + BasicErrorResponseType::InvalidRequest, + *error_response.error() + ); + assert_eq!( + Some(&"Expired code.".to_string()), + error_response.error_description() + ); + assert_eq!(None, error_response.error_uri()); + } + other => panic!("Unexpected error: {:?}", other), + } + + assert_eq!( + "Server returned error response: invalid_request: Expired code.", + token_err.to_string(), + ); +} + +#[test] +fn test_exchange_code_fails_gracefully_on_transport_error() { + let client = BasicClient::new(ClientId::new("aaa".to_string())) + .set_client_secret(ClientSecret::new("bbb".to_string())) + .set_auth_url(AuthUrl::new("https://auth".to_string()).unwrap()) + .set_token_url(TokenUrl::new("https://token".to_string()).unwrap()); + + let token = client + .exchange_code(AuthorizationCode::new("ccc".to_string())) + .request(|_| Err(FakeError::Err)); + + assert!(token.is_err()); + + match token.err().unwrap() { + RequestTokenError::Request(FakeError::Err) => (), + other => panic!("Unexpected error: {:?}", other), + } +} + +#[test] +fn test_extension_successful_with_minimal_json_response() { + let client = ColorfulClient::new(ClientId::new("aaa".to_string())) + .set_client_secret(ClientSecret::new("bbb".to_string())) + .set_auth_url(AuthUrl::new("https://example.com/auth".to_string()).unwrap()) + .set_token_url(TokenUrl::new("https://example.com/token".to_string()).unwrap()); + + let token = client + .exchange_code(AuthorizationCode::new("ccc".to_string())) + .request(mock_http_client( + vec![ + (ACCEPT, "application/json"), + (CONTENT_TYPE, "application/x-www-form-urlencoded"), + (AUTHORIZATION, "Basic YWFhOmJiYg=="), + ], + "grant_type=authorization_code&code=ccc", + None, + HttpResponse { + status_code: StatusCode::OK, + headers: vec![( + CONTENT_TYPE, + HeaderValue::from_str("application/json").unwrap(), + )] + .into_iter() + .collect(), + body: "{\"access_token\": \"12/34\", \"token_type\": \"green\", \"height\": 10}" + .to_string() + .into_bytes(), + }, + )) + .unwrap(); + + assert_eq!("12/34", token.access_token().secret()); + assert_eq!(ColorfulTokenType::Green, *token.token_type()); + assert_eq!(None, token.expires_in()); + assert!(token.refresh_token().is_none()); + assert_eq!(None, token.extra_fields().shape()); + assert_eq!(10, token.extra_fields().height()); + + // Ensure that serialization produces an equivalent JSON value. + let serialized_json = serde_json::to_string(&token).unwrap(); + assert_eq!( + "{\"access_token\":\"12/34\",\"token_type\":\"green\",\"height\":10}".to_string(), + serialized_json + ); + + let deserialized_token = + serde_json::from_str::(&serialized_json).unwrap(); + assert_token_eq(&token, &deserialized_token); +} + +#[test] +fn test_extension_successful_with_complete_json_response() { + let client = ColorfulClient::new(ClientId::new("aaa".to_string())) + .set_client_secret(ClientSecret::new("bbb".to_string())) + .set_auth_url(AuthUrl::new("https://example.com/auth".to_string()).unwrap()) + .set_token_url(TokenUrl::new("https://example.com/token".to_string()).unwrap()) + .set_auth_type(AuthType::RequestBody); + + let token = client + .exchange_code(AuthorizationCode::new("ccc".to_string())) + .request(mock_http_client( + vec![ + (ACCEPT, "application/json"), + (CONTENT_TYPE, "application/x-www-form-urlencoded"), + ], + "grant_type=authorization_code&code=ccc&client_id=aaa&client_secret=bbb", + None, + HttpResponse { + status_code: StatusCode::OK, + headers: vec![( + CONTENT_TYPE, + HeaderValue::from_str("application/json").unwrap(), + )] + .into_iter() + .collect(), + body: "{\ + \"access_token\": \"12/34\", \ + \"token_type\": \"red\", \ + \"scope\": \"read write\", \ + \"expires_in\": 3600, \ + \"refresh_token\": \"foobar\", \ + \"shape\": \"round\", \ + \"height\": 12\ + }" + .to_string() + .into_bytes(), + }, + )) + .unwrap(); + + assert_eq!("12/34", token.access_token().secret()); + assert_eq!(ColorfulTokenType::Red, *token.token_type()); + assert_eq!( + Some(&vec![ + Scope::new("read".to_string()), + Scope::new("write".to_string()), + ]), + token.scopes() + ); + assert_eq!(3600, token.expires_in().unwrap().as_secs()); + assert_eq!("foobar", token.refresh_token().unwrap().secret()); + assert_eq!(Some(&"round".to_string()), token.extra_fields().shape()); + assert_eq!(12, token.extra_fields().height()); + + // Ensure that serialization produces an equivalent JSON value. + let serialized_json = serde_json::to_string(&token).unwrap(); + assert_eq!( + "{\"access_token\":\"12/34\",\"token_type\":\"red\",\"expires_in\":3600,\ + \"refresh_token\":\"foobar\",\"scope\":\"read write\",\"shape\":\"round\",\"height\":12}" + .to_string(), + serialized_json + ); + + let deserialized_token = + serde_json::from_str::(&serialized_json).unwrap(); + assert_token_eq(&token, &deserialized_token); +} + +#[test] +fn test_extension_with_simple_json_error() { + let client = ColorfulClient::new(ClientId::new("aaa".to_string())) + .set_client_secret(ClientSecret::new("bbb".to_string())) + .set_auth_url(AuthUrl::new("https://example.com/auth".to_string()).unwrap()) + .set_token_url(TokenUrl::new("https://example.com/token".to_string()).unwrap()); + + let token = client + .exchange_code(AuthorizationCode::new("ccc".to_string())) + .request(mock_http_client( + vec![ + (ACCEPT, "application/json"), + (CONTENT_TYPE, "application/x-www-form-urlencoded"), + (AUTHORIZATION, "Basic YWFhOmJiYg=="), + ], + "grant_type=authorization_code&code=ccc", + None, + HttpResponse { + status_code: StatusCode::BAD_REQUEST, + headers: vec![( + CONTENT_TYPE, + HeaderValue::from_str("application/json").unwrap(), + )] + .into_iter() + .collect(), + body: "{\"error\": \"too_light\", \"error_description\": \"stuff happened\", \ + \"error_uri\": \"https://errors\"}" + .to_string() + .into_bytes(), + }, + )); + + assert!(token.is_err()); + + let token_err = token.err().unwrap(); + match token_err { + RequestTokenError::ServerResponse(ref error_response) => { + assert_eq!(ColorfulErrorResponseType::TooLight, *error_response.error()); + assert_eq!( + Some(&"stuff happened".to_string()), + error_response.error_description() + ); + assert_eq!( + Some(&"https://errors".to_string()), + error_response.error_uri() + ); + + // Ensure that serialization produces an equivalent JSON value. + let serialized_json = serde_json::to_string(&error_response).unwrap(); + assert_eq!( + "{\"error\":\"too_light\",\"error_description\":\"stuff happened\",\ + \"error_uri\":\"https://errors\"}" + .to_string(), + serialized_json + ); + + let deserialized_error = serde_json::from_str::< + StandardErrorResponse, + >(&serialized_json) + .unwrap(); + assert_eq!(error_response, &deserialized_error); + } + other => panic!("Unexpected error: {:?}", other), + } + + // Test Debug trait for RequestTokenError + assert_eq!( + "ServerResponse(StandardErrorResponse { error: too_light, \ + error_description: Some(\"stuff happened\"), error_uri: Some(\"https://errors\") })", + format!("{:?}", token_err) + ); + // Test Display trait for RequestTokenError + assert_eq!( + "Server returned error response: too_light: stuff happened (see https://errors)", + token_err.to_string() + ); +} + +mod custom_errors { + use crate::tests::colorful_extension::{ + ColorfulFields, ColorfulRevocableToken, ColorfulTokenType, + }; + use crate::{Client, ErrorResponse, StandardTokenIntrospectionResponse, StandardTokenResponse}; + + use serde::{Deserialize, Serialize}; + + use std::fmt::Error as FormatterError; + use std::fmt::{Display, Formatter}; + + extern crate serde_json; + + #[derive(Serialize, Deserialize, Debug)] + pub struct CustomErrorResponse { + pub custom_error: String, + } + + impl Display for CustomErrorResponse { + fn fmt(&self, f: &mut Formatter) -> Result<(), FormatterError> { + write!(f, "Custom Error from server") + } + } + + impl ErrorResponse for CustomErrorResponse {} + + pub type CustomErrorClient< + const HAS_AUTH_URL: bool, + const HAS_DEVICE_AUTH_URL: bool, + const HAS_INTROSPECTION_URL: bool, + const HAS_REVOCATION_URL: bool, + const HAS_TOKEN_URL: bool, + > = Client< + CustomErrorResponse, + StandardTokenResponse, + ColorfulTokenType, + StandardTokenIntrospectionResponse, + ColorfulRevocableToken, + CustomErrorResponse, + HAS_AUTH_URL, + HAS_DEVICE_AUTH_URL, + HAS_INTROSPECTION_URL, + HAS_REVOCATION_URL, + HAS_TOKEN_URL, + >; +} + +#[test] +fn test_extension_with_custom_json_error() { + let client = CustomErrorClient::new(ClientId::new("aaa".to_string())) + .set_client_secret(ClientSecret::new("bbb".to_string())) + .set_auth_url(AuthUrl::new("https://example.com/auth".to_string()).unwrap()) + .set_token_url(TokenUrl::new("https://example.com/token".to_string()).unwrap()); + + let token = client + .exchange_code(AuthorizationCode::new("ccc".to_string())) + .request(mock_http_client( + vec![ + (ACCEPT, "application/json"), + (CONTENT_TYPE, "application/x-www-form-urlencoded"), + (AUTHORIZATION, "Basic YWFhOmJiYg=="), + ], + "grant_type=authorization_code&code=ccc", + None, + HttpResponse { + status_code: StatusCode::BAD_REQUEST, + headers: vec![( + CONTENT_TYPE, + HeaderValue::from_str("application/json").unwrap(), + )] + .into_iter() + .collect(), + body: "{\"custom_error\": \"non-compliant oauth implementation ;-)\"}" + .to_string() + .into_bytes(), + }, + )); + + assert!(token.is_err()); + + match token.err().unwrap() { + RequestTokenError::ServerResponse(e) => { + assert_eq!("non-compliant oauth implementation ;-)", e.custom_error) + } + e => panic!("failed to correctly parse custom server error, got {:?}", e), + }; +} + +#[test] +fn test_extension_serializer() { + let mut token_response = ColorfulTokenResponse::new( + AccessToken::new("mysecret".to_string()), + ColorfulTokenType::Red, + ColorfulFields { + shape: Some("circle".to_string()), + height: 10, + }, + ); + token_response.set_expires_in(Some(&Duration::from_secs(3600))); + token_response.set_refresh_token(Some(RefreshToken::new("myothersecret".to_string()))); + let serialized = serde_json::to_string(&token_response).unwrap(); + assert_eq!( + "{\ + \"access_token\":\"mysecret\",\ + \"token_type\":\"red\",\ + \"expires_in\":3600,\ + \"refresh_token\":\"myothersecret\",\ + \"shape\":\"circle\",\ + \"height\":10\ + }", + serialized, + ); +} diff --git a/src/types.rs b/src/types.rs index a8229db..3e48dfa 100644 --- a/src/types.rs +++ b/src/types.rs @@ -602,3 +602,49 @@ new_secret_type![ #[derive(Clone, Deserialize, Serialize)] UserCode(String) ]; + +#[cfg(test)] +mod tests { + use crate::{ClientSecret, PkceCodeChallenge, PkceCodeVerifier}; + + #[test] + fn test_secret_redaction() { + let secret = ClientSecret::new("top_secret".to_string()); + assert_eq!("ClientSecret([redacted])", format!("{:?}", secret)); + } + + #[test] + #[should_panic] + fn test_code_verifier_too_short() { + PkceCodeChallenge::new_random_sha256_len(31); + } + + #[test] + #[should_panic] + fn test_code_verifier_too_long() { + PkceCodeChallenge::new_random_sha256_len(97); + } + + #[test] + fn test_code_verifier_min() { + let code = PkceCodeChallenge::new_random_sha256_len(32); + assert_eq!(code.1.secret().len(), 43); + } + + #[test] + fn test_code_verifier_max() { + let code = PkceCodeChallenge::new_random_sha256_len(96); + assert_eq!(code.1.secret().len(), 128); + } + + #[test] + fn test_code_verifier_challenge() { + // Example from https://tools.ietf.org/html/rfc7636#appendix-B + let code_verifier = + PkceCodeVerifier::new("dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk".to_string()); + assert_eq!( + PkceCodeChallenge::from_code_verifier_sha256(&code_verifier).as_str(), + "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM", + ); + } +}