From 7644b4d55892efadba6feb8d0448ba02efb52569 Mon Sep 17 00:00:00 2001 From: Arni Hod Date: Thu, 11 Jul 2024 10:43:50 +0300 Subject: [PATCH] refactor: gateway compiler handle declare tx --- Cargo.lock | 1 + crates/gateway/src/compilation.rs | 100 ++++++++++-------- crates/gateway/src/compilation_test.rs | 20 ++-- crates/gateway/src/gateway.rs | 2 +- .../stateful_transaction_validator_test.rs | 2 +- crates/mempool_test_utils/Cargo.toml | 1 + .../src/starknet_api_test_utils.rs | 16 ++- crates/starknet_sierra_compile/src/utils.rs | 10 +- 8 files changed, 87 insertions(+), 65 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 0f175579d..b1f78beb6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3421,6 +3421,7 @@ version = "0.0.0" dependencies = [ "assert_matches", "blockifier 0.8.0-rc.0 (git+https://github.com/starkware-libs/blockifier.git?rev=32191d41)", + "rstest", "serde_json", "starknet-types-core", "starknet_api", diff --git a/crates/gateway/src/compilation.rs b/crates/gateway/src/compilation.rs index 18dd58197..e1c4ea213 100644 --- a/crates/gateway/src/compilation.rs +++ b/crates/gateway/src/compilation.rs @@ -5,6 +5,7 @@ use blockifier::execution::contract_class::{ClassInfo, ContractClass, ContractCl use cairo_lang_starknet_classes::casm_contract_class::{ CasmContractClass, CasmContractEntryPoints, }; +use cairo_lang_starknet_classes::contract_class::ContractClass as CairoLangContractClass; use starknet_api::core::CompiledClassHash; use starknet_api::rpc_transaction::RPCDeclareTransaction; use starknet_sierra_compile::compile::compile_sierra_to_casm; @@ -19,6 +20,7 @@ use crate::utils::is_subsequence; #[path = "compilation_test.rs"] mod compilation_test; +// TODO(Arni): Pass the compiler with dependancy injection. #[derive(Clone)] pub struct GatewayCompiler { #[allow(dead_code)] @@ -29,64 +31,56 @@ impl GatewayCompiler { /// Formats the contract class for compilation, compiles it, and returns the compiled contract /// class wrapped in a [`ClassInfo`]. /// Assumes the contract class is of a Sierra program which is compiled to Casm. - pub fn compile_contract_class( + pub fn process_declare_tx( &self, declare_tx: &RPCDeclareTransaction, ) -> GatewayResult { let RPCDeclareTransaction::V3(tx) = declare_tx; - let starknet_api_contract_class = &tx.contract_class; - let cairo_lang_contract_class = - into_contract_class_for_compilation(starknet_api_contract_class); + let rpc_contract_class = &tx.contract_class; + let cairo_lang_contract_class = into_contract_class_for_compilation(rpc_contract_class); - // Compile Sierra to Casm. + let casm_contract_class = self.compile(cairo_lang_contract_class)?; + + validate_compiled_class_hash(&casm_contract_class, &tx.compiled_class_hash)?; + validate_casm_class(&casm_contract_class)?; + + Ok(ClassInfo::new( + &ContractClass::V1(ContractClassV1::try_from(casm_contract_class)?), + rpc_contract_class.sierra_program.len(), + rpc_contract_class.abi.len(), + )?) + } + + // TODO(Arni): Pass the compilation args from the config. + fn compile( + &self, + cairo_lang_contract_class: CairoLangContractClass, + ) -> Result { let catch_unwind_result = panic::catch_unwind(|| compile_sierra_to_casm(cairo_lang_contract_class)); - let casm_contract_class = match catch_unwind_result { - Ok(compilation_result) => compilation_result?, - Err(_) => { - // TODO(Arni): Log the panic. - return Err(GatewayError::CompilationError(CompilationUtilError::CompilationPanic)); - } - }; - self.validate_casm_class(&casm_contract_class)?; + let casm_contract_class = + catch_unwind_result.map_err(|_| CompilationUtilError::CompilationPanic)??; - let hash_result = CompiledClassHash(casm_contract_class.compiled_class_hash()); - if hash_result != tx.compiled_class_hash { - return Err(GatewayError::CompiledClassHashMismatch { - supplied: tx.compiled_class_hash, - hash_result, - }); - } - - // Convert Casm contract class to Starknet contract class directly. - let blockifier_contract_class = - ContractClass::V1(ContractClassV1::try_from(casm_contract_class)?); - let class_info = ClassInfo::new( - &blockifier_contract_class, - starknet_api_contract_class.sierra_program.len(), - starknet_api_contract_class.abi.len(), - )?; - Ok(class_info) + Ok(casm_contract_class) } +} - // TODO(Arni): Add test. - fn validate_casm_class(&self, contract_class: &CasmContractClass) -> Result<(), GatewayError> { - let CasmContractEntryPoints { external, l1_handler, constructor } = - &contract_class.entry_points_by_type; - let entry_points_iterator = - external.iter().chain(l1_handler.iter()).chain(constructor.iter()); +// TODO(Arni): Add test. +fn validate_casm_class(contract_class: &CasmContractClass) -> Result<(), GatewayError> { + let CasmContractEntryPoints { external, l1_handler, constructor } = + &contract_class.entry_points_by_type; + let entry_points_iterator = external.iter().chain(l1_handler.iter()).chain(constructor.iter()); - for entry_point in entry_points_iterator { - let builtins = &entry_point.builtins; - if !is_subsequence(builtins, supported_builtins()) { - return Err(GatewayError::UnsupportedBuiltins { - builtins: builtins.clone(), - supported_builtins: supported_builtins().to_vec(), - }); - } + for entry_point in entry_points_iterator { + let builtins = &entry_point.builtins; + if !is_subsequence(builtins, supported_builtins()) { + return Err(GatewayError::UnsupportedBuiltins { + builtins: builtins.clone(), + supported_builtins: supported_builtins().to_vec(), + }); } - Ok(()) } + Ok(()) } // TODO(Arni): Add to a config. @@ -101,3 +95,19 @@ fn supported_builtins() -> &'static Vec { SUPPORTED_BUILTIN_NAMES.iter().map(|builtin| builtin.to_string()).collect::>() }) } + +/// Validates that the compiled class hash of the compiled contract class matches the supplied +/// compiled class hash. +fn validate_compiled_class_hash( + casm_contract_class: &CasmContractClass, + supplied_compiled_class_hash: &CompiledClassHash, +) -> Result<(), GatewayError> { + let compiled_class_hash = CompiledClassHash(casm_contract_class.compiled_class_hash()); + if compiled_class_hash != *supplied_compiled_class_hash { + return Err(GatewayError::CompiledClassHashMismatch { + supplied: *supplied_compiled_class_hash, + hash_result: compiled_class_hash, + }); + } + Ok(()) +} diff --git a/crates/gateway/src/compilation_test.rs b/crates/gateway/src/compilation_test.rs index 22e81d10d..913f71857 100644 --- a/crates/gateway/src/compilation_test.rs +++ b/crates/gateway/src/compilation_test.rs @@ -15,23 +15,23 @@ fn gateway_compiler() -> GatewayCompiler { GatewayCompiler { config: Default::default() } } +// TODO(Arni): Redesign this test once the compiler is passed with dependancy injection. #[rstest] -fn test_compile_contract_class_compiled_class_hash_missmatch(gateway_compiler: GatewayCompiler) { +fn test_compile_contract_class_compiled_class_hash_mismatch(gateway_compiler: GatewayCompiler) { let mut tx = assert_matches!( declare_tx(), RPCTransaction::Declare(RPCDeclareTransaction::V3(tx)) => tx ); - let expected_hash_result = tx.compiled_class_hash; - let supplied_hash = CompiledClassHash::default(); - - tx.compiled_class_hash = supplied_hash; + let expected_hash = tx.compiled_class_hash; + let wrong_supplied_hash = CompiledClassHash::default(); + tx.compiled_class_hash = wrong_supplied_hash; let declare_tx = RPCDeclareTransaction::V3(tx); - let result = gateway_compiler.compile_contract_class(&declare_tx); + let result = gateway_compiler.process_declare_tx(&declare_tx); assert_matches!( result.unwrap_err(), GatewayError::CompiledClassHashMismatch { supplied, hash_result } - if supplied == supplied_hash && hash_result == expected_hash_result + if supplied == wrong_supplied_hash && hash_result == expected_hash ); } @@ -45,7 +45,7 @@ fn test_compile_contract_class_bad_sierra(gateway_compiler: GatewayCompiler) { tx.contract_class.sierra_program = tx.contract_class.sierra_program[..100].to_vec(); let declare_tx = RPCDeclareTransaction::V3(tx); - let result = gateway_compiler.compile_contract_class(&declare_tx); + let result = gateway_compiler.process_declare_tx(&declare_tx); assert_matches!( result.unwrap_err(), GatewayError::CompilationError(CompilationUtilError::AllowedLibfuncsError( @@ -55,7 +55,7 @@ fn test_compile_contract_class_bad_sierra(gateway_compiler: GatewayCompiler) { } #[rstest] -fn test_compile_contract_class(gateway_compiler: GatewayCompiler) { +fn test_process_declare_tx_success(gateway_compiler: GatewayCompiler) { let declare_tx = assert_matches!( declare_tx(), RPCTransaction::Declare(declare_tx) => declare_tx @@ -63,7 +63,7 @@ fn test_compile_contract_class(gateway_compiler: GatewayCompiler) { let RPCDeclareTransaction::V3(declare_tx_v3) = &declare_tx; let contract_class = &declare_tx_v3.contract_class; - let class_info = gateway_compiler.compile_contract_class(&declare_tx).unwrap(); + let class_info = gateway_compiler.process_declare_tx(&declare_tx).unwrap(); assert_matches!(class_info.contract_class(), ContractClass::V1(_)); assert_eq!(class_info.sierra_program_length(), contract_class.sierra_program.len()); assert_eq!(class_info.abi_length(), contract_class.abi.len()); diff --git a/crates/gateway/src/gateway.rs b/crates/gateway/src/gateway.rs index 9d1a53fdb..61c1b4764 100644 --- a/crates/gateway/src/gateway.rs +++ b/crates/gateway/src/gateway.rs @@ -128,7 +128,7 @@ fn process_tx( // Compile Sierra to Casm. let optional_class_info = match &tx { RPCTransaction::Declare(declare_tx) => { - Some(gateway_compiler.compile_contract_class(declare_tx)?) + Some(gateway_compiler.process_declare_tx(declare_tx)?) } _ => None, }; diff --git a/crates/gateway/src/stateful_transaction_validator_test.rs b/crates/gateway/src/stateful_transaction_validator_test.rs index 6a27fd350..09fc59fef 100644 --- a/crates/gateway/src/stateful_transaction_validator_test.rs +++ b/crates/gateway/src/stateful_transaction_validator_test.rs @@ -97,7 +97,7 @@ fn test_stateful_tx_validator( let optional_class_info = match &external_tx { RPCTransaction::Declare(declare_tx) => Some( GatewayCompiler { config: GatewayCompilerConfig {} } - .compile_contract_class(declare_tx) + .process_declare_tx(declare_tx) .unwrap(), ), _ => None, diff --git a/crates/mempool_test_utils/Cargo.toml b/crates/mempool_test_utils/Cargo.toml index 4385ec6cd..cb8a0d170 100644 --- a/crates/mempool_test_utils/Cargo.toml +++ b/crates/mempool_test_utils/Cargo.toml @@ -10,6 +10,7 @@ license.workspace = true [dependencies] assert_matches.workspace = true blockifier = { workspace = true, features = ["testing"] } +rstest.workspace = true starknet-types-core.workspace = true starknet_api.workspace = true serde_json.workspace = true diff --git a/crates/mempool_test_utils/src/starknet_api_test_utils.rs b/crates/mempool_test_utils/src/starknet_api_test_utils.rs index 187de758e..4f5cbce21 100644 --- a/crates/mempool_test_utils/src/starknet_api_test_utils.rs +++ b/crates/mempool_test_utils/src/starknet_api_test_utils.rs @@ -90,11 +90,21 @@ pub fn executable_resource_bounds_mapping() -> ResourceBoundsMapping { ) } -pub fn declare_tx() -> RPCTransaction { +/// Get the contract class used for testing. +pub fn contract_class() -> ContractClass { env::set_current_dir(get_absolute_path(TEST_FILES_FOLDER)).expect("Couldn't set working dir."); let json_file_path = Path::new(CONTRACT_CLASS_FILE); - let contract_class = serde_json::from_reader(File::open(json_file_path).unwrap()).unwrap(); - let compiled_class_hash = CompiledClassHash(felt!(COMPILED_CLASS_HASH_OF_CONTRACT_CLASS)); + serde_json::from_reader(File::open(json_file_path).unwrap()).unwrap() +} + +/// Get the compiled class hash corresponding to the contract class used for testing. +pub fn compiled_class_hash() -> CompiledClassHash { + CompiledClassHash(felt!(COMPILED_CLASS_HASH_OF_CONTRACT_CLASS)) +} + +pub fn declare_tx() -> RPCTransaction { + let contract_class = contract_class(); + let compiled_class_hash = compiled_class_hash(); let account_contract = FeatureContract::AccountWithoutValidations(CairoVersion::Cairo1); let account_address = account_contract.get_instance_address(0); diff --git a/crates/starknet_sierra_compile/src/utils.rs b/crates/starknet_sierra_compile/src/utils.rs index 717eaf176..5ccdcf9fa 100644 --- a/crates/starknet_sierra_compile/src/utils.rs +++ b/crates/starknet_sierra_compile/src/utils.rs @@ -6,7 +6,7 @@ use cairo_lang_starknet_classes::contract_class::{ }; use cairo_lang_utils::bigint::BigUintAsHex; use starknet_api::rpc_transaction::{ - ContractClass as StarknetApiContractClass, EntryPointByType as StarknetApiEntryPointByType, + ContractClass as RpcContractClass, EntryPointByType as StarknetApiEntryPointByType, }; use starknet_api::state::EntryPoint as StarknetApiEntryPoint; use starknet_types_core::felt::Felt; @@ -14,17 +14,17 @@ use starknet_types_core::felt::Felt; /// Retruns a [`CairoLangContractClass`] struct ready for Sierra to Casm compilation. Note the `abi` /// field is None as it is not relevant for the compilation. pub fn into_contract_class_for_compilation( - starknet_api_contract_class: &StarknetApiContractClass, + rpc_contract_class: &RpcContractClass, ) -> CairoLangContractClass { let sierra_program = - starknet_api_contract_class.sierra_program.iter().map(felt_to_big_uint_as_hex).collect(); + rpc_contract_class.sierra_program.iter().map(felt_to_big_uint_as_hex).collect(); let entry_points_by_type = - into_cairo_lang_contract_entry_points(&starknet_api_contract_class.entry_points_by_type); + into_cairo_lang_contract_entry_points(&rpc_contract_class.entry_points_by_type); CairoLangContractClass { sierra_program, sierra_program_debug_info: None, - contract_class_version: starknet_api_contract_class.contract_class_version.clone(), + contract_class_version: rpc_contract_class.contract_class_version.clone(), entry_points_by_type, abi: None, }