From 10f1542a5c25ee086dd19caba2ed61c7e33c985f Mon Sep 17 00:00:00 2001 From: Yonatan Iluz Date: Sun, 13 Oct 2024 22:18:16 +0300 Subject: [PATCH] chore(blockifier): share EntryPoint code between native and casm --- .../src/execution/contract_class.rs | 171 +++++++----------- crates/blockifier/src/test_utils/contracts.rs | 10 +- 2 files changed, 76 insertions(+), 105 deletions(-) diff --git a/crates/blockifier/src/execution/contract_class.rs b/crates/blockifier/src/execution/contract_class.rs index 4b75d12c57..48ec3d607e 100644 --- a/crates/blockifier/src/execution/contract_class.rs +++ b/crates/blockifier/src/execution/contract_class.rs @@ -63,21 +63,6 @@ pub enum TrackedResource { SierraGas, // AKA Sierra mode. } -#[derive(Clone)] -pub enum Cairo1EntryPoint { - Casm(EntryPointV1), - Native(NativeEntryPoint), -} - -impl Cairo1EntryPoint { - pub fn selector(&self) -> &EntryPointSelector { - match self { - Cairo1EntryPoint::Casm(ep) => &ep.selector, - Cairo1EntryPoint::Native(ep) => &ep.selector, - } - } -} - /// Represents a runnable Starknet contract class (meaning, the program is runnable by the VM). #[derive(Clone, Debug, Eq, PartialEq, derive_more::From)] pub enum ContractClass { @@ -86,28 +71,6 @@ pub enum ContractClass { V1Native(NativeContractClassV1), } -pub fn get_entry_point( - contract_class: &ContractClass, - call: &CallEntryPoint, -) -> Result { - call.verify_constructor()?; - - let entry_points_of_same_type = contract_class.entry_points_of_same_type(call.entry_point_type); - let filtered_entry_points: Vec<_> = entry_points_of_same_type - .iter() - .filter(|ep| *ep.selector() == call.entry_point_selector) - .collect(); - - match &filtered_entry_points[..] { - [] => Err(PreExecutionError::EntryPointNotFound(call.entry_point_selector)), - [entry_point] => Ok((**entry_point).clone()), - _ => Err(PreExecutionError::DuplicatedEntryPointSelector { - selector: call.entry_point_selector, - typ: call.entry_point_type, - }), - } -} - impl TryFrom for ContractClass { type Error = ProgramError; @@ -144,23 +107,6 @@ impl ContractClass { } } - pub fn entry_points_of_same_type( - &self, - entry_point_type: EntryPointType, - ) -> Vec { - match self { - ContractClass::V0(_) => panic!("V0 contracts do not support entry points."), - ContractClass::V1(class) => class.entry_points_by_type[&entry_point_type] - .iter() - .map(|ep| Cairo1EntryPoint::Casm(ep.clone())) - .collect(), - ContractClass::V1Native(class) => class.entry_points_by_type[entry_point_type] - .iter() - .map(|ep| Cairo1EntryPoint::Native(ep.clone())) - .collect(), - } - } - pub fn get_visited_segments( &self, visited_pcs: &HashSet, @@ -292,7 +238,7 @@ impl Deref for ContractClassV1 { impl ContractClassV1 { fn constructor_selector(&self) -> Option { - Some(self.0.entry_points_by_type[&EntryPointType::Constructor].first()?.selector) + self.0.entry_points_by_type.constructor.first().map(|ep| ep.selector) } pub fn bytecode_length(&self) -> usize { @@ -307,10 +253,7 @@ impl ContractClassV1 { &self, call: &CallEntryPoint, ) -> Result { - match get_entry_point(&ContractClass::V1(self.clone()), call)? { - Cairo1EntryPoint::Casm(entry_point) => Ok(entry_point), - Cairo1EntryPoint::Native(_) => panic!("Unexpected entry point type."), - } + self.entry_points_by_type.get_entry_point(call) } /// Returns whether this contract should run using Cairo steps or Sierra gas. @@ -453,7 +396,7 @@ fn get_visited_segments( #[derive(Clone, Debug, Eq, PartialEq)] pub struct ContractClassV1Inner { pub program: Program, - pub entry_points_by_type: HashMap>, + pub entry_points_by_type: EntryPointsByType, pub hints: HashMap, pub compiler_version: CompilerVersion, bytecode_segment_lengths: NestedIntList, @@ -472,6 +415,12 @@ impl EntryPointV1 { } } +impl HasSelector for EntryPointV1 { + fn selector(&self) -> &EntryPointSelector { + &self.selector + } +} + impl TryFrom for ContractClassV1 { type Error = ProgramError; @@ -513,19 +462,11 @@ impl TryFrom for ContractClassV1 { instruction_locations, )?; - let mut entry_points_by_type = HashMap::new(); - entry_points_by_type.insert( - EntryPointType::Constructor, - convert_entry_points_v1(&class.entry_points_by_type.constructor), - ); - entry_points_by_type.insert( - EntryPointType::External, - convert_entry_points_v1(&class.entry_points_by_type.external), - ); - entry_points_by_type.insert( - EntryPointType::L1Handler, - convert_entry_points_v1(&class.entry_points_by_type.l1_handler), - ); + let entry_points_by_type = EntryPointsByType { + constructor: convert_entry_points_v1(&class.entry_points_by_type.constructor), + external: convert_entry_points_v1(&class.entry_points_by_type.external), + l1_handler: convert_entry_points_v1(&class.entry_points_by_type.l1_handler), + }; let bytecode_segment_lengths = class .bytecode_segment_lengths .unwrap_or_else(|| NestedIntList::Leaf(program.data_len())); @@ -682,18 +623,15 @@ impl NativeContractClassV1 { /// Returns an entry point into the natively compiled contract. pub fn get_entry_point(&self, call: &CallEntryPoint) -> Result { - match get_entry_point(&ContractClass::V1Native(self.clone()), call)? { - Cairo1EntryPoint::Native(entry_point) => Ok(entry_point.function_id), - Cairo1EntryPoint::Casm(_) => panic!("Unexpected entry point type."), - } + self.entry_points_by_type.get_entry_point(call).map(|ep| ep.function_id) } } #[derive(Debug)] pub struct NativeContractClassV1Inner { pub executor: AotNativeExecutor, - entry_points_by_type: NativeContractEntryPoints, - // Storing the raw sierra program and entry points to be able to compare the contract class + entry_points_by_type: EntryPointsByType, + // Storing the raw sierra program and entry points to be able to compare the contract class. sierra_program: Vec, } @@ -701,7 +639,7 @@ impl NativeContractClassV1Inner { fn new(executor: AotNativeExecutor, sierra_contract_class: SierraContractClass) -> Self { NativeContractClassV1Inner { executor, - entry_points_by_type: NativeContractEntryPoints::from(&sierra_contract_class), + entry_points_by_type: EntryPointsByType::from(&sierra_contract_class), sierra_program: sierra_contract_class.sierra_program, } } @@ -718,16 +656,49 @@ impl PartialEq for NativeContractClassV1Inner { impl Eq for NativeContractClassV1Inner {} -#[derive(Debug, PartialEq)] -/// Modelled after [cairo_lang_starknet_classes::contract_class::ContractEntryPoints] -/// and enriched with information for the Cairo Native ABI. -struct NativeContractEntryPoints { - constructor: Vec, - external: Vec, - l1_handler: Vec, +// TODO(Yoni): organize this file. +#[derive(Clone, Debug, Default, Eq, PartialEq)] +/// Modelled after [cairo_lang_starknet_classes::contract_class::ContractEntryPoints]. +pub struct EntryPointsByType { + constructor: Vec, + external: Vec, + l1_handler: Vec, +} + +impl EntryPointsByType { + pub fn get_entry_point(&self, call: &CallEntryPoint) -> Result { + call.verify_constructor()?; + + let entry_points_of_same_type = &self[call.entry_point_type]; + let filtered_entry_points: Vec<_> = entry_points_of_same_type + .iter() + .filter(|ep| *ep.selector() == call.entry_point_selector) + .collect(); + + match filtered_entry_points[..] { + [] => Err(PreExecutionError::EntryPointNotFound(call.entry_point_selector)), + [entry_point] => Ok(entry_point.clone()), + _ => Err(PreExecutionError::DuplicatedEntryPointSelector { + selector: call.entry_point_selector, + typ: call.entry_point_type, + }), + } + } } -impl From<&SierraContractClass> for NativeContractEntryPoints { +impl Index for EntryPointsByType { + type Output = Vec; + + fn index(&self, index: EntryPointType) -> &Self::Output { + match index { + EntryPointType::Constructor => &self.constructor, + EntryPointType::External => &self.external, + EntryPointType::L1Handler => &self.l1_handler, + } + } +} + +impl From<&SierraContractClass> for EntryPointsByType { fn from(sierra_contract_class: &SierraContractClass) -> Self { let program = sierra_contract_class.extract_sierra_program().expect("Can't get sierra program."); @@ -736,7 +707,7 @@ impl From<&SierraContractClass> for NativeContractEntryPoints { let entry_points_by_type = &sierra_contract_class.entry_points_by_type; - NativeContractEntryPoints { + EntryPointsByType:: { constructor: sierra_eps_to_native_eps(&func_ids, &entry_points_by_type.constructor), external: sierra_eps_to_native_eps(&func_ids, &entry_points_by_type.external), l1_handler: sierra_eps_to_native_eps(&func_ids, &entry_points_by_type.l1_handler), @@ -744,18 +715,6 @@ impl From<&SierraContractClass> for NativeContractEntryPoints { } } -impl Index for NativeContractEntryPoints { - type Output = Vec; - - fn index(&self, index: EntryPointType) -> &Self::Output { - match index { - EntryPointType::Constructor => &self.constructor, - EntryPointType::External => &self.external, - EntryPointType::L1Handler => &self.l1_handler, - } - } -} - fn sierra_eps_to_native_eps( func_ids: &[&FunctionId], sierra_eps: &[SierraContractEntryPoint], @@ -763,6 +722,10 @@ fn sierra_eps_to_native_eps( sierra_eps.iter().map(|sierra_ep| NativeEntryPoint::from(func_ids, sierra_ep)).collect() } +pub trait HasSelector { + fn selector(&self) -> &EntryPointSelector; +} + #[derive(Clone, Debug, PartialEq)] /// Provides a relation between a function in a contract and a compiled contract. pub struct NativeEntryPoint { @@ -781,3 +744,9 @@ impl NativeEntryPoint { } } } + +impl HasSelector for NativeEntryPoint { + fn selector(&self) -> &EntryPointSelector { + &self.selector + } +} diff --git a/crates/blockifier/src/test_utils/contracts.rs b/crates/blockifier/src/test_utils/contracts.rs index cb76b26968..11b95acfc4 100644 --- a/crates/blockifier/src/test_utils/contracts.rs +++ b/crates/blockifier/src/test_utils/contracts.rs @@ -18,6 +18,7 @@ use strum_macros::EnumIter; use crate::abi::abi_utils::selector_from_name; use crate::abi::constants::CONSTRUCTOR_ENTRY_POINT_NAME; use crate::execution::contract_class::{ContractClass, ContractClassV0, ContractClassV1}; +use crate::execution::entry_point::CallEntryPoint; use crate::test_utils::cairo_compile::{cairo0_compile, cairo1_compile}; use crate::test_utils::{get_raw_contract_class, CairoVersion}; @@ -329,10 +330,11 @@ impl FeatureContract { ContractClass::V1(class) => { class .entry_points_by_type - .get(&entry_point_type) - .unwrap() - .iter() - .find(|ep| ep.selector == entry_point_selector) + .get_entry_point(&CallEntryPoint { + entry_point_type, + entry_point_selector, + ..Default::default() + }) .unwrap() .offset }