diff --git a/crates/blockifier/src/state/stateful_compression.rs b/crates/blockifier/src/state/stateful_compression.rs index e810b62de1..59d9ab963c 100644 --- a/crates/blockifier/src/state/stateful_compression.rs +++ b/crates/blockifier/src/state/stateful_compression.rs @@ -146,24 +146,33 @@ pub fn compress( ) -> CompressionResult { let alias_compressor = AliasCompressor { state, alias_contract_address }; - let mut nonces = HashMap::new(); - for (contract_address, nonce) in state_diff.nonces.iter() { - nonces.insert(alias_compressor.compress_address(contract_address)?, *nonce); - } - let mut class_hashes = HashMap::new(); - for (contract_address, class_hash) in state_diff.class_hashes.iter() { - class_hashes.insert(alias_compressor.compress_address(contract_address)?, *class_hash); - } - let mut storage = HashMap::new(); - for ((contract_address, key), value) in state_diff.storage.iter() { - storage.insert( - ( - alias_compressor.compress_address(contract_address)?, - alias_compressor.compress_storage_key(key, contract_address)?, - ), - *value, - ); - } + let nonces = state_diff + .nonces + .iter() + .map(|(contract_address, nonce)| { + Ok((alias_compressor.compress_address(contract_address)?, *nonce)) + }) + .collect::>()?; + let class_hashes = state_diff + .class_hashes + .iter() + .map(|(contract_address, class_hash)| { + Ok((alias_compressor.compress_address(contract_address)?, *class_hash)) + }) + .collect::>()?; + let storage = state_diff + .storage + .iter() + .map(|((contract_address, key), value)| { + Ok(( + ( + alias_compressor.compress_address(contract_address)?, + alias_compressor.compress_storage_key(key, contract_address)?, + ), + *value, + )) + }) + .collect::>()?; Ok(StateMaps { nonces, class_hashes, storage, ..state_diff.clone() }) } @@ -174,7 +183,6 @@ struct AliasCompressor<'a, S: StateReader> { alias_contract_address: ContractAddress, } -#[allow(dead_code)] impl AliasCompressor<'_, S> { fn compress_address( &self, diff --git a/crates/blockifier/src/state/stateful_compression_test.rs b/crates/blockifier/src/state/stateful_compression_test.rs index 916a0e393c..72898f112b 100644 --- a/crates/blockifier/src/state/stateful_compression_test.rs +++ b/crates/blockifier/src/state/stateful_compression_test.rs @@ -1,19 +1,23 @@ -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use std::sync::LazyLock; use assert_matches::assert_matches; use rstest::rstest; use starknet_api::core::{ClassHash, CompiledClassHash, ContractAddress, Nonce, PatriciaKey}; +use starknet_api::felt; use starknet_api::state::StorageKey; use starknet_types_core::felt::Felt; use super::{ compress, state_diff_with_alias_allocation, + Alias, + AliasKey, AliasUpdater, ALIAS_COUNTER_STORAGE_KEY, INITIAL_AVAILABLE_ALIAS, MAX_NON_COMPRESSED_CONTRACT_ADDRESS, + MIN_VALUE_FOR_ALIAS_ALLOC, }; use crate::state::cached_state::{CachedState, StateMaps, StorageEntry}; use crate::state::state_api::{State, StateReader}; @@ -23,6 +27,90 @@ use crate::test_utils::dict_state_reader::DictStateReader; static ALIAS_CONTRACT_ADDRESS: LazyLock = LazyLock::new(|| ContractAddress(PatriciaKey::try_from(Felt::TWO).unwrap())); +/// Decompresses the state diff by replacing the aliases with addresses and storage keys. +fn decompress( + state_diff: &StateMaps, + state: &S, + alias_contract_address: ContractAddress, + alias_keys: HashSet, +) -> StateMaps { + let alias_decompressor = AliasDecompressorUtil::new(state, alias_contract_address, alias_keys); + + let mut nonces = HashMap::new(); + for (alias_contract_address, nonce) in state_diff.nonces.iter() { + nonces.insert(alias_decompressor.decompress_address(alias_contract_address), *nonce); + } + let mut class_hashes = HashMap::new(); + for (alias_contract_address, class_hash) in state_diff.class_hashes.iter() { + class_hashes + .insert(alias_decompressor.decompress_address(alias_contract_address), *class_hash); + } + let mut storage = HashMap::new(); + for ((alias_contract_address, alias_storage_key), value) in state_diff.storage.iter() { + let contract_address = alias_decompressor.decompress_address(alias_contract_address); + storage.insert( + ( + contract_address, + alias_decompressor.decompress_storage_key(alias_storage_key, &contract_address), + ), + *value, + ); + } + + StateMaps { nonces, class_hashes, storage, ..state_diff.clone() } +} + +/// Replaces aliases with the original contact addresses and storage keys. +struct AliasDecompressorUtil { + reversed_alias_mapping: HashMap, +} + +impl AliasDecompressorUtil { + fn new( + state: &S, + alias_contract_address: ContractAddress, + alias_keys: HashSet, + ) -> Self { + let mut reversed_alias_mapping = HashMap::new(); + for alias_key in alias_keys.into_iter() { + reversed_alias_mapping.insert( + state.get_storage_at(alias_contract_address, alias_key).unwrap(), + alias_key, + ); + } + Self { reversed_alias_mapping } + } + + fn decompress_address(&self, contract_address_alias: &ContractAddress) -> ContractAddress { + if contract_address_alias.0 >= MIN_VALUE_FOR_ALIAS_ALLOC { + ContractAddress::try_from( + *self.restore_alias_key(Felt::from(*contract_address_alias)).key(), + ) + .unwrap() + } else { + *contract_address_alias + } + } + + fn decompress_storage_key( + &self, + storage_key_alias: &StorageKey, + contact_address: &ContractAddress, + ) -> StorageKey { + if storage_key_alias.0 >= MIN_VALUE_FOR_ALIAS_ALLOC + && contact_address > &MAX_NON_COMPRESSED_CONTRACT_ADDRESS + { + self.restore_alias_key(*storage_key_alias.0) + } else { + *storage_key_alias + } + } + + fn restore_alias_key(&self, alias: Alias) -> AliasKey { + *self.reversed_alias_mapping.get(&alias).unwrap() + } +} + fn insert_to_alias_contract( storage: &mut HashMap, key: StorageKey, @@ -304,4 +392,9 @@ fn test_compression() { let compressed_state_diff = compress(&state_diff, &state_reader, *ALIAS_CONTRACT_ADDRESS).unwrap(); assert_eq!(compressed_state_diff, expected_compressed_state_diff); + + let alias_keys = state_reader.storage_view.keys().map(|(_, key)| *key).collect(); + let decompressed_state_diff = + decompress(&compressed_state_diff, &state_reader, *ALIAS_CONTRACT_ADDRESS, alias_keys); + assert_eq!(decompressed_state_diff, state_diff); }