diff --git a/crates/gateway/src/gateway_test.rs b/crates/gateway/src/gateway_test.rs index adbeada26..7dda3151b 100644 --- a/crates/gateway/src/gateway_test.rs +++ b/crates/gateway/src/gateway_test.rs @@ -10,7 +10,7 @@ use starknet_api::external_transaction::ExternalTransaction; use starknet_api::transaction::TransactionHash; use starknet_mempool::mempool::{create_mempool_server, Mempool}; use starknet_mempool_types::mempool_types::{ - MempoolClient, MempoolClientImpl, MempoolRequestAndResponseSender, + MempoolClient, MempoolClientImpl, MempoolRequestWithResponder, }; use tokio::sync::mpsc::channel; use tokio::task; @@ -52,7 +52,7 @@ async fn test_add_tx() { // TODO(Tsabary): wrap creation of channels in dedicated functions, take channel capacity from // config. let (tx_mempool, rx_mempool) = - channel::(MEMPOOL_INVOCATIONS_QUEUE_SIZE); + channel::(MEMPOOL_INVOCATIONS_QUEUE_SIZE); let mut mempool_server = create_mempool_server(mempool, rx_mempool); task::spawn(async move { mempool_server.start().await; diff --git a/crates/mempool/src/mempool.rs b/crates/mempool/src/mempool.rs index be4e464ec..05d041df4 100644 --- a/crates/mempool/src/mempool.rs +++ b/crates/mempool/src/mempool.rs @@ -8,7 +8,7 @@ use starknet_mempool_infra::component_definitions::ComponentRequestHandler; use starknet_mempool_infra::component_server::ComponentServer; use starknet_mempool_types::errors::MempoolError; use starknet_mempool_types::mempool_types::{ - Account, AccountState, MempoolInput, MempoolRequest, MempoolRequestAndResponseSender, + Account, AccountState, MempoolInput, MempoolRequest, MempoolRequestWithResponder, MempoolResponse, MempoolResult, ThinTransaction, }; use tokio::sync::mpsc::Receiver; @@ -138,7 +138,7 @@ type MempoolCommunicationServer = pub fn create_mempool_server( mempool: Mempool, - rx_mempool: Receiver, + rx_mempool: Receiver, ) -> MempoolCommunicationServer { let mempool_communication_wrapper = MempoolCommunicationWrapper::new(mempool); ComponentServer::new(mempool_communication_wrapper, rx_mempool) diff --git a/crates/mempool_infra/src/component_client.rs b/crates/mempool_infra/src/component_client.rs index 6c6a6bb29..555b03600 100644 --- a/crates/mempool_infra/src/component_client.rs +++ b/crates/mempool_infra/src/component_client.rs @@ -1,5 +1,7 @@ use thiserror::Error; use tokio::sync::mpsc::{channel, Sender}; +use tonic::transport::Error as TonicError; +use tonic::Status as TonicStatus; use crate::component_definitions::ComponentRequestAndResponseSender; @@ -36,4 +38,8 @@ where pub enum ClientError { #[error("Got an unexpected response type.")] UnexpectedResponse, + #[error("Failed to connect to the server")] + ConnectionFailure(TonicError), + #[error("Failed to get a response from the server")] + ResponseFailure(TonicStatus), } diff --git a/crates/mempool_infra/src/component_server.rs b/crates/mempool_infra/src/component_server.rs index 0a726100e..da405b851 100644 --- a/crates/mempool_infra/src/component_server.rs +++ b/crates/mempool_infra/src/component_server.rs @@ -1,6 +1,6 @@ use tokio::sync::mpsc::Receiver; -use crate::component_definitions::{ComponentRequestAndResponseSender, ComponentRequestHandler}; +use crate::component_definitions::{ComponentRequestHandler, ComponentRequestAndResponseSender}; pub struct ComponentServer where diff --git a/crates/mempool_infra/src/lib.rs b/crates/mempool_infra/src/lib.rs index 6f843ec3c..7aeef30d6 100644 --- a/crates/mempool_infra/src/lib.rs +++ b/crates/mempool_infra/src/lib.rs @@ -1,4 +1,6 @@ pub mod component_client; +pub mod component_client_rpc; pub mod component_definitions; pub mod component_runner; pub mod component_server; +pub mod component_server_rpc; diff --git a/crates/mempool_infra/tests/common/mod.rs b/crates/mempool_infra/tests/common/mod.rs index d41531b53..f40bfa6ca 100644 --- a/crates/mempool_infra/tests/common/mod.rs +++ b/crates/mempool_infra/tests/common/mod.rs @@ -1,4 +1,5 @@ use async_trait::async_trait; +use starknet_mempool_infra::component_client::ClientError; pub(crate) type ValueA = u32; pub(crate) type ValueB = u8; @@ -6,47 +7,40 @@ pub(crate) type ValueB = u8; // TODO(Tsabary): add more messages / functions to the components. #[async_trait] -pub(crate) trait ComponentATrait: Send + Sync { - async fn a_get_value(&self) -> ValueA; +pub(crate) trait AClientTrait: Send + Sync { + async fn a_get_value(&self) -> Result; } #[async_trait] -pub(crate) trait ComponentBTrait: Send + Sync { - async fn b_get_value(&self) -> ValueB; +pub(crate) trait BClientTrait: Send + Sync { + async fn b_get_value(&self) -> Result; } pub(crate) struct ComponentA { - b: Box, -} - -#[async_trait] -impl ComponentATrait for ComponentA { - async fn a_get_value(&self) -> ValueA { - let b_value = self.b.b_get_value().await; - b_value.into() - } + b: Box, } impl ComponentA { - pub fn new(b: Box) -> Self { + pub fn new(b: Box) -> Self { Self { b } } + + pub async fn a_get_value(&self) -> ValueA { + let b_value = self.b.b_get_value().await.unwrap(); + b_value.into() + } } pub(crate) struct ComponentB { value: ValueB, - _a: Box, -} - -#[async_trait] -impl ComponentBTrait for ComponentB { - async fn b_get_value(&self) -> ValueB { - self.value - } + _a: Box, } impl ComponentB { - pub fn new(value: ValueB, a: Box) -> Self { + pub fn new(value: ValueB, a: Box) -> Self { Self { value, _a: a } } + pub async fn b_get_value(&self) -> ValueB { + self.value + } } diff --git a/crates/mempool_infra/tests/component_server_client_rpc_test.rs b/crates/mempool_infra/tests/component_server_client_rpc_test.rs index c255ceaa4..dcca67bce 100644 --- a/crates/mempool_infra/tests/component_server_client_rpc_test.rs +++ b/crates/mempool_infra/tests/component_server_client_rpc_test.rs @@ -6,78 +6,56 @@ mod component_b_service { } mod common; -use std::net::{IpAddr, SocketAddr}; +use std::net::IpAddr; use async_trait::async_trait; -use common::{ComponentATrait, ComponentBTrait}; +use common::{AClientTrait, BClientTrait}; use component_a_service::remote_a_client::RemoteAClient; use component_a_service::remote_a_server::{RemoteA, RemoteAServer}; use component_a_service::{AGetValueMessage, AGetValueReturnMessage}; use component_b_service::remote_b_client::RemoteBClient; use component_b_service::remote_b_server::{RemoteB, RemoteBServer}; use component_b_service::{BGetValueMessage, BGetValueReturnMessage}; +use starknet_mempool_infra::component_client::ClientError; +use starknet_mempool_infra::component_client_rpc::ComponentClientRpc; +use starknet_mempool_infra::component_server_rpc::ComponentServerRpc; use tokio::task; use tonic::transport::Server; -use tonic::{Request, Response, Status}; +use tonic::{Response, Status}; use crate::common::{ComponentA, ComponentB, ValueA, ValueB}; -fn construct_url(ip_address: IpAddr, port: u16) -> String { - match ip_address { - IpAddr::V4(ip_address) => format!("http://{}:{}/", ip_address, port), - IpAddr::V6(ip_address) => format!("http://[{}]:{}/", ip_address, port), - } -} - -struct ComponentAClientRpc { - dst: String, -} - -impl ComponentAClientRpc { - fn new(ip_address: IpAddr, port: u16) -> Self { - Self { dst: construct_url(ip_address, port) } - } -} - #[async_trait] -impl ComponentATrait for ComponentAClientRpc { - async fn a_get_value(&self) -> ValueA { - let Ok(mut client) = RemoteAClient::connect(self.dst.clone()).await else { - panic!("Could not connect to server"); +impl AClientTrait for ComponentClientRpc { + async fn a_get_value(&self) -> Result { + let mut client = match RemoteAClient::connect(self.dst.clone()).await { + Ok(client) => client, + Err(e) => return Err(ClientError::ConnectionFailure(e)), }; - let Ok(response) = client.remote_a_get_value(Request::new(AGetValueMessage {})).await - else { - panic!("Could not get response from server"); + let response = match client.remote_a_get_value(AGetValueMessage {}).await { + Ok(response) => response, + Err(e) => return Err(ClientError::ResponseFailure(e)), }; - response.get_ref().value - } -} - -struct ComponentBClientRpc { - dst: String, -} - -impl ComponentBClientRpc { - fn new(ip_address: IpAddr, port: u16) -> Self { - Self { dst: construct_url(ip_address, port) } + Ok(response.into_inner().value) } } #[async_trait] -impl ComponentBTrait for ComponentBClientRpc { - async fn b_get_value(&self) -> ValueB { - let Ok(mut client) = RemoteBClient::connect(self.dst.clone()).await else { - panic!("Could not connect to server"); +impl BClientTrait for ComponentClientRpc { + async fn b_get_value(&self) -> Result { + let mut client = match RemoteBClient::connect(self.dst.clone()).await { + Ok(client) => client, + Err(e) => return Err(ClientError::ConnectionFailure(e)), }; - let Ok(response) = client.remote_b_get_value(Request::new(BGetValueMessage {})).await - else { - panic!("Could not get response from server"); + let response = match client.remote_b_get_value(BGetValueMessage {}).await { + Ok(response) => response, + Err(e) => return Err(ClientError::ResponseFailure(e)), }; - response.get_ref().value.try_into().unwrap() + Ok(response.into_inner().value.try_into().unwrap()) } } @@ -91,18 +69,15 @@ impl RemoteA for ComponentA { } } -struct ComponentAServerRpc { - a: Option, - address: SocketAddr, +#[async_trait] +pub trait ServerStart { + async fn start(&mut self); } -impl ComponentAServerRpc { - fn new(a: ComponentA, ip_address: IpAddr, port: u16) -> Self { - Self { a: Some(a), address: SocketAddr::new(ip_address, port) } - } - +#[async_trait] +impl ServerStart for ComponentServerRpc { async fn start(&mut self) { - let svc = RemoteAServer::new(self.a.take().unwrap()); + let svc = RemoteAServer::new(self.component.take().unwrap()); Server::builder().add_service(svc).serve(self.address).await.unwrap(); } } @@ -117,25 +92,19 @@ impl RemoteB for ComponentB { } } -struct ComponentBServerRpc { - b: Option, - address: SocketAddr, -} - -impl ComponentBServerRpc { - fn new(b: ComponentB, ip_address: IpAddr, port: u16) -> Self { - Self { b: Some(b), address: SocketAddr::new(ip_address, port) } - } - +#[async_trait] +impl ServerStart for ComponentServerRpc { async fn start(&mut self) { - let svc = RemoteBServer::new(self.b.take().unwrap()); + let svc = RemoteBServer::new(self.component.take().unwrap()); Server::builder().add_service(svc).serve(self.address).await.unwrap(); } } async fn verify_response(ip_address: IpAddr, port: u16, expected_value: ValueA) { - let a_client = ComponentAClientRpc::new(ip_address, port); - assert_eq!(a_client.a_get_value().await, expected_value); + let a_client = ComponentClientRpc::::new(ip_address, port); + + let returned_value = a_client.a_get_value().await.expect("Value should be returned"); + assert_eq!(returned_value, expected_value); } #[tokio::test] @@ -147,14 +116,16 @@ async fn test_setup() { let a_port = 10000; let b_port = 10001; - let a_client = ComponentAClientRpc::new(local_ip, a_port); - let b_client = ComponentBClientRpc::new(local_ip, b_port); + let a_client = ComponentClientRpc::::new(local_ip, a_port); + let b_client = ComponentClientRpc::::new(local_ip, b_port); let component_a = ComponentA::new(Box::new(b_client)); let component_b = ComponentB::new(setup_value, Box::new(a_client)); - let mut component_a_server = ComponentAServerRpc::new(component_a, local_ip, a_port); - let mut component_b_server = ComponentBServerRpc::new(component_b, local_ip, b_port); + let mut component_a_server = + ComponentServerRpc::::new(component_a, local_ip, a_port); + let mut component_b_server = + ComponentServerRpc::::new(component_b, local_ip, b_port); task::spawn(async move { component_a_server.start().await; diff --git a/crates/mempool_infra/tests/component_server_client_test.rs b/crates/mempool_infra/tests/component_server_client_test.rs index 9ec0fb9eb..21f051692 100644 --- a/crates/mempool_infra/tests/component_server_client_test.rs +++ b/crates/mempool_infra/tests/component_server_client_test.rs @@ -1,8 +1,8 @@ mod common; use async_trait::async_trait; -use common::{ComponentATrait, ComponentBTrait}; -use starknet_mempool_infra::component_client::ComponentClient; +use common::{AClientTrait, BClientTrait}; +use starknet_mempool_infra::component_client::{ClientError, ComponentClient}; use starknet_mempool_infra::component_definitions::{ ComponentRequestAndResponseSender, ComponentRequestHandler, }; @@ -14,81 +14,70 @@ use crate::common::{ComponentA, ComponentB, ValueA, ValueB}; // TODO(Tsabary): send messages from component b to component a. -pub enum ComponentARequest { +pub enum RequestA { AGetValue, } -pub enum ComponentAResponse { +pub enum ResponseA { Value(ValueA), } #[async_trait] -impl ComponentATrait for ComponentClient { - async fn a_get_value(&self) -> ValueA { - let res = self.send(ComponentARequest::AGetValue).await; +impl AClientTrait for ComponentClient { + async fn a_get_value(&self) -> Result { + let res = self.send(RequestA::AGetValue).await; match res { - ComponentAResponse::Value(value) => value, + ResponseA::Value(value) => Ok(value), } } } #[async_trait] -impl ComponentRequestHandler for ComponentA { - async fn handle_request(&mut self, request: ComponentARequest) -> ComponentAResponse { +impl ComponentRequestHandler for ComponentA { + async fn handle_request(&mut self, request: RequestA) -> ResponseA { match request { - ComponentARequest::AGetValue => ComponentAResponse::Value(self.a_get_value().await), + RequestA::AGetValue => ResponseA::Value(self.a_get_value().await), } } } #[derive(Copy, Clone, Debug, PartialEq, Eq)] -pub enum ComponentBRequest { +pub enum RequestB { BGetValue, } #[derive(Copy, Clone, Debug, PartialEq, Eq)] -pub enum ComponentBResponse { +pub enum ResponseB { Value(ValueB), } #[async_trait] -impl ComponentBTrait for ComponentClient { - async fn b_get_value(&self) -> ValueB { - let res = self.send(ComponentBRequest::BGetValue).await; +impl BClientTrait for ComponentClient { + async fn b_get_value(&self) -> Result { + let res = self.send(RequestB::BGetValue).await; match res { - ComponentBResponse::Value(value) => value, + ResponseB::Value(value) => Ok(value), } } } #[async_trait] -impl ComponentRequestHandler for ComponentB { - async fn handle_request(&mut self, request: ComponentBRequest) -> ComponentBResponse { +impl ComponentRequestHandler for ComponentB { + async fn handle_request(&mut self, request: RequestB) -> ResponseB { match request { - ComponentBRequest::BGetValue => ComponentBResponse::Value(self.b_get_value().await), + RequestB::BGetValue => ResponseB::Value(self.b_get_value().await), } } } async fn verify_response( - tx_a: Sender>, + tx_a: Sender>, expected_value: ValueA, ) { - let (tx_a_main, mut rx_a_main) = channel::(1); + let a_client = ComponentClient::new(tx_a); - let request_and_res_tx: ComponentRequestAndResponseSender< - ComponentARequest, - ComponentAResponse, - > = ComponentRequestAndResponseSender { request: ComponentARequest::AGetValue, tx: tx_a_main }; - - tx_a.send(request_and_res_tx).await.unwrap(); - - let res = rx_a_main.recv().await.unwrap(); - match res { - ComponentAResponse::Value(value) => { - assert_eq!(value, expected_value); - } - } + let returned_value = a_client.a_get_value().await.expect("Value should be returned"); + assert_eq!(returned_value, expected_value); } #[tokio::test] @@ -96,10 +85,8 @@ async fn test_setup() { let setup_value: ValueB = 30; let expected_value: ValueA = setup_value.into(); - let (tx_a, rx_a) = - channel::>(32); - let (tx_b, rx_b) = - channel::>(32); + let (tx_a, rx_a) = channel::>(32); + let (tx_b, rx_b) = channel::>(32); let a_client = ComponentClient::new(tx_a.clone()); let b_client = ComponentClient::new(tx_b.clone()); diff --git a/crates/mempool_types/src/mempool_types.rs b/crates/mempool_types/src/mempool_types.rs index cd9fbabc1..c17b1ed49 100644 --- a/crates/mempool_types/src/mempool_types.rs +++ b/crates/mempool_types/src/mempool_types.rs @@ -68,7 +68,7 @@ pub enum MempoolResponse { } pub type MempoolClientImpl = ComponentClient; -pub type MempoolRequestAndResponseSender = +pub type MempoolRequestWithResponder = ComponentRequestAndResponseSender; #[async_trait] diff --git a/crates/tests-integration/tests/end_to_end_test.rs b/crates/tests-integration/tests/end_to_end_test.rs index a540cd60a..a87f58f7c 100644 --- a/crates/tests-integration/tests/end_to_end_test.rs +++ b/crates/tests-integration/tests/end_to_end_test.rs @@ -14,7 +14,7 @@ use starknet_gateway::state_reader_test_utils::rpc_test_state_reader_factory; use starknet_mempool::mempool::{create_mempool_server, Mempool}; use starknet_mempool_integration_tests::integration_test_utils::GatewayClient; use starknet_mempool_types::mempool_types::{ - MempoolClient, MempoolClientImpl, MempoolRequestAndResponseSender, + MempoolClient, MempoolClientImpl, MempoolRequestWithResponder, }; use tokio::sync::mpsc::channel; use tokio::task; @@ -59,7 +59,7 @@ async fn test_end_to_end() { // TODO(Tsabary): wrap creation of channels in dedicated functions, take channel capacity from // config. let (tx_mempool, rx_mempool) = - channel::(MEMPOOL_INVOCATIONS_QUEUE_SIZE); + channel::(MEMPOOL_INVOCATIONS_QUEUE_SIZE); let mempool = Mempool::empty(); let mut mempool_server = create_mempool_server(mempool, rx_mempool);