Skip to content

Commit

Permalink
Add AsyncHttpClient and SyncHttpClient traits
Browse files Browse the repository at this point in the history
Previously, this crate defined its own `HttpRequest`/`HttpResponse`
types and used `Fn` traits to define the HTTP client interfaces. This
change replaces the `HttpRequest` and `HttpResponse` structs with type
aliases to use `http::Request` and `http::Response`,
respectively (#236). It also replaces the `Fn`-based interface with new
`AsyncHttpClient` and `SyncHttpClient` traits. The corresponding
`http_client` and `async_http_client` implementations (for `reqwest`,
`curl`, and `ureq`) are replaced with trait implementations on stateful
clients, where available. For example, `AsyncHttpClient` is now
implemented for `reqwest::Client`, which allows connection reuse between
requests. For convenience, these traits are also implemented for `Fn`
types to support custom clients without requiring an explicit trait
implementation.

BREAKING CHANGES:
 - Each `request()` method now accepts a reference to a type
   implementing `SyncHttpClient`, rather than an owned function type.
 - Each `request_async()` method now accepts a reference to a type
   implementing `AsyncHttpClient`, rather than an owned function type.
   They now return `Pin<Box<dyn Future>>`` instead of being declared as
   `async fn`s.
 - `HttpRequest` is now a type alias to `http::Request`.
 - `HttpResponse` is now a type alias to `http::Response`.
 - `curl::http_client` has been replaced with `curl::CurlHttpClient`.
 - `reqwest::async_http_client` has been replaced with the
   `AsyncHttpClient` trait being implemented for `reqwest::Client`.
 - `reqwest::sync_http_client` has been replaced with the
   `SyncHttpClient` trait being implemented for
   `reqwest::blocking::Client`.
 - `ureq::http_client` has been replaced with the `SyncHttpClient` trait
   being implemented for `ureq::Agent`.

Resolves #163.
Resolves #236.
Resolves #253.
  • Loading branch information
ramosbugs committed Feb 27, 2024
1 parent 4d55c26 commit 23b952b
Show file tree
Hide file tree
Showing 21 changed files with 1,154 additions and 933 deletions.
11 changes: 8 additions & 3 deletions examples/github.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
//!
use oauth2::basic::BasicClient;
// Alternatively, this can be `oauth2::curl::http_client` or a custom client.
use oauth2::reqwest::http_client;
use oauth2::reqwest::reqwest;
use oauth2::{
AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, RedirectUrl, Scope,
TokenResponse, TokenUrl,
Expand Down Expand Up @@ -50,6 +49,12 @@ fn main() {
RedirectUrl::new("http://localhost:8080".to_string()).expect("Invalid redirect URL"),
);

let http_client = reqwest::blocking::ClientBuilder::new()
// Following redirects opens the client up to SSRF vulnerabilities.
.redirect(reqwest::redirect::Policy::none())
.build()
.expect("Client should build");

// Generate the authorization URL to which we'll redirect the user.
let (authorize_url, csrf_state) = client
.authorize_url(CsrfToken::new_random)
Expand Down Expand Up @@ -108,7 +113,7 @@ fn main() {
);

// Exchange the code with a token.
let token_res = client.exchange_code(code).request(http_client);
let token_res = client.exchange_code(code).request(&http_client);

println!("Github returned the following token:\n{:?}\n", token_res);

Expand Down
13 changes: 8 additions & 5 deletions examples/github_async.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
//!
use oauth2::basic::BasicClient;
use oauth2::reqwest::async_http_client;
use oauth2::reqwest::reqwest;
use oauth2::{
AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, RedirectUrl, Scope,
TokenResponse, TokenUrl,
Expand Down Expand Up @@ -50,6 +50,12 @@ async fn main() {
RedirectUrl::new("http://localhost:8080".to_string()).expect("Invalid redirect URL"),
);

let http_client = reqwest::ClientBuilder::new()
// Following redirects opens the client up to SSRF vulnerabilities.
.redirect(reqwest::redirect::Policy::none())
.build()
.expect("Client should build");

// Generate the authorization URL to which we'll redirect the user.
let (authorize_url, csrf_state) = client
.authorize_url(CsrfToken::new_random)
Expand Down Expand Up @@ -108,10 +114,7 @@ async fn main() {
);

// Exchange the code with a token.
let token_res = client
.exchange_code(code)
.request_async(async_http_client)
.await;
let token_res = client.exchange_code(code).request_async(&http_client).await;

println!("Github returned the following token:\n{:?}\n", token_res);

Expand Down
13 changes: 9 additions & 4 deletions examples/google.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@
//! ...and follow the instructions.
//!
use oauth2::reqwest::reqwest;
use oauth2::{basic::BasicClient, StandardRevocableToken, TokenResponse};
// Alternatively, this can be oauth2::curl::http_client or a custom.
use oauth2::reqwest::http_client;
use oauth2::{
AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, PkceCodeChallenge, RedirectUrl,
RevocationUrl, Scope, TokenUrl,
Expand Down Expand Up @@ -55,6 +54,12 @@ fn main() {
.expect("Invalid revocation endpoint URL"),
);

let http_client = reqwest::blocking::ClientBuilder::new()
// Following redirects opens the client up to SSRF vulnerabilities.
.redirect(reqwest::redirect::Policy::none())
.build()
.expect("Client should build");

// Google supports Proof Key for Code Exchange (PKCE - https://oauth.net/2/pkce/).
// Create a PKCE code verifier and SHA-256 encode it as a code challenge.
let (pkce_code_challenge, pkce_code_verifier) = PkceCodeChallenge::new_random_sha256();
Expand Down Expand Up @@ -125,7 +130,7 @@ fn main() {
let token_response = client
.exchange_code(code)
.set_pkce_verifier(pkce_code_verifier)
.request(http_client);
.request(&http_client);

println!(
"Google returned the following token:\n{:?}\n",
Expand All @@ -142,6 +147,6 @@ fn main() {
client
.revoke_token(token_to_revoke)
.unwrap()
.request(http_client)
.request(&http_client)
.expect("Failed to revoke token");
}
13 changes: 9 additions & 4 deletions examples/google_devicecode.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@
//!
use oauth2::basic::BasicClient;
// Alternatively, this can be oauth2::curl::http_client or a custom.
use oauth2::reqwest::http_client;
use oauth2::reqwest::reqwest;
use oauth2::{
AuthType, AuthUrl, ClientId, ClientSecret, DeviceAuthorizationResponse, DeviceAuthorizationUrl,
ExtraDeviceAuthorizationFields, Scope, TokenUrl,
Expand Down Expand Up @@ -58,11 +57,17 @@ fn main() {
.set_device_authorization_url(device_auth_url)
.set_auth_type(AuthType::RequestBody);

let http_client = reqwest::blocking::ClientBuilder::new()
// Following redirects opens the client up to SSRF vulnerabilities.
.redirect(reqwest::redirect::Policy::none())
.build()
.expect("Client should build");

// Request the set of codes from the Device Authorization endpoint.
let details: StoringDeviceAuthorizationResponse = device_client
.exchange_device_code()
.add_scope(Scope::new("profile".to_string()))
.request(http_client)
.request(&http_client)
.expect("Failed to request codes from device auth endpoint");

// Display the URL and user-code.
Expand All @@ -75,7 +80,7 @@ fn main() {
// Now poll for the token
let token = device_client
.exchange_device_access_token(&details)
.request(http_client, std::thread::sleep, None)
.request(&http_client, std::thread::sleep, None)
.expect("Failed to get token");

println!("Google returned the following token:\n{:?}\n", token);
Expand Down
23 changes: 18 additions & 5 deletions examples/letterboxd.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ use hex::ToHex;
use hmac::{Hmac, Mac};
use oauth2::{
basic::BasicClient, AuthType, AuthUrl, ClientId, ClientSecret, HttpRequest, HttpResponse,
ResourceOwnerPassword, ResourceOwnerUsername, TokenUrl,
ResourceOwnerPassword, ResourceOwnerUsername, SyncHttpClient, TokenUrl,
};
use sha2::Sha256;
use url::Url;
Expand Down Expand Up @@ -63,7 +63,7 @@ fn main() -> Result<(), anyhow::Error> {
let token_result = client
.set_auth_type(AuthType::RequestBody)
.exchange_password(&letterboxd_username, &letterboxd_password)
.request(|request| http_client.execute(request))?;
.request(&|request| http_client.execute(request))?;

println!("{:?}", token_result);

Expand All @@ -77,21 +77,34 @@ fn main() -> Result<(), anyhow::Error> {
struct SigningHttpClient {
client_id: ClientId,
client_secret: ClientSecret,
inner: reqwest::blocking::Client,
}

impl SigningHttpClient {
fn new(client_id: ClientId, client_secret: ClientSecret) -> Self {
Self {
client_id,
client_secret,
inner: reqwest::blocking::ClientBuilder::new()
// Following redirects opens the client up to SSRF vulnerabilities.
.redirect(reqwest::redirect::Policy::none())
.build()
.expect("Client should build"),
}
}

/// Signs the request before calling `oauth2::reqwest::http_client`.
fn execute(&self, mut request: HttpRequest) -> Result<HttpResponse, impl std::error::Error> {
let signed_url = self.sign_url(request.url, &request.method, &request.body);
request.url = signed_url;
oauth2::reqwest::http_client(request)
let signed_url = self.sign_url(
Url::parse(&request.uri().to_string()).expect("http::Uri should be a valid url::Url"),
request.method(),
request.body(),
);
*request.uri_mut() = signed_url
.as_str()
.try_into()
.expect("url::Url should be a valid http::Uri");
self.inner.call(request)
}

/// Signs the request based on a random and unique nonce, timestamp, and
Expand Down
11 changes: 8 additions & 3 deletions examples/microsoft_devicecode_common_user.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use oauth2::basic::BasicClient;
use oauth2::reqwest::async_http_client;
use oauth2::{
AuthUrl, ClientId, DeviceAuthorizationUrl, Scope, StandardDeviceAuthorizationResponse, TokenUrl,
};
Expand All @@ -19,10 +18,16 @@ async fn main() -> Result<(), Box<dyn Error>> {
"https://login.microsoftonline.com/common/oauth2/v2.0/devicecode".to_string(),
)?);

let http_client = reqwest::ClientBuilder::new()
// Following redirects opens the client up to SSRF vulnerabilities.
.redirect(reqwest::redirect::Policy::none())
.build()
.expect("Client should build");

let details: StandardDeviceAuthorizationResponse = client
.exchange_device_code()
.add_scope(Scope::new("read".to_string()))
.request_async(async_http_client)
.request_async(&http_client)
.await?;

eprintln!(
Expand All @@ -33,7 +38,7 @@ async fn main() -> Result<(), Box<dyn Error>> {

let token_result = client
.exchange_device_access_token(&details)
.request_async(async_http_client, tokio::time::sleep, None)
.request_async(&http_client, tokio::time::sleep, None)
.await;

eprintln!("Token:{:?}", token_result);
Expand Down
12 changes: 9 additions & 3 deletions examples/microsoft_devicecode_tenant_user.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use oauth2::basic::BasicClient;
use oauth2::reqwest::async_http_client;
use oauth2::reqwest::reqwest;
use oauth2::StandardDeviceAuthorizationResponse;
use oauth2::{AuthUrl, ClientId, DeviceAuthorizationUrl, Scope, TokenUrl};

Expand All @@ -25,10 +25,16 @@ async fn main() -> Result<(), Box<dyn Error>> {
TENANT_ID
))?);

let http_client = reqwest::ClientBuilder::new()
// Following redirects opens the client up to SSRF vulnerabilities.
.redirect(reqwest::redirect::Policy::none())
.build()
.expect("Client should build");

let details: StandardDeviceAuthorizationResponse = client
.exchange_device_code()
.add_scope(Scope::new("read".to_string()))
.request_async(async_http_client)
.request_async(&http_client)
.await?;

eprintln!(
Expand All @@ -39,7 +45,7 @@ async fn main() -> Result<(), Box<dyn Error>> {

let token_result = client
.exchange_device_access_token(&details)
.request_async(async_http_client, tokio::time::sleep, None)
.request_async(&http_client, tokio::time::sleep, None)
.await;

eprintln!("Token:{:?}", token_result);
Expand Down
11 changes: 8 additions & 3 deletions examples/msgraph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@
//!
use oauth2::basic::BasicClient;
// Alternatively, this can be `oauth2::curl::http_client` or a custom client.
use oauth2::reqwest::http_client;
use oauth2::reqwest::reqwest;
use oauth2::{
AuthType, AuthUrl, AuthorizationCode, ClientId, ClientSecret, CsrfToken, PkceCodeChallenge,
RedirectUrl, Scope, TokenUrl,
Expand Down Expand Up @@ -63,6 +62,12 @@ fn main() {
.expect("Invalid redirect URL"),
);

let http_client = reqwest::blocking::ClientBuilder::new()
// Following redirects opens the client up to SSRF vulnerabilities.
.redirect(reqwest::redirect::Policy::none())
.build()
.expect("Client should build");

// Microsoft Graph supports Proof Key for Code Exchange (PKCE - https://oauth.net/2/pkce/).
// Create a PKCE code verifier and SHA-256 encode it as a code challenge.
let (pkce_code_challenge, pkce_code_verifier) = PkceCodeChallenge::new_random_sha256();
Expand Down Expand Up @@ -131,7 +136,7 @@ fn main() {
.exchange_code(code)
// Send the PKCE code verifier in the token request
.set_pkce_verifier(pkce_code_verifier)
.request(http_client);
.request(&http_client);

println!("MS Graph returned the following token:\n{:?}\n", token);
}
13 changes: 9 additions & 4 deletions examples/wunderlist.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,13 @@ use oauth2::basic::{
BasicTokenType,
};
use oauth2::helpers;
use oauth2::{StandardRevocableToken, TokenType};
// Alternatively, this can be `oauth2::curl::http_client` or a custom client.
use oauth2::reqwest::http_client;
use oauth2::reqwest::reqwest;
use oauth2::{
AccessToken, AuthUrl, AuthorizationCode, Client, ClientId, ClientSecret, CsrfToken,
EmptyExtraTokenFields, ExtraTokenFields, RedirectUrl, RefreshToken, Scope, TokenResponse,
TokenUrl,
};
use oauth2::{StandardRevocableToken, TokenType};
use serde::{Deserialize, Serialize};
use url::Url;

Expand Down Expand Up @@ -157,6 +156,12 @@ fn main() {
RedirectUrl::new("http://localhost:8080".to_string()).expect("Invalid redirect URL"),
);

let http_client = reqwest::blocking::ClientBuilder::new()
// Following redirects opens the client up to SSRF vulnerabilities.
.redirect(reqwest::redirect::Policy::none())
.build()
.expect("Client should build");

// Generate the authorization URL to which we'll redirect the user.
let (authorize_url, csrf_state) = client.authorize_url(CsrfToken::new_random).url();

Expand Down Expand Up @@ -217,7 +222,7 @@ fn main() {
.exchange_code(code)
.add_extra_param("client_id", client_id_str)
.add_extra_param("client_secret", client_secret_str)
.request(http_client);
.request(&http_client);

println!(
"Wunderlist returned the following token:\n{:?}\n",
Expand Down
24 changes: 12 additions & 12 deletions src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,9 @@ use std::sync::Arc;
/// # use thiserror::Error;
/// # use http::status::StatusCode;
/// # use http::header::{HeaderValue, CONTENT_TYPE};
/// # use http::Response;
/// # use oauth2::{*, basic::*};
/// #
/// # let client = BasicClient::new(ClientId::new("aaa".to_string()))
/// # .set_client_secret(ClientSecret::new("bbb".to_string()))
/// # .set_auth_uri(AuthUrl::new("https://example.com/auth".to_string()).unwrap())
Expand All @@ -55,25 +57,23 @@ use std::sync::Arc;
/// # }
/// #
/// # let http_client = |_| -> Result<HttpResponse, FakeError> {
/// # Ok(HttpResponse {
/// # status_code: StatusCode::BAD_REQUEST,
/// # headers: vec![(
/// # CONTENT_TYPE,
/// # HeaderValue::from_str("application/json").unwrap(),
/// # )]
/// # .into_iter()
/// # .collect(),
/// # body: "{\"error\": \"unsupported_token_type\", \"error_description\": \"stuff happened\", \
/// # \"error_uri\": \"https://errors\"}"
/// # Ok(Response::builder()
/// # .status(StatusCode::BAD_REQUEST)
/// # .header(CONTENT_TYPE, HeaderValue::from_str("application/json").unwrap())
/// # .body(
/// # r#"{"error": "unsupported_token_type",
/// # "error_description": "stuff happened",
/// # "error_uri": "https://errors"}"#
/// # .to_string()
/// # .into_bytes(),
/// # })
/// # )
/// # .unwrap())
/// # };
/// #
/// let res = client
/// .revoke_token(AccessToken::new("some token".to_string()).into())
/// .unwrap()
/// .request(http_client);
/// .request(&http_client);
///
/// assert!(matches!(res, Err(
/// RequestTokenError::ServerResponse(err)) if matches!(err.error(),
Expand Down
Loading

0 comments on commit 23b952b

Please sign in to comment.