diff --git a/crates/blockifier/src/state/cached_state.rs b/crates/blockifier/src/state/cached_state.rs index ac26d64bd3..dbb84243e7 100644 --- a/crates/blockifier/src/state/cached_state.rs +++ b/crates/blockifier/src/state/cached_state.rs @@ -679,14 +679,24 @@ impl StateChangesKeys { } } +/// Holds the set of allocated storage keys. +/// Ignores all but storage entry allocations - newly allocated contract addresses and +/// class hashes are paid for separately. #[cfg_attr(any(feature = "testing", test), derive(Clone))] #[derive(Debug, Default, Eq, PartialEq)] pub struct AllocatedKeys(HashSet); impl AllocatedKeys { + /// Extends the set of allocated keys with the allocated_keys of the given state changes. + /// Removes storage keys that are set back to zero. pub fn update(&mut self, state_change: &StateChanges) { self.0.extend(&state_change.allocated_keys.0); - // TODO: Remove keys that are set back to zero. + // Remove keys that are set back to zero. + state_change.state_maps.storage.iter().for_each(|(k, v)| { + if v == &Felt::ZERO { + self.0.remove(k); + } + }); } pub fn len(&self) -> usize { @@ -697,14 +707,19 @@ impl AllocatedKeys { self.0.is_empty() } - /// Collect entries that turn zero -> nonzero. + /// Collects entries that turn zero -> nonzero. pub fn from_storage_diff( - _updated_storage: &HashMap, - _base_storage: &HashMap, + updated_storage: &HashMap, + base_storage: &HashMap, ) -> Self { Self( - HashSet::new(), - // TODO: Calculate the difference between the updated_storage and the base_storage. + updated_storage + .iter() + .filter_map(|(k, v)| { + let base_value = base_storage.get(k).unwrap_or(&Felt::ZERO); + if *v != Felt::ZERO && *base_value == Felt::ZERO { Some(*k) } else { None } + }) + .collect(), ) } } diff --git a/crates/blockifier/src/state/cached_state_test.rs b/crates/blockifier/src/state/cached_state_test.rs index bdec817678..bcc91ab2f5 100644 --- a/crates/blockifier/src/state/cached_state_test.rs +++ b/crates/blockifier/src/state/cached_state_test.rs @@ -335,15 +335,16 @@ fn test_from_state_changes_for_fee_charge( let state_changes = create_state_changes_for_test(&mut state, sender_address, fee_token_address); let state_changes_count = state_changes.count_for_fee_charge(sender_address, fee_token_address); + let n_expected_storage_updates = 1 + usize::from(sender_address.is_some()); let expected_state_changes_count = StateChangesCountForFee { // 1 for storage update + 1 for sender balance update if sender is defined. state_changes_count: StateChangesCount { - n_storage_updates: 1 + usize::from(sender_address.is_some()), + n_storage_updates: n_expected_storage_updates, n_class_hash_updates: 1, n_compiled_class_hash_updates: 1, n_modified_contracts: 2, }, - n_allocated_keys: 0, + n_allocated_keys: n_expected_storage_updates, }; assert_eq!(state_changes_count, expected_state_changes_count); } @@ -415,6 +416,75 @@ fn test_state_changes_merge( ); } +// Test that `allocated_keys` collects zero -> nonzero updates, where we commit each update. +#[rstest] +#[case(false, vec![felt!("0x0")], false)] +#[case(true, vec![felt!("0x7")], true)] +#[case(false, vec![felt!("0x7")], false)] +#[case(true, vec![felt!("0x7"), felt!("0x0")], false)] +#[case(false, vec![felt!("0x0"), felt!("0x8")], true)] +#[case(false, vec![felt!("0x0"), felt!("0x8"), felt!("0x0")], false)] +fn test_allocated_keys_commit_and_merge( + #[case] is_base_empty: bool, + #[case] storage_updates: Vec, + #[case] charged: bool, +) { + let contract_address = contract_address!(CONTRACT_ADDRESS); + let storage_key = StorageKey::from(0x10_u16); + // Set initial state + let mut state: CachedState = CachedState::default(); + if !is_base_empty { + state.set_storage_at(contract_address, storage_key, felt!("0x1")).unwrap(); + } + let mut state_changes = vec![]; + + for value in storage_updates { + // In the end of the previous loop, state has moved into the transactional state. + let mut transactional_state = TransactionalState::create_transactional(&mut state); + // Update state and collect the state changes. + transactional_state.set_storage_at(contract_address, storage_key, value).unwrap(); + state_changes.push(transactional_state.get_actual_state_changes().unwrap()); + transactional_state.commit(); + } + + let merged_changes = StateChanges::merge(state_changes); + assert_ne!(merged_changes.allocated_keys.is_empty(), charged); +} + +// Test that allocations in validate and execute phases are properly squashed. +#[rstest] +#[case(false, felt!("0x7"), felt!("0x8"), false)] +#[case(true, felt!("0x0"), felt!("0x8"), true)] +#[case(true, felt!("0x7"), felt!("0x7"), true)] +// TODO: not charge in the following case. +#[case(false, felt!("0x0"), felt!("0x8"), true)] +#[case(true, felt!("0x7"), felt!("0x0"), false)] +fn test_allocated_keys_two_transactions( + #[case] is_base_empty: bool, + #[case] validate_value: Felt, + #[case] execute_value: Felt, + #[case] charged: bool, +) { + let contract_address = contract_address!(CONTRACT_ADDRESS); + let storage_key = StorageKey::from(0x10_u16); + // Set initial state + let mut state: CachedState = CachedState::default(); + if !is_base_empty { + state.set_storage_at(contract_address, storage_key, felt!("0x1")).unwrap(); + } + + let mut first_state = TransactionalState::create_transactional(&mut state); + first_state.set_storage_at(contract_address, storage_key, validate_value).unwrap(); + let first_state_changes = first_state.get_actual_state_changes().unwrap(); + + let mut second_state = TransactionalState::create_transactional(&mut first_state); + second_state.set_storage_at(contract_address, storage_key, execute_value).unwrap(); + let second_state_changes = second_state.get_actual_state_changes().unwrap(); + + let merged_changes = StateChanges::merge(vec![first_state_changes, second_state_changes]); + assert_ne!(merged_changes.allocated_keys.is_empty(), charged); +} + #[test] fn test_contract_cache_is_used() { // Initialize the global cache with a single class, and initialize an empty state with this diff --git a/crates/blockifier/src/transaction/account_transactions_test.rs b/crates/blockifier/src/transaction/account_transactions_test.rs index a73a74f19c..5870922b47 100644 --- a/crates/blockifier/src/transaction/account_transactions_test.rs +++ b/crates/blockifier/src/transaction/account_transactions_test.rs @@ -1396,7 +1396,9 @@ fn test_count_actual_storage_changes( n_modified_contracts: 2, ..Default::default() }, - n_allocated_keys: 0, + // Storage writing and sequencer fee update. The account balance storage change is not + // allocated in this transaction. + n_allocated_keys: 2, }; assert_eq!(expected_modified_contracts, state_changes_1.state_maps.get_modified_contracts()); @@ -1485,7 +1487,8 @@ fn test_count_actual_storage_changes( n_modified_contracts: 1, ..Default::default() }, - n_allocated_keys: 0, + // A storage allocation for the recipient. + n_allocated_keys: 1, }; assert_eq!( diff --git a/crates/blockifier/src/transaction/post_execution_test.rs b/crates/blockifier/src/transaction/post_execution_test.rs index 98956e0813..3ac147f0a3 100644 --- a/crates/blockifier/src/transaction/post_execution_test.rs +++ b/crates/blockifier/src/transaction/post_execution_test.rs @@ -289,7 +289,7 @@ fn test_revert_on_resource_overuse( // We need this kind of invocation, to be able to test the specific scenario: the resource // bounds must be enough to allow completion of the transaction, and yet must still fail // post-execution bounds check. - let execution_info_measure = run_invoke_tx( + let mut execution_info_measure = run_invoke_tx( &mut state, &block_context, invoke_tx_args! { @@ -337,6 +337,17 @@ fn test_revert_on_resource_overuse( .unwrap(); assert_eq!(execution_info_tight.revert_error, None); assert_eq!(execution_info_tight.receipt.fee, actual_fee); + // The only difference between the two executions should be the number of allocated keys, as the + // second execution writes to the same keys as the first. + let n_allocated_keys = &mut execution_info_measure + .receipt + .resources + .starknet_resources + .state + .state_changes_for_fee + .n_allocated_keys; + assert_eq!(n_allocated_keys, &usize::from(n_writes)); + *n_allocated_keys = 0; assert_eq!(execution_info_tight.receipt.resources, execution_info_measure.receipt.resources); // Re-run the same function with max bounds slightly below the actual usage, and verify it's