Skip to content

Commit

Permalink
test(blockifier): function to build calldata for recursive call contr…
Browse files Browse the repository at this point in the history
…act calls
  • Loading branch information
TzahiTaub committed Oct 15, 2024
1 parent 7baca1a commit 7160c7a
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 34 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,14 @@ use crate::retdata;
use crate::state::state_api::StateReader;
use crate::test_utils::contracts::FeatureContract;
use crate::test_utils::initial_test_state::test_state;
use crate::test_utils::{create_calldata, trivial_external_entry_point_new, CairoVersion, BALANCE};
use crate::test_utils::syscall::build_recurse_calldata;
use crate::test_utils::{
create_calldata,
trivial_external_entry_point_new,
CairoVersion,
CompilerBasedVersion,
BALANCE,
};

#[test]
fn test_call_contract_that_panics() {
Expand Down Expand Up @@ -76,7 +83,7 @@ fn test_call_contract(
inner_contract.get_instance_address(0),
"test_storage_read_write",
&[
felt!(405_u16), // Calldata: address.
felt!(405_u16), // Calldata: storage address.
felt!(48_u8), // Calldata: value.
],
);
Expand All @@ -98,7 +105,7 @@ fn test_call_contract(

/// Cairo0 / Cairo1 calls to Cairo0 / Cairo1.
#[rstest]
fn test_track_resources(
fn test_tracked_resources(
#[values(CairoVersion::Cairo0, CairoVersion::Cairo1)] outer_version: CairoVersion,
#[values(CairoVersion::Cairo0, CairoVersion::Cairo1)] inner_version: CairoVersion,
) {
Expand All @@ -108,14 +115,7 @@ fn test_track_resources(
let mut state = test_state(chain_info, BALANCE, &[(outer_contract, 1), (inner_contract, 1)]);

let outer_entry_point_selector = selector_from_name("test_call_contract");
let calldata = create_calldata(
inner_contract.get_instance_address(0),
"test_storage_read_write",
&[
felt!(405_u16), // Calldata: address.
felt!(48_u8), // Calldata: value.
],
);
let calldata = build_recurse_calldata(&[inner_version.into()]);
let entry_point_call = CallEntryPoint {
entry_point_selector: outer_entry_point_selector,
calldata,
Expand All @@ -140,37 +140,23 @@ fn test_track_resources(
/// 1) Cairo-Steps contract that calls Sierra-Gas (nested) contract.
/// 2) Sierra-Gas contract.
#[rstest]
fn test_track_resources_nested(
fn test_tracked_resources_nested(
#[values(
FeatureContract::TestContract(CairoVersion::Cairo0),
FeatureContract::CairoStepsTestContract
CompilerBasedVersion::CairoVersion(CairoVersion::Cairo0),
CompilerBasedVersion::OldCairo1
)]
cairo_steps_contract: FeatureContract,
cairo_steps_contract_version: CompilerBasedVersion,
) {
let cairo_steps_contract = cairo_steps_contract_version.get_test_contract();
let sierra_gas_contract = FeatureContract::TestContract(CairoVersion::Cairo1);
let chain_info = &ChainInfo::create_for_testing();
let mut state =
test_state(chain_info, BALANCE, &[(sierra_gas_contract, 1), (cairo_steps_contract, 1)]);

let first_calldata = create_calldata(
cairo_steps_contract.get_instance_address(0),
"test_call_contract",
&[
sierra_gas_contract.get_instance_address(0).into(),
selector_from_name("test_storage_read_write").0,
felt!(2_u8), // Calldata length
felt!(405_u16), // Calldata: address.
felt!(48_u8), // Calldata: value.
],
);
let second_calldata = create_calldata(
sierra_gas_contract.get_instance_address(0),
"test_storage_read_write",
&[
felt!(406_u16), // Calldata: address.
felt!(49_u8), // Calldata: value.
],
);
let first_calldata =
build_recurse_calldata(&[cairo_steps_contract_version, CairoVersion::Cairo1.into()]);

let second_calldata = build_recurse_calldata(&[CairoVersion::Cairo1.into()]);

let concated_calldata_felts = [first_calldata.0, second_calldata.0]
.into_iter()
Expand Down
19 changes: 19 additions & 0 deletions crates/blockifier/src/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ pub mod initial_test_state;
pub mod invoke;
pub mod prices;
pub mod struct_impls;
pub mod syscall;
pub mod transfers_generator;
use std::collections::HashMap;
use std::fs;
Expand Down Expand Up @@ -90,6 +91,24 @@ impl CairoVersion {
}
}

pub enum CompilerBasedVersion {
CairoVersion(CairoVersion),
OldCairo1,
}

impl From<CairoVersion> for CompilerBasedVersion {
fn from(version: CairoVersion) -> Self {
Self::CairoVersion(version)
}
}
impl CompilerBasedVersion {
pub fn get_test_contract(&self) -> FeatureContract {
match self {
Self::CairoVersion(version) => FeatureContract::TestContract(*version),
Self::OldCairo1 => FeatureContract::CairoStepsTestContract,
}
}
}
// Storage keys.
pub fn test_erc20_sequencer_balance_key() -> StorageKey {
get_fee_token_var_address(contract_address!(TEST_SEQUENCER_ADDRESS))
Expand Down
33 changes: 33 additions & 0 deletions crates/blockifier/src/test_utils/syscall.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
use starknet_api::felt;
use starknet_api::transaction::Calldata;

use crate::test_utils::{create_calldata, CompilerBasedVersion};

/// Returns the calldata for N recursive call contract syscalls, where N is the length of versions.
/// versions determines the cairo version of the called contract in each recursive call. Final call
/// is a simple local contract call (test_storage_read_write).
/// The first element in the returned value is the calldata for a call from a contract of the first
/// element in versions, to the a contract of the second element, etc.
pub fn build_recurse_calldata(versions: &[CompilerBasedVersion]) -> Calldata {
if versions.is_empty() {
return Calldata(vec![].into());
}
let last_version = versions.last().unwrap();
let mut calldata = create_calldata(
last_version.get_test_contract().get_instance_address(0),
"test_storage_read_write",
&[
felt!(123_u16), // Calldata: address.
felt!(45_u8), // Calldata: value.
],
);

for version in versions[..versions.len() - 1].iter().rev() {
calldata = create_calldata(
version.get_test_contract().get_instance_address(0),
"test_call_contract",
&calldata.0,
);
}
return calldata;
}

0 comments on commit 7160c7a

Please sign in to comment.