diff --git a/examples/github.rs b/examples/github.rs index e94b363..29f9a6f 100644 --- a/examples/github.rs +++ b/examples/github.rs @@ -14,8 +14,7 @@ //! use oauth2::basic::BasicClient; -// Alternatively, this can be `oauth2::curl::http_client` or a custom client. -use oauth2::reqwest::http_client; +use oauth2::reqwest::reqwest; use oauth2::{ AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, RedirectUrl, Scope, TokenResponse, TokenUrl, @@ -50,6 +49,12 @@ fn main() { RedirectUrl::new("http://localhost:8080".to_string()).expect("Invalid redirect URL"), ); + let http_client = reqwest::blocking::ClientBuilder::new() + // Following redirects opens the client up to SSRF vulnerabilities. + .redirect(reqwest::redirect::Policy::none()) + .build() + .expect("Client should build"); + // Generate the authorization URL to which we'll redirect the user. let (authorize_url, csrf_state) = client .authorize_url(CsrfToken::new_random) @@ -108,7 +113,7 @@ fn main() { ); // Exchange the code with a token. - let token_res = client.exchange_code(code).request(http_client); + let token_res = client.exchange_code(code).request(&http_client); println!("Github returned the following token:\n{:?}\n", token_res); diff --git a/examples/github_async.rs b/examples/github_async.rs index e41e9ce..b66863b 100644 --- a/examples/github_async.rs +++ b/examples/github_async.rs @@ -14,7 +14,7 @@ //! use oauth2::basic::BasicClient; -use oauth2::reqwest::async_http_client; +use oauth2::reqwest::reqwest; use oauth2::{ AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, RedirectUrl, Scope, TokenResponse, TokenUrl, @@ -50,6 +50,12 @@ async fn main() { RedirectUrl::new("http://localhost:8080".to_string()).expect("Invalid redirect URL"), ); + let http_client = reqwest::ClientBuilder::new() + // Following redirects opens the client up to SSRF vulnerabilities. + .redirect(reqwest::redirect::Policy::none()) + .build() + .expect("Client should build"); + // Generate the authorization URL to which we'll redirect the user. let (authorize_url, csrf_state) = client .authorize_url(CsrfToken::new_random) @@ -108,10 +114,7 @@ async fn main() { ); // Exchange the code with a token. - let token_res = client - .exchange_code(code) - .request_async(async_http_client) - .await; + let token_res = client.exchange_code(code).request_async(&http_client).await; println!("Github returned the following token:\n{:?}\n", token_res); diff --git a/examples/google.rs b/examples/google.rs index f27ce8b..04bb506 100644 --- a/examples/google.rs +++ b/examples/google.rs @@ -13,9 +13,8 @@ //! ...and follow the instructions. //! +use oauth2::reqwest::reqwest; use oauth2::{basic::BasicClient, StandardRevocableToken, TokenResponse}; -// Alternatively, this can be oauth2::curl::http_client or a custom. -use oauth2::reqwest::http_client; use oauth2::{ AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, PkceCodeChallenge, RedirectUrl, RevocationUrl, Scope, TokenUrl, @@ -55,6 +54,12 @@ fn main() { .expect("Invalid revocation endpoint URL"), ); + let http_client = reqwest::blocking::ClientBuilder::new() + // Following redirects opens the client up to SSRF vulnerabilities. + .redirect(reqwest::redirect::Policy::none()) + .build() + .expect("Client should build"); + // Google supports Proof Key for Code Exchange (PKCE - https://oauth.net/2/pkce/). // Create a PKCE code verifier and SHA-256 encode it as a code challenge. let (pkce_code_challenge, pkce_code_verifier) = PkceCodeChallenge::new_random_sha256(); @@ -125,7 +130,7 @@ fn main() { let token_response = client .exchange_code(code) .set_pkce_verifier(pkce_code_verifier) - .request(http_client); + .request(&http_client); println!( "Google returned the following token:\n{:?}\n", @@ -142,6 +147,6 @@ fn main() { client .revoke_token(token_to_revoke) .unwrap() - .request(http_client) + .request(&http_client) .expect("Failed to revoke token"); } diff --git a/examples/google_devicecode.rs b/examples/google_devicecode.rs index 6c4c4e5..0924192 100644 --- a/examples/google_devicecode.rs +++ b/examples/google_devicecode.rs @@ -14,8 +14,7 @@ //! use oauth2::basic::BasicClient; -// Alternatively, this can be oauth2::curl::http_client or a custom. -use oauth2::reqwest::http_client; +use oauth2::reqwest::reqwest; use oauth2::{ AuthType, AuthUrl, ClientId, ClientSecret, DeviceAuthorizationResponse, DeviceAuthorizationUrl, ExtraDeviceAuthorizationFields, Scope, TokenUrl, @@ -58,11 +57,17 @@ fn main() { .set_device_authorization_url(device_auth_url) .set_auth_type(AuthType::RequestBody); + let http_client = reqwest::blocking::ClientBuilder::new() + // Following redirects opens the client up to SSRF vulnerabilities. + .redirect(reqwest::redirect::Policy::none()) + .build() + .expect("Client should build"); + // Request the set of codes from the Device Authorization endpoint. let details: StoringDeviceAuthorizationResponse = device_client .exchange_device_code() .add_scope(Scope::new("profile".to_string())) - .request(http_client) + .request(&http_client) .expect("Failed to request codes from device auth endpoint"); // Display the URL and user-code. @@ -75,7 +80,7 @@ fn main() { // Now poll for the token let token = device_client .exchange_device_access_token(&details) - .request(http_client, std::thread::sleep, None) + .request(&http_client, std::thread::sleep, None) .expect("Failed to get token"); println!("Google returned the following token:\n{:?}\n", token); diff --git a/examples/letterboxd.rs b/examples/letterboxd.rs index 087132c..5252282 100644 --- a/examples/letterboxd.rs +++ b/examples/letterboxd.rs @@ -17,7 +17,7 @@ use hex::ToHex; use hmac::{Hmac, Mac}; use oauth2::{ basic::BasicClient, AuthType, AuthUrl, ClientId, ClientSecret, HttpRequest, HttpResponse, - ResourceOwnerPassword, ResourceOwnerUsername, TokenUrl, + ResourceOwnerPassword, ResourceOwnerUsername, SyncHttpClient, TokenUrl, }; use sha2::Sha256; use url::Url; @@ -63,7 +63,7 @@ fn main() -> Result<(), anyhow::Error> { let token_result = client .set_auth_type(AuthType::RequestBody) .exchange_password(&letterboxd_username, &letterboxd_password) - .request(|request| http_client.execute(request))?; + .request(&|request| http_client.execute(request))?; println!("{:?}", token_result); @@ -77,6 +77,7 @@ fn main() -> Result<(), anyhow::Error> { struct SigningHttpClient { client_id: ClientId, client_secret: ClientSecret, + inner: reqwest::blocking::Client, } impl SigningHttpClient { @@ -84,14 +85,26 @@ impl SigningHttpClient { Self { client_id, client_secret, + inner: reqwest::blocking::ClientBuilder::new() + // Following redirects opens the client up to SSRF vulnerabilities. + .redirect(reqwest::redirect::Policy::none()) + .build() + .expect("Client should build"), } } /// Signs the request before calling `oauth2::reqwest::http_client`. fn execute(&self, mut request: HttpRequest) -> Result { - let signed_url = self.sign_url(request.url, &request.method, &request.body); - request.url = signed_url; - oauth2::reqwest::http_client(request) + let signed_url = self.sign_url( + Url::parse(&request.uri().to_string()).expect("http::Uri should be a valid url::Url"), + request.method(), + request.body(), + ); + *request.uri_mut() = signed_url + .as_str() + .try_into() + .expect("url::Url should be a valid http::Uri"); + self.inner.call(request) } /// Signs the request based on a random and unique nonce, timestamp, and diff --git a/examples/microsoft_devicecode_common_user.rs b/examples/microsoft_devicecode_common_user.rs index 69a2ee4..29f43c8 100644 --- a/examples/microsoft_devicecode_common_user.rs +++ b/examples/microsoft_devicecode_common_user.rs @@ -1,5 +1,4 @@ use oauth2::basic::BasicClient; -use oauth2::reqwest::async_http_client; use oauth2::{ AuthUrl, ClientId, DeviceAuthorizationUrl, Scope, StandardDeviceAuthorizationResponse, TokenUrl, }; @@ -19,10 +18,16 @@ async fn main() -> Result<(), Box> { "https://login.microsoftonline.com/common/oauth2/v2.0/devicecode".to_string(), )?); + let http_client = reqwest::ClientBuilder::new() + // Following redirects opens the client up to SSRF vulnerabilities. + .redirect(reqwest::redirect::Policy::none()) + .build() + .expect("Client should build"); + let details: StandardDeviceAuthorizationResponse = client .exchange_device_code() .add_scope(Scope::new("read".to_string())) - .request_async(async_http_client) + .request_async(&http_client) .await?; eprintln!( @@ -33,7 +38,7 @@ async fn main() -> Result<(), Box> { let token_result = client .exchange_device_access_token(&details) - .request_async(async_http_client, tokio::time::sleep, None) + .request_async(&http_client, tokio::time::sleep, None) .await; eprintln!("Token:{:?}", token_result); diff --git a/examples/microsoft_devicecode_tenant_user.rs b/examples/microsoft_devicecode_tenant_user.rs index b65a275..42e3052 100644 --- a/examples/microsoft_devicecode_tenant_user.rs +++ b/examples/microsoft_devicecode_tenant_user.rs @@ -1,5 +1,5 @@ use oauth2::basic::BasicClient; -use oauth2::reqwest::async_http_client; +use oauth2::reqwest::reqwest; use oauth2::StandardDeviceAuthorizationResponse; use oauth2::{AuthUrl, ClientId, DeviceAuthorizationUrl, Scope, TokenUrl}; @@ -25,10 +25,16 @@ async fn main() -> Result<(), Box> { TENANT_ID ))?); + let http_client = reqwest::ClientBuilder::new() + // Following redirects opens the client up to SSRF vulnerabilities. + .redirect(reqwest::redirect::Policy::none()) + .build() + .expect("Client should build"); + let details: StandardDeviceAuthorizationResponse = client .exchange_device_code() .add_scope(Scope::new("read".to_string())) - .request_async(async_http_client) + .request_async(&http_client) .await?; eprintln!( @@ -39,7 +45,7 @@ async fn main() -> Result<(), Box> { let token_result = client .exchange_device_access_token(&details) - .request_async(async_http_client, tokio::time::sleep, None) + .request_async(&http_client, tokio::time::sleep, None) .await; eprintln!("Token:{:?}", token_result); diff --git a/examples/msgraph.rs b/examples/msgraph.rs index 690d672..bfba120 100644 --- a/examples/msgraph.rs +++ b/examples/msgraph.rs @@ -21,8 +21,7 @@ //! use oauth2::basic::BasicClient; -// Alternatively, this can be `oauth2::curl::http_client` or a custom client. -use oauth2::reqwest::http_client; +use oauth2::reqwest::reqwest; use oauth2::{ AuthType, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, PkceCodeChallenge, RedirectUrl, Scope, TokenUrl, @@ -63,6 +62,12 @@ fn main() { .expect("Invalid redirect URL"), ); + let http_client = reqwest::blocking::ClientBuilder::new() + // Following redirects opens the client up to SSRF vulnerabilities. + .redirect(reqwest::redirect::Policy::none()) + .build() + .expect("Client should build"); + // Microsoft Graph supports Proof Key for Code Exchange (PKCE - https://oauth.net/2/pkce/). // Create a PKCE code verifier and SHA-256 encode it as a code challenge. let (pkce_code_challenge, pkce_code_verifier) = PkceCodeChallenge::new_random_sha256(); @@ -131,7 +136,7 @@ fn main() { .exchange_code(code) // Send the PKCE code verifier in the token request .set_pkce_verifier(pkce_code_verifier) - .request(http_client); + .request(&http_client); println!("MS Graph returned the following token:\n{:?}\n", token); } diff --git a/examples/wunderlist.rs b/examples/wunderlist.rs index 53bba58..dd98c67 100644 --- a/examples/wunderlist.rs +++ b/examples/wunderlist.rs @@ -19,14 +19,13 @@ use oauth2::basic::{ BasicTokenType, }; use oauth2::helpers; -use oauth2::{StandardRevocableToken, TokenType}; -// Alternatively, this can be `oauth2::curl::http_client` or a custom client. -use oauth2::reqwest::http_client; +use oauth2::reqwest::reqwest; use oauth2::{ AccessToken, AuthUrl, AuthorizationCode, Client, ClientId, ClientSecret, CsrfToken, EmptyExtraTokenFields, ExtraTokenFields, RedirectUrl, RefreshToken, Scope, TokenResponse, TokenUrl, }; +use oauth2::{StandardRevocableToken, TokenType}; use serde::{Deserialize, Serialize}; use url::Url; @@ -157,6 +156,12 @@ fn main() { RedirectUrl::new("http://localhost:8080".to_string()).expect("Invalid redirect URL"), ); + let http_client = reqwest::blocking::ClientBuilder::new() + // Following redirects opens the client up to SSRF vulnerabilities. + .redirect(reqwest::redirect::Policy::none()) + .build() + .expect("Client should build"); + // Generate the authorization URL to which we'll redirect the user. let (authorize_url, csrf_state) = client.authorize_url(CsrfToken::new_random).url(); @@ -217,7 +222,7 @@ fn main() { .exchange_code(code) .add_extra_param("client_id", client_id_str) .add_extra_param("client_secret", client_secret_str) - .request(http_client); + .request(&http_client); println!( "Wunderlist returned the following token:\n{:?}\n", diff --git a/src/client.rs b/src/client.rs index adc0163..bec0b4c 100644 --- a/src/client.rs +++ b/src/client.rs @@ -41,7 +41,9 @@ use std::sync::Arc; /// # use thiserror::Error; /// # use http::status::StatusCode; /// # use http::header::{HeaderValue, CONTENT_TYPE}; +/// # use http::Response; /// # use oauth2::{*, basic::*}; +/// # /// # let client = BasicClient::new(ClientId::new("aaa".to_string())) /// # .set_client_secret(ClientSecret::new("bbb".to_string())) /// # .set_auth_uri(AuthUrl::new("https://example.com/auth".to_string()).unwrap()) @@ -55,25 +57,23 @@ use std::sync::Arc; /// # } /// # /// # let http_client = |_| -> Result { -/// # Ok(HttpResponse { -/// # status_code: StatusCode::BAD_REQUEST, -/// # headers: vec![( -/// # CONTENT_TYPE, -/// # HeaderValue::from_str("application/json").unwrap(), -/// # )] -/// # .into_iter() -/// # .collect(), -/// # body: "{\"error\": \"unsupported_token_type\", \"error_description\": \"stuff happened\", \ -/// # \"error_uri\": \"https://errors\"}" +/// # Ok(Response::builder() +/// # .status(StatusCode::BAD_REQUEST) +/// # .header(CONTENT_TYPE, HeaderValue::from_str("application/json").unwrap()) +/// # .body( +/// # r#"{"error": "unsupported_token_type", +/// # "error_description": "stuff happened", +/// # "error_uri": "https://errors"}"# /// # .to_string() /// # .into_bytes(), -/// # }) +/// # ) +/// # .unwrap()) /// # }; /// # /// let res = client /// .revoke_token(AccessToken::new("some token".to_string()).into()) /// .unwrap() -/// .request(http_client); +/// .request(&http_client); /// /// assert!(matches!(res, Err( /// RequestTokenError::ServerResponse(err)) if matches!(err.error(), diff --git a/src/curl.rs b/src/curl.rs index 11fe5d1..ad8afc3 100644 --- a/src/curl.rs +++ b/src/curl.rs @@ -1,7 +1,7 @@ -use crate::{HttpRequest, HttpResponse}; +use crate::{HttpRequest, HttpResponse, SyncHttpClient}; use curl::easy::Easy; -use http::header::{HeaderMap, HeaderValue, CONTENT_TYPE}; +use http::header::{HeaderValue, CONTENT_TYPE}; use http::method::Method; use http::status::StatusCode; @@ -21,62 +21,65 @@ pub enum Error { Other(String), } -/// Synchronous HTTP client. -pub fn http_client(request: HttpRequest) -> Result { - let mut easy = Easy::new(); - easy.url(&request.url.to_string()[..])?; +/// A synchronous HTTP client using [`curl`]. +pub struct CurlHttpClient; +impl SyncHttpClient for CurlHttpClient { + type Error = Error; - let mut headers = curl::easy::List::new(); - for (name, value) in &request.headers { - headers.append(&format!( - "{}: {}", - name, - // TODO: Unnecessary fallibility, curl uses a CString under the hood - value.to_str().map_err(|_| Error::Other(format!( - "invalid {} header value {:?}", + fn call(&self, request: HttpRequest) -> Result { + let mut easy = Easy::new(); + easy.url(&request.uri().to_string()[..])?; + + let mut headers = curl::easy::List::new(); + for (name, value) in request.headers() { + headers.append(&format!( + "{}: {}", name, - value.as_bytes() - )))? - ))? - } + // TODO: Unnecessary fallibility, curl uses a CString under the hood + value.to_str().map_err(|_| Error::Other(format!( + "invalid {} header value {:?}", + name, + value.as_bytes() + )))? + ))? + } - easy.http_headers(headers)?; + easy.http_headers(headers)?; - if let Method::POST = request.method { - easy.post(true)?; - easy.post_field_size(request.body.len() as u64)?; - } else { - assert_eq!(request.method, Method::GET); - } + if let Method::POST = *request.method() { + easy.post(true)?; + easy.post_field_size(request.body().len() as u64)?; + } else { + assert_eq!(*request.method(), Method::GET); + } - let mut form_slice = &request.body[..]; - let mut data = Vec::new(); - { - let mut transfer = easy.transfer(); + let mut form_slice = &request.body()[..]; + let mut data = Vec::new(); + { + let mut transfer = easy.transfer(); - transfer.read_function(|buf| Ok(form_slice.read(buf).unwrap_or(0)))?; + transfer.read_function(|buf| Ok(form_slice.read(buf).unwrap_or(0)))?; - transfer.write_function(|new_data| { - data.extend_from_slice(new_data); - Ok(new_data.len()) - })?; + transfer.write_function(|new_data| { + data.extend_from_slice(new_data); + Ok(new_data.len()) + })?; - transfer.perform()?; - } + transfer.perform()?; + } - let status_code = easy.response_code()? as u16; + let mut builder = http::Response::builder() + .status(StatusCode::from_u16(easy.response_code()? as u16).map_err(http::Error::from)?); - Ok(HttpResponse { - status_code: StatusCode::from_u16(status_code).map_err(http::Error::from)?, - headers: easy + if let Some(content_type) = easy .content_type()? - .map(|content_type| HeaderValue::from_str(content_type).map_err(http::Error::from)) - .transpose()? - .map_or_else(HeaderMap::new, |content_type| { - vec![(CONTENT_TYPE, content_type)] - .into_iter() - .collect::() - }), - body: data, - }) + .map(HeaderValue::from_str) + .transpose() + .map_err(http::Error::from)? + { + builder = builder.header(CONTENT_TYPE, content_type); + } + + builder.body(data).map_err(Error::Http) + } } diff --git a/src/devicecode.rs b/src/devicecode.rs index 7b373cc..caf086f 100644 --- a/src/devicecode.rs +++ b/src/devicecode.rs @@ -2,9 +2,10 @@ use crate::basic::BasicErrorResponseType; use crate::endpoint::{endpoint_request, endpoint_response}; use crate::types::VerificationUriComplete; use crate::{ - AuthType, ClientId, ClientSecret, DeviceAuthorizationUrl, DeviceCode, EndUserVerificationUrl, - ErrorResponse, ErrorResponseType, HttpRequest, HttpResponse, RequestTokenError, Scope, - StandardErrorResponse, TokenResponse, TokenType, TokenUrl, UserCode, + AsyncHttpClient, AuthType, ClientId, ClientSecret, DeviceAuthorizationUrl, DeviceCode, + EndUserVerificationUrl, ErrorResponse, ErrorResponseType, HttpRequest, HttpResponse, + RequestTokenError, Scope, StandardErrorResponse, SyncHttpClient, TokenRequestFuture, + TokenResponse, TokenType, TokenUrl, UserCode, }; use chrono::{DateTime, Utc}; @@ -17,9 +18,14 @@ use std::fmt::Error as FormatterError; use std::fmt::{Debug, Display, Formatter}; use std::future::Future; use std::marker::PhantomData; +use std::pin::Pin; use std::sync::Arc; use std::time::Duration; +/// Future returned by [`DeviceAuthorizationRequest::request_async`]. +pub type DeviceAuthorizationRequestFuture<'c, C, EF, TE> = + TokenRequestFuture<'c, >::Error, TE, DeviceAuthorizationResponse>; + /// The request for a set of verification codes from the authorization server. /// /// See . @@ -78,7 +84,10 @@ where self } - fn prepare_request(self) -> HttpRequest { + fn prepare_request(self) -> Result> + where + RE: Error + 'static, + { endpoint_request( self.auth_type, self.client_id, @@ -89,34 +98,32 @@ where self.device_authorization_url.url(), vec![], ) + .map_err(|err| RequestTokenError::Other(format!("failed to prepare request: {err}"))) } /// Synchronously sends the request to the authorization server and awaits a response. - pub fn request( + pub fn request( self, - http_client: F, - ) -> Result, RequestTokenError> + http_client: &C, + ) -> Result, RequestTokenError<::Error, TE>> where - F: FnOnce(HttpRequest) -> Result, - RE: Error + 'static, + C: SyncHttpClient, EF: ExtraDeviceAuthorizationFields, { - endpoint_response(http_client(self.prepare_request())?) + endpoint_response(http_client.call(self.prepare_request()?)?) } /// Asynchronously sends the request to the authorization server and returns a Future. - pub async fn request_async( + pub fn request_async<'c, C, EF>( self, - http_client: C, - ) -> Result, RequestTokenError> + http_client: &'c C, + ) -> Pin>> where - C: FnOnce(HttpRequest) -> F, - F: Future>, - RE: Error + 'static, + Self: 'c, + C: AsyncHttpClient<'c>, EF: ExtraDeviceAuthorizationFields, { - let http_response = http_client(self.prepare_request()).await?; - endpoint_response(http_response) + Box::pin(async move { endpoint_response(http_client.call(self.prepare_request()?).await?) }) } } @@ -189,16 +196,15 @@ where /// Synchronously polls the authorization server for a response, waiting /// using a user defined sleep function. - pub fn request( + pub fn request( self, - http_client: F, + http_client: &C, sleep_fn: S, timeout: Option, - ) -> Result> + ) -> Result::Error, DeviceCodeErrorResponse>> where - F: Fn(HttpRequest) -> Result, + C: SyncHttpClient, S: Fn(Duration), - RE: Error + 'static, { // Get the request timeout and starting interval let timeout_dt = self.compute_timeout(timeout)?; @@ -217,7 +223,7 @@ where )); } - match self.process_response(http_client(self.prepare_request()), interval) { + match self.process_response(http_client.call(self.prepare_request()?), interval) { DeviceAccessTokenPollResult::ContinueWithNewPollInterval(new_interval) => { interval = new_interval } @@ -230,49 +236,58 @@ where } /// Asynchronously sends the request to the authorization server and awaits a response. - pub async fn request_async( + pub fn request_async<'c, C, S, SF>( self, - http_client: C, + http_client: &'c C, sleep_fn: S, timeout: Option, - ) -> Result> + ) -> Pin< + Box>::Error, DeviceCodeErrorResponse, TR>>, + > where - C: Fn(HttpRequest) -> F, - F: Future>, - S: Fn(Duration) -> SF, + Self: 'c, + C: AsyncHttpClient<'c>, + S: Fn(Duration) -> SF + 'c, SF: Future, - RE: Error + 'static, { - // Get the request timeout and starting interval - let timeout_dt = self.compute_timeout(timeout)?; - let mut interval = self.dev_auth_resp.interval(); - - // Loop while requesting a token. - loop { - let now = (*self.time_fn)(); - if now > timeout_dt { - break Err(RequestTokenError::ServerResponse( - DeviceCodeErrorResponse::new( - DeviceCodeErrorResponseType::ExpiredToken, - Some(String::from("This device code has expired.")), - None, - ), - )); - } + Box::pin(async move { + // Get the request timeout and starting interval + let timeout_dt = self.compute_timeout(timeout)?; + let mut interval = self.dev_auth_resp.interval(); + + // Loop while requesting a token. + loop { + let now = (*self.time_fn)(); + if now > timeout_dt { + break Err(RequestTokenError::ServerResponse( + DeviceCodeErrorResponse::new( + DeviceCodeErrorResponseType::ExpiredToken, + Some(String::from("This device code has expired.")), + None, + ), + )); + } - match self.process_response(http_client(self.prepare_request()).await, interval) { - DeviceAccessTokenPollResult::ContinueWithNewPollInterval(new_interval) => { - interval = new_interval + match self + .process_response(http_client.call(self.prepare_request()?).await, interval) + { + DeviceAccessTokenPollResult::ContinueWithNewPollInterval(new_interval) => { + interval = new_interval + } + DeviceAccessTokenPollResult::Done(res, _) => break res, } - DeviceAccessTokenPollResult::Done(res, _) => break res, - } - // Sleep here using the provided sleep function. - sleep_fn(interval).await; - } + // Sleep here using the provided sleep function. + sleep_fn(interval).await; + } + }) } - fn prepare_request(&self) -> HttpRequest { + fn prepare_request(&self) -> Result> + where + RE: Error + 'static, + TE: ErrorResponse + 'static, + { endpoint_request( self.auth_type, self.client_id, @@ -286,6 +301,7 @@ where ("device_code", self.dev_auth_resp.device_code().secret()), ], ) + .map_err(|err| RequestTokenError::Other(format!("failed to prepare request: {err}"))) } fn process_response( @@ -357,7 +373,7 @@ where let timeout_dur = timeout.unwrap_or_else(|| self.dev_auth_resp.expires_in()); let chrono_timeout = chrono::Duration::from_std(timeout_dur).map_err(|e| { RequestTokenError::Other(format!( - "Failed to convert `{:?}` to `chrono::Duration`: {}", + "failed to convert `{:?}` to `chrono::Duration`: {}", timeout_dur, e )) })?; @@ -365,7 +381,7 @@ where // Calculate the DateTime at which the request times out. let timeout_dt = (*self.time_fn)() .checked_add_signed(chrono_timeout) - .ok_or_else(|| RequestTokenError::Other("Failed to calculate timeout".to_string()))?; + .ok_or_else(|| RequestTokenError::Other("failed to calculate timeout".to_string()))?; Ok(timeout_dt) } @@ -580,13 +596,13 @@ mod tests { use crate::basic::BasicTokenType; use crate::tests::{mock_http_client, mock_http_client_success_fail, new_client}; use crate::{ - DeviceAuthorizationUrl, DeviceCodeErrorResponse, DeviceCodeErrorResponseType, HttpResponse, + DeviceAuthorizationUrl, DeviceCodeErrorResponse, DeviceCodeErrorResponseType, RequestTokenError, Scope, StandardDeviceAuthorizationResponse, TokenResponse, }; use chrono::{DateTime, Utc}; use http::header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE}; - use http::{HeaderValue, StatusCode}; + use http::{HeaderValue, Response, StatusCode}; use std::time::Duration; @@ -611,7 +627,7 @@ mod tests { .exchange_device_code() .add_extra_param("foo", "bar") .add_scope(Scope::new("openid".to_string())) - .request(mock_http_client( + .request(&mock_http_client( vec![ (ACCEPT, "application/json"), (CONTENT_TYPE, "application/x-www-form-urlencoded"), @@ -619,16 +635,14 @@ mod tests { ], "scope=openid&foo=bar", Some(device_auth_url.url().to_owned()), - HttpResponse { - status_code: StatusCode::OK, - headers: vec![( + Response::builder() + .status(StatusCode::OK) + .header( CONTENT_TYPE, HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: body.into_bytes(), - }, + ) + .body(body.into_bytes()) + .unwrap(), )) .unwrap() } @@ -653,49 +667,46 @@ mod tests { let token = new_client() .exchange_device_access_token(&details) .set_time_fn(mock_time_fn()) - .request(mock_http_client_success_fail( - None, - vec![ - (ACCEPT, "application/json"), - (CONTENT_TYPE, "application/x-www-form-urlencoded"), - (AUTHORIZATION, "Basic YWFhOmJiYg=="), - ], - "grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Adevice_code&device_code=12345", - HttpResponse { - status_code: StatusCode::from_u16(400).unwrap(), - headers: vec![( - CONTENT_TYPE, - HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: "{\ + .request( + &mock_http_client_success_fail( + None, + vec![ + (ACCEPT, "application/json"), + (CONTENT_TYPE, "application/x-www-form-urlencoded"), + (AUTHORIZATION, "Basic YWFhOmJiYg=="), + ], + "grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Adevice_code&device_code=12345", + Response::builder() + .status(StatusCode::BAD_REQUEST) + .header( + CONTENT_TYPE, + HeaderValue::from_str("application/json").unwrap(), + ) + .body("{\ \"error\": \"authorization_pending\", \ \"error_description\": \"Still waiting for user\"\ }" - .to_string() - .into_bytes(), - }, - 5, - HttpResponse { - status_code: StatusCode::OK, - headers: vec![( - CONTENT_TYPE, - HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: "{\ + .to_string() + .into_bytes()) + .unwrap(), + 5, + Response::builder() + .status(StatusCode::OK) + .header( + CONTENT_TYPE, + HeaderValue::from_str("application/json").unwrap(), + ) + .body("{\ \"access_token\": \"12/34\", \ \"token_type\": \"bearer\", \ \"scope\": \"openid\"\ }" - .to_string() - .into_bytes(), - }, - ), - mock_sleep_fn, - None) + .to_string() + .into_bytes()) + .unwrap(), + ), + mock_sleep_fn, + None) .unwrap(); assert_eq!("12/34", token.access_token().secret()); @@ -728,49 +739,46 @@ mod tests { let token = new_client() .exchange_device_access_token(&details) .set_time_fn(mock_time_fn()) - .request(mock_http_client_success_fail( - None, - vec![ - (ACCEPT, "application/json"), - (CONTENT_TYPE, "application/x-www-form-urlencoded"), - (AUTHORIZATION, "Basic YWFhOmJiYg=="), - ], - "grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Adevice_code&device_code=12345", - HttpResponse { - status_code: StatusCode::from_u16(400).unwrap(), - headers: vec![( - CONTENT_TYPE, - HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: "{\ + .request( + &mock_http_client_success_fail( + None, + vec![ + (ACCEPT, "application/json"), + (CONTENT_TYPE, "application/x-www-form-urlencoded"), + (AUTHORIZATION, "Basic YWFhOmJiYg=="), + ], + "grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Adevice_code&device_code=12345", + Response::builder() + .status(StatusCode::BAD_REQUEST) + .header( + CONTENT_TYPE, + HeaderValue::from_str("application/json").unwrap(), + ) + .body("{\ \"error\": \"slow_down\", \ \"error_description\": \"Woah there partner\"\ }" - .to_string() - .into_bytes(), - }, - 5, - HttpResponse { - status_code: StatusCode::OK, - headers: vec![( - CONTENT_TYPE, - HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: "{\ + .to_string() + .into_bytes()) + .unwrap(), + 5, + Response::builder() + .status(StatusCode::OK) + .header( + CONTENT_TYPE, + HeaderValue::from_str("application/json").unwrap(), + ) + .body("{\ \"access_token\": \"12/34\", \ \"token_type\": \"bearer\", \ \"scope\": \"openid\"\ }" - .to_string() - .into_bytes(), - }, - ), - mock_sleep_fn, - None) + .to_string() + .into_bytes()) + .unwrap(), + ), + mock_sleep_fn, + None) .unwrap(); assert_eq!("12/34", token.access_token().secret()); @@ -827,33 +835,32 @@ mod tests { let token = new_client() .exchange_device_access_token(&details) .set_time_fn(mock_time_fn()) - .request(mock_http_client( - vec![ - (ACCEPT, "application/json"), - (CONTENT_TYPE, "application/x-www-form-urlencoded"), - (AUTHORIZATION, "Basic YWFhOmJiYg=="), - ], - "grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Adevice_code&device_code=12345", - None, - HttpResponse { - status_code: StatusCode::OK, - headers: vec![( - CONTENT_TYPE, - HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: "{\ + .request( + &mock_http_client( + vec![ + (ACCEPT, "application/json"), + (CONTENT_TYPE, "application/x-www-form-urlencoded"), + (AUTHORIZATION, "Basic YWFhOmJiYg=="), + ], + "grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Adevice_code&device_code=12345", + None, + Response::builder() + .status(StatusCode::OK) + .header( + CONTENT_TYPE, + HeaderValue::from_str("application/json").unwrap(), + ) + .body("{\ \"access_token\": \"12/34\", \ \"token_type\": \"bearer\", \ \"scope\": \"openid\"\ }" - .to_string() - .into_bytes(), - }, - ), - mock_sleep_fn, - None) + .to_string() + .into_bytes()) + .unwrap(), + ), + mock_sleep_fn, + None) .unwrap(); assert_eq!("12/34", token.access_token().secret()); @@ -886,32 +893,31 @@ mod tests { let token = new_client() .exchange_device_access_token(&details) .set_time_fn(mock_time_fn()) - .request(mock_http_client( - vec![ - (ACCEPT, "application/json"), - (CONTENT_TYPE, "application/x-www-form-urlencoded"), - (AUTHORIZATION, "Basic YWFhOmJiYg=="), - ], - "grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Adevice_code&device_code=12345", - None, - HttpResponse { - status_code: StatusCode::from_u16(400).unwrap(), - headers: vec![( - CONTENT_TYPE, - HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: "{\ + .request( + &mock_http_client( + vec![ + (ACCEPT, "application/json"), + (CONTENT_TYPE, "application/x-www-form-urlencoded"), + (AUTHORIZATION, "Basic YWFhOmJiYg=="), + ], + "grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Adevice_code&device_code=12345", + None, + Response::builder() + .status(StatusCode::BAD_REQUEST) + .header( + CONTENT_TYPE, + HeaderValue::from_str("application/json").unwrap(), + ) + .body("{\ \"error\": \"authorization_pending\", \ \"error_description\": \"Still waiting for user\"\ }" - .to_string() - .into_bytes(), - }, - ), - mock_sleep_fn, - None) + .to_string() + .into_bytes()) + .unwrap(), + ), + mock_sleep_fn, + None) .err() .unwrap(); match token { @@ -947,32 +953,31 @@ mod tests { let token = new_client() .exchange_device_access_token(&details) .set_time_fn(mock_time_fn()) - .request(mock_http_client( - vec![ - (ACCEPT, "application/json"), - (CONTENT_TYPE, "application/x-www-form-urlencoded"), - (AUTHORIZATION, "Basic YWFhOmJiYg=="), - ], - "grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Adevice_code&device_code=12345", - None, - HttpResponse { - status_code: StatusCode::from_u16(400).unwrap(), - headers: vec![( - CONTENT_TYPE, - HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: "{\ + .request( + &mock_http_client( + vec![ + (ACCEPT, "application/json"), + (CONTENT_TYPE, "application/x-www-form-urlencoded"), + (AUTHORIZATION, "Basic YWFhOmJiYg=="), + ], + "grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Adevice_code&device_code=12345", + None, + Response::builder() + .status(StatusCode::BAD_REQUEST) + .header( + CONTENT_TYPE, + HeaderValue::from_str("application/json").unwrap(), + ) + .body("{\ \"error\": \"access_denied\", \ \"error_description\": \"Access Denied\"\ }" - .to_string() - .into_bytes(), - }, - ), - mock_sleep_fn, - None) + .to_string() + .into_bytes()) + .unwrap(), + ), + mock_sleep_fn, + None) .err() .unwrap(); match token { @@ -1003,32 +1008,31 @@ mod tests { let token = new_client() .exchange_device_access_token(&details) .set_time_fn(mock_time_fn()) - .request(mock_http_client( - vec![ - (ACCEPT, "application/json"), - (CONTENT_TYPE, "application/x-www-form-urlencoded"), - (AUTHORIZATION, "Basic YWFhOmJiYg=="), - ], - "grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Adevice_code&device_code=12345", - None, - HttpResponse { - status_code: StatusCode::from_u16(400).unwrap(), - headers: vec![( - CONTENT_TYPE, - HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: "{\ + .request( + &mock_http_client( + vec![ + (ACCEPT, "application/json"), + (CONTENT_TYPE, "application/x-www-form-urlencoded"), + (AUTHORIZATION, "Basic YWFhOmJiYg=="), + ], + "grant_type=urn%3Aietf%3Aparams%3Aoauth%3Agrant-type%3Adevice_code&device_code=12345", + None, + Response::builder() + .status(StatusCode::BAD_REQUEST) + .header( + CONTENT_TYPE, + HeaderValue::from_str("application/json").unwrap(), + ) + .body("{\ \"error\": \"expired_token\", \ \"error_description\": \"Token has expired\"\ }" - .to_string() - .into_bytes(), - }, - ), - mock_sleep_fn, - None) + .to_string() + .into_bytes()) + .unwrap(), + ), + mock_sleep_fn, + None) .err() .unwrap(); match token { diff --git a/src/endpoint.rs b/src/endpoint.rs index 3956973..84f8768 100644 --- a/src/endpoint.rs +++ b/src/endpoint.rs @@ -4,37 +4,70 @@ use crate::{ }; use http::header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE}; -use http::{HeaderMap, HeaderValue, StatusCode}; +use http::{HeaderValue, StatusCode}; use serde::de::DeserializeOwned; use url::{form_urlencoded, Url}; use std::borrow::Cow; use std::error::Error; +use std::future::Future; +use std::pin::Pin; /// An HTTP request. -#[derive(Clone, Debug)] -pub struct HttpRequest { - // These are all owned values so that the request can safely be passed between - // threads. - /// URL to which the HTTP request is being made. - pub url: Url, - /// HTTP request method for this request. - pub method: http::method::Method, - /// HTTP request headers to send. - pub headers: HeaderMap, - /// HTTP request body (typically for POST requests only). - pub body: Vec, -} +pub type HttpRequest = http::Request>; /// An HTTP response. -#[derive(Clone, Debug)] -pub struct HttpResponse { - /// HTTP status code returned by the server. - pub status_code: StatusCode, - /// HTTP response headers returned by the server. - pub headers: HeaderMap, - /// HTTP response body returned by the server. - pub body: Vec, +pub type HttpResponse = http::Response>; + +/// An asynchronous (future-based) HTTP client. +pub trait AsyncHttpClient<'c> { + /// Error type returned by HTTP client. + type Error: Error + 'static; + + /// Perform a single HTTP request. + fn call( + &'c self, + request: HttpRequest, + ) -> Pin> + 'c>>; +} +impl<'c, E, F, T> AsyncHttpClient<'c> for T +where + E: Error + 'static, + F: Future> + 'c, + // We can't implement this for FnOnce because the device code flow requires clients to support + // multiple calls. + T: Fn(HttpRequest) -> F, +{ + type Error = E; + + fn call( + &'c self, + request: HttpRequest, + ) -> Pin> + 'c>> { + Box::pin(self(request)) + } +} + +/// A synchronous (blocking) HTTP client. +pub trait SyncHttpClient { + /// Error type returned by HTTP client. + type Error: Error + 'static; + + /// Perform a single HTTP request. + fn call(&self, request: HttpRequest) -> Result; +} +impl SyncHttpClient for T +where + E: Error + 'static, + // We can't implement this for FnOnce because the device code flow requires clients to support + // multiple calls. + T: Fn(HttpRequest) -> Result, +{ + type Error = E; + + fn call(&self, request: HttpRequest) -> Result { + self(request) + } } #[allow(clippy::too_many_arguments)] @@ -47,13 +80,15 @@ pub(crate) fn endpoint_request<'a>( scopes: Option<&'a Vec>>, url: &'a Url, params: Vec<(&'a str, &'a str)>, -) -> HttpRequest { - let mut headers = HeaderMap::new(); - headers.append(ACCEPT, HeaderValue::from_static(CONTENT_TYPE_JSON)); - headers.append( - CONTENT_TYPE, - HeaderValue::from_static(CONTENT_TYPE_FORMENCODED), - ); +) -> Result { + let mut builder = http::Request::builder() + .uri(url.to_string()) + .method(http::Method::POST) + .header(ACCEPT, HeaderValue::from_static(CONTENT_TYPE_JSON)) + .header( + CONTENT_TYPE, + HeaderValue::from_static(CONTENT_TYPE_FORMENCODED), + ); let scopes_opt = scopes.and_then(|scopes| { if !scopes.is_empty() { @@ -88,7 +123,7 @@ pub(crate) fn endpoint_request<'a>( form_urlencoded::byte_serialize(secret.secret().as_bytes()).collect(); let b64_credential = base64::encode(format!("{}:{}", &urlencoded_id, urlencoded_secret)); - headers.append( + builder = builder.header( AUTHORIZATION, HeaderValue::from_str(&format!("Basic {}", &b64_credential)).unwrap(), ); @@ -118,19 +153,14 @@ pub(crate) fn endpoint_request<'a>( .finish() .into_bytes(); - HttpRequest { - url: url.to_owned(), - method: http::method::Method::POST, - headers, - body, - } + builder.body(body) } pub(crate) fn endpoint_response( http_response: HttpResponse, ) -> Result> where - RE: Error + 'static, + RE: Error, TE: ErrorResponse, DO: DeserializeOwned, { @@ -138,7 +168,7 @@ where check_response_body(&http_response)?; - let response_body = http_response.body.as_slice(); + let response_body = http_response.body().as_slice(); serde_path_to_error::deserialize(&mut serde_json::Deserializer::from_slice(response_body)) .map_err(|e| RequestTokenError::Parse(e, response_body.to_vec())) } @@ -160,11 +190,11 @@ where RE: Error + 'static, TE: ErrorResponse, { - if http_response.status_code != StatusCode::OK { - let reason = http_response.body.as_slice(); + if http_response.status() != StatusCode::OK { + let reason = http_response.body().as_slice(); if reason.is_empty() { Err(RequestTokenError::Other( - "Server returned empty error response".to_string(), + "server returned empty error response".to_string(), )) } else { let error = match serde_path_to_error::deserialize::<_, TE>( @@ -189,7 +219,7 @@ where { // Validate that the response Content-Type is JSON. http_response - .headers + .headers() .get(CONTENT_TYPE) .map_or(Ok(()), |content_type| // Section 3.1.1.1 of RFC 7231 indicates that media types are case-insensitive and @@ -199,7 +229,7 @@ where Err( RequestTokenError::Other( format!( - "Unexpected response Content-Type: {:?}, should be `{}`", + "unexpected response Content-Type: {:?}, should be `{}`", content_type, CONTENT_TYPE_JSON ) @@ -210,11 +240,44 @@ where } )?; - if http_response.body.is_empty() { + if http_response.body().is_empty() { return Err(RequestTokenError::Other( - "Server returned empty response body".to_string(), + "server returned empty response body".to_string(), )); } Ok(()) } + +#[cfg(test)] +mod tests { + use crate::tests::{clone_response, new_client, FakeError}; + use crate::{AuthorizationCode, TokenResponse}; + + use http::{Response, StatusCode}; + + #[tokio::test] + async fn test_async_client_closure() { + let client = new_client(); + + let http_response = Response::builder() + .status(StatusCode::OK) + .body( + "{\"access_token\": \"12/34\", \"token_type\": \"BEARER\"}" + .to_string() + .into_bytes(), + ) + .unwrap(); + + let token = client + .exchange_code(AuthorizationCode::new("ccc".to_string())) + // NB: This tests that the closure doesn't require a static lifetime. + .request_async(&|_| async { + Ok(clone_response(&http_response)) as Result<_, FakeError> + }) + .await + .unwrap(); + + assert_eq!("12/34", token.access_token().secret()); + } +} diff --git a/src/introspection.rs b/src/introspection.rs index a7f743b..8f4e7d2 100644 --- a/src/introspection.rs +++ b/src/introspection.rs @@ -1,7 +1,8 @@ use crate::endpoint::{endpoint_request, endpoint_response}; use crate::{ - AccessToken, AuthType, ClientId, ClientSecret, ErrorResponse, ExtraTokenFields, HttpRequest, - HttpResponse, IntrospectionUrl, RequestTokenError, Scope, TokenType, + AccessToken, AsyncHttpClient, AuthType, ClientId, ClientSecret, ErrorResponse, + ExtraTokenFields, HttpRequest, IntrospectionUrl, RequestTokenError, Scope, SyncHttpClient, + TokenRequestFuture, TokenType, }; use chrono::serde::ts_seconds_option; @@ -12,8 +13,8 @@ use serde::{Deserialize, Serialize}; use std::borrow::Cow; use std::error::Error; use std::fmt::Debug; -use std::future::Future; use std::marker::PhantomData; +use std::pin::Pin; /// A request to introspect an access token. /// @@ -85,7 +86,10 @@ where self } - fn prepare_request(self) -> HttpRequest { + fn prepare_request(self) -> Result> + where + RE: Error + 'static, + { let mut params: Vec<(&str, &str)> = vec![("token", self.token.secret())]; if let Some(ref token_type_hint) = self.token_type_hint { params.push(("token_type_hint", token_type_hint)); @@ -101,29 +105,30 @@ where self.introspection_url.url(), params, ) + .map_err(|err| RequestTokenError::Other(format!("failed to prepare request: {err}"))) } /// Synchronously sends the request to the authorization server and awaits a response. - pub fn request(self, http_client: F) -> Result> + pub fn request( + self, + http_client: &C, + ) -> Result::Error, TE>> where - F: FnOnce(HttpRequest) -> Result, - RE: Error + 'static, + C: SyncHttpClient, { - endpoint_response(http_client(self.prepare_request())?) + endpoint_response(http_client.call(self.prepare_request()?)?) } /// Asynchronously sends the request to the authorization server and returns a Future. - pub async fn request_async( + pub fn request_async<'c, C>( self, - http_client: C, - ) -> Result> + http_client: &'c C, + ) -> Pin>::Error, TE, TIR>>> where - C: FnOnce(HttpRequest) -> F, - F: Future>, - RE: Error + 'static, + Self: 'c, + C: AsyncHttpClient<'c>, { - let http_response = http_client(self.prepare_request()).await?; - endpoint_response(http_response) + Box::pin(async move { endpoint_response(http_client.call(self.prepare_request()?).await?) }) } } @@ -393,13 +398,11 @@ where mod tests { use crate::basic::BasicTokenType; use crate::tests::{mock_http_client, new_client}; - use crate::{ - AccessToken, AuthType, ClientId, HttpResponse, IntrospectionUrl, RedirectUrl, Scope, - }; + use crate::{AccessToken, AuthType, ClientId, IntrospectionUrl, RedirectUrl, Scope}; use chrono::{TimeZone, Utc}; use http::header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE}; - use http::{HeaderValue, StatusCode}; + use http::{HeaderValue, Response, StatusCode}; #[test] fn test_token_introspection_successful_with_basic_auth_minimal_response() { @@ -413,7 +416,7 @@ mod tests { let introspection_response = client .introspect(&AccessToken::new("access_token_123".to_string())) .unwrap() - .request(mock_http_client( + .request(&mock_http_client( vec![ (ACCEPT, "application/json"), (CONTENT_TYPE, "application/x-www-form-urlencoded"), @@ -421,20 +424,20 @@ mod tests { ], "token=access_token_123", Some("https://introspection/url".parse().unwrap()), - HttpResponse { - status_code: StatusCode::OK, - headers: vec![( + Response::builder() + .status(StatusCode::OK) + .header( CONTENT_TYPE, HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: "{\ + ) + .body( + "{\ \"active\": true\ }" - .to_string() - .into_bytes(), - }, + .to_string() + .into_bytes(), + ) + .unwrap(), )) .unwrap(); @@ -465,7 +468,7 @@ mod tests { .introspect(&AccessToken::new("access_token_123".to_string())) .unwrap() .set_token_type_hint("access_token") - .request(mock_http_client( + .request(&mock_http_client( vec![ (ACCEPT, "application/json"), (CONTENT_TYPE, "application/x-www-form-urlencoded"), @@ -473,31 +476,31 @@ mod tests { ], "token=access_token_123&token_type_hint=access_token", Some("https://introspection/url".parse().unwrap()), - HttpResponse { - status_code: StatusCode::OK, - headers: vec![( + Response::builder() + .status(StatusCode::OK) + .header( CONTENT_TYPE, HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: r#"{ - "active": true, - "scope": "email profile", - "client_id": "aaa", - "username": "demo", - "token_type": "bearer", - "exp": 1604073517, - "iat": 1604073217, - "nbf": 1604073317, - "sub": "demo", - "aud": "demo", - "iss": "http://127.0.0.1:8080/auth/realms/test-realm", - "jti": "be1b7da2-fc18-47b3-bdf1-7a4f50bcf53f" - }"# - .to_string() - .into_bytes(), - }, + ) + .body( + r#"{ + "active": true, + "scope": "email profile", + "client_id": "aaa", + "username": "demo", + "token_type": "bearer", + "exp": 1604073517, + "iat": 1604073217, + "nbf": 1604073317, + "sub": "demo", + "aud": "demo", + "iss": "http://127.0.0.1:8080/auth/realms/test-realm", + "jti": "be1b7da2-fc18-47b3-bdf1-7a4f50bcf53f" + }"# + .to_string() + .into_bytes(), + ) + .unwrap(), )) .unwrap(); diff --git a/src/lib.rs b/src/lib.rs index d00f098..b9aa121 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -22,29 +22,44 @@ //! * **Synchronous (blocking)** //! * **Asynchronous** //! +//! ## Security Warning +//! +//! To prevent +//! [SSRF](https://cheatsheetseries.owasp.org/cheatsheets/Server_Side_Request_Forgery_Prevention_Cheat_Sheet.html) +//! vulnerabilities, be sure to configure the HTTP client **not to follow redirects**. For example, +//! use [`redirect::Policy::none`](::reqwest::redirect::Policy::none) when using +//! [`reqwest`](::reqwest), or [`redirects(0)`](::ureq::AgentBuilder::redirects) when using +//! [`ureq`](::ureq). +//! +//! ## HTTP Clients +//! //! For the HTTP client modes described above, the following HTTP client implementations can be //! used: -//! * **[`reqwest`]** +//! * **[`reqwest`](::reqwest)** //! //! The `reqwest` HTTP client supports both the synchronous and asynchronous modes and is enabled //! by default. //! -//! Synchronous client: [`reqwest::http_client`] +//! Synchronous client: [`reqwest::blocking::Client`](::reqwest::blocking::Client) (requires the +//! `reqwest-blocking` feature flag) //! -//! Asynchronous client: [`reqwest::async_http_client`] +//! Asynchronous client: [`reqwest::Client`](::reqwest::Client) (requires either the +//! `reqwest` or `reqwest-blocking` feature flags) //! -//! * **[`curl`]** +//! * **[`curl`](::curl)** //! //! The `curl` HTTP client only supports the synchronous HTTP client mode and can be enabled in //! `Cargo.toml` via the `curl` feature flag. //! -//! Synchronous client: [`curl::http_client`] +//! Synchronous client: [`oauth2::curl::CurlHttpClient`](crate::curl::CurlHttpClient) //! -//! * **[`ureq`]** +//! * **[`ureq`](::ureq)** //! //! The `ureq` HTTP client is a simple HTTP client with minimal dependencies. It only supports //! the synchronous HTTP client mode and can be enabled in `Cargo.toml` via the `ureq` feature -//! flag. +//! flag. +//! +//! Synchronous client: [`ureq::Agent`](::ureq::Agent) //! //! * **Custom** //! @@ -57,18 +72,21 @@ //! oauth2 = { version = "...", default-features = false } //! ``` //! -//! Synchronous HTTP clients should implement the following trait: +//! Synchronous HTTP clients should implement the [`SyncHttpClient`] trait, which is +//! automatically implemented for any function/closure that implements: //! ```rust,ignore -//! FnOnce(HttpRequest) -> Result -//! where RE: std::error::Error + 'static +//! Fn(HttpRequest) -> Result +//! where +//! E: std::error::Error + 'static //! ``` //! -//! Asynchronous HTTP clients should implement the following trait: +//! Asynchronous HTTP clients should implement the [`AsyncHttpClient`] trait, which is +//! automatically implemented for any function/closure that implements: //! ```rust,ignore -//! FnOnce(HttpRequest) -> F +//! Fn(HttpRequest) -> F //! where -//! F: Future>, -//! RE: std::error::Error + 'static +//! E: std::error::Error + 'static, +//! F: Future>, //! ``` //! //! # Comparing secrets securely @@ -110,9 +128,10 @@ //! TokenUrl //! }; //! use oauth2::basic::BasicClient; -//! use oauth2::reqwest::http_client; +//! use oauth2::reqwest::reqwest; //! use url::Url; //! +//! # #[cfg(feature = "reqwest-blocking")] //! # fn err_wrapper() -> Result<(), anyhow::Error> { //! // Create an OAuth2 client by specifying the client ID, client secret, authorization URL and //! // token URL. @@ -144,13 +163,19 @@ //! // authorization code. For security reasons, your code should verify that the `state` //! // parameter returned by the server matches `csrf_token`. //! +//! let http_client = reqwest::blocking::ClientBuilder::new() +//! // Following redirects opens the client up to SSRF vulnerabilities. +//! .redirect(reqwest::redirect::Policy::none()) +//! .build() +//! .expect("Client should build"); +//! //! // Now you can trade it for an access token. //! let token_result = //! client //! .exchange_code(AuthorizationCode::new("some authorization code".to_string())) //! // Set the PKCE code verifier. //! .set_pkce_verifier(pkce_verifier) -//! .request(http_client)?; +//! .request(&http_client)?; //! //! // Unwrapping token_result will either produce a Token or a RequestTokenError. //! # Ok(()) @@ -175,11 +200,9 @@ //! TokenUrl //! }; //! use oauth2::basic::BasicClient; -//! # #[cfg(feature = "reqwest")] -//! use oauth2::reqwest::async_http_client; +//! use oauth2::reqwest::reqwest; //! use url::Url; //! -//! # #[cfg(feature = "reqwest")] //! # async fn err_wrapper() -> Result<(), anyhow::Error> { //! // Create an OAuth2 client by specifying the client ID, client secret, authorization URL and //! // token URL. @@ -211,12 +234,18 @@ //! // authorization code. For security reasons, your code should verify that the `state` //! // parameter returned by the server matches `csrf_token`. //! +//! let http_client = reqwest::ClientBuilder::new() +//! // Following redirects opens the client up to SSRF vulnerabilities. +//! .redirect(reqwest::redirect::Policy::none()) +//! .build() +//! .expect("Client should build"); +//! //! // Now you can trade it for an access token. //! let token_result = client //! .exchange_code(AuthorizationCode::new("some authorization code".to_string())) //! // Set the PKCE code verifier. //! .set_pkce_verifier(pkce_verifier) -//! .request_async(async_http_client) +//! .request_async(&http_client) //! .await?; //! //! // Unwrapping token_result will either produce a Token or a RequestTokenError. @@ -286,15 +315,22 @@ //! TokenUrl //! }; //! use oauth2::basic::BasicClient; -//! use oauth2::reqwest::http_client; +//! use oauth2::reqwest::reqwest; //! use url::Url; //! +//! # #[cfg(feature = "reqwest-blocking")] //! # fn err_wrapper() -> Result<(), anyhow::Error> { //! let client = BasicClient::new(ClientId::new("client_id".to_string())) //! .set_client_secret(ClientSecret::new("client_secret".to_string())) //! .set_auth_uri(AuthUrl::new("http://authorize".to_string())?) //! .set_token_uri(TokenUrl::new("http://token".to_string())?); //! +//! let http_client = reqwest::blocking::ClientBuilder::new() +//! // Following redirects opens the client up to SSRF vulnerabilities. +//! .redirect(reqwest::redirect::Policy::none()) +//! .build() +//! .expect("Client should build"); +//! //! let token_result = //! client //! .exchange_password( @@ -302,7 +338,7 @@ //! &ResourceOwnerPassword::new("pass".to_string()) //! ) //! .add_scope(Scope::new("read".to_string())) -//! .request(http_client)?; +//! .request(&http_client)?; //! # Ok(()) //! # } //! ``` @@ -324,19 +360,26 @@ //! TokenUrl //! }; //! use oauth2::basic::BasicClient; -//! use oauth2::reqwest::http_client; +//! use oauth2::reqwest::reqwest; //! use url::Url; //! +//! # #[cfg(feature = "reqwest-blocking")] //! # fn err_wrapper() -> Result<(), anyhow::Error> { //! let client = BasicClient::new(ClientId::new("client_id".to_string())) //! .set_client_secret(ClientSecret::new("client_secret".to_string())) //! .set_auth_uri(AuthUrl::new("http://authorize".to_string())?) //! .set_token_uri(TokenUrl::new("http://token".to_string())?); //! +//! let http_client = reqwest::blocking::ClientBuilder::new() +//! // Following redirects opens the client up to SSRF vulnerabilities. +//! .redirect(reqwest::redirect::Policy::none()) +//! .build() +//! .expect("Client should build"); +//! //! let token_result = client //! .exchange_client_credentials() //! .add_scope(Scope::new("read".to_string())) -//! .request(http_client)?; +//! .request(&http_client)?; //! # Ok(()) //! # } //! ``` @@ -363,9 +406,10 @@ //! TokenUrl //! }; //! use oauth2::basic::BasicClient; -//! use oauth2::reqwest::http_client; +//! use oauth2::reqwest::reqwest; //! use url::Url; //! +//! # #[cfg(feature = "reqwest-blocking")] //! # fn err_wrapper() -> Result<(), anyhow::Error> { //! let device_auth_url = DeviceAuthorizationUrl::new("http://deviceauth".to_string())?; //! let client = BasicClient::new(ClientId::new("client_id".to_string())) @@ -374,10 +418,16 @@ //! .set_token_uri(TokenUrl::new("http://token".to_string())?) //! .set_device_authorization_url(device_auth_url); //! +//! let http_client = reqwest::blocking::ClientBuilder::new() +//! // Following redirects opens the client up to SSRF vulnerabilities. +//! .redirect(reqwest::redirect::Policy::none()) +//! .build() +//! .expect("Client should build"); +//! //! let details: StandardDeviceAuthorizationResponse = client //! .exchange_device_code() //! .add_scope(Scope::new("read".to_string())) -//! .request(http_client)?; +//! .request(&http_client)?; //! //! println!( //! "Open this URL in your browser:\n{}\nand enter the code: {}", @@ -388,7 +438,7 @@ //! let token_result = //! client //! .exchange_device_access_token(&details) -//! .request(http_client, std::thread::sleep, None)?; +//! .request(&http_client, std::thread::sleep, None)?; //! //! # Ok(()) //! # } @@ -462,11 +512,12 @@ pub mod ureq; pub use crate::client::Client; pub use crate::code::AuthorizationRequest; pub use crate::devicecode::{ - DeviceAccessTokenRequest, DeviceAuthorizationRequest, DeviceAuthorizationResponse, - DeviceCodeErrorResponse, DeviceCodeErrorResponseType, EmptyExtraDeviceAuthorizationFields, - ExtraDeviceAuthorizationFields, StandardDeviceAuthorizationResponse, + DeviceAccessTokenRequest, DeviceAuthorizationRequest, DeviceAuthorizationRequestFuture, + DeviceAuthorizationResponse, DeviceCodeErrorResponse, DeviceCodeErrorResponseType, + EmptyExtraDeviceAuthorizationFields, ExtraDeviceAuthorizationFields, + StandardDeviceAuthorizationResponse, }; -pub use crate::endpoint::{HttpRequest, HttpResponse}; +pub use crate::endpoint::{AsyncHttpClient, HttpRequest, HttpResponse, SyncHttpClient}; pub use crate::error::{ ErrorResponse, ErrorResponseType, RequestTokenError, StandardErrorResponse, }; @@ -478,7 +529,8 @@ pub use crate::revocation::{ }; pub use crate::token::{ ClientCredentialsTokenRequest, CodeTokenRequest, EmptyExtraTokenFields, ExtraTokenFields, - PasswordTokenRequest, RefreshTokenRequest, StandardTokenResponse, TokenResponse, TokenType, + PasswordTokenRequest, RefreshTokenRequest, StandardTokenResponse, TokenRequestFuture, + TokenResponse, TokenType, }; pub use crate::types::{ AccessToken, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, diff --git a/src/reqwest.rs b/src/reqwest.rs index f2a32dc..8fe8912 100644 --- a/src/reqwest.rs +++ b/src/reqwest.rs @@ -1,5 +1,12 @@ +use crate::{AsyncHttpClient, HttpRequest, HttpResponse}; + use thiserror::Error; +use std::future::Future; +use std::pin::Pin; + +pub use reqwest; + /// Error type returned by failed reqwest HTTP requests. #[non_exhaustive] #[derive(Debug, Error)] @@ -15,85 +22,56 @@ pub enum Error { Io(#[from] std::io::Error), } -#[cfg(all(feature = "reqwest-blocking", not(target_arch = "wasm32")))] -pub use blocking::http_client; - -pub use async_client::async_http_client; - -#[cfg(all(feature = "reqwest-blocking", not(target_arch = "wasm32")))] -mod blocking { - use crate::reqwest::Error; - use crate::{HttpRequest, HttpResponse}; - - pub use reqwest; - use reqwest::blocking; - use reqwest::redirect::Policy as RedirectPolicy; - - use std::io::Read; +impl<'c> AsyncHttpClient<'c> for reqwest::Client { + type Error = Error; - /// Synchronous HTTP client. - pub fn http_client(request: HttpRequest) -> Result { - let client = blocking::Client::builder() - // Following redirects opens the client up to SSRF vulnerabilities. - .redirect(RedirectPolicy::none()) - .build()?; + fn call( + &'c self, + request: HttpRequest, + ) -> Pin> + 'c>> { + Box::pin(async move { + let response = self.execute(request.try_into()?).await?; - let mut request_builder = client - .request(request.method, request.url.as_str()) - .body(request.body); + // This should be simpler once https://github.com/seanmonstar/reqwest/pull/2060 is + // merged. + let mut builder = http::Response::builder().status(response.status()); - for (name, value) in &request.headers { - request_builder = request_builder.header(name.as_str(), value.as_bytes()); - } - let mut response = client.execute(request_builder.build()?)?; + #[cfg(not(target_arch = "wasm32"))] + { + builder = builder.version(response.version()); + } - let mut body = Vec::new(); - response.read_to_end(&mut body)?; + for (name, value) in response.headers().iter() { + builder = builder.header(name, value); + } - Ok(HttpResponse { - status_code: response.status(), - headers: response.headers().to_owned(), - body, + builder + .body(response.bytes().await?.to_vec()) + .map_err(Error::Http) }) } } -mod async_client { - use crate::reqwest::Error; - use crate::{HttpRequest, HttpResponse}; - - pub use reqwest; - - /// Asynchronous HTTP client. - pub async fn async_http_client(request: HttpRequest) -> Result { - let client = { - let builder = reqwest::Client::builder(); +#[cfg(all(feature = "reqwest-blocking", not(target_arch = "wasm32")))] +impl crate::SyncHttpClient for reqwest::blocking::Client { + type Error = Error; - // Following redirects opens the client up to SSRF vulnerabilities. - // but this is not possible to prevent on wasm targets - #[cfg(not(target_arch = "wasm32"))] - let builder = builder.redirect(reqwest::redirect::Policy::none()); + fn call(&self, request: HttpRequest) -> Result { + let mut response = self.execute(request.try_into()?)?; - builder.build()? - }; + // This should be simpler once https://github.com/seanmonstar/reqwest/pull/2060 is + // merged. + let mut builder = http::Response::builder() + .status(response.status()) + .version(response.version()); - let mut request_builder = client - .request(request.method, request.url.as_str()) - .body(request.body); - for (name, value) in &request.headers { - request_builder = request_builder.header(name.as_str(), value.as_bytes()); + for (name, value) in response.headers().iter() { + builder = builder.header(name, value); } - let request = request_builder.build()?; - let response = client.execute(request).await?; + let mut body = Vec::new(); + ::read_to_end(&mut response, &mut body)?; - let status_code = response.status(); - let headers = response.headers().to_owned(); - let chunks = response.bytes().await?; - Ok(HttpResponse { - status_code, - headers, - body: chunks.to_vec(), - }) + builder.body(body).map_err(Error::Http) } } diff --git a/src/revocation.rs b/src/revocation.rs index f0854c3..6ab95ee 100644 --- a/src/revocation.rs +++ b/src/revocation.rs @@ -1,8 +1,9 @@ use crate::basic::BasicErrorResponseType; -use crate::endpoint::{endpoint_request, endpoint_response_status_only}; +use crate::endpoint::{endpoint_request, endpoint_response, endpoint_response_status_only}; use crate::{ - AccessToken, AuthType, ClientId, ClientSecret, ErrorResponse, ErrorResponseType, HttpRequest, - HttpResponse, RefreshToken, RequestTokenError, RevocationUrl, + AccessToken, AsyncHttpClient, AuthType, ClientId, ClientSecret, ErrorResponse, + ErrorResponseType, HttpRequest, RefreshToken, RequestTokenError, RevocationUrl, SyncHttpClient, + TokenRequestFuture, }; use serde::{Deserialize, Serialize}; @@ -11,8 +12,8 @@ use std::borrow::Cow; use std::error::Error; use std::fmt::Error as FormatterError; use std::fmt::{Debug, Display, Formatter}; -use std::future::Future; use std::marker::PhantomData; +use std::pin::Pin; /// A revocable token. /// @@ -42,16 +43,48 @@ pub trait RevocableToken { /// if issued to the client, must be supported by the server, otherwise fallback to access token (which may or may not /// be supported by the server). /// -/// ```ignore +/// ```rust +/// # use http::{Response, StatusCode}; +/// # use oauth2::{ +/// # AccessToken, AuthUrl, ClientId, EmptyExtraTokenFields, HttpResponse, RequestTokenError, +/// # RevocationUrl, StandardRevocableToken, StandardTokenResponse, TokenResponse, TokenUrl, +/// # }; +/// # use oauth2::basic::{BasicClient, BasicRequestTokenError, BasicTokenResponse, BasicTokenType}; +/// # +/// # fn err_wrapper() -> Result<(), anyhow::Error> { +/// # +/// # let token_response = BasicTokenResponse::new( +/// # AccessToken::new("access".to_string()), +/// # BasicTokenType::Bearer, +/// # EmptyExtraTokenFields {}, +/// # ); +/// # +/// # let http_client = |_| -> Result> { +/// # Ok(Response::builder() +/// # .status(StatusCode::OK) +/// # .body(Vec::new()) +/// # .unwrap()) +/// # }; +/// # +/// let client = BasicClient::new(ClientId::new("aaa".to_string())) +/// .set_auth_uri(AuthUrl::new("https://example.com/auth".to_string()).unwrap()) +/// .set_token_uri(TokenUrl::new("https://example.com/token".to_string()).unwrap()) +/// // Be sure to set a revocation URL. +/// .set_revocation_url(RevocationUrl::new("https://revocation/url".to_string()).unwrap()); +/// +/// // ... +/// /// let token_to_revoke: StandardRevocableToken = match token_response.refresh_token() { /// Some(token) => token.into(), /// None => token_response.access_token().into(), /// }; /// /// client -/// .revoke_token(token_to_revoke) -/// .request(http_client) -/// .unwrap(); +/// .revoke_token(token_to_revoke)? +/// .request(&http_client) +/// # .unwrap(); +/// # Ok(()) +/// # } /// ``` /// /// [`revoke_token()`]: crate::Client::revoke_token() @@ -154,7 +187,10 @@ where self } - fn prepare_request(self) -> HttpRequest { + fn prepare_request(self) -> Result> + where + RE: Error + 'static, + { let mut params: Vec<(&str, &str)> = vec![("token", self.token.secret())]; if let Some(type_hint) = self.token.type_hint() { params.push(("token_type_hint", type_hint)); @@ -170,6 +206,7 @@ where self.revocation_url.url(), params, ) + .map_err(|err| RequestTokenError::Other(format!("failed to prepare request: {err}"))) } /// Synchronously sends the request to the authorization server and awaits a response. @@ -179,29 +216,29 @@ where /// /// Error [`UnsupportedTokenType`](RevocationErrorResponseType::UnsupportedTokenType) will be returned if the /// type of token type given is not supported by the server. - pub fn request(self, http_client: F) -> Result<(), RequestTokenError> + pub fn request( + self, + http_client: &C, + ) -> Result<(), RequestTokenError<::Error, TE>> where - F: FnOnce(HttpRequest) -> Result, - RE: Error + 'static, + C: SyncHttpClient, { // From https://tools.ietf.org/html/rfc7009#section-2.2: // "The content of the response body is ignored by the client as all // necessary information is conveyed in the response code." - endpoint_response_status_only(http_client(self.prepare_request())?) + endpoint_response_status_only(http_client.call(self.prepare_request()?)?) } /// Asynchronously sends the request to the authorization server and returns a Future. - pub async fn request_async( + pub fn request_async<'c, C>( self, - http_client: C, - ) -> Result<(), RequestTokenError> + http_client: &'c C, + ) -> Pin>::Error, TE, ()>>> where - C: FnOnce(HttpRequest) -> F, - F: Future>, - RE: Error + 'static, + Self: 'c, + C: AsyncHttpClient<'c>, { - let http_response = http_client(self.prepare_request()).await?; - endpoint_response_status_only(http_response) + Box::pin(async move { endpoint_response(http_client.call(self.prepare_request()?).await?) }) } } @@ -272,12 +309,12 @@ mod tests { use crate::tests::colorful_extension::{ColorfulClient, ColorfulRevocableToken}; use crate::tests::{mock_http_client, new_client}; use crate::{ - AccessToken, AuthUrl, ClientId, ClientSecret, HttpResponse, RefreshToken, - RequestTokenError, RevocationErrorResponseType, RevocationUrl, TokenUrl, + AccessToken, AuthUrl, ClientId, ClientSecret, RefreshToken, RequestTokenError, + RevocationErrorResponseType, RevocationUrl, TokenUrl, }; use http::header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE}; - use http::{HeaderValue, StatusCode}; + use http::{HeaderValue, Response, StatusCode}; #[test] fn test_token_revocation_with_non_https_url() { @@ -301,7 +338,7 @@ mod tests { let revocation_response = client .revoke_token(AccessToken::new("access_token_123".to_string()).into()).unwrap() - .request(mock_http_client( + .request(&mock_http_client( vec![ (ACCEPT, "application/json"), (CONTENT_TYPE, "application/x-www-form-urlencoded"), @@ -309,19 +346,21 @@ mod tests { ], "token=access_token_123&token_type_hint=access_token", Some("https://revocation/url".parse().unwrap()), - HttpResponse { - status_code: StatusCode::BAD_REQUEST, - headers: vec![( - CONTENT_TYPE, - HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: "{\"error\": \"unsupported_token_type\", \"error_description\": \"stuff happened\", \ - \"error_uri\": \"https://errors\"}" - .to_string() - .into_bytes(), - }, + Response::builder() + .status(StatusCode::BAD_REQUEST) + .header( + CONTENT_TYPE, + HeaderValue::from_str("application/json").unwrap(), + ) + .body( + "{\ + \"error\": \"unsupported_token_type\", \"error_description\": \"stuff happened\", \ + \"error_uri\": \"https://errors\"\ + }" + .to_string() + .into_bytes(), + ) + .unwrap(), )); assert!(matches!( @@ -343,7 +382,7 @@ mod tests { client .revoke_token(AccessToken::new("access_token_123".to_string()).into()) .unwrap() - .request(mock_http_client( + .request(&mock_http_client( vec![ (ACCEPT, "application/json"), (CONTENT_TYPE, "application/x-www-form-urlencoded"), @@ -351,16 +390,14 @@ mod tests { ], "token=access_token_123&token_type_hint=access_token", Some("https://revocation/url".parse().unwrap()), - HttpResponse { - status_code: StatusCode::OK, - headers: vec![( + Response::builder() + .status(StatusCode::OK) + .header( CONTENT_TYPE, HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: b"{}".to_vec(), - }, + ) + .body(b"{}".to_vec()) + .unwrap(), )) .unwrap(); } @@ -373,7 +410,7 @@ mod tests { client .revoke_token(AccessToken::new("access_token_123".to_string()).into()) .unwrap() - .request(mock_http_client( + .request(&mock_http_client( vec![ (ACCEPT, "application/json"), (CONTENT_TYPE, "application/x-www-form-urlencoded"), @@ -381,11 +418,10 @@ mod tests { ], "token=access_token_123&token_type_hint=access_token", Some("https://revocation/url".parse().unwrap()), - HttpResponse { - status_code: StatusCode::OK, - headers: vec![].into_iter().collect(), - body: vec![], - }, + Response::builder() + .status(StatusCode::OK) + .body(vec![]) + .unwrap(), )) .unwrap(); } @@ -398,7 +434,7 @@ mod tests { client .revoke_token(AccessToken::new("access_token_123".to_string()).into()) .unwrap() - .request(mock_http_client( + .request(&mock_http_client( vec![ (ACCEPT, "application/json"), (CONTENT_TYPE, "application/x-www-form-urlencoded"), @@ -406,16 +442,14 @@ mod tests { ], "token=access_token_123&token_type_hint=access_token", Some("https://revocation/url".parse().unwrap()), - HttpResponse { - status_code: StatusCode::OK, - headers: vec![( + Response::builder() + .status(StatusCode::OK) + .header( CONTENT_TYPE, HeaderValue::from_str("application/octet-stream").unwrap(), - )] - .into_iter() - .collect(), - body: vec![1, 2, 3], - }, + ) + .body(vec![1, 2, 3]) + .unwrap(), )) .unwrap(); } @@ -428,7 +462,7 @@ mod tests { client .revoke_token(RefreshToken::new("refresh_token_123".to_string()).into()) .unwrap() - .request(mock_http_client( + .request(&mock_http_client( vec![ (ACCEPT, "application/json"), (CONTENT_TYPE, "application/x-www-form-urlencoded"), @@ -436,16 +470,14 @@ mod tests { ], "token=refresh_token_123&token_type_hint=refresh_token", Some("https://revocation/url".parse().unwrap()), - HttpResponse { - status_code: StatusCode::OK, - headers: vec![( + Response::builder() + .status(StatusCode::OK) + .header( CONTENT_TYPE, HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: b"{}".to_vec(), - }, + ) + .body(b"{}".to_vec()) + .unwrap(), )) .unwrap(); } @@ -463,7 +495,7 @@ mod tests { "colorful_token_123".to_string(), )) .unwrap() - .request(mock_http_client( + .request(&mock_http_client( vec![ (ACCEPT, "application/json"), (CONTENT_TYPE, "application/x-www-form-urlencoded"), @@ -471,16 +503,14 @@ mod tests { ], "token=colorful_token_123&token_type_hint=red_token", Some("https://revocation/url".parse().unwrap()), - HttpResponse { - status_code: StatusCode::OK, - headers: vec![( + Response::builder() + .status(StatusCode::OK) + .header( CONTENT_TYPE, HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: b"{}".to_vec(), - }, + ) + .body(b"{}".to_vec()) + .unwrap(), )) .unwrap(); } diff --git a/src/tests.rs b/src/tests.rs index 7efb0d2..c76a7ee 100644 --- a/src/tests.rs +++ b/src/tests.rs @@ -26,6 +26,17 @@ pub(crate) fn new_client() -> BasicClient { .set_client_secret(ClientSecret::new("bbb".to_string())) } +// FIXME: just clone `response` directly once we update `http` to 1.0, which implements `Clone`. +pub(crate) fn clone_response(response: &HttpResponse) -> HttpResponse { + let mut response_copy = http::Response::builder() + .status(response.status()) + .version(response.version()); + for (name, value) in response.headers() { + response_copy = response_copy.header(name, value); + } + response_copy.body(response.body().to_owned()).unwrap() +} + pub(crate) fn mock_http_client( request_headers: Vec<(HeaderName, &'static str)>, request_body: &'static str, @@ -34,21 +45,24 @@ pub(crate) fn mock_http_client( ) -> impl Fn(HttpRequest) -> Result { move |request: HttpRequest| { assert_eq!( - &request.url, + &Url::parse(&request.uri().to_string()).unwrap(), request_url .as_ref() .unwrap_or(&Url::parse("https://example.com/token").unwrap()) ); assert_eq!( - request.headers, - request_headers + request.headers(), + &request_headers .iter() .map(|(name, value)| (name.clone(), HeaderValue::from_str(value).unwrap())) .collect(), ); - assert_eq!(&String::from_utf8(request.body).unwrap(), request_body); + assert_eq!( + &String::from_utf8(request.body().to_owned()).unwrap(), + request_body + ); - Ok(response.clone()) + Ok(clone_response(&response)) } } @@ -179,27 +193,31 @@ pub(crate) fn mock_http_client_success_fail( num_failures: usize, success_response: HttpResponse, ) -> impl Fn(HttpRequest) -> Result { - let responses: Vec = std::iter::repeat(failure_response) - .take(num_failures) - .chain(std::iter::once(success_response)) - .collect(); + let responses: Vec = + std::iter::from_fn(|| Some(clone_response(&failure_response))) + .take(num_failures) + .chain(std::iter::once(success_response)) + .collect(); let sync_responses = std::sync::Mutex::new(responses); move |request: HttpRequest| { assert_eq!( - &request.url, + &Url::parse(&request.uri().to_string()).unwrap(), request_url .as_ref() .unwrap_or(&Url::parse("https://example.com/token").unwrap()) ); assert_eq!( - request.headers, - request_headers + request.headers(), + &request_headers .iter() .map(|(name, value)| (name.clone(), HeaderValue::from_str(value).unwrap())) .collect(), ); - assert_eq!(&String::from_utf8(request.body).unwrap(), request_body); + assert_eq!( + &String::from_utf8(request.body().to_owned()).unwrap(), + request_body + ); { let mut rsp_vec = sync_responses.lock().unwrap(); @@ -314,6 +332,6 @@ fn test_send_sync_impl() { #[cfg(feature = "curl")] is_sync_and_send::(); - #[cfg(feature = "reqwest")] + #[cfg(any(feature = "reqwest", feature = "reqwest-blocking"))] is_sync_and_send::(); } diff --git a/src/token/mod.rs b/src/token/mod.rs index 5bc93ec..b521a02 100644 --- a/src/token/mod.rs +++ b/src/token/mod.rs @@ -1,8 +1,8 @@ use crate::endpoint::{endpoint_request, endpoint_response}; use crate::{ - AccessToken, AuthType, AuthorizationCode, ClientId, ClientSecret, ErrorResponse, HttpRequest, - HttpResponse, PkceCodeVerifier, RedirectUrl, RefreshToken, RequestTokenError, - ResourceOwnerPassword, ResourceOwnerUsername, Scope, TokenUrl, + AccessToken, AsyncHttpClient, AuthType, AuthorizationCode, ClientId, ClientSecret, + ErrorResponse, HttpRequest, PkceCodeVerifier, RedirectUrl, RefreshToken, RequestTokenError, + ResourceOwnerPassword, ResourceOwnerUsername, Scope, SyncHttpClient, TokenUrl, }; use serde::de::DeserializeOwned; @@ -13,11 +13,16 @@ use std::error::Error; use std::fmt::Debug; use std::future::Future; use std::marker::PhantomData; +use std::pin::Pin; use std::time::Duration; #[cfg(test)] mod tests; +/// Future returned by `request_async` methods. +pub type TokenRequestFuture<'c, RE, TE, TR> = + dyn Future>> + 'c; + /// A request to exchange an authorization code for an access token. /// /// See . @@ -82,7 +87,10 @@ where self } - fn prepare_request(self) -> HttpRequest { + fn prepare_request(self) -> Result> + where + RE: Error + 'static, + { let mut params = vec![ ("grant_type", "authorization_code"), ("code", self.code.secret()), @@ -101,29 +109,30 @@ where self.token_url.url(), params, ) + .map_err(|err| RequestTokenError::Other(format!("failed to prepare request: {err}"))) } /// Synchronously sends the request to the authorization server and awaits a response. - pub fn request(self, http_client: F) -> Result> + pub fn request( + self, + http_client: &C, + ) -> Result::Error, TE>> where - F: FnOnce(HttpRequest) -> Result, - RE: Error + 'static, + C: SyncHttpClient, { - endpoint_response(http_client(self.prepare_request())?) + endpoint_response(http_client.call(self.prepare_request()?)?) } /// Asynchronously sends the request to the authorization server and returns a Future. - pub async fn request_async( + pub fn request_async<'c, C>( self, - http_client: C, - ) -> Result> + http_client: &'c C, + ) -> Pin>::Error, TE, TR>>> where - C: FnOnce(HttpRequest) -> F, - F: Future>, - RE: Error + 'static, + Self: 'c, + C: AsyncHttpClient<'c>, { - let http_response = http_client(self.prepare_request()).await?; - endpoint_response(http_response) + Box::pin(async move { endpoint_response(http_client.call(self.prepare_request()?).await?) }) } } @@ -190,33 +199,32 @@ where } /// Synchronously sends the request to the authorization server and awaits a response. - pub fn request(self, http_client: F) -> Result> + pub fn request( + self, + http_client: &C, + ) -> Result::Error, TE>> where - F: FnOnce(HttpRequest) -> Result, - RE: Error + 'static, + C: SyncHttpClient, { - endpoint_response(http_client(self.prepare_request()?)?) + endpoint_response(http_client.call(self.prepare_request()?)?) } /// Asynchronously sends the request to the authorization server and awaits a response. - pub async fn request_async( + pub fn request_async<'c, C>( self, - http_client: C, - ) -> Result> + http_client: &'c C, + ) -> Pin>::Error, TE, TR>>> where - C: FnOnce(HttpRequest) -> F, - F: Future>, - RE: Error + 'static, + Self: 'c, + C: AsyncHttpClient<'c>, { - let http_request = self.prepare_request()?; - let http_response = http_client(http_request).await?; - endpoint_response(http_response) + Box::pin(async move { endpoint_response(http_client.call(self.prepare_request()?).await?) }) } fn prepare_request(&self) -> Result> where RE: Error + 'static, { - Ok(endpoint_request( + endpoint_request( self.auth_type, self.client_id, self.client_secret, @@ -228,7 +236,8 @@ where ("grant_type", "refresh_token"), ("refresh_token", self.refresh_token.secret()), ], - )) + ) + .map_err(|err| RequestTokenError::Other(format!("failed to prepare request: {err}"))) } } @@ -296,34 +305,33 @@ where } /// Synchronously sends the request to the authorization server and awaits a response. - pub fn request(self, http_client: F) -> Result> + pub fn request( + self, + http_client: &C, + ) -> Result::Error, TE>> where - F: FnOnce(HttpRequest) -> Result, - RE: Error + 'static, + C: SyncHttpClient, { - endpoint_response(http_client(self.prepare_request()?)?) + endpoint_response(http_client.call(self.prepare_request()?)?) } /// Asynchronously sends the request to the authorization server and awaits a response. - pub async fn request_async( + pub fn request_async<'c, C>( self, - http_client: C, - ) -> Result> + http_client: &'c C, + ) -> Pin>::Error, TE, TR>>> where - C: FnOnce(HttpRequest) -> F, - F: Future>, - RE: Error + 'static, + Self: 'c, + C: AsyncHttpClient<'c>, { - let http_request = self.prepare_request()?; - let http_response = http_client(http_request).await?; - endpoint_response(http_response) + Box::pin(async move { endpoint_response(http_client.call(self.prepare_request()?).await?) }) } fn prepare_request(&self) -> Result> where RE: Error + 'static, { - Ok(endpoint_request( + endpoint_request( self.auth_type, self.client_id, self.client_secret, @@ -336,7 +344,8 @@ where ("username", self.username), ("password", self.password.secret()), ], - )) + ) + .map_err(|err| RequestTokenError::Other(format!("failed to prepare request: {err}"))) } } @@ -402,34 +411,33 @@ where } /// Synchronously sends the request to the authorization server and awaits a response. - pub fn request(self, http_client: F) -> Result> + pub fn request( + self, + http_client: &C, + ) -> Result::Error, TE>> where - F: FnOnce(HttpRequest) -> Result, - RE: Error + 'static, + C: SyncHttpClient, { - endpoint_response(http_client(self.prepare_request()?)?) + endpoint_response(http_client.call(self.prepare_request()?)?) } /// Asynchronously sends the request to the authorization server and awaits a response. - pub async fn request_async( + pub fn request_async<'c, C>( self, - http_client: C, - ) -> Result> + http_client: &'c C, + ) -> Pin>::Error, TE, TR>>> where - C: FnOnce(HttpRequest) -> F, - F: Future>, - RE: Error + 'static, + Self: 'c, + C: AsyncHttpClient<'c>, { - let http_request = self.prepare_request()?; - let http_response = http_client(http_request).await?; - endpoint_response(http_response) + Box::pin(async move { endpoint_response(http_client.call(self.prepare_request()?).await?) }) } fn prepare_request(&self) -> Result> where RE: Error + 'static, { - Ok(endpoint_request( + endpoint_request( self.auth_type, self.client_id, self.client_secret, @@ -438,7 +446,8 @@ where Some(&self.scopes), self.token_url.url(), vec![("grant_type", "client_credentials")], - )) + ) + .map_err(|err| RequestTokenError::Other(format!("failed to prepare request: {err}"))) } } diff --git a/src/token/tests.rs b/src/token/tests.rs index 5152b93..9255e16 100644 --- a/src/token/tests.rs +++ b/src/token/tests.rs @@ -9,13 +9,13 @@ use crate::tests::{mock_http_client, new_client, FakeError}; use crate::token::tests::custom_errors::CustomErrorClient; use crate::{ AccessToken, AuthType, AuthUrl, AuthorizationCode, ClientId, ClientSecret, ExtraTokenFields, - HttpResponse, PkceCodeVerifier, RedirectUrl, RefreshToken, RequestTokenError, - ResourceOwnerPassword, ResourceOwnerUsername, Scope, StandardErrorResponse, - StandardTokenResponse, TokenResponse, TokenType, TokenUrl, + PkceCodeVerifier, RedirectUrl, RefreshToken, RequestTokenError, ResourceOwnerPassword, + ResourceOwnerUsername, Scope, StandardErrorResponse, StandardTokenResponse, TokenResponse, + TokenType, TokenUrl, }; use http::header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE}; -use http::{HeaderMap, HeaderValue, StatusCode}; +use http::{HeaderValue, Response, StatusCode}; use std::borrow::Cow; use std::time::Duration; @@ -46,7 +46,7 @@ fn test_exchange_code_successful_with_minimal_json_response() { let token = client .exchange_code(AuthorizationCode::new("ccc".to_string())) - .request(mock_http_client( + .request(&mock_http_client( vec![ (ACCEPT, "application/json"), (CONTENT_TYPE, "application/x-www-form-urlencoded"), @@ -54,13 +54,14 @@ fn test_exchange_code_successful_with_minimal_json_response() { ], "grant_type=authorization_code&code=ccc", None, - HttpResponse { - status_code: StatusCode::OK, - headers: HeaderMap::new(), - body: "{\"access_token\": \"12/34\", \"token_type\": \"BEARER\"}" - .to_string() - .into_bytes(), - }, + Response::builder() + .status(StatusCode::OK) + .body( + "{\"access_token\": \"12/34\", \"token_type\": \"BEARER\"}" + .to_string() + .into_bytes(), + ) + .unwrap(), )) .unwrap(); @@ -85,31 +86,31 @@ fn test_exchange_code_successful_with_complete_json_response() { let client = new_client().set_auth_type(AuthType::RequestBody); let token = client .exchange_code(AuthorizationCode::new("ccc".to_string())) - .request(mock_http_client( + .request(&mock_http_client( vec![ (ACCEPT, "application/json"), (CONTENT_TYPE, "application/x-www-form-urlencoded"), ], "grant_type=authorization_code&code=ccc&client_id=aaa&client_secret=bbb", None, - HttpResponse { - status_code: StatusCode::OK, - headers: vec![( + Response::builder() + .status(StatusCode::OK) + .header( CONTENT_TYPE, HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: "{\ + ) + .body( + "{\ \"access_token\": \"12/34\", \ \"token_type\": \"bearer\", \ \"scope\": \"read write\", \ \"expires_in\": 3600, \ \"refresh_token\": \"foobar\"\ }" - .to_string() - .into_bytes(), - }, + .to_string() + .into_bytes(), + ) + .unwrap(), )) .unwrap(); @@ -148,7 +149,7 @@ fn test_exchange_client_credentials_with_basic_auth() { let token = client .exchange_client_credentials() - .request(mock_http_client( + .request(&mock_http_client( vec![ (ACCEPT, "application/json"), (CONTENT_TYPE, "application/x-www-form-urlencoded"), @@ -156,17 +157,18 @@ fn test_exchange_client_credentials_with_basic_auth() { ], "grant_type=client_credentials", None, - HttpResponse { - status_code: StatusCode::OK, - headers: HeaderMap::new(), - body: "{\ + Response::builder() + .status(StatusCode::OK) + .body( + "{\ \"access_token\": \"12/34\", \ \"token_type\": \"bearer\", \ \"scope\": \"read write\"\ }" - .to_string() - .into_bytes(), - }, + .to_string() + .into_bytes(), + ) + .unwrap(), )) .unwrap(); @@ -192,24 +194,25 @@ fn test_exchange_client_credentials_with_basic_auth_but_no_client_secret() { let token = client .exchange_client_credentials() - .request(mock_http_client( + .request(&mock_http_client( vec![ (ACCEPT, "application/json"), (CONTENT_TYPE, "application/x-www-form-urlencoded"), ], "grant_type=client_credentials&client_id=aaa%2F%3B%26", None, - HttpResponse { - status_code: StatusCode::OK, - headers: HeaderMap::new(), - body: "{\ + Response::builder() + .status(StatusCode::OK) + .body( + "{\ \"access_token\": \"12/34\", \ \"token_type\": \"bearer\", \ \"scope\": \"read write\"\ }" - .to_string() - .into_bytes(), - }, + .to_string() + .into_bytes(), + ) + .unwrap(), )) .unwrap(); @@ -233,29 +236,29 @@ fn test_exchange_client_credentials_with_body_auth_and_scope() { .exchange_client_credentials() .add_scope(Scope::new("read".to_string())) .add_scope(Scope::new("write".to_string())) - .request(mock_http_client( + .request(&mock_http_client( vec![ (ACCEPT, "application/json"), (CONTENT_TYPE, "application/x-www-form-urlencoded"), ], "grant_type=client_credentials&scope=read+write&client_id=aaa&client_secret=bbb", None, - HttpResponse { - status_code: StatusCode::OK, - headers: vec![( + Response::builder() + .status(StatusCode::OK) + .header( CONTENT_TYPE, HeaderValue::from_str("APPLICATION/jSoN").unwrap(), - )] - .into_iter() - .collect(), - body: "{\ + ) + .body( + "{\ \"access_token\": \"12/34\", \ \"token_type\": \"bearer\", \ \"scope\": \"read write\"\ }" - .to_string() - .into_bytes(), - }, + .to_string() + .into_bytes(), + ) + .unwrap(), )) .unwrap(); @@ -277,7 +280,7 @@ fn test_exchange_refresh_token_with_basic_auth() { let client = new_client().set_auth_type(AuthType::BasicAuth); let token = client .exchange_refresh_token(&RefreshToken::new("ccc".to_string())) - .request(mock_http_client( + .request(&mock_http_client( vec![ (ACCEPT, "application/json"), (CONTENT_TYPE, "application/x-www-form-urlencoded"), @@ -285,16 +288,17 @@ fn test_exchange_refresh_token_with_basic_auth() { ], "grant_type=refresh_token&refresh_token=ccc", None, - HttpResponse { - status_code: StatusCode::OK, - headers: HeaderMap::new(), - body: "{\"access_token\": \"12/34\", \ + Response::builder() + .status(StatusCode::OK) + .body( + "{\"access_token\": \"12/34\", \ \"token_type\": \"bearer\", \ \"scope\": \"read write\"\ }" - .to_string() - .into_bytes(), - }, + .to_string() + .into_bytes(), + ) + .unwrap(), )) .unwrap(); @@ -316,7 +320,7 @@ fn test_exchange_refresh_token_with_json_response() { let client = new_client(); let token = client .exchange_refresh_token(&RefreshToken::new("ccc".to_string())) - .request(mock_http_client( + .request(&mock_http_client( vec![ (ACCEPT, "application/json"), (CONTENT_TYPE, "application/x-www-form-urlencoded"), @@ -324,17 +328,18 @@ fn test_exchange_refresh_token_with_json_response() { ], "grant_type=refresh_token&refresh_token=ccc", None, - HttpResponse { - status_code: StatusCode::OK, - headers: HeaderMap::new(), - body: "{\ + Response::builder() + .status(StatusCode::OK) + .body( + "{\ \"access_token\": \"12/34\", \ \"token_type\": \"bearer\", \ \"scope\": \"read write\"\ }" - .to_string() - .into_bytes(), - }, + .to_string() + .into_bytes(), + ) + .unwrap(), )) .unwrap(); @@ -361,7 +366,7 @@ fn test_exchange_password_with_json_response() { ) .add_scope(Scope::new("read".to_string())) .add_scope(Scope::new("write".to_string())) - .request(mock_http_client( + .request(&mock_http_client( vec![ (ACCEPT, "application/json"), (CONTENT_TYPE, "application/x-www-form-urlencoded"), @@ -369,22 +374,22 @@ fn test_exchange_password_with_json_response() { ], "grant_type=password&username=user&password=pass&scope=read+write", None, - HttpResponse { - status_code: StatusCode::OK, - headers: vec![( + Response::builder() + .status(StatusCode::OK) + .header( CONTENT_TYPE, HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: "{\ + ) + .body( + "{\ \"access_token\": \"12/34\", \ \"token_type\": \"bearer\", \ \"scope\": \"read write\"\ }" - .to_string() - .into_bytes(), - }, + .to_string() + .into_bytes(), + ) + .unwrap(), )) .unwrap(); @@ -409,7 +414,7 @@ fn test_exchange_code_successful_with_redirect_url() { let token = client .exchange_code(AuthorizationCode::new("ccc".to_string())) - .request(mock_http_client( + .request(&mock_http_client( vec![ (ACCEPT, "application/json"), (CONTENT_TYPE, "application/x-www-form-urlencoded"), @@ -417,22 +422,22 @@ fn test_exchange_code_successful_with_redirect_url() { "grant_type=authorization_code&code=ccc&client_id=aaa&client_secret=bbb&\ redirect_uri=https%3A%2F%2Fredirect%2Fhere", None, - HttpResponse { - status_code: StatusCode::OK, - headers: vec![( + Response::builder() + .status(StatusCode::OK) + .header( CONTENT_TYPE, HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: "{\ + ) + .body( + "{\ \"access_token\": \"12/34\", \ \"token_type\": \"bearer\", \ \"scope\": \"read write\"\ }" - .to_string() - .into_bytes(), - }, + .to_string() + .into_bytes(), + ) + .unwrap(), )) .unwrap(); @@ -460,7 +465,7 @@ fn test_exchange_code_successful_with_redirect_url_override() { .set_redirect_uri(Cow::Owned( RedirectUrl::new("https://redirect/alternative".to_string()).unwrap(), )) - .request(mock_http_client( + .request(&mock_http_client( vec![ (ACCEPT, "application/json"), (CONTENT_TYPE, "application/x-www-form-urlencoded"), @@ -468,22 +473,22 @@ fn test_exchange_code_successful_with_redirect_url_override() { "grant_type=authorization_code&code=ccc&client_id=aaa&client_secret=bbb&\ redirect_uri=https%3A%2F%2Fredirect%2Falternative", None, - HttpResponse { - status_code: StatusCode::OK, - headers: vec![( + Response::builder() + .status(StatusCode::OK) + .header( CONTENT_TYPE, HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: "{\ + ) + .body( + "{\ \"access_token\": \"12/34\", \ \"token_type\": \"bearer\", \ \"scope\": \"read write\"\ }" - .to_string() - .into_bytes(), - }, + .to_string() + .into_bytes(), + ) + .unwrap(), )) .unwrap(); @@ -508,7 +513,7 @@ fn test_exchange_code_successful_with_basic_auth() { let token = client .exchange_code(AuthorizationCode::new("ccc".to_string())) - .request(mock_http_client( + .request(&mock_http_client( vec![ (ACCEPT, "application/json"), (CONTENT_TYPE, "application/x-www-form-urlencoded"), @@ -516,22 +521,22 @@ fn test_exchange_code_successful_with_basic_auth() { ], "grant_type=authorization_code&code=ccc&redirect_uri=https%3A%2F%2Fredirect%2Fhere", None, - HttpResponse { - status_code: StatusCode::OK, - headers: vec![( + Response::builder() + .status(StatusCode::OK) + .header( CONTENT_TYPE, HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: "{\ + ) + .body( + "{\ \"access_token\": \"12/34\", \ \"token_type\": \"bearer\", \ \"scope\": \"read write\"\ }" - .to_string() - .into_bytes(), - }, + .to_string() + .into_bytes(), + ) + .unwrap(), )) .unwrap(); @@ -560,7 +565,7 @@ fn test_exchange_code_successful_with_pkce_and_extension() { "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk".to_string(), )) .add_extra_param("foo", "bar") - .request(mock_http_client( + .request(&mock_http_client( vec![ (ACCEPT, "application/json"), (CONTENT_TYPE, "application/x-www-form-urlencoded"), @@ -572,22 +577,22 @@ fn test_exchange_code_successful_with_pkce_and_extension() { &redirect_uri=https%3A%2F%2Fredirect%2Fhere\ &foo=bar", None, - HttpResponse { - status_code: StatusCode::OK, - headers: vec![( + Response::builder() + .status(StatusCode::OK) + .header( CONTENT_TYPE, HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: "{\ + ) + .body( + "{\ \"access_token\": \"12/34\", \ \"token_type\": \"bearer\", \ \"scope\": \"read write\"\ }" - .to_string() - .into_bytes(), - }, + .to_string() + .into_bytes(), + ) + .unwrap(), )) .unwrap(); @@ -613,7 +618,7 @@ fn test_exchange_refresh_token_successful_with_extension() { let token = client .exchange_refresh_token(&RefreshToken::new("ccc".to_string())) .add_extra_param("foo", "bar") - .request(mock_http_client( + .request(&mock_http_client( vec![ (ACCEPT, "application/json"), (CONTENT_TYPE, "application/x-www-form-urlencoded"), @@ -621,22 +626,22 @@ fn test_exchange_refresh_token_successful_with_extension() { ], "grant_type=refresh_token&refresh_token=ccc&foo=bar", None, - HttpResponse { - status_code: StatusCode::OK, - headers: vec![( + Response::builder() + .status(StatusCode::OK) + .header( CONTENT_TYPE, HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: "{\ + ) + .body( + "{\ \"access_token\": \"12/34\", \ \"token_type\": \"bearer\", \ \"scope\": \"read write\"\ }" - .to_string() - .into_bytes(), - }, + .to_string() + .into_bytes(), + ) + .unwrap(), )) .unwrap(); @@ -658,7 +663,7 @@ fn test_exchange_code_with_simple_json_error() { let client = new_client(); let token = client .exchange_code(AuthorizationCode::new("ccc".to_string())) - .request(mock_http_client( + .request(&mock_http_client( vec![ (ACCEPT, "application/json"), (CONTENT_TYPE, "application/x-www-form-urlencoded"), @@ -666,21 +671,21 @@ fn test_exchange_code_with_simple_json_error() { ], "grant_type=authorization_code&code=ccc", None, - HttpResponse { - status_code: StatusCode::BAD_REQUEST, - headers: vec![( + Response::builder() + .status(StatusCode::BAD_REQUEST) + .header( CONTENT_TYPE, HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: "{\ + ) + .body( + "{\ \"error\": \"invalid_request\", \ \"error_description\": \"stuff happened\"\ }" - .to_string() - .into_bytes(), - }, + .to_string() + .into_bytes(), + ) + .unwrap(), )); assert!(token.is_err()); @@ -748,7 +753,7 @@ fn test_exchange_code_with_json_parse_error() { let client = new_client(); let token = client .exchange_code(AuthorizationCode::new("ccc".to_string())) - .request(mock_http_client( + .request(&mock_http_client( vec![ (ACCEPT, "application/json"), (CONTENT_TYPE, "application/x-www-form-urlencoded"), @@ -756,16 +761,14 @@ fn test_exchange_code_with_json_parse_error() { ], "grant_type=authorization_code&code=ccc", None, - HttpResponse { - status_code: StatusCode::OK, - headers: vec![( + Response::builder() + .status(StatusCode::OK) + .header( CONTENT_TYPE, HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: "broken json".to_string().into_bytes(), - }, + ) + .body("broken json".to_string().into_bytes()) + .unwrap(), )); assert!(token.is_err()); @@ -789,7 +792,7 @@ fn test_exchange_code_with_unexpected_content_type() { let client = new_client(); let token = client .exchange_code(AuthorizationCode::new("ccc".to_string())) - .request(mock_http_client( + .request(&mock_http_client( vec![ (ACCEPT, "application/json"), (CONTENT_TYPE, "application/x-www-form-urlencoded"), @@ -797,13 +800,11 @@ fn test_exchange_code_with_unexpected_content_type() { ], "grant_type=authorization_code&code=ccc", None, - HttpResponse { - status_code: StatusCode::OK, - headers: vec![(CONTENT_TYPE, HeaderValue::from_str("text/plain").unwrap())] - .into_iter() - .collect(), - body: "broken json".to_string().into_bytes(), - }, + Response::builder() + .status(StatusCode::OK) + .header(CONTENT_TYPE, HeaderValue::from_str("text/plain").unwrap()) + .body("broken json".to_string().into_bytes()) + .unwrap(), )); assert!(token.is_err()); @@ -811,7 +812,7 @@ fn test_exchange_code_with_unexpected_content_type() { match token.err().unwrap() { RequestTokenError::Other(error_str) => { assert_eq!( - "Unexpected response Content-Type: \"text/plain\", should be `application/json`", + "unexpected response Content-Type: \"text/plain\", should be `application/json`", error_str ); } @@ -827,25 +828,25 @@ fn test_exchange_code_with_invalid_token_type() { let token = client .exchange_code(AuthorizationCode::new("ccc".to_string())) - .request(mock_http_client( + .request(&mock_http_client( vec![ (ACCEPT, "application/json"), (CONTENT_TYPE, "application/x-www-form-urlencoded"), ], "grant_type=authorization_code&code=ccc&client_id=aaa", None, - HttpResponse { - status_code: StatusCode::OK, - headers: vec![( + Response::builder() + .status(StatusCode::OK) + .header( CONTENT_TYPE, HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: "{\"access_token\": \"12/34\", \"token_type\": 123}" - .to_string() - .into_bytes(), - }, + ) + .body( + "{\"access_token\": \"12/34\", \"token_type\": 123}" + .to_string() + .into_bytes(), + ) + .unwrap(), )); assert!(token.is_err()); @@ -869,7 +870,7 @@ fn test_exchange_code_with_400_status_code() { let client = new_client(); let token_err = client .exchange_code(AuthorizationCode::new("ccc".to_string())) - .request(mock_http_client( + .request(&mock_http_client( vec![ (ACCEPT, "application/json"), (CONTENT_TYPE, "application/x-www-form-urlencoded"), @@ -877,16 +878,14 @@ fn test_exchange_code_with_400_status_code() { ], "grant_type=authorization_code&code=ccc", None, - HttpResponse { - status_code: StatusCode::BAD_REQUEST, - headers: vec![( + Response::builder() + .status(StatusCode::BAD_REQUEST) + .header( CONTENT_TYPE, HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: body.to_string().into_bytes(), - }, + ) + .body(body.to_string().into_bytes()) + .unwrap(), )) .err() .unwrap(); @@ -921,7 +920,7 @@ fn test_exchange_code_fails_gracefully_on_transport_error() { let token = client .exchange_code(AuthorizationCode::new("ccc".to_string())) - .request(|_| Err(FakeError::Err)); + .request(&|_| Err(FakeError::Err)); assert!(token.is_err()); @@ -940,7 +939,7 @@ fn test_extension_successful_with_minimal_json_response() { let token = client .exchange_code(AuthorizationCode::new("ccc".to_string())) - .request(mock_http_client( + .request(&mock_http_client( vec![ (ACCEPT, "application/json"), (CONTENT_TYPE, "application/x-www-form-urlencoded"), @@ -948,18 +947,18 @@ fn test_extension_successful_with_minimal_json_response() { ], "grant_type=authorization_code&code=ccc", None, - HttpResponse { - status_code: StatusCode::OK, - headers: vec![( + Response::builder() + .status(StatusCode::OK) + .header( CONTENT_TYPE, HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: "{\"access_token\": \"12/34\", \"token_type\": \"green\", \"height\": 10}" - .to_string() - .into_bytes(), - }, + ) + .body( + "{\"access_token\": \"12/34\", \"token_type\": \"green\", \"height\": 10}" + .to_string() + .into_bytes(), + ) + .unwrap(), )) .unwrap(); @@ -992,22 +991,21 @@ fn test_extension_successful_with_complete_json_response() { let token = client .exchange_code(AuthorizationCode::new("ccc".to_string())) - .request(mock_http_client( + .request(&mock_http_client( vec![ (ACCEPT, "application/json"), (CONTENT_TYPE, "application/x-www-form-urlencoded"), ], "grant_type=authorization_code&code=ccc&client_id=aaa&client_secret=bbb", None, - HttpResponse { - status_code: StatusCode::OK, - headers: vec![( + Response::builder() + .status(StatusCode::OK) + .header( CONTENT_TYPE, HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: "{\ + ) + .body( + "{\ \"access_token\": \"12/34\", \ \"token_type\": \"red\", \ \"scope\": \"read write\", \ @@ -1016,9 +1014,10 @@ fn test_extension_successful_with_complete_json_response() { \"shape\": \"round\", \ \"height\": 12\ }" - .to_string() - .into_bytes(), - }, + .to_string() + .into_bytes(), + ) + .unwrap(), )) .unwrap(); @@ -1059,7 +1058,7 @@ fn test_extension_with_simple_json_error() { let token = client .exchange_code(AuthorizationCode::new("ccc".to_string())) - .request(mock_http_client( + .request(&mock_http_client( vec![ (ACCEPT, "application/json"), (CONTENT_TYPE, "application/x-www-form-urlencoded"), @@ -1067,19 +1066,19 @@ fn test_extension_with_simple_json_error() { ], "grant_type=authorization_code&code=ccc", None, - HttpResponse { - status_code: StatusCode::BAD_REQUEST, - headers: vec![( + Response::builder() + .status(StatusCode::BAD_REQUEST) + .header( CONTENT_TYPE, HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: "{\"error\": \"too_light\", \"error_description\": \"stuff happened\", \ + ) + .body( + "{\"error\": \"too_light\", \"error_description\": \"stuff happened\", \ \"error_uri\": \"https://errors\"}" - .to_string() - .into_bytes(), - }, + .to_string() + .into_bytes(), + ) + .unwrap(), )); assert!(token.is_err()); @@ -1184,7 +1183,7 @@ fn test_extension_with_custom_json_error() { let token = client .exchange_code(AuthorizationCode::new("ccc".to_string())) - .request(mock_http_client( + .request(&mock_http_client( vec![ (ACCEPT, "application/json"), (CONTENT_TYPE, "application/x-www-form-urlencoded"), @@ -1192,18 +1191,18 @@ fn test_extension_with_custom_json_error() { ], "grant_type=authorization_code&code=ccc", None, - HttpResponse { - status_code: StatusCode::BAD_REQUEST, - headers: vec![( + Response::builder() + .status(StatusCode::BAD_REQUEST) + .header( CONTENT_TYPE, HeaderValue::from_str("application/json").unwrap(), - )] - .into_iter() - .collect(), - body: "{\"custom_error\": \"non-compliant oauth implementation ;-)\"}" - .to_string() - .into_bytes(), - }, + ) + .body( + "{\"custom_error\": \"non-compliant oauth implementation ;-)\"}" + .to_string() + .into_bytes(), + ) + .unwrap(), )); assert!(token.is_err()); diff --git a/src/ureq.rs b/src/ureq.rs index 3e76fb9..21f8b87 100644 --- a/src/ureq.rs +++ b/src/ureq.rs @@ -1,11 +1,13 @@ use crate::{HttpRequest, HttpResponse}; use http::{ - header::{HeaderMap, HeaderValue, CONTENT_TYPE}, + header::{HeaderValue, CONTENT_TYPE}, method::Method, status::StatusCode, }; +use std::io::Read; + /// Error type returned by failed ureq HTTP requests. #[derive(Debug, thiserror::Error)] pub enum Error { @@ -24,16 +26,18 @@ pub enum Error { Ureq(#[from] Box), } -/// Synchronous HTTP client for ureq. -pub fn http_client(request: HttpRequest) -> Result { - let mut req = if request.method == Method::POST { - ureq::post(request.url.as_ref()) - } else { - ureq::get(request.url.as_ref()) - }; +impl crate::SyncHttpClient for ureq::Agent { + type Error = Error; + + fn call(&self, request: HttpRequest) -> Result { + let mut req = if *request.method() == Method::POST { + self.post(&request.uri().to_string()) + } else { + debug_assert_eq!(*request.method(), Method::GET); + self.get(&request.uri().to_string()) + }; - for (name, value) in request.headers { - if let Some(name) = name { + for (name, value) in request.headers() { req = req.set( name.as_ref(), // TODO: In newer `ureq` it should be easier to convert arbitrary byte sequences @@ -47,23 +51,29 @@ pub fn http_client(request: HttpRequest) -> Result { })?, ); } - } - let response = if let Method::POST = request.method { - req.send_bytes(&request.body) - } else { - req.call() - } - .map_err(Box::new)?; + let response = if let Method::POST = *request.method() { + req.send_bytes(request.body()) + } else { + req.call() + } + .map_err(Box::new)?; - Ok(HttpResponse { - status_code: StatusCode::from_u16(response.status()).map_err(http::Error::from)?, - headers: vec![( - CONTENT_TYPE, - HeaderValue::from_str(response.content_type()).map_err(http::Error::from)?, - )] - .into_iter() - .collect::(), - body: response.into_string()?.as_bytes().into(), - }) + let mut builder = http::Response::builder() + .status(StatusCode::from_u16(response.status()).map_err(http::Error::from)?); + + if let Some(content_type) = response + .header(CONTENT_TYPE.as_str()) + .map(HeaderValue::from_str) + .transpose() + .map_err(http::Error::from)? + { + builder = builder.header(CONTENT_TYPE, content_type); + } + + let mut body = Vec::new(); + response.into_reader().read_to_end(&mut body)?; + + builder.body(body).map_err(Error::Http) + } }