Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace map_err() conversions with a From call via the Try operator #239

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 33 additions & 48 deletions src/curl.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ use super::{HttpRequest, HttpResponse};
pub enum Error {
/// Error returned by curl crate.
#[error("curl request failed")]
Curl(#[source] curl::Error),
Curl(#[from] curl::Error),
/// Non-curl HTTP error.
#[error("HTTP error")]
Http(#[source] http::Error),
Http(#[from] http::Error),
/// Other error.
#[error("Other error: {}", _0)]
Other(String),
Expand All @@ -28,34 +28,27 @@ pub enum Error {
///
pub fn http_client(request: HttpRequest) -> Result<HttpResponse, Error> {
let mut easy = Easy::new();
easy.url(&request.url.to_string()[..])
.map_err(Error::Curl)?;
easy.url(&request.url.to_string()[..])?;

let mut headers = curl::easy::List::new();
request
.headers
.iter()
.map(|(name, value)| {
headers
.append(&format!(
"{}: {}",
name,
value.to_str().map_err(|_| Error::Other(format!(
"invalid {} header value {:?}",
name,
value.as_bytes()
)))?
))
.map_err(Error::Curl)
})
.collect::<Result<_, _>>()?;
for (name, value) in &request.headers {
headers.append(&format!(
"{}: {}",
name,
// TODO: Unnecessary fallibility, curl uses a CString under the hood
value.to_str().map_err(|_| Error::Other(format!(
"invalid {} header value {:?}",
name,
value.as_bytes()
)))?
))?
}

easy.http_headers(headers).map_err(Error::Curl)?;
easy.http_headers(headers)?;

if let Method::POST = request.method {
easy.post(true).map_err(Error::Curl)?;
easy.post_field_size(request.body.len() as u64)
.map_err(Error::Curl)?;
easy.post(true)?;
easy.post_field_size(request.body.len() as u64)?;
} else {
assert_eq!(request.method, Method::GET);
}
Expand All @@ -65,37 +58,29 @@ pub fn http_client(request: HttpRequest) -> Result<HttpResponse, Error> {
{
let mut transfer = easy.transfer();

transfer
.read_function(|buf| Ok(form_slice.read(buf).unwrap_or(0)))
.map_err(Error::Curl)?;
transfer.read_function(|buf| Ok(form_slice.read(buf).unwrap_or(0)))?;

transfer
.write_function(|new_data| {
data.extend_from_slice(new_data);
Ok(new_data.len())
})
.map_err(Error::Curl)?;
transfer.write_function(|new_data| {
data.extend_from_slice(new_data);
Ok(new_data.len())
})?;

transfer.perform().map_err(Error::Curl)?;
transfer.perform()?;
}

let status_code = easy.response_code().map_err(Error::Curl)? as u16;
let status_code = easy.response_code()? as u16;

Ok(HttpResponse {
status_code: StatusCode::from_u16(status_code).map_err(|err| Error::Http(err.into()))?,
status_code: StatusCode::from_u16(status_code).map_err(http::Error::from)?,
headers: easy
.content_type()
.map_err(Error::Curl)?
.map(|content_type| {
Ok(vec![(
CONTENT_TYPE,
HeaderValue::from_str(content_type).map_err(|err| Error::Http(err.into()))?,
)]
.into_iter()
.collect::<HeaderMap>())
})
.content_type()?
.map(|content_type| HeaderValue::from_str(content_type).map_err(http::Error::from))
.transpose()?
.unwrap_or_else(HeaderMap::new),
.map_or_else(HeaderMap::new, |content_type| {
vec![(CONTENT_TYPE, content_type)]
.into_iter()
.collect::<HeaderMap>()
}),
body: data,
})
}
66 changes: 21 additions & 45 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1309,9 +1309,7 @@ where
F: FnOnce(HttpRequest) -> Result<HttpResponse, RE>,
RE: Error + 'static,
{
http_client(self.prepare_request()?)
.map_err(RequestTokenError::Request)
.and_then(endpoint_response)
endpoint_response(http_client(self.prepare_request()?)?)
}

///
Expand All @@ -1327,9 +1325,7 @@ where
RE: Error + 'static,
{
let http_request = self.prepare_request()?;
let http_response = http_client(http_request)
.await
.map_err(RequestTokenError::Request)?;
let http_response = http_client(http_request).await?;
endpoint_response(http_response)
}
}
Expand Down Expand Up @@ -1412,9 +1408,7 @@ where
F: FnOnce(HttpRequest) -> Result<HttpResponse, RE>,
RE: Error + 'static,
{
http_client(self.prepare_request()?)
.map_err(RequestTokenError::Request)
.and_then(endpoint_response)
endpoint_response(http_client(self.prepare_request()?)?)
}
///
/// Asynchronously sends the request to the authorization server and awaits a response.
Expand All @@ -1429,9 +1423,7 @@ where
RE: Error + 'static,
{
let http_request = self.prepare_request()?;
let http_response = http_client(http_request)
.await
.map_err(RequestTokenError::Request)?;
let http_response = http_client(http_request).await?;
endpoint_response(http_response)
}

Expand Down Expand Up @@ -1536,9 +1528,7 @@ where
F: FnOnce(HttpRequest) -> Result<HttpResponse, RE>,
RE: Error + 'static,
{
http_client(self.prepare_request()?)
.map_err(RequestTokenError::Request)
.and_then(endpoint_response)
endpoint_response(http_client(self.prepare_request()?)?)
}

///
Expand All @@ -1554,9 +1544,7 @@ where
RE: Error + 'static,
{
let http_request = self.prepare_request()?;
let http_response = http_client(http_request)
.await
.map_err(RequestTokenError::Request)?;
let http_response = http_client(http_request).await?;
endpoint_response(http_response)
}

Expand Down Expand Up @@ -1660,9 +1648,7 @@ where
F: FnOnce(HttpRequest) -> Result<HttpResponse, RE>,
RE: Error + 'static,
{
http_client(self.prepare_request()?)
.map_err(RequestTokenError::Request)
.and_then(endpoint_response)
endpoint_response(http_client(self.prepare_request()?)?)
}

///
Expand All @@ -1678,9 +1664,7 @@ where
RE: Error + 'static,
{
let http_request = self.prepare_request()?;
let http_response = http_client(http_request)
.await
.map_err(RequestTokenError::Request)?;
let http_response = http_client(http_request).await?;
endpoint_response(http_response)
}

Expand Down Expand Up @@ -1810,9 +1794,7 @@ where
F: FnOnce(HttpRequest) -> Result<HttpResponse, RE>,
RE: Error + 'static,
{
http_client(self.prepare_request()?)
.map_err(RequestTokenError::Request)
.and_then(endpoint_response)
endpoint_response(http_client(self.prepare_request()?)?)
}

///
Expand All @@ -1828,9 +1810,7 @@ where
RE: Error + 'static,
{
let http_request = self.prepare_request()?;
let http_response = http_client(http_request)
.await
.map_err(RequestTokenError::Request)?;
let http_response = http_client(http_request).await?;
endpoint_response(http_response)
}
}
Expand Down Expand Up @@ -1923,9 +1903,7 @@ where
// From https://tools.ietf.org/html/rfc7009#section-2.2:
// "The content of the response body is ignored by the client as all
// necessary information is conveyed in the response code."
http_client(self.prepare_request()?)
.map_err(RequestTokenError::Request)
.and_then(endpoint_response_status_only)
endpoint_response_status_only(http_client(self.prepare_request()?)?)
}

///
Expand All @@ -1941,9 +1919,7 @@ where
RE: Error + 'static,
{
let http_request = self.prepare_request()?;
let http_response = http_client(http_request)
.await
.map_err(RequestTokenError::Request)?;
let http_response = http_client(http_request).await?;
endpoint_response_status_only(http_response)
}
}
Expand Down Expand Up @@ -2224,9 +2200,7 @@ where
RE: Error + 'static,
EF: ExtraDeviceAuthorizationFields,
{
http_client(self.prepare_request()?)
.map_err(RequestTokenError::Request)
.and_then(endpoint_response)
endpoint_response(http_client(self.prepare_request()?)?)
}

///
Expand All @@ -2243,9 +2217,7 @@ where
EF: ExtraDeviceAuthorizationFields,
{
let http_request = self.prepare_request()?;
let http_response = http_client(http_request)
.await
.map_err(RequestTokenError::Request)?;
let http_response = http_client(http_request).await?;
endpoint_response(http_response)
}
}
Expand Down Expand Up @@ -2504,8 +2476,12 @@ where
// use that, otherwise use the value given by the device authorization
// response.
let timeout_dur = timeout.unwrap_or_else(|| self.dev_auth_resp.expires_in());
let chrono_timeout = chrono::Duration::from_std(timeout_dur)
.map_err(|_| RequestTokenError::Other("Failed to convert duration".to_string()))?;
let chrono_timeout = chrono::Duration::from_std(timeout_dur).map_err(|e| {
RequestTokenError::Other(format!(
"Failed to convert `{:?}` to `chrono::Duration`: {}",
timeout_dur, e
))
})?;

// Calculate the DateTime at which the request times out.
let timeout_dt = (*self.time_fn)()
Expand Down Expand Up @@ -3179,7 +3155,7 @@ where
/// connectivity failed).
///
#[error("Request failed")]
Request(#[source] RE),
Request(#[from] RE),
///
/// Failed to parse server response. Parse errors may occur while parsing either successful
/// or error responses.
Expand Down
16 changes: 8 additions & 8 deletions src/ureq.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,18 @@ pub enum Error {
/// Synchronous HTTP client for ureq.
///
pub fn http_client(request: HttpRequest) -> Result<HttpResponse, Error> {
let mut req = if let Method::POST = request.method {
ureq::post(&request.url.to_string())
let mut req = if request.method == Method::POST {
ureq::post(request.url.as_ref())
} else {
ureq::get(&request.url.to_string())
ureq::get(request.url.as_ref())
};

for (name, value) in request.headers {
if let Some(name) = name {
req = req.set(
&name.to_string(),
name.as_ref(),
// TODO: In newer `ureq` it should be easier to convert arbitrary byte sequences
// without unnecessary UTF-8 fallibility here.
value.to_str().map_err(|_| {
Error::Other(format!(
"invalid {} header value {:?}",
Expand All @@ -59,12 +61,10 @@ pub fn http_client(request: HttpRequest) -> Result<HttpResponse, Error> {
.map_err(Box::new)?;

Ok(HttpResponse {
status_code: StatusCode::from_u16(response.status())
.map_err(|err| Error::Http(err.into()))?,
status_code: StatusCode::from_u16(response.status()).map_err(http::Error::from)?,
headers: vec![(
CONTENT_TYPE,
HeaderValue::from_str(response.content_type())
.map_err(|err| Error::Http(err.into()))?,
HeaderValue::from_str(response.content_type()).map_err(http::Error::from)?,
)]
.into_iter()
.collect::<HeaderMap>(),
Expand Down
Loading