Skip to content

Commit

Permalink
fix(katana): invalid trie path conversion (#2844)
Browse files Browse the repository at this point in the history
fix trie path rpc conversion
  • Loading branch information
kariy authored Dec 25, 2024
1 parent 3832eca commit 3691369
Show file tree
Hide file tree
Showing 6 changed files with 215 additions and 87 deletions.
85 changes: 62 additions & 23 deletions crates/katana/rpc/rpc-types/src/trie.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use std::ops::{Deref, DerefMut};

use katana_primitives::contract::StorageKey;
use katana_primitives::hash::StarkHash;
use katana_primitives::{ContractAddress, Felt};
use katana_trie::bonsai::BitSlice;
use katana_trie::{MultiProof, Path, ProofNode};
use katana_trie::bitvec::view::BitView;
use katana_trie::{BitVec, MultiProof, Path, ProofNode};
use serde::{Deserialize, Serialize};

#[derive(Debug, Default, Serialize, Deserialize)]
Expand All @@ -24,7 +25,7 @@ pub struct GlobalRoots {
}

/// Node in the Merkle-Patricia trie.
#[derive(Debug, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(untagged)]
pub enum MerkleNode {
/// Represents a path to the highest non-zero descendant node.
Expand All @@ -47,6 +48,21 @@ pub enum MerkleNode {
},
}

impl MerkleNode {
// Taken from `bonsai-trie`: https://github.com/madara-alliance/bonsai-trie/blob/bfc6ad47b3cb8b75b1326bf630ca16e581f194c5/src/trie/merkle_node.rs#L234-L248
pub fn compute_hash<Hash: StarkHash>(&self) -> Felt {
match self {
Self::Binary { left, right } => Hash::hash(left, right),
Self::Edge { child, path, length } => {
let mut length_bytes = [0u8; 32];
length_bytes[31] = *length;
let length = Felt::from_bytes_be(&length_bytes);
Hash::hash(child, path) + length
}
}
}
}

/// The response type for `starknet_getStorageProof` method.
///
/// The requested storage proofs. Note that if a requested leaf has the default value, the path to
Expand Down Expand Up @@ -142,40 +158,63 @@ impl From<MerkleNode> for ProofNode {
fn from(value: MerkleNode) -> Self {
match value {
MerkleNode::Binary { left, right } => Self::Binary { left, right },
MerkleNode::Edge { path, child, .. } => Self::Edge { child, path: felt_to_path(path) },
MerkleNode::Edge { path, child, length } => {
Self::Edge { child, path: felt_to_path(path, length) }
}
}
}
}

fn felt_to_path(felt: Felt) -> Path {
Path(BitSlice::from_slice(&felt.to_bytes_be())[5..].to_bitvec())
fn felt_to_path(felt: Felt, length: u8) -> Path {
let length = length as usize;
let mut bits = BitVec::new();

// This function converts a Felt to a Path by preserving leading zeros
// that are semantically important in the Merkle tree path representation.
//
// Example:
// For a path "0000100" (length=7):
// - As an integer/hex: 0x4 (leading zeros get truncated)
// - As a Path: [0,0,0,0,1,0,0] (leading zeros preserved)
//
// We need to preserve these leading zeros because in a Merkle tree path:
// - Each bit represents a direction (left=0, right=1)
// - The position/index of each bit matters for the path traversal
// - "0000100" and "100" would represent different paths in the tree
for bit in &felt.to_bits_be()[256 - length..] {
bits.push(*bit);
}

Path(bits)
}

fn path_to_felt(path: Path) -> Felt {
let mut arr = [0u8; 32];
let slice = &mut BitSlice::from_slice_mut(&mut arr)[5..];
slice[..path.len()].copy_from_bitslice(&path);
Felt::from_bytes_be(&arr)
let mut bytes = [0u8; 32];
bytes.view_bits_mut()[256 - path.len()..].copy_from_bitslice(&path);
Felt::from_bytes_be(&bytes)
}

#[cfg(test)]
mod tests {
use katana_primitives::felt;
use katana_trie::BitVec;

use super::*;

// This test is assuming that the `path` field in `MerkleNode::Edge` is already a valid trie
// path value.
// Test cases taken from `bonsai-trie` crate
#[rstest::rstest]
#[case(felt!("0x1234567890abcdef"))]
#[case(felt!("0xdeadbeef"))]
#[case(Felt::MAX)]
#[case(Felt::ZERO)]
fn test_path_felt_roundtrip(#[case] path_in_felt: Felt) {
let initial_path = felt_to_path(path_in_felt);

let converted_felt = path_to_felt(initial_path.clone());
let path = felt_to_path(converted_felt);
assert_eq!(initial_path, path);
#[case(&[0b10101010, 0b10101010])]
#[case(&[])]
#[case(&[0b10101010])]
#[case(&[0b00000000])]
#[case(&[0b11111111])]
#[case(&[0b11111111, 0b00000000, 0b10101010, 0b10101010, 0b11111111, 0b00000000, 0b10101010, 0b10101010, 0b11111111, 0b00000000, 0b10101010, 0b10101010])]
fn path_felt_rt(#[case] input: &[u8]) {
let path = Path(BitVec::from_slice(input));

let converted_felt = path_to_felt(path.clone());
let converted_path = felt_to_path(converted_felt, path.len() as u8);

assert_eq!(path, converted_path);
assert_eq!(path.len(), converted_path.len());
}
}
20 changes: 11 additions & 9 deletions crates/katana/rpc/rpc/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,6 @@ repository.workspace = true
version.workspace = true

[dependencies]
anyhow.workspace = true
dojo-metrics.workspace = true
futures.workspace = true
http.workspace = true
jsonrpsee = { workspace = true, features = [ "server" ] }
katana-core.workspace = true
katana-executor.workspace = true
katana-pool.workspace = true
Expand All @@ -21,6 +16,12 @@ katana-rpc-api.workspace = true
katana-rpc-types.workspace = true
katana-rpc-types-builder.workspace = true
katana-tasks.workspace = true

anyhow.workspace = true
dojo-metrics.workspace = true
futures.workspace = true
http.workspace = true
jsonrpsee = { workspace = true, features = [ "server" ] }
metrics.workspace = true
serde_json.workspace = true
starknet.workspace = true
Expand All @@ -32,6 +33,11 @@ tracing.workspace = true
url.workspace = true

[dev-dependencies]
katana-cairo.workspace = true
katana-node.workspace = true
katana-rpc-api = { workspace = true, features = [ "client" ] }
katana-trie.workspace = true

alloy = { git = "https://github.com/alloy-rs/alloy", features = [ "contract", "network", "node-bindings", "provider-http", "providers", "signer-local" ] }
alloy-primitives = { workspace = true, features = [ "serde" ] }
assert_matches.workspace = true
Expand All @@ -40,10 +46,6 @@ dojo-test-utils.workspace = true
dojo-utils.workspace = true
indexmap.workspace = true
jsonrpsee = { workspace = true, features = [ "client" ] }
katana-cairo.workspace = true
katana-node.workspace = true
katana-rpc-api = { workspace = true, features = [ "client" ] }
katana-trie.workspace = true
num-traits.workspace = true
rand.workspace = true
rstest.workspace = true
Expand Down
132 changes: 96 additions & 36 deletions crates/katana/rpc/rpc/tests/proofs.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,15 @@ use jsonrpsee::http_client::HttpClientBuilder;
use katana_node::config::rpc::DEFAULT_RPC_MAX_PROOF_KEYS;
use katana_node::config::SequencingConfig;
use katana_primitives::block::BlockIdOrTag;
use katana_primitives::class::ClassHash;
use katana_primitives::hash::StarkHash;
use katana_primitives::{hash, Felt};
use katana_primitives::class::{ClassHash, CompiledClassHash};
use katana_primitives::Felt;
use katana_rpc_api::starknet::StarknetApiClient;
use katana_rpc_types::trie::GetStorageProofResponse;
use katana_trie::bitvec::view::AsBits;
use katana_trie::bonsai::BitVec;
use katana_trie::MultiProof;
use starknet::accounts::Account;
use katana_trie::{compute_classes_trie_value, ClassesMultiProof, MultiProof};
use starknet::accounts::{Account, ConnectedAccount, SingleOwnerAccount};
use starknet::core::types::BlockTag;
use starknet::macros::short_string;
use starknet::providers::jsonrpc::HttpTransport;
use starknet::providers::JsonRpcClient;
use starknet::signers::LocalWallet;

mod common;

Expand Down Expand Up @@ -74,14 +72,96 @@ async fn proofs_limit() {

#[tokio::test]
async fn classes_proofs() {
let sequencer =
TestSequencer::start(get_default_test_config(SequencingConfig::default())).await;
let cfg = get_default_test_config(SequencingConfig::default());

let provider = sequencer.provider();
let sequencer = TestSequencer::start(cfg).await;
let account = sequencer.account();

let path: PathBuf = PathBuf::from("tests/test_data/cairo1_contract.json");
let (contract, compiled_class_hash) = common::prepare_contract_declaration_params(&path)
let (class_hash1, compiled_class_hash1) =
declare(&account, "tests/test_data/cairo1_contract.json").await;
let (class_hash2, compiled_class_hash2) =
declare(&account, "tests/test_data/cairo_l1_msg_contract.json").await;
let (class_hash3, compiled_class_hash3) =
declare(&account, "tests/test_data/test_sierra_contract.json").await;

// We need to use the jsonrpsee client because `starknet-rs` doesn't yet support RPC 0.8.0
let client = HttpClientBuilder::default().build(sequencer.url()).unwrap();

{
let class_hash = class_hash1;
let trie_entry = compute_classes_trie_value(compiled_class_hash1);

let proofs = client
.get_storage_proof(BlockIdOrTag::Number(1), Some(vec![class_hash]), None, None)
.await
.expect("failed to get storage proof");

let results = ClassesMultiProof::from(MultiProof::from(proofs.classes_proof.nodes))
.verify(proofs.global_roots.classes_tree_root, vec![class_hash]);

assert_eq!(vec![trie_entry], results);
}

{
let class_hash = class_hash2;
let trie_entry = compute_classes_trie_value(compiled_class_hash2);

let proofs = client
.get_storage_proof(BlockIdOrTag::Number(2), Some(vec![class_hash]), None, None)
.await
.expect("failed to get storage proof");

let results = ClassesMultiProof::from(MultiProof::from(proofs.classes_proof.nodes))
.verify(proofs.global_roots.classes_tree_root, vec![class_hash]);

assert_eq!(vec![trie_entry], results);
}

{
let class_hash = class_hash3;
let trie_entry = compute_classes_trie_value(compiled_class_hash3);

let proofs = client
.get_storage_proof(BlockIdOrTag::Number(3), Some(vec![class_hash]), None, None)
.await
.expect("failed to get storage proof");

let results = ClassesMultiProof::from(MultiProof::from(proofs.classes_proof.nodes))
.verify(proofs.global_roots.classes_tree_root, vec![class_hash]);

assert_eq!(vec![trie_entry], results);
}

{
let class_hashes = vec![class_hash1, class_hash2, class_hash3];
let trie_entries = vec![
compute_classes_trie_value(compiled_class_hash1),
compute_classes_trie_value(compiled_class_hash2),
compute_classes_trie_value(compiled_class_hash3),
];

let proofs = client
.get_storage_proof(
BlockIdOrTag::Tag(BlockTag::Latest),
Some(class_hashes.clone()),
None,
None,
)
.await
.expect("failed to get storage proof");

let results = ClassesMultiProof::from(MultiProof::from(proofs.classes_proof.nodes))
.verify(proofs.global_roots.classes_tree_root, class_hashes.clone());

assert_eq!(trie_entries, results);
}
}

async fn declare(
account: &SingleOwnerAccount<JsonRpcClient<HttpTransport>, LocalWallet>,
path: impl Into<PathBuf>,
) -> (ClassHash, CompiledClassHash) {
let (contract, compiled_class_hash) = common::prepare_contract_declaration_params(&path.into())
.expect("failed to prepare class declaration params");

let class_hash = contract.class_hash();
Expand All @@ -91,29 +171,9 @@ async fn classes_proofs() {
.await
.expect("failed to send declare tx");

dojo_utils::TransactionWaiter::new(res.transaction_hash, &provider)
dojo_utils::TransactionWaiter::new(res.transaction_hash, account.provider())
.await
.expect("failed to wait on tx");

// We need to use the jsonrpsee client because `starknet-rs` doesn't yet support RPC 0.8.0
let client = HttpClientBuilder::default().build(sequencer.url()).unwrap();

let GetStorageProofResponse { global_roots, classes_proof, .. } = client
.get_storage_proof(BlockIdOrTag::Tag(BlockTag::Latest), Some(vec![class_hash]), None, None)
.await
.expect("failed to get storage proof");

let key: BitVec = class_hash.to_bytes_be().as_bits()[5..].to_owned();
let value =
hash::Poseidon::hash(&short_string!("CONTRACT_CLASS_LEAF_V0"), &compiled_class_hash);

let classes_proof = MultiProof::from(classes_proof.nodes);

// the returned data is the list of values corresponds to the [key]
let results = classes_proof
.verify_proof::<hash::Pedersen>(global_roots.classes_tree_root, [key], 251)
.collect::<Result<Vec<_>, _>>()
.expect("failed to verify proofs");

assert_eq!(vec![value], results);
(class_hash, compiled_class_hash)
}
31 changes: 28 additions & 3 deletions crates/katana/trie/src/classes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,32 @@ use starknet_types_core::hash::{Poseidon, StarkHash};

use crate::id::CommitId;

#[derive(Debug)]
pub struct ClassesMultiProof(pub MultiProof);

impl ClassesMultiProof {
// TODO: maybe perform results check in this method as well. make it accept the compiled class
// hashes
pub fn verify(&self, root: Felt, class_hashes: Vec<ClassHash>) -> Vec<Felt> {
crate::verify_proof::<Pedersen>(&self.0, root, class_hashes)
}
}

impl From<MultiProof> for ClassesMultiProof {
fn from(value: MultiProof) -> Self {
Self(value)
}
}

#[derive(Debug)]
pub struct ClassesTrie<DB: BonsaiDatabase> {
trie: crate::BonsaiTrie<DB, Pedersen>,
}

/////////////////////////////////////////////////////
// ClassesTrie implementations
/////////////////////////////////////////////////////

impl<DB: BonsaiDatabase> ClassesTrie<DB> {
const BONSAI_IDENTIFIER: &'static [u8] = b"classes";

Expand All @@ -34,13 +55,17 @@ where
DB: BonsaiDatabase + BonsaiPersistentDatabase<CommitId>,
{
pub fn insert(&mut self, hash: ClassHash, compiled_hash: CompiledClassHash) {
// https://docs.starknet.io/architecture-and-concepts/network-architecture/starknet-state/#classes_trie
const CONTRACT_CLASS_LEAF_V0: Felt = short_string!("CONTRACT_CLASS_LEAF_V0");
let value = Poseidon::hash(&CONTRACT_CLASS_LEAF_V0, &compiled_hash);
let value = compute_classes_trie_value(compiled_hash);
self.trie.insert(Self::BONSAI_IDENTIFIER, hash, value)
}

pub fn commit(&mut self, block: BlockNumber) {
self.trie.commit(block.into())
}
}

pub fn compute_classes_trie_value(compiled_class_hash: CompiledClassHash) -> Felt {
// https://docs.starknet.io/architecture-and-concepts/network-architecture/starknet-state/#classes_trie
const CONTRACT_CLASS_LEAF_V0: Felt = short_string!("CONTRACT_CLASS_LEAF_V0");
Poseidon::hash(&CONTRACT_CLASS_LEAF_V0, &compiled_class_hash)
}
Loading

0 comments on commit 3691369

Please sign in to comment.