diff --git a/crates/mempool_infra/src/component_client.rs b/crates/mempool_infra/src/component_client.rs index 9453ea65..3e7a8015 100644 --- a/crates/mempool_infra/src/component_client.rs +++ b/crates/mempool_infra/src/component_client.rs @@ -2,6 +2,7 @@ use std::marker::PhantomData; use std::net::IpAddr; use bincode::{deserialize, serialize}; +use hyper::body::to_bytes; use hyper::header::CONTENT_TYPE; use hyper::{Body, Client, Request as HyperRequest, Uri}; use serde::{Deserialize, Serialize}; @@ -73,26 +74,45 @@ where Self { uri, _req: PhantomData, _res: PhantomData } } - pub async fn send(&self, component_request: Request) -> Response { + pub async fn send(&self, component_request: Request) -> ClientResult { let http_request = HyperRequest::post(self.uri.clone()) .header(CONTENT_TYPE, APPLICATION_OCTET_STREAM) .body(Body::from( serialize(&component_request).expect("Request serialization should succeed"), )) - .expect("Request builidng should succeed"); + .expect("Request building should succeed"); // Todo(uriel): Add configuration to control number of retries - let http_response = - Client::new().request(http_request).await.expect("Could not connect to server"); - let body_bytes = hyper::body::to_bytes(http_response.into_body()) + let http_response = Client::new() + .request(http_request) .await - .expect("Could not get response from server"); - deserialize(&body_bytes).expect("Response deserialization should succeed") + .map_err(|_e| ClientError::CommunicationFailure)?; // Todo(uriel): To be split into multiple errors + let body_bytes = to_bytes(http_response.into_body()) + .await + .map_err(|e| ClientError::BodyExtractionFailure(e.to_string()))?; + Ok(deserialize(&body_bytes).expect("Response deserialization should succeed")) + } +} + +// Can't derive because derive forces the generics to also be `Clone`, which we prefer not to do +// since it'll require the generic Request and Response types to be cloneable. +impl Clone for ComponentClientHttp +where + Request: Serialize, + Response: for<'a> Deserialize<'a>, +{ + fn clone(&self) -> Self { + Self { uri: self.uri.clone(), _req: PhantomData, _res: PhantomData } } } -#[derive(Debug, Error)] +#[derive(Debug, Error, PartialEq)] pub enum ClientError { + // Todo(uriel): Split this error into more fine grained errors and add a dedicated test + #[error("Could not connect to server")] + CommunicationFailure, + #[error("Could not extract body from HTTP response: {0}")] + BodyExtractionFailure(String), #[error("Got an unexpected response type.")] UnexpectedResponse, } diff --git a/crates/mempool_infra/tests/component_server_client_http_test.rs b/crates/mempool_infra/tests/component_server_client_http_test.rs index 83724659..ff7e021d 100644 --- a/crates/mempool_infra/tests/component_server_client_http_test.rs +++ b/crates/mempool_infra/tests/component_server_client_http_test.rs @@ -1,11 +1,9 @@ mod common; -use std::net::IpAddr; - use async_trait::async_trait; use common::{ComponentAClientTrait, ComponentBClientTrait, ResultA, ResultB}; use serde::{Deserialize, Serialize}; -use starknet_mempool_infra::component_client::ComponentClientHttp; +use starknet_mempool_infra::component_client::{ClientError, ComponentClientHttp}; use starknet_mempool_infra::component_definitions::ComponentRequestHandler; use starknet_mempool_infra::component_server::ComponentServerHttp; use tokio::task; @@ -27,7 +25,7 @@ pub enum ComponentAResponse { #[async_trait] impl ComponentAClientTrait for ComponentClientHttp { async fn a_get_value(&self) -> ResultA { - match self.send(ComponentARequest::AGetValue).await { + match self.send(ComponentARequest::AGetValue).await? { ComponentAResponse::Value(value) => Ok(value), } } @@ -57,7 +55,7 @@ pub enum ComponentBResponse { #[async_trait] impl ComponentBClientTrait for ComponentClientHttp { async fn b_get_value(&self) -> ResultB { - match self.send(ComponentBRequest::BGetValue).await { + match self.send(ComponentBRequest::BGetValue).await? { ComponentBResponse::Value(value) => Ok(value), } } @@ -72,11 +70,20 @@ impl ComponentRequestHandler for Componen } } -async fn verify_response(ip_address: IpAddr, port: u16, expected_value: ValueA) { - let a_client = ComponentClientHttp::new(ip_address, port); +async fn verify_response( + a_client: ComponentClientHttp, + expected_value: ValueA, +) { assert_eq!(a_client.a_get_value().await.unwrap(), expected_value); } +async fn verify_error( + a_client: ComponentClientHttp, + expected_error: ClientError, +) { + assert_eq!(a_client.a_get_value().await, Err(expected_error)); +} + #[tokio::test] async fn test_setup() { let setup_value: ValueB = 90; @@ -91,8 +98,10 @@ async fn test_setup() { let b_client = ComponentClientHttp::::new(local_ip, b_port); + verify_error(a_client.clone(), ClientError::CommunicationFailure).await; + let component_a = ComponentA::new(Box::new(b_client)); - let component_b = ComponentB::new(setup_value, Box::new(a_client)); + let component_b = ComponentB::new(setup_value, Box::new(a_client.clone())); let mut component_a_server = ComponentServerHttp::< ComponentA, @@ -116,5 +125,5 @@ async fn test_setup() { // Todo(uriel): Get rid of this task::yield_now().await; - verify_response(local_ip, a_port, expected_value).await; + verify_response(a_client.clone(), expected_value).await; }