Skip to content

Commit

Permalink
fix(blockifier): merge state diff with squash (#2310)
Browse files Browse the repository at this point in the history
  • Loading branch information
yoavGrs authored Dec 5, 2024
1 parent 17d5d3b commit 61340f7
Show file tree
Hide file tree
Showing 3 changed files with 116 additions and 78 deletions.
70 changes: 45 additions & 25 deletions crates/blockifier/src/state/cached_state.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::cell::RefCell;
use std::cell::{Ref, RefCell};
use std::collections::{HashMap, HashSet};

use indexmap::IndexMap;
Expand Down Expand Up @@ -58,6 +58,11 @@ impl<S: StateReader> CachedState<S> {
self.to_state_diff()
}

pub fn borrow_updated_state_cache(&mut self) -> StateResult<Ref<'_, StateCache>> {
self.update_initial_values_of_write_only_access()?;
Ok(self.cache.borrow())
}

pub fn update_cache(
&mut self,
write_updates: &StateMaps,
Expand Down Expand Up @@ -383,7 +388,7 @@ impl StateMaps {
/// The tracked changes are needed for block state commitment.
// Invariant: keys cannot be deleted from fields (only used internally by the cached state).
#[derive(Debug, Default, PartialEq, Eq)]
#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct StateCache {
// Reader's cached information; initial values, read before any write operation (per cell).
pub(crate) initial_reads: StateMaps,
Expand All @@ -402,6 +407,44 @@ impl StateCache {
StateChanges { state_maps, allocated_keys }
}

/// Squashes the given state caches into a single one and returns the state diff. Note that the
/// order of the state caches is important.
pub fn squash_state_caches(state_caches: Vec<&Self>) -> Self {
let mut squashed_state_cache = StateCache::default();

// Gives priority to early initial reads.
state_caches.iter().rev().for_each(|state_cache| {
squashed_state_cache.initial_reads.extend(&state_cache.initial_reads)
});
// Gives priority to late writes.
state_caches
.iter()
.for_each(|state_cache| squashed_state_cache.writes.extend(&state_cache.writes));
squashed_state_cache
}

/// Squashes the given state caches into a single one and returns the state diff. Note that the
/// order of the state caches is important.
/// If 'comprehensive_state_diff' is false, opposite updates may not be canceled out. Used for
/// backward compatibility.
pub fn squash_state_diff(
state_caches: Vec<&Self>,
comprehensive_state_diff: bool,
) -> StateChanges {
if comprehensive_state_diff {
return Self::squash_state_caches(state_caches).to_state_diff();
}

// Backward compatibility.
let mut merged_state_changes = StateChanges::default();
for state_cache in state_caches {
let state_change = state_cache.to_state_diff();
merged_state_changes.state_maps.extend(&state_change.state_maps);
merged_state_changes.allocated_keys.0.extend(&state_change.allocated_keys.0);
}
merged_state_changes
}

fn declare_contract(&mut self, class_hash: ClassHash) {
self.writes.declared_contracts.insert(class_hash, true);
}
Expand Down Expand Up @@ -680,18 +723,6 @@ impl StateChangesKeys {
pub struct AllocatedKeys(HashSet<StorageEntry>);

impl AllocatedKeys {
/// Extends the set of allocated keys with the allocated_keys of the given state changes.
/// Removes storage keys that are set back to zero.
pub fn update(&mut self, state_change: &StateChanges) {
self.0.extend(&state_change.allocated_keys.0);
// Remove keys that are set back to zero.
state_change.state_maps.storage.iter().for_each(|(k, v)| {
if v == &Felt::ZERO {
self.0.remove(k);
}
});
}

pub fn len(&self) -> usize {
self.0.len()
}
Expand Down Expand Up @@ -726,17 +757,6 @@ pub struct StateChanges {
}

impl StateChanges {
/// Merges the given state changes into a single one. Note that the order of the state changes
/// is important. The state changes are merged in the order they appear in the given vector.
pub fn merge(state_changes: Vec<Self>) -> Self {
let mut merged_state_changes = Self::default();
for state_change in state_changes {
merged_state_changes.state_maps.extend(&state_change.state_maps);
merged_state_changes.allocated_keys.update(&state_change);
}
merged_state_changes
}

pub fn count_for_fee_charge(
&self,
sender_address: Option<ContractAddress>,
Expand Down
107 changes: 61 additions & 46 deletions crates/blockifier/src/state/cached_state_test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -289,11 +289,11 @@ fn cached_state_state_diff_conversion() {
assert_eq!(expected_state_diff, state.to_state_diff().unwrap().state_maps.into());
}

fn create_state_changes_for_test<S: StateReader>(
fn create_state_cache_for_test<S: StateReader>(
state: &mut CachedState<S>,
sender_address: Option<ContractAddress>,
fee_token_address: ContractAddress,
) -> StateChanges {
) -> StateCache {
let contract_address = contract_address!(CONTRACT_ADDRESS);
let contract_address2 = contract_address!("0x101");
let class_hash = class_hash!("0x10");
Expand Down Expand Up @@ -323,7 +323,7 @@ fn create_state_changes_for_test<S: StateReader>(
let sender_balance_key = get_fee_token_var_address(sender_address);
state.set_storage_at(fee_token_address, sender_balance_key, felt!("0x1999")).unwrap();
}
state.get_actual_state_changes().unwrap()
state.borrow_updated_state_cache().unwrap().clone()
}

#[rstest]
Expand All @@ -333,7 +333,7 @@ fn test_from_state_changes_for_fee_charge(
let mut state: CachedState<DictStateReader> = CachedState::default();
let fee_token_address = contract_address!("0x17");
let state_changes =
create_state_changes_for_test(&mut state, sender_address, fee_token_address);
create_state_cache_for_test(&mut state, sender_address, fee_token_address).to_state_diff();
let state_changes_count = state_changes.count_for_fee_charge(sender_address, fee_token_address);
let n_expected_storage_updates = 1 + usize::from(sender_address.is_some());
let expected_state_changes_count = StateChangesCountForFee {
Expand All @@ -350,37 +350,32 @@ fn test_from_state_changes_for_fee_charge(
}

#[rstest]
fn test_state_changes_merge(
fn test_state_cache_merge(
#[values(Some(contract_address!("0x102")), None)] sender_address: Option<ContractAddress>,
) {
// Create a transactional state containing the `create_state_changes_for_test` logic, get the
// state changes and then commit.
// state cache and then commit.
let mut state: CachedState<DictStateReader> = CachedState::default();
let mut transactional_state = TransactionalState::create_transactional(&mut state);
let block_context = BlockContext::create_for_testing();
let fee_token_address = block_context.chain_info.fee_token_addresses.eth_fee_token_address;
let state_changes1 =
create_state_changes_for_test(&mut transactional_state, sender_address, fee_token_address);
let state_cache1 =
create_state_cache_for_test(&mut transactional_state, sender_address, fee_token_address);
transactional_state.commit();

// After performing `commit`, the transactional state is moved (into state). We need to create
// a new transactional state that wraps `state` to continue.
let mut transactional_state = TransactionalState::create_transactional(&mut state);
// Make sure that `get_actual_state_changes` on a newly created transactional state returns null
// state changes and that merging null state changes with non-null state changes results in the
// non-null state changes, no matter the order.
let state_changes2 = transactional_state.get_actual_state_changes().unwrap();
assert_eq!(state_changes2, StateChanges::default());
assert_eq!(
StateChanges::merge(vec![state_changes1.clone(), state_changes2.clone()]),
state_changes1
);
assert_eq!(
StateChanges::merge(vec![state_changes2.clone(), state_changes1.clone()]),
state_changes1
);

// Get the storage updates addresses and keys from the state_changes1, to overwrite.
// Make sure that the state_changes of a newly created transactional state returns null
// state cache and that merging null state cache with non-null state cache results in the
// non-null state cache, no matter the order.
let state_cache2 = transactional_state.borrow_updated_state_cache().unwrap().clone();
assert_eq!(state_cache2, StateCache::default());
assert_eq!(StateCache::squash_state_caches(vec![&state_cache1, &state_cache2]), state_cache1);
assert_eq!(StateCache::squash_state_caches(vec![&state_cache2, &state_cache1]), state_cache1);

// Get the storage updates addresses and keys from the state_cache1, to overwrite.
let state_changes1 = state_cache1.to_state_diff();
let mut storage_updates_keys = state_changes1.state_maps.storage.keys();
let &(contract_address, storage_key) = storage_updates_keys
.find(|(contract_address, _)| contract_address == &contract_address!(CONTRACT_ADDRESS))
Expand All @@ -394,24 +389,22 @@ fn test_state_changes_merge(
.set_storage_at(new_contract_address, storage_key, felt!("0x43210"))
.unwrap();
transactional_state.increment_nonce(contract_address).unwrap();
// Get the new state changes and then commit the transactional state.
let state_changes3 = transactional_state.get_actual_state_changes().unwrap();
// Get the new state cache and then commit the transactional state.
let state_cache3 = transactional_state.borrow_updated_state_cache().unwrap().clone();
transactional_state.commit();

// Get the total state changes of the CachedState underlying all the temporary transactional
// states. We expect the state_changes to match the merged state_changes of the transactional
// states, but only when done in the right order.
let state_changes_final = state.get_actual_state_changes().unwrap();
assert_eq!(
StateChanges::merge(vec![
state_changes1.clone(),
state_changes2.clone(),
state_changes3.clone()
]),
StateCache::squash_state_caches(vec![&state_cache1, &state_cache2, &state_cache3])
.to_state_diff(),
state_changes_final
);
assert_ne!(
StateChanges::merge(vec![state_changes3, state_changes1, state_changes2]),
StateCache::squash_state_caches(vec![&state_cache3, &state_cache1, &state_cache2])
.to_state_diff(),
state_changes_final
);
}
Expand All @@ -422,42 +415,63 @@ fn test_state_changes_merge(
#[case(true, vec![felt!("0x7")], true)]
#[case(false, vec![felt!("0x7")], false)]
#[case(true, vec![felt!("0x7"), felt!("0x0")], false)]
#[case(false, vec![felt!("0x0"), felt!("0x8")], true)]
#[case(false, vec![felt!("0x7"), felt!("0x1")], false)]
#[case(false, vec![felt!("0x0"), felt!("0x8")], false)]
#[case(false, vec![felt!("0x0"), felt!("0x8"), felt!("0x0")], false)]
fn test_allocated_keys_commit_and_merge(
fn test_state_cache_commit_and_merge(
#[case] is_base_empty: bool,
#[case] storage_updates: Vec<Felt>,
#[case] charged: bool,
#[values(true, false)] comprehensive_state_diff: bool,
) {
let contract_address = contract_address!(CONTRACT_ADDRESS);
let storage_key = StorageKey::from(0x10_u16);
// Set initial state
let mut state: CachedState<DictStateReader> = CachedState::default();

let non_empty_base_value = felt!("0x1");
if !is_base_empty {
state.set_storage_at(contract_address, storage_key, felt!("0x1")).unwrap();
state.set_storage_at(contract_address, storage_key, non_empty_base_value).unwrap();
}
let mut state_changes = vec![];
let mut state_caches = vec![];

for value in storage_updates {
for value in storage_updates.iter() {
// In the end of the previous loop, state has moved into the transactional state.
let mut transactional_state = TransactionalState::create_transactional(&mut state);
// Update state and collect the state changes.
transactional_state.set_storage_at(contract_address, storage_key, value).unwrap();
state_changes.push(transactional_state.get_actual_state_changes().unwrap());
transactional_state.set_storage_at(contract_address, storage_key, *value).unwrap();
state_caches.push(transactional_state.borrow_updated_state_cache().unwrap().clone());
transactional_state.commit();
}

let merged_changes = StateChanges::merge(state_changes);
assert_ne!(merged_changes.allocated_keys.is_empty(), charged);
let merged_changes =
StateCache::squash_state_diff(state_caches.iter().collect(), comprehensive_state_diff);
if comprehensive_state_diff {
// The comprehensive_state_diff is needed for backward compatibility of versions before the
// allocated keys feature was inserted.
assert_ne!(merged_changes.allocated_keys.is_empty(), charged);
}

// Test the storage diff.
let base_value = if is_base_empty { Felt::ZERO } else { non_empty_base_value };
let last_value = storage_updates.last().unwrap();
let expected_storage_diff = if (&base_value == last_value) && comprehensive_state_diff {
None
} else {
Some(last_value)
};
assert_eq!(
merged_changes.state_maps.storage.get(&(contract_address, storage_key)),
expected_storage_diff,
);
}

// Test that allocations in validate and execute phases are properly squashed.
#[rstest]
#[case(false, felt!("0x7"), felt!("0x8"), false)]
#[case(true, felt!("0x0"), felt!("0x8"), true)]
#[case(true, felt!("0x7"), felt!("0x7"), true)]
// TODO: not charge in the following case.
#[case(false, felt!("0x0"), felt!("0x8"), true)]
#[case(false, felt!("0x0"), felt!("0x8"), false)]
#[case(true, felt!("0x7"), felt!("0x0"), false)]
fn test_allocated_keys_two_transactions(
#[case] is_base_empty: bool,
Expand All @@ -475,14 +489,15 @@ fn test_allocated_keys_two_transactions(

let mut first_state = TransactionalState::create_transactional(&mut state);
first_state.set_storage_at(contract_address, storage_key, validate_value).unwrap();
let first_state_changes = first_state.get_actual_state_changes().unwrap();
let first_state_changes = first_state.borrow_updated_state_cache().unwrap().clone();

let mut second_state = TransactionalState::create_transactional(&mut first_state);
second_state.set_storage_at(contract_address, storage_key, execute_value).unwrap();
let second_state_changes = second_state.get_actual_state_changes().unwrap();
let second_state_changes = second_state.borrow_updated_state_cache().unwrap().clone();

let merged_changes = StateChanges::merge(vec![first_state_changes, second_state_changes]);
assert_ne!(merged_changes.allocated_keys.is_empty(), charged);
let merged_changes =
StateCache::squash_state_caches(vec![&first_state_changes, &second_state_changes]);
assert_ne!(merged_changes.to_state_diff().allocated_keys.is_empty(), charged);
}

#[test]
Expand Down
17 changes: 10 additions & 7 deletions crates/blockifier/src/transaction/account_transaction.rs
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ use crate::fee::fee_utils::{
use crate::fee::gas_usage::estimate_minimal_gas_vector;
use crate::fee::receipt::TransactionReceipt;
use crate::retdata;
use crate::state::cached_state::{StateChanges, TransactionalState};
use crate::state::cached_state::{StateCache, TransactionalState};
use crate::state::state_api::{State, StateReader, UpdatableState};
use crate::transaction::errors::{
TransactionExecutionError,
Expand Down Expand Up @@ -575,7 +575,7 @@ impl AccountTransaction {

// Save the state changes resulting from running `validate_tx`, to be used later for
// resource and fee calculation.
let validate_state_changes = state.get_actual_state_changes()?;
let validate_state_cache = state.borrow_updated_state_cache()?.clone();

// Create copies of state and validate_resources for the execution.
// Both will be rolled back if the execution is reverted or committed upon success.
Expand All @@ -591,7 +591,7 @@ impl AccountTransaction {
let revert_receipt = TransactionReceipt::from_account_tx(
self,
&tx_context,
&validate_state_changes,
&validate_state_cache.to_state_diff(),
CallInfo::summarize_many(
validate_call_info.iter(),
&tx_context.block_context.versioned_constants,
Expand All @@ -606,10 +606,13 @@ impl AccountTransaction {
let tx_receipt = TransactionReceipt::from_account_tx(
self,
&tx_context,
&StateChanges::merge(vec![
validate_state_changes,
execution_state.get_actual_state_changes()?,
]),
&StateCache::squash_state_diff(
vec![
&validate_state_cache,
&execution_state.borrow_updated_state_cache()?.clone(),
],
tx_context.block_context.versioned_constants.comprehensive_state_diff,
),
CallInfo::summarize_many(
validate_call_info.iter().chain(execute_call_info.iter()),
&tx_context.block_context.versioned_constants,
Expand Down

0 comments on commit 61340f7

Please sign in to comment.