From 8df36171e9f9ffc51265ca74aa56ff4e7ca71eb3 Mon Sep 17 00:00:00 2001 From: Uriel Korach Date: Tue, 25 Jun 2024 14:00:59 +0300 Subject: [PATCH] chore: change component trait to component client trait and return result instead of value --- crates/mempool_infra/src/component_client.rs | 2 + crates/mempool_infra/tests/common/mod.rs | 45 +++++++++---------- .../component_server_client_http_test.rs | 18 ++++---- .../tests/component_server_client_rpc_test.rs | 18 ++++---- .../tests/component_server_client_test.rs | 16 +++---- 5 files changed, 50 insertions(+), 49 deletions(-) diff --git a/crates/mempool_infra/src/component_client.rs b/crates/mempool_infra/src/component_client.rs index 9baeb075..75d3f44c 100644 --- a/crates/mempool_infra/src/component_client.rs +++ b/crates/mempool_infra/src/component_client.rs @@ -48,3 +48,5 @@ pub enum ClientError { #[error("Got an unexpected response type.")] UnexpectedResponse, } + +pub type ClientResult = Result; diff --git a/crates/mempool_infra/tests/common/mod.rs b/crates/mempool_infra/tests/common/mod.rs index d41531b5..343f6d90 100644 --- a/crates/mempool_infra/tests/common/mod.rs +++ b/crates/mempool_infra/tests/common/mod.rs @@ -1,52 +1,51 @@ use async_trait::async_trait; +use starknet_mempool_infra::component_client::ClientResult; pub(crate) type ValueA = u32; pub(crate) type ValueB = u8; +pub(crate) type ResultA = ClientResult; +pub(crate) type ResultB = ClientResult; + // TODO(Tsabary): add more messages / functions to the components. #[async_trait] -pub(crate) trait ComponentATrait: Send + Sync { - async fn a_get_value(&self) -> ValueA; +#[allow(dead_code)] // Used in integration tests, which are compiled as part of a different crate. +pub(crate) trait ComponentAClientTrait: Send + Sync { + async fn a_get_value(&self) -> ResultA; } #[async_trait] -pub(crate) trait ComponentBTrait: Send + Sync { - async fn b_get_value(&self) -> ValueB; +pub(crate) trait ComponentBClientTrait: Send + Sync { + async fn b_get_value(&self) -> ResultB; } pub(crate) struct ComponentA { - b: Box, -} - -#[async_trait] -impl ComponentATrait for ComponentA { - async fn a_get_value(&self) -> ValueA { - let b_value = self.b.b_get_value().await; - b_value.into() - } + b: Box, } impl ComponentA { - pub fn new(b: Box) -> Self { + pub fn new(b: Box) -> Self { Self { b } } + + pub async fn a_get_value(&self) -> ValueA { + let b_value = self.b.b_get_value().await.unwrap(); + b_value.into() + } } pub(crate) struct ComponentB { value: ValueB, - _a: Box, -} - -#[async_trait] -impl ComponentBTrait for ComponentB { - async fn b_get_value(&self) -> ValueB { - self.value - } + _a: Box, } impl ComponentB { - pub fn new(value: ValueB, a: Box) -> Self { + pub fn new(value: ValueB, a: Box) -> Self { Self { value, _a: a } } + + pub fn b_get_value(&self) -> ValueB { + self.value + } } diff --git a/crates/mempool_infra/tests/component_server_client_http_test.rs b/crates/mempool_infra/tests/component_server_client_http_test.rs index 11b7c654..3ce96ceb 100644 --- a/crates/mempool_infra/tests/component_server_client_http_test.rs +++ b/crates/mempool_infra/tests/component_server_client_http_test.rs @@ -4,7 +4,7 @@ use std::net::{IpAddr, SocketAddr}; use std::sync::Arc; use async_trait::async_trait; -use common::{ComponentATrait, ComponentBTrait}; +use common::{ComponentAClientTrait, ComponentBClientTrait, ResultA, ResultB}; use hyper::header::CONTENT_TYPE; use hyper::service::{make_service_fn, service_fn}; use hyper::{Body, Client, Request, Response, Server, Uri}; @@ -44,8 +44,8 @@ impl ComponentAClientHttp { // Todo(uriel): Change the component trait to client specific and make it return result #[async_trait] -impl ComponentATrait for ComponentAClientHttp { - async fn a_get_value(&self) -> ValueA { +impl ComponentAClientTrait for ComponentAClientHttp { + async fn a_get_value(&self) -> ResultA { let component_request = ComponentARequest::AGetValue; let http_request = Request::post(self.uri.clone()) .header("Content-Type", "application/octet-stream") @@ -62,7 +62,7 @@ impl ComponentATrait for ComponentAClientHttp { .await .expect("Could not get response from server"); match bincode::deserialize(&body_bytes).expect("Response deserialization should succeed") { - ComponentAResponse::Value(value) => value, + ComponentAResponse::Value(value) => Ok(value), } } } @@ -158,8 +158,8 @@ impl ComponentBClientHttp { // Todo(uriel): Change the component trait to client specific and make it return result #[async_trait] -impl ComponentBTrait for ComponentBClientHttp { - async fn b_get_value(&self) -> ValueB { +impl ComponentBClientTrait for ComponentBClientHttp { + async fn b_get_value(&self) -> ResultB { let component_request = ComponentBRequest::BGetValue; let http_request = Request::post(self.uri.clone()) .header("Content-Type", "application/octet-stream") @@ -176,7 +176,7 @@ impl ComponentBTrait for ComponentBClientHttp { .await .expect("Could not get response from server"); match bincode::deserialize(&body_bytes).expect("Response deserialization should succeed") { - ComponentBResponse::Value(value) => value, + ComponentBResponse::Value(value) => Ok(value), } } } @@ -185,7 +185,7 @@ impl ComponentBTrait for ComponentBClientHttp { impl ComponentRequestHandler for ComponentB { async fn handle_request(&mut self, request: ComponentBRequest) -> ComponentBResponse { match request { - ComponentBRequest::BGetValue => ComponentBResponse::Value(self.b_get_value().await), + ComponentBRequest::BGetValue => ComponentBResponse::Value(self.b_get_value()), } } } @@ -242,7 +242,7 @@ impl ComponentBServerHttp { async fn verify_response(ip_address: IpAddr, port: u16, expected_value: ValueA) { let a_client = ComponentAClientHttp::new(ip_address, port); - assert_eq!(a_client.a_get_value().await, expected_value); + assert_eq!(a_client.a_get_value().await.unwrap(), expected_value); } #[tokio::test] diff --git a/crates/mempool_infra/tests/component_server_client_rpc_test.rs b/crates/mempool_infra/tests/component_server_client_rpc_test.rs index c255ceaa..e6c584ea 100644 --- a/crates/mempool_infra/tests/component_server_client_rpc_test.rs +++ b/crates/mempool_infra/tests/component_server_client_rpc_test.rs @@ -9,7 +9,7 @@ mod common; use std::net::{IpAddr, SocketAddr}; use async_trait::async_trait; -use common::{ComponentATrait, ComponentBTrait}; +use common::{ComponentAClientTrait, ComponentBClientTrait, ResultA, ResultB}; use component_a_service::remote_a_client::RemoteAClient; use component_a_service::remote_a_server::{RemoteA, RemoteAServer}; use component_a_service::{AGetValueMessage, AGetValueReturnMessage}; @@ -40,8 +40,8 @@ impl ComponentAClientRpc { } #[async_trait] -impl ComponentATrait for ComponentAClientRpc { - async fn a_get_value(&self) -> ValueA { +impl ComponentAClientTrait for ComponentAClientRpc { + async fn a_get_value(&self) -> ResultA { let Ok(mut client) = RemoteAClient::connect(self.dst.clone()).await else { panic!("Could not connect to server"); }; @@ -51,7 +51,7 @@ impl ComponentATrait for ComponentAClientRpc { panic!("Could not get response from server"); }; - response.get_ref().value + Ok(response.get_ref().value) } } @@ -66,8 +66,8 @@ impl ComponentBClientRpc { } #[async_trait] -impl ComponentBTrait for ComponentBClientRpc { - async fn b_get_value(&self) -> ValueB { +impl ComponentBClientTrait for ComponentBClientRpc { + async fn b_get_value(&self) -> ResultB { let Ok(mut client) = RemoteBClient::connect(self.dst.clone()).await else { panic!("Could not connect to server"); }; @@ -77,7 +77,7 @@ impl ComponentBTrait for ComponentBClientRpc { panic!("Could not get response from server"); }; - response.get_ref().value.try_into().unwrap() + Ok(response.get_ref().value.try_into().unwrap()) } } @@ -113,7 +113,7 @@ impl RemoteB for ComponentB { &self, _request: tonic::Request, ) -> Result, Status> { - Ok(Response::new(BGetValueReturnMessage { value: self.b_get_value().await.into() })) + Ok(Response::new(BGetValueReturnMessage { value: self.b_get_value().into() })) } } @@ -135,7 +135,7 @@ impl ComponentBServerRpc { async fn verify_response(ip_address: IpAddr, port: u16, expected_value: ValueA) { let a_client = ComponentAClientRpc::new(ip_address, port); - assert_eq!(a_client.a_get_value().await, expected_value); + assert_eq!(a_client.a_get_value().await.unwrap(), expected_value); } #[tokio::test] diff --git a/crates/mempool_infra/tests/component_server_client_test.rs b/crates/mempool_infra/tests/component_server_client_test.rs index 9ec0fb9e..df66c5e6 100644 --- a/crates/mempool_infra/tests/component_server_client_test.rs +++ b/crates/mempool_infra/tests/component_server_client_test.rs @@ -1,7 +1,7 @@ mod common; use async_trait::async_trait; -use common::{ComponentATrait, ComponentBTrait}; +use common::{ComponentAClientTrait, ComponentBClientTrait, ResultA, ResultB}; use starknet_mempool_infra::component_client::ComponentClient; use starknet_mempool_infra::component_definitions::{ ComponentRequestAndResponseSender, ComponentRequestHandler, @@ -23,11 +23,11 @@ pub enum ComponentAResponse { } #[async_trait] -impl ComponentATrait for ComponentClient { - async fn a_get_value(&self) -> ValueA { +impl ComponentAClientTrait for ComponentClient { + async fn a_get_value(&self) -> ResultA { let res = self.send(ComponentARequest::AGetValue).await; match res { - ComponentAResponse::Value(value) => value, + ComponentAResponse::Value(value) => Ok(value), } } } @@ -52,11 +52,11 @@ pub enum ComponentBResponse { } #[async_trait] -impl ComponentBTrait for ComponentClient { - async fn b_get_value(&self) -> ValueB { +impl ComponentBClientTrait for ComponentClient { + async fn b_get_value(&self) -> ResultB { let res = self.send(ComponentBRequest::BGetValue).await; match res { - ComponentBResponse::Value(value) => value, + ComponentBResponse::Value(value) => Ok(value), } } } @@ -65,7 +65,7 @@ impl ComponentBTrait for ComponentClient impl ComponentRequestHandler for ComponentB { async fn handle_request(&mut self, request: ComponentBRequest) -> ComponentBResponse { match request { - ComponentBRequest::BGetValue => ComponentBResponse::Value(self.b_get_value().await), + ComponentBRequest::BGetValue => ComponentBResponse::Value(self.b_get_value()), } } }