From 7160c7ab237345acce046b8908cfa487c2f4b9b9 Mon Sep 17 00:00:00 2001 From: Tzahi Taub Date: Tue, 15 Oct 2024 17:08:26 +0300 Subject: [PATCH] test(blockifier): function to build calldata for recursive call contract calls --- .../syscalls/syscall_tests/call_contract.rs | 54 +++++++------------ crates/blockifier/src/test_utils.rs | 19 +++++++ crates/blockifier/src/test_utils/syscall.rs | 33 ++++++++++++ 3 files changed, 72 insertions(+), 34 deletions(-) create mode 100644 crates/blockifier/src/test_utils/syscall.rs diff --git a/crates/blockifier/src/execution/syscalls/syscall_tests/call_contract.rs b/crates/blockifier/src/execution/syscalls/syscall_tests/call_contract.rs index 9e3b573c98..a6ad1395dc 100644 --- a/crates/blockifier/src/execution/syscalls/syscall_tests/call_contract.rs +++ b/crates/blockifier/src/execution/syscalls/syscall_tests/call_contract.rs @@ -19,7 +19,14 @@ use crate::retdata; use crate::state::state_api::StateReader; use crate::test_utils::contracts::FeatureContract; use crate::test_utils::initial_test_state::test_state; -use crate::test_utils::{create_calldata, trivial_external_entry_point_new, CairoVersion, BALANCE}; +use crate::test_utils::syscall::build_recurse_calldata; +use crate::test_utils::{ + create_calldata, + trivial_external_entry_point_new, + CairoVersion, + CompilerBasedVersion, + BALANCE, +}; #[test] fn test_call_contract_that_panics() { @@ -76,7 +83,7 @@ fn test_call_contract( inner_contract.get_instance_address(0), "test_storage_read_write", &[ - felt!(405_u16), // Calldata: address. + felt!(405_u16), // Calldata: storage address. felt!(48_u8), // Calldata: value. ], ); @@ -98,7 +105,7 @@ fn test_call_contract( /// Cairo0 / Cairo1 calls to Cairo0 / Cairo1. #[rstest] -fn test_track_resources( +fn test_tracked_resources( #[values(CairoVersion::Cairo0, CairoVersion::Cairo1)] outer_version: CairoVersion, #[values(CairoVersion::Cairo0, CairoVersion::Cairo1)] inner_version: CairoVersion, ) { @@ -108,14 +115,7 @@ fn test_track_resources( let mut state = test_state(chain_info, BALANCE, &[(outer_contract, 1), (inner_contract, 1)]); let outer_entry_point_selector = selector_from_name("test_call_contract"); - let calldata = create_calldata( - inner_contract.get_instance_address(0), - "test_storage_read_write", - &[ - felt!(405_u16), // Calldata: address. - felt!(48_u8), // Calldata: value. - ], - ); + let calldata = build_recurse_calldata(&[inner_version.into()]); let entry_point_call = CallEntryPoint { entry_point_selector: outer_entry_point_selector, calldata, @@ -140,37 +140,23 @@ fn test_track_resources( /// 1) Cairo-Steps contract that calls Sierra-Gas (nested) contract. /// 2) Sierra-Gas contract. #[rstest] -fn test_track_resources_nested( +fn test_tracked_resources_nested( #[values( - FeatureContract::TestContract(CairoVersion::Cairo0), - FeatureContract::CairoStepsTestContract + CompilerBasedVersion::CairoVersion(CairoVersion::Cairo0), + CompilerBasedVersion::OldCairo1 )] - cairo_steps_contract: FeatureContract, + cairo_steps_contract_version: CompilerBasedVersion, ) { + let cairo_steps_contract = cairo_steps_contract_version.get_test_contract(); let sierra_gas_contract = FeatureContract::TestContract(CairoVersion::Cairo1); let chain_info = &ChainInfo::create_for_testing(); let mut state = test_state(chain_info, BALANCE, &[(sierra_gas_contract, 1), (cairo_steps_contract, 1)]); - let first_calldata = create_calldata( - cairo_steps_contract.get_instance_address(0), - "test_call_contract", - &[ - sierra_gas_contract.get_instance_address(0).into(), - selector_from_name("test_storage_read_write").0, - felt!(2_u8), // Calldata length - felt!(405_u16), // Calldata: address. - felt!(48_u8), // Calldata: value. - ], - ); - let second_calldata = create_calldata( - sierra_gas_contract.get_instance_address(0), - "test_storage_read_write", - &[ - felt!(406_u16), // Calldata: address. - felt!(49_u8), // Calldata: value. - ], - ); + let first_calldata = + build_recurse_calldata(&[cairo_steps_contract_version, CairoVersion::Cairo1.into()]); + + let second_calldata = build_recurse_calldata(&[CairoVersion::Cairo1.into()]); let concated_calldata_felts = [first_calldata.0, second_calldata.0] .into_iter() diff --git a/crates/blockifier/src/test_utils.rs b/crates/blockifier/src/test_utils.rs index 111b425dcf..f85ba00b2c 100644 --- a/crates/blockifier/src/test_utils.rs +++ b/crates/blockifier/src/test_utils.rs @@ -7,6 +7,7 @@ pub mod initial_test_state; pub mod invoke; pub mod prices; pub mod struct_impls; +pub mod syscall; pub mod transfers_generator; use std::collections::HashMap; use std::fs; @@ -90,6 +91,24 @@ impl CairoVersion { } } +pub enum CompilerBasedVersion { + CairoVersion(CairoVersion), + OldCairo1, +} + +impl From for CompilerBasedVersion { + fn from(version: CairoVersion) -> Self { + Self::CairoVersion(version) + } +} +impl CompilerBasedVersion { + pub fn get_test_contract(&self) -> FeatureContract { + match self { + Self::CairoVersion(version) => FeatureContract::TestContract(*version), + Self::OldCairo1 => FeatureContract::CairoStepsTestContract, + } + } +} // Storage keys. pub fn test_erc20_sequencer_balance_key() -> StorageKey { get_fee_token_var_address(contract_address!(TEST_SEQUENCER_ADDRESS)) diff --git a/crates/blockifier/src/test_utils/syscall.rs b/crates/blockifier/src/test_utils/syscall.rs new file mode 100644 index 0000000000..cbf8c16dc9 --- /dev/null +++ b/crates/blockifier/src/test_utils/syscall.rs @@ -0,0 +1,33 @@ +use starknet_api::felt; +use starknet_api::transaction::Calldata; + +use crate::test_utils::{create_calldata, CompilerBasedVersion}; + +/// Returns the calldata for N recursive call contract syscalls, where N is the length of versions. +/// versions determines the cairo version of the called contract in each recursive call. Final call +/// is a simple local contract call (test_storage_read_write). +/// The first element in the returned value is the calldata for a call from a contract of the first +/// element in versions, to the a contract of the second element, etc. +pub fn build_recurse_calldata(versions: &[CompilerBasedVersion]) -> Calldata { + if versions.is_empty() { + return Calldata(vec![].into()); + } + let last_version = versions.last().unwrap(); + let mut calldata = create_calldata( + last_version.get_test_contract().get_instance_address(0), + "test_storage_read_write", + &[ + felt!(123_u16), // Calldata: address. + felt!(45_u8), // Calldata: value. + ], + ); + + for version in versions[..versions.len() - 1].iter().rev() { + calldata = create_calldata( + version.get_test_contract().get_instance_address(0), + "test_call_contract", + &calldata.0, + ); + } + return calldata; +}