diff --git a/crates/mempool_infra/src/component_client.rs b/crates/mempool_infra/src/component_client.rs index 6c6a6bb2..555b0360 100644 --- a/crates/mempool_infra/src/component_client.rs +++ b/crates/mempool_infra/src/component_client.rs @@ -1,5 +1,7 @@ use thiserror::Error; use tokio::sync::mpsc::{channel, Sender}; +use tonic::transport::Error as TonicError; +use tonic::Status as TonicStatus; use crate::component_definitions::ComponentRequestAndResponseSender; @@ -36,4 +38,8 @@ where pub enum ClientError { #[error("Got an unexpected response type.")] UnexpectedResponse, + #[error("Failed to connect to the server")] + ConnectionFailure(TonicError), + #[error("Failed to get a response from the server")] + ResponseFailure(TonicStatus), } diff --git a/crates/mempool_infra/src/component_client_rpc.rs b/crates/mempool_infra/src/component_client_rpc.rs new file mode 100644 index 00000000..8ba10e07 --- /dev/null +++ b/crates/mempool_infra/src/component_client_rpc.rs @@ -0,0 +1,19 @@ +use std::net::IpAddr; + +fn construct_url(ip_address: IpAddr, port: u16) -> String { + match ip_address { + IpAddr::V4(ip_address) => format!("http://{}:{}/", ip_address, port), + IpAddr::V6(ip_address) => format!("http://[{}]:{}/", ip_address, port), + } +} + +pub struct ComponentClientRpc { + pub dst: String, + _component: std::marker::PhantomData, +} + +impl ComponentClientRpc { + pub fn new(ip_address: IpAddr, port: u16) -> Self { + Self { dst: construct_url(ip_address, port), _component: Default::default() } + } +} diff --git a/crates/mempool_infra/src/component_server_rpc.rs b/crates/mempool_infra/src/component_server_rpc.rs new file mode 100644 index 00000000..6763f8e3 --- /dev/null +++ b/crates/mempool_infra/src/component_server_rpc.rs @@ -0,0 +1,23 @@ +use std::net::{IpAddr, SocketAddr}; + +use async_trait::async_trait; + +#[async_trait] +pub trait ServerStart { + async fn start_server(self, address: SocketAddr); +} + +pub struct ComponentServerRpc { + component: Option, + address: SocketAddr, +} + +impl ComponentServerRpc { + pub fn new(component: Component, ip_address: IpAddr, port: u16) -> Self { + Self { component: Some(component), address: SocketAddr::new(ip_address, port) } + } + + pub async fn start(&mut self) { + self.component.take().unwrap().start_server(self.address).await; + } +} diff --git a/crates/mempool_infra/src/lib.rs b/crates/mempool_infra/src/lib.rs index 6f843ec3..7aeef30d 100644 --- a/crates/mempool_infra/src/lib.rs +++ b/crates/mempool_infra/src/lib.rs @@ -1,4 +1,6 @@ pub mod component_client; +pub mod component_client_rpc; pub mod component_definitions; pub mod component_runner; pub mod component_server; +pub mod component_server_rpc; diff --git a/crates/mempool_infra/tests/common/mod.rs b/crates/mempool_infra/tests/common/mod.rs index d41531b5..ea0cb5eb 100644 --- a/crates/mempool_infra/tests/common/mod.rs +++ b/crates/mempool_infra/tests/common/mod.rs @@ -1,52 +1,50 @@ use async_trait::async_trait; +use starknet_mempool_infra::component_client::ClientError; pub(crate) type ValueA = u32; pub(crate) type ValueB = u8; // TODO(Tsabary): add more messages / functions to the components. +pub type ClientResult = Result; +pub type AClientResult = ClientResult; +pub type BClientResult = ClientResult; + #[async_trait] -pub(crate) trait ComponentATrait: Send + Sync { - async fn a_get_value(&self) -> ValueA; +pub(crate) trait AClient: Send + Sync { + async fn a_get_value(&self) -> AClientResult; } #[async_trait] -pub(crate) trait ComponentBTrait: Send + Sync { - async fn b_get_value(&self) -> ValueB; +pub(crate) trait BClient: Send + Sync { + async fn b_get_value(&self) -> BClientResult; } pub(crate) struct ComponentA { - b: Box, -} - -#[async_trait] -impl ComponentATrait for ComponentA { - async fn a_get_value(&self) -> ValueA { - let b_value = self.b.b_get_value().await; - b_value.into() - } + b: Box, } impl ComponentA { - pub fn new(b: Box) -> Self { + pub fn new(b: Box) -> Self { Self { b } } + + pub async fn a_get_value(&self) -> ValueA { + let b_value = self.b.b_get_value().await.unwrap(); + b_value.into() + } } pub(crate) struct ComponentB { value: ValueB, - _a: Box, -} - -#[async_trait] -impl ComponentBTrait for ComponentB { - async fn b_get_value(&self) -> ValueB { - self.value - } + _a: Box, } impl ComponentB { - pub fn new(value: ValueB, a: Box) -> Self { + pub fn new(value: ValueB, a: Box) -> Self { Self { value, _a: a } } + pub fn b_get_value(&self) -> ValueB { + self.value + } } diff --git a/crates/mempool_infra/tests/component_server_client_rpc_test.rs b/crates/mempool_infra/tests/component_server_client_rpc_test.rs index c255ceaa..517622f3 100644 --- a/crates/mempool_infra/tests/component_server_client_rpc_test.rs +++ b/crates/mempool_infra/tests/component_server_client_rpc_test.rs @@ -9,75 +9,47 @@ mod common; use std::net::{IpAddr, SocketAddr}; use async_trait::async_trait; -use common::{ComponentATrait, ComponentBTrait}; +use common::{AClient, AClientResult, BClient, BClientResult}; use component_a_service::remote_a_client::RemoteAClient; use component_a_service::remote_a_server::{RemoteA, RemoteAServer}; use component_a_service::{AGetValueMessage, AGetValueReturnMessage}; use component_b_service::remote_b_client::RemoteBClient; use component_b_service::remote_b_server::{RemoteB, RemoteBServer}; use component_b_service::{BGetValueMessage, BGetValueReturnMessage}; +use starknet_mempool_infra::component_client::ClientError; +use starknet_mempool_infra::component_client_rpc::ComponentClientRpc; +use starknet_mempool_infra::component_server_rpc::{ComponentServerRpc, ServerStart}; use tokio::task; use tonic::transport::Server; -use tonic::{Request, Response, Status}; +use tonic::{Response, Status}; use crate::common::{ComponentA, ComponentB, ValueA, ValueB}; -fn construct_url(ip_address: IpAddr, port: u16) -> String { - match ip_address { - IpAddr::V4(ip_address) => format!("http://{}:{}/", ip_address, port), - IpAddr::V6(ip_address) => format!("http://[{}]:{}/", ip_address, port), - } -} - -struct ComponentAClientRpc { - dst: String, -} - -impl ComponentAClientRpc { - fn new(ip_address: IpAddr, port: u16) -> Self { - Self { dst: construct_url(ip_address, port) } - } -} - #[async_trait] -impl ComponentATrait for ComponentAClientRpc { - async fn a_get_value(&self) -> ValueA { - let Ok(mut client) = RemoteAClient::connect(self.dst.clone()).await else { - panic!("Could not connect to server"); - }; - - let Ok(response) = client.remote_a_get_value(Request::new(AGetValueMessage {})).await - else { - panic!("Could not get response from server"); - }; - - response.get_ref().value - } -} - -struct ComponentBClientRpc { - dst: String, -} - -impl ComponentBClientRpc { - fn new(ip_address: IpAddr, port: u16) -> Self { - Self { dst: construct_url(ip_address, port) } +impl AClient for ComponentClientRpc { + async fn a_get_value(&self) -> AClientResult { + let mut client = RemoteAClient::connect(self.dst.clone()) + .await + .map_err(ClientError::ConnectionFailure)?; + let response = client + .remote_a_get_value(AGetValueMessage {}) + .await + .map_err(ClientError::ResponseFailure)?; + Ok(response.into_inner().value) } } #[async_trait] -impl ComponentBTrait for ComponentBClientRpc { - async fn b_get_value(&self) -> ValueB { - let Ok(mut client) = RemoteBClient::connect(self.dst.clone()).await else { - panic!("Could not connect to server"); - }; - - let Ok(response) = client.remote_b_get_value(Request::new(BGetValueMessage {})).await - else { - panic!("Could not get response from server"); - }; - - response.get_ref().value.try_into().unwrap() +impl BClient for ComponentClientRpc { + async fn b_get_value(&self) -> BClientResult { + let mut client = RemoteBClient::connect(self.dst.clone()) + .await + .map_err(ClientError::ConnectionFailure)?; + let response = client + .remote_b_get_value(BGetValueMessage {}) + .await + .map_err(ClientError::ResponseFailure)?; + Ok(response.into_inner().value.try_into().unwrap()) } } @@ -91,51 +63,37 @@ impl RemoteA for ComponentA { } } -struct ComponentAServerRpc { - a: Option, - address: SocketAddr, -} - -impl ComponentAServerRpc { - fn new(a: ComponentA, ip_address: IpAddr, port: u16) -> Self { - Self { a: Some(a), address: SocketAddr::new(ip_address, port) } - } - - async fn start(&mut self) { - let svc = RemoteAServer::new(self.a.take().unwrap()); - Server::builder().add_service(svc).serve(self.address).await.unwrap(); - } -} - #[async_trait] impl RemoteB for ComponentB { async fn remote_b_get_value( &self, _request: tonic::Request, ) -> Result, Status> { - Ok(Response::new(BGetValueReturnMessage { value: self.b_get_value().await.into() })) + Ok(Response::new(BGetValueReturnMessage { value: self.b_get_value().into() })) } } -struct ComponentBServerRpc { - b: Option, - address: SocketAddr, -} - -impl ComponentBServerRpc { - fn new(b: ComponentB, ip_address: IpAddr, port: u16) -> Self { - Self { b: Some(b), address: SocketAddr::new(ip_address, port) } +#[async_trait] +impl ServerStart for ComponentA { + async fn start_server(self, address: SocketAddr) { + let svc = RemoteAServer::new(self); + Server::builder().add_service(svc).serve(address).await.unwrap(); } +} - async fn start(&mut self) { - let svc = RemoteBServer::new(self.b.take().unwrap()); - Server::builder().add_service(svc).serve(self.address).await.unwrap(); +#[async_trait] +impl ServerStart for ComponentB { + async fn start_server(self, address: SocketAddr) { + let svc = RemoteBServer::new(self); + Server::builder().add_service(svc).serve(address).await.unwrap(); } } async fn verify_response(ip_address: IpAddr, port: u16, expected_value: ValueA) { - let a_client = ComponentAClientRpc::new(ip_address, port); - assert_eq!(a_client.a_get_value().await, expected_value); + let a_client = ComponentClientRpc::new(ip_address, port); + + let returned_value = a_client.a_get_value().await.expect("Value should be returned"); + assert_eq!(returned_value, expected_value); } #[tokio::test] @@ -147,14 +105,14 @@ async fn test_setup() { let a_port = 10000; let b_port = 10001; - let a_client = ComponentAClientRpc::new(local_ip, a_port); - let b_client = ComponentBClientRpc::new(local_ip, b_port); + let a_client = ComponentClientRpc::new(local_ip, a_port); + let b_client = ComponentClientRpc::new(local_ip, b_port); let component_a = ComponentA::new(Box::new(b_client)); let component_b = ComponentB::new(setup_value, Box::new(a_client)); - let mut component_a_server = ComponentAServerRpc::new(component_a, local_ip, a_port); - let mut component_b_server = ComponentBServerRpc::new(component_b, local_ip, b_port); + let mut component_a_server = ComponentServerRpc::new(component_a, local_ip, a_port); + let mut component_b_server = ComponentServerRpc::new(component_b, local_ip, b_port); task::spawn(async move { component_a_server.start().await; diff --git a/crates/mempool_infra/tests/component_server_client_test.rs b/crates/mempool_infra/tests/component_server_client_test.rs index 9ec0fb9e..6a8364be 100644 --- a/crates/mempool_infra/tests/component_server_client_test.rs +++ b/crates/mempool_infra/tests/component_server_client_test.rs @@ -1,7 +1,7 @@ mod common; use async_trait::async_trait; -use common::{ComponentATrait, ComponentBTrait}; +use common::{AClient, AClientResult, BClient, BClientResult}; use starknet_mempool_infra::component_client::ComponentClient; use starknet_mempool_infra::component_definitions::{ ComponentRequestAndResponseSender, ComponentRequestHandler, @@ -14,81 +14,70 @@ use crate::common::{ComponentA, ComponentB, ValueA, ValueB}; // TODO(Tsabary): send messages from component b to component a. -pub enum ComponentARequest { +pub enum RequestA { AGetValue, } -pub enum ComponentAResponse { +pub enum ResponseA { Value(ValueA), } #[async_trait] -impl ComponentATrait for ComponentClient { - async fn a_get_value(&self) -> ValueA { - let res = self.send(ComponentARequest::AGetValue).await; +impl AClient for ComponentClient { + async fn a_get_value(&self) -> AClientResult { + let res = self.send(RequestA::AGetValue).await; match res { - ComponentAResponse::Value(value) => value, + ResponseA::Value(value) => Ok(value), } } } #[async_trait] -impl ComponentRequestHandler for ComponentA { - async fn handle_request(&mut self, request: ComponentARequest) -> ComponentAResponse { +impl ComponentRequestHandler for ComponentA { + async fn handle_request(&mut self, request: RequestA) -> ResponseA { match request { - ComponentARequest::AGetValue => ComponentAResponse::Value(self.a_get_value().await), + RequestA::AGetValue => ResponseA::Value(self.a_get_value().await), } } } #[derive(Copy, Clone, Debug, PartialEq, Eq)] -pub enum ComponentBRequest { +pub enum RequestB { BGetValue, } #[derive(Copy, Clone, Debug, PartialEq, Eq)] -pub enum ComponentBResponse { +pub enum ResponseB { Value(ValueB), } #[async_trait] -impl ComponentBTrait for ComponentClient { - async fn b_get_value(&self) -> ValueB { - let res = self.send(ComponentBRequest::BGetValue).await; +impl BClient for ComponentClient { + async fn b_get_value(&self) -> BClientResult { + let res = self.send(RequestB::BGetValue).await; match res { - ComponentBResponse::Value(value) => value, + ResponseB::Value(value) => Ok(value), } } } #[async_trait] -impl ComponentRequestHandler for ComponentB { - async fn handle_request(&mut self, request: ComponentBRequest) -> ComponentBResponse { +impl ComponentRequestHandler for ComponentB { + async fn handle_request(&mut self, request: RequestB) -> ResponseB { match request { - ComponentBRequest::BGetValue => ComponentBResponse::Value(self.b_get_value().await), + RequestB::BGetValue => ResponseB::Value(self.b_get_value()), } } } async fn verify_response( - tx_a: Sender>, + tx_a: Sender>, expected_value: ValueA, ) { - let (tx_a_main, mut rx_a_main) = channel::(1); + let a_client = ComponentClient::new(tx_a); - let request_and_res_tx: ComponentRequestAndResponseSender< - ComponentARequest, - ComponentAResponse, - > = ComponentRequestAndResponseSender { request: ComponentARequest::AGetValue, tx: tx_a_main }; - - tx_a.send(request_and_res_tx).await.unwrap(); - - let res = rx_a_main.recv().await.unwrap(); - match res { - ComponentAResponse::Value(value) => { - assert_eq!(value, expected_value); - } - } + let returned_value = a_client.a_get_value().await.expect("Value should be returned"); + assert_eq!(returned_value, expected_value); } #[tokio::test] @@ -96,10 +85,8 @@ async fn test_setup() { let setup_value: ValueB = 30; let expected_value: ValueA = setup_value.into(); - let (tx_a, rx_a) = - channel::>(32); - let (tx_b, rx_b) = - channel::>(32); + let (tx_a, rx_a) = channel::>(32); + let (tx_b, rx_b) = channel::>(32); let a_client = ComponentClient::new(tx_a.clone()); let b_client = ComponentClient::new(tx_b.clone());