Skip to content

Commit

Permalink
feat: implement AddressPriorityQueue in Mempool
Browse files Browse the repository at this point in the history
  • Loading branch information
ayeletstarkware committed Jun 16, 2024
1 parent 7aad112 commit 678dac5
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 104 deletions.
11 changes: 7 additions & 4 deletions crates/gateway/src/gateway_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ use axum::extract::State;
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use blockifier::context::ChainInfo;
use rstest::fixture;
use starknet_api::external_transaction::ExternalTransaction;
use starknet_api::transaction::TransactionHash;
use starknet_mempool::mempool::{create_mempool_server, Mempool};
Expand All @@ -25,6 +26,11 @@ use crate::utils::{external_tx_to_account_tx, get_tx_hash};

const MEMPOOL_INVOCATIONS_QUEUE_SIZE: usize = 32;

#[fixture]
fn mempool() -> Mempool {
Mempool::empty()
}

pub fn app_state(mempool_client: Arc<dyn MempoolClient>) -> AppState {
AppState {
stateless_tx_validator: StatelessTransactionValidator {
Expand All @@ -46,14 +52,11 @@ pub fn app_state(mempool_client: Arc<dyn MempoolClient>) -> AppState {
// TODO(Ayelet): add test cases for declare and deploy account transactions.
#[tokio::test]
async fn test_add_tx() {
// TODO: Add fixture.

let mempool = Mempool::new([]);
// TODO(Tsabary): wrap creation of channels in dedicated functions, take channel capacity from
// config.
let (tx_mempool, rx_mempool) =
channel::<MempoolRequestAndResponseSender>(MEMPOOL_INVOCATIONS_QUEUE_SIZE);
let mut mempool_server = create_mempool_server(mempool, rx_mempool);
let mut mempool_server = create_mempool_server(mempool(), rx_mempool);
task::spawn(async move {
mempool_server.start().await;
});
Expand Down
91 changes: 45 additions & 46 deletions crates/mempool/src/mempool.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
use std::collections::hash_map::Entry::{Occupied, Vacant};
use std::collections::HashMap;

use async_trait::async_trait;
Expand All @@ -8,66 +7,57 @@ use starknet_mempool_infra::component_definitions::ComponentRequestHandler;
use starknet_mempool_infra::component_server::ComponentServer;
use starknet_mempool_types::errors::MempoolError;
use starknet_mempool_types::mempool_types::{
Account, AccountState, MempoolInput, MempoolRequest, MempoolRequestAndResponseSender,
MempoolResponse, MempoolResult, ThinTransaction,
AccountState, MempoolInput, MempoolRequest, MempoolRequestAndResponseSender, MempoolResponse,
MempoolResult, ThinTransaction,
};
use tokio::sync::mpsc::Receiver;

use crate::priority_queue::TransactionPriorityQueue;
use crate::priority_queue::{AddressStore, TransactionPriorityQueue};

#[cfg(test)]
#[path = "mempool_test.rs"]
pub mod mempool_test;

#[derive(Debug)]
#[derive(Default)]
pub struct Mempool {
// TODO: add docstring explaining visibility and coupling of the fields.
txs_queue: TransactionPriorityQueue,
state: HashMap<ContractAddress, AccountState>,
address_to_store: HashMap<ContractAddress, AddressStore>,
}

impl Mempool {
pub fn new(inputs: impl IntoIterator<Item = MempoolInput>) -> Self {
let mut mempool =
Mempool { txs_queue: TransactionPriorityQueue::default(), state: HashMap::default() };

mempool.txs_queue = TransactionPriorityQueue::from(
inputs
.into_iter()
.map(|input| {
// Attempts to insert a key-value pair into the mempool's state. Returns `None`
// if the key was not present, otherwise returns the old value while updating
// the new value.
let prev_value =
mempool.state.insert(input.account.sender_address, input.account.state);
assert!(
prev_value.is_none(),
"Sender address: {:?} already exists in the mempool. Can't add {:?} to \
the mempool.",
input.account.sender_address,
input.tx
);
input.tx
})
.collect::<Vec<ThinTransaction>>(),
);

mempool
pub fn new(inputs: impl IntoIterator<Item = MempoolInput>) -> MempoolResult<Self> {
let mut mempool = Mempool::default();

for input in inputs {
mempool.insert_tx(input.tx)?;
}

Ok(mempool)
}

pub fn empty() -> Self {
Mempool::new([])
Mempool::default()
}

/// Retrieves up to `n_txs` transactions with the highest priority from the mempool.
/// Transactions are guaranteed to be unique across calls until `commit_block` is invoked.
// TODO: the last part about commit_block is incorrect if we delete txs in get_txs and then push
// back. TODO: Consider renaming to `pop_txs` to be more consistent with the standard
// library.
// back. TODO: Consider renaming to `pop_txs` to be more consistent with the standard library.
// TODO: If `n_txs` is greater than the number of transactions in `txs_queue`, it will also
// check and add transactions from `address_to_store`.
pub fn get_txs(&mut self, n_txs: usize) -> MempoolResult<Vec<ThinTransaction>> {
let txs = self.txs_queue.pop_last_chunk(n_txs);
for tx in &txs {
self.state.remove(&tx.sender_address);
if let Some(address_queue) = self.address_to_store.get_mut(&tx.sender_address) {
address_queue.pop_front();

if address_queue.is_empty() {
self.address_to_store.remove(&tx.sender_address);
} else if let Some(next_tx) = address_queue.top() {
self.txs_queue.push(next_tx.clone());
}
}
}

Ok(txs)
Expand All @@ -76,15 +66,9 @@ impl Mempool {
/// Adds a new transaction to the mempool.
/// TODO: support fee escalation and transactions with future nonces.
/// TODO: change input type to `MempoolInput`.
pub fn add_tx(&mut self, tx: ThinTransaction, account: Account) -> MempoolResult<()> {
match self.state.entry(account.sender_address) {
Occupied(_) => Err(MempoolError::DuplicateTransaction { tx_hash: tx.tx_hash }),
Vacant(entry) => {
entry.insert(account.state);
self.txs_queue.push(tx);
Ok(())
}
}
pub fn add_tx(&mut self, tx: ThinTransaction) -> MempoolResult<()> {
self.insert_tx(tx)?;
Ok(())
}

/// Update the mempool's internal state according to the committed block's transactions.
Expand All @@ -99,6 +83,21 @@ impl Mempool {
) -> MempoolResult<()> {
todo!()
}

fn insert_tx(&mut self, tx: ThinTransaction) -> MempoolResult<()> {
let address_queue = self.address_to_store.entry(tx.sender_address).or_default();

if address_queue.contains(&tx) {
return Err(MempoolError::DuplicateTransaction { tx_hash: tx.tx_hash });
}

address_queue.push(tx.clone());
if address_queue.len() == 1 {
self.txs_queue.push(tx);
}

Ok(())
}
}

/// Wraps the mempool to enable inbound async communication from other components.
Expand All @@ -112,7 +111,7 @@ impl MempoolCommunicationWrapper {
}

fn add_tx(&mut self, mempool_input: MempoolInput) -> MempoolResult<()> {
self.mempool.add_tx(mempool_input.tx, mempool_input.account)
self.mempool.add_tx(mempool_input.tx)
}

fn get_txs(&mut self, n_txs: usize) -> MempoolResult<Vec<ThinTransaction>> {
Expand Down
Loading

0 comments on commit 678dac5

Please sign in to comment.