diff --git a/Cargo.lock b/Cargo.lock index a47052fbd3..d72d2d236b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -7706,6 +7706,7 @@ dependencies = [ "blockifier", "indexmap 2.6.0", "papyrus_storage", + "rstest", "starknet-types-core", "starknet_api", ] diff --git a/crates/blockifier/src/blockifier/config.rs b/crates/blockifier/src/blockifier/config.rs index e49fb7fbac..0aad8c1ee1 100644 --- a/crates/blockifier/src/blockifier/config.rs +++ b/crates/blockifier/src/blockifier/config.rs @@ -84,6 +84,13 @@ impl Default for ContractClassManagerConfig { } } +impl ContractClassManagerConfig { + #[cfg(any(test, feature = "testing", feature = "native_blockifier"))] + pub fn create_for_testing(run_cairo_native: bool, wait_on_native_compilation: bool) -> Self { + Self { run_cairo_native, wait_on_native_compilation, ..Default::default() } + } +} + impl SerializeConfig for ContractClassManagerConfig { fn dump(&self) -> BTreeMap { BTreeMap::from_iter([ diff --git a/crates/blockifier/src/state/contract_class_manager.rs b/crates/blockifier/src/state/contract_class_manager.rs index 8f73e2a335..0da8330ea8 100644 --- a/crates/blockifier/src/state/contract_class_manager.rs +++ b/crates/blockifier/src/state/contract_class_manager.rs @@ -27,7 +27,6 @@ use crate::execution::native::contract_class::NativeCompiledClassV1; #[cfg(feature = "cairo_native")] use crate::state::global_cache::CachedCairoNative; use crate::state::global_cache::{CachedCasm, ContractCaches}; - pub const DEFAULT_COMPILATION_REQUEST_CHANNEL_SIZE: usize = 1000; /// Represents a request to compile a sierra contract class to a native compiled class. @@ -152,6 +151,11 @@ impl ContractClassManager { self.contract_caches.set_casm(class_hash, compiled_class); } + #[cfg(all(feature = "cairo_native", feature = "testing"))] + pub fn set_native(&self, class_hash: ClassHash, compiled_class: NativeCompiledClassV1) { + self.contract_caches.set_native(class_hash, CachedCairoNative::Compiled(compiled_class)); + } + #[cfg(feature = "cairo_native")] pub fn run_cairo_native(&self) -> bool { self.config.run_cairo_native diff --git a/crates/papyrus_state_reader/Cargo.toml b/crates/papyrus_state_reader/Cargo.toml index 5b279784cd..1d5d219e5b 100644 --- a/crates/papyrus_state_reader/Cargo.toml +++ b/crates/papyrus_state_reader/Cargo.toml @@ -7,6 +7,7 @@ license.workspace = true [features] cairo_native = ["blockifier/cairo_native"] +testing = ["blockifier/testing", "rstest"] [lints] workspace = true @@ -14,6 +15,7 @@ workspace = true [dependencies] blockifier.workspace = true papyrus_storage.workspace = true +rstest = { workspace = true, optional = true } starknet-types-core.workspace = true starknet_api.workspace = true @@ -22,3 +24,4 @@ assert_matches.workspace = true blockifier = { workspace = true, features = ["testing"] } indexmap.workspace = true papyrus_storage = { workspace = true, features = ["testing"] } +rstest.workspace = true diff --git a/crates/papyrus_state_reader/src/papyrus_state_test.rs b/crates/papyrus_state_reader/src/papyrus_state_test.rs index e3e006a496..3ef94d1d94 100644 --- a/crates/papyrus_state_reader/src/papyrus_state_test.rs +++ b/crates/papyrus_state_reader/src/papyrus_state_test.rs @@ -1,20 +1,29 @@ +use core::panic; +use std::sync::Arc; + use assert_matches::assert_matches; use blockifier::blockifier::config::ContractClassManagerConfig; use blockifier::execution::call_info::CallExecution; +#[cfg(feature = "cairo_native")] +use blockifier::execution::contract_class::RunnableCompiledClass; use blockifier::execution::entry_point::CallEntryPoint; use blockifier::retdata; use blockifier::state::cached_state::CachedState; use blockifier::state::contract_class_manager::ContractClassManager; +use blockifier::state::global_cache::CachedCasm; use blockifier::state::state_api::StateReader; use blockifier::test_utils::contracts::FeatureContract; -use blockifier::test_utils::{trivial_external_entry_point_new, CairoVersion}; +use blockifier::test_utils::{trivial_external_entry_point_new, CairoVersion, RunnableCairo1}; use indexmap::IndexMap; use papyrus_storage::class::ClassStorageWriter; +use papyrus_storage::compiled_class::CasmStorageWriter; use papyrus_storage::state::StateStorageWriter; +use rstest::rstest; use starknet_api::abi::abi_utils::selector_from_name; use starknet_api::block::BlockNumber; use starknet_api::contract_class::ContractClass; -use starknet_api::state::{StateDiff, StorageKey}; +use starknet_api::core::ClassHash; +use starknet_api::state::{StateDiff, StorageKey, ThinStateDiff}; use starknet_api::{calldata, felt}; use crate::papyrus_state::PapyrusReader; @@ -76,3 +85,167 @@ fn test_entry_point_with_papyrus_state() -> papyrus_storage::StorageResult<()> { Ok(()) } + +fn build_papyrus_state_reader_and_declare_contract( + class_hash: ClassHash, + contract: FeatureContract, + contract_manager_config: ContractClassManagerConfig, +) -> PapyrusReader { + let ((storage_reader, mut storage_writer), _) = papyrus_storage::test_utils::get_test_storage(); + let test_compiled_class_hash = contract.get_compiled_class_hash(); + let block_number = BlockNumber::default(); + + // Hack to declare the contract in the storage. + match contract.get_class() { + ContractClass::V1((casm_class, _)) => { + let thin_state_diff = ThinStateDiff { + declared_classes: IndexMap::from([(class_hash, test_compiled_class_hash)]), + ..Default::default() + }; + storage_writer + .begin_rw_txn() + .unwrap() + .append_state_diff(block_number, thin_state_diff) + .unwrap() + .append_classes(block_number, &[(class_hash, &contract.get_sierra())], &[]) + .unwrap() + .append_casm(&class_hash, &casm_class) + .unwrap() + .commit() + .unwrap(); + } + + ContractClass::V0(deprecated_contract_class) => { + let thin_state_diff = ThinStateDiff { + deprecated_declared_classes: vec![class_hash], + ..Default::default() + }; + storage_writer + .begin_rw_txn() + .unwrap() + .append_state_diff(block_number, thin_state_diff) + .unwrap() + .append_classes(block_number, &[], &[(class_hash, &deprecated_contract_class)]) + .unwrap() + .commit() + .unwrap(); + } + } + + PapyrusReader::new( + storage_reader, + BlockNumber(1), + ContractClassManager::start(contract_manager_config), + ) +} + +#[rstest] +#[case::dont_run_cairo_native(false, false)] +#[cfg_attr(feature = "cairo_native", case::run_cairo_native_without_waiting(true, false))] +#[cfg_attr(feature = "cairo_native", case::run_cairo_native_and_wait(true, true))] +fn test_get_compiled_class( + #[values(CairoVersion::Cairo0, CairoVersion::Cairo1(RunnableCairo1::Casm))] + cairo_version: CairoVersion, + #[values(true, false)] is_cached: bool, + #[case] run_cairo_native: bool, + #[case] wait_on_native_compilation: bool, +) { + // Sanity check + if !run_cairo_native { + assert!(!wait_on_native_compilation); + } + #[cfg(not(feature = "cairo_native"))] + assert!(!run_cairo_native); + + // We store the sierra with the casm only when the casm is cairo1 and the native flag is enabled + let cached_with_sierra = run_cairo_native && matches!(cairo_version, CairoVersion::Cairo1(_)); + // We don't need a native contract because we only use the contract to get the casm amd sierra + // classes. + let test_contract = FeatureContract::TestContract(cairo_version); + let test_class_hash = test_contract.get_class_hash(); + let contract_manager_config = ContractClassManagerConfig::create_for_testing( + run_cairo_native, + wait_on_native_compilation, + ); + + let papyrus_reader = build_papyrus_state_reader_and_declare_contract( + test_class_hash, + test_contract, + contract_manager_config, + ); + + if is_cached { + // Simulate the scenario where the classes are already in the cache. + // Create a cached casm and store it in the cache. + let casm_cashed = match cached_with_sierra { + true => CachedCasm::WithSierra( + test_contract.get_runnable_class(), + Arc::new(test_contract.get_sierra()), + ), + false => CachedCasm::WithoutSierra(test_contract.get_runnable_class()), + }; + papyrus_reader.contract_class_manager.set_casm(test_class_hash, casm_cashed); + } + + let compiled_class = papyrus_reader.get_compiled_class(test_class_hash).unwrap(); + + if cached_with_sierra { + // TODO: Test that a compilation request was sent. + if wait_on_native_compilation { + #[cfg(feature = "cairo_native")] + assert_matches!( + compiled_class, + RunnableCompiledClass::V1Native(_), + "We should have waited to the native class." + ); + } else { + assert_matches!( + compiled_class, + RunnableCompiledClass::V1(_), + "`get_compiled_class` should return Cario1 casm" + ); + } + } else if run_cairo_native { + assert_matches!( + compiled_class, + RunnableCompiledClass::V0(_), + "`get_compiled_class` should return Cario0 casm" + ); + } else { + assert_eq!( + compiled_class, + test_contract.get_runnable_class(), + "`get_compiled_class` should return the casm" + ); + } + // Check that the casm cached type is as expected. + let cached_casm = papyrus_reader.contract_class_manager.get_casm(&test_class_hash); + if cached_with_sierra { + assert_matches!(cached_casm, Some(CachedCasm::WithSierra(_, _))); + } else { + assert_matches!(cached_casm, Some(CachedCasm::WithoutSierra(_))); + } +} + +#[cfg(feature = "cairo_native")] +#[test] +fn test_get_compiled_class_when_native_is_cached() { + let ((storage_reader, _), _) = papyrus_storage::test_utils::get_test_storage(); + let test_contract = FeatureContract::TestContract(CairoVersion::Cairo1(RunnableCairo1::Native)); + let test_class_hash = test_contract.get_class_hash(); + let contract_manager_config = ContractClassManagerConfig::create_for_testing(true, true); + let papyrus_reader = PapyrusReader::new( + storage_reader, + BlockNumber::default(), + ContractClassManager::start(contract_manager_config), + ); + if let RunnableCompiledClass::V1Native(native_compiled_class) = + test_contract.get_runnable_class() + { + papyrus_reader.contract_class_manager.set_native(test_class_hash, native_compiled_class); + } else { + panic!("Expected NativeCompiledClassV1"); + } + let compiled_class = papyrus_reader.get_compiled_class(test_class_hash).unwrap(); + assert_matches!(compiled_class, RunnableCompiledClass::V1Native(_)); +}