From cb70f0d14bdc58ccb0cd0fe4878889b8b1b0d3b6 Mon Sep 17 00:00:00 2001 From: Aviv Greenburg Date: Sun, 15 Dec 2024 19:35:18 +0200 Subject: [PATCH] chore(blockifier): create unit tests for contract_class_manager --- .../src/state/contract_class_manager.rs | 4 +- .../src/state/contract_class_manager_test.rs | 279 ++++++++++++++++++ crates/blockifier/src/state/global_cache.rs | 2 +- crates/blockifier/src/test_utils/contracts.rs | 33 ++- 4 files changed, 312 insertions(+), 6 deletions(-) create mode 100644 crates/blockifier/src/state/contract_class_manager_test.rs diff --git a/crates/blockifier/src/state/contract_class_manager.rs b/crates/blockifier/src/state/contract_class_manager.rs index 8f73e2a335..171afdfd9f 100644 --- a/crates/blockifier/src/state/contract_class_manager.rs +++ b/crates/blockifier/src/state/contract_class_manager.rs @@ -30,6 +30,9 @@ use crate::state::global_cache::{CachedCasm, ContractCaches}; pub const DEFAULT_COMPILATION_REQUEST_CHANNEL_SIZE: usize = 1000; +#[cfg(test)] +#[path = "contract_class_manager_test.rs"] +mod contract_class_manager_test; /// Represents a request to compile a sierra contract class to a native compiled class. /// /// # Fields: @@ -65,7 +68,6 @@ impl ContractClassManager { /// 2. `config.run_cairo_native` is `false`. /// 3. `config.wait_on_native_compilation` is `true`. pub fn start(config: ContractClassManagerConfig) -> ContractClassManager { - // TODO(Avi, 15/12/2024): Add the size of the channel to the config. let contract_caches = ContractCaches::new(config.contract_cache_size); #[cfg(not(feature = "cairo_native"))] return ContractClassManager { contract_caches }; diff --git a/crates/blockifier/src/state/contract_class_manager_test.rs b/crates/blockifier/src/state/contract_class_manager_test.rs new file mode 100644 index 0000000000..55cfabf06c --- /dev/null +++ b/crates/blockifier/src/state/contract_class_manager_test.rs @@ -0,0 +1,279 @@ +#[cfg(feature = "cairo_native")] +use std::sync::mpsc::sync_channel; +#[cfg(feature = "cairo_native")] +use std::sync::Arc; + +#[cfg(feature = "cairo_native")] +use rstest::rstest; +#[cfg(feature = "cairo_native")] +use starknet_sierra_compile::command_line_compiler::CommandLineCompiler; +#[cfg(feature = "cairo_native")] +use starknet_sierra_compile::config::SierraToCasmCompilationConfig; + +#[cfg(feature = "cairo_native")] +use crate::blockifier::config::ContractClassManagerConfig; +// use crate::concurrency::test_utils::class_hash; +#[cfg(feature = "cairo_native")] +use crate::execution::contract_class::RunnableCompiledClass; +#[cfg(all(test, feature = "cairo_native"))] +use crate::execution::native::contract_class::NativeCompiledClassV1; +#[cfg(feature = "cairo_native")] +use crate::state::contract_class_manager::process_compilation_request; +#[cfg(all(test, feature = "cairo_native"))] +use crate::state::contract_class_manager::CompilationRequest; +#[cfg(feature = "cairo_native")] +use crate::state::contract_class_manager::ContractClassManager; +#[cfg(feature = "cairo_native")] +use crate::state::global_cache::{CachedCairoNative, ContractCaches}; +#[cfg(feature = "cairo_native")] +use crate::test_utils::contracts::FeatureContract; +#[cfg(feature = "cairo_native")] +use crate::test_utils::{CairoVersion, RunnableCairo1}; + +#[cfg(feature = "cairo_native")] +type TestRequestWithNative = (CompilationRequest, NativeCompiledClassV1); +#[cfg(feature = "cairo_native")] +const TEST_CHANNEL_SIZE: usize = 10; + +#[cfg(feature = "cairo_native")] +#[rstest] +fn test_start( + #[values(true, false)] run_cairo_native: bool, + #[values(true, false)] wait_on_native_compilation: bool, +) { + let config = ContractClassManagerConfig { + run_cairo_native, + wait_on_native_compilation, + ..Default::default() + }; + let manager = ContractClassManager::start(config.clone()); + // TODO(AvivG): any constraints on initial caches? should start empty? + // TODO(AvivG): any checks for not cairo_native? #[cfg(not(feature = "cairo_native"))] + + assert_eq!(manager.config, config); + if !run_cairo_native | wait_on_native_compilation { + assert!(manager.sender.is_none(), "Sender should be None"); + } else { + assert!(manager.sender.is_some(), "Sender should be Some"); + } + if !run_cairo_native | !wait_on_native_compilation { + assert!(manager.compiler.is_none(), "Compiler should be None"); + } else { + // TODO(AvivG): any constraints on initial compiler? + assert!(manager.compiler.is_some(), "Compiler should be Some"); + } +} + +#[cfg(feature = "cairo_native")] +#[test] +#[should_panic] +fn test_send_compilation_request_channel_disconnected() { + let config = ContractClassManagerConfig { run_cairo_native: true, ..Default::default() }; + let contract_caches = ContractCaches::new(config.contract_cache_size); + let (sender, receiver) = sync_channel(config.channel_size); + drop(receiver); + let manager = + ContractClassManager { config, contract_caches, sender: Some(sender), compiler: None }; + + let request = create_test_request(); + // TODO(AvivG): add massage: Expected panic when sending request with disconnected channel + manager.send_compilation_request(request); +} + +#[cfg(feature = "cairo_native")] +#[test] +fn test_send_compilation_request_channel_full() { + let config = ContractClassManagerConfig { + run_cairo_native: true, + channel_size: 1, + ..Default::default() + }; + let manager = ContractClassManager::start(config); + let request = create_test_request(); + let second_request = create_test_request(); + + // Fill the channel (it can only hold 1 message) + manager.send_compilation_request(request); + // Should log an error without panicking + manager.send_compilation_request(second_request); + // TODO(AvivG): how to add massage? : "Should not panic when channel is full."; +} + +#[cfg(feature = "cairo_native")] +#[test] +fn test_send_compilation_request_wait_on_native() { + let config = ContractClassManagerConfig { + run_cairo_native: true, + wait_on_native_compilation: true, + ..Default::default() + }; + let manager = ContractClassManager::start(config); + let (request, native) = create_test_request_with_native(); + let class_hash = request.0; + manager.send_compilation_request(request); + assert_eq!( + manager.get_native(&class_hash), + Some(CachedCairoNative::Compiled(native)), + "Cached Native class should match the expected result" + ); +} + +#[cfg(feature = "cairo_native")] +#[test] +#[should_panic] +fn test_send_compilation_request_run_cairo_native_false() { + let config = ContractClassManagerConfig { + run_cairo_native: false, + wait_on_native_compilation: true, + ..Default::default() + }; + let manager = ContractClassManager::start(config); + let request = create_test_request(); + manager.send_compilation_request(request); + // TODO (AvivG): add massage: Expected panic when sending request with run_cairo_native false +} + +#[cfg(feature = "cairo_native")] +#[rstest] +#[case::success(create_test_request_with_native(), CachedCairoNative::Compiled(create_test_request_with_native().1))] +#[case::failure(create_faulty_test_request(), CachedCairoNative::CompilationFailed)] +fn test_process_compilation_request( + #[case] request_w_native: TestRequestWithNative, + #[case] expected_cache: CachedCairoNative, +) { + // let manager = create_test_contract_class_manager(TEST_CHANNEL_SIZE); + let config = ContractClassManagerConfig { + run_cairo_native: true, + channel_size: TEST_CHANNEL_SIZE, + wait_on_native_compilation: true, + ..Default::default() + }; + let manager = ContractClassManager::start(config); + let (request, _native) = request_w_native; + let compiler_config = SierraToCasmCompilationConfig::default(); + let compiler = Arc::new(CommandLineCompiler::new(compiler_config)); + process_compilation_request(manager.contract_caches.clone(), compiler.clone(), request.clone()); + + let cached_native = manager.get_native(&request.0); + assert_eq!( + cached_native, + Some(expected_cache), + "Cached Native class should match the expected result" + ); +} + +#[cfg(feature = "cairo_native")] +#[rstest] +#[case(false, "Sender should be None when native compilation is disabled")] +#[case(true, "Sender should be Some when native compilation is enabled")] +fn test_sender_with_native_compilation(#[case] run_cairo_native: bool, #[case] message: &str) { + let config = ContractClassManagerConfig { run_cairo_native, ..Default::default() }; + let manager = ContractClassManager::start(config); + + if run_cairo_native { + assert!(manager.sender.is_some(), "{}", message); + assert!( + manager.sender.as_ref().unwrap().try_send(create_test_request()).is_ok(), + "Sender should be able to send a request successfully" + ); + } else { + assert!(manager.sender.is_none(), "{}", message); + } +} + +#[cfg(all(test, feature = "cairo_native"))] +fn create_faulty_test_request() -> TestRequestWithNative { + let test_contract = FeatureContract::TestContract(CairoVersion::Cairo1(RunnableCairo1::Native)); + create_faulty_request(test_contract) +} + +#[cfg(all(test, feature = "cairo_native"))] +fn create_faulty_request(test_contract: FeatureContract) -> TestRequestWithNative { + let class_hash = test_contract.get_class_hash(); + let casm = test_contract.get_casm(); + let mut sierra = test_contract.get_sierra(); + // Truncate the sierra program to trigger an error. + sierra.sierra_program = sierra.sierra_program[..100].to_vec(); + + let request = (class_hash, Arc::new(sierra), casm); + + (request, get_native(test_contract)) +} + +#[cfg(feature = "cairo_native")] +fn create_test_request_from_contract(test_contract: FeatureContract) -> CompilationRequest { + let class_hash = test_contract.get_class_hash(); + let sierra = Arc::new(test_contract.get_sierra()); + let casm = test_contract.get_casm(); + + (class_hash, sierra, casm) +} + +#[cfg(all(test, feature = "cairo_native"))] +fn create_test_request() -> CompilationRequest { + // Question (AvivG): are we interested in testing other contracts? + let test_contract = FeatureContract::TestContract(CairoVersion::Cairo1(RunnableCairo1::Native)); + create_test_request_from_contract(test_contract) +} + +#[cfg(all(test, feature = "cairo_native"))] +fn get_native(test_contract: FeatureContract) -> NativeCompiledClassV1 { + match test_contract.get_runnable_class() { + RunnableCompiledClass::V1Native(native) => native, + _ => panic!("Expected NativeCompiledClassV1"), + } +} + +#[cfg(all(test, feature = "cairo_native"))] +fn create_test_request_with_native() -> TestRequestWithNative { + let test_contract = FeatureContract::TestContract(CairoVersion::Cairo1(RunnableCairo1::Native)); + let request = create_test_request_from_contract(test_contract); + let native = get_native(test_contract); + + (request, native) +} + +// TODO (AvivG): Add tests for: +// getters? +// setters? +// clear +// process_compilation_request +// run_compilation_worker + +// #[test] +// fn test_get_casm() { +// let config = ContractClassManagerConfig { +// run_cairo_native: false, +// ..Default::default() +// }; +// let manager = ContractClassManager::start(config); +// let class_hash = ClassHash::default(); +// assert!(manager.get_casm(&class_hash).is_none()); +// } + +// #[test] +// fn test_set_and_get_casm() { +// let config = ContractClassManagerConfig { +// run_cairo_native: false, +// ..Default::default() +// }; +// let manager = ContractClassManager::start(config); +// let class_hash = ClassHash::default(); +// let compiled_class = CachedCasm::default(); +// manager.set_casm(class_hash, compiled_class.clone()); +// assert_eq!(manager.get_casm(&class_hash), Some(compiled_class)); +// } + +// #[test] +// fn test_clear_cache() { +// let config = ContractClassManagerConfig { +// run_cairo_native: false, +// ..Default::default() +// }; +// let mut manager = ContractClassManager::start(config); +// let class_hash = ClassHash::default(); +// let compiled_class = CachedCasm::default(); +// manager.set_casm(class_hash, compiled_class); +// manager.clear(); +// assert!(manager.get_casm(&class_hash).is_none()); +// } diff --git a/crates/blockifier/src/state/global_cache.rs b/crates/blockifier/src/state/global_cache.rs index dd5c5f8a64..20780805c0 100644 --- a/crates/blockifier/src/state/global_cache.rs +++ b/crates/blockifier/src/state/global_cache.rs @@ -30,7 +30,7 @@ impl CachedCasm { } #[cfg(feature = "cairo_native")] -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub enum CachedCairoNative { Compiled(NativeCompiledClassV1), CompilationFailed, diff --git a/crates/blockifier/src/test_utils/contracts.rs b/crates/blockifier/src/test_utils/contracts.rs index 6bec15ea0f..52d87a046d 100644 --- a/crates/blockifier/src/test_utils/contracts.rs +++ b/crates/blockifier/src/test_utils/contracts.rs @@ -17,7 +17,7 @@ use starknet_types_core::felt::Felt; use strum::IntoEnumIterator; use strum_macros::EnumIter; -use crate::execution::contract_class::RunnableCompiledClass; +use crate::execution::contract_class::{CompiledClassV1, RunnableCompiledClass}; use crate::execution::entry_point::CallEntryPoint; #[cfg(feature = "cairo_native")] use crate::execution::native::contract_class::NativeCompiledClassV1; @@ -180,6 +180,28 @@ impl FeatureContract { } } + pub fn get_casm(&self) -> CompiledClassV1 { + // Question (AvivG) : what is the desired behaviour? + if *self == Self::ERC20(self.cairo_version()) { + todo!("ERC20 cannot be tested with Native") + }; + match self.cairo_version() { + CairoVersion::Cairo0 => { + panic!("Casm contracts are only available for Cairo1."); + } + CairoVersion::Cairo1(_) => { + let compiled_path = format!( + "feature_contracts/cairo{}/{}{}.json", + "1/compiled", + self.get_non_erc20_base_name(), + ".casm" + ); + let contact_class = CasmContractClass::from_file(&compiled_path); + (contact_class, self.get_sierra_version()).try_into().unwrap() + } + } + } + pub fn get_runnable_class(&self) -> RunnableCompiledClass { #[cfg(feature = "cairo_native")] if CairoVersion::Cairo1(RunnableCairo1::Native) == self.cairo_version() { @@ -199,10 +221,13 @@ impl FeatureContract { get_raw_contract_class(&self.get_sierra_path()) } - pub fn get_sierra(&self) -> SierraContractClass { + pub fn get_contract_class(&self) -> CairoLangContractClass { let raw_sierra = self.get_raw_sierra(); - let cairo_contract_class: CairoLangContractClass = - serde_json::from_str(&raw_sierra).unwrap(); + serde_json::from_str(&raw_sierra).unwrap() + } + + pub fn get_sierra(&self) -> SierraContractClass { + let cairo_contract_class = self.get_contract_class(); SierraContractClass::from(cairo_contract_class) }