Skip to content

Commit

Permalink
feat(blockifier): add TrackingResource to CallInfo and use to update …
Browse files Browse the repository at this point in the history
…gas (#928)
  • Loading branch information
TzahiTaub authored Sep 29, 2024
1 parent 3ea49d6 commit b67bf97
Show file tree
Hide file tree
Showing 8 changed files with 87 additions and 17 deletions.
2 changes: 2 additions & 0 deletions crates/blockifier/src/execution/call_info.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use starknet_api::state::StorageKey;
use starknet_api::transaction::{EventContent, L2ToL1Payload};
use starknet_types_core::felt::Felt;

use crate::execution::contract_class::TrackedResource;
use crate::execution::entry_point::CallEntryPoint;
use crate::state::cached_state::StorageEntry;
use crate::utils::u128_from_usize;
Expand Down Expand Up @@ -102,6 +103,7 @@ pub struct CallInfo {
pub execution: CallExecution,
pub resources: ExecutionResources,
pub inner_calls: Vec<CallInfo>,
pub tracked_resource: TrackedResource,

// Additional information gathered during execution.
pub storage_read_values: Vec<Felt>,
Expand Down
25 changes: 13 additions & 12 deletions crates/blockifier/src/execution/contract_class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,8 +53,9 @@ pub mod test;
pub type ContractClassResult<T> = Result<T, ContractClassError>;

/// The resource used to run a contract function.
#[cfg_attr(feature = "transaction_serde", derive(serde::Deserialize))]
#[derive(Clone, Copy, Default, Debug, Eq, PartialEq, Serialize)]
pub enum TrackingResource {
pub enum TrackedResource {
#[default]
CairoSteps, // AKA VM mode.
SierraGas, // AKA Sierra mode.
Expand Down Expand Up @@ -118,11 +119,11 @@ impl ContractClass {
}

/// Returns whether this contract should run using Cairo steps or Sierra gas.
pub fn tracking_resource(&self, min_sierra_version: &CompilerVersion) -> TrackingResource {
pub fn tracked_resource(&self, min_sierra_version: &CompilerVersion) -> TrackedResource {
match self {
ContractClass::V0(_) => TrackingResource::CairoSteps,
ContractClass::V0(_) => TrackedResource::CairoSteps,
ContractClass::V1(contract_class) => {
contract_class.tracking_resource(min_sierra_version)
contract_class.tracked_resource(min_sierra_version)
}
}
}
Expand Down Expand Up @@ -177,8 +178,8 @@ impl ContractClassV0 {
}
}

pub fn tracking_resource(&self) -> TrackingResource {
TrackingResource::CairoSteps
pub fn tracked_resource(&self) -> TrackedResource {
TrackedResource::CairoSteps
}

pub fn try_from_json_string(raw_contract_class: &str) -> Result<ContractClassV0, ProgramError> {
Expand Down Expand Up @@ -260,11 +261,11 @@ impl ContractClassV1 {
}

/// Returns whether this contract should run using Cairo steps or Sierra gas.
pub fn tracking_resource(&self, min_sierra_version: &CompilerVersion) -> TrackingResource {
pub fn tracked_resource(&self, min_sierra_version: &CompilerVersion) -> TrackedResource {
if *min_sierra_version <= self.compiler_version {
TrackingResource::SierraGas
TrackedResource::SierraGas
} else {
TrackingResource::CairoSteps
TrackedResource::CairoSteps
}
}

Expand Down Expand Up @@ -422,7 +423,7 @@ impl TryFrom<CasmContractClass> for ContractClassV1 {
type Error = ProgramError;

fn try_from(class: CasmContractClass) -> Result<Self, Self::Error> {
try_from_casm_contrcat_class_internal(
try_from_casm_contract_class_internal(
&class.bytecode,
&class.hints,
&class.entry_points_by_type,
Expand All @@ -436,7 +437,7 @@ impl TryFrom<&CasmContractClass> for ContractClassV1 {
type Error = ProgramError;

fn try_from(class: &CasmContractClass) -> Result<Self, Self::Error> {
try_from_casm_contrcat_class_internal(
try_from_casm_contract_class_internal(
&class.bytecode,
&class.hints,
&class.entry_points_by_type,
Expand All @@ -459,7 +460,7 @@ pub fn deserialize_program<'de, D: Deserializer<'de>>(

// V1 utilities.

fn try_from_casm_contrcat_class_internal(
fn try_from_casm_contract_class_internal(
bytecode: &[BigUintAsHex],
casm_class_hints: &[(usize, Vec<Hint>)],
casm_class_entry_points_by_type: &CasmContractEntryPoints,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use super::execution_utils::SEGMENT_ARENA_BUILTIN_SIZE;
use crate::abi::abi_utils::selector_from_name;
use crate::abi::constants::{CONSTRUCTOR_ENTRY_POINT_NAME, DEFAULT_ENTRY_POINT_SELECTOR};
use crate::execution::call_info::{CallExecution, CallInfo};
use crate::execution::contract_class::ContractClassV0;
use crate::execution::contract_class::{ContractClassV0, TrackedResource};
use crate::execution::deprecated_syscalls::hint_processor::DeprecatedSyscallHintProcessor;
use crate::execution::entry_point::{
CallEntryPoint,
Expand Down Expand Up @@ -280,6 +280,7 @@ pub fn finalize_execution(
},
resources: full_call_resources.filter_unused_builtins(),
inner_calls: syscall_handler.inner_calls,
tracked_resource: TrackedResource::CairoSteps,
storage_read_values: syscall_handler.read_values,
accessed_storage_keys: syscall_handler.accessed_keys,
})
Expand Down
8 changes: 6 additions & 2 deletions crates/blockifier/src/execution/entry_point.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,13 @@ use crate::abi::constants;
use crate::context::{BlockContext, TransactionContext};
use crate::execution::call_info::CallInfo;
use crate::execution::common_hints::ExecutionMode;
use crate::execution::contract_class::TrackedResource;
use crate::execution::errors::{
ConstructorEntryPointExecutionError,
EntryPointExecutionError,
PreExecutionError,
};
use crate::execution::execution_utils::execute_entry_point_call;
use crate::execution::execution_utils::execute_entry_point_call_wrapper;
use crate::state::state_api::State;
use crate::transaction::objects::{HasRelatedFeeType, TransactionInfo};
use crate::transaction::transaction_types::TransactionType;
Expand Down Expand Up @@ -102,7 +103,7 @@ impl CallEntryPoint {
self.class_hash = Some(class_hash);
let contract_class = state.get_compiled_contract_class(class_hash)?;

execute_entry_point_call(self, contract_class, state, resources, context)
execute_entry_point_call_wrapper(self, contract_class, state, resources, context)
}
}

Expand Down Expand Up @@ -130,6 +131,8 @@ pub struct EntryPointExecutionContext {

// The execution mode affects the behavior of the hint processor.
pub execution_mode: ExecutionMode,
// The call stack of tracked resources from the first entry point to the current.
pub tracked_resource_stack: Vec<TrackedResource>,
}

impl EntryPointExecutionContext {
Expand All @@ -146,6 +149,7 @@ impl EntryPointExecutionContext {
tx_context: tx_context.clone(),
current_recursion_depth: Default::default(),
execution_mode: mode,
tracked_resource_stack: vec![],
}
}

Expand Down
7 changes: 6 additions & 1 deletion crates/blockifier/src/execution/entry_point_execution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use starknet_api::felt;
use starknet_types_core::felt::Felt;

use crate::execution::call_info::{CallExecution, CallInfo, Retdata};
use crate::execution::contract_class::{ContractClassV1, EntryPointV1};
use crate::execution::contract_class::{ContractClassV1, EntryPointV1, TrackedResource};
use crate::execution::entry_point::{
CallEntryPoint,
EntryPointExecutionContext,
Expand Down Expand Up @@ -63,6 +63,8 @@ pub fn execute_entry_point_call(
"Class hash must not be None when executing an entry point.".into(),
))?;

let tracked_resource =
*context.tracked_resource_stack.last().expect("Unexpected empty tracked resource.");
let VmExecutionContext {
mut runner,
mut syscall_handler,
Expand Down Expand Up @@ -103,6 +105,7 @@ pub fn execute_entry_point_call(
previous_resources,
n_total_args,
program_extra_data_length,
tracked_resource,
)?;
if call_info.execution.failed {
return Err(EntryPointExecutionError::ExecutionFailed {
Expand Down Expand Up @@ -373,6 +376,7 @@ pub fn finalize_execution(
previous_resources: ExecutionResources,
n_total_args: usize,
program_extra_data_length: usize,
tracked_resource: TrackedResource,
) -> Result<CallInfo, PostExecutionError> {
// Close memory holes in segments (OS code touches those memory cells, we simulate it).
let program_start_ptr = runner
Expand Down Expand Up @@ -421,6 +425,7 @@ pub fn finalize_execution(
},
resources: full_call_resources.filter_unused_builtins(),
inner_calls: syscall_handler.inner_calls,
tracked_resource,
storage_read_values: syscall_handler.read_values,
accessed_storage_keys: syscall_handler.accessed_keys,
})
Expand Down
27 changes: 26 additions & 1 deletion crates/blockifier/src/execution/execution_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ use starknet_types_core::felt::Felt;
use super::entry_point::ConstructorEntryPointExecutionResult;
use super::errors::ConstructorEntryPointExecutionError;
use crate::execution::call_info::{CallInfo, Retdata};
use crate::execution::contract_class::ContractClass;
use crate::execution::contract_class::{ContractClass, TrackedResource};
use crate::execution::entry_point::{
execute_constructor_entry_point,
CallEntryPoint,
Expand All @@ -43,6 +43,31 @@ pub type Args = Vec<CairoArg>;

pub const SEGMENT_ARENA_BUILTIN_SIZE: usize = 3;

/// A wrapper for execute_entry_point_call that performs pre and post-processing.
pub fn execute_entry_point_call_wrapper(
call: CallEntryPoint,
contract_class: ContractClass,
state: &mut dyn State,
resources: &mut ExecutionResources,
context: &mut EntryPointExecutionContext,
) -> EntryPointExecutionResult<CallInfo> {
let tracked_resource = contract_class
.tracked_resource(&context.versioned_constants().min_compiler_version_for_sierra_gas);
// Note: no return statements (explicit or implicit) should be added between the push and the
// pop commands.

// Once we ran with CairoSteps, we will continue to run using it for all nested calls.
if context.tracked_resource_stack.last().is_some_and(|x| *x == TrackedResource::CairoSteps) {
context.tracked_resource_stack.push(TrackedResource::CairoSteps);
} else {
context.tracked_resource_stack.push(tracked_resource);
}

let res = execute_entry_point_call(call, contract_class, state, resources, context);
context.tracked_resource_stack.pop();
res
}

/// Executes a specific call to a contract entry point and returns its output.
pub fn execute_entry_point_call(
call: CallEntryPoint,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ use crate::test_utils::{
CairoVersion,
BALANCE,
};
use crate::versioned_constants::VersionedConstants;

#[test_case(FeatureContract::TestContract(CairoVersion::Cairo1), REQUIRED_GAS_LIBRARY_CALL_TEST; "VM")]
fn test_library_call(test_contract: FeatureContract, expected_gas: u64) {
Expand Down Expand Up @@ -142,6 +143,12 @@ fn test_nested_library_call(test_contract: FeatureContract, expected_gas: u64) {
n_memory_holes: 0,
builtin_instance_counter: HashMap::from([(BuiltinName::range_check, 7)]),
};

// The default VersionedConstants is used in the execute_directly call bellow.
let tracked_resource = test_contract.get_class().tracked_resource(
&VersionedConstants::create_for_testing().min_compiler_version_for_sierra_gas,
);

let nested_storage_call_info = CallInfo {
call: nested_storage_entry_point,
execution: CallExecution {
Expand All @@ -150,6 +157,7 @@ fn test_nested_library_call(test_contract: FeatureContract, expected_gas: u64) {
..CallExecution::default()
},
resources: storage_entry_point_resources.clone(),
tracked_resource,
storage_read_values: vec![felt!(value + 1)],
accessed_storage_keys: HashSet::from([StorageKey(patricia_key!(key + 1))]),
..Default::default()
Expand All @@ -161,6 +169,7 @@ fn test_nested_library_call(test_contract: FeatureContract, expected_gas: u64) {
n_memory_holes: 0,
builtin_instance_counter: HashMap::from([(BuiltinName::range_check, 15)]),
};

let library_call_info = CallInfo {
call: library_entry_point,
execution: CallExecution {
Expand All @@ -170,6 +179,7 @@ fn test_nested_library_call(test_contract: FeatureContract, expected_gas: u64) {
},
resources: library_call_resources,
inner_calls: vec![nested_storage_call_info],
tracked_resource,
..Default::default()
};

Expand All @@ -183,6 +193,7 @@ fn test_nested_library_call(test_contract: FeatureContract, expected_gas: u64) {
resources: storage_entry_point_resources,
storage_read_values: vec![felt!(value)],
accessed_storage_keys: HashSet::from([StorageKey(patricia_key!(key))]),
tracked_resource,
..Default::default()
};

Expand All @@ -201,6 +212,7 @@ fn test_nested_library_call(test_contract: FeatureContract, expected_gas: u64) {
},
resources: main_call_resources,
inner_calls: vec![library_call_info, storage_call_info],
tracked_resource,
..Default::default()
};

Expand Down
20 changes: 20 additions & 0 deletions crates/blockifier/src/transaction/transactions_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ use crate::execution::call_info::{
OrderedL2ToL1Message,
Retdata,
};
use crate::execution::contract_class::TrackedResource;
use crate::execution::entry_point::{CallEntryPoint, CallType};
use crate::execution::errors::{ConstructorEntryPointExecutionError, EntryPointExecutionError};
use crate::execution::syscalls::hint_processor::EmitEventError;
Expand Down Expand Up @@ -173,6 +174,7 @@ fn expected_validate_call_info(
calldata: Calldata,
storage_address: ContractAddress,
cairo_version: CairoVersion,
tracked_resource: TrackedResource,
) -> Option<CallInfo> {
let retdata = match cairo_version {
CairoVersion::Cairo0 => Retdata::default(),
Expand Down Expand Up @@ -218,6 +220,7 @@ fn expected_validate_call_info(
// The account contract we use for testing has trivial `validate` functions.
resources,
execution: CallExecution { retdata, gas_consumed, ..Default::default() },
tracked_resource,
..Default::default()
})
}
Expand Down Expand Up @@ -444,6 +447,10 @@ fn test_invoke_tx(

let actual_execution_info = account_tx.execute(state, block_context, true, true).unwrap();

let tracked_resource = account_contract
.get_class()
.tracked_resource(&versioned_constants.min_compiler_version_for_sierra_gas);

// Build expected validate call info.
let expected_account_class_hash = account_contract.get_class_hash();
let expected_validate_call_info = expected_validate_call_info(
Expand All @@ -453,6 +460,7 @@ fn test_invoke_tx(
calldata,
sender_address,
account_cairo_version,
tracked_resource,
);

// Build expected execute call info.
Expand Down Expand Up @@ -488,6 +496,7 @@ fn test_invoke_tx(
resources: ExecutionResources { n_steps: 23, n_memory_holes: 0, ..Default::default() },
..Default::default()
}],
tracked_resource,
..Default::default()
});

Expand Down Expand Up @@ -1255,6 +1264,7 @@ fn declare_validate_callinfo(
declared_class_hash: ClassHash,
account_class_hash: ClassHash,
account_address: ContractAddress,
tracked_resource: TrackedResource,
) -> Option<CallInfo> {
// V0 transactions do not run validate.
if version == TransactionVersion::ZERO {
Expand All @@ -1267,6 +1277,7 @@ fn declare_validate_callinfo(
calldata![declared_class_hash.0],
account_address,
declared_contract_version,
tracked_resource,
)
}
}
Expand Down Expand Up @@ -1360,6 +1371,9 @@ fn test_declare_tx(
class_hash,
account.get_class_hash(),
sender_address,
account
.get_class()
.tracked_resource(&versioned_constants.min_compiler_version_for_sierra_gas),
);

// Build expected fee transfer call info.
Expand Down Expand Up @@ -1512,6 +1526,9 @@ fn test_deploy_account_tx(
Calldata(validate_calldata.into()),
deployed_account_address,
cairo_version,
account
.get_class()
.tracked_resource(&versioned_constants.min_compiler_version_for_sierra_gas),
);

// Build expected execute call info.
Expand Down Expand Up @@ -2032,6 +2049,9 @@ fn test_l1_handler(#[values(false, true)] use_kzg_da: bool) {
builtin_instance_counter: HashMap::from([(BuiltinName::range_check, 6)]),
},
accessed_storage_keys: HashSet::from_iter(vec![accessed_storage_key]),
tracked_resource: test_contract
.get_class()
.tracked_resource(&versioned_constants.min_compiler_version_for_sierra_gas),
..Default::default()
};

Expand Down

0 comments on commit b67bf97

Please sign in to comment.