diff --git a/crates/katana/storage/db/src/abstraction/mod.rs b/crates/katana/storage/db/src/abstraction/mod.rs index f08f0896ad..8f0f353e13 100644 --- a/crates/katana/storage/db/src/abstraction/mod.rs +++ b/crates/katana/storage/db/src/abstraction/mod.rs @@ -1,5 +1,50 @@ mod cursor; mod transaction; +use std::fmt::Debug; + pub use cursor::*; pub use transaction::*; + +use crate::error::DatabaseError; + +/// Main persistent database trait. The database implementation must be transactional. +pub trait Database: Send + Sync { + /// Read-Only transaction + type Tx: DbTx + Send + Sync + Debug + 'static; + /// Read-Write transaction + type TxMut: DbTxMut + Send + Sync + Debug + 'static; + + /// Create and begin read-only transaction. + #[track_caller] + fn tx(&self) -> Result; + + /// Create and begin read-write transaction, should return error if the database is unable to + /// create the transaction e.g, not opened with read-write permission. + #[track_caller] + fn tx_mut(&self) -> Result; + + /// Takes a function and passes a read-only transaction into it, making sure it's closed in the + /// end of the execution. + fn view(&self, f: F) -> Result + where + F: FnOnce(&Self::Tx) -> T, + { + let tx = self.tx()?; + let res = f(&tx); + tx.commit()?; + Ok(res) + } + + /// Takes a function and passes a write-read transaction into it, making sure it's committed in + /// the end of the execution. + fn update(&self, f: F) -> Result + where + F: FnOnce(&Self::TxMut) -> T, + { + let tx = self.tx_mut()?; + let res = f(&tx); + tx.commit()?; + Ok(res) + } +} diff --git a/crates/katana/storage/db/src/mdbx/mod.rs b/crates/katana/storage/db/src/mdbx/mod.rs index 019aa3a5bc..6770f2ecf7 100644 --- a/crates/katana/storage/db/src/mdbx/mod.rs +++ b/crates/katana/storage/db/src/mdbx/mod.rs @@ -11,7 +11,7 @@ pub use libmdbx; use libmdbx::{DatabaseFlags, EnvironmentFlags, Geometry, Mode, PageSize, SyncMode, RO, RW}; use self::tx::Tx; -use crate::abstraction::DbTx; +use crate::abstraction::{Database, DbTx}; use crate::error::DatabaseError; use crate::tables::{TableType, Tables}; use crate::utils; @@ -88,16 +88,6 @@ impl DbEnv { Ok(()) } - /// Begin a read-only transaction. - pub fn tx(&self) -> Result, DatabaseError> { - Ok(Tx::new(self.0.begin_ro_txn().map_err(DatabaseError::CreateROTx)?)) - } - - /// Begin a read-write transaction. - pub fn tx_mut(&self) -> Result, DatabaseError> { - Ok(Tx::new(self.0.begin_rw_txn().map_err(DatabaseError::CreateRWTx)?)) - } - /// Takes a function and passes a read-write transaction into it, making sure it's always /// committed in the end of the execution. pub fn update(&self, f: F) -> Result @@ -111,6 +101,19 @@ impl DbEnv { } } +impl Database for DbEnv { + type Tx = Tx; + type TxMut = Tx; + + fn tx(&self) -> Result { + Ok(Tx::new(self.0.begin_ro_txn().map_err(DatabaseError::CreateROTx)?)) + } + + fn tx_mut(&self) -> Result { + Ok(Tx::new(self.0.begin_rw_txn().map_err(DatabaseError::CreateRWTx)?)) + } +} + #[cfg(any(test, feature = "test-utils"))] pub mod test_utils { use std::path::Path; diff --git a/crates/katana/storage/provider/src/providers/db/mod.rs b/crates/katana/storage/provider/src/providers/db/mod.rs index f3c18e6f1f..16acb00476 100644 --- a/crates/katana/storage/provider/src/providers/db/mod.rs +++ b/crates/katana/storage/provider/src/providers/db/mod.rs @@ -4,9 +4,9 @@ use std::collections::HashMap; use std::fmt::Debug; use std::ops::{Range, RangeInclusive}; -use katana_db::abstraction::{DbCursor, DbCursorMut, DbDupSortCursor, DbTx, DbTxMut}; +use katana_db::abstraction::{Database, DbCursor, DbCursorMut, DbDupSortCursor, DbTx, DbTxMut}; use katana_db::error::DatabaseError; -use katana_db::mdbx::{self, DbEnv}; +use katana_db::mdbx::DbEnv; use katana_db::models::block::StoredBlockBodyIndices; use katana_db::models::contract::{ ContractClassChange, ContractInfoChangeList, ContractNonceChange, @@ -45,17 +45,18 @@ use crate::traits::transaction::{ use crate::ProviderResult; /// A provider implementation that uses a persistent database as the backend. +// TODO: remove the default generic type #[derive(Debug)] -pub struct DbProvider(DbEnv); +pub struct DbProvider(Db); -impl DbProvider { +impl DbProvider { /// Creates a new [`DbProvider`] from the given [`DbEnv`]. - pub fn new(db: DbEnv) -> Self { + pub fn new(db: Db) -> Self { Self(db) } } -impl StateFactoryProvider for DbProvider { +impl StateFactoryProvider for DbProvider { fn latest(&self) -> ProviderResult> { Ok(Box::new(self::state::LatestStateProvider::new(self.0.tx()?))) } @@ -84,7 +85,7 @@ impl StateFactoryProvider for DbProvider { } } -impl BlockNumberProvider for DbProvider { +impl BlockNumberProvider for DbProvider { fn block_number_by_hash(&self, hash: BlockHash) -> ProviderResult> { let db_tx = self.0.tx()?; let block_num = db_tx.get::(hash)?; @@ -101,7 +102,7 @@ impl BlockNumberProvider for DbProvider { } } -impl BlockHashProvider for DbProvider { +impl BlockHashProvider for DbProvider { fn latest_hash(&self) -> ProviderResult { let latest_block = self.latest_number()?; let db_tx = self.0.tx()?; @@ -118,7 +119,7 @@ impl BlockHashProvider for DbProvider { } } -impl HeaderProvider for DbProvider { +impl HeaderProvider for DbProvider { fn header(&self, id: BlockHashOrNumber) -> ProviderResult> { let db_tx = self.0.tx()?; @@ -138,7 +139,7 @@ impl HeaderProvider for DbProvider { } } -impl BlockProvider for DbProvider { +impl BlockProvider for DbProvider { fn block_body_indices( &self, id: BlockHashOrNumber, @@ -222,7 +223,7 @@ impl BlockProvider for DbProvider { } } -impl BlockStatusProvider for DbProvider { +impl BlockStatusProvider for DbProvider { fn block_status(&self, id: BlockHashOrNumber) -> ProviderResult> { let db_tx = self.0.tx()?; @@ -243,7 +244,7 @@ impl BlockStatusProvider for DbProvider { } } -impl StateRootProvider for DbProvider { +impl StateRootProvider for DbProvider { fn state_root(&self, block_id: BlockHashOrNumber) -> ProviderResult> { let db_tx = self.0.tx()?; @@ -262,21 +263,22 @@ impl StateRootProvider for DbProvider { } } -impl StateUpdateProvider for DbProvider { +impl StateUpdateProvider for DbProvider { fn state_update(&self, block_id: BlockHashOrNumber) -> ProviderResult> { // A helper function that iterates over all entries in a dupsort table and collects the // results into `V`. If `key` is not found, `V::default()` is returned. - fn dup_entries( - db_tx: &mdbx::tx::TxRO, + fn dup_entries( + db_tx: &::Tx, key: ::Key, f: impl FnMut(Result, DatabaseError>) -> ProviderResult, ) -> ProviderResult where + Db: Database, Tb: DupSort + Debug, V: FromIterator + Default, { Ok(db_tx - .cursor::()? + .cursor_dup::()? .walk_dup(Some(key), None)? .map(|walker| walker.map(f).collect::>()) .transpose()? @@ -288,6 +290,7 @@ impl StateUpdateProvider for DbProvider { if let Some(block_num) = block_num { let nonce_updates = dup_entries::< + Db, tables::NonceChangeHistory, HashMap, _, @@ -297,6 +300,7 @@ impl StateUpdateProvider for DbProvider { })?; let contract_updates = dup_entries::< + Db, tables::ClassChangeHistory, HashMap, _, @@ -306,6 +310,7 @@ impl StateUpdateProvider for DbProvider { })?; let declared_classes = dup_entries::< + Db, tables::ClassDeclarations, HashMap, _, @@ -321,6 +326,7 @@ impl StateUpdateProvider for DbProvider { let storage_updates = { let entries = dup_entries::< + Db, tables::StorageChangeHistory, Vec<(ContractAddress, (StorageKey, StorageValue))>, _, @@ -351,7 +357,7 @@ impl StateUpdateProvider for DbProvider { } } -impl TransactionProvider for DbProvider { +impl TransactionProvider for DbProvider { fn transaction_by_hash(&self, hash: TxHash) -> ProviderResult> { let db_tx = self.0.tx()?; @@ -456,7 +462,7 @@ impl TransactionProvider for DbProvider { } } -impl TransactionsProviderExt for DbProvider { +impl TransactionsProviderExt for DbProvider { fn transaction_hashes_in_range(&self, range: Range) -> ProviderResult> { let db_tx = self.0.tx()?; @@ -474,7 +480,7 @@ impl TransactionsProviderExt for DbProvider { } } -impl TransactionStatusProvider for DbProvider { +impl TransactionStatusProvider for DbProvider { fn transaction_status(&self, hash: TxHash) -> ProviderResult> { let db_tx = self.0.tx()?; if let Some(tx_num) = db_tx.get::(hash)? { @@ -492,7 +498,7 @@ impl TransactionStatusProvider for DbProvider { } } -impl TransactionTraceProvider for DbProvider { +impl TransactionTraceProvider for DbProvider { fn transaction_execution(&self, hash: TxHash) -> ProviderResult> { let db_tx = self.0.tx()?; if let Some(num) = db_tx.get::(hash)? { @@ -539,7 +545,7 @@ impl TransactionTraceProvider for DbProvider { } } -impl ReceiptProvider for DbProvider { +impl ReceiptProvider for DbProvider { fn receipt_by_hash(&self, hash: TxHash) -> ProviderResult> { let db_tx = self.0.tx()?; if let Some(num) = db_tx.get::(hash)? { @@ -576,7 +582,7 @@ impl ReceiptProvider for DbProvider { } } -impl BlockEnvProvider for DbProvider { +impl BlockEnvProvider for DbProvider { fn block_env_at(&self, block_id: BlockHashOrNumber) -> ProviderResult> { let Some(header) = self.header(block_id)? else { return Ok(None) }; @@ -589,7 +595,7 @@ impl BlockEnvProvider for DbProvider { } } -impl BlockWriter for DbProvider { +impl BlockWriter for DbProvider { fn insert_block_with_states_and_receipts( &self, block: SealedBlockWithStatus, @@ -652,7 +658,7 @@ impl BlockWriter for DbProvider { // insert storage changes { - let mut storage_cursor = db_tx.cursor::()?; + let mut storage_cursor = db_tx.cursor_dup_mut::()?; for (addr, entries) in states.state_updates.storage_updates { let entries = entries.into_iter().map(|(key, value)| StorageEntry { key, value }); diff --git a/crates/katana/storage/provider/src/providers/db/state.rs b/crates/katana/storage/provider/src/providers/db/state.rs index 507a74c101..75de28092b 100644 --- a/crates/katana/storage/provider/src/providers/db/state.rs +++ b/crates/katana/storage/provider/src/providers/db/state.rs @@ -1,5 +1,6 @@ -use katana_db::abstraction::{DbCursorMut, DbDupSortCursor, DbTx, DbTxMut}; -use katana_db::mdbx::{self}; +use core::fmt; + +use katana_db::abstraction::{Database, DbCursorMut, DbDupSortCursor, DbTx, DbTxMut}; use katana_db::models::contract::ContractInfoChangeList; use katana_db::models::list::BlockList; use katana_db::models::storage::{ContractStorageKey, StorageEntry}; @@ -16,7 +17,7 @@ use crate::traits::contract::{ContractClassProvider, ContractClassWriter}; use crate::traits::state::{StateProvider, StateWriter}; use crate::ProviderResult; -impl StateWriter for DbProvider { +impl StateWriter for DbProvider { fn set_nonce(&self, address: ContractAddress, nonce: Nonce) -> ProviderResult<()> { self.0.update(move |db_tx| -> ProviderResult<()> { let value = if let Some(info) = db_tx.get::(address)? { @@ -36,7 +37,7 @@ impl StateWriter for DbProvider { storage_value: StorageValue, ) -> ProviderResult<()> { self.0.update(move |db_tx| -> ProviderResult<()> { - let mut cursor = db_tx.cursor::()?; + let mut cursor = db_tx.cursor_dup_mut::()?; let entry = cursor.seek_by_key_subkey(address, storage_key)?; match entry { @@ -101,15 +102,18 @@ impl ContractClassWriter for DbProvider { /// A state provider that provides the latest states from the database. #[derive(Debug)] -pub(super) struct LatestStateProvider(mdbx::tx::TxRO); +pub(super) struct LatestStateProvider(Tx); -impl LatestStateProvider { - pub fn new(tx: mdbx::tx::TxRO) -> Self { +impl LatestStateProvider { + pub fn new(tx: Tx) -> Self { Self(tx) } } -impl ContractClassProvider for LatestStateProvider { +impl ContractClassProvider for LatestStateProvider +where + Tx: DbTx + Send + Sync, +{ fn class(&self, hash: ClassHash) -> ProviderResult> { let class = self.0.get::(hash)?; Ok(class) @@ -129,7 +133,10 @@ impl ContractClassProvider for LatestStateProvider { } } -impl StateProvider for LatestStateProvider { +impl StateProvider for LatestStateProvider +where + Tx: DbTx + fmt::Debug + Send + Sync, +{ fn nonce(&self, address: ContractAddress) -> ProviderResult> { let info = self.0.get::(address)?; Ok(info.map(|info| info.nonce)) @@ -148,7 +155,7 @@ impl StateProvider for LatestStateProvider { address: ContractAddress, storage_key: StorageKey, ) -> ProviderResult> { - let mut cursor = self.0.cursor::()?; + let mut cursor = self.0.cursor_dup::()?; let entry = cursor.seek_by_key_subkey(address, storage_key)?; match entry { Some(entry) if entry.key == storage_key => Ok(Some(entry.value)), @@ -159,20 +166,23 @@ impl StateProvider for LatestStateProvider { /// A historical state provider. #[derive(Debug)] -pub(super) struct HistoricalStateProvider { +pub(super) struct HistoricalStateProvider { /// The database transaction used to read the database. - tx: mdbx::tx::TxRO, + tx: Tx, /// The block number of the state. block_number: u64, } -impl HistoricalStateProvider { - pub fn new(tx: mdbx::tx::TxRO, block_number: u64) -> Self { +impl HistoricalStateProvider { + pub fn new(tx: Tx, block_number: u64) -> Self { Self { tx, block_number } } } -impl ContractClassProvider for HistoricalStateProvider { +impl ContractClassProvider for HistoricalStateProvider +where + Tx: DbTx + fmt::Debug + Send + Sync, +{ fn compiled_class_hash_of_class_hash( &self, hash: ClassHash, @@ -207,14 +217,17 @@ impl ContractClassProvider for HistoricalStateProvider { } } -impl StateProvider for HistoricalStateProvider { +impl StateProvider for HistoricalStateProvider +where + Tx: DbTx + fmt::Debug + Send + Sync, +{ fn nonce(&self, address: ContractAddress) -> ProviderResult> { let change_list = self.tx.get::(address)?; if let Some(num) = change_list .and_then(|entry| recent_change_from_block(self.block_number, &entry.nonce_change_list)) { - let mut cursor = self.tx.cursor::()?; + let mut cursor = self.tx.cursor_dup::()?; let entry = cursor.seek_by_key_subkey(num, address)?.ok_or( ProviderError::MissingContractNonceChangeEntry { block: num, @@ -240,7 +253,7 @@ impl StateProvider for HistoricalStateProvider { if let Some(num) = change_list .and_then(|entry| recent_change_from_block(self.block_number, &entry.class_change_list)) { - let mut cursor = self.tx.cursor::()?; + let mut cursor = self.tx.cursor_dup::()?; let entry = cursor.seek_by_key_subkey(num, address)?.ok_or( ProviderError::MissingContractClassChangeEntry { block: num, @@ -267,7 +280,7 @@ impl StateProvider for HistoricalStateProvider { if let Some(num) = block_list.and_then(|list| recent_change_from_block(self.block_number, &list)) { - let mut cursor = self.tx.cursor::()?; + let mut cursor = self.tx.cursor_dup::()?; let entry = cursor.seek_by_key_subkey(num, key)?.ok_or( ProviderError::MissingStorageChangeEntry { block: num,