Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP. #462

Closed
wants to merge 1 commit into from
Closed

WIP. #462

Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
150 changes: 115 additions & 35 deletions crates/gateway/src/config.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
use std::collections::BTreeMap;
use std::fmt;
use std::net::IpAddr;

use blockifier::context::{BlockContext, ChainInfo, FeeTokenAddresses};
use blockifier::context::{BlockContext, ChainInfo as BlockifierChainInfo, FeeTokenAddresses};
use papyrus_config::dumping::{append_sub_config_name, ser_param, SerializeConfig};
use papyrus_config::{ParamPath, ParamPrivacyInput, SerializedParam};
use serde::de::{MapAccess, Visitor};
use serde::ser::SerializeStruct;
use serde::{Deserialize, Serialize};
use starknet_api::core::{ChainId, ContractAddress, Nonce};
use starknet_api::core::Nonce;
use starknet_types_core::felt::Felt;
use validator::Validate;

Expand Down Expand Up @@ -175,63 +178,140 @@ impl SerializeConfig for RpcStateReaderConfig {
}

// TODO(Arni): Remove this struct once Chain info supports Papyrus serialization.
#[derive(Clone, Debug, Serialize, Deserialize, Validate, PartialEq)]
pub struct ChainInfoConfig {
pub chain_id: ChainId,
pub strk_fee_token_address: ContractAddress,
pub eth_fee_token_address: ContractAddress,
}
#[derive(Clone, Debug, Default)]
pub struct ChainInfo(pub BlockifierChainInfo);

impl From<ChainInfoConfig> for ChainInfo {
fn from(chain_info: ChainInfoConfig) -> Self {
Self {
chain_id: chain_info.chain_id,
fee_token_addresses: FeeTokenAddresses {
strk_fee_token_address: chain_info.strk_fee_token_address,
eth_fee_token_address: chain_info.eth_fee_token_address,
},
}
impl ChainInfo {
pub fn create_for_testing() -> Self {
Self(BlockContext::create_for_testing().chain_info().clone())
}
}

impl From<ChainInfo> for ChainInfoConfig {
fn from(chain_info: ChainInfo) -> Self {
let FeeTokenAddresses { strk_fee_token_address, eth_fee_token_address } =
chain_info.fee_token_addresses;
Self { chain_id: chain_info.chain_id, strk_fee_token_address, eth_fee_token_address }
// TODO(Arni): Remove this once Chain info derives PartialEq.
impl PartialEq for ChainInfo {
fn eq(&self, other: &Self) -> bool {
fn eq(lhs: &BlockifierChainInfo, rhs: &BlockifierChainInfo) -> bool {
lhs.chain_id == rhs.chain_id
&& lhs.fee_token_addresses.strk_fee_token_address
== rhs.fee_token_addresses.strk_fee_token_address
&& lhs.fee_token_addresses.eth_fee_token_address
== rhs.fee_token_addresses.eth_fee_token_address
}
eq(&self.0, &other.0)
}
}

impl Default for ChainInfoConfig {
fn default() -> Self {
ChainInfo::default().into()
impl<'de> Deserialize<'de> for ChainInfo {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
struct ChainInfoVisitor;

impl<'de> Visitor<'de> for ChainInfoVisitor {
type Value = ChainInfo;

fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> std::fmt::Result {
formatter.write_str("struct ChainInfo")
}

fn visit_map<V>(self, mut map: V) -> Result<Self::Value, V::Error>
where
V: MapAccess<'de>,
{
let mut chain_id = None;
let mut strk_fee_token_address = None;
let mut eth_fee_token_address = None;

while let Some(key) = map.next_key()? {
match key {
"chain_id" => {
if chain_id.is_some() {
return Err(serde::de::Error::duplicate_field("chain_id"));
}
chain_id = Some(map.next_value()?);
}
"strk_fee_token_address" => {
if strk_fee_token_address.is_some() {
return Err(serde::de::Error::duplicate_field(
"strk_fee_token_address",
));
}
strk_fee_token_address = Some(map.next_value()?);
}
"eth_fee_token_address" => {
if eth_fee_token_address.is_some() {
return Err(serde::de::Error::duplicate_field(
"eth_fee_token_address",
));
}
eth_fee_token_address = Some(map.next_value()?);
}
_ => {
return Err(serde::de::Error::unknown_field(key, &["chain_id"]));
}
}
}

let chain_id =
chain_id.ok_or_else(|| serde::de::Error::missing_field("chain_id"))?;
let strk_fee_token_address = strk_fee_token_address
.ok_or_else(|| serde::de::Error::missing_field("strk_fee_token_address"))?;
let eth_fee_token_address = eth_fee_token_address
.ok_or_else(|| serde::de::Error::missing_field("eth_fee_token_address"))?;

Ok(ChainInfo(BlockifierChainInfo {
chain_id,
fee_token_addresses: FeeTokenAddresses {
strk_fee_token_address,
eth_fee_token_address,
},
}))
}
}

const FIELDS: &[&str] = &["chain_id", "strk_fee_token_address", "eth_fee_token_address"];
deserializer.deserialize_struct("ChainInfo", FIELDS, ChainInfoVisitor)
}
}

impl ChainInfoConfig {
pub fn create_for_testing() -> Self {
BlockContext::create_for_testing().chain_info().clone().into()
impl Serialize for ChainInfo {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let mut s = serializer.serialize_struct("ChainInfo", 3)?;
s.serialize_field("chain_id", &self.0.chain_id)?;
s.serialize_field(
"strk_fee_token_address",
&self.0.fee_token_addresses.strk_fee_token_address,
)?;
s.serialize_field(
"eth_fee_token_address",
&self.0.fee_token_addresses.eth_fee_token_address,
)?;
s.end()
}
}

impl SerializeConfig for ChainInfoConfig {
impl SerializeConfig for ChainInfo {
fn dump(&self) -> BTreeMap<ParamPath, SerializedParam> {
BTreeMap::from_iter([
ser_param(
"chain_id",
&self.chain_id,
&self.0.chain_id,
"The chain ID of the StarkNet chain.",
ParamPrivacyInput::Public,
),
ser_param(
"strk_fee_token_address",
&self.strk_fee_token_address,
&self.0.fee_token_addresses.strk_fee_token_address,
"Address of the STRK fee token.",
ParamPrivacyInput::Public,
),
ser_param(
"eth_fee_token_address",
&self.eth_fee_token_address,
&self.0.fee_token_addresses.eth_fee_token_address,
"Address of the ETH fee token.",
ParamPrivacyInput::Public,
),
Expand All @@ -244,7 +324,7 @@ pub struct StatefulTransactionValidatorConfig {
pub max_nonce_for_validation_skip: Nonce,
pub validate_max_n_steps: u32,
pub max_recursion_depth: usize,
pub chain_info: ChainInfoConfig,
pub chain_info: ChainInfo,
}

impl Default for StatefulTransactionValidatorConfig {
Expand All @@ -253,7 +333,7 @@ impl Default for StatefulTransactionValidatorConfig {
max_nonce_for_validation_skip: Nonce(Felt::ONE),
validate_max_n_steps: 1_000_000,
max_recursion_depth: 50,
chain_info: ChainInfoConfig::default(),
chain_info: ChainInfo::default(),
}
}
}
Expand Down Expand Up @@ -291,7 +371,7 @@ impl StatefulTransactionValidatorConfig {
max_nonce_for_validation_skip: Default::default(),
validate_max_n_steps: 1000000,
max_recursion_depth: 50,
chain_info: ChainInfoConfig::create_for_testing(),
chain_info: ChainInfo::create_for_testing(),
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions crates/gateway/src/gateway_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ use axum::body::{Bytes, HttpBody};
use axum::extract::State;
use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use blockifier::context::ChainInfo;
use blockifier::context::ChainInfo as BlockifierChainInfo;
use blockifier::test_utils::CairoVersion;
use mempool_test_utils::starknet_api_test_utils::{declare_tx, deploy_account_tx, invoke_tx};
use rstest::{fixture, rstest};
Expand Down Expand Up @@ -125,7 +125,7 @@ fn calculate_hash(
let account_tx = external_tx_to_account_tx(
external_tx,
optional_class_info,
&ChainInfo::create_for_testing().chain_id,
&BlockifierChainInfo::create_for_testing().chain_id,
)
.unwrap();
get_tx_hash(&account_tx)
Expand Down
4 changes: 2 additions & 2 deletions crates/gateway/src/stateful_transaction_validator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ impl StatefulTransactionValidator {
let account_tx = external_tx_to_account_tx(
external_tx,
optional_class_info,
&self.config.chain_info.chain_id,
&self.config.chain_info.0.chain_id,
)?;
let tx_hash = get_tx_hash(&account_tx);

Expand Down Expand Up @@ -68,7 +68,7 @@ impl StatefulTransactionValidator {
// able to read the block_hash of 10 blocks ago from papyrus.
let block_context = BlockContext::new(
block_info,
self.config.chain_info.clone().into(),
self.config.chain_info.0.clone(),
versioned_constants,
BouncerConfig::max(),
);
Expand Down
8 changes: 4 additions & 4 deletions crates/gateway/src/stateful_transaction_validator_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use starknet_api::transaction::TransactionHash;
use starknet_types_core::felt::Felt;

use crate::compilation::GatewayCompiler;
use crate::config::{GatewayCompilerConfig, StatefulTransactionValidatorConfig};
use crate::config::{ChainInfo, GatewayCompilerConfig, StatefulTransactionValidatorConfig};
use crate::errors::{StatefulTransactionValidatorError, StatefulTransactionValidatorResult};
use crate::state_reader_test_utils::{
local_test_state_reader_factory, local_test_state_reader_factory_for_deploy_account,
Expand All @@ -39,7 +39,7 @@ fn stateful_validator(block_context: BlockContext) -> StatefulTransactionValidat
max_nonce_for_validation_skip: Default::default(),
validate_max_n_steps: block_context.versioned_constants().validate_max_n_steps,
max_recursion_depth: block_context.versioned_constants().max_recursion_depth,
chain_info: block_context.chain_info().clone().into(),
chain_info: ChainInfo(block_context.chain_info().clone()),
},
}
}
Expand Down Expand Up @@ -118,7 +118,7 @@ fn test_instantiate_validator() {
max_nonce_for_validation_skip: Default::default(),
validate_max_n_steps: block_context.versioned_constants().validate_max_n_steps,
max_recursion_depth: block_context.versioned_constants().max_recursion_depth,
chain_info: block_context.chain_info().clone().into(),
chain_info: ChainInfo(block_context.chain_info().clone()),
},
};
let blockifier_validator = stateful_validator.instantiate_validator(&state_reader_factory);
Expand Down Expand Up @@ -161,7 +161,7 @@ fn test_skip_stateful_validation(
// To be sure that the validations were actually skipped, we check that the error came from
// the blockifier stateful validations, and not from the pre validations since those are
// executed also when skip_validate is true.
assert_matches!(result, Err(StatefulTransactionValidatorError::StatefulValidatorError(err))
assert_matches!(result, Err(StatefulTransactionValidatorError::StatefulValidatorError(err))
if !matches!(err, StatefulValidatorError::TransactionPreValidationError(_)));
}
}
Expand Down
10 changes: 5 additions & 5 deletions crates/tests-integration/src/state_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::net::SocketAddr;
use std::sync::{Arc, OnceLock};

use blockifier::abi::abi_utils::get_fee_token_var_address;
use blockifier::context::{BlockContext, ChainInfo};
use blockifier::context::{BlockContext, ChainInfo as BlockifierChainInfo};
use blockifier::test_utils::contracts::FeatureContract;
use blockifier::test_utils::{
CairoVersion, BALANCE, CURRENT_BLOCK_TIMESTAMP, DEFAULT_ETH_L1_GAS_PRICE,
Expand Down Expand Up @@ -78,7 +78,7 @@ pub async fn spawn_test_rpc_state_reader(n_accounts: usize) -> SocketAddr {
}

fn initialize_papyrus_test_state(
chain_info: &ChainInfo,
chain_info: &BlockifierChainInfo,
initial_balances: u128,
contract_instances: &[(FeatureContract, usize)],
fund_additional_accounts: Vec<ContractAddress>,
Expand All @@ -97,7 +97,7 @@ fn initialize_papyrus_test_state(
}

fn prepare_state_diff(
chain_info: &ChainInfo,
chain_info: &BlockifierChainInfo,
contract_instances: &[(FeatureContract, usize)],
initial_balances: u128,
fund_accounts: Vec<ContractAddress>,
Expand Down Expand Up @@ -232,7 +232,7 @@ fn fund_feature_account_contract(
contract: &FeatureContract,
instance: u16,
initial_balances: u128,
chain_info: &ChainInfo,
chain_info: &BlockifierChainInfo,
) {
match contract {
FeatureContract::AccountWithLongValidate(_)
Expand All @@ -253,7 +253,7 @@ fn fund_account(
storage_diffs: &mut IndexMap<ContractAddress, IndexMap<StorageKey, Felt>>,
account_address: &ContractAddress,
initial_balances: u128,
chain_info: &ChainInfo,
chain_info: &BlockifierChainInfo,
) {
let key_value = indexmap! {
get_fee_token_var_address(*account_address) => felt!(initial_balances),
Expand Down
Loading