Skip to content

Commit

Permalink
Refactor ProofToPath and BuildTrie
Browse files Browse the repository at this point in the history
verify proof passes all test cases

range proof but has issues

prove non-membership

add non-membership tests

doesnt work

some fixes

more tests

doesn't work fully

pass 4 key trie D

pass trie c and d

pass all 4 keys

tidy up tests

fix shiftright bug

pass 9 keys non-existent

proof to path pass all test cases

tidy up tests

storageNodeSet's Put will merge with existing one
  • Loading branch information
weiihann committed Nov 26, 2024
1 parent b37ef1a commit 7ec3c1a
Show file tree
Hide file tree
Showing 8 changed files with 1,464 additions and 1,577 deletions.
94 changes: 55 additions & 39 deletions core/trie/key.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,14 @@ package trie
import (
"bytes"
"encoding/hex"
"errors"
"fmt"
"math/big"

"github.com/NethermindEth/juno/core/felt"
)

var NilKey = &Key{len: 0, bitset: [32]byte{}}

type Key struct {
len uint8
bitset [32]byte
Expand All @@ -25,23 +26,7 @@ func NewKey(length uint8, keyBytes []byte) Key {
}

func (k *Key) SubKey(n uint8) (*Key, error) {
if n > k.len {
return nil, errors.New(fmt.Sprint("cannot subtract key of length %i from key of length %i", n, k.len))
}

newKey := &Key{len: n}
copy(newKey.bitset[:], k.bitset[len(k.bitset)-int((k.len+7)/8):]) //nolint:mnd

// Shift right by the number of bits that are not needed
shift := k.len - n
for i := len(newKey.bitset) - 1; i >= 0; i-- {
newKey.bitset[i] >>= shift
if i > 0 {
newKey.bitset[i] |= newKey.bitset[i-1] << (8 - shift)
}
}

return newKey, nil
panic("TODO(weiihann): not used")

Check warning on line 29 in core/trie/key.go

View check run for this annotation

Codecov / codecov/patch

core/trie/key.go#L29

Added line #L29 was not covered by tests
}

func (k *Key) bytesNeeded() uint {
Expand Down Expand Up @@ -96,31 +81,47 @@ func (k *Key) Equal(other *Key) bool {
return k.len == other.len && k.bitset == other.bitset
}

func (k *Key) Test(bit uint8) bool {
// IsBitSet returns whether the bit at the given position is 1.
// Position 0 represents the least significant (rightmost) bit.
func (k *Key) IsBitSet(position uint8) bool {
const LSB = uint8(0x1)
byteIdx := bit / 8
byteIdx := position / 8
byteAtIdx := k.bitset[len(k.bitset)-int(byteIdx)-1]
bitIdx := bit % 8
bitIdx := position % 8
return ((byteAtIdx >> bitIdx) & LSB) != 0
}

func (k *Key) String() string {
return fmt.Sprintf("(%d) %s", k.len, hex.EncodeToString(k.bitset[:]))
}

// DeleteLSB right shifts and shortens the key
func (k *Key) DeleteLSB(n uint8) {
// ShiftRight removes n least significant bits from the key by performing a right shift
// operation and reducing the key length. For example, if the key contains bits
// "1111 0000" (length=8) and n=4, the result will be "1111" (length=4).
//
// The operation is destructive - it modifies the key in place.
func (k *Key) ShiftRight(n uint8) {
if k.len < n {
panic("deleting more bits than there are")
}

if n == 0 {
return
}

var bigInt big.Int
bigInt.SetBytes(k.bitset[:])
bigInt.Rsh(&bigInt, uint(n))
bigInt.FillBytes(k.bitset[:])
k.len -= n
}

func (k *Key) MostSignificantBits(n uint8) (*Key, error) {
if n > k.len {
return nil, fmt.Errorf("cannot get more bits than the key length")
}

Check warning on line 118 in core/trie/key.go

View check run for this annotation

Codecov / codecov/patch

core/trie/key.go#L117-L118

Added lines #L117 - L118 were not covered by tests

keyCopy := k.Copy()
keyCopy.ShiftRight(k.len - n)
return &keyCopy, nil
}

// Truncate truncates key to `length` bits by clearing the remaining upper bits
func (k *Key) Truncate(length uint8) {
k.len = length
Expand All @@ -136,20 +137,35 @@ func (k *Key) Truncate(length uint8) {
}
}

func (k *Key) RemoveLastBit() {
if k.len == 0 {
return
}
func (k *Key) String() string {
return fmt.Sprintf("(%d) %s", k.len, hex.EncodeToString(k.bitset[:]))

Check warning on line 141 in core/trie/key.go

View check run for this annotation

Codecov / codecov/patch

core/trie/key.go#L140-L141

Added lines #L140 - L141 were not covered by tests
}

k.len--
// Copy returns a deep copy of the key
func (k *Key) Copy() Key {
newKey := Key{len: k.len}
copy(newKey.bitset[:], k.bitset[:])
return newKey
}

unusedBytes := k.unusedBytes()
clear(unusedBytes)
// findCommonKey finds the set of common MSB bits in two key bitsets.
func findCommonKey(longerKey, shorterKey *Key) (Key, bool) {
divergentBit := findDivergentBit(longerKey, shorterKey)
commonKey := *shorterKey
commonKey.ShiftRight(shorterKey.Len() - divergentBit + 1)
return commonKey, divergentBit == shorterKey.Len()+1
}

// clear upper bits on the last used byte
inUseBytes := k.inUseBytes()
unusedBitsCount := 8 - (k.len % 8)
if unusedBitsCount != 8 && len(inUseBytes) > 0 {
inUseBytes[0] = (inUseBytes[0] << unusedBitsCount) >> unusedBitsCount
func findDivergentBit(longerKey, shorterKey *Key) uint8 {
divergentBit := uint8(0)
for divergentBit <= shorterKey.Len() &&
longerKey.IsBitSet(longerKey.Len()-divergentBit) == shorterKey.IsBitSet(shorterKey.Len()-divergentBit) {
divergentBit++
}
return divergentBit
}

func isSubset(longerKey, shorterKey *Key) bool {
divergentBit := findDivergentBit(longerKey, shorterKey)
return divergentBit == shorterKey.Len()+1
}
118 changes: 70 additions & 48 deletions core/trie/key_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,11 +71,74 @@ func BenchmarkKeyEncoding(b *testing.B) {
func TestKeyTest(t *testing.T) {
key := trie.NewKey(44, []byte{0x10, 0x02})
for i := 0; i < int(key.Len()); i++ {
assert.Equal(t, i == 1 || i == 12, key.Test(uint8(i)), i)
assert.Equal(t, i == 1 || i == 12, key.IsBitSet(uint8(i)), i)
}
}

func TestDeleteLSB(t *testing.T) {
func TestIsBitSet(t *testing.T) {
tests := map[string]struct {
key trie.Key
position uint8
expected bool
}{
"single byte, LSB set": {
key: trie.NewKey(8, []byte{0x01}),
position: 0,
expected: true,
},
"single byte, MSB set": {
key: trie.NewKey(8, []byte{0x80}),
position: 7,
expected: true,
},
"single byte, middle bit set": {
key: trie.NewKey(8, []byte{0x10}),
position: 4,
expected: true,
},
"single byte, bit not set": {
key: trie.NewKey(8, []byte{0xFE}),
position: 0,
expected: false,
},
"multiple bytes, LSB set": {
key: trie.NewKey(16, []byte{0x00, 0x02}),
position: 1,
expected: true,
},
"multiple bytes, MSB set": {
key: trie.NewKey(16, []byte{0x01, 0x00}),
position: 8,
expected: true,
},
"multiple bytes, no bits set": {
key: trie.NewKey(16, []byte{0x00, 0x00}),
position: 7,
expected: false,
},
"check all bits in pattern": {
key: trie.NewKey(8, []byte{0xA5}), // 10100101
position: 0,
expected: true,
},
}

// Additional test for 0xA5 pattern
key := trie.NewKey(8, []byte{0xA5}) // 10100101
expectedBits := []bool{true, false, true, false, false, true, false, true}
for i, expected := range expectedBits {
assert.Equal(t, expected, key.IsBitSet(uint8(i)), "bit %d in 0xA5", i)
}

for name, tc := range tests {
t.Run(name, func(t *testing.T) {
result := tc.key.IsBitSet(tc.position)
assert.Equal(t, tc.expected, result)
})
}
}

func TestShiftRight(t *testing.T) {
key := trie.NewKey(16, []byte{0xF3, 0x04})

tests := map[string]struct {
Expand All @@ -98,57 +161,16 @@ func TestDeleteLSB(t *testing.T) {
shiftAmount: 9,
expectedKey: trie.NewKey(7, []byte{0x79}),
},
}

for desc, test := range tests {
t.Run(desc, func(t *testing.T) {
copyKey := key
copyKey.DeleteLSB(test.shiftAmount)
assert.Equal(t, test.expectedKey, copyKey)
})
}
}

func TestTruncate(t *testing.T) {
tests := map[string]struct {
key trie.Key
newLen uint8
expectedKey trie.Key
}{
"truncate to 12 bits": {
key: trie.NewKey(16, []byte{0xF3, 0x14}),
newLen: 12,
expectedKey: trie.NewKey(12, []byte{0x03, 0x14}),
},
"truncate to 9 bits": {
key: trie.NewKey(16, []byte{0xF3, 0x14}),
newLen: 9,
expectedKey: trie.NewKey(9, []byte{0x01, 0x14}),
},
"truncate to 3 bits": {
key: trie.NewKey(16, []byte{0xF3, 0x14}),
newLen: 3,
expectedKey: trie.NewKey(3, []byte{0x04}),
},
"truncate to multiple of 8": {
key: trie.NewKey(251, []uint8{
0x7, 0x40, 0x33, 0x8c, 0xbc, 0x9, 0xeb, 0xf, 0xb7, 0xab,
0xc5, 0x20, 0x35, 0xc6, 0x4d, 0x4e, 0xa5, 0x78, 0x18, 0x9e, 0xd6, 0x37, 0x47, 0x91, 0xd0,
0x6e, 0x44, 0x1e, 0xf7, 0x7f, 0xf, 0x5f,
}),
newLen: 248,
expectedKey: trie.NewKey(248, []uint8{
0x0, 0x40, 0x33, 0x8c, 0xbc, 0x9, 0xeb, 0xf, 0xb7, 0xab,
0xc5, 0x20, 0x35, 0xc6, 0x4d, 0x4e, 0xa5, 0x78, 0x18, 0x9e, 0xd6, 0x37, 0x47, 0x91, 0xd0,
0x6e, 0x44, 0x1e, 0xf7, 0x7f, 0xf, 0x5f,
}),
"delete all bits": {
shiftAmount: 16,
expectedKey: trie.NewKey(0, []byte{}),
},
}

for desc, test := range tests {
t.Run(desc, func(t *testing.T) {
copyKey := test.key
copyKey.Truncate(test.newLen)
copyKey := key
copyKey.ShiftRight(test.shiftAmount)
assert.Equal(t, test.expectedKey, copyKey)
})
}
Expand Down
54 changes: 54 additions & 0 deletions core/trie/node.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package trie
import (
"bytes"
"errors"
"fmt"

"github.com/NethermindEth/juno/core/felt"
)
Expand Down Expand Up @@ -138,3 +139,56 @@ func (n *Node) UnmarshalBinary(data []byte) error {
n.RightHash.SetBytes(data[:felt.Bytes])
return nil
}

func (n *Node) String() string {
return fmt.Sprintf("Node{Value: %s, Left: %s, Right: %s, LeftHash: %s, RightHash: %s}", n.Value, n.Left, n.Right, n.LeftHash, n.RightHash)

Check warning on line 144 in core/trie/node.go

View check run for this annotation

Codecov / codecov/patch

core/trie/node.go#L143-L144

Added lines #L143 - L144 were not covered by tests
}

func (n *Node) Merge(other *Node) error {

Check failure on line 147 in core/trie/node.go

View workflow job for this annotation

GitHub Actions / lint

cyclomatic complexity 21 of func `(*Node).Merge` is high (> 15) (gocyclo)
// Compare Value if both exist
if n.Value != nil && other.Value != nil {
if !n.Value.Equal(other.Value) {
return fmt.Errorf("conflicting Values: %v != %v", n.Value, other.Value)
}
} else if other.Value != nil {
n.Value = other.Value
}

Check warning on line 155 in core/trie/node.go

View check run for this annotation

Codecov / codecov/patch

core/trie/node.go#L147-L155

Added lines #L147 - L155 were not covered by tests

// Compare Left if both exist
if n.Left != nil && other.Left != nil {
if !n.Left.Equal(other.Left) {
return fmt.Errorf("conflicting Left keys: %v != %v", n.Left, other.Left)
}
} else if other.Left != nil {
n.Left = other.Left
}

Check warning on line 164 in core/trie/node.go

View check run for this annotation

Codecov / codecov/patch

core/trie/node.go#L158-L164

Added lines #L158 - L164 were not covered by tests

// Compare Right if both exist
if n.Right != nil && other.Right != nil {
if !n.Right.Equal(other.Right) {
return fmt.Errorf("conflicting Right keys: %v != %v", n.Right, other.Right)
}
} else if other.Right != nil {
n.Right = other.Right
}

Check warning on line 173 in core/trie/node.go

View check run for this annotation

Codecov / codecov/patch

core/trie/node.go#L167-L173

Added lines #L167 - L173 were not covered by tests

// Compare LeftHash if both exist
if n.LeftHash != nil && other.LeftHash != nil {
if !n.LeftHash.Equal(other.LeftHash) {
return fmt.Errorf("conflicting LeftHash: %v != %v", n.LeftHash, other.LeftHash)
}
} else if other.LeftHash != nil {
n.LeftHash = other.LeftHash
}

Check warning on line 182 in core/trie/node.go

View check run for this annotation

Codecov / codecov/patch

core/trie/node.go#L176-L182

Added lines #L176 - L182 were not covered by tests

// Compare RightHash if both exist
if n.RightHash != nil && other.RightHash != nil {
if !n.RightHash.Equal(other.RightHash) {
return fmt.Errorf("conflicting RightHash: %v != %v", n.RightHash, other.RightHash)
}
} else if other.RightHash != nil {
n.RightHash = other.RightHash
}

Check warning on line 191 in core/trie/node.go

View check run for this annotation

Codecov / codecov/patch

core/trie/node.go#L185-L191

Added lines #L185 - L191 were not covered by tests

return nil

Check warning on line 193 in core/trie/node.go

View check run for this annotation

Codecov / codecov/patch

core/trie/node.go#L193

Added line #L193 was not covered by tests
}
Loading

0 comments on commit 7ec3c1a

Please sign in to comment.