diff --git a/crates/blockifier/src/state/contract_class_manager.rs b/crates/blockifier/src/state/contract_class_manager.rs index ebbb7102cd7..a0fd7a60e1e 100644 --- a/crates/blockifier/src/state/contract_class_manager.rs +++ b/crates/blockifier/src/state/contract_class_manager.rs @@ -30,6 +30,9 @@ use crate::state::global_cache::{CachedCasm, ContractCaches}; pub const DEFAULT_COMPILATION_REQUEST_CHANNEL_SIZE: usize = 1000; +#[cfg(all(test, feature = "cairo_native"))] +#[path = "contract_class_manager_test.rs"] +mod contract_class_manager_test; /// Represents a request to compile a sierra contract class to a native compiled class. /// /// # Fields: @@ -57,52 +60,52 @@ pub struct ContractClassManager { } impl ContractClassManager { - /// Creates a new contract class manager and spawns a thread that listens for compilation - /// requests and processes them (a.k.a. the compilation worker). - /// Returns the contract class manager. - /// NOTE: the compilation worker is not spawned if one of the following conditions is met: - /// 1. The feature `cairo_native` is not enabled. - /// 2. `config.run_cairo_native` is `false`. - /// 3. `config.wait_on_native_compilation` is `true`. - pub fn start(config: ContractClassManagerConfig) -> ContractClassManager { - // TODO(Avi, 15/12/2024): Add the size of the channel to the config. - let contract_caches = ContractCaches::new(config.contract_cache_size); - #[cfg(not(feature = "cairo_native"))] - return ContractClassManager { contract_caches }; - #[cfg(feature = "cairo_native")] - { - if !config.run_cairo_native { - // Native compilation is disabled - no need to start the compilation worker. - return ContractClassManager { - config, - contract_caches, - sender: None, - compiler: None, - }; - } - - let compiler_config = SierraToCasmCompilationConfig::default(); - let compiler = Arc::new(CommandLineCompiler::new(compiler_config)); - if config.wait_on_native_compilation { - // Compilation requests are processed synchronously. No need to start the worker. - return ContractClassManager { - config, - contract_caches, - sender: None, - compiler: Some(compiler), - }; - } + /// Creates a new contract class manager and spawns a thread that listens for compilation + /// requests and processes them (a.k.a. the compilation worker). + /// Returns the contract class manager. + /// NOTE: the compilation worker is not spawned if one of the following conditions is met: + /// 1. The feature `cairo_native` is not enabled. + /// 2. `config.run_cairo_native` is `false`. + /// 3. `config.wait_on_native_compilation` is `true`. + pub fn start(config: ContractClassManagerConfig) -> ContractClassManager { + // TODO(Avi, 15/12/2024): Add the size of the channel to the config. + let contract_caches = ContractCaches::new(config.contract_cache_size); + #[cfg(not(feature = "cairo_native"))] + return ContractClassManager { contract_caches }; + #[cfg(feature = "cairo_native")] + { + if !config.run_cairo_native { + // Native compilation is disabled - no need to start the compilation worker. + return ContractClassManager { + config, + contract_caches, + sender: None, + compiler: None, + }; + } + + let compiler_config = SierraToCasmCompilationConfig::default(); + let compiler = Arc::new(CommandLineCompiler::new(compiler_config)); + if config.wait_on_native_compilation { + // Compilation requests are processed synchronously. No need to start the worker. + return ContractClassManager { + config, + contract_caches, + sender: None, + compiler: Some(compiler), + }; + } let (sender, receiver) = sync_channel(config.channel_size); - std::thread::spawn({ - let contract_caches = contract_caches.clone(); - move || run_compilation_worker(contract_caches, receiver, compiler) - }); + std::thread::spawn({ + let contract_caches = contract_caches.clone(); + move || run_compilation_worker(contract_caches, receiver, compiler) + }); - ContractClassManager { config, contract_caches, sender: Some(sender), compiler: None } + ContractClassManager { config, contract_caches, sender: Some(sender), compiler: None } + } } - } /// Sends a compilation request. Two cases: /// 1. If `config.wait_on_native_compilation` is `false`, sends the request to the compilation diff --git a/crates/blockifier/src/state/contract_class_manager_test.rs b/crates/blockifier/src/state/contract_class_manager_test.rs new file mode 100644 index 00000000000..e33821a4fd7 --- /dev/null +++ b/crates/blockifier/src/state/contract_class_manager_test.rs @@ -0,0 +1,202 @@ +// use std::sync::mpsc::sync_channel; +use std::sync::Arc; + +use crate::blockifier::config::ContractClassManagerConfig; +use crate::execution::contract_class::RunnableCompiledClass; +use crate::execution::native::contract_class::NativeCompiledClassV1; +use crate::state::contract_class_manager::{CompilationRequest, ContractClassManager}; +use crate::state::global_cache::{CachedCairoNative, ContractCaches}; +use crate::test_utils::contracts::FeatureContract; +use crate::test_utils::{CairoVersion, RunnableCairo1}; + +const TEST_CHANNEL_SIZE: usize = 10; + +fn create_faulty_test_request() -> CompilationRequest { + let test_contract = FeatureContract::TestContract(CairoVersion::Cairo1(RunnableCairo1::Native)); + + let mut contract_class = test_contract.get_contract_class(); + // Truncate the sierra program to trigger an error. + contract_class.sierra_program = contract_class.sierra_program[..100].to_vec(); + + let sierra = contract_class.into(); + let class_hash = test_contract.get_class_hash(); + let casm = test_contract.get_casm(); + + (class_hash, Arc::new(sierra), casm) +} + +fn create_test_contract_class_manager(channel_size: usize) -> ContractClassManager { + let config = + ContractClassManagerConfig { run_cairo_native: true, channel_size, ..Default::default() }; + + ContractClassManager::start(config) +} + +fn create_test_request_from_contract(test_contract: FeatureContract) -> CompilationRequest { + let class_hash = test_contract.get_class_hash(); + let sierra = Arc::new(test_contract.get_sierra()); + let casm = test_contract.get_casm(); + + (class_hash, sierra, casm) +} + +fn create_test_request() -> CompilationRequest { + // Question (AvivG): are we interested in testing other contracts? + let test_contract = FeatureContract::TestContract(CairoVersion::Cairo1(RunnableCairo1::Native)); + create_test_request_from_contract(test_contract) +} + +fn create_test_request_with_native() -> (CompilationRequest, NativeCompiledClassV1) { + let test_contract = FeatureContract::TestContract(CairoVersion::Cairo1(RunnableCairo1::Native)); + let request = create_test_request_from_contract(test_contract); + let native = match test_contract.get_runnable_class() { + RunnableCompiledClass::V1Native(native) => native, + _ => panic!("Expected NativeCompiledClassV1"), + }; + + (request, native) +} + +#[test] +fn test_sender_with_native_compilation_disabled() { + let config = ContractClassManagerConfig { run_cairo_native: false, ..Default::default() }; + let manager = ContractClassManager::start(config); + assert!(manager.sender.is_none(), "Sender should be None when native compilation is disabled"); +} + +#[test] +fn test_sender_with_native_compilation_enabled() { + let config = ContractClassManagerConfig { run_cairo_native: true, ..Default::default() }; + let manager = ContractClassManager::start(config); + assert!(manager.sender.is_some()); + + assert!( + manager.sender.as_ref().unwrap().try_send(create_test_request()).is_ok(), + "Sender should be able to send a request successfully" + ); +} + +#[test] +fn test_send_request_channel_disconnected() { + let config = ContractClassManagerConfig { run_cairo_native: true, ..Default::default() }; + let contract_caches = ContractCaches::new(config.contract_cache_size); + let manager = ContractClassManager { config, contract_caches, sender: None, compiler: None }; + + let request = create_test_request(); + let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + manager.send_compilation_request(request); + })); + + assert!(result.is_err(), "Expected panic when sending request with disconnected channel"); +} + +#[test] +fn test_send_compilation_request_channel_full() { + let config = ContractClassManagerConfig { + run_cairo_native: true, + channel_size: 1, + ..Default::default() + }; + let manager = ContractClassManager::start(config); + let request = create_test_request(); + let second_request = create_test_request(); + + // Fill the channel (it can only hold 1 message) + manager.send_compilation_request(request); + let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| { + manager.send_compilation_request(second_request); + })); + + assert!(result.is_ok(), "Should not panic when channel is full."); +} + +#[test] +fn test_run_compilation_worker_success() { + let manager = create_test_contract_class_manager(TEST_CHANNEL_SIZE); + let (request, native) = create_test_request_with_native(); + + manager.sender.as_ref().unwrap().send(request.clone()).unwrap(); + + // Wait for the worker to process the request + // Question (AvivG): better to have a loop and try to get native every X mil sec? + std::thread::sleep(std::time::Duration::from_millis(50000)); + + let cached_native = manager.get_native(&request.0); + + assert!(cached_native.is_some(), "Native compiled class should exist in the cache"); + + match cached_native.unwrap() { + CachedCairoNative::Compiled(cached_class) => { + assert_eq!(cached_class, native, "Cached class should match the expected native class"); + } + CachedCairoNative::CompilationFailed => { + panic!("Expected CachedCairoNative::Compiled variant") + } + }; +} + +#[test] +fn test_run_compilation_worker_failure() { + let manager = create_test_contract_class_manager(TEST_CHANNEL_SIZE); + + let request = create_faulty_test_request(); + + manager.sender.as_ref().unwrap().send(request.clone()).unwrap(); + + // Wait for the worker to process the request + // Question (AvivG): better to have a loop and try to get native every X mil sec? + std::thread::sleep(std::time::Duration::from_millis(5000)); + + // Check if the compilation-failed variant was added to the cache + let cached_native = manager.get_native(&request.0); + assert_eq!( + cached_native, + Some(CachedCairoNative::CompilationFailed), + "Native compiled class should indicate compilation failure" + ); +} + +#[test] +fn test_channel_receiver_down_when_sender_dropped() { + // TODO (AvivG). +} + +// TODO (AvivG):test compilation logs. + +// #[test] +// fn test_get_casm() { +// let config = ContractClassManagerConfig { +// run_cairo_native: false, +// ..Default::default() +// }; +// let manager = ContractClassManager::start(config); +// let class_hash = ClassHash::default(); +// assert!(manager.get_casm(&class_hash).is_none()); +// } + +// #[test] +// fn test_set_and_get_casm() { +// let config = ContractClassManagerConfig { +// run_cairo_native: false, +// ..Default::default() +// }; +// let manager = ContractClassManager::start(config); +// let class_hash = ClassHash::default(); +// let compiled_class = CachedCasm::default(); +// manager.set_casm(class_hash, compiled_class.clone()); +// assert_eq!(manager.get_casm(&class_hash), Some(compiled_class)); +// } + +// #[test] +// fn test_clear_cache() { +// let config = ContractClassManagerConfig { +// run_cairo_native: false, +// ..Default::default() +// }; +// let mut manager = ContractClassManager::start(config); +// let class_hash = ClassHash::default(); +// let compiled_class = CachedCasm::default(); +// manager.set_casm(class_hash, compiled_class); +// manager.clear(); +// assert!(manager.get_casm(&class_hash).is_none()); +// } diff --git a/crates/blockifier/src/state/global_cache.rs b/crates/blockifier/src/state/global_cache.rs index 6bd63b15828..5ec8c335bed 100644 --- a/crates/blockifier/src/state/global_cache.rs +++ b/crates/blockifier/src/state/global_cache.rs @@ -23,7 +23,7 @@ pub enum CachedCasm { } #[cfg(feature = "cairo_native")] -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq)] pub enum CachedCairoNative { Compiled(NativeCompiledClassV1), CompilationFailed, diff --git a/crates/blockifier/src/test_utils/contracts.rs b/crates/blockifier/src/test_utils/contracts.rs index 6bec15ea0f1..e8f38110079 100644 --- a/crates/blockifier/src/test_utils/contracts.rs +++ b/crates/blockifier/src/test_utils/contracts.rs @@ -17,7 +17,10 @@ use starknet_types_core::felt::Felt; use strum::IntoEnumIterator; use strum_macros::EnumIter; -use crate::execution::contract_class::RunnableCompiledClass; +use crate::execution::contract_class::{ + CompiledClassV1, + RunnableCompiledClass, +}; use crate::execution::entry_point::CallEntryPoint; #[cfg(feature = "cairo_native")] use crate::execution::native::contract_class::NativeCompiledClassV1; @@ -180,6 +183,28 @@ impl FeatureContract { } } + pub fn get_casm(&self) -> CompiledClassV1 { + // Question (AvivG) : what is the desired behaviour? + if *self == Self::ERC20(self.cairo_version()) { + todo!("ERC20 cannot be tested with Native") + }; + match self.cairo_version() { + CairoVersion::Cairo0 => { + panic!("Casm contracts are only available for Cairo1."); + } + CairoVersion::Cairo1(_) => { + let compiled_path = format!( + "feature_contracts/cairo{}/{}{}.json", + "1/compiled", + self.get_non_erc20_base_name(), + ".casm" + ); + let contact_class = CasmContractClass::from_file(&compiled_path); + (contact_class, self.get_sierra_version()).try_into().unwrap() + } + } + } + pub fn get_runnable_class(&self) -> RunnableCompiledClass { #[cfg(feature = "cairo_native")] if CairoVersion::Cairo1(RunnableCairo1::Native) == self.cairo_version() { @@ -199,10 +224,13 @@ impl FeatureContract { get_raw_contract_class(&self.get_sierra_path()) } - pub fn get_sierra(&self) -> SierraContractClass { + pub fn get_contract_class(&self) -> CairoLangContractClass { let raw_sierra = self.get_raw_sierra(); - let cairo_contract_class: CairoLangContractClass = - serde_json::from_str(&raw_sierra).unwrap(); + serde_json::from_str(&raw_sierra).unwrap() + } + + pub fn get_sierra(&self) -> SierraContractClass { + let cairo_contract_class = self.get_contract_class(); SierraContractClass::from(cairo_contract_class) } diff --git a/crates/native_blockifier/src/py_objects.rs b/crates/native_blockifier/src/py_objects.rs index 9cc2eb109d9..3d0e53314ae 100644 --- a/crates/native_blockifier/src/py_objects.rs +++ b/crates/native_blockifier/src/py_objects.rs @@ -180,6 +180,8 @@ impl Default for PyContractClassManagerConfig { } } + + impl From for ContractClassManagerConfig { fn from(py_contract_class_manager_config: PyContractClassManagerConfig) -> Self { ContractClassManagerConfig {