diff --git a/crates/mempool_infra/src/component_server/definitions.rs b/crates/mempool_infra/src/component_server/definitions.rs index 7aa54bb3..3f266566 100644 --- a/crates/mempool_infra/src/component_server/definitions.rs +++ b/crates/mempool_infra/src/component_server/definitions.rs @@ -1,6 +1,8 @@ use async_trait::async_trait; +use tokio::sync::mpsc::Receiver; use tracing::{error, info}; +use crate::component_definitions::{ComponentRequestAndResponseSender, ComponentRequestHandler}; use crate::component_runner::ComponentStarter; #[async_trait] @@ -20,3 +22,21 @@ where info!("ComponentServer::start() completed."); true } + +pub async fn request_response_loop( + rx: &mut Receiver>, + component: &mut Component, +) where + Component: ComponentRequestHandler + Send + Sync, + Request: Send + Sync, + Response: Send + Sync, +{ + while let Some(request_and_res_tx) = rx.recv().await { + let request = request_and_res_tx.request; + let tx = request_and_res_tx.tx; + + let res = component.handle_request(request).await; + + tx.send(res).await.expect("Response connection should be open."); + } +} diff --git a/crates/mempool_infra/src/component_server/local_component_server.rs b/crates/mempool_infra/src/component_server/local_component_server.rs index 19c0906a..746cbd44 100644 --- a/crates/mempool_infra/src/component_server/local_component_server.rs +++ b/crates/mempool_infra/src/component_server/local_component_server.rs @@ -1,7 +1,8 @@ use async_trait::async_trait; use tokio::sync::mpsc::Receiver; +use tracing::error; -use super::definitions::{start_component, ComponentServerStarter}; +use super::definitions::{request_response_loop, start_component, ComponentServerStarter}; use crate::component_definitions::{ComponentRequestAndResponseSender, ComponentRequestHandler}; use crate::component_runner::ComponentStarter; @@ -137,14 +138,56 @@ where { async fn start(&mut self) { if start_component(&mut self.component).await { - while let Some(request_and_res_tx) = self.rx.recv().await { - let request = request_and_res_tx.request; - let tx = request_and_res_tx.tx; + request_response_loop(&mut self.rx, &mut self.component).await; + } + } +} + +pub struct LocalActiveComponentServer +where + Component: ComponentRequestHandler + ComponentStarter + Clone + Send + Sync, + Request: Send + Sync, + Response: Send + Sync, +{ + component: Component, + rx: Receiver>, +} + +impl LocalActiveComponentServer +where + Component: ComponentRequestHandler + ComponentStarter + Clone + Send + Sync, + Request: Send + Sync, + Response: Send + Sync, +{ + pub fn new( + component: Component, + rx: Receiver>, + ) -> Self { + Self { component, rx } + } +} - let res = self.component.handle_request(request).await; +#[async_trait] +impl ComponentServerStarter + for LocalActiveComponentServer +where + Component: ComponentRequestHandler + ComponentStarter + Clone + Send + Sync, + Request: Send + Sync, + Response: Send + Sync, +{ + async fn start(&mut self) { + let mut component = self.component.clone(); + let component_future = async move { component.start().await }; + let request_response_future = request_response_loop(&mut self.rx, &mut self.component); - tx.send(res).await.expect("Response connection should be open."); + tokio::select! { + _res = component_future => { + error!("Component stopped."); } - } + _res = request_response_future => { + error!("Server stopped."); + } + }; + error!("Server ended with unexpected Ok."); } } diff --git a/crates/mempool_infra/tests/active_component_server_client_test.rs b/crates/mempool_infra/tests/active_component_server_client_test.rs new file mode 100644 index 00000000..01f8d4a9 --- /dev/null +++ b/crates/mempool_infra/tests/active_component_server_client_test.rs @@ -0,0 +1,187 @@ +use std::future::pending; +use std::sync::Arc; + +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; +use starknet_mempool_infra::component_client::definitions::{ClientError, ClientResult}; +use starknet_mempool_infra::component_client::local_component_client::LocalComponentClient; +use starknet_mempool_infra::component_definitions::{ + ComponentRequestAndResponseSender, ComponentRequestHandler, +}; +use starknet_mempool_infra::component_runner::{ComponentStartError, ComponentStarter}; +use starknet_mempool_infra::component_server::definitions::ComponentServerStarter; +use starknet_mempool_infra::component_server::empty_component_server::EmptyServer; +use starknet_mempool_infra::component_server::local_component_server::LocalActiveComponentServer; +use tokio::sync::mpsc::{channel, Sender}; +use tokio::sync::{Mutex, Barrier}; +use tokio::task; + +#[derive(Debug, Clone)] +struct ComponentC { + counter: Arc>, + max_iterations: usize, + barrier: Arc, +} + +impl ComponentC { + pub fn new(init_counter_value: usize, max_iterations: usize, barrier: Arc) -> Self { + Self { + counter: Arc::new(Mutex::new(init_counter_value)), + max_iterations, + barrier, + } + } + + pub async fn c_get_counter(&self) -> usize { + *self.counter.lock().await + } + + pub async fn c_increment_counter(&self) { + *self.counter.lock().await += 1; + } +} + +#[async_trait] +impl ComponentStarter for ComponentC { + async fn start(&mut self) -> Result<(), ComponentStartError> { + for _ in 0..self.max_iterations { + self.c_increment_counter().await; + } + let val = self.c_get_counter().await; + assert!(val >= self.max_iterations); + self.barrier.wait().await; + + // Mimicking real start function that should not return. + let () = pending().await; + Ok(()) + } +} + +#[derive(Serialize, Deserialize, Debug)] +pub enum ComponentCRequest { + CIncCounter, + CGetCounter, +} + +#[derive(Serialize, Deserialize, Debug)] +pub enum ComponentCResponse { + CIncCounter, + CGetCounter(usize), +} + +#[async_trait] +trait ComponentCClientTrait: Send + Sync { + async fn c_inc_counter(&self) -> ClientResult<()>; + async fn c_get_counter(&self) -> ClientResult; +} + +struct ComponentD { + c: Box, + max_iterations: usize, + barrier: Arc, +} + +impl ComponentD { + pub fn new(c: Box, max_iterations: usize, barrier: Arc) -> Self { + Self { c, max_iterations, barrier } + } + + pub async fn d_increment_counter(&self) { + self.c.c_inc_counter().await.unwrap() + } + + pub async fn d_get_counter(&self) -> usize { + self.c.c_get_counter().await.unwrap() + } +} + +#[async_trait] +impl ComponentStarter for ComponentD { + async fn start(&mut self) -> Result<(), ComponentStartError> { + for _ in 0..self.max_iterations { + self.d_increment_counter().await; + } + let val = self.d_get_counter().await; + assert!(val >= self.max_iterations); + self.barrier.wait().await; + + // Mimicking real start function that should not return. + let () = pending().await; + Ok(()) + } +} + +#[async_trait] +impl ComponentCClientTrait for LocalComponentClient { + async fn c_inc_counter(&self) -> ClientResult<()> { + let res = self.send(ComponentCRequest::CIncCounter).await; + match res { + ComponentCResponse::CIncCounter => Ok(()), + _ => Err(ClientError::UnexpectedResponse), + } + } + + async fn c_get_counter(&self) -> ClientResult { + let res = self.send(ComponentCRequest::CGetCounter).await; + match res { + ComponentCResponse::CGetCounter(counter) => Ok(counter), + _ => Err(ClientError::UnexpectedResponse), + } + } +} + +#[async_trait] +impl ComponentRequestHandler for ComponentC { + async fn handle_request(&mut self, request: ComponentCRequest) -> ComponentCResponse { + match request { + ComponentCRequest::CGetCounter => { + ComponentCResponse::CGetCounter(self.c_get_counter().await) + } + ComponentCRequest::CIncCounter => { + self.c_increment_counter().await; + ComponentCResponse::CIncCounter + } + } + } +} + +async fn wait_and_verify_response( + tx_c: Sender>, + expected_counter_value: usize, + barrier: Arc, +) { + let c_client = LocalComponentClient::new(tx_c); + + barrier.wait().await; + assert_eq!(c_client.c_get_counter().await.unwrap(), expected_counter_value); +} + +#[tokio::test] +async fn test_setup_c_d() { + let init_counter_value: usize = 0; + let max_iterations: usize = 1024; + let expected_counter_value = max_iterations * 2; + + let (tx_c, rx_c) = + channel::>(32); + + let c_client = LocalComponentClient::new(tx_c.clone()); + + let barrier = Arc::new(Barrier::new(3)); + let component_c = ComponentC::new(init_counter_value, max_iterations, barrier.clone()); + let component_d = ComponentD::new(Box::new(c_client), max_iterations, barrier.clone()); + + let mut component_c_server = LocalActiveComponentServer::new(component_c, rx_c); + let mut component_d_server = EmptyServer::new(component_d); + + task::spawn(async move { + component_c_server.start().await; + }); + + task::spawn(async move { + component_d_server.start().await; + }); + + // Wait for the components to finish incrementing of the ComponentC::counter and verify it. + wait_and_verify_response(tx_c.clone(), expected_counter_value, barrier).await; +}