diff --git a/crates/gateway/src/errors.rs b/crates/gateway/src/errors.rs index 4ff5cfe5..fb465f67 100644 --- a/crates/gateway/src/errors.rs +++ b/crates/gateway/src/errors.rs @@ -5,7 +5,8 @@ use blockifier::execution::errors::ContractClassError; use blockifier::state::errors::StateError; use blockifier::transaction::errors::TransactionExecutionError; use cairo_vm::types::errors::program_errors::ProgramError; -use starknet_api::block::BlockNumber; +use serde_json::{Error as SerdeError, Value}; +use starknet_api::block::{BlockNumber, GasPrice}; use starknet_api::transaction::{Resource, ResourceBounds}; use starknet_api::StarknetApiError; use thiserror::Error; @@ -102,3 +103,42 @@ pub enum GatewayRunError { #[error(transparent)] ServerStartupError(#[from] hyper::Error), } + +#[derive(Debug, Error)] +pub enum RPCStateReaderError { + #[error("Block not found for request {0}")] + BlockNotFound(Value), + #[error("Class hash not found for request {0}")] + ClassHashNotFound(Value), + #[error("Failed to parse gas price {:?}", 0)] + GasPriceParsingFailure(GasPrice), + #[error("Contract address not found for request {0}")] + ContractAddressNotFound(Value), + #[error(transparent)] + ReqwestError(#[from] reqwest::Error), + #[error("RPC error: {0}")] + RPCError(StatusCode), + #[error("Unexpected error code: {0}")] + UnexpectedErrorCode(u16), +} + +pub type RPCStateReaderResult = Result; + +impl From for StateError { + fn from(err: RPCStateReaderError) -> Self { + match err { + RPCStateReaderError::ClassHashNotFound(request) => { + match serde_json::from_value(request["params"]["class_hash"].clone()) { + Ok(class_hash) => StateError::UndeclaredClassHash(class_hash), + Err(e) => serde_err_to_state_err(e), + } + } + _ => StateError::StateReadError(err.to_string()), + } + } +} + +// Converts a serde error to the error type of the state reader. +pub fn serde_err_to_state_err(err: SerdeError) -> StateError { + StateError::StateReadError(format!("Failed to parse rpc result {:?}", err.to_string())) +} diff --git a/crates/gateway/src/rpc_objects.rs b/crates/gateway/src/rpc_objects.rs index d713f208..f6295ab2 100644 --- a/crates/gateway/src/rpc_objects.rs +++ b/crates/gateway/src/rpc_objects.rs @@ -1,7 +1,6 @@ use std::num::NonZeroU128; use blockifier::blockifier::block::{BlockInfo, GasPrices}; -use blockifier::state::errors::StateError; use serde::{Deserialize, Serialize}; use serde_json::Value; use starknet_api::block::{BlockHash, BlockNumber, BlockTimestamp, GasPrice}; @@ -9,6 +8,8 @@ use starknet_api::core::{ClassHash, ContractAddress, GlobalRoot}; use starknet_api::data_availability::L1DataAvailabilityMode; use starknet_api::state::StorageKey; +use crate::errors::RPCStateReaderError; + // Starknet Spec error codes: // TODO(yael 30/4/2024): consider turning these into an enum. pub const RPC_ERROR_CONTRACT_ADDRESS_NOT_FOUND: u16 = 20; @@ -78,7 +79,7 @@ pub struct BlockHeader { } impl TryInto for BlockHeader { - type Error = StateError; + type Error = RPCStateReaderError; fn try_into(self) -> Result { Ok(BlockInfo { block_number: self.block_number, @@ -95,9 +96,8 @@ impl TryInto for BlockHeader { } } -fn parse_gas_price(gas_price: GasPrice) -> Result { - NonZeroU128::new(gas_price.0) - .ok_or(StateError::StateReadError("Couldn't parse gas_price".to_string())) +fn parse_gas_price(gas_price: GasPrice) -> Result { + NonZeroU128::new(gas_price.0).ok_or(RPCStateReaderError::GasPriceParsingFailure(gas_price)) } #[derive(Serialize, Deserialize, Debug)] diff --git a/crates/gateway/src/rpc_state_reader.rs b/crates/gateway/src/rpc_state_reader.rs index f367a12f..5ddcc74b 100644 --- a/crates/gateway/src/rpc_state_reader.rs +++ b/crates/gateway/src/rpc_state_reader.rs @@ -4,9 +4,8 @@ use blockifier::state::errors::StateError; use blockifier::state::state_api::{StateReader as BlockifierStateReader, StateResult}; use cairo_lang_starknet_classes::casm_contract_class::CasmContractClass; use reqwest::blocking::Client as BlockingClient; -use reqwest::Error as ReqwestError; use serde::{Deserialize, Serialize}; -use serde_json::{json, Error as SerdeError, Value}; +use serde_json::{json, Value}; use starknet_api::block::BlockNumber; use starknet_api::core::{ClassHash, CompiledClassHash, ContractAddress, Nonce}; use starknet_api::deprecated_contract_class::ContractClass as StarknetApiDeprecatedContractClass; @@ -14,6 +13,7 @@ use starknet_api::hash::StarkFelt; use starknet_api::state::StorageKey; use crate::config::RpcStateReaderConfig; +use crate::errors::{serde_err_to_state_err, RPCStateReaderError, RPCStateReaderResult}; use crate::rpc_objects::{ BlockHeader, BlockId, GetBlockWithTxHashesParams, GetClassHashAtParams, GetCompiledContractClassParams, GetNonceParams, GetStorageAtParams, RpcResponse, @@ -39,7 +39,7 @@ impl RpcStateReader { &self, method: &str, params: impl Serialize, - ) -> Result { + ) -> RPCStateReaderResult { let request_body = json!({ "jsonrpc": self.config.json_rpc_version, "id": 0, @@ -52,45 +52,32 @@ impl RpcStateReader { .post(self.config.url.clone()) .header("Content-Type", "application/json") .json(&request_body) - .send() - .map_err(reqwest_err_to_state_err)?; + .send()?; if !response.status().is_success() { - return Err(StateError::StateReadError(format!( - "RPC ERROR, code {}", - response.status() - ))); + return Err(RPCStateReaderError::RPCError(response.status())); } - let rpc_response: RpcResponse = - response.json::().map_err(reqwest_err_to_state_err)?; + let rpc_response: RpcResponse = response.json::()?; match rpc_response { RpcResponse::Success(rpc_success_response) => Ok(rpc_success_response.result), RpcResponse::Error(rpc_error_response) => match rpc_error_response.error.code { - RPC_ERROR_BLOCK_NOT_FOUND => Err(StateError::StateReadError(format!( - "Block not found, request: {}", - request_body - ))), - RPC_ERROR_CONTRACT_ADDRESS_NOT_FOUND => Err(StateError::StateReadError(format!( - "Contract address not found, request: {}", - request_body - ))), - RPC_CLASS_HASH_NOT_FOUND => Err(StateError::StateReadError(format!( - "Class hash not found, request: {}", - request_body - ))), - _ => Err(StateError::StateReadError(format!( - "Unexpected error code {}", - rpc_error_response.error.code - ))), + RPC_ERROR_BLOCK_NOT_FOUND => Err(RPCStateReaderError::BlockNotFound(request_body)), + RPC_ERROR_CONTRACT_ADDRESS_NOT_FOUND => { + Err(RPCStateReaderError::ContractAddressNotFound(request_body)) + } + RPC_CLASS_HASH_NOT_FOUND => { + Err(RPCStateReaderError::ClassHashNotFound(request_body)) + } + _ => Err(RPCStateReaderError::UnexpectedErrorCode(rpc_error_response.error.code)), }, } } } impl MempoolStateReader for RpcStateReader { - fn get_block_info(&self) -> Result { + fn get_block_info(&self) -> StateResult { let get_block_params = GetBlockWithTxHashesParams { block_id: self.block_id }; // The response from the rpc is a full block but we only deserialize the header. @@ -120,9 +107,15 @@ impl BlockifierStateReader for RpcStateReader { fn get_nonce_at(&self, contract_address: ContractAddress) -> StateResult { let get_nonce_params = GetNonceParams { block_id: self.block_id, contract_address }; - let result = self.send_rpc_request("starknet_getNonce", get_nonce_params)?; - let nonce: Nonce = serde_json::from_value(result).map_err(serde_err_to_state_err)?; - Ok(nonce) + let result = self.send_rpc_request("starknet_getNonce", get_nonce_params); + match result { + Ok(value) => { + let nonce: Nonce = serde_json::from_value(value).map_err(serde_err_to_state_err)?; + Ok(nonce) + } + Err(RPCStateReaderError::ContractAddressNotFound(_)) => Ok(Nonce::default()), + Err(e) => Err(e)?, + } } fn get_compiled_contract_class(&self, class_hash: ClassHash) -> StateResult { @@ -147,10 +140,16 @@ impl BlockifierStateReader for RpcStateReader { let get_class_hash_at_params = GetClassHashAtParams { contract_address, block_id: self.block_id }; - let result = self.send_rpc_request("starknet_getClassHashAt", get_class_hash_at_params)?; - let class_hash: ClassHash = - serde_json::from_value(result).map_err(serde_err_to_state_err)?; - Ok(class_hash) + let result = self.send_rpc_request("starknet_getClassHashAt", get_class_hash_at_params); + match result { + Ok(value) => { + let class_hash: ClassHash = + serde_json::from_value(value).map_err(serde_err_to_state_err)?; + Ok(class_hash) + } + Err(RPCStateReaderError::ContractAddressNotFound(_)) => Ok(ClassHash::default()), + Err(e) => Err(e)?, + } } fn get_compiled_class_hash(&self, _class_hash: ClassHash) -> StateResult { @@ -165,16 +164,6 @@ pub enum CompiledContractClass { V1(CasmContractClass), } -// Converts a serder error to the error type of the state reader. -fn serde_err_to_state_err(err: SerdeError) -> StateError { - StateError::StateReadError(format!("Failed to parse rpc result {:?}", err.to_string())) -} - -// Converts a reqwest error to the error type of the state reader. -fn reqwest_err_to_state_err(err: ReqwestError) -> StateError { - StateError::StateReadError(format!("Rpc request failed with error {:?}", err.to_string())) -} - pub struct RpcStateReaderFactory { pub config: RpcStateReaderConfig, }