Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: make generic component server http #337

Merged
merged 1 commit into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 79 additions & 1 deletion crates/mempool_infra/src/component_server.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,19 @@
use std::marker::PhantomData;
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;

use bincode::{deserialize, serialize};
use hyper::body::to_bytes;
use hyper::header::CONTENT_TYPE;
use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Request as HyperRequest, Response as HyperResponse, Server};
use serde::{Deserialize, Serialize};
use tokio::sync::mpsc::Receiver;
use tokio::sync::Mutex;

use crate::component_definitions::{ComponentRequestAndResponseSender, ComponentRequestHandler};
use crate::component_definitions::{
ComponentRequestAndResponseSender, ComponentRequestHandler, APPLICATION_OCTET_STREAM,
};

pub struct ComponentServer<Component, Request, Response>
where
Expand Down Expand Up @@ -36,3 +49,68 @@ where
}
}
}

pub struct ComponentServerHttp<Component, Request, Response>
where
Component: ComponentRequestHandler<Request, Response> + Send + 'static,
Request: for<'a> Deserialize<'a> + Send + 'static,
Response: Serialize + 'static,
{
socket: SocketAddr,
component: Arc<Mutex<Component>>,
_req: PhantomData<Request>,
_res: PhantomData<Response>,
}

impl<Component, Request, Response> ComponentServerHttp<Component, Request, Response>
where
Component: ComponentRequestHandler<Request, Response> + Send + 'static,
Request: for<'a> Deserialize<'a> + Send + 'static,
Response: Serialize + 'static,
{
pub fn new(component: Component, ip_address: IpAddr, port: u16) -> Self {
Self {
component: Arc::new(Mutex::new(component)),
socket: SocketAddr::new(ip_address, port),
_req: PhantomData,
_res: PhantomData,
}
}

pub async fn start(&mut self) {
let make_svc = make_service_fn(|_conn| {
let component = Arc::clone(&self.component);
async {
Ok::<_, hyper::Error>(service_fn(move |req| {
Self::handler(req, Arc::clone(&component))
}))
}
});

Server::bind(&self.socket.clone()).serve(make_svc).await.unwrap();
}

async fn handler(
http_request: HyperRequest<Body>,
component: Arc<Mutex<Component>>,
) -> Result<HyperResponse<Body>, hyper::Error> {
let body_bytes = to_bytes(http_request.into_body()).await?;
let component_request: Request =
deserialize(&body_bytes).expect("Request deserialization should succeed");

// Acquire the lock for component computation, release afterwards.
let component_response;
{
let mut component_guard = component.lock().await;
component_response = component_guard.handle_request(component_request).await;
}
let http_response = HyperResponse::builder()
.header(CONTENT_TYPE, APPLICATION_OCTET_STREAM)
.body(Body::from(
serialize(&component_response).expect("Response serialization should succeed"),
))
.expect("Response builidng should succeed");

Ok(http_response)
}
}
123 changes: 12 additions & 111 deletions crates/mempool_infra/tests/component_server_client_http_test.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
mod common;

use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;
use std::net::IpAddr;

use async_trait::async_trait;
use common::{ComponentAClientTrait, ComponentBClientTrait, ResultA, ResultB};
use hyper::header::CONTENT_TYPE;
use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Request, Response, Server};
use serde::{Deserialize, Serialize};
use starknet_mempool_infra::component_client::ComponentClientHttp;
use starknet_mempool_infra::component_definitions::ComponentRequestHandler;
use tokio::sync::Mutex;
use starknet_mempool_infra::component_server::ComponentServerHttp;
use tokio::task;

use crate::common::{ComponentA, ComponentB, ValueA, ValueB};
Expand Down Expand Up @@ -46,59 +42,6 @@ impl ComponentRequestHandler<ComponentARequest, ComponentAResponse> for Componen
}
}

struct ComponentAServerHttp {
socket: SocketAddr,
component: Arc<Mutex<ComponentA>>,
}

impl ComponentAServerHttp {
pub fn new(component: ComponentA, ip_address: IpAddr, port: u16) -> Self {
Self {
component: Arc::new(Mutex::new(component)),
socket: SocketAddr::new(ip_address, port),
}
}

pub async fn start(&mut self) {
let make_svc = make_service_fn(|_conn| {
let component = Arc::clone(&self.component);
async {
Ok::<_, hyper::Error>(service_fn(move |req| {
Self::handler(req, Arc::clone(&component))
}))
}
});

Server::bind(&self.socket.clone()).serve(make_svc).await.unwrap();
}

async fn handler(
http_request: Request<Body>,
component: Arc<Mutex<ComponentA>>,
) -> Result<Response<Body>, hyper::Error> {
let body_bytes = hyper::body::to_bytes(http_request.into_body()).await?;
let component_request: ComponentARequest =
bincode::deserialize(&body_bytes).expect("Request deserialization should succeed");

// Scoping is for releasing lock early (otherwise, component is locked until end of
// function)
let component_response;
{
let mut component_guard = component.lock().await;
component_response = component_guard.handle_request(component_request).await;
}
let http_response = Response::builder()
.header(CONTENT_TYPE, "application/octet-stream")
.body(Body::from(
bincode::serialize(&component_response)
.expect("Response serialization should succeed"),
))
.expect("Response builidng should succeed");

Ok(http_response)
}
}

// Todo(uriel): Move to common
#[derive(Serialize, Deserialize, Debug)]
pub enum ComponentBRequest {
Expand Down Expand Up @@ -129,56 +72,6 @@ impl ComponentRequestHandler<ComponentBRequest, ComponentBResponse> for Componen
}
}

struct ComponentBServerHttp {
socket: SocketAddr,
component: Arc<Mutex<ComponentB>>,
}

impl ComponentBServerHttp {
pub fn new(component: ComponentB, ip_address: IpAddr, port: u16) -> Self {
Self {
component: Arc::new(Mutex::new(component)),
socket: SocketAddr::new(ip_address, port),
}
}

pub async fn start(&mut self) {
let make_svc = make_service_fn(|_conn| {
let component = Arc::clone(&self.component);
async {
Ok::<_, hyper::Error>(service_fn(move |req| {
Self::handler(req, Arc::clone(&component))
}))
}
});

Server::bind(&self.socket.clone()).serve(make_svc).await.unwrap();
}

async fn handler(
http_request: Request<Body>,
component: Arc<Mutex<ComponentB>>,
) -> Result<Response<Body>, hyper::Error> {
let body_bytes = hyper::body::to_bytes(http_request.into_body()).await?;
let component_request: ComponentBRequest =
bincode::deserialize(&body_bytes).expect("Request deserialization should succeed");

// Scoping is for releasing lock early (otherwise, component is locked until end of
// function)
let component_response;
{
let mut component_guard = component.lock().await;
component_response = component_guard.handle_request(component_request).await;
}
let http_response = Response::builder()
.header(CONTENT_TYPE, "application/octet-stream")
.body(Body::from(bincode::serialize(&component_response).unwrap()))
.expect("Response builidng should succeed");

Ok(http_response)
}
}

async fn verify_response(ip_address: IpAddr, port: u16, expected_value: ValueA) {
let a_client = ComponentClientHttp::new(ip_address, port);
assert_eq!(a_client.a_get_value().await.unwrap(), expected_value);
Expand All @@ -201,8 +94,16 @@ async fn test_setup() {
let component_a = ComponentA::new(Box::new(b_client));
let component_b = ComponentB::new(setup_value, Box::new(a_client));

let mut component_a_server = ComponentAServerHttp::new(component_a, local_ip, a_port);
let mut component_b_server = ComponentBServerHttp::new(component_b, local_ip, b_port);
let mut component_a_server = ComponentServerHttp::<
ComponentA,
ComponentARequest,
ComponentAResponse,
>::new(component_a, local_ip, a_port);
let mut component_b_server = ComponentServerHttp::<
ComponentB,
ComponentBRequest,
ComponentBResponse,
>::new(component_b, local_ip, b_port);

task::spawn(async move {
component_a_server.start().await;
Expand Down
Loading