Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: adding server for active components #526

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 20 additions & 0 deletions crates/mempool_infra/src/component_server/definitions.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use async_trait::async_trait;
use tokio::sync::mpsc::Receiver;
use tracing::{error, info};

use crate::component_definitions::{ComponentRequestAndResponseSender, ComponentRequestHandler};
use crate::component_runner::ComponentStarter;

#[async_trait]
Expand All @@ -20,3 +22,21 @@ where
info!("ComponentServer::start() completed.");
true
}

pub async fn request_response_loop<Request, Response, Component>(
rx: &mut Receiver<ComponentRequestAndResponseSender<Request, Response>>,
component: &mut Component,
) where
Component: ComponentRequestHandler<Request, Response> + Send + Sync,
Request: Send + Sync,
Response: Send + Sync,
{
while let Some(request_and_res_tx) = rx.recv().await {
let request = request_and_res_tx.request;
let tx = request_and_res_tx.tx;

let res = component.handle_request(request).await;

tx.send(res).await.expect("Response connection should be open.");
}
}
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
use async_trait::async_trait;
use tokio::sync::mpsc::Receiver;
use tracing::error;

use super::definitions::{start_component, ComponentServerStarter};
use super::definitions::{request_response_loop, start_component, ComponentServerStarter};
use crate::component_definitions::{ComponentRequestAndResponseSender, ComponentRequestHandler};
use crate::component_runner::ComponentStarter;

Expand Down Expand Up @@ -137,14 +138,56 @@ where
{
async fn start(&mut self) {
if start_component(&mut self.component).await {
while let Some(request_and_res_tx) = self.rx.recv().await {
let request = request_and_res_tx.request;
let tx = request_and_res_tx.tx;
request_response_loop(&mut self.rx, &mut self.component).await;
}
}
}

pub struct LocalActiveComponentServer<Component, Request, Response>
where
Component: ComponentRequestHandler<Request, Response> + ComponentStarter + Clone + Send + Sync,
Request: Send + Sync,
Response: Send + Sync,
{
component: Component,
rx: Receiver<ComponentRequestAndResponseSender<Request, Response>>,
}

impl<Component, Request, Response> LocalActiveComponentServer<Component, Request, Response>
where
Component: ComponentRequestHandler<Request, Response> + ComponentStarter + Clone + Send + Sync,
Request: Send + Sync,
Response: Send + Sync,
{
pub fn new(
component: Component,
rx: Receiver<ComponentRequestAndResponseSender<Request, Response>>,
) -> Self {
Self { component, rx }
}
}

let res = self.component.handle_request(request).await;
#[async_trait]
impl<Component, Request, Response> ComponentServerStarter
for LocalActiveComponentServer<Component, Request, Response>
where
Component: ComponentRequestHandler<Request, Response> + ComponentStarter + Clone + Send + Sync,
Request: Send + Sync,
Response: Send + Sync,
{
async fn start(&mut self) {
let mut component = self.component.clone();
let component_future = async move { component.start().await };
let request_response_future = request_response_loop(&mut self.rx, &mut self.component);

tx.send(res).await.expect("Response connection should be open.");
tokio::select! {
_res = component_future => {
error!("Component stopped.");
}
}
_res = request_response_future => {
error!("Server stopped.");
}
};
error!("Server ended with unexpected Ok.");
}
}
187 changes: 187 additions & 0 deletions crates/mempool_infra/tests/active_component_server_client_test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
use std::future::pending;
use std::sync::Arc;

use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use starknet_mempool_infra::component_client::definitions::{ClientError, ClientResult};
use starknet_mempool_infra::component_client::local_component_client::LocalComponentClient;
use starknet_mempool_infra::component_definitions::{
ComponentRequestAndResponseSender, ComponentRequestHandler,
};
use starknet_mempool_infra::component_runner::{ComponentStartError, ComponentStarter};
use starknet_mempool_infra::component_server::definitions::ComponentServerStarter;
use starknet_mempool_infra::component_server::empty_component_server::EmptyServer;
use starknet_mempool_infra::component_server::local_component_server::LocalActiveComponentServer;
use tokio::sync::mpsc::{channel, Sender};
use tokio::sync::{Mutex, Barrier};
use tokio::task;

#[derive(Debug, Clone)]
struct ComponentC {
counter: Arc<Mutex<usize>>,
max_iterations: usize,
barrier: Arc<Barrier>,
}

impl ComponentC {
pub fn new(init_counter_value: usize, max_iterations: usize, barrier: Arc<Barrier>) -> Self {
Self {
counter: Arc::new(Mutex::new(init_counter_value)),
max_iterations,
barrier,
}
}

pub async fn c_get_counter(&self) -> usize {
*self.counter.lock().await
}

pub async fn c_increment_counter(&self) {
*self.counter.lock().await += 1;
}
}

#[async_trait]
impl ComponentStarter for ComponentC {
async fn start(&mut self) -> Result<(), ComponentStartError> {
for _ in 0..self.max_iterations {
self.c_increment_counter().await;
}
let val = self.c_get_counter().await;
assert!(val >= self.max_iterations);
self.barrier.wait().await;

// Mimicking real start function that should not return.
let () = pending().await;
Ok(())
}
}

#[derive(Serialize, Deserialize, Debug)]
pub enum ComponentCRequest {
CIncCounter,
CGetCounter,
}

#[derive(Serialize, Deserialize, Debug)]
pub enum ComponentCResponse {
CIncCounter,
CGetCounter(usize),
}

#[async_trait]
trait ComponentCClientTrait: Send + Sync {
async fn c_inc_counter(&self) -> ClientResult<()>;
async fn c_get_counter(&self) -> ClientResult<usize>;
}

struct ComponentD {
c: Box<dyn ComponentCClientTrait>,
max_iterations: usize,
barrier: Arc<Barrier>,
}

impl ComponentD {
pub fn new(c: Box<dyn ComponentCClientTrait>, max_iterations: usize, barrier: Arc<Barrier>) -> Self {
Self { c, max_iterations, barrier }
}

pub async fn d_increment_counter(&self) {
self.c.c_inc_counter().await.unwrap()
}

pub async fn d_get_counter(&self) -> usize {
self.c.c_get_counter().await.unwrap()
}
}

#[async_trait]
impl ComponentStarter for ComponentD {
async fn start(&mut self) -> Result<(), ComponentStartError> {
for _ in 0..self.max_iterations {
self.d_increment_counter().await;
}
let val = self.d_get_counter().await;
assert!(val >= self.max_iterations);
self.barrier.wait().await;

// Mimicking real start function that should not return.
let () = pending().await;
Ok(())
}
}

#[async_trait]
impl ComponentCClientTrait for LocalComponentClient<ComponentCRequest, ComponentCResponse> {
async fn c_inc_counter(&self) -> ClientResult<()> {
let res = self.send(ComponentCRequest::CIncCounter).await;
match res {
ComponentCResponse::CIncCounter => Ok(()),
_ => Err(ClientError::UnexpectedResponse),
}
}

async fn c_get_counter(&self) -> ClientResult<usize> {
let res = self.send(ComponentCRequest::CGetCounter).await;
match res {
ComponentCResponse::CGetCounter(counter) => Ok(counter),
_ => Err(ClientError::UnexpectedResponse),
}
}
}

#[async_trait]
impl ComponentRequestHandler<ComponentCRequest, ComponentCResponse> for ComponentC {
async fn handle_request(&mut self, request: ComponentCRequest) -> ComponentCResponse {
match request {
ComponentCRequest::CGetCounter => {
ComponentCResponse::CGetCounter(self.c_get_counter().await)
}
ComponentCRequest::CIncCounter => {
self.c_increment_counter().await;
ComponentCResponse::CIncCounter
}
}
}
}

async fn wait_and_verify_response(
tx_c: Sender<ComponentRequestAndResponseSender<ComponentCRequest, ComponentCResponse>>,
expected_counter_value: usize,
barrier: Arc<Barrier>,
) {
let c_client = LocalComponentClient::new(tx_c);

barrier.wait().await;
assert_eq!(c_client.c_get_counter().await.unwrap(), expected_counter_value);
}

#[tokio::test]
async fn test_setup_c_d() {
let init_counter_value: usize = 0;
let max_iterations: usize = 1024;
let expected_counter_value = max_iterations * 2;

let (tx_c, rx_c) =
channel::<ComponentRequestAndResponseSender<ComponentCRequest, ComponentCResponse>>(32);

let c_client = LocalComponentClient::new(tx_c.clone());

let barrier = Arc::new(Barrier::new(3));
let component_c = ComponentC::new(init_counter_value, max_iterations, barrier.clone());
let component_d = ComponentD::new(Box::new(c_client), max_iterations, barrier.clone());

let mut component_c_server = LocalActiveComponentServer::new(component_c, rx_c);
let mut component_d_server = EmptyServer::new(component_d);

task::spawn(async move {
component_c_server.start().await;
});

task::spawn(async move {
component_d_server.start().await;
});

// Wait for the components to finish incrementing of the ComponentC::counter and verify it.
wait_and_verify_response(tx_c.clone(), expected_counter_value, barrier).await;
}
Loading