Skip to content

Commit

Permalink
refactor(katana): change stored compiled class (#1559)
Browse files Browse the repository at this point in the history
Both `blockifier` and `starknet_in_rust` use their own proprietary contract class types. So need to have a type where the types from those two crates can be derived from.
  • Loading branch information
kariy committed Mar 7, 2024
1 parent 8bfd5ad commit e4e982a
Show file tree
Hide file tree
Showing 37 changed files with 395 additions and 984 deletions.
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions crates/katana/core/src/backend/contract.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use blockifier::execution::contract_class::ContractClassV0;
use katana_primitives::contract::DeprecatedCompiledClass;
use starknet::core::types::FlattenedSierraClass;

pub enum StarknetContract {
Legacy(ContractClassV0),
Legacy(DeprecatedCompiledClass),
Sierra(FlattenedSierraClass),
}
6 changes: 3 additions & 3 deletions crates/katana/core/src/sequencer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use katana_executor::blockifier::PendingState;
use katana_primitives::block::{BlockHash, BlockHashOrNumber, BlockIdOrTag, BlockNumber};
use katana_primitives::chain::ChainId;
use katana_primitives::contract::{
ClassHash, CompiledContractClass, ContractAddress, Nonce, StorageKey, StorageValue,
ClassHash, CompiledClass, ContractAddress, Nonce, StorageKey, StorageValue,
};
use katana_primitives::event::{ContinuationToken, ContinuationTokenError};
use katana_primitives::receipt::Event;
Expand Down Expand Up @@ -244,8 +244,8 @@ impl KatanaSequencer {
};

match class {
CompiledContractClass::V0(class) => Ok(Some(StarknetContract::Legacy(class))),
CompiledContractClass::V1(_) => {
CompiledClass::Deprecated(class) => Ok(Some(StarknetContract::Legacy(class))),
CompiledClass::Class(_) => {
let class = ContractClassProvider::sierra_class(&state, class_hash)?
.map(StarknetContract::Sierra);
Ok(class)
Expand Down
2 changes: 1 addition & 1 deletion crates/katana/core/src/service/block_producer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ impl IntervalBlockProducer {
trace!(target: "miner", "created new block: {}", outcome.block_number);

backend.update_block_env(&mut block_env);
pending_state.reset_state(StateRefDb(new_state), block_env, cfg_env);
pending_state.reset_state(new_state, block_env, cfg_env);

Ok(outcome)
}
Expand Down
4 changes: 2 additions & 2 deletions crates/katana/executor/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ version.workspace = true
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html

[dependencies]
katana-primitives = { path = "../primitives" }
katana-provider = { path = "../storage/provider" }
katana-primitives.workspace = true
katana-provider.workspace = true

anyhow.workspace = true
convert_case.workspace = true
Expand Down
27 changes: 13 additions & 14 deletions crates/katana/executor/src/blockifier/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,9 +11,8 @@ use blockifier::transaction::objects::TransactionExecutionInfo;
use blockifier::transaction::transaction_execution::Transaction;
use blockifier::transaction::transactions::ExecutableTransaction;
use katana_primitives::env::{BlockEnv, CfgEnv};
use katana_primitives::transaction::{
DeclareTxWithClass, ExecutableTx, ExecutableTxWithHash, TxWithHash,
};
use katana_primitives::transaction::{ExecutableTx, ExecutableTxWithHash, TxWithHash};
use katana_provider::traits::state::StateProvider;
use parking_lot::RwLock;
use tracing::{trace, warn};

Expand Down Expand Up @@ -145,13 +144,9 @@ fn execute_tx(
charge_fee: bool,
validate: bool,
) -> TxExecutionResult {
let sierra = if let ExecutableTx::Declare(DeclareTxWithClass {
transaction,
sierra_class: Some(sierra_class),
..
}) = tx.as_ref()
{
Some((transaction.class_hash(), sierra_class.clone()))
let class_declaration_params = if let ExecutableTx::Declare(tx) = tx.as_ref() {
let class_hash = tx.class_hash();
Some((class_hash, tx.compiled_class.clone(), tx.sierra_class.clone()))
} else {
None
};
Expand All @@ -166,8 +161,12 @@ fn execute_tx(
};

if res.is_ok() {
if let Some((class_hash, sierra_class)) = sierra {
state.sierra_class_mut().insert(class_hash, sierra_class);
if let Some((class_hash, compiled_class, sierra_class)) = class_declaration_params {
state.class_cache.write().compiled.insert(class_hash, compiled_class);

if let Some(sierra_class) = sierra_class {
state.class_cache.write().sierra.insert(class_hash, sierra_class);
}
}
}

Expand Down Expand Up @@ -198,9 +197,9 @@ impl PendingState {
}
}

pub fn reset_state(&self, state: StateRefDb, block_env: BlockEnv, cfg_env: CfgEnv) {
pub fn reset_state(&self, state: Box<dyn StateProvider>, block_env: BlockEnv, cfg_env: CfgEnv) {
*self.block_envs.write() = (block_env, cfg_env);
self.state.reset_with_new_state(state);
self.state.reset_with_new_state(StateRefDb(state));
}

pub fn add_executed_txs(&self, transactions: Vec<(TxWithHash, TxExecutionResult)>) {
Expand Down
123 changes: 58 additions & 65 deletions crates/katana/executor/src/blockifier/state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ use std::collections::HashMap;

use blockifier::state::cached_state::{CachedState, GlobalContractCache};
use blockifier::state::errors::StateError;
use blockifier::state::state_api::StateReader;
use katana_primitives::contract::FlattenedSierraClass;
use blockifier::state::state_api::{StateReader, StateResult};
use katana_primitives::contract::{CompiledClass, FlattenedSierraClass};
use katana_primitives::conversion::blockifier::to_class;
use katana_primitives::FieldElement;
use katana_provider::traits::contract::ContractClassProvider;
use katana_provider::traits::state::StateProvider;
Expand All @@ -14,6 +15,12 @@ use starknet_api::hash::StarkHash;
use starknet_api::patricia_key;
use starknet_api::state::StorageKey;

mod primitives {
pub use katana_primitives::contract::{
ClassHash, CompiledClassHash, ContractAddress, Nonce, StorageKey, StorageValue,
};
}

/// A state db only provide read access.
///
/// This type implements the [`StateReader`] trait so that it can be used as a with [`CachedState`].
Expand All @@ -25,58 +32,55 @@ impl StateRefDb {
}
}

impl ContractClassProvider for StateRefDb {
fn class(
impl StateProvider for StateRefDb {
fn class_hash_of_contract(
&self,
hash: katana_primitives::contract::ClassHash,
) -> ProviderResult<Option<katana_primitives::contract::CompiledContractClass>> {
self.0.class(hash)
address: primitives::ContractAddress,
) -> ProviderResult<Option<primitives::ClassHash>> {
self.0.class_hash_of_contract(address)
}

fn compiled_class_hash_of_class_hash(
fn nonce(
&self,
hash: katana_primitives::contract::ClassHash,
) -> ProviderResult<Option<katana_primitives::contract::CompiledClassHash>> {
self.0.compiled_class_hash_of_class_hash(hash)
address: primitives::ContractAddress,
) -> ProviderResult<Option<primitives::Nonce>> {
self.0.nonce(address)
}

fn sierra_class(
fn storage(
&self,
hash: katana_primitives::contract::ClassHash,
) -> ProviderResult<Option<FlattenedSierraClass>> {
self.0.sierra_class(hash)
address: primitives::ContractAddress,
storage_key: primitives::StorageKey,
) -> ProviderResult<Option<primitives::StorageValue>> {
self.0.storage(address, storage_key)
}
}

impl StateProvider for StateRefDb {
fn nonce(
&self,
address: katana_primitives::contract::ContractAddress,
) -> ProviderResult<Option<katana_primitives::contract::Nonce>> {
self.0.nonce(address)
impl ContractClassProvider for StateRefDb {
fn class(&self, hash: primitives::ClassHash) -> ProviderResult<Option<CompiledClass>> {
self.0.class(hash)
}

fn class_hash_of_contract(
fn compiled_class_hash_of_class_hash(
&self,
address: katana_primitives::contract::ContractAddress,
) -> ProviderResult<Option<katana_primitives::contract::ClassHash>> {
self.0.class_hash_of_contract(address)
hash: primitives::ClassHash,
) -> ProviderResult<Option<primitives::CompiledClassHash>> {
self.0.compiled_class_hash_of_class_hash(hash)
}

fn storage(
fn sierra_class(
&self,
address: katana_primitives::contract::ContractAddress,
storage_key: katana_primitives::contract::StorageKey,
) -> ProviderResult<Option<katana_primitives::contract::StorageValue>> {
self.0.storage(address, storage_key)
hash: primitives::ClassHash,
) -> ProviderResult<Option<FlattenedSierraClass>> {
self.0.sierra_class(hash)
}
}

impl StateReader for StateRefDb {
fn get_nonce_at(
&mut self,
contract_address: starknet_api::core::ContractAddress,
) -> blockifier::state::state_api::StateResult<Nonce> {
) -> StateResult<Nonce> {
StateProvider::nonce(&self.0, contract_address.into())
.map(|n| Nonce(n.unwrap_or_default().into()))
.map_err(|e| StateError::StateReadError(e.to_string()))
Expand All @@ -86,7 +90,7 @@ impl StateReader for StateRefDb {
&mut self,
contract_address: starknet_api::core::ContractAddress,
key: starknet_api::state::StorageKey,
) -> blockifier::state::state_api::StateResult<starknet_api::hash::StarkFelt> {
) -> StateResult<starknet_api::hash::StarkFelt> {
StateProvider::storage(&self.0, contract_address.into(), (*key.0.key()).into())
.map(|v| v.unwrap_or_default().into())
.map_err(|e| StateError::StateReadError(e.to_string()))
Expand All @@ -95,7 +99,7 @@ impl StateReader for StateRefDb {
fn get_class_hash_at(
&mut self,
contract_address: starknet_api::core::ContractAddress,
) -> blockifier::state::state_api::StateResult<starknet_api::core::ClassHash> {
) -> StateResult<starknet_api::core::ClassHash> {
StateProvider::class_hash_of_contract(&self.0, contract_address.into())
.map(|v| ClassHash(v.unwrap_or_default().into()))
.map_err(|e| StateError::StateReadError(e.to_string()))
Expand All @@ -104,7 +108,7 @@ impl StateReader for StateRefDb {
fn get_compiled_class_hash(
&mut self,
class_hash: starknet_api::core::ClassHash,
) -> blockifier::state::state_api::StateResult<starknet_api::core::CompiledClassHash> {
) -> StateResult<starknet_api::core::CompiledClassHash> {
if let Some(hash) =
ContractClassProvider::compiled_class_hash_of_class_hash(&self.0, class_hash.0.into())
.map_err(|e| StateError::StateReadError(e.to_string()))?
Expand All @@ -118,71 +122,60 @@ impl StateReader for StateRefDb {
fn get_compiled_contract_class(
&mut self,
class_hash: starknet_api::core::ClassHash,
) -> blockifier::state::state_api::StateResult<
blockifier::execution::contract_class::ContractClass,
> {
) -> StateResult<blockifier::execution::contract_class::ContractClass> {
if let Some(class) = ContractClassProvider::class(&self.0, class_hash.0.into())
.map_err(|e| StateError::StateReadError(e.to_string()))?
{
Ok(class)
to_class(class).map_err(|e| StateError::StateReadError(e.to_string()))
} else {
Err(StateError::UndeclaredClassHash(class_hash))
}
}
}

#[derive(Default)]
pub struct ClassCache {
pub(crate) compiled: HashMap<primitives::ClassHash, CompiledClass>,
pub(crate) sierra: HashMap<primitives::ClassHash, FlattenedSierraClass>,
}

pub struct CachedStateWrapper {
inner: Mutex<CachedState<StateRefDb>>,
sierra_class: RwLock<HashMap<katana_primitives::contract::ClassHash, FlattenedSierraClass>>,
pub(crate) class_cache: RwLock<ClassCache>,
}

impl CachedStateWrapper {
pub fn new(db: StateRefDb) -> Self {
Self {
sierra_class: Default::default(),
class_cache: RwLock::new(ClassCache::default()),
inner: Mutex::new(CachedState::new(db, GlobalContractCache::default())),
}
}

pub(super) fn reset_with_new_state(&self, db: StateRefDb) {
*self.inner() = CachedState::new(db, GlobalContractCache::default());
self.sierra_class_mut().clear();
let mut lock = self.class_cache.write();
lock.compiled.clear();
lock.sierra.clear();
}

pub fn inner(
&self,
) -> parking_lot::lock_api::MutexGuard<'_, RawMutex, CachedState<StateRefDb>> {
self.inner.lock()
}

pub fn sierra_class(
&self,
) -> parking_lot::RwLockReadGuard<
'_,
HashMap<katana_primitives::contract::ClassHash, FlattenedSierraClass>,
> {
self.sierra_class.read()
}

pub fn sierra_class_mut(
&self,
) -> parking_lot::RwLockWriteGuard<
'_,
HashMap<katana_primitives::contract::ClassHash, FlattenedSierraClass>,
> {
self.sierra_class.write()
}
}

impl ContractClassProvider for CachedStateWrapper {
fn class(
&self,
hash: katana_primitives::contract::ClassHash,
) -> ProviderResult<Option<katana_primitives::contract::CompiledContractClass>> {
let Ok(class) = self.inner().get_compiled_contract_class(ClassHash(hash.into())) else {
return Ok(None);
};
Ok(Some(class))
) -> ProviderResult<Option<CompiledClass>> {
if let res @ Some(_) = self.class_cache.read().compiled.get(&hash).cloned() {
Ok(res)
} else {
self.inner().state.class(hash)
}
}

fn compiled_class_hash_of_class_hash(
Expand All @@ -199,7 +192,7 @@ impl ContractClassProvider for CachedStateWrapper {
&self,
hash: katana_primitives::contract::ClassHash,
) -> ProviderResult<Option<FlattenedSierraClass>> {
if let Some(class) = self.sierra_class().get(&hash) {
if let Some(class) = self.class_cache.read().sierra.get(&hash) {
Ok(Some(class.clone()))
} else {
self.inner.lock().state.0.sierra_class(hash)
Expand Down
9 changes: 7 additions & 2 deletions crates/katana/executor/src/blockifier/transactions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use ::blockifier::transaction::transaction_execution::Transaction;
use ::blockifier::transaction::transactions::{DeployAccountTransaction, InvokeTransaction};
use blockifier::transaction::account_transaction::AccountTransaction;
use blockifier::transaction::transactions::{DeclareTransaction, L1HandlerTransaction};
use katana_primitives::conversion::blockifier::to_class;
use katana_primitives::transaction::{
DeclareTx, DeployAccountTx, ExecutableTx, ExecutableTxWithHash, InvokeTx,
};
Expand Down Expand Up @@ -191,8 +192,12 @@ impl From<ExecutableTxWithHash> for BlockifierTx {
}
};

let tx = DeclareTransaction::new(tx, TransactionHash(hash.into()), contract_class)
.expect("class mismatch");
let tx = DeclareTransaction::new(
tx,
TransactionHash(hash.into()),
to_class(contract_class).unwrap(),
)
.expect("class mismatch");
Transaction::AccountTransaction(AccountTransaction::Declare(tx))
}

Expand Down
Loading

0 comments on commit e4e982a

Please sign in to comment.