diff --git a/Cargo.lock b/Cargo.lock index b28d950ab4..d044fc69cc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5611,6 +5611,7 @@ dependencies = [ "katana-provider", "katana-rpc-types", "katana-rpc-types-builder", + "katana-tasks", "serde", "serde_json", "serde_with", @@ -5661,6 +5662,16 @@ dependencies = [ "url", ] +[[package]] +name = "katana-tasks" +version = "0.5.0" +dependencies = [ + "futures", + "rayon", + "thiserror", + "tokio", +] + [[package]] name = "keccak" version = "0.1.4" @@ -7555,9 +7566,9 @@ checksum = "c707298afce11da2efef2f600116fa93ffa7a032b5d7b628aa17711ec81383ca" [[package]] name = "reqwest" -version = "0.11.22" +version = "0.11.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "046cd98826c46c2ac8ddecae268eb5c2e58628688a5fc7a2643704a73faba95b" +checksum = "37b1ae8d9ac08420c66222fb9096fc5de435c3c48542bc5336c51892cffafb41" dependencies = [ "async-compression", "base64 0.21.5", diff --git a/crates/katana/core/src/sequencer.rs b/crates/katana/core/src/sequencer.rs index 1f2b85dbde..16c3ca9b7e 100644 --- a/crates/katana/core/src/sequencer.rs +++ b/crates/katana/core/src/sequencer.rs @@ -267,8 +267,9 @@ impl KatanaSequencer { self.backend.chain_id } - pub fn block_number(&self) -> BlockNumber { - BlockNumberProvider::latest_number(&self.backend.blockchain.provider()).unwrap() + pub fn block_number(&self) -> SequencerResult { + let num = BlockNumberProvider::latest_number(&self.backend.blockchain.provider())?; + Ok(num) } pub fn block_tx_count(&self, block_id: BlockIdOrTag) -> SequencerResult> { @@ -300,7 +301,7 @@ impl KatanaSequencer { Ok(count) } - pub async fn nonce_at( + pub fn nonce_at( &self, block_id: BlockIdOrTag, contract_address: ContractAddress, @@ -352,7 +353,7 @@ impl KatanaSequencer { Ok(tx) } - pub async fn events( + pub fn events( &self, from_block: BlockIdOrTag, to_block: BlockIdOrTag, diff --git a/crates/katana/rpc/Cargo.toml b/crates/katana/rpc/Cargo.toml index 33d21f9570..91450e7971 100644 --- a/crates/katana/rpc/Cargo.toml +++ b/crates/katana/rpc/Cargo.toml @@ -12,6 +12,7 @@ katana-primitives = { path = "../primitives" } katana-provider = { path = "../storage/provider" } katana-rpc-types = { path = "rpc-types" } katana-rpc-types-builder = { path = "rpc-types-builder" } +katana-tasks = { path = "../tasks" } anyhow.workspace = true cairo-lang-starknet = "2.3.1" diff --git a/crates/katana/rpc/src/starknet.rs b/crates/katana/rpc/src/starknet.rs index 737139d262..42e6f9984c 100644 --- a/crates/katana/rpc/src/starknet.rs +++ b/crates/katana/rpc/src/starknet.rs @@ -29,23 +29,59 @@ use katana_rpc_types::transaction::{ }; use katana_rpc_types::{ContractClass, FeeEstimate, FeltAsHex, FunctionCall}; use katana_rpc_types_builder::ReceiptBuilder; +use katana_tasks::{BlockingTaskPool, TokioTaskSpawner}; use starknet::core::types::{BlockTag, TransactionExecutionStatus, TransactionStatus}; use crate::api::starknet::{StarknetApiError, StarknetApiServer}; +#[derive(Clone)] pub struct StarknetApi { + inner: Arc, +} + +struct StarknetApiInner { sequencer: Arc, + blocking_task_pool: BlockingTaskPool, } impl StarknetApi { pub fn new(sequencer: Arc) -> Self { - Self { sequencer } + let blocking_task_pool = + BlockingTaskPool::new().expect("failed to create blocking task pool"); + Self { inner: Arc::new(StarknetApiInner { sequencer, blocking_task_pool }) } + } + + async fn on_cpu_blocking_task(&self, func: F) -> T + where + F: FnOnce(Self) -> T + Send + 'static, + T: Send + 'static, + { + let this = self.clone(); + self.inner.blocking_task_pool.spawn(move || func(this)).await.unwrap() + } + + async fn on_io_blocking_task(&self, func: F) -> T + where + F: FnOnce(Self) -> T + Send + 'static, + T: Send + 'static, + { + let this = self.clone(); + // create oneshot tokio channel + let (sender, receiver) = tokio::sync::oneshot::channel::(); + let _ = TokioTaskSpawner::new() + .unwrap() + .spawn_blocking(move || { + let res = func(this); + let _ = sender.send(res); + }) + .await; + receiver.await.unwrap() } } #[async_trait] impl StarknetApiServer for StarknetApi { async fn chain_id(&self) -> Result { - Ok(FieldElement::from(self.sequencer.chain_id()).into()) + Ok(FieldElement::from(self.inner.sequencer.chain_id()).into()) } async fn nonce( @@ -53,36 +89,51 @@ impl StarknetApiServer for StarknetApi { block_id: BlockIdOrTag, contract_address: FieldElement, ) -> Result { - let nonce = self - .sequencer - .nonce_at(block_id, contract_address.into()) - .await - .map_err(StarknetApiError::from)? - .ok_or(StarknetApiError::ContractNotFound)?; - - Ok(nonce.into()) + self.on_io_blocking_task(move |this| { + let nonce = this + .inner + .sequencer + .nonce_at(block_id, contract_address.into()) + .map_err(StarknetApiError::from)? + .ok_or(StarknetApiError::ContractNotFound)?; + Ok(nonce.into()) + }) + .await } async fn block_number(&self) -> Result { - Ok(self.sequencer.block_number()) + self.on_io_blocking_task(move |this| { + let block_number = + this.inner.sequencer.block_number().map_err(StarknetApiError::from)?; + Ok(block_number) + }) + .await } async fn transaction_by_hash(&self, transaction_hash: FieldElement) -> Result { - let tx = self - .sequencer - .transaction(&transaction_hash) - .map_err(StarknetApiError::from)? - .ok_or(StarknetApiError::TxnHashNotFound)?; - Ok(tx.into()) + self.on_io_blocking_task(move |this| { + let tx = this + .inner + .sequencer + .transaction(&transaction_hash) + .map_err(StarknetApiError::from)? + .ok_or(StarknetApiError::TxnHashNotFound)?; + Ok(tx.into()) + }) + .await } async fn block_transaction_count(&self, block_id: BlockIdOrTag) -> Result { - let count = self - .sequencer - .block_tx_count(block_id) - .map_err(StarknetApiError::from)? - .ok_or(StarknetApiError::BlockNotFound)?; - Ok(count) + self.on_io_blocking_task(move |this| { + let count = this + .inner + .sequencer + .block_tx_count(block_id) + .map_err(StarknetApiError::from)? + .ok_or(StarknetApiError::BlockNotFound)?; + Ok(count) + }) + .await } async fn class_at( @@ -91,17 +142,22 @@ impl StarknetApiServer for StarknetApi { contract_address: FieldElement, ) -> Result { let class_hash = self - .sequencer - .class_hash_at(block_id, contract_address.into()) - .map_err(StarknetApiError::from)? - .ok_or(StarknetApiError::ContractNotFound)?; - + .on_io_blocking_task(move |this| { + this.inner + .sequencer + .class_hash_at(block_id, contract_address.into()) + .map_err(StarknetApiError::from)? + .ok_or(StarknetApiError::ContractNotFound) + }) + .await?; self.class(block_id, class_hash).await } async fn block_hash_and_number(&self) -> Result { - let hash_and_num_pair = - self.sequencer.block_hash_and_number().map_err(StarknetApiError::from)?; + let hash_and_num_pair = self + .on_io_blocking_task(move |this| this.inner.sequencer.block_hash_and_number()) + .await + .map_err(StarknetApiError::from)?; Ok(hash_and_num_pair.into()) } @@ -109,51 +165,53 @@ impl StarknetApiServer for StarknetApi { &self, block_id: BlockIdOrTag, ) -> Result { - let provider = self.sequencer.backend.blockchain.provider(); - - if BlockIdOrTag::Tag(BlockTag::Pending) == block_id { - if let Some(pending_state) = self.sequencer.pending_state() { - let block_env = pending_state.block_envs.read().0.clone(); - let latest_hash = - BlockHashProvider::latest_hash(provider).map_err(StarknetApiError::from)?; - - let gas_prices = GasPrices { - eth: block_env.l1_gas_prices.eth, - strk: block_env.l1_gas_prices.strk, - }; + self.on_io_blocking_task(move |this| { + let provider = this.inner.sequencer.backend.blockchain.provider(); + + if BlockIdOrTag::Tag(BlockTag::Pending) == block_id { + if let Some(pending_state) = this.inner.sequencer.pending_state() { + let block_env = pending_state.block_envs.read().0.clone(); + let latest_hash = + BlockHashProvider::latest_hash(provider).map_err(StarknetApiError::from)?; + + let gas_prices = GasPrices { + eth: block_env.l1_gas_prices.eth, + strk: block_env.l1_gas_prices.strk, + }; + + let header = PartialHeader { + gas_prices, + parent_hash: latest_hash, + version: CURRENT_STARKNET_VERSION, + timestamp: block_env.timestamp, + sequencer_address: block_env.sequencer_address, + }; + + let transactions = pending_state + .executed_txs + .read() + .iter() + .map(|(tx, _)| tx.hash) + .collect::>(); - let header = PartialHeader { - gas_prices, - parent_hash: latest_hash, - version: CURRENT_STARKNET_VERSION, - timestamp: block_env.timestamp, - sequencer_address: block_env.sequencer_address, - }; + return Ok(MaybePendingBlockWithTxHashes::Pending( + PendingBlockWithTxHashes::new(header, transactions), + )); + } + } - let transactions = pending_state - .executed_txs - .read() - .iter() - .map(|(tx, _)| tx.hash) - .collect::>(); + let block_num = BlockIdReader::convert_block_id(provider, block_id) + .map_err(StarknetApiError::from)? + .map(BlockHashOrNumber::Num) + .ok_or(StarknetApiError::BlockNotFound)?; - return Ok(MaybePendingBlockWithTxHashes::Pending(PendingBlockWithTxHashes::new( - header, - transactions, - ))); - } - } - - let block_num = BlockIdReader::convert_block_id(provider, block_id) - .map_err(StarknetApiError::from)? - .map(BlockHashOrNumber::Num) - .ok_or(StarknetApiError::BlockNotFound)?; - - katana_rpc_types_builder::BlockBuilder::new(block_num, provider) - .build_with_tx_hash() - .map_err(StarknetApiError::from)? - .map(MaybePendingBlockWithTxHashes::Block) - .ok_or(Error::from(StarknetApiError::BlockNotFound)) + katana_rpc_types_builder::BlockBuilder::new(block_num, provider) + .build_with_tx_hash() + .map_err(StarknetApiError::from)? + .map(MaybePendingBlockWithTxHashes::Block) + .ok_or(Error::from(StarknetApiError::BlockNotFound)) + }) + .await } async fn transaction_by_block_id_and_index( @@ -161,133 +219,145 @@ impl StarknetApiServer for StarknetApi { block_id: BlockIdOrTag, index: u64, ) -> Result { - // TEMP: have to handle pending tag independently for now - let tx = if BlockIdOrTag::Tag(BlockTag::Pending) == block_id { - let Some(pending_state) = self.sequencer.pending_state() else { - return Err(StarknetApiError::BlockNotFound.into()); - }; + self.on_io_blocking_task(move |this| { + // TEMP: have to handle pending tag independently for now + let tx = if BlockIdOrTag::Tag(BlockTag::Pending) == block_id { + let Some(pending_state) = this.inner.sequencer.pending_state() else { + return Err(StarknetApiError::BlockNotFound.into()); + }; - let pending_txs = pending_state.executed_txs.read(); - pending_txs.iter().nth(index as usize).map(|(tx, _)| tx.clone()) - } else { - let provider = &self.sequencer.backend.blockchain.provider(); + let pending_txs = pending_state.executed_txs.read(); + pending_txs.iter().nth(index as usize).map(|(tx, _)| tx.clone()) + } else { + let provider = &this.inner.sequencer.backend.blockchain.provider(); - let block_num = BlockIdReader::convert_block_id(provider, block_id) - .map_err(StarknetApiError::from)? - .map(BlockHashOrNumber::Num) - .ok_or(StarknetApiError::BlockNotFound)?; + let block_num = BlockIdReader::convert_block_id(provider, block_id) + .map_err(StarknetApiError::from)? + .map(BlockHashOrNumber::Num) + .ok_or(StarknetApiError::BlockNotFound)?; - TransactionProvider::transaction_by_block_and_idx(provider, block_num, index) - .map_err(StarknetApiError::from)? - }; + TransactionProvider::transaction_by_block_and_idx(provider, block_num, index) + .map_err(StarknetApiError::from)? + }; - Ok(tx.ok_or(StarknetApiError::InvalidTxnIndex)?.into()) + Ok(tx.ok_or(StarknetApiError::InvalidTxnIndex)?.into()) + }) + .await } async fn block_with_txs( &self, block_id: BlockIdOrTag, ) -> Result { - let provider = self.sequencer.backend.blockchain.provider(); - - if BlockIdOrTag::Tag(BlockTag::Pending) == block_id { - if let Some(pending_state) = self.sequencer.pending_state() { - let block_env = pending_state.block_envs.read().0.clone(); - let latest_hash = - BlockHashProvider::latest_hash(provider).map_err(StarknetApiError::from)?; - - let gas_prices = GasPrices { - eth: block_env.l1_gas_prices.eth, - strk: block_env.l1_gas_prices.strk, - }; - - let header = PartialHeader { - gas_prices, - parent_hash: latest_hash, - version: CURRENT_STARKNET_VERSION, - timestamp: block_env.timestamp, - sequencer_address: block_env.sequencer_address, - }; + self.on_io_blocking_task(move |this| { + let provider = this.inner.sequencer.backend.blockchain.provider(); + + if BlockIdOrTag::Tag(BlockTag::Pending) == block_id { + if let Some(pending_state) = this.inner.sequencer.pending_state() { + let block_env = pending_state.block_envs.read().0.clone(); + let latest_hash = + BlockHashProvider::latest_hash(provider).map_err(StarknetApiError::from)?; + + let gas_prices = GasPrices { + eth: block_env.l1_gas_prices.eth, + strk: block_env.l1_gas_prices.strk, + }; + + let header = PartialHeader { + gas_prices, + parent_hash: latest_hash, + version: CURRENT_STARKNET_VERSION, + timestamp: block_env.timestamp, + sequencer_address: block_env.sequencer_address, + }; + + let transactions = pending_state + .executed_txs + .read() + .iter() + .map(|(tx, _)| tx.clone()) + .collect::>(); + + return Ok(MaybePendingBlockWithTxs::Pending(PendingBlockWithTxs::new( + header, + transactions, + ))); + } + } - let transactions = pending_state - .executed_txs - .read() - .iter() - .map(|(tx, _)| tx.clone()) - .collect::>(); + let block_num = BlockIdReader::convert_block_id(provider, block_id) + .map_err(|e| StarknetApiError::UnexpectedError { reason: e.to_string() })? + .map(BlockHashOrNumber::Num) + .ok_or(StarknetApiError::BlockNotFound)?; - return Ok(MaybePendingBlockWithTxs::Pending(PendingBlockWithTxs::new( - header, - transactions, - ))); - } - } - - let block_num = BlockIdReader::convert_block_id(provider, block_id) - .map_err(|e| StarknetApiError::UnexpectedError { reason: e.to_string() })? - .map(BlockHashOrNumber::Num) - .ok_or(StarknetApiError::BlockNotFound)?; - - katana_rpc_types_builder::BlockBuilder::new(block_num, provider) - .build() - .map_err(|e| StarknetApiError::UnexpectedError { reason: e.to_string() })? - .map(MaybePendingBlockWithTxs::Block) - .ok_or(Error::from(StarknetApiError::BlockNotFound)) + katana_rpc_types_builder::BlockBuilder::new(block_num, provider) + .build() + .map_err(|e| StarknetApiError::UnexpectedError { reason: e.to_string() })? + .map(MaybePendingBlockWithTxs::Block) + .ok_or(Error::from(StarknetApiError::BlockNotFound)) + }) + .await } async fn state_update(&self, block_id: BlockIdOrTag) -> Result { - let provider = self.sequencer.backend.blockchain.provider(); + self.on_io_blocking_task(move |this| { + let provider = this.inner.sequencer.backend.blockchain.provider(); - let block_id = match block_id { - BlockIdOrTag::Number(num) => BlockHashOrNumber::Num(num), - BlockIdOrTag::Hash(hash) => BlockHashOrNumber::Hash(hash), + let block_id = match block_id { + BlockIdOrTag::Number(num) => BlockHashOrNumber::Num(num), + BlockIdOrTag::Hash(hash) => BlockHashOrNumber::Hash(hash), - BlockIdOrTag::Tag(BlockTag::Latest) => BlockNumberProvider::latest_number(provider) - .map(BlockHashOrNumber::Num) - .map_err(|_| StarknetApiError::BlockNotFound)?, + BlockIdOrTag::Tag(BlockTag::Latest) => BlockNumberProvider::latest_number(provider) + .map(BlockHashOrNumber::Num) + .map_err(|_| StarknetApiError::BlockNotFound)?, - BlockIdOrTag::Tag(BlockTag::Pending) => { - return Err(StarknetApiError::BlockNotFound.into()); - } - }; + BlockIdOrTag::Tag(BlockTag::Pending) => { + return Err(StarknetApiError::BlockNotFound.into()); + } + }; - katana_rpc_types_builder::StateUpdateBuilder::new(block_id, provider) - .build() - .map_err(|e| StarknetApiError::UnexpectedError { reason: e.to_string() })? - .ok_or(Error::from(StarknetApiError::BlockNotFound)) + katana_rpc_types_builder::StateUpdateBuilder::new(block_id, provider) + .build() + .map_err(|e| StarknetApiError::UnexpectedError { reason: e.to_string() })? + .ok_or(Error::from(StarknetApiError::BlockNotFound)) + }) + .await } async fn transaction_receipt( &self, transaction_hash: FieldElement, ) -> Result { - let provider = self.sequencer.backend.blockchain.provider(); - let receipt = ReceiptBuilder::new(transaction_hash, provider) - .build() - .map_err(|e| StarknetApiError::UnexpectedError { reason: e.to_string() })?; - - match receipt { - Some(receipt) => Ok(MaybePendingTxReceipt::Receipt(receipt)), - - None => { - let pending_receipt = self.sequencer.pending_state().and_then(|s| { - s.executed_txs - .read() - .iter() - .find(|(tx, _)| tx.hash == transaction_hash) - .map(|(_, rct)| rct.receipt.clone()) - }); - - let Some(pending_receipt) = pending_receipt else { - return Err(StarknetApiError::TxnHashNotFound.into()); - }; - - Ok(MaybePendingTxReceipt::Pending(PendingTxReceipt::new( - transaction_hash, - pending_receipt, - ))) + self.on_io_blocking_task(move |this| { + let provider = this.inner.sequencer.backend.blockchain.provider(); + let receipt = ReceiptBuilder::new(transaction_hash, provider) + .build() + .map_err(|e| StarknetApiError::UnexpectedError { reason: e.to_string() })?; + + match receipt { + Some(receipt) => Ok(MaybePendingTxReceipt::Receipt(receipt)), + + None => { + let pending_receipt = this.inner.sequencer.pending_state().and_then(|s| { + s.executed_txs + .read() + .iter() + .find(|(tx, _)| tx.hash == transaction_hash) + .map(|(_, rct)| rct.receipt.clone()) + }); + + let Some(pending_receipt) = pending_receipt else { + return Err(StarknetApiError::TxnHashNotFound.into()); + }; + + Ok(MaybePendingTxReceipt::Pending(PendingTxReceipt::new( + transaction_hash, + pending_receipt, + ))) + } } - } + }) + .await } async fn class_hash_at( @@ -295,13 +365,16 @@ impl StarknetApiServer for StarknetApi { block_id: BlockIdOrTag, contract_address: FieldElement, ) -> Result { - let hash = self - .sequencer - .class_hash_at(block_id, contract_address.into()) - .map_err(StarknetApiError::from)? - .ok_or(StarknetApiError::ContractNotFound)?; - - Ok(hash.into()) + self.on_io_blocking_task(move |this| { + let hash = this + .inner + .sequencer + .class_hash_at(block_id, contract_address.into()) + .map_err(StarknetApiError::from)? + .ok_or(StarknetApiError::ContractNotFound)?; + Ok(hash.into()) + }) + .await } async fn class( @@ -309,40 +382,48 @@ impl StarknetApiServer for StarknetApi { block_id: BlockIdOrTag, class_hash: FieldElement, ) -> Result { - let class = self.sequencer.class(block_id, class_hash).map_err(StarknetApiError::from)?; - let Some(class) = class else { return Err(StarknetApiError::ClassHashNotFound.into()) }; - - match class { - StarknetContract::Legacy(c) => { - let contract = legacy_inner_to_rpc_class(c) - .map_err(|e| StarknetApiError::UnexpectedError { reason: e.to_string() })?; - Ok(contract) + self.on_io_blocking_task(move |this| { + let class = + this.inner.sequencer.class(block_id, class_hash).map_err(StarknetApiError::from)?; + let Some(class) = class else { return Err(StarknetApiError::ClassHashNotFound.into()) }; + + match class { + StarknetContract::Legacy(c) => { + let contract = legacy_inner_to_rpc_class(c) + .map_err(|e| StarknetApiError::UnexpectedError { reason: e.to_string() })?; + Ok(contract) + } + StarknetContract::Sierra(c) => Ok(ContractClass::Sierra(c)), } - StarknetContract::Sierra(c) => Ok(ContractClass::Sierra(c)), - } + }) + .await } async fn events(&self, filter: EventFilterWithPage) -> Result { - let from_block = filter.event_filter.from_block.unwrap_or(BlockIdOrTag::Number(0)); - let to_block = filter.event_filter.to_block.unwrap_or(BlockIdOrTag::Tag(BlockTag::Latest)); - - let keys = filter.event_filter.keys; - let keys = keys.filter(|keys| !(keys.len() == 1 && keys.is_empty())); - - let events = self - .sequencer - .events( - from_block, - to_block, - filter.event_filter.address.map(|f| f.into()), - keys, - filter.result_page_request.continuation_token, - filter.result_page_request.chunk_size, - ) - .await - .map_err(StarknetApiError::from)?; - - Ok(events) + self.on_io_blocking_task(move |this| { + let from_block = filter.event_filter.from_block.unwrap_or(BlockIdOrTag::Number(0)); + let to_block = + filter.event_filter.to_block.unwrap_or(BlockIdOrTag::Tag(BlockTag::Latest)); + + let keys = filter.event_filter.keys; + let keys = keys.filter(|keys| !(keys.len() == 1 && keys.is_empty())); + + let events = this + .inner + .sequencer + .events( + from_block, + to_block, + filter.event_filter.address.map(|f| f.into()), + keys, + filter.result_page_request.continuation_token, + filter.result_page_request.chunk_size, + ) + .map_err(StarknetApiError::from)?; + + Ok(events) + }) + .await } async fn call( @@ -350,15 +431,18 @@ impl StarknetApiServer for StarknetApi { request: FunctionCall, block_id: BlockIdOrTag, ) -> Result, Error> { - let request = EntryPointCall { - calldata: request.calldata, - contract_address: request.contract_address.into(), - entry_point_selector: request.entry_point_selector, - }; - - let res = self.sequencer.call(request, block_id).map_err(StarknetApiError::from)?; + self.on_io_blocking_task(move |this| { + let request = EntryPointCall { + calldata: request.calldata, + contract_address: request.contract_address.into(), + entry_point_selector: request.entry_point_selector, + }; - Ok(res.into_iter().map(|v| v.into()).collect()) + let res = + this.inner.sequencer.call(request, block_id).map_err(StarknetApiError::from)?; + Ok(res.into_iter().map(|v| v.into()).collect()) + }) + .await } async fn storage_at( @@ -367,33 +451,40 @@ impl StarknetApiServer for StarknetApi { key: FieldElement, block_id: BlockIdOrTag, ) -> Result { - let value = self - .sequencer - .storage_at(contract_address.into(), key, block_id) - .map_err(StarknetApiError::from)?; - - Ok(value.into()) + self.on_io_blocking_task(move |this| { + let value = this + .inner + .sequencer + .storage_at(contract_address.into(), key, block_id) + .map_err(StarknetApiError::from)?; + + Ok(value.into()) + }) + .await } async fn add_deploy_account_transaction( &self, deploy_account_transaction: BroadcastedDeployAccountTx, ) -> Result { - if deploy_account_transaction.is_query { - return Err(StarknetApiError::UnsupportedTransactionVersion.into()); - } + self.on_io_blocking_task(move |this| { + if deploy_account_transaction.is_query { + return Err(StarknetApiError::UnsupportedTransactionVersion.into()); + } - let chain_id = self.sequencer.chain_id(); + let chain_id = this.inner.sequencer.chain_id(); - let tx = deploy_account_transaction.into_tx_with_chain_id(chain_id); - let contract_address = tx.contract_address; + let tx = deploy_account_transaction.into_tx_with_chain_id(chain_id); + let contract_address = tx.contract_address; - let tx = ExecutableTxWithHash::new(ExecutableTx::DeployAccount(tx)); - let tx_hash = tx.hash; + let tx = ExecutableTxWithHash::new(ExecutableTx::DeployAccount(tx)); + let tx_hash = tx.hash; - self.sequencer.add_transaction_to_pool(tx); + this.inner.sequencer.add_transaction_to_pool(tx); - Ok((tx_hash, contract_address).into()) + Ok((tx_hash, contract_address).into()) + }) + .await } async fn estimate_fee( @@ -401,38 +492,44 @@ impl StarknetApiServer for StarknetApi { request: Vec, block_id: BlockIdOrTag, ) -> Result, Error> { - let chain_id = self.sequencer.chain_id(); - - let transactions = request - .into_iter() - .map(|tx| { - let tx = match tx { - BroadcastedTx::Invoke(tx) => { - let tx = tx.into_tx_with_chain_id(chain_id); - ExecutableTxWithHash::new_query(ExecutableTx::Invoke(tx)) - } - - BroadcastedTx::DeployAccount(tx) => { - let tx = tx.into_tx_with_chain_id(chain_id); - ExecutableTxWithHash::new_query(ExecutableTx::DeployAccount(tx)) - } - - BroadcastedTx::Declare(tx) => { - let tx = tx - .try_into_tx_with_chain_id(chain_id) - .map_err(|_| StarknetApiError::InvalidContractClass)?; - ExecutableTxWithHash::new_query(ExecutableTx::Declare(tx)) - } - }; - - Result::::Ok(tx) - }) - .collect::, _>>()?; - - let res = - self.sequencer.estimate_fee(transactions, block_id).map_err(StarknetApiError::from)?; - - Ok(res) + self.on_cpu_blocking_task(move |this| { + let chain_id = this.inner.sequencer.chain_id(); + + let transactions = request + .into_iter() + .map(|tx| { + let tx = match tx { + BroadcastedTx::Invoke(tx) => { + let tx = tx.into_tx_with_chain_id(chain_id); + ExecutableTxWithHash::new_query(ExecutableTx::Invoke(tx)) + } + + BroadcastedTx::DeployAccount(tx) => { + let tx = tx.into_tx_with_chain_id(chain_id); + ExecutableTxWithHash::new_query(ExecutableTx::DeployAccount(tx)) + } + + BroadcastedTx::Declare(tx) => { + let tx = tx + .try_into_tx_with_chain_id(chain_id) + .map_err(|_| StarknetApiError::InvalidContractClass)?; + ExecutableTxWithHash::new_query(ExecutableTx::Declare(tx)) + } + }; + + Result::::Ok(tx) + }) + .collect::, _>>()?; + + let res = this + .inner + .sequencer + .estimate_fee(transactions, block_id) + .map_err(StarknetApiError::from)?; + + Ok(res) + }) + .await } async fn estimate_message_fee( @@ -440,129 +537,143 @@ impl StarknetApiServer for StarknetApi { message: MsgFromL1, block_id: BlockIdOrTag, ) -> Result { - let chain_id = self.sequencer.chain_id(); + self.on_cpu_blocking_task(move |this| { + let chain_id = this.inner.sequencer.chain_id(); - let tx = message.into_tx_with_chain_id(chain_id); - let hash = tx.calculate_hash(); - let tx: ExecutableTxWithHash = ExecutableTxWithHash { hash, transaction: tx.into() }; + let tx = message.into_tx_with_chain_id(chain_id); + let hash = tx.calculate_hash(); + let tx: ExecutableTxWithHash = ExecutableTxWithHash { hash, transaction: tx.into() }; - let res = self - .sequencer - .estimate_fee(vec![tx], block_id) - .map_err(StarknetApiError::from)? - .pop() - .expect("should have estimate result"); + let res = this + .inner + .sequencer + .estimate_fee(vec![tx], block_id) + .map_err(StarknetApiError::from)? + .pop() + .expect("should have estimate result"); - Ok(res) + Ok(res) + }) + .await } async fn add_declare_transaction( &self, declare_transaction: BroadcastedDeclareTx, ) -> Result { - if declare_transaction.is_query() { - return Err(StarknetApiError::UnsupportedTransactionVersion.into()); - } + self.on_io_blocking_task(move |this| { + if declare_transaction.is_query() { + return Err(StarknetApiError::UnsupportedTransactionVersion.into()); + } - let chain_id = self.sequencer.chain_id(); + let chain_id = this.inner.sequencer.chain_id(); - // // validate compiled class hash - // let is_valid = declare_transaction - // .validate_compiled_class_hash() - // .map_err(|_| StarknetApiError::InvalidContractClass)?; + // // validate compiled class hash + // let is_valid = declare_transaction + // .validate_compiled_class_hash() + // .map_err(|_| StarknetApiError::InvalidContractClass)?; - // if !is_valid { - // return Err(StarknetApiError::CompiledClassHashMismatch.into()); - // } + // if !is_valid { + // return Err(StarknetApiError::CompiledClassHashMismatch.into()); + // } - let tx = declare_transaction - .try_into_tx_with_chain_id(chain_id) - .map_err(|_| StarknetApiError::InvalidContractClass)?; + let tx = declare_transaction + .try_into_tx_with_chain_id(chain_id) + .map_err(|_| StarknetApiError::InvalidContractClass)?; - let class_hash = tx.class_hash(); - let tx = ExecutableTxWithHash::new(ExecutableTx::Declare(tx)); - let tx_hash = tx.hash; + let class_hash = tx.class_hash(); + let tx = ExecutableTxWithHash::new(ExecutableTx::Declare(tx)); + let tx_hash = tx.hash; - self.sequencer.add_transaction_to_pool(tx); + this.inner.sequencer.add_transaction_to_pool(tx); - Ok((tx_hash, class_hash).into()) + Ok((tx_hash, class_hash).into()) + }) + .await } async fn add_invoke_transaction( &self, invoke_transaction: BroadcastedInvokeTx, ) -> Result { - if invoke_transaction.is_query { - return Err(StarknetApiError::UnsupportedTransactionVersion.into()); - } + self.on_io_blocking_task(move |this| { + if invoke_transaction.is_query { + return Err(StarknetApiError::UnsupportedTransactionVersion.into()); + } - let chain_id = self.sequencer.chain_id(); + let chain_id = this.inner.sequencer.chain_id(); - let tx = invoke_transaction.into_tx_with_chain_id(chain_id); - let tx = ExecutableTxWithHash::new(ExecutableTx::Invoke(tx)); - let tx_hash = tx.hash; + let tx = invoke_transaction.into_tx_with_chain_id(chain_id); + let tx = ExecutableTxWithHash::new(ExecutableTx::Invoke(tx)); + let tx_hash = tx.hash; - self.sequencer.add_transaction_to_pool(tx); + this.inner.sequencer.add_transaction_to_pool(tx); - Ok(tx_hash.into()) + Ok(tx_hash.into()) + }) + .await } async fn transaction_status( &self, transaction_hash: TxHash, ) -> Result { - let provider = self.sequencer.backend.blockchain.provider(); + self.on_io_blocking_task(move |this| { + let provider = this.inner.sequencer.backend.blockchain.provider(); + + let tx_status = + TransactionStatusProvider::transaction_status(provider, transaction_hash) + .map_err(StarknetApiError::from)?; + + if let Some(status) = tx_status { + if let Some(receipt) = ReceiptProvider::receipt_by_hash(provider, transaction_hash) + .map_err(StarknetApiError::from)? + { + let execution_status = if receipt.is_reverted() { + TransactionExecutionStatus::Reverted + } else { + TransactionExecutionStatus::Succeeded + }; + + return Ok(match status { + FinalityStatus::AcceptedOnL1 => { + TransactionStatus::AcceptedOnL1(execution_status) + } + FinalityStatus::AcceptedOnL2 => { + TransactionStatus::AcceptedOnL2(execution_status) + } + }); + } + } - let tx_status = TransactionStatusProvider::transaction_status(provider, transaction_hash) - .map_err(StarknetApiError::from)?; + let pending_state = this.inner.sequencer.pending_state(); + let state = pending_state.ok_or(StarknetApiError::TxnHashNotFound)?; + let executed_txs = state.executed_txs.read(); - if let Some(status) = tx_status { - if let Some(receipt) = ReceiptProvider::receipt_by_hash(provider, transaction_hash) - .map_err(StarknetApiError::from)? + // attemps to find in the valid transactions list first (executed_txs) + // if not found, then search in the rejected transactions list (rejected_txs) + if let Some(is_reverted) = executed_txs + .iter() + .find(|(tx, _)| tx.hash == transaction_hash) + .map(|(_, rct)| rct.receipt.is_reverted()) { - let execution_status = if receipt.is_reverted() { + let exec_status = if is_reverted { TransactionExecutionStatus::Reverted } else { TransactionExecutionStatus::Succeeded }; - return Ok(match status { - FinalityStatus::AcceptedOnL1 => { - TransactionStatus::AcceptedOnL1(execution_status) - } - FinalityStatus::AcceptedOnL2 => { - TransactionStatus::AcceptedOnL2(execution_status) - } - }); - } - } - - let pending_state = self.sequencer.pending_state(); - let state = pending_state.ok_or(StarknetApiError::TxnHashNotFound)?; - let executed_txs = state.executed_txs.read(); - - // attemps to find in the valid transactions list first (executed_txs) - // if not found, then search in the rejected transactions list (rejected_txs) - if let Some(is_reverted) = executed_txs - .iter() - .find(|(tx, _)| tx.hash == transaction_hash) - .map(|(_, rct)| rct.receipt.is_reverted()) - { - let exec_status = if is_reverted { - TransactionExecutionStatus::Reverted + Ok(TransactionStatus::AcceptedOnL2(exec_status)) } else { - TransactionExecutionStatus::Succeeded - }; + let rejected_txs = state.rejected_txs.read(); - Ok(TransactionStatus::AcceptedOnL2(exec_status)) - } else { - let rejected_txs = state.rejected_txs.read(); - - rejected_txs - .iter() - .find(|(tx, _)| tx.hash == transaction_hash) - .map(|_| TransactionStatus::Rejected) - .ok_or(Error::from(StarknetApiError::TxnHashNotFound)) - } + rejected_txs + .iter() + .find(|(tx, _)| tx.hash == transaction_hash) + .map(|_| TransactionStatus::Rejected) + .ok_or(Error::from(StarknetApiError::TxnHashNotFound)) + } + }) + .await } } diff --git a/crates/katana/tasks/Cargo.toml b/crates/katana/tasks/Cargo.toml new file mode 100644 index 0000000000..fd03a40729 --- /dev/null +++ b/crates/katana/tasks/Cargo.toml @@ -0,0 +1,12 @@ +[package] +edition.workspace = true +name = "katana-tasks" +version.workspace = true + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +futures.workspace = true +rayon.workspace = true +thiserror.workspace = true +tokio.workspace = true diff --git a/crates/katana/tasks/src/lib.rs b/crates/katana/tasks/src/lib.rs new file mode 100644 index 0000000000..cbec89247d --- /dev/null +++ b/crates/katana/tasks/src/lib.rs @@ -0,0 +1,111 @@ +use std::any::Any; +use std::future::Future; +use std::panic::{self, AssertUnwindSafe}; +use std::pin::Pin; +use std::sync::Arc; +use std::task::Poll; + +use futures::channel::oneshot; +use rayon::{ThreadPoolBuildError, ThreadPoolBuilder}; +use tokio::runtime::Handle; +use tokio::task::JoinHandle; + +/// This `struct` is created by the [TokioTaskSpawner::new] method. +#[derive(Debug, thiserror::Error)] +#[error("Failed to initialize task spawner: {0}")] +pub struct TaskSpawnerInitError(tokio::runtime::TryCurrentError); + +/// A task spawner for spawning tasks on a tokio runtime. This is simple wrapper around a tokio's +/// runtime [Handle] to easily spawn tasks on the runtime. +/// +/// For running expensive CPU-bound tasks, use [BlockingTaskPool] instead. +#[derive(Clone)] +pub struct TokioTaskSpawner { + /// Handle to the tokio runtime. + tokio_handle: Handle, +} + +impl TokioTaskSpawner { + /// Creates a new [TokioTaskSpawner] over the currently running tokio runtime. + /// + /// ## Errors + /// + /// Returns an error if no tokio runtime has been started. + pub fn new() -> Result { + Ok(Self { tokio_handle: Handle::try_current().map_err(TaskSpawnerInitError)? }) + } + + /// Creates a new [TokioTaskSpawner] with the given tokio runtime [Handle]. + pub fn new_with_handle(tokio_handle: Handle) -> Self { + Self { tokio_handle } + } +} + +impl TokioTaskSpawner { + pub fn spawn(&self, future: F) -> JoinHandle + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + self.tokio_handle.spawn(future) + } + + pub fn spawn_blocking(&self, func: F) -> JoinHandle + where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, + { + self.tokio_handle.spawn_blocking(func) + } +} + +type BlockingTaskResult = Result>; + +#[derive(Debug)] +#[must_use = "BlockingTaskHandle does nothing unless polled"] +pub struct BlockingTaskHandle(oneshot::Receiver>); + +impl Future for BlockingTaskHandle { + type Output = BlockingTaskResult; + + fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { + match Pin::new(&mut self.get_mut().0).poll(cx) { + Poll::Ready(Ok(res)) => Poll::Ready(res), + Poll::Ready(Err(_)) => panic!("blocking task cancelled"), + Poll::Pending => Poll::Pending, + } + } +} + +/// This is mainly for expensive CPU-bound tasks. For spawing blocking IO-bound tasks, use +/// [TokioTaskSpawner::spawn_blocking] instead. +#[derive(Debug, Clone)] +pub struct BlockingTaskPool { + pool: Arc, +} + +impl BlockingTaskPool { + pub fn build() -> ThreadPoolBuilder { + ThreadPoolBuilder::new().thread_name(|i| format!("blocking-thread-pool-{i}")) + } + + pub fn new() -> Result { + Self::build().build().map(|pool| Self { pool: Arc::new(pool) }) + } + + pub fn new_with_pool(rayon_pool: rayon::ThreadPool) -> Self { + Self { pool: Arc::new(rayon_pool) } + } + + pub fn spawn(&self, func: F) -> BlockingTaskHandle + where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, + { + let (tx, rx) = oneshot::channel(); + self.pool.spawn(move || { + let _ = tx.send(panic::catch_unwind(AssertUnwindSafe(func))); + }); + BlockingTaskHandle(rx) + } +}