From ddc08524378e8d1d95b4b3c8958e52faff6a3aa6 Mon Sep 17 00:00:00 2001 From: Kariy Date: Thu, 18 Jan 2024 00:32:05 +0900 Subject: [PATCH 1/5] wip --- crates/katana/rpc/src/starknet.rs | 16 ++++++++-- crates/katana/tasks/Cargo.toml | 10 ++++++ crates/katana/tasks/src/lib.rs | 51 +++++++++++++++++++++++++++++++ 3 files changed, 75 insertions(+), 2 deletions(-) create mode 100644 crates/katana/tasks/Cargo.toml create mode 100644 crates/katana/tasks/src/lib.rs diff --git a/crates/katana/rpc/src/starknet.rs b/crates/katana/rpc/src/starknet.rs index 737139d262..3603f24aa2 100644 --- a/crates/katana/rpc/src/starknet.rs +++ b/crates/katana/rpc/src/starknet.rs @@ -29,10 +29,12 @@ use katana_rpc_types::transaction::{ }; use katana_rpc_types::{ContractClass, FeeEstimate, FeltAsHex, FunctionCall}; use katana_rpc_types_builder::ReceiptBuilder; +use katana_tasks::TokioTaskSpawner; use starknet::core::types::{BlockTag, TransactionExecutionStatus, TransactionStatus}; use crate::api::starknet::{StarknetApiError, StarknetApiServer}; +#[derive(Clone)] pub struct StarknetApi { sequencer: Arc, } @@ -41,6 +43,15 @@ impl StarknetApi { pub fn new(sequencer: Arc) -> Self { Self { sequencer } } + + async fn on_blocking_task(&self, func: F) -> T + where + F: FnOnce(Self) -> T + Send + 'static, + T: Send + 'static, + { + let this = self.clone(); + TokioTaskSpawner::new().unwrap().spawn_blocking(move || func(this)).await.unwrap() + } } #[async_trait] impl StarknetApiServer for StarknetApi { @@ -54,8 +65,9 @@ impl StarknetApiServer for StarknetApi { contract_address: FieldElement, ) -> Result { let nonce = self - .sequencer - .nonce_at(block_id, contract_address.into()) + .on_blocking_task(move |this| { + this.sequencer.nonce_at(block_id, contract_address.into()) + }) .await .map_err(StarknetApiError::from)? .ok_or(StarknetApiError::ContractNotFound)?; diff --git a/crates/katana/tasks/Cargo.toml b/crates/katana/tasks/Cargo.toml new file mode 100644 index 0000000000..f0b7df590f --- /dev/null +++ b/crates/katana/tasks/Cargo.toml @@ -0,0 +1,10 @@ +[package] +edition.workspace = true +name = "katana-tasks" +version.workspace = true + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +thiserror.workspace = true +tokio.workspace = true diff --git a/crates/katana/tasks/src/lib.rs b/crates/katana/tasks/src/lib.rs new file mode 100644 index 0000000000..1cf96aa550 --- /dev/null +++ b/crates/katana/tasks/src/lib.rs @@ -0,0 +1,51 @@ +use std::future::Future; + +use tokio::runtime::Handle; +use tokio::task::JoinHandle; + +/// This `struct` is created by the [TokioTaskSpawner::new] method. +#[derive(Debug, thiserror::Error)] +#[error("Failed to initialize task spawner: {0}")] +pub struct TaskSpawnerInitError(tokio::runtime::TryCurrentError); + +/// A task spawner for spawning tasks on a tokio runtime. This is simple wrapper around a tokio's +/// runtime [Handle] to easily spawn tasks on the runtime. +#[derive(Clone)] +pub struct TokioTaskSpawner { + /// Handle to the tokio runtime. + tokio_handle: Handle, +} + +impl TokioTaskSpawner { + /// Creates a new [TokioTaskSpawner] over the currently running tokio runtime. + /// + /// ## Errors + /// + /// Returns an error if no tokio runtime has been started. + pub fn new() -> Result { + Ok(Self { tokio_handle: Handle::try_current().map_err(TaskSpawnerInitError)? }) + } + + /// Creates a new [TokioTaskSpawner] with the given tokio runtime [Handle]. + pub fn new_with_handle(tokio_handle: Handle) -> Self { + Self { tokio_handle } + } +} + +impl TokioTaskSpawner { + pub fn spawn(&self, future: F) -> JoinHandle + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + self.tokio_handle.spawn(future) + } + + pub fn spawn_blocking(&self, func: F) -> JoinHandle + where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, + { + self.tokio_handle.spawn_blocking(func) + } +} From 5db4d9d2bb04587b9a8f7c74187eff3a185fbb12 Mon Sep 17 00:00:00 2001 From: Kariy Date: Thu, 18 Jan 2024 00:46:02 +0900 Subject: [PATCH 2/5] update --- Cargo.lock | 9 +++++++++ crates/katana/core/src/sequencer.rs | 2 +- crates/katana/rpc/Cargo.toml | 1 + 3 files changed, 11 insertions(+), 1 deletion(-) diff --git a/Cargo.lock b/Cargo.lock index b28d950ab4..cc32ce3064 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5611,6 +5611,7 @@ dependencies = [ "katana-provider", "katana-rpc-types", "katana-rpc-types-builder", + "katana-tasks", "serde", "serde_json", "serde_with", @@ -5661,6 +5662,14 @@ dependencies = [ "url", ] +[[package]] +name = "katana-tasks" +version = "0.5.0" +dependencies = [ + "thiserror", + "tokio", +] + [[package]] name = "keccak" version = "0.1.4" diff --git a/crates/katana/core/src/sequencer.rs b/crates/katana/core/src/sequencer.rs index 1f2b85dbde..6038b0ca3e 100644 --- a/crates/katana/core/src/sequencer.rs +++ b/crates/katana/core/src/sequencer.rs @@ -300,7 +300,7 @@ impl KatanaSequencer { Ok(count) } - pub async fn nonce_at( + pub fn nonce_at( &self, block_id: BlockIdOrTag, contract_address: ContractAddress, diff --git a/crates/katana/rpc/Cargo.toml b/crates/katana/rpc/Cargo.toml index 33d21f9570..91450e7971 100644 --- a/crates/katana/rpc/Cargo.toml +++ b/crates/katana/rpc/Cargo.toml @@ -12,6 +12,7 @@ katana-primitives = { path = "../primitives" } katana-provider = { path = "../storage/provider" } katana-rpc-types = { path = "rpc-types" } katana-rpc-types-builder = { path = "rpc-types-builder" } +katana-tasks = { path = "../tasks" } anyhow.workspace = true cairo-lang-starknet = "2.3.1" From 4f5b394f82a87752f743a2aea971509a378c2441 Mon Sep 17 00:00:00 2001 From: Kariy Date: Thu, 18 Jan 2024 13:30:25 +0900 Subject: [PATCH 3/5] wip --- crates/katana/core/src/sequencer.rs | 7 +- crates/katana/rpc/src/starknet.rs | 757 +++++++++++++++------------- 2 files changed, 414 insertions(+), 350 deletions(-) diff --git a/crates/katana/core/src/sequencer.rs b/crates/katana/core/src/sequencer.rs index 6038b0ca3e..16c3ca9b7e 100644 --- a/crates/katana/core/src/sequencer.rs +++ b/crates/katana/core/src/sequencer.rs @@ -267,8 +267,9 @@ impl KatanaSequencer { self.backend.chain_id } - pub fn block_number(&self) -> BlockNumber { - BlockNumberProvider::latest_number(&self.backend.blockchain.provider()).unwrap() + pub fn block_number(&self) -> SequencerResult { + let num = BlockNumberProvider::latest_number(&self.backend.blockchain.provider())?; + Ok(num) } pub fn block_tx_count(&self, block_id: BlockIdOrTag) -> SequencerResult> { @@ -352,7 +353,7 @@ impl KatanaSequencer { Ok(tx) } - pub async fn events( + pub fn events( &self, from_block: BlockIdOrTag, to_block: BlockIdOrTag, diff --git a/crates/katana/rpc/src/starknet.rs b/crates/katana/rpc/src/starknet.rs index 3603f24aa2..3e94d1f367 100644 --- a/crates/katana/rpc/src/starknet.rs +++ b/crates/katana/rpc/src/starknet.rs @@ -64,37 +64,47 @@ impl StarknetApiServer for StarknetApi { block_id: BlockIdOrTag, contract_address: FieldElement, ) -> Result { - let nonce = self - .on_blocking_task(move |this| { - this.sequencer.nonce_at(block_id, contract_address.into()) - }) - .await - .map_err(StarknetApiError::from)? - .ok_or(StarknetApiError::ContractNotFound)?; - - Ok(nonce.into()) + self.on_blocking_task(move |this| { + let nonce = this + .sequencer + .nonce_at(block_id, contract_address.into()) + .map_err(StarknetApiError::from)? + .ok_or(StarknetApiError::ContractNotFound)?; + Ok(nonce.into()) + }) + .await } async fn block_number(&self) -> Result { - Ok(self.sequencer.block_number()) + self.on_blocking_task(move |this| { + let block_number = this.sequencer.block_number().map_err(StarknetApiError::from)?; + Ok(block_number) + }) + .await } async fn transaction_by_hash(&self, transaction_hash: FieldElement) -> Result { - let tx = self - .sequencer - .transaction(&transaction_hash) - .map_err(StarknetApiError::from)? - .ok_or(StarknetApiError::TxnHashNotFound)?; - Ok(tx.into()) + self.on_blocking_task(move |this| { + let tx = this + .sequencer + .transaction(&transaction_hash) + .map_err(StarknetApiError::from)? + .ok_or(StarknetApiError::TxnHashNotFound)?; + Ok(tx.into()) + }) + .await } async fn block_transaction_count(&self, block_id: BlockIdOrTag) -> Result { - let count = self - .sequencer - .block_tx_count(block_id) - .map_err(StarknetApiError::from)? - .ok_or(StarknetApiError::BlockNotFound)?; - Ok(count) + self.on_blocking_task(move |this| { + let count = this + .sequencer + .block_tx_count(block_id) + .map_err(StarknetApiError::from)? + .ok_or(StarknetApiError::BlockNotFound)?; + Ok(count) + }) + .await } async fn class_at( @@ -103,17 +113,21 @@ impl StarknetApiServer for StarknetApi { contract_address: FieldElement, ) -> Result { let class_hash = self - .sequencer - .class_hash_at(block_id, contract_address.into()) - .map_err(StarknetApiError::from)? - .ok_or(StarknetApiError::ContractNotFound)?; - + .on_blocking_task(move |this| { + this.sequencer + .class_hash_at(block_id, contract_address.into()) + .map_err(StarknetApiError::from)? + .ok_or(StarknetApiError::ContractNotFound) + }) + .await?; self.class(block_id, class_hash).await } async fn block_hash_and_number(&self) -> Result { - let hash_and_num_pair = - self.sequencer.block_hash_and_number().map_err(StarknetApiError::from)?; + let hash_and_num_pair = self + .on_blocking_task(move |this| this.sequencer.block_hash_and_number()) + .await + .map_err(StarknetApiError::from)?; Ok(hash_and_num_pair.into()) } @@ -121,51 +135,53 @@ impl StarknetApiServer for StarknetApi { &self, block_id: BlockIdOrTag, ) -> Result { - let provider = self.sequencer.backend.blockchain.provider(); - - if BlockIdOrTag::Tag(BlockTag::Pending) == block_id { - if let Some(pending_state) = self.sequencer.pending_state() { - let block_env = pending_state.block_envs.read().0.clone(); - let latest_hash = - BlockHashProvider::latest_hash(provider).map_err(StarknetApiError::from)?; - - let gas_prices = GasPrices { - eth: block_env.l1_gas_prices.eth, - strk: block_env.l1_gas_prices.strk, - }; + self.on_blocking_task(move |this| { + let provider = this.sequencer.backend.blockchain.provider(); + + if BlockIdOrTag::Tag(BlockTag::Pending) == block_id { + if let Some(pending_state) = this.sequencer.pending_state() { + let block_env = pending_state.block_envs.read().0.clone(); + let latest_hash = + BlockHashProvider::latest_hash(provider).map_err(StarknetApiError::from)?; + + let gas_prices = GasPrices { + eth: block_env.l1_gas_prices.eth, + strk: block_env.l1_gas_prices.strk, + }; + + let header = PartialHeader { + gas_prices, + parent_hash: latest_hash, + version: CURRENT_STARKNET_VERSION, + timestamp: block_env.timestamp, + sequencer_address: block_env.sequencer_address, + }; + + let transactions = pending_state + .executed_txs + .read() + .iter() + .map(|(tx, _)| tx.hash) + .collect::>(); - let header = PartialHeader { - gas_prices, - parent_hash: latest_hash, - version: CURRENT_STARKNET_VERSION, - timestamp: block_env.timestamp, - sequencer_address: block_env.sequencer_address, - }; + return Ok(MaybePendingBlockWithTxHashes::Pending( + PendingBlockWithTxHashes::new(header, transactions), + )); + } + } - let transactions = pending_state - .executed_txs - .read() - .iter() - .map(|(tx, _)| tx.hash) - .collect::>(); + let block_num = BlockIdReader::convert_block_id(provider, block_id) + .map_err(StarknetApiError::from)? + .map(BlockHashOrNumber::Num) + .ok_or(StarknetApiError::BlockNotFound)?; - return Ok(MaybePendingBlockWithTxHashes::Pending(PendingBlockWithTxHashes::new( - header, - transactions, - ))); - } - } - - let block_num = BlockIdReader::convert_block_id(provider, block_id) - .map_err(StarknetApiError::from)? - .map(BlockHashOrNumber::Num) - .ok_or(StarknetApiError::BlockNotFound)?; - - katana_rpc_types_builder::BlockBuilder::new(block_num, provider) - .build_with_tx_hash() - .map_err(StarknetApiError::from)? - .map(MaybePendingBlockWithTxHashes::Block) - .ok_or(Error::from(StarknetApiError::BlockNotFound)) + katana_rpc_types_builder::BlockBuilder::new(block_num, provider) + .build_with_tx_hash() + .map_err(StarknetApiError::from)? + .map(MaybePendingBlockWithTxHashes::Block) + .ok_or(Error::from(StarknetApiError::BlockNotFound)) + }) + .await } async fn transaction_by_block_id_and_index( @@ -173,133 +189,145 @@ impl StarknetApiServer for StarknetApi { block_id: BlockIdOrTag, index: u64, ) -> Result { - // TEMP: have to handle pending tag independently for now - let tx = if BlockIdOrTag::Tag(BlockTag::Pending) == block_id { - let Some(pending_state) = self.sequencer.pending_state() else { - return Err(StarknetApiError::BlockNotFound.into()); - }; + self.on_blocking_task(move |this| { + // TEMP: have to handle pending tag independently for now + let tx = if BlockIdOrTag::Tag(BlockTag::Pending) == block_id { + let Some(pending_state) = this.sequencer.pending_state() else { + return Err(StarknetApiError::BlockNotFound.into()); + }; - let pending_txs = pending_state.executed_txs.read(); - pending_txs.iter().nth(index as usize).map(|(tx, _)| tx.clone()) - } else { - let provider = &self.sequencer.backend.blockchain.provider(); + let pending_txs = pending_state.executed_txs.read(); + pending_txs.iter().nth(index as usize).map(|(tx, _)| tx.clone()) + } else { + let provider = &this.sequencer.backend.blockchain.provider(); - let block_num = BlockIdReader::convert_block_id(provider, block_id) - .map_err(StarknetApiError::from)? - .map(BlockHashOrNumber::Num) - .ok_or(StarknetApiError::BlockNotFound)?; + let block_num = BlockIdReader::convert_block_id(provider, block_id) + .map_err(StarknetApiError::from)? + .map(BlockHashOrNumber::Num) + .ok_or(StarknetApiError::BlockNotFound)?; - TransactionProvider::transaction_by_block_and_idx(provider, block_num, index) - .map_err(StarknetApiError::from)? - }; + TransactionProvider::transaction_by_block_and_idx(provider, block_num, index) + .map_err(StarknetApiError::from)? + }; - Ok(tx.ok_or(StarknetApiError::InvalidTxnIndex)?.into()) + Ok(tx.ok_or(StarknetApiError::InvalidTxnIndex)?.into()) + }) + .await } async fn block_with_txs( &self, block_id: BlockIdOrTag, ) -> Result { - let provider = self.sequencer.backend.blockchain.provider(); - - if BlockIdOrTag::Tag(BlockTag::Pending) == block_id { - if let Some(pending_state) = self.sequencer.pending_state() { - let block_env = pending_state.block_envs.read().0.clone(); - let latest_hash = - BlockHashProvider::latest_hash(provider).map_err(StarknetApiError::from)?; - - let gas_prices = GasPrices { - eth: block_env.l1_gas_prices.eth, - strk: block_env.l1_gas_prices.strk, - }; - - let header = PartialHeader { - gas_prices, - parent_hash: latest_hash, - version: CURRENT_STARKNET_VERSION, - timestamp: block_env.timestamp, - sequencer_address: block_env.sequencer_address, - }; + self.on_blocking_task(move |this| { + let provider = this.sequencer.backend.blockchain.provider(); + + if BlockIdOrTag::Tag(BlockTag::Pending) == block_id { + if let Some(pending_state) = this.sequencer.pending_state() { + let block_env = pending_state.block_envs.read().0.clone(); + let latest_hash = + BlockHashProvider::latest_hash(provider).map_err(StarknetApiError::from)?; + + let gas_prices = GasPrices { + eth: block_env.l1_gas_prices.eth, + strk: block_env.l1_gas_prices.strk, + }; + + let header = PartialHeader { + gas_prices, + parent_hash: latest_hash, + version: CURRENT_STARKNET_VERSION, + timestamp: block_env.timestamp, + sequencer_address: block_env.sequencer_address, + }; + + let transactions = pending_state + .executed_txs + .read() + .iter() + .map(|(tx, _)| tx.clone()) + .collect::>(); + + return Ok(MaybePendingBlockWithTxs::Pending(PendingBlockWithTxs::new( + header, + transactions, + ))); + } + } - let transactions = pending_state - .executed_txs - .read() - .iter() - .map(|(tx, _)| tx.clone()) - .collect::>(); + let block_num = BlockIdReader::convert_block_id(provider, block_id) + .map_err(|e| StarknetApiError::UnexpectedError { reason: e.to_string() })? + .map(BlockHashOrNumber::Num) + .ok_or(StarknetApiError::BlockNotFound)?; - return Ok(MaybePendingBlockWithTxs::Pending(PendingBlockWithTxs::new( - header, - transactions, - ))); - } - } - - let block_num = BlockIdReader::convert_block_id(provider, block_id) - .map_err(|e| StarknetApiError::UnexpectedError { reason: e.to_string() })? - .map(BlockHashOrNumber::Num) - .ok_or(StarknetApiError::BlockNotFound)?; - - katana_rpc_types_builder::BlockBuilder::new(block_num, provider) - .build() - .map_err(|e| StarknetApiError::UnexpectedError { reason: e.to_string() })? - .map(MaybePendingBlockWithTxs::Block) - .ok_or(Error::from(StarknetApiError::BlockNotFound)) + katana_rpc_types_builder::BlockBuilder::new(block_num, provider) + .build() + .map_err(|e| StarknetApiError::UnexpectedError { reason: e.to_string() })? + .map(MaybePendingBlockWithTxs::Block) + .ok_or(Error::from(StarknetApiError::BlockNotFound)) + }) + .await } async fn state_update(&self, block_id: BlockIdOrTag) -> Result { - let provider = self.sequencer.backend.blockchain.provider(); + self.on_blocking_task(move |this| { + let provider = this.sequencer.backend.blockchain.provider(); - let block_id = match block_id { - BlockIdOrTag::Number(num) => BlockHashOrNumber::Num(num), - BlockIdOrTag::Hash(hash) => BlockHashOrNumber::Hash(hash), + let block_id = match block_id { + BlockIdOrTag::Number(num) => BlockHashOrNumber::Num(num), + BlockIdOrTag::Hash(hash) => BlockHashOrNumber::Hash(hash), - BlockIdOrTag::Tag(BlockTag::Latest) => BlockNumberProvider::latest_number(provider) - .map(BlockHashOrNumber::Num) - .map_err(|_| StarknetApiError::BlockNotFound)?, + BlockIdOrTag::Tag(BlockTag::Latest) => BlockNumberProvider::latest_number(provider) + .map(BlockHashOrNumber::Num) + .map_err(|_| StarknetApiError::BlockNotFound)?, - BlockIdOrTag::Tag(BlockTag::Pending) => { - return Err(StarknetApiError::BlockNotFound.into()); - } - }; + BlockIdOrTag::Tag(BlockTag::Pending) => { + return Err(StarknetApiError::BlockNotFound.into()); + } + }; - katana_rpc_types_builder::StateUpdateBuilder::new(block_id, provider) - .build() - .map_err(|e| StarknetApiError::UnexpectedError { reason: e.to_string() })? - .ok_or(Error::from(StarknetApiError::BlockNotFound)) + katana_rpc_types_builder::StateUpdateBuilder::new(block_id, provider) + .build() + .map_err(|e| StarknetApiError::UnexpectedError { reason: e.to_string() })? + .ok_or(Error::from(StarknetApiError::BlockNotFound)) + }) + .await } async fn transaction_receipt( &self, transaction_hash: FieldElement, ) -> Result { - let provider = self.sequencer.backend.blockchain.provider(); - let receipt = ReceiptBuilder::new(transaction_hash, provider) - .build() - .map_err(|e| StarknetApiError::UnexpectedError { reason: e.to_string() })?; - - match receipt { - Some(receipt) => Ok(MaybePendingTxReceipt::Receipt(receipt)), - - None => { - let pending_receipt = self.sequencer.pending_state().and_then(|s| { - s.executed_txs - .read() - .iter() - .find(|(tx, _)| tx.hash == transaction_hash) - .map(|(_, rct)| rct.receipt.clone()) - }); - - let Some(pending_receipt) = pending_receipt else { - return Err(StarknetApiError::TxnHashNotFound.into()); - }; - - Ok(MaybePendingTxReceipt::Pending(PendingTxReceipt::new( - transaction_hash, - pending_receipt, - ))) + self.on_blocking_task(move |this| { + let provider = this.sequencer.backend.blockchain.provider(); + let receipt = ReceiptBuilder::new(transaction_hash, provider) + .build() + .map_err(|e| StarknetApiError::UnexpectedError { reason: e.to_string() })?; + + match receipt { + Some(receipt) => Ok(MaybePendingTxReceipt::Receipt(receipt)), + + None => { + let pending_receipt = this.sequencer.pending_state().and_then(|s| { + s.executed_txs + .read() + .iter() + .find(|(tx, _)| tx.hash == transaction_hash) + .map(|(_, rct)| rct.receipt.clone()) + }); + + let Some(pending_receipt) = pending_receipt else { + return Err(StarknetApiError::TxnHashNotFound.into()); + }; + + Ok(MaybePendingTxReceipt::Pending(PendingTxReceipt::new( + transaction_hash, + pending_receipt, + ))) + } } - } + }) + .await } async fn class_hash_at( @@ -307,13 +335,15 @@ impl StarknetApiServer for StarknetApi { block_id: BlockIdOrTag, contract_address: FieldElement, ) -> Result { - let hash = self - .sequencer - .class_hash_at(block_id, contract_address.into()) - .map_err(StarknetApiError::from)? - .ok_or(StarknetApiError::ContractNotFound)?; - - Ok(hash.into()) + self.on_blocking_task(move |this| { + let hash = this + .sequencer + .class_hash_at(block_id, contract_address.into()) + .map_err(StarknetApiError::from)? + .ok_or(StarknetApiError::ContractNotFound)?; + Ok(hash.into()) + }) + .await } async fn class( @@ -321,40 +351,47 @@ impl StarknetApiServer for StarknetApi { block_id: BlockIdOrTag, class_hash: FieldElement, ) -> Result { - let class = self.sequencer.class(block_id, class_hash).map_err(StarknetApiError::from)?; - let Some(class) = class else { return Err(StarknetApiError::ClassHashNotFound.into()) }; - - match class { - StarknetContract::Legacy(c) => { - let contract = legacy_inner_to_rpc_class(c) - .map_err(|e| StarknetApiError::UnexpectedError { reason: e.to_string() })?; - Ok(contract) + self.on_blocking_task(move |this| { + let class = + this.sequencer.class(block_id, class_hash).map_err(StarknetApiError::from)?; + let Some(class) = class else { return Err(StarknetApiError::ClassHashNotFound.into()) }; + + match class { + StarknetContract::Legacy(c) => { + let contract = legacy_inner_to_rpc_class(c) + .map_err(|e| StarknetApiError::UnexpectedError { reason: e.to_string() })?; + Ok(contract) + } + StarknetContract::Sierra(c) => Ok(ContractClass::Sierra(c)), } - StarknetContract::Sierra(c) => Ok(ContractClass::Sierra(c)), - } + }) + .await } async fn events(&self, filter: EventFilterWithPage) -> Result { - let from_block = filter.event_filter.from_block.unwrap_or(BlockIdOrTag::Number(0)); - let to_block = filter.event_filter.to_block.unwrap_or(BlockIdOrTag::Tag(BlockTag::Latest)); - - let keys = filter.event_filter.keys; - let keys = keys.filter(|keys| !(keys.len() == 1 && keys.is_empty())); - - let events = self - .sequencer - .events( - from_block, - to_block, - filter.event_filter.address.map(|f| f.into()), - keys, - filter.result_page_request.continuation_token, - filter.result_page_request.chunk_size, - ) - .await - .map_err(StarknetApiError::from)?; - - Ok(events) + self.on_blocking_task(move |this| { + let from_block = filter.event_filter.from_block.unwrap_or(BlockIdOrTag::Number(0)); + let to_block = + filter.event_filter.to_block.unwrap_or(BlockIdOrTag::Tag(BlockTag::Latest)); + + let keys = filter.event_filter.keys; + let keys = keys.filter(|keys| !(keys.len() == 1 && keys.is_empty())); + + let events = this + .sequencer + .events( + from_block, + to_block, + filter.event_filter.address.map(|f| f.into()), + keys, + filter.result_page_request.continuation_token, + filter.result_page_request.chunk_size, + ) + .map_err(StarknetApiError::from)?; + + Ok(events) + }) + .await } async fn call( @@ -362,15 +399,17 @@ impl StarknetApiServer for StarknetApi { request: FunctionCall, block_id: BlockIdOrTag, ) -> Result, Error> { - let request = EntryPointCall { - calldata: request.calldata, - contract_address: request.contract_address.into(), - entry_point_selector: request.entry_point_selector, - }; - - let res = self.sequencer.call(request, block_id).map_err(StarknetApiError::from)?; + self.on_blocking_task(move |this| { + let request = EntryPointCall { + calldata: request.calldata, + contract_address: request.contract_address.into(), + entry_point_selector: request.entry_point_selector, + }; - Ok(res.into_iter().map(|v| v.into()).collect()) + let res = this.sequencer.call(request, block_id).map_err(StarknetApiError::from)?; + Ok(res.into_iter().map(|v| v.into()).collect()) + }) + .await } async fn storage_at( @@ -379,33 +418,39 @@ impl StarknetApiServer for StarknetApi { key: FieldElement, block_id: BlockIdOrTag, ) -> Result { - let value = self - .sequencer - .storage_at(contract_address.into(), key, block_id) - .map_err(StarknetApiError::from)?; - - Ok(value.into()) + self.on_blocking_task(move |this| { + let value = this + .sequencer + .storage_at(contract_address.into(), key, block_id) + .map_err(StarknetApiError::from)?; + + Ok(value.into()) + }) + .await } async fn add_deploy_account_transaction( &self, deploy_account_transaction: BroadcastedDeployAccountTx, ) -> Result { - if deploy_account_transaction.is_query { - return Err(StarknetApiError::UnsupportedTransactionVersion.into()); - } + self.on_blocking_task(move |this| { + if deploy_account_transaction.is_query { + return Err(StarknetApiError::UnsupportedTransactionVersion.into()); + } - let chain_id = self.sequencer.chain_id(); + let chain_id = this.sequencer.chain_id(); - let tx = deploy_account_transaction.into_tx_with_chain_id(chain_id); - let contract_address = tx.contract_address; + let tx = deploy_account_transaction.into_tx_with_chain_id(chain_id); + let contract_address = tx.contract_address; - let tx = ExecutableTxWithHash::new(ExecutableTx::DeployAccount(tx)); - let tx_hash = tx.hash; + let tx = ExecutableTxWithHash::new(ExecutableTx::DeployAccount(tx)); + let tx_hash = tx.hash; - self.sequencer.add_transaction_to_pool(tx); + this.sequencer.add_transaction_to_pool(tx); - Ok((tx_hash, contract_address).into()) + Ok((tx_hash, contract_address).into()) + }) + .await } async fn estimate_fee( @@ -413,38 +458,43 @@ impl StarknetApiServer for StarknetApi { request: Vec, block_id: BlockIdOrTag, ) -> Result, Error> { - let chain_id = self.sequencer.chain_id(); - - let transactions = request - .into_iter() - .map(|tx| { - let tx = match tx { - BroadcastedTx::Invoke(tx) => { - let tx = tx.into_tx_with_chain_id(chain_id); - ExecutableTxWithHash::new_query(ExecutableTx::Invoke(tx)) - } - - BroadcastedTx::DeployAccount(tx) => { - let tx = tx.into_tx_with_chain_id(chain_id); - ExecutableTxWithHash::new_query(ExecutableTx::DeployAccount(tx)) - } - - BroadcastedTx::Declare(tx) => { - let tx = tx - .try_into_tx_with_chain_id(chain_id) - .map_err(|_| StarknetApiError::InvalidContractClass)?; - ExecutableTxWithHash::new_query(ExecutableTx::Declare(tx)) - } - }; - - Result::::Ok(tx) - }) - .collect::, _>>()?; - - let res = - self.sequencer.estimate_fee(transactions, block_id).map_err(StarknetApiError::from)?; - - Ok(res) + self.on_blocking_task(move |this| { + let chain_id = this.sequencer.chain_id(); + + let transactions = request + .into_iter() + .map(|tx| { + let tx = match tx { + BroadcastedTx::Invoke(tx) => { + let tx = tx.into_tx_with_chain_id(chain_id); + ExecutableTxWithHash::new_query(ExecutableTx::Invoke(tx)) + } + + BroadcastedTx::DeployAccount(tx) => { + let tx = tx.into_tx_with_chain_id(chain_id); + ExecutableTxWithHash::new_query(ExecutableTx::DeployAccount(tx)) + } + + BroadcastedTx::Declare(tx) => { + let tx = tx + .try_into_tx_with_chain_id(chain_id) + .map_err(|_| StarknetApiError::InvalidContractClass)?; + ExecutableTxWithHash::new_query(ExecutableTx::Declare(tx)) + } + }; + + Result::::Ok(tx) + }) + .collect::, _>>()?; + + let res = this + .sequencer + .estimate_fee(transactions, block_id) + .map_err(StarknetApiError::from)?; + + Ok(res) + }) + .await } async fn estimate_message_fee( @@ -452,129 +502,142 @@ impl StarknetApiServer for StarknetApi { message: MsgFromL1, block_id: BlockIdOrTag, ) -> Result { - let chain_id = self.sequencer.chain_id(); + self.on_blocking_task(move |this| { + let chain_id = this.sequencer.chain_id(); - let tx = message.into_tx_with_chain_id(chain_id); - let hash = tx.calculate_hash(); - let tx: ExecutableTxWithHash = ExecutableTxWithHash { hash, transaction: tx.into() }; + let tx = message.into_tx_with_chain_id(chain_id); + let hash = tx.calculate_hash(); + let tx: ExecutableTxWithHash = ExecutableTxWithHash { hash, transaction: tx.into() }; - let res = self - .sequencer - .estimate_fee(vec![tx], block_id) - .map_err(StarknetApiError::from)? - .pop() - .expect("should have estimate result"); + let res = this + .sequencer + .estimate_fee(vec![tx], block_id) + .map_err(StarknetApiError::from)? + .pop() + .expect("should have estimate result"); - Ok(res) + Ok(res) + }) + .await } async fn add_declare_transaction( &self, declare_transaction: BroadcastedDeclareTx, ) -> Result { - if declare_transaction.is_query() { - return Err(StarknetApiError::UnsupportedTransactionVersion.into()); - } + self.on_blocking_task(move |this| { + if declare_transaction.is_query() { + return Err(StarknetApiError::UnsupportedTransactionVersion.into()); + } - let chain_id = self.sequencer.chain_id(); + let chain_id = this.sequencer.chain_id(); - // // validate compiled class hash - // let is_valid = declare_transaction - // .validate_compiled_class_hash() - // .map_err(|_| StarknetApiError::InvalidContractClass)?; + // // validate compiled class hash + // let is_valid = declare_transaction + // .validate_compiled_class_hash() + // .map_err(|_| StarknetApiError::InvalidContractClass)?; - // if !is_valid { - // return Err(StarknetApiError::CompiledClassHashMismatch.into()); - // } + // if !is_valid { + // return Err(StarknetApiError::CompiledClassHashMismatch.into()); + // } - let tx = declare_transaction - .try_into_tx_with_chain_id(chain_id) - .map_err(|_| StarknetApiError::InvalidContractClass)?; + let tx = declare_transaction + .try_into_tx_with_chain_id(chain_id) + .map_err(|_| StarknetApiError::InvalidContractClass)?; - let class_hash = tx.class_hash(); - let tx = ExecutableTxWithHash::new(ExecutableTx::Declare(tx)); - let tx_hash = tx.hash; + let class_hash = tx.class_hash(); + let tx = ExecutableTxWithHash::new(ExecutableTx::Declare(tx)); + let tx_hash = tx.hash; - self.sequencer.add_transaction_to_pool(tx); + this.sequencer.add_transaction_to_pool(tx); - Ok((tx_hash, class_hash).into()) + Ok((tx_hash, class_hash).into()) + }) + .await } async fn add_invoke_transaction( &self, invoke_transaction: BroadcastedInvokeTx, ) -> Result { - if invoke_transaction.is_query { - return Err(StarknetApiError::UnsupportedTransactionVersion.into()); - } + self.on_blocking_task(move |this| { + if invoke_transaction.is_query { + return Err(StarknetApiError::UnsupportedTransactionVersion.into()); + } - let chain_id = self.sequencer.chain_id(); + let chain_id = this.sequencer.chain_id(); - let tx = invoke_transaction.into_tx_with_chain_id(chain_id); - let tx = ExecutableTxWithHash::new(ExecutableTx::Invoke(tx)); - let tx_hash = tx.hash; + let tx = invoke_transaction.into_tx_with_chain_id(chain_id); + let tx = ExecutableTxWithHash::new(ExecutableTx::Invoke(tx)); + let tx_hash = tx.hash; - self.sequencer.add_transaction_to_pool(tx); + this.sequencer.add_transaction_to_pool(tx); - Ok(tx_hash.into()) + Ok(tx_hash.into()) + }) + .await } async fn transaction_status( &self, transaction_hash: TxHash, ) -> Result { - let provider = self.sequencer.backend.blockchain.provider(); + self.on_blocking_task(move |this| { + let provider = this.sequencer.backend.blockchain.provider(); + + let tx_status = + TransactionStatusProvider::transaction_status(provider, transaction_hash) + .map_err(StarknetApiError::from)?; + + if let Some(status) = tx_status { + if let Some(receipt) = ReceiptProvider::receipt_by_hash(provider, transaction_hash) + .map_err(StarknetApiError::from)? + { + let execution_status = if receipt.is_reverted() { + TransactionExecutionStatus::Reverted + } else { + TransactionExecutionStatus::Succeeded + }; + + return Ok(match status { + FinalityStatus::AcceptedOnL1 => { + TransactionStatus::AcceptedOnL1(execution_status) + } + FinalityStatus::AcceptedOnL2 => { + TransactionStatus::AcceptedOnL2(execution_status) + } + }); + } + } - let tx_status = TransactionStatusProvider::transaction_status(provider, transaction_hash) - .map_err(StarknetApiError::from)?; + let pending_state = this.sequencer.pending_state(); + let state = pending_state.ok_or(StarknetApiError::TxnHashNotFound)?; + let executed_txs = state.executed_txs.read(); - if let Some(status) = tx_status { - if let Some(receipt) = ReceiptProvider::receipt_by_hash(provider, transaction_hash) - .map_err(StarknetApiError::from)? + // attemps to find in the valid transactions list first (executed_txs) + // if not found, then search in the rejected transactions list (rejected_txs) + if let Some(is_reverted) = executed_txs + .iter() + .find(|(tx, _)| tx.hash == transaction_hash) + .map(|(_, rct)| rct.receipt.is_reverted()) { - let execution_status = if receipt.is_reverted() { + let exec_status = if is_reverted { TransactionExecutionStatus::Reverted } else { TransactionExecutionStatus::Succeeded }; - return Ok(match status { - FinalityStatus::AcceptedOnL1 => { - TransactionStatus::AcceptedOnL1(execution_status) - } - FinalityStatus::AcceptedOnL2 => { - TransactionStatus::AcceptedOnL2(execution_status) - } - }); - } - } - - let pending_state = self.sequencer.pending_state(); - let state = pending_state.ok_or(StarknetApiError::TxnHashNotFound)?; - let executed_txs = state.executed_txs.read(); - - // attemps to find in the valid transactions list first (executed_txs) - // if not found, then search in the rejected transactions list (rejected_txs) - if let Some(is_reverted) = executed_txs - .iter() - .find(|(tx, _)| tx.hash == transaction_hash) - .map(|(_, rct)| rct.receipt.is_reverted()) - { - let exec_status = if is_reverted { - TransactionExecutionStatus::Reverted + Ok(TransactionStatus::AcceptedOnL2(exec_status)) } else { - TransactionExecutionStatus::Succeeded - }; + let rejected_txs = state.rejected_txs.read(); - Ok(TransactionStatus::AcceptedOnL2(exec_status)) - } else { - let rejected_txs = state.rejected_txs.read(); - - rejected_txs - .iter() - .find(|(tx, _)| tx.hash == transaction_hash) - .map(|_| TransactionStatus::Rejected) - .ok_or(Error::from(StarknetApiError::TxnHashNotFound)) - } + rejected_txs + .iter() + .find(|(tx, _)| tx.hash == transaction_hash) + .map(|_| TransactionStatus::Rejected) + .ok_or(Error::from(StarknetApiError::TxnHashNotFound)) + } + }) + .await } } From 5dd757b40a7bb0347abdb5c2e6018d7fa13b5eb5 Mon Sep 17 00:00:00 2001 From: Kariy Date: Thu, 18 Jan 2024 16:26:01 +0900 Subject: [PATCH 4/5] wip --- Cargo.lock | 6 ++- crates/katana/core/src/sequencer.rs | 5 ++ .../katana/rpc/rpc-types/src/transaction.rs | 2 + crates/katana/rpc/src/starknet.rs | 24 +++++++++- crates/katana/tasks/Cargo.toml | 2 + crates/katana/tasks/src/lib.rs | 48 +++++++++++++++++++ 6 files changed, 84 insertions(+), 3 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index cc32ce3064..d044fc69cc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5666,6 +5666,8 @@ dependencies = [ name = "katana-tasks" version = "0.5.0" dependencies = [ + "futures", + "rayon", "thiserror", "tokio", ] @@ -7564,9 +7566,9 @@ checksum = "c707298afce11da2efef2f600116fa93ffa7a032b5d7b628aa17711ec81383ca" [[package]] name = "reqwest" -version = "0.11.22" +version = "0.11.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "046cd98826c46c2ac8ddecae268eb5c2e58628688a5fc7a2643704a73faba95b" +checksum = "37b1ae8d9ac08420c66222fb9096fc5de435c3c48542bc5336c51892cffafb41" dependencies = [ "async-compression", "base64 0.21.5", diff --git a/crates/katana/core/src/sequencer.rs b/crates/katana/core/src/sequencer.rs index 16c3ca9b7e..29729b7057 100644 --- a/crates/katana/core/src/sequencer.rs +++ b/crates/katana/core/src/sequencer.rs @@ -193,12 +193,17 @@ impl KatanaSequencer { transactions: Vec, block_id: BlockIdOrTag, ) -> SequencerResult> { + println!("ohayo 3"); let state = self.state(&block_id)?; + println!("ohayo 4"); + let block_context = self .block_execution_context_at(block_id)? .ok_or_else(|| SequencerError::BlockNotFound(block_id))?; + println!("ohayo 5"); + katana_executor::blockifier::utils::estimate_fee( transactions.into_iter(), block_context, diff --git a/crates/katana/rpc/rpc-types/src/transaction.rs b/crates/katana/rpc/rpc-types/src/transaction.rs index 00aab586bc..8bec14c625 100644 --- a/crates/katana/rpc/rpc-types/src/transaction.rs +++ b/crates/katana/rpc/rpc-types/src/transaction.rs @@ -79,9 +79,11 @@ impl BroadcastedDeclareTx { } BroadcastedDeclareTransaction::V2(tx) => { + println!("hi"); // TODO: avoid computing the class hash again let (class_hash, _, compiled_class) = flattened_sierra_to_compiled_class(&tx.contract_class)?; + println!("hi 1"); Ok(DeclareTxWithClass { compiled_class, diff --git a/crates/katana/rpc/src/starknet.rs b/crates/katana/rpc/src/starknet.rs index 3e94d1f367..5145fcd329 100644 --- a/crates/katana/rpc/src/starknet.rs +++ b/crates/katana/rpc/src/starknet.rs @@ -50,7 +50,16 @@ impl StarknetApi { T: Send + 'static, { let this = self.clone(); - TokioTaskSpawner::new().unwrap().spawn_blocking(move || func(this)).await.unwrap() + // create oneshot tokio channel + let (sender, receiver) = tokio::sync::oneshot::channel::(); + let _ = TokioTaskSpawner::new() + .unwrap() + .spawn_blocking(move || { + let res = func(this); + let _ = sender.send(res); + }) + .await; + receiver.await.unwrap() } } #[async_trait] @@ -461,21 +470,26 @@ impl StarknetApiServer for StarknetApi { self.on_blocking_task(move |this| { let chain_id = this.sequencer.chain_id(); + println!("ohayo"); + let transactions = request .into_iter() .map(|tx| { let tx = match tx { BroadcastedTx::Invoke(tx) => { + println!("red"); let tx = tx.into_tx_with_chain_id(chain_id); ExecutableTxWithHash::new_query(ExecutableTx::Invoke(tx)) } BroadcastedTx::DeployAccount(tx) => { + println!("blue"); let tx = tx.into_tx_with_chain_id(chain_id); ExecutableTxWithHash::new_query(ExecutableTx::DeployAccount(tx)) } BroadcastedTx::Declare(tx) => { + println!("purple"); let tx = tx .try_into_tx_with_chain_id(chain_id) .map_err(|_| StarknetApiError::InvalidContractClass)?; @@ -487,14 +501,22 @@ impl StarknetApiServer for StarknetApi { }) .collect::, _>>()?; + println!("ohayo 2"); + let res = this .sequencer .estimate_fee(transactions, block_id) .map_err(StarknetApiError::from)?; + println!("ohayo 6"); + Ok(res) }) .await + .map_err(|e| { + println!("got error {:?}", e); + e + }) } async fn estimate_message_fee( diff --git a/crates/katana/tasks/Cargo.toml b/crates/katana/tasks/Cargo.toml index f0b7df590f..fd03a40729 100644 --- a/crates/katana/tasks/Cargo.toml +++ b/crates/katana/tasks/Cargo.toml @@ -6,5 +6,7 @@ version.workspace = true # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +futures.workspace = true +rayon.workspace = true thiserror.workspace = true tokio.workspace = true diff --git a/crates/katana/tasks/src/lib.rs b/crates/katana/tasks/src/lib.rs index 1cf96aa550..1788c1a4ec 100644 --- a/crates/katana/tasks/src/lib.rs +++ b/crates/katana/tasks/src/lib.rs @@ -1,5 +1,9 @@ use std::future::Future; +use std::pin::Pin; +use std::task::Poll; +use futures::channel::oneshot; +use rayon::ThreadPoolBuilder; use tokio::runtime::Handle; use tokio::task::JoinHandle; @@ -49,3 +53,47 @@ impl TokioTaskSpawner { self.tokio_handle.spawn_blocking(func) } } + +#[derive(Debug)] +#[must_use = "BlockingTaskHandle does nothing unless polled"] +pub struct BlockingTaskHandle(oneshot::Receiver); + +impl Future for BlockingTaskHandle { + type Output = T; + fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { + match Pin::new(&mut self.get_mut().0).poll(cx) { + Poll::Ready(Ok(res)) => Poll::Ready(res), + Poll::Ready(Err(_)) => panic!("blocking task cancelled"), + Poll::Pending => Poll::Pending, + } + } +} + +/// For expensive CPU-bound computations. For spawing blocking IO-bound tasks, use +/// [TokioTaskSpawner::spawn_blocking]. +pub struct BlockingTaskPool { + pool: rayon::ThreadPool, +} + +impl BlockingTaskPool { + pub fn new() -> Self { + Self { pool: Self::build().build().unwrap() } + } + + pub fn build() -> ThreadPoolBuilder { + ThreadPoolBuilder::new().thread_name(|i| format!("blocking-thread-pool-{i}")) + } + + pub fn spawn(&self, func: F) -> BlockingTaskHandle + where + F: FnOnce() -> T + Send + 'static, + T: Send + 'static, + { + let (tx, rx) = oneshot::channel::(); + self.pool.spawn(move || { + let res = func(); + let _ = tx.send(res); + }); + BlockingTaskHandle(rx) + } +} From ebfcbb46e6fb7be875121d125c3e70d4be0efeaf Mon Sep 17 00:00:00 2001 From: Kariy Date: Thu, 18 Jan 2024 18:13:28 +0900 Subject: [PATCH 5/5] update --- crates/katana/core/src/sequencer.rs | 5 - .../katana/rpc/rpc-types/src/transaction.rs | 2 - crates/katana/rpc/src/starknet.rs | 138 ++++++++++-------- crates/katana/tasks/src/lib.rs | 44 ++++-- 4 files changed, 104 insertions(+), 85 deletions(-) diff --git a/crates/katana/core/src/sequencer.rs b/crates/katana/core/src/sequencer.rs index 29729b7057..16c3ca9b7e 100644 --- a/crates/katana/core/src/sequencer.rs +++ b/crates/katana/core/src/sequencer.rs @@ -193,17 +193,12 @@ impl KatanaSequencer { transactions: Vec, block_id: BlockIdOrTag, ) -> SequencerResult> { - println!("ohayo 3"); let state = self.state(&block_id)?; - println!("ohayo 4"); - let block_context = self .block_execution_context_at(block_id)? .ok_or_else(|| SequencerError::BlockNotFound(block_id))?; - println!("ohayo 5"); - katana_executor::blockifier::utils::estimate_fee( transactions.into_iter(), block_context, diff --git a/crates/katana/rpc/rpc-types/src/transaction.rs b/crates/katana/rpc/rpc-types/src/transaction.rs index 8bec14c625..00aab586bc 100644 --- a/crates/katana/rpc/rpc-types/src/transaction.rs +++ b/crates/katana/rpc/rpc-types/src/transaction.rs @@ -79,11 +79,9 @@ impl BroadcastedDeclareTx { } BroadcastedDeclareTransaction::V2(tx) => { - println!("hi"); // TODO: avoid computing the class hash again let (class_hash, _, compiled_class) = flattened_sierra_to_compiled_class(&tx.contract_class)?; - println!("hi 1"); Ok(DeclareTxWithClass { compiled_class, diff --git a/crates/katana/rpc/src/starknet.rs b/crates/katana/rpc/src/starknet.rs index 5145fcd329..42e6f9984c 100644 --- a/crates/katana/rpc/src/starknet.rs +++ b/crates/katana/rpc/src/starknet.rs @@ -29,22 +29,38 @@ use katana_rpc_types::transaction::{ }; use katana_rpc_types::{ContractClass, FeeEstimate, FeltAsHex, FunctionCall}; use katana_rpc_types_builder::ReceiptBuilder; -use katana_tasks::TokioTaskSpawner; +use katana_tasks::{BlockingTaskPool, TokioTaskSpawner}; use starknet::core::types::{BlockTag, TransactionExecutionStatus, TransactionStatus}; use crate::api::starknet::{StarknetApiError, StarknetApiServer}; #[derive(Clone)] pub struct StarknetApi { + inner: Arc, +} + +struct StarknetApiInner { sequencer: Arc, + blocking_task_pool: BlockingTaskPool, } impl StarknetApi { pub fn new(sequencer: Arc) -> Self { - Self { sequencer } + let blocking_task_pool = + BlockingTaskPool::new().expect("failed to create blocking task pool"); + Self { inner: Arc::new(StarknetApiInner { sequencer, blocking_task_pool }) } } - async fn on_blocking_task(&self, func: F) -> T + async fn on_cpu_blocking_task(&self, func: F) -> T + where + F: FnOnce(Self) -> T + Send + 'static, + T: Send + 'static, + { + let this = self.clone(); + self.inner.blocking_task_pool.spawn(move || func(this)).await.unwrap() + } + + async fn on_io_blocking_task(&self, func: F) -> T where F: FnOnce(Self) -> T + Send + 'static, T: Send + 'static, @@ -65,7 +81,7 @@ impl StarknetApi { #[async_trait] impl StarknetApiServer for StarknetApi { async fn chain_id(&self) -> Result { - Ok(FieldElement::from(self.sequencer.chain_id()).into()) + Ok(FieldElement::from(self.inner.sequencer.chain_id()).into()) } async fn nonce( @@ -73,8 +89,9 @@ impl StarknetApiServer for StarknetApi { block_id: BlockIdOrTag, contract_address: FieldElement, ) -> Result { - self.on_blocking_task(move |this| { + self.on_io_blocking_task(move |this| { let nonce = this + .inner .sequencer .nonce_at(block_id, contract_address.into()) .map_err(StarknetApiError::from)? @@ -85,16 +102,18 @@ impl StarknetApiServer for StarknetApi { } async fn block_number(&self) -> Result { - self.on_blocking_task(move |this| { - let block_number = this.sequencer.block_number().map_err(StarknetApiError::from)?; + self.on_io_blocking_task(move |this| { + let block_number = + this.inner.sequencer.block_number().map_err(StarknetApiError::from)?; Ok(block_number) }) .await } async fn transaction_by_hash(&self, transaction_hash: FieldElement) -> Result { - self.on_blocking_task(move |this| { + self.on_io_blocking_task(move |this| { let tx = this + .inner .sequencer .transaction(&transaction_hash) .map_err(StarknetApiError::from)? @@ -105,8 +124,9 @@ impl StarknetApiServer for StarknetApi { } async fn block_transaction_count(&self, block_id: BlockIdOrTag) -> Result { - self.on_blocking_task(move |this| { + self.on_io_blocking_task(move |this| { let count = this + .inner .sequencer .block_tx_count(block_id) .map_err(StarknetApiError::from)? @@ -122,8 +142,9 @@ impl StarknetApiServer for StarknetApi { contract_address: FieldElement, ) -> Result { let class_hash = self - .on_blocking_task(move |this| { - this.sequencer + .on_io_blocking_task(move |this| { + this.inner + .sequencer .class_hash_at(block_id, contract_address.into()) .map_err(StarknetApiError::from)? .ok_or(StarknetApiError::ContractNotFound) @@ -134,7 +155,7 @@ impl StarknetApiServer for StarknetApi { async fn block_hash_and_number(&self) -> Result { let hash_and_num_pair = self - .on_blocking_task(move |this| this.sequencer.block_hash_and_number()) + .on_io_blocking_task(move |this| this.inner.sequencer.block_hash_and_number()) .await .map_err(StarknetApiError::from)?; Ok(hash_and_num_pair.into()) @@ -144,11 +165,11 @@ impl StarknetApiServer for StarknetApi { &self, block_id: BlockIdOrTag, ) -> Result { - self.on_blocking_task(move |this| { - let provider = this.sequencer.backend.blockchain.provider(); + self.on_io_blocking_task(move |this| { + let provider = this.inner.sequencer.backend.blockchain.provider(); if BlockIdOrTag::Tag(BlockTag::Pending) == block_id { - if let Some(pending_state) = this.sequencer.pending_state() { + if let Some(pending_state) = this.inner.sequencer.pending_state() { let block_env = pending_state.block_envs.read().0.clone(); let latest_hash = BlockHashProvider::latest_hash(provider).map_err(StarknetApiError::from)?; @@ -198,17 +219,17 @@ impl StarknetApiServer for StarknetApi { block_id: BlockIdOrTag, index: u64, ) -> Result { - self.on_blocking_task(move |this| { + self.on_io_blocking_task(move |this| { // TEMP: have to handle pending tag independently for now let tx = if BlockIdOrTag::Tag(BlockTag::Pending) == block_id { - let Some(pending_state) = this.sequencer.pending_state() else { + let Some(pending_state) = this.inner.sequencer.pending_state() else { return Err(StarknetApiError::BlockNotFound.into()); }; let pending_txs = pending_state.executed_txs.read(); pending_txs.iter().nth(index as usize).map(|(tx, _)| tx.clone()) } else { - let provider = &this.sequencer.backend.blockchain.provider(); + let provider = &this.inner.sequencer.backend.blockchain.provider(); let block_num = BlockIdReader::convert_block_id(provider, block_id) .map_err(StarknetApiError::from)? @@ -228,11 +249,11 @@ impl StarknetApiServer for StarknetApi { &self, block_id: BlockIdOrTag, ) -> Result { - self.on_blocking_task(move |this| { - let provider = this.sequencer.backend.blockchain.provider(); + self.on_io_blocking_task(move |this| { + let provider = this.inner.sequencer.backend.blockchain.provider(); if BlockIdOrTag::Tag(BlockTag::Pending) == block_id { - if let Some(pending_state) = this.sequencer.pending_state() { + if let Some(pending_state) = this.inner.sequencer.pending_state() { let block_env = pending_state.block_envs.read().0.clone(); let latest_hash = BlockHashProvider::latest_hash(provider).map_err(StarknetApiError::from)?; @@ -279,8 +300,8 @@ impl StarknetApiServer for StarknetApi { } async fn state_update(&self, block_id: BlockIdOrTag) -> Result { - self.on_blocking_task(move |this| { - let provider = this.sequencer.backend.blockchain.provider(); + self.on_io_blocking_task(move |this| { + let provider = this.inner.sequencer.backend.blockchain.provider(); let block_id = match block_id { BlockIdOrTag::Number(num) => BlockHashOrNumber::Num(num), @@ -307,8 +328,8 @@ impl StarknetApiServer for StarknetApi { &self, transaction_hash: FieldElement, ) -> Result { - self.on_blocking_task(move |this| { - let provider = this.sequencer.backend.blockchain.provider(); + self.on_io_blocking_task(move |this| { + let provider = this.inner.sequencer.backend.blockchain.provider(); let receipt = ReceiptBuilder::new(transaction_hash, provider) .build() .map_err(|e| StarknetApiError::UnexpectedError { reason: e.to_string() })?; @@ -317,7 +338,7 @@ impl StarknetApiServer for StarknetApi { Some(receipt) => Ok(MaybePendingTxReceipt::Receipt(receipt)), None => { - let pending_receipt = this.sequencer.pending_state().and_then(|s| { + let pending_receipt = this.inner.sequencer.pending_state().and_then(|s| { s.executed_txs .read() .iter() @@ -344,8 +365,9 @@ impl StarknetApiServer for StarknetApi { block_id: BlockIdOrTag, contract_address: FieldElement, ) -> Result { - self.on_blocking_task(move |this| { + self.on_io_blocking_task(move |this| { let hash = this + .inner .sequencer .class_hash_at(block_id, contract_address.into()) .map_err(StarknetApiError::from)? @@ -360,9 +382,9 @@ impl StarknetApiServer for StarknetApi { block_id: BlockIdOrTag, class_hash: FieldElement, ) -> Result { - self.on_blocking_task(move |this| { + self.on_io_blocking_task(move |this| { let class = - this.sequencer.class(block_id, class_hash).map_err(StarknetApiError::from)?; + this.inner.sequencer.class(block_id, class_hash).map_err(StarknetApiError::from)?; let Some(class) = class else { return Err(StarknetApiError::ClassHashNotFound.into()) }; match class { @@ -378,7 +400,7 @@ impl StarknetApiServer for StarknetApi { } async fn events(&self, filter: EventFilterWithPage) -> Result { - self.on_blocking_task(move |this| { + self.on_io_blocking_task(move |this| { let from_block = filter.event_filter.from_block.unwrap_or(BlockIdOrTag::Number(0)); let to_block = filter.event_filter.to_block.unwrap_or(BlockIdOrTag::Tag(BlockTag::Latest)); @@ -387,6 +409,7 @@ impl StarknetApiServer for StarknetApi { let keys = keys.filter(|keys| !(keys.len() == 1 && keys.is_empty())); let events = this + .inner .sequencer .events( from_block, @@ -408,14 +431,15 @@ impl StarknetApiServer for StarknetApi { request: FunctionCall, block_id: BlockIdOrTag, ) -> Result, Error> { - self.on_blocking_task(move |this| { + self.on_io_blocking_task(move |this| { let request = EntryPointCall { calldata: request.calldata, contract_address: request.contract_address.into(), entry_point_selector: request.entry_point_selector, }; - let res = this.sequencer.call(request, block_id).map_err(StarknetApiError::from)?; + let res = + this.inner.sequencer.call(request, block_id).map_err(StarknetApiError::from)?; Ok(res.into_iter().map(|v| v.into()).collect()) }) .await @@ -427,8 +451,9 @@ impl StarknetApiServer for StarknetApi { key: FieldElement, block_id: BlockIdOrTag, ) -> Result { - self.on_blocking_task(move |this| { + self.on_io_blocking_task(move |this| { let value = this + .inner .sequencer .storage_at(contract_address.into(), key, block_id) .map_err(StarknetApiError::from)?; @@ -442,12 +467,12 @@ impl StarknetApiServer for StarknetApi { &self, deploy_account_transaction: BroadcastedDeployAccountTx, ) -> Result { - self.on_blocking_task(move |this| { + self.on_io_blocking_task(move |this| { if deploy_account_transaction.is_query { return Err(StarknetApiError::UnsupportedTransactionVersion.into()); } - let chain_id = this.sequencer.chain_id(); + let chain_id = this.inner.sequencer.chain_id(); let tx = deploy_account_transaction.into_tx_with_chain_id(chain_id); let contract_address = tx.contract_address; @@ -455,7 +480,7 @@ impl StarknetApiServer for StarknetApi { let tx = ExecutableTxWithHash::new(ExecutableTx::DeployAccount(tx)); let tx_hash = tx.hash; - this.sequencer.add_transaction_to_pool(tx); + this.inner.sequencer.add_transaction_to_pool(tx); Ok((tx_hash, contract_address).into()) }) @@ -467,29 +492,24 @@ impl StarknetApiServer for StarknetApi { request: Vec, block_id: BlockIdOrTag, ) -> Result, Error> { - self.on_blocking_task(move |this| { - let chain_id = this.sequencer.chain_id(); - - println!("ohayo"); + self.on_cpu_blocking_task(move |this| { + let chain_id = this.inner.sequencer.chain_id(); let transactions = request .into_iter() .map(|tx| { let tx = match tx { BroadcastedTx::Invoke(tx) => { - println!("red"); let tx = tx.into_tx_with_chain_id(chain_id); ExecutableTxWithHash::new_query(ExecutableTx::Invoke(tx)) } BroadcastedTx::DeployAccount(tx) => { - println!("blue"); let tx = tx.into_tx_with_chain_id(chain_id); ExecutableTxWithHash::new_query(ExecutableTx::DeployAccount(tx)) } BroadcastedTx::Declare(tx) => { - println!("purple"); let tx = tx .try_into_tx_with_chain_id(chain_id) .map_err(|_| StarknetApiError::InvalidContractClass)?; @@ -501,22 +521,15 @@ impl StarknetApiServer for StarknetApi { }) .collect::, _>>()?; - println!("ohayo 2"); - let res = this + .inner .sequencer .estimate_fee(transactions, block_id) .map_err(StarknetApiError::from)?; - println!("ohayo 6"); - Ok(res) }) .await - .map_err(|e| { - println!("got error {:?}", e); - e - }) } async fn estimate_message_fee( @@ -524,14 +537,15 @@ impl StarknetApiServer for StarknetApi { message: MsgFromL1, block_id: BlockIdOrTag, ) -> Result { - self.on_blocking_task(move |this| { - let chain_id = this.sequencer.chain_id(); + self.on_cpu_blocking_task(move |this| { + let chain_id = this.inner.sequencer.chain_id(); let tx = message.into_tx_with_chain_id(chain_id); let hash = tx.calculate_hash(); let tx: ExecutableTxWithHash = ExecutableTxWithHash { hash, transaction: tx.into() }; let res = this + .inner .sequencer .estimate_fee(vec![tx], block_id) .map_err(StarknetApiError::from)? @@ -547,12 +561,12 @@ impl StarknetApiServer for StarknetApi { &self, declare_transaction: BroadcastedDeclareTx, ) -> Result { - self.on_blocking_task(move |this| { + self.on_io_blocking_task(move |this| { if declare_transaction.is_query() { return Err(StarknetApiError::UnsupportedTransactionVersion.into()); } - let chain_id = this.sequencer.chain_id(); + let chain_id = this.inner.sequencer.chain_id(); // // validate compiled class hash // let is_valid = declare_transaction @@ -571,7 +585,7 @@ impl StarknetApiServer for StarknetApi { let tx = ExecutableTxWithHash::new(ExecutableTx::Declare(tx)); let tx_hash = tx.hash; - this.sequencer.add_transaction_to_pool(tx); + this.inner.sequencer.add_transaction_to_pool(tx); Ok((tx_hash, class_hash).into()) }) @@ -582,18 +596,18 @@ impl StarknetApiServer for StarknetApi { &self, invoke_transaction: BroadcastedInvokeTx, ) -> Result { - self.on_blocking_task(move |this| { + self.on_io_blocking_task(move |this| { if invoke_transaction.is_query { return Err(StarknetApiError::UnsupportedTransactionVersion.into()); } - let chain_id = this.sequencer.chain_id(); + let chain_id = this.inner.sequencer.chain_id(); let tx = invoke_transaction.into_tx_with_chain_id(chain_id); let tx = ExecutableTxWithHash::new(ExecutableTx::Invoke(tx)); let tx_hash = tx.hash; - this.sequencer.add_transaction_to_pool(tx); + this.inner.sequencer.add_transaction_to_pool(tx); Ok(tx_hash.into()) }) @@ -604,8 +618,8 @@ impl StarknetApiServer for StarknetApi { &self, transaction_hash: TxHash, ) -> Result { - self.on_blocking_task(move |this| { - let provider = this.sequencer.backend.blockchain.provider(); + self.on_io_blocking_task(move |this| { + let provider = this.inner.sequencer.backend.blockchain.provider(); let tx_status = TransactionStatusProvider::transaction_status(provider, transaction_hash) @@ -632,7 +646,7 @@ impl StarknetApiServer for StarknetApi { } } - let pending_state = this.sequencer.pending_state(); + let pending_state = this.inner.sequencer.pending_state(); let state = pending_state.ok_or(StarknetApiError::TxnHashNotFound)?; let executed_txs = state.executed_txs.read(); diff --git a/crates/katana/tasks/src/lib.rs b/crates/katana/tasks/src/lib.rs index 1788c1a4ec..cbec89247d 100644 --- a/crates/katana/tasks/src/lib.rs +++ b/crates/katana/tasks/src/lib.rs @@ -1,9 +1,12 @@ +use std::any::Any; use std::future::Future; +use std::panic::{self, AssertUnwindSafe}; use std::pin::Pin; +use std::sync::Arc; use std::task::Poll; use futures::channel::oneshot; -use rayon::ThreadPoolBuilder; +use rayon::{ThreadPoolBuildError, ThreadPoolBuilder}; use tokio::runtime::Handle; use tokio::task::JoinHandle; @@ -14,6 +17,8 @@ pub struct TaskSpawnerInitError(tokio::runtime::TryCurrentError); /// A task spawner for spawning tasks on a tokio runtime. This is simple wrapper around a tokio's /// runtime [Handle] to easily spawn tasks on the runtime. +/// +/// For running expensive CPU-bound tasks, use [BlockingTaskPool] instead. #[derive(Clone)] pub struct TokioTaskSpawner { /// Handle to the tokio runtime. @@ -54,12 +59,15 @@ impl TokioTaskSpawner { } } +type BlockingTaskResult = Result>; + #[derive(Debug)] #[must_use = "BlockingTaskHandle does nothing unless polled"] -pub struct BlockingTaskHandle(oneshot::Receiver); +pub struct BlockingTaskHandle(oneshot::Receiver>); impl Future for BlockingTaskHandle { - type Output = T; + type Output = BlockingTaskResult; + fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll { match Pin::new(&mut self.get_mut().0).poll(cx) { Poll::Ready(Ok(res)) => Poll::Ready(res), @@ -69,30 +77,34 @@ impl Future for BlockingTaskHandle { } } -/// For expensive CPU-bound computations. For spawing blocking IO-bound tasks, use -/// [TokioTaskSpawner::spawn_blocking]. +/// This is mainly for expensive CPU-bound tasks. For spawing blocking IO-bound tasks, use +/// [TokioTaskSpawner::spawn_blocking] instead. +#[derive(Debug, Clone)] pub struct BlockingTaskPool { - pool: rayon::ThreadPool, + pool: Arc, } impl BlockingTaskPool { - pub fn new() -> Self { - Self { pool: Self::build().build().unwrap() } - } - pub fn build() -> ThreadPoolBuilder { ThreadPoolBuilder::new().thread_name(|i| format!("blocking-thread-pool-{i}")) } - pub fn spawn(&self, func: F) -> BlockingTaskHandle + pub fn new() -> Result { + Self::build().build().map(|pool| Self { pool: Arc::new(pool) }) + } + + pub fn new_with_pool(rayon_pool: rayon::ThreadPool) -> Self { + Self { pool: Arc::new(rayon_pool) } + } + + pub fn spawn(&self, func: F) -> BlockingTaskHandle where - F: FnOnce() -> T + Send + 'static, - T: Send + 'static, + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, { - let (tx, rx) = oneshot::channel::(); + let (tx, rx) = oneshot::channel(); self.pool.spawn(move || { - let res = func(); - let _ = tx.send(res); + let _ = tx.send(panic::catch_unwind(AssertUnwindSafe(func))); }); BlockingTaskHandle(rx) }