diff --git a/Cargo.lock b/Cargo.lock index b916467c27..7e853bda17 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8273,6 +8273,7 @@ name = "katana-pool" version = "1.0.0-rc.1" dependencies = [ "futures", + "futures-util", "katana-executor", "katana-primitives", "katana-provider", diff --git a/crates/katana/core/src/service/mod.rs b/crates/katana/core/src/service/mod.rs index db7a3fe565..8bdae8bb7f 100644 --- a/crates/katana/core/src/service/mod.rs +++ b/crates/katana/core/src/service/mod.rs @@ -1,19 +1,14 @@ -// TODO: remove the messaging feature flag -// TODO: move the tasks to a separate module - use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; use block_producer::BlockProductionError; -use futures::channel::mpsc::Receiver; -use futures::stream::{Fuse, Stream, StreamExt}; +use futures::stream::StreamExt; use katana_executor::ExecutorFactory; use katana_pool::ordering::PoolOrd; use katana_pool::pending::PendingTransactions; use katana_pool::{TransactionPool, TxPool}; use katana_primitives::transaction::ExecutableTxWithHash; -use katana_primitives::Felt; use tracing::{error, info}; use self::block_producer::BlockProducer; @@ -114,11 +109,6 @@ pub struct TransactionMiner where O: PoolOrd, { - /// stores whether there are pending transacions (if known) - has_pending_txs: Option, - /// Receives hashes of transactions that are ready from the pool - rx: Fuse>, - pending_txs: PendingTransactions, } @@ -126,41 +116,21 @@ impl TransactionMiner where O: PoolOrd, { - pub fn new( - pending_txs: PendingTransactions, - rx: Receiver, - ) -> Self { - Self { pending_txs, rx: rx.fuse(), has_pending_txs: None } + pub fn new(pending_txs: PendingTransactions) -> Self { + Self { pending_txs } } - fn poll( - &mut self, - // pool: &TxPool, - cx: &mut Context<'_>, - ) -> Poll> { - // drain the notification stream - while let Poll::Ready(Some(_)) = Pin::new(&mut self.rx).poll_next(cx) { - self.has_pending_txs = Some(true); - } - - if self.has_pending_txs == Some(false) { - return Poll::Pending; - } - + fn poll(&mut self, cx: &mut Context<'_>) -> Poll> { let mut transactions = Vec::new(); + while let Poll::Ready(Some(tx)) = self.pending_txs.poll_next_unpin(cx) { transactions.push(tx.tx.as_ref().clone()); } - // take all the transactions from the pool - // let transactions = - // pool.take_transactions().map(|tx| tx.tx.as_ref().clone()).collect::>(); - if transactions.is_empty() { return Poll::Pending; } - self.has_pending_txs = Some(false); Poll::Ready(transactions) } } diff --git a/crates/katana/pipeline/src/stage/sequencing.rs b/crates/katana/pipeline/src/stage/sequencing.rs index da8ab6518c..c988ecab57 100644 --- a/crates/katana/pipeline/src/stage/sequencing.rs +++ b/crates/katana/pipeline/src/stage/sequencing.rs @@ -53,11 +53,10 @@ impl Sequencing { } fn run_block_production(&self) -> TaskHandle> { - let pool = self.pool.clone(); - let miner = TransactionMiner::new(pool.pending_transactions(), pool.add_listener()); + // Create a new transaction miner with a subscription to the pool's pending transactions. + let miner = TransactionMiner::new(self.pool.pending_transactions()); let block_producer = self.block_producer.clone(); - - let service = BlockProductionTask::new(pool, miner, block_producer); + let service = BlockProductionTask::new(self.pool.clone(), miner, block_producer); self.task_spawner.build_task().name("Block production").spawn(service) } } diff --git a/crates/katana/pool/Cargo.toml b/crates/katana/pool/Cargo.toml index b6785de8ad..207b18cbc6 100644 --- a/crates/katana/pool/Cargo.toml +++ b/crates/katana/pool/Cargo.toml @@ -17,4 +17,5 @@ tokio = { workspace = true, features = [ "sync" ] } tracing.workspace = true [dev-dependencies] +futures-util.workspace = true rand.workspace = true diff --git a/crates/katana/pool/src/lib.rs b/crates/katana/pool/src/lib.rs index 811d97d57b..f61d458104 100644 --- a/crates/katana/pool/src/lib.rs +++ b/crates/katana/pool/src/lib.rs @@ -47,6 +47,8 @@ pub trait TransactionPool { /// Add a new transaction to the pool. fn add_transaction(&self, tx: Self::Transaction) -> PoolResult; + /// Returns a [`Stream`](futures::Stream) which yields pending transactions - transactions that + /// can be executed - from the pool. fn pending_transactions(&self) -> PendingTransactions; /// Check if the pool contains a transaction with the given hash. @@ -57,6 +59,7 @@ pub trait TransactionPool { fn add_listener(&self) -> Receiver; + /// Removes a list of transactions from the pool according to their hashes. fn remove_transactions(&self, hashes: &[TxHash]); /// Get the total number of transactions in the pool. diff --git a/crates/katana/pool/src/ordering.rs b/crates/katana/pool/src/ordering.rs index 97dcd1ba74..5ed88c496c 100644 --- a/crates/katana/pool/src/ordering.rs +++ b/crates/katana/pool/src/ordering.rs @@ -125,6 +125,8 @@ impl Default for TipOrdering { #[cfg(test)] mod tests { + use futures::executor; + use crate::ordering::{self, FiFo}; use crate::pool::test_utils::*; use crate::tx::PoolTransaction; @@ -145,10 +147,10 @@ mod tests { }); // Get pending transactions - let pendings = pool.pending_transactions().collect::>(); + let pendings = executor::block_on_stream(pool.pending_transactions()); // Assert that the transactions are in the order they were added (first to last) - pendings.iter().zip(txs).for_each(|(pending, tx)| { + pendings.into_iter().zip(txs).for_each(|(pending, tx)| { assert_eq!(pending.tx.as_ref(), &tx); }); } @@ -177,7 +179,7 @@ mod tests { }); // Get pending transactions - let pending = pool.pending_transactions().collect::>(); + let pending = executor::block_on_stream(pool.pending_transactions()).collect::>(); assert_eq!(pending.len(), txs.len()); // Assert that the transactions are ordered by tip (highest to lowest) diff --git a/crates/katana/pool/src/pending.rs b/crates/katana/pool/src/pending.rs index 9913c3472d..82b6c76c68 100644 --- a/crates/katana/pool/src/pending.rs +++ b/crates/katana/pool/src/pending.rs @@ -5,15 +5,18 @@ use std::task::{Context, Poll}; use futures::{Stream, StreamExt}; use crate::ordering::PoolOrd; -use crate::subscription::PoolSubscription; +use crate::subscription::Subscription; use crate::tx::{PendingTx, PoolTransaction}; -/// an iterator that yields transactions from the pool that can be included in a block, sorted by +/// An iterator that yields transactions from the pool that can be included in a block, sorted by /// by its priority. #[derive(Debug)] pub struct PendingTransactions { + /// Iterator over all the pending transactions at the time of the creation of this struct. pub(crate) all: IntoIter>, - pub(crate) subscription: PoolSubscription, + /// Subscription to the pool to get notified when new transactions are added. This is used to + /// wait on the new transactions after exhausting the `all` iterator. + pub(crate) subscription: Subscription, } impl Stream for PendingTransactions @@ -25,7 +28,6 @@ where fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); - if let Some(tx) = this.all.next() { Poll::Ready(Some(tx)) } else { @@ -34,14 +36,104 @@ where } } -impl Iterator for PendingTransactions -where - T: PoolTransaction, - O: PoolOrd, -{ - type Item = PendingTx; +#[cfg(test)] +mod tests { + + use std::sync::atomic::{AtomicBool, Ordering}; + use std::sync::Arc; + + use futures::StreamExt; + use tokio::task::yield_now; + + use crate::pool::test_utils::PoolTx; + use crate::pool::Pool; + use crate::validation::NoopValidator; + use crate::{ordering, PoolTransaction, TransactionPool}; + + #[tokio::test] + async fn pending_transactions() { + let pool = Pool::new(NoopValidator::::new(), ordering::FiFo::new()); + + let first_batch = [ + PoolTx::new(), + PoolTx::new(), + PoolTx::new(), + PoolTx::new(), + PoolTx::new(), + PoolTx::new(), + ]; + + for tx in &first_batch { + pool.add_transaction(tx.clone()).expect("failed to add tx"); + } + + let mut pendings = pool.pending_transactions(); + + // exhaust all the first batch transactions + for expected in &first_batch { + let actual = pendings.next().await.map(|t| t.tx).unwrap(); + assert_eq!(expected, actual.as_ref()); + } - fn next(&mut self) -> Option { - self.all.next() + let second_batch = [ + PoolTx::new(), + PoolTx::new(), + PoolTx::new(), + PoolTx::new(), + PoolTx::new(), + PoolTx::new(), + ]; + + for tx in &second_batch { + pool.add_transaction(tx.clone()).expect("failed to add tx"); + } + + // exhaust all the first batch transactions + for expected in &second_batch { + let actual = pendings.next().await.map(|t| t.tx).unwrap(); + assert_eq!(expected, actual.as_ref()); + } + + // Check that all the added transaction is still in the pool because we haven't removed it + // yet. + let all = [first_batch, second_batch].concat(); + for tx in all { + assert!(pool.contains(tx.hash())); + } + } + + #[tokio::test(flavor = "multi_thread")] + async fn subscription_stream_wakeup() { + let pool = Pool::new(NoopValidator::::new(), ordering::FiFo::new()); + let mut pending = pool.pending_transactions(); + + // Spawn a task that will add a transaction after a delay + let pool_clone = pool.clone(); + + let txs = [PoolTx::new(), PoolTx::new(), PoolTx::new()]; + let txs_clone = txs.clone(); + + let has_polled_once = Arc::new(AtomicBool::new(false)); + let has_polled_once_clone = has_polled_once.clone(); + + tokio::spawn(async move { + while !has_polled_once_clone.load(Ordering::SeqCst) { + yield_now().await; + } + + for tx in txs_clone { + pool_clone.add_transaction(tx).expect("failed to add tx"); + } + }); + + // Check that first poll_next returns Pending because no pending transaction has been added + // to the pool yet + assert!(futures_util::poll!(pending.next()).is_pending()); + has_polled_once.store(true, Ordering::SeqCst); + + for tx in txs { + let received = pending.next().await.unwrap(); + assert_eq!(&tx, received.tx.as_ref()); + } } } diff --git a/crates/katana/pool/src/pool.rs b/crates/katana/pool/src/pool.rs index b2b9b2a51a..92e095677e 100644 --- a/crates/katana/pool/src/pool.rs +++ b/crates/katana/pool/src/pool.rs @@ -5,12 +5,12 @@ use std::sync::Arc; use futures::channel::mpsc::{channel, Receiver, Sender}; use katana_primitives::transaction::TxHash; use parking_lot::RwLock; -use tokio::sync::Notify; +use tokio::sync::mpsc; use tracing::{error, info, warn}; use crate::ordering::PoolOrd; use crate::pending::PendingTransactions; -use crate::subscription::PoolSubscription; +use crate::subscription::Subscription; use crate::tx::{PendingTx, PoolTransaction, TxId}; use crate::validation::error::InvalidTransactionError; use crate::validation::{ValidationOutcome, Validator}; @@ -35,7 +35,7 @@ struct Inner { listeners: RwLock>>, /// subscribers for incoming txs - subscribers: RwLock>>, + subscribers: RwLock>>>, /// the tx validator validator: V, @@ -90,24 +90,36 @@ where } } + fn notify_subscribers(&self, tx: PendingTx) { + let mut subscribers = self.inner.subscribers.write(); + // this is basically a retain but with mut reference + for n in (0..subscribers.len()).rev() { + let sender = subscribers.swap_remove(n); + let retain = match sender.send(tx.clone()) { + Ok(()) => true, + Err(error) => { + warn!(%error, "Subscription channel closed"); + false + } + }; + + if retain { + subscribers.push(sender) + } + } + } + // notify both listener and subscribers fn notify(&self, tx: PendingTx) { self.notify_listener(tx.tx.hash()); self.notify_subscribers(tx); } - fn notify_subscribers(&self, tx: PendingTx) { - let subscribers = self.inner.subscribers.read(); - for subscriber in subscribers.iter() { - subscriber.broadcast(tx.clone()); - } - } - - fn subscribe(&self) -> PoolSubscription { - let notify = Arc::new(Notify::new()); - let subscription = PoolSubscription { notify, txs: Default::default() }; - self.inner.subscribers.write().push(subscription.clone()); - subscription + fn subscribe(&self) -> Subscription { + let (tx, rx) = mpsc::unbounded_channel(); + let subscriber = Subscription::new(rx); + self.inner.subscribers.write().push(tx); + subscriber } } @@ -307,7 +319,9 @@ pub(crate) mod test_utils { #[cfg(test)] mod tests { + use futures::executor; use katana_primitives::contract::{ContractAddress, Nonce}; + use katana_primitives::transaction::TxHash; use katana_primitives::Felt; use super::test_utils::*; @@ -356,7 +370,7 @@ mod tests { assert!(txs.iter().all(|tx| pool.get(tx.hash()).is_some())); // noop validator should consider all txs as valid - let pendings = pool.pending_transactions().collect::>(); + let pendings = executor::block_on_stream(pool.pending_transactions()).collect::>(); assert_eq!(pendings.len(), txs.len()); // bcs we're using fcfs, the order should be the same as the order of the txs submission @@ -411,6 +425,41 @@ mod tests { assert_eq!(counter, txs.len()); } + #[test] + fn remove_transactions() { + let pool = TestPool::test(); + + let txs = [ + PoolTx::new(), + PoolTx::new(), + PoolTx::new(), + PoolTx::new(), + PoolTx::new(), + PoolTx::new(), + PoolTx::new(), + PoolTx::new(), + ]; + + // start adding txs to the pool + txs.iter().for_each(|tx| { + let _ = pool.add_transaction(tx.clone()); + }); + + // first check that the transaction are indeed in the pool + txs.iter().for_each(|tx| { + assert!(pool.contains(tx.hash())); + }); + + // remove the transactions + let hashes = txs.iter().map(|t| t.hash()).collect::>(); + pool.remove_transactions(&hashes); + + // check that the transaction are no longer in the pool + txs.iter().for_each(|tx| { + assert!(!pool.contains(tx.hash())); + }); + } + #[test] #[ignore = "Txs dependency management not fully implemented yet"] fn dependent_txs_linear_insertion() { @@ -429,7 +478,7 @@ mod tests { }); // Get pending transactions - let pending = pool.pending_transactions().collect::>(); + let pending = executor::block_on_stream(pool.pending_transactions()).collect::>(); // Check that the number of pending transactions matches the number of added transactions assert_eq!(pending.len(), total as usize); diff --git a/crates/katana/pool/src/subscription.rs b/crates/katana/pool/src/subscription.rs index 31b45e4937..465cffba97 100644 --- a/crates/katana/pool/src/subscription.rs +++ b/crates/katana/pool/src/subscription.rs @@ -1,40 +1,31 @@ use std::collections::BTreeSet; -use std::future::Future; -use std::pin::{pin, Pin}; -use std::sync::Arc; +use std::pin::Pin; use std::task::{Context, Poll}; use futures::Stream; -use parking_lot::RwLock; -use tokio::sync::Notify; +use parking_lot::Mutex; +use tokio::sync::mpsc; use crate::ordering::PoolOrd; use crate::tx::{PendingTx, PoolTransaction}; #[derive(Debug)] -pub struct PoolSubscription { - pub(crate) txs: Arc>>>, - pub(crate) notify: Arc, +pub struct Subscription { + txs: Mutex>>, + receiver: mpsc::UnboundedReceiver>, } -impl Clone for PoolSubscription { - fn clone(&self) -> Self { - Self { txs: self.txs.clone(), notify: self.notify.clone() } - } -} - -impl PoolSubscription +impl Subscription where T: PoolTransaction, O: PoolOrd, { - pub(crate) fn broadcast(&self, tx: PendingTx) { - self.notify.notify_waiters(); - self.txs.write().insert(tx); + pub(crate) fn new(receiver: mpsc::UnboundedReceiver>) -> Self { + Self { txs: Default::default(), receiver } } } -impl Stream for PoolSubscription +impl Stream for Subscription where T: PoolTransaction, O: PoolOrd, @@ -43,17 +34,34 @@ where fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let this = self.get_mut(); + let mut txs = this.txs.lock(); + // In the event where a lot of transactions have been sent to the receiver channel and this + // stream hasn't been iterated since, the next call to `.next()` of this Stream will + // require to drain the channel and insert all the transactions into the btree set. If there + // are a lot of transactions to insert, it would take a while and might block the + // runtime. loop { - if let Some(tx) = this.txs.write().pop_first() { + if let Some(tx) = txs.pop_first() { return Poll::Ready(Some(tx)); } - if pin!(this.notify.notified()).poll(cx).is_pending() { - break; + // Check the channel if there are new transactions available. + match this.receiver.poll_recv(cx) { + // insert the new transactions into the btree set to make sure they are ordered + // according to the pool's ordering. + Poll::Ready(Some(tx)) => { + txs.insert(tx); + + // Check if there are more transactions available in the channel. + while let Poll::Ready(Some(tx)) = this.receiver.poll_recv(cx) { + txs.insert(tx); + } + } + + Poll::Ready(None) => return Poll::Ready(None), + Poll::Pending => return Poll::Pending, } } - - Poll::Pending } }