From 490dab093672ffc52dd7aa5ee8a9251b6123cfc6 Mon Sep 17 00:00:00 2001 From: Ammar Arif Date: Thu, 21 Dec 2023 04:33:18 +0800 Subject: [PATCH] feat(katana-provider): implement a DB provider (#1299) * wip * wip * add test for de/compress contact class * wip * update * update * update test * impl state update provider + tests * update * increase code coverage --- Cargo.lock | 55 +- crates/katana/primitives/src/block.rs | 2 +- crates/katana/primitives/src/state.rs | 2 +- crates/katana/storage/db/Cargo.toml | 7 +- crates/katana/storage/db/benches/codec.rs | 2 +- .../katana/storage/db/src/codecs/postcard.rs | 11 +- crates/katana/storage/db/src/mdbx/cursor.rs | 13 +- crates/katana/storage/db/src/mdbx/mod.rs | 9 +- crates/katana/storage/db/src/models/class.rs | 500 ++++++++++ .../katana/storage/db/src/models/contract.rs | 486 +--------- crates/katana/storage/db/src/models/mod.rs | 1 + .../katana/storage/db/src/models/storage.rs | 103 ++- crates/katana/storage/db/src/tables.rs | 72 +- crates/katana/storage/provider/Cargo.toml | 1 + .../storage/provider/src/providers/db/mod.rs | 864 ++++++++++++++++++ .../provider/src/providers/db/state.rs | 337 +++++++ .../storage/provider/src/providers/mod.rs | 1 + .../provider/src/traits/transaction.rs | 5 + crates/katana/storage/provider/tests/block.rs | 141 ++- crates/katana/storage/provider/tests/class.rs | 23 + .../katana/storage/provider/tests/contract.rs | 25 +- .../katana/storage/provider/tests/fixtures.rs | 101 +- .../katana/storage/provider/tests/storage.rs | 23 + 23 files changed, 2225 insertions(+), 559 deletions(-) create mode 100644 crates/katana/storage/db/src/models/class.rs create mode 100644 crates/katana/storage/provider/src/providers/db/mod.rs create mode 100644 crates/katana/storage/provider/src/providers/db/state.rs diff --git a/Cargo.lock b/Cargo.lock index 6d66fe6cca..10d5c7b210 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -827,15 +827,6 @@ dependencies = [ "serde", ] -[[package]] -name = "bincode" -version = "1.3.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad" -dependencies = [ - "serde", -] - [[package]] name = "bincode" version = "2.0.0-rc.3" @@ -1727,7 +1718,7 @@ version = "0.8.2" source = "git+https://github.com/dojoengine/cairo-rs.git?rev=262b7eb4b11ab165a2a936a5f914e78aa732d4a2#262b7eb4b11ab165a2a936a5f914e78aa732d4a2" dependencies = [ "anyhow", - "bincode 2.0.0-rc.3", + "bincode", "bitvec", "cairo-felt", "generic-array", @@ -1888,9 +1879,9 @@ dependencies = [ [[package]] name = "clap-verbosity-flag" -version = "2.1.1" +version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c90e95e5bd4e8ac34fa6f37c774b0c6f8ed06ea90c79931fd448fcf941a9767" +checksum = "e5fdbb015d790cfb378aca82caf9cc52a38be96a7eecdb92f31b4366a8afc019" dependencies = [ "clap", "log", @@ -2074,9 +2065,9 @@ dependencies = [ [[package]] name = "const-oid" -version = "0.9.6" +version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" +checksum = "28c122c3980598d243d63d9a704629a2d748d101f278052ff068be5a4423ab6f" [[package]] name = "const_format" @@ -2969,9 +2960,9 @@ checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" [[package]] name = "erased-serde" -version = "0.4.1" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4adbf0983fe06bd3a5c19c8477a637c2389feb0994eca7a59e3b961054aa7c0a" +checksum = "a3286168faae03a0e583f6fde17c02c8b8bba2dcc2061d0f7817066e5b0af706" dependencies = [ "serde", ] @@ -3344,9 +3335,9 @@ dependencies = [ [[package]] name = "eyre" -version = "0.6.11" +version = "0.6.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6267a1fa6f59179ea4afc8e50fd8612a3cc60bc858f786ff877a4a8cb042799" +checksum = "80f656be11ddf91bd709454d15d5bd896fbaf4cc3314e69349e4d1569f5b46cd" dependencies = [ "indenter", "once_cell", @@ -4286,9 +4277,9 @@ dependencies = [ [[package]] name = "gix-ref" -version = "0.39.1" +version = "0.39.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3b2069adc212cf7f3317ef55f6444abd06c50f28479dbbac5a86acf3b05cbbfe" +checksum = "1ac23ed741583c792f573c028785db683496a6dfcd672ec701ee54ba6a77e1ff" dependencies = [ "gix-actor", "gix-date", @@ -4756,11 +4747,11 @@ dependencies = [ [[package]] name = "home" -version = "0.5.9" +version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3d1354bf6b7235cb4a0576c2619fd4ed18183f689b12b006a0ee7329eeff9a5" +checksum = "5444c27eef6923071f7ebcc33e3444508466a76f7a2b93da00ed6e19f30c1ddb" dependencies = [ - "windows-sys 0.52.0", + "windows-sys 0.48.0", ] [[package]] @@ -5515,7 +5506,6 @@ name = "katana-db" version = "0.4.2" dependencies = [ "anyhow", - "bincode 1.3.3", "blockifier", "cairo-lang-starknet", "cairo-vm", @@ -5583,6 +5573,7 @@ dependencies = [ "rstest", "rstest_reuse", "starknet", + "tempfile", "thiserror", "tokio", "tracing", @@ -9363,18 +9354,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.51" +version = "1.0.50" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f11c217e1416d6f036b870f14e0413d480dbf28edbee1f877abaf0206af43bb7" +checksum = "f9a7210f5c9a7156bb50aa36aed4c95afb51df0df00713949448cf9e97d382d2" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.51" +version = "1.0.50" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01742297787513b79cf8e29d1056ede1313e2420b7b3b15d0a768b4921f549df" +checksum = "266b2e40bc00e5a6c09c3584011e08b06f123c00362c92b975ba9843aaaa14b8" dependencies = [ "proc-macro2", "quote", @@ -10975,18 +10966,18 @@ checksum = "09041cd90cf85f7f8b2df60c646f853b7f535ce68f85244eb6731cf89fa498ec" [[package]] name = "zerocopy" -version = "0.7.31" +version = "0.7.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1c4061bedbb353041c12f413700357bec76df2c7e2ca8e4df8bac24c6bf68e3d" +checksum = "306dca4455518f1f31635ec308b6b3e4eb1b11758cefafc782827d0aa7acb5c7" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.7.31" +version = "0.7.30" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3c129550b3e6de3fd0ba67ba5c81818f9805e58b8d7fee80a3a59d2c9fc601a" +checksum = "be912bf68235a88fbefd1b73415cb218405958d1655b2ece9035a19920bdf6ba" dependencies = [ "proc-macro2", "quote", diff --git a/crates/katana/primitives/src/block.rs b/crates/katana/primitives/src/block.rs index b451dbc3f1..f5f762f462 100644 --- a/crates/katana/primitives/src/block.rs +++ b/crates/katana/primitives/src/block.rs @@ -115,7 +115,7 @@ pub struct Block { } /// A block with only the transaction hashes. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct BlockWithTxHashes { pub header: Header, pub body: Vec, diff --git a/crates/katana/primitives/src/state.rs b/crates/katana/primitives/src/state.rs index ed95f1a682..65cfa8341a 100644 --- a/crates/katana/primitives/src/state.rs +++ b/crates/katana/primitives/src/state.rs @@ -8,7 +8,7 @@ use crate::contract::{ /// State updates. /// /// Represents all the state updates after performing some executions on a state. -#[derive(Debug, Default, Clone)] +#[derive(Debug, Default, Clone, PartialEq, Eq)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct StateUpdates { /// A mapping of contract addresses to their updated nonces. diff --git a/crates/katana/storage/db/Cargo.toml b/crates/katana/storage/db/Cargo.toml index 75870a70fd..87de9a249d 100644 --- a/crates/katana/storage/db/Cargo.toml +++ b/crates/katana/storage/db/Cargo.toml @@ -10,7 +10,6 @@ version.workspace = true katana-primitives = { path = "../../primitives" } anyhow.workspace = true -bincode = "1.3.3" page_size = "0.6.0" parking_lot.workspace = true serde.workspace = true @@ -23,7 +22,11 @@ cairo-vm.workspace = true starknet_api.workspace = true # codecs -postcard = { version = "1.0.8", optional = true, default-features = false, features = [ "use-std" ] } +[dependencies.postcard] +default-features = false +features = [ "use-std" ] +optional = true +version = "1.0.8" [dependencies.libmdbx] git = "https://github.com/paradigmxyz/reth.git" diff --git a/crates/katana/storage/db/benches/codec.rs b/crates/katana/storage/db/benches/codec.rs index e74344ec92..183667a021 100644 --- a/crates/katana/storage/db/benches/codec.rs +++ b/crates/katana/storage/db/benches/codec.rs @@ -2,7 +2,7 @@ use blockifier::execution::contract_class::ContractClassV1; use cairo_lang_starknet::casm_contract_class::CasmContractClass; use criterion::{black_box, criterion_group, criterion_main, Criterion}; use katana_db::codecs::{Compress, Decompress}; -use katana_db::models::contract::StoredContractClass; +use katana_db::models::class::StoredContractClass; use katana_primitives::contract::CompiledContractClass; fn compress_contract(contract: CompiledContractClass) -> Vec { diff --git a/crates/katana/storage/db/src/codecs/postcard.rs b/crates/katana/storage/db/src/codecs/postcard.rs index 0564dc3370..66a7ccddc7 100644 --- a/crates/katana/storage/db/src/codecs/postcard.rs +++ b/crates/katana/storage/db/src/codecs/postcard.rs @@ -1,4 +1,4 @@ -use katana_primitives::block::Header; +use katana_primitives::block::{BlockNumber, Header}; use katana_primitives::contract::{ContractAddress, GenericContractInfo, SierraClass}; use katana_primitives::receipt::Receipt; use katana_primitives::transaction::Tx; @@ -7,7 +7,8 @@ use katana_primitives::FieldElement; use super::{Compress, Decompress}; use crate::error::CodecError; use crate::models::block::StoredBlockBodyIndices; -use crate::models::contract::StoredContractClass; +use crate::models::class::StoredContractClass; +use crate::models::contract::ContractInfoChangeList; macro_rules! impl_compress_and_decompress_for_table_values { ($($name:ty),*) => { @@ -21,7 +22,7 @@ macro_rules! impl_compress_and_decompress_for_table_values { impl Decompress for $name { fn decompress>(bytes: B) -> Result { - postcard::from_bytes(bytes.as_ref()).map_err(|e| CodecError::Decode(e.to_string())) + postcard::from_bytes(bytes.as_ref()).map_err(|e| CodecError::Decompress(e.to_string())) } } )* @@ -36,7 +37,9 @@ impl_compress_and_decompress_for_table_values!( SierraClass, FieldElement, ContractAddress, + Vec, StoredContractClass, GenericContractInfo, - StoredBlockBodyIndices + StoredBlockBodyIndices, + ContractInfoChangeList ); diff --git a/crates/katana/storage/db/src/mdbx/cursor.rs b/crates/katana/storage/db/src/mdbx/cursor.rs index 943d92c3df..9cac3876e3 100644 --- a/crates/katana/storage/db/src/mdbx/cursor.rs +++ b/crates/katana/storage/db/src/mdbx/cursor.rs @@ -141,7 +141,7 @@ impl Cursor { &mut self, key: Option, subkey: Option, - ) -> Result, DatabaseError> { + ) -> Result>, DatabaseError> { let start = match (key, subkey) { (Some(key), Some(subkey)) => { // encode key and decode it after. @@ -154,10 +154,17 @@ impl Cursor { (Some(key), None) => { let key: Vec = key.encode().into(); - self.inner + + let Some(start) = self + .inner .set(key.as_ref()) .map_err(DatabaseError::Read)? .map(|val| decoder::((Cow::Owned(key), val))) + else { + return Ok(None); + }; + + Some(start) } (None, Some(subkey)) => { @@ -175,7 +182,7 @@ impl Cursor { (None, None) => self.first().transpose(), }; - Ok(DupWalker::new(self, start)) + Ok(Some(DupWalker::new(self, start))) } } diff --git a/crates/katana/storage/db/src/mdbx/mod.rs b/crates/katana/storage/db/src/mdbx/mod.rs index 76fc9b1cf3..3b31ba0507 100644 --- a/crates/katana/storage/db/src/mdbx/mod.rs +++ b/crates/katana/storage/db/src/mdbx/mod.rs @@ -112,18 +112,17 @@ impl DbEnv { #[cfg(any(test, feature = "test-utils"))] pub mod test_utils { use std::path::Path; - use std::sync::Arc; use super::{DbEnv, DbEnvKind}; const ERROR_DB_CREATION: &str = "Not able to create the mdbx file."; /// Create database for testing - pub fn create_test_db(kind: DbEnvKind) -> Arc { - Arc::new(create_test_db_with_path( + pub fn create_test_db(kind: DbEnvKind) -> DbEnv { + create_test_db_with_path( kind, &tempfile::TempDir::new().expect("Failed to create temp dir.").into_path(), - )) + ) } /// Create database for testing with specified path @@ -392,7 +391,7 @@ mod tests { { let tx = env.tx().expect(ERROR_INIT_TX); let mut cursor = tx.cursor::().unwrap(); - let mut walker = cursor.walk_dup(Some(key), Some(felt!("1"))).unwrap(); + let mut walker = cursor.walk_dup(Some(key), Some(felt!("1"))).unwrap().unwrap(); assert_eq!( (key, value11), diff --git a/crates/katana/storage/db/src/models/class.rs b/crates/katana/storage/db/src/models/class.rs new file mode 100644 index 0000000000..b30ae84ed6 --- /dev/null +++ b/crates/katana/storage/db/src/models/class.rs @@ -0,0 +1,500 @@ +//! Serializable without using custome functions + +use std::collections::HashMap; +use std::sync::Arc; + +use blockifier::execution::contract_class::{ + ContractClass, ContractClassV0, ContractClassV0Inner, ContractClassV1, ContractClassV1Inner, +}; +use cairo_vm::felt::Felt252; +use cairo_vm::hint_processor::hint_processor_definition::HintReference; +use cairo_vm::serde::deserialize_program::{ + ApTracking, Attribute, BuiltinName, FlowTrackingData, HintParams, Identifier, + InstructionLocation, Member, OffsetValue, +}; +use cairo_vm::types::program::{Program, SharedProgramData}; +use cairo_vm::types::relocatable::MaybeRelocatable; +use serde::{Deserialize, Serialize}; +use starknet_api::core::EntryPointSelector; +use starknet_api::deprecated_contract_class::{EntryPoint, EntryPointOffset, EntryPointType}; + +#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] +pub enum StoredContractClass { + V0(StoredContractClassV0), + V1(StoredContractClassV1), +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] +pub struct StoredContractClassV0 { + pub program: SerializableProgram, + pub entry_points_by_type: HashMap>, +} + +#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] +pub struct StoredContractClassV1 { + pub program: SerializableProgram, + pub hints: HashMap>, + pub entry_points_by_type: HashMap>, +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct SerializableEntryPoint { + pub selector: EntryPointSelector, + pub offset: SerializableEntryPointOffset, +} + +impl From for SerializableEntryPoint { + fn from(value: EntryPoint) -> Self { + Self { selector: value.selector, offset: value.offset.into() } + } +} + +impl From for EntryPoint { + fn from(value: SerializableEntryPoint) -> Self { + Self { selector: value.selector, offset: value.offset.into() } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct SerializableEntryPointOffset(pub usize); + +impl From for SerializableEntryPointOffset { + fn from(value: EntryPointOffset) -> Self { + Self(value.0) + } +} + +impl From for EntryPointOffset { + fn from(value: SerializableEntryPointOffset) -> Self { + Self(value.0) + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct SerializableEntryPointV1 { + pub selector: EntryPointSelector, + pub offset: SerializableEntryPointOffset, + pub builtins: Vec, +} + +impl From for blockifier::execution::contract_class::EntryPointV1 { + fn from(value: SerializableEntryPointV1) -> Self { + blockifier::execution::contract_class::EntryPointV1 { + selector: value.selector, + offset: value.offset.into(), + builtins: value.builtins, + } + } +} + +impl From for SerializableEntryPointV1 { + fn from(value: blockifier::execution::contract_class::EntryPointV1) -> Self { + SerializableEntryPointV1 { + selector: value.selector, + offset: value.offset.into(), + builtins: value.builtins, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct SerializableProgram { + pub shared_program_data: SerializableSharedProgramData, + pub constants: HashMap, + pub builtins: Vec, +} + +impl From for SerializableProgram { + fn from(value: Program) -> Self { + Self { + shared_program_data: value.shared_program_data.as_ref().clone().into(), + constants: value.constants, + builtins: value.builtins, + } + } +} + +impl From for Program { + fn from(value: SerializableProgram) -> Self { + Self { + shared_program_data: Arc::new(value.shared_program_data.into()), + constants: value.constants, + builtins: value.builtins, + } + } +} + +// Fields of `SerializableProgramData` must not rely on `deserialize_any` +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct SerializableSharedProgramData { + pub data: Vec, + pub hints: HashMap>, + pub main: Option, + pub start: Option, + pub end: Option, + pub error_message_attributes: Vec, + pub instruction_locations: Option>, + pub identifiers: HashMap, + pub reference_manager: Vec, +} + +impl From for SerializableSharedProgramData { + fn from(value: SharedProgramData) -> Self { + Self { + data: value.data, + hints: value + .hints + .into_iter() + .map(|(k, v)| (k, v.into_iter().map(|h| h.into()).collect())) + .collect(), + main: value.main, + start: value.start, + end: value.end, + error_message_attributes: value + .error_message_attributes + .into_iter() + .map(|a| a.into()) + .collect(), + instruction_locations: value.instruction_locations, + identifiers: value.identifiers.into_iter().map(|(k, v)| (k, v.into())).collect(), + reference_manager: value.reference_manager.into_iter().map(|r| r.into()).collect(), + } + } +} + +impl From for SharedProgramData { + fn from(value: SerializableSharedProgramData) -> Self { + Self { + data: value.data, + hints: value + .hints + .into_iter() + .map(|(k, v)| (k, v.into_iter().map(|h| h.into()).collect())) + .collect(), + main: value.main, + start: value.start, + end: value.end, + error_message_attributes: value + .error_message_attributes + .into_iter() + .map(|a| a.into()) + .collect(), + instruction_locations: value.instruction_locations, + identifiers: value.identifiers.into_iter().map(|(k, v)| (k, v.into())).collect(), + reference_manager: value.reference_manager.into_iter().map(|r| r.into()).collect(), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct SerializableHintParams { + pub code: String, + pub accessible_scopes: Vec, + pub flow_tracking_data: SerializableFlowTrackingData, +} + +impl From for SerializableHintParams { + fn from(value: HintParams) -> Self { + Self { + code: value.code, + accessible_scopes: value.accessible_scopes, + flow_tracking_data: value.flow_tracking_data.into(), + } + } +} + +impl From for HintParams { + fn from(value: SerializableHintParams) -> Self { + Self { + code: value.code, + accessible_scopes: value.accessible_scopes, + flow_tracking_data: value.flow_tracking_data.into(), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct SerializableIdentifier { + pub pc: Option, + pub type_: Option, + pub value: Option, + pub full_name: Option, + pub members: Option>, + pub cairo_type: Option, +} + +impl From for SerializableIdentifier { + fn from(value: Identifier) -> Self { + Self { + pc: value.pc, + type_: value.type_, + value: value.value, + full_name: value.full_name, + members: value.members, + cairo_type: value.cairo_type, + } + } +} + +impl From for Identifier { + fn from(value: SerializableIdentifier) -> Self { + Self { + pc: value.pc, + type_: value.type_, + value: value.value, + full_name: value.full_name, + members: value.members, + cairo_type: value.cairo_type, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct SerializableHintReference { + pub offset1: OffsetValue, + pub offset2: OffsetValue, + pub dereference: bool, + pub ap_tracking_data: Option, + pub cairo_type: Option, +} + +impl From for SerializableHintReference { + fn from(value: HintReference) -> Self { + Self { + offset1: value.offset1, + offset2: value.offset2, + dereference: value.dereference, + ap_tracking_data: value.ap_tracking_data, + cairo_type: value.cairo_type, + } + } +} + +impl From for HintReference { + fn from(value: SerializableHintReference) -> Self { + Self { + offset1: value.offset1, + offset2: value.offset2, + dereference: value.dereference, + ap_tracking_data: value.ap_tracking_data, + cairo_type: value.cairo_type, + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct SerializableAttribute { + pub name: String, + pub start_pc: usize, + pub end_pc: usize, + pub value: String, + pub flow_tracking_data: Option, +} + +impl From for SerializableAttribute { + fn from(value: Attribute) -> Self { + Self { + name: value.name, + start_pc: value.start_pc, + end_pc: value.end_pc, + value: value.value, + flow_tracking_data: value.flow_tracking_data.map(|d| d.into()), + } + } +} + +impl From for Attribute { + fn from(value: SerializableAttribute) -> Self { + Self { + name: value.name, + start_pc: value.start_pc, + end_pc: value.end_pc, + value: value.value, + flow_tracking_data: value.flow_tracking_data.map(|d| d.into()), + } + } +} + +#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] +pub struct SerializableFlowTrackingData { + pub ap_tracking: ApTracking, + pub reference_ids: HashMap, +} + +impl From for SerializableFlowTrackingData { + fn from(value: FlowTrackingData) -> Self { + Self { ap_tracking: value.ap_tracking, reference_ids: value.reference_ids } + } +} + +impl From for FlowTrackingData { + fn from(value: SerializableFlowTrackingData) -> Self { + Self { ap_tracking: value.ap_tracking, reference_ids: value.reference_ids } + } +} + +impl From for ContractClass { + fn from(value: StoredContractClass) -> Self { + match value { + StoredContractClass::V0(v0) => { + ContractClass::V0(ContractClassV0(Arc::new(ContractClassV0Inner { + program: v0.program.into(), + entry_points_by_type: v0 + .entry_points_by_type + .into_iter() + .map(|(k, v)| (k, v.into_iter().map(|h| h.into()).collect())) + .collect(), + }))) + } + StoredContractClass::V1(v1) => { + ContractClass::V1(ContractClassV1(Arc::new(ContractClassV1Inner { + hints: v1 + .hints + .clone() + .into_iter() + .map(|(k, v)| (k, serde_json::from_slice(&v).expect("valid hint"))) + .collect(), + program: v1.program.into(), + entry_points_by_type: v1 + .entry_points_by_type + .into_iter() + .map(|(k, v)| { + ( + k, + v.into_iter() + .map(Into::into) + .collect::>(), + ) + }) + .collect::>(), + }))) + } + } + } +} + +impl From for StoredContractClass { + fn from(value: ContractClass) -> Self { + match value { + ContractClass::V0(v0) => { + let entry_points_by_type = v0 + .entry_points_by_type + .clone() + .into_iter() + .map(|(k, v)| (k, v.into_iter().map(SerializableEntryPoint::from).collect())) + .collect(); + + StoredContractClass::V0(StoredContractClassV0 { + program: v0.program.clone().into(), + entry_points_by_type, + }) + } + + ContractClass::V1(v1) => StoredContractClass::V1(StoredContractClassV1 { + program: v1.program.clone().into(), + entry_points_by_type: v1 + .entry_points_by_type + .clone() + .into_iter() + .map(|(k, v)| { + ( + k, + v.into_iter() + .map(Into::into) + .collect::>(), + ) + }) + .collect::>(), + hints: v1 + .hints + .clone() + .into_iter() + .map(|(k, v)| (k, serde_json::to_vec(&v).expect("valid hint"))) + .collect(), + }), + } + } +} + +#[cfg(test)] +mod tests { + use cairo_lang_starknet::casm_contract_class::CasmContractClass; + use katana_primitives::contract::CompiledContractClass; + use starknet_api::hash::StarkFelt; + use starknet_api::stark_felt; + + use super::*; + use crate::codecs::{Compress, Decompress}; + + #[test] + fn serialize_deserialize_legacy_entry_points() { + let non_serde = vec![ + EntryPoint { + offset: EntryPointOffset(0x25f), + selector: EntryPointSelector(stark_felt!( + "0x289da278a8dc833409cabfdad1581e8e7d40e42dcaed693fa4008dcdb4963b3" + )), + }, + EntryPoint { + offset: EntryPointOffset(0x1b2), + selector: EntryPointSelector(stark_felt!( + "0x29e211664c0b63c79638fbea474206ca74016b3e9a3dc4f9ac300ffd8bdf2cd" + )), + }, + EntryPoint { + offset: EntryPointOffset(0x285), + selector: EntryPointSelector(stark_felt!( + "0x36fcbf06cd96843058359e1a75928beacfac10727dab22a3972f0af8aa92895" + )), + }, + ]; + + // convert to serde and back + let serde: Vec = + non_serde.iter().map(|e| e.clone().into()).collect(); + + // convert to json + let json = serde_json::to_vec(&serde).unwrap(); + let serde: Vec = serde_json::from_slice(&json).unwrap(); + + let same_non_serde: Vec = serde.iter().map(|e| e.clone().into()).collect(); + + assert_eq!(non_serde, same_non_serde); + } + + #[test] + fn compress_and_decompress_contract_class() { + let class = + serde_json::from_slice(include_bytes!("../../benches/artifacts/dojo_world_240.json")) + .unwrap(); + + let class = CasmContractClass::from_contract_class(class, true).unwrap(); + let class = CompiledContractClass::V1(ContractClassV1::try_from(class).unwrap()); + + let compressed = StoredContractClass::from(class.clone()).compress(); + let decompressed = ::decompress(compressed).unwrap(); + + let actual_class = CompiledContractClass::from(decompressed); + + assert_eq!(class, actual_class); + } + + #[test] + fn compress_and_decompress_legacy_contract_class() { + let class: ContractClassV0 = serde_json::from_slice(include_bytes!( + "../../../../core/contracts/compiled/account.json" + )) + .unwrap(); + + let class = CompiledContractClass::V0(class); + + let compressed = StoredContractClass::from(class.clone()).compress(); + let decompressed = ::decompress(compressed).unwrap(); + + let actual_class = CompiledContractClass::from(decompressed); + + assert_eq!(class, actual_class); + } +} diff --git a/crates/katana/storage/db/src/models/contract.rs b/crates/katana/storage/db/src/models/contract.rs index 92894bb601..25c9818671 100644 --- a/crates/katana/storage/db/src/models/contract.rs +++ b/crates/katana/storage/db/src/models/contract.rs @@ -1,463 +1,65 @@ -//! Serializable without using custome functions - -use std::collections::HashMap; -use std::sync::Arc; - -use blockifier::execution::contract_class::{ - ContractClass, ContractClassV0, ContractClassV0Inner, ContractClassV1, ContractClassV1Inner, -}; -use cairo_vm::felt::Felt252; -use cairo_vm::hint_processor::hint_processor_definition::HintReference; -use cairo_vm::serde::deserialize_program::{ - ApTracking, Attribute, BuiltinName, FlowTrackingData, HintParams, Identifier, - InstructionLocation, Member, OffsetValue, -}; -use cairo_vm::types::program::{Program, SharedProgramData}; -use cairo_vm::types::relocatable::MaybeRelocatable; +use katana_primitives::block::BlockNumber; +use katana_primitives::contract::{ClassHash, ContractAddress, Nonce}; use serde::{Deserialize, Serialize}; -use starknet_api::core::EntryPointSelector; -use starknet_api::deprecated_contract_class::{EntryPoint, EntryPointOffset, EntryPointType}; - -#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] -pub enum StoredContractClass { - V0(StoredContractClassV0), - V1(StoredContractClassV1), -} - -#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] -pub struct StoredContractClassV0 { - pub program: SerializableProgram, - pub entry_points_by_type: HashMap>, -} - -#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] -pub struct StoredContractClassV1 { - pub program: SerializableProgram, - pub hints: HashMap>, - pub entry_points_by_type: HashMap>, -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -pub struct SerializableEntryPoint { - pub selector: EntryPointSelector, - pub offset: SerializableEntryPointOffset, -} - -impl From for SerializableEntryPoint { - fn from(value: EntryPoint) -> Self { - Self { selector: value.selector, offset: value.offset.into() } - } -} - -impl From for EntryPoint { - fn from(value: SerializableEntryPoint) -> Self { - Self { selector: value.selector, offset: value.offset.into() } - } -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -pub struct SerializableEntryPointOffset(pub usize); - -impl From for SerializableEntryPointOffset { - fn from(value: EntryPointOffset) -> Self { - Self(value.0) - } -} - -impl From for EntryPointOffset { - fn from(value: SerializableEntryPointOffset) -> Self { - Self(value.0) - } -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -pub struct SerializableEntryPointV1 { - pub selector: EntryPointSelector, - pub offset: SerializableEntryPointOffset, - pub builtins: Vec, -} - -impl From for blockifier::execution::contract_class::EntryPointV1 { - fn from(value: SerializableEntryPointV1) -> Self { - blockifier::execution::contract_class::EntryPointV1 { - selector: value.selector, - offset: value.offset.into(), - builtins: value.builtins, - } - } -} - -impl From for SerializableEntryPointV1 { - fn from(value: blockifier::execution::contract_class::EntryPointV1) -> Self { - SerializableEntryPointV1 { - selector: value.selector, - offset: value.offset.into(), - builtins: value.builtins, - } - } -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -pub struct SerializableProgram { - pub shared_program_data: SerializableSharedProgramData, - pub constants: HashMap, - pub builtins: Vec, -} - -impl From for SerializableProgram { - fn from(value: Program) -> Self { - Self { - shared_program_data: value.shared_program_data.as_ref().clone().into(), - constants: value.constants, - builtins: value.builtins, - } - } -} - -impl From for Program { - fn from(value: SerializableProgram) -> Self { - Self { - shared_program_data: Arc::new(value.shared_program_data.into()), - constants: value.constants, - builtins: value.builtins, - } - } -} - -// Fields of `SerializableProgramData` must not rely on `deserialize_any` -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -pub struct SerializableSharedProgramData { - pub data: Vec, - pub hints: HashMap>, - pub main: Option, - pub start: Option, - pub end: Option, - pub error_message_attributes: Vec, - pub instruction_locations: Option>, - pub identifiers: HashMap, - pub reference_manager: Vec, -} - -impl From for SerializableSharedProgramData { - fn from(value: SharedProgramData) -> Self { - Self { - data: value.data, - hints: value - .hints - .into_iter() - .map(|(k, v)| (k, v.into_iter().map(|h| h.into()).collect())) - .collect(), - main: value.main, - start: value.start, - end: value.end, - error_message_attributes: value - .error_message_attributes - .into_iter() - .map(|a| a.into()) - .collect(), - instruction_locations: value.instruction_locations, - identifiers: value.identifiers.into_iter().map(|(k, v)| (k, v.into())).collect(), - reference_manager: value.reference_manager.into_iter().map(|r| r.into()).collect(), - } - } -} - -impl From for SharedProgramData { - fn from(value: SerializableSharedProgramData) -> Self { - Self { - data: value.data, - hints: value - .hints - .into_iter() - .map(|(k, v)| (k, v.into_iter().map(|h| h.into()).collect())) - .collect(), - main: value.main, - start: value.start, - end: value.end, - error_message_attributes: value - .error_message_attributes - .into_iter() - .map(|a| a.into()) - .collect(), - instruction_locations: value.instruction_locations, - identifiers: value.identifiers.into_iter().map(|(k, v)| (k, v.into())).collect(), - reference_manager: value.reference_manager.into_iter().map(|r| r.into()).collect(), - } - } -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -pub struct SerializableHintParams { - pub code: String, - pub accessible_scopes: Vec, - pub flow_tracking_data: SerializableFlowTrackingData, -} - -impl From for SerializableHintParams { - fn from(value: HintParams) -> Self { - Self { - code: value.code, - accessible_scopes: value.accessible_scopes, - flow_tracking_data: value.flow_tracking_data.into(), - } - } -} - -impl From for HintParams { - fn from(value: SerializableHintParams) -> Self { - Self { - code: value.code, - accessible_scopes: value.accessible_scopes, - flow_tracking_data: value.flow_tracking_data.into(), - } - } -} - -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -pub struct SerializableIdentifier { - pub pc: Option, - pub type_: Option, - pub value: Option, - pub full_name: Option, - pub members: Option>, - pub cairo_type: Option, -} - -impl From for SerializableIdentifier { - fn from(value: Identifier) -> Self { - Self { - pc: value.pc, - type_: value.type_, - value: value.value, - full_name: value.full_name, - members: value.members, - cairo_type: value.cairo_type, - } - } -} - -impl From for Identifier { - fn from(value: SerializableIdentifier) -> Self { - Self { - pc: value.pc, - type_: value.type_, - value: value.value, - full_name: value.full_name, - members: value.members, - cairo_type: value.cairo_type, - } - } -} -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -pub struct SerializableHintReference { - pub offset1: OffsetValue, - pub offset2: OffsetValue, - pub dereference: bool, - pub ap_tracking_data: Option, - pub cairo_type: Option, -} +use crate::codecs::{Compress, Decode, Decompress, Encode}; -impl From for SerializableHintReference { - fn from(value: HintReference) -> Self { - Self { - offset1: value.offset1, - offset2: value.offset2, - dereference: value.dereference, - ap_tracking_data: value.ap_tracking_data, - cairo_type: value.cairo_type, - } - } -} +pub type BlockList = Vec; -impl From for HintReference { - fn from(value: SerializableHintReference) -> Self { - Self { - offset1: value.offset1, - offset2: value.offset2, - dereference: value.dereference, - ap_tracking_data: value.ap_tracking_data, - cairo_type: value.cairo_type, - } - } +#[derive(Debug, Default, Serialize, Deserialize)] +pub struct ContractInfoChangeList { + pub class_change_list: BlockList, + pub nonce_change_list: BlockList, } -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -pub struct SerializableAttribute { - pub name: String, - pub start_pc: usize, - pub end_pc: usize, - pub value: String, - pub flow_tracking_data: Option, +#[derive(Debug)] +pub struct ContractClassChange { + pub contract_address: ContractAddress, + /// The updated class hash of `contract_address`. + pub class_hash: ClassHash, } -impl From for SerializableAttribute { - fn from(value: Attribute) -> Self { - Self { - name: value.name, - start_pc: value.start_pc, - end_pc: value.end_pc, - value: value.value, - flow_tracking_data: value.flow_tracking_data.map(|d| d.into()), - } +impl Compress for ContractClassChange { + type Compressed = Vec; + fn compress(self) -> Self::Compressed { + let mut buf = Vec::new(); + buf.extend_from_slice(self.contract_address.encode().as_ref()); + buf.extend_from_slice(self.class_hash.compress().as_ref()); + buf } } -impl From for Attribute { - fn from(value: SerializableAttribute) -> Self { - Self { - name: value.name, - start_pc: value.start_pc, - end_pc: value.end_pc, - value: value.value, - flow_tracking_data: value.flow_tracking_data.map(|d| d.into()), - } +impl Decompress for ContractClassChange { + fn decompress>(bytes: B) -> Result { + let bytes = bytes.as_ref(); + let contract_address = ContractAddress::decode(&bytes[0..32])?; + let class_hash = ClassHash::decompress(&bytes[32..])?; + Ok(Self { contract_address, class_hash }) } } -#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)] -pub struct SerializableFlowTrackingData { - pub ap_tracking: ApTracking, - pub reference_ids: HashMap, +#[derive(Debug)] +pub struct ContractNonceChange { + pub contract_address: ContractAddress, + /// The updated nonce value of `contract_address`. + pub nonce: Nonce, } -impl From for SerializableFlowTrackingData { - fn from(value: FlowTrackingData) -> Self { - Self { ap_tracking: value.ap_tracking, reference_ids: value.reference_ids } +impl Compress for ContractNonceChange { + type Compressed = Vec; + fn compress(self) -> Self::Compressed { + let mut buf = Vec::new(); + buf.extend_from_slice(&self.contract_address.encode()); + buf.extend_from_slice(&self.nonce.compress()); + buf } } -impl From for FlowTrackingData { - fn from(value: SerializableFlowTrackingData) -> Self { - Self { ap_tracking: value.ap_tracking, reference_ids: value.reference_ids } - } -} - -impl From for ContractClass { - fn from(value: StoredContractClass) -> Self { - match value { - StoredContractClass::V0(v0) => { - ContractClass::V0(ContractClassV0(Arc::new(ContractClassV0Inner { - program: v0.program.into(), - entry_points_by_type: v0 - .entry_points_by_type - .into_iter() - .map(|(k, v)| (k, v.into_iter().map(|h| h.into()).collect())) - .collect(), - }))) - } - StoredContractClass::V1(v1) => { - ContractClass::V1(ContractClassV1(Arc::new(ContractClassV1Inner { - hints: v1 - .hints - .clone() - .into_iter() - .map(|(k, v)| (k, serde_json::from_slice(&v).expect("valid hint"))) - .collect(), - program: v1.program.into(), - entry_points_by_type: v1 - .entry_points_by_type - .into_iter() - .map(|(k, v)| { - ( - k, - v.into_iter() - .map(Into::into) - .collect::>(), - ) - }) - .collect::>(), - }))) - } - } - } -} - -impl From for StoredContractClass { - fn from(value: ContractClass) -> Self { - match value { - ContractClass::V0(v0) => { - let entry_points_by_type = v0 - .entry_points_by_type - .clone() - .into_iter() - .map(|(k, v)| (k, v.into_iter().map(SerializableEntryPoint::from).collect())) - .collect(); - - StoredContractClass::V0(StoredContractClassV0 { - program: v0.program.clone().into(), - entry_points_by_type, - }) - } - - ContractClass::V1(v1) => StoredContractClass::V1(StoredContractClassV1 { - program: v1.program.clone().into(), - entry_points_by_type: v1 - .entry_points_by_type - .clone() - .into_iter() - .map(|(k, v)| { - ( - k, - v.into_iter() - .map(Into::into) - .collect::>(), - ) - }) - .collect::>(), - hints: v1 - .hints - .clone() - .into_iter() - .map(|(k, v)| (k, serde_json::to_vec(&v).expect("valid hint"))) - .collect(), - }), - } - } -} - -#[cfg(test)] -mod tests { - use starknet_api::hash::StarkFelt; - use starknet_api::stark_felt; - - use super::*; - - #[test] - fn serialize_deserialize_legacy_entry_points() { - let non_serde = vec![ - EntryPoint { - offset: EntryPointOffset(0x25f), - selector: EntryPointSelector(stark_felt!( - "0x289da278a8dc833409cabfdad1581e8e7d40e42dcaed693fa4008dcdb4963b3" - )), - }, - EntryPoint { - offset: EntryPointOffset(0x1b2), - selector: EntryPointSelector(stark_felt!( - "0x29e211664c0b63c79638fbea474206ca74016b3e9a3dc4f9ac300ffd8bdf2cd" - )), - }, - EntryPoint { - offset: EntryPointOffset(0x285), - selector: EntryPointSelector(stark_felt!( - "0x36fcbf06cd96843058359e1a75928beacfac10727dab22a3972f0af8aa92895" - )), - }, - ]; - - // convert to serde and back - let serde: Vec = - non_serde.iter().map(|e| e.clone().into()).collect(); - - // convert to json - let json = serde_json::to_vec(&serde).unwrap(); - let serde: Vec = serde_json::from_slice(&json).unwrap(); - - let same_non_serde: Vec = serde.iter().map(|e| e.clone().into()).collect(); - - assert_eq!(non_serde, same_non_serde); +impl Decompress for ContractNonceChange { + fn decompress>(bytes: B) -> Result { + let bytes = bytes.as_ref(); + let contract_address = ContractAddress::decode(&bytes[0..32])?; + let nonce = Nonce::decompress(&bytes[32..])?; + Ok(Self { contract_address, nonce }) } } diff --git a/crates/katana/storage/db/src/models/mod.rs b/crates/katana/storage/db/src/models/mod.rs index 656ed6c7a2..66150ed28b 100644 --- a/crates/katana/storage/db/src/models/mod.rs +++ b/crates/katana/storage/db/src/models/mod.rs @@ -1,3 +1,4 @@ pub mod block; +pub mod class; pub mod contract; pub mod storage; diff --git a/crates/katana/storage/db/src/models/storage.rs b/crates/katana/storage/db/src/models/storage.rs index c4350421b9..6b1c3da54d 100644 --- a/crates/katana/storage/db/src/models/storage.rs +++ b/crates/katana/storage/db/src/models/storage.rs @@ -1,6 +1,7 @@ -use katana_primitives::contract::{StorageKey, StorageValue}; +use katana_primitives::block::BlockNumber; +use katana_primitives::contract::{ContractAddress, StorageKey, StorageValue}; -use crate::codecs::{Compress, Decompress}; +use crate::codecs::{Compress, Decode, Decompress, Encode}; use crate::error::CodecError; /// Represents a contract storage entry. @@ -33,3 +34,101 @@ impl Decompress for StorageEntry { Ok(Self { key, value }) } } + +#[derive(Debug)] +pub struct StorageEntryChangeList { + pub key: StorageKey, + pub block_list: Vec, +} + +impl Compress for StorageEntryChangeList { + type Compressed = Vec; + fn compress(self) -> Self::Compressed { + let mut buf = Vec::new(); + buf.extend_from_slice(&self.key.encode()); + buf.extend_from_slice(&self.block_list.compress()); + buf + } +} + +impl Decompress for StorageEntryChangeList { + fn decompress>(bytes: B) -> Result { + let bytes = bytes.as_ref(); + let key = StorageKey::decode(&bytes[0..32])?; + let blocks = Vec::::decompress(&bytes[32..])?; + Ok(Self { key, block_list: blocks }) + } +} + +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub struct ContractStorageKey { + pub contract_address: ContractAddress, + pub key: StorageKey, +} + +impl Encode for ContractStorageKey { + type Encoded = [u8; 64]; + fn encode(self) -> Self::Encoded { + let mut buf = [0u8; 64]; + buf[0..32].copy_from_slice(&self.contract_address.encode()); + buf[32..64].copy_from_slice(&self.key.encode()); + buf + } +} + +impl Decode for ContractStorageKey { + fn decode>(bytes: B) -> Result { + let bytes = bytes.as_ref(); + let contract_address = ContractAddress::decode(&bytes[0..32])?; + let key = StorageKey::decode(&bytes[32..])?; + Ok(Self { contract_address, key }) + } +} + +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub struct ContractStorageEntry { + pub key: ContractStorageKey, + pub value: StorageValue, +} + +impl Compress for ContractStorageEntry { + type Compressed = Vec; + fn compress(self) -> Self::Compressed { + let mut buf = Vec::with_capacity(64); + buf.extend_from_slice(self.key.encode().as_ref()); + buf.extend_from_slice(self.value.compress().as_ref()); + buf + } +} + +impl Decompress for ContractStorageEntry { + fn decompress>(bytes: B) -> Result { + let bytes = bytes.as_ref(); + let key = ContractStorageKey::decode(&bytes[0..64])?; + let value = StorageValue::decompress(&bytes[64..])?; + Ok(Self { key, value }) + } +} + +#[cfg(test)] +mod tests { + use starknet::macros::felt; + + use crate::codecs::{Compress, Decompress}; + + #[test] + fn compress_and_decompress_account_entry() { + let account_storage_entry = super::ContractStorageEntry { + key: super::ContractStorageKey { + contract_address: felt!("0x1234").into(), + key: felt!("0x111"), + }, + value: felt!("0x99"), + }; + + let compressed = account_storage_entry.clone().compress(); + let actual_value = super::ContractStorageEntry::decompress(compressed).unwrap(); + + assert_eq!(account_storage_entry, actual_value); + } +} diff --git a/crates/katana/storage/db/src/tables.rs b/crates/katana/storage/db/src/tables.rs index 6ec43b55c8..df6facb002 100644 --- a/crates/katana/storage/db/src/tables.rs +++ b/crates/katana/storage/db/src/tables.rs @@ -7,8 +7,11 @@ use katana_primitives::transaction::{Tx, TxHash, TxNumber}; use crate::codecs::{Compress, Decode, Decompress, Encode}; use crate::models::block::StoredBlockBodyIndices; -use crate::models::contract::StoredContractClass; -use crate::models::storage::StorageEntry; +use crate::models::class::StoredContractClass; +use crate::models::contract::{ContractClassChange, ContractInfoChangeList, ContractNonceChange}; +use crate::models::storage::{ + ContractStorageEntry, ContractStorageKey, StorageEntry, StorageEntryChangeList, +}; pub trait Key: Encode + Decode + Clone + std::fmt::Debug {} pub trait Value: Compress + Decompress + std::fmt::Debug {} @@ -43,7 +46,7 @@ pub enum TableType { DupSort, } -pub const NUM_TABLES: usize = 17; +pub const NUM_TABLES: usize = 22; /// Macro to declare `libmdbx` tables. #[macro_export] @@ -158,9 +161,14 @@ define_tables_enum! {[ (CompiledContractClasses, TableType::Table), (SierraClasses, TableType::Table), (ContractInfo, TableType::Table), - (ContractDeployments, TableType::DupSort), + (ContractStorage, TableType::DupSort), + (ClassDeclarationBlock, TableType::Table), (ClassDeclarations, TableType::DupSort), - (ContractStorage, TableType::DupSort) + (ContractInfoChangeSet, TableType::Table), + (NonceChanges, TableType::DupSort), + (ContractClassChanges, TableType::DupSort), + (StorageChanges, TableType::DupSort), + (StorageChangeSet, TableType::DupSort) ]} tables! { @@ -195,8 +203,58 @@ tables! { ContractInfo: (ContractAddress) => GenericContractInfo, /// Store contract storage ContractStorage: (ContractAddress, StorageKey) => StorageEntry, + + + /// Stores the block number where the class hash was declared. + ClassDeclarationBlock: (ClassHash) => BlockNumber, /// Stores the list of class hashes according to the block number it was declared in. ClassDeclarations: (BlockNumber, ClassHash) => ClassHash, - /// Store the list of contracts deployed in a block according to its block number. - ContractDeployments: (BlockNumber, ContractAddress) => ContractAddress + + /// Generic contract info change set. + /// + /// Stores the list of blocks where the contract info (nonce / class hash) has changed. + ContractInfoChangeSet: (ContractAddress) => ContractInfoChangeList, + + /// Contract nonce changes by block. + NonceChanges: (BlockNumber, ContractAddress) => ContractNonceChange, + /// Contract class hash changes by block. + ContractClassChanges: (BlockNumber, ContractAddress) => ContractClassChange, + + /// storage change set + StorageChangeSet: (ContractAddress, StorageKey) => StorageEntryChangeList, + /// Account storage change set + StorageChanges: (BlockNumber, ContractStorageKey) => ContractStorageEntry + +} + +#[cfg(test)] +mod tests { + #[test] + fn test_tables() { + use super::*; + + assert_eq!(Tables::ALL.len(), NUM_TABLES); + assert_eq!(Tables::ALL[0].name(), Headers::NAME); + assert_eq!(Tables::ALL[1].name(), BlockHashes::NAME); + assert_eq!(Tables::ALL[2].name(), BlockNumbers::NAME); + assert_eq!(Tables::ALL[3].name(), BlockBodyIndices::NAME); + assert_eq!(Tables::ALL[4].name(), BlockStatusses::NAME); + assert_eq!(Tables::ALL[5].name(), TxNumbers::NAME); + assert_eq!(Tables::ALL[6].name(), TxBlocks::NAME); + assert_eq!(Tables::ALL[7].name(), TxHashes::NAME); + assert_eq!(Tables::ALL[8].name(), Transactions::NAME); + assert_eq!(Tables::ALL[9].name(), Receipts::NAME); + assert_eq!(Tables::ALL[10].name(), CompiledClassHashes::NAME); + assert_eq!(Tables::ALL[11].name(), CompiledContractClasses::NAME); + assert_eq!(Tables::ALL[12].name(), SierraClasses::NAME); + assert_eq!(Tables::ALL[13].name(), ContractInfo::NAME); + assert_eq!(Tables::ALL[14].name(), ContractStorage::NAME); + assert_eq!(Tables::ALL[15].name(), ClassDeclarationBlock::NAME); + assert_eq!(Tables::ALL[16].name(), ClassDeclarations::NAME); + assert_eq!(Tables::ALL[17].name(), ContractInfoChangeSet::NAME); + assert_eq!(Tables::ALL[18].name(), NonceChanges::NAME); + assert_eq!(Tables::ALL[19].name(), ContractClassChanges::NAME); + assert_eq!(Tables::ALL[20].name(), StorageChanges::NAME); + assert_eq!(Tables::ALL[21].name(), StorageChangeSet::NAME); + } } diff --git a/crates/katana/storage/provider/Cargo.toml b/crates/katana/storage/provider/Cargo.toml index af9d492a08..5b4e4145da 100644 --- a/crates/katana/storage/provider/Cargo.toml +++ b/crates/katana/storage/provider/Cargo.toml @@ -33,4 +33,5 @@ rand = "0.8.5" rstest = "0.18.2" rstest_reuse = "0.6.0" starknet.workspace = true +tempfile = "3.8.1" url.workspace = true diff --git a/crates/katana/storage/provider/src/providers/db/mod.rs b/crates/katana/storage/provider/src/providers/db/mod.rs new file mode 100644 index 0000000000..14568e8bd2 --- /dev/null +++ b/crates/katana/storage/provider/src/providers/db/mod.rs @@ -0,0 +1,864 @@ +pub mod state; + +use std::collections::HashMap; +use std::fmt::Debug; +use std::ops::{Range, RangeInclusive}; + +use anyhow::Result; +use katana_db::error::DatabaseError; +use katana_db::mdbx::{self, DbEnv}; +use katana_db::models::block::StoredBlockBodyIndices; +use katana_db::models::contract::{ + ContractClassChange, ContractInfoChangeList, ContractNonceChange, +}; +use katana_db::models::storage::{ + ContractStorageEntry, ContractStorageKey, StorageEntry, StorageEntryChangeList, +}; +use katana_db::tables::{ + BlockBodyIndices, BlockHashes, BlockNumbers, BlockStatusses, ClassDeclarationBlock, + ClassDeclarations, CompiledClassHashes, CompiledContractClasses, ContractClassChanges, + ContractInfo, ContractInfoChangeSet, ContractStorage, DupSort, Headers, NonceChanges, Receipts, + SierraClasses, StorageChangeSet, StorageChanges, Table, Transactions, TxBlocks, TxHashes, + TxNumbers, +}; +use katana_db::utils::KeyValue; +use katana_primitives::block::{ + Block, BlockHash, BlockHashOrNumber, BlockNumber, BlockWithTxHashes, FinalityStatus, Header, + SealedBlockWithStatus, +}; +use katana_primitives::contract::{ + ClassHash, CompiledClassHash, ContractAddress, GenericContractInfo, Nonce, StorageKey, + StorageValue, +}; +use katana_primitives::receipt::Receipt; +use katana_primitives::state::{StateUpdates, StateUpdatesWithDeclaredClasses}; +use katana_primitives::transaction::{TxHash, TxNumber, TxWithHash}; +use katana_primitives::FieldElement; + +use crate::traits::block::{ + BlockHashProvider, BlockNumberProvider, BlockProvider, BlockStatusProvider, BlockWriter, + HeaderProvider, +}; +use crate::traits::state::{StateFactoryProvider, StateProvider, StateRootProvider}; +use crate::traits::state_update::StateUpdateProvider; +use crate::traits::transaction::{ + ReceiptProvider, TransactionProvider, TransactionStatusProvider, TransactionsProviderExt, +}; + +/// A provider implementation that uses a database as a backend. +#[derive(Debug)] +pub struct DbProvider(DbEnv); + +impl DbProvider { + /// Creates a new [`DbProvider`] from the given [`DbEnv`]. + pub fn new(db: DbEnv) -> Self { + Self(db) + } +} + +impl StateFactoryProvider for DbProvider { + fn latest(&self) -> Result> { + Ok(Box::new(self::state::LatestStateProvider::new(self.0.tx()?))) + } + + fn historical(&self, block_id: BlockHashOrNumber) -> Result>> { + let block_number = match block_id { + BlockHashOrNumber::Num(num) => { + let latest_num = self.latest_number()?; + + match num.cmp(&latest_num) { + std::cmp::Ordering::Less => Some(num), + std::cmp::Ordering::Greater => return Ok(None), + std::cmp::Ordering::Equal => return self.latest().map(Some), + } + } + + BlockHashOrNumber::Hash(hash) => self.block_number_by_hash(hash)?, + }; + + let Some(num) = block_number else { return Ok(None) }; + + Ok(Some(Box::new(self::state::HistoricalStateProvider::new(self.0.tx()?, num)))) + } +} + +impl BlockNumberProvider for DbProvider { + fn block_number_by_hash(&self, hash: BlockHash) -> Result> { + let db_tx = self.0.tx()?; + let block_num = db_tx.get::(hash)?; + db_tx.commit()?; + Ok(block_num) + } + + fn latest_number(&self) -> Result { + let db_tx = self.0.tx()?; + let total_blocks = db_tx.entries::()? as u64; + db_tx.commit()?; + Ok(if total_blocks == 0 { 0 } else { total_blocks - 1 }) + } +} + +impl BlockHashProvider for DbProvider { + fn latest_hash(&self) -> Result { + let db_tx = self.0.tx()?; + let total_blocks = db_tx.entries::()? as u64; + let latest_block = if total_blocks == 0 { 0 } else { total_blocks - 1 }; + let latest_hash = db_tx.get::(latest_block)?.expect("block hash should exist"); + db_tx.commit()?; + Ok(latest_hash) + } + + fn block_hash_by_num(&self, num: BlockNumber) -> Result> { + let db_tx = self.0.tx()?; + let block_hash = db_tx.get::(num)?; + db_tx.commit()?; + Ok(block_hash) + } +} + +impl HeaderProvider for DbProvider { + fn header(&self, id: BlockHashOrNumber) -> Result> { + let db_tx = self.0.tx()?; + + let num = match id { + BlockHashOrNumber::Num(num) => Some(num), + BlockHashOrNumber::Hash(hash) => db_tx.get::(hash)?, + }; + + if let Some(num) = num { + let header = db_tx.get::(num)?.expect("should exist"); + db_tx.commit()?; + Ok(Some(header)) + } else { + Ok(None) + } + } +} + +impl BlockProvider for DbProvider { + fn block_body_indices(&self, id: BlockHashOrNumber) -> Result> { + let db_tx = self.0.tx()?; + + let block_num = match id { + BlockHashOrNumber::Num(num) => Some(num), + BlockHashOrNumber::Hash(hash) => db_tx.get::(hash)?, + }; + + if let Some(num) = block_num { + let indices = db_tx.get::(num)?; + db_tx.commit()?; + Ok(indices) + } else { + Ok(None) + } + } + + fn block(&self, id: BlockHashOrNumber) -> Result> { + let db_tx = self.0.tx()?; + + if let Some(header) = self.header(id)? { + let body = self.transactions_by_block(id)?.expect("should exist"); + db_tx.commit()?; + Ok(Some(Block { header, body })) + } else { + Ok(None) + } + } + + fn block_with_tx_hashes(&self, id: BlockHashOrNumber) -> Result> { + let db_tx = self.0.tx()?; + + let block_num = match id { + BlockHashOrNumber::Num(num) => Some(num), + BlockHashOrNumber::Hash(hash) => db_tx.get::(hash)?, + }; + + let Some(block_num) = block_num else { return Ok(None) }; + + if let Some(header) = db_tx.get::(block_num)? { + let body_indices = db_tx.get::(block_num)?.expect("should exist"); + let body = self.transaction_hashes_in_range(Range::from(body_indices))?; + let block = BlockWithTxHashes { header, body }; + + db_tx.commit()?; + + Ok(Some(block)) + } else { + Ok(None) + } + } + + fn blocks_in_range(&self, range: RangeInclusive) -> Result> { + let db_tx = self.0.tx()?; + + let total = range.end() - range.start() + 1; + let mut blocks = Vec::with_capacity(total as usize); + + for num in range { + if let Some(header) = db_tx.get::(num)? { + let body_indices = db_tx.get::(num)?.expect("should exist"); + let body = self.transaction_in_range(Range::from(body_indices))?; + blocks.push(Block { header, body }) + } + } + + db_tx.commit()?; + Ok(blocks) + } +} + +impl BlockStatusProvider for DbProvider { + fn block_status(&self, id: BlockHashOrNumber) -> Result> { + let db_tx = self.0.tx()?; + + let block_num = match id { + BlockHashOrNumber::Num(num) => Some(num), + BlockHashOrNumber::Hash(hash) => self.block_number_by_hash(hash)?, + }; + + if let Some(block_num) = block_num { + let status = db_tx.get::(block_num)?.expect("should exist"); + db_tx.commit()?; + Ok(Some(status)) + } else { + Ok(None) + } + } +} + +impl StateRootProvider for DbProvider { + fn state_root(&self, block_id: BlockHashOrNumber) -> Result> { + let db_tx = self.0.tx()?; + + let block_num = match block_id { + BlockHashOrNumber::Num(num) => Some(num), + BlockHashOrNumber::Hash(hash) => db_tx.get::(hash)?, + }; + + if let Some(block_num) = block_num { + let header = db_tx.get::(block_num)?; + db_tx.commit()?; + Ok(header.map(|h| h.state_root)) + } else { + Ok(None) + } + } +} + +impl StateUpdateProvider for DbProvider { + fn state_update(&self, block_id: BlockHashOrNumber) -> Result> { + // A helper function that iterates over all entries in a dupsort table and collects the + // results into `V`. If `key` is not found, `V::default()` is returned. + fn dup_entries( + db_tx: &mdbx::tx::TxRO, + key: ::Key, + f: impl FnMut(Result, DatabaseError>) -> Result, + ) -> Result + where + Tb: DupSort + Debug, + V: FromIterator + Default, + { + Ok(db_tx + .cursor::()? + .walk_dup(Some(key), None)? + .map(|walker| walker.map(f).collect::>()) + .transpose()? + .unwrap_or_default()) + } + + let db_tx = self.0.tx()?; + let block_num = self.block_number_by_id(block_id)?; + + if let Some(block_num) = block_num { + let nonce_updates = dup_entries::, _>( + &db_tx, + block_num, + |entry| { + let (_, ContractNonceChange { contract_address, nonce }) = entry?; + Ok((contract_address, nonce)) + }, + )?; + + let contract_updates = dup_entries::< + ContractClassChanges, + HashMap, + _, + >(&db_tx, block_num, |entry| { + let (_, ContractClassChange { contract_address, class_hash }) = entry?; + Ok((contract_address, class_hash)) + })?; + + let declared_classes = dup_entries::< + ClassDeclarations, + HashMap, + _, + >(&db_tx, block_num, |entry| { + let (_, class_hash) = entry?; + let compiled_hash = + db_tx.get::(class_hash)?.expect("qed; must exist"); + Ok((class_hash, compiled_hash)) + })?; + + let storage_updates = { + let entries = dup_entries::< + StorageChanges, + Vec<(ContractAddress, (StorageKey, StorageValue))>, + _, + >(&db_tx, block_num, |entry| { + let (_, ContractStorageEntry { key, value }) = entry?; + Ok::<_, DatabaseError>((key.contract_address, (key.key, value))) + })?; + + let mut map: HashMap<_, HashMap> = HashMap::new(); + + entries.into_iter().for_each(|(addr, (key, value))| { + map.entry(addr).or_default().insert(key, value); + }); + + map + }; + + db_tx.commit()?; + Ok(Some(StateUpdates { + nonce_updates, + storage_updates, + contract_updates, + declared_classes, + })) + } else { + Ok(None) + } + } +} + +impl TransactionProvider for DbProvider { + fn transaction_by_hash(&self, hash: TxHash) -> Result> { + let db_tx = self.0.tx()?; + + if let Some(num) = db_tx.get::(hash)? { + let transaction = db_tx.get::(num)?.expect("transaction should exist"); + let transaction = TxWithHash { hash, transaction }; + db_tx.commit()?; + + Ok(Some(transaction)) + } else { + Ok(None) + } + } + + fn transactions_by_block( + &self, + block_id: BlockHashOrNumber, + ) -> Result>> { + if let Some(indices) = self.block_body_indices(block_id)? { + Ok(Some(self.transaction_in_range(Range::from(indices))?)) + } else { + Ok(None) + } + } + + fn transaction_in_range(&self, range: Range) -> Result> { + let db_tx = self.0.tx()?; + + let total = range.end - range.start; + let mut transactions = Vec::with_capacity(total as usize); + + for i in range { + if let Some(transaction) = db_tx.get::(i)? { + let hash = db_tx.get::(i)?.expect("should exist"); + transactions.push(TxWithHash { hash, transaction }); + }; + } + + db_tx.commit()?; + Ok(transactions) + } + + fn transaction_block_num_and_hash( + &self, + hash: TxHash, + ) -> Result> { + let db_tx = self.0.tx()?; + if let Some(num) = db_tx.get::(hash)? { + let block_num = db_tx.get::(num)?.expect("should exist"); + let block_hash = db_tx.get::(block_num)?.expect("should exist"); + db_tx.commit()?; + Ok(Some((block_num, block_hash))) + } else { + Ok(None) + } + } + + fn transaction_by_block_and_idx( + &self, + block_id: BlockHashOrNumber, + idx: u64, + ) -> Result> { + let db_tx = self.0.tx()?; + + match self.block_body_indices(block_id)? { + // make sure the requested idx is within the range of the block tx count + Some(indices) if idx < indices.tx_count => { + let num = indices.tx_offset + idx; + let hash = db_tx.get::(num)?.expect("should exist"); + let transaction = db_tx.get::(num)?.expect("should exist"); + let transaction = TxWithHash { hash, transaction }; + db_tx.commit()?; + Ok(Some(transaction)) + } + + _ => Ok(None), + } + } + + fn transaction_count_by_block(&self, block_id: BlockHashOrNumber) -> Result> { + let db_tx = self.0.tx()?; + if let Some(indices) = self.block_body_indices(block_id)? { + db_tx.commit()?; + Ok(Some(indices.tx_count)) + } else { + Ok(None) + } + } +} + +impl TransactionsProviderExt for DbProvider { + fn transaction_hashes_in_range(&self, range: Range) -> Result> { + let db_tx = self.0.tx()?; + + let total = range.end - range.start; + let mut hashes = Vec::with_capacity(total as usize); + + for i in range { + if let Some(hash) = db_tx.get::(i)? { + hashes.push(hash); + } + } + + db_tx.commit()?; + Ok(hashes) + } +} + +impl TransactionStatusProvider for DbProvider { + fn transaction_status(&self, hash: TxHash) -> Result> { + let db_tx = self.0.tx()?; + if let Some(tx_num) = db_tx.get::(hash)? { + let block_num = db_tx.get::(tx_num)?.expect("should exist"); + let status = db_tx.get::(block_num)?.expect("should exist"); + db_tx.commit()?; + Ok(Some(status)) + } else { + Ok(None) + } + } +} + +impl ReceiptProvider for DbProvider { + fn receipt_by_hash(&self, hash: TxHash) -> Result> { + let db_tx = self.0.tx()?; + if let Some(num) = db_tx.get::(hash)? { + let receipt = db_tx.get::(num)?.expect("should exist"); + db_tx.commit()?; + Ok(Some(receipt)) + } else { + Ok(None) + } + } + + fn receipts_by_block(&self, block_id: BlockHashOrNumber) -> Result>> { + if let Some(indices) = self.block_body_indices(block_id)? { + let db_tx = self.0.tx()?; + let mut receipts = Vec::with_capacity(indices.tx_count as usize); + + let range = indices.tx_offset..indices.tx_offset + indices.tx_count; + for i in range { + if let Some(receipt) = db_tx.get::(i)? { + receipts.push(receipt); + } + } + + db_tx.commit()?; + Ok(Some(receipts)) + } else { + Ok(None) + } + } +} + +impl BlockWriter for DbProvider { + fn insert_block_with_states_and_receipts( + &self, + block: SealedBlockWithStatus, + states: StateUpdatesWithDeclaredClasses, + receipts: Vec, + ) -> Result<()> { + self.0.update(move |db_tx| -> Result<()> { + let block_hash = block.block.header.hash; + let block_number = block.block.header.header.number; + + let block_header = block.block.header.header; + let transactions = block.block.body; + + let tx_count = transactions.len() as u64; + let tx_offset = db_tx.entries::()? as u64; + let block_body_indices = StoredBlockBodyIndices { tx_offset, tx_count }; + + db_tx.put::(block_number, block_hash)?; + db_tx.put::(block_hash, block_number)?; + db_tx.put::(block_number, block.status)?; + + db_tx.put::(block_number, block_header)?; + db_tx.put::(block_number, block_body_indices)?; + + for (i, (transaction, receipt)) in transactions.into_iter().zip(receipts).enumerate() { + let tx_number = tx_offset + i as u64; + let tx_hash = transaction.hash; + + db_tx.put::(tx_number, tx_hash)?; + db_tx.put::(tx_hash, tx_number)?; + db_tx.put::(tx_number, block_number)?; + db_tx.put::(tx_number, transaction.transaction)?; + db_tx.put::(tx_number, receipt)?; + } + + // insert classes + + for (class_hash, compiled_hash) in states.state_updates.declared_classes { + db_tx.put::(class_hash, compiled_hash)?; + + db_tx.put::(class_hash, block_number)?; + db_tx.put::(block_number, class_hash)? + } + + for (hash, compiled_class) in states.declared_compiled_classes { + db_tx.put::(hash, compiled_class.into())?; + } + + for (class_hash, sierra_class) in states.declared_sierra_classes { + db_tx.put::(class_hash, sierra_class)?; + } + + // insert storage changes + { + let mut storage_cursor = db_tx.cursor::()?; + for (addr, entries) in states.state_updates.storage_updates { + let entries = + entries.into_iter().map(|(key, value)| StorageEntry { key, value }); + + for entry in entries { + match storage_cursor.seek_by_key_subkey(addr, entry.key)? { + Some(current) if current.key == entry.key => { + storage_cursor.delete_current()?; + } + + _ => {} + } + + let mut change_set_cursor = db_tx.cursor::()?; + let new_block_list = + match change_set_cursor.seek_by_key_subkey(addr, entry.key)? { + Some(StorageEntryChangeList { mut block_list, key }) + if key == entry.key => + { + change_set_cursor.delete_current()?; + + block_list.push(block_number); + block_list.sort(); + block_list + } + + _ => { + vec![block_number] + } + }; + + change_set_cursor.upsert( + addr, + StorageEntryChangeList { key: entry.key, block_list: new_block_list }, + )?; + storage_cursor.upsert(addr, entry)?; + + let storage_change_sharded_key = + ContractStorageKey { contract_address: addr, key: entry.key }; + + db_tx.put::( + block_number, + ContractStorageEntry { + key: storage_change_sharded_key, + value: entry.value, + }, + )?; + } + } + } + + // update contract info + + for (addr, class_hash) in states.state_updates.contract_updates { + let value = if let Some(info) = db_tx.get::(addr)? { + GenericContractInfo { class_hash, ..info } + } else { + GenericContractInfo { class_hash, ..Default::default() } + }; + + let new_change_set = + if let Some(mut change_set) = db_tx.get::(addr)? { + change_set.class_change_list.push(block_number); + change_set.class_change_list.sort(); + change_set + } else { + ContractInfoChangeList { + class_change_list: vec![block_number], + ..Default::default() + } + }; + + db_tx.put::(addr, value)?; + + let class_change_key = ContractClassChange { contract_address: addr, class_hash }; + db_tx.put::(block_number, class_change_key)?; + db_tx.put::(addr, new_change_set)?; + } + + for (addr, nonce) in states.state_updates.nonce_updates { + let value = if let Some(info) = db_tx.get::(addr)? { + GenericContractInfo { nonce, ..info } + } else { + GenericContractInfo { nonce, ..Default::default() } + }; + + let new_change_set = + if let Some(mut change_set) = db_tx.get::(addr)? { + change_set.nonce_change_list.push(block_number); + change_set.nonce_change_list.sort(); + change_set + } else { + ContractInfoChangeList { + nonce_change_list: vec![block_number], + ..Default::default() + } + }; + + db_tx.put::(addr, value)?; + + let nonce_change_key = ContractNonceChange { contract_address: addr, nonce }; + db_tx.put::(block_number, nonce_change_key)?; + db_tx.put::(addr, new_change_set)?; + } + + Ok(()) + })? + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashMap; + + use katana_db::mdbx::DbEnvKind; + use katana_primitives::block::{ + Block, BlockHashOrNumber, FinalityStatus, Header, SealedBlockWithStatus, + }; + use katana_primitives::contract::ContractAddress; + use katana_primitives::receipt::Receipt; + use katana_primitives::state::{StateUpdates, StateUpdatesWithDeclaredClasses}; + use katana_primitives::transaction::{Tx, TxHash, TxWithHash}; + use starknet::macros::felt; + + use super::DbProvider; + use crate::traits::block::{ + BlockHashProvider, BlockNumberProvider, BlockProvider, BlockStatusProvider, BlockWriter, + }; + use crate::traits::state::StateFactoryProvider; + use crate::traits::transaction::TransactionProvider; + + fn create_dummy_block() -> SealedBlockWithStatus { + let header = Header { parent_hash: 199u8.into(), number: 0, ..Default::default() }; + let block = Block { + header, + body: vec![TxWithHash { + hash: 24u8.into(), + transaction: Tx::Invoke(Default::default()), + }], + } + .seal(); + SealedBlockWithStatus { block, status: FinalityStatus::AcceptedOnL2 } + } + + fn create_dummy_state_updates() -> StateUpdatesWithDeclaredClasses { + StateUpdatesWithDeclaredClasses { + state_updates: StateUpdates { + nonce_updates: HashMap::from([ + (ContractAddress::from(felt!("1")), felt!("1")), + (ContractAddress::from(felt!("2")), felt!("2")), + ]), + contract_updates: HashMap::from([ + (ContractAddress::from(felt!("1")), felt!("3")), + (ContractAddress::from(felt!("2")), felt!("4")), + ]), + declared_classes: HashMap::from([ + (felt!("3"), felt!("89")), + (felt!("4"), felt!("90")), + ]), + storage_updates: HashMap::from([( + ContractAddress::from(felt!("1")), + HashMap::from([(felt!("1"), felt!("1")), (felt!("2"), felt!("2"))]), + )]), + }, + ..Default::default() + } + } + + fn create_dummy_state_updates_2() -> StateUpdatesWithDeclaredClasses { + StateUpdatesWithDeclaredClasses { + state_updates: StateUpdates { + nonce_updates: HashMap::from([ + (ContractAddress::from(felt!("1")), felt!("5")), + (ContractAddress::from(felt!("2")), felt!("6")), + ]), + contract_updates: HashMap::from([ + (ContractAddress::from(felt!("1")), felt!("77")), + (ContractAddress::from(felt!("2")), felt!("66")), + ]), + storage_updates: HashMap::from([( + ContractAddress::from(felt!("1")), + HashMap::from([(felt!("1"), felt!("100")), (felt!("2"), felt!("200"))]), + )]), + ..Default::default() + }, + ..Default::default() + } + } + + fn create_db_provider() -> DbProvider { + DbProvider(katana_db::mdbx::test_utils::create_test_db(DbEnvKind::RW)) + } + + #[test] + fn insert_block() { + let provider = create_db_provider(); + + let block = create_dummy_block(); + let state_updates = create_dummy_state_updates(); + + // insert block + BlockWriter::insert_block_with_states_and_receipts( + &provider, + block.clone(), + state_updates, + vec![Receipt::Invoke(Default::default())], + ) + .expect("failed to insert block"); + + // get values + + let block_id: BlockHashOrNumber = block.block.header.hash.into(); + + let latest_number = provider.latest_number().unwrap(); + let latest_hash = provider.latest_hash().unwrap(); + + let actual_block = provider.block(block_id).unwrap().unwrap(); + let tx_count = provider.transaction_count_by_block(block_id).unwrap().unwrap(); + let block_status = provider.block_status(block_id).unwrap().unwrap(); + let body_indices = provider.block_body_indices(block_id).unwrap().unwrap(); + + let tx_hash: TxHash = 24u8.into(); + let tx = provider.transaction_by_hash(tx_hash).unwrap().unwrap(); + + let state_prov = StateFactoryProvider::latest(&provider).unwrap(); + + let nonce1 = state_prov.nonce(ContractAddress::from(felt!("1"))).unwrap().unwrap(); + let nonce2 = state_prov.nonce(ContractAddress::from(felt!("2"))).unwrap().unwrap(); + + let class_hash1 = state_prov.class_hash_of_contract(felt!("1").into()).unwrap().unwrap(); + let class_hash2 = state_prov.class_hash_of_contract(felt!("2").into()).unwrap().unwrap(); + + let compiled_hash1 = + state_prov.compiled_class_hash_of_class_hash(class_hash1).unwrap().unwrap(); + let compiled_hash2 = + state_prov.compiled_class_hash_of_class_hash(class_hash2).unwrap().unwrap(); + + let storage1 = + state_prov.storage(ContractAddress::from(felt!("1")), felt!("1")).unwrap().unwrap(); + let storage2 = + state_prov.storage(ContractAddress::from(felt!("1")), felt!("2")).unwrap().unwrap(); + + // assert values are populated correctly + + assert_eq!(tx_hash, tx.hash); + assert_eq!(tx.transaction, Tx::Invoke(Default::default())); + + assert_eq!(tx_count, 1); + assert_eq!(body_indices.tx_offset, 0); + assert_eq!(body_indices.tx_count, tx_count); + + assert_eq!(block_status, FinalityStatus::AcceptedOnL2); + assert_eq!(block.block.header.hash, latest_hash); + assert_eq!(block.block.body.len() as u64, tx_count); + assert_eq!(block.block.header.header.number, latest_number); + assert_eq!(block.block.unseal(), actual_block); + + assert_eq!(nonce1, felt!("1")); + assert_eq!(nonce2, felt!("2")); + assert_eq!(class_hash1, felt!("3")); + assert_eq!(class_hash2, felt!("4")); + + assert_eq!(compiled_hash1, felt!("89")); + assert_eq!(compiled_hash2, felt!("90")); + + assert_eq!(storage1, felt!("1")); + assert_eq!(storage2, felt!("2")); + } + + #[test] + fn storage_updated_correctly() { + let provider = create_db_provider(); + + let block = create_dummy_block(); + let state_updates1 = create_dummy_state_updates(); + let state_updates2 = create_dummy_state_updates_2(); + + // insert block + BlockWriter::insert_block_with_states_and_receipts( + &provider, + block.clone(), + state_updates1, + vec![Receipt::Invoke(Default::default())], + ) + .expect("failed to insert block"); + + // insert another block + BlockWriter::insert_block_with_states_and_receipts( + &provider, + block, + state_updates2, + vec![Receipt::Invoke(Default::default())], + ) + .expect("failed to insert block"); + + // assert storage is updated correctly + + let state_prov = StateFactoryProvider::latest(&provider).unwrap(); + + let nonce1 = state_prov.nonce(ContractAddress::from(felt!("1"))).unwrap().unwrap(); + let nonce2 = state_prov.nonce(ContractAddress::from(felt!("2"))).unwrap().unwrap(); + + let class_hash1 = state_prov.class_hash_of_contract(felt!("1").into()).unwrap().unwrap(); + let class_hash2 = state_prov.class_hash_of_contract(felt!("2").into()).unwrap().unwrap(); + + let storage1 = + state_prov.storage(ContractAddress::from(felt!("1")), felt!("1")).unwrap().unwrap(); + let storage2 = + state_prov.storage(ContractAddress::from(felt!("1")), felt!("2")).unwrap().unwrap(); + + assert_eq!(nonce1, felt!("5")); + assert_eq!(nonce2, felt!("6")); + + assert_eq!(class_hash1, felt!("77")); + assert_eq!(class_hash2, felt!("66")); + + assert_eq!(storage1, felt!("100")); + assert_eq!(storage2, felt!("200")); + } +} diff --git a/crates/katana/storage/provider/src/providers/db/state.rs b/crates/katana/storage/provider/src/providers/db/state.rs new file mode 100644 index 0000000000..0230ade4e5 --- /dev/null +++ b/crates/katana/storage/provider/src/providers/db/state.rs @@ -0,0 +1,337 @@ +use std::cmp::Ordering; + +use anyhow::Result; +use katana_db::mdbx::{self}; +use katana_db::models::contract::{ + ContractClassChange, ContractInfoChangeList, ContractNonceChange, +}; +use katana_db::models::storage::{ContractStorageEntry, ContractStorageKey, StorageEntry}; +use katana_db::tables::{ + ClassDeclarationBlock, CompiledClassHashes, CompiledContractClasses, ContractClassChanges, + ContractInfo, ContractInfoChangeSet, ContractStorage, NonceChanges, SierraClasses, + StorageChangeSet, StorageChanges, +}; +use katana_primitives::block::BlockNumber; +use katana_primitives::contract::{ + ClassHash, CompiledClassHash, CompiledContractClass, ContractAddress, GenericContractInfo, + Nonce, SierraClass, StorageKey, StorageValue, +}; + +use super::DbProvider; +use crate::traits::contract::{ContractClassProvider, ContractClassWriter}; +use crate::traits::state::{StateProvider, StateWriter}; + +impl StateWriter for DbProvider { + fn set_nonce(&self, address: ContractAddress, nonce: Nonce) -> Result<()> { + self.0.update(move |db_tx| -> Result<()> { + let value = if let Some(info) = db_tx.get::(address)? { + GenericContractInfo { nonce, ..info } + } else { + GenericContractInfo { nonce, ..Default::default() } + }; + db_tx.put::(address, value)?; + Ok(()) + })? + } + + fn set_storage( + &self, + address: ContractAddress, + storage_key: StorageKey, + storage_value: StorageValue, + ) -> Result<()> { + self.0.update(move |db_tx| -> Result<()> { + let mut cursor = db_tx.cursor::()?; + let entry = cursor.seek_by_key_subkey(address, storage_key)?; + + match entry { + Some(entry) if entry.key == storage_key => { + cursor.delete_current()?; + } + _ => {} + } + + cursor.upsert(address, StorageEntry { key: storage_key, value: storage_value })?; + Ok(()) + })? + } + + fn set_class_hash_of_contract( + &self, + address: ContractAddress, + class_hash: ClassHash, + ) -> Result<()> { + self.0.update(move |db_tx| -> Result<()> { + let value = if let Some(info) = db_tx.get::(address)? { + GenericContractInfo { class_hash, ..info } + } else { + GenericContractInfo { class_hash, ..Default::default() } + }; + db_tx.put::(address, value)?; + Ok(()) + })? + } +} + +impl ContractClassWriter for DbProvider { + fn set_class(&self, hash: ClassHash, class: CompiledContractClass) -> Result<()> { + self.0.update(move |db_tx| -> Result<()> { + db_tx.put::(hash, class.into())?; + Ok(()) + })? + } + + fn set_compiled_class_hash_of_class_hash( + &self, + hash: ClassHash, + compiled_hash: CompiledClassHash, + ) -> Result<()> { + self.0.update(move |db_tx| -> Result<()> { + db_tx.put::(hash, compiled_hash)?; + Ok(()) + })? + } + + fn set_sierra_class(&self, hash: ClassHash, sierra: SierraClass) -> Result<()> { + self.0.update(move |db_tx| -> Result<()> { + db_tx.put::(hash, sierra)?; + Ok(()) + })? + } +} + +/// A state provider that provides the latest states from the database. +pub(super) struct LatestStateProvider(mdbx::tx::TxRO); + +impl LatestStateProvider { + pub fn new(tx: mdbx::tx::TxRO) -> Self { + Self(tx) + } +} + +impl ContractClassProvider for LatestStateProvider { + fn class(&self, hash: ClassHash) -> Result> { + let class = self.0.get::(hash)?; + Ok(class.map(CompiledContractClass::from)) + } + + fn compiled_class_hash_of_class_hash( + &self, + hash: ClassHash, + ) -> Result> { + let hash = self.0.get::(hash)?; + Ok(hash) + } + + fn sierra_class(&self, hash: ClassHash) -> Result> { + let class = self.0.get::(hash)?; + Ok(class) + } +} + +impl StateProvider for LatestStateProvider { + fn nonce(&self, address: ContractAddress) -> Result> { + let info = self.0.get::(address)?; + Ok(info.map(|info| info.nonce)) + } + + fn class_hash_of_contract( + &self, + address: ContractAddress, + ) -> Result> { + let info = self.0.get::(address)?; + Ok(info.map(|info| info.class_hash)) + } + + fn storage( + &self, + address: ContractAddress, + storage_key: StorageKey, + ) -> Result> { + let mut cursor = self.0.cursor::()?; + let entry = cursor.seek_by_key_subkey(address, storage_key)?; + match entry { + Some(entry) if entry.key == storage_key => Ok(Some(entry.value)), + _ => Ok(None), + } + } +} + +/// A historical state provider. +pub(super) struct HistoricalStateProvider { + /// The database transaction used to read the database. + tx: mdbx::tx::TxRO, + /// The block number of the state. + block_number: u64, +} + +impl HistoricalStateProvider { + pub fn new(tx: mdbx::tx::TxRO, block_number: u64) -> Self { + Self { tx, block_number } + } + + // This looks ugly but it works and I will most likely forget how it works + // if I don't document it. But im lazy. + fn recent_block_change_relative_to_pinned_block_num( + block_number: BlockNumber, + block_list: &[BlockNumber], + ) -> Option { + if block_list.first().is_some_and(|num| block_number < *num) { + return None; + } + + // if the pinned block number is smaller than the first block number in the list, + // then that means there is no change happening before the pinned block number. + let pos = { + if let Some(pos) = block_list.last().and_then(|num| { + if block_number >= *num { Some(block_list.len() - 1) } else { None } + }) { + Some(pos) + } else { + block_list.iter().enumerate().find_map(|(i, num)| match block_number.cmp(num) { + Ordering::Equal => Some(i), + Ordering::Greater => None, + Ordering::Less => { + if i == 0 || block_number == 0 { + None + } else { + Some(i - 1) + } + } + }) + } + }?; + + block_list.get(pos).copied() + } +} + +impl ContractClassProvider for HistoricalStateProvider { + fn compiled_class_hash_of_class_hash( + &self, + hash: ClassHash, + ) -> Result> { + // check that the requested class hash was declared before the pinned block number + if !self.tx.get::(hash)?.is_some_and(|num| num <= self.block_number) + { + return Ok(None); + }; + + Ok(self.tx.get::(hash)?) + } + + fn class(&self, hash: ClassHash) -> Result> { + self.compiled_class_hash_of_class_hash(hash).and_then(|_| { + let contract = self.tx.get::(hash)?; + Ok(contract.map(CompiledContractClass::from)) + }) + } + + fn sierra_class(&self, hash: ClassHash) -> Result> { + self.compiled_class_hash_of_class_hash(hash) + .and_then(|_| self.tx.get::(hash).map_err(|e| e.into())) + } +} + +impl StateProvider for HistoricalStateProvider { + fn nonce(&self, address: ContractAddress) -> Result> { + let change_list = self.tx.get::(address)?; + + if let Some(num) = change_list.and_then(|entry| { + Self::recent_block_change_relative_to_pinned_block_num( + self.block_number, + &entry.nonce_change_list, + ) + }) { + let mut cursor = self.tx.cursor::()?; + let ContractNonceChange { contract_address, nonce } = cursor + .seek_by_key_subkey(num, address)? + .expect("if block number is in the block set, change entry must exist"); + + if contract_address == address { + return Ok(Some(nonce)); + } + } + + Ok(None) + } + + fn class_hash_of_contract(&self, address: ContractAddress) -> Result> { + let change_list: Option = + self.tx.get::(address)?; + + if let Some(num) = change_list.and_then(|entry| { + Self::recent_block_change_relative_to_pinned_block_num( + self.block_number, + &entry.class_change_list, + ) + }) { + let mut cursor = self.tx.cursor::()?; + let ContractClassChange { contract_address, class_hash } = cursor + .seek_by_key_subkey(num, address)? + .expect("if block number is in the block set, change entry must exist"); + + if contract_address == address { + return Ok(Some(class_hash)); + } + } + + Ok(None) + } + + fn storage( + &self, + address: ContractAddress, + storage_key: StorageKey, + ) -> Result> { + let mut cursor = self.tx.cursor::()?; + + if let Some(num) = cursor.seek_by_key_subkey(address, storage_key)?.and_then(|entry| { + Self::recent_block_change_relative_to_pinned_block_num( + self.block_number, + &entry.block_list, + ) + }) { + let mut cursor = self.tx.cursor::()?; + let sharded_key = ContractStorageKey { contract_address: address, key: storage_key }; + + let ContractStorageEntry { key, value } = cursor + .seek_by_key_subkey(num, sharded_key)? + .expect("if block number is in the block set, change entry must exist"); + + if key.contract_address == address && key.key == storage_key { + return Ok(Some(value)); + } + } + + Ok(None) + } +} + +#[cfg(test)] +mod tests { + use super::HistoricalStateProvider; + + const BLOCK_LIST: [u64; 5] = [1, 2, 5, 6, 10]; + + #[rstest::rstest] + #[case(0, None)] + #[case(1, Some(1))] + #[case(3, Some(2))] + #[case(5, Some(5))] + #[case(9, Some(6))] + #[case(10, Some(10))] + #[case(11, Some(10))] + fn position_of_most_recent_block_in_block_list( + #[case] block_num: u64, + #[case] expected_block_num: Option, + ) { + assert_eq!( + HistoricalStateProvider::recent_block_change_relative_to_pinned_block_num( + block_num, + &BLOCK_LIST, + ), + expected_block_num + ); + } +} diff --git a/crates/katana/storage/provider/src/providers/mod.rs b/crates/katana/storage/provider/src/providers/mod.rs index 0c2dc27d17..d565e76a9a 100644 --- a/crates/katana/storage/provider/src/providers/mod.rs +++ b/crates/katana/storage/provider/src/providers/mod.rs @@ -1,3 +1,4 @@ +pub mod db; #[cfg(feature = "fork")] pub mod fork; #[cfg(feature = "in-memory")] diff --git a/crates/katana/storage/provider/src/traits/transaction.rs b/crates/katana/storage/provider/src/traits/transaction.rs index 8023d86568..c7dbee6be1 100644 --- a/crates/katana/storage/provider/src/traits/transaction.rs +++ b/crates/katana/storage/provider/src/traits/transaction.rs @@ -29,6 +29,11 @@ pub trait TransactionProvider: Send + Sync { &self, hash: TxHash, ) -> Result>; + + /// Retrieves all the transactions at the given range. + fn transaction_in_range(&self, _range: Range) -> Result> { + todo!() + } } #[auto_impl::auto_impl(&, Box, Arc)] diff --git a/crates/katana/storage/provider/tests/block.rs b/crates/katana/storage/provider/tests/block.rs index 03b8a4cecc..e62efa89e6 100644 --- a/crates/katana/storage/provider/tests/block.rs +++ b/crates/katana/storage/provider/tests/block.rs @@ -1,15 +1,29 @@ -use katana_primitives::block::BlockHashOrNumber; +use anyhow::Result; +use katana_primitives::block::{ + Block, BlockHashOrNumber, BlockNumber, BlockWithTxHashes, FinalityStatus, +}; +use katana_primitives::state::StateUpdates; +use katana_provider::providers::db::DbProvider; use katana_provider::providers::fork::ForkedProvider; use katana_provider::providers::in_memory::InMemoryProvider; -use katana_provider::traits::block::{BlockProvider, BlockWriter}; -use katana_provider::traits::transaction::{ReceiptProvider, TransactionProvider}; +use katana_provider::traits::block::{ + BlockHashProvider, BlockProvider, BlockStatusProvider, BlockWriter, +}; +use katana_provider::traits::state::StateRootProvider; +use katana_provider::traits::state_update::StateUpdateProvider; +use katana_provider::traits::transaction::{ + ReceiptProvider, TransactionProvider, TransactionStatusProvider, +}; use katana_provider::BlockchainProvider; use rstest_reuse::{self, *}; mod fixtures; mod utils; -use fixtures::{fork_provider, in_memory_provider}; +use fixtures::{ + db_provider, fork_provider, fork_provider_with_spawned_fork_network, in_memory_provider, + mock_state_updates, provider_with_states, +}; use utils::generate_dummy_blocks_and_receipts; #[template] @@ -24,7 +38,7 @@ fn insert_block_cases(#[case] block_count: u64) {} fn insert_block_with_in_memory_provider( #[from(in_memory_provider)] provider: BlockchainProvider, #[case] block_count: u64, -) -> anyhow::Result<()> { +) -> Result<()> { insert_block_test_impl(provider, block_count) } @@ -32,13 +46,25 @@ fn insert_block_with_in_memory_provider( fn insert_block_with_fork_provider( #[from(fork_provider)] provider: BlockchainProvider, #[case] block_count: u64, -) -> anyhow::Result<()> { +) -> Result<()> { insert_block_test_impl(provider, block_count) } -fn insert_block_test_impl(provider: BlockchainProvider, count: u64) -> anyhow::Result<()> +#[apply(insert_block_cases)] +fn insert_block_with_db_provider( + #[from(db_provider)] provider: BlockchainProvider, + #[case] block_count: u64, +) -> Result<()> { + insert_block_test_impl(provider, block_count) +} + +fn insert_block_test_impl(provider: BlockchainProvider, count: u64) -> Result<()> where - Db: BlockProvider + BlockWriter + ReceiptProvider, + Db: BlockProvider + + BlockWriter + + ReceiptProvider + + StateRootProvider + + TransactionStatusProvider, { let blocks = generate_dummy_blocks_and_receipts(count); @@ -50,23 +76,120 @@ where )?; } + let actual_blocks_in_range = provider.blocks_in_range(0..=count)?; + + assert_eq!(actual_blocks_in_range.len(), count as usize); + assert_eq!( + actual_blocks_in_range, + blocks.clone().into_iter().map(|b| b.0.block.unseal()).collect::>() + ); + for (block, receipts) in blocks { let block_id = BlockHashOrNumber::Hash(block.block.header.hash); + + let expected_block_num = block.block.header.header.number; + let expected_block_hash = block.block.header.hash; let expected_block = block.block.unseal(); + let actual_block_hash = provider.block_hash_by_num(expected_block_num)?; + let actual_block = provider.block(block_id)?; let actual_block_txs = provider.transactions_by_block(block_id)?; - let actual_block_tx_count = provider.transaction_count_by_block(block_id)?; + let actual_status = provider.block_status(block_id)?; + let actual_state_root = provider.state_root(block_id)?; + let actual_block_tx_count = provider.transaction_count_by_block(block_id)?; let actual_receipts = provider.receipts_by_block(block_id)?; + let expected_block_with_tx_hashes = BlockWithTxHashes { + header: expected_block.header.clone(), + body: expected_block.body.clone().into_iter().map(|t| t.hash).collect(), + }; + + let actual_block_with_tx_hashes = provider.block_with_tx_hashes(block_id)?; + + assert_eq!(actual_status, Some(FinalityStatus::AcceptedOnL2)); + assert_eq!(actual_block_with_tx_hashes, Some(expected_block_with_tx_hashes)); + + for (idx, tx) in expected_block.body.iter().enumerate() { + let actual_receipt = provider.receipt_by_hash(tx.hash)?; + let actual_tx = provider.transaction_by_hash(tx.hash)?; + let actual_tx_status = provider.transaction_status(tx.hash)?; + let actual_tx_block_num_hash = provider.transaction_block_num_and_hash(tx.hash)?; + let actual_tx_by_block_idx = + provider.transaction_by_block_and_idx(block_id, idx as u64)?; + + assert_eq!(actual_tx_block_num_hash, Some((expected_block_num, expected_block_hash))); + assert_eq!(actual_tx_status, Some(FinalityStatus::AcceptedOnL2)); + assert_eq!(actual_receipt, Some(receipts[idx].clone())); + assert_eq!(actual_tx_by_block_idx, Some(tx.clone())); + assert_eq!(actual_tx, Some(tx.clone())); + } + assert_eq!(actual_receipts.as_ref().map(|r| r.len()), Some(expected_block.body.len())); assert_eq!(actual_receipts, Some(receipts)); assert_eq!(actual_block_tx_count, Some(expected_block.body.len() as u64)); + assert_eq!(actual_state_root, Some(expected_block.header.state_root)); assert_eq!(actual_block_txs, Some(expected_block.body.clone())); + assert_eq!(actual_block_hash, Some(expected_block_hash)); assert_eq!(actual_block, Some(expected_block)); } Ok(()) } + +#[template] +#[rstest::rstest] +#[case::state_update_at_block_1(1, mock_state_updates().0)] +#[case::state_update_at_block_2(2, mock_state_updates().1)] +#[case::state_update_at_block_3(3, StateUpdates::default())] +#[case::state_update_at_block_5(5, mock_state_updates().2)] +fn test_read_state_update( + #[from(provider_with_states)] provider: BlockchainProvider, + #[case] block_num: BlockNumber, + #[case] expected_state_update: StateUpdates, +) { +} + +#[apply(test_read_state_update)] +fn test_read_state_update_with_in_memory_provider( + #[with(in_memory_provider())] provider: BlockchainProvider, + #[case] block_num: BlockNumber, + #[case] expected_state_update: StateUpdates, +) -> Result<()> { + test_read_state_update_impl(provider, block_num, expected_state_update) +} + +#[apply(test_read_state_update)] +fn test_read_state_update_with_fork_provider( + #[with(fork_provider_with_spawned_fork_network::default())] provider: BlockchainProvider< + ForkedProvider, + >, + #[case] block_num: BlockNumber, + #[case] expected_state_update: StateUpdates, +) -> Result<()> { + test_read_state_update_impl(provider, block_num, expected_state_update) +} + +#[apply(test_read_state_update)] +fn test_read_state_update_with_db_provider( + #[with(db_provider())] provider: BlockchainProvider, + #[case] block_num: BlockNumber, + #[case] expected_state_update: StateUpdates, +) -> Result<()> { + test_read_state_update_impl(provider, block_num, expected_state_update) +} + +fn test_read_state_update_impl( + provider: BlockchainProvider, + block_num: BlockNumber, + expected_state_update: StateUpdates, +) -> Result<()> +where + Db: StateUpdateProvider, +{ + let actual_state_update = provider.state_update(BlockHashOrNumber::from(block_num))?; + assert_eq!(actual_state_update, Some(expected_state_update)); + Ok(()) +} diff --git a/crates/katana/storage/provider/tests/class.rs b/crates/katana/storage/provider/tests/class.rs index 68a2959843..58f399fc0f 100644 --- a/crates/katana/storage/provider/tests/class.rs +++ b/crates/katana/storage/provider/tests/class.rs @@ -23,7 +23,10 @@ fn assert_state_provider_class( } mod latest { + use katana_provider::providers::db::DbProvider; + use super::*; + use crate::fixtures::db_provider; fn assert_latest_class( provider: BlockchainProvider, @@ -65,10 +68,21 @@ mod latest { ) -> Result<()> { assert_latest_class(provider, expected_class) } + + #[apply(test_latest_class_read)] + fn read_class_from_db_provider( + #[with(db_provider())] provider: BlockchainProvider, + #[case] expected_class: Vec<(ClassHash, Option)>, + ) -> Result<()> { + assert_latest_class(provider, expected_class) + } } mod historical { + use katana_provider::providers::db::DbProvider; + use super::*; + use crate::fixtures::db_provider; fn assert_historical_class( provider: BlockchainProvider, @@ -143,4 +157,13 @@ mod historical { ) -> Result<()> { assert_historical_class(provider, block_num, expected_class) } + + #[apply(test_historical_class_read)] + fn read_class_from_db_provider( + #[with(db_provider())] provider: BlockchainProvider, + #[case] block_num: BlockNumber, + #[case] expected_class: Vec<(ClassHash, Option)>, + ) -> Result<()> { + assert_historical_class(provider, block_num, expected_class) + } } diff --git a/crates/katana/storage/provider/tests/contract.rs b/crates/katana/storage/provider/tests/contract.rs index 3b54fdff78..57a6f428fd 100644 --- a/crates/katana/storage/provider/tests/contract.rs +++ b/crates/katana/storage/provider/tests/contract.rs @@ -1,7 +1,9 @@ mod fixtures; use anyhow::Result; -use fixtures::{fork_provider_with_spawned_fork_network, in_memory_provider, provider_with_states}; +use fixtures::{ + db_provider, fork_provider_with_spawned_fork_network, in_memory_provider, provider_with_states, +}; use katana_primitives::block::{BlockHashOrNumber, BlockNumber}; use katana_primitives::contract::{ClassHash, ContractAddress, Nonce}; use katana_provider::providers::fork::ForkedProvider; @@ -27,6 +29,8 @@ fn assert_state_provider_contract_info( } mod latest { + use katana_provider::providers::db::DbProvider; + use super::*; fn assert_latest_contract_info( @@ -68,9 +72,19 @@ mod latest { ) -> Result<()> { assert_latest_contract_info(provider, expected_contract_info) } + + #[apply(test_latest_contract_info_read)] + fn read_storage_from_db_provider( + #[with(db_provider())] provider: BlockchainProvider, + #[case] expected_contract_info: Vec<(ContractAddress, Option, Option)>, + ) -> Result<()> { + assert_latest_contract_info(provider, expected_contract_info) + } } mod historical { + use katana_provider::providers::db::DbProvider; + use super::*; fn assert_historical_contract_info( @@ -142,4 +156,13 @@ mod historical { ) -> Result<()> { assert_historical_contract_info(provider, block_num, expected_contract_info) } + + #[apply(test_historical_storage_read)] + fn read_storage_from_db_provider( + #[with(db_provider())] provider: BlockchainProvider, + #[case] block_num: BlockNumber, + #[case] expected_contract_info: Vec<(ContractAddress, Option, Option)>, + ) -> Result<()> { + assert_historical_contract_info(provider, block_num, expected_contract_info) + } } diff --git a/crates/katana/storage/provider/tests/fixtures.rs b/crates/katana/storage/provider/tests/fixtures.rs index 4021ab3672..3f0a685001 100644 --- a/crates/katana/storage/provider/tests/fixtures.rs +++ b/crates/katana/storage/provider/tests/fixtures.rs @@ -1,11 +1,13 @@ use std::collections::HashMap; use std::sync::Arc; +use katana_db::mdbx; use katana_primitives::block::{ BlockHashOrNumber, FinalityStatus, Header, SealedBlock, SealedBlockWithStatus, SealedHeader, }; use katana_primitives::contract::ContractAddress; use katana_primitives::state::{StateUpdates, StateUpdatesWithDeclaredClasses}; +use katana_provider::providers::db::DbProvider; use katana_provider::providers::fork::ForkedProvider; use katana_provider::providers::in_memory::InMemoryProvider; use katana_provider::traits::block::BlockWriter; @@ -50,13 +52,13 @@ pub fn fork_provider_with_spawned_fork_network( } #[rstest::fixture] -#[default(BlockchainProvider)] -pub fn provider_with_states( - #[default(in_memory_provider())] provider: BlockchainProvider, -) -> BlockchainProvider -where - Db: BlockWriter + StateFactoryProvider, -{ +pub fn db_provider() -> BlockchainProvider { + let env = mdbx::test_utils::create_test_db(mdbx::DbEnvKind::RW); + BlockchainProvider::new(DbProvider::new(env)) +} + +#[rstest::fixture] +pub fn mock_state_updates() -> (StateUpdates, StateUpdates, StateUpdates) { let address_1 = ContractAddress::from(felt!("1")); let address_2 = ContractAddress::from(felt!("2")); @@ -69,54 +71,55 @@ where let class_hash_3 = felt!("33"); let compiled_class_hash_3 = felt!("3000"); - let state_update_at_block_1 = StateUpdatesWithDeclaredClasses { - state_updates: StateUpdates { - nonce_updates: HashMap::from([(address_1, 1u8.into()), (address_2, 1u8.into())]), - storage_updates: HashMap::from([ - ( - address_1, - HashMap::from([(1u8.into(), 100u32.into()), (2u8.into(), 101u32.into())]), - ), - ( - address_2, - HashMap::from([(1u8.into(), 200u32.into()), (2u8.into(), 201u32.into())]), - ), - ]), - declared_classes: HashMap::from([(class_hash_1, compiled_class_hash_1)]), - contract_updates: HashMap::from([(address_1, class_hash_1), (address_2, class_hash_1)]), - }, - ..Default::default() + let state_update_1 = StateUpdates { + nonce_updates: HashMap::from([(address_1, 1u8.into()), (address_2, 1u8.into())]), + storage_updates: HashMap::from([ + (address_1, HashMap::from([(1u8.into(), 100u32.into()), (2u8.into(), 101u32.into())])), + (address_2, HashMap::from([(1u8.into(), 200u32.into()), (2u8.into(), 201u32.into())])), + ]), + declared_classes: HashMap::from([(class_hash_1, compiled_class_hash_1)]), + contract_updates: HashMap::from([(address_1, class_hash_1), (address_2, class_hash_1)]), }; - let state_update_at_block_2 = StateUpdatesWithDeclaredClasses { - state_updates: StateUpdates { - nonce_updates: HashMap::from([(address_1, 2u8.into())]), - storage_updates: HashMap::from([( - address_1, - HashMap::from([(felt!("1"), felt!("111")), (felt!("2"), felt!("222"))]), - )]), - declared_classes: HashMap::from([(class_hash_2, compiled_class_hash_2)]), - contract_updates: HashMap::from([(address_2, class_hash_2)]), - }, - ..Default::default() + let state_update_2 = StateUpdates { + nonce_updates: HashMap::from([(address_1, 2u8.into())]), + storage_updates: HashMap::from([( + address_1, + HashMap::from([(felt!("1"), felt!("111")), (felt!("2"), felt!("222"))]), + )]), + declared_classes: HashMap::from([(class_hash_2, compiled_class_hash_2)]), + contract_updates: HashMap::from([(address_2, class_hash_2)]), }; - let state_update_at_block_5 = StateUpdatesWithDeclaredClasses { - state_updates: StateUpdates { - nonce_updates: HashMap::from([(address_1, 3u8.into()), (address_2, 2u8.into())]), - storage_updates: HashMap::from([ - (address_1, HashMap::from([(3u8.into(), 77u32.into())])), - ( - address_2, - HashMap::from([(1u8.into(), 12u32.into()), (2u8.into(), 13u32.into())]), - ), - ]), - contract_updates: HashMap::from([(address_1, class_hash_2), (address_2, class_hash_3)]), - declared_classes: HashMap::from([(class_hash_3, compiled_class_hash_3)]), - }, - ..Default::default() + let state_update_3 = StateUpdates { + nonce_updates: HashMap::from([(address_1, 3u8.into()), (address_2, 2u8.into())]), + storage_updates: HashMap::from([ + (address_1, HashMap::from([(3u8.into(), 77u32.into())])), + (address_2, HashMap::from([(1u8.into(), 12u32.into()), (2u8.into(), 13u32.into())])), + ]), + contract_updates: HashMap::from([(address_1, class_hash_2), (address_2, class_hash_3)]), + declared_classes: HashMap::from([(class_hash_3, compiled_class_hash_3)]), }; + (state_update_1, state_update_2, state_update_3) +} + +#[rstest::fixture] +#[default(BlockchainProvider)] +pub fn provider_with_states( + #[default(in_memory_provider())] provider: BlockchainProvider, + #[from(mock_state_updates)] state_updates: (StateUpdates, StateUpdates, StateUpdates), +) -> BlockchainProvider +where + Db: BlockWriter + StateFactoryProvider, +{ + let state_update_at_block_1 = + StateUpdatesWithDeclaredClasses { state_updates: state_updates.0, ..Default::default() }; + let state_update_at_block_2 = + StateUpdatesWithDeclaredClasses { state_updates: state_updates.1, ..Default::default() }; + let state_update_at_block_5 = + StateUpdatesWithDeclaredClasses { state_updates: state_updates.2, ..Default::default() }; + // Fill provider with states. for i in 0..=5 { diff --git a/crates/katana/storage/provider/tests/storage.rs b/crates/katana/storage/provider/tests/storage.rs index 4f12e8f5d8..07afc2e457 100644 --- a/crates/katana/storage/provider/tests/storage.rs +++ b/crates/katana/storage/provider/tests/storage.rs @@ -23,7 +23,10 @@ fn assert_state_provider_storage( } mod latest { + use katana_provider::providers::db::DbProvider; + use super::*; + use crate::fixtures::db_provider; fn assert_latest_storage_value( provider: BlockchainProvider, @@ -67,10 +70,21 @@ mod latest { ) -> Result<()> { assert_latest_storage_value(provider, expected_storage_entry) } + + #[apply(test_latest_storage_read)] + fn read_storage_from_db_provider( + #[with(db_provider())] provider: BlockchainProvider, + #[case] expected_storage_entry: Vec<(ContractAddress, StorageKey, Option)>, + ) -> Result<()> { + assert_latest_storage_value(provider, expected_storage_entry) + } } mod historical { + use katana_provider::providers::db::DbProvider; + use super::*; + use crate::fixtures::db_provider; fn assert_historical_storage_value( provider: BlockchainProvider, @@ -150,4 +164,13 @@ mod historical { ) -> Result<()> { assert_historical_storage_value(provider, block_num, expected_storage_entry) } + + #[apply(test_historical_storage_read)] + fn read_storage_from_db_provider( + #[with(db_provider())] provider: BlockchainProvider, + #[case] block_num: BlockNumber, + #[case] expected_storage_entry: Vec<(ContractAddress, StorageKey, Option)>, + ) -> Result<()> { + assert_historical_storage_value(provider, block_num, expected_storage_entry) + } }