Skip to content

Commit

Permalink
refactor: return an error from server in case request deserialization…
Browse files Browse the repository at this point in the history
… fails
  • Loading branch information
uriel-starkware committed Jul 8, 2024
1 parent 2c7b230 commit 6477df2
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 24 deletions.
36 changes: 28 additions & 8 deletions crates/mempool_infra/src/component_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,17 @@ use std::net::IpAddr;
use bincode::{deserialize, serialize, ErrorKind};
use hyper::body::to_bytes;
use hyper::header::CONTENT_TYPE;
use hyper::{Body, Client, Error as HyperError, Request as HyperRequest, Uri};
use hyper::{
Body, Client, Error as HyperError, Request as HyperRequest, Response as HyperResponse,
StatusCode, Uri,
};
use serde::{Deserialize, Serialize};
use thiserror::Error;
use tokio::sync::mpsc::{channel, Sender};

use crate::component_definitions::{ComponentRequestAndResponseSender, APPLICATION_OCTET_STREAM};
use crate::component_definitions::{
ComponentRequestAndResponseSender, ServerError, APPLICATION_OCTET_STREAM,
};

pub struct ComponentClient<Request, Response>
where
Expand Down Expand Up @@ -90,13 +95,26 @@ where
// Todo(uriel): Add configuration for controlling the number of retries.
let http_response =
self.client.request(http_request).await.map_err(ClientError::CommunicationFailure)?;
let body_bytes = to_bytes(http_response.into_body())
.await
.map_err(ClientError::ResponseParsingFailure)?;
deserialize(&body_bytes).map_err(ClientError::ResponseDeserializationFailure)

match http_response.status() {
StatusCode::OK => get_response_body(http_response).await,
status_code => Err(ClientError::ResponseError(
status_code,
get_response_body(http_response).await?,
)),
}
}
}

async fn get_response_body<T>(response: HyperResponse<Body>) -> Result<T, ClientError>
where
T: for<'a> Deserialize<'a>,
{
let body_bytes =
to_bytes(response.into_body()).await.map_err(ClientError::ResponseParsingFailure)?;
deserialize(&body_bytes).map_err(ClientError::ResponseDeserializationFailure)
}

// Can't derive because derive forces the generics to also be `Clone`, which we prefer not to do
// since it'll require the generic Request and Response types to be cloneable.
impl<Request, Response> Clone for ComponentClientHttp<Request, Response>
Expand All @@ -118,12 +136,14 @@ where
pub enum ClientError {
#[error("Communication error: {0}")]
CommunicationFailure(HyperError),
#[error("Could not deserialize server response: {0}")]
ResponseDeserializationFailure(Box<ErrorKind>),
#[error("Could not parse the response: {0}")]
ResponseParsingFailure(HyperError),
#[error("Got status code: {0}, with server error: {1}")]
ResponseError(StatusCode, ServerError),
#[error("Got an unexpected response type.")]
UnexpectedResponse,
#[error("Could not deserialize server response: {0}")]
ResponseDeserializationFailure(Box<ErrorKind>),
}

pub type ClientResult<T> = Result<T, ClientError>;
8 changes: 8 additions & 0 deletions crates/mempool_infra/src/component_definitions.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use tokio::sync::mpsc::Sender;

#[async_trait]
Expand All @@ -16,3 +18,9 @@ where
}

pub const APPLICATION_OCTET_STREAM: &str = "application/octet-stream";

#[derive(Debug, Error, Deserialize, Serialize)]
pub enum ServerError {
#[error("Could not deserialize client request: {0}")]
RequestDeserializationFailure(String),
}
39 changes: 23 additions & 16 deletions crates/mempool_infra/src/component_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@ use bincode::{deserialize, serialize};
use hyper::body::to_bytes;
use hyper::header::CONTENT_TYPE;
use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Request as HyperRequest, Response as HyperResponse, Server};
use hyper::{Body, Request as HyperRequest, Response as HyperResponse, Server, StatusCode};
use serde::{Deserialize, Serialize};
use tokio::sync::mpsc::Receiver;
use tokio::sync::Mutex;

use crate::component_definitions::{
ComponentRequestAndResponseSender, ComponentRequestHandler, APPLICATION_OCTET_STREAM,
ComponentRequestAndResponseSender, ComponentRequestHandler, ServerError,
APPLICATION_OCTET_STREAM,
};
use crate::component_runner::ComponentRunner;

Expand Down Expand Up @@ -111,21 +112,27 @@ where
component: Arc<Mutex<Component>>,
) -> Result<HyperResponse<Body>, hyper::Error> {
let body_bytes = to_bytes(http_request.into_body()).await?;
let component_request: Request =
deserialize(&body_bytes).expect("Request deserialization should succeed");

// Acquire the lock for component computation, release afterwards.
let component_response;
{
let mut component_guard = component.lock().await;
component_response = component_guard.handle_request(component_request).await;
let http_response = match deserialize(&body_bytes) {
Ok(component_request) => {
// Acquire the lock for component computation, release afterwards.
let component_response =
{ component.lock().await.handle_request(component_request).await };
HyperResponse::builder()
.status(StatusCode::OK)
.header(CONTENT_TYPE, APPLICATION_OCTET_STREAM)
.body(Body::from(
serialize(&component_response)
.expect("Response serialization should succeed"),
))
}
Err(error) => {
let server_error = ServerError::RequestDeserializationFailure(error.to_string());
HyperResponse::builder().status(StatusCode::BAD_REQUEST).body(Body::from(
serialize(&server_error).expect("Server error serialization should succeed"),
))
}
}
let http_response = HyperResponse::builder()
.header(CONTENT_TYPE, APPLICATION_OCTET_STREAM)
.body(Body::from(
serialize(&component_response).expect("Response serialization should succeed"),
))
.expect("Response builidng should succeed");
.expect("Response building should succeed");

Ok(http_response)
}
Expand Down

0 comments on commit 6477df2

Please sign in to comment.