diff --git a/go.mod b/go.mod index 8db03b5..36862ed 100644 --- a/go.mod +++ b/go.mod @@ -1,6 +1,6 @@ module github.com/libp2p/go-libp2p-xor -go 1.17 +go 1.20 require github.com/libp2p/go-libp2p-kbucket v0.3.1 diff --git a/kademlia/bucket.go b/kademlia/bucket.go index 8f4180f..fa69e97 100644 --- a/kademlia/bucket.go +++ b/kademlia/bucket.go @@ -7,7 +7,7 @@ import ( // BucketAtDepth returns the bucket in the routing table at a given depth. // A bucket at depth D holds contacts that share a prefix of exactly D bits with node. -func BucketAtDepth(node key.Key, table *trie.Trie, depth int) *trie.Trie { +func BucketAtDepth[T any](node key.Key, table *trie.Trie[T], depth int) *trie.Trie[T] { dir := node.BitAt(depth) if table.IsLeaf() { return nil @@ -21,11 +21,11 @@ func BucketAtDepth(node key.Key, table *trie.Trie, depth int) *trie.Trie { } // ClosestN will return the count closest keys to the given key. -func ClosestN(node key.Key, table *trie.Trie, count int) []key.Key { +func ClosestN[T any](node key.Key, table *trie.Trie[T], count int) []key.Key { return closestAtDepth(node, table, 0, count, make([]key.Key, 0, count)) } -func closestAtDepth(node key.Key, table *trie.Trie, depth int, count int, found []key.Key) []key.Key { +func closestAtDepth[T any](node key.Key, table *trie.Trie[T], depth int, count int, found []key.Key) []key.Key { // If we've already found enough peers, abort. if count == len(found) { return found diff --git a/kademlia/bucket_test.go b/kademlia/bucket_test.go index a078568..6940563 100644 --- a/kademlia/bucket_test.go +++ b/kademlia/bucket_test.go @@ -8,8 +8,8 @@ import ( "github.com/libp2p/go-libp2p-xor/trie" ) -func randomTrie(count int, keySizeByte int) *trie.Trie { - t := trie.New() +func randomTrie[T any](count int, keySizeByte int) *trie.Trie[T] { + t := trie.New[T]() for i := 0; i < count; i++ { t.Add(randomKey(keySizeByte)) } @@ -18,7 +18,7 @@ func randomTrie(count int, keySizeByte int) *trie.Trie { func TestClosestN(t *testing.T) { keySizeByte := 16 - root := randomTrie(100, keySizeByte) + root := randomTrie[any](100, keySizeByte) all := root.List() for count := 0; count <= 100; count += 10 { target := randomKey(keySizeByte) @@ -50,7 +50,7 @@ var _x int func BenchmarkClosestN(b *testing.B) { keySizeByte := 16 - root := randomTrie(100000, keySizeByte) + root := randomTrie[any](100000, keySizeByte) count := 20 target := randomKey(keySizeByte) b.ResetTimer() @@ -61,7 +61,7 @@ func BenchmarkClosestN(b *testing.B) { func BenchmarkClosestTrivial(b *testing.B) { keySizeByte := 16 - root := randomTrie(100000, keySizeByte) + root := randomTrie[any](100000, keySizeByte) keys := root.List() count := 20 target := randomKey(keySizeByte) diff --git a/kademlia/health.go b/kademlia/health.go index e524c2c..0c16c8f 100644 --- a/kademlia/health.go +++ b/kademlia/health.go @@ -70,7 +70,7 @@ type Table struct { // AllTablesHealth computes health reports for a network of nodes, whose routing contacts are given. func AllTablesHealth(tables []*Table) (report []*TableHealthReport) { // Construct global network view trie - knownNodes := trie.New() + knownNodes := trie.New[any]() for _, table := range tables { knownNodes.Add(table.Node) } @@ -82,7 +82,7 @@ func AllTablesHealth(tables []*Table) (report []*TableHealthReport) { } func TableHealthFromSets(node key.Key, nodeContacts []key.Key, knownNodes []key.Key) *TableHealthReport { - knownNodesTrie := trie.New() + knownNodesTrie := trie.New[any]() for _, k := range knownNodes { knownNodesTrie.Add(k) } @@ -91,9 +91,9 @@ func TableHealthFromSets(node key.Key, nodeContacts []key.Key, knownNodes []key. // TableHealth computes the health report for a node, // given its routing contacts and a list of all known nodes in the network currently. -func TableHealth(node key.Key, nodeContacts []key.Key, knownNodes *trie.Trie) *TableHealthReport { +func TableHealth[T any](node key.Key, nodeContacts []key.Key, knownNodes *trie.Trie[T]) *TableHealthReport { // Reconstruct the node's routing table as a trie - nodeTable := trie.New() + nodeTable := trie.New[T]() nodeTable.Add(node) for _, u := range nodeContacts { nodeTable.Add(u) @@ -110,13 +110,13 @@ func TableHealth(node key.Key, nodeContacts []key.Key, knownNodes *trie.Trie) *T // BucketHealth computes the health report for each bucket in a node's routing table, // given the node's routing table and a list of all known nodes in the network currently. -func BucketHealth(node key.Key, nodeTable, knownNodes *trie.Trie) []*BucketHealthReport { +func BucketHealth[T any](node key.Key, nodeTable, knownNodes *trie.Trie[T]) []*BucketHealthReport { r := walkBucketHealth(0, node, nodeTable, knownNodes) sort.Sort(sortedBucketHealthReport(r)) return r } -func walkBucketHealth(depth int, node key.Key, nodeTable, knownNodes *trie.Trie) []*BucketHealthReport { +func walkBucketHealth[T any](depth int, node key.Key, nodeTable, knownNodes *trie.Trie[T]) []*BucketHealthReport { if nodeTable.IsLeaf() { return nil } else { @@ -156,7 +156,7 @@ func walkBucketHealth(depth int, node key.Key, nodeTable, knownNodes *trie.Trie) } } -func bucketReportFromTries(depth int, actualBucket, maxBucket *trie.Trie) *BucketHealthReport { +func bucketReportFromTries[T any](depth int, actualBucket, maxBucket *trie.Trie[T]) *BucketHealthReport { actualKnown := trie.IntersectAtDepth(depth, actualBucket, maxBucket) actualKnownSize := actualKnown.Size() return &BucketHealthReport{ diff --git a/trie/add.go b/trie/add.go index a6933e7..f314c7d 100644 --- a/trie/add.go +++ b/trie/add.go @@ -5,11 +5,11 @@ import ( ) // Add adds the key q to the trie. Add mutates the trie. -func (trie *Trie) Add(q key.Key) (insertedDepth int, insertedOK bool) { +func (trie *Trie[T]) Add(q key.Key) (insertedDepth int, insertedOK bool) { return trie.AddAtDepth(0, q) } -func (trie *Trie) AddAtDepth(depth int, q key.Key) (insertedDepth int, insertedOK bool) { +func (trie *Trie[T]) AddAtDepth(depth int, q key.Key) (insertedDepth int, insertedOK bool) { switch { case trie.IsEmptyLeaf(): trie.Key = q @@ -21,7 +21,7 @@ func (trie *Trie) AddAtDepth(depth int, q key.Key) (insertedDepth int, insertedO p := trie.Key trie.Key = nil // both branches are nil - trie.Branch[0], trie.Branch[1] = &Trie{}, &Trie{} + trie.Branch[0], trie.Branch[1] = &Trie[T]{}, &Trie[T]{} trie.Branch[p.BitAt(depth)].Key = p return trie.Branch[q.BitAt(depth)].AddAtDepth(depth+1, q) } @@ -32,40 +32,40 @@ func (trie *Trie) AddAtDepth(depth int, q key.Key) (insertedDepth int, insertedO // Add adds the key q to trie, returning a new trie. // Add is immutable/non-destructive: The original trie remains unchanged. -func Add(trie *Trie, q key.Key) *Trie { +func Add[T any](trie *Trie[T], q key.Key) *Trie[T] { return AddAtDepth(0, trie, q) } -func AddAtDepth(depth int, trie *Trie, q key.Key) *Trie { +func AddAtDepth[T any](depth int, trie *Trie[T], q key.Key) *Trie[T] { switch { case trie.IsEmptyLeaf(): - return &Trie{Key: q} + return &Trie[T]{Key: q} case trie.IsNonEmptyLeaf(): if key.Equal(trie.Key, q) { return trie } else { - return trieForTwo(depth, trie.Key, q) + return trieForTwo[T](depth, trie.Key, q) } default: dir := q.BitAt(depth) - s := &Trie{} + s := &Trie[T]{} s.Branch[dir] = AddAtDepth(depth+1, trie.Branch[dir], q) s.Branch[1-dir] = trie.Branch[1-dir] return s } } -func trieForTwo(depth int, p, q key.Key) *Trie { +func trieForTwo[T any](depth int, p, q key.Key) *Trie[T] { pDir, qDir := p.BitAt(depth), q.BitAt(depth) if qDir == pDir { - s := &Trie{} - s.Branch[pDir] = trieForTwo(depth+1, p, q) - s.Branch[1-pDir] = &Trie{} + s := &Trie[T]{} + s.Branch[pDir] = trieForTwo[T](depth+1, p, q) + s.Branch[1-pDir] = &Trie[T]{} return s } else { - s := &Trie{} - s.Branch[pDir] = &Trie{Key: p} - s.Branch[qDir] = &Trie{Key: q} + s := &Trie[T]{} + s.Branch[pDir] = &Trie[T]{Key: p} + s.Branch[qDir] = &Trie[T]{Key: q} return s } } diff --git a/trie/add_test.go b/trie/add_test.go index bf514bb..1bfbba9 100644 --- a/trie/add_test.go +++ b/trie/add_test.go @@ -10,8 +10,8 @@ import ( // Verify mutable and immutable add do the same thing. func TestMutableAndImmutableAddSame(t *testing.T) { for _, s := range append(testAddSamples, randomTestAddSamples(100)...) { - mut := New() - immut := New() + mut := New[any]() + immut := New[any]() for _, k := range s.Keys { mut.Add(k) immut = Add(immut, k) @@ -30,7 +30,7 @@ func TestMutableAndImmutableAddSame(t *testing.T) { func TestAddIsOrderIndependent(t *testing.T) { for _, s := range append(testAddSamples, randomTestAddSamples(100)...) { - base := New() + base := New[any]() for _, k := range s.Keys { base.Add(k) } @@ -39,7 +39,7 @@ func TestAddIsOrderIndependent(t *testing.T) { } for j := 0; j < 100; j++ { perm := rand.Perm(len(s.Keys)) - reordered := New() + reordered := New[any]() for i := range s.Keys { reordered.Add(s.Keys[perm[i]]) } diff --git a/trie/check.go b/trie/check.go index 04c5482..73c3456 100644 --- a/trie/check.go +++ b/trie/check.go @@ -11,11 +11,11 @@ type InvariantDiscrepancy struct { } // CheckInvariant panics of the trie does not meet its invariant. -func (trie *Trie) CheckInvariant() *InvariantDiscrepancy { +func (trie *Trie[T]) CheckInvariant() *InvariantDiscrepancy { return trie.checkInvariant(0, nil) } -func (trie *Trie) checkInvariant(depth int, pathSoFar *triePath) *InvariantDiscrepancy { +func (trie *Trie[T]) checkInvariant(depth int, pathSoFar *triePath) *InvariantDiscrepancy { switch { case trie.IsEmptyLeaf(): return nil diff --git a/trie/equal.go b/trie/equal.go index a9919c9..f75fb48 100644 --- a/trie/equal.go +++ b/trie/equal.go @@ -4,7 +4,7 @@ import ( "github.com/libp2p/go-libp2p-xor/key" ) -func Equal(p, q *Trie) bool { +func Equal[T any](p, q *Trie[T]) bool { switch { case p.IsLeaf() && q.IsLeaf(): return key.Equal(p.Key, q.Key) diff --git a/trie/find.go b/trie/find.go index d96375e..970f48c 100644 --- a/trie/find.go +++ b/trie/find.go @@ -7,11 +7,11 @@ import ( // Find looks for the key q in the trie. // It returns the depth of the leaf reached along the path of q, regardless of whether q was found in that leaf. // It also returns a boolean flag indicating whether the key was found. -func (trie *Trie) Find(q key.Key) (reachedDepth int, found bool) { +func (trie *Trie[T]) Find(q key.Key) (reachedDepth int, found bool) { return trie.FindAtDepth(0, q) } -func (trie *Trie) FindAtDepth(depth int, q key.Key) (reachedDepth int, found bool) { +func (trie *Trie[T]) FindAtDepth(depth int, q key.Key) (reachedDepth int, found bool) { switch { case trie.IsEmptyLeaf(): return depth, false diff --git a/trie/intersect.go b/trie/intersect.go index c30ab7d..9aaf134 100644 --- a/trie/intersect.go +++ b/trie/intersect.go @@ -25,37 +25,37 @@ func keyIsIn(q key.Key, s []key.Key) bool { // Intersect computes the intersection of the keys in p and q. // p and q must be non-nil. The returned trie is never nil. -func Intersect(p, q *Trie) *Trie { +func Intersect[T any](p, q *Trie[T]) *Trie[T] { return IntersectAtDepth(0, p, q) } -func IntersectAtDepth(depth int, p, q *Trie) *Trie { +func IntersectAtDepth[T any](depth int, p, q *Trie[T]) *Trie[T] { switch { case p.IsLeaf() && q.IsLeaf(): if p.IsEmpty() || q.IsEmpty() { - return &Trie{} // empty set + return &Trie[T]{} // empty set } else { if key.Equal(p.Key, q.Key) { - return &Trie{Key: p.Key} // singleton + return &Trie[T]{Key: p.Key} // singleton } else { - return &Trie{} // empty set + return &Trie[T]{} // empty set } } case p.IsLeaf() && !q.IsLeaf(): if p.IsEmpty() { - return &Trie{} // empty set + return &Trie[T]{} // empty set } else { if _, found := q.FindAtDepth(depth, p.Key); found { - return &Trie{Key: p.Key} + return &Trie[T]{Key: p.Key} } else { - return &Trie{} // empty set + return &Trie[T]{} // empty set } } case !p.IsLeaf() && q.IsLeaf(): return IntersectAtDepth(depth, q, p) case !p.IsLeaf() && !q.IsLeaf(): - disjointUnion := &Trie{ - Branch: [2]*Trie{ + disjointUnion := &Trie[T]{ + Branch: [2]*Trie[T]{ IntersectAtDepth(depth+1, p.Branch[0], q.Branch[0]), IntersectAtDepth(depth+1, p.Branch[1], q.Branch[1]), }, diff --git a/trie/intersect_test.go b/trie/intersect_test.go index 29d7265..088cbf4 100644 --- a/trie/intersect_test.go +++ b/trie/intersect_test.go @@ -28,7 +28,7 @@ func TestIntersectFromJSON(t *testing.T) { } func testIntersect(t *testing.T, sample *testSetSample) { - left, right, expected := New(), New(), New() + left, right, expected := New[any](), New[any](), New[any]() for _, l := range sample.LeftKeys { left.Add(l) } @@ -149,19 +149,19 @@ var testJSONSamples = []string{ func TestIntersectTriesFromJSON(t *testing.T) { for _, json := range testIntersectJSONTries { - s := testIntersectTrieFromJSON(json) + s := testIntersectTrieFromJSON[any](json) testIntersectTries(t, s) } } -func testIntersectTries(t *testing.T, sample *testIntersectTrie) { +func testIntersectTries[T any](t *testing.T, sample *testIntersectTrie[T]) { if d := sample.LeftTrie.CheckInvariant(); d != nil { t.Fatalf("left trie invariant discrepancy: %v", d) } if d := sample.RightTrie.CheckInvariant(); d != nil { t.Fatalf("right trie invariant discrepancy: %v", d) } - expected := New() + expected := New[T]() for _, s := range setIntersect(sample.LeftTrie.List(), sample.RightTrie.List()) { expected.Add(s) } @@ -175,13 +175,13 @@ func testIntersectTries(t *testing.T, sample *testIntersectTrie) { } } -type testIntersectTrie struct { - LeftTrie *Trie - RightTrie *Trie +type testIntersectTrie[T any] struct { + LeftTrie *Trie[T] + RightTrie *Trie[T] } -func testIntersectTrieFromJSON(srcJSON string) *testIntersectTrie { - s := &testIntersectTrie{} +func testIntersectTrieFromJSON[T any](srcJSON string) *testIntersectTrie[T] { + s := &testIntersectTrie[T]{} if err := json.Unmarshal([]byte(srcJSON), s); err != nil { panic(err) } diff --git a/trie/list.go b/trie/list.go index b8c355e..6c1a79d 100644 --- a/trie/list.go +++ b/trie/list.go @@ -5,7 +5,7 @@ import ( ) // List returns a list of all keys in the trie. -func (trie *Trie) List() []key.Key { +func (trie *Trie[T]) List() []key.Key { switch { case trie.IsEmptyLeaf(): return nil diff --git a/trie/remove.go b/trie/remove.go index 6290b21..b2a9d8d 100644 --- a/trie/remove.go +++ b/trie/remove.go @@ -6,11 +6,11 @@ import ( // Remove removes the key q from the trie. Remove mutates the trie. // TODO: Also implement an immutable version of Remove. -func (trie *Trie) Remove(q key.Key) (removedDepth int, removed bool) { +func (trie *Trie[T]) Remove(q key.Key) (removedDepth int, removed bool) { return trie.RemoveAtDepth(0, q) } -func (trie *Trie) RemoveAtDepth(depth int, q key.Key) (reachedDepth int, removed bool) { +func (trie *Trie[T]) RemoveAtDepth(depth int, q key.Key) (reachedDepth int, removed bool) { switch { case trie.IsEmptyLeaf(): return depth, false @@ -27,25 +27,25 @@ func (trie *Trie) RemoveAtDepth(depth int, q key.Key) (reachedDepth int, removed } } -func Remove(trie *Trie, q key.Key) *Trie { +func Remove[T any](trie *Trie[T], q key.Key) *Trie[T] { return RemoveAtDepth(0, trie, q) } -func RemoveAtDepth(depth int, trie *Trie, q key.Key) *Trie { +func RemoveAtDepth[T any](depth int, trie *Trie[T], q key.Key) *Trie[T] { switch { case trie.IsEmptyLeaf(): return trie case trie.IsNonEmptyLeaf() && !key.Equal(trie.Key, q): return trie case trie.IsNonEmptyLeaf() && key.Equal(trie.Key, q): - return &Trie{} + return &Trie[T]{} default: dir := q.BitAt(depth) afterDelete := RemoveAtDepth(depth+1, trie.Branch[dir], q) if afterDelete == trie.Branch[dir] { return trie } - copy := &Trie{} + copy := &Trie[T]{} copy.Branch[dir] = afterDelete copy.Branch[1-dir] = trie.Branch[1-dir] copy.shrink() diff --git a/trie/remove_test.go b/trie/remove_test.go index d322c78..bfe7aec 100644 --- a/trie/remove_test.go +++ b/trie/remove_test.go @@ -8,7 +8,7 @@ import ( func TestImmutableRemoveIsImmutable(t *testing.T) { for _, keySet := range testAddSamples { - trie := FromKeys(keySet.Keys) + trie := FromKeys[any](keySet.Keys) for _, key := range keySet.Keys { updated := Remove(trie, key) if Equal(trie, updated) { @@ -21,12 +21,12 @@ func TestImmutableRemoveIsImmutable(t *testing.T) { func TestMutableAndImmutableRemoveSame(t *testing.T) { for _, keySet := range append(testAddSamples, randomTestAddSamples(100)...) { - mut := FromKeys(keySet.Keys) - immut := FromKeys(keySet.Keys) + mut := FromKeys[any](keySet.Keys) + immut := FromKeys[any](keySet.Keys) - for _, key := range keySet.Keys { - mut.Remove(key) - immut = Remove(immut, key) + for _, k := range keySet.Keys { + mut.Remove(k) + immut = Remove(immut, k) if d := mut.CheckInvariant(); d != nil { t.Fatalf("mutable trie invariant discrepancy: %v", d) } @@ -42,8 +42,8 @@ func TestMutableAndImmutableRemoveSame(t *testing.T) { func TestRemoveIsOrderIndependent(t *testing.T) { for _, keySet := range append(testAddSamples, randomTestAddSamples(100)...) { - mut := FromKeys(keySet.Keys) - immut := FromKeys(keySet.Keys) + mut := FromKeys[any](keySet.Keys) + immut := FromKeys[any](keySet.Keys) for j := 0; j < 100; j++ { perm := rand.Perm(len(keySet.Keys)) @@ -63,7 +63,7 @@ func TestRemoveIsOrderIndependent(t *testing.T) { } func TestRemoveReturnsOriginalWhenNoKeyRemoved(t *testing.T) { - trie := FromKeys(testAddSamples[0].Keys) + trie := FromKeys[any](testAddSamples[0].Keys) result := Remove(trie, key.ByteKey(2)) if trie != result { diff --git a/trie/trie.go b/trie/trie.go index f58d847..cf5d70a 100644 --- a/trie/trie.go +++ b/trie/trie.go @@ -11,41 +11,42 @@ import ( // (1) Either both branches are nil, or both are non-nil. // (2) If branches are non-nil, key must be nil. // (3) If both branches are leaves, then they are both non-empty (have keys). -type Trie struct { - Branch [2]*Trie +type Trie[T any] struct { + Branch [2]*Trie[T] Key key.Key + Data T } -func New() *Trie { - return &Trie{} +func New[T any]() *Trie[T] { + return &Trie[T]{} } -func FromKeys(k []key.Key) *Trie { - t := New() +func FromKeys[T any](k []key.Key) *Trie[T] { + t := New[T]() for _, k := range k { t.Add(k) } return t } -func FromKeysAtDepth(depth int, k []key.Key) *Trie { - t := New() +func FromKeysAtDepth[T any](depth int, k []key.Key) *Trie[T] { + t := New[T]() for _, k := range k { t.AddAtDepth(depth, k) } return t } -func (trie *Trie) String() string { +func (trie *Trie[T]) String() string { b, _ := json.Marshal(trie) return string(b) } -func (trie *Trie) Depth() int { +func (trie *Trie[T]) Depth() int { return trie.DepthAtDepth(0) } -func (trie *Trie) DepthAtDepth(depth int) int { +func (trie *Trie[T]) DepthAtDepth(depth int) int { if trie.Branch[0] == nil && trie.Branch[1] == nil { return depth } else { @@ -62,11 +63,11 @@ func max(x, y int) int { // Size returns the number of keys added to the trie. // In other words, it returns the number of non-empty leaves in the trie. -func (trie *Trie) Size() int { +func (trie *Trie[T]) Size() int { return trie.SizeAtDepth(0) } -func (trie *Trie) SizeAtDepth(depth int) int { +func (trie *Trie[T]) SizeAtDepth(depth int) int { if trie.Branch[0] == nil && trie.Branch[1] == nil { if trie.IsEmpty() { return 0 @@ -78,34 +79,34 @@ func (trie *Trie) SizeAtDepth(depth int) int { } } -func (trie *Trie) IsEmpty() bool { +func (trie *Trie[T]) IsEmpty() bool { return trie.Key == nil } -func (trie *Trie) IsLeaf() bool { +func (trie *Trie[T]) IsLeaf() bool { return trie.Branch[0] == nil && trie.Branch[1] == nil } -func (trie *Trie) IsEmptyLeaf() bool { +func (trie *Trie[T]) IsEmptyLeaf() bool { return trie.IsEmpty() && trie.IsLeaf() } -func (trie *Trie) IsNonEmptyLeaf() bool { +func (trie *Trie[T]) IsNonEmptyLeaf() bool { return !trie.IsEmpty() && trie.IsLeaf() } -func (trie *Trie) Copy() *Trie { +func (trie *Trie[T]) Copy() *Trie[T] { if trie.IsLeaf() { - return &Trie{Key: trie.Key} + return &Trie[T]{Key: trie.Key} } - return &Trie{Branch: [2]*Trie{ + return &Trie[T]{Branch: [2]*Trie[T]{ trie.Branch[0].Copy(), trie.Branch[1].Copy(), }} } -func (trie *Trie) shrink() { +func (trie *Trie[T]) shrink() { b0, b1 := trie.Branch[0], trie.Branch[1] switch { case b0.IsEmptyLeaf() && b1.IsEmptyLeaf(): diff --git a/trie/trie_test.go b/trie/trie_test.go index 743a0e2..9c6201e 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -7,12 +7,12 @@ import ( ) func TestInsertRemove(t *testing.T) { - r := New() + r := New[any]() testSeq(r, t) testSeq(r, t) } -func testSeq(r *Trie, t *testing.T) { +func testSeq[T any](r *Trie[T], t *testing.T) { for _, s := range testInsertSeq { depth, _ := r.Add(key.BytesKey(s.key)) if depth != s.insertedDepth { @@ -29,7 +29,7 @@ func testSeq(r *Trie, t *testing.T) { func TestCopy(t *testing.T) { for _, sample := range testAddSamples { - trie := FromKeys(sample.Keys) + trie := FromKeys[any](sample.Keys) copy := trie.Copy() if d := copy.CheckInvariant(); d != nil { t.Fatalf("trie invariant discrepancy: %v", d) diff --git a/trie/union.go b/trie/union.go index fbeaa27..42514bb 100644 --- a/trie/union.go +++ b/trie/union.go @@ -14,22 +14,22 @@ func UnionKeySlices(left, right []key.Key) []key.Key { return result } -func Union(left, right *Trie) *Trie { +func Union[T any](left, right *Trie[T]) *Trie[T] { return UnionAtDepth(0, left, right) } -func UnionAtDepth(depth int, left, right *Trie) *Trie { +func UnionAtDepth[T any](depth int, left, right *Trie[T]) *Trie[T] { switch { case left.IsLeaf() && right.IsLeaf(): switch { case left.IsEmpty() && right.IsEmpty(): - return &Trie{} + return &Trie[T]{} case !left.IsEmpty() && right.IsEmpty(): - return &Trie{Key: left.Key} + return &Trie[T]{Key: left.Key} case left.IsEmpty() && !right.IsEmpty(): - return &Trie{Key: right.Key} + return &Trie[T]{Key: right.Key} case !left.IsEmpty() && !right.IsEmpty(): - u := &Trie{} + u := &Trie[T]{} u.AddAtDepth(depth, left.Key) u.AddAtDepth(depth, right.Key) return u @@ -39,7 +39,7 @@ func UnionAtDepth(depth int, left, right *Trie) *Trie { case left.IsLeaf() && !right.IsLeaf(): return unionTrieAndLeaf(depth, right, left) case !left.IsLeaf() && !right.IsLeaf(): - return &Trie{Branch: [2]*Trie{ + return &Trie[T]{Branch: [2]*Trie[T]{ UnionAtDepth(depth+1, left.Branch[0], right.Branch[0]), UnionAtDepth(depth+1, left.Branch[1], right.Branch[1]), }} @@ -47,12 +47,12 @@ func UnionAtDepth(depth int, left, right *Trie) *Trie { panic("unreachable") } -func unionTrieAndLeaf(depth int, trie, leaf *Trie) *Trie { +func unionTrieAndLeaf[T any](depth int, trie, leaf *Trie[T]) *Trie[T] { if leaf.IsEmpty() { return trie.Copy() } else { dir := leaf.Key.BitAt(depth) - copy := &Trie{} + copy := &Trie[T]{} copy.Branch[dir] = UnionAtDepth(depth+1, trie.Branch[dir], leaf) copy.Branch[1-dir] = trie.Branch[1-dir].Copy() return copy diff --git a/trie/union_test.go b/trie/union_test.go index d0a152c..883c4e7 100644 --- a/trie/union_test.go +++ b/trie/union_test.go @@ -24,20 +24,20 @@ func TestUnionFromJSON(t *testing.T) { } func testUnion(t *testing.T, sample *testSetSample) { - left := FromKeys(sample.LeftKeys) - right := FromKeys(sample.RightKeys) + left := FromKeys[any](sample.LeftKeys) + right := FromKeys[any](sample.RightKeys) trie := Union(left, right) - expected := FromKeys(UnionKeySlices(sample.LeftKeys, sample.RightKeys)) + expected := FromKeys[any](UnionKeySlices(sample.LeftKeys, sample.RightKeys)) if !Equal(trie, expected) { t.Errorf("union does not have expected values") } - nodesMap := trieNodes(left, make(map[*Trie]bool)) + nodesMap := trieNodes(left, make(map[*Trie[any]]bool)) nodesMap = trieNodes(right, nodesMap) testTrieNotSameReference(t, nodesMap, trie) } -func testTrieNotSameReference(t *testing.T, nodesMap map[*Trie]bool, union *Trie) { +func testTrieNotSameReference[T any](t *testing.T, nodesMap map[*Trie[T]]bool, union *Trie[T]) { if union == nil { return } @@ -48,7 +48,7 @@ func testTrieNotSameReference(t *testing.T, nodesMap map[*Trie]bool, union *Trie testTrieNotSameReference(t, nodesMap, union.Branch[1]) } -func trieNodes(trie *Trie, nodesMap map[*Trie]bool) map[*Trie]bool { +func trieNodes[T any](trie *Trie[T], nodesMap map[*Trie[T]]bool) map[*Trie[T]]bool { if trie == nil { return nodesMap }