diff --git a/crates/mempool_infra/src/component_client.rs b/crates/mempool_infra/src/component_client.rs index 75d3f44ce..9453ea65e 100644 --- a/crates/mempool_infra/src/component_client.rs +++ b/crates/mempool_infra/src/component_client.rs @@ -1,7 +1,14 @@ +use std::marker::PhantomData; +use std::net::IpAddr; + +use bincode::{deserialize, serialize}; +use hyper::header::CONTENT_TYPE; +use hyper::{Body, Client, Request as HyperRequest, Uri}; +use serde::{Deserialize, Serialize}; use thiserror::Error; use tokio::sync::mpsc::{channel, Sender}; -use crate::component_definitions::ComponentRequestAndResponseSender; +use crate::component_definitions::{ComponentRequestAndResponseSender, APPLICATION_OCTET_STREAM}; pub struct ComponentClient where @@ -43,6 +50,47 @@ where } } +pub struct ComponentClientHttp +where + Request: Serialize, + Response: for<'a> Deserialize<'a>, +{ + uri: Uri, + _req: PhantomData, + _res: PhantomData, +} + +impl ComponentClientHttp +where + Request: Serialize, + Response: for<'a> Deserialize<'a>, +{ + pub fn new(ip_address: IpAddr, port: u16) -> Self { + let uri = match ip_address { + IpAddr::V4(ip_address) => format!("http://{}:{}/", ip_address, port).parse().unwrap(), + IpAddr::V6(ip_address) => format!("http://[{}]:{}/", ip_address, port).parse().unwrap(), + }; + Self { uri, _req: PhantomData, _res: PhantomData } + } + + pub async fn send(&self, component_request: Request) -> Response { + let http_request = HyperRequest::post(self.uri.clone()) + .header(CONTENT_TYPE, APPLICATION_OCTET_STREAM) + .body(Body::from( + serialize(&component_request).expect("Request serialization should succeed"), + )) + .expect("Request builidng should succeed"); + + // Todo(uriel): Add configuration to control number of retries + let http_response = + Client::new().request(http_request).await.expect("Could not connect to server"); + let body_bytes = hyper::body::to_bytes(http_response.into_body()) + .await + .expect("Could not get response from server"); + deserialize(&body_bytes).expect("Response deserialization should succeed") + } +} + #[derive(Debug, Error)] pub enum ClientError { #[error("Got an unexpected response type.")] diff --git a/crates/mempool_infra/src/component_definitions.rs b/crates/mempool_infra/src/component_definitions.rs index dd12d59ff..71c91cf0d 100644 --- a/crates/mempool_infra/src/component_definitions.rs +++ b/crates/mempool_infra/src/component_definitions.rs @@ -14,3 +14,5 @@ where pub request: Request, pub tx: Sender, } + +pub const APPLICATION_OCTET_STREAM: &str = "application/octet-stream"; diff --git a/crates/mempool_infra/tests/component_server_client_http_test.rs b/crates/mempool_infra/tests/component_server_client_http_test.rs index 3ce96ceb2..70a6b6eef 100644 --- a/crates/mempool_infra/tests/component_server_client_http_test.rs +++ b/crates/mempool_infra/tests/component_server_client_http_test.rs @@ -7,8 +7,9 @@ use async_trait::async_trait; use common::{ComponentAClientTrait, ComponentBClientTrait, ResultA, ResultB}; use hyper::header::CONTENT_TYPE; use hyper::service::{make_service_fn, service_fn}; -use hyper::{Body, Client, Request, Response, Server, Uri}; +use hyper::{Body, Request, Response, Server}; use serde::{Deserialize, Serialize}; +use starknet_mempool_infra::component_client::ComponentClientHttp; use starknet_mempool_infra::component_definitions::ComponentRequestHandler; use tokio::sync::Mutex; use tokio::task; @@ -27,41 +28,10 @@ pub enum ComponentAResponse { Value(ValueA), } -// Todo(uriel): Make generic - ComponentClientHttp -struct ComponentAClientHttp { - uri: Uri, -} - -impl ComponentAClientHttp { - pub fn new(ip_address: IpAddr, port: u16) -> Self { - let uri = match ip_address { - IpAddr::V4(ip_address) => format!("http://{}:{}/", ip_address, port).parse().unwrap(), - IpAddr::V6(ip_address) => format!("http://[{}]:{}/", ip_address, port).parse().unwrap(), - }; - Self { uri } - } -} - -// Todo(uriel): Change the component trait to client specific and make it return result #[async_trait] -impl ComponentAClientTrait for ComponentAClientHttp { +impl ComponentAClientTrait for ComponentClientHttp { async fn a_get_value(&self) -> ResultA { - let component_request = ComponentARequest::AGetValue; - let http_request = Request::post(self.uri.clone()) - .header("Content-Type", "application/octet-stream") - .body(Body::from( - bincode::serialize(&component_request) - .expect("Request serialization should succeed"), - )) - .expect("Request builidng should succeed"); - - // Todo(uriel): Add configuration to control number of retries - let http_response = - Client::new().request(http_request).await.expect("Could not connect to server"); - let body_bytes = hyper::body::to_bytes(http_response.into_body()) - .await - .expect("Could not get response from server"); - match bincode::deserialize(&body_bytes).expect("Response deserialization should succeed") { + match self.send(ComponentARequest::AGetValue).await { ComponentAResponse::Value(value) => Ok(value), } } @@ -141,41 +111,10 @@ pub enum ComponentBResponse { Value(ValueB), } -// Todo(uriel): Make generic - ComponentClientHttp -struct ComponentBClientHttp { - uri: Uri, -} - -impl ComponentBClientHttp { - pub fn new(ip_address: IpAddr, port: u16) -> Self { - let uri = match ip_address { - IpAddr::V4(ip_address) => format!("http://{}:{}/", ip_address, port).parse().unwrap(), - IpAddr::V6(ip_address) => format!("http://[{}]:{}/", ip_address, port).parse().unwrap(), - }; - Self { uri } - } -} - -// Todo(uriel): Change the component trait to client specific and make it return result #[async_trait] -impl ComponentBClientTrait for ComponentBClientHttp { +impl ComponentBClientTrait for ComponentClientHttp { async fn b_get_value(&self) -> ResultB { - let component_request = ComponentBRequest::BGetValue; - let http_request = Request::post(self.uri.clone()) - .header("Content-Type", "application/octet-stream") - .body(Body::from( - bincode::serialize(&component_request) - .expect("Request serialization should succeed"), - )) - .expect("Request builidng should succeed"); - - // Todo(uriel): Add configuration to control number of retries - let http_response = - Client::new().request(http_request).await.expect("Could not connect to server"); - let body_bytes = hyper::body::to_bytes(http_response.into_body()) - .await - .expect("Could not get response from server"); - match bincode::deserialize(&body_bytes).expect("Response deserialization should succeed") { + match self.send(ComponentBRequest::BGetValue).await { ComponentBResponse::Value(value) => Ok(value), } } @@ -241,7 +180,7 @@ impl ComponentBServerHttp { } async fn verify_response(ip_address: IpAddr, port: u16, expected_value: ValueA) { - let a_client = ComponentAClientHttp::new(ip_address, port); + let a_client = ComponentClientHttp::new(ip_address, port); assert_eq!(a_client.a_get_value().await.unwrap(), expected_value); } @@ -254,8 +193,10 @@ async fn test_setup() { let a_port = 10000; let b_port = 10001; - let a_client = ComponentAClientHttp::new(local_ip, a_port); - let b_client = ComponentBClientHttp::new(local_ip, b_port); + let a_client = + ComponentClientHttp::::new(local_ip, a_port); + let b_client = + ComponentClientHttp::::new(local_ip, b_port); let component_a = ComponentA::new(Box::new(b_client)); let component_b = ComponentB::new(setup_value, Box::new(a_client));