diff --git a/lib/mpt.cairo b/lib/mpt.cairo index 5e3f950..1b2a684 100644 --- a/lib/mpt.cairo +++ b/lib/mpt.cairo @@ -42,6 +42,7 @@ func verify_mpt_proof{range_check_ptr, bitwise_ptr: BitwiseBuiltin*, keccak_ptr: // Last node : item of interest is the value. // Check that the hash of the last node is the expected one. // Check that the final accumulated key is the expected one. + // Check the number of bytes in the key is equal to the number of bytes checked in the key. let (node_hash: Uint256) = keccak(mpt_proof[node_index], mpt_proof_bytes_len[node_index]); %{ print(f"node_hash : {hex(ids.node_hash.low + 2**128*ids.node_hash.high)}") %} %{ print(f"hash_to_assert : {hex(ids.hash_to_assert.low + 2**128*ids.hash_to_assert.high)}") %} @@ -56,7 +57,34 @@ func verify_mpt_proof{range_check_ptr, bitwise_ptr: BitwiseBuiltin*, keccak_ptr: key_little=key_little, n_nibbles_already_checked=n_nibbles_already_checked, ); + local key_bits; + with pow2_array { + if (key_little.high != 0) { + let key_bit_high = get_felt_bitlength(key_little.high); + assert key_bits = 128 + key_bit_high; + } else { + let key_bit_low = get_felt_bitlength(key_little.low); + assert key_bits = key_bit_low; + } + } + local n_bytes_in_key; + let (n_bytes_in_key_tmp, rem) = felt_divmod_8(key_bits); + if (n_bytes_in_key_tmp == 0) { + assert n_bytes_in_key = 1; + } else { + assert n_bytes_in_key = n_bytes_in_key_tmp; + assert rem = 0; + } + local n_bytes_checked; + let (n_bytes_checked_tmp, rem) = felt_divmod(n_nibbles_checked, 2); + if (rem != 0) { + assert n_bytes_checked = n_bytes_checked_tmp + 1; + } else { + assert n_bytes_checked = n_bytes_checked_tmp; + } + + assert n_bytes_in_key = n_bytes_checked; return (item_of_interest, item_of_interest_len); } else { // Not last node : item of interest is the hash of the next node. @@ -284,11 +312,11 @@ func decode_node_list_lazy{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}( // Ensure first_item_type is either 0 or 1. assert (first_item_type - 1) * (first_item_type) = 0; - let first_item_prefix = extract_nibble_at_byte_pos( + let first_item_key_prefix = extract_nibble_at_byte_pos( rlp[0], first_item_start_offset + first_item_type, 0, pow2_array ); %{ - prefix = ids.first_item_prefix + prefix = ids.first_item_key_prefix if prefix == 0: print("First item is an extension node, even number of nibbles") elif prefix == 1: @@ -301,10 +329,10 @@ func decode_node_list_lazy{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}( raise Exception(f"Unknown prefix {prefix} for MPT node with 2 items") %} local odd: felt; - if (first_item_prefix == 0) { + if (first_item_key_prefix == 0) { assert odd = 0; } else { - if (first_item_prefix == 2) { + if (first_item_key_prefix == 2) { assert odd = 0; } else { // 1 & 3 case. @@ -328,6 +356,10 @@ func decode_node_list_lazy{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}( ); %{ print(f"nibbles already checked: {ids.n_nibbles_already_checked}") %} + local range_check_ptr_f; + local bitwise_ptr_f: BitwiseBuiltin*; + local n_nibbles_already_checked_f; + local pow2_array_f: felt*; if (first_item_type != 0) { // If the first item is not a single byte, verify subset in key. assert_subset_in_key( @@ -339,36 +371,36 @@ func decode_node_list_lazy{range_check_ptr, bitwise_ptr: BitwiseBuiltin*}( cut_nibble=odd, pow2_array=pow2_array, ); - tempvar range_check_ptr = range_check_ptr; - tempvar bitwise_ptr = bitwise_ptr; - tempvar pow2_array = pow2_array; + assert range_check_ptr_f = range_check_ptr; + assert bitwise_ptr_f = bitwise_ptr; + assert n_nibbles_already_checked_f = n_nibbles_already_checked; + assert pow2_array_f = pow2_array; } else { - // if the first item is a single byte, skip subset verification and assert n_nibbles_already_checked == n_nibbles_in_key - local key_bits; - with pow2_array { - if (key_little.high != 0) { - let key_bit_high = get_felt_bitlength(key_little.high); - assert key_bits = 128 + key_bit_high; - } else { - let key_bit_low = get_felt_bitlength(key_little.low); - assert key_bits = key_bit_low; - } - } - local n_nibbles_in_key: felt; // <=> ceil(key_bits/4) - let (n_nibbles_in_key_tmp, remainder) = felt_divmod(key_bits, 4); - if (remainder != 0) { - assert n_nibbles_in_key = n_nibbles_in_key_tmp + 1; + // if the first item is a single byte + + if (odd != 0) { + // If the first item has an odd number of nibbles, since there are two nibbles in one byte, the second nibble needs to be checked + let key_nibble = extract_nibble_from_key( + key_little, n_nibbles_already_checked, pow2_array + ); + let (_, first_item_nibble) = felt_divmod(first_item_prefix, 2 ** 4); + assert key_nibble = first_item_nibble; + assert range_check_ptr_f = range_check_ptr; + assert bitwise_ptr_f = bitwise_ptr; + assert n_nibbles_already_checked_f = n_nibbles_already_checked + 1; + assert pow2_array_f = pow2_array; } else { - assert n_nibbles_in_key = n_nibbles_in_key_tmp; + // If the first item has en even number of nibbles, since there are two nibbles in one byte, there is nothing to check. + assert range_check_ptr_f = range_check_ptr; + assert bitwise_ptr_f = bitwise_ptr; + assert n_nibbles_already_checked_f = n_nibbles_already_checked; + assert pow2_array_f = pow2_array; } - assert n_nibbles_in_key = n_nibbles_already_checked; - tempvar range_check_ptr = range_check_ptr; - tempvar bitwise_ptr = bitwise_ptr; - tempvar pow2_array = pow2_array; } - let range_check_ptr = range_check_ptr; - let bitwise_ptr = bitwise_ptr; - let pow2_array = pow2_array; + let range_check_ptr = range_check_ptr_f; + let bitwise_ptr = bitwise_ptr_f; + let pow2_array = pow2_array_f; + let n_nibbles_already_checked = n_nibbles_already_checked_f; // Extract the hash or value.