diff --git a/crates/mempool/src/mempool.rs b/crates/mempool/src/mempool.rs index dac3acb2e..3e8c9bfe2 100644 --- a/crates/mempool/src/mempool.rs +++ b/crates/mempool/src/mempool.rs @@ -52,7 +52,7 @@ impl Mempool { pub fn get_txs(&mut self, n_txs: usize) -> MempoolResult> { let mut eligible_txs: Vec = Vec::with_capacity(n_txs); for tx_hash in self.tx_queue.pop_last_chunk(n_txs) { - let tx = self.tx_pool.remove(tx_hash)?; + let tx = self.tx_pool.get(tx_hash)?.clone(); assert!(!self.staging.contains(&tx_hash)); self.staging.push(tx_hash); eligible_txs.push(tx); @@ -84,11 +84,31 @@ impl Mempool { // push back. pub fn commit_block( &mut self, - _block_number: u64, - _txs_in_block: &[TransactionHash], + txs_in_block: &[TransactionHash], _state_changes: HashMap, ) -> MempoolResult<()> { - todo!() + let mut counter = 0; + for &tx_hash in txs_in_block { + if self.staging.contains(&tx_hash) { + counter += 1; + self.tx_pool.remove(tx_hash)?; + } + } + // It pops the first `counter` hashes from staging area. + // Since transactions maintain their order after being processed by the Mempool, the + // transactions to be included in the block should be the first ones in the staging area. + self.staging.drain(0..counter); + + // Re-insert transaction to PQ. + for &tx_hash in self.staging.iter() { + let tx = self.tx_pool.get(tx_hash)?; + self.tx_queue.insert(TransactionReference::new(tx)); + } + + // Cleanin the `StagingArea`. + self.staging = Vec::default(); + + Ok(()) } fn insert_tx(&mut self, input: MempoolInput) -> MempoolResult<()> { diff --git a/crates/mempool/src/mempool_test.rs b/crates/mempool/src/mempool_test.rs index 027ae3cf4..082b3bcb7 100644 --- a/crates/mempool/src/mempool_test.rs +++ b/crates/mempool/src/mempool_test.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + use assert_matches::assert_matches; use itertools::zip_eq; use pretty_assertions::assert_eq; @@ -47,7 +49,7 @@ fn mempool() -> Mempool { // Asserts that the transactions in the mempool are in ascending order as per the expected // transactions. #[track_caller] -fn check_mempool_txs_eq(mempool: &Mempool, expected_txs: &[ThinTransaction]) { +fn check_mempool_pq_txs_eq(mempool: &Mempool, expected_txs: &[ThinTransaction]) { let mempool_txs = mempool.tx_queue.iter(); let expected_txs = expected_txs.iter().map(TransactionReference::new); @@ -58,6 +60,14 @@ fn check_mempool_txs_eq(mempool: &Mempool, expected_txs: &[ThinTransaction]) { ); } +#[track_caller] +fn check_mempool_txs_eq(mempool_txs: &[ThinTransaction], expected_txs: &[ThinTransaction]) { + assert!( + zip_eq(expected_txs, mempool_txs) + .all(|(expected_tx, mempool_tx)| expected_tx == mempool_tx) + ); +} + #[rstest] #[case(3)] // Requesting exactly the number of transactions in the queue #[case(5)] // Requesting more transactions than are in the queue @@ -89,7 +99,7 @@ fn test_get_txs(#[case] requested_txs: usize) { assert_eq!(txs, expected_txs); // checks that the transactions that were not returned are still in the mempool. - check_mempool_txs_eq(&mempool, remaining_txs); + check_mempool_pq_txs_eq(&mempool, remaining_txs); } #[rstest] @@ -106,7 +116,7 @@ fn test_add_tx(mut mempool: Mempool) { let expected_txs = &[input_tip_50_address_0.tx, input_tip_80_address_2.tx, input_tip_100_address_1.tx]; - check_mempool_txs_eq(&mempool, expected_txs) + check_mempool_pq_txs_eq(&mempool, expected_txs) } #[rstest] @@ -121,7 +131,30 @@ fn test_add_same_tx(mut mempool: Mempool) { Err(MempoolError::DuplicateTransaction { .. }) ); // Assert that the original tx remains in the pool after the failed attempt. - check_mempool_txs_eq(&mempool, &[same_input.tx]) + check_mempool_pq_txs_eq(&mempool, &[same_input.tx]) +} + +#[rstest] +fn test_commit_block() { + let tx_tip_50_address_0 = add_tx_input!(Tip(50), TransactionHash(StarkFelt::ONE)); + let tx_tip_100_address_1 = + add_tx_input!(Tip(100), TransactionHash(StarkFelt::TWO), contract_address!("0x1")); + + let mut mempool = + Mempool::new([tx_tip_50_address_0.clone(), tx_tip_100_address_1.clone()]).unwrap(); + + let sorted_txs = vec![tx_tip_100_address_1.tx, tx_tip_50_address_0.tx]; + + let txs = mempool.get_txs(2).unwrap(); + + // checks that the returned transactions are the ones with the highest priority. + assert_eq!(txs.len(), 2); + check_mempool_txs_eq(txs.as_slice(), &sorted_txs[..2]); + + mempool.commit_block(&[TransactionHash(StarkFelt::TWO)], HashMap::default()).unwrap(); + + let _txs = mempool.get_txs(1).unwrap(); + mempool.commit_block(&[TransactionHash(StarkFelt::ONE)], HashMap::default()).unwrap(); } #[rstest] @@ -137,7 +170,7 @@ fn test_add_tx_with_identical_tip_succeeds(mut mempool: Mempool) { // TODO: currently hash comparison tie-breaks the two. Once more robust tie-breaks are added // replace this assertion with a dedicated test. - check_mempool_txs_eq(&mempool, &[input2.tx, input1.tx]); + check_mempool_pq_txs_eq(&mempool, &[input2.tx, input1.tx]); } #[rstest] @@ -151,5 +184,5 @@ fn test_tip_priority_over_tx_hash(mut mempool: Mempool) { assert_eq!(mempool.add_tx(input_big_tip_small_hash.clone()), Ok(())); assert_eq!(mempool.add_tx(input_small_tip_big_hash.clone()), Ok(())); - check_mempool_txs_eq(&mempool, &[input_small_tip_big_hash.tx, input_big_tip_small_hash.tx]) + check_mempool_pq_txs_eq(&mempool, &[input_small_tip_big_hash.tx, input_big_tip_small_hash.tx]) }