diff --git a/crates/gateway/src/config.rs b/crates/gateway/src/config.rs index 49970fa7a..9a00042dc 100644 --- a/crates/gateway/src/config.rs +++ b/crates/gateway/src/config.rs @@ -1,11 +1,14 @@ use std::collections::BTreeMap; +use std::fmt; use std::net::IpAddr; -use blockifier::context::{BlockContext, ChainInfo, FeeTokenAddresses}; +use blockifier::context::{BlockContext, ChainInfo as BlockifierChainInfo, FeeTokenAddresses}; use papyrus_config::dumping::{append_sub_config_name, ser_param, SerializeConfig}; use papyrus_config::{ParamPath, ParamPrivacyInput, SerializedParam}; +use serde::de::{MapAccess, Visitor}; +use serde::ser::SerializeStruct; use serde::{Deserialize, Serialize}; -use starknet_api::core::{ChainId, ContractAddress, Nonce}; +use starknet_api::core::Nonce; use starknet_types_core::felt::Felt; use validator::Validate; @@ -175,63 +178,140 @@ impl SerializeConfig for RpcStateReaderConfig { } // TODO(Arni): Remove this struct once Chain info supports Papyrus serialization. -#[derive(Clone, Debug, Serialize, Deserialize, Validate, PartialEq)] -pub struct ChainInfoConfig { - pub chain_id: ChainId, - pub strk_fee_token_address: ContractAddress, - pub eth_fee_token_address: ContractAddress, -} +#[derive(Clone, Debug, Default)] +pub struct ChainInfo(pub BlockifierChainInfo); -impl From for ChainInfo { - fn from(chain_info: ChainInfoConfig) -> Self { - Self { - chain_id: chain_info.chain_id, - fee_token_addresses: FeeTokenAddresses { - strk_fee_token_address: chain_info.strk_fee_token_address, - eth_fee_token_address: chain_info.eth_fee_token_address, - }, - } +impl ChainInfo { + pub fn create_for_testing() -> Self { + Self(BlockContext::create_for_testing().chain_info().clone()) } } -impl From for ChainInfoConfig { - fn from(chain_info: ChainInfo) -> Self { - let FeeTokenAddresses { strk_fee_token_address, eth_fee_token_address } = - chain_info.fee_token_addresses; - Self { chain_id: chain_info.chain_id, strk_fee_token_address, eth_fee_token_address } +// TODO(Arni): Remove this once Chain info derives PartialEq. +impl PartialEq for ChainInfo { + fn eq(&self, other: &Self) -> bool { + fn eq(lhs: &BlockifierChainInfo, rhs: &BlockifierChainInfo) -> bool { + lhs.chain_id == rhs.chain_id + && lhs.fee_token_addresses.strk_fee_token_address + == rhs.fee_token_addresses.strk_fee_token_address + && lhs.fee_token_addresses.eth_fee_token_address + == rhs.fee_token_addresses.eth_fee_token_address + } + eq(&self.0, &other.0) } } -impl Default for ChainInfoConfig { - fn default() -> Self { - ChainInfo::default().into() +impl<'de> Deserialize<'de> for ChainInfo { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + struct ChainInfoVisitor; + + impl<'de> Visitor<'de> for ChainInfoVisitor { + type Value = ChainInfo; + + fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> std::fmt::Result { + formatter.write_str("struct ChainInfo") + } + + fn visit_map(self, mut map: V) -> Result + where + V: MapAccess<'de>, + { + let mut chain_id = None; + let mut strk_fee_token_address = None; + let mut eth_fee_token_address = None; + + while let Some(key) = map.next_key()? { + match key { + "chain_id" => { + if chain_id.is_some() { + return Err(serde::de::Error::duplicate_field("chain_id")); + } + chain_id = Some(map.next_value()?); + } + "strk_fee_token_address" => { + if strk_fee_token_address.is_some() { + return Err(serde::de::Error::duplicate_field( + "strk_fee_token_address", + )); + } + strk_fee_token_address = Some(map.next_value()?); + } + "eth_fee_token_address" => { + if eth_fee_token_address.is_some() { + return Err(serde::de::Error::duplicate_field( + "eth_fee_token_address", + )); + } + eth_fee_token_address = Some(map.next_value()?); + } + _ => { + return Err(serde::de::Error::unknown_field(key, &["chain_id"])); + } + } + } + + let chain_id = + chain_id.ok_or_else(|| serde::de::Error::missing_field("chain_id"))?; + let strk_fee_token_address = strk_fee_token_address + .ok_or_else(|| serde::de::Error::missing_field("strk_fee_token_address"))?; + let eth_fee_token_address = eth_fee_token_address + .ok_or_else(|| serde::de::Error::missing_field("eth_fee_token_address"))?; + + Ok(ChainInfo(BlockifierChainInfo { + chain_id, + fee_token_addresses: FeeTokenAddresses { + strk_fee_token_address, + eth_fee_token_address, + }, + })) + } + } + + const FIELDS: &[&str] = &["chain_id", "strk_fee_token_address", "eth_fee_token_address"]; + deserializer.deserialize_struct("ChainInfo", FIELDS, ChainInfoVisitor) } } -impl ChainInfoConfig { - pub fn create_for_testing() -> Self { - BlockContext::create_for_testing().chain_info().clone().into() +impl Serialize for ChainInfo { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + let mut s = serializer.serialize_struct("ChainInfo", 3)?; + s.serialize_field("chain_id", &self.0.chain_id)?; + s.serialize_field( + "strk_fee_token_address", + &self.0.fee_token_addresses.strk_fee_token_address, + )?; + s.serialize_field( + "eth_fee_token_address", + &self.0.fee_token_addresses.eth_fee_token_address, + )?; + s.end() } } -impl SerializeConfig for ChainInfoConfig { +impl SerializeConfig for ChainInfo { fn dump(&self) -> BTreeMap { BTreeMap::from_iter([ ser_param( "chain_id", - &self.chain_id, + &self.0.chain_id, "The chain ID of the StarkNet chain.", ParamPrivacyInput::Public, ), ser_param( "strk_fee_token_address", - &self.strk_fee_token_address, + &self.0.fee_token_addresses.strk_fee_token_address, "Address of the STRK fee token.", ParamPrivacyInput::Public, ), ser_param( "eth_fee_token_address", - &self.eth_fee_token_address, + &self.0.fee_token_addresses.eth_fee_token_address, "Address of the ETH fee token.", ParamPrivacyInput::Public, ), @@ -244,7 +324,7 @@ pub struct StatefulTransactionValidatorConfig { pub max_nonce_for_validation_skip: Nonce, pub validate_max_n_steps: u32, pub max_recursion_depth: usize, - pub chain_info: ChainInfoConfig, + pub chain_info: ChainInfo, } impl Default for StatefulTransactionValidatorConfig { @@ -253,7 +333,7 @@ impl Default for StatefulTransactionValidatorConfig { max_nonce_for_validation_skip: Nonce(Felt::ONE), validate_max_n_steps: 1_000_000, max_recursion_depth: 50, - chain_info: ChainInfoConfig::default(), + chain_info: ChainInfo::default(), } } } @@ -291,7 +371,7 @@ impl StatefulTransactionValidatorConfig { max_nonce_for_validation_skip: Default::default(), validate_max_n_steps: 1000000, max_recursion_depth: 50, - chain_info: ChainInfoConfig::create_for_testing(), + chain_info: ChainInfo::create_for_testing(), } } } diff --git a/crates/gateway/src/gateway_test.rs b/crates/gateway/src/gateway_test.rs index 5216c8797..db66d0b57 100644 --- a/crates/gateway/src/gateway_test.rs +++ b/crates/gateway/src/gateway_test.rs @@ -4,7 +4,7 @@ use axum::body::{Bytes, HttpBody}; use axum::extract::State; use axum::http::StatusCode; use axum::response::{IntoResponse, Response}; -use blockifier::context::ChainInfo; +use blockifier::context::ChainInfo as BlockifierChainInfo; use blockifier::test_utils::CairoVersion; use mempool_test_utils::starknet_api_test_utils::{declare_tx, deploy_account_tx, invoke_tx}; use rstest::{fixture, rstest}; @@ -125,7 +125,7 @@ fn calculate_hash( let account_tx = external_tx_to_account_tx( external_tx, optional_class_info, - &ChainInfo::create_for_testing().chain_id, + &BlockifierChainInfo::create_for_testing().chain_id, ) .unwrap(); get_tx_hash(&account_tx) diff --git a/crates/gateway/src/stateful_transaction_validator.rs b/crates/gateway/src/stateful_transaction_validator.rs index a14018664..b41dad174 100644 --- a/crates/gateway/src/stateful_transaction_validator.rs +++ b/crates/gateway/src/stateful_transaction_validator.rs @@ -35,7 +35,7 @@ impl StatefulTransactionValidator { let account_tx = external_tx_to_account_tx( external_tx, optional_class_info, - &self.config.chain_info.chain_id, + &self.config.chain_info.0.chain_id, )?; let tx_hash = get_tx_hash(&account_tx); @@ -68,7 +68,7 @@ impl StatefulTransactionValidator { // able to read the block_hash of 10 blocks ago from papyrus. let block_context = BlockContext::new( block_info, - self.config.chain_info.clone().into(), + self.config.chain_info.0.clone(), versioned_constants, BouncerConfig::max(), ); diff --git a/crates/gateway/src/stateful_transaction_validator_test.rs b/crates/gateway/src/stateful_transaction_validator_test.rs index bd1216075..e45515738 100644 --- a/crates/gateway/src/stateful_transaction_validator_test.rs +++ b/crates/gateway/src/stateful_transaction_validator_test.rs @@ -19,7 +19,7 @@ use starknet_api::transaction::TransactionHash; use starknet_types_core::felt::Felt; use crate::compilation::GatewayCompiler; -use crate::config::{GatewayCompilerConfig, StatefulTransactionValidatorConfig}; +use crate::config::{ChainInfo, GatewayCompilerConfig, StatefulTransactionValidatorConfig}; use crate::errors::{StatefulTransactionValidatorError, StatefulTransactionValidatorResult}; use crate::state_reader_test_utils::{ local_test_state_reader_factory, local_test_state_reader_factory_for_deploy_account, @@ -39,7 +39,7 @@ fn stateful_validator(block_context: BlockContext) -> StatefulTransactionValidat max_nonce_for_validation_skip: Default::default(), validate_max_n_steps: block_context.versioned_constants().validate_max_n_steps, max_recursion_depth: block_context.versioned_constants().max_recursion_depth, - chain_info: block_context.chain_info().clone().into(), + chain_info: ChainInfo(block_context.chain_info().clone()), }, } } @@ -118,7 +118,7 @@ fn test_instantiate_validator() { max_nonce_for_validation_skip: Default::default(), validate_max_n_steps: block_context.versioned_constants().validate_max_n_steps, max_recursion_depth: block_context.versioned_constants().max_recursion_depth, - chain_info: block_context.chain_info().clone().into(), + chain_info: ChainInfo(block_context.chain_info().clone()), }, }; let blockifier_validator = stateful_validator.instantiate_validator(&state_reader_factory); @@ -161,7 +161,7 @@ fn test_skip_stateful_validation( // To be sure that the validations were actually skipped, we check that the error came from // the blockifier stateful validations, and not from the pre validations since those are // executed also when skip_validate is true. - assert_matches!(result, Err(StatefulTransactionValidatorError::StatefulValidatorError(err)) + assert_matches!(result, Err(StatefulTransactionValidatorError::StatefulValidatorError(err)) if !matches!(err, StatefulValidatorError::TransactionPreValidationError(_))); } } diff --git a/crates/tests-integration/src/state_reader.rs b/crates/tests-integration/src/state_reader.rs index 4b3841845..5ed84b7cb 100644 --- a/crates/tests-integration/src/state_reader.rs +++ b/crates/tests-integration/src/state_reader.rs @@ -2,7 +2,7 @@ use std::net::SocketAddr; use std::sync::{Arc, OnceLock}; use blockifier::abi::abi_utils::get_fee_token_var_address; -use blockifier::context::{BlockContext, ChainInfo}; +use blockifier::context::{BlockContext, ChainInfo as BlockifierChainInfo}; use blockifier::test_utils::contracts::FeatureContract; use blockifier::test_utils::{ CairoVersion, BALANCE, CURRENT_BLOCK_TIMESTAMP, DEFAULT_ETH_L1_GAS_PRICE, @@ -78,7 +78,7 @@ pub async fn spawn_test_rpc_state_reader(n_accounts: usize) -> SocketAddr { } fn initialize_papyrus_test_state( - chain_info: &ChainInfo, + chain_info: &BlockifierChainInfo, initial_balances: u128, contract_instances: &[(FeatureContract, usize)], fund_additional_accounts: Vec, @@ -97,7 +97,7 @@ fn initialize_papyrus_test_state( } fn prepare_state_diff( - chain_info: &ChainInfo, + chain_info: &BlockifierChainInfo, contract_instances: &[(FeatureContract, usize)], initial_balances: u128, fund_accounts: Vec, @@ -232,7 +232,7 @@ fn fund_feature_account_contract( contract: &FeatureContract, instance: u16, initial_balances: u128, - chain_info: &ChainInfo, + chain_info: &BlockifierChainInfo, ) { match contract { FeatureContract::AccountWithLongValidate(_) @@ -253,7 +253,7 @@ fn fund_account( storage_diffs: &mut IndexMap>, account_address: &ContractAddress, initial_balances: u128, - chain_info: &ChainInfo, + chain_info: &BlockifierChainInfo, ) { let key_value = indexmap! { get_fee_token_var_address(*account_address) => felt!(initial_balances),