From d09a6906460d7437ef91146ca33349a083ef3c8c Mon Sep 17 00:00:00 2001 From: Uriel Korach Date: Mon, 22 Jul 2024 17:13:23 +0300 Subject: [PATCH] test: add b-set-value function to component communications tests (local/remote) --- crates/mempool_infra/tests/common/mod.rs | 7 ++++ .../component_server_client_http_test.rs | 27 ++++++++++++-- .../tests/component_server_client_test.rs | 37 ++++++++++++++----- 3 files changed, 59 insertions(+), 12 deletions(-) diff --git a/crates/mempool_infra/tests/common/mod.rs b/crates/mempool_infra/tests/common/mod.rs index 97c97c93..18a6a945 100644 --- a/crates/mempool_infra/tests/common/mod.rs +++ b/crates/mempool_infra/tests/common/mod.rs @@ -24,11 +24,13 @@ pub enum ComponentAResponse { #[derive(Serialize, Deserialize, Debug)] pub enum ComponentBRequest { BGetValue, + BSetValue(ValueB), } #[derive(Serialize, Deserialize, Debug)] pub enum ComponentBResponse { BGetValue(ValueB), + BSetValue, } #[async_trait] @@ -39,6 +41,7 @@ pub(crate) trait ComponentAClientTrait: Send + Sync { #[async_trait] pub(crate) trait ComponentBClientTrait: Send + Sync { async fn b_get_value(&self) -> ResultB; + async fn b_set_value(&self, value: ValueB) -> ClientResult<()>; } pub(crate) struct ComponentA { @@ -72,6 +75,10 @@ impl ComponentB { pub fn b_get_value(&self) -> ValueB { self.value } + + pub fn b_set_value(&mut self, value: ValueB) { + self.value = value; + } } #[async_trait] diff --git a/crates/mempool_infra/tests/component_server_client_http_test.rs b/crates/mempool_infra/tests/component_server_client_http_test.rs index 3469a415..c03564a1 100644 --- a/crates/mempool_infra/tests/component_server_client_http_test.rs +++ b/crates/mempool_infra/tests/component_server_client_http_test.rs @@ -14,7 +14,7 @@ use hyper::service::{make_service_fn, service_fn}; use hyper::{Body, Client, Request, Response, Server, StatusCode, Uri}; use rstest::rstest; use serde::Serialize; -use starknet_mempool_infra::component_client::{ClientError, ComponentClientHttp}; +use starknet_mempool_infra::component_client::{ClientError, ClientResult, ComponentClientHttp}; use starknet_mempool_infra::component_definitions::{ ComponentRequestHandler, ServerError, APPLICATION_OCTET_STREAM, }; @@ -64,6 +64,14 @@ impl ComponentBClientTrait for ComponentClientHttp ResultB { match self.send(ComponentBRequest::BGetValue).await? { ComponentBResponse::BGetValue(value) => Ok(value), + _ => Err(ClientError::UnexpectedResponse), + } + } + + async fn b_set_value(&self, value: ValueB) -> ClientResult<()> { + match self.send(ComponentBRequest::BSetValue(value)).await? { + ComponentBResponse::BSetValue => Ok(()), + _ => Err(ClientError::UnexpectedResponse), } } } @@ -73,12 +81,24 @@ impl ComponentRequestHandler for Componen async fn handle_request(&mut self, request: ComponentBRequest) -> ComponentBResponse { match request { ComponentBRequest::BGetValue => ComponentBResponse::BGetValue(self.b_get_value()), + ComponentBRequest::BSetValue(value) => { + self.b_set_value(value); + ComponentBResponse::BSetValue + } } } } -async fn verify_response(a_client: ComponentAClient, expected_value: ValueA) { +async fn verify_response( + a_client: ComponentAClient, + b_client: ComponentBClient, + expected_value: ValueA, +) { assert_eq!(a_client.a_get_value().await.unwrap(), expected_value); + let new_expected_value: ValueB = 222; + + assert!(b_client.b_set_value(new_expected_value).await.is_ok()); + assert_eq!(a_client.a_get_value().await.unwrap(), new_expected_value.into()); } async fn verify_error( @@ -164,7 +184,8 @@ async fn test_proper_setup() { let setup_value: ValueB = 90; setup_for_tests(setup_value, A_PORT_TEST_SETUP, B_PORT_TEST_SETUP).await; let a_client = ComponentAClient::new(LOCAL_IP, A_PORT_TEST_SETUP); - verify_response(a_client, setup_value.into()).await; + let b_client = ComponentBClient::new(LOCAL_IP, B_PORT_TEST_SETUP); + verify_response(a_client, b_client, setup_value.into()).await; } #[tokio::test] diff --git a/crates/mempool_infra/tests/component_server_client_test.rs b/crates/mempool_infra/tests/component_server_client_test.rs index 37a16017..2be900f7 100644 --- a/crates/mempool_infra/tests/component_server_client_test.rs +++ b/crates/mempool_infra/tests/component_server_client_test.rs @@ -5,16 +5,19 @@ use common::{ ComponentAClientTrait, ComponentARequest, ComponentAResponse, ComponentBClientTrait, ComponentBRequest, ComponentBResponse, ResultA, ResultB, }; -use starknet_mempool_infra::component_client::ComponentClient; +use starknet_mempool_infra::component_client::{ClientError, ClientResult, ComponentClient}; use starknet_mempool_infra::component_definitions::{ ComponentRequestAndResponseSender, ComponentRequestHandler, }; use starknet_mempool_infra::component_server::{ComponentServer, ComponentServerStarter}; -use tokio::sync::mpsc::{channel, Sender}; +use tokio::sync::mpsc::channel; use tokio::task; use crate::common::{ComponentA, ComponentB, ValueA, ValueB}; +type ComponentAClient = ComponentClient; +type ComponentBClient = ComponentClient; + // TODO(Tsabary): send messages from component b to component a. #[async_trait] @@ -42,6 +45,14 @@ impl ComponentBClientTrait for ComponentClient Ok(value), + _ => Err(ClientError::UnexpectedResponse), + } + } + + async fn b_set_value(&self, value: ValueB) -> ClientResult<()> { + match self.send(ComponentBRequest::BSetValue(value)).await { + ComponentBResponse::BSetValue => Ok(()), + _ => Err(ClientError::UnexpectedResponse), } } } @@ -51,16 +62,24 @@ impl ComponentRequestHandler for Componen async fn handle_request(&mut self, request: ComponentBRequest) -> ComponentBResponse { match request { ComponentBRequest::BGetValue => ComponentBResponse::BGetValue(self.b_get_value()), + ComponentBRequest::BSetValue(value) => { + self.b_set_value(value); + ComponentBResponse::BSetValue + } } } } async fn verify_response( - tx_a: Sender>, + a_client: ComponentAClient, + b_client: ComponentBClient, expected_value: ValueA, ) { - let a_client = ComponentClient::new(tx_a); assert_eq!(a_client.a_get_value().await.unwrap(), expected_value); + let new_expected_value: ValueB = 222; + + assert!(b_client.b_set_value(new_expected_value).await.is_ok()); + assert_eq!(a_client.a_get_value().await.unwrap(), new_expected_value.into()); } #[tokio::test] @@ -73,11 +92,11 @@ async fn test_setup() { let (tx_b, rx_b) = channel::>(32); - let a_client = ComponentClient::new(tx_a.clone()); - let b_client = ComponentClient::new(tx_b.clone()); + let a_client = ComponentAClient::new(tx_a.clone()); + let b_client = ComponentBClient::new(tx_b.clone()); - let component_a = ComponentA::new(Box::new(b_client)); - let component_b = ComponentB::new(setup_value, Box::new(a_client)); + let component_a = ComponentA::new(Box::new(b_client.clone())); + let component_b = ComponentB::new(setup_value, Box::new(a_client.clone())); let mut component_a_server = ComponentServer::new(component_a, rx_a); let mut component_b_server = ComponentServer::new(component_b, rx_b); @@ -90,5 +109,5 @@ async fn test_setup() { component_b_server.start().await; }); - verify_response(tx_a.clone(), expected_value).await; + verify_response(a_client, b_client, expected_value).await; }