Skip to content

Commit

Permalink
Improve trie test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
iand committed Jul 5, 2023
1 parent 9adb356 commit 8ea6166
Show file tree
Hide file tree
Showing 5 changed files with 346 additions and 155 deletions.
53 changes: 53 additions & 0 deletions key/keyutil/keyutil.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
package keyutil

import (
"crypto/rand"
"encoding/binary"
"strconv"

"github.com/plprobelab/go-kademlia/key"
)

// Random returns a KadKey of length l populated with random data.
func Random(l int) key.KadKey {
buf := make([]byte, l)
if _, err := rand.Read(buf); err != nil {
panic("RandomWithPrefix: failed to read enough entropy for key")
}
return buf
}

// RandomWithPrefix returns a KadKey of length l having a prefix equal to the bit pattern held in s.
func RandomWithPrefix(s string, l int) key.KadKey {
kk := Random(l)
if s == "" {
return kk
}

bits := len(s)
if bits > 64 {
panic("RandomWithPrefix: prefix too long")
}
n, err := strconv.ParseInt(s, 2, 64)
if err != nil {
panic("RandomWithPrefix: " + err.Error())
}
prefix := uint64(n) << (64 - bits)

size := l
if size < 8 {
size = 8
}

buf := make([]byte, size)
if _, err := rand.Read(buf); err != nil {
panic("RandomWithPrefix: failed to read enough entropy for key")
}

lead := binary.BigEndian.Uint64(buf)
lead <<= bits
lead >>= bits
lead |= prefix
binary.BigEndian.PutUint64(buf, lead)
return key.KadKey(buf[:l])
}
2 changes: 2 additions & 0 deletions key/trie/README.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
# Trie

This package contains an implementation of a XOR Trie.

The trie is to be treated as immutable. Mutator functions such as Add and Remove return copies of the trie.
108 changes: 53 additions & 55 deletions key/trie/trie.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,42 +25,21 @@ func New[T any]() *Trie[T] {
return &Trie[T]{}
}

// Depth returns the maximum depth of the Trie.
func (tr *Trie[T]) Depth() int {
return tr.DepthAtDepth(0)
}

// Depth returns the maximum depth at or beyond depth d.
func (tr *Trie[T]) DepthAtDepth(d int) int {
if tr.IsLeaf() {
return d
} else {
return max(tr.Branch[0].DepthAtDepth(d+1), tr.Branch[1].DepthAtDepth(d+1))
}
}

func max(x, y int) int {
if x > y {
return x
}
return y
}

// Size returns the number of keys added to the trie.
func (tr *Trie[T]) Size() int {
return tr.SizeAtDepth(0)
return tr.sizeAtDepth(0)
}

// Size returns the number of keys added to the trie at or beyond depth d.
func (tr *Trie[T]) SizeAtDepth(d int) int {
func (tr *Trie[T]) sizeAtDepth(d int) int {
if tr.IsLeaf() {
if !tr.HasKey() {
return 0
} else {
return 1
}
} else {
return tr.Branch[0].SizeAtDepth(d+1) + tr.Branch[1].SizeAtDepth(d+1)
return tr.Branch[0].sizeAtDepth(d+1) + tr.Branch[1].sizeAtDepth(d+1)
}
}

Expand All @@ -74,10 +53,12 @@ func (tr *Trie[T]) IsLeaf() bool {
return tr.Branch[0] == nil && tr.Branch[1] == nil
}

// IsEmptyLeaf reports whether the Trie is a leaf node without branches that also has no key.
func (tr *Trie[T]) IsEmptyLeaf() bool {
return !tr.HasKey() && tr.IsLeaf()
}

// IsEmptyLeaf reports whether the Trie is a leaf node without branches but has a key.
func (tr *Trie[T]) IsNonEmptyLeaf() bool {
return tr.HasKey() && tr.IsLeaf()
}
Expand Down Expand Up @@ -107,44 +88,57 @@ func (tr *Trie[T]) shrink() {
}
}

func (tr *Trie[T]) firstNonEmptyLeaf() *Trie[T] {
if tr.IsLeaf() {
if tr.HasKey() {
return tr
}
return nil
}
f := tr.Branch[0].firstNonEmptyLeaf()
if f != nil {
return f
}
return tr.Branch[1].firstNonEmptyLeaf()
}

// Add adds the key to trie, returning a new trie.
// Add is immutable/non-destructive: The original trie remains unchanged.
// Add is immutable/non-destructive: the original trie remains unchanged.
func Add[T any](tr *Trie[T], kk key.KadKey, data T) (*Trie[T], error) {
return AddAtDepth(0, tr, kk, data)
f := tr.firstNonEmptyLeaf()
if f != nil {
if f.Key.Size() != kk.Size() {
return tr, ErrMismatchedKeyLength
}
}
return addAtDepth(0, tr, kk, data), nil
}

func AddAtDepth[T any](depth int, tr *Trie[T], kk key.KadKey, data T) (*Trie[T], error) {
func addAtDepth[T any](depth int, tr *Trie[T], kk key.KadKey, data T) *Trie[T] {
switch {
case tr.IsEmptyLeaf():
return &Trie[T]{Key: kk, Data: data}, nil
return &Trie[T]{Key: kk, Data: data}
case tr.IsNonEmptyLeaf():
if tr.Key.Size() != kk.Size() {
return nil, ErrMismatchedKeyLength
}
eq, _ := tr.Key.Equal(kk)
if eq {
return tr, nil
return tr
}
return trieForTwo[T](depth, tr.Key, tr.Data, kk, data), nil
return trieForTwo(depth, tr.Key, tr.Data, kk, data)

default:
dir := kk.BitAt(depth)
s := &Trie[T]{}
b, err := AddAtDepth(depth+1, tr.Branch[dir], kk, data)
if err != nil {
return nil, err
}
s.Branch[dir] = b
s.Branch[dir] = addAtDepth(depth+1, tr.Branch[dir], kk, data)
s.Branch[1-dir] = tr.Branch[1-dir]
return s, nil
return s
}
}

func trieForTwo[T any](depth int, p key.KadKey, pdata T, q key.KadKey, qdata T) *Trie[T] {
pDir, qDir := p.BitAt(depth), q.BitAt(depth)
if qDir == pDir {
s := &Trie[T]{}
s.Branch[pDir] = trieForTwo[T](depth+1, p, pdata, q, qdata)
s.Branch[pDir] = trieForTwo(depth+1, p, pdata, q, qdata)
s.Branch[1-pDir] = &Trie[T]{}
return s
} else {
Expand All @@ -155,40 +149,44 @@ func trieForTwo[T any](depth int, p key.KadKey, pdata T, q key.KadKey, qdata T)
}
}

// Remove is immutable/non-destructive: The original trie remains unchanged.
func Remove[T any](tr *Trie[T], q key.KadKey) (*Trie[T], error) {
return RemoveAtDepth(0, tr, q)
// Remove removes the key from the trie.
// Remove is immutable/non-destructive: the original trie remains unchanged.
// If the key did not exist in the trie then the original trie is returned.
func Remove[T any](tr *Trie[T], kk key.KadKey) (*Trie[T], error) {
f := tr.firstNonEmptyLeaf()
if f != nil {
if f.Key.Size() != kk.Size() {
return tr, ErrMismatchedKeyLength
}
}
return removeAtDepth(0, tr, kk), nil
}

func RemoveAtDepth[T any](depth int, tr *Trie[T], kk key.KadKey) (*Trie[T], error) {
func removeAtDepth[T any](depth int, tr *Trie[T], kk key.KadKey) *Trie[T] {
switch {
case tr.IsEmptyLeaf():
return tr, nil
return tr
case tr.IsNonEmptyLeaf():
if tr.Key.Size() != kk.Size() {
return nil, ErrMismatchedKeyLength
return nil
}
eq, _ := tr.Key.Equal(kk)
if !eq {
return tr, nil
return tr
}
return &Trie[T]{}, nil
return &Trie[T]{}

default:
dir := kk.BitAt(depth)
b, err := RemoveAtDepth(depth+1, tr.Branch[dir], kk)
if err != nil {
return nil, err
}
afterDelete := b
afterDelete := removeAtDepth(depth+1, tr.Branch[dir], kk)
if afterDelete == tr.Branch[dir] {
return tr, nil
return tr
}
copy := &Trie[T]{}
copy.Branch[dir] = afterDelete
copy.Branch[1-dir] = tr.Branch[1-dir]
copy.shrink()
return copy, nil
return copy
}
}

Expand Down
Loading

0 comments on commit 8ea6166

Please sign in to comment.