From 9b9e606066c9f19a2ecbeac1c81f69a1303b27c3 Mon Sep 17 00:00:00 2001 From: Ian Davis <18375+iand@users.noreply.github.com> Date: Wed, 23 Aug 2023 12:45:13 +0100 Subject: [PATCH] Add GetNode method to kad.RoutingTable (#108) Both routing table implementations already had a Find method. This change changes the name of the method to GetNode and adjusts the signature to match the style of other methods of kad.RoutingTable --- coord/coordinator_test.go | 12 +++++----- kad/kad.go | 5 ++++ routing/include.go | 11 ++------- routing/include_test.go | 12 +++++----- routing/simplert/table.go | 16 ++++--------- routing/simplert/table_test.go | 23 ++++++++---------- routing/triert/table.go | 11 ++++----- routing/triert/table_test.go | 44 ++++++++++++++++------------------ 8 files changed, 59 insertions(+), 75 deletions(-) diff --git a/coord/coordinator_test.go b/coord/coordinator_test.go index cf9f68a..4781593 100644 --- a/coord/coordinator_test.go +++ b/coord/coordinator_test.go @@ -373,9 +373,9 @@ func TestIncludeNode(t *testing.T) { candidate := nodes[3] // not in nodes[0] routing table // the routing table should not contain the node yet - foundNode, err := rts[0].Find(ctx, candidate.ID().Key()) - require.NoError(t, err) - require.Nil(t, foundNode) + foundNode, found := rts[0].GetNode(candidate.ID().Key()) + require.False(t, found) + require.Zero(t, foundNode) self := nodes[0].ID() c, err := NewCoordinator[key.Key8, kadtest.StrAddr](self, eps[0], findNodeFn, rts[0], scheds[0], ccfg) @@ -399,7 +399,7 @@ func TestIncludeNode(t *testing.T) { require.Equal(t, candidate.ID(), tev.NodeInfo.ID()) // the routing table should contain the node - foundNode, err = rts[0].Find(ctx, candidate.ID().Key()) - require.NoError(t, err) - require.NotNil(t, foundNode) + foundNode, found = rts[0].GetNode(candidate.ID().Key()) + require.True(t, found) + require.NotZero(t, foundNode) } diff --git a/kad/kad.go b/kad/kad.go index 0d319bd..815df3a 100644 --- a/kad/kad.go +++ b/kad/kad.go @@ -68,6 +68,11 @@ type RoutingTable[K Key[K], N NodeID[K]] interface { // contain at maximum the given number of entries, but also possibly less // if the number exceeds the number of nodes in the routing table. NearestNodes(K, int) []N + + // GetNode returns the node identified by the supplied Kademlia key or a zero + // value if the node is not present in the routing table. The boolean second + // return value indicates whether the node was found in the table. + GetNode(K) (N, bool) } // NodeID is a generic node identifier and not equal to a Kademlia key. Some diff --git a/routing/include.go b/routing/include.go index d6fec3f..cab122a 100644 --- a/routing/include.go +++ b/routing/include.go @@ -112,15 +112,8 @@ func (b *Include[K, A]) Advance(ctx context.Context, ev IncludeEvent) IncludeSta } // Ignore if node already in routing table - // TODO: promote this interface (or something similar) to kad.RoutingTable - if rtf, ok := b.rt.(interface { - Find(context.Context, kad.NodeID[K]) (kad.NodeInfo[K, A], error) - }); ok { - n, _ := rtf.Find(ctx, tev.NodeInfo.ID()) - if n != nil { - // node already in routing table - break - } + if _, exists := b.rt.GetNode(tev.NodeInfo.ID().Key()); exists { + break } // TODO: potentially time out a check and make room in the queue diff --git a/routing/include_test.go b/routing/include_test.go index 4b00311..dc2c88c 100644 --- a/routing/include_test.go +++ b/routing/include_test.go @@ -263,8 +263,8 @@ func TestIncludeMessageResponse(t *testing.T) { require.Equal(t, kadtest.NewID(key.Key8(4)), st.NodeInfo.ID()) // the routing table should contain the node - foundNode, err := rt.Find(ctx, key.Key8(4)) - require.NoError(t, err) + foundNode, found := rt.GetNode(key.Key8(4)) + require.True(t, found) require.NotNil(t, foundNode) require.True(t, key.Equal(foundNode.Key(), key.Key8(4))) @@ -306,8 +306,8 @@ func TestIncludeMessageResponseInvalid(t *testing.T) { require.IsType(t, &StateIncludeIdle{}, state) // the routing table should not contain the node - foundNode, err := rt.Find(ctx, key.Key8(4)) - require.NoError(t, err) + foundNode, found := rt.GetNode(key.Key8(4)) + require.False(t, found) require.Nil(t, foundNode) } @@ -344,7 +344,7 @@ func TestIncludeMessageFailure(t *testing.T) { require.IsType(t, &StateIncludeIdle{}, state) // the routing table should not contain the node - foundNode, err := rt.Find(ctx, key.Key8(4)) - require.NoError(t, err) + foundNode, found := rt.GetNode(key.Key8(4)) + require.False(t, found) require.Nil(t, foundNode) } diff --git a/routing/simplert/table.go b/routing/simplert/table.go index ac4d769..fc43433 100644 --- a/routing/simplert/table.go +++ b/routing/simplert/table.go @@ -1,16 +1,12 @@ package simplert import ( - "context" "sort" "github.com/plprobelab/go-kademlia/internal/kadtest" "github.com/plprobelab/go-kademlia/kad" "github.com/plprobelab/go-kademlia/key" - "github.com/plprobelab/go-kademlia/util" - "go.opentelemetry.io/otel/attribute" - "go.opentelemetry.io/otel/trace" ) type peerInfo[K kad.Key[K], N kad.NodeID[K]] struct { @@ -174,19 +170,15 @@ func (rt *SimpleRT[K, N]) RemoveKey(kadId K) bool { return false } -func (rt *SimpleRT[K, N]) Find(ctx context.Context, kadId K) (N, error) { - _, span := util.StartSpan(ctx, "routing.simple.find", trace.WithAttributes( - attribute.String("KadID", key.HexString(kadId)), - )) - defer span.End() - +func (rt *SimpleRT[K, N]) GetNode(kadId K) (N, bool) { bid, _ := rt.BucketIdForKey(kadId) for _, p := range rt.buckets[bid] { if key.Equal(kadId, p.kadId) { - return p.id, nil + return p.id, true } } - return *new(N), nil // TODO: can't return nil + var zero N + return zero, false } // TODO: not exactly working as expected diff --git a/routing/simplert/table_test.go b/routing/simplert/table_test.go index 0140e2a..9eb0f2e 100644 --- a/routing/simplert/table_test.go +++ b/routing/simplert/table_test.go @@ -1,7 +1,6 @@ package simplert import ( - "context" "fmt" "testing" @@ -117,7 +116,7 @@ func TestAddPeer(t *testing.T) { require.False(t, success) } -func TestRemovePeer(t *testing.T) { +func TestRemoveKey(t *testing.T) { p := kt.NewID(key0) // irrelevant rt := New[key.Key256](kt.NewID(key0), 2) @@ -128,29 +127,27 @@ func TestRemovePeer(t *testing.T) { require.True(t, success) } -func TestFindPeer(t *testing.T) { - ctx := context.Background() +func TestGetNode(t *testing.T) { p := kt.NewID(key0) rt := New[key.Key256](kt.NewID(key0), 2) success := rt.addPeer(key1, p) require.True(t, success) - peerid, err := rt.Find(ctx, key1) - require.NoError(t, err) + peerid, found := rt.GetNode(key1) + require.True(t, found) require.Equal(t, p, peerid) - peerid, err = rt.Find(ctx, key2) - require.NoError(t, err) - require.Nil(t, peerid) + peerid, found = rt.GetNode(key2) + require.False(t, found) + require.Zero(t, peerid) success = rt.RemoveKey(key1) - require.NoError(t, err) require.True(t, success) - peerid, err = rt.Find(ctx, key1) - require.NoError(t, err) - require.Nil(t, peerid) + peerid, found = rt.GetNode(key1) + require.False(t, found) + require.Zero(t, peerid) } func TestNearestPeers(t *testing.T) { diff --git a/routing/triert/table.go b/routing/triert/table.go index a7379b6..15e92fc 100644 --- a/routing/triert/table.go +++ b/routing/triert/table.go @@ -1,7 +1,6 @@ package triert import ( - "context" "fmt" "github.com/plprobelab/go-kademlia/internal/kadtest" @@ -80,13 +79,13 @@ func (rt *TrieRT[K, N]) NearestNodes(target K, n int) []N { return nodes } -func (rt *TrieRT[K, N]) Find(ctx context.Context, kk K) (kad.NodeID[K], error) { +func (rt *TrieRT[K, N]) GetNode(kk K) (N, bool) { found, node := trie.Find(rt.keys, kk) - if found { - return node, nil + if !found { + var zero N + return zero, false } - - return nil, nil + return node, true } // Size returns the number of peers contained in the table. diff --git a/routing/triert/table_test.go b/routing/triert/table_test.go index ae61392..2931e50 100644 --- a/routing/triert/table_test.go +++ b/routing/triert/table_test.go @@ -1,7 +1,6 @@ package triert import ( - "context" "math/rand" "testing" @@ -110,7 +109,7 @@ func TestRemovePeer(t *testing.T) { }) } -func TestFindPeer(t *testing.T) { +func TestGetNode(t *testing.T) { t.Run("known peer", func(t *testing.T) { rt, err := New[key.Key32](node0, nil) require.NoError(t, err) @@ -118,17 +117,17 @@ func TestFindPeer(t *testing.T) { require.True(t, success) want := node1 - got, err := rt.Find(context.Background(), key1) - require.NoError(t, err) + got, found := rt.GetNode(key1) + require.True(t, found) require.Equal(t, want, got) }) t.Run("unknown peer", func(t *testing.T) { rt, err := New[key.Key32](node0, nil) require.NoError(t, err) - got, err := rt.Find(context.Background(), key2) - require.NoError(t, err) - require.Nil(t, got) + got, found := rt.GetNode(key2) + require.False(t, found) + require.Zero(t, got) }) t.Run("removed peer", func(t *testing.T) { @@ -138,16 +137,16 @@ func TestFindPeer(t *testing.T) { require.True(t, success) want := node1 - got, err := rt.Find(context.Background(), key1) - require.NoError(t, err) + got, found := rt.GetNode(key1) + require.True(t, found) require.Equal(t, want, got) success = rt.RemoveKey(key1) require.True(t, success) - got, err = rt.Find(context.Background(), key1) - require.NoError(t, err) - require.Nil(t, got) + got, found = rt.GetNode(key1) + require.False(t, found) + require.Zero(t, got) }) } @@ -287,7 +286,6 @@ func TestCplSize(t *testing.T) { } func TestKeyFilter(t *testing.T) { - ctx := context.Background() cfg := DefaultConfig[key.Key32, node[key.Key32]]() cfg.KeyFilter = func(rt *TrieRT[key.Key32, node[key.Key32]], kk key.Key32) bool { return !key.Equal(kk, key2) // don't allow key2 to be added @@ -300,9 +298,9 @@ func TestKeyFilter(t *testing.T) { require.NoError(t, err) require.False(t, success) - got, err := rt.Find(ctx, key2) - require.NoError(t, err) - require.Nil(t, got) + got, found := rt.GetNode(key2) + require.False(t, found) + require.Zero(t, got) // can add other key success = rt.AddNode(node1) @@ -310,8 +308,8 @@ func TestKeyFilter(t *testing.T) { require.True(t, success) want := node1 - got, err = rt.Find(ctx, key1) - require.NoError(t, err) + got, found = rt.GetNode(key1) + require.True(t, found) require.Equal(t, want, got) } @@ -382,7 +380,7 @@ func benchmarkFindPositive(n int) func(b *testing.B) { b.ResetTimer() b.ReportAllocs() for i := 0; i < b.N; i++ { - rt.Find(context.Background(), keys[i%len(keys)]) + rt.GetNode(keys[i%len(keys)]) } } } @@ -404,7 +402,7 @@ func benchmarkFindNegative(n int) func(b *testing.B) { unknown := make([]key.Key32, n) for i := 0; i < n; i++ { kk := kadtest.RandomKey() - if found, _ := rt.Find(context.Background(), kk); found != nil { + if found, _ := rt.GetNode(kk); found != nil { continue } unknown[i] = kk @@ -413,7 +411,7 @@ func benchmarkFindNegative(n int) func(b *testing.B) { b.ResetTimer() b.ReportAllocs() for i := 0; i < b.N; i++ { - rt.Find(context.Background(), unknown[i%len(unknown)]) + rt.GetNode(unknown[i%len(unknown)]) } } } @@ -459,8 +457,8 @@ func benchmarkChurn(n int) func(b *testing.B) { b.ReportAllocs() for i := 0; i < b.N; i++ { node := universe[i%len(universe)] - found, _ := rt.Find(context.Background(), node.Key()) - if found == nil { + _, found := rt.GetNode(node.Key()) + if !found { // add new peer rt.AddNode(universe[i%len(universe)]) } else {