Skip to content

Commit

Permalink
refactor: change try from casm contract class to be by ref to avoid c…
Browse files Browse the repository at this point in the history
…lones
  • Loading branch information
ArniStarkware committed Aug 29, 2024
1 parent 5f0331e commit 82db51b
Show file tree
Hide file tree
Showing 12 changed files with 39 additions and 47 deletions.
42 changes: 20 additions & 22 deletions crates/blockifier/src/execution/contract_class.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@ pub enum ContractClass {
V1(ContractClassV1),
}

impl TryFrom<CasmContractClass> for ContractClass {
impl TryFrom<&CasmContractClass> for ContractClass {
type Error = ProgramError;

fn try_from(contract_class: CasmContractClass) -> Result<Self, Self::Error> {
fn try_from(contract_class: &CasmContractClass) -> Result<Self, Self::Error> {
Ok(ContractClass::V1(contract_class.try_into()?))
}
}
Expand Down Expand Up @@ -241,7 +241,7 @@ impl ContractClassV1 {

pub fn try_from_json_string(raw_contract_class: &str) -> Result<ContractClassV1, ProgramError> {
let casm_contract_class: CasmContractClass = serde_json::from_str(raw_contract_class)?;
let contract_class: ContractClassV1 = casm_contract_class.try_into()?;
let contract_class: ContractClassV1 = (&casm_contract_class).try_into()?;

Ok(contract_class)
}
Expand Down Expand Up @@ -370,15 +370,12 @@ impl EntryPointV1 {
}
}

impl TryFrom<CasmContractClass> for ContractClassV1 {
impl TryFrom<&CasmContractClass> for ContractClassV1 {
type Error = ProgramError;

fn try_from(class: CasmContractClass) -> Result<Self, Self::Error> {
let data: Vec<MaybeRelocatable> = class
.bytecode
.into_iter()
.map(|x| MaybeRelocatable::from(Felt::from(x.value)))
.collect();
fn try_from(class: &CasmContractClass) -> Result<Self, Self::Error> {
let data: Vec<MaybeRelocatable> =
class.bytecode.iter().map(|x| MaybeRelocatable::from(Felt::from(&x.value))).collect();

let mut hints: HashMap<usize, Vec<HintParams>> = HashMap::new();
for (i, hint_list) in class.hints.iter() {
Expand Down Expand Up @@ -417,19 +414,20 @@ impl TryFrom<CasmContractClass> for ContractClassV1 {
let mut entry_points_by_type = HashMap::new();
entry_points_by_type.insert(
EntryPointType::Constructor,
convert_entry_points_v1(class.entry_points_by_type.constructor),
convert_entry_points_v1(&class.entry_points_by_type.constructor),
);
entry_points_by_type.insert(
EntryPointType::External,
convert_entry_points_v1(class.entry_points_by_type.external),
convert_entry_points_v1(&class.entry_points_by_type.external),
);
entry_points_by_type.insert(
EntryPointType::L1Handler,
convert_entry_points_v1(class.entry_points_by_type.l1_handler),
convert_entry_points_v1(&class.entry_points_by_type.l1_handler),
);

let bytecode_segment_lengths = class
.bytecode_segment_lengths
.clone()
.unwrap_or_else(|| NestedIntList::Leaf(program.data_len()));

Ok(Self(Arc::new(ContractClassV1Inner {
Expand Down Expand Up @@ -466,16 +464,16 @@ fn hint_to_hint_params(hint: &cairo_lang_casm::hints::Hint) -> Result<HintParams
})
}

fn convert_entry_points_v1(external: Vec<CasmContractEntryPoint>) -> Vec<EntryPointV1> {
fn convert_entry_points_v1(external: &[CasmContractEntryPoint]) -> Vec<EntryPointV1> {
external
.into_iter()
.iter()
.map(|ep| EntryPointV1 {
selector: EntryPointSelector(Felt::from(ep.selector)),
selector: EntryPointSelector(Felt::from(&ep.selector)),
offset: EntryPointOffset(ep.offset),
builtins: ep
.builtins
.into_iter()
.map(|builtin| BuiltinName::from_str(&builtin).expect("Unrecognized builtin."))
.iter()
.map(|builtin| BuiltinName::from_str(builtin).expect("Unrecognized builtin."))
.collect(),
})
.collect()
Expand All @@ -489,10 +487,10 @@ pub struct ClassInfo {
abi_length: usize,
}

impl TryFrom<starknet_api::contract_class::ClassInfo> for ClassInfo {
impl TryFrom<&starknet_api::contract_class::ClassInfo> for ClassInfo {
type Error = ProgramError;

fn try_from(class_info: starknet_api::contract_class::ClassInfo) -> Result<Self, Self::Error> {
fn try_from(class_info: &starknet_api::contract_class::ClassInfo) -> Result<Self, Self::Error> {
let starknet_api::contract_class::ClassInfo {
casm_contract_class,
sierra_program_length,
Expand All @@ -501,8 +499,8 @@ impl TryFrom<starknet_api::contract_class::ClassInfo> for ClassInfo {

Ok(Self {
contract_class: casm_contract_class.try_into()?,
sierra_program_length,
abi_length,
sierra_program_length: *sierra_program_length,
abi_length: *abi_length,
})
}
}
Expand Down
8 changes: 4 additions & 4 deletions crates/blockifier/src/transaction/account_transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,24 +77,24 @@ pub enum AccountTransaction {
Invoke(InvokeTransaction),
}

impl TryFrom<starknet_api::executable_transaction::Transaction> for AccountTransaction {
impl TryFrom<&starknet_api::executable_transaction::Transaction> for AccountTransaction {
type Error = TransactionExecutionError;

fn try_from(
executable_transaction: starknet_api::executable_transaction::Transaction,
executable_transaction: &starknet_api::executable_transaction::Transaction,
) -> Result<Self, Self::Error> {
match executable_transaction {
starknet_api::executable_transaction::Transaction::Declare(declare_tx) => {
Ok(Self::Declare(declare_tx.try_into()?))
}
starknet_api::executable_transaction::Transaction::DeployAccount(deploy_account_tx) => {
Ok(Self::DeployAccount(DeployAccountTransaction {
tx: deploy_account_tx,
tx: deploy_account_tx.clone(),
only_query: false,
}))
}
starknet_api::executable_transaction::Transaction::Invoke(invoke_tx) => {
Ok(Self::Invoke(InvokeTransaction { tx: invoke_tx, only_query: false }))
Ok(Self::Invoke(InvokeTransaction { tx: invoke_tx.clone(), only_query: false }))
}
}
}
Expand Down
8 changes: 4 additions & 4 deletions crates/blockifier/src/transaction/transactions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,11 +135,11 @@ pub struct DeclareTransaction {
pub class_info: ClassInfo,
}

impl TryFrom<starknet_api::executable_transaction::DeclareTransaction> for DeclareTransaction {
impl TryFrom<&starknet_api::executable_transaction::DeclareTransaction> for DeclareTransaction {
type Error = TransactionExecutionError;

fn try_from(
declare_tx: starknet_api::executable_transaction::DeclareTransaction,
declare_tx: &starknet_api::executable_transaction::DeclareTransaction,
) -> Result<Self, Self::Error> {
Self::new_from_executable_tx(declare_tx, false)
}
Expand Down Expand Up @@ -174,14 +174,14 @@ impl DeclareTransaction {
}

fn new_from_executable_tx(
declare_tx: starknet_api::executable_transaction::DeclareTransaction,
declare_tx: &starknet_api::executable_transaction::DeclareTransaction,
only_query: bool,
) -> Result<Self, TransactionExecutionError> {
let starknet_api::executable_transaction::DeclareTransaction { tx, tx_hash, class_info } =
declare_tx;
let class_info: ClassInfo = class_info.try_into()?;

Self::create(tx, tx_hash, class_info, only_query)
Self::create(tx.clone(), *tx_hash, class_info, only_query)
}

implement_inner_tx_getter_calls!(
Expand Down
2 changes: 1 addition & 1 deletion crates/gateway/src/rpc_state_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ impl BlockifierStateReader for RpcStateReader {
serde_json::from_value(result).map_err(serde_err_to_state_err)?;
match contract_class {
CompiledContractClass::V1(contract_class_v1) => Ok(ContractClass::V1(
ContractClassV1::try_from(contract_class_v1).map_err(StateError::ProgramError)?,
ContractClassV1::try_from(&contract_class_v1).map_err(StateError::ProgramError)?,
)),
CompiledContractClass::V0(contract_class_v0) => Ok(ContractClass::V0(
ContractClassV0::try_from(contract_class_v0).map_err(StateError::ProgramError)?,
Expand Down
2 changes: 1 addition & 1 deletion crates/gateway/src/rpc_state_reader_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ async fn test_get_compiled_contract_class() {
.await
.unwrap()
.unwrap();
assert_eq!(result, ContractClass::V1(CasmContractClass::default().try_into().unwrap()));
assert_eq!(result, ContractClass::try_from(&CasmContractClass::default()).unwrap());
mock.assert_async().await;
}

Expand Down
6 changes: 1 addition & 5 deletions crates/gateway/src/stateful_transaction_validator.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,7 @@ impl StatefulTransactionValidator {
executable_tx: &ExecutableTransaction,
mut validator: V,
) -> StatefulTransactionValidatorResult<ValidateInfo> {
let account_tx = AccountTransaction::try_from(
// TODO(Arni): create a try_from for &ExecutableTransaction.
executable_tx.clone(),
)
.map_err(|error| {
let account_tx = AccountTransaction::try_from(executable_tx).map_err(|error| {
error!("Failed to convert executable transaction into account transaction: {}", error);
GatewaySpecError::UnexpectedError { data: "Internal server error".to_owned() }
})?;
Expand Down
2 changes: 1 addition & 1 deletion crates/native_blockifier/src/py_block_executor_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ use crate::test_utils::MockStorage;
fn global_contract_cache_update() {
// Initialize executor and set a contract class on the state.
let casm = CasmContractClass::default();
let contract_class = ContractClass::V1(ContractClassV1::try_from(casm.clone()).unwrap());
let contract_class = ContractClass::V1(ContractClassV1::try_from(&casm).unwrap());
let class_hash = class_hash!("0x1");

let temp_storage_path = tempfile::tempdir().unwrap().into_path();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ impl PapyrusReader {
inconsistent.",
);

return Ok(ContractClass::V1(ContractClassV1::try_from(casm_contract_class)?));
return Ok(ContractClass::V1(ContractClassV1::try_from(&casm_contract_class)?));
}

let v0_contract_class = self
Expand Down
2 changes: 1 addition & 1 deletion crates/papyrus_execution/src/execution_utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ pub(crate) fn get_contract_class(
return Err(ExecutionUtilsError::CasmTableNotSynced);
};
return Ok(Some(BlockifierContractClass::V1(
ContractClassV1::try_from(casm).map_err(ExecutionUtilsError::ProgramError)?,
ContractClassV1::try_from(&casm).map_err(ExecutionUtilsError::ProgramError)?,
)));
}
None => {}
Expand Down
4 changes: 2 additions & 2 deletions crates/papyrus_execution/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -794,7 +794,7 @@ fn to_blockifier_tx(
only_query,
) => {
let class_v1 = BlockifierContractClass::V1(
compiled_class.try_into().map_err(BlockifierError::new)?,
(&compiled_class).try_into().map_err(BlockifierError::new)?,
);
let class_info =
ClassInfo::new(&class_v1, sierra_program_length, abi_length).map_err(|err| {
Expand All @@ -821,7 +821,7 @@ fn to_blockifier_tx(
only_query,
) => {
let class_v1 = BlockifierContractClass::V1(
compiled_class.try_into().map_err(BlockifierError::new)?,
(&compiled_class).try_into().map_err(BlockifierError::new)?,
);
let class_info =
ClassInfo::new(&class_v1, sierra_program_length, abi_length).map_err(|err| {
Expand Down
2 changes: 1 addition & 1 deletion crates/papyrus_execution/src/state_reader.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ impl BlockifierStateReader for ExecutionStateReader {
.and_then(|pending_data| pending_data.classes.get_compiled_class(class_hash))
{
return Ok(BlockifierContractClass::V1(
ContractClassV1::try_from(pending_casm).map_err(StateError::ProgramError)?,
ContractClassV1::try_from(&pending_casm).map_err(StateError::ProgramError)?,
));
}
if let Some(ApiContractClass::DeprecatedContractClass(pending_deprecated_class)) = self
Expand Down
6 changes: 2 additions & 4 deletions crates/papyrus_execution/src/state_reader_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ fn read_state() {
// The class is not used in the execution, so it can be default.
let class0 = ContractClass::default();
let casm0 = get_test_casm();
let blockifier_casm0 =
BlockifierContractClass::V1(ContractClassV1::try_from(casm0.clone()).unwrap());
let blockifier_casm0 = BlockifierContractClass::V1(ContractClassV1::try_from(&casm0).unwrap());
let compiled_class_hash0 = CompiledClassHash(StarkHash::default());

let class_hash1 = ClassHash(1u128.into());
Expand All @@ -64,8 +63,7 @@ fn read_state() {
let compiled_class_hash2 = CompiledClassHash(StarkHash::TWO);
let mut casm1 = get_test_casm();
casm1.bytecode[0] = BigUintAsHex { value: 12345u32.into() };
let blockifier_casm1 =
BlockifierContractClass::V1(ContractClassV1::try_from(casm1.clone()).unwrap());
let blockifier_casm1 = BlockifierContractClass::V1(ContractClassV1::try_from(&casm1).unwrap());
let nonce1 = Nonce(felt!(2_u128));
let class_hash3 = ClassHash(567_u128.into());
let class_hash4 = ClassHash(89_u128.into());
Expand Down

0 comments on commit 82db51b

Please sign in to comment.