Skip to content

Commit

Permalink
test(blockifier): function to build calldata for recursive call contr…
Browse files Browse the repository at this point in the history
…act calls (#1449)
  • Loading branch information
TzahiTaub authored Oct 27, 2024
1 parent ea50d56 commit f9bf674
Show file tree
Hide file tree
Showing 3 changed files with 105 additions and 46 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,14 @@ use crate::retdata;
use crate::state::state_api::StateReader;
use crate::test_utils::contracts::FeatureContract;
use crate::test_utils::initial_test_state::test_state;
use crate::test_utils::{create_calldata, trivial_external_entry_point_new, CairoVersion, BALANCE};
use crate::test_utils::syscall::build_recurse_calldata;
use crate::test_utils::{
create_calldata,
trivial_external_entry_point_new,
CairoVersion,
CompilerBasedVersion,
BALANCE,
};

#[test]
fn test_call_contract_that_panics() {
Expand Down Expand Up @@ -79,7 +86,7 @@ fn test_call_contract(
inner_contract.get_instance_address(0),
"test_storage_read_write",
&[
felt!(405_u16), // Calldata: address.
felt!(405_u16), // Calldata: storage address.
felt!(48_u8), // Calldata: value.
],
);
Expand All @@ -99,42 +106,44 @@ fn test_call_contract(
);
}

/// Cairo0 / Cairo1 calls to Cairo0 / Cairo1.
/// Cairo0 / Old Cairo1 / Cairo1 calls to Cairo0 / Old Cairo1/ Cairo1.
#[rstest]
fn test_track_resources(
#[values(CairoVersion::Cairo0, CairoVersion::Cairo1)] outer_version: CairoVersion,
#[values(CairoVersion::Cairo0, CairoVersion::Cairo1)] inner_version: CairoVersion,
fn test_tracked_resources(
#[values(
CompilerBasedVersion::CairoVersion(CairoVersion::Cairo0),
CompilerBasedVersion::OldCairo1,
CompilerBasedVersion::CairoVersion(CairoVersion::Cairo1)
)]
outer_version: CompilerBasedVersion,
#[values(
CompilerBasedVersion::CairoVersion(CairoVersion::Cairo0),
CompilerBasedVersion::OldCairo1,
CompilerBasedVersion::CairoVersion(CairoVersion::Cairo1)
)]
inner_version: CompilerBasedVersion,
) {
let outer_contract = FeatureContract::TestContract(outer_version);
let inner_contract = FeatureContract::TestContract(inner_version);
let outer_contract = outer_version.get_test_contract();
let inner_contract = inner_version.get_test_contract();
let chain_info = &ChainInfo::create_for_testing();
let mut state = test_state(chain_info, BALANCE, &[(outer_contract, 1), (inner_contract, 1)]);

let outer_entry_point_selector = selector_from_name("test_call_contract");
let calldata = create_calldata(
inner_contract.get_instance_address(0),
"test_storage_read_write",
&[
felt!(405_u16), // Calldata: address.
felt!(48_u8), // Calldata: value.
],
);
let calldata = build_recurse_calldata(&[inner_version]);
let entry_point_call = CallEntryPoint {
entry_point_selector: outer_entry_point_selector,
calldata,
..trivial_external_entry_point_new(outer_contract)
};

let execution = entry_point_call.execute_directly(&mut state).unwrap();
let expected_outer_resource = match outer_version {
CairoVersion::Cairo0 => TrackedResource::CairoSteps,
CairoVersion::Cairo1 => TrackedResource::SierraGas,
};
let expected_outer_resource = outer_version.own_tracked_resource();
assert_eq!(execution.tracked_resource, expected_outer_resource);

let expected_inner_resource = match (outer_version, inner_version) {
(CairoVersion::Cairo1, CairoVersion::Cairo1) => TrackedResource::SierraGas,
_ => TrackedResource::CairoSteps,
let expected_inner_resource = if expected_outer_resource == inner_version.own_tracked_resource()
{
expected_outer_resource
} else {
TrackedResource::CairoSteps
};
assert_eq!(execution.inner_calls.first().unwrap().tracked_resource, expected_inner_resource);
}
Expand All @@ -143,37 +152,26 @@ fn test_track_resources(
/// 1) Cairo-Steps contract that calls Sierra-Gas (nested) contract.
/// 2) Sierra-Gas contract.
#[rstest]
fn test_track_resources_nested(
fn test_tracked_resources_nested(
#[values(
FeatureContract::TestContract(CairoVersion::Cairo0),
FeatureContract::CairoStepsTestContract
CompilerBasedVersion::CairoVersion(CairoVersion::Cairo0),
CompilerBasedVersion::OldCairo1
)]
cairo_steps_contract: FeatureContract,
cairo_steps_contract_version: CompilerBasedVersion,
) {
let cairo_steps_contract = cairo_steps_contract_version.get_test_contract();
let sierra_gas_contract = FeatureContract::TestContract(CairoVersion::Cairo1);
let chain_info = &ChainInfo::create_for_testing();
let mut state =
test_state(chain_info, BALANCE, &[(sierra_gas_contract, 1), (cairo_steps_contract, 1)]);

let first_calldata = create_calldata(
cairo_steps_contract.get_instance_address(0),
"test_call_contract",
&[
sierra_gas_contract.get_instance_address(0).into(),
selector_from_name("test_storage_read_write").0,
felt!(2_u8), // Calldata length
felt!(405_u16), // Calldata: address.
felt!(48_u8), // Calldata: value.
],
);
let second_calldata = create_calldata(
sierra_gas_contract.get_instance_address(0),
"test_storage_read_write",
&[
felt!(406_u16), // Calldata: address.
felt!(49_u8), // Calldata: value.
],
);
let first_calldata = build_recurse_calldata(&[
cairo_steps_contract_version,
CompilerBasedVersion::CairoVersion(CairoVersion::Cairo1),
]);

let second_calldata =
build_recurse_calldata(&[CompilerBasedVersion::CairoVersion(CairoVersion::Cairo1)]);

let concated_calldata_felts = [first_calldata.0, second_calldata.0]
.into_iter()
Expand Down
28 changes: 28 additions & 0 deletions crates/blockifier/src/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ pub mod initial_test_state;
pub mod invoke;
pub mod prices;
pub mod struct_impls;
pub mod syscall;
pub mod transfers_generator;
use std::collections::HashMap;
use std::fs;
Expand All @@ -30,6 +31,7 @@ use starknet_types_core::felt::Felt;

use crate::abi::abi_utils::{get_fee_token_var_address, selector_from_name};
use crate::execution::call_info::ExecutionSummary;
use crate::execution::contract_class::TrackedResource;
use crate::execution::deprecated_syscalls::hint_processor::SyscallCounter;
use crate::execution::entry_point::CallEntryPoint;
use crate::execution::syscalls::SyscallSelector;
Expand Down Expand Up @@ -90,6 +92,32 @@ impl CairoVersion {
}
}

#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub enum CompilerBasedVersion {
CairoVersion(CairoVersion),
OldCairo1,
}

impl CompilerBasedVersion {
pub fn get_test_contract(&self) -> FeatureContract {
match self {
Self::CairoVersion(version) => FeatureContract::TestContract(*version),
Self::OldCairo1 => FeatureContract::CairoStepsTestContract,
}
}

/// Returns the context-free tracked resource of this contract (does not take caller contract
/// into account).
pub fn own_tracked_resource(&self) -> TrackedResource {
match self {
Self::CairoVersion(CairoVersion::Cairo0) | Self::OldCairo1 => {
TrackedResource::CairoSteps
}
Self::CairoVersion(CairoVersion::Cairo1) => TrackedResource::SierraGas,
}
}
}

// Storage keys.
pub fn test_erc20_sequencer_balance_key() -> StorageKey {
get_fee_token_var_address(contract_address!(TEST_SEQUENCER_ADDRESS))
Expand Down
33 changes: 33 additions & 0 deletions crates/blockifier/src/test_utils/syscall.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
use starknet_api::felt;
use starknet_api::transaction::Calldata;

use crate::test_utils::{create_calldata, CompilerBasedVersion};

/// Returns the calldata for N recursive call contract syscalls, where N is the length of versions.
/// versions determines the cairo version of the called contract in each recursive call. Final call
/// is a simple local contract call (test_storage_read_write).
/// The first element in the returned value is the calldata for a call from a contract of the first
/// element in versions, to the a contract of the second element, etc.
pub fn build_recurse_calldata(versions: &[CompilerBasedVersion]) -> Calldata {
if versions.is_empty() {
return Calldata(vec![].into());
}
let last_version = versions.last().unwrap();
let mut calldata = create_calldata(
last_version.get_test_contract().get_instance_address(0),
"test_storage_read_write",
&[
felt!(123_u16), // Calldata: address.
felt!(45_u8), // Calldata: value.
],
);

for version in versions[..versions.len() - 1].iter().rev() {
calldata = create_calldata(
version.get_test_contract().get_instance_address(0),
"test_call_contract",
&calldata.0,
);
}
calldata
}

0 comments on commit f9bf674

Please sign in to comment.