diff --git a/crates/mempool_node/src/mempool.rs b/crates/mempool_node/src/mempool.rs index 42f54b12..37d985c0 100644 --- a/crates/mempool_node/src/mempool.rs +++ b/crates/mempool_node/src/mempool.rs @@ -1,30 +1,32 @@ use async_trait::async_trait; +use tokio::sync::Mutex; pub type AddTransactionCallType = u32; pub type AddTransactionReturnType = usize; #[async_trait] pub trait MempoolTrait { - async fn add_transaction(&mut self, tx: AddTransactionCallType) -> AddTransactionReturnType; + async fn add_transaction(&self, tx: AddTransactionCallType) -> AddTransactionReturnType; } #[derive(Default)] pub struct Mempool { - transactions: Vec, + transactions: Mutex>, } impl Mempool { pub fn new() -> Self { Self { - transactions: vec![], + transactions: Mutex::new(vec![]), } } } #[async_trait] impl MempoolTrait for Mempool { - async fn add_transaction(&mut self, tx: AddTransactionCallType) -> AddTransactionReturnType { - self.transactions.push(tx); - self.transactions.len() + async fn add_transaction(&self, tx: AddTransactionCallType) -> AddTransactionReturnType { + let mut guarded_transactions = self.transactions.lock().await; + guarded_transactions.push(tx); + guarded_transactions.len() } } diff --git a/crates/mempool_node/src/mempool_proxy.rs b/crates/mempool_node/src/mempool_proxy.rs index 60666d3d..c8d6726d 100644 --- a/crates/mempool_node/src/mempool_proxy.rs +++ b/crates/mempool_node/src/mempool_proxy.rs @@ -1,9 +1,9 @@ +use std::sync::Arc; + use crate::mempool::{AddTransactionCallType, AddTransactionReturnType, Mempool, MempoolTrait}; use async_trait::async_trait; -use std::sync::Arc; -use tokio::sync::mpsc::{channel, Receiver, Sender}; -use tokio::sync::Mutex; +use tokio::sync::mpsc::{channel, Sender}; use tokio::task; enum ProxyFunc { @@ -14,48 +14,57 @@ enum ProxyRetValue { AddTransaction(AddTransactionReturnType), } +#[derive(Clone)] pub struct MempoolProxy { - tx_call: Sender, - rx_ret_value: Receiver, + tx_call: Sender<(ProxyFunc, Sender)>, +} + +impl Default for MempoolProxy { + fn default() -> Self { + Self::new() + } } impl MempoolProxy { - pub fn new(mempool: Arc>) -> Self { - let (tx_call, mut rx_call) = channel(32); - let (tx_ret_value, rx_ret_value) = channel(32); + pub fn new() -> Self { + let (tx_call, mut rx_call) = channel::<(ProxyFunc, Sender)>(32); task::spawn(async move { + let mempool = Arc::new(Mempool::default()); while let Some(call) = rx_call.recv().await { match call { - ProxyFunc::AddTransaction(tx) => { - let ret_value = mempool.lock().await.add_transaction(tx).await; - tx_ret_value - .send(ProxyRetValue::AddTransaction(ret_value)) - .await - .expect("Sender of the func call is expecting a return value"); + (ProxyFunc::AddTransaction(tx), tx_response) => { + let mempool = mempool.clone(); + task::spawn(async move { + let ret_value = mempool.add_transaction(tx).await; + tx_response + .send(ProxyRetValue::AddTransaction(ret_value)) + .await + .expect("Receiver should be listening."); + }); } } } }); - MempoolProxy { - tx_call, - rx_ret_value, - } + MempoolProxy { tx_call } } } #[async_trait] impl MempoolTrait for MempoolProxy { - async fn add_transaction(&mut self, tx: AddTransactionCallType) -> AddTransactionReturnType { + async fn add_transaction(&self, tx: AddTransactionCallType) -> AddTransactionReturnType { + let (tx_response, mut rx_response) = channel(32); self.tx_call - .send(ProxyFunc::AddTransaction(tx)) + .send((ProxyFunc::AddTransaction(tx), tx_response)) .await - .expect("Receiver is always listening in a dedicated task"); + .expect("Receiver should be listening."); - match self.rx_ret_value.recv().await.expect( - "Receiver of the function call always returns a return value after sending a func call", - ) { + match rx_response + .recv() + .await + .expect("Sender should be responding.") + { ProxyRetValue::AddTransaction(ret_value) => ret_value, } } diff --git a/crates/mempool_node/src/mempool_proxy_test.rs b/crates/mempool_node/src/mempool_proxy_test.rs index 2441b842..5751a54b 100644 --- a/crates/mempool_node/src/mempool_proxy_test.rs +++ b/crates/mempool_node/src/mempool_proxy_test.rs @@ -1,18 +1,77 @@ mod tests { + use std::sync::Arc; - use tokio::sync::Mutex; + use tokio::task::JoinSet; use crate::{ - mempool::{Mempool, MempoolTrait}, + mempool::{AddTransactionCallType, AddTransactionReturnType, Mempool, MempoolTrait}, mempool_proxy::MempoolProxy, }; #[tokio::test] - async fn test_proxy_add_transaction() { - let mempool = Arc::new(Mutex::new(Mempool::new())); - let mut proxy = MempoolProxy::new(mempool); - assert_eq!(proxy.add_transaction(1).await, 1); - assert_eq!(proxy.add_transaction(1).await, 2); + async fn test_mempool_simple_add_transaction() { + let mempool = Mempool::default(); + let tx: AddTransactionCallType = 1; + let expected_result: AddTransactionReturnType = 1; + assert_eq!(mempool.add_transaction(tx).await, expected_result); + } + + #[tokio::test] + async fn test_proxy_simple_add_transaction() { + let proxy = MempoolProxy::default(); + let tx: AddTransactionCallType = 1; + let expected_result: AddTransactionReturnType = 1; + assert_eq!(proxy.add_transaction(tx).await, expected_result); + } + + #[tokio::test] + async fn test_mempool_concurrent_add_transaction() { + let mempool = Arc::new(Mempool::default()); + + let mut tasks: JoinSet<_> = (0..5) + .map(|_| { + let mempool = mempool.clone(); + async move { + let tx: AddTransactionCallType = 1; + mempool.add_transaction(tx).await + } + }) + .collect(); + + let mut results: Vec = vec![]; + while let Some(result) = tasks.join_next().await { + results.push(result.unwrap()); + } + + results.sort(); + + let expected_results: Vec = (1..=5).collect(); + assert_eq!(results, expected_results); + } + + #[tokio::test] + async fn test_proxy_concurrent_add_transaction() { + let proxy = MempoolProxy::default(); + + let mut tasks: JoinSet<_> = (0..5) + .map(|_| { + let proxy = proxy.clone(); + async move { + let tx: AddTransactionCallType = 1; + proxy.add_transaction(tx).await + } + }) + .collect(); + + let mut results: Vec = vec![]; + while let Some(result) = tasks.join_next().await { + results.push(result.unwrap()); + } + + results.sort(); + + let expected_results: Vec = (1..=5).collect(); + assert_eq!(results, expected_results); } }