From 1278ba26e5727df0b5567ec95fb3ca07536b5aaf Mon Sep 17 00:00:00 2001 From: CHAMI Rachid Date: Tue, 16 Jul 2024 16:07:39 +0200 Subject: [PATCH] feat: add support for subtree root verification (#260) ## Overview Closes https://github.com/celestiaorg/nmt/issues/256 I added this change here so that we have a reference implementation of the algorithm that we will implement in Solidity. Also, adds a method to generate the subtree roots, which didn't exist before and will be needed during proof generation in Celestia-node. The codecov missing coverage complaints are for conditions that are checked twice in different contexts. So there is no way to bypass the first check to arrive at the second check. So, I guess they're fine. ## Summary by CodeRabbit - **New Features** - Added subtree root computation functionality to the Namespaced Merkle Tree (NMT). - Introduced new validation methods for subtree root inclusion in NMT. - **Tests** - Added comprehensive tests for subtree root computation and verification in the Namespaced Merkle Tree. - Introduced helper functions for enhanced verification capabilities. - Added edge case handling for various scenarios in the NMT proof system. --- nmt.go | 57 ++-- nmt_test.go | 137 ++++++++++ proof.go | 197 +++++++++++++- proof_test.go | 704 ++++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 1079 insertions(+), 16 deletions(-) diff --git a/nmt.go b/nmt.go index e9c318a..f8f82f5 100644 --- a/nmt.go +++ b/nmt.go @@ -107,9 +107,9 @@ type NamespacedMerkleTree struct { // namespaceRanges can be used to efficiently look up the range for an // existing namespace without iterating through the leaves. The map key is - // the string representation of a namespace.ID and the leafRange indicates + // the string representation of a namespace.ID and the LeafRange indicates // the range of the leaves matching that namespace ID in the tree - namespaceRanges map[string]leafRange + namespaceRanges map[string]LeafRange // minNID is the minimum namespace ID of the leaves minNID namespace.ID // maxNID is the maximum namespace ID of the leaves @@ -151,7 +151,7 @@ func New(h hash.Hash, setters ...Option) *NamespacedMerkleTree { visit: opts.NodeVisitor, leaves: make([][]byte, 0, opts.InitialCapacity), leafHashes: make([][]byte, 0, opts.InitialCapacity), - namespaceRanges: make(map[string]leafRange), + namespaceRanges: make(map[string]LeafRange), minNID: bytes.Repeat([]byte{0xFF}, int(opts.NamespaceIDSize)), maxNID: bytes.Repeat([]byte{0x00}, int(opts.NamespaceIDSize)), } @@ -436,7 +436,7 @@ func (n *NamespacedMerkleTree) foundInRange(nID namespace.ID) (found bool, start // This is a faster version of this code snippet: // https://github.com/celestiaorg/celestiaorg-prototype/blob/2aeca6f55ad389b9d68034a0a7038f80a8d2982e/simpleblock.go#L106-L117 foundRng, found := n.namespaceRanges[string(nID)] - return found, foundRng.start, foundRng.end + return found, foundRng.Start, foundRng.End } // NamespaceSize returns the underlying namespace size. Note that all namespaced @@ -590,14 +590,14 @@ func (n *NamespacedMerkleTree) updateNamespaceRanges() { lastNsStr := string(lastPushed[:n.treeHasher.NamespaceSize()]) lastRange, found := n.namespaceRanges[lastNsStr] if !found { - n.namespaceRanges[lastNsStr] = leafRange{ - start: lastIndex, - end: lastIndex + 1, + n.namespaceRanges[lastNsStr] = LeafRange{ + Start: lastIndex, + End: lastIndex + 1, } } else { - n.namespaceRanges[lastNsStr] = leafRange{ - start: lastRange.start, - end: lastRange.end + 1, + n.namespaceRanges[lastNsStr] = LeafRange{ + Start: lastRange.Start, + End: lastRange.End + 1, } } } @@ -644,11 +644,38 @@ func (n *NamespacedMerkleTree) updateMinMaxID(id namespace.ID) { } } -type leafRange struct { - // start and end denote the indices of a leaf in the tree. start ranges from - // 0 up to the total number of leaves minus 1 end ranges from 1 up to the - // total number of leaves end is non-inclusive - start, end int +// ComputeSubtreeRoot takes a leaf range and returns the corresponding subtree root. +// Also, it requires the start and end range to correctly reference an inner node. +// The provided range, defined by start and end, is end-exclusive. +func (n *NamespacedMerkleTree) ComputeSubtreeRoot(start, end int) ([]byte, error) { + if start < 0 { + return nil, fmt.Errorf("start %d shouldn't be strictly negative", start) + } + if end <= start { + return nil, fmt.Errorf("end %d should be stricly bigger than start %d", end, start) + } + uStart, err := safeIntToUint(start) + if err != nil { + return nil, err + } + uEnd, err := safeIntToUint(end) + if err != nil { + return nil, err + } + // check if the provided range correctly references an inner node. + // calculates the ideal tree from the provided range, and verifies if it is the same as the range + if idealTreeRange := nextSubtreeSize(uint64(uStart), uint64(uEnd)); end-start != idealTreeRange { + return nil, fmt.Errorf("the provided range [%d, %d) does not construct a valid subtree root range", start, end) + } + return n.computeRoot(start, end) +} + +type LeafRange struct { + // Start and End denote the indices of a leaf in the tree. + // Start ranges from 0 up to the total number of leaves minus 1. + // End ranges from 1 up to the total number of leaves. + // End is non-inclusive + Start, End int } // MinNamespace extracts the minimum namespace ID from a given namespace hash, diff --git a/nmt_test.go b/nmt_test.go index 6e0565e..8307681 100644 --- a/nmt_test.go +++ b/nmt_test.go @@ -862,6 +862,20 @@ func exampleNMT(nidSize int, ignoreMaxNamespace bool, leavesNIDs ...byte) *Names return tree } +// exampleNMT2 Replica of exampleNMT except that it uses the namespace IDs in the +// leaves instead of the index. +func exampleNMT2(nidSize int, ignoreMaxNamespace bool, leavesNIDs ...byte) *NamespacedMerkleTree { + tree := New(sha256.New(), NamespaceIDSize(nidSize), IgnoreMaxNamespace(ignoreMaxNamespace)) + for _, nid := range leavesNIDs { + namespace := bytes.Repeat([]byte{nid}, nidSize) + d := append(namespace, []byte(fmt.Sprintf("leaf_%d", nid))...) + if err := tree.Push(d); err != nil { + panic(fmt.Sprintf("unexpected error: %v", err)) + } + } + return tree +} + func swap(slice [][]byte, i int, j int) { temp := slice[i] slice[i] = slice[j] @@ -1175,3 +1189,126 @@ func TestForcedOutOfOrderNamespacedMerkleTree(t *testing.T) { assert.NoError(t, err) } } + +func TestComputeSubtreeRoot(t *testing.T) { + n := exampleNMT2(1, true, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15) + tests := []struct { + start, end int + tree *NamespacedMerkleTree + expectedRoot []byte + expectError bool + }{ + { + start: 0, + end: 16, + tree: n, + expectedRoot: func() []byte { + root, err := n.Root() + require.NoError(t, err) + return root + }(), + }, + { + start: 0, + end: 8, + tree: n, + expectedRoot: func() []byte { + // because the root of the range [0,8) coincides with the root of this tree + root, err := exampleNMT2(1, true, 0, 1, 2, 3, 4, 5, 6, 7).Root() + require.NoError(t, err) + return root + }(), + }, + { + start: 8, + end: 16, + tree: n, + expectedRoot: func() []byte { + // because the root of the range [8,16) coincides with the root of this tree + root, err := exampleNMT2(1, true, 8, 9, 10, 11, 12, 13, 14, 15).Root() + require.NoError(t, err) + return root + }(), + }, + { + start: 8, + end: 12, + tree: n, + expectedRoot: func() []byte { + // because the root of the range [8,12) coincides with the root of this tree + root, err := exampleNMT2(1, true, 8, 9, 10, 11).Root() + require.NoError(t, err) + return root + }(), + }, + { + start: 4, + end: 8, + tree: n, + expectedRoot: func() []byte { + // because the root of the range [4,8) coincides with the root of this tree + root, err := exampleNMT2(1, true, 4, 5, 6, 7).Root() + require.NoError(t, err) + return root + }(), + }, + { + start: 4, + end: 6, + tree: n, + expectedRoot: func() []byte { + // because the root of the range [4,6) coincides with the root of this tree + root, err := exampleNMT2(1, true, 4, 5).Root() + require.NoError(t, err) + return root + }(), + }, + { + start: 4, + end: 5, + tree: n, + expectedRoot: func() []byte { + // because the root of the range [4,5) coincides with the root of this tree + root, err := exampleNMT2(1, true, 4).Root() + require.NoError(t, err) + return root + }(), + }, + { // doesn't correctly reference an inner node + start: 2, + end: 6, + tree: n, + expectError: true, + }, + { + start: -1, // invalid start + end: 4, + tree: n, + expectError: true, + }, + { + start: 4, + end: 4, // start == end + tree: n, + expectError: true, + }, + { + start: 5, // start >= end + end: 4, + tree: n, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("treeSize=%d,start=%d,end=%d", tt.tree.Size(), tt.start, tt.end), func(t *testing.T) { + root, err := tt.tree.ComputeSubtreeRoot(tt.start, tt.end) + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expectedRoot, root) + } + }) + } +} diff --git a/proof.go b/proof.go index 6e824e5..bde6723 100644 --- a/proof.go +++ b/proof.go @@ -9,7 +9,7 @@ import ( "math/bits" "github.com/celestiaorg/nmt/namespace" - pb "github.com/celestiaorg/nmt/pb" + "github.com/celestiaorg/nmt/pb" ) var ( @@ -465,6 +465,187 @@ func (proof Proof) VerifyInclusion(h hash.Hash, nid namespace.ID, leavesWithoutN return res } +// VerifySubtreeRootInclusion verifies that a set of subtree roots is included in +// an NMT. +// Warning: This method is Celestia specific! Using it without verifying +// the following assumptions, can return unexpected errors, false positive/negatives: +// - The subtree roots are created according to the ADR-013 +// https://github.com/celestiaorg/celestia-app/blob/main/docs/architecture/adr-013-non-interactive-default-rules-for-zero-padding.md +// - The tree's number of leaves is a power of two +// The subtreeWidth is also defined in ADR-013. +// More information on the algorithm used can be found in the ToLeafRanges() method docs. +func (proof Proof) VerifySubtreeRootInclusion(nth *NmtHasher, subtreeRoots [][]byte, subtreeWidth int, root []byte) (bool, error) { + // check that the proof range is valid + if proof.Start() < 0 || proof.Start() >= proof.End() { + return false, fmt.Errorf("proof range [proof.start=%d, proof.end=%d) is not valid: %w", proof.Start(), proof.End(), ErrInvalidRange) + } + + // check that the root is valid w.r.t the NMT hasher + if err := nth.ValidateNodeFormat(root); err != nil { + return false, fmt.Errorf("root does not match the NMT hasher's hash format: %w", err) + } + // check that all the proof.Notes() are valid w.r.t the NMT hasher + for _, node := range proof.Nodes() { + if err := nth.ValidateNodeFormat(node); err != nil { + return false, fmt.Errorf("proof nodes do not match the NMT hasher's hash format: %w", err) + } + } + // check that all the subtree roots are valid w.r.t the NMT hasher + for _, subtreeRoot := range subtreeRoots { + if err := nth.ValidateNodeFormat(subtreeRoot); err != nil { + return false, fmt.Errorf("inner nodes does not match the NMT hasher's hash format: %w", err) + } + } + + // get the subtree roots leaf ranges + ranges, err := ToLeafRanges(proof.Start(), proof.End(), subtreeWidth) + if err != nil { + return false, err + } + + // check whether the number of ranges matches the number of subtree roots. + // if not, make an early return. + if len(subtreeRoots) != len(ranges) { + return false, fmt.Errorf("number of subtree roots %d is different than the number of the expected leaf ranges %d", len(subtreeRoots), len(ranges)) + } + + var computeRoot func(start, end int) ([]byte, error) + // computeRoot can return error iff the HashNode function fails while calculating the root + computeRoot = func(start, end int) ([]byte, error) { + // if the current range does not overlap with the proof range, pop and + // return a proof node if present, else return nil because subtree + // doesn't exist + if end <= proof.Start() || start >= proof.End() { + return popIfNonEmpty(&proof.nodes), nil + } + + if len(ranges) == 0 { + return nil, fmt.Errorf(fmt.Sprintf("expected to have a subtree root for range [%d, %d)", start, end)) + } + + if ranges[0].Start == start && ranges[0].End == end { + ranges = ranges[1:] + return popIfNonEmpty(&subtreeRoots), nil + } + + // Recursively get left and right subtree + k := getSplitPoint(end - start) + left, err := computeRoot(start, start+k) + if err != nil { + return nil, fmt.Errorf("failed to compute subtree root [%d, %d): %w", start, start+k, err) + } + right, err := computeRoot(start+k, end) + if err != nil { + return nil, fmt.Errorf("failed to compute subtree root [%d, %d): %w", start+k, end, err) + } + + // only right leaf/subtree can be non-existent + if right == nil { + return left, nil + } + hash, err := nth.HashNode(left, right) + if err != nil { + return nil, fmt.Errorf("failed to hash node: %w", err) + } + return hash, nil + } + + // estimate the leaf size of the subtree containing the proof range + proofRangeSubtreeEstimate := getSplitPoint(proof.End()) * 2 + if proofRangeSubtreeEstimate < 1 { + proofRangeSubtreeEstimate = 1 + } + rootHash, err := computeRoot(0, proofRangeSubtreeEstimate) + if err != nil { + return false, fmt.Errorf("failed to compute root [%d, %d): %w", 0, proofRangeSubtreeEstimate, err) + } + for i := 0; i < len(proof.Nodes()); i++ { + rootHash, err = nth.HashNode(rootHash, proof.Nodes()[i]) + if err != nil { + return false, fmt.Errorf("failed to hash node: %w", err) + } + } + + return bytes.Equal(rootHash, root), nil +} + +// ToLeafRanges returns the leaf ranges corresponding to the provided subtree roots. +// The proof range defined by proofStart and proofEnd is end exclusive. +// It uses the subtree root width to calculate the maximum number of leaves a subtree root can +// commit to. +// The subtree root width is defined as per ADR-013: +// https://github.com/celestiaorg/celestia-app/blob/main/docs/architecture/adr-013-non-interactive-default-rules-for-zero-padding.md +// This method assumes: +// - The subtree roots are created according to the ADR-013 non-interactive defaults rules +// - The tree's number of leaves is a power of two +// The algorithm is as follows: +// - Let `d` be `y - x` (the range of the proof). +// - `i` is the index of the next subtree root. +// - While `d != 0`: +// - Let `z` be the largest power of 2 that fits in `d`; here we are finding the range for the next subtree root. +// - The range for the next subtree root is `[x, x + z)`, i.e., `S_i` is the subtree root of leaves at indices `[x, x + z)`. +// - `d = d - z` (move past the first subtree root and its range). +// - `i = i + 1`. +// - Go back to the loop condition. +// +// Note: This method is Celestia specific. +func ToLeafRanges(proofStart, proofEnd, subtreeWidth int) ([]LeafRange, error) { + if proofStart < 0 { + return nil, fmt.Errorf("proof start %d shouldn't be strictly negative", proofStart) + } + if proofEnd <= proofStart { + return nil, fmt.Errorf("proof end %d should be stricly bigger than proof start %d", proofEnd, proofStart) + } + if subtreeWidth <= 0 { + return nil, fmt.Errorf("subtree root width cannot be negative %d", subtreeWidth) + } + currentStart := proofStart + currentLeafRange := proofEnd - proofStart + var ranges []LeafRange + maximumLeafRange := subtreeWidth + for currentLeafRange != 0 { + nextRange, err := nextLeafRange(currentStart, proofEnd, maximumLeafRange) + if err != nil { + return nil, err + } + ranges = append(ranges, nextRange) + currentStart = nextRange.End + currentLeafRange = currentLeafRange - nextRange.End + nextRange.Start + } + return ranges, nil +} + +// nextLeafRange takes a proof start, proof end, and the maximum range a subtree +// root can cover, and returns the corresponding subtree root range. +// Check ToLeafRanges() for more information on the algorithm used. +// The subtreeWidth is calculated using SubTreeWidth() method +// in celestiaorg/go-square/inclusion package. +// The subtreeWidth is a power of two. +// Also, the LeafRange values, i.e., the range size, are all powers of two. +// Note: This method is Celestia specific. +func nextLeafRange(currentStart, currentEnd, subtreeWidth int) (LeafRange, error) { + currentLeafRange := currentEnd - currentStart + minimum := minInt(currentLeafRange, subtreeWidth) + uMinimum, err := safeIntToUint(minimum) + if err != nil { + return LeafRange{}, fmt.Errorf("failed to convert subtree root range to Uint %w", err) + } + currentRange, err := largestPowerOfTwo(uMinimum) + if err != nil { + return LeafRange{}, err + } + return LeafRange{Start: currentStart, End: currentStart + currentRange}, nil +} + +// largestPowerOfTwo calculates the largest power of two +// that is smaller than 'bound' +func largestPowerOfTwo(bound uint) (int, error) { + if bound == 0 { + return 0, fmt.Errorf("bound cannot be equal to 0") + } + return 1 << (bits.Len(bound) - 1), nil +} + // ProtoToProof creates a proof from its proto representation. func ProtoToProof(protoProof pb.Proof) Proof { if protoProof.Start == 0 && protoProof.End == 0 { @@ -512,3 +693,17 @@ func popIfNonEmpty(s *[][]byte) []byte { } return nil } + +func safeIntToUint(val int) (uint, error) { + if val < 0 { + return 0, fmt.Errorf("cannot convert a negative int %d to uint", val) + } + return uint(val), nil +} + +func minInt(val1, val2 int) int { + if val1 > val2 { + return val2 + } + return val1 +} diff --git a/proof_test.go b/proof_test.go index 235a403..b2ea98f 100644 --- a/proof_test.go +++ b/proof_test.go @@ -3,6 +3,7 @@ package nmt import ( "bytes" "crypto/sha256" + "fmt" "hash" "testing" @@ -1124,3 +1125,706 @@ func Test_ProtoToProof(t *testing.T) { }) } } + +func TestLargestPowerOfTwo(t *testing.T) { + tests := []struct { + bound uint + expected int + expectError bool + }{ + {bound: 1, expected: 1}, + {bound: 2, expected: 2}, + {bound: 3, expected: 2}, + {bound: 4, expected: 4}, + {bound: 5, expected: 4}, + {bound: 6, expected: 4}, + {bound: 7, expected: 4}, + {bound: 8, expected: 8}, + {bound: 0, expectError: true}, + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("bound=%d", tt.bound), func(t *testing.T) { + result, err := largestPowerOfTwo(tt.bound) + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expected, result) + } + }) + } +} + +func TestToLeafRanges(t *testing.T) { + tests := []struct { + proofStart, proofEnd, subtreeWidth int + expectedRanges []LeafRange + expectError bool + }{ + { + proofStart: 0, + proofEnd: 8, + subtreeWidth: 1, + expectedRanges: []LeafRange{ + {Start: 0, End: 1}, + {Start: 1, End: 2}, + {Start: 2, End: 3}, + {Start: 3, End: 4}, + {Start: 4, End: 5}, + {Start: 5, End: 6}, + {Start: 6, End: 7}, + {Start: 7, End: 8}, + }, + }, + { + proofStart: 0, + proofEnd: 9, + subtreeWidth: 1, + expectedRanges: []LeafRange{ + {Start: 0, End: 1}, + {Start: 1, End: 2}, + {Start: 2, End: 3}, + {Start: 3, End: 4}, + {Start: 4, End: 5}, + {Start: 5, End: 6}, + {Start: 6, End: 7}, + {Start: 7, End: 8}, + {Start: 8, End: 9}, + }, + }, + { + proofStart: 0, + proofEnd: 16, + subtreeWidth: 1, + expectedRanges: []LeafRange{ + {Start: 0, End: 1}, + {Start: 1, End: 2}, + {Start: 2, End: 3}, + {Start: 3, End: 4}, + {Start: 4, End: 5}, + {Start: 5, End: 6}, + {Start: 6, End: 7}, + {Start: 7, End: 8}, + {Start: 8, End: 9}, + {Start: 9, End: 10}, + {Start: 10, End: 11}, + {Start: 11, End: 12}, + {Start: 12, End: 13}, + {Start: 13, End: 14}, + {Start: 14, End: 15}, + {Start: 15, End: 16}, + }, + }, + { + proofStart: 0, + proofEnd: 100, + subtreeWidth: 2, + expectedRanges: func() []LeafRange { + var ranges []LeafRange + for i := 0; i < 100; i = i + 2 { + ranges = append(ranges, LeafRange{i, i + 2}) + } + return ranges + }(), + }, + { + proofStart: 0, + proofEnd: 150, + subtreeWidth: 4, + expectedRanges: func() []LeafRange { + var ranges []LeafRange + for i := 0; i < 148; i = i + 4 { + ranges = append(ranges, LeafRange{i, i + 4}) + } + ranges = append(ranges, LeafRange{ + Start: 148, + End: 150, + }) + return ranges + }(), + }, + { + proofStart: 0, + proofEnd: 400, + subtreeWidth: 8, + expectedRanges: func() []LeafRange { + var ranges []LeafRange + for i := 0; i < 400; i = i + 8 { + ranges = append(ranges, LeafRange{i, i + 8}) + } + return ranges + }(), + }, + { + proofStart: -1, + proofEnd: 0, + subtreeWidth: -1, + expectedRanges: nil, + expectError: true, + }, + { + proofStart: 0, + proofEnd: -1, + subtreeWidth: -1, + expectedRanges: nil, + expectError: true, + }, + { + proofStart: 0, + proofEnd: 0, + subtreeWidth: 2, + expectedRanges: nil, + expectError: true, + }, + { + proofStart: 0, + proofEnd: 0, + subtreeWidth: -1, + expectedRanges: nil, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("proofStart=%d, proofEnd=%d, subtreeWidth=%d", tt.proofStart, tt.proofEnd, tt.subtreeWidth), func(t *testing.T) { + result, err := ToLeafRanges(tt.proofStart, tt.proofEnd, tt.subtreeWidth) + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.True(t, compareRanges(result, tt.expectedRanges)) + } + }) + } +} + +func compareRanges(a, b []LeafRange) bool { + if len(a) != len(b) { + return false + } + for i := range a { + if a[i] != b[i] { + return false + } + } + return true +} + +func TestNextLeafRange(t *testing.T) { + tests := []struct { + currentStart, currentEnd int + // the maximum leaf range == subtree width used in these tests do not follow ADR-013 + // they're just used to try different test cases + subtreeRootMaximumLeafRange int + expectedRange LeafRange + expectError bool + }{ + { + currentStart: 0, + currentEnd: 8, + subtreeRootMaximumLeafRange: 4, + expectedRange: LeafRange{Start: 0, End: 4}, + }, + { + currentStart: 4, + currentEnd: 10, + subtreeRootMaximumLeafRange: 8, + expectedRange: LeafRange{Start: 4, End: 8}, + }, + { + currentStart: 4, + currentEnd: 20, + subtreeRootMaximumLeafRange: 16, + expectedRange: LeafRange{Start: 4, End: 20}, + }, + { + currentStart: 4, + currentEnd: 20, + subtreeRootMaximumLeafRange: 1, + expectedRange: LeafRange{Start: 4, End: 5}, + }, + { + currentStart: 4, + currentEnd: 20, + subtreeRootMaximumLeafRange: 2, + expectedRange: LeafRange{Start: 4, End: 6}, + }, + { + currentStart: 4, + currentEnd: 20, + subtreeRootMaximumLeafRange: 4, + expectedRange: LeafRange{Start: 4, End: 8}, + }, + { + currentStart: 4, + currentEnd: 20, + subtreeRootMaximumLeafRange: 8, + expectedRange: LeafRange{Start: 4, End: 12}, + }, + { + currentStart: 0, + currentEnd: 1, + subtreeRootMaximumLeafRange: 1, + expectedRange: LeafRange{Start: 0, End: 1}, + }, + { + currentStart: 0, + currentEnd: 16, + subtreeRootMaximumLeafRange: 16, + expectedRange: LeafRange{Start: 0, End: 16}, + }, + { + currentStart: 0, + currentEnd: 0, + subtreeRootMaximumLeafRange: 4, + expectError: true, + }, + { + currentStart: 5, + currentEnd: 2, + subtreeRootMaximumLeafRange: 4, + expectError: true, + }, + { + currentStart: 5, + currentEnd: 2, + subtreeRootMaximumLeafRange: 0, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("currentStart=%d, currentEnd=%d, subtreeRootMaximumLeafRange=%d", tt.currentStart, tt.currentEnd, tt.subtreeRootMaximumLeafRange), func(t *testing.T) { + result, err := nextLeafRange(tt.currentStart, tt.currentEnd, tt.subtreeRootMaximumLeafRange) + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.expectedRange, result) + } + }) + } +} + +func TestSafeIntToUint(t *testing.T) { + tests := []struct { + input int + expectedUint uint + expectedError error + }{ + { + input: 10, + expectedUint: 10, + expectedError: nil, + }, + { + input: 0, + expectedUint: 0, + expectedError: nil, + }, + { + input: -5, + expectedUint: 0, + expectedError: fmt.Errorf("cannot convert a negative int %d to uint", -5), + }, + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("input=%d", tt.input), func(t *testing.T) { + result, err := safeIntToUint(tt.input) + if (err != nil) != (tt.expectedError != nil) || (err != nil && err.Error() != tt.expectedError.Error()) { + t.Errorf("expected error %v, got %v", tt.expectedError, err) + } + if result != tt.expectedUint { + t.Errorf("expected uint %v, got %v", tt.expectedUint, result) + } + }) + } +} + +func TestMinInt(t *testing.T) { + tests := []struct { + val1, val2 int + expected int + }{ + { + val1: 10, + val2: 20, + expected: 10, + }, + { + val1: -5, + val2: 6, + expected: -5, + }, + { + val1: 5, + val2: -6, + expected: -6, + }, + { + val1: -5, + val2: -6, + expected: -6, + }, + { + val1: 0, + val2: 0, + expected: 0, + }, + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("val1=%d, val2=%d", tt.val1, tt.val2), func(t *testing.T) { + result := minInt(tt.val1, tt.val2) + if result != tt.expected { + t.Errorf("expected %d, got %d", tt.expected, result) + } + }) + } +} + +func TestVerifySubtreeRootInclusion(t *testing.T) { + tree := exampleNMT(1, true, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15) + root, err := tree.Root() + require.NoError(t, err) + + nmthasher := tree.treeHasher + hasher := nmthasher.(*NmtHasher) + + tests := []struct { + proof Proof + subtreeRoots [][]byte + // the subtree widths used in these tests do not follow ADR-013 + // they're just used to try different test cases + subtreeWidth int + root []byte + validProof bool + expectError bool + }{ + { + proof: func() Proof { + p, err := tree.ProveRange(0, 8) + require.NoError(t, err) + return p + }(), + subtreeRoots: func() [][]byte { + subtreeRoot, err := tree.ComputeSubtreeRoot(0, 8) + require.NoError(t, err) + return [][]byte{subtreeRoot} + }(), + subtreeWidth: 8, + root: root, + validProof: true, + }, + { + proof: func() Proof { + p, err := tree.ProveRange(0, 1) + require.NoError(t, err) + return p + }(), + subtreeRoots: func() [][]byte { + subtreeRoot, err := tree.ComputeSubtreeRoot(0, 1) + require.NoError(t, err) + return [][]byte{subtreeRoot} + }(), + subtreeWidth: 8, + root: root, + validProof: true, + }, + { + proof: func() Proof { + p, err := tree.ProveRange(0, 2) + require.NoError(t, err) + return p + }(), + subtreeRoots: func() [][]byte { + subtreeRoot, err := tree.ComputeSubtreeRoot(0, 2) + require.NoError(t, err) + return [][]byte{subtreeRoot} + }(), + subtreeWidth: 8, + root: root, + validProof: true, + }, + { + proof: func() Proof { + p, err := tree.ProveRange(2, 4) + require.NoError(t, err) + return p + }(), + subtreeRoots: func() [][]byte { + subtreeRoot, err := tree.ComputeSubtreeRoot(2, 4) + require.NoError(t, err) + return [][]byte{subtreeRoot} + }(), + subtreeWidth: 8, + root: root, + validProof: true, + }, + { + proof: func() Proof { + p, err := tree.ProveRange(0, 8) + require.NoError(t, err) + return p + }(), + subtreeRoots: func() [][]byte { + subtreeRoot1, err := tree.ComputeSubtreeRoot(0, 4) + require.NoError(t, err) + subtreeRoot2, err := tree.ComputeSubtreeRoot(4, 8) + require.NoError(t, err) + return [][]byte{subtreeRoot1, subtreeRoot2} + }(), + subtreeWidth: 4, + root: root, + validProof: true, + }, + { + proof: func() Proof { + p, err := tree.ProveRange(0, 8) + require.NoError(t, err) + return p + }(), + subtreeRoots: func() [][]byte { + subtreeRoot1, err := tree.ComputeSubtreeRoot(0, 2) + require.NoError(t, err) + subtreeRoot2, err := tree.ComputeSubtreeRoot(2, 4) + require.NoError(t, err) + subtreeRoot3, err := tree.ComputeSubtreeRoot(4, 6) + require.NoError(t, err) + subtreeRoot4, err := tree.ComputeSubtreeRoot(6, 8) + require.NoError(t, err) + return [][]byte{subtreeRoot1, subtreeRoot2, subtreeRoot3, subtreeRoot4} + }(), + subtreeWidth: 2, + root: root, + validProof: true, + }, + { + proof: func() Proof { + p, err := tree.ProveRange(0, 8) + require.NoError(t, err) + return p + }(), + subtreeRoots: func() [][]byte { + subtreeRoot1, err := tree.ComputeSubtreeRoot(0, 1) + require.NoError(t, err) + subtreeRoot2, err := tree.ComputeSubtreeRoot(1, 2) + require.NoError(t, err) + subtreeRoot3, err := tree.ComputeSubtreeRoot(2, 3) + require.NoError(t, err) + subtreeRoot4, err := tree.ComputeSubtreeRoot(3, 4) + require.NoError(t, err) + subtreeRoot5, err := tree.ComputeSubtreeRoot(4, 5) + require.NoError(t, err) + subtreeRoot6, err := tree.ComputeSubtreeRoot(5, 6) + require.NoError(t, err) + subtreeRoot7, err := tree.ComputeSubtreeRoot(6, 7) + require.NoError(t, err) + subtreeRoot8, err := tree.ComputeSubtreeRoot(7, 8) + require.NoError(t, err) + return [][]byte{subtreeRoot1, subtreeRoot2, subtreeRoot3, subtreeRoot4, subtreeRoot5, subtreeRoot6, subtreeRoot7, subtreeRoot8} + }(), + subtreeWidth: 1, + root: root, + validProof: true, + }, + { + proof: func() Proof { + p, err := tree.ProveRange(4, 8) + require.NoError(t, err) + return p + }(), + subtreeRoots: func() [][]byte { + subtreeRoot, err := tree.ComputeSubtreeRoot(4, 8) + require.NoError(t, err) + return [][]byte{subtreeRoot} + }(), + subtreeWidth: 8, + root: root, + validProof: true, + }, + { + proof: func() Proof { + p, err := tree.ProveRange(12, 14) + require.NoError(t, err) + return p + }(), + subtreeRoots: func() [][]byte { + subtreeRoot, err := tree.ComputeSubtreeRoot(12, 14) + require.NoError(t, err) + return [][]byte{subtreeRoot} + }(), + subtreeWidth: 8, + root: root, + validProof: true, + }, + { + proof: func() Proof { + p, err := tree.ProveRange(14, 16) + require.NoError(t, err) + return p + }(), + subtreeRoots: func() [][]byte { + subtreeRoot, err := tree.ComputeSubtreeRoot(14, 16) + require.NoError(t, err) + return [][]byte{subtreeRoot} + }(), + subtreeWidth: 8, + root: root, + validProof: true, + }, + { + proof: func() Proof { + p, err := tree.ProveRange(14, 15) + require.NoError(t, err) + return p + }(), + subtreeRoots: func() [][]byte { + subtreeRoot, err := tree.ComputeSubtreeRoot(14, 15) + require.NoError(t, err) + return [][]byte{subtreeRoot} + }(), + subtreeWidth: 8, + root: root, + validProof: true, + }, + { + proof: func() Proof { + p, err := tree.ProveRange(15, 16) + require.NoError(t, err) + return p + }(), + subtreeRoots: func() [][]byte { + subtreeRoot, err := tree.ComputeSubtreeRoot(15, 16) + require.NoError(t, err) + return [][]byte{subtreeRoot} + }(), + subtreeWidth: 8, + root: root, + validProof: true, + }, + { + proof: func() Proof { + p, err := tree.ProveRange(15, 16) + require.NoError(t, err) + return p + }(), + subtreeRoots: func() [][]byte { + subtreeRoot, err := tree.ComputeSubtreeRoot(15, 16) + require.NoError(t, err) + return [][]byte{subtreeRoot} + }(), + subtreeWidth: -3, // invalid subtree root width + root: root, + expectError: true, + }, + { + proof: func() Proof { + p, err := tree.ProveRange(15, 16) + require.NoError(t, err) + return p + }(), + subtreeRoots: func() [][]byte { + subtreeRoot, err := tree.ComputeSubtreeRoot(15, 16) + require.NoError(t, err) + return [][]byte{subtreeRoot, subtreeRoot} // invalid number of subtree roots + }(), + subtreeWidth: 8, + root: root, + expectError: true, + }, + { + proof: func() Proof { + p, err := tree.ProveRange(15, 16) + require.NoError(t, err) + return p + }(), + subtreeRoots: func() [][]byte { + subtreeRoot, err := tree.ComputeSubtreeRoot(15, 16) + require.NoError(t, err) + return [][]byte{subtreeRoot} + }(), + subtreeWidth: 8, + root: []byte("random root"), // invalid root format + expectError: true, + }, + { + proof: Proof{start: -1}, // invalid start + subtreeRoots: func() [][]byte { + subtreeRoot, err := tree.ComputeSubtreeRoot(15, 16) + require.NoError(t, err) + return [][]byte{subtreeRoot} + }(), + subtreeWidth: 8, + root: root, + expectError: true, + }, + { + proof: Proof{end: 1, start: 2}, // invalid end + subtreeRoots: func() [][]byte { + subtreeRoot, err := tree.ComputeSubtreeRoot(15, 16) + require.NoError(t, err) + return [][]byte{subtreeRoot} + }(), + subtreeWidth: 8, + root: root, + expectError: true, + }, + { + proof: Proof{ + start: 0, + end: 4, + nodes: [][]byte{[]byte("invalid proof node")}, // invalid proof node + }, + subtreeRoots: func() [][]byte { + subtreeRoot, err := tree.ComputeSubtreeRoot(15, 16) + require.NoError(t, err) + return [][]byte{subtreeRoot} + }(), + subtreeWidth: 8, + root: root, + expectError: true, + }, + { + proof: func() Proof { + p, err := tree.ProveRange(15, 16) + require.NoError(t, err) + return p + }(), + subtreeRoots: [][]byte{[]byte("invalid subtree root")}, // invalid subtree root + subtreeWidth: 8, + root: root, + expectError: true, + }, + + { + proof: func() Proof { + p, err := tree.ProveRange(0, 8) + require.NoError(t, err) + return p + }(), + subtreeRoots: func() [][]byte { + subtreeRoot1, err := tree.ComputeSubtreeRoot(0, 4) + require.NoError(t, err) + return [][]byte{subtreeRoot1} // will error because it requires the subtree root of [4,8) too + }(), + subtreeWidth: 4, + root: root, + expectError: true, + }, + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("proofStart=%d, proofEnd=%d, subTreeWidth=%d", tt.proof.Start(), tt.proof.End(), tt.subtreeWidth), func(t *testing.T) { + result, err := tt.proof.VerifySubtreeRootInclusion(hasher, tt.subtreeRoots, tt.subtreeWidth, tt.root) + if tt.expectError { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tt.validProof, result) + } + }) + } +}