Skip to content

Commit

Permalink
WIP.
Browse files Browse the repository at this point in the history
  • Loading branch information
ArniStarkware committed Jul 14, 2024
1 parent a5f6540 commit b504b24
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 48 deletions.
150 changes: 115 additions & 35 deletions crates/gateway/src/config.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
use std::collections::BTreeMap;
use std::fmt;
use std::net::IpAddr;

use blockifier::context::{BlockContext, ChainInfo, FeeTokenAddresses};
use blockifier::context::{BlockContext, ChainInfo as BlockifierChainInfo, FeeTokenAddresses};
use papyrus_config::dumping::{append_sub_config_name, ser_param, SerializeConfig};
use papyrus_config::{ParamPath, ParamPrivacyInput, SerializedParam};
use serde::de::{MapAccess, Visitor};
use serde::ser::SerializeStruct;
use serde::{Deserialize, Serialize};
use starknet_api::core::{ChainId, ContractAddress, Nonce};
use starknet_api::core::Nonce;
use starknet_types_core::felt::Felt;
use validator::Validate;

Expand Down Expand Up @@ -175,63 +178,140 @@ impl SerializeConfig for RpcStateReaderConfig {
}

// TODO(Arni): Remove this struct once Chain info supports Papyrus serialization.
#[derive(Clone, Debug, Serialize, Deserialize, Validate, PartialEq)]
pub struct ChainInfoConfig {
pub chain_id: ChainId,
pub strk_fee_token_address: ContractAddress,
pub eth_fee_token_address: ContractAddress,
}
#[derive(Clone, Debug, Default)]
pub struct ChainInfo(pub BlockifierChainInfo);

impl From<ChainInfoConfig> for ChainInfo {
fn from(chain_info: ChainInfoConfig) -> Self {
Self {
chain_id: chain_info.chain_id,
fee_token_addresses: FeeTokenAddresses {
strk_fee_token_address: chain_info.strk_fee_token_address,
eth_fee_token_address: chain_info.eth_fee_token_address,
},
}
impl ChainInfo {
pub fn create_for_testing() -> Self {
Self(BlockContext::create_for_testing().chain_info().clone())
}
}

impl From<ChainInfo> for ChainInfoConfig {
fn from(chain_info: ChainInfo) -> Self {
let FeeTokenAddresses { strk_fee_token_address, eth_fee_token_address } =
chain_info.fee_token_addresses;
Self { chain_id: chain_info.chain_id, strk_fee_token_address, eth_fee_token_address }
// TODO(Arni): Remove this once Chain info derives PartialEq.
impl PartialEq for ChainInfo {
fn eq(&self, other: &Self) -> bool {
fn eq(lhs: &BlockifierChainInfo, rhs: &BlockifierChainInfo) -> bool {
lhs.chain_id == rhs.chain_id
&& lhs.fee_token_addresses.strk_fee_token_address
== rhs.fee_token_addresses.strk_fee_token_address
&& lhs.fee_token_addresses.eth_fee_token_address
== rhs.fee_token_addresses.eth_fee_token_address
}
eq(&self.0, &other.0)
}
}

impl Default for ChainInfoConfig {
fn default() -> Self {
ChainInfo::default().into()
impl<'de> Deserialize<'de> for ChainInfo {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
struct ChainInfoVisitor;

impl<'de> Visitor<'de> for ChainInfoVisitor {
type Value = ChainInfo;

fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> std::fmt::Result {
formatter.write_str("struct ChainInfo")
}

fn visit_map<V>(self, mut map: V) -> Result<Self::Value, V::Error>
where
V: MapAccess<'de>,
{
let mut chain_id = None;
let mut strk_fee_token_address = None;
let mut eth_fee_token_address = None;

while let Some(key) = map.next_key()? {
match key {
"chain_id" => {
if chain_id.is_some() {
return Err(serde::de::Error::duplicate_field("chain_id"));
}
chain_id = Some(map.next_value()?);
}
"strk_fee_token_address" => {
if strk_fee_token_address.is_some() {
return Err(serde::de::Error::duplicate_field(
"strk_fee_token_address",
));
}
strk_fee_token_address = Some(map.next_value()?);
}
"eth_fee_token_address" => {
if eth_fee_token_address.is_some() {
return Err(serde::de::Error::duplicate_field(
"eth_fee_token_address",
));
}
eth_fee_token_address = Some(map.next_value()?);
}
_ => {
return Err(serde::de::Error::unknown_field(key, &["chain_id"]));
}
}
}

let chain_id =
chain_id.ok_or_else(|| serde::de::Error::missing_field("chain_id"))?;
let strk_fee_token_address = strk_fee_token_address
.ok_or_else(|| serde::de::Error::missing_field("strk_fee_token_address"))?;
let eth_fee_token_address = eth_fee_token_address
.ok_or_else(|| serde::de::Error::missing_field("eth_fee_token_address"))?;

Ok(ChainInfo(BlockifierChainInfo {
chain_id,
fee_token_addresses: FeeTokenAddresses {
strk_fee_token_address,
eth_fee_token_address,
},
}))
}
}

const FIELDS: &[&str] = &["chain_id", "strk_fee_token_address", "eth_fee_token_address"];
deserializer.deserialize_struct("ChainInfo", FIELDS, ChainInfoVisitor)
}
}

impl ChainInfoConfig {
pub fn create_for_testing() -> Self {
BlockContext::create_for_testing().chain_info().clone().into()
impl Serialize for ChainInfo {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let mut s = serializer.serialize_struct("ChainInfo", 3)?;
s.serialize_field("chain_id", &self.0.chain_id)?;
s.serialize_field(
"strk_fee_token_address",
&self.0.fee_token_addresses.strk_fee_token_address,
)?;
s.serialize_field(
"eth_fee_token_address",
&self.0.fee_token_addresses.eth_fee_token_address,
)?;
s.end()
}
}

impl SerializeConfig for ChainInfoConfig {
impl SerializeConfig for ChainInfo {
fn dump(&self) -> BTreeMap<ParamPath, SerializedParam> {
BTreeMap::from_iter([
ser_param(
"chain_id",
&self.chain_id,
&self.0.chain_id,
"The chain ID of the StarkNet chain.",
ParamPrivacyInput::Public,
),
ser_param(
"strk_fee_token_address",
&self.strk_fee_token_address,
&self.0.fee_token_addresses.strk_fee_token_address,
"Address of the STRK fee token.",
ParamPrivacyInput::Public,
),
ser_param(
"eth_fee_token_address",
&self.eth_fee_token_address,
&self.0.fee_token_addresses.eth_fee_token_address,
"Address of the ETH fee token.",
ParamPrivacyInput::Public,
),
Expand All @@ -244,7 +324,7 @@ pub struct StatefulTransactionValidatorConfig {
pub max_nonce_for_validation_skip: Nonce,
pub validate_max_n_steps: u32,
pub max_recursion_depth: usize,
pub chain_info: ChainInfoConfig,
pub chain_info: ChainInfo,
}

impl Default for StatefulTransactionValidatorConfig {
Expand All @@ -253,7 +333,7 @@ impl Default for StatefulTransactionValidatorConfig {
max_nonce_for_validation_skip: Nonce(Felt::ONE),
validate_max_n_steps: 1_000_000,
max_recursion_depth: 50,
chain_info: ChainInfoConfig::default(),
chain_info: ChainInfo::default(),
}
}
}
Expand Down Expand Up @@ -291,7 +371,7 @@ impl StatefulTransactionValidatorConfig {
max_nonce_for_validation_skip: Default::default(),
validate_max_n_steps: 1000000,
max_recursion_depth: 50,
chain_info: ChainInfoConfig::create_for_testing(),
chain_info: ChainInfo::create_for_testing(),
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions crates/gateway/src/gateway_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use axum::body::{Bytes, HttpBody};
use axum::extract::State;
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use blockifier::context::ChainInfo;
use blockifier::context::ChainInfo as BlockifierChainInfo;
use blockifier::test_utils::CairoVersion;
use mempool_test_utils::starknet_api_test_utils::{declare_tx, deploy_account_tx, invoke_tx};
use rstest::{fixture, rstest};
Expand Down Expand Up @@ -125,7 +125,7 @@ fn calculate_hash(
let account_tx = external_tx_to_account_tx(
external_tx,
optional_class_info,
&ChainInfo::create_for_testing().chain_id,
&BlockifierChainInfo::create_for_testing().chain_id,
)
.unwrap();
get_tx_hash(&account_tx)
Expand Down
4 changes: 2 additions & 2 deletions crates/gateway/src/stateful_transaction_validator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ impl StatefulTransactionValidator {
let account_tx = external_tx_to_account_tx(
external_tx,
optional_class_info,
&self.config.chain_info.chain_id,
&self.config.chain_info.0.chain_id,
)?;
let tx_hash = get_tx_hash(&account_tx);

Expand Down Expand Up @@ -68,7 +68,7 @@ impl StatefulTransactionValidator {
// able to read the block_hash of 10 blocks ago from papyrus.
let block_context = BlockContext::new(
block_info,
self.config.chain_info.clone().into(),
self.config.chain_info.0.clone(),
versioned_constants,
BouncerConfig::max(),
);
Expand Down
8 changes: 4 additions & 4 deletions crates/gateway/src/stateful_transaction_validator_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use starknet_api::transaction::TransactionHash;
use starknet_types_core::felt::Felt;

use crate::compilation::GatewayCompiler;
use crate::config::{GatewayCompilerConfig, StatefulTransactionValidatorConfig};
use crate::config::{ChainInfo, GatewayCompilerConfig, StatefulTransactionValidatorConfig};
use crate::errors::{StatefulTransactionValidatorError, StatefulTransactionValidatorResult};
use crate::state_reader_test_utils::{
local_test_state_reader_factory, local_test_state_reader_factory_for_deploy_account,
Expand All @@ -39,7 +39,7 @@ fn stateful_validator(block_context: BlockContext) -> StatefulTransactionValidat
max_nonce_for_validation_skip: Default::default(),
validate_max_n_steps: block_context.versioned_constants().validate_max_n_steps,
max_recursion_depth: block_context.versioned_constants().max_recursion_depth,
chain_info: block_context.chain_info().clone().into(),
chain_info: ChainInfo(block_context.chain_info().clone()),
},
}
}
Expand Down Expand Up @@ -118,7 +118,7 @@ fn test_instantiate_validator() {
max_nonce_for_validation_skip: Default::default(),
validate_max_n_steps: block_context.versioned_constants().validate_max_n_steps,
max_recursion_depth: block_context.versioned_constants().max_recursion_depth,
chain_info: block_context.chain_info().clone().into(),
chain_info: ChainInfo(block_context.chain_info().clone()),
},
};
let blockifier_validator = stateful_validator.instantiate_validator(&state_reader_factory);
Expand Down Expand Up @@ -161,7 +161,7 @@ fn test_skip_stateful_validation(
// To be sure that the validations were actually skipped, we check that the error came from
// the blockifier stateful validations, and not from the pre validations since those are
// executed also when skip_validate is true.
assert_matches!(result, Err(StatefulTransactionValidatorError::StatefulValidatorError(err))
assert_matches!(result, Err(StatefulTransactionValidatorError::StatefulValidatorError(err))
if !matches!(err, StatefulValidatorError::TransactionPreValidationError(_)));
}
}
Expand Down
10 changes: 5 additions & 5 deletions crates/tests-integration/src/state_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::net::SocketAddr;
use std::sync::{Arc, OnceLock};

use blockifier::abi::abi_utils::get_fee_token_var_address;
use blockifier::context::{BlockContext, ChainInfo};
use blockifier::context::{BlockContext, ChainInfo as BlockifierChainInfo};
use blockifier::test_utils::contracts::FeatureContract;
use blockifier::test_utils::{
CairoVersion, BALANCE, CURRENT_BLOCK_TIMESTAMP, DEFAULT_ETH_L1_GAS_PRICE,
Expand Down Expand Up @@ -78,7 +78,7 @@ pub async fn spawn_test_rpc_state_reader(n_accounts: usize) -> SocketAddr {
}

fn initialize_papyrus_test_state(
chain_info: &ChainInfo,
chain_info: &BlockifierChainInfo,
initial_balances: u128,
contract_instances: &[(FeatureContract, usize)],
fund_additional_accounts: Vec<ContractAddress>,
Expand All @@ -97,7 +97,7 @@ fn initialize_papyrus_test_state(
}

fn prepare_state_diff(
chain_info: &ChainInfo,
chain_info: &BlockifierChainInfo,
contract_instances: &[(FeatureContract, usize)],
initial_balances: u128,
fund_accounts: Vec<ContractAddress>,
Expand Down Expand Up @@ -232,7 +232,7 @@ fn fund_feature_account_contract(
contract: &FeatureContract,
instance: u16,
initial_balances: u128,
chain_info: &ChainInfo,
chain_info: &BlockifierChainInfo,
) {
match contract {
FeatureContract::AccountWithLongValidate(_)
Expand All @@ -253,7 +253,7 @@ fn fund_account(
storage_diffs: &mut IndexMap<ContractAddress, IndexMap<StorageKey, Felt>>,
account_address: &ContractAddress,
initial_balances: u128,
chain_info: &ChainInfo,
chain_info: &BlockifierChainInfo,
) {
let key_value = indexmap! {
get_fee_token_var_address(*account_address) => felt!(initial_balances),
Expand Down

0 comments on commit b504b24

Please sign in to comment.