diff --git a/crates/mempool_infra/src/component_client.rs b/crates/mempool_infra/src/component_client.rs index bed4beea..6ed0b00a 100644 --- a/crates/mempool_infra/src/component_client.rs +++ b/crates/mempool_infra/src/component_client.rs @@ -4,12 +4,17 @@ use std::net::IpAddr; use bincode::{deserialize, serialize, ErrorKind}; use hyper::body::to_bytes; use hyper::header::CONTENT_TYPE; -use hyper::{Body, Client, Error as HyperError, Request as HyperRequest, Uri}; +use hyper::{ + Body, Client, Error as HyperError, Request as HyperRequest, Response as HyperResponse, + StatusCode, Uri, +}; use serde::{Deserialize, Serialize}; use thiserror::Error; use tokio::sync::mpsc::{channel, Sender}; -use crate::component_definitions::{ComponentRequestAndResponseSender, APPLICATION_OCTET_STREAM}; +use crate::component_definitions::{ + ComponentRequestAndResponseSender, ServerError, APPLICATION_OCTET_STREAM, +}; pub struct ComponentClient where @@ -90,13 +95,26 @@ where // Todo(uriel): Add configuration for controlling the number of retries. let http_response = self.client.request(http_request).await.map_err(ClientError::CommunicationFailure)?; - let body_bytes = to_bytes(http_response.into_body()) - .await - .map_err(ClientError::ResponseParsingFailure)?; - deserialize(&body_bytes).map_err(ClientError::ResponseDeserializationFailure) + + match http_response.status() { + StatusCode::OK => get_response_body(http_response).await, + status_code => Err(ClientError::ResponseError( + status_code, + get_response_body(http_response).await?, + )), + } } } +async fn get_response_body(response: HyperResponse) -> Result +where + T: for<'a> Deserialize<'a>, +{ + let body_bytes = + to_bytes(response.into_body()).await.map_err(ClientError::ResponseParsingFailure)?; + deserialize(&body_bytes).map_err(ClientError::ResponseDeserializationFailure) +} + // Can't derive because derive forces the generics to also be `Clone`, which we prefer not to do // since it'll require the generic Request and Response types to be cloneable. impl Clone for ComponentClientHttp @@ -118,12 +136,14 @@ where pub enum ClientError { #[error("Communication error: {0}")] CommunicationFailure(HyperError), + #[error("Could not deserialize server response: {0}")] + ResponseDeserializationFailure(Box), #[error("Could not parse the response: {0}")] ResponseParsingFailure(HyperError), + #[error("Got status code: {0}, with server error: {1}")] + ResponseError(StatusCode, ServerError), #[error("Got an unexpected response type.")] UnexpectedResponse, - #[error("Could not deserialize server response: {0}")] - ResponseDeserializationFailure(Box), } pub type ClientResult = Result; diff --git a/crates/mempool_infra/src/component_definitions.rs b/crates/mempool_infra/src/component_definitions.rs index 71c91cf0..a0fcb755 100644 --- a/crates/mempool_infra/src/component_definitions.rs +++ b/crates/mempool_infra/src/component_definitions.rs @@ -1,4 +1,6 @@ use async_trait::async_trait; +use serde::{Deserialize, Serialize}; +use thiserror::Error; use tokio::sync::mpsc::Sender; #[async_trait] @@ -16,3 +18,9 @@ where } pub const APPLICATION_OCTET_STREAM: &str = "application/octet-stream"; + +#[derive(Debug, Error, Deserialize, Serialize)] +pub enum ServerError { + #[error("Could not deserialize client request: {0}")] + RequestDeserializationFailure(String), +} diff --git a/crates/mempool_infra/src/component_server.rs b/crates/mempool_infra/src/component_server.rs index 7ec15c75..94f6d8ef 100644 --- a/crates/mempool_infra/src/component_server.rs +++ b/crates/mempool_infra/src/component_server.rs @@ -7,13 +7,14 @@ use bincode::{deserialize, serialize}; use hyper::body::to_bytes; use hyper::header::CONTENT_TYPE; use hyper::service::{make_service_fn, service_fn}; -use hyper::{Body, Request as HyperRequest, Response as HyperResponse, Server}; +use hyper::{Body, Request as HyperRequest, Response as HyperResponse, Server, StatusCode}; use serde::{Deserialize, Serialize}; use tokio::sync::mpsc::Receiver; use tokio::sync::Mutex; use crate::component_definitions::{ - ComponentRequestAndResponseSender, ComponentRequestHandler, APPLICATION_OCTET_STREAM, + ComponentRequestAndResponseSender, ComponentRequestHandler, ServerError, + APPLICATION_OCTET_STREAM, }; use crate::component_runner::ComponentRunner; @@ -111,21 +112,27 @@ where component: Arc>, ) -> Result, hyper::Error> { let body_bytes = to_bytes(http_request.into_body()).await?; - let component_request: Request = - deserialize(&body_bytes).expect("Request deserialization should succeed"); - - // Acquire the lock for component computation, release afterwards. - let component_response; - { - let mut component_guard = component.lock().await; - component_response = component_guard.handle_request(component_request).await; + let http_response = match deserialize(&body_bytes) { + Ok(component_request) => { + // Acquire the lock for component computation, release afterwards. + let component_response = + { component.lock().await.handle_request(component_request).await }; + HyperResponse::builder() + .status(StatusCode::OK) + .header(CONTENT_TYPE, APPLICATION_OCTET_STREAM) + .body(Body::from( + serialize(&component_response) + .expect("Response serialization should succeed"), + )) + } + Err(error) => { + let server_error = ServerError::RequestDeserializationFailure(error.to_string()); + HyperResponse::builder().status(StatusCode::BAD_REQUEST).body(Body::from( + serialize(&server_error).expect("Server error serialization should succeed"), + )) + } } - let http_response = HyperResponse::builder() - .header(CONTENT_TYPE, APPLICATION_OCTET_STREAM) - .body(Body::from( - serialize(&component_response).expect("Response serialization should succeed"), - )) - .expect("Response builidng should succeed"); + .expect("Response building should succeed"); Ok(http_response) }