From 7479fd6791a8c62e2ebb2b5f8477deadbd887676 Mon Sep 17 00:00:00 2001 From: Ayelet Zilber Date: Wed, 10 Jul 2024 10:52:06 +0300 Subject: [PATCH] test(mempool): add function to assert equality and implement Eq,PartialEq for related structs --- crates/mempool/src/mempool_test.rs | 42 +++++++++++++------------ crates/mempool/src/transaction_pool.rs | 4 +-- crates/mempool/src/transaction_queue.rs | 9 +++--- 3 files changed, 28 insertions(+), 27 deletions(-) diff --git a/crates/mempool/src/mempool_test.rs b/crates/mempool/src/mempool_test.rs index 8e26170be..ce37b73a7 100644 --- a/crates/mempool/src/mempool_test.rs +++ b/crates/mempool/src/mempool_test.rs @@ -1,3 +1,4 @@ +use std::cmp::Reverse; use std::collections::HashMap; use assert_matches::assert_matches; @@ -33,6 +34,11 @@ impl MempoolState { let tx_queue: TransactionQueue = queue_txs.into_iter().collect(); MempoolState { tx_pool, tx_queue } } + + fn assert_eq_mempool_state(&self, mempool: &Mempool) { + assert_eq!(self.tx_pool, mempool.tx_pool); + assert_eq!(self.tx_queue, mempool.tx_queue); + } } impl From for Mempool { @@ -139,35 +145,31 @@ fn assert_eq_mempool_queue(mempool: &Mempool, expected_queue: &[ThinTransaction] #[case::test_get_more_than_all_eligible_txs(5)] #[case::test_get_less_than_all_eligible_txs(2)] fn test_get_txs(#[case] requested_txs: usize) { - // TODO(Ayelet): Avoid cloning the transactions in the test. - let add_tx_inputs = [ - add_tx_input!(tip: 50, tx_hash: 1), - add_tx_input!(tip: 100, tx_hash: 2, sender_address: "0x1"), - add_tx_input!(tip: 10, tx_hash: 3, sender_address: "0x2"), - ]; - let tx_references_iterator = - add_tx_inputs.iter().map(|input| TransactionReference::new(&input.tx)); - let txs_iterator = add_tx_inputs.iter().map(|input| input.tx.clone()); + let tx1 = add_tx_input!(tip: 50, tx_hash: 1).tx; + let tx2 = add_tx_input!(tip: 100, tx_hash: 2, sender_address: "0x1").tx; + let tx3 = add_tx_input!(tip: 10, tx_hash: 3, sender_address: "0x2").tx; + + let mut tx_inputs = vec![tx1, tx2, tx3]; + let tx_references_iterator = tx_inputs.iter().map(TransactionReference::new); + let txs_iterator = tx_inputs.iter().cloned(); let mut mempool: Mempool = MempoolState::new(txs_iterator, tx_references_iterator).into(); let txs = mempool.get_txs(requested_txs).unwrap(); - let sorted_txs = [ - add_tx_inputs[1].tx.clone(), // tip 100 - add_tx_inputs[0].tx.clone(), // tip 50 - add_tx_inputs[2].tx.clone(), // tip 10 - ]; + tx_inputs.sort_by_key(|tx| Reverse(tx.tip)); - // This ensures we do not exceed the number of transactions available in the mempool. - let max_requested_txs = requested_txs.min(add_tx_inputs.len()); + // Ensure we do not exceed the number of transactions available in the mempool. + let max_requested_txs = requested_txs.min(tx_inputs.len()); - // checks that the returned transactions are the ones with the highest priority. - let (expected_queue, remaining_txs) = sorted_txs.split_at(max_requested_txs); + // Check that the returned transactions are the ones with the highest priority. + let (expected_queue, remaining_txs) = tx_inputs.split_at(max_requested_txs); assert_eq!(txs, expected_queue); - // checks that the transactions that were not returned are still in the mempool. - assert_eq_mempool_queue(&mempool, remaining_txs); + // Check that the transactions that were not returned are still in the mempool. + let remaining_tx_references = remaining_txs.iter().map(TransactionReference::new); + let mempool_state = MempoolState::new(remaining_txs.to_vec(), remaining_tx_references); + mempool_state.assert_eq_mempool_state(&mempool); } #[rstest] diff --git a/crates/mempool/src/transaction_pool.rs b/crates/mempool/src/transaction_pool.rs index 6a5811bcb..fa21c30a3 100644 --- a/crates/mempool/src/transaction_pool.rs +++ b/crates/mempool/src/transaction_pool.rs @@ -13,7 +13,7 @@ type HashToTransaction = HashMap; /// Invariant: both data structures are consistent regarding the existence of transactions: /// A transaction appears in one if and only if it appears in the other. /// No duplicate transactions appear in the pool. -#[derive(Debug, Default)] +#[derive(Debug, Default, Eq, PartialEq)] pub struct TransactionPool { // Holds the complete transaction objects; it should be the sole entity that does so. tx_pool: HashToTransaction, @@ -79,7 +79,7 @@ impl TransactionPool { } } -#[derive(Debug, Default)] +#[derive(Debug, Default, Eq, PartialEq)] struct AccountTransactionIndex(HashMap>); impl AccountTransactionIndex { diff --git a/crates/mempool/src/transaction_queue.rs b/crates/mempool/src/transaction_queue.rs index ddcbb9ab2..b7d7412eb 100644 --- a/crates/mempool/src/transaction_queue.rs +++ b/crates/mempool/src/transaction_queue.rs @@ -5,11 +5,10 @@ use starknet_api::core::{ContractAddress, Nonce}; use starknet_api::transaction::TransactionHash; use crate::mempool::TransactionReference; -// Assumption: for the MVP only one transaction from the same contract class can be in the mempool -// at a time. When this changes, saving the transactions themselves on the queu might no longer be -// appropriate, because we'll also need to stores transactions without indexing them. For example, -// transactions with future nonces will need to be stored, and potentially indexed on block commits. -#[derive(Debug, Default)] + +// Note: the derived comparison functionality considers the order guaranteed by the data structures +// used. +#[derive(Debug, Default, Eq, PartialEq)] pub struct TransactionQueue { // Priority queue of transactions with associated priority. queue: BTreeSet,