Skip to content

Commit

Permalink
feat: merkle tree verifier implementation to support all numbers of l…
Browse files Browse the repository at this point in the history
…eaves (#253)

## Overview

Closes #249

The implementation is taken from:
https://github.com/celestiaorg/celestia-core/blob/0498541b8db00c7fefa918d906877ef2ee0a3710/crypto/merkle/proof.go#L166-L197

## Checklist

<!-- 
Please complete the checklist to ensure that the PR is ready to be
reviewed.

IMPORTANT:
PRs should be left in Draft until the below checklist is completed.
-->

- [ ] New and updated code has appropriate documentation
- [ ] New and updated code has new and/or updated testing
- [ ] Required CI checks are passing
- [ ] Visual proof for any user facing features like CLI or
documentation updates
- [ ] Linked issues closed with keywords


<!-- This is an auto-generated comment: release notes by coderabbit.ai
-->
## Summary by CodeRabbit

The existing bullet-point list is still valid based on the provided
information. No changes are required.
<!-- end of auto-generated comment: release notes by coderabbit.ai -->
  • Loading branch information
rach-id authored Nov 9, 2023
1 parent 5b34595 commit 4c8099e
Show file tree
Hide file tree
Showing 5 changed files with 248 additions and 85 deletions.
31 changes: 31 additions & 0 deletions src/lib/tree/Utils.sol
Original file line number Diff line number Diff line change
Expand Up @@ -42,3 +42,34 @@ function pathLengthFromKey(uint256 key, uint256 numLeaves) pure returns (uint256
return 1 + pathLengthFromKey(key - numLeavesLeftSubTree, numLeaves - numLeavesLeftSubTree);
}
}

/// @notice Returns the minimum number of bits required to represent `x`; the
/// result is 0 for `x` == 0.
/// @param x Number.
function _bitsLen(uint256 x) pure returns (uint256) {
uint256 count = 0;

while (x != 0) {
count++;
x >>= 1;
}

return count;
}

/// @notice Returns the largest power of 2 less than `x`.
/// @param x Number.
function _getSplitPoint(uint256 x) pure returns (uint256) {
// Note: since `x` is always an unsigned int * 2, the only way for this
// to be violated is if the input == 0. Since the input is the end
// index exclusive, an input of 0 is guaranteed to be invalid (it would
// be a proof of inclusion of nothing, which is vacuous).
require(x >= 1);

uint256 bitLen = _bitsLen(x);
uint256 k = 1 << (bitLen - 1);
if (k == x) {
k >>= 1;
}
return k;
}
100 changes: 47 additions & 53 deletions src/lib/tree/binary/BinaryMerkleTree.sol
Original file line number Diff line number Diff line change
Expand Up @@ -42,65 +42,59 @@ library BinaryMerkleTree {
}
}

uint256 height = 1;
uint256 stableEnd = proof.key;
bytes32 computedHash = computeRootHash(proof.key, proof.numLeaves, digest, proof.sideNodes);

// While the current subtree (of height 'height') is complete, determine
// the position of the next sibling using the complete subtree algorithm.
// 'stableEnd' tells us the ending index of the last full subtree. It gets
// initialized to 'key' because the first full subtree was the
// subtree of height 1, created above (and had an ending index of
// 'key').

while (true) {
// Determine if the subtree is complete. This is accomplished by
// rounding down the key to the nearest 1 << 'height', adding 1
// << 'height', and comparing the result to the number of leaves in the
// Merkle tree.

uint256 subTreeStartIndex = (proof.key / (1 << height)) * (1 << height);
uint256 subTreeEndIndex = subTreeStartIndex + (1 << height) - 1;

// If the Merkle tree does not have a leaf at index
// 'subTreeEndIndex', then the subtree of the current height is not
// a complete subtree.
if (subTreeEndIndex >= proof.numLeaves) {
break;
}
stableEnd = subTreeEndIndex;

// Determine if the key is in the first or the second half of
// the subtree.
if (proof.sideNodes.length <= height - 1) {
return false;
}
if (proof.key - subTreeStartIndex < (1 << (height - 1))) {
digest = nodeDigest(digest, proof.sideNodes[height - 1]);
} else {
digest = nodeDigest(proof.sideNodes[height - 1], digest);
}
return (computedHash == root);
}

height += 1;
/// @notice Use the leafHash and innerHashes to get the root merkle hash.
/// If the length of the innerHashes slice isn't exactly correct, the result is nil.
/// Recursive impl.
function computeRootHash(uint256 key, uint256 numLeaves, bytes32 leafHash, bytes32[] memory sideNodes)
private
pure
returns (bytes32)
{
if (numLeaves == 0) {
revert("cannot call computeRootHash with 0 number of leaves");
}

// Determine if the next hash belongs to an orphan that was elevated. This
// is the case IFF 'stableEnd' (the last index of the largest full subtree)
// is equal to the number of leaves in the Merkle tree.
if (stableEnd != proof.numLeaves - 1) {
if (proof.sideNodes.length <= height - 1) {
return false;
if (numLeaves == 1) {
if (sideNodes.length != 0) {
revert("unexpected inner hashes");
}
digest = nodeDigest(digest, proof.sideNodes[height - 1]);
height += 1;
return leafHash;
}

// All remaining elements in the proof set will belong to a left sibling\
// i.e proof sideNodes are hashed in "from the left"
while (height - 1 < proof.sideNodes.length) {
digest = nodeDigest(proof.sideNodes[height - 1], digest);
height += 1;
if (sideNodes.length == 0) {
revert("expected at least one inner hash");
}
uint256 numLeft = _getSplitPoint(numLeaves);
bytes32[] memory sideNodesLeft = slice(sideNodes, 0, sideNodes.length - 1);
if (key < numLeft) {
bytes32 leftHash = computeRootHash(key, numLeft, leafHash, sideNodesLeft);
return nodeDigest(leftHash, sideNodes[sideNodes.length - 1]);
}
bytes32 rightHash = computeRootHash(key - numLeft, numLeaves - numLeft, leafHash, sideNodesLeft);
return nodeDigest(sideNodes[sideNodes.length - 1], rightHash);
}

return (digest == root);
/// @notice creates a slice of bytes32 from the data slice of bytes32 containing the elements
/// that correspond to the provided range.
/// It selects a half-open range which includes the begin element, but excludes the end one.
/// @param _data The slice that we want to select data from.
/// @param _begin The beginning of the range (inclusive).
/// @param _end The ending of the range (exclusive).
/// @return _ the sliced data.
function slice(bytes32[] memory _data, uint256 _begin, uint256 _end) internal pure returns (bytes32[] memory) {
if (_begin > _end) {
revert("Invalid range: _begin is greater than _end");
}
if (_begin > _data.length || _end > _data.length) {
revert("Invalid range: _begin or _end are out of bounds");
}
bytes32[] memory out = new bytes32[](_end-_begin);
for (uint256 i = _begin; i < _end; i++) {
out[i - _begin] = _data[i];
}
return out;
}
}
169 changes: 169 additions & 0 deletions src/lib/tree/binary/test/BinaryMerkleTree.t.sol
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
pragma solidity ^0.8.22;

import "ds-test/test.sol";
import "forge-std/Vm.sol";

import "../BinaryMerkleProof.sol";
import "../BinaryMerkleTree.sol";
Expand Down Expand Up @@ -40,6 +41,8 @@ import "../BinaryMerkleTree.sol";
*/

contract BinaryMerkleProofTest is DSTest {
Vm private constant vm = Vm(address(uint160(uint256(keccak256("hevm cheat code")))));

function setUp() external {}

function testVerifyNone() external {
Expand Down Expand Up @@ -101,6 +104,36 @@ contract BinaryMerkleProofTest is DSTest {
assertTrue(isValid);
}

function testVerifyLeafTwoOfEight() external {
bytes32 root = 0xc1ad6548cb4c7663110df219ec8b36ca63b01158956f4be31a38a88d0c7f7071;
bytes32[] memory sideNodes = new bytes32[](3);
sideNodes[0] = 0xb413f47d13ee2fe6c845b2ee141af81de858df4ec549a58b7970bb96645bc8d2;
sideNodes[1] = 0x78850a5ab36238b076dd99fd258c70d523168704247988a94caa8c9ccd056b8d;
sideNodes[2] = 0x4301a067262bbb18b4919742326f6f6d706099f9c0e8b0f2db7b88f204b2cf09;

uint256 key = 1;
uint256 numLeaves = 8;
BinaryMerkleProof memory proof = BinaryMerkleProof(sideNodes, key, numLeaves);
bytes memory data = hex"02";
bool isValid = BinaryMerkleTree.verify(root, proof, data);
assertTrue(isValid);
}

function testVerifyLeafThreeOfEight() external {
bytes32 root = 0xc1ad6548cb4c7663110df219ec8b36ca63b01158956f4be31a38a88d0c7f7071;
bytes32[] memory sideNodes = new bytes32[](3);
sideNodes[0] = 0x4f35212d12f9ad2036492c95f1fe79baf4ec7bd9bef3dffa7579f2293ff546a4;
sideNodes[1] = 0x6bcf0e2e93e0a18e22789aee965e6553f4fbe93f0acfc4a705d691c8311c4965;
sideNodes[2] = 0x4301a067262bbb18b4919742326f6f6d706099f9c0e8b0f2db7b88f204b2cf09;

uint256 key = 2;
uint256 numLeaves = 8;
BinaryMerkleProof memory proof = BinaryMerkleProof(sideNodes, key, numLeaves);
bytes memory data = hex"03";
bool isValid = BinaryMerkleTree.verify(root, proof, data);
assertTrue(isValid);
}

function testVerifyLeafSevenOfEight() external {
bytes32 root = 0xc1ad6548cb4c7663110df219ec8b36ca63b01158956f4be31a38a88d0c7f7071;
bytes32[] memory sideNodes = new bytes32[](3);
Expand Down Expand Up @@ -130,4 +163,140 @@ contract BinaryMerkleProofTest is DSTest {
bool isValid = BinaryMerkleTree.verify(root, proof, data);
assertTrue(isValid);
}

// Test vectors:
// 0x00
// 0x01
// 0x02
// 0x03
// 0x04
function testVerifyProofOfFiveLeaves() external {
bytes32 root = 0xb855b42d6c30f5b087e05266783fbd6e394f7b926013ccaa67700a8b0c5a596f;
bytes32[] memory sideNodes = new bytes32[](3);
sideNodes[0] = 0x96a296d224f285c67bee93c30f8a309157f0daa35dc5b87e410b78630a09cfc7;
sideNodes[1] = 0x52c56b473e5246933e7852989cd9feba3b38f078742b93afff1e65ed46797825;
sideNodes[2] = 0x4f35212d12f9ad2036492c95f1fe79baf4ec7bd9bef3dffa7579f2293ff546a4;

uint256 key = 1;
uint256 numLeaves = 5;
BinaryMerkleProof memory proof = BinaryMerkleProof(sideNodes, key, numLeaves);
bytes memory data = bytes(hex"01");
bool isValid = BinaryMerkleTree.verify(root, proof, data);
assertTrue(isValid);
}

function testVerifyInvalidProofRoot() external {
// correct root: 0xb855b42d6c30f5b087e05266783fbd6e394f7b926013ccaa67700a8b0c5a596f;
bytes32 root = 0xc855b42d6c30f5b087e05266783fbd6e394f7b926013ccaa67700a8b0c5a596f;
bytes32[] memory sideNodes = new bytes32[](3);
sideNodes[0] = 0x96a296d224f285c67bee93c30f8a309157f0daa35dc5b87e410b78630a09cfc7;
sideNodes[1] = 0x52c56b473e5246933e7852989cd9feba3b38f078742b93afff1e65ed46797825;
sideNodes[2] = 0x4f35212d12f9ad2036492c95f1fe79baf4ec7bd9bef3dffa7579f2293ff546a4;

uint256 key = 1;
uint256 numLeaves = 5;
BinaryMerkleProof memory proof = BinaryMerkleProof(sideNodes, key, numLeaves);
bytes memory data = bytes(hex"01");
bool isValid = BinaryMerkleTree.verify(root, proof, data);
assertTrue(!isValid);
}

function testVerifyInvalidProofKey() external {
bytes32 root = 0xb855b42d6c30f5b087e05266783fbd6e394f7b926013ccaa67700a8b0c5a596f;
bytes32[] memory sideNodes = new bytes32[](3);
sideNodes[0] = 0x96a296d224f285c67bee93c30f8a309157f0daa35dc5b87e410b78630a09cfc7;
sideNodes[1] = 0x52c56b473e5246933e7852989cd9feba3b38f078742b93afff1e65ed46797825;
sideNodes[2] = 0x4f35212d12f9ad2036492c95f1fe79baf4ec7bd9bef3dffa7579f2293ff546a4;

// correct key: 1
uint256 key = 2;
uint256 numLeaves = 5;
BinaryMerkleProof memory proof = BinaryMerkleProof(sideNodes, key, numLeaves);
bytes memory data = bytes(hex"01");
bool isValid = BinaryMerkleTree.verify(root, proof, data);
assertTrue(!isValid);
}

function testVerifyInvalidProofNumberOfLeaves() external {
bytes32 root = 0xb855b42d6c30f5b087e05266783fbd6e394f7b926013ccaa67700a8b0c5a596f;
bytes32[] memory sideNodes = new bytes32[](3);
sideNodes[0] = 0x96a296d224f285c67bee93c30f8a309157f0daa35dc5b87e410b78630a09cfc7;
sideNodes[1] = 0x52c56b473e5246933e7852989cd9feba3b38f078742b93afff1e65ed46797825;
sideNodes[2] = 0x4f35212d12f9ad2036492c95f1fe79baf4ec7bd9bef3dffa7579f2293ff546a4;

uint256 key = 1;
// correct numLeaves: 5
uint256 numLeaves = 200;
BinaryMerkleProof memory proof = BinaryMerkleProof(sideNodes, key, numLeaves);
bytes memory data = bytes(hex"01");
bool isValid = BinaryMerkleTree.verify(root, proof, data);
assertTrue(!isValid);
}

function testVerifyInvalidProofSideNodes() external {
bytes32 root = 0xb855b42d6c30f5b087e05266783fbd6e394f7b926013ccaa67700a8b0c5a596f;
bytes32[] memory sideNodes = new bytes32[](3);
sideNodes[0] = 0x96a296d224f285c67bee93c30f8a309157f0daa35dc5b87e410b78630a09cfc7;
sideNodes[1] = 0x52c56b473e5246933e7852989cd9feba3b38f078742b93afff1e65ed46797825;
// correct side node: 0x4f35212d12f9ad2036492c95f1fe79baf4ec7bd9bef3dffa7579f2293ff546a4;
sideNodes[2] = 0x5f35212d12f9ad2036492c95f1fe79baf4ec7bd9bef3dffa7579f2293ff546a4;

uint256 key = 1;
uint256 numLeaves = 5;
BinaryMerkleProof memory proof = BinaryMerkleProof(sideNodes, key, numLeaves);
bytes memory data = bytes(hex"01");
bool isValid = BinaryMerkleTree.verify(root, proof, data);
assertTrue(!isValid);
}

function testVerifyInvalidProofData() external {
bytes32 root = 0xb855b42d6c30f5b087e05266783fbd6e394f7b926013ccaa67700a8b0c5a596f;
bytes32[] memory sideNodes = new bytes32[](3);
sideNodes[0] = 0x96a296d224f285c67bee93c30f8a309157f0daa35dc5b87e410b78630a09cfc7;
sideNodes[1] = 0x52c56b473e5246933e7852989cd9feba3b38f078742b93afff1e65ed46797825;
sideNodes[2] = 0x4f35212d12f9ad2036492c95f1fe79baf4ec7bd9bef3dffa7579f2293ff546a4;

uint256 key = 1;
uint256 numLeaves = 5;
BinaryMerkleProof memory proof = BinaryMerkleProof(sideNodes, key, numLeaves);
// correct data: 01
bytes memory data = bytes(hex"012345");
bool isValid = BinaryMerkleTree.verify(root, proof, data);
assertTrue(!isValid);
}

function testValidSlice() public {
bytes32[] memory data = new bytes32[](4);
data[0] = "a";
data[1] = "b";
data[2] = "c";
data[3] = "d";

bytes32[] memory result = BinaryMerkleTree.slice(data, 1, 3);

assertEq(result[0], data[1]);
assertEq(result[1], data[2]);
}

function testInvalidSliceBeginEnd() public {
bytes32[] memory data = new bytes32[](4);
data[0] = "a";
data[1] = "b";
data[2] = "c";
data[3] = "d";

vm.expectRevert("Invalid range: _begin is greater than _end");
BinaryMerkleTree.slice(data, 2, 1);
}

function testOutOfBoundsSlice() public {
bytes32[] memory data = new bytes32[](4);
data[0] = "a";
data[1] = "b";
data[2] = "c";
data[3] = "d";

vm.expectRevert("Invalid range: _begin or _end are out of bounds");
BinaryMerkleTree.slice(data, 2, 5);
}
}
31 changes: 0 additions & 31 deletions src/lib/tree/namespace/NamespaceMerkleTree.sol
Original file line number Diff line number Diff line change
Expand Up @@ -218,37 +218,6 @@ library NamespaceMerkleTree {
return count;
}

/// @notice Returns the minimum number of bits required to represent `x`; the
/// result is 0 for `x` == 0.
/// @param x Number.
function _bitsLen(uint256 x) private pure returns (uint256) {
uint256 count = 0;

while (x != 0) {
count++;
x >>= 1;
}

return count;
}

/// @notice Returns the largest power of 2 less than `x`.
/// @param x Number.
function _getSplitPoint(uint256 x) private pure returns (uint256) {
// Note: since `x` is always an unsigned int * 2, the only way for this
// to be violated is if the input == 0. Since the input is the end
// index exclusive, an input of 0 is guaranteed to be invalid (it would
// be a proof of inclusion of nothing, which is vacuous).
require(x >= 1);

uint256 bitLen = _bitsLen(x);
uint256 k = 1 << (bitLen - 1);
if (k == x) {
k >>= 1;
}
return k;
}

/// @notice Computes the NMT root recursively.
/// @param proof Namespace Merkle multiproof for the leaves.
/// @param leafNodes Leaf nodes for which inclusion is proven.
Expand Down
Loading

0 comments on commit 4c8099e

Please sign in to comment.