Skip to content

Commit

Permalink
refactor: update priority queue to use ThinPriorityTransaction
Browse files Browse the repository at this point in the history
  • Loading branch information
MohammadNassar1 committed Jun 11, 2024
1 parent c43438c commit eb6791e
Show file tree
Hide file tree
Showing 4 changed files with 142 additions and 72 deletions.
60 changes: 27 additions & 33 deletions crates/mempool/src/mempool.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ pub mod mempool_test;
pub struct Mempool {
// TODO: add docstring explaining visibility and coupling of the fields.
txs_queue: TransactionPriorityQueue,
// All transactions currently held in the mempool.
tx_store: TransactionStore,
state: HashMap<ContractAddress, AccountState>,
}
Expand All @@ -35,35 +36,25 @@ impl Mempool {
state: HashMap::default(),
};

mempool.txs_queue = TransactionPriorityQueue::from(
inputs
.into_iter()
.map(|input| {
// Attempts to insert a key-value pair into the mempool's state. Returns `None`
// if the key was not present, otherwise returns the old value while updating
// the new value.
let prev_value =
mempool.state.insert(input.account.sender_address, input.account.state);
assert!(
prev_value.is_none(),
"Sender address: {:?} already exists in the mempool. Can't add {:?} to \
the mempool.",
input.account.sender_address,
input.tx
);

// Insert the transaction into the tx_store.
let res = mempool.tx_store.push(input.tx.clone());
assert!(
res.is_ok(),
"Transaction: {:?} already exists in the mempool.",
input.tx.tx_hash
);

input.tx
})
.collect::<Vec<ThinTransaction>>(),
);
for MempoolInput { tx, account: Account { sender_address, state } } in inputs.into_iter() {
// Attempts to insert a key-value pair into the mempool's state. Returns `None`
// if the key was not present, otherwise returns the old value while updating
// the new value.
let existing_account_state = mempool.state.insert(sender_address, state);
assert!(
existing_account_state.is_none(),
"Sender address: {:?} already exists in the mempool. Can't add {:?} to the \
mempool.",
sender_address,
tx
);

// Insert the transaction into the tx_store.
let res = mempool.tx_store.push(tx.clone());
assert!(res.is_ok(), "Transaction: {:?} already exists in the mempool.", tx.tx_hash);

mempool.txs_queue.push(tx.clone().into());
}

mempool
}
Expand All @@ -78,10 +69,13 @@ impl Mempool {
// back. TODO: Consider renaming to `pop_txs` to be more consistent with the standard
// library.
pub fn get_txs(&mut self, n_txs: usize) -> MempoolResult<Vec<ThinTransaction>> {
let txs = self.txs_queue.pop_last_chunk(n_txs);
for tx in &txs {
let pq_txs = self.txs_queue.pop_last_chunk(n_txs);

let mut txs: Vec<ThinTransaction> = Vec::default();
for pq_tx in &pq_txs {
let tx = self.tx_store.remove(&pq_tx.tx_hash)?;
self.state.remove(&tx.sender_address);
self.tx_store.remove(&tx.tx_hash)?;
txs.push(tx);
}

Ok(txs)
Expand All @@ -96,7 +90,7 @@ impl Mempool {
Vacant(entry) => {
entry.insert(account.state);
// TODO(Mohammad): use `handle_tx`.
self.txs_queue.push(tx.clone());
self.txs_queue.push(tx.clone().into());
self.tx_store.push(tx)?;

Ok(())
Expand Down
23 changes: 14 additions & 9 deletions crates/mempool/src/mempool_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use starknet_mempool_types::errors::MempoolError;
use starknet_mempool_types::mempool_types::ThinTransaction;

use crate::mempool::{Account, Mempool, MempoolInput};
use crate::priority_queue::ThinPriorityTransaction;

/// Creates a valid input for mempool's `add_tx` with optional default value for
/// `sender_address`.
Expand Down Expand Up @@ -120,10 +121,10 @@ fn test_add_tx(mut mempool: Mempool) {
mempool.state.contains_key(&account2.sender_address);
mempool.state.contains_key(&account3.sender_address);

check_mempool_txs_eq(
assert!(check_mempool_txs_eq(
&mempool,
&[tx_tip_50_address_0, tx_tip_80_address_2, tx_tip_100_address_1],
)
&[tx_tip_50_address_0, tx_tip_80_address_2, tx_tip_100_address_1]
));
}

#[rstest]
Expand All @@ -137,15 +138,19 @@ fn test_add_same_tx(mut mempool: Mempool) {
Err(MempoolError::DuplicateTransaction { tx_hash: TransactionHash(StarkFelt::ONE) })
);
// Assert that the original tx remains in the pool after the failed attempt.
check_mempool_txs_eq(&mempool, &[tx])
assert!(check_mempool_txs_eq(&mempool, &[tx]));
}

// Asserts that the transactions in the mempool are in ascending order as per the expected
// transactions.
fn check_mempool_txs_eq(mempool: &Mempool, expected_txs: &[ThinTransaction]) {
fn check_mempool_txs_eq(mempool: &Mempool, expected_txs: &[ThinTransaction]) -> bool {
let mempool_txs = mempool.txs_queue.iter();
// Deref the inner mempool tx type.
expected_txs.iter().zip(mempool_txs).all(|(a, b)| *a == **b);

// Convert and compare transactions
expected_txs.iter().zip(mempool_txs).all(|(expected, actual)| {
let expected_converted: ThinPriorityTransaction = expected.clone().into();
expected_converted == *actual
})
}

#[rstest]
Expand All @@ -162,7 +167,7 @@ fn test_add_tx_with_identical_tip_succeeds(mut mempool: Mempool) {

// TODO: currently hash comparison tie-breaks the two. Once more robust tie-breaks are added
// replace this assertion with a dedicated test.
check_mempool_txs_eq(&mempool, &[tx2, tx1]);
assert!(check_mempool_txs_eq(&mempool, &[tx2, tx1]));
}

#[rstest]
Expand All @@ -176,5 +181,5 @@ fn test_tip_priority_over_tx_hash(mut mempool: Mempool) {

assert!(mempool.add_tx(tx_big_tip_small_hash.clone(), account1).is_ok());
assert!(mempool.add_tx(tx_small_tip_big_hash.clone(), account2).is_ok());
check_mempool_txs_eq(&mempool, &[tx_big_tip_small_hash, tx_small_tip_big_hash])
assert!(check_mempool_txs_eq(&mempool, &[tx_small_tip_big_hash, tx_big_tip_small_hash]));
}
71 changes: 50 additions & 21 deletions crates/mempool/src/priority_queue.rs
Original file line number Diff line number Diff line change
@@ -1,64 +1,93 @@
use std::cmp::Ordering;
use std::collections::BTreeSet;
use std::collections::btree_set::Iter;
use std::collections::{BTreeSet, HashMap};

use starknet_api::core::{ContractAddress, Nonce};
use starknet_api::transaction::{Tip, TransactionHash};
use starknet_mempool_types::mempool_types::ThinTransaction;
// Assumption: for the MVP only one transaction from the same contract class can be in the mempool
// at a time. When this changes, saving the transactions themselves on the queu might no longer be
// appropriate, because we'll also need to stores transactions without indexing them. For example,
// transactions with future nonces will need to be stored, and potentially indexed on block commits.
#[derive(Clone, Debug, Default, derive_more::Deref, derive_more::DerefMut)]
pub struct TransactionPriorityQueue(BTreeSet<PrioritizedTransaction>);
#[derive(Clone, Debug, Default)]
pub struct TransactionPriorityQueue {
// Priority queue of transactions with associated priority.
queue: BTreeSet<ThinPriorityTransaction>,
// FIX: Set of account addresses for efficient existence checks.
address_to_nonce: HashMap<ContractAddress, Nonce>,
}

impl TransactionPriorityQueue {
/// Adds a transaction to the mempool, ensuring unique keys.
/// Panics: if given a duplicate tx.
pub fn push(&mut self, tx: ThinTransaction) {
let mempool_tx = PrioritizedTransaction(tx);
assert!(self.insert(mempool_tx), "Keys should be unique; duplicates are checked prior.");
pub fn push(&mut self, tx: ThinPriorityTransaction) {
self.address_to_nonce.insert(tx.address, tx.nonce);
assert!(self.queue.insert(tx), "Keys should be unique; duplicates are checked prior.");
}

// TODO(gilad): remove collect
pub fn pop_last_chunk(&mut self, n_txs: usize) -> Vec<ThinTransaction> {
(0..n_txs).filter_map(|_| self.pop_last().map(|tx| tx.0)).collect()
pub fn pop_last_chunk(&mut self, n_txs: usize) -> Vec<ThinPriorityTransaction> {
let txs: Vec<ThinPriorityTransaction> =
(0..n_txs).filter_map(|_| self.queue.pop_last()).collect();
for tx in txs.iter() {
self.address_to_nonce.remove(&tx.address);
}
txs
}

pub fn iter(&self) -> Iter<'_, ThinPriorityTransaction> {
self.queue.iter()
}
}

impl From<Vec<ThinTransaction>> for TransactionPriorityQueue {
fn from(transactions: Vec<ThinTransaction>) -> Self {
TransactionPriorityQueue(BTreeSet::from_iter(
transactions.into_iter().map(PrioritizedTransaction),
))
pub fn get_nonce(&self, address: &ContractAddress) -> Option<&Nonce> {
self.address_to_nonce.get(address)
}
}

#[derive(Clone, Debug, derive_more::Deref, derive_more::From)]
pub struct PrioritizedTransaction(pub ThinTransaction);
#[derive(Clone, Debug, Default)]
pub struct ThinPriorityTransaction {
pub address: ContractAddress,
pub nonce: Nonce,
pub tx_hash: TransactionHash,
pub tip: Tip,
}

/// Compare transactions based only on their tip, a uint, using the Eq trait. It ensures that two
/// tips are either exactly equal or not.
impl PartialEq for PrioritizedTransaction {
fn eq(&self, other: &PrioritizedTransaction) -> bool {
impl PartialEq for ThinPriorityTransaction {
fn eq(&self, other: &ThinPriorityTransaction) -> bool {
self.tip == other.tip && self.tx_hash == other.tx_hash
}
}

/// Marks this struct as capable of strict equality comparisons, signaling to the compiler it
/// adheres to equality semantics.
// Note: this depends on the implementation of `PartialEq`, see its docstring.
impl Eq for PrioritizedTransaction {}
impl Eq for ThinPriorityTransaction {}

impl Ord for PrioritizedTransaction {
impl Ord for ThinPriorityTransaction {
fn cmp(&self, other: &Self) -> Ordering {
self.tip.cmp(&other.tip).then_with(|| self.tx_hash.cmp(&other.tx_hash))
}
}

impl PartialOrd for PrioritizedTransaction {
impl PartialOrd for ThinPriorityTransaction {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}

impl From<ThinTransaction> for ThinPriorityTransaction {
fn from(tx: ThinTransaction) -> Self {
ThinPriorityTransaction {
address: tx.sender_address,
nonce: tx.nonce,
tx_hash: tx.tx_hash,
tip: tx.tip,
}
}
}

// TODO: remove when is used.
#[allow(dead_code)]
// Assumption: there are no gaps, and the transactions are received in order.
Expand Down
60 changes: 51 additions & 9 deletions crates/mempool/src/transaction_store.rs
Original file line number Diff line number Diff line change
@@ -1,35 +1,77 @@
use std::collections::hash_map::Entry;
use std::collections::HashMap;
use std::collections::{BTreeMap, HashMap};

use starknet_api::core::{ContractAddress, Nonce};
use starknet_api::transaction::TransactionHash;
use starknet_mempool_types::errors::MempoolError;
use starknet_mempool_types::mempool_types::ThinTransaction;

use crate::priority_queue::ThinPriorityTransaction;

// All transactions currently held in the mempool.
#[derive(Clone, Debug, Default)]
pub struct TransactionStore {
// All transactions currently held in the mempool.
store: HashMap<TransactionHash, ThinTransaction>,
// Transactions organized by account address, sorted by ascending nonce values.
txs_by_account: HashMap<ContractAddress, BTreeMap<Nonce, ThinPriorityTransaction>>,
}

impl TransactionStore {
// Insert transaction into the store, ensuring no duplicates
pub fn push(&mut self, tx: ThinTransaction) -> Result<(), MempoolError> {
match self.store.entry(tx.tx_hash) {
Entry::Occupied(_) => {
// TODO: Allow overriding a previous transaction if needed.
Err(MempoolError::DuplicateTransaction { tx_hash: tx.tx_hash })
if let Entry::Occupied(_) = self.store.entry(tx.tx_hash) {
return Err(MempoolError::DuplicateTransaction { tx_hash: tx.tx_hash });
} else {
self.store.insert(tx.tx_hash, tx.clone());
}

match self.txs_by_account.entry(tx.sender_address) {
Entry::Occupied(mut entry) => {
let txs_by_account = entry.get_mut();
if txs_by_account.contains_key(&tx.nonce) {
// Remove the transaction from the store if duplicate nonce found
self.store.remove(&tx.tx_hash);
return Err(MempoolError::DuplicateTransaction { tx_hash: tx.tx_hash });
}
txs_by_account.insert(tx.nonce, tx.into());
}
Entry::Vacant(entry) => {
entry.insert(tx);
Ok(())
let mut txs_by_account = BTreeMap::new();
txs_by_account.insert(tx.nonce, tx.into());
entry.insert(txs_by_account);
}
}

Ok(())
}

pub fn remove(&mut self, tx_hash: &TransactionHash) -> Result<ThinTransaction, MempoolError> {
self.store.remove(tx_hash).ok_or(MempoolError::TransactionNotFound { tx_hash: *tx_hash })
// Remove the transaction from the store
let tx = self.store.remove(tx_hash);

if tx.is_none() {
return Err(MempoolError::TransactionNotFound { tx_hash: *tx_hash });
}
let tx = tx.unwrap();

if let Entry::Occupied(mut entry) = self.txs_by_account.entry(tx.sender_address) {
let txs_by_account = entry.get_mut();
txs_by_account.remove(&tx.nonce);

if txs_by_account.is_empty() {
entry.remove();
}
Ok(tx)
} else {
Err(MempoolError::TransactionNotFound { tx_hash: tx.tx_hash })
}
}

pub fn get(&self, tx_hash: &TransactionHash) -> Result<&ThinTransaction, MempoolError> {
self.store.get(tx_hash).ok_or(MempoolError::TransactionNotFound { tx_hash: *tx_hash })
match self.store.get(tx_hash) {
Some(tx) => Ok(tx),
None => Err(MempoolError::TransactionNotFound { tx_hash: *tx_hash }),
}
}
}

0 comments on commit eb6791e

Please sign in to comment.