Skip to content

Commit

Permalink
Merge pull request #7 from HerodotusDev/non-inclusion
Browse files Browse the repository at this point in the history
feat: add support for verifying non-inclusion proofs
  • Loading branch information
petscheit authored Aug 19, 2024
2 parents f4db981 + e977e86 commit f632dea
Show file tree
Hide file tree
Showing 3 changed files with 177 additions and 119 deletions.
279 changes: 165 additions & 114 deletions lib/mpt.cairo
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from starkware.cairo.common.uint256 import Uint256, uint256_reverse_endian
from starkware.cairo.common.cairo_builtins import BitwiseBuiltin, KeccakBuiltin
from starkware.cairo.common.builtin_keccak.keccak import keccak
from starkware.cairo.common.alloc import alloc
from starkware.cairo.common.registers import get_fp_and_pc
from lib.rlp_little import (
extract_byte_at_pos,
Expand Down Expand Up @@ -34,6 +35,8 @@ from lib.utils import (
// returns:
// - the value of the proof as a felt* array of little endian 8 bytes chunks.
// - the total length in bytes of the value.
// If the proof passed is a non inclusion proof for the given key,
// returns (value=rlp, value_len=-1).
func verify_mpt_proof{range_check_ptr, bitwise_ptr: BitwiseBuiltin*, keccak_ptr: KeccakBuiltin*}(
mpt_proof: felt**,
mpt_proof_bytes_len: felt*,
Expand Down Expand Up @@ -379,7 +382,10 @@ func decode_node_list_lazy{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(
local bitwise_ptr_f: BitwiseBuiltin*;
local n_nibbles_already_checked_f;
local pow2_array_f: felt*;

local key_checked: felt;
if (first_item_type != 0) {
// First item is a long string.
tempvar n_nibbles_in_first_item = 2 * (first_item_len - 1) + odd;
// %{ print(f"n_nibbles_in_first_item : {ids.n_nibbles_in_first_item}") %}
// Extract the key or key_end. start offset + 1 (item prefix) + 1 (key prefix) - odd (1 if to include prefix's byte in case the nibbles are odd).
Expand All @@ -400,13 +406,8 @@ func decode_node_list_lazy{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(
pow2_array,
);
// %{
// print(f"nibbles already checked: {ids.n_nibbles_already_checked}")
// if ids.extracted_key_subset_len == 1:
// print(f"Extracted key subset : {hex(memory[ids.extracted_key_subset])}")
// %}
// If the first item is not a single byte, verify subset in key.
assert_subset_in_key_be(
// Check if the extracted key is contained is contained in the full key.
let (contains_subkey) = assert_subset_in_key_be(
key_subset=extracted_key_subset,
key_subset_len=extracted_key_subset_len,
key_subset_nibble_len=n_nibbles_in_first_item,
Expand All @@ -417,14 +418,14 @@ func decode_node_list_lazy{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(
cut_nibble=odd,
pow2_array=pow2_array,
);
assert key_checked = contains_subkey;
assert range_check_ptr_f = range_check_ptr;
assert bitwise_ptr_f = bitwise_ptr;
assert n_nibbles_already_checked_f = n_nibbles_already_checked +
n_nibbles_in_first_item;
assert pow2_array_f = pow2_array;
} else {
// if the first item is a single byte

if (odd != 0) {
// If the first item has an odd number of nibbles, since there are two nibbles in one byte, the second nibble needs to be checked
let key_nibble = extract_nibble_from_key_be(
Expand All @@ -435,7 +436,11 @@ func decode_node_list_lazy{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(
pow2_array,
);
let (_, first_item_nibble) = felt_divmod(first_item_prefix, 2 ** 4);
assert key_nibble = first_item_nibble;
if (key_nibble == first_item_nibble) {
assert key_checked = 1;
} else {
assert key_checked = 0;
}
assert range_check_ptr_f = range_check_ptr;
assert bitwise_ptr_f = bitwise_ptr;
assert n_nibbles_already_checked_f = n_nibbles_already_checked + 1;
Expand All @@ -446,67 +451,118 @@ func decode_node_list_lazy{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(
assert bitwise_ptr_f = bitwise_ptr;
assert n_nibbles_already_checked_f = n_nibbles_already_checked;
assert pow2_array_f = pow2_array;
assert key_checked = 1;
}
}
let range_check_ptr = range_check_ptr_f;
let bitwise_ptr = bitwise_ptr_f;
let pow2_array = pow2_array_f;
let n_nibbles_already_checked = n_nibbles_already_checked_f;

// Extract the hash or value.
let (second_item_value_starts_word, second_item_value_start_offset) = felt_divmod(
second_item_value_starts_at_byte, 8
);
if (last_node != 0) {
// Extract value
let (value, value_len) = extract_n_bytes_from_le_64_chunks_array(
rlp,
second_item_value_starts_word,
second_item_value_start_offset,
second_item_bytes_len,
pow2_array,
);
if (key_checked == 0) {
// Key does not match, we have a non-inclusion. Return empty value.
// Make sure nibbles checked will pass. We encode non-inclusion result as (-1) length.
return (
n_nibbles_already_checked=n_nibbles_already_checked,
item_of_interest=value,
item_of_interest_len=second_item_bytes_len,
n_nibbles_already_checked=key_be_leading_zeroes_nibbles + key_be_nibbles,
item_of_interest=rlp,
item_of_interest_len=-1,
);
} else {
// Extract hash (32 bytes)
// %{ print(f"Extracting hash in leaf/node case)") %}
assert second_item_bytes_len = 32;
let (local hash_le: Uint256) = extract_le_hash_from_le_64_chunks_array(
rlp, second_item_value_starts_word, second_item_value_start_offset, pow2_array
);
return (
n_nibbles_already_checked=n_nibbles_already_checked,
item_of_interest=cast(&hash_le, felt*),
item_of_interest_len=32,
// Key match and is included, return actual value.
// Extract value or hash.
let (second_item_value_starts_word, second_item_value_start_offset) = felt_divmod(
second_item_value_starts_at_byte, 8
);
if (last_node != 0) {
// Extract value
let (value, value_len) = extract_n_bytes_from_le_64_chunks_array(
rlp,
second_item_value_starts_word,
second_item_value_start_offset,
second_item_bytes_len,
pow2_array,
);
return (
n_nibbles_already_checked=n_nibbles_already_checked,
item_of_interest=value,
item_of_interest_len=second_item_bytes_len,
);
} else {
// Extract hash (32 bytes)
// %{ print(f"Extracting hash in leaf/node case)") %}
assert second_item_bytes_len = 32;
let (local hash_le: Uint256) = extract_le_hash_from_le_64_chunks_array(
rlp, second_item_value_starts_word, second_item_value_start_offset, pow2_array
);
return (
n_nibbles_already_checked=n_nibbles_already_checked,
item_of_interest=cast(&hash_le, felt*),
item_of_interest_len=32,
);
}
}
} else {
// Node has more than 2 items : it's a branch.
if (last_node != 0) {
// %{ print(f"Branch case, last node : yes") %}
%{ print(f"Branch case, last node : yes") %}

// Branch is the last node in the proof. We need to extract the last item (17th).
// Key should already be fully checked at this point.
let (third_item_start_word, third_item_start_offset) = felt_divmod(
third_item_starts_at_byte, 8
);
let (
last_item_start_word, last_item_start_offset
) = jump_branch_node_till_element_at_index(
rlp, 0, 16, third_item_start_word, third_item_start_offset, pow2_array
);
tempvar last_item_bytes_len = bytes_len - (
last_item_start_word * 8 + last_item_start_offset
);
let (last_item: felt*, last_item_len: felt) = extract_n_bytes_from_le_64_chunks_array(
rlp, last_item_start_word, last_item_start_offset, last_item_bytes_len, pow2_array
);
// Branch is the last node in the proof.
// For an inclusion, proof, key should already be fully checked at this point.
// For a non inclusion proof (key hasn't been already checked despite last node), the item at the next nibble index should be empty.
if (key_be_leading_zeroes_nibbles + key_be_nibbles != n_nibbles_already_checked) {
let next_key_nibble = extract_nibble_from_key_be(
key_be,
key_be_nibbles,
key_be_leading_zeroes_nibbles,
n_nibbles_already_checked,
pow2_array,
);
let (
item_of_interest_start_word: felt, item_of_interest_start_offset: felt
) = get_branch_value_precomputed_offsets_1_2_3(
rlp,
next_key_nibble,
first_item_start_offset,
second_item_value_starts_at_byte,
third_item_starts_at_byte,
pow2_array,
);
let should_be_empty_prefix = extract_byte_at_pos(
rlp[item_of_interest_start_word], item_of_interest_start_offset, pow2_array
);
return (n_nibbles_already_checked, last_item, last_item_bytes_len);
assert should_be_empty_prefix = 0x80;
return (
n_nibbles_already_checked + key_be_nibbles + key_be_leading_zeroes_nibbles,
rlp,
-1,
);
} else {
let (third_item_start_word, third_item_start_offset) = felt_divmod(
third_item_starts_at_byte, 8
);
let (
last_item_start_word, last_item_start_offset
) = jump_branch_node_till_element_at_index(
rlp, 2, 16, third_item_start_word, third_item_start_offset, pow2_array
); // we start jumping of the 3rd item (index 2) to the 17th item (index 16)
tempvar last_item_bytes_len = bytes_len - (
last_item_start_word * 8 + last_item_start_offset
);

let (
last_item: felt*, last_item_len: felt
) = extract_n_bytes_from_le_64_chunks_array(
rlp,
last_item_start_word,
last_item_start_offset,
last_item_bytes_len,
pow2_array,
);

return (n_nibbles_already_checked, last_item, last_item_bytes_len);
}
} else {
// %{ print(f"Branch case, last node : no") %}
// Branch is not the last node in the proof. We need to extract the hash corresponding to the next nibble of the key.
Expand All @@ -520,67 +576,17 @@ func decode_node_list_lazy{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(
pow2_array,
);
// %{ print(f"Next Key nibble {ids.next_key_nibble}") %}
local item_of_interest_start_word: felt;
local item_of_interest_start_offset: felt;
local range_check_ptr_f;
local bitwise_ptr_f: BitwiseBuiltin*;
if (next_key_nibble == 0) {
// Store coordinates of the first item's value.
// %{ print(f"\t Branch case, key index = 0") %}
assert item_of_interest_start_word = 0;
assert item_of_interest_start_offset = first_item_start_offset + 1;
assert range_check_ptr_f = range_check_ptr;
assert bitwise_ptr_f = bitwise_ptr;
} else {
if (next_key_nibble == 1) {
// Store coordinates of the second item's value.
// %{ print(f"\t Branch case, key index = 1") %}
let (
second_item_value_start_word, second_item_value_start_offset
) = felt_divmod_8(second_item_value_starts_at_byte);
assert item_of_interest_start_word = second_item_value_start_word;
assert item_of_interest_start_offset = second_item_value_start_offset;
assert range_check_ptr_f = range_check_ptr;
assert bitwise_ptr_f = bitwise_ptr;
} else {
if (next_key_nibble == 2) {
// Store coordinates of the third item's value.
// %{ print(f"\t Branch case, key index = 2") %}
let (
third_item_value_start_word, third_item_value_start_offset
) = felt_divmod_8(third_item_starts_at_byte + 1);
assert item_of_interest_start_word = third_item_value_start_word;
assert item_of_interest_start_offset = third_item_value_start_offset;
assert range_check_ptr_f = range_check_ptr;
assert bitwise_ptr_f = bitwise_ptr;
} else {
// Store coordinates of the item's value at index next_key_nibble != (0, 1, 2).
// %{ print(f"\t Branch case, key index {ids.next_key_nibble}") %}
let (third_item_start_word, third_item_start_offset) = felt_divmod(
third_item_starts_at_byte, 8
);
let (
item_start_word, item_start_offset
) = jump_branch_node_till_element_at_index(
rlp=rlp,
item_start_index=2,
target_index=next_key_nibble,
prefix_start_word=third_item_start_word,
prefix_start_offset=third_item_start_offset,
pow2_array=pow2_array,
);
let (item_value_start_word, item_value_start_offset) = felt_divmod(
item_start_word * 8 + item_start_offset + 1, 8
);
assert item_of_interest_start_word = item_value_start_word;
assert item_of_interest_start_offset = item_value_start_offset;
assert range_check_ptr_f = range_check_ptr;
assert bitwise_ptr_f = bitwise_ptr;
}
}
}
let range_check_ptr = range_check_ptr_f;
let bitwise_ptr = bitwise_ptr_f;
let (
local item_of_interest_start_word: felt, local item_of_interest_start_offset: felt
) = get_branch_value_precomputed_offsets_1_2_3(
rlp,
next_key_nibble,
first_item_start_offset,
second_item_value_starts_at_byte,
third_item_starts_at_byte,
pow2_array,
);

// Extract the hash at the correct coordinates.

let (local hash_le: Uint256) = extract_le_hash_from_le_64_chunks_array(
Expand All @@ -593,6 +599,51 @@ func decode_node_list_lazy{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(
}
}

func get_branch_value_precomputed_offsets_1_2_3{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(
rlp: felt*,
next_key_nibble: felt,
first_item_start_offset: felt,
second_item_value_starts_at_byte: felt,
third_item_starts_at_byte: felt,
pow2_array: felt*,
) -> (item_of_interest_start_word: felt, item_of_interest_start_offset: felt) {
if (next_key_nibble == 0) {
// Store coordinates of the first item's value.
// %{ print(f"\t Branch case, key index = 0") %}
return (0, first_item_start_offset + 1);
} else {
if (next_key_nibble == 1) {
// Store coordinates of the second item's value.
// %{ print(f"\t Branch case, key index = 1") %}
let (q, r) = felt_divmod_8(second_item_value_starts_at_byte);
return (q, r);
} else {
if (next_key_nibble == 2) {
// Store coordinates of the third item's value.
// %{ print(f"\t Branch case, key index = 2") %}
let (q, r) = felt_divmod_8(third_item_starts_at_byte + 1);
return (q, r);
} else {
// Store coordinates of the item's value at index next_key_nibble != (0, 1, 2).
// %{ print(f"\t Branch case, key index {ids.next_key_nibble}") %}
let (third_item_start_word, third_item_start_offset) = felt_divmod(
third_item_starts_at_byte, 8
);
let (item_start_word, item_start_offset) = jump_branch_node_till_element_at_index(
rlp=rlp,
item_start_index=2,
target_index=next_key_nibble,
prefix_start_word=third_item_start_word,
prefix_start_offset=third_item_start_offset,
pow2_array=pow2_array,
);
let (q, r) = felt_divmod(item_start_word * 8 + item_start_offset + 1, 8);
return (q, r);
}
}
}
}
// Jumps on a branch until index i is reached.
// params:
// - rlp: the branch node as an array of little endian 8 bytes chunks.
Expand Down
14 changes: 10 additions & 4 deletions lib/rlp_little.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ func assert_subset_in_key_be{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(
n_nibbles_already_checked: felt,
cut_nibble: felt,
pow2_array: felt*,
) {
) -> (res: felt) {
alloc_locals;

// Get the little endian 256 bit number from the extracted 64 bit le words array :
Expand Down Expand Up @@ -373,9 +373,15 @@ func assert_subset_in_key_be{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}(
// %{ print(f"Key subset final: {hex(ids.key_subset_be_final.low + ids.key_subset_be_final.high*2**128)}") %}

// %{ print(f"key subset expect: {hex(ids.key_shifted.low + ids.key_shifted.high*2**128)}") %}
assert key_subset_be_final.low = key_shifted.low;
assert key_subset_be_final.high = key_shifted.high;
return ();
if (key_subset_be_final.low == key_shifted.low) {
if (key_subset_be_final.high == key_shifted.high) {
return (1,);
} else {
return (0,);
}
} else {
return (0,);
}
}

// From a 256 bit key in big endian of the form :
Expand Down
Loading

0 comments on commit f632dea

Please sign in to comment.