From f41901b2ebc86c54ac567147ac13313ff151df78 Mon Sep 17 00:00:00 2001 From: Uriel Korach Date: Mon, 1 Jul 2024 17:43:47 +0300 Subject: [PATCH] chore: change internal component client http to handle errors in failure and propogate it --- crates/mempool_infra/src/component_client.rs | 29 ++++++++++++++----- .../component_server_client_http_test.rs | 27 +++++++++++------ 2 files changed, 40 insertions(+), 16 deletions(-) diff --git a/crates/mempool_infra/src/component_client.rs b/crates/mempool_infra/src/component_client.rs index 9453ea65e..dbd51d9b5 100644 --- a/crates/mempool_infra/src/component_client.rs +++ b/crates/mempool_infra/src/component_client.rs @@ -2,6 +2,7 @@ use std::marker::PhantomData; use std::net::IpAddr; use bincode::{deserialize, serialize}; +use hyper::body::to_bytes; use hyper::header::CONTENT_TYPE; use hyper::{Body, Client, Request as HyperRequest, Uri}; use serde::{Deserialize, Serialize}; @@ -73,7 +74,7 @@ where Self { uri, _req: PhantomData, _res: PhantomData } } - pub async fn send(&self, component_request: Request) -> Response { + pub async fn send(&self, component_request: Request) -> ClientResult { let http_request = HyperRequest::post(self.uri.clone()) .header(CONTENT_TYPE, APPLICATION_OCTET_STREAM) .body(Body::from( @@ -82,17 +83,31 @@ where .expect("Request builidng should succeed"); // Todo(uriel): Add configuration to control number of retries - let http_response = - Client::new().request(http_request).await.expect("Could not connect to server"); - let body_bytes = hyper::body::to_bytes(http_response.into_body()) + let http_response = Client::new() + .request(http_request) .await - .expect("Could not get response from server"); - deserialize(&body_bytes).expect("Response deserialization should succeed") + .map_err(|_e| ClientError::ConnectionFailure)?; + let body_bytes = to_bytes(http_response.into_body()).await.expect("Body should exist"); + Ok(deserialize(&body_bytes).expect("Response deserialization should succeed")) } } -#[derive(Debug, Error)] +// Can't derive because derive forces the generics to also be `Clone`, which we prefer not to do +// since it'll require transactions to be cloneable. +impl Clone for ComponentClientHttp +where + Request: Serialize, + Response: for<'a> Deserialize<'a>, +{ + fn clone(&self) -> Self { + Self { uri: self.uri.clone(), _req: PhantomData, _res: PhantomData } + } +} + +#[derive(Debug, Error, PartialEq)] pub enum ClientError { + #[error("Could not conenct to server")] + ConnectionFailure, #[error("Got an unexpected response type.")] UnexpectedResponse, } diff --git a/crates/mempool_infra/tests/component_server_client_http_test.rs b/crates/mempool_infra/tests/component_server_client_http_test.rs index 837246590..8aefbe389 100644 --- a/crates/mempool_infra/tests/component_server_client_http_test.rs +++ b/crates/mempool_infra/tests/component_server_client_http_test.rs @@ -1,11 +1,9 @@ mod common; -use std::net::IpAddr; - use async_trait::async_trait; use common::{ComponentAClientTrait, ComponentBClientTrait, ResultA, ResultB}; use serde::{Deserialize, Serialize}; -use starknet_mempool_infra::component_client::ComponentClientHttp; +use starknet_mempool_infra::component_client::{ClientError, ComponentClientHttp}; use starknet_mempool_infra::component_definitions::ComponentRequestHandler; use starknet_mempool_infra::component_server::ComponentServerHttp; use tokio::task; @@ -27,7 +25,7 @@ pub enum ComponentAResponse { #[async_trait] impl ComponentAClientTrait for ComponentClientHttp { async fn a_get_value(&self) -> ResultA { - match self.send(ComponentARequest::AGetValue).await { + match self.send(ComponentARequest::AGetValue).await? { ComponentAResponse::Value(value) => Ok(value), } } @@ -57,7 +55,7 @@ pub enum ComponentBResponse { #[async_trait] impl ComponentBClientTrait for ComponentClientHttp { async fn b_get_value(&self) -> ResultB { - match self.send(ComponentBRequest::BGetValue).await { + match self.send(ComponentBRequest::BGetValue).await? { ComponentBResponse::Value(value) => Ok(value), } } @@ -72,11 +70,20 @@ impl ComponentRequestHandler for Componen } } -async fn verify_response(ip_address: IpAddr, port: u16, expected_value: ValueA) { - let a_client = ComponentClientHttp::new(ip_address, port); +async fn verify_response( + a_client: &ComponentClientHttp, + expected_value: ValueA, +) { assert_eq!(a_client.a_get_value().await.unwrap(), expected_value); } +async fn verify_error( + a_client: &ComponentClientHttp, + expected_error: ClientError, +) { + assert_eq!(a_client.a_get_value().await, Err(expected_error)); +} + #[tokio::test] async fn test_setup() { let setup_value: ValueB = 90; @@ -91,8 +98,10 @@ async fn test_setup() { let b_client = ComponentClientHttp::::new(local_ip, b_port); + verify_error(&a_client, ClientError::ConnectionFailure).await; + let component_a = ComponentA::new(Box::new(b_client)); - let component_b = ComponentB::new(setup_value, Box::new(a_client)); + let component_b = ComponentB::new(setup_value, Box::new(a_client.clone())); let mut component_a_server = ComponentServerHttp::< ComponentA, @@ -116,5 +125,5 @@ async fn test_setup() { // Todo(uriel): Get rid of this task::yield_now().await; - verify_response(local_ip, a_port, expected_value).await; + verify_response(&a_client, expected_value).await; }