Skip to content

Commit

Permalink
refactor(katana-db): main database trait (#2173)
Browse files Browse the repository at this point in the history
  • Loading branch information
kariy authored Jul 11, 2024
1 parent f8776ba commit 600cfca
Show file tree
Hide file tree
Showing 4 changed files with 121 additions and 54 deletions.
45 changes: 45 additions & 0 deletions crates/katana/storage/db/src/abstraction/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,50 @@
mod cursor;
mod transaction;

use std::fmt::Debug;

pub use cursor::*;
pub use transaction::*;

use crate::error::DatabaseError;

/// Main persistent database trait. The database implementation must be transactional.
pub trait Database: Send + Sync {
/// Read-Only transaction
type Tx: DbTx + Send + Sync + Debug + 'static;
/// Read-Write transaction
type TxMut: DbTxMut + Send + Sync + Debug + 'static;

/// Create and begin read-only transaction.
#[track_caller]
fn tx(&self) -> Result<Self::Tx, DatabaseError>;

/// Create and begin read-write transaction, should return error if the database is unable to
/// create the transaction e.g, not opened with read-write permission.
#[track_caller]
fn tx_mut(&self) -> Result<Self::TxMut, DatabaseError>;

/// Takes a function and passes a read-only transaction into it, making sure it's closed in the
/// end of the execution.
fn view<T, F>(&self, f: F) -> Result<T, DatabaseError>
where
F: FnOnce(&Self::Tx) -> T,
{
let tx = self.tx()?;
let res = f(&tx);
tx.commit()?;
Ok(res)
}

/// Takes a function and passes a write-read transaction into it, making sure it's committed in
/// the end of the execution.
fn update<T, F>(&self, f: F) -> Result<T, DatabaseError>
where
F: FnOnce(&Self::TxMut) -> T,
{
let tx = self.tx_mut()?;
let res = f(&tx);
tx.commit()?;
Ok(res)
}
}
25 changes: 14 additions & 11 deletions crates/katana/storage/db/src/mdbx/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pub use libmdbx;
use libmdbx::{DatabaseFlags, EnvironmentFlags, Geometry, Mode, PageSize, SyncMode, RO, RW};

use self::tx::Tx;
use crate::abstraction::DbTx;
use crate::abstraction::{Database, DbTx};
use crate::error::DatabaseError;
use crate::tables::{TableType, Tables};
use crate::utils;
Expand Down Expand Up @@ -88,16 +88,6 @@ impl DbEnv {
Ok(())
}

/// Begin a read-only transaction.
pub fn tx(&self) -> Result<Tx<RO>, DatabaseError> {
Ok(Tx::new(self.0.begin_ro_txn().map_err(DatabaseError::CreateROTx)?))
}

/// Begin a read-write transaction.
pub fn tx_mut(&self) -> Result<Tx<RW>, DatabaseError> {
Ok(Tx::new(self.0.begin_rw_txn().map_err(DatabaseError::CreateRWTx)?))
}

/// Takes a function and passes a read-write transaction into it, making sure it's always
/// committed in the end of the execution.
pub fn update<T, F>(&self, f: F) -> Result<T, DatabaseError>
Expand All @@ -111,6 +101,19 @@ impl DbEnv {
}
}

impl Database for DbEnv {
type Tx = Tx<RO>;
type TxMut = Tx<RW>;

fn tx(&self) -> Result<Self::Tx, DatabaseError> {
Ok(Tx::new(self.0.begin_ro_txn().map_err(DatabaseError::CreateROTx)?))
}

fn tx_mut(&self) -> Result<Self::TxMut, DatabaseError> {
Ok(Tx::new(self.0.begin_rw_txn().map_err(DatabaseError::CreateRWTx)?))
}
}

#[cfg(any(test, feature = "test-utils"))]
pub mod test_utils {
use std::path::Path;
Expand Down
54 changes: 30 additions & 24 deletions crates/katana/storage/provider/src/providers/db/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@ use std::collections::HashMap;
use std::fmt::Debug;
use std::ops::{Range, RangeInclusive};

use katana_db::abstraction::{DbCursor, DbCursorMut, DbDupSortCursor, DbTx, DbTxMut};
use katana_db::abstraction::{Database, DbCursor, DbCursorMut, DbDupSortCursor, DbTx, DbTxMut};
use katana_db::error::DatabaseError;
use katana_db::mdbx::{self, DbEnv};
use katana_db::mdbx::DbEnv;
use katana_db::models::block::StoredBlockBodyIndices;
use katana_db::models::contract::{
ContractClassChange, ContractInfoChangeList, ContractNonceChange,
Expand Down Expand Up @@ -45,17 +45,18 @@ use crate::traits::transaction::{
use crate::ProviderResult;

/// A provider implementation that uses a persistent database as the backend.
// TODO: remove the default generic type
#[derive(Debug)]
pub struct DbProvider(DbEnv);
pub struct DbProvider<Db: Database = DbEnv>(Db);

impl DbProvider {
impl<Db: Database> DbProvider<Db> {
/// Creates a new [`DbProvider`] from the given [`DbEnv`].
pub fn new(db: DbEnv) -> Self {
pub fn new(db: Db) -> Self {
Self(db)
}
}

impl StateFactoryProvider for DbProvider {
impl<Db: Database> StateFactoryProvider for DbProvider<Db> {
fn latest(&self) -> ProviderResult<Box<dyn StateProvider>> {
Ok(Box::new(self::state::LatestStateProvider::new(self.0.tx()?)))
}
Expand Down Expand Up @@ -84,7 +85,7 @@ impl StateFactoryProvider for DbProvider {
}
}

impl BlockNumberProvider for DbProvider {
impl<Db: Database> BlockNumberProvider for DbProvider<Db> {
fn block_number_by_hash(&self, hash: BlockHash) -> ProviderResult<Option<BlockNumber>> {
let db_tx = self.0.tx()?;
let block_num = db_tx.get::<tables::BlockNumbers>(hash)?;
Expand All @@ -101,7 +102,7 @@ impl BlockNumberProvider for DbProvider {
}
}

impl BlockHashProvider for DbProvider {
impl<Db: Database> BlockHashProvider for DbProvider<Db> {
fn latest_hash(&self) -> ProviderResult<BlockHash> {
let latest_block = self.latest_number()?;
let db_tx = self.0.tx()?;
Expand All @@ -118,7 +119,7 @@ impl BlockHashProvider for DbProvider {
}
}

impl HeaderProvider for DbProvider {
impl<Db: Database> HeaderProvider for DbProvider<Db> {
fn header(&self, id: BlockHashOrNumber) -> ProviderResult<Option<Header>> {
let db_tx = self.0.tx()?;

Expand All @@ -138,7 +139,7 @@ impl HeaderProvider for DbProvider {
}
}

impl BlockProvider for DbProvider {
impl<Db: Database> BlockProvider for DbProvider<Db> {
fn block_body_indices(
&self,
id: BlockHashOrNumber,
Expand Down Expand Up @@ -222,7 +223,7 @@ impl BlockProvider for DbProvider {
}
}

impl BlockStatusProvider for DbProvider {
impl<Db: Database> BlockStatusProvider for DbProvider<Db> {
fn block_status(&self, id: BlockHashOrNumber) -> ProviderResult<Option<FinalityStatus>> {
let db_tx = self.0.tx()?;

Expand All @@ -243,7 +244,7 @@ impl BlockStatusProvider for DbProvider {
}
}

impl StateRootProvider for DbProvider {
impl<Db: Database> StateRootProvider for DbProvider<Db> {
fn state_root(&self, block_id: BlockHashOrNumber) -> ProviderResult<Option<FieldElement>> {
let db_tx = self.0.tx()?;

Expand All @@ -262,21 +263,22 @@ impl StateRootProvider for DbProvider {
}
}

impl StateUpdateProvider for DbProvider {
impl<Db: Database> StateUpdateProvider for DbProvider<Db> {
fn state_update(&self, block_id: BlockHashOrNumber) -> ProviderResult<Option<StateUpdates>> {
// A helper function that iterates over all entries in a dupsort table and collects the
// results into `V`. If `key` is not found, `V::default()` is returned.
fn dup_entries<Tb, V, T>(
db_tx: &mdbx::tx::TxRO,
fn dup_entries<Db, Tb, V, T>(
db_tx: &<Db as Database>::Tx,
key: <Tb as Table>::Key,
f: impl FnMut(Result<KeyValue<Tb>, DatabaseError>) -> ProviderResult<T>,
) -> ProviderResult<V>
where
Db: Database,
Tb: DupSort + Debug,
V: FromIterator<T> + Default,
{
Ok(db_tx
.cursor::<Tb>()?
.cursor_dup::<Tb>()?
.walk_dup(Some(key), None)?
.map(|walker| walker.map(f).collect::<ProviderResult<V>>())
.transpose()?
Expand All @@ -288,6 +290,7 @@ impl StateUpdateProvider for DbProvider {

if let Some(block_num) = block_num {
let nonce_updates = dup_entries::<
Db,
tables::NonceChangeHistory,
HashMap<ContractAddress, Nonce>,
_,
Expand All @@ -297,6 +300,7 @@ impl StateUpdateProvider for DbProvider {
})?;

let contract_updates = dup_entries::<
Db,
tables::ClassChangeHistory,
HashMap<ContractAddress, ClassHash>,
_,
Expand All @@ -306,6 +310,7 @@ impl StateUpdateProvider for DbProvider {
})?;

let declared_classes = dup_entries::<
Db,
tables::ClassDeclarations,
HashMap<ClassHash, CompiledClassHash>,
_,
Expand All @@ -321,6 +326,7 @@ impl StateUpdateProvider for DbProvider {

let storage_updates = {
let entries = dup_entries::<
Db,
tables::StorageChangeHistory,
Vec<(ContractAddress, (StorageKey, StorageValue))>,
_,
Expand Down Expand Up @@ -351,7 +357,7 @@ impl StateUpdateProvider for DbProvider {
}
}

impl TransactionProvider for DbProvider {
impl<Db: Database> TransactionProvider for DbProvider<Db> {
fn transaction_by_hash(&self, hash: TxHash) -> ProviderResult<Option<TxWithHash>> {
let db_tx = self.0.tx()?;

Expand Down Expand Up @@ -456,7 +462,7 @@ impl TransactionProvider for DbProvider {
}
}

impl TransactionsProviderExt for DbProvider {
impl<Db: Database> TransactionsProviderExt for DbProvider<Db> {
fn transaction_hashes_in_range(&self, range: Range<TxNumber>) -> ProviderResult<Vec<TxHash>> {
let db_tx = self.0.tx()?;

Expand All @@ -474,7 +480,7 @@ impl TransactionsProviderExt for DbProvider {
}
}

impl TransactionStatusProvider for DbProvider {
impl<Db: Database> TransactionStatusProvider for DbProvider<Db> {
fn transaction_status(&self, hash: TxHash) -> ProviderResult<Option<FinalityStatus>> {
let db_tx = self.0.tx()?;
if let Some(tx_num) = db_tx.get::<tables::TxNumbers>(hash)? {
Expand All @@ -492,7 +498,7 @@ impl TransactionStatusProvider for DbProvider {
}
}

impl TransactionTraceProvider for DbProvider {
impl<Db: Database> TransactionTraceProvider for DbProvider<Db> {
fn transaction_execution(&self, hash: TxHash) -> ProviderResult<Option<TxExecInfo>> {
let db_tx = self.0.tx()?;
if let Some(num) = db_tx.get::<tables::TxNumbers>(hash)? {
Expand Down Expand Up @@ -539,7 +545,7 @@ impl TransactionTraceProvider for DbProvider {
}
}

impl ReceiptProvider for DbProvider {
impl<Db: Database> ReceiptProvider for DbProvider<Db> {
fn receipt_by_hash(&self, hash: TxHash) -> ProviderResult<Option<Receipt>> {
let db_tx = self.0.tx()?;
if let Some(num) = db_tx.get::<tables::TxNumbers>(hash)? {
Expand Down Expand Up @@ -576,7 +582,7 @@ impl ReceiptProvider for DbProvider {
}
}

impl BlockEnvProvider for DbProvider {
impl<Db: Database> BlockEnvProvider for DbProvider<Db> {
fn block_env_at(&self, block_id: BlockHashOrNumber) -> ProviderResult<Option<BlockEnv>> {
let Some(header) = self.header(block_id)? else { return Ok(None) };

Expand All @@ -589,7 +595,7 @@ impl BlockEnvProvider for DbProvider {
}
}

impl BlockWriter for DbProvider {
impl<Db: Database> BlockWriter for DbProvider<Db> {
fn insert_block_with_states_and_receipts(
&self,
block: SealedBlockWithStatus,
Expand Down Expand Up @@ -652,7 +658,7 @@ impl BlockWriter for DbProvider {

// insert storage changes
{
let mut storage_cursor = db_tx.cursor::<tables::ContractStorage>()?;
let mut storage_cursor = db_tx.cursor_dup_mut::<tables::ContractStorage>()?;
for (addr, entries) in states.state_updates.storage_updates {
let entries =
entries.into_iter().map(|(key, value)| StorageEntry { key, value });
Expand Down
Loading

0 comments on commit 600cfca

Please sign in to comment.