From e5f8ab918995d551b8bf4b02bbca0ab15b42e9e6 Mon Sep 17 00:00:00 2001 From: Yoav Gross Date: Wed, 27 Nov 2024 15:20:53 +0200 Subject: [PATCH] fix(blockifier): merge state diff with squash --- crates/blockifier/src/state/cached_state.rs | 51 ++++++--- .../blockifier/src/state/cached_state_test.rs | 101 +++++++++++------- .../src/transaction/account_transaction.rs | 13 +-- 3 files changed, 105 insertions(+), 60 deletions(-) diff --git a/crates/blockifier/src/state/cached_state.rs b/crates/blockifier/src/state/cached_state.rs index 356b7b8f893..71797c314c7 100644 --- a/crates/blockifier/src/state/cached_state.rs +++ b/crates/blockifier/src/state/cached_state.rs @@ -2,6 +2,7 @@ use std::cell::RefCell; use std::collections::{HashMap, HashSet}; use indexmap::IndexMap; +use itertools::Itertools; use starknet_api::abi::abi_utils::get_fee_token_var_address; use starknet_api::core::{ClassHash, CompiledClassHash, ContractAddress, Nonce}; use starknet_api::state::StorageKey; @@ -58,6 +59,11 @@ impl CachedState { self.to_state_diff() } + pub fn to_state_cache(&mut self) -> StateResult { + self.update_initial_values_of_write_only_access()?; + Ok(self.cache.borrow().clone()) + } + pub fn update_cache( &mut self, write_updates: &StateMaps, @@ -383,7 +389,7 @@ impl StateMaps { /// The tracked changes are needed for block state commitment. // Invariant: keys cannot be deleted from fields (only used internally by the cached state). -#[derive(Debug, Default, PartialEq, Eq)] +#[derive(Clone, Debug, Default, PartialEq, Eq)] pub struct StateCache { // Reader's cached information; initial values, read before any write operation (per cell). pub(crate) initial_reads: StateMaps, @@ -402,6 +408,35 @@ impl StateCache { StateChanges { state_maps, allocated_keys } } + /// Squashes the given state caches into a single one and returns the state diff. Note that the + /// order of the state caches is important. + pub fn squashed_state_diff( + state_caches: Vec, + comprehensive_state_diff: bool, + ) -> StateChanges { + // Backward compatibility. + if !comprehensive_state_diff { + return StateChanges::merge( + state_caches + .into_iter() + .map(|state_cache| state_cache.to_state_diff()) + .collect_vec(), + ); + } + + let mut squashed_state_cache = StateCache::default(); + + // Gives priority to early initial reads. + state_caches.iter().rev().for_each(|state_cache| { + squashed_state_cache.initial_reads.extend(&state_cache.initial_reads) + }); + // Gives priority to late writes. + state_caches + .iter() + .for_each(|state_cache| squashed_state_cache.writes.extend(&state_cache.writes)); + squashed_state_cache.to_state_diff() + } + fn declare_contract(&mut self, class_hash: ClassHash) { self.writes.declared_contracts.insert(class_hash, true); } @@ -680,18 +715,6 @@ impl StateChangesKeys { pub struct AllocatedKeys(HashSet); impl AllocatedKeys { - /// Extends the set of allocated keys with the allocated_keys of the given state changes. - /// Removes storage keys that are set back to zero. - pub fn update(&mut self, state_change: &StateChanges) { - self.0.extend(&state_change.allocated_keys.0); - // Remove keys that are set back to zero. - state_change.state_maps.storage.iter().for_each(|(k, v)| { - if v == &Felt::ZERO { - self.0.remove(k); - } - }); - } - pub fn len(&self) -> usize { self.0.len() } @@ -732,7 +755,7 @@ impl StateChanges { let mut merged_state_changes = Self::default(); for state_change in state_changes { merged_state_changes.state_maps.extend(&state_change.state_maps); - merged_state_changes.allocated_keys.update(&state_change); + merged_state_changes.allocated_keys.0.extend(&state_change.allocated_keys.0); } merged_state_changes } diff --git a/crates/blockifier/src/state/cached_state_test.rs b/crates/blockifier/src/state/cached_state_test.rs index 20f191a99cc..b76f40ef7a4 100644 --- a/crates/blockifier/src/state/cached_state_test.rs +++ b/crates/blockifier/src/state/cached_state_test.rs @@ -289,11 +289,11 @@ fn cached_state_state_diff_conversion() { assert_eq!(expected_state_diff, state.to_state_diff().unwrap().state_maps.into()); } -fn create_state_changes_for_test( +fn create_state_cache_for_test( state: &mut CachedState, sender_address: Option, fee_token_address: ContractAddress, -) -> StateChanges { +) -> StateCache { let contract_address = contract_address!(CONTRACT_ADDRESS); let contract_address2 = contract_address!("0x101"); let class_hash = class_hash!("0x10"); @@ -323,7 +323,7 @@ fn create_state_changes_for_test( let sender_balance_key = get_fee_token_var_address(sender_address); state.set_storage_at(fee_token_address, sender_balance_key, felt!("0x1999")).unwrap(); } - state.get_actual_state_changes().unwrap() + state.to_state_cache().unwrap() } #[rstest] @@ -333,7 +333,7 @@ fn test_from_state_changes_for_fee_charge( let mut state: CachedState = CachedState::default(); let fee_token_address = contract_address!("0x17"); let state_changes = - create_state_changes_for_test(&mut state, sender_address, fee_token_address); + create_state_cache_for_test(&mut state, sender_address, fee_token_address).to_state_diff(); let state_changes_count = state_changes.count_for_fee_charge(sender_address, fee_token_address); let n_expected_storage_updates = 1 + usize::from(sender_address.is_some()); let expected_state_changes_count = StateChangesCountForFee { @@ -350,37 +350,38 @@ fn test_from_state_changes_for_fee_charge( } #[rstest] -fn test_state_changes_merge( +fn test_state_cache_merge( #[values(Some(contract_address!("0x102")), None)] sender_address: Option, ) { // Create a transactional state containing the `create_state_changes_for_test` logic, get the - // state changes and then commit. + // state cache and then commit. let mut state: CachedState = CachedState::default(); let mut transactional_state = TransactionalState::create_transactional(&mut state); let block_context = BlockContext::create_for_testing(); let fee_token_address = block_context.chain_info.fee_token_addresses.eth_fee_token_address; - let state_changes1 = - create_state_changes_for_test(&mut transactional_state, sender_address, fee_token_address); + let state_cache1 = + create_state_cache_for_test(&mut transactional_state, sender_address, fee_token_address); transactional_state.commit(); // After performing `commit`, the transactional state is moved (into state). We need to create // a new transactional state that wraps `state` to continue. let mut transactional_state = TransactionalState::create_transactional(&mut state); - // Make sure that `get_actual_state_changes` on a newly created transactional state returns null - // state changes and that merging null state changes with non-null state changes results in the - // non-null state changes, no matter the order. - let state_changes2 = transactional_state.get_actual_state_changes().unwrap(); - assert_eq!(state_changes2, StateChanges::default()); + // Make sure that the state_changes of a newly created transactional state returns null + // state cache and that merging null state cache with non-null state cache results in the + // non-null state cache, no matter the order. + let state_cache2 = transactional_state.to_state_cache().unwrap(); + assert_eq!(state_cache2, StateCache::default()); assert_eq!( - StateChanges::merge(vec![state_changes1.clone(), state_changes2.clone()]), - state_changes1 + StateCache::squashed_state_diff(vec![state_cache1.clone(), state_cache2.clone()], true), + state_cache1.to_state_diff() ); assert_eq!( - StateChanges::merge(vec![state_changes2.clone(), state_changes1.clone()]), - state_changes1 + StateCache::squashed_state_diff(vec![state_cache2.clone(), state_cache1.clone()], true), + state_cache1.to_state_diff() ); - // Get the storage updates addresses and keys from the state_changes1, to overwrite. + // Get the storage updates addresses and keys from the state_cache1, to overwrite. + let state_changes1 = state_cache1.to_state_diff(); let mut storage_updates_keys = state_changes1.state_maps.storage.keys(); let &(contract_address, storage_key) = storage_updates_keys .find(|(contract_address, _)| contract_address == &contract_address!(CONTRACT_ADDRESS)) @@ -394,8 +395,8 @@ fn test_state_changes_merge( .set_storage_at(new_contract_address, storage_key, felt!("0x43210")) .unwrap(); transactional_state.increment_nonce(contract_address).unwrap(); - // Get the new state changes and then commit the transactional state. - let state_changes3 = transactional_state.get_actual_state_changes().unwrap(); + // Get the new state cache and then commit the transactional state. + let state_cache3 = transactional_state.to_state_cache().unwrap(); transactional_state.commit(); // Get the total state changes of the CachedState underlying all the temporary transactional @@ -403,15 +404,14 @@ fn test_state_changes_merge( // states, but only when done in the right order. let state_changes_final = state.get_actual_state_changes().unwrap(); assert_eq!( - StateChanges::merge(vec![ - state_changes1.clone(), - state_changes2.clone(), - state_changes3.clone() - ]), + StateCache::squashed_state_diff( + vec![state_cache1.clone(), state_cache2.clone(), state_cache3.clone()], + true, + ), state_changes_final ); assert_ne!( - StateChanges::merge(vec![state_changes3, state_changes1, state_changes2]), + StateCache::squashed_state_diff(vec![state_cache3, state_cache1, state_cache2], true), state_changes_final ); } @@ -422,33 +422,54 @@ fn test_state_changes_merge( #[case(true, vec![felt!("0x7")], true)] #[case(false, vec![felt!("0x7")], false)] #[case(true, vec![felt!("0x7"), felt!("0x0")], false)] -#[case(false, vec![felt!("0x0"), felt!("0x8")], true)] +#[case(false, vec![felt!("0x7"), felt!("0x1")], false)] +#[case(false, vec![felt!("0x0"), felt!("0x8")], false)] #[case(false, vec![felt!("0x0"), felt!("0x8"), felt!("0x0")], false)] -fn test_allocated_keys_commit_and_merge( +fn test_state_cache_commit_and_merge( #[case] is_base_empty: bool, #[case] storage_updates: Vec, #[case] charged: bool, + #[values(true, false)] comprehensive_state_diff: bool, ) { let contract_address = contract_address!(CONTRACT_ADDRESS); let storage_key = StorageKey::from(0x10_u16); // Set initial state let mut state: CachedState = CachedState::default(); + + let non_empty_base_value = felt!("0x1"); if !is_base_empty { - state.set_storage_at(contract_address, storage_key, felt!("0x1")).unwrap(); + state.set_storage_at(contract_address, storage_key, non_empty_base_value).unwrap(); } - let mut state_changes = vec![]; + let mut state_caches = vec![]; - for value in storage_updates { + for value in storage_updates.iter() { // In the end of the previous loop, state has moved into the transactional state. let mut transactional_state = TransactionalState::create_transactional(&mut state); // Update state and collect the state changes. - transactional_state.set_storage_at(contract_address, storage_key, value).unwrap(); - state_changes.push(transactional_state.get_actual_state_changes().unwrap()); + transactional_state.set_storage_at(contract_address, storage_key, *value).unwrap(); + state_caches.push(transactional_state.to_state_cache().unwrap()); transactional_state.commit(); } - let merged_changes = StateChanges::merge(state_changes); - assert_ne!(merged_changes.allocated_keys.is_empty(), charged); + let merged_changes = StateCache::squashed_state_diff(state_caches, comprehensive_state_diff); + if comprehensive_state_diff { + // The comprehensive_state_diff is needed for backward compatibility of versions before the + // allocated keys feature was inserted. + assert_ne!(merged_changes.allocated_keys.is_empty(), charged); + } + + // Test the storage diff. + let base_value = if is_base_empty { Felt::ZERO } else { non_empty_base_value }; + let last_value = storage_updates.last().unwrap(); + let expected_storage_diff = if (&base_value == last_value) && comprehensive_state_diff { + None + } else { + Some(last_value) + }; + assert_eq!( + merged_changes.state_maps.storage.get(&(contract_address, storage_key)), + expected_storage_diff, + ); } // Test that allocations in validate and execute phases are properly squashed. @@ -456,8 +477,7 @@ fn test_allocated_keys_commit_and_merge( #[case(false, felt!("0x7"), felt!("0x8"), false)] #[case(true, felt!("0x0"), felt!("0x8"), true)] #[case(true, felt!("0x7"), felt!("0x7"), true)] -// TODO: not charge in the following case. -#[case(false, felt!("0x0"), felt!("0x8"), true)] +#[case(false, felt!("0x0"), felt!("0x8"), false)] #[case(true, felt!("0x7"), felt!("0x0"), false)] fn test_allocated_keys_two_transactions( #[case] is_base_empty: bool, @@ -475,13 +495,14 @@ fn test_allocated_keys_two_transactions( let mut first_state = TransactionalState::create_transactional(&mut state); first_state.set_storage_at(contract_address, storage_key, validate_value).unwrap(); - let first_state_changes = first_state.get_actual_state_changes().unwrap(); + let first_state_changes = first_state.to_state_cache().unwrap(); let mut second_state = TransactionalState::create_transactional(&mut first_state); second_state.set_storage_at(contract_address, storage_key, execute_value).unwrap(); - let second_state_changes = second_state.get_actual_state_changes().unwrap(); + let second_state_changes = second_state.to_state_cache().unwrap(); - let merged_changes = StateChanges::merge(vec![first_state_changes, second_state_changes]); + let merged_changes = + StateCache::squashed_state_diff(vec![first_state_changes, second_state_changes], true); assert_ne!(merged_changes.allocated_keys.is_empty(), charged); } diff --git a/crates/blockifier/src/transaction/account_transaction.rs b/crates/blockifier/src/transaction/account_transaction.rs index ca61f75576c..b35fb6bae42 100644 --- a/crates/blockifier/src/transaction/account_transaction.rs +++ b/crates/blockifier/src/transaction/account_transaction.rs @@ -44,7 +44,7 @@ use crate::fee::fee_utils::{ use crate::fee::gas_usage::estimate_minimal_gas_vector; use crate::fee::receipt::TransactionReceipt; use crate::retdata; -use crate::state::cached_state::{StateChanges, TransactionalState}; +use crate::state::cached_state::{StateCache, TransactionalState}; use crate::state::state_api::{State, StateReader, UpdatableState}; use crate::transaction::errors::{ TransactionExecutionError, @@ -612,7 +612,8 @@ impl AccountTransaction { // Save the state changes resulting from running `validate_tx`, to be used later for // resource and fee calculation. - let validate_state_changes = state.get_actual_state_changes()?; + let validate_state_cache = state.to_state_cache()?; + let validate_state_changes = validate_state_cache.to_state_diff(); // Create copies of state and validate_resources for the execution. // Both will be rolled back if the execution is reverted or committed upon success. @@ -643,10 +644,10 @@ impl AccountTransaction { let tx_receipt = TransactionReceipt::from_account_tx( self, &tx_context, - &StateChanges::merge(vec![ - validate_state_changes, - execution_state.get_actual_state_changes()?, - ]), + &StateCache::squashed_state_diff( + vec![validate_state_cache, execution_state.to_state_cache()?], + tx_context.block_context.versioned_constants.comprehensive_state_diff, + ), CallInfo::summarize_many( validate_call_info.iter().chain(execute_call_info.iter()), &tx_context.block_context.versioned_constants,