From 76dae5ebf84990c0a379055e9508a76b78841b55 Mon Sep 17 00:00:00 2001 From: Ammar Arif Date: Wed, 28 Aug 2024 09:17:40 -0700 Subject: [PATCH] hotfix(katana): make sure validator state is synced with block producer (#2353) * wip * wip * wip * fix blockifier patch * fix * clippy * fix * fmt * fmt * fix --- Cargo.lock | 4 +- .../katana/core/src/service/block_producer.rs | 191 +++++++++++------- crates/katana/core/src/service/mod.rs | 3 - crates/katana/executor/Cargo.toml | 2 +- crates/katana/node/src/lib.rs | 19 +- crates/katana/pool/src/pool.rs | 5 +- crates/katana/pool/src/validation/stateful.rs | 56 +++-- crates/katana/rpc/rpc/src/starknet/mod.rs | 8 +- crates/katana/rpc/rpc/src/starknet/write.rs | 3 +- crates/katana/rpc/rpc/tests/starknet.rs | 42 +++- 10 files changed, 212 insertions(+), 121 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 977bf39b0b..fcc6187899 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1859,7 +1859,7 @@ dependencies = [ [[package]] name = "blockifier" version = "0.8.0-dev.2" -source = "git+https://github.com/dojoengine/blockifier?branch=cairo-2.7-new#42b2b5e28fd47bdfa0d807109360c41a92edafe4" +source = "git+https://github.com/dojoengine/blockifier?branch=cairo-2.7-newer#19b99d35e0fb459305fa588d669fd6c3b39c7fea" dependencies = [ "anyhow", "ark-ec", @@ -6707,7 +6707,7 @@ dependencies = [ "httpdate", "itoa", "pin-project-lite", - "socket2 0.5.7", + "socket2 0.4.10", "tokio", "tower-service", "tracing", diff --git a/crates/katana/core/src/service/block_producer.rs b/crates/katana/core/src/service/block_producer.rs index b8a47b3524..940ded8ac5 100644 --- a/crates/katana/core/src/service/block_producer.rs +++ b/crates/katana/core/src/service/block_producer.rs @@ -20,7 +20,8 @@ use katana_provider::traits::block::{BlockHashProvider, BlockNumberProvider}; use katana_provider::traits::env::BlockEnvProvider; use katana_provider::traits::state::StateFactoryProvider; use katana_tasks::{BlockingTaskPool, BlockingTaskResult}; -use parking_lot::RwLock; +use parking_lot::lock_api::RawMutex; +use parking_lot::{Mutex, RwLock}; use tokio::time::{interval_at, Instant, Interval}; use tracing::{error, info, trace, warn}; @@ -60,7 +61,7 @@ pub struct TxWithOutcome { type ServiceFuture = Pin> + Send + Sync>>; type BlockProductionResult = Result; -type BlockProductionFuture = ServiceFuture; +type BlockProductionFuture = ServiceFuture>; type TxExecutionResult = Result, BlockProductionError>; type TxExecutionFuture = ServiceFuture; @@ -74,31 +75,27 @@ type BlockProductionWithTxnsFuture = pub struct BlockProducer { /// The inner mode of mining. pub producer: RwLock>, - /// validator used in the tx pool - // the validator needs to always be built against the state of the block producer, so - // im putting here for now until we find a better way to handle this. - validator: TxValidator, } impl BlockProducer { /// Creates a block producer that mines a new block every `interval` milliseconds. pub fn interval(backend: Arc>, interval: u64) -> Self { - let (prod, validator) = IntervalBlockProducer::new(backend, Some(interval)); - Self { producer: BlockProducerMode::Interval(prod).into(), validator } + let prod = IntervalBlockProducer::new(backend, Some(interval)); + Self { producer: BlockProducerMode::Interval(prod).into() } } /// Creates a new block producer that will only be possible to mine by calling the /// `katana_generateBlock` RPC method. pub fn on_demand(backend: Arc>) -> Self { - let (prod, validator) = IntervalBlockProducer::new(backend, None); - Self { producer: BlockProducerMode::Interval(prod).into(), validator } + let prod = IntervalBlockProducer::new(backend, None); + Self { producer: BlockProducerMode::Interval(prod).into() } } /// Creates a block producer that mines a new block as soon as there are ready transactions in /// the transactions pool. pub fn instant(backend: Arc>) -> Self { - let (prod, validator) = InstantBlockProducer::new(backend); - Self { producer: BlockProducerMode::Instant(prod).into(), validator } + let prod = InstantBlockProducer::new(backend); + Self { producer: BlockProducerMode::Instant(prod).into() } } pub(super) fn queue(&self, transactions: Vec) { @@ -109,6 +106,14 @@ impl BlockProducer { } } + pub fn validator(&self) -> TxValidator { + let mode = self.producer.read(); + match &*mode { + BlockProducerMode::Instant(pd) => pd.validator.clone(), + BlockProducerMode::Interval(pd) => pd.validator.clone(), + } + } + /// Returns `true` if the block producer is running in _interval_ mode. Otherwise, `fales`. pub fn is_interval_mining(&self) -> bool { matches!(*self.producer.read(), BlockProducerMode::Interval(_)) @@ -129,37 +134,6 @@ impl BlockProducer { } } - pub fn validator(&self) -> &TxValidator { - &self.validator - } - - pub fn update_validator(&self) -> Result<(), ProviderError> { - let mut mode = self.producer.write(); - - match &mut *mode { - BlockProducerMode::Instant(pd) => { - let provider = pd.backend.blockchain.provider(); - let state = provider.latest()?; - - let latest_num = provider.latest_number()?; - let block_env = provider.block_env_at(latest_num.into())?.expect("latest"); - - self.validator.update(state, &block_env) - } - - BlockProducerMode::Interval(pd) => { - let pending_state = pd.executor.0.read(); - - let state = pending_state.state(); - let block_env = pending_state.block_env(); - - self.validator.update(state, &block_env) - } - }; - - Ok(()) - } - pub(super) fn poll_next(&self, cx: &mut Context<'_>) -> Poll> { let mut mode = self.producer.write(); match &mut *mode { @@ -212,10 +186,17 @@ pub struct IntervalBlockProducer { ongoing_execution: Option, /// Listeners notified when a new executed tx is added. tx_execution_listeners: RwLock>>>, + + permit: Arc>, + + /// validator used in the tx pool + // the validator needs to always be built against the state of the block producer, so + // im putting here for now until we find a better way to handle this. + validator: TxValidator, } impl IntervalBlockProducer { - pub fn new(backend: Arc>, interval: Option) -> (Self, TxValidator) { + pub fn new(backend: Arc>, interval: Option) -> Self { let interval = interval.map(|time| { let duration = Duration::from_millis(time); let mut interval = interval_at(Instant::now() + duration, duration); @@ -232,13 +213,18 @@ impl IntervalBlockProducer { let state = provider.latest().unwrap(); let executor = backend.executor_factory.with_state_and_block_env(state, block_env.clone()); + let permit = Arc::new(Mutex::new(())); + // -- build the validator using the same state and envs as the executor let state = executor.state(); let cfg = backend.executor_factory.cfg(); let flags = backend.executor_factory.execution_flags(); - let validator = TxValidator::new(state, flags.clone(), cfg.clone(), &block_env); + let validator = + TxValidator::new(state, flags.clone(), cfg.clone(), &block_env, permit.clone()); - let producer = Self { + Self { + validator, + permit, backend, interval, ongoing_mining: None, @@ -247,15 +233,13 @@ impl IntervalBlockProducer { executor: PendingExecutor::new(executor), tx_execution_listeners: RwLock::new(vec![]), blocking_task_spawner: BlockingTaskPool::new().unwrap(), - }; - - (producer, validator) + } } /// Creates a new [IntervalBlockProducer] with no `interval`. This mode will not produce blocks /// for every fixed interval, although it will still execute all queued transactions and /// keep hold of the pending state. - pub fn new_no_mining(backend: Arc>) -> (Self, TxValidator) { + pub fn new_no_mining(backend: Arc>) -> Self { Self::new(backend, None) } @@ -265,11 +249,24 @@ impl IntervalBlockProducer { /// Force mine a new block. It will only able to mine if there is no ongoing mining process. pub fn force_mine(&mut self) { - match Self::do_mine(self.executor.clone(), self.backend.clone()) { + match Self::do_mine(self.permit.clone(), self.executor.clone(), self.backend.clone()) { Ok(outcome) => { info!(target: LOG_TARGET, block_number = %outcome.block_number, "Force mined block."); self.executor = self.create_new_executor_for_next_block().expect("fail to create executor"); + + // update pool validator state here --------- + + let provider = self.backend.blockchain.provider(); + let state = self.executor.0.read().state(); + let num = provider.latest_number().unwrap(); + let block_env = provider.block_env_at(num.into()).unwrap().unwrap(); + + self.validator.update(state, &block_env); + + // ------------------------------------------- + + unsafe { self.permit.raw().unlock() }; } Err(e) => { error!(target: LOG_TARGET, error = %e, "On force mine."); @@ -278,9 +275,11 @@ impl IntervalBlockProducer { } fn do_mine( + permit: Arc>, executor: PendingExecutor, backend: Arc>, ) -> Result { + unsafe { permit.raw() }.lock(); let executor = &mut executor.write(); trace!(target: LOG_TARGET, "Creating new block."); @@ -373,7 +372,7 @@ impl IntervalBlockProducer { impl Stream for IntervalBlockProducer { // mined block outcome and the new state - type Item = BlockProductionResult; + type Item = Result; fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let pin = self.get_mut(); @@ -381,10 +380,13 @@ impl Stream for IntervalBlockProducer { if let Some(interval) = &mut pin.interval { // mine block if the interval is over if interval.poll_tick(cx).is_ready() && pin.ongoing_mining.is_none() { - let executor = pin.executor.clone(); - let backend = pin.backend.clone(); - let fut = pin.blocking_task_spawner.spawn(|| Self::do_mine(executor, backend)); - pin.ongoing_mining = Some(Box::pin(fut)); + pin.ongoing_mining = Some(Box::pin({ + let executor = pin.executor.clone(); + let backend = pin.backend.clone(); + let permit = pin.permit.clone(); + + pin.blocking_task_spawner.spawn(|| Self::do_mine(permit, executor, backend)) + })); } } @@ -439,7 +441,19 @@ impl Stream for IntervalBlockProducer { Ok(outcome) => { match pin.create_new_executor_for_next_block() { Ok(executor) => { + // update pool validator state here --------- + + let provider = pin.backend.blockchain.provider(); + let state = executor.0.read().state(); + let num = provider.latest_number()?; + let block_env = provider.block_env_at(num.into()).unwrap().unwrap(); + + pin.validator.update(state, &block_env); + + // ------------------------------------------- + pin.executor = executor; + unsafe { pin.permit.raw().unlock() }; } Err(e) => return Poll::Ready(Some(Err(e))), @@ -450,7 +464,7 @@ impl Stream for IntervalBlockProducer { Err(_) => { return Poll::Ready(Some(Err( - BlockProductionError::BlockMiningTaskCancelled, + BlockProductionError::ExecutionTaskCancelled, ))); } } @@ -475,12 +489,21 @@ pub struct InstantBlockProducer { blocking_task_pool: BlockingTaskPool, /// Listeners notified when a new executed tx is added. tx_execution_listeners: RwLock>>>, + + permit: Arc>, + + /// validator used in the tx pool + // the validator needs to always be built against the state of the block producer, so + // im putting here for now until we find a better way to handle this. + validator: TxValidator, } impl InstantBlockProducer { - pub fn new(backend: Arc>) -> (Self, TxValidator) { + pub fn new(backend: Arc>) -> Self { let provider = backend.blockchain.provider(); + let permit = Arc::new(Mutex::new(())); + let latest_num = provider.latest_number().expect("latest block num"); let block_env = provider .block_env_at(latest_num.into()) @@ -490,34 +513,46 @@ impl InstantBlockProducer { let state = provider.latest().expect("latest state"); let cfg = backend.executor_factory.cfg(); let flags = backend.executor_factory.execution_flags(); - let validator = TxValidator::new(state, flags.clone(), cfg.clone(), &block_env); + let validator = + TxValidator::new(state, flags.clone(), cfg.clone(), &block_env, permit.clone()); - let producer = Self { + Self { + permit, backend, + validator, block_mining: None, queued: VecDeque::default(), blocking_task_pool: BlockingTaskPool::new().unwrap(), tx_execution_listeners: RwLock::new(vec![]), - }; - - (producer, validator) + } } pub fn force_mine(&mut self) { if self.block_mining.is_none() { - let txs = self.queued.pop_front().unwrap_or_default(); - let _ = Self::do_mine(self.backend.clone(), txs); + let txs = std::mem::take(&mut self.queued); + let _ = Self::do_mine( + self.validator.clone(), + self.permit.clone(), + self.backend.clone(), + txs, + ); } else { trace!(target: LOG_TARGET, "Unable to force mine while a mining process is running.") } } fn do_mine( + validator: TxValidator, + permit: Arc>, backend: Arc>, - transactions: Vec, + transactions: VecDeque>, ) -> Result<(MinedBlockOutcome, Vec), BlockProductionError> { + let _permit = permit.lock(); + trace!(target: LOG_TARGET, "Creating new block."); + let transactions = transactions.into_iter().flatten().collect::>(); + let provider = backend.blockchain.provider(); let latest_num = provider.latest_number()?; @@ -558,7 +593,15 @@ impl InstantBlockProducer { let outcome = backend.do_mine_block(&block_env, execution_output)?; - // update pool validator state here + // update pool validator state here --------- + + let provider = backend.blockchain.provider(); + let state = provider.latest()?; + let latest_num = provider.latest_number()?; + let block_env = provider.block_env_at(latest_num.into())?.expect("latest"); + validator.update(state, &block_env); + + // ------------------------------------------- trace!(target: LOG_TARGET, block_number = %outcome.block_number, "Created new block."); @@ -607,12 +650,16 @@ impl Stream for InstantBlockProducer { let pin = self.get_mut(); if !pin.queued.is_empty() && pin.block_mining.is_none() { - let transactions = pin.queued.pop_front().expect("not empty; qed"); - let backend = pin.backend.clone(); + pin.block_mining = Some(Box::pin({ + // take everything that is already in the queue + let transactions = std::mem::take(&mut pin.queued); + let validator = pin.validator.clone(); + let backend = pin.backend.clone(); + let permit = pin.permit.clone(); - pin.block_mining = Some(Box::pin( - pin.blocking_task_pool.spawn(|| Self::do_mine(backend, transactions)), - )); + pin.blocking_task_pool + .spawn(|| Self::do_mine(validator, permit, backend, transactions)) + })); } // poll the mining future diff --git a/crates/katana/core/src/service/mod.rs b/crates/katana/core/src/service/mod.rs index 3a9cbe541b..0dce5669cd 100644 --- a/crates/katana/core/src/service/mod.rs +++ b/crates/katana/core/src/service/mod.rs @@ -44,7 +44,6 @@ pub struct NodeService { pub(crate) messaging: Option>, /// Metrics for recording the service operations metrics: ServiceMetrics, - // validator: StatefulValidator } impl NodeService { @@ -100,8 +99,6 @@ impl Future for NodeService { let steps_used = outcome.stats.cairo_steps_used; metrics.l1_gas_processed_total.increment(gas_used as u64); metrics.cairo_steps_processed_total.increment(steps_used as u64); - - pin.block_producer.update_validator().expect("failed to update validator"); } Err(err) => { diff --git a/crates/katana/executor/Cargo.toml b/crates/katana/executor/Cargo.toml index 28b23c070a..914c8c7a5f 100644 --- a/crates/katana/executor/Cargo.toml +++ b/crates/katana/executor/Cargo.toml @@ -15,7 +15,7 @@ starknet = { workspace = true, optional = true } thiserror.workspace = true tracing.workspace = true -blockifier = { git = "https://github.com/dojoengine/blockifier", branch = "cairo-2.7-new", features = [ "testing" ], optional = true } +blockifier = { git = "https://github.com/dojoengine/blockifier", branch = "cairo-2.7-newer", features = [ "testing" ], optional = true } katana-cairo = { workspace = true, optional = true } [dev-dependencies] diff --git a/crates/katana/node/src/lib.rs b/crates/katana/node/src/lib.rs index 4454474d01..7b7a73c716 100644 --- a/crates/katana/node/src/lib.rs +++ b/crates/katana/node/src/lib.rs @@ -24,6 +24,7 @@ use katana_core::service::{NodeService, TransactionMiner}; use katana_executor::implementation::blockifier::BlockifierFactory; use katana_executor::{ExecutorFactory, SimulationFlag}; use katana_pool::ordering::FiFo; +use katana_pool::validation::stateful::TxValidator; use katana_pool::{TransactionPool, TxPool}; use katana_primitives::block::FinalityStatus; use katana_primitives::env::{CfgEnv, FeeTokenAddressses}; @@ -167,8 +168,8 @@ pub async fn start( // --- build transaction pool and miner - let validator = block_producer.validator().clone(); - let pool = TxPool::new(validator, FiFo::new()); + let validator = block_producer.validator(); + let pool = TxPool::new(validator.clone(), FiFo::new()); let miner = TransactionMiner::new(pool.add_listener()); // --- build metrics service @@ -212,7 +213,7 @@ pub async fn start( // --- spawn rpc server - let node_components = (pool, backend.clone(), block_producer); + let node_components = (pool, backend.clone(), block_producer, validator); let rpc_handle = spawn(node_components, server_config).await?; Ok((rpc_handle, backend)) @@ -220,10 +221,10 @@ pub async fn start( // Moved from `katana_rpc` crate pub async fn spawn( - node_components: (TxPool, Arc>, Arc>), + node_components: (TxPool, Arc>, Arc>, TxValidator), config: ServerConfig, ) -> Result { - let (pool, backend, block_producer) = node_components; + let (pool, backend, block_producer, validator) = node_components; let mut methods = RpcModule::new(()); methods.register_method("health", |_, _| Ok(serde_json::json!({ "health": true })))?; @@ -232,8 +233,12 @@ pub async fn spawn( match api { ApiKind::Starknet => { // TODO: merge these into a single logic. - let server = - StarknetApi::new(backend.clone(), pool.clone(), block_producer.clone()); + let server = StarknetApi::new( + backend.clone(), + pool.clone(), + block_producer.clone(), + validator.clone(), + ); methods.merge(StarknetApiServer::into_rpc(server.clone()))?; methods.merge(StarknetWriteApiServer::into_rpc(server.clone()))?; methods.merge(StarknetTraceApiServer::into_rpc(server))?; diff --git a/crates/katana/pool/src/pool.rs b/crates/katana/pool/src/pool.rs index ba7b7d8ed8..3f105cc65a 100644 --- a/crates/katana/pool/src/pool.rs +++ b/crates/katana/pool/src/pool.rs @@ -114,14 +114,15 @@ where Ok(hash) } - ValidationOutcome::Invalid { tx, error } => { - warn!(hash = format!("{:#x}", tx.hash()), "Invalid transaction."); + ValidationOutcome::Invalid { error, .. } => { + warn!(hash = format!("{hash:#x}"), "Invalid transaction."); Err(PoolError::InvalidTransaction(Box::new(error))) } // return as error for now but ideally we should kept the tx in a separate // queue and revalidate it when the parent tx is added to the pool ValidationOutcome::Dependent { tx, tx_nonce, current_nonce } => { + info!(hash = format!("{hash:#x}"), "Dependent transaction."); let err = InvalidTransactionError::InvalidNonce { address: tx.sender(), current_nonce, diff --git a/crates/katana/pool/src/validation/stateful.rs b/crates/katana/pool/src/validation/stateful.rs index 1b57bd376a..7bd19e956b 100644 --- a/crates/katana/pool/src/validation/stateful.rs +++ b/crates/katana/pool/src/validation/stateful.rs @@ -24,9 +24,14 @@ use crate::tx::PoolTransaction; #[allow(missing_debug_implementations)] #[derive(Clone)] pub struct TxValidator { + inner: Arc, +} + +struct Inner { cfg_env: CfgEnv, execution_flags: SimulationFlag, - validator: Arc>, + validator: Mutex, + permit: Arc>, } impl TxValidator { @@ -35,16 +40,28 @@ impl TxValidator { execution_flags: SimulationFlag, cfg_env: CfgEnv, block_env: &BlockEnv, + permit: Arc>, ) -> Self { - let inner = StatefulValidatorAdapter::new(state, block_env, &cfg_env); - Self { cfg_env, execution_flags, validator: Arc::new(Mutex::new(inner)) } + let validator = StatefulValidatorAdapter::new(state, block_env, &cfg_env); + Self { + inner: Arc::new(Inner { + permit, + cfg_env, + execution_flags, + validator: Mutex::new(validator), + }), + } } /// Reset the state of the validator with the given params. This method is used to update the /// validator's state with a new state and block env after a block is mined. - pub fn update(&self, state: Box, block_env: &BlockEnv) { - let updated = StatefulValidatorAdapter::new(state, block_env, &self.cfg_env); - *self.validator.lock() = updated; + pub fn update(&self, new_state: Box, block_env: &BlockEnv) { + let mut validator = self.inner.validator.lock(); + + let mut state = validator.inner.tx_executor.block_state.take().unwrap(); + state.state = StateProviderDb::new(new_state); + + *validator = StatefulValidatorAdapter::new_inner(state, block_env, &self.inner.cfg_env); } // NOTE: @@ -54,7 +71,7 @@ impl TxValidator { // safety is not guaranteed by TransactionExecutor itself. pub fn get_nonce(&self, address: ContractAddress) -> Nonce { let address = to_blk_address(address); - let nonce = self.validator.lock().inner.get_nonce(address).expect("state err"); + let nonce = self.inner.validator.lock().inner.get_nonce(address).expect("state err"); nonce.0 } } @@ -65,23 +82,19 @@ struct StatefulValidatorAdapter { } impl StatefulValidatorAdapter { - fn new( - state: Box, - block_env: &BlockEnv, - cfg_env: &CfgEnv, - ) -> StatefulValidatorAdapter { - let inner = Self::new_inner(state, block_env, cfg_env); - Self { inner } + fn new(state: Box, block_env: &BlockEnv, cfg_env: &CfgEnv) -> Self { + let state = CachedState::new(StateProviderDb::new(state)); + Self::new_inner(state, block_env, cfg_env) } fn new_inner( - state: Box, + state: CachedState>, block_env: &BlockEnv, cfg_env: &CfgEnv, - ) -> StatefulValidator> { - let state = CachedState::new(StateProviderDb::new(state)); + ) -> Self { let block_context = block_context_from_envs(block_env, cfg_env); - StatefulValidator::create(state, block_context) + let inner = StatefulValidator::create(state, block_context, Default::default()); + Self { inner } } /// Used only in the [`Validator::validate`] trait @@ -125,7 +138,8 @@ impl Validator for TxValidator { type Transaction = ExecutableTxWithHash; fn validate(&self, tx: Self::Transaction) -> ValidationResult { - let this = &mut *self.validator.lock(); + let _permit = self.inner.permit.lock(); + let this = &mut *self.inner.validator.lock(); // Check if validation of an invoke transaction should be skipped due to deploy_account not // being proccessed yet. This feature is used to improve UX for users sending @@ -145,8 +159,8 @@ impl Validator for TxValidator { StatefulValidatorAdapter::validate( this, tx, - self.execution_flags.skip_validate || skip_validate, - self.execution_flags.skip_fee_transfer, + self.inner.execution_flags.skip_validate || skip_validate, + self.inner.execution_flags.skip_fee_transfer, ) } } diff --git a/crates/katana/rpc/rpc/src/starknet/mod.rs b/crates/katana/rpc/rpc/src/starknet/mod.rs index acdcd43f56..c7f1ca036d 100644 --- a/crates/katana/rpc/rpc/src/starknet/mod.rs +++ b/crates/katana/rpc/rpc/src/starknet/mod.rs @@ -13,6 +13,7 @@ use anyhow::Result; use katana_core::backend::Backend; use katana_core::service::block_producer::{BlockProducer, BlockProducerMode, PendingExecutor}; use katana_executor::{ExecutionResult, ExecutorFactory}; +use katana_pool::validation::stateful::TxValidator; use katana_pool::TxPool; use katana_primitives::block::{ BlockHash, BlockHashOrNumber, BlockIdOrTag, BlockNumber, BlockTag, FinalityStatus, @@ -53,6 +54,7 @@ impl Clone for StarknetApi { } struct Inner { + validator: TxValidator, pool: TxPool, backend: Arc>, block_producer: Arc>, @@ -64,11 +66,12 @@ impl StarknetApi { backend: Arc>, pool: TxPool, block_producer: Arc>, + validator: TxValidator, ) -> Self { let blocking_task_pool = BlockingTaskPool::new().expect("failed to create blocking task pool"); - let inner = Inner { pool, backend, block_producer, blocking_task_pool }; + let inner = Inner { pool, backend, block_producer, blocking_task_pool, validator }; Self { inner: Arc::new(inner) } } @@ -296,8 +299,7 @@ impl StarknetApi { // TODO: this is a temporary solution, we should have a better way to handle this. // perhaps a pending/pool state provider that implements all the state provider traits. if let BlockIdOrTag::Tag(BlockTag::Pending) = block_id { - let validator = this.inner.block_producer.validator(); - let pool_nonce = validator.get_nonce(contract_address); + let pool_nonce = this.inner.validator.get_nonce(contract_address); return Ok(pool_nonce); } diff --git a/crates/katana/rpc/rpc/src/starknet/write.rs b/crates/katana/rpc/rpc/src/starknet/write.rs index e30e213b1a..158b0212a8 100644 --- a/crates/katana/rpc/rpc/src/starknet/write.rs +++ b/crates/katana/rpc/rpc/src/starknet/write.rs @@ -23,7 +23,8 @@ impl StarknetApi { let tx = tx.into_tx_with_chain_id(this.inner.backend.chain_id); let tx = ExecutableTxWithHash::new(ExecutableTx::Invoke(tx)); - let hash = this.inner.pool.add_transaction(tx)?; + let hash = + this.inner.pool.add_transaction(tx).inspect_err(|e| println!("Error: {:?}", e))?; Ok(hash.into()) }) diff --git a/crates/katana/rpc/rpc/tests/starknet.rs b/crates/katana/rpc/rpc/tests/starknet.rs index fd4f51f37b..786d76a9a9 100644 --- a/crates/katana/rpc/rpc/tests/starknet.rs +++ b/crates/katana/rpc/rpc/tests/starknet.rs @@ -27,6 +27,7 @@ use starknet::core::utils::{get_contract_address, get_selector_from_name}; use starknet::macros::felt; use starknet::providers::{Provider, ProviderError}; use starknet::signers::{LocalWallet, SigningKey}; +use tokio::sync::Mutex; mod common; @@ -230,8 +231,8 @@ async fn estimate_fee() -> Result<()> { } #[rstest::rstest] -#[tokio::test] -async fn rapid_transactions_submissions( +#[tokio::test(flavor = "multi_thread")] +async fn concurrent_transactions_submissions( #[values(None, Some(1000))] block_time: Option, ) -> Result<()> { // setup test sequencer with the given configuration @@ -240,25 +241,48 @@ async fn rapid_transactions_submissions( let sequencer = TestSequencer::start(sequencer_config, starknet_config).await; let provider = sequencer.provider(); - let account = sequencer.account(); + let account = Arc::new(sequencer.account()); // setup test contract to interact with. abigen_legacy!(Contract, "crates/katana/rpc/rpc/tests/test_data/erc20.json"); - let contract = Contract::new(DEFAULT_FEE_TOKEN_ADDRESS.into(), &account); // function call params let recipient = Felt::ONE; let amount = Uint256 { low: Felt::ONE, high: Felt::ZERO }; - const N: usize = 10; - let mut txs = IndexSet::with_capacity(N); + let initial_nonce = + provider.get_nonce(BlockId::Tag(BlockTag::Pending), sequencer.account().address()).await?; + + const N: usize = 100; + let nonce = Arc::new(Mutex::new(initial_nonce)); + let txs = Arc::new(Mutex::new(IndexSet::with_capacity(N))); + + let mut handles = Vec::with_capacity(N); for _ in 0..N { - let res = contract.transfer(&recipient, &amount).send().await?; - txs.insert(res.transaction_hash); + let txs = txs.clone(); + let nonce = nonce.clone(); + let amount = amount.clone(); + let account = account.clone(); + + let handle = tokio::spawn(async move { + let mut nonce = nonce.lock().await; + let contract = Contract::new(DEFAULT_FEE_TOKEN_ADDRESS.into(), account); + let res = contract.transfer(&recipient, &amount).nonce(*nonce).send().await.unwrap(); + txs.lock().await.insert(res.transaction_hash); + *nonce += Felt::ONE; + }); + + handles.push(handle); + } + + // wait for all txs to be submitted + for handle in handles { + handle.await?; } // Wait only for the last transaction to be accepted + let txs = txs.lock().await; let last_tx = txs.last().unwrap(); dojo_utils::TransactionWaiter::new(*last_tx, &provider).await?; @@ -266,7 +290,7 @@ async fn rapid_transactions_submissions( assert_eq!(txs.len(), N); // check the status of each txs - for hash in txs { + for hash in txs.iter() { let receipt = provider.get_transaction_receipt(hash).await?; assert_eq!(receipt.receipt.execution_result(), &ExecutionResult::Succeeded); assert_eq!(receipt.receipt.finality_status(), &TransactionFinalityStatus::AcceptedOnL2);