Skip to content

Commit

Permalink
refactor: sort component servers into dedicated files (#523)
Browse files Browse the repository at this point in the history
  • Loading branch information
uriel-starkware authored Jul 23, 2024
1 parent bd34515 commit b4aabe3
Show file tree
Hide file tree
Showing 10 changed files with 161 additions and 143 deletions.
4 changes: 3 additions & 1 deletion crates/gateway/src/communication.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use starknet_mempool_infra::component_server::{create_empty_server, EmptyServer};
use starknet_mempool_infra::component_server::empty_component_server::{
create_empty_server, EmptyServer,
};

use crate::gateway::Gateway;

Expand Down
2 changes: 1 addition & 1 deletion crates/mempool/src/communication.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use async_trait::async_trait;
use starknet_mempool_infra::component_definitions::ComponentRequestHandler;
use starknet_mempool_infra::component_runner::ComponentStarter;
use starknet_mempool_infra::component_server::LocalComponentServer;
use starknet_mempool_infra::component_server::local_component_server::LocalComponentServer;
use starknet_mempool_types::communication::{
MempoolRequest, MempoolRequestAndResponseSender, MempoolResponse,
};
Expand Down
22 changes: 22 additions & 0 deletions crates/mempool_infra/src/component_server/definitions.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
use async_trait::async_trait;
use tracing::{error, info};

use crate::component_runner::ComponentStarter;

#[async_trait]
pub trait ComponentServerStarter: Send + Sync {
async fn start(&mut self);
}

pub async fn start_component<Component>(component: &mut Component) -> bool
where
Component: ComponentStarter + Sync + Send,
{
if let Err(err) = component.start().await {
error!("ComponentServer::start() failed: {:?}", err);
return false;
}

info!("ComponentServer::start() completed.");
true
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
use async_trait::async_trait;

use super::definitions::{start_component, ComponentServerStarter};
use crate::component_runner::ComponentStarter;

pub struct EmptyServer<T: ComponentStarter + Send + Sync> {
component: T,
}

impl<T: ComponentStarter + Send + Sync> EmptyServer<T> {
pub fn new(component: T) -> Self {
Self { component }
}
}

#[async_trait]
impl<T: ComponentStarter + Send + Sync> ComponentServerStarter for EmptyServer<T> {
async fn start(&mut self) {
start_component(&mut self.component).await;
}
}

pub fn create_empty_server<T: ComponentStarter + Send + Sync>(component: T) -> EmptyServer<T> {
EmptyServer::new(component)
}
Original file line number Diff line number Diff line change
@@ -1,22 +1,8 @@
use std::marker::PhantomData;
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;

use async_trait::async_trait;
use bincode::{deserialize, serialize};
use hyper::body::to_bytes;
use hyper::header::CONTENT_TYPE;
use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Request as HyperRequest, Response as HyperResponse, Server, StatusCode};
use serde::{Deserialize, Serialize};
use tokio::sync::mpsc::Receiver;
use tokio::sync::Mutex;
use tracing::{error, info};

use crate::component_definitions::{
ComponentRequestAndResponseSender, ComponentRequestHandler, ServerError,
APPLICATION_OCTET_STREAM,
};
use super::definitions::{start_component, ComponentServerStarter};
use crate::component_definitions::{ComponentRequestAndResponseSender, ComponentRequestHandler};
use crate::component_runner::ComponentStarter;

/// The `LocalComponentServer` struct is a generic server that handles requests and responses for a
Expand Down Expand Up @@ -53,9 +39,8 @@ use crate::component_runner::ComponentStarter;
/// use crate::starknet_mempool_infra::component_definitions::{
/// ComponentRequestAndResponseSender, ComponentRequestHandler,
/// };
/// use crate::starknet_mempool_infra::component_server::{
/// ComponentServerStarter, LocalComponentServer,
/// };
/// use crate::starknet_mempool_infra::component_server::local_component_server::LocalComponentServer;
/// use crate::starknet_mempool_infra::component_server::definitions::ComponentServerStarter;
///
/// // Define your component
/// struct MyComponent {}
Expand Down Expand Up @@ -142,11 +127,6 @@ where
}
}

#[async_trait]
pub trait ComponentServerStarter: Send + Sync {
async fn start(&mut self);
}

#[async_trait]
impl<Component, Request, Response> ComponentServerStarter
for LocalComponentServer<Component, Request, Response>
Expand All @@ -168,117 +148,3 @@ where
}
}
}

pub async fn start_component<Component>(component: &mut Component) -> bool
where
Component: ComponentStarter + Sync + Send,
{
if let Err(err) = component.start().await {
error!("ComponentServer::start() failed: {:?}", err);
return false;
}

info!("ComponentServer::start() completed.");
true
}

pub struct RemoteComponentServer<Component, Request, Response>
where
Component: ComponentRequestHandler<Request, Response> + Send + 'static,
Request: for<'a> Deserialize<'a> + Send + 'static,
Response: Serialize + 'static,
{
socket: SocketAddr,
component: Arc<Mutex<Component>>,
_req: PhantomData<Request>,
_res: PhantomData<Response>,
}

impl<Component, Request, Response> RemoteComponentServer<Component, Request, Response>
where
Component: ComponentRequestHandler<Request, Response> + Send + 'static,
Request: for<'a> Deserialize<'a> + Send + 'static,
Response: Serialize + 'static,
{
pub fn new(component: Component, ip_address: IpAddr, port: u16) -> Self {
Self {
component: Arc::new(Mutex::new(component)),
socket: SocketAddr::new(ip_address, port),
_req: PhantomData,
_res: PhantomData,
}
}

async fn handler(
http_request: HyperRequest<Body>,
component: Arc<Mutex<Component>>,
) -> Result<HyperResponse<Body>, hyper::Error> {
let body_bytes = to_bytes(http_request.into_body()).await?;
let http_response = match deserialize(&body_bytes) {
Ok(component_request) => {
// Acquire the lock for component computation, release afterwards.
let component_response =
{ component.lock().await.handle_request(component_request).await };
HyperResponse::builder()
.status(StatusCode::OK)
.header(CONTENT_TYPE, APPLICATION_OCTET_STREAM)
.body(Body::from(
serialize(&component_response)
.expect("Response serialization should succeed"),
))
}
Err(error) => {
let server_error = ServerError::RequestDeserializationFailure(error.to_string());
HyperResponse::builder().status(StatusCode::BAD_REQUEST).body(Body::from(
serialize(&server_error).expect("Server error serialization should succeed"),
))
}
}
.expect("Response building should succeed");

Ok(http_response)
}
}

#[async_trait]
impl<Component, Request, Response> ComponentServerStarter
for RemoteComponentServer<Component, Request, Response>
where
Component: ComponentRequestHandler<Request, Response> + Send + 'static,
Request: for<'a> Deserialize<'a> + Send + Sync + 'static,
Response: Serialize + Send + Sync + 'static,
{
async fn start(&mut self) {
let make_svc = make_service_fn(|_conn| {
let component = Arc::clone(&self.component);
async {
Ok::<_, hyper::Error>(service_fn(move |req| {
Self::handler(req, Arc::clone(&component))
}))
}
});

Server::bind(&self.socket.clone()).serve(make_svc).await.unwrap();
}
}

pub struct EmptyServer<T: ComponentStarter + Send + Sync> {
component: T,
}

impl<T: ComponentStarter + Send + Sync> EmptyServer<T> {
pub fn new(component: T) -> Self {
Self { component }
}
}

#[async_trait]
impl<T: ComponentStarter + Send + Sync> ComponentServerStarter for EmptyServer<T> {
async fn start(&mut self) {
start_component(&mut self.component).await;
}
}

pub fn create_empty_server<T: ComponentStarter + Send + Sync>(component: T) -> EmptyServer<T> {
EmptyServer::new(component)
}
4 changes: 4 additions & 0 deletions crates/mempool_infra/src/component_server/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
pub mod definitions;
pub mod empty_component_server;
pub mod local_component_server;
pub mod remote_component_server;
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
use std::marker::PhantomData;
use std::net::{IpAddr, SocketAddr};
use std::sync::Arc;

use async_trait::async_trait;
use bincode::{deserialize, serialize};
use hyper::body::to_bytes;
use hyper::header::CONTENT_TYPE;
use hyper::service::{make_service_fn, service_fn};
use hyper::{Body, Request as HyperRequest, Response as HyperResponse, Server, StatusCode};
use serde::{Deserialize, Serialize};
use tokio::sync::Mutex;

use super::definitions::ComponentServerStarter;
use crate::component_definitions::{
ComponentRequestHandler, ServerError, APPLICATION_OCTET_STREAM,
};

pub struct RemoteComponentServer<Component, Request, Response>
where
Component: ComponentRequestHandler<Request, Response> + Send + 'static,
Request: for<'a> Deserialize<'a> + Send + 'static,
Response: Serialize + 'static,
{
socket: SocketAddr,
component: Arc<Mutex<Component>>,
_req: PhantomData<Request>,
_res: PhantomData<Response>,
}

impl<Component, Request, Response> RemoteComponentServer<Component, Request, Response>
where
Component: ComponentRequestHandler<Request, Response> + Send + 'static,
Request: for<'a> Deserialize<'a> + Send + 'static,
Response: Serialize + 'static,
{
pub fn new(component: Component, ip_address: IpAddr, port: u16) -> Self {
Self {
component: Arc::new(Mutex::new(component)),
socket: SocketAddr::new(ip_address, port),
_req: PhantomData,
_res: PhantomData,
}
}

async fn handler(
http_request: HyperRequest<Body>,
component: Arc<Mutex<Component>>,
) -> Result<HyperResponse<Body>, hyper::Error> {
let body_bytes = to_bytes(http_request.into_body()).await?;
let http_response = match deserialize(&body_bytes) {
Ok(component_request) => {
// Acquire the lock for component computation, release afterwards.
let component_response =
{ component.lock().await.handle_request(component_request).await };
HyperResponse::builder()
.status(StatusCode::OK)
.header(CONTENT_TYPE, APPLICATION_OCTET_STREAM)
.body(Body::from(
serialize(&component_response)
.expect("Response serialization should succeed"),
))
}
Err(error) => {
let server_error = ServerError::RequestDeserializationFailure(error.to_string());
HyperResponse::builder().status(StatusCode::BAD_REQUEST).body(Body::from(
serialize(&server_error).expect("Server error serialization should succeed"),
))
}
}
.expect("Response building should succeed");

Ok(http_response)
}
}

#[async_trait]
impl<Component, Request, Response> ComponentServerStarter
for RemoteComponentServer<Component, Request, Response>
where
Component: ComponentRequestHandler<Request, Response> + Send + 'static,
Request: for<'a> Deserialize<'a> + Send + Sync + 'static,
Response: Serialize + Send + Sync + 'static,
{
async fn start(&mut self) {
let make_svc = make_service_fn(|_conn| {
let component = Arc::clone(&self.component);
async {
Ok::<_, hyper::Error>(service_fn(move |req| {
Self::handler(req, Arc::clone(&component))
}))
}
});

Server::bind(&self.socket.clone()).serve(make_svc).await.unwrap();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ use starknet_mempool_infra::component_client::remote_component_client::RemoteCom
use starknet_mempool_infra::component_definitions::{
ComponentRequestHandler, ServerError, APPLICATION_OCTET_STREAM,
};
use starknet_mempool_infra::component_server::{ComponentServerStarter, RemoteComponentServer};
use starknet_mempool_infra::component_server::definitions::ComponentServerStarter;
use starknet_mempool_infra::component_server::remote_component_server::RemoteComponentServer;
use tokio::task;

type ComponentAClient = RemoteComponentClient<ComponentARequest, ComponentAResponse>;
Expand Down
3 changes: 2 additions & 1 deletion crates/mempool_infra/tests/component_server_client_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ use starknet_mempool_infra::component_client::local_component_client::LocalCompo
use starknet_mempool_infra::component_definitions::{
ComponentRequestAndResponseSender, ComponentRequestHandler,
};
use starknet_mempool_infra::component_server::{ComponentServerStarter, LocalComponentServer};
use starknet_mempool_infra::component_server::definitions::ComponentServerStarter;
use starknet_mempool_infra::component_server::local_component_server::LocalComponentServer;
use tokio::sync::mpsc::channel;
use tokio::task;

Expand Down
2 changes: 1 addition & 1 deletion crates/mempool_node/src/servers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use std::pin::Pin;
use futures::{Future, FutureExt};
use starknet_gateway::communication::{create_gateway_server, GatewayServer};
use starknet_mempool::communication::{create_mempool_server, MempoolServer};
use starknet_mempool_infra::component_server::ComponentServerStarter;
use starknet_mempool_infra::component_server::definitions::ComponentServerStarter;
use tracing::error;

use crate::communication::MempoolNodeCommunication;
Expand Down

0 comments on commit b4aabe3

Please sign in to comment.