From 71ceabc7be8e6331bed03d862a2c739abfc0ae7a Mon Sep 17 00:00:00 2001 From: Uriel Korach Date: Mon, 1 Jul 2024 15:57:35 +0300 Subject: [PATCH] feat: make generic component server http --- crates/mempool_infra/src/component_server.rs | 75 +++++++++++- .../component_server_client_http_test.rs | 115 +----------------- 2 files changed, 78 insertions(+), 112 deletions(-) diff --git a/crates/mempool_infra/src/component_server.rs b/crates/mempool_infra/src/component_server.rs index 0a726100..a30a173c 100644 --- a/crates/mempool_infra/src/component_server.rs +++ b/crates/mempool_infra/src/component_server.rs @@ -1,6 +1,19 @@ +use std::marker::PhantomData; +use std::net::{IpAddr, SocketAddr}; +use std::sync::Arc; + +use bincode::{deserialize, serialize}; +use hyper::body::to_bytes; +use hyper::header::CONTENT_TYPE; +use hyper::service::{make_service_fn, service_fn}; +use hyper::{Body, Request as HyperRequest, Response as HyperResponse, Server}; +use serde::{Deserialize, Serialize}; use tokio::sync::mpsc::Receiver; +use tokio::sync::Mutex; -use crate::component_definitions::{ComponentRequestAndResponseSender, ComponentRequestHandler}; +use crate::component_definitions::{ + ComponentRequestAndResponseSender, ComponentRequestHandler, APPLICATION_OCTET_STREAM, +}; pub struct ComponentServer where @@ -36,3 +49,63 @@ where } } } + +pub struct ComponentServerHttp { + socket: SocketAddr, + component: Arc>, + _req: PhantomData, + _res: PhantomData, +} + +impl ComponentServerHttp +where + Component: ComponentRequestHandler + Send + 'static, + Request: for<'a> Deserialize<'a> + Send + 'static, + Response: Serialize + 'static, +{ + pub fn new(component: Component, ip_address: IpAddr, port: u16) -> Self { + Self { + component: Arc::new(Mutex::new(component)), + socket: SocketAddr::new(ip_address, port), + _req: PhantomData, + _res: PhantomData, + } + } + + pub async fn start(&mut self) { + let make_svc = make_service_fn(|_conn| { + let component = Arc::clone(&self.component); + async { + Ok::<_, hyper::Error>(service_fn(move |req| { + Self::handler(req, Arc::clone(&component)) + })) + } + }); + + Server::bind(&self.socket.clone()).serve(make_svc).await.unwrap(); + } + + async fn handler( + http_request: HyperRequest, + component: Arc>, + ) -> Result, hyper::Error> { + let body_bytes = to_bytes(http_request.into_body()).await?; + let component_request: Request = + deserialize(&body_bytes).expect("Request deserialization should succeed"); + + // Acquire the lock for component computation, release afterwards + let component_response; + { + let mut component_guard = component.lock().await; + component_response = component_guard.handle_request(component_request).await; + } + let http_response = HyperResponse::builder() + .header(CONTENT_TYPE, APPLICATION_OCTET_STREAM) + .body(Body::from( + serialize(&component_response).expect("Response serialization should succeed"), + )) + .expect("Response builidng should succeed"); + + Ok(http_response) + } +} diff --git a/crates/mempool_infra/tests/component_server_client_http_test.rs b/crates/mempool_infra/tests/component_server_client_http_test.rs index 2d8b8a00..6de58dbb 100644 --- a/crates/mempool_infra/tests/component_server_client_http_test.rs +++ b/crates/mempool_infra/tests/component_server_client_http_test.rs @@ -1,17 +1,13 @@ mod common; -use std::net::{IpAddr, SocketAddr}; -use std::sync::Arc; +use std::net::IpAddr; use async_trait::async_trait; use common::{ComponentAClientTrait, ComponentBClientTrait, ResultA, ResultB}; -use hyper::header::CONTENT_TYPE; -use hyper::service::{make_service_fn, service_fn}; -use hyper::{Body, Request, Response, Server}; use serde::{Deserialize, Serialize}; use starknet_mempool_infra::component_client::ComponentClientHttp; use starknet_mempool_infra::component_definitions::ComponentRequestHandler; -use tokio::sync::Mutex; +use starknet_mempool_infra::component_server::ComponentServerHttp; use tokio::task; use crate::common::{ComponentA, ComponentB, ValueA, ValueB}; @@ -46,59 +42,6 @@ impl ComponentRequestHandler for Componen } } -struct ComponentAServerHttp { - socket: SocketAddr, - component: Arc>, -} - -impl ComponentAServerHttp { - pub fn new(component: ComponentA, ip_address: IpAddr, port: u16) -> Self { - Self { - component: Arc::new(Mutex::new(component)), - socket: SocketAddr::new(ip_address, port), - } - } - - pub async fn start(&mut self) { - let make_svc = make_service_fn(|_conn| { - let component = Arc::clone(&self.component); - async { - Ok::<_, hyper::Error>(service_fn(move |req| { - Self::handler(req, Arc::clone(&component)) - })) - } - }); - - Server::bind(&self.socket.clone()).serve(make_svc).await.unwrap(); - } - - async fn handler( - http_request: Request, - component: Arc>, - ) -> Result, hyper::Error> { - let body_bytes = hyper::body::to_bytes(http_request.into_body()).await?; - let component_request: ComponentARequest = - bincode::deserialize(&body_bytes).expect("Request deserialization should succeed"); - - // Scoping is for releasing lock early (otherwise, component is locked until end of - // function) - let component_response; - { - let mut component_guard = component.lock().await; - component_response = component_guard.handle_request(component_request).await; - } - let http_response = Response::builder() - .header(CONTENT_TYPE, "application/octet-stream") - .body(Body::from( - bincode::serialize(&component_response) - .expect("Response serialization should succeed"), - )) - .expect("Response builidng should succeed"); - - Ok(http_response) - } -} - // Todo(uriel): Move to common #[derive(Serialize, Deserialize, Debug)] pub enum ComponentBRequest { @@ -129,56 +72,6 @@ impl ComponentRequestHandler for Componen } } -struct ComponentBServerHttp { - socket: SocketAddr, - component: Arc>, -} - -impl ComponentBServerHttp { - pub fn new(component: ComponentB, ip_address: IpAddr, port: u16) -> Self { - Self { - component: Arc::new(Mutex::new(component)), - socket: SocketAddr::new(ip_address, port), - } - } - - pub async fn start(&mut self) { - let make_svc = make_service_fn(|_conn| { - let component = Arc::clone(&self.component); - async { - Ok::<_, hyper::Error>(service_fn(move |req| { - Self::handler(req, Arc::clone(&component)) - })) - } - }); - - Server::bind(&self.socket.clone()).serve(make_svc).await.unwrap(); - } - - async fn handler( - http_request: Request, - component: Arc>, - ) -> Result, hyper::Error> { - let body_bytes = hyper::body::to_bytes(http_request.into_body()).await?; - let component_request: ComponentBRequest = - bincode::deserialize(&body_bytes).expect("Request deserialization should succeed"); - - // Scoping is for releasing lock early (otherwise, component is locked until end of - // function) - let component_response; - { - let mut component_guard = component.lock().await; - component_response = component_guard.handle_request(component_request).await; - } - let http_response = Response::builder() - .header(CONTENT_TYPE, "application/octet-stream") - .body(Body::from(bincode::serialize(&component_response).unwrap())) - .expect("Response builidng should succeed"); - - Ok(http_response) - } -} - async fn verify_response(ip_address: IpAddr, port: u16, expected_value: ValueA) { let a_client = ComponentClientHttp::new(ip_address, port); assert_eq!(a_client.a_get_value().await.unwrap(), expected_value); @@ -199,8 +92,8 @@ async fn test_setup() { let component_a = ComponentA::new(Box::new(b_client)); let component_b = ComponentB::new(setup_value, Box::new(a_client)); - let mut component_a_server = ComponentAServerHttp::new(component_a, local_ip, a_port); - let mut component_b_server = ComponentBServerHttp::new(component_b, local_ip, b_port); + let mut component_a_server = ComponentServerHttp::new(component_a, local_ip, a_port); + let mut component_b_server = ComponentServerHttp::new(component_b, local_ip, b_port); task::spawn(async move { component_a_server.start().await;