Skip to content

Commit

Permalink
chore: clean component server client tests
Browse files Browse the repository at this point in the history
  • Loading branch information
uriel-starkware committed Jun 16, 2024
1 parent 8090c8b commit b513b8c
Show file tree
Hide file tree
Showing 7 changed files with 140 additions and 141 deletions.
6 changes: 6 additions & 0 deletions crates/mempool_infra/src/component_client.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use thiserror::Error;
use tokio::sync::mpsc::{channel, Sender};
use tonic::transport::Error as TonicError;
use tonic::Status as TonicStatus;

use crate::component_definitions::ComponentRequestAndResponseSender;

Expand Down Expand Up @@ -36,4 +38,8 @@ where
pub enum ClientError {
#[error("Got an unexpected response type.")]
UnexpectedResponse,
#[error("Failed to connect to the server")]
ConnectionFailure(TonicError),
#[error("Failed to get a response from the server")]
ResponseFailure(TonicStatus),
}
19 changes: 19 additions & 0 deletions crates/mempool_infra/src/component_client_rpc.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
use std::net::IpAddr;

fn construct_url(ip_address: IpAddr, port: u16) -> String {
match ip_address {
IpAddr::V4(ip_address) => format!("http://{}:{}/", ip_address, port),
IpAddr::V6(ip_address) => format!("http://[{}]:{}/", ip_address, port),
}
}

pub struct ComponentClientRpc<Component> {
pub dst: String,
_component: std::marker::PhantomData<Component>,
}

impl<Component> ComponentClientRpc<Component> {
pub fn new(ip_address: IpAddr, port: u16) -> Self {
Self { dst: construct_url(ip_address, port), _component: Default::default() }
}
}
23 changes: 23 additions & 0 deletions crates/mempool_infra/src/component_server_rpc.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
use std::net::{IpAddr, SocketAddr};

use async_trait::async_trait;

#[async_trait]
pub trait ServerStart {
async fn start_server(self, address: SocketAddr);
}

pub struct ComponentServerRpc<Component> {
component: Option<Component>,
address: SocketAddr,
}

impl<Component: ServerStart> ComponentServerRpc<Component> {
pub fn new(component: Component, ip_address: IpAddr, port: u16) -> Self {
Self { component: Some(component), address: SocketAddr::new(ip_address, port) }
}

pub async fn start(&mut self) {
self.component.take().unwrap().start_server(self.address).await;
}
}
2 changes: 2 additions & 0 deletions crates/mempool_infra/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
pub mod component_client;
pub mod component_client_rpc;
pub mod component_definitions;
pub mod component_runner;
pub mod component_server;
pub mod component_server_rpc;
42 changes: 19 additions & 23 deletions crates/mempool_infra/tests/common/mod.rs
Original file line number Diff line number Diff line change
@@ -1,52 +1,48 @@
use async_trait::async_trait;
use starknet_mempool_infra::component_client::ClientError;

pub(crate) type ValueA = u32;
pub(crate) type ValueB = u8;

// TODO(Tsabary): add more messages / functions to the components.

pub type ClientResult<T> = Result<T, ClientError>;

#[async_trait]
pub(crate) trait ComponentATrait: Send + Sync {
async fn a_get_value(&self) -> ValueA;
pub(crate) trait AClientTrait: Send + Sync {
async fn a_get_value(&self) -> ClientResult<ValueA>;
}

#[async_trait]
pub(crate) trait ComponentBTrait: Send + Sync {
async fn b_get_value(&self) -> ValueB;
pub(crate) trait BClientTrait: Send + Sync {
async fn b_get_value(&self) -> ClientResult<ValueB>;
}

pub(crate) struct ComponentA {
b: Box<dyn ComponentBTrait>,
}

#[async_trait]
impl ComponentATrait for ComponentA {
async fn a_get_value(&self) -> ValueA {
let b_value = self.b.b_get_value().await;
b_value.into()
}
b: Box<dyn BClientTrait>,
}

impl ComponentA {
pub fn new(b: Box<dyn ComponentBTrait>) -> Self {
pub fn new(b: Box<dyn BClientTrait>) -> Self {
Self { b }
}

pub async fn a_get_value(&self) -> ValueA {
let b_value = self.b.b_get_value().await.unwrap();
b_value.into()
}
}

pub(crate) struct ComponentB {
value: ValueB,
_a: Box<dyn ComponentATrait>,
}

#[async_trait]
impl ComponentBTrait for ComponentB {
async fn b_get_value(&self) -> ValueB {
self.value
}
_a: Box<dyn AClientTrait>,
}

impl ComponentB {
pub fn new(value: ValueB, a: Box<dyn ComponentATrait>) -> Self {
pub fn new(value: ValueB, a: Box<dyn AClientTrait>) -> Self {
Self { value, _a: a }
}
pub async fn b_get_value(&self) -> ValueB {
self.value
}
}
124 changes: 45 additions & 79 deletions crates/mempool_infra/tests/component_server_client_rpc_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,75 +9,53 @@ mod common;
use std::net::{IpAddr, SocketAddr};

use async_trait::async_trait;
use common::{ComponentATrait, ComponentBTrait};
use common::{AClientTrait, BClientTrait};
use component_a_service::remote_a_client::RemoteAClient;
use component_a_service::remote_a_server::{RemoteA, RemoteAServer};
use component_a_service::{AGetValueMessage, AGetValueReturnMessage};
use component_b_service::remote_b_client::RemoteBClient;
use component_b_service::remote_b_server::{RemoteB, RemoteBServer};
use component_b_service::{BGetValueMessage, BGetValueReturnMessage};
use starknet_mempool_infra::component_client::ClientError;
use starknet_mempool_infra::component_client_rpc::ComponentClientRpc;
use starknet_mempool_infra::component_server_rpc::{ComponentServerRpc, ServerStart};
use tokio::task;
use tonic::transport::Server;
use tonic::{Request, Response, Status};
use tonic::{Response, Status};

use crate::common::{ComponentA, ComponentB, ValueA, ValueB};

fn construct_url(ip_address: IpAddr, port: u16) -> String {
match ip_address {
IpAddr::V4(ip_address) => format!("http://{}:{}/", ip_address, port),
IpAddr::V6(ip_address) => format!("http://[{}]:{}/", ip_address, port),
}
}

struct ComponentAClientRpc {
dst: String,
}

impl ComponentAClientRpc {
fn new(ip_address: IpAddr, port: u16) -> Self {
Self { dst: construct_url(ip_address, port) }
}
}
use crate::common::{ClientResult, ComponentA, ComponentB, ValueA, ValueB};

#[async_trait]
impl ComponentATrait for ComponentAClientRpc {
async fn a_get_value(&self) -> ValueA {
let Ok(mut client) = RemoteAClient::connect(self.dst.clone()).await else {
panic!("Could not connect to server");
impl AClientTrait for ComponentClientRpc<ComponentA> {
async fn a_get_value(&self) -> ClientResult<ValueA> {
let mut client = match RemoteAClient::connect(self.dst.clone()).await {
Ok(client) => client,
Err(e) => return Err(ClientError::ConnectionFailure(e)),
};

let Ok(response) = client.remote_a_get_value(Request::new(AGetValueMessage {})).await
else {
panic!("Could not get response from server");
let response = match client.remote_a_get_value(AGetValueMessage {}).await {
Ok(response) => response,
Err(e) => return Err(ClientError::ResponseFailure(e)),
};

response.get_ref().value
}
}

struct ComponentBClientRpc {
dst: String,
}

impl ComponentBClientRpc {
fn new(ip_address: IpAddr, port: u16) -> Self {
Self { dst: construct_url(ip_address, port) }
Ok(response.into_inner().value)
}
}

#[async_trait]
impl ComponentBTrait for ComponentBClientRpc {
async fn b_get_value(&self) -> ValueB {
let Ok(mut client) = RemoteBClient::connect(self.dst.clone()).await else {
panic!("Could not connect to server");
impl BClientTrait for ComponentClientRpc<ComponentB> {
async fn b_get_value(&self) -> ClientResult<ValueB> {
let mut client = match RemoteBClient::connect(self.dst.clone()).await {
Ok(client) => client,
Err(e) => return Err(ClientError::ConnectionFailure(e)),
};

let Ok(response) = client.remote_b_get_value(Request::new(BGetValueMessage {})).await
else {
panic!("Could not get response from server");
let response = match client.remote_b_get_value(BGetValueMessage {}).await {
Ok(response) => response,
Err(e) => return Err(ClientError::ResponseFailure(e)),
};

response.get_ref().value.try_into().unwrap()
Ok(response.into_inner().value.try_into().unwrap())
}
}

Expand All @@ -91,22 +69,6 @@ impl RemoteA for ComponentA {
}
}

struct ComponentAServerRpc {
a: Option<ComponentA>,
address: SocketAddr,
}

impl ComponentAServerRpc {
fn new(a: ComponentA, ip_address: IpAddr, port: u16) -> Self {
Self { a: Some(a), address: SocketAddr::new(ip_address, port) }
}

async fn start(&mut self) {
let svc = RemoteAServer::new(self.a.take().unwrap());
Server::builder().add_service(svc).serve(self.address).await.unwrap();
}
}

#[async_trait]
impl RemoteB for ComponentB {
async fn remote_b_get_value(
Expand All @@ -117,25 +79,27 @@ impl RemoteB for ComponentB {
}
}

struct ComponentBServerRpc {
b: Option<ComponentB>,
address: SocketAddr,
}

impl ComponentBServerRpc {
fn new(b: ComponentB, ip_address: IpAddr, port: u16) -> Self {
Self { b: Some(b), address: SocketAddr::new(ip_address, port) }
#[async_trait]
impl ServerStart for ComponentA {
async fn start_server(self, address: SocketAddr) {
let svc = RemoteAServer::new(self);
Server::builder().add_service(svc).serve(address).await.unwrap();
}
}

async fn start(&mut self) {
let svc = RemoteBServer::new(self.b.take().unwrap());
Server::builder().add_service(svc).serve(self.address).await.unwrap();
#[async_trait]
impl ServerStart for ComponentB {
async fn start_server(self, address: SocketAddr) {
let svc = RemoteBServer::new(self);
Server::builder().add_service(svc).serve(address).await.unwrap();
}
}

async fn verify_response(ip_address: IpAddr, port: u16, expected_value: ValueA) {
let a_client = ComponentAClientRpc::new(ip_address, port);
assert_eq!(a_client.a_get_value().await, expected_value);
let a_client = ComponentClientRpc::<ComponentA>::new(ip_address, port);

let returned_value = a_client.a_get_value().await.expect("Value should be returned");
assert_eq!(returned_value, expected_value);
}

#[tokio::test]
Expand All @@ -147,14 +111,16 @@ async fn test_setup() {
let a_port = 10000;
let b_port = 10001;

let a_client = ComponentAClientRpc::new(local_ip, a_port);
let b_client = ComponentBClientRpc::new(local_ip, b_port);
let a_client = ComponentClientRpc::<ComponentA>::new(local_ip, a_port);
let b_client = ComponentClientRpc::<ComponentB>::new(local_ip, b_port);

let component_a = ComponentA::new(Box::new(b_client));
let component_b = ComponentB::new(setup_value, Box::new(a_client));

let mut component_a_server = ComponentAServerRpc::new(component_a, local_ip, a_port);
let mut component_b_server = ComponentBServerRpc::new(component_b, local_ip, b_port);
let mut component_a_server =
ComponentServerRpc::<ComponentA>::new(component_a, local_ip, a_port);
let mut component_b_server =
ComponentServerRpc::<ComponentB>::new(component_b, local_ip, b_port);

task::spawn(async move {
component_a_server.start().await;
Expand Down
Loading

0 comments on commit b513b8c

Please sign in to comment.