Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test(blockifier): function to build calldata for recursive call_contract calls #1449

Merged
merged 1 commit into from
Oct 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,14 @@ use crate::retdata;
use crate::state::state_api::StateReader;
use crate::test_utils::contracts::FeatureContract;
use crate::test_utils::initial_test_state::test_state;
use crate::test_utils::{create_calldata, trivial_external_entry_point_new, CairoVersion, BALANCE};
use crate::test_utils::syscall::build_recurse_calldata;
use crate::test_utils::{
create_calldata,
trivial_external_entry_point_new,
CairoVersion,
CompilerBasedVersion,
BALANCE,
};

#[test]
fn test_call_contract_that_panics() {
Expand Down Expand Up @@ -79,7 +86,7 @@ fn test_call_contract(
inner_contract.get_instance_address(0),
"test_storage_read_write",
&[
felt!(405_u16), // Calldata: address.
felt!(405_u16), // Calldata: storage address.
felt!(48_u8), // Calldata: value.
],
);
Expand All @@ -99,42 +106,44 @@ fn test_call_contract(
);
}

/// Cairo0 / Cairo1 calls to Cairo0 / Cairo1.
/// Cairo0 / Old Cairo1 / Cairo1 calls to Cairo0 / Old Cairo1/ Cairo1.
#[rstest]
fn test_track_resources(
#[values(CairoVersion::Cairo0, CairoVersion::Cairo1)] outer_version: CairoVersion,
#[values(CairoVersion::Cairo0, CairoVersion::Cairo1)] inner_version: CairoVersion,
fn test_tracked_resources(
#[values(
CompilerBasedVersion::CairoVersion(CairoVersion::Cairo0),
CompilerBasedVersion::OldCairo1,
CompilerBasedVersion::CairoVersion(CairoVersion::Cairo1)
)]
outer_version: CompilerBasedVersion,
#[values(
CompilerBasedVersion::CairoVersion(CairoVersion::Cairo0),
CompilerBasedVersion::OldCairo1,
CompilerBasedVersion::CairoVersion(CairoVersion::Cairo1)
)]
inner_version: CompilerBasedVersion,
) {
let outer_contract = FeatureContract::TestContract(outer_version);
let inner_contract = FeatureContract::TestContract(inner_version);
let outer_contract = outer_version.get_test_contract();
let inner_contract = inner_version.get_test_contract();
let chain_info = &ChainInfo::create_for_testing();
let mut state = test_state(chain_info, BALANCE, &[(outer_contract, 1), (inner_contract, 1)]);

let outer_entry_point_selector = selector_from_name("test_call_contract");
let calldata = create_calldata(
inner_contract.get_instance_address(0),
"test_storage_read_write",
&[
felt!(405_u16), // Calldata: address.
felt!(48_u8), // Calldata: value.
],
);
let calldata = build_recurse_calldata(&[inner_version]);
let entry_point_call = CallEntryPoint {
entry_point_selector: outer_entry_point_selector,
calldata,
..trivial_external_entry_point_new(outer_contract)
};

let execution = entry_point_call.execute_directly(&mut state).unwrap();
let expected_outer_resource = match outer_version {
CairoVersion::Cairo0 => TrackedResource::CairoSteps,
CairoVersion::Cairo1 => TrackedResource::SierraGas,
};
let expected_outer_resource = outer_version.own_tracked_resource();
assert_eq!(execution.tracked_resource, expected_outer_resource);

let expected_inner_resource = match (outer_version, inner_version) {
(CairoVersion::Cairo1, CairoVersion::Cairo1) => TrackedResource::SierraGas,
_ => TrackedResource::CairoSteps,
let expected_inner_resource = if expected_outer_resource == inner_version.own_tracked_resource()
{
expected_outer_resource
} else {
TrackedResource::CairoSteps
};
assert_eq!(execution.inner_calls.first().unwrap().tracked_resource, expected_inner_resource);
}
Expand All @@ -143,37 +152,26 @@ fn test_track_resources(
/// 1) Cairo-Steps contract that calls Sierra-Gas (nested) contract.
/// 2) Sierra-Gas contract.
#[rstest]
fn test_track_resources_nested(
fn test_tracked_resources_nested(
#[values(
FeatureContract::TestContract(CairoVersion::Cairo0),
FeatureContract::CairoStepsTestContract
CompilerBasedVersion::CairoVersion(CairoVersion::Cairo0),
CompilerBasedVersion::OldCairo1
)]
cairo_steps_contract: FeatureContract,
cairo_steps_contract_version: CompilerBasedVersion,
) {
let cairo_steps_contract = cairo_steps_contract_version.get_test_contract();
let sierra_gas_contract = FeatureContract::TestContract(CairoVersion::Cairo1);
let chain_info = &ChainInfo::create_for_testing();
let mut state =
test_state(chain_info, BALANCE, &[(sierra_gas_contract, 1), (cairo_steps_contract, 1)]);

let first_calldata = create_calldata(
cairo_steps_contract.get_instance_address(0),
"test_call_contract",
&[
sierra_gas_contract.get_instance_address(0).into(),
selector_from_name("test_storage_read_write").0,
felt!(2_u8), // Calldata length
felt!(405_u16), // Calldata: address.
felt!(48_u8), // Calldata: value.
],
);
let second_calldata = create_calldata(
sierra_gas_contract.get_instance_address(0),
"test_storage_read_write",
&[
felt!(406_u16), // Calldata: address.
felt!(49_u8), // Calldata: value.
],
);
let first_calldata = build_recurse_calldata(&[
cairo_steps_contract_version,
CompilerBasedVersion::CairoVersion(CairoVersion::Cairo1),
]);

let second_calldata =
build_recurse_calldata(&[CompilerBasedVersion::CairoVersion(CairoVersion::Cairo1)]);

let concated_calldata_felts = [first_calldata.0, second_calldata.0]
.into_iter()
Expand Down
28 changes: 28 additions & 0 deletions crates/blockifier/src/test_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ pub mod initial_test_state;
pub mod invoke;
pub mod prices;
pub mod struct_impls;
pub mod syscall;
pub mod transfers_generator;
use std::collections::HashMap;
use std::fs;
Expand All @@ -30,6 +31,7 @@ use starknet_types_core::felt::Felt;

use crate::abi::abi_utils::{get_fee_token_var_address, selector_from_name};
use crate::execution::call_info::ExecutionSummary;
use crate::execution::contract_class::TrackedResource;
use crate::execution::deprecated_syscalls::hint_processor::SyscallCounter;
use crate::execution::entry_point::CallEntryPoint;
use crate::execution::syscalls::SyscallSelector;
Expand Down Expand Up @@ -90,6 +92,32 @@ impl CairoVersion {
}
}

#[derive(Clone, Copy, PartialEq, Eq, Debug)]
pub enum CompilerBasedVersion {
CairoVersion(CairoVersion),
OldCairo1,
}

impl CompilerBasedVersion {
pub fn get_test_contract(&self) -> FeatureContract {
match self {
Self::CairoVersion(version) => FeatureContract::TestContract(*version),
Self::OldCairo1 => FeatureContract::CairoStepsTestContract,
}
}

/// Returns the context-free tracked resource of this contract (does not take caller contract
/// into account).
pub fn own_tracked_resource(&self) -> TrackedResource {
match self {
Self::CairoVersion(CairoVersion::Cairo0) | Self::OldCairo1 => {
TrackedResource::CairoSteps
}
Self::CairoVersion(CairoVersion::Cairo1) => TrackedResource::SierraGas,
}
}
}

// Storage keys.
pub fn test_erc20_sequencer_balance_key() -> StorageKey {
get_fee_token_var_address(contract_address!(TEST_SEQUENCER_ADDRESS))
Expand Down
33 changes: 33 additions & 0 deletions crates/blockifier/src/test_utils/syscall.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
use starknet_api::felt;
use starknet_api::transaction::Calldata;

use crate::test_utils::{create_calldata, CompilerBasedVersion};

/// Returns the calldata for N recursive call contract syscalls, where N is the length of versions.
/// versions determines the cairo version of the called contract in each recursive call. Final call
/// is a simple local contract call (test_storage_read_write).
/// The first element in the returned value is the calldata for a call from a contract of the first
/// element in versions, to the a contract of the second element, etc.
pub fn build_recurse_calldata(versions: &[CompilerBasedVersion]) -> Calldata {
if versions.is_empty() {
return Calldata(vec![].into());
}
let last_version = versions.last().unwrap();
let mut calldata = create_calldata(
last_version.get_test_contract().get_instance_address(0),
"test_storage_read_write",
&[
felt!(123_u16), // Calldata: address.
felt!(45_u8), // Calldata: value.
],
);

for version in versions[..versions.len() - 1].iter().rev() {
calldata = create_calldata(
version.get_test_contract().get_instance_address(0),
"test_call_contract",
&calldata.0,
);
}
calldata
}
Loading