Skip to content

Commit

Permalink
Add GetNode method to kad.RoutingTable (#108)
Browse files Browse the repository at this point in the history
Both routing table implementations already had a Find method. This
change changes the name of the method to GetNode and adjusts the
signature to match the style of other methods of kad.RoutingTable
  • Loading branch information
iand authored Aug 23, 2023
1 parent 44238be commit 9b9e606
Show file tree
Hide file tree
Showing 8 changed files with 59 additions and 75 deletions.
12 changes: 6 additions & 6 deletions coord/coordinator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -373,9 +373,9 @@ func TestIncludeNode(t *testing.T) {
candidate := nodes[3] // not in nodes[0] routing table

// the routing table should not contain the node yet
foundNode, err := rts[0].Find(ctx, candidate.ID().Key())
require.NoError(t, err)
require.Nil(t, foundNode)
foundNode, found := rts[0].GetNode(candidate.ID().Key())
require.False(t, found)
require.Zero(t, foundNode)

self := nodes[0].ID()
c, err := NewCoordinator[key.Key8, kadtest.StrAddr](self, eps[0], findNodeFn, rts[0], scheds[0], ccfg)
Expand All @@ -399,7 +399,7 @@ func TestIncludeNode(t *testing.T) {
require.Equal(t, candidate.ID(), tev.NodeInfo.ID())

// the routing table should contain the node
foundNode, err = rts[0].Find(ctx, candidate.ID().Key())
require.NoError(t, err)
require.NotNil(t, foundNode)
foundNode, found = rts[0].GetNode(candidate.ID().Key())
require.True(t, found)
require.NotZero(t, foundNode)
}
5 changes: 5 additions & 0 deletions kad/kad.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ type RoutingTable[K Key[K], N NodeID[K]] interface {
// contain at maximum the given number of entries, but also possibly less
// if the number exceeds the number of nodes in the routing table.
NearestNodes(K, int) []N

// GetNode returns the node identified by the supplied Kademlia key or a zero
// value if the node is not present in the routing table. The boolean second
// return value indicates whether the node was found in the table.
GetNode(K) (N, bool)
}

// NodeID is a generic node identifier and not equal to a Kademlia key. Some
Expand Down
11 changes: 2 additions & 9 deletions routing/include.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,15 +112,8 @@ func (b *Include[K, A]) Advance(ctx context.Context, ev IncludeEvent) IncludeSta
}

// Ignore if node already in routing table
// TODO: promote this interface (or something similar) to kad.RoutingTable
if rtf, ok := b.rt.(interface {
Find(context.Context, kad.NodeID[K]) (kad.NodeInfo[K, A], error)
}); ok {
n, _ := rtf.Find(ctx, tev.NodeInfo.ID())
if n != nil {
// node already in routing table
break
}
if _, exists := b.rt.GetNode(tev.NodeInfo.ID().Key()); exists {
break
}

// TODO: potentially time out a check and make room in the queue
Expand Down
12 changes: 6 additions & 6 deletions routing/include_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -263,8 +263,8 @@ func TestIncludeMessageResponse(t *testing.T) {
require.Equal(t, kadtest.NewID(key.Key8(4)), st.NodeInfo.ID())

// the routing table should contain the node
foundNode, err := rt.Find(ctx, key.Key8(4))
require.NoError(t, err)
foundNode, found := rt.GetNode(key.Key8(4))
require.True(t, found)
require.NotNil(t, foundNode)

require.True(t, key.Equal(foundNode.Key(), key.Key8(4)))
Expand Down Expand Up @@ -306,8 +306,8 @@ func TestIncludeMessageResponseInvalid(t *testing.T) {
require.IsType(t, &StateIncludeIdle{}, state)

// the routing table should not contain the node
foundNode, err := rt.Find(ctx, key.Key8(4))
require.NoError(t, err)
foundNode, found := rt.GetNode(key.Key8(4))
require.False(t, found)
require.Nil(t, foundNode)
}

Expand Down Expand Up @@ -344,7 +344,7 @@ func TestIncludeMessageFailure(t *testing.T) {
require.IsType(t, &StateIncludeIdle{}, state)

// the routing table should not contain the node
foundNode, err := rt.Find(ctx, key.Key8(4))
require.NoError(t, err)
foundNode, found := rt.GetNode(key.Key8(4))
require.False(t, found)
require.Nil(t, foundNode)
}
16 changes: 4 additions & 12 deletions routing/simplert/table.go
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
package simplert

import (
"context"
"sort"

"github.com/plprobelab/go-kademlia/internal/kadtest"

"github.com/plprobelab/go-kademlia/kad"
"github.com/plprobelab/go-kademlia/key"
"github.com/plprobelab/go-kademlia/util"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
)

type peerInfo[K kad.Key[K], N kad.NodeID[K]] struct {
Expand Down Expand Up @@ -174,19 +170,15 @@ func (rt *SimpleRT[K, N]) RemoveKey(kadId K) bool {
return false
}

func (rt *SimpleRT[K, N]) Find(ctx context.Context, kadId K) (N, error) {
_, span := util.StartSpan(ctx, "routing.simple.find", trace.WithAttributes(
attribute.String("KadID", key.HexString(kadId)),
))
defer span.End()

func (rt *SimpleRT[K, N]) GetNode(kadId K) (N, bool) {
bid, _ := rt.BucketIdForKey(kadId)
for _, p := range rt.buckets[bid] {
if key.Equal(kadId, p.kadId) {
return p.id, nil
return p.id, true
}
}
return *new(N), nil // TODO: can't return nil
var zero N
return zero, false
}

// TODO: not exactly working as expected
Expand Down
23 changes: 10 additions & 13 deletions routing/simplert/table_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package simplert

import (
"context"
"fmt"
"testing"

Expand Down Expand Up @@ -117,7 +116,7 @@ func TestAddPeer(t *testing.T) {
require.False(t, success)
}

func TestRemovePeer(t *testing.T) {
func TestRemoveKey(t *testing.T) {
p := kt.NewID(key0) // irrelevant

rt := New[key.Key256](kt.NewID(key0), 2)
Expand All @@ -128,29 +127,27 @@ func TestRemovePeer(t *testing.T) {
require.True(t, success)
}

func TestFindPeer(t *testing.T) {
ctx := context.Background()
func TestGetNode(t *testing.T) {
p := kt.NewID(key0)

rt := New[key.Key256](kt.NewID(key0), 2)
success := rt.addPeer(key1, p)
require.True(t, success)

peerid, err := rt.Find(ctx, key1)
require.NoError(t, err)
peerid, found := rt.GetNode(key1)
require.True(t, found)
require.Equal(t, p, peerid)

peerid, err = rt.Find(ctx, key2)
require.NoError(t, err)
require.Nil(t, peerid)
peerid, found = rt.GetNode(key2)
require.False(t, found)
require.Zero(t, peerid)

success = rt.RemoveKey(key1)
require.NoError(t, err)
require.True(t, success)

peerid, err = rt.Find(ctx, key1)
require.NoError(t, err)
require.Nil(t, peerid)
peerid, found = rt.GetNode(key1)
require.False(t, found)
require.Zero(t, peerid)
}

func TestNearestPeers(t *testing.T) {
Expand Down
11 changes: 5 additions & 6 deletions routing/triert/table.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package triert

import (
"context"
"fmt"

"github.com/plprobelab/go-kademlia/internal/kadtest"
Expand Down Expand Up @@ -80,13 +79,13 @@ func (rt *TrieRT[K, N]) NearestNodes(target K, n int) []N {
return nodes
}

func (rt *TrieRT[K, N]) Find(ctx context.Context, kk K) (kad.NodeID[K], error) {
func (rt *TrieRT[K, N]) GetNode(kk K) (N, bool) {
found, node := trie.Find(rt.keys, kk)
if found {
return node, nil
if !found {
var zero N
return zero, false
}

return nil, nil
return node, true
}

// Size returns the number of peers contained in the table.
Expand Down
44 changes: 21 additions & 23 deletions routing/triert/table_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package triert

import (
"context"
"math/rand"
"testing"

Expand Down Expand Up @@ -110,25 +109,25 @@ func TestRemovePeer(t *testing.T) {
})
}

func TestFindPeer(t *testing.T) {
func TestGetNode(t *testing.T) {
t.Run("known peer", func(t *testing.T) {
rt, err := New[key.Key32](node0, nil)
require.NoError(t, err)
success := rt.AddNode(node1)
require.True(t, success)

want := node1
got, err := rt.Find(context.Background(), key1)
require.NoError(t, err)
got, found := rt.GetNode(key1)
require.True(t, found)
require.Equal(t, want, got)
})

t.Run("unknown peer", func(t *testing.T) {
rt, err := New[key.Key32](node0, nil)
require.NoError(t, err)
got, err := rt.Find(context.Background(), key2)
require.NoError(t, err)
require.Nil(t, got)
got, found := rt.GetNode(key2)
require.False(t, found)
require.Zero(t, got)
})

t.Run("removed peer", func(t *testing.T) {
Expand All @@ -138,16 +137,16 @@ func TestFindPeer(t *testing.T) {
require.True(t, success)

want := node1
got, err := rt.Find(context.Background(), key1)
require.NoError(t, err)
got, found := rt.GetNode(key1)
require.True(t, found)
require.Equal(t, want, got)

success = rt.RemoveKey(key1)
require.True(t, success)

got, err = rt.Find(context.Background(), key1)
require.NoError(t, err)
require.Nil(t, got)
got, found = rt.GetNode(key1)
require.False(t, found)
require.Zero(t, got)
})
}

Expand Down Expand Up @@ -287,7 +286,6 @@ func TestCplSize(t *testing.T) {
}

func TestKeyFilter(t *testing.T) {
ctx := context.Background()
cfg := DefaultConfig[key.Key32, node[key.Key32]]()
cfg.KeyFilter = func(rt *TrieRT[key.Key32, node[key.Key32]], kk key.Key32) bool {
return !key.Equal(kk, key2) // don't allow key2 to be added
Expand All @@ -300,18 +298,18 @@ func TestKeyFilter(t *testing.T) {
require.NoError(t, err)
require.False(t, success)

got, err := rt.Find(ctx, key2)
require.NoError(t, err)
require.Nil(t, got)
got, found := rt.GetNode(key2)
require.False(t, found)
require.Zero(t, got)

// can add other key
success = rt.AddNode(node1)
require.NoError(t, err)
require.True(t, success)

want := node1
got, err = rt.Find(ctx, key1)
require.NoError(t, err)
got, found = rt.GetNode(key1)
require.True(t, found)
require.Equal(t, want, got)
}

Expand Down Expand Up @@ -382,7 +380,7 @@ func benchmarkFindPositive(n int) func(b *testing.B) {
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
rt.Find(context.Background(), keys[i%len(keys)])
rt.GetNode(keys[i%len(keys)])
}
}
}
Expand All @@ -404,7 +402,7 @@ func benchmarkFindNegative(n int) func(b *testing.B) {
unknown := make([]key.Key32, n)
for i := 0; i < n; i++ {
kk := kadtest.RandomKey()
if found, _ := rt.Find(context.Background(), kk); found != nil {
if found, _ := rt.GetNode(kk); found != nil {
continue
}
unknown[i] = kk
Expand All @@ -413,7 +411,7 @@ func benchmarkFindNegative(n int) func(b *testing.B) {
b.ResetTimer()
b.ReportAllocs()
for i := 0; i < b.N; i++ {
rt.Find(context.Background(), unknown[i%len(unknown)])
rt.GetNode(unknown[i%len(unknown)])
}
}
}
Expand Down Expand Up @@ -459,8 +457,8 @@ func benchmarkChurn(n int) func(b *testing.B) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
node := universe[i%len(universe)]
found, _ := rt.Find(context.Background(), node.Key())
if found == nil {
_, found := rt.GetNode(node.Key())
if !found {
// add new peer
rt.AddNode(universe[i%len(universe)])
} else {
Expand Down

0 comments on commit 9b9e606

Please sign in to comment.