From 7194fb9402ada85a7ef0abda89fdf8e7aabc13b4 Mon Sep 17 00:00:00 2001 From: Ammar Arif Date: Fri, 10 May 2024 04:29:39 +0800 Subject: [PATCH] Refactor forked provider backend APIs (#1944) # Description ## Related issue ## Tests - [ ] Yes - [ ] No, because they aren't needed - [ ] No, because I need help ## Added to documentation? - [ ] README.md - [ ] [Dojo Book](https://github.com/dojoengine/book) - [ ] No documentation needed ## Checklist - [ ] I've formatted my code (`scripts/prettier.sh`, `scripts/rust_fmt.sh`, `scripts/cairo_fmt.sh`) - [ ] I've linted my code (`scripts/clippy.sh`, `scripts/docs.sh`) - [ ] I've commented my code - [ ] I've requested a review after addressing the comments --- bin/katana/src/args.rs | 2 +- crates/katana/storage/provider/src/error.rs | 2 +- .../provider/src/providers/fork/backend.rs | 601 ++++++++++-------- .../provider/src/providers/fork/mod.rs | 6 +- 4 files changed, 334 insertions(+), 277 deletions(-) diff --git a/bin/katana/src/args.rs b/bin/katana/src/args.rs index 3c3c99499c..5bde75658d 100644 --- a/bin/katana/src/args.rs +++ b/bin/katana/src/args.rs @@ -203,7 +203,7 @@ pub struct EnvironmentOptions { impl KatanaArgs { pub fn init_logging(&self) -> Result<(), Box> { - const DEFAULT_LOG_FILTER: &str = "info,executor=trace,forked_backend=trace,server=debug,\ + const DEFAULT_LOG_FILTER: &str = "info,executor=trace,forking::backend=trace,server=debug,\ katana_core=trace,blockifier=off,jsonrpsee_server=off,\ hyper=off,messaging=debug,node=error"; diff --git a/crates/katana/storage/provider/src/error.rs b/crates/katana/storage/provider/src/error.rs index a2b894ecf2..c02fda220d 100644 --- a/crates/katana/storage/provider/src/error.rs +++ b/crates/katana/storage/provider/src/error.rs @@ -106,7 +106,7 @@ pub enum ProviderError { /// [ForkedProvider](crate::providers::fork::ForkedProvider). #[cfg(feature = "fork")] #[error(transparent)] - ForkedBackend(#[from] crate::providers::fork::backend::ForkedBackendError), + ForkedBackend(#[from] crate::providers::fork::backend::BackendError), /// Any error that is not covered by the other variants. #[error("soemthing went wrong: {0}")] diff --git a/crates/katana/storage/provider/src/providers/fork/backend.rs b/crates/katana/storage/provider/src/providers/fork/backend.rs index 657279bde0..5fe359b8a1 100644 --- a/crates/katana/storage/provider/src/providers/fork/backend.rs +++ b/crates/katana/storage/provider/src/providers/fork/backend.rs @@ -1,11 +1,13 @@ use std::collections::VecDeque; use std::pin::Pin; -use std::sync::mpsc::{channel as oneshot, RecvError, Sender as OneshotSender}; +use std::sync::mpsc::{ + channel as oneshot, Receiver as OneshotReceiver, RecvError, Sender as OneshotSender, +}; use std::sync::Arc; use std::task::{Context, Poll}; -use std::thread; +use std::{io, thread}; -use futures::channel::mpsc::{channel, Receiver, SendError, Sender}; +use futures::channel::mpsc::{channel as async_channel, Receiver, SendError, Sender}; use futures::future::BoxFuture; use futures::stream::Stream; use futures::{Future, FutureExt}; @@ -18,10 +20,9 @@ use katana_primitives::conversion::rpc::{ }; use katana_primitives::FieldElement; use parking_lot::Mutex; -use starknet::core::types::{BlockId, ContractClass, StarknetError}; -use starknet::providers::jsonrpc::HttpTransport; -use starknet::providers::{JsonRpcClient, Provider, ProviderError as StarknetProviderError}; -use tracing::{error, trace}; +use starknet::core::types::{BlockId, ContractClass as RpcContractClass, StarknetError}; +use starknet::providers::{Provider, ProviderError as StarknetProviderError}; +use tracing::{error, info, trace}; use crate::error::ProviderError; use crate::providers::in_memory::cache::CacheStateDb; @@ -29,74 +30,164 @@ use crate::traits::contract::ContractClassProvider; use crate::traits::state::StateProvider; use crate::ProviderResult; -type GetNonceResult = Result; -type GetStorageResult = Result; -type GetClassHashAtResult = Result; -type GetClassAtResult = Result; +const LOG_TARGET: &str = "forking::backend"; + +type BackendResult = Result; -pub(crate) const LOG_TARGET: &str = "forked_backend"; +type GetNonceResult = BackendResult; +type GetStorageResult = BackendResult; +type GetClassHashAtResult = BackendResult; +type GetClassAtResult = BackendResult; #[derive(Debug, thiserror::Error)] -pub enum ForkedBackendError { - #[error("Failed to send request to the forked backend: {0}")] - Send(#[from] SendError), - #[error("Failed to receive result from the forked backend: {0}")] - Receive(#[from] RecvError), - #[error("Compute class hash error: {0}")] - ComputeClassHashError(String), - #[error("Failed to spawn forked backend thread: {0}")] - BackendThreadInit(#[from] std::io::Error), - #[error(transparent)] +pub enum BackendError { + #[error("failed to send request to backend: {0}")] + FailedSendRequest(#[from] SendError), + #[error("failed to receive result from backend: {0}")] + FailedReceiveResult(#[from] RecvError), + #[error("compute class hash error: {0}")] + ComputeClassHashError(anyhow::Error), + #[error("failed to spawn backend thread: {0}")] + BackendThreadInit(#[from] io::Error), + #[error("rpc provider error: {0}")] StarknetProvider(#[from] starknet::providers::ProviderError), } -/// The request types that is processed by [`Backend`]. +struct Request { + payload: P, + sender: OneshotSender>, +} + +/// The types of request that can be sent to [`Backend`]. /// -/// Each request is accompanied by the sender-half of a oneshot channel that will be used -/// to send the [`ProviderResult`] back to the backend client, [`ForkedBackend`], which sent the -/// requests. -pub enum BackendRequest { - GetClassAt(ClassHash, OneshotSender), - GetNonce(ContractAddress, OneshotSender), - GetClassHashAt(ContractAddress, OneshotSender), - GetStorage(ContractAddress, StorageKey, OneshotSender), +/// Each request consists of a payload and the sender half of a oneshot channel that will be used +/// to send the result back to the backend handle. +enum BackendRequest { + Nonce(Request), + Class(Request), + ClassHash(Request), + Storage(Request<(ContractAddress, StorageKey), StorageValue>), + // Test-only request kind for requesting the backend stats + #[cfg(test)] + Stats(OneshotSender), +} + +impl BackendRequest { + /// Create a new request for fetching the nonce of a contract. + fn nonce(address: ContractAddress) -> (BackendRequest, OneshotReceiver) { + let (sender, receiver) = oneshot(); + (BackendRequest::Nonce(Request { payload: address, sender }), receiver) + } + + /// Create a new request for fetching the class definitions of a contract. + fn class(hash: ClassHash) -> (BackendRequest, OneshotReceiver) { + let (sender, receiver) = oneshot(); + (BackendRequest::Class(Request { payload: hash, sender }), receiver) + } + + /// Create a new request for fetching the class hash of a contract. + fn class_hash( + address: ContractAddress, + ) -> (BackendRequest, OneshotReceiver) { + let (sender, receiver) = oneshot(); + (BackendRequest::ClassHash(Request { payload: address, sender }), receiver) + } + + /// Create a new request for fetching the storage value of a contract. + fn storage( + address: ContractAddress, + key: StorageKey, + ) -> (BackendRequest, OneshotReceiver) { + let (sender, receiver) = oneshot(); + (BackendRequest::Storage(Request { payload: (address, key), sender }), receiver) + } + + #[cfg(test)] + fn stats() -> (BackendRequest, OneshotReceiver) { + let (sender, receiver) = oneshot(); + (BackendRequest::Stats(sender), receiver) + } } type BackendRequestFuture = BoxFuture<'static, ()>; -/// The backend for the forked provider. It processes all requests from the [ForkedBackend]'s -/// and sends the ProviderResults back to it. +/// The backend for the forked provider. /// -/// It is responsible it fetching the data from the forked provider. -pub struct Backend { - provider: Arc>, +/// It is responsible for processing [requests](BackendRequest) to fetch data from the remote +/// provider. +pub struct Backend

{ + /// The Starknet RPC provider that will be used to fetch data from. + provider: Arc

, /// Requests that are currently being poll. pending_requests: Vec, /// Requests that are queued to be polled. queued_requests: VecDeque, - /// A channel for receiving requests from the [ForkedBackend]'s. + /// A channel for receiving requests from the [BackendHandle]s. incoming: Receiver, /// Pinned block id for all requests. block: BlockId, } -impl Backend { - /// This function is responsible for transforming the incoming request - /// into a future that will be polled until completion by the `BackendHandler`. - /// - /// Each request is accompanied by the sender-half of a oneshot channel that will be used - /// to send the ProviderResult back to the [ForkedBackend] which sent the requests. +impl

Backend

+where + P: Provider + Send + Sync + 'static, +{ + // TODO(kariy): create a `.start()` method start running the backend logic and let the users + // choose which thread to running it on instead of spawning the thread ourselves. + /// Create a new [Backend] with the given provider and block id, and returns a handle to it. The + /// backend will start processing requests immediately upon creation. + #[allow(clippy::new_ret_no_self)] + pub fn new(provider: P, block_id: BlockHashOrNumber) -> Result { + let (handle, backend) = Self::new_inner(provider, block_id); + + thread::Builder::new() + .name("forking-backend".into()) + .spawn(move || { + tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .expect("failed to create tokio runtime") + .block_on(backend); + }) + .map_err(BackendError::BackendThreadInit)?; + + info!(target: LOG_TARGET, "Forking backend started."); + + Ok(handle) + } + + fn new_inner(provider: P, block_id: BlockHashOrNumber) -> (BackendHandle, Backend

) { + let block = match block_id { + BlockHashOrNumber::Hash(hash) => BlockId::Hash(hash), + BlockHashOrNumber::Num(number) => BlockId::Number(number), + }; + + // Create async channel to receive requests from the handle. + let (tx, rx) = async_channel(100); + let backend = Backend { + block, + incoming: rx, + provider: Arc::new(provider), + pending_requests: Vec::new(), + queued_requests: VecDeque::new(), + }; + + (BackendHandle(Mutex::new(tx)), backend) + } + + /// This method is responsible for transforming the incoming request + /// sent from a [BackendHandle] into a RPC request to the remote network. fn handle_requests(&mut self, request: BackendRequest) { let block = self.block; let provider = self.provider.clone(); match request { - BackendRequest::GetNonce(contract_address, sender) => { + BackendRequest::Nonce(Request { payload, sender }) => { let fut = Box::pin(async move { let res = provider - .get_nonce(block, Into::::into(contract_address)) + .get_nonce(block, FieldElement::from(payload)) .await - .map_err(ForkedBackendError::StarknetProvider); + .map_err(BackendError::StarknetProvider); sender.send(res).expect("failed to send nonce result") }); @@ -104,12 +195,12 @@ impl Backend { self.pending_requests.push(fut); } - BackendRequest::GetStorage(contract_address, key, sender) => { + BackendRequest::Storage(Request { payload: (addr, key), sender }) => { let fut = Box::pin(async move { let res = provider - .get_storage_at(Into::::into(contract_address), key, block) + .get_storage_at(FieldElement::from(addr), key, block) .await - .map_err(ForkedBackendError::StarknetProvider); + .map_err(BackendError::StarknetProvider); sender.send(res).expect("failed to send storage result") }); @@ -117,12 +208,12 @@ impl Backend { self.pending_requests.push(fut); } - BackendRequest::GetClassHashAt(contract_address, sender) => { + BackendRequest::ClassHash(Request { payload, sender }) => { let fut = Box::pin(async move { let res = provider - .get_class_hash_at(block, Into::::into(contract_address)) + .get_class_hash_at(block, FieldElement::from(payload)) .await - .map_err(ForkedBackendError::StarknetProvider); + .map_err(BackendError::StarknetProvider); sender.send(res).expect("failed to send class hash result") }); @@ -130,23 +221,32 @@ impl Backend { self.pending_requests.push(fut); } - BackendRequest::GetClassAt(class_hash, sender) => { + BackendRequest::Class(Request { payload, sender }) => { let fut = Box::pin(async move { let res = provider - .get_class(block, class_hash) + .get_class(block, payload) .await - .map_err(ForkedBackendError::StarknetProvider); + .map_err(BackendError::StarknetProvider); sender.send(res).expect("failed to send class result") }); self.pending_requests.push(fut); } + + #[cfg(test)] + BackendRequest::Stats(sender) => { + let total_ongoing_request = self.pending_requests.len(); + sender.send(total_ongoing_request).expect("failed to send backend stats"); + } } } } -impl Future for Backend { +impl

Future for Backend

+where + P: Provider + Send + Sync + 'static, +{ type Output = (); fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { @@ -190,142 +290,80 @@ impl Future for Backend { } } -/// A thread safe handler to the [`Backend`]. This is the primary interface for sending -/// request to the backend thread to fetch data from the forked provider. -pub struct ForkedBackend(Mutex>); +/// A thread safe handler to [`Backend`]. +/// +/// This is the primary interface for sending request to the backend to fetch data from the remote +/// network. +pub struct BackendHandle(Mutex>); -impl Clone for ForkedBackend { +impl Clone for BackendHandle { fn clone(&self) -> Self { Self(Mutex::new(self.0.lock().clone())) } } -impl ForkedBackend { - /// Create a new [`ForkedBackend`] with a dedicated backend thread. - /// - /// This method will spawn a new thread that will run the [`Backend`]. - pub fn new_with_backend_thread( - provider: Arc>, - block_id: BlockHashOrNumber, - ) -> Result { - let (handler, backend) = Self::new(provider, block_id); - - thread::Builder::new().spawn(move || { - tokio::runtime::Builder::new_current_thread() - .enable_all() - .build() - .expect("failed to create tokio runtime") - .block_on(backend); - })?; - - trace!(target: LOG_TARGET, "Fork backend thread spawned."); - - Ok(handler) - } - - fn new( - provider: Arc>, - block_id: BlockHashOrNumber, - ) -> (Self, Backend) { - let block = match block_id { - BlockHashOrNumber::Hash(hash) => BlockId::Hash(hash), - BlockHashOrNumber::Num(number) => BlockId::Number(number), - }; - - let (sender, rx) = channel(1); - let backend = Backend { - incoming: rx, - provider, - block, - queued_requests: VecDeque::new(), - pending_requests: Vec::new(), - }; - - (Self(Mutex::new(sender)), backend) - } - - pub fn do_get_nonce( - &self, - contract_address: ContractAddress, - ) -> Result { - trace!(target: LOG_TARGET, contract_address = %contract_address, "Requesting nonce for contract address."); - let (sender, rx) = oneshot(); - self.0 - .lock() - .try_send(BackendRequest::GetNonce(contract_address, sender)) - .map_err(|e| e.into_send_error())?; +impl BackendHandle { + pub fn get_nonce(&self, address: ContractAddress) -> Result { + trace!(target: LOG_TARGET, %address, "Requesting contract nonce."); + let (req, rx) = BackendRequest::nonce(address); + self.request(req)?; rx.recv()? } - pub fn do_get_storage( + pub fn get_storage( &self, - contract_address: ContractAddress, + address: ContractAddress, key: StorageKey, - ) -> Result { - trace!( - target: LOG_TARGET, - contract_address = %contract_address, - key = %format!("{:#x}", key), - "Requesting storage." - ); - let (sender, rx) = oneshot(); - self.0 - .lock() - .try_send(BackendRequest::GetStorage(contract_address, key, sender)) - .map_err(|e| e.into_send_error())?; + ) -> Result { + trace!(target: LOG_TARGET, %address, key = %format!("{key:#x}"), "Requesting contract storage."); + let (req, rx) = BackendRequest::storage(address, key); + self.request(req)?; rx.recv()? } - pub fn do_get_class_hash_at( - &self, - contract_address: ContractAddress, - ) -> Result { - trace!(target: LOG_TARGET, contract_address = %contract_address, "Requesting class hash at address."); - let (sender, rx) = oneshot(); - self.0 - .lock() - .try_send(BackendRequest::GetClassHashAt(contract_address, sender)) - .map_err(|e| e.into_send_error())?; + pub fn get_class_hash_at(&self, address: ContractAddress) -> Result { + trace!(target: LOG_TARGET, %address, "Requesting contract class hash."); + let (req, rx) = BackendRequest::class_hash(address); + self.request(req)?; rx.recv()? } - pub fn do_get_class_at( - &self, - class_hash: ClassHash, - ) -> Result { - trace!( - target: LOG_TARGET, - class_hash = %format!("{:#x}", class_hash), - "Requesting class." - ); - let (sender, rx) = oneshot(); - self.0 - .lock() - .try_send(BackendRequest::GetClassAt(class_hash, sender)) - .map_err(|e| e.into_send_error())?; + pub fn get_class_at(&self, class_hash: ClassHash) -> Result { + trace!(target: LOG_TARGET, class_hash = %format!("{class_hash:#x}"), "Requesting class."); + let (req, rx) = BackendRequest::class(class_hash); + self.request(req)?; rx.recv()? } - pub fn do_get_compiled_class_hash( + pub fn get_compiled_class_hash( &self, class_hash: ClassHash, - ) -> Result { - trace!( - target: LOG_TARGET, - class_hash = %format!("{:#x}", class_hash), - "Requesting compiled class hash." - ); - let class = self.do_get_class_at(class_hash)?; + ) -> Result { + trace!(target: LOG_TARGET, class_hash = %format!("{class_hash:#x}"), "Requesting compiled class hash."); + let class = self.get_class_at(class_hash)?; // if its a legacy class, then we just return back the class hash // else if sierra class, then we have to compile it and compute the compiled class hash. match class { - starknet::core::types::ContractClass::Legacy(_) => Ok(class_hash), - starknet::core::types::ContractClass::Sierra(sierra_class) => { + RpcContractClass::Legacy(_) => Ok(class_hash), + RpcContractClass::Sierra(sierra_class) => { compiled_class_hash_from_flattened_sierra_class(&sierra_class) - .map_err(|e| ForkedBackendError::ComputeClassHashError(e.to_string())) + .map_err(BackendError::ComputeClassHashError) } } } + + /// Send a request to the backend thread. + fn request(&self, req: BackendRequest) -> Result<(), BackendError> { + self.0.lock().try_send(req).map_err(|e| e.into_send_error())?; + Ok(()) + } + + #[cfg(test)] + fn stats(&self) -> Result { + let (req, rx) = BackendRequest::stats(); + self.request(req)?; + Ok(rx.recv()?) + } } /// A shared cache that stores data fetched from the forked network. @@ -334,10 +372,10 @@ impl ForkedBackend { /// cache to avoid fetching it again. This is shared across multiple instances of /// [`ForkedStateDb`](super::state::ForkedStateDb). #[derive(Clone)] -pub struct SharedStateProvider(Arc>); +pub struct SharedStateProvider(Arc>); impl SharedStateProvider { - pub(crate) fn new_with_backend(backend: ForkedBackend) -> Self { + pub(crate) fn new_with_backend(backend: BackendHandle) -> Self { Self(Arc::new(CacheStateDb::new(backend))) } } @@ -371,17 +409,10 @@ impl StateProvider for SharedStateProvider { return Ok(nonce); } - if let Some(nonce) = handle_contract_or_class_not_found_err(self.0.do_get_nonce(address)) - .map_err(|e| { - error!( - target: LOG_TARGET, - contract_address = %address, - error = %e, - "Fetching nonce." - ); - e - })? - { + if let Some(nonce) = handle_not_found_err(self.0.get_nonce(address)).map_err(|error| { + error!(target: LOG_TARGET, %address, %error, "Fetching nonce."); + error + })? { self.0.contract_state.write().entry(address).or_default().nonce = nonce; Ok(Some(nonce)) } else { @@ -401,17 +432,10 @@ impl StateProvider for SharedStateProvider { } let value = - handle_contract_or_class_not_found_err(self.0.do_get_storage(address, storage_key)) - .map_err(|e| { - error!( - target: LOG_TARGET, - address = %address, - storage_key = %format!("{:#x}", storage_key), - error = %e, - "Fetching storage value." - ); - e - })?; + handle_not_found_err(self.0.get_storage(address, storage_key)).map_err(|error| { + error!(target: LOG_TARGET, %address, storage_key = %format!("{storage_key:#x}"), %error, "Fetching storage value."); + error + })?; self.0 .storage @@ -439,18 +463,12 @@ impl StateProvider for SharedStateProvider { return Ok(hash); } - if let Some(hash) = handle_contract_or_class_not_found_err( - self.0.do_get_class_hash_at(address), - ) - .map_err(|e| { - error!( - target: LOG_TARGET, - contract_address = %address, - error = %e, - "Fetching class hash." - ); - e - })? { + if let Some(hash) = + handle_not_found_err(self.0.get_class_hash_at(address)).map_err(|error| { + error!(target: LOG_TARGET, %address, %error, "Fetching class hash."); + error + })? + { self.0.contract_state.write().entry(address).or_default().class_hash = hash; Ok(Some(hash)) } else { @@ -465,16 +483,10 @@ impl ContractClassProvider for SharedStateProvider { return Ok(class.cloned()); } - let Some(class) = handle_contract_or_class_not_found_err(self.0.do_get_class_at(hash)) - .map_err(|e| { - error!( - target: LOG_TARGET, - hash = %format!("{:#x}", hash), - error = %e, - "Fetching sierra class." - ); - e - })? + let Some(class) = handle_not_found_err(self.0.get_class_at(hash)).map_err(|error| { + error!(target: LOG_TARGET, hash = %format!("{hash:#x}"), %error, "Fetching sierra class."); + error + })? else { return Ok(None); }; @@ -501,16 +513,10 @@ impl ContractClassProvider for SharedStateProvider { } if let Some(hash) = - handle_contract_or_class_not_found_err(self.0.do_get_compiled_class_hash(hash)) - .map_err(|e| { - error!( - target: LOG_TARGET, - hash = %format!("{:#x}", hash), - error = %e, - "Fetching compiled class hash." - ); - e - })? + handle_not_found_err(self.0.get_compiled_class_hash(hash)).map_err(|error| { + error!(target: LOG_TARGET, hash = %format!("{hash:#x}"), %error, "Fetching compiled class hash."); + error + })? { self.0.compiled_class_hashes.write().insert(hash, hash); Ok(Some(hash)) @@ -524,45 +530,29 @@ impl ContractClassProvider for SharedStateProvider { return Ok(Some(class.clone())); } - let Some(class) = handle_contract_or_class_not_found_err(self.0.do_get_class_at(hash)) - .map_err(|e| { - error!( - target: LOG_TARGET, - hash = %format!("{:#x}", hash), - error = %e, - "Fetching class." - ); - e - })? + let Some(class) = handle_not_found_err(self.0.get_class_at(hash)).map_err(|error| { + error!(target: LOG_TARGET, hash = %format!("{hash:#x}"), %error, "Fetching class."); + error + })? else { return Ok(None); }; let (class_hash, compiled_class_hash, casm, sierra) = match class { - ContractClass::Legacy(class) => { - let (_, compiled_class) = legacy_rpc_to_compiled_class(&class).map_err(|e| { - error!( - target: LOG_TARGET, - hash = %format!("{:#x}", hash), - error = %e, - "Parsing legacy class." - ); - ProviderError::ParsingError(e.to_string()) + RpcContractClass::Legacy(class) => { + let (_, compiled_class) = legacy_rpc_to_compiled_class(&class).map_err(|error| { + error!(target: LOG_TARGET, hash = %format!("{hash:#x}"), %error, "Parsing legacy class."); + ProviderError::ParsingError(error.to_string()) })?; (hash, hash, compiled_class, None) } - ContractClass::Sierra(sierra_class) => { + RpcContractClass::Sierra(sierra_class) => { let (_, compiled_class_hash, compiled_class) = - flattened_sierra_to_compiled_class(&sierra_class).map_err(|e| { - error!( - target: LOG_TARGET, - hash = %format!("{:#x}", hash), - error = %e, - "Parsing sierra class." - ); - ProviderError::ParsingError(e.to_string()) + flattened_sierra_to_compiled_class(&sierra_class).map_err(|error| { + error!(target: LOG_TARGET, hash = %format!("{hash:#x}"), %error, "Parsing sierra class."); + ProviderError::ParsingError(error.to_string()) })?; (hash, compiled_class_hash, compiled_class, Some(sierra_class)) @@ -591,13 +581,16 @@ impl ContractClassProvider for SharedStateProvider { } } -fn handle_contract_or_class_not_found_err( - result: Result, -) -> Result, ForkedBackendError> { +/// A helper function to convert a contract/class not found error returned by the RPC provider into +/// a `Option::None`. +/// +/// This is to follow the Katana's provider APIs convention where 'not found'/'non-existent' should +/// be represented as `Option::None`. +fn handle_not_found_err(result: Result) -> Result, BackendError> { match result { Ok(value) => Ok(Some(value)), - Err(ForkedBackendError::StarknetProvider(StarknetProviderError::StarknetError( + Err(BackendError::StarknetProvider(StarknetProviderError::StarknetError( StarknetError::ContractNotFound | StarknetError::ClassHashNotFound, ))) => Ok(None), @@ -607,9 +600,16 @@ fn handle_contract_or_class_not_found_err( #[cfg(test)] mod tests { + + use std::sync::mpsc::sync_channel; + use std::time::Duration; + use katana_primitives::block::BlockNumber; use katana_primitives::contract::GenericContractInfo; use starknet::macros::felt; + use starknet::providers::jsonrpc::HttpTransport; + use starknet::providers::JsonRpcClient; + use tokio::net::TcpListener; use url::Url; use super::*; @@ -622,32 +622,89 @@ mod tests { const ADDR_1_STORAGE_VALUE: StorageKey = felt!("0x8080"); const ADDR_1_CLASS_HASH: StorageKey = felt!("0x1"); - fn create_forked_backend(rpc_url: String, block_num: BlockNumber) -> (ForkedBackend, Backend) { - ForkedBackend::new( - Arc::new(JsonRpcClient::new(HttpTransport::new( - Url::parse(&rpc_url).expect("valid url"), - ))), - BlockHashOrNumber::Num(block_num), - ) + fn create_forked_backend(rpc_url: String, block_num: BlockNumber) -> BackendHandle { + let url = Url::parse(&rpc_url).expect("valid url"); + let provider = Arc::new(JsonRpcClient::new(HttpTransport::new(url))); + let block_id = BlockHashOrNumber::Num(block_num); + Backend::new(provider, block_id).unwrap() } - fn create_forked_backend_with_backend_thread( - rpc_url: String, - block_num: BlockNumber, - ) -> ForkedBackend { - ForkedBackend::new_with_backend_thread( - Arc::new(JsonRpcClient::new(HttpTransport::new( - Url::parse(&rpc_url).expect("valid url"), - ))), - BlockHashOrNumber::Num(block_num), - ) - .unwrap() + // Starts a TCP server that never close the connection. + fn start_tcp_server() { + use tokio::runtime::Builder; + + let (tx, rx) = sync_channel::<()>(1); + thread::spawn(move || { + Builder::new_current_thread().enable_all().build().unwrap().block_on(async move { + let listener = TcpListener::bind("127.0.0.1:8080").await.unwrap(); + let mut connections = Vec::new(); + + tx.send(()).unwrap(); + + loop { + let (socket, _) = listener.accept().await.unwrap(); + connections.push(socket); + } + }); + }); + + rx.recv().unwrap(); + } + + const ERROR_INIT_BACKEND: &str = "Failed to create backend"; + const ERROR_SEND_REQUEST: &str = "Failed to send request to backend"; + const ERROR_STATS: &str = "Failed to get stats"; + + #[test] + fn handle_incoming_requests() { + let url = Url::try_from("http://127.0.0.1:8080").unwrap(); + let provider = JsonRpcClient::new(HttpTransport::new(url)); + let block_id = BlockHashOrNumber::Num(1); + + // start a mock remote network + start_tcp_server(); + + // start backend + let handle = Backend::new(Arc::new(provider), block_id).expect(ERROR_INIT_BACKEND); + + // check no pending requests + let stats = handle.stats().expect(ERROR_STATS); + assert_eq!(stats, 0, "Backend should not have any ongoing requests."); + + // send requests to the backend + let h1 = handle.clone(); + thread::spawn(move || { + h1.get_nonce(felt!("0x1").into()).expect(ERROR_SEND_REQUEST); + }); + let h2 = handle.clone(); + thread::spawn(move || { + h2.get_class_at(felt!("0x1")).expect(ERROR_SEND_REQUEST); + }); + let h3 = handle.clone(); + thread::spawn(move || { + h3.get_compiled_class_hash(felt!("0x1")).expect(ERROR_SEND_REQUEST); + }); + let h4 = handle.clone(); + thread::spawn(move || { + h4.get_class_hash_at(felt!("0x1").into()).expect(ERROR_SEND_REQUEST); + }); + let h5 = handle.clone(); + thread::spawn(move || { + h5.get_storage(felt!("0x1").into(), felt!("0x1")).expect(ERROR_SEND_REQUEST); + }); + + // wait for the requests to be handled + thread::sleep(Duration::from_secs(1)); + + // check request are handled + let stats = handle.stats().expect(ERROR_STATS); + assert_eq!(stats, 5, "Backend should have 5 ongoing requests.") } #[test] fn get_from_cache_if_exist() { // setup - let (backend, _) = create_forked_backend(LOCAL_RPC_URL.into(), 1); + let backend = create_forked_backend(LOCAL_RPC_URL.into(), 1); let state_db = CacheStateDb::new(backend); state_db @@ -677,7 +734,7 @@ mod tests { #[test] fn fetch_from_fork_will_err_if_backend_thread_not_running() { - let (backend, _) = create_forked_backend(LOCAL_RPC_URL.into(), 1); + let backend = create_forked_backend(LOCAL_RPC_URL.into(), 1); let provider = SharedStateProvider(Arc::new(CacheStateDb::new(backend))); assert!(StateProvider::nonce(&provider, ADDR_1).is_err()) } @@ -694,7 +751,7 @@ mod tests { #[test] #[ignore] fn fetch_from_fork_if_not_in_cache() { - let backend = create_forked_backend_with_backend_thread(FORKED_URL.into(), 908622); + let backend = create_forked_backend(FORKED_URL.into(), 908622); let provider = SharedStateProvider(Arc::new(CacheStateDb::new(backend))); // fetch from remote diff --git a/crates/katana/storage/provider/src/providers/fork/mod.rs b/crates/katana/storage/provider/src/providers/fork/mod.rs index 16237cb7d3..3fc82df95b 100644 --- a/crates/katana/storage/provider/src/providers/fork/mod.rs +++ b/crates/katana/storage/provider/src/providers/fork/mod.rs @@ -20,7 +20,7 @@ use parking_lot::RwLock; use starknet::providers::jsonrpc::HttpTransport; use starknet::providers::JsonRpcClient; -use self::backend::{ForkedBackend, ForkedBackendError, SharedStateProvider}; +use self::backend::{Backend, BackendError, SharedStateProvider}; use self::state::ForkedStateDb; use super::in_memory::cache::{CacheDb, CacheStateDb}; use super::in_memory::state::HistoricalStates; @@ -49,8 +49,8 @@ impl ForkedProvider { pub fn new( provider: Arc>, block_id: BlockHashOrNumber, - ) -> Result { - let backend = ForkedBackend::new_with_backend_thread(provider, block_id)?; + ) -> Result { + let backend = Backend::new(provider, block_id)?; let shared_provider = SharedStateProvider::new_with_backend(backend); let storage = RwLock::new(CacheDb::new(()));