Skip to content

Commit

Permalink
chore(blockifier): share EntryPoint code between native and casm
Browse files Browse the repository at this point in the history
  • Loading branch information
Yoni-Starkware committed Oct 14, 2024
1 parent ad6c24a commit 10f1542
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 105 deletions.
171 changes: 70 additions & 101 deletions crates/blockifier/src/execution/contract_class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -63,21 +63,6 @@ pub enum TrackedResource {
SierraGas, // AKA Sierra mode.
}

#[derive(Clone)]
pub enum Cairo1EntryPoint {
Casm(EntryPointV1),
Native(NativeEntryPoint),
}

impl Cairo1EntryPoint {
pub fn selector(&self) -> &EntryPointSelector {
match self {
Cairo1EntryPoint::Casm(ep) => &ep.selector,
Cairo1EntryPoint::Native(ep) => &ep.selector,
}
}
}

/// Represents a runnable Starknet contract class (meaning, the program is runnable by the VM).
#[derive(Clone, Debug, Eq, PartialEq, derive_more::From)]
pub enum ContractClass {
Expand All @@ -86,28 +71,6 @@ pub enum ContractClass {
V1Native(NativeContractClassV1),
}

pub fn get_entry_point(
contract_class: &ContractClass,
call: &CallEntryPoint,
) -> Result<Cairo1EntryPoint, PreExecutionError> {
call.verify_constructor()?;

let entry_points_of_same_type = contract_class.entry_points_of_same_type(call.entry_point_type);
let filtered_entry_points: Vec<_> = entry_points_of_same_type
.iter()
.filter(|ep| *ep.selector() == call.entry_point_selector)
.collect();

match &filtered_entry_points[..] {
[] => Err(PreExecutionError::EntryPointNotFound(call.entry_point_selector)),
[entry_point] => Ok((**entry_point).clone()),
_ => Err(PreExecutionError::DuplicatedEntryPointSelector {
selector: call.entry_point_selector,
typ: call.entry_point_type,
}),
}
}

impl TryFrom<RawContractClass> for ContractClass {
type Error = ProgramError;

Expand Down Expand Up @@ -144,23 +107,6 @@ impl ContractClass {
}
}

pub fn entry_points_of_same_type(
&self,
entry_point_type: EntryPointType,
) -> Vec<Cairo1EntryPoint> {
match self {
ContractClass::V0(_) => panic!("V0 contracts do not support entry points."),
ContractClass::V1(class) => class.entry_points_by_type[&entry_point_type]
.iter()
.map(|ep| Cairo1EntryPoint::Casm(ep.clone()))
.collect(),
ContractClass::V1Native(class) => class.entry_points_by_type[entry_point_type]
.iter()
.map(|ep| Cairo1EntryPoint::Native(ep.clone()))
.collect(),
}
}

pub fn get_visited_segments(
&self,
visited_pcs: &HashSet<usize>,
Expand Down Expand Up @@ -292,7 +238,7 @@ impl Deref for ContractClassV1 {

impl ContractClassV1 {
fn constructor_selector(&self) -> Option<EntryPointSelector> {
Some(self.0.entry_points_by_type[&EntryPointType::Constructor].first()?.selector)
self.0.entry_points_by_type.constructor.first().map(|ep| ep.selector)
}

pub fn bytecode_length(&self) -> usize {
Expand All @@ -307,10 +253,7 @@ impl ContractClassV1 {
&self,
call: &CallEntryPoint,
) -> Result<EntryPointV1, PreExecutionError> {
match get_entry_point(&ContractClass::V1(self.clone()), call)? {
Cairo1EntryPoint::Casm(entry_point) => Ok(entry_point),
Cairo1EntryPoint::Native(_) => panic!("Unexpected entry point type."),
}
self.entry_points_by_type.get_entry_point(call)
}

/// Returns whether this contract should run using Cairo steps or Sierra gas.
Expand Down Expand Up @@ -453,7 +396,7 @@ fn get_visited_segments(
#[derive(Clone, Debug, Eq, PartialEq)]
pub struct ContractClassV1Inner {
pub program: Program,
pub entry_points_by_type: HashMap<EntryPointType, Vec<EntryPointV1>>,
pub entry_points_by_type: EntryPointsByType<EntryPointV1>,
pub hints: HashMap<String, Hint>,
pub compiler_version: CompilerVersion,
bytecode_segment_lengths: NestedIntList,
Expand All @@ -472,6 +415,12 @@ impl EntryPointV1 {
}
}

impl HasSelector for EntryPointV1 {
fn selector(&self) -> &EntryPointSelector {
&self.selector
}
}

impl TryFrom<CasmContractClass> for ContractClassV1 {
type Error = ProgramError;

Expand Down Expand Up @@ -513,19 +462,11 @@ impl TryFrom<CasmContractClass> for ContractClassV1 {
instruction_locations,
)?;

let mut entry_points_by_type = HashMap::new();
entry_points_by_type.insert(
EntryPointType::Constructor,
convert_entry_points_v1(&class.entry_points_by_type.constructor),
);
entry_points_by_type.insert(
EntryPointType::External,
convert_entry_points_v1(&class.entry_points_by_type.external),
);
entry_points_by_type.insert(
EntryPointType::L1Handler,
convert_entry_points_v1(&class.entry_points_by_type.l1_handler),
);
let entry_points_by_type = EntryPointsByType {
constructor: convert_entry_points_v1(&class.entry_points_by_type.constructor),
external: convert_entry_points_v1(&class.entry_points_by_type.external),
l1_handler: convert_entry_points_v1(&class.entry_points_by_type.l1_handler),
};
let bytecode_segment_lengths = class
.bytecode_segment_lengths
.unwrap_or_else(|| NestedIntList::Leaf(program.data_len()));
Expand Down Expand Up @@ -682,26 +623,23 @@ impl NativeContractClassV1 {

/// Returns an entry point into the natively compiled contract.
pub fn get_entry_point(&self, call: &CallEntryPoint) -> Result<FunctionId, PreExecutionError> {
match get_entry_point(&ContractClass::V1Native(self.clone()), call)? {
Cairo1EntryPoint::Native(entry_point) => Ok(entry_point.function_id),
Cairo1EntryPoint::Casm(_) => panic!("Unexpected entry point type."),
}
self.entry_points_by_type.get_entry_point(call).map(|ep| ep.function_id)
}
}

#[derive(Debug)]
pub struct NativeContractClassV1Inner {
pub executor: AotNativeExecutor,
entry_points_by_type: NativeContractEntryPoints,
// Storing the raw sierra program and entry points to be able to compare the contract class
entry_points_by_type: EntryPointsByType<NativeEntryPoint>,
// Storing the raw sierra program and entry points to be able to compare the contract class.
sierra_program: Vec<BigUintAsHex>,
}

impl NativeContractClassV1Inner {
fn new(executor: AotNativeExecutor, sierra_contract_class: SierraContractClass) -> Self {
NativeContractClassV1Inner {
executor,
entry_points_by_type: NativeContractEntryPoints::from(&sierra_contract_class),
entry_points_by_type: EntryPointsByType::from(&sierra_contract_class),
sierra_program: sierra_contract_class.sierra_program,
}
}
Expand All @@ -718,16 +656,49 @@ impl PartialEq for NativeContractClassV1Inner {

impl Eq for NativeContractClassV1Inner {}

#[derive(Debug, PartialEq)]
/// Modelled after [cairo_lang_starknet_classes::contract_class::ContractEntryPoints]
/// and enriched with information for the Cairo Native ABI.
struct NativeContractEntryPoints {
constructor: Vec<NativeEntryPoint>,
external: Vec<NativeEntryPoint>,
l1_handler: Vec<NativeEntryPoint>,
// TODO(Yoni): organize this file.
#[derive(Clone, Debug, Default, Eq, PartialEq)]
/// Modelled after [cairo_lang_starknet_classes::contract_class::ContractEntryPoints].
pub struct EntryPointsByType<EP: HasSelector> {
constructor: Vec<EP>,
external: Vec<EP>,
l1_handler: Vec<EP>,
}

impl<EP: Clone + HasSelector> EntryPointsByType<EP> {
pub fn get_entry_point(&self, call: &CallEntryPoint) -> Result<EP, PreExecutionError> {
call.verify_constructor()?;

let entry_points_of_same_type = &self[call.entry_point_type];
let filtered_entry_points: Vec<_> = entry_points_of_same_type
.iter()
.filter(|ep| *ep.selector() == call.entry_point_selector)
.collect();

match filtered_entry_points[..] {
[] => Err(PreExecutionError::EntryPointNotFound(call.entry_point_selector)),
[entry_point] => Ok(entry_point.clone()),
_ => Err(PreExecutionError::DuplicatedEntryPointSelector {
selector: call.entry_point_selector,
typ: call.entry_point_type,
}),
}
}
}

impl From<&SierraContractClass> for NativeContractEntryPoints {
impl<EP: HasSelector> Index<EntryPointType> for EntryPointsByType<EP> {
type Output = Vec<EP>;

fn index(&self, index: EntryPointType) -> &Self::Output {
match index {
EntryPointType::Constructor => &self.constructor,
EntryPointType::External => &self.external,
EntryPointType::L1Handler => &self.l1_handler,
}
}
}

impl From<&SierraContractClass> for EntryPointsByType<NativeEntryPoint> {
fn from(sierra_contract_class: &SierraContractClass) -> Self {
let program =
sierra_contract_class.extract_sierra_program().expect("Can't get sierra program.");
Expand All @@ -736,33 +707,25 @@ impl From<&SierraContractClass> for NativeContractEntryPoints {

let entry_points_by_type = &sierra_contract_class.entry_points_by_type;

NativeContractEntryPoints {
EntryPointsByType::<NativeEntryPoint> {
constructor: sierra_eps_to_native_eps(&func_ids, &entry_points_by_type.constructor),
external: sierra_eps_to_native_eps(&func_ids, &entry_points_by_type.external),
l1_handler: sierra_eps_to_native_eps(&func_ids, &entry_points_by_type.l1_handler),
}
}
}

impl Index<EntryPointType> for NativeContractEntryPoints {
type Output = Vec<NativeEntryPoint>;

fn index(&self, index: EntryPointType) -> &Self::Output {
match index {
EntryPointType::Constructor => &self.constructor,
EntryPointType::External => &self.external,
EntryPointType::L1Handler => &self.l1_handler,
}
}
}

fn sierra_eps_to_native_eps(
func_ids: &[&FunctionId],
sierra_eps: &[SierraContractEntryPoint],
) -> Vec<NativeEntryPoint> {
sierra_eps.iter().map(|sierra_ep| NativeEntryPoint::from(func_ids, sierra_ep)).collect()
}

pub trait HasSelector {
fn selector(&self) -> &EntryPointSelector;
}

#[derive(Clone, Debug, PartialEq)]
/// Provides a relation between a function in a contract and a compiled contract.
pub struct NativeEntryPoint {
Expand All @@ -781,3 +744,9 @@ impl NativeEntryPoint {
}
}
}

impl HasSelector for NativeEntryPoint {
fn selector(&self) -> &EntryPointSelector {
&self.selector
}
}
10 changes: 6 additions & 4 deletions crates/blockifier/src/test_utils/contracts.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ use strum_macros::EnumIter;
use crate::abi::abi_utils::selector_from_name;
use crate::abi::constants::CONSTRUCTOR_ENTRY_POINT_NAME;
use crate::execution::contract_class::{ContractClass, ContractClassV0, ContractClassV1};
use crate::execution::entry_point::CallEntryPoint;
use crate::test_utils::cairo_compile::{cairo0_compile, cairo1_compile};
use crate::test_utils::{get_raw_contract_class, CairoVersion};

Expand Down Expand Up @@ -329,10 +330,11 @@ impl FeatureContract {
ContractClass::V1(class) => {
class
.entry_points_by_type
.get(&entry_point_type)
.unwrap()
.iter()
.find(|ep| ep.selector == entry_point_selector)
.get_entry_point(&CallEntryPoint {
entry_point_type,
entry_point_selector,
..Default::default()
})
.unwrap()
.offset
}
Expand Down

0 comments on commit 10f1542

Please sign in to comment.