Skip to content

Commit

Permalink
chore: change internal component client http to handle errors in fail…
Browse files Browse the repository at this point in the history
…ure and propogate it
  • Loading branch information
uriel-starkware committed Jul 1, 2024
1 parent 6734b2d commit f41901b
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 16 deletions.
29 changes: 22 additions & 7 deletions crates/mempool_infra/src/component_client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use std::marker::PhantomData;
use std::net::IpAddr;

use bincode::{deserialize, serialize};
use hyper::body::to_bytes;
use hyper::header::CONTENT_TYPE;
use hyper::{Body, Client, Request as HyperRequest, Uri};
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -73,7 +74,7 @@ where
Self { uri, _req: PhantomData, _res: PhantomData }
}

pub async fn send(&self, component_request: Request) -> Response {
pub async fn send(&self, component_request: Request) -> ClientResult<Response> {
let http_request = HyperRequest::post(self.uri.clone())
.header(CONTENT_TYPE, APPLICATION_OCTET_STREAM)
.body(Body::from(
Expand All @@ -82,17 +83,31 @@ where
.expect("Request builidng should succeed");

// Todo(uriel): Add configuration to control number of retries
let http_response =
Client::new().request(http_request).await.expect("Could not connect to server");
let body_bytes = hyper::body::to_bytes(http_response.into_body())
let http_response = Client::new()
.request(http_request)
.await
.expect("Could not get response from server");
deserialize(&body_bytes).expect("Response deserialization should succeed")
.map_err(|_e| ClientError::ConnectionFailure)?;
let body_bytes = to_bytes(http_response.into_body()).await.expect("Body should exist");
Ok(deserialize(&body_bytes).expect("Response deserialization should succeed"))
}
}

#[derive(Debug, Error)]
// Can't derive because derive forces the generics to also be `Clone`, which we prefer not to do
// since it'll require transactions to be cloneable.
impl<Request, Response> Clone for ComponentClientHttp<Request, Response>
where
Request: Serialize,
Response: for<'a> Deserialize<'a>,
{
fn clone(&self) -> Self {
Self { uri: self.uri.clone(), _req: PhantomData, _res: PhantomData }
}
}

#[derive(Debug, Error, PartialEq)]
pub enum ClientError {
#[error("Could not conenct to server")]
ConnectionFailure,
#[error("Got an unexpected response type.")]
UnexpectedResponse,
}
Expand Down
27 changes: 18 additions & 9 deletions crates/mempool_infra/tests/component_server_client_http_test.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
mod common;

use std::net::IpAddr;

use async_trait::async_trait;
use common::{ComponentAClientTrait, ComponentBClientTrait, ResultA, ResultB};
use serde::{Deserialize, Serialize};
use starknet_mempool_infra::component_client::ComponentClientHttp;
use starknet_mempool_infra::component_client::{ClientError, ComponentClientHttp};
use starknet_mempool_infra::component_definitions::ComponentRequestHandler;
use starknet_mempool_infra::component_server::ComponentServerHttp;
use tokio::task;
Expand All @@ -27,7 +25,7 @@ pub enum ComponentAResponse {
#[async_trait]
impl ComponentAClientTrait for ComponentClientHttp<ComponentARequest, ComponentAResponse> {
async fn a_get_value(&self) -> ResultA {
match self.send(ComponentARequest::AGetValue).await {
match self.send(ComponentARequest::AGetValue).await? {
ComponentAResponse::Value(value) => Ok(value),
}
}
Expand Down Expand Up @@ -57,7 +55,7 @@ pub enum ComponentBResponse {
#[async_trait]
impl ComponentBClientTrait for ComponentClientHttp<ComponentBRequest, ComponentBResponse> {
async fn b_get_value(&self) -> ResultB {
match self.send(ComponentBRequest::BGetValue).await {
match self.send(ComponentBRequest::BGetValue).await? {
ComponentBResponse::Value(value) => Ok(value),
}
}
Expand All @@ -72,11 +70,20 @@ impl ComponentRequestHandler<ComponentBRequest, ComponentBResponse> for Componen
}
}

async fn verify_response(ip_address: IpAddr, port: u16, expected_value: ValueA) {
let a_client = ComponentClientHttp::new(ip_address, port);
async fn verify_response(
a_client: &ComponentClientHttp<ComponentARequest, ComponentAResponse>,
expected_value: ValueA,
) {
assert_eq!(a_client.a_get_value().await.unwrap(), expected_value);
}

async fn verify_error(
a_client: &ComponentClientHttp<ComponentARequest, ComponentAResponse>,
expected_error: ClientError,
) {
assert_eq!(a_client.a_get_value().await, Err(expected_error));
}

#[tokio::test]
async fn test_setup() {
let setup_value: ValueB = 90;
Expand All @@ -91,8 +98,10 @@ async fn test_setup() {
let b_client =
ComponentClientHttp::<ComponentBRequest, ComponentBResponse>::new(local_ip, b_port);

verify_error(&a_client, ClientError::ConnectionFailure).await;

let component_a = ComponentA::new(Box::new(b_client));
let component_b = ComponentB::new(setup_value, Box::new(a_client));
let component_b = ComponentB::new(setup_value, Box::new(a_client.clone()));

let mut component_a_server = ComponentServerHttp::<
ComponentA,
Expand All @@ -116,5 +125,5 @@ async fn test_setup() {
// Todo(uriel): Get rid of this
task::yield_now().await;

verify_response(local_ip, a_port, expected_value).await;
verify_response(&a_client, expected_value).await;
}

0 comments on commit f41901b

Please sign in to comment.