From bcefa055cb52f857f28c7a506280f232db2c7de7 Mon Sep 17 00:00:00 2001 From: Arni Hod Date: Sun, 25 Aug 2024 10:16:43 +0300 Subject: [PATCH] refactor: change try from casm contract class to be by ref to avoid clones --- .../src/execution/contract_class.rs | 42 +++++++++---------- .../src/transaction/account_transaction.rs | 8 ++-- .../src/transaction/transactions.rs | 8 ++-- crates/gateway/src/rpc_state_reader.rs | 2 +- crates/gateway/src/rpc_state_reader_test.rs | 2 +- .../src/stateful_transaction_validator.rs | 6 +-- .../src/py_block_executor_test.rs | 2 +- .../src/state_readers/papyrus_state.rs | 2 +- .../papyrus_execution/src/execution_utils.rs | 2 +- crates/papyrus_execution/src/lib.rs | 4 +- crates/papyrus_execution/src/state_reader.rs | 2 +- .../src/state_reader_test.rs | 6 +-- 12 files changed, 39 insertions(+), 47 deletions(-) diff --git a/crates/blockifier/src/execution/contract_class.rs b/crates/blockifier/src/execution/contract_class.rs index 26cc6ee46a7..67a03b246a5 100644 --- a/crates/blockifier/src/execution/contract_class.rs +++ b/crates/blockifier/src/execution/contract_class.rs @@ -52,10 +52,10 @@ pub enum ContractClass { V1(ContractClassV1), } -impl TryFrom for ContractClass { +impl TryFrom<&CasmContractClass> for ContractClass { type Error = ProgramError; - fn try_from(contract_class: CasmContractClass) -> Result { + fn try_from(contract_class: &CasmContractClass) -> Result { Ok(ContractClass::V1(contract_class.try_into()?)) } } @@ -241,7 +241,7 @@ impl ContractClassV1 { pub fn try_from_json_string(raw_contract_class: &str) -> Result { let casm_contract_class: CasmContractClass = serde_json::from_str(raw_contract_class)?; - let contract_class: ContractClassV1 = casm_contract_class.try_into()?; + let contract_class: ContractClassV1 = (&casm_contract_class).try_into()?; Ok(contract_class) } @@ -370,15 +370,12 @@ impl EntryPointV1 { } } -impl TryFrom for ContractClassV1 { +impl TryFrom<&CasmContractClass> for ContractClassV1 { type Error = ProgramError; - fn try_from(class: CasmContractClass) -> Result { - let data: Vec = class - .bytecode - .into_iter() - .map(|x| MaybeRelocatable::from(Felt::from(x.value))) - .collect(); + fn try_from(class: &CasmContractClass) -> Result { + let data: Vec = + class.bytecode.iter().map(|x| MaybeRelocatable::from(Felt::from(&x.value))).collect(); let mut hints: HashMap> = HashMap::new(); for (i, hint_list) in class.hints.iter() { @@ -417,19 +414,20 @@ impl TryFrom for ContractClassV1 { let mut entry_points_by_type = HashMap::new(); entry_points_by_type.insert( EntryPointType::Constructor, - convert_entry_points_v1(class.entry_points_by_type.constructor), + convert_entry_points_v1(&class.entry_points_by_type.constructor), ); entry_points_by_type.insert( EntryPointType::External, - convert_entry_points_v1(class.entry_points_by_type.external), + convert_entry_points_v1(&class.entry_points_by_type.external), ); entry_points_by_type.insert( EntryPointType::L1Handler, - convert_entry_points_v1(class.entry_points_by_type.l1_handler), + convert_entry_points_v1(&class.entry_points_by_type.l1_handler), ); let bytecode_segment_lengths = class .bytecode_segment_lengths + .clone() .unwrap_or_else(|| NestedIntList::Leaf(program.data_len())); Ok(Self(Arc::new(ContractClassV1Inner { @@ -466,16 +464,16 @@ fn hint_to_hint_params(hint: &cairo_lang_casm::hints::Hint) -> Result) -> Vec { +fn convert_entry_points_v1(external: &[CasmContractEntryPoint]) -> Vec { external - .into_iter() + .iter() .map(|ep| EntryPointV1 { - selector: EntryPointSelector(Felt::from(ep.selector)), + selector: EntryPointSelector(Felt::from(&ep.selector)), offset: EntryPointOffset(ep.offset), builtins: ep .builtins - .into_iter() - .map(|builtin| BuiltinName::from_str(&builtin).expect("Unrecognized builtin.")) + .iter() + .map(|builtin| BuiltinName::from_str(builtin).expect("Unrecognized builtin.")) .collect(), }) .collect() @@ -489,10 +487,10 @@ pub struct ClassInfo { abi_length: usize, } -impl TryFrom for ClassInfo { +impl TryFrom<&starknet_api::contract_class::ClassInfo> for ClassInfo { type Error = ProgramError; - fn try_from(class_info: starknet_api::contract_class::ClassInfo) -> Result { + fn try_from(class_info: &starknet_api::contract_class::ClassInfo) -> Result { let starknet_api::contract_class::ClassInfo { casm_contract_class, sierra_program_length, @@ -501,8 +499,8 @@ impl TryFrom for ClassInfo { Ok(Self { contract_class: casm_contract_class.try_into()?, - sierra_program_length, - abi_length, + sierra_program_length: *sierra_program_length, + abi_length: *abi_length, }) } } diff --git a/crates/blockifier/src/transaction/account_transaction.rs b/crates/blockifier/src/transaction/account_transaction.rs index 28cd17aa68a..dbe4546589d 100644 --- a/crates/blockifier/src/transaction/account_transaction.rs +++ b/crates/blockifier/src/transaction/account_transaction.rs @@ -77,11 +77,11 @@ pub enum AccountTransaction { Invoke(InvokeTransaction), } -impl TryFrom for AccountTransaction { +impl TryFrom<&starknet_api::executable_transaction::Transaction> for AccountTransaction { type Error = TransactionExecutionError; fn try_from( - executable_transaction: starknet_api::executable_transaction::Transaction, + executable_transaction: &starknet_api::executable_transaction::Transaction, ) -> Result { match executable_transaction { starknet_api::executable_transaction::Transaction::Declare(declare_tx) => { @@ -89,12 +89,12 @@ impl TryFrom for AccountTrans } starknet_api::executable_transaction::Transaction::DeployAccount(deploy_account_tx) => { Ok(Self::DeployAccount(DeployAccountTransaction { - tx: deploy_account_tx, + tx: deploy_account_tx.clone(), only_query: false, })) } starknet_api::executable_transaction::Transaction::Invoke(invoke_tx) => { - Ok(Self::Invoke(InvokeTransaction { tx: invoke_tx, only_query: false })) + Ok(Self::Invoke(InvokeTransaction { tx: invoke_tx.clone(), only_query: false })) } } } diff --git a/crates/blockifier/src/transaction/transactions.rs b/crates/blockifier/src/transaction/transactions.rs index b75965e8e71..fac8f1a5db2 100644 --- a/crates/blockifier/src/transaction/transactions.rs +++ b/crates/blockifier/src/transaction/transactions.rs @@ -135,11 +135,11 @@ pub struct DeclareTransaction { pub class_info: ClassInfo, } -impl TryFrom for DeclareTransaction { +impl TryFrom<&starknet_api::executable_transaction::DeclareTransaction> for DeclareTransaction { type Error = TransactionExecutionError; fn try_from( - declare_tx: starknet_api::executable_transaction::DeclareTransaction, + declare_tx: &starknet_api::executable_transaction::DeclareTransaction, ) -> Result { Self::new_from_executable_tx(declare_tx, false) } @@ -174,14 +174,14 @@ impl DeclareTransaction { } fn new_from_executable_tx( - declare_tx: starknet_api::executable_transaction::DeclareTransaction, + declare_tx: &starknet_api::executable_transaction::DeclareTransaction, only_query: bool, ) -> Result { let starknet_api::executable_transaction::DeclareTransaction { tx, tx_hash, class_info } = declare_tx; let class_info: ClassInfo = class_info.try_into()?; - Self::create(tx, tx_hash, class_info, only_query) + Self::create(tx.clone(), *tx_hash, class_info, only_query) } implement_inner_tx_getter_calls!( diff --git a/crates/gateway/src/rpc_state_reader.rs b/crates/gateway/src/rpc_state_reader.rs index 60fec240bdf..7b9fbe88952 100644 --- a/crates/gateway/src/rpc_state_reader.rs +++ b/crates/gateway/src/rpc_state_reader.rs @@ -135,7 +135,7 @@ impl BlockifierStateReader for RpcStateReader { serde_json::from_value(result).map_err(serde_err_to_state_err)?; match contract_class { CompiledContractClass::V1(contract_class_v1) => Ok(ContractClass::V1( - ContractClassV1::try_from(contract_class_v1).map_err(StateError::ProgramError)?, + ContractClassV1::try_from(&contract_class_v1).map_err(StateError::ProgramError)?, )), CompiledContractClass::V0(contract_class_v0) => Ok(ContractClass::V0( ContractClassV0::try_from(contract_class_v0).map_err(StateError::ProgramError)?, diff --git a/crates/gateway/src/rpc_state_reader_test.rs b/crates/gateway/src/rpc_state_reader_test.rs index af344104a08..828eed89eff 100644 --- a/crates/gateway/src/rpc_state_reader_test.rs +++ b/crates/gateway/src/rpc_state_reader_test.rs @@ -176,7 +176,7 @@ async fn test_get_compiled_contract_class() { .await .unwrap() .unwrap(); - assert_eq!(result, ContractClass::V1(CasmContractClass::default().try_into().unwrap())); + assert_eq!(result, ContractClass::try_from(&CasmContractClass::default()).unwrap()); mock.assert_async().await; } diff --git a/crates/gateway/src/stateful_transaction_validator.rs b/crates/gateway/src/stateful_transaction_validator.rs index 7ee942bcd5e..88ab66fcace 100644 --- a/crates/gateway/src/stateful_transaction_validator.rs +++ b/crates/gateway/src/stateful_transaction_validator.rs @@ -74,11 +74,7 @@ impl StatefulTransactionValidator { executable_tx: &ExecutableTransaction, mut validator: V, ) -> StatefulTransactionValidatorResult { - let account_tx = AccountTransaction::try_from( - // TODO(Arni): create a try_from for &ExecutableTransaction. - executable_tx.clone(), - ) - .map_err(|error| { + let account_tx = AccountTransaction::try_from(executable_tx).map_err(|error| { error!("Failed to convert executable transaction into account transaction: {}", error); GatewaySpecError::UnexpectedError { data: "Internal server error".to_owned() } })?; diff --git a/crates/native_blockifier/src/py_block_executor_test.rs b/crates/native_blockifier/src/py_block_executor_test.rs index b3cbca87c04..70e4dc8cbf9 100644 --- a/crates/native_blockifier/src/py_block_executor_test.rs +++ b/crates/native_blockifier/src/py_block_executor_test.rs @@ -20,7 +20,7 @@ use crate::test_utils::MockStorage; fn global_contract_cache_update() { // Initialize executor and set a contract class on the state. let casm = CasmContractClass::default(); - let contract_class = ContractClass::V1(ContractClassV1::try_from(casm.clone()).unwrap()); + let contract_class = ContractClass::V1(ContractClassV1::try_from(&casm).unwrap()); let class_hash = class_hash!("0x1"); let temp_storage_path = tempfile::tempdir().unwrap().into_path(); diff --git a/crates/native_blockifier/src/state_readers/papyrus_state.rs b/crates/native_blockifier/src/state_readers/papyrus_state.rs index 463d242b0cb..1bd8c36c931 100644 --- a/crates/native_blockifier/src/state_readers/papyrus_state.rs +++ b/crates/native_blockifier/src/state_readers/papyrus_state.rs @@ -63,7 +63,7 @@ impl PapyrusReader { inconsistent.", ); - return Ok(ContractClass::V1(ContractClassV1::try_from(casm_contract_class)?)); + return Ok(ContractClass::V1(ContractClassV1::try_from(&casm_contract_class)?)); } let v0_contract_class = self diff --git a/crates/papyrus_execution/src/execution_utils.rs b/crates/papyrus_execution/src/execution_utils.rs index 928728acabf..0ff0e25acc1 100644 --- a/crates/papyrus_execution/src/execution_utils.rs +++ b/crates/papyrus_execution/src/execution_utils.rs @@ -67,7 +67,7 @@ pub(crate) fn get_contract_class( return Err(ExecutionUtilsError::CasmTableNotSynced); }; return Ok(Some(BlockifierContractClass::V1( - ContractClassV1::try_from(casm).map_err(ExecutionUtilsError::ProgramError)?, + ContractClassV1::try_from(&casm).map_err(ExecutionUtilsError::ProgramError)?, ))); } None => {} diff --git a/crates/papyrus_execution/src/lib.rs b/crates/papyrus_execution/src/lib.rs index e4502626ffe..f6b1c30fa66 100644 --- a/crates/papyrus_execution/src/lib.rs +++ b/crates/papyrus_execution/src/lib.rs @@ -794,7 +794,7 @@ fn to_blockifier_tx( only_query, ) => { let class_v1 = BlockifierContractClass::V1( - compiled_class.try_into().map_err(BlockifierError::new)?, + (&compiled_class).try_into().map_err(BlockifierError::new)?, ); let class_info = ClassInfo::new(&class_v1, sierra_program_length, abi_length).map_err(|err| { @@ -821,7 +821,7 @@ fn to_blockifier_tx( only_query, ) => { let class_v1 = BlockifierContractClass::V1( - compiled_class.try_into().map_err(BlockifierError::new)?, + (&compiled_class).try_into().map_err(BlockifierError::new)?, ); let class_info = ClassInfo::new(&class_v1, sierra_program_length, abi_length).map_err(|err| { diff --git a/crates/papyrus_execution/src/state_reader.rs b/crates/papyrus_execution/src/state_reader.rs index a15df3c3c83..a103e6d4d23 100644 --- a/crates/papyrus_execution/src/state_reader.rs +++ b/crates/papyrus_execution/src/state_reader.rs @@ -85,7 +85,7 @@ impl BlockifierStateReader for ExecutionStateReader { .and_then(|pending_data| pending_data.classes.get_compiled_class(class_hash)) { return Ok(BlockifierContractClass::V1( - ContractClassV1::try_from(pending_casm).map_err(StateError::ProgramError)?, + ContractClassV1::try_from(&pending_casm).map_err(StateError::ProgramError)?, )); } if let Some(ApiContractClass::DeprecatedContractClass(pending_deprecated_class)) = self diff --git a/crates/papyrus_execution/src/state_reader_test.rs b/crates/papyrus_execution/src/state_reader_test.rs index 332eec48c91..72225a64eeb 100644 --- a/crates/papyrus_execution/src/state_reader_test.rs +++ b/crates/papyrus_execution/src/state_reader_test.rs @@ -49,8 +49,7 @@ fn read_state() { // The class is not used in the execution, so it can be default. let class0 = ContractClass::default(); let casm0 = get_test_casm(); - let blockifier_casm0 = - BlockifierContractClass::V1(ContractClassV1::try_from(casm0.clone()).unwrap()); + let blockifier_casm0 = BlockifierContractClass::V1(ContractClassV1::try_from(&casm0).unwrap()); let compiled_class_hash0 = CompiledClassHash(StarkHash::default()); let class_hash1 = ClassHash(1u128.into()); @@ -64,8 +63,7 @@ fn read_state() { let compiled_class_hash2 = CompiledClassHash(StarkHash::TWO); let mut casm1 = get_test_casm(); casm1.bytecode[0] = BigUintAsHex { value: 12345u32.into() }; - let blockifier_casm1 = - BlockifierContractClass::V1(ContractClassV1::try_from(casm1.clone()).unwrap()); + let blockifier_casm1 = BlockifierContractClass::V1(ContractClassV1::try_from(&casm1).unwrap()); let nonce1 = Nonce(felt!(2_u128)); let class_hash3 = ClassHash(567_u128.into()); let class_hash4 = ClassHash(89_u128.into());