Skip to content

Commit

Permalink
feat: mock blockifier (#369)
Browse files Browse the repository at this point in the history
  • Loading branch information
yair-starkware authored Jul 23, 2024
1 parent d2a047f commit 7f2356e
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 58 deletions.
20 changes: 1 addition & 19 deletions crates/gateway/src/state_reader_test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,10 @@ use blockifier::state::errors::StateError;
use blockifier::state::state_api::{StateReader as BlockifierStateReader, StateResult};
use blockifier::test_utils::contracts::FeatureContract;
use blockifier::test_utils::dict_state_reader::DictStateReader;
use blockifier::test_utils::initial_test_state::{fund_account, test_state};
use blockifier::test_utils::initial_test_state::test_state;
use blockifier::test_utils::{CairoVersion, BALANCE};
use mempool_test_utils::starknet_api_test_utils::deployed_account_contract_address;
use starknet_api::block::BlockNumber;
use starknet_api::core::{ClassHash, CompiledClassHash, ContractAddress, Nonce};
use starknet_api::rpc_transaction::RPCTransaction;
use starknet_api::state::StorageKey;
use starknet_types_core::felt::Felt;

Expand Down Expand Up @@ -90,19 +88,3 @@ pub fn local_test_state_reader_factory(
},
}
}

pub fn local_test_state_reader_factory_for_deploy_account(
deploy_tx: &RPCTransaction,
) -> TestStateReaderFactory {
let mut state_reader_factory = local_test_state_reader_factory(CairoVersion::Cairo1, false);

// Fund the deployed_account_address.
let deployed_account_address = deployed_account_contract_address(deploy_tx);
fund_account(
BlockContext::create_for_testing().chain_info(),
deployed_account_address,
BALANCE,
&mut state_reader_factory.state_reader.blockifier_state_reader,
);
state_reader_factory
}
40 changes: 34 additions & 6 deletions crates/gateway/src/stateful_transaction_validator.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
use blockifier::blockifier::block::BlockInfo;
use blockifier::blockifier::stateful_validator::StatefulValidator;
use blockifier::blockifier::stateful_validator::{StatefulValidator, StatefulValidatorResult};
use blockifier::bouncer::BouncerConfig;
use blockifier::context::BlockContext;
use blockifier::execution::contract_class::ClassInfo;
use blockifier::state::cached_state::CachedState;
use blockifier::transaction::account_transaction::AccountTransaction;
use blockifier::versioned_constants::VersionedConstants;
use starknet_api::core::Nonce;
#[cfg(test)]
use mockall::automock;
use starknet_api::core::{ContractAddress, Nonce};
use starknet_api::rpc_transaction::{RPCInvokeTransaction, RPCTransaction};
use starknet_api::transaction::TransactionHash;
use starknet_types_core::felt::Felt;
Expand All @@ -25,23 +28,48 @@ pub struct StatefulTransactionValidator {

type BlockifierStatefulValidator = StatefulValidator<Box<dyn MempoolStateReader>>;

// TODO(yair): move the trait to Blockifier.
#[cfg_attr(test, automock)]
pub trait StatefulTransactionValidatorTrait {
fn validate(
&mut self,
account_tx: AccountTransaction,
skip_validate: bool,
) -> StatefulTransactionValidatorResult<()>;

fn get_nonce(&mut self, account_address: ContractAddress) -> StatefulValidatorResult<Nonce>;
}

impl StatefulTransactionValidatorTrait for BlockifierStatefulValidator {
fn validate(
&mut self,
account_tx: AccountTransaction,
skip_validate: bool,
) -> StatefulTransactionValidatorResult<()> {
Ok(self.perform_validations(account_tx, skip_validate)?)
}

fn get_nonce(&mut self, account_address: ContractAddress) -> StatefulValidatorResult<Nonce> {
self.get_nonce(account_address)
}
}

impl StatefulTransactionValidator {
pub fn run_validate(
pub fn run_validate<V: StatefulTransactionValidatorTrait>(
&self,
external_tx: &RPCTransaction,
optional_class_info: Option<ClassInfo>,
mut validator: BlockifierStatefulValidator,
mut validator: V,
) -> StatefulTransactionValidatorResult<TransactionHash> {
let account_tx = external_tx_to_account_tx(
external_tx,
optional_class_info,
&self.config.chain_info.chain_id,
)?;
let tx_hash = get_tx_hash(&account_tx);

let account_nonce = validator.get_nonce(get_sender_address(external_tx))?;
let skip_validate = skip_stateful_validations(external_tx, account_nonce);
validator.perform_validations(account_tx, skip_validate)?;
validator.validate(account_tx, skip_validate)?;
Ok(tx_hash)
}

Expand Down
47 changes: 14 additions & 33 deletions crates/gateway/src/stateful_transaction_validator_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ use blockifier::test_utils::CairoVersion;
use blockifier::transaction::errors::{TransactionFeeError, TransactionPreValidationError};
use mempool_test_utils::invoke_tx_args;
use mempool_test_utils::starknet_api_test_utils::{
declare_tx, deploy_account_tx, external_invoke_tx, invoke_tx, TEST_SENDER_ADDRESS,
VALID_L1_GAS_MAX_AMOUNT, VALID_L1_GAS_MAX_PRICE_PER_UNIT,
deploy_account_tx, external_invoke_tx, invoke_tx, TEST_SENDER_ADDRESS, VALID_L1_GAS_MAX_AMOUNT,
VALID_L1_GAS_MAX_PRICE_PER_UNIT,
};
use num_bigint::BigUint;
use pretty_assertions::assert_eq;
Expand All @@ -22,10 +22,11 @@ use crate::compilation::GatewayCompiler;
use crate::config::{GatewayCompilerConfig, StatefulTransactionValidatorConfig};
use crate::errors::{StatefulTransactionValidatorError, StatefulTransactionValidatorResult};
use crate::state_reader_test_utils::{
local_test_state_reader_factory, local_test_state_reader_factory_for_deploy_account,
TestStateReader, TestStateReaderFactory,
local_test_state_reader_factory, TestStateReader, TestStateReaderFactory,
};
use crate::stateful_transaction_validator::{
MockStatefulTransactionValidatorTrait, StatefulTransactionValidator,
};
use crate::stateful_transaction_validator::StatefulTransactionValidator;

#[fixture]
fn block_context() -> BlockContext {
Expand All @@ -45,37 +46,14 @@ fn stateful_validator(block_context: BlockContext) -> StatefulTransactionValidat
}

#[rstest]
#[case::valid_invoke_tx_cairo1(
#[case::valid_tx(
invoke_tx(CairoVersion::Cairo1),
local_test_state_reader_factory(CairoVersion::Cairo1, false),
Ok(TransactionHash(felt!(
"0x152b8dd0c30e95fa3a4ee7a9398fcfc46fb00c048b4fdcfa9958c64d65899b8"
)))
)]
#[case::valid_invoke_tx_cairo0(
invoke_tx(CairoVersion::Cairo0),
local_test_state_reader_factory(CairoVersion::Cairo0, false),
Ok(TransactionHash(felt!(
"0x39650ba8d14d8534957a415db496a7eea9e10a4cb06b018d4d24d0537bcc943"
)))
)]
#[case::valid_deploy_account_tx(
deploy_account_tx(),
local_test_state_reader_factory_for_deploy_account(&external_tx),
Ok(TransactionHash(felt!(
"0xe9ad58949803159d16d295ff8536ed89ac2dd0b7168c461648a7a2ff44ead2"
)))
)]
#[case::valid_declare_tx(
declare_tx(),
local_test_state_reader_factory(CairoVersion::Cairo1, false),
Ok(TransactionHash(felt!(
"0x157c517d0bd6fe177dd4f13b47bc3050aceae12609338ccd44a0eff1a3ce7c9"
)))
)]
#[case::invalid_tx(
invoke_tx(CairoVersion::Cairo1),
local_test_state_reader_factory(CairoVersion::Cairo1, true),
Err(StatefulTransactionValidatorError::StatefulValidatorError(
StatefulValidatorError::TransactionPreValidationError(
TransactionPreValidationError::TransactionFeeError(
Expand All @@ -90,7 +68,6 @@ fn stateful_validator(block_context: BlockContext) -> StatefulTransactionValidat
)]
fn test_stateful_tx_validator(
#[case] external_tx: RPCTransaction,
#[case] state_reader_factory: TestStateReaderFactory,
#[case] expected_result: StatefulTransactionValidatorResult<TransactionHash>,
stateful_validator: StatefulTransactionValidator,
) {
Expand All @@ -103,10 +80,14 @@ fn test_stateful_tx_validator(
_ => None,
};

let validator = stateful_validator.instantiate_validator(&state_reader_factory).unwrap();
let expected_result_msg = format!("{:?}", expected_result);

let mut mock_validator = MockStatefulTransactionValidatorTrait::new();
mock_validator.expect_validate().return_once(|_, _| expected_result.map(|_| ()));
mock_validator.expect_get_nonce().returning(|_| Ok(Nonce(Felt::ZERO)));

let result = stateful_validator.run_validate(&external_tx, optional_class_info, validator);
assert_eq!(format!("{:?}", result), format!("{:?}", expected_result));
let result = stateful_validator.run_validate(&external_tx, optional_class_info, mock_validator);
assert_eq!(format!("{:?}", result), expected_result_msg);
}

#[test]
Expand Down

0 comments on commit 7f2356e

Please sign in to comment.