diff --git a/crates/mempool_infra/src/component_client.rs b/crates/mempool_infra/src/component_client.rs deleted file mode 100644 index db6ccc8ee..000000000 --- a/crates/mempool_infra/src/component_client.rs +++ /dev/null @@ -1,198 +0,0 @@ -use std::marker::PhantomData; -use std::net::IpAddr; - -use bincode::{deserialize, serialize, ErrorKind}; -use hyper::body::to_bytes; -use hyper::header::CONTENT_TYPE; -use hyper::{ - Body, Client, Error as HyperError, Request as HyperRequest, Response as HyperResponse, - StatusCode, Uri, -}; -use serde::{Deserialize, Serialize}; -use thiserror::Error; -use tokio::sync::mpsc::{channel, Sender}; - -use crate::component_definitions::{ - ComponentRequestAndResponseSender, ServerError, APPLICATION_OCTET_STREAM, -}; - -/// The `ComponentClient` struct is a generic client for sending component requests and receiving -/// responses asynchronously. -/// -/// # Type Parameters -/// - `Request`: The type of the request. This type must implement both `Send` and `Sync` traits. -/// - `Response`: The type of the response. This type must implement both `Send` and `Sync` traits. -/// -/// # Fields -/// - `tx`: An asynchronous sender channel for transmitting -/// `ComponentRequestAndResponseSender` messages. -/// -/// # Example -/// ```rust -/// // Example usage of the ComponentClient -/// use tokio::sync::mpsc::Sender; -/// -/// use crate::starknet_mempool_infra::component_client::ComponentClient; -/// use crate::starknet_mempool_infra::component_definitions::ComponentRequestAndResponseSender; -/// -/// // Define your request and response types -/// struct MyRequest { -/// pub content: String, -/// } -/// -/// struct MyResponse { -/// content: String, -/// } -/// -/// #[tokio::main] -/// async fn main() { -/// // Create a channel for sending requests and receiving responses -/// let (tx, _rx) = tokio::sync::mpsc::channel::< -/// ComponentRequestAndResponseSender, -/// >(100); -/// -/// // Instantiate the client. -/// let client = ComponentClient::new(tx); -/// -/// // Instantiate a request. -/// let request = MyRequest { content: "Hello, world!".to_string() }; -/// -/// // Send the request; typically, the client should await for a response. -/// client.send(request); -/// } -/// ``` -/// -/// # Notes -/// - The `ComponentClient` struct is designed to work in an asynchronous environment, utilizing -/// Tokio's async runtime and channels. -pub struct ComponentClient -where - Request: Send + Sync, - Response: Send + Sync, -{ - tx: Sender>, -} - -impl ComponentClient -where - Request: Send + Sync, - Response: Send + Sync, -{ - pub fn new(tx: Sender>) -> Self { - Self { tx } - } - - // TODO(Tsabary, 1/5/2024): Consider implementation for messages without expected responses. - - pub async fn send(&self, request: Request) -> Response { - let (res_tx, mut res_rx) = channel::(1); - let request_and_res_tx = ComponentRequestAndResponseSender { request, tx: res_tx }; - self.tx.send(request_and_res_tx).await.expect("Outbound connection should be open."); - - res_rx.recv().await.expect("Inbound connection should be open.") - } -} - -// Can't derive because derive forces the generics to also be `Clone`, which we prefer not to do -// since it'll require transactions to be cloneable. -impl Clone for ComponentClient -where - Request: Send + Sync, - Response: Send + Sync, -{ - fn clone(&self) -> Self { - Self { tx: self.tx.clone() } - } -} - -pub struct ComponentClientHttp -where - Request: Serialize, - Response: for<'a> Deserialize<'a>, -{ - uri: Uri, - client: Client, - _req: PhantomData, - _res: PhantomData, -} - -impl ComponentClientHttp -where - Request: Serialize, - Response: for<'a> Deserialize<'a>, -{ - pub fn new(ip_address: IpAddr, port: u16) -> Self { - let uri = match ip_address { - IpAddr::V4(ip_address) => format!("http://{}:{}/", ip_address, port).parse().unwrap(), - IpAddr::V6(ip_address) => format!("http://[{}]:{}/", ip_address, port).parse().unwrap(), - }; - // TODO(Tsabary): Add a configuration for the maximum number of idle connections. - // TODO(Tsabary): Add a configuration for "keep-alive" time of idle connections. - let client = - Client::builder().http2_only(true).pool_max_idle_per_host(usize::MAX).build_http(); - Self { uri, client, _req: PhantomData, _res: PhantomData } - } - - pub async fn send(&self, component_request: Request) -> ClientResult { - let http_request = HyperRequest::post(self.uri.clone()) - .header(CONTENT_TYPE, APPLICATION_OCTET_STREAM) - .body(Body::from( - serialize(&component_request).expect("Request serialization should succeed"), - )) - .expect("Request building should succeed"); - - // Todo(uriel): Add configuration for controlling the number of retries. - let http_response = - self.client.request(http_request).await.map_err(ClientError::CommunicationFailure)?; - - match http_response.status() { - StatusCode::OK => get_response_body(http_response).await, - status_code => Err(ClientError::ResponseError( - status_code, - get_response_body(http_response).await?, - )), - } - } -} - -async fn get_response_body(response: HyperResponse) -> Result -where - T: for<'a> Deserialize<'a>, -{ - let body_bytes = - to_bytes(response.into_body()).await.map_err(ClientError::ResponseParsingFailure)?; - deserialize(&body_bytes).map_err(ClientError::ResponseDeserializationFailure) -} - -// Can't derive because derive forces the generics to also be `Clone`, which we prefer not to do -// since it'll require the generic Request and Response types to be cloneable. -impl Clone for ComponentClientHttp -where - Request: Serialize, - Response: for<'a> Deserialize<'a>, -{ - fn clone(&self) -> Self { - Self { - uri: self.uri.clone(), - client: self.client.clone(), - _req: PhantomData, - _res: PhantomData, - } - } -} - -#[derive(Debug, Error)] -pub enum ClientError { - #[error("Communication error: {0}")] - CommunicationFailure(HyperError), - #[error("Could not deserialize server response: {0}")] - ResponseDeserializationFailure(Box), - #[error("Could not parse the response: {0}")] - ResponseParsingFailure(HyperError), - #[error("Got status code: {0}, with server error: {1}")] - ResponseError(StatusCode, ServerError), - #[error("Got an unexpected response type.")] - UnexpectedResponse, -} - -pub type ClientResult = Result; diff --git a/crates/mempool_infra/src/component_client/local_component_client.rs b/crates/mempool_infra/src/component_client/local_component_client.rs new file mode 100644 index 000000000..2330f9fc8 --- /dev/null +++ b/crates/mempool_infra/src/component_client/local_component_client.rs @@ -0,0 +1,92 @@ +use tokio::sync::mpsc::{channel, Sender}; + +use crate::component_definitions::ComponentRequestAndResponseSender; + +/// The `LocalComponentClient` struct is a generic client for sending component requests and +/// receiving responses asynchronously. +/// +/// # Type Parameters +/// - `Request`: The type of the request. This type must implement both `Send` and `Sync` traits. +/// - `Response`: The type of the response. This type must implement both `Send` and `Sync` traits. +/// +/// # Fields +/// - `tx`: An asynchronous sender channel for transmitting +/// `ComponentRequestAndResponseSender` messages. +/// +/// # Example +/// ```rust +/// // Example usage of the LocalComponentClient +/// use tokio::sync::mpsc::Sender; +/// +/// use crate::starknet_mempool_infra::component_client::local_component_client::LocalComponentClient; +/// use crate::starknet_mempool_infra::component_definitions::ComponentRequestAndResponseSender; +/// +/// // Define your request and response types +/// struct MyRequest { +/// pub content: String, +/// } +/// +/// struct MyResponse { +/// content: String, +/// } +/// +/// #[tokio::main] +/// async fn main() { +/// // Create a channel for sending requests and receiving responses +/// let (tx, _rx) = tokio::sync::mpsc::channel::< +/// ComponentRequestAndResponseSender, +/// >(100); +/// +/// // Instantiate the client. +/// let client = LocalComponentClient::new(tx); +/// +/// // Instantiate a request. +/// let request = MyRequest { content: "Hello, world!".to_string() }; +/// +/// // Send the request; typically, the client should await for a response. +/// client.send(request); +/// } +/// ``` +/// +/// # Notes +/// - The `LocalComponentClient` struct is designed to work in an asynchronous environment, +/// utilizing Tokio's async runtime and channels. +pub struct LocalComponentClient +where + Request: Send + Sync, + Response: Send + Sync, +{ + tx: Sender>, +} + +impl LocalComponentClient +where + Request: Send + Sync, + Response: Send + Sync, +{ + pub fn new(tx: Sender>) -> Self { + Self { tx } + } + + // TODO(Tsabary, 1/5/2024): Consider implementation for messages without expected responses. + + pub async fn send(&self, request: Request) -> Response { + let (res_tx, mut res_rx) = channel::(1); + let request_and_res_tx = ComponentRequestAndResponseSender { request, tx: res_tx }; + self.tx.send(request_and_res_tx).await.expect("Outbound connection should be open."); + + res_rx.recv().await.expect("Inbound connection should be open.") + } +} + +// Can't derive because derive forces the generics to also be `Clone`, which we prefer not to do +// since it'll require transactions to be cloneable. +impl Clone for LocalComponentClient +where + Request: Send + Sync, + Response: Send + Sync, +{ + fn clone(&self) -> Self { + Self { tx: self.tx.clone() } + } +} diff --git a/crates/mempool_infra/src/component_client/mod.rs b/crates/mempool_infra/src/component_client/mod.rs new file mode 100644 index 000000000..40b3fca56 --- /dev/null +++ b/crates/mempool_infra/src/component_client/mod.rs @@ -0,0 +1,3 @@ +pub mod definitions; +pub mod local_component_client; +pub mod remote_component_client; diff --git a/crates/mempool_infra/src/component_client/remote_component_client.rs b/crates/mempool_infra/src/component_client/remote_component_client.rs new file mode 100644 index 000000000..c908582c0 --- /dev/null +++ b/crates/mempool_infra/src/component_client/remote_component_client.rs @@ -0,0 +1,87 @@ +use std::marker::PhantomData; +use std::net::IpAddr; + +use bincode::{deserialize, serialize}; +use hyper::body::to_bytes; +use hyper::header::CONTENT_TYPE; +use hyper::{Body, Client, Request as HyperRequest, Response as HyperResponse, StatusCode, Uri}; +use serde::{Deserialize, Serialize}; + +use super::definitions::{ClientError, ClientResult}; +use crate::component_definitions::APPLICATION_OCTET_STREAM; + +pub struct RemoteComponentClient +where + Request: Serialize, + Response: for<'a> Deserialize<'a>, +{ + uri: Uri, + client: Client, + _req: PhantomData, + _res: PhantomData, +} + +impl RemoteComponentClient +where + Request: Serialize, + Response: for<'a> Deserialize<'a>, +{ + pub fn new(ip_address: IpAddr, port: u16) -> Self { + let uri = match ip_address { + IpAddr::V4(ip_address) => format!("http://{}:{}/", ip_address, port).parse().unwrap(), + IpAddr::V6(ip_address) => format!("http://[{}]:{}/", ip_address, port).parse().unwrap(), + }; + // TODO(Tsabary): Add a configuration for the maximum number of idle connections. + // TODO(Tsabary): Add a configuration for "keep-alive" time of idle connections. + let client = + Client::builder().http2_only(true).pool_max_idle_per_host(usize::MAX).build_http(); + Self { uri, client, _req: PhantomData, _res: PhantomData } + } + + pub async fn send(&self, component_request: Request) -> ClientResult { + let http_request = HyperRequest::post(self.uri.clone()) + .header(CONTENT_TYPE, APPLICATION_OCTET_STREAM) + .body(Body::from( + serialize(&component_request).expect("Request serialization should succeed"), + )) + .expect("Request building should succeed"); + + // Todo(uriel): Add configuration for controlling the number of retries. + let http_response = + self.client.request(http_request).await.map_err(ClientError::CommunicationFailure)?; + + match http_response.status() { + StatusCode::OK => get_response_body(http_response).await, + status_code => Err(ClientError::ResponseError( + status_code, + get_response_body(http_response).await?, + )), + } + } +} + +async fn get_response_body(response: HyperResponse) -> Result +where + T: for<'a> Deserialize<'a>, +{ + let body_bytes = + to_bytes(response.into_body()).await.map_err(ClientError::ResponseParsingFailure)?; + deserialize(&body_bytes).map_err(ClientError::ResponseDeserializationFailure) +} + +// Can't derive because derive forces the generics to also be `Clone`, which we prefer not to do +// since it'll require the generic Request and Response types to be cloneable. +impl Clone for RemoteComponentClient +where + Request: Serialize, + Response: for<'a> Deserialize<'a>, +{ + fn clone(&self) -> Self { + Self { + uri: self.uri.clone(), + client: self.client.clone(), + _req: PhantomData, + _res: PhantomData, + } + } +} diff --git a/crates/mempool_infra/tests/common/mod.rs b/crates/mempool_infra/tests/common/mod.rs index 18a6a945d..57b5fabb4 100644 --- a/crates/mempool_infra/tests/common/mod.rs +++ b/crates/mempool_infra/tests/common/mod.rs @@ -1,6 +1,6 @@ use async_trait::async_trait; use serde::{Deserialize, Serialize}; -use starknet_mempool_infra::component_client::ClientResult; +use starknet_mempool_infra::component_client::definitions::ClientResult; use starknet_mempool_infra::component_runner::ComponentStarter; pub(crate) type ValueA = u32; diff --git a/crates/mempool_infra/tests/component_server_client_http_test.rs b/crates/mempool_infra/tests/component_server_client_http_test.rs index c03564a16..345104dc7 100644 --- a/crates/mempool_infra/tests/component_server_client_http_test.rs +++ b/crates/mempool_infra/tests/component_server_client_http_test.rs @@ -14,15 +14,16 @@ use hyper::service::{make_service_fn, service_fn}; use hyper::{Body, Client, Request, Response, Server, StatusCode, Uri}; use rstest::rstest; use serde::Serialize; -use starknet_mempool_infra::component_client::{ClientError, ClientResult, ComponentClientHttp}; +use starknet_mempool_infra::component_client::definitions::{ClientError, ClientResult}; +use starknet_mempool_infra::component_client::remote_component_client::RemoteComponentClient; use starknet_mempool_infra::component_definitions::{ ComponentRequestHandler, ServerError, APPLICATION_OCTET_STREAM, }; use starknet_mempool_infra::component_server::{ComponentServerHttp, ComponentServerStarter}; use tokio::task; -type ComponentAClient = ComponentClientHttp; -type ComponentBClient = ComponentClientHttp; +type ComponentAClient = RemoteComponentClient; +type ComponentBClient = RemoteComponentClient; use crate::common::{ComponentA, ComponentB, ValueA, ValueB}; @@ -42,7 +43,7 @@ const DESERIALIZE_REQ_ERROR_MESSAGE: &str = "Could not deserialize client reques const DESERIALIZE_RES_ERROR_MESSAGE: &str = "Could not deserialize server response"; #[async_trait] -impl ComponentAClientTrait for ComponentClientHttp { +impl ComponentAClientTrait for RemoteComponentClient { async fn a_get_value(&self) -> ResultA { match self.send(ComponentARequest::AGetValue).await? { ComponentAResponse::AGetValue(value) => Ok(value), @@ -60,7 +61,7 @@ impl ComponentRequestHandler for Componen } #[async_trait] -impl ComponentBClientTrait for ComponentClientHttp { +impl ComponentBClientTrait for RemoteComponentClient { async fn b_get_value(&self) -> ResultB { match self.send(ComponentBRequest::BGetValue).await? { ComponentBResponse::BGetValue(value) => Ok(value), diff --git a/crates/mempool_infra/tests/component_server_client_test.rs b/crates/mempool_infra/tests/component_server_client_test.rs index 2be900f70..bc6c9e8c8 100644 --- a/crates/mempool_infra/tests/component_server_client_test.rs +++ b/crates/mempool_infra/tests/component_server_client_test.rs @@ -5,7 +5,8 @@ use common::{ ComponentAClientTrait, ComponentARequest, ComponentAResponse, ComponentBClientTrait, ComponentBRequest, ComponentBResponse, ResultA, ResultB, }; -use starknet_mempool_infra::component_client::{ClientError, ClientResult, ComponentClient}; +use starknet_mempool_infra::component_client::definitions::{ClientError, ClientResult}; +use starknet_mempool_infra::component_client::local_component_client::LocalComponentClient; use starknet_mempool_infra::component_definitions::{ ComponentRequestAndResponseSender, ComponentRequestHandler, }; @@ -15,13 +16,13 @@ use tokio::task; use crate::common::{ComponentA, ComponentB, ValueA, ValueB}; -type ComponentAClient = ComponentClient; -type ComponentBClient = ComponentClient; +type ComponentAClient = LocalComponentClient; +type ComponentBClient = LocalComponentClient; // TODO(Tsabary): send messages from component b to component a. #[async_trait] -impl ComponentAClientTrait for ComponentClient { +impl ComponentAClientTrait for LocalComponentClient { async fn a_get_value(&self) -> ResultA { let res = self.send(ComponentARequest::AGetValue).await; match res { @@ -40,7 +41,7 @@ impl ComponentRequestHandler for Componen } #[async_trait] -impl ComponentBClientTrait for ComponentClient { +impl ComponentBClientTrait for LocalComponentClient { async fn b_get_value(&self) -> ResultB { let res = self.send(ComponentBRequest::BGetValue).await; match res { diff --git a/crates/mempool_types/src/communication.rs b/crates/mempool_types/src/communication.rs index 69e56b611..d19405a8b 100644 --- a/crates/mempool_types/src/communication.rs +++ b/crates/mempool_types/src/communication.rs @@ -3,14 +3,15 @@ use std::sync::Arc; use async_trait::async_trait; use mockall::predicate::*; use mockall::*; -use starknet_mempool_infra::component_client::{ClientError, ComponentClient}; +use starknet_mempool_infra::component_client::definitions::ClientError; +use starknet_mempool_infra::component_client::local_component_client::LocalComponentClient; use starknet_mempool_infra::component_definitions::ComponentRequestAndResponseSender; use thiserror::Error; use crate::errors::MempoolError; use crate::mempool_types::{MempoolInput, ThinTransaction}; -pub type MempoolClientImpl = ComponentClient; +pub type MempoolClientImpl = LocalComponentClient; pub type MempoolResult = Result; pub type MempoolClientResult = Result; pub type MempoolRequestAndResponseSender =