diff --git a/Cargo.lock b/Cargo.lock index 3ffc836e..d0dcf0bb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1723,6 +1723,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "downcast" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1435fa1053d8b2fbbe9be7e97eca7f33d37b28409959813daefc1446a14247f1" + [[package]] name = "dunce" version = "1.0.4" @@ -2234,6 +2240,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "fragile" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c2141d6d6c8512188a7891b4b01590a45f6dac67afb4f255c4124dbb86d4eaa" + [[package]] name = "fs2" version = "0.4.3" @@ -3428,6 +3440,33 @@ dependencies = [ "windows-sys 0.48.0", ] +[[package]] +name = "mockall" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "43766c2b5203b10de348ffe19f7e54564b64f3d6018ff7648d1e2d6d3a0f0a48" +dependencies = [ + "cfg-if", + "downcast", + "fragile", + "lazy_static", + "mockall_derive", + "predicates", + "predicates-tree", +] + +[[package]] +name = "mockall_derive" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "af7cbce79ec385a1d4f54baa90a76401eb15d9cab93685f62e7e9f942aa00ae2" +dependencies = [ + "cfg-if", + "proc-macro2", + "quote", + "syn 2.0.68", +] + [[package]] name = "native-tls" version = "0.2.12" @@ -4172,6 +4211,32 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "925383efa346730478fb4838dbe9137d2a47675ad789c546d150a6e1dd4ab31c" +[[package]] +name = "predicates" +version = "3.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "68b87bfd4605926cdfefc1c3b5f8fe560e3feca9d5552cf68c466d3d8236c7e8" +dependencies = [ + "anstyle", + "predicates-core", +] + +[[package]] +name = "predicates-core" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b794032607612e7abeb4db69adb4e33590fa6cf1149e95fd7cb00e634b92f174" + +[[package]] +name = "predicates-tree" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "368ba315fb8c5052ab692e68a0eefec6ec57b23a36959c14496f0b0df2c0cecf" +dependencies = [ + "predicates-core", + "termtree", +] + [[package]] name = "pretty_assertions" version = "1.4.0" @@ -5302,6 +5367,7 @@ dependencies = [ "cairo-lang-starknet-classes", "cairo-vm", "hyper", + "mockall", "num-bigint", "num-traits 0.2.19", "papyrus_config", @@ -5643,6 +5709,12 @@ dependencies = [ "winapi", ] +[[package]] +name = "termtree" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3369f5ac52d5eb6ab48c6b4ffdc8efbcad6b89c765749064ba298f2c68a16a76" + [[package]] name = "test_utils" version = "0.0.0" diff --git a/Cargo.toml b/Cargo.toml index 0396867d..20ed7cdc 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -51,6 +51,7 @@ hyper = { version = "0.14", features = ["client", "server", "http1", "http2", "t indexmap = "2.1.0" itertools = "0.13.0" lazy_static = "1.4.0" +mockall = "0.12.1" num-traits = "0.2" num-bigint = { version = "0.4.5", default-features = false } # TODO(YaelD, 28/5/2024): The special Papyrus version is needed in order to be aligned with the diff --git a/crates/gateway/Cargo.toml b/crates/gateway/Cargo.toml index ce583b1c..4e4d841e 100644 --- a/crates/gateway/Cargo.toml +++ b/crates/gateway/Cargo.toml @@ -18,6 +18,7 @@ blockifier= { workspace = true , features = ["testing"] } cairo-lang-starknet-classes.workspace = true cairo-vm.workspace = true hyper.workspace = true +mockall.workspace = true num-traits.workspace = true papyrus_config.workspace = true papyrus_rpc.workspace = true diff --git a/crates/gateway/src/gateway.rs b/crates/gateway/src/gateway.rs index cd51fb77..b9432381 100644 --- a/crates/gateway/src/gateway.rs +++ b/crates/gateway/src/gateway.rs @@ -125,8 +125,8 @@ fn process_tx( }; // TODO(Yael, 19/5/2024): pass the relevant deploy_account_hash. - let tx_hash = - stateful_tx_validator.run_validate(state_reader_factory, &tx, optional_class_info, None)?; + let validator = stateful_tx_validator.instantiate_validator(state_reader_factory)?; + let tx_hash = stateful_tx_validator.run_validate(&tx, optional_class_info, None, validator)?; // TODO(Arni): Add the Sierra and the Casm to the mempool input. Ok(MempoolInput { diff --git a/crates/gateway/src/stateful_transaction_validator.rs b/crates/gateway/src/stateful_transaction_validator.rs index bce61a69..e3686192 100644 --- a/crates/gateway/src/stateful_transaction_validator.rs +++ b/crates/gateway/src/stateful_transaction_validator.rs @@ -1,10 +1,13 @@ use blockifier::blockifier::block::BlockInfo; -use blockifier::blockifier::stateful_validator::StatefulValidator as BlockifierStatefulValidator; +use blockifier::blockifier::stateful_validator::StatefulValidator as GenericBlockifierStatefulValidator; use blockifier::bouncer::BouncerConfig; use blockifier::context::BlockContext; use blockifier::execution::contract_class::ClassInfo; use blockifier::state::cached_state::CachedState; +use blockifier::transaction::account_transaction::AccountTransaction; use blockifier::versioned_constants::VersionedConstants; +use mockall::predicate::*; +use mockall::*; use starknet_api::rpc_transaction::RPCTransaction; use starknet_api::transaction::TransactionHash; @@ -21,14 +24,50 @@ pub struct StatefulTransactionValidator { pub config: StatefulTransactionValidatorConfig, } +type BlockifierStatefulValidator = GenericBlockifierStatefulValidator>; + +#[automock] +pub trait StatefulTransactionValidatorTrait { + fn perform_validations( + &mut self, + account_tx: AccountTransaction, + deploy_account_tx_hash: Option, + ) -> StatefulTransactionValidatorResult<()>; +} + + +impl StatefulTransactionValidatorTrait for BlockifierStatefulValidator { + fn perform_validations( + &mut self, + account_tx: AccountTransaction, + deploy_account_tx_hash: Option, + ) -> StatefulTransactionValidatorResult<()> { + Ok(self.perform_validations(account_tx, deploy_account_tx_hash)?) + } +} + impl StatefulTransactionValidator { - pub fn run_validate( + pub fn run_validate( &self, - state_reader_factory: &dyn StateReaderFactory, external_tx: &RPCTransaction, optional_class_info: Option, deploy_account_tx_hash: Option, + mut validator: TStatefulTransactionValidator, ) -> StatefulTransactionValidatorResult { + let account_tx = external_tx_to_account_tx( + external_tx, + optional_class_info, + &self.config.chain_info.chain_id, + )?; + let tx_hash = get_tx_hash(&account_tx); + validator.perform_validations(account_tx, deploy_account_tx_hash)?; + Ok(tx_hash) + } + + pub fn instantiate_validator( + &self, + state_reader_factory: &dyn StateReaderFactory, + ) -> StatefulTransactionValidatorResult { // TODO(yael 6/5/2024): consider storing the block_info as part of the // StatefulTransactionValidator and update it only once a new block is created. let latest_block_info = get_latest_block_info(state_reader_factory)?; @@ -53,19 +92,11 @@ impl StatefulTransactionValidator { BouncerConfig::max(), ); - let mut validator = BlockifierStatefulValidator::create( + Ok(BlockifierStatefulValidator::create( state, block_context, self.config.max_nonce_for_validation_skip, - ); - let account_tx = external_tx_to_account_tx( - external_tx, - optional_class_info, - &self.config.chain_info.chain_id, - )?; - let tx_hash = get_tx_hash(&account_tx); - validator.perform_validations(account_tx, deploy_account_tx_hash)?; - Ok(tx_hash) + )) } } diff --git a/crates/gateway/src/stateful_transaction_validator_test.rs b/crates/gateway/src/stateful_transaction_validator_test.rs index 8eebdcf6..a71f1418 100644 --- a/crates/gateway/src/stateful_transaction_validator_test.rs +++ b/crates/gateway/src/stateful_transaction_validator_test.rs @@ -8,51 +8,26 @@ use starknet_api::felt; use starknet_api::rpc_transaction::RPCTransaction; use starknet_api::transaction::TransactionHash; use test_utils::starknet_api_test_utils::{ - declare_tx, deploy_account_tx, invoke_tx, VALID_L1_GAS_MAX_AMOUNT, - VALID_L1_GAS_MAX_PRICE_PER_UNIT, + invoke_tx, VALID_L1_GAS_MAX_AMOUNT, VALID_L1_GAS_MAX_PRICE_PER_UNIT, }; use crate::compilation::compile_contract_class; use crate::config::StatefulTransactionValidatorConfig; use crate::errors::{StatefulTransactionValidatorError, StatefulTransactionValidatorResult}; -use crate::state_reader_test_utils::{ - local_test_state_reader_factory, local_test_state_reader_factory_for_deploy_account, - TestStateReaderFactory, +use crate::state_reader_test_utils::local_test_state_reader_factory; +use crate::stateful_transaction_validator::{ + MockStatefulTransactionValidatorTrait, StatefulTransactionValidator, }; -use crate::stateful_transaction_validator::StatefulTransactionValidator; #[rstest] -#[case::valid_invoke_tx_cairo1( +#[case::valid_tx( invoke_tx(CairoVersion::Cairo1), - local_test_state_reader_factory(CairoVersion::Cairo1, false), Ok(TransactionHash(felt!( "0x007d70505b4487a4e1c1a4b4e4342cb5aa9e73b86d031891170c45a57ad8b4e6" ))) )] -#[case::valid_invoke_tx_cairo0( - invoke_tx(CairoVersion::Cairo0), - local_test_state_reader_factory(CairoVersion::Cairo0, false), - Ok(TransactionHash(felt!( - "0x032e3a969a64027f15ce2b526d8dff47d47524c58ff0363f93ce4cbe7c280861" - ))) -)] -#[case::valid_deploy_account_tx( - deploy_account_tx(), - local_test_state_reader_factory_for_deploy_account(&external_tx), - Ok(TransactionHash(felt!( - "0x013287740b37dc112391de4ef0f7cd7aeca323537ca2a78a1108c6aee5a55d70" - ))) -)] -#[case::valid_declare_tx( - declare_tx(), - local_test_state_reader_factory(CairoVersion::Cairo1, false), - Ok(TransactionHash(felt!( - "0x02da54b89e00d2e201f8e3ed2bcc715a69e89aefdce88aff2d2facb8dec55c0a" - ))) -)] #[case::invalid_tx( invoke_tx(CairoVersion::Cairo1), - local_test_state_reader_factory(CairoVersion::Cairo1, true), Err(StatefulTransactionValidatorError::StatefulValidatorError( StatefulValidatorError::TransactionPreValidationError( TransactionPreValidationError::TransactionFeeError( @@ -67,7 +42,6 @@ use crate::stateful_transaction_validator::StatefulTransactionValidator; )] fn test_stateful_tx_validator( #[case] external_tx: RPCTransaction, - #[case] state_reader_factory: TestStateReaderFactory, #[case] expected_result: StatefulTransactionValidatorResult, ) { let block_context = &BlockContext::create_for_testing(); @@ -84,11 +58,31 @@ fn test_stateful_tx_validator( _ => None, }; - let result = stateful_validator.run_validate( - &state_reader_factory, - &external_tx, - optional_class_info, - None, - ); - assert_eq!(format!("{:?}", result), format!("{:?}", expected_result)); + let expected_result_msg = format!("{:?}", expected_result); + + let mut mock_validator = MockStatefulTransactionValidatorTrait::new(); + mock_validator.expect_perform_validations().return_once(|_, _| match expected_result { + Ok(..) => Ok(()), + Err(e) => Err(e), + }); + + let result = + stateful_validator.run_validate(&external_tx, optional_class_info, None, mock_validator); + assert_eq!(format!("{:?}", result), expected_result_msg); +} + +#[test] +fn test_instantiate_validator() { + let state_reader_factory = local_test_state_reader_factory(CairoVersion::Cairo1, false); + let block_context = &BlockContext::create_for_testing(); + let stateful_validator = StatefulTransactionValidator { + config: StatefulTransactionValidatorConfig { + max_nonce_for_validation_skip: Default::default(), + validate_max_n_steps: block_context.versioned_constants().validate_max_n_steps, + max_recursion_depth: block_context.versioned_constants().max_recursion_depth, + chain_info: block_context.chain_info().clone().into(), + }, + }; + let blockifier_validator = stateful_validator.instantiate_validator(&state_reader_factory); + assert!(blockifier_validator.is_ok()); }