Skip to content

Commit

Permalink
feat: add a mempool proxy support for concurrent calls
Browse files Browse the repository at this point in the history
  • Loading branch information
uriel-starkware committed Apr 16, 2024
1 parent a268176 commit 06dff8d
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 37 deletions.
14 changes: 8 additions & 6 deletions crates/mempool_node/src/mempool.rs
Original file line number Diff line number Diff line change
@@ -1,30 +1,32 @@
use async_trait::async_trait;
use tokio::sync::Mutex;

pub type AddTransactionCallType = u32;
pub type AddTransactionReturnType = usize;

#[async_trait]
pub trait MempoolTrait {
async fn add_transaction(&mut self, tx: AddTransactionCallType) -> AddTransactionReturnType;
async fn add_transaction(&self, tx: AddTransactionCallType) -> AddTransactionReturnType;
}

#[derive(Default)]
pub struct Mempool {
transactions: Vec<u32>,
transactions: Mutex<Vec<u32>>,
}

impl Mempool {
pub fn new() -> Self {
Self {
transactions: vec![],
transactions: Mutex::new(vec![]),
}
}
}

#[async_trait]
impl MempoolTrait for Mempool {
async fn add_transaction(&mut self, tx: AddTransactionCallType) -> AddTransactionReturnType {
self.transactions.push(tx);
self.transactions.len()
async fn add_transaction(&self, tx: AddTransactionCallType) -> AddTransactionReturnType {
let mut guarded_transactions = self.transactions.lock().await;
guarded_transactions.push(tx);
guarded_transactions.len()
}
}
57 changes: 33 additions & 24 deletions crates/mempool_node/src/mempool_proxy.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use std::sync::Arc;

use crate::mempool::{AddTransactionCallType, AddTransactionReturnType, Mempool, MempoolTrait};
use async_trait::async_trait;
use std::sync::Arc;

use tokio::sync::mpsc::{channel, Receiver, Sender};
use tokio::sync::Mutex;
use tokio::sync::mpsc::{channel, Sender};
use tokio::task;

enum ProxyFunc {
Expand All @@ -14,48 +14,57 @@ enum ProxyRetValue {
AddTransaction(AddTransactionReturnType),
}

#[derive(Clone)]
pub struct MempoolProxy {
tx_call: Sender<ProxyFunc>,
rx_ret_value: Receiver<ProxyRetValue>,
tx_call: Sender<(ProxyFunc, Sender<ProxyRetValue>)>,
}

impl Default for MempoolProxy {
fn default() -> Self {
Self::new()
}
}

impl MempoolProxy {
pub fn new(mempool: Arc<Mutex<Mempool>>) -> Self {
let (tx_call, mut rx_call) = channel(32);
let (tx_ret_value, rx_ret_value) = channel(32);
pub fn new() -> Self {
let (tx_call, mut rx_call) = channel::<(ProxyFunc, Sender<ProxyRetValue>)>(32);

task::spawn(async move {
let mempool = Arc::new(Mempool::default());
while let Some(call) = rx_call.recv().await {
match call {
ProxyFunc::AddTransaction(tx) => {
let ret_value = mempool.lock().await.add_transaction(tx).await;
tx_ret_value
.send(ProxyRetValue::AddTransaction(ret_value))
.await
.expect("Sender of the func call is expecting a return value");
(ProxyFunc::AddTransaction(tx), tx_response) => {
let mempool = mempool.clone();
task::spawn(async move {
let ret_value = mempool.add_transaction(tx).await;
tx_response
.send(ProxyRetValue::AddTransaction(ret_value))
.await
.expect("Receiver should be listening.");
});
}
}
}
});

MempoolProxy {
tx_call,
rx_ret_value,
}
MempoolProxy { tx_call }
}
}

#[async_trait]
impl MempoolTrait for MempoolProxy {
async fn add_transaction(&mut self, tx: AddTransactionCallType) -> AddTransactionReturnType {
async fn add_transaction(&self, tx: AddTransactionCallType) -> AddTransactionReturnType {
let (tx_response, mut rx_response) = channel(32);
self.tx_call
.send(ProxyFunc::AddTransaction(tx))
.send((ProxyFunc::AddTransaction(tx), tx_response))
.await
.expect("Receiver is always listening in a dedicated task");
.expect("Receiver should be listening.");

match self.rx_ret_value.recv().await.expect(
"Receiver of the function call always returns a return value after sending a func call",
) {
match rx_response
.recv()
.await
.expect("Sender should be responding.")
{
ProxyRetValue::AddTransaction(ret_value) => ret_value,
}
}
Expand Down
73 changes: 66 additions & 7 deletions crates/mempool_node/src/mempool_proxy_test.rs
Original file line number Diff line number Diff line change
@@ -1,18 +1,77 @@
mod tests {

use std::sync::Arc;

use tokio::sync::Mutex;
use tokio::task::JoinSet;

use crate::{
mempool::{Mempool, MempoolTrait},
mempool::{AddTransactionCallType, AddTransactionReturnType, Mempool, MempoolTrait},
mempool_proxy::MempoolProxy,
};

#[tokio::test]
async fn test_proxy_add_transaction() {
let mempool = Arc::new(Mutex::new(Mempool::new()));
let mut proxy = MempoolProxy::new(mempool);
assert_eq!(proxy.add_transaction(1).await, 1);
assert_eq!(proxy.add_transaction(1).await, 2);
async fn test_mempool_simple_add_transaction() {
let mempool = Mempool::default();
let tx: AddTransactionCallType = 1;
let expected_result: AddTransactionReturnType = 1;
assert_eq!(mempool.add_transaction(tx).await, expected_result);
}

#[tokio::test]
async fn test_proxy_simple_add_transaction() {
let proxy = MempoolProxy::default();
let tx: AddTransactionCallType = 1;
let expected_result: AddTransactionReturnType = 1;
assert_eq!(proxy.add_transaction(tx).await, expected_result);
}

#[tokio::test]
async fn test_mempool_concurrent_add_transaction() {
let mempool = Arc::new(Mempool::default());

let mut tasks: JoinSet<_> = (0..5)
.map(|_| {
let mempool = mempool.clone();
async move {
let tx: AddTransactionCallType = 1;
mempool.add_transaction(tx).await
}
})
.collect();

let mut results: Vec<AddTransactionReturnType> = vec![];
while let Some(result) = tasks.join_next().await {
results.push(result.unwrap());
}

results.sort();

let expected_results: Vec<AddTransactionReturnType> = (1..=5).collect();
assert_eq!(results, expected_results);
}

#[tokio::test]
async fn test_proxy_concurrent_add_transaction() {
let proxy = MempoolProxy::default();

let mut tasks: JoinSet<_> = (0..5)
.map(|_| {
let proxy = proxy.clone();
async move {
let tx: AddTransactionCallType = 1;
proxy.add_transaction(tx).await
}
})
.collect();

let mut results: Vec<AddTransactionReturnType> = vec![];
while let Some(result) = tasks.join_next().await {
results.push(result.unwrap());
}

results.sort();

let expected_results: Vec<AddTransactionReturnType> = (1..=5).collect();
assert_eq!(results, expected_results);
}
}

0 comments on commit 06dff8d

Please sign in to comment.