Skip to content

Commit

Permalink
chore: clean component server client tests
Browse files Browse the repository at this point in the history
  • Loading branch information
uriel-starkware committed Jun 10, 2024
1 parent dbc0b15 commit 7541c67
Show file tree
Hide file tree
Showing 10 changed files with 102 additions and 142 deletions.
4 changes: 2 additions & 2 deletions crates/gateway/src/gateway_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use starknet_api::external_transaction::ExternalTransaction;
use starknet_api::transaction::TransactionHash;
use starknet_mempool::mempool::{create_mempool_server, Mempool};
use starknet_mempool_types::mempool_types::{
MempoolClient, MempoolClientImpl, MempoolRequestAndResponseSender,
MempoolClient, MempoolClientImpl, MempoolRequestWithResponder,
};
use tokio::sync::mpsc::channel;
use tokio::task;
Expand Down Expand Up @@ -52,7 +52,7 @@ async fn test_add_tx() {
// TODO(Tsabary): wrap creation of channels in dedicated functions, take channel capacity from
// config.
let (tx_mempool, rx_mempool) =
channel::<MempoolRequestAndResponseSender>(MEMPOOL_INVOCATIONS_QUEUE_SIZE);
channel::<MempoolRequestWithResponder>(MEMPOOL_INVOCATIONS_QUEUE_SIZE);
let mut mempool_server = create_mempool_server(mempool, rx_mempool);
task::spawn(async move {
mempool_server.start().await;
Expand Down
4 changes: 2 additions & 2 deletions crates/mempool/src/mempool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use starknet_mempool_infra::component_definitions::ComponentRequestHandler;
use starknet_mempool_infra::component_server::ComponentServer;
use starknet_mempool_types::errors::MempoolError;
use starknet_mempool_types::mempool_types::{
Account, AccountState, MempoolInput, MempoolRequest, MempoolRequestAndResponseSender,
Account, AccountState, MempoolInput, MempoolRequest, MempoolRequestWithResponder,
MempoolResponse, MempoolResult, ThinTransaction,
};
use tokio::sync::mpsc::Receiver;
Expand Down Expand Up @@ -138,7 +138,7 @@ type MempoolCommunicationServer =

pub fn create_mempool_server(
mempool: Mempool,
rx_mempool: Receiver<MempoolRequestAndResponseSender>,
rx_mempool: Receiver<MempoolRequestWithResponder>,
) -> MempoolCommunicationServer {
let mempool_communication_wrapper = MempoolCommunicationWrapper::new(mempool);
ComponentServer::new(mempool_communication_wrapper, rx_mempool)
Expand Down
6 changes: 6 additions & 0 deletions crates/mempool_infra/src/component_client.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use thiserror::Error;
use tokio::sync::mpsc::{channel, Sender};
use tonic::transport::Error as TonicError;
use tonic::Status as TonicStatus;

use crate::component_definitions::ComponentRequestAndResponseSender;

Expand Down Expand Up @@ -36,4 +38,8 @@ where
pub enum ClientError {
#[error("Got an unexpected response type.")]
UnexpectedResponse,
#[error("Failed to connect to the server")]
ConnectionFailure(TonicError),
#[error("Failed to get a response from the server")]
ResponseFailure(TonicStatus),
}
2 changes: 1 addition & 1 deletion crates/mempool_infra/src/component_server.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use tokio::sync::mpsc::Receiver;

use crate::component_definitions::{ComponentRequestAndResponseSender, ComponentRequestHandler};
use crate::component_definitions::{ComponentRequestHandler, ComponentRequestAndResponseSender};

pub struct ComponentServer<Component, Request, Response>
where
Expand Down
2 changes: 2 additions & 0 deletions crates/mempool_infra/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
pub mod component_client;
pub mod component_client_rpc;
pub mod component_definitions;
pub mod component_runner;
pub mod component_server;
pub mod component_server_rpc;
40 changes: 17 additions & 23 deletions crates/mempool_infra/tests/common/mod.rs
Original file line number Diff line number Diff line change
@@ -1,52 +1,46 @@
use async_trait::async_trait;
use starknet_mempool_infra::component_client::ClientError;

pub(crate) type ValueA = u32;
pub(crate) type ValueB = u8;

// TODO(Tsabary): add more messages / functions to the components.

#[async_trait]
pub(crate) trait ComponentATrait: Send + Sync {
async fn a_get_value(&self) -> ValueA;
pub(crate) trait AClientTrait: Send + Sync {
async fn a_get_value(&self) -> Result<ValueA, ClientError>;
}

#[async_trait]
pub(crate) trait ComponentBTrait: Send + Sync {
async fn b_get_value(&self) -> ValueB;
pub(crate) trait BClientTrait: Send + Sync {
async fn b_get_value(&self) -> Result<ValueB, ClientError>;
}

pub(crate) struct ComponentA {
b: Box<dyn ComponentBTrait>,
}

#[async_trait]
impl ComponentATrait for ComponentA {
async fn a_get_value(&self) -> ValueA {
let b_value = self.b.b_get_value().await;
b_value.into()
}
b: Box<dyn BClientTrait>,
}

impl ComponentA {
pub fn new(b: Box<dyn ComponentBTrait>) -> Self {
pub fn new(b: Box<dyn BClientTrait>) -> Self {
Self { b }
}

pub async fn a_get_value(&self) -> ValueA {
let b_value = self.b.b_get_value().await.unwrap();
b_value.into()
}
}

pub(crate) struct ComponentB {
value: ValueB,
_a: Box<dyn ComponentATrait>,
}

#[async_trait]
impl ComponentBTrait for ComponentB {
async fn b_get_value(&self) -> ValueB {
self.value
}
_a: Box<dyn AClientTrait>,
}

impl ComponentB {
pub fn new(value: ValueB, a: Box<dyn ComponentATrait>) -> Self {
pub fn new(value: ValueB, a: Box<dyn AClientTrait>) -> Self {
Self { value, _a: a }
}
pub async fn b_get_value(&self) -> ValueB {
self.value
}
}
115 changes: 43 additions & 72 deletions crates/mempool_infra/tests/component_server_client_rpc_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,78 +6,56 @@ mod component_b_service {
}
mod common;

use std::net::{IpAddr, SocketAddr};
use std::net::IpAddr;

use async_trait::async_trait;
use common::{ComponentATrait, ComponentBTrait};
use common::{AClientTrait, BClientTrait};
use component_a_service::remote_a_client::RemoteAClient;
use component_a_service::remote_a_server::{RemoteA, RemoteAServer};
use component_a_service::{AGetValueMessage, AGetValueReturnMessage};
use component_b_service::remote_b_client::RemoteBClient;
use component_b_service::remote_b_server::{RemoteB, RemoteBServer};
use component_b_service::{BGetValueMessage, BGetValueReturnMessage};
use starknet_mempool_infra::component_client::ClientError;
use starknet_mempool_infra::component_client_rpc::ComponentClientRpc;
use starknet_mempool_infra::component_server_rpc::ComponentServerRpc;
use tokio::task;
use tonic::transport::Server;
use tonic::{Request, Response, Status};
use tonic::{Response, Status};

use crate::common::{ComponentA, ComponentB, ValueA, ValueB};

fn construct_url(ip_address: IpAddr, port: u16) -> String {
match ip_address {
IpAddr::V4(ip_address) => format!("http://{}:{}/", ip_address, port),
IpAddr::V6(ip_address) => format!("http://[{}]:{}/", ip_address, port),
}
}

struct ComponentAClientRpc {
dst: String,
}

impl ComponentAClientRpc {
fn new(ip_address: IpAddr, port: u16) -> Self {
Self { dst: construct_url(ip_address, port) }
}
}

#[async_trait]
impl ComponentATrait for ComponentAClientRpc {
async fn a_get_value(&self) -> ValueA {
let Ok(mut client) = RemoteAClient::connect(self.dst.clone()).await else {
panic!("Could not connect to server");
impl AClientTrait for ComponentClientRpc<ComponentA> {
async fn a_get_value(&self) -> Result<ValueA, ClientError> {
let mut client = match RemoteAClient::connect(self.dst.clone()).await {
Ok(client) => client,
Err(e) => return Err(ClientError::ConnectionFailure(e)),
};

let Ok(response) = client.remote_a_get_value(Request::new(AGetValueMessage {})).await
else {
panic!("Could not get response from server");
let response = match client.remote_a_get_value(AGetValueMessage {}).await {
Ok(response) => response,
Err(e) => return Err(ClientError::ResponseFailure(e)),
};

response.get_ref().value
}
}

struct ComponentBClientRpc {
dst: String,
}

impl ComponentBClientRpc {
fn new(ip_address: IpAddr, port: u16) -> Self {
Self { dst: construct_url(ip_address, port) }
Ok(response.into_inner().value)
}
}

#[async_trait]
impl ComponentBTrait for ComponentBClientRpc {
async fn b_get_value(&self) -> ValueB {
let Ok(mut client) = RemoteBClient::connect(self.dst.clone()).await else {
panic!("Could not connect to server");
impl BClientTrait for ComponentClientRpc<ComponentB> {
async fn b_get_value(&self) -> Result<ValueB, ClientError> {
let mut client = match RemoteBClient::connect(self.dst.clone()).await {
Ok(client) => client,
Err(e) => return Err(ClientError::ConnectionFailure(e)),
};

let Ok(response) = client.remote_b_get_value(Request::new(BGetValueMessage {})).await
else {
panic!("Could not get response from server");
let response = match client.remote_b_get_value(BGetValueMessage {}).await {
Ok(response) => response,
Err(e) => return Err(ClientError::ResponseFailure(e)),
};

response.get_ref().value.try_into().unwrap()
Ok(response.into_inner().value.try_into().unwrap())
}
}

Expand All @@ -91,18 +69,15 @@ impl RemoteA for ComponentA {
}
}

struct ComponentAServerRpc {
a: Option<ComponentA>,
address: SocketAddr,
#[async_trait]
pub trait ServerStart {
async fn start(&mut self);
}

impl ComponentAServerRpc {
fn new(a: ComponentA, ip_address: IpAddr, port: u16) -> Self {
Self { a: Some(a), address: SocketAddr::new(ip_address, port) }
}

#[async_trait]
impl ServerStart for ComponentServerRpc<ComponentA> {
async fn start(&mut self) {
let svc = RemoteAServer::new(self.a.take().unwrap());
let svc = RemoteAServer::new(self.component.take().unwrap());
Server::builder().add_service(svc).serve(self.address).await.unwrap();
}
}
Expand All @@ -117,25 +92,19 @@ impl RemoteB for ComponentB {
}
}

struct ComponentBServerRpc {
b: Option<ComponentB>,
address: SocketAddr,
}

impl ComponentBServerRpc {
fn new(b: ComponentB, ip_address: IpAddr, port: u16) -> Self {
Self { b: Some(b), address: SocketAddr::new(ip_address, port) }
}

#[async_trait]
impl ServerStart for ComponentServerRpc<ComponentB> {
async fn start(&mut self) {
let svc = RemoteBServer::new(self.b.take().unwrap());
let svc = RemoteBServer::new(self.component.take().unwrap());
Server::builder().add_service(svc).serve(self.address).await.unwrap();
}
}

async fn verify_response(ip_address: IpAddr, port: u16, expected_value: ValueA) {
let a_client = ComponentAClientRpc::new(ip_address, port);
assert_eq!(a_client.a_get_value().await, expected_value);
let a_client = ComponentClientRpc::<ComponentA>::new(ip_address, port);

let returned_value = a_client.a_get_value().await.expect("Value should be returned");
assert_eq!(returned_value, expected_value);
}

#[tokio::test]
Expand All @@ -147,14 +116,16 @@ async fn test_setup() {
let a_port = 10000;
let b_port = 10001;

let a_client = ComponentAClientRpc::new(local_ip, a_port);
let b_client = ComponentBClientRpc::new(local_ip, b_port);
let a_client = ComponentClientRpc::<ComponentA>::new(local_ip, a_port);
let b_client = ComponentClientRpc::<ComponentB>::new(local_ip, b_port);

let component_a = ComponentA::new(Box::new(b_client));
let component_b = ComponentB::new(setup_value, Box::new(a_client));

let mut component_a_server = ComponentAServerRpc::new(component_a, local_ip, a_port);
let mut component_b_server = ComponentBServerRpc::new(component_b, local_ip, b_port);
let mut component_a_server =
ComponentServerRpc::<ComponentA>::new(component_a, local_ip, a_port);
let mut component_b_server =
ComponentServerRpc::<ComponentB>::new(component_b, local_ip, b_port);

task::spawn(async move {
component_a_server.start().await;
Expand Down
Loading

0 comments on commit 7541c67

Please sign in to comment.