From ac99c1c085608dd66e0dfabd0b962f440115ff39 Mon Sep 17 00:00:00 2001 From: Mohammad Nassar Date: Thu, 30 May 2024 16:52:35 +0300 Subject: [PATCH] refactor: update priority queue to use ThinPriorityTransaction --- crates/mempool/src/mempool.rs | 56 +++++++++++-------------- crates/mempool/src/mempool_test.rs | 7 ++-- crates/mempool/src/priority_queue.rs | 43 ++++++++++--------- crates/mempool/src/transaction_store.rs | 33 +++++++-------- 4 files changed, 67 insertions(+), 72 deletions(-) diff --git a/crates/mempool/src/mempool.rs b/crates/mempool/src/mempool.rs index b501ce1e..0a6c917b 100644 --- a/crates/mempool/src/mempool.rs +++ b/crates/mempool/src/mempool.rs @@ -42,30 +42,23 @@ impl Mempool { batcher_network, }; - mempool.txs_queue = TransactionPriorityQueue::from( - inputs - .into_iter() - .map(|input| { - // Attempts to insert a key-value pair into the mempool's state. Returns `None` - // if the key was not present, otherwise returns the old value while updating - // the new value. - let prev_value = - mempool.state.insert(input.account.address, input.account.state); - assert!( - prev_value.is_none(), - "Sender address: {:?} already exists in the mempool. Can't add {:?} to \ - the mempool.", - input.account.address, - input.tx - ); - - // Insert the transaction into the tx_store. - mempool.tx_store.push(input.tx.clone()); - - input.tx - }) - .collect::>(), - ); + for input in inputs.into_iter() { + // Attempts to insert a key-value pair into the mempool's state. Returns `None` + // if the key was not present, otherwise returns the old value while updating + // the new value. + let prev_value = mempool.state.insert(input.account.address, input.account.state); + assert!( + prev_value.is_none(), + "Sender address: {:?} already exists in the mempool. Can't add {:?} to the \ + mempool.", + input.account.address, + input.tx + ); + + // Insert the transaction into the tx_store. + mempool.tx_store.push(input.tx.clone()); + mempool.txs_queue.push(input.tx.clone().into()); + } mempool } @@ -83,10 +76,13 @@ impl Mempool { // back. TODO: Consider renaming to `pop_txs` to be more consistent with the standard // library. pub fn get_txs(&mut self, n_txs: usize) -> MempoolResult> { - let txs = self.txs_queue.pop_last_chunk(n_txs); - for tx in &txs { + let pq_txs = self.txs_queue.pop_last_chunk(n_txs); + + let mut txs: Vec = Vec::default(); + for pq_tx in &pq_txs { + let tx = self.tx_store.remove(&pq_tx.tx_hash).unwrap(); self.state.remove(&tx.sender_address); - self.tx_store.remove(&tx.sender_address, &tx.nonce); + txs.push(tx); } Ok(txs) @@ -100,10 +96,8 @@ impl Mempool { Occupied(_) => Err(MempoolError::DuplicateTransaction { tx_hash: tx.tx_hash }), Vacant(entry) => { entry.insert(account.state); - self.txs_queue.push(tx.clone()); - - // Insert the transaction into the tx_store - self.tx_store.entry(account.address).or_default().insert(tx.nonce, tx); + self.txs_queue.push(tx.clone().into()); + self.tx_store.push(tx); Ok(()) } diff --git a/crates/mempool/src/mempool_test.rs b/crates/mempool/src/mempool_test.rs index 45980911..83e70d95 100644 --- a/crates/mempool/src/mempool_test.rs +++ b/crates/mempool/src/mempool_test.rs @@ -13,7 +13,6 @@ use starknet_mempool_types::utils::create_thin_tx_for_testing; use tokio::sync::mpsc::channel; use crate::mempool::{Account, Mempool, MempoolInput}; -use crate::priority_queue::PrioritizedTransaction; fn create_for_testing(inputs: impl IntoIterator) -> Mempool { let (_, rx_gateway_to_mempool) = channel::(1); @@ -117,9 +116,9 @@ fn test_add_tx(mut mempool: Mempool) { mempool.state.contains_key(&account2.address); mempool.state.contains_key(&account3.address); - assert_eq!(mempool.txs_queue.pop_last().unwrap(), PrioritizedTransaction(tx_tip_100_address_1)); - assert_eq!(mempool.txs_queue.pop_last().unwrap(), PrioritizedTransaction(tx_tip_80_address_2)); - assert_eq!(mempool.txs_queue.pop_last().unwrap(), PrioritizedTransaction(tx_tip_50_address_0)); + assert_eq!(mempool.txs_queue.pop_last().unwrap(), tx_tip_100_address_1.into(),); + assert_eq!(mempool.txs_queue.pop_last().unwrap(), tx_tip_80_address_2.into()); + assert_eq!(mempool.txs_queue.pop_last().unwrap(), tx_tip_50_address_0.into()); } #[rstest] diff --git a/crates/mempool/src/priority_queue.rs b/crates/mempool/src/priority_queue.rs index 967cb1a8..05271acf 100644 --- a/crates/mempool/src/priority_queue.rs +++ b/crates/mempool/src/priority_queue.rs @@ -1,41 +1,38 @@ use std::cmp::Ordering; use std::collections::BTreeSet; +use starknet_api::core::Nonce; +use starknet_api::transaction::{Tip, TransactionHash}; use starknet_mempool_types::mempool_types::ThinTransaction; // Assumption: for the MVP only one transaction from the same contract class can be in the mempool // at a time. When this changes, saving the transactions themselves on the queu might no longer be // appropriate, because we'll also need to stores transactions without indexing them. For example, // transactions with future nonces will need to be stored, and potentially indexed on block commits. #[derive(Clone, Debug, Default, derive_more::Deref, derive_more::DerefMut)] -pub struct TransactionPriorityQueue(BTreeSet); +pub struct TransactionPriorityQueue(BTreeSet); impl TransactionPriorityQueue { - pub fn push(&mut self, tx: ThinTransaction) { - let mempool_tx = PrioritizedTransaction(tx); - self.insert(mempool_tx); + pub fn push(&mut self, tx: ThinPriorityTransaction) { + self.insert(tx); } // TODO(gilad): remove collect - pub fn pop_last_chunk(&mut self, n_txs: usize) -> Vec { - (0..n_txs).filter_map(|_| self.pop_last().map(|tx| tx.0)).collect() + pub fn pop_last_chunk(&mut self, n_txs: usize) -> Vec { + (0..n_txs).filter_map(|_| self.pop_last()).collect() } } -impl From> for TransactionPriorityQueue { - fn from(transactions: Vec) -> Self { - TransactionPriorityQueue(BTreeSet::from_iter( - transactions.into_iter().map(PrioritizedTransaction), - )) - } +#[derive(Clone, Debug, Default)] +pub struct ThinPriorityTransaction { + pub nonce: Nonce, + pub tx_hash: TransactionHash, + pub tip: Tip, } -#[derive(Clone, Debug, derive_more::Deref, derive_more::From)] -pub struct PrioritizedTransaction(pub ThinTransaction); - /// Compare transactions based only on their tip, a uint, using the Eq trait. It ensures that two /// tips are either exactly equal or not. -impl PartialEq for PrioritizedTransaction { - fn eq(&self, other: &PrioritizedTransaction) -> bool { +impl PartialEq for ThinPriorityTransaction { + fn eq(&self, other: &ThinPriorityTransaction) -> bool { self.tip == other.tip } } @@ -43,16 +40,22 @@ impl PartialEq for PrioritizedTransaction { /// Marks this struct as capable of strict equality comparisons, signaling to the compiler it /// adheres to equality semantics. // Note: this depends on the implementation of `PartialEq`, see its docstring. -impl Eq for PrioritizedTransaction {} +impl Eq for ThinPriorityTransaction {} -impl Ord for PrioritizedTransaction { +impl Ord for ThinPriorityTransaction { fn cmp(&self, other: &Self) -> Ordering { self.tip.cmp(&other.tip) } } -impl PartialOrd for PrioritizedTransaction { +impl PartialOrd for ThinPriorityTransaction { fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } + +impl From for ThinPriorityTransaction { + fn from(tx: ThinTransaction) -> Self { + ThinPriorityTransaction { nonce: tx.nonce, tx_hash: tx.tx_hash, tip: tx.tip } + } +} diff --git a/crates/mempool/src/transaction_store.rs b/crates/mempool/src/transaction_store.rs index 1ab15b3f..376ec273 100644 --- a/crates/mempool/src/transaction_store.rs +++ b/crates/mempool/src/transaction_store.rs @@ -1,34 +1,33 @@ use std::collections::{BTreeMap, HashMap}; use starknet_api::core::{ContractAddress, Nonce}; +use starknet_api::transaction::TransactionHash; use starknet_mempool_types::mempool_types::ThinTransaction; -#[derive(Clone, Debug, Default, derive_more::Deref, derive_more::DerefMut)] -pub struct TransactionStore(HashMap>); +#[derive(Clone, Debug, Default)] +pub struct TransactionStore { + store: HashMap>, + tx_hash_2_tx: HashMap, +} impl TransactionStore { pub fn push(&mut self, tx: ThinTransaction) { - self.entry(tx.sender_address).or_default().insert(tx.nonce, tx.clone()); + self.store.entry(tx.sender_address).or_default().insert(tx.nonce, tx.clone()); + self.tx_hash_2_tx.insert(tx.tx_hash, (tx.sender_address, tx.nonce)); } - pub fn remove( - &mut self, - sender_address: &ContractAddress, - nonce: &Nonce, - ) -> Option { - if let Some(tree) = self.0.get_mut(sender_address) { - return tree.remove(nonce); + pub fn remove(&mut self, tx_hash: &TransactionHash) -> Option { + let (address, nonce) = self.tx_hash_2_tx.remove(tx_hash).unwrap(); + if let Some(tree_map) = self.store.get_mut(&address) { + return tree_map.remove(&nonce); } None } - pub fn get( - &mut self, - sender_address: &ContractAddress, - nonce: &Nonce, - ) -> Option<&ThinTransaction> { - if let Some(tree) = self.0.get(sender_address) { - return tree.get(nonce); + pub fn get(&mut self, tx_hash: &TransactionHash) -> Option<&ThinTransaction> { + let (address, nonce) = self.tx_hash_2_tx.get(tx_hash).unwrap(); + if let Some(tree_map) = self.store.get(address) { + return tree_map.get(nonce); } None }