From 678dac5c282e6ea68b094d437983f178d9239fc7 Mon Sep 17 00:00:00 2001 From: Ayelet Zilber Date: Thu, 25 Apr 2024 16:18:43 +0300 Subject: [PATCH] feat: implement AddressPriorityQueue in Mempool --- crates/gateway/src/gateway_test.rs | 11 +- crates/mempool/src/mempool.rs | 91 ++++++++-------- crates/mempool/src/mempool_test.rs | 156 ++++++++++++++++++--------- crates/mempool/src/priority_queue.rs | 14 ++- 4 files changed, 168 insertions(+), 104 deletions(-) diff --git a/crates/gateway/src/gateway_test.rs b/crates/gateway/src/gateway_test.rs index 6725b797c..298e2e0e3 100644 --- a/crates/gateway/src/gateway_test.rs +++ b/crates/gateway/src/gateway_test.rs @@ -6,6 +6,7 @@ use axum::extract::State; use axum::http::StatusCode; use axum::response::{IntoResponse, Response}; use blockifier::context::ChainInfo; +use rstest::fixture; use starknet_api::external_transaction::ExternalTransaction; use starknet_api::transaction::TransactionHash; use starknet_mempool::mempool::{create_mempool_server, Mempool}; @@ -25,6 +26,11 @@ use crate::utils::{external_tx_to_account_tx, get_tx_hash}; const MEMPOOL_INVOCATIONS_QUEUE_SIZE: usize = 32; +#[fixture] +fn mempool() -> Mempool { + Mempool::empty() +} + pub fn app_state(mempool_client: Arc) -> AppState { AppState { stateless_tx_validator: StatelessTransactionValidator { @@ -46,14 +52,11 @@ pub fn app_state(mempool_client: Arc) -> AppState { // TODO(Ayelet): add test cases for declare and deploy account transactions. #[tokio::test] async fn test_add_tx() { - // TODO: Add fixture. - - let mempool = Mempool::new([]); // TODO(Tsabary): wrap creation of channels in dedicated functions, take channel capacity from // config. let (tx_mempool, rx_mempool) = channel::(MEMPOOL_INVOCATIONS_QUEUE_SIZE); - let mut mempool_server = create_mempool_server(mempool, rx_mempool); + let mut mempool_server = create_mempool_server(mempool(), rx_mempool); task::spawn(async move { mempool_server.start().await; }); diff --git a/crates/mempool/src/mempool.rs b/crates/mempool/src/mempool.rs index fc40a78de..aa39a3a61 100644 --- a/crates/mempool/src/mempool.rs +++ b/crates/mempool/src/mempool.rs @@ -1,4 +1,3 @@ -use std::collections::hash_map::Entry::{Occupied, Vacant}; use std::collections::HashMap; use async_trait::async_trait; @@ -8,66 +7,57 @@ use starknet_mempool_infra::component_definitions::ComponentRequestHandler; use starknet_mempool_infra::component_server::ComponentServer; use starknet_mempool_types::errors::MempoolError; use starknet_mempool_types::mempool_types::{ - Account, AccountState, MempoolInput, MempoolRequest, MempoolRequestAndResponseSender, - MempoolResponse, MempoolResult, ThinTransaction, + AccountState, MempoolInput, MempoolRequest, MempoolRequestAndResponseSender, MempoolResponse, + MempoolResult, ThinTransaction, }; use tokio::sync::mpsc::Receiver; -use crate::priority_queue::TransactionPriorityQueue; +use crate::priority_queue::{AddressStore, TransactionPriorityQueue}; #[cfg(test)] #[path = "mempool_test.rs"] pub mod mempool_test; -#[derive(Debug)] +#[derive(Default)] pub struct Mempool { // TODO: add docstring explaining visibility and coupling of the fields. txs_queue: TransactionPriorityQueue, - state: HashMap, + address_to_store: HashMap, } impl Mempool { - pub fn new(inputs: impl IntoIterator) -> Self { - let mut mempool = - Mempool { txs_queue: TransactionPriorityQueue::default(), state: HashMap::default() }; - - mempool.txs_queue = TransactionPriorityQueue::from( - inputs - .into_iter() - .map(|input| { - // Attempts to insert a key-value pair into the mempool's state. Returns `None` - // if the key was not present, otherwise returns the old value while updating - // the new value. - let prev_value = - mempool.state.insert(input.account.sender_address, input.account.state); - assert!( - prev_value.is_none(), - "Sender address: {:?} already exists in the mempool. Can't add {:?} to \ - the mempool.", - input.account.sender_address, - input.tx - ); - input.tx - }) - .collect::>(), - ); - - mempool + pub fn new(inputs: impl IntoIterator) -> MempoolResult { + let mut mempool = Mempool::default(); + + for input in inputs { + mempool.insert_tx(input.tx)?; + } + + Ok(mempool) } pub fn empty() -> Self { - Mempool::new([]) + Mempool::default() } /// Retrieves up to `n_txs` transactions with the highest priority from the mempool. /// Transactions are guaranteed to be unique across calls until `commit_block` is invoked. // TODO: the last part about commit_block is incorrect if we delete txs in get_txs and then push - // back. TODO: Consider renaming to `pop_txs` to be more consistent with the standard - // library. + // back. TODO: Consider renaming to `pop_txs` to be more consistent with the standard library. + // TODO: If `n_txs` is greater than the number of transactions in `txs_queue`, it will also + // check and add transactions from `address_to_store`. pub fn get_txs(&mut self, n_txs: usize) -> MempoolResult> { let txs = self.txs_queue.pop_last_chunk(n_txs); for tx in &txs { - self.state.remove(&tx.sender_address); + if let Some(address_queue) = self.address_to_store.get_mut(&tx.sender_address) { + address_queue.pop_front(); + + if address_queue.is_empty() { + self.address_to_store.remove(&tx.sender_address); + } else if let Some(next_tx) = address_queue.top() { + self.txs_queue.push(next_tx.clone()); + } + } } Ok(txs) @@ -76,15 +66,9 @@ impl Mempool { /// Adds a new transaction to the mempool. /// TODO: support fee escalation and transactions with future nonces. /// TODO: change input type to `MempoolInput`. - pub fn add_tx(&mut self, tx: ThinTransaction, account: Account) -> MempoolResult<()> { - match self.state.entry(account.sender_address) { - Occupied(_) => Err(MempoolError::DuplicateTransaction { tx_hash: tx.tx_hash }), - Vacant(entry) => { - entry.insert(account.state); - self.txs_queue.push(tx); - Ok(()) - } - } + pub fn add_tx(&mut self, tx: ThinTransaction) -> MempoolResult<()> { + self.insert_tx(tx)?; + Ok(()) } /// Update the mempool's internal state according to the committed block's transactions. @@ -99,6 +83,21 @@ impl Mempool { ) -> MempoolResult<()> { todo!() } + + fn insert_tx(&mut self, tx: ThinTransaction) -> MempoolResult<()> { + let address_queue = self.address_to_store.entry(tx.sender_address).or_default(); + + if address_queue.contains(&tx) { + return Err(MempoolError::DuplicateTransaction { tx_hash: tx.tx_hash }); + } + + address_queue.push(tx.clone()); + if address_queue.len() == 1 { + self.txs_queue.push(tx); + } + + Ok(()) + } } /// Wraps the mempool to enable inbound async communication from other components. @@ -112,7 +111,7 @@ impl MempoolCommunicationWrapper { } fn add_tx(&mut self, mempool_input: MempoolInput) -> MempoolResult<()> { - self.mempool.add_tx(mempool_input.tx, mempool_input.account) + self.mempool.add_tx(mempool_input.tx) } fn get_txs(&mut self, n_txs: usize) -> MempoolResult> { diff --git a/crates/mempool/src/mempool_test.rs b/crates/mempool/src/mempool_test.rs index 8d81b9c0b..5fd73db94 100644 --- a/crates/mempool/src/mempool_test.rs +++ b/crates/mempool/src/mempool_test.rs @@ -7,9 +7,10 @@ use starknet_api::hash::{StarkFelt, StarkHash}; use starknet_api::transaction::{Tip, TransactionHash}; use starknet_api::{contract_address, patricia_key}; use starknet_mempool_types::errors::MempoolError; -use starknet_mempool_types::mempool_types::ThinTransaction; +use starknet_mempool_types::mempool_types::{Account, MempoolInput, ThinTransaction}; -use crate::mempool::{Account, Mempool, MempoolInput}; +use crate::mempool::Mempool; +use crate::priority_queue::PrioritizedTransaction; /// Creates a valid input for mempool's `add_tx` with optional default value for /// `sender_address`. @@ -18,8 +19,6 @@ use crate::mempool::{Account, Mempool, MempoolInput}; /// 2. add_tx_input!(tip, tx_hash, address) /// 3. add_tx_input!(tip, tx_hash) // TODO: Return MempoolInput once it's used in `add_tx`. -// TODO: remove unused macro_rules warning when the macro is used. -#[allow(unused_macro_rules)] macro_rules! add_tx_input { // Pattern for all four arguments ($tip:expr, $tx_hash:expr, $sender_address:expr, $nonce:expr) => {{ @@ -44,7 +43,50 @@ macro_rules! add_tx_input { #[fixture] fn mempool() -> Mempool { - Mempool::new([]) + Mempool::empty() +} + +#[test] +fn test_new_with_duplicate_tx() { + let (tx, account) = add_tx_input!(Tip(0), TransactionHash(StarkFelt::ONE)); + let same_tx = tx.clone(); + + let inputs = vec![MempoolInput { tx, account }, MempoolInput { tx: same_tx, account }]; + + assert!(matches!( + Mempool::new(inputs), + Err(MempoolError::DuplicateTransaction { tx_hash: TransactionHash(StarkFelt::ONE) }) + )); +} + +#[test] +fn test_new_success() { + let (tx0, account0) = + add_tx_input!(Tip(50), TransactionHash(StarkFelt::ZERO), contract_address!("0x0")); + let (tx1, account1) = + add_tx_input!(Tip(60), TransactionHash(StarkFelt::ONE), contract_address!("0x1")); + let (tx3, _) = add_tx_input!( + Tip(80), + TransactionHash(StarkFelt::THREE), + contract_address!("0x0"), + Nonce(StarkFelt::ONE) + ); + + let inputs = vec![ + MempoolInput { tx: tx0.clone(), account: account0 }, + MempoolInput { tx: tx1.clone(), account: account1 }, + MempoolInput { tx: tx3.clone(), account: account0 }, + ]; + + let mempool = Mempool::new(inputs).unwrap(); + + assert!(mempool.address_to_store.get(&account0.sender_address).unwrap().contains(&tx0)); + assert!(mempool.address_to_store.get(&account1.sender_address).unwrap().contains(&tx1)); + assert!(mempool.address_to_store.get(&account0.sender_address).unwrap().contains(&tx3)); + + assert!(mempool.txs_queue.contains(&PrioritizedTransaction(tx0))); + assert!(mempool.txs_queue.contains(&PrioritizedTransaction(tx1))); + assert!(!mempool.txs_queue.contains(&PrioritizedTransaction(tx3))); } #[rstest] @@ -57,24 +99,37 @@ fn test_get_txs(#[case] requested_txs: usize) { add_tx_input!(Tip(100), TransactionHash(StarkFelt::TWO), contract_address!("0x1")); let (tx_tip_10_address_2, account3) = add_tx_input!(Tip(10), TransactionHash(StarkFelt::THREE), contract_address!("0x2")); + let (tx2_address_0, _) = add_tx_input!( + Tip(50), + TransactionHash(StarkFelt::ZERO), + contract_address!("0x0"), + Nonce(StarkFelt::ONE) + ); - let mut mempool = Mempool::new([ + let inputs = [ MempoolInput { tx: tx_tip_50_address_0.clone(), account: account1 }, MempoolInput { tx: tx_tip_100_address_1.clone(), account: account2 }, MempoolInput { tx: tx_tip_10_address_2.clone(), account: account3 }, - ]); + MempoolInput { tx: tx2_address_0.clone(), account: account1 }, + ]; - let expected_addresses = - vec![contract_address!("0x0"), contract_address!("0x1"), contract_address!("0x2")]; - // checks that the transactions were added to the mempool. - for address in &expected_addresses { - assert!(mempool.state.contains_key(address)); - } + let mut mempool = Mempool::new(inputs).unwrap(); + + assert!(mempool.txs_queue.contains(&PrioritizedTransaction(tx_tip_50_address_0.clone()))); + assert!(mempool.txs_queue.contains(&PrioritizedTransaction(tx_tip_100_address_1.clone()))); + assert!(mempool.txs_queue.contains(&PrioritizedTransaction(tx_tip_10_address_2.clone()))); + assert!(!mempool.txs_queue.contains(&PrioritizedTransaction(tx2_address_0.clone()))); let sorted_txs = vec![tx_tip_100_address_1, tx_tip_50_address_0, tx_tip_10_address_2]; let txs = mempool.get_txs(requested_txs).unwrap(); + // check that the account1's queue and the mempool's txs_queue are updated. + assert!( + mempool.address_to_store.get(&account1.sender_address).unwrap().contains(&tx2_address_0) + ); + assert!(mempool.txs_queue.contains(&PrioritizedTransaction(tx2_address_0))); + // This ensures we do not exceed the priority queue's limit of 3 transactions. let max_requested_txs = requested_txs.min(3); @@ -83,25 +138,10 @@ fn test_get_txs(#[case] requested_txs: usize) { assert_eq!(txs, sorted_txs[..max_requested_txs].to_vec()); // checks that the transactions that were not returned are still in the mempool. - let actual_addresses: Vec<&ContractAddress> = mempool.state.keys().collect(); - let expected_remaining_addresses: Vec<&ContractAddress> = - expected_addresses[max_requested_txs..].iter().collect(); - assert_eq!(actual_addresses, expected_remaining_addresses,); -} - -#[rstest] -#[should_panic(expected = "Sender address: \ - ContractAddress(PatriciaKey(StarkFelt(\"\ - 0x0000000000000000000000000000000000000000000000000000000000000000\"\ - ))) already exists in the mempool. Can't add")] -fn test_mempool_initialization_with_duplicate_sender_addresses() { - let (tx, account) = add_tx_input!(Tip(50), TransactionHash(StarkFelt::ONE)); - let same_tx = tx.clone(); - - let inputs = vec![MempoolInput { tx, account }, MempoolInput { tx: same_tx, account }]; - - // This call should panic because of duplicate sender addresses - let _mempool = Mempool::new(inputs.into_iter()); + let expected_remaining_txs: Vec<&ThinTransaction> = txs[max_requested_txs..].iter().collect(); + for tx in expected_remaining_txs { + assert!(mempool.txs_queue.contains(&PrioritizedTransaction(tx.clone()))); + } } #[rstest] @@ -111,15 +151,29 @@ fn test_add_tx(mut mempool: Mempool) { add_tx_input!(Tip(100), TransactionHash(StarkFelt::TWO), contract_address!("0x1")); let (tx_tip_80_address_2, account3) = add_tx_input!(Tip(80), TransactionHash(StarkFelt::THREE), contract_address!("0x2")); + let (tx2_address_0, _) = add_tx_input!( + Tip(50), + TransactionHash(StarkFelt::ZERO), + contract_address!("0x0"), + Nonce(StarkFelt::ONE) + ); + + assert_matches!(mempool.add_tx(tx_tip_50_address_0.clone()), Ok(())); + assert_matches!(mempool.add_tx(tx_tip_100_address_1.clone()), Ok(())); + assert_matches!(mempool.add_tx(tx_tip_80_address_2.clone()), Ok(())); + assert_matches!(mempool.add_tx(tx2_address_0.clone()), Ok(())); + + assert_eq!(mempool.txs_queue.len(), 3); + + let account_0_queue = mempool.address_to_store.get(&account1.sender_address).unwrap(); + assert_eq!(&tx_tip_50_address_0, account_0_queue.0.first().unwrap()); + assert_eq!(&tx2_address_0, account_0_queue.0.last().unwrap()); - assert_matches!(mempool.add_tx(tx_tip_50_address_0.clone(), account1), Ok(())); - assert_matches!(mempool.add_tx(tx_tip_100_address_1.clone(), account2), Ok(())); - assert_matches!(mempool.add_tx(tx_tip_80_address_2.clone(), account3), Ok(())); + let account_1_queue = mempool.address_to_store.get(&account2.sender_address).unwrap(); + assert_eq!(&tx_tip_100_address_1, account_1_queue.0.first().unwrap()); - assert_eq!(mempool.state.len(), 3); - mempool.state.contains_key(&account1.sender_address); - mempool.state.contains_key(&account2.sender_address); - mempool.state.contains_key(&account3.sender_address); + let account_2_queue = mempool.address_to_store.get(&account3.sender_address).unwrap(); + assert_eq!(&tx_tip_80_address_2, account_2_queue.0.first().unwrap()); check_mempool_txs_eq( &mempool, @@ -128,13 +182,13 @@ fn test_add_tx(mut mempool: Mempool) { } #[rstest] -fn test_add_same_tx(mut mempool: Mempool) { - let (tx, account) = add_tx_input!(Tip(50), TransactionHash(StarkFelt::ONE)); +fn test_add_tx_with_duplicate_tx(mut mempool: Mempool) { + let (tx, _account) = add_tx_input!(Tip(50), TransactionHash(StarkFelt::ONE)); let same_tx = tx.clone(); - assert_matches!(mempool.add_tx(tx.clone(), account), Ok(())); + assert_matches!(mempool.add_tx(tx.clone()), Ok(())); assert_matches!( - mempool.add_tx(same_tx, account), + mempool.add_tx(same_tx), Err(MempoolError::DuplicateTransaction { tx_hash: TransactionHash(StarkFelt::ONE) }) ); // Assert that the original tx remains in the pool after the failed attempt. @@ -156,15 +210,15 @@ fn check_mempool_txs_eq(mempool: &Mempool, expected_txs: &[ThinTransaction]) { #[rstest] fn test_add_tx_with_identical_tip_succeeds(mut mempool: Mempool) { - let (tx1, account1) = add_tx_input!(Tip(1), TransactionHash(StarkFelt::TWO)); + let (tx1, _account1) = add_tx_input!(Tip(1), TransactionHash(StarkFelt::TWO)); // Create a transaction with identical tip, it should be allowed through since the priority // queue tie-breaks identical tips by other tx-unique identifiers (for example tx hash). - let (tx2, account2) = + let (tx2, _account2) = add_tx_input!(Tip(1), TransactionHash(StarkFelt::ONE), contract_address!("0x1")); - assert!(mempool.add_tx(tx1.clone(), account1).is_ok()); - assert!(mempool.add_tx(tx2.clone(), account2).is_ok()); + assert!(mempool.add_tx(tx1.clone()).is_ok()); + assert!(mempool.add_tx(tx2.clone()).is_ok()); // TODO: currently hash comparison tie-breaks the two. Once more robust tie-breaks are added // replace this assertion with a dedicated test. @@ -173,14 +227,14 @@ fn test_add_tx_with_identical_tip_succeeds(mut mempool: Mempool) { #[rstest] fn test_tip_priority_over_tx_hash(mut mempool: Mempool) { - let (tx_big_tip_small_hash, account1) = add_tx_input!(Tip(2), TransactionHash(StarkFelt::ONE)); + let (tx_big_tip_small_hash, _account1) = add_tx_input!(Tip(2), TransactionHash(StarkFelt::ONE)); // Create a transaction with identical tip, it should be allowed through since the priority // queue tie-breaks identical tips by other tx-unique identifiers (for example tx hash). - let (tx_small_tip_big_hash, account2) = + let (tx_small_tip_big_hash, _account2) = add_tx_input!(Tip(1), TransactionHash(StarkFelt::TWO), contract_address!("0x1")); - assert!(mempool.add_tx(tx_big_tip_small_hash.clone(), account1).is_ok()); - assert!(mempool.add_tx(tx_small_tip_big_hash.clone(), account2).is_ok()); + assert!(mempool.add_tx(tx_big_tip_small_hash.clone()).is_ok()); + assert!(mempool.add_tx(tx_small_tip_big_hash.clone()).is_ok()); check_mempool_txs_eq(&mempool, &[tx_small_tip_big_hash, tx_big_tip_small_hash]) } diff --git a/crates/mempool/src/priority_queue.rs b/crates/mempool/src/priority_queue.rs index cec62d0c8..6c1b8894d 100644 --- a/crates/mempool/src/priority_queue.rs +++ b/crates/mempool/src/priority_queue.rs @@ -2,8 +2,9 @@ use std::cmp::Ordering; use std::collections::BTreeSet; use starknet_mempool_types::mempool_types::ThinTransaction; + // Assumption: for the MVP only one transaction from the same contract class can be in the mempool -// at a time. When this changes, saving the transactions themselves on the queu might no longer be +// at a time. When this changes, saving the transactions themselves on the queue might no longer be // appropriate, because we'll also need to stores transactions without indexing them. For example, // transactions with future nonces will need to be stored, and potentially indexed on block commits. #[derive(Clone, Debug, Default, derive_more::Deref, derive_more::DerefMut)] @@ -61,12 +62,15 @@ impl PartialOrd for PrioritizedTransaction { // TODO: remove when is used. #[allow(dead_code)] +#[derive(Debug, Default)] // Assumption: there are no gaps, and the transactions are received in order. -pub struct AddressPriorityQueue(pub Vec); +// TODO: support fee escalation +// TODO: support transactions with future nonces. +pub struct AddressStore(pub Vec); // TODO: remove when is used. #[allow(dead_code)] -impl AddressPriorityQueue { +impl AddressStore { pub fn push(&mut self, tx: ThinTransaction) { self.0.push(tx); } @@ -86,4 +90,8 @@ impl AddressPriorityQueue { pub fn contains(&self, tx: &ThinTransaction) -> bool { self.0.contains(tx) } + + pub fn len(&self) -> usize { + self.0.len() + } }