Skip to content

Commit

Permalink
refactor(blockifier): rename contract class to runnable contract class (
Browse files Browse the repository at this point in the history
  • Loading branch information
noaov1 authored Nov 3, 2024
1 parent 9d8f82b commit 148bc1e
Show file tree
Hide file tree
Showing 35 changed files with 225 additions and 167 deletions.
15 changes: 11 additions & 4 deletions crates/batcher/src/papyrus_state.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
// TODO(yael 22/9/2024): This module is copied from native_blockifier, need to how to share it
// between the crates.
use blockifier::execution::contract_class::{ContractClass, ContractClassV0, ContractClassV1};
use blockifier::execution::contract_class::{
ContractClassV0,
ContractClassV1,
RunnableContractClass,
};
use blockifier::state::errors::StateError;
use blockifier::state::global_cache::GlobalContractCache;
use blockifier::state::state_api::{StateReader, StateResult};
Expand Down Expand Up @@ -40,7 +44,7 @@ impl PapyrusReader {
fn get_compiled_contract_class_inner(
&self,
class_hash: ClassHash,
) -> StateResult<ContractClass> {
) -> StateResult<RunnableContractClass> {
let state_number = StateNumber(self.latest_block);
let class_declaration_block_number = self
.reader()?
Expand All @@ -60,7 +64,7 @@ impl PapyrusReader {
inconsistent.",
);

return Ok(ContractClass::V1(ContractClassV1::try_from(casm_contract_class)?));
return Ok(RunnableContractClass::V1(ContractClassV1::try_from(casm_contract_class)?));
}

let v0_contract_class = self
Expand Down Expand Up @@ -118,7 +122,10 @@ impl StateReader for PapyrusReader {
}
}

fn get_compiled_contract_class(&self, class_hash: ClassHash) -> StateResult<ContractClass> {
fn get_compiled_contract_class(
&self,
class_hash: ClassHash,
) -> StateResult<RunnableContractClass> {
// Assumption: the global cache is cleared upon reverted blocks.
let contract_class = self.global_class_hash_to_class.get(&class_hash);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ fn test_declare(
version: tx_version,
resource_bounds: l1_resource_bounds(0_u8.into(), DEFAULT_STRK_L1_GAS_PRICE.into()),
},
calculate_class_info_for_testing(declared_contract.get_class()),
calculate_class_info_for_testing(declared_contract.get_runnable_class()),
)
.into();
tx_executor_test_body(state, block_context, tx, expected_bouncer_weights);
Expand Down
9 changes: 6 additions & 3 deletions crates/blockifier/src/concurrency/versioned_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use starknet_types_core::felt::Felt;

use crate::concurrency::versioned_storage::VersionedStorage;
use crate::concurrency::TxIndex;
use crate::execution::contract_class::ContractClass;
use crate::execution::contract_class::RunnableContractClass;
use crate::state::cached_state::{ContractClassMapping, StateMaps};
use crate::state::errors::StateError;
use crate::state::state_api::{StateReader, StateResult, UpdatableState};
Expand All @@ -34,7 +34,7 @@ pub struct VersionedState<S: StateReader> {
// the compiled contract classes mapping. Each key with value false, sohuld not apprear
// in the compiled contract classes mapping.
declared_contracts: VersionedStorage<ClassHash, bool>,
compiled_contract_classes: VersionedStorage<ClassHash, ContractClass>,
compiled_contract_classes: VersionedStorage<ClassHash, RunnableContractClass>,
}

impl<S: StateReader> VersionedState<S> {
Expand Down Expand Up @@ -336,7 +336,10 @@ impl<S: StateReader> StateReader for VersionedStateProxy<S> {
}
}

fn get_compiled_contract_class(&self, class_hash: ClassHash) -> StateResult<ContractClass> {
fn get_compiled_contract_class(
&self,
class_hash: ClassHash,
) -> StateResult<RunnableContractClass> {
let mut state = self.state();
match state.compiled_contract_classes.read(self.tx_index, class_hash) {
Some(value) => Ok(value),
Expand Down
21 changes: 13 additions & 8 deletions crates/blockifier/src/concurrency/versioned_state_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ fn test_versioned_state_proxy() {
let class_hash = class_hash!(27_u8);
let another_class_hash = class_hash!(28_u8);
let compiled_class_hash = compiled_class_hash!(29_u8);
let contract_class = test_contract.get_class();
let contract_class = test_contract.get_runnable_class();

// Create the versioned state
let cached_state = CachedState::from(DictStateReader {
Expand Down Expand Up @@ -118,7 +118,8 @@ fn test_versioned_state_proxy() {
let class_hash_v7 = class_hash!(28_u8);
let class_hash_v10 = class_hash!(29_u8);
let compiled_class_hash_v18 = compiled_class_hash!(30_u8);
let contract_class_v11 = FeatureContract::TestContract(CairoVersion::Cairo1).get_class();
let contract_class_v11 =
FeatureContract::TestContract(CairoVersion::Cairo1).get_runnable_class();

versioned_state_proxys[3].state().apply_writes(
3,
Expand Down Expand Up @@ -404,7 +405,8 @@ fn test_false_validate_reads_declared_contracts(
..Default::default()
};
let version_state_proxy = safe_versioned_state.pin_version(0);
let compiled_contract_calss = FeatureContract::TestContract(CairoVersion::Cairo1).get_class();
let compiled_contract_calss =
FeatureContract::TestContract(CairoVersion::Cairo1).get_runnable_class();
let class_hash_to_class = HashMap::from([(class_hash!(1_u8), compiled_contract_calss)]);
version_state_proxy.state().apply_writes(0, &tx_0_writes, &class_hash_to_class);
assert!(!safe_versioned_state.pin_version(1).validate_reads(&tx_1_reads));
Expand All @@ -429,7 +431,7 @@ fn test_apply_writes(
assert_eq!(transactional_states[0].cache.borrow().writes.class_hashes.len(), 1);

// Transaction 0 contract class.
let contract_class_0 = FeatureContract::TestContract(CairoVersion::Cairo1).get_class();
let contract_class_0 = FeatureContract::TestContract(CairoVersion::Cairo1).get_runnable_class();
assert!(transactional_states[0].class_hash_to_class.borrow().is_empty());
transactional_states[0].set_contract_class(class_hash, contract_class_0.clone()).unwrap();
assert_eq!(transactional_states[0].class_hash_to_class.borrow().len(), 1);
Expand Down Expand Up @@ -509,7 +511,10 @@ fn test_delete_writes(
}
// Modify the `class_hash_to_class` member of the CachedState.
tx_state
.set_contract_class(feature_contract.get_class_hash(), feature_contract.get_class())
.set_contract_class(
feature_contract.get_class_hash(),
feature_contract.get_runnable_class(),
)
.unwrap();
safe_versioned_state.pin_version(i).apply_writes(
&tx_state.cache.borrow().writes,
Expand Down Expand Up @@ -568,7 +573,7 @@ fn test_delete_writes_completeness(
declared_contracts: HashMap::from([(feature_contract.get_class_hash(), true)]),
};
let class_hash_to_class_writes =
HashMap::from([(feature_contract.get_class_hash(), feature_contract.get_class())]);
HashMap::from([(feature_contract.get_class_hash(), feature_contract.get_runnable_class())]);

let tx_index = 0;
let mut versioned_state_proxy = safe_versioned_state.pin_version(tx_index);
Expand Down Expand Up @@ -631,9 +636,9 @@ fn test_versioned_proxy_state_flow(
transactional_states[3].set_class_hash_at(contract_address, class_hash_3).unwrap();

// Clients contract class values.
let contract_class_0 = FeatureContract::TestContract(CairoVersion::Cairo0).get_class();
let contract_class_0 = FeatureContract::TestContract(CairoVersion::Cairo0).get_runnable_class();
let contract_class_2 =
FeatureContract::AccountWithLongValidate(CairoVersion::Cairo1).get_class();
FeatureContract::AccountWithLongValidate(CairoVersion::Cairo1).get_runnable_class();

transactional_states[0].set_contract_class(class_hash, contract_class_0).unwrap();
transactional_states[2].set_contract_class(class_hash, contract_class_2.clone()).unwrap();
Expand Down
2 changes: 1 addition & 1 deletion crates/blockifier/src/concurrency/worker_logic_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -552,7 +552,7 @@ fn test_deploy_before_declare(
let account_address_1 = account_contract.get_instance_address(1);
let test_contract = FeatureContract::TestContract(cairo_version);
let test_class_hash = test_contract.get_class_hash();
let test_class_info = calculate_class_info_for_testing(test_contract.get_class());
let test_class_info = calculate_class_info_for_testing(test_contract.get_runnable_class());
let test_compiled_class_hash = test_contract.get_compiled_class_hash();
let declare_tx = declare_tx(
declare_tx_args! {
Expand Down
64 changes: 29 additions & 35 deletions crates/blockifier/src/execution/contract_class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use itertools::Itertools;
use semver::Version;
use serde::de::Error as DeserializationError;
use serde::{Deserialize, Deserializer, Serialize};
use starknet_api::contract_class::{ContractClass as RawContractClass, EntryPointType};
use starknet_api::contract_class::{ContractClass, EntryPointType};
use starknet_api::core::EntryPointSelector;
use starknet_api::deprecated_contract_class::{
ContractClass as DeprecatedContractClass,
Expand Down Expand Up @@ -62,46 +62,42 @@ pub enum TrackedResource {

/// Represents a runnable Starknet contract class (meaning, the program is runnable by the VM).
#[derive(Clone, Debug, Eq, PartialEq, derive_more::From)]
pub enum ContractClass {
pub enum RunnableContractClass {
V0(ContractClassV0),
V1(ContractClassV1),
#[cfg(feature = "cairo_native")]
V1Native(NativeContractClassV1),
}

impl TryFrom<RawContractClass> for ContractClass {
impl TryFrom<ContractClass> for RunnableContractClass {
type Error = ProgramError;

fn try_from(raw_contract_class: RawContractClass) -> Result<Self, Self::Error> {
let contract_class: ContractClass = match raw_contract_class {
RawContractClass::V0(raw_contract_class) => {
ContractClass::V0(raw_contract_class.try_into()?)
}
RawContractClass::V1(raw_contract_class) => {
ContractClass::V1(raw_contract_class.try_into()?)
}
fn try_from(raw_contract_class: ContractClass) -> Result<Self, Self::Error> {
let contract_class: Self = match raw_contract_class {
ContractClass::V0(raw_contract_class) => Self::V0(raw_contract_class.try_into()?),
ContractClass::V1(raw_contract_class) => Self::V1(raw_contract_class.try_into()?),
};

Ok(contract_class)
}
}

impl ContractClass {
impl RunnableContractClass {
pub fn constructor_selector(&self) -> Option<EntryPointSelector> {
match self {
ContractClass::V0(class) => class.constructor_selector(),
ContractClass::V1(class) => class.constructor_selector(),
Self::V0(class) => class.constructor_selector(),
Self::V1(class) => class.constructor_selector(),
#[cfg(feature = "cairo_native")]
ContractClass::V1Native(class) => class.constructor_selector(),
Self::V1Native(class) => class.constructor_selector(),
}
}

pub fn estimate_casm_hash_computation_resources(&self) -> ExecutionResources {
match self {
ContractClass::V0(class) => class.estimate_casm_hash_computation_resources(),
ContractClass::V1(class) => class.estimate_casm_hash_computation_resources(),
Self::V0(class) => class.estimate_casm_hash_computation_resources(),
Self::V1(class) => class.estimate_casm_hash_computation_resources(),
#[cfg(feature = "cairo_native")]
ContractClass::V1Native(_) => {
Self::V1Native(_) => {
todo!("Use casm to estimate casm hash computation resources")
}
}
Expand All @@ -112,23 +108,23 @@ impl ContractClass {
visited_pcs: &HashSet<usize>,
) -> Result<Vec<usize>, TransactionExecutionError> {
match self {
ContractClass::V0(_) => {
Self::V0(_) => {
panic!("get_visited_segments is not supported for v0 contracts.")
}
ContractClass::V1(class) => class.get_visited_segments(visited_pcs),
Self::V1(class) => class.get_visited_segments(visited_pcs),
#[cfg(feature = "cairo_native")]
ContractClass::V1Native(_) => {
Self::V1Native(_) => {
panic!("get_visited_segments is not supported for native contracts.")
}
}
}

pub fn bytecode_length(&self) -> usize {
match self {
ContractClass::V0(class) => class.bytecode_length(),
ContractClass::V1(class) => class.bytecode_length(),
Self::V0(class) => class.bytecode_length(),
Self::V1(class) => class.bytecode_length(),
#[cfg(feature = "cairo_native")]
ContractClass::V1Native(_) => {
Self::V1Native(_) => {
todo!("implement bytecode_length for native contracts.")
}
}
Expand All @@ -137,12 +133,10 @@ impl ContractClass {
/// Returns whether this contract should run using Cairo steps or Sierra gas.
pub fn tracked_resource(&self, min_sierra_version: &CompilerVersion) -> TrackedResource {
match self {
ContractClass::V0(_) => TrackedResource::CairoSteps,
ContractClass::V1(contract_class) => {
contract_class.tracked_resource(min_sierra_version)
}
Self::V0(_) => TrackedResource::CairoSteps,
Self::V1(contract_class) => contract_class.tracked_resource(min_sierra_version),
#[cfg(feature = "cairo_native")]
ContractClass::V1Native(_) => TrackedResource::SierraGas,
Self::V1Native(_) => TrackedResource::SierraGas,
}
}
}
Expand Down Expand Up @@ -530,7 +524,7 @@ fn convert_entry_points_v1(external: &[CasmContractEntryPoint]) -> Vec<EntryPoin
#[derive(Clone, Debug)]
// TODO(Ayelet,10/02/2024): Change to bytes.
pub struct ClassInfo {
contract_class: ContractClass,
contract_class: RunnableContractClass,
sierra_program_length: usize,
abi_length: usize,
}
Expand All @@ -554,7 +548,7 @@ impl ClassInfo {
self.contract_class.bytecode_length()
}

pub fn contract_class(&self) -> ContractClass {
pub fn contract_class(&self) -> RunnableContractClass {
self.contract_class.clone()
}

Expand All @@ -574,15 +568,15 @@ impl ClassInfo {
}

pub fn new(
contract_class: &ContractClass,
contract_class: &RunnableContractClass,
sierra_program_length: usize,
abi_length: usize,
) -> ContractClassResult<Self> {
let (contract_class_version, condition) = match contract_class {
ContractClass::V0(_) => (0, sierra_program_length == 0),
ContractClass::V1(_) => (1, sierra_program_length > 0),
RunnableContractClass::V0(_) => (0, sierra_program_length == 0),
RunnableContractClass::V1(_) => (1, sierra_program_length > 0),
#[cfg(feature = "cairo_native")]
ContractClass::V1Native(_) => (1, sierra_program_length > 0),
RunnableContractClass::V1Native(_) => (1, sierra_program_length > 0),
};

if condition {
Expand Down
26 changes: 14 additions & 12 deletions crates/blockifier/src/execution/execution_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ use super::errors::{
};
use super::syscalls::hint_processor::ENTRYPOINT_NOT_FOUND_ERROR;
use crate::execution::call_info::{CallInfo, Retdata};
use crate::execution::contract_class::{ContractClass, TrackedResource};
use crate::execution::contract_class::{RunnableContractClass, TrackedResource};
use crate::execution::entry_point::{
execute_constructor_entry_point,
CallEntryPoint,
Expand All @@ -54,7 +54,7 @@ pub const SEGMENT_ARENA_BUILTIN_SIZE: usize = 3;
/// A wrapper for execute_entry_point_call that performs pre and post-processing.
pub fn execute_entry_point_call_wrapper(
mut call: CallEntryPoint,
contract_class: ContractClass,
contract_class: RunnableContractClass,
state: &mut dyn State,
resources: &mut ExecutionResources,
context: &mut EntryPointExecutionContext,
Expand Down Expand Up @@ -118,13 +118,13 @@ pub fn execute_entry_point_call_wrapper(
/// Executes a specific call to a contract entry point and returns its output.
pub fn execute_entry_point_call(
call: CallEntryPoint,
contract_class: ContractClass,
contract_class: RunnableContractClass,
state: &mut dyn State,
resources: &mut ExecutionResources,
context: &mut EntryPointExecutionContext,
) -> EntryPointExecutionResult<CallInfo> {
match contract_class {
ContractClass::V0(contract_class) => {
RunnableContractClass::V0(contract_class) => {
deprecated_entry_point_execution::execute_entry_point_call(
call,
contract_class,
Expand All @@ -133,15 +133,17 @@ pub fn execute_entry_point_call(
context,
)
}
ContractClass::V1(contract_class) => entry_point_execution::execute_entry_point_call(
call,
contract_class,
state,
resources,
context,
),
RunnableContractClass::V1(contract_class) => {
entry_point_execution::execute_entry_point_call(
call,
contract_class,
state,
resources,
context,
)
}
#[cfg(feature = "cairo_native")]
ContractClass::V1Native(contract_class) => {
RunnableContractClass::V1Native(contract_class) => {
if context.tracked_resource_stack.last() == Some(&TrackedResource::CairoSteps) {
// We cannot run native with cairo steps as the tracked resources (it's a vm
// resouorce).
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ fn test_nested_library_call(test_contract: FeatureContract, expected_gas: u64) {
};

// The default VersionedConstants is used in the execute_directly call bellow.
let tracked_resource = test_contract.get_class().tracked_resource(
let tracked_resource = test_contract.get_runnable_class().tracked_resource(
&VersionedConstants::create_for_testing().min_compiler_version_for_sierra_gas,
);

Expand Down
2 changes: 1 addition & 1 deletion crates/blockifier/src/fee/receipt_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ fn test_calculate_tx_gas_usage_basic<'a>(

// Declare.
for cairo_version in [CairoVersion::Cairo0, CairoVersion::Cairo1] {
let empty_contract = FeatureContract::Empty(cairo_version).get_class();
let empty_contract = FeatureContract::Empty(cairo_version).get_runnable_class();
let class_info = calculate_class_info_for_testing(empty_contract);
let declare_tx_starknet_resources = StarknetResources::new(
0,
Expand Down
Loading

0 comments on commit 148bc1e

Please sign in to comment.