From bc453f981c6ce2753543474a2492ecac03c36ba4 Mon Sep 17 00:00:00 2001 From: Filip Krawczyk Date: Mon, 3 Jun 2024 16:01:32 +0200 Subject: [PATCH] Change MMR Size type to u128 --- src/data_structures/mmr/mmr.cairo | 21 ++++++++------ src/data_structures/mmr/peaks.cairo | 9 +++--- src/data_structures/mmr/proof.cairo | 17 ++++++----- .../mmr/tests/test_peaks.cairo | 5 ++-- .../mmr/tests/test_utils.cairo | 2 +- src/data_structures/mmr/utils.cairo | 29 ++++++++++--------- src/encoding/tests/test_rlp.cairo | 2 +- 7 files changed, 45 insertions(+), 40 deletions(-) diff --git a/src/data_structures/mmr/mmr.cairo b/src/data_structures/mmr/mmr.cairo index 246a658..e962560 100644 --- a/src/data_structures/mmr/mmr.cairo +++ b/src/data_structures/mmr/mmr.cairo @@ -6,11 +6,14 @@ use cairo_lib::data_structures::mmr::utils::{ }; use cairo_lib::hashing::poseidon::PoseidonHasher; +type MmrElement = felt252; +type MmrSize = u128; + // @notice Merkle Mountatin Range struct #[derive(Drop, Clone, Serde, starknet::Store)] struct MMR { - root: felt252, - last_pos: usize + root: MmrElement, + last_pos: MmrSize } impl MMRDefault of Default { @@ -28,7 +31,7 @@ impl MMRImpl of MMRTrait { // @param last_pos The last position in the MMR // @return MMR with the given root and last_pos #[inline(always)] - fn new(root: felt252, last_pos: usize) -> MMR { + fn new(root: MmrElement, last_pos: MmrSize) -> MMR { MMR { root, last_pos } } @@ -36,9 +39,9 @@ impl MMRImpl of MMRTrait { // @param hash The hashed element to append // @param peaks The peaks of the MMR // @return Result with the new root and new peaks of the MMR - fn append(ref self: MMR, hash: felt252, peaks: Peaks) -> Result<(felt252, Peaks), felt252> { - let leaf_count = mmr_size_to_leaf_count(self.last_pos.into()); - let peaks_count = peaks.len(); + fn append(ref self: MMR, hash: MmrElement, peaks: Peaks) -> Result<(MmrElement, Peaks), felt252> { + let leaf_count = mmr_size_to_leaf_count(self.last_pos); + let peaks_count= peaks.len(); if leaf_count_to_peaks_count(leaf_count) != peaks_count.into() { return Result::Err('Invalid peaks count'); @@ -74,7 +77,7 @@ impl MMRImpl of MMRTrait { }; new_peaks.append(last_peak); - let new_root = compute_root(self.last_pos.into(), new_peaks.span()); + let new_root = compute_root(self.last_pos, new_peaks.span()); self.root = new_root; Result::Ok((new_root, new_peaks.span())) @@ -87,9 +90,9 @@ impl MMRImpl of MMRTrait { // @param proof The proof for the element // @return Result with true if the proof is valid, false otherwise fn verify_proof( - self: @MMR, index: usize, hash: felt252, peaks: Peaks, proof: Proof + self: @MMR, index: MmrSize, hash: MmrElement, peaks: Peaks, proof: Proof ) -> Result { - let leaf_count = mmr_size_to_leaf_count((*self.last_pos).into()); + let leaf_count = mmr_size_to_leaf_count(*self.last_pos); if leaf_count_to_peaks_count(leaf_count) != peaks.len().into() { return Result::Err('Invalid peaks count'); } diff --git a/src/data_structures/mmr/peaks.cairo b/src/data_structures/mmr/peaks.cairo index 1fbbae8..91b4b20 100644 --- a/src/data_structures/mmr/peaks.cairo +++ b/src/data_structures/mmr/peaks.cairo @@ -1,15 +1,16 @@ use cairo_lib::hashing::poseidon::PoseidonHasher; use cairo_lib::data_structures::mmr::utils::compute_root; +use cairo_lib::data_structures::mmr::mmr::{MmrSize, MmrElement}; use cairo_lib::utils::array::span_contains; // @notice Represents the peaks of the MMR -type Peaks = Span; +type Peaks = Span; #[generate_trait] impl PeaksImpl of PeaksTrait { // @notice Bags the peaks (hashing them together) // @return The bagged peaks - fn bag(self: Peaks) -> felt252 { + fn bag(self: Peaks) -> MmrElement { if self.is_empty() { return 0; } @@ -35,8 +36,8 @@ impl PeaksImpl of PeaksTrait { // @param last_pos The last position in the MMR // @param root The root of the MMR // @return True if the peaks are valid - fn valid(self: Peaks, last_pos: usize, root: felt252) -> bool { - let computed_root = compute_root(last_pos.into(), self); + fn valid(self: Peaks, last_pos: MmrSize, root: MmrElement) -> bool { + let computed_root = compute_root(last_pos, self); computed_root == root } } diff --git a/src/data_structures/mmr/proof.cairo b/src/data_structures/mmr/proof.cairo index 411eb3e..2f85dfe 100644 --- a/src/data_structures/mmr/proof.cairo +++ b/src/data_structures/mmr/proof.cairo @@ -1,4 +1,5 @@ use cairo_lib::hashing::poseidon::PoseidonHasher; +use cairo_lib::data_structures::mmr::mmr::{MmrSize, MmrElement}; use cairo_lib::data_structures::mmr::utils::get_height; use cairo_lib::utils::bitwise::{left_shift, bit_length}; use cairo_lib::utils::math::pow; @@ -13,24 +14,24 @@ impl ProofImpl of ProofTrait { // @param index Index of the element to start from // @param hash Hash of the element to start from // @return The root of the subtree - fn compute_peak(self: Proof, index: usize, hash: felt252) -> felt252 { + fn compute_peak(self: Proof, index: MmrSize, hash: MmrElement) -> felt252 { // calculate direction array // direction[i] - whether the i-th node from the root is a left or a right child of its // parent - let mut bits = bit_length(index); - if self.len() + 1 > bits { - bits = self.len() + 1; + let mut bits: MmrSize = bit_length(index); + if self.len().into() + 1 > bits { + bits = self.len().into() + 1; }; let mut direction: Array = ArrayTrait::new(); - let mut p: usize = 1; - let mut q: usize = pow(2, bits) - 1; + let mut p: MmrSize = 1; + let mut q: MmrSize = pow(2, bits) - 1; loop { if p >= q { break (); } - let m: usize = (p + q) / 2; + let m: MmrSize = (p + q) / 2; if index < m { q = m - 1; @@ -47,7 +48,7 @@ impl ProofImpl of ProofTrait { let mut current_hash = hash; let mut i: usize = 0; - let mut two_pow_i: usize = 2; + let mut two_pow_i: MmrSize = 2; loop { if i == self.len() { break current_hash; diff --git a/src/data_structures/mmr/tests/test_peaks.cairo b/src/data_structures/mmr/tests/test_peaks.cairo index d2c0484..5d4489f 100644 --- a/src/data_structures/mmr/tests/test_peaks.cairo +++ b/src/data_structures/mmr/tests/test_peaks.cairo @@ -48,8 +48,7 @@ fn test_valid() { let bag = peaks.span().bag(); let last_pos = 923048; - let last_pos_u32 = 923048_u32; - let root = PoseidonHasher::hash_double(last_pos, bag); + let root = PoseidonHasher::hash_double(last_pos.into(), bag); - assert(peaks.span().valid(last_pos_u32, root), 'Valid'); + assert(peaks.span().valid(last_pos, root), 'Valid'); } diff --git a/src/data_structures/mmr/tests/test_utils.cairo b/src/data_structures/mmr/tests/test_utils.cairo index 4a80033..dfac494 100644 --- a/src/data_structures/mmr/tests/test_utils.cairo +++ b/src/data_structures/mmr/tests/test_utils.cairo @@ -27,7 +27,7 @@ fn test_compute_root() { let bag = peaks.span().bag(); let last_pos = 923048; - let root = PoseidonHasher::hash_double(last_pos, bag); + let root = PoseidonHasher::hash_double(last_pos.into(), bag); let computed_root = compute_root(last_pos, peaks.span()); assert(root == computed_root, 'Roots not matching'); diff --git a/src/data_structures/mmr/utils.cairo b/src/data_structures/mmr/utils.cairo index ae4a218..a4369cf 100644 --- a/src/data_structures/mmr/utils.cairo +++ b/src/data_structures/mmr/utils.cairo @@ -1,12 +1,13 @@ use cairo_lib::utils::bitwise::{bit_length, left_shift}; use cairo_lib::utils::math::pow; use cairo_lib::hashing::poseidon::PoseidonHasher; +use cairo_lib::data_structures::mmr::mmr::{MmrSize, MmrElement}; use cairo_lib::data_structures::mmr::peaks::{Peaks, PeaksTrait}; // @notice Computes the height of a node in the MMR // @param index The index of the node // @return The height of the node -fn get_height(index: usize) -> usize { +fn get_height(index: MmrSize) -> MmrSize { let bits = bit_length(index); let ones = pow(2, bits) - 1; @@ -22,15 +23,15 @@ fn get_height(index: usize) -> usize { // @param last_pos The position of the last node in the MMR // @param peaks The peaks of the MMR // @return The root of the MMR -fn compute_root(last_pos: felt252, peaks: Peaks) -> felt252 { +fn compute_root(last_pos: MmrSize, peaks: Peaks) -> MmrElement { let bag = peaks.bag(); - PoseidonHasher::hash_double(last_pos, bag) + PoseidonHasher::hash_double(last_pos.into(), bag) } // @notice Count the number of bits set to 1 in an unsigned integer -// @param arg The usize (u32) unsigned integer +// @param arg The u128 unsigned integer // @return The number of bits set to 1 in n -fn count_ones(n: usize) -> usize { +fn count_ones(n: MmrSize) -> usize { let mut n = n; let mut count = 0; loop { @@ -47,8 +48,8 @@ fn count_ones(n: usize) -> usize { // @return The MMR index // Explanation of why this formula is correct // https://mmr.herodotus.dev/mmr-size-vs-leaf-count#leaf-count-to-mmr-size-algorithm -fn leaf_index_to_mmr_index(n: usize) -> usize { - 2 * n - 1 - count_ones(n - 1) +fn leaf_index_to_mmr_index(n: MmrSize) -> MmrSize { + 2 * n - 1 - count_ones(n - 1).into() } // @notice Convert a Merkle Mountain Range tree size to number of leaves @@ -56,7 +57,7 @@ fn leaf_index_to_mmr_index(n: usize) -> usize { // @result Number of leaves // Explanation of why this algorithm is correct // https://mmr.herodotus.dev/mmr-size-vs-leaf-count#mmr-size-to-leaf-count-algorithm -fn mmr_size_to_leaf_count(n: usize) -> usize { +fn mmr_size_to_leaf_count(n: MmrSize) -> MmrSize { let mut mmr_size = n; let bits = bit_length(mmr_size + 1); let mut mountain_leaf_count = pow(2, bits - 1); @@ -78,19 +79,19 @@ fn mmr_size_to_leaf_count(n: usize) -> usize { // @notice Convert a number of leaves to number of peaks // @param leaf_count Number of leaves // @return Number of peaks -fn leaf_count_to_peaks_count(leaf_count: usize) -> usize { +fn leaf_count_to_peaks_count(leaf_count: MmrSize) -> usize { count_ones(leaf_count) } // @notice Get the number of trailing ones in the binary representation of a number // @param n The number // @return Number of trailing ones -fn trailing_ones(n: usize) -> usize { +fn trailing_ones(n: MmrSize) -> usize { let mut n = n; let mut count = 0; loop { let (halfed, rem) = DivRem::div_rem( - n, TryInto::>::try_into(2).unwrap() + n, TryInto::>::try_into(2).unwrap() ); if rem == 0 { break count; @@ -104,12 +105,12 @@ fn trailing_ones(n: usize) -> usize { // @param elements_count The size of the MMR (number of elements in the MMR) // @param element_index The index of the element in the MMR // @return (peak index, peak height) -fn get_peak_info(elements_count: usize, element_index: usize) -> (usize, usize) { +fn get_peak_info(elements_count: MmrSize, element_index: MmrSize) -> (usize, usize) { let mut elements_count = elements_count; let mut element_index = element_index; - let mut mountain_height = bit_length(elements_count); - let mut mountain_elements_count = pow(2, mountain_height) - 1; + let mut mountain_height: usize = bit_length(elements_count).try_into().unwrap(); + let mut mountain_elements_count: MmrSize = pow(2, mountain_height.into()) - 1; let mut mountain_index = 0; loop { if mountain_elements_count <= elements_count { diff --git a/src/encoding/tests/test_rlp.cairo b/src/encoding/tests/test_rlp.cairo index b55b3c0..1397454 100644 --- a/src/encoding/tests/test_rlp.cairo +++ b/src/encoding/tests/test_rlp.cairo @@ -84,7 +84,7 @@ fn test_rlp_decode_list_lazy() { array![8, 11].span() ) .unwrap(); - let ((block_number, block_number_byte_len), (timestamp, timestamp_byte_len)) = + let ((block_number, _block_number_byte_len), (timestamp, _timestamp_byte_len)) = match decoded_rlp { RLPItem::Bytes(_) => panic_with_felt252('Invalid header rlp'), RLPItem::List(l) => { (*l.at(0), *l.at(1)) },