From 0fec0270a2e59090e99c98b4d6d4ae426117cf0a Mon Sep 17 00:00:00 2001 From: Kariy Date: Tue, 28 Nov 2023 11:10:42 +0800 Subject: [PATCH] refactor(katana-provider): add contract class provider --- crates/katana/storage/provider/src/lib.rs | 27 ++-- .../provider/src/providers/fork/backend.rs | 139 +++++++++--------- .../provider/src/providers/fork/mod.rs | 7 +- .../provider/src/providers/fork/state.rs | 95 +++++++----- .../provider/src/providers/in_memory/cache.rs | 4 +- .../provider/src/providers/in_memory/mod.rs | 16 +- .../provider/src/providers/in_memory/state.rs | 77 ++++++---- .../storage/provider/src/traits/contract.rs | 41 +++++- .../storage/provider/src/traits/state.rs | 47 +++--- 9 files changed, 277 insertions(+), 176 deletions(-) diff --git a/crates/katana/storage/provider/src/lib.rs b/crates/katana/storage/provider/src/lib.rs index 476d1e02ac..36aad5bebc 100644 --- a/crates/katana/storage/provider/src/lib.rs +++ b/crates/katana/storage/provider/src/lib.rs @@ -9,13 +9,14 @@ use katana_primitives::contract::{ SierraClass, StorageKey, StorageValue, }; use katana_primitives::transaction::{Receipt, Tx, TxHash, TxNumber}; +use traits::contract::ContractClassProvider; pub mod providers; pub mod traits; use crate::traits::block::{BlockHashProvider, BlockNumberProvider, BlockProvider, HeaderProvider}; -use crate::traits::contract::ContractProvider; -use crate::traits::state::{StateFactoryProvider, StateProvider, StateProviderExt}; +use crate::traits::contract::ContractInfoProvider; +use crate::traits::state::{StateFactoryProvider, StateProvider}; use crate::traits::state_update::StateUpdateProvider; use crate::traits::transaction::{ReceiptProvider, TransactionProvider, TransactionsProviderExt}; @@ -128,10 +129,6 @@ impl StateProvider for BlockchainProvider where Db: StateProvider, { - fn class(&self, hash: ClassHash) -> Result> { - self.provider.class(hash) - } - fn class_hash_of_contract(&self, address: ContractAddress) -> Result> { self.provider.class_hash_of_contract(address) } @@ -150,19 +147,23 @@ where ) -> Result> { self.provider.storage(address, storage_key) } +} +impl ContractClassProvider for BlockchainProvider +where + Db: ContractClassProvider, +{ fn compiled_class_hash_of_class_hash( &self, hash: ClassHash, ) -> Result> { self.provider.compiled_class_hash_of_class_hash(hash) } -} -impl StateProviderExt for BlockchainProvider -where - Db: StateProviderExt, -{ + fn class(&self, hash: ClassHash) -> Result> { + self.provider.class(hash) + } + fn sierra_class(&self, hash: ClassHash) -> Result> { self.provider.sierra_class(hash) } @@ -190,9 +191,9 @@ where } } -impl ContractProvider for BlockchainProvider +impl ContractInfoProvider for BlockchainProvider where - Db: ContractProvider, + Db: ContractInfoProvider, { fn contract(&self, address: ContractAddress) -> Result> { self.provider.contract(address) diff --git a/crates/katana/storage/provider/src/providers/fork/backend.rs b/crates/katana/storage/provider/src/providers/fork/backend.rs index 68e1775064..ec27ff7fe2 100644 --- a/crates/katana/storage/provider/src/providers/fork/backend.rs +++ b/crates/katana/storage/provider/src/providers/fork/backend.rs @@ -12,8 +12,8 @@ use futures::stream::Stream; use futures::{Future, FutureExt}; use katana_primitives::block::BlockHashOrNumber; use katana_primitives::contract::{ - ClassHash, CompiledClassHash, CompiledContractClass, ContractAddress, Nonce, SierraClass, - StorageKey, StorageValue, + ClassHash, CompiledClassHash, CompiledContractClass, ContractAddress, GenericContractInfo, + Nonce, SierraClass, StorageKey, StorageValue, }; use katana_primitives::conversion::rpc::{ compiled_class_hash_from_flattened_sierra_class, legacy_rpc_to_inner_class, rpc_to_inner_class, @@ -26,7 +26,8 @@ use starknet::providers::{JsonRpcClient, Provider, ProviderError}; use tracing::trace; use crate::providers::in_memory::cache::CacheStateDb; -use crate::traits::state::{StateProvider, StateProviderExt}; +use crate::traits::contract::{ContractClassProvider, ContractInfoProvider}; +use crate::traits::state::StateProvider; type GetNonceResult = Result; type GetStorageResult = Result; @@ -316,56 +317,26 @@ impl SharedStateProvider { } } -impl StateProvider for SharedStateProvider { - fn nonce(&self, address: ContractAddress) -> Result> { - if let Some(nonce) = self.0.contract_state.read().get(&address).map(|c| c.nonce) { - return Ok(Some(nonce)); +impl ContractInfoProvider for SharedStateProvider { + fn contract(&self, address: ContractAddress) -> Result> { + if let Some(info) = self.0.contract_state.read().get(&address).cloned() { + return Ok(Some(info)); } let nonce = self.0.do_get_nonce(address).unwrap(); - self.0.contract_state.write().entry(address).or_default().nonce = nonce; - - Ok(Some(nonce)) - } - - fn class(&self, hash: ClassHash) -> Result> { - if let Some(class) = - self.0.shared_contract_classes.compiled_classes.read().get(&hash).cloned() - { - return Ok(Some(class)); - } - - let class = self.0.do_get_class_at(hash).unwrap(); - let (class_hash, compiled_class_hash, casm, sierra) = match class { - ContractClass::Legacy(class) => { - let (_, compiled_class) = legacy_rpc_to_inner_class(&class)?; - (hash, hash, compiled_class, None) - } - ContractClass::Sierra(sierra_class) => { - let (_, compiled_class_hash, compiled_class) = rpc_to_inner_class(&sierra_class)?; - (hash, compiled_class_hash, compiled_class, Some(sierra_class)) - } - }; - - self.0.compiled_class_hashes.write().insert(class_hash, compiled_class_hash); + let class_hash = self.0.do_get_class_hash_at(address).unwrap(); + let info = GenericContractInfo { nonce, class_hash }; - self.0 - .shared_contract_classes - .compiled_classes - .write() - .entry(class_hash) - .or_insert(casm.clone()); + self.0.contract_state.write().insert(address, info.clone()); - if let Some(sierra) = sierra { - self.0 - .shared_contract_classes - .sierra_classes - .write() - .entry(class_hash) - .or_insert(sierra); - } + Ok(Some(info)) + } +} - Ok(Some(casm)) +impl StateProvider for SharedStateProvider { + fn nonce(&self, address: ContractAddress) -> Result> { + let nonce = ContractInfoProvider::contract(&self, address)?.map(|i| i.nonce); + Ok(nonce) } fn storage( @@ -384,14 +355,30 @@ impl StateProvider for SharedStateProvider { } fn class_hash_of_contract(&self, address: ContractAddress) -> Result> { - if let Some(hash) = self.0.contract_state.read().get(&address).map(|c| c.class_hash) { - return Ok(Some(hash)); + let hash = ContractInfoProvider::contract(&self, address)?.map(|i| i.class_hash); + Ok(hash) + } +} + +impl ContractClassProvider for SharedStateProvider { + fn sierra_class(&self, hash: ClassHash) -> Result> { + if let class @ Some(_) = self.0.shared_contract_classes.sierra_classes.read().get(&hash) { + return Ok(class.cloned()); } - let class_hash = self.0.do_get_class_hash_at(address).unwrap(); - self.0.contract_state.write().entry(address).or_default().class_hash = class_hash; + let class = self.0.do_get_class_at(hash).unwrap(); + match class { + starknet::core::types::ContractClass::Legacy(_) => Ok(None), + starknet::core::types::ContractClass::Sierra(sierra_class) => { + self.0 + .shared_contract_classes + .sierra_classes + .write() + .insert(hash, sierra_class.clone()); - Ok(Some(class_hash)) + Ok(Some(sierra_class)) + } + } } fn compiled_class_hash_of_class_hash( @@ -407,27 +394,45 @@ impl StateProvider for SharedStateProvider { Ok(Some(hash)) } -} -impl StateProviderExt for SharedStateProvider { - fn sierra_class(&self, hash: ClassHash) -> Result> { - if let class @ Some(_) = self.0.shared_contract_classes.sierra_classes.read().get(&hash) { - return Ok(class.cloned()); + fn class(&self, hash: ClassHash) -> Result> { + if let Some(class) = + self.0.shared_contract_classes.compiled_classes.read().get(&hash).cloned() + { + return Ok(Some(class)); } let class = self.0.do_get_class_at(hash).unwrap(); - match class { - starknet::core::types::ContractClass::Legacy(_) => Ok(None), - starknet::core::types::ContractClass::Sierra(sierra_class) => { - self.0 - .shared_contract_classes - .sierra_classes - .write() - .insert(hash, sierra_class.clone()); - - Ok(Some(sierra_class)) + let (class_hash, compiled_class_hash, casm, sierra) = match class { + ContractClass::Legacy(class) => { + let (_, compiled_class) = legacy_rpc_to_inner_class(&class)?; + (hash, hash, compiled_class, None) + } + ContractClass::Sierra(sierra_class) => { + let (_, compiled_class_hash, compiled_class) = rpc_to_inner_class(&sierra_class)?; + (hash, compiled_class_hash, compiled_class, Some(sierra_class)) } + }; + + self.0.compiled_class_hashes.write().insert(class_hash, compiled_class_hash); + + self.0 + .shared_contract_classes + .compiled_classes + .write() + .entry(class_hash) + .or_insert(casm.clone()); + + if let Some(sierra) = sierra { + self.0 + .shared_contract_classes + .sierra_classes + .write() + .entry(class_hash) + .or_insert(sierra); } + + Ok(Some(casm)) } } diff --git a/crates/katana/storage/provider/src/providers/fork/mod.rs b/crates/katana/storage/provider/src/providers/fork/mod.rs index 70c8f116e9..838dfed6a7 100644 --- a/crates/katana/storage/provider/src/providers/fork/mod.rs +++ b/crates/katana/storage/provider/src/providers/fork/mod.rs @@ -20,7 +20,7 @@ use self::state::ForkedStateDb; use super::in_memory::cache::{CacheDb, CacheStateDb}; use super::in_memory::state::HistoricalStates; use crate::traits::block::{BlockHashProvider, BlockNumberProvider, BlockProvider, HeaderProvider}; -use crate::traits::contract::ContractProvider; +use crate::traits::contract::ContractInfoProvider; use crate::traits::state::{StateFactoryProvider, StateProvider}; use crate::traits::state_update::StateUpdateProvider; use crate::traits::transaction::{ReceiptProvider, TransactionProvider, TransactionsProviderExt}; @@ -100,8 +100,9 @@ impl BlockProvider for ForkedProvider { }; let body = self.transactions_by_block(id)?.unwrap_or_default(); + let status = self.storage.block_statusses.get(&header.number).cloned().expect("must have"); - Ok(Some(Block { header, body })) + Ok(Some(Block { header, body, status })) } fn blocks_in_range(&self, range: RangeInclusive) -> Result> { @@ -209,7 +210,7 @@ impl ReceiptProvider for ForkedProvider { } } -impl ContractProvider for ForkedProvider { +impl ContractInfoProvider for ForkedProvider { fn contract(&self, address: ContractAddress) -> Result> { let contract = self.state.contract_state.read().get(&address).cloned(); Ok(contract) diff --git a/crates/katana/storage/provider/src/providers/fork/state.rs b/crates/katana/storage/provider/src/providers/fork/state.rs index e6aa8faf44..ced68bee25 100644 --- a/crates/katana/storage/provider/src/providers/fork/state.rs +++ b/crates/katana/storage/provider/src/providers/fork/state.rs @@ -2,14 +2,15 @@ use std::sync::Arc; use anyhow::Result; use katana_primitives::contract::{ - ClassHash, CompiledClassHash, CompiledContractClass, ContractAddress, Nonce, SierraClass, - StorageKey, StorageValue, + ClassHash, CompiledClassHash, CompiledContractClass, ContractAddress, GenericContractInfo, + Nonce, SierraClass, StorageKey, StorageValue, }; use super::backend::SharedStateProvider; use crate::providers::in_memory::cache::CacheStateDb; use crate::providers::in_memory::state::StateSnapshot; -use crate::traits::state::{StateProvider, StateProviderExt}; +use crate::traits::contract::{ContractClassProvider, ContractInfoProvider}; +use crate::traits::state::StateProvider; pub type ForkedStateDb = CacheStateDb; pub type ForkedSnapshot = StateSnapshot; @@ -23,14 +24,16 @@ impl ForkedStateDb { } } -impl StateProvider for ForkedStateDb { - fn class(&self, hash: ClassHash) -> Result> { - if let class @ Some(_) = self.shared_contract_classes.compiled_classes.read().get(&hash) { - return Ok(class.cloned()); +impl ContractInfoProvider for ForkedStateDb { + fn contract(&self, address: ContractAddress) -> Result> { + if let info @ Some(_) = self.contract_state.read().get(&address).cloned() { + return Ok(info); } - StateProvider::class(&self.db, hash) + ContractInfoProvider::contract(&self.db, address) } +} +impl StateProvider for ForkedStateDb { fn class_hash_of_contract(&self, address: ContractAddress) -> Result> { if let hash @ Some(_) = self.contract_state.read().get(&address).map(|i| i.class_hash) { return Ok(hash); @@ -55,6 +58,15 @@ impl StateProvider for ForkedStateDb { } StateProvider::storage(&self.db, address, storage_key) } +} + +impl ContractClassProvider for CacheStateDb { + fn sierra_class(&self, hash: ClassHash) -> Result> { + if let class @ Some(_) = self.shared_contract_classes.sierra_classes.read().get(&hash) { + return Ok(class.cloned()); + } + ContractClassProvider::sierra_class(&self.db, hash) + } fn compiled_class_hash_of_class_hash( &self, @@ -63,21 +75,25 @@ impl StateProvider for ForkedStateDb { if let hash @ Some(_) = self.compiled_class_hashes.read().get(&hash) { return Ok(hash.cloned()); } - StateProvider::compiled_class_hash_of_class_hash(&self.db, hash) + ContractClassProvider::compiled_class_hash_of_class_hash(&self.db, hash) } -} -impl StateProviderExt for CacheStateDb { - fn sierra_class(&self, hash: ClassHash) -> Result> { - if let class @ Some(_) = self.shared_contract_classes.sierra_classes.read().get(&hash) { + fn class(&self, hash: ClassHash) -> Result> { + if let class @ Some(_) = self.shared_contract_classes.compiled_classes.read().get(&hash) { return Ok(class.cloned()); } - StateProviderExt::sierra_class(&self.db, hash) + ContractClassProvider::class(&self.db, hash) } } pub(super) struct LatestStateProvider(pub(super) Arc); +impl ContractInfoProvider for LatestStateProvider { + fn contract(&self, address: ContractAddress) -> Result> { + ContractInfoProvider::contract(&self.0, address) + } +} + impl StateProvider for LatestStateProvider { fn nonce(&self, address: ContractAddress) -> Result> { StateProvider::nonce(&self.0, address) @@ -91,25 +107,34 @@ impl StateProvider for LatestStateProvider { StateProvider::storage(&self.0, address, storage_key) } - fn class(&self, hash: ClassHash) -> Result> { - StateProvider::class(&self.0, hash) - } - fn class_hash_of_contract(&self, address: ContractAddress) -> Result> { StateProvider::class_hash_of_contract(&self.0, address) } +} + +impl ContractClassProvider for LatestStateProvider { + fn sierra_class(&self, hash: ClassHash) -> Result> { + ContractClassProvider::sierra_class(&self.0, hash) + } + + fn class(&self, hash: ClassHash) -> Result> { + ContractClassProvider::class(&self.0, hash) + } fn compiled_class_hash_of_class_hash( &self, hash: ClassHash, ) -> Result> { - StateProvider::compiled_class_hash_of_class_hash(&self.0, hash) + ContractClassProvider::compiled_class_hash_of_class_hash(&self.0, hash) } } -impl StateProviderExt for LatestStateProvider { - fn sierra_class(&self, hash: ClassHash) -> Result> { - StateProviderExt::sierra_class(&self.0, hash) +impl ContractInfoProvider for ForkedSnapshot { + fn contract(&self, address: ContractAddress) -> Result> { + if let info @ Some(_) = self.inner.contract_state.get(&address).cloned() { + return Ok(info); + } + ContractInfoProvider::contract(&self.inner.db, address) } } @@ -132,13 +157,6 @@ impl StateProvider for ForkedSnapshot { StateProvider::storage(&self.inner.db, address, storage_key) } - fn class(&self, hash: ClassHash) -> Result> { - if let class @ Some(_) = self.classes.compiled_classes.read().get(&hash).cloned() { - return Ok(class); - } - StateProvider::class(&self.inner.db, hash) - } - fn class_hash_of_contract(&self, address: ContractAddress) -> Result> { if let class_hash @ Some(_) = self.inner.contract_state.get(&address).map(|info| info.class_hash) @@ -147,6 +165,15 @@ impl StateProvider for ForkedSnapshot { } StateProvider::class_hash_of_contract(&self.inner.db, address) } +} + +impl ContractClassProvider for ForkedSnapshot { + fn sierra_class(&self, hash: ClassHash) -> Result> { + if let class @ Some(_) = self.classes.sierra_classes.read().get(&hash).cloned() { + return Ok(class); + } + ContractClassProvider::sierra_class(&self.inner.db, hash) + } fn compiled_class_hash_of_class_hash( &self, @@ -155,15 +182,13 @@ impl StateProvider for ForkedSnapshot { if let hash @ Some(_) = self.inner.compiled_class_hashes.get(&hash).cloned() { return Ok(hash); } - StateProvider::compiled_class_hash_of_class_hash(&self.inner.db, hash) + ContractClassProvider::compiled_class_hash_of_class_hash(&self.inner.db, hash) } -} -impl StateProviderExt for ForkedSnapshot { - fn sierra_class(&self, hash: ClassHash) -> Result> { - if let class @ Some(_) = self.classes.sierra_classes.read().get(&hash).cloned() { + fn class(&self, hash: ClassHash) -> Result> { + if let class @ Some(_) = self.classes.compiled_classes.read().get(&hash).cloned() { return Ok(class); } - StateProviderExt::sierra_class(&self.inner.db, hash) + ContractClassProvider::class(&self.inner.db, hash) } } diff --git a/crates/katana/storage/provider/src/providers/in_memory/cache.rs b/crates/katana/storage/provider/src/providers/in_memory/cache.rs index dccfe740e9..4889c0e64a 100644 --- a/crates/katana/storage/provider/src/providers/in_memory/cache.rs +++ b/crates/katana/storage/provider/src/providers/in_memory/cache.rs @@ -2,7 +2,7 @@ use std::collections::HashMap; use std::sync::Arc; use katana_db::models::block::StoredBlockBodyIndices; -use katana_primitives::block::{BlockHash, BlockNumber, Header, StateUpdate}; +use katana_primitives::block::{BlockHash, BlockNumber, BlockStatus, Header, StateUpdate}; use katana_primitives::contract::{ ClassHash, CompiledClassHash, CompiledContractClass, ContractAddress, GenericContractInfo, SierraClass, StorageKey, StorageValue, @@ -43,6 +43,7 @@ pub struct CacheDb { pub(crate) block_headers: HashMap, pub(crate) block_hashes: HashMap, pub(crate) block_numbers: HashMap, + pub(crate) block_statusses: HashMap, pub(crate) block_body_indices: HashMap, pub(crate) latest_block_hash: BlockHash, pub(crate) latest_block_number: BlockNumber, @@ -75,6 +76,7 @@ impl CacheDb { block_hashes: HashMap::new(), block_headers: HashMap::new(), block_numbers: HashMap::new(), + block_statusses: HashMap::new(), transaction_hashes: HashMap::new(), block_body_indices: HashMap::new(), transaction_numbers: HashMap::new(), diff --git a/crates/katana/storage/provider/src/providers/in_memory/mod.rs b/crates/katana/storage/provider/src/providers/in_memory/mod.rs index 92f4a49c7b..8c0791ddec 100644 --- a/crates/katana/storage/provider/src/providers/in_memory/mod.rs +++ b/crates/katana/storage/provider/src/providers/in_memory/mod.rs @@ -16,7 +16,7 @@ use parking_lot::RwLock; use self::cache::CacheDb; use self::state::{HistoricalStates, InMemoryStateDb, LatestStateProvider}; use crate::traits::block::{BlockHashProvider, BlockNumberProvider, BlockProvider, HeaderProvider}; -use crate::traits::contract::ContractProvider; +use crate::traits::contract::ContractInfoProvider; use crate::traits::state::{StateFactoryProvider, StateProvider}; use crate::traits::state_update::StateUpdateProvider; use crate::traits::transaction::{ReceiptProvider, TransactionProvider, TransactionsProviderExt}; @@ -27,6 +27,15 @@ pub struct InMemoryProvider { historical_states: RwLock, } +impl InMemoryProvider { + pub fn new() -> Self { + let storage = CacheDb::new(()); + let state = Arc::new(InMemoryStateDb::new(())); + let historical_states = RwLock::new(HistoricalStates::default()); + Self { storage, state, historical_states } + } +} + impl BlockHashProvider for InMemoryProvider { fn latest_hash(&self) -> Result { Ok(self.storage.latest_block_hash) @@ -82,8 +91,9 @@ impl BlockProvider for InMemoryProvider { }; let body = self.transactions_by_block(id)?.unwrap_or_default(); + let status = self.storage.block_statusses.get(&header.number).cloned().expect("must have"); - Ok(Some(Block { header, body })) + Ok(Some(Block { header, body, status })) } fn blocks_in_range(&self, range: RangeInclusive) -> Result> { @@ -191,7 +201,7 @@ impl ReceiptProvider for InMemoryProvider { } } -impl ContractProvider for InMemoryProvider { +impl ContractInfoProvider for InMemoryProvider { fn contract(&self, address: ContractAddress) -> Result> { let contract = self.state.contract_state.read().get(&address).cloned(); Ok(contract) diff --git a/crates/katana/storage/provider/src/providers/in_memory/state.rs b/crates/katana/storage/provider/src/providers/in_memory/state.rs index 42324d9cbd..1dc3e1b195 100644 --- a/crates/katana/storage/provider/src/providers/in_memory/state.rs +++ b/crates/katana/storage/provider/src/providers/in_memory/state.rs @@ -4,12 +4,13 @@ use std::sync::Arc; use anyhow::Result; use katana_primitives::block::BlockNumber; use katana_primitives::contract::{ - ClassHash, CompiledClassHash, CompiledContractClass, ContractAddress, Nonce, SierraClass, - StorageKey, StorageValue, + ClassHash, CompiledClassHash, CompiledContractClass, ContractAddress, GenericContractInfo, + Nonce, SierraClass, StorageKey, StorageValue, }; use super::cache::{CacheSnapshotWithoutClasses, CacheStateDb, SharedContractClasses}; -use crate::traits::state::{StateProvider, StateProviderExt}; +use crate::traits::contract::{ContractClassProvider, ContractInfoProvider}; +use crate::traits::state::StateProvider; pub struct StateSnapshot { pub(crate) classes: Arc, @@ -24,7 +25,7 @@ const MIN_HISTORY_LIMIT: usize = 10; /// It should store at N - 1 states, where N is the latest block number. pub struct HistoricalStates { /// The states at a certain block based on the block number - states: HashMap>, + states: HashMap>, /// How many states to store at most in_memory_limit: usize, /// minimum amount of states we keep in memory @@ -44,7 +45,7 @@ impl HistoricalStates { } /// Returns the state for the given `block_hash` if present - pub fn get(&self, block_num: &BlockNumber) -> Option<&Arc> { + pub fn get(&self, block_num: &BlockNumber) -> Option<&Arc> { self.states.get(block_num) } @@ -56,7 +57,7 @@ impl HistoricalStates { /// Since we keep a snapshot of the entire state as history, the size of the state will increase /// with the transactions processed. To counter this, we gradually decrease the cache limit with /// the number of states/blocks until we reached the `min_limit`. - pub fn insert(&mut self, block_num: BlockNumber, state: Box) { + pub fn insert(&mut self, block_num: BlockNumber, state: Box) { if self.present.len() >= self.in_memory_limit { // once we hit the max limit we gradually decrease it self.in_memory_limit = @@ -114,9 +115,16 @@ impl InMemoryStateDb { } } +impl ContractInfoProvider for InMemorySnapshot { + fn contract(&self, address: ContractAddress) -> Result> { + let info = self.inner.contract_state.get(&address).cloned(); + Ok(info) + } +} + impl StateProvider for InMemorySnapshot { fn nonce(&self, address: ContractAddress) -> Result> { - let nonce = self.inner.contract_state.get(&address).map(|info| info.nonce); + let nonce = ContractInfoProvider::contract(&self, address)?.map(|i| i.nonce); Ok(nonce) } @@ -129,14 +137,21 @@ impl StateProvider for InMemorySnapshot { Ok(value) } - fn class(&self, hash: ClassHash) -> Result> { - let class = self.classes.compiled_classes.read().get(&hash).cloned(); + fn class_hash_of_contract(&self, address: ContractAddress) -> Result> { + let class_hash = ContractInfoProvider::contract(&self, address)?.map(|i| i.class_hash); + Ok(class_hash) + } +} + +impl ContractClassProvider for InMemorySnapshot { + fn sierra_class(&self, hash: ClassHash) -> Result> { + let class = self.classes.sierra_classes.read().get(&hash).cloned(); Ok(class) } - fn class_hash_of_contract(&self, address: ContractAddress) -> Result> { - let class_hash = self.inner.contract_state.get(&address).map(|info| info.class_hash); - Ok(class_hash) + fn class(&self, hash: ClassHash) -> Result> { + let class = self.classes.compiled_classes.read().get(&hash).cloned(); + Ok(class) } fn compiled_class_hash_of_class_hash( @@ -148,18 +163,18 @@ impl StateProvider for InMemorySnapshot { } } -impl StateProviderExt for InMemorySnapshot { - fn sierra_class(&self, hash: ClassHash) -> Result> { - let class = self.classes.sierra_classes.read().get(&hash).cloned(); - Ok(class) +pub(super) struct LatestStateProvider(pub(super) Arc); + +impl ContractInfoProvider for LatestStateProvider { + fn contract(&self, address: ContractAddress) -> Result> { + let info = self.0.contract_state.read().get(&address).cloned(); + Ok(info) } } -pub(super) struct LatestStateProvider(pub(super) Arc); - impl StateProvider for LatestStateProvider { fn nonce(&self, address: ContractAddress) -> Result> { - let nonce = self.0.contract_state.read().get(&address).map(|info| info.nonce); + let nonce = ContractInfoProvider::contract(&self, address)?.map(|i| i.nonce); Ok(nonce) } @@ -172,14 +187,21 @@ impl StateProvider for LatestStateProvider { Ok(value) } - fn class(&self, hash: ClassHash) -> Result> { - let class = self.0.shared_contract_classes.compiled_classes.read().get(&hash).cloned(); + fn class_hash_of_contract(&self, address: ContractAddress) -> Result> { + let class_hash = ContractInfoProvider::contract(&self, address)?.map(|i| i.class_hash); + Ok(class_hash) + } +} + +impl ContractClassProvider for LatestStateProvider { + fn sierra_class(&self, hash: ClassHash) -> Result> { + let class = self.0.shared_contract_classes.sierra_classes.read().get(&hash).cloned(); Ok(class) } - fn class_hash_of_contract(&self, address: ContractAddress) -> Result> { - let class_hash = self.0.contract_state.read().get(&address).map(|info| info.class_hash); - Ok(class_hash) + fn class(&self, hash: ClassHash) -> Result> { + let class = self.0.shared_contract_classes.compiled_classes.read().get(&hash).cloned(); + Ok(class) } fn compiled_class_hash_of_class_hash( @@ -191,13 +213,6 @@ impl StateProvider for LatestStateProvider { } } -impl StateProviderExt for LatestStateProvider { - fn sierra_class(&self, hash: ClassHash) -> Result> { - let class = self.0.shared_contract_classes.sierra_classes.read().get(&hash).cloned(); - Ok(class) - } -} - #[cfg(test)] mod tests { use katana_primitives::block::BlockHashOrNumber; diff --git a/crates/katana/storage/provider/src/traits/contract.rs b/crates/katana/storage/provider/src/traits/contract.rs index 6fad0bf64d..c3f5694936 100644 --- a/crates/katana/storage/provider/src/traits/contract.rs +++ b/crates/katana/storage/provider/src/traits/contract.rs @@ -1,7 +1,44 @@ use anyhow::Result; -use katana_primitives::contract::{ContractAddress, GenericContractInfo}; +use katana_primitives::contract::{ + ClassHash, CompiledClassHash, CompiledContractClass, ContractAddress, GenericContractInfo, + SierraClass, +}; -pub trait ContractProvider: Send + Sync { +#[auto_impl::auto_impl(&, Box, Arc)] +pub trait ContractInfoProvider: Send + Sync { /// Returns the contract information given its address. fn contract(&self, address: ContractAddress) -> Result>; } + +/// A provider trait for retrieving contract class related information. +#[auto_impl::auto_impl(&, Box, Arc)] +pub trait ContractClassProvider: Send + Sync { + /// Returns the compiled class hash for the given class hash. + fn compiled_class_hash_of_class_hash( + &self, + hash: ClassHash, + ) -> Result>; + + /// Returns the compiled class definition of a contract class given its class hash. + fn class(&self, hash: ClassHash) -> Result>; + + /// Retrieves the Sierra class definition of a contract class given its class hash. + fn sierra_class(&self, hash: ClassHash) -> Result>; +} + +// TEMP: added mainly for compatibility reason following the path of least resistance. +#[auto_impl::auto_impl(&, Box, Arc)] +pub trait ContractClassWriter: ContractClassProvider + Send + Sync { + /// Returns the compiled class hash for the given class hash. + fn set_compiled_class_hash_of_class_hash( + &self, + hash: ClassHash, + compiled_hash: CompiledClassHash, + ) -> Result<()>; + + /// Returns the compiled class definition of a contract class given its class hash. + fn set_class(&self, hash: ClassHash, class: CompiledContractClass) -> Result<()>; + + /// Retrieves the Sierra class definition of a contract class given its class hash. + fn set_sierra_class(&self, hash: ClassHash, sierra: SierraClass) -> Result<()>; +} diff --git a/crates/katana/storage/provider/src/traits/state.rs b/crates/katana/storage/provider/src/traits/state.rs index 27f62a55da..d6e726a3a1 100644 --- a/crates/katana/storage/provider/src/traits/state.rs +++ b/crates/katana/storage/provider/src/traits/state.rs @@ -1,15 +1,11 @@ use anyhow::Result; use katana_primitives::block::BlockHashOrNumber; -use katana_primitives::contract::{ - ClassHash, CompiledClassHash, CompiledContractClass, ContractAddress, Nonce, SierraClass, - StorageKey, StorageValue, -}; +use katana_primitives::contract::{ClassHash, ContractAddress, Nonce, StorageKey, StorageValue}; -#[auto_impl::auto_impl(&, Box, Arc)] -pub trait StateProvider: Send + Sync { - /// Returns the compiled class definition of a contract class given its class hash. - fn class(&self, hash: ClassHash) -> Result>; +use super::contract::{ContractClassProvider, ContractClassWriter, ContractInfoProvider}; +#[auto_impl::auto_impl(&, Box, Arc)] +pub trait StateProvider: ContractInfoProvider + ContractClassProvider + Send + Sync { /// Returns the nonce of a contract. fn nonce(&self, address: ContractAddress) -> Result>; @@ -22,19 +18,6 @@ pub trait StateProvider: Send + Sync { /// Returns the class hash of a contract. fn class_hash_of_contract(&self, address: ContractAddress) -> Result>; - - /// Returns the compiled class hash for the given class hash. - fn compiled_class_hash_of_class_hash( - &self, - hash: ClassHash, - ) -> Result>; -} - -/// An extension of the `StateProvider` trait which provides additional methods. -#[auto_impl::auto_impl(&, Box, Arc)] -pub trait StateProviderExt: StateProvider + Send + Sync { - /// Retrieves the Sierra class definition of a contract class given its class hash. - fn sierra_class(&self, hash: ClassHash) -> Result>; } /// A state factory provider is a provider which can create state providers for @@ -47,3 +30,25 @@ pub trait StateFactoryProvider { /// Returns a state provider for retrieving historical state at the given block. fn historical(&self, block_id: BlockHashOrNumber) -> Result>>; } + +// TEMP: added mainly for compatibility reason following the path of least resistance. +#[auto_impl::auto_impl(&, Box, Arc)] +pub trait StateWriter: StateProvider + ContractClassWriter + Send + Sync { + /// Sets the nonce of a contract. + fn set_nonce(&self, address: ContractAddress, nonce: Nonce) -> Result<()>; + + /// Sets the value of a contract storage. + fn set_storage( + &self, + address: ContractAddress, + storage_key: StorageKey, + storage_value: StorageValue, + ) -> Result<()>; + + /// Sets the class hash of a contract. + fn set_class_hash_of_contract( + &self, + address: ContractAddress, + class_hash: ClassHash, + ) -> Result<()>; +}