Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: clean component server client tests #233

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions crates/mempool_infra/src/component_client.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
use thiserror::Error;
use tokio::sync::mpsc::{channel, Sender};
use tonic::transport::Error as TonicError;
use tonic::Status as TonicStatus;

use crate::component_definitions::ComponentRequestAndResponseSender;

Expand Down Expand Up @@ -36,4 +38,8 @@ where
pub enum ClientError {
#[error("Got an unexpected response type.")]
UnexpectedResponse,
#[error("Failed to connect to the server")]
ConnectionFailure(TonicError),
#[error("Failed to get a response from the server")]
ResponseFailure(TonicStatus),
}
19 changes: 19 additions & 0 deletions crates/mempool_infra/src/component_client_rpc.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
use std::net::IpAddr;

fn construct_url(ip_address: IpAddr, port: u16) -> String {
match ip_address {
IpAddr::V4(ip_address) => format!("http://{}:{}/", ip_address, port),
IpAddr::V6(ip_address) => format!("http://[{}]:{}/", ip_address, port),
}
}

pub struct ComponentClientRpc<Component> {
pub dst: String,
_component: std::marker::PhantomData<Component>,
}

impl<Component> ComponentClientRpc<Component> {
pub fn new(ip_address: IpAddr, port: u16) -> Self {
Self { dst: construct_url(ip_address, port), _component: Default::default() }
}
}
23 changes: 23 additions & 0 deletions crates/mempool_infra/src/component_server_rpc.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
use std::net::{IpAddr, SocketAddr};

use async_trait::async_trait;

#[async_trait]
pub trait ServerStart {
async fn start_server(self, address: SocketAddr);
}

pub struct ComponentServerRpc<Component> {
component: Option<Component>,
address: SocketAddr,
}

impl<Component: ServerStart> ComponentServerRpc<Component> {
pub fn new(component: Component, ip_address: IpAddr, port: u16) -> Self {
Self { component: Some(component), address: SocketAddr::new(ip_address, port) }
}

pub async fn start(&mut self) {
self.component.take().unwrap().start_server(self.address).await;
}
}
2 changes: 2 additions & 0 deletions crates/mempool_infra/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
pub mod component_client;
pub mod component_client_rpc;
pub mod component_definitions;
pub mod component_runner;
pub mod component_server;
pub mod component_server_rpc;
44 changes: 21 additions & 23 deletions crates/mempool_infra/tests/common/mod.rs
Original file line number Diff line number Diff line change
@@ -1,52 +1,50 @@
use async_trait::async_trait;
use starknet_mempool_infra::component_client::ClientError;

pub(crate) type ValueA = u32;
pub(crate) type ValueB = u8;

// TODO(Tsabary): add more messages / functions to the components.

pub type ClientResult<T> = Result<T, ClientError>;
pub type AClientResult = ClientResult<ValueA>;
pub type BClientResult = ClientResult<ValueB>;

#[async_trait]
pub(crate) trait ComponentATrait: Send + Sync {
async fn a_get_value(&self) -> ValueA;
pub(crate) trait AClient: Send + Sync {
async fn a_get_value(&self) -> AClientResult;
}

#[async_trait]
pub(crate) trait ComponentBTrait: Send + Sync {
async fn b_get_value(&self) -> ValueB;
pub(crate) trait BClient: Send + Sync {
async fn b_get_value(&self) -> BClientResult;
}

pub(crate) struct ComponentA {
b: Box<dyn ComponentBTrait>,
}

#[async_trait]
impl ComponentATrait for ComponentA {
async fn a_get_value(&self) -> ValueA {
let b_value = self.b.b_get_value().await;
b_value.into()
}
b: Box<dyn BClient>,
}

impl ComponentA {
pub fn new(b: Box<dyn ComponentBTrait>) -> Self {
pub fn new(b: Box<dyn BClient>) -> Self {
Self { b }
}

pub async fn a_get_value(&self) -> ValueA {
let b_value = self.b.b_get_value().await.unwrap();
b_value.into()
}
}

pub(crate) struct ComponentB {
value: ValueB,
_a: Box<dyn ComponentATrait>,
}

#[async_trait]
impl ComponentBTrait for ComponentB {
async fn b_get_value(&self) -> ValueB {
self.value
}
_a: Box<dyn AClient>,
}

impl ComponentB {
pub fn new(value: ValueB, a: Box<dyn ComponentATrait>) -> Self {
pub fn new(value: ValueB, a: Box<dyn AClient>) -> Self {
Self { value, _a: a }
}
pub fn b_get_value(&self) -> ValueB {
self.value
}
}
132 changes: 45 additions & 87 deletions crates/mempool_infra/tests/component_server_client_rpc_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,75 +9,47 @@ mod common;
use std::net::{IpAddr, SocketAddr};

use async_trait::async_trait;
use common::{ComponentATrait, ComponentBTrait};
use common::{AClient, AClientResult, BClient, BClientResult};
use component_a_service::remote_a_client::RemoteAClient;
use component_a_service::remote_a_server::{RemoteA, RemoteAServer};
use component_a_service::{AGetValueMessage, AGetValueReturnMessage};
use component_b_service::remote_b_client::RemoteBClient;
use component_b_service::remote_b_server::{RemoteB, RemoteBServer};
use component_b_service::{BGetValueMessage, BGetValueReturnMessage};
use starknet_mempool_infra::component_client::ClientError;
use starknet_mempool_infra::component_client_rpc::ComponentClientRpc;
use starknet_mempool_infra::component_server_rpc::{ComponentServerRpc, ServerStart};
use tokio::task;
use tonic::transport::Server;
use tonic::{Request, Response, Status};
use tonic::{Response, Status};

use crate::common::{ComponentA, ComponentB, ValueA, ValueB};

fn construct_url(ip_address: IpAddr, port: u16) -> String {
match ip_address {
IpAddr::V4(ip_address) => format!("http://{}:{}/", ip_address, port),
IpAddr::V6(ip_address) => format!("http://[{}]:{}/", ip_address, port),
}
}

struct ComponentAClientRpc {
dst: String,
}

impl ComponentAClientRpc {
fn new(ip_address: IpAddr, port: u16) -> Self {
Self { dst: construct_url(ip_address, port) }
}
}

#[async_trait]
impl ComponentATrait for ComponentAClientRpc {
async fn a_get_value(&self) -> ValueA {
let Ok(mut client) = RemoteAClient::connect(self.dst.clone()).await else {
panic!("Could not connect to server");
};

let Ok(response) = client.remote_a_get_value(Request::new(AGetValueMessage {})).await
else {
panic!("Could not get response from server");
};

response.get_ref().value
}
}

struct ComponentBClientRpc {
dst: String,
}

impl ComponentBClientRpc {
fn new(ip_address: IpAddr, port: u16) -> Self {
Self { dst: construct_url(ip_address, port) }
impl AClient for ComponentClientRpc<ComponentA> {
async fn a_get_value(&self) -> AClientResult {
let mut client = RemoteAClient::connect(self.dst.clone())
.await
.map_err(ClientError::ConnectionFailure)?;
let response = client
.remote_a_get_value(AGetValueMessage {})
.await
.map_err(ClientError::ResponseFailure)?;
Ok(response.into_inner().value)
}
}

#[async_trait]
impl ComponentBTrait for ComponentBClientRpc {
async fn b_get_value(&self) -> ValueB {
let Ok(mut client) = RemoteBClient::connect(self.dst.clone()).await else {
panic!("Could not connect to server");
};

let Ok(response) = client.remote_b_get_value(Request::new(BGetValueMessage {})).await
else {
panic!("Could not get response from server");
};

response.get_ref().value.try_into().unwrap()
impl BClient for ComponentClientRpc<ComponentB> {
async fn b_get_value(&self) -> BClientResult {
let mut client = RemoteBClient::connect(self.dst.clone())
.await
.map_err(ClientError::ConnectionFailure)?;
let response = client
.remote_b_get_value(BGetValueMessage {})
.await
.map_err(ClientError::ResponseFailure)?;
Ok(response.into_inner().value.try_into().unwrap())
}
}

Expand All @@ -91,51 +63,37 @@ impl RemoteA for ComponentA {
}
}

struct ComponentAServerRpc {
a: Option<ComponentA>,
address: SocketAddr,
}

impl ComponentAServerRpc {
fn new(a: ComponentA, ip_address: IpAddr, port: u16) -> Self {
Self { a: Some(a), address: SocketAddr::new(ip_address, port) }
}

async fn start(&mut self) {
let svc = RemoteAServer::new(self.a.take().unwrap());
Server::builder().add_service(svc).serve(self.address).await.unwrap();
}
}

#[async_trait]
impl RemoteB for ComponentB {
async fn remote_b_get_value(
&self,
_request: tonic::Request<BGetValueMessage>,
) -> Result<Response<BGetValueReturnMessage>, Status> {
Ok(Response::new(BGetValueReturnMessage { value: self.b_get_value().await.into() }))
Ok(Response::new(BGetValueReturnMessage { value: self.b_get_value().into() }))
}
}

struct ComponentBServerRpc {
b: Option<ComponentB>,
address: SocketAddr,
}

impl ComponentBServerRpc {
fn new(b: ComponentB, ip_address: IpAddr, port: u16) -> Self {
Self { b: Some(b), address: SocketAddr::new(ip_address, port) }
#[async_trait]
impl ServerStart for ComponentA {
async fn start_server(self, address: SocketAddr) {
let svc = RemoteAServer::new(self);
Server::builder().add_service(svc).serve(address).await.unwrap();
}
}

async fn start(&mut self) {
let svc = RemoteBServer::new(self.b.take().unwrap());
Server::builder().add_service(svc).serve(self.address).await.unwrap();
#[async_trait]
impl ServerStart for ComponentB {
async fn start_server(self, address: SocketAddr) {
let svc = RemoteBServer::new(self);
Server::builder().add_service(svc).serve(address).await.unwrap();
}
}

async fn verify_response(ip_address: IpAddr, port: u16, expected_value: ValueA) {
let a_client = ComponentAClientRpc::new(ip_address, port);
assert_eq!(a_client.a_get_value().await, expected_value);
let a_client = ComponentClientRpc::new(ip_address, port);

let returned_value = a_client.a_get_value().await.expect("Value should be returned");
assert_eq!(returned_value, expected_value);
}

#[tokio::test]
Expand All @@ -147,14 +105,14 @@ async fn test_setup() {
let a_port = 10000;
let b_port = 10001;

let a_client = ComponentAClientRpc::new(local_ip, a_port);
let b_client = ComponentBClientRpc::new(local_ip, b_port);
let a_client = ComponentClientRpc::new(local_ip, a_port);
let b_client = ComponentClientRpc::new(local_ip, b_port);

let component_a = ComponentA::new(Box::new(b_client));
let component_b = ComponentB::new(setup_value, Box::new(a_client));

let mut component_a_server = ComponentAServerRpc::new(component_a, local_ip, a_port);
let mut component_b_server = ComponentBServerRpc::new(component_b, local_ip, b_port);
let mut component_a_server = ComponentServerRpc::new(component_a, local_ip, a_port);
let mut component_b_server = ComponentServerRpc::new(component_b, local_ip, b_port);

task::spawn(async move {
component_a_server.start().await;
Expand Down
Loading
Loading