From 9b9e606066c9f19a2ecbeac1c81f69a1303b27c3 Mon Sep 17 00:00:00 2001
From: Ian Davis <18375+iand@users.noreply.github.com>
Date: Wed, 23 Aug 2023 12:45:13 +0100
Subject: [PATCH] Add GetNode method to kad.RoutingTable (#108)

Both routing table implementations already had a Find method. This
change changes the name of the method to GetNode and adjusts the
signature to match the style of other methods of kad.RoutingTable
---
 coord/coordinator_test.go      | 12 +++++-----
 kad/kad.go                     |  5 ++++
 routing/include.go             | 11 ++-------
 routing/include_test.go        | 12 +++++-----
 routing/simplert/table.go      | 16 ++++---------
 routing/simplert/table_test.go | 23 ++++++++----------
 routing/triert/table.go        | 11 ++++-----
 routing/triert/table_test.go   | 44 ++++++++++++++++------------------
 8 files changed, 59 insertions(+), 75 deletions(-)

diff --git a/coord/coordinator_test.go b/coord/coordinator_test.go
index cf9f68a..4781593 100644
--- a/coord/coordinator_test.go
+++ b/coord/coordinator_test.go
@@ -373,9 +373,9 @@ func TestIncludeNode(t *testing.T) {
 	candidate := nodes[3] // not in nodes[0] routing table
 
 	// the routing table should not contain the node yet
-	foundNode, err := rts[0].Find(ctx, candidate.ID().Key())
-	require.NoError(t, err)
-	require.Nil(t, foundNode)
+	foundNode, found := rts[0].GetNode(candidate.ID().Key())
+	require.False(t, found)
+	require.Zero(t, foundNode)
 
 	self := nodes[0].ID()
 	c, err := NewCoordinator[key.Key8, kadtest.StrAddr](self, eps[0], findNodeFn, rts[0], scheds[0], ccfg)
@@ -399,7 +399,7 @@ func TestIncludeNode(t *testing.T) {
 	require.Equal(t, candidate.ID(), tev.NodeInfo.ID())
 
 	// the routing table should contain the node
-	foundNode, err = rts[0].Find(ctx, candidate.ID().Key())
-	require.NoError(t, err)
-	require.NotNil(t, foundNode)
+	foundNode, found = rts[0].GetNode(candidate.ID().Key())
+	require.True(t, found)
+	require.NotZero(t, foundNode)
 }
diff --git a/kad/kad.go b/kad/kad.go
index 0d319bd..815df3a 100644
--- a/kad/kad.go
+++ b/kad/kad.go
@@ -68,6 +68,11 @@ type RoutingTable[K Key[K], N NodeID[K]] interface {
 	// contain at maximum the given number of entries, but also possibly less
 	// if the number exceeds the number of nodes in the routing table.
 	NearestNodes(K, int) []N
+
+	// GetNode returns the node identified by the supplied Kademlia key or a zero
+	// value if the node is not present in the routing table. The boolean second
+	// return value indicates whether the node was found in the table.
+	GetNode(K) (N, bool)
 }
 
 // NodeID is a generic node identifier and not equal to a Kademlia key. Some
diff --git a/routing/include.go b/routing/include.go
index d6fec3f..cab122a 100644
--- a/routing/include.go
+++ b/routing/include.go
@@ -112,15 +112,8 @@ func (b *Include[K, A]) Advance(ctx context.Context, ev IncludeEvent) IncludeSta
 		}
 
 		// Ignore if node already in routing table
-		// TODO: promote this interface (or something similar) to kad.RoutingTable
-		if rtf, ok := b.rt.(interface {
-			Find(context.Context, kad.NodeID[K]) (kad.NodeInfo[K, A], error)
-		}); ok {
-			n, _ := rtf.Find(ctx, tev.NodeInfo.ID())
-			if n != nil {
-				// node already in routing table
-				break
-			}
+		if _, exists := b.rt.GetNode(tev.NodeInfo.ID().Key()); exists {
+			break
 		}
 
 		// TODO: potentially time out a check and make room in the queue
diff --git a/routing/include_test.go b/routing/include_test.go
index 4b00311..dc2c88c 100644
--- a/routing/include_test.go
+++ b/routing/include_test.go
@@ -263,8 +263,8 @@ func TestIncludeMessageResponse(t *testing.T) {
 	require.Equal(t, kadtest.NewID(key.Key8(4)), st.NodeInfo.ID())
 
 	// the routing table should contain the node
-	foundNode, err := rt.Find(ctx, key.Key8(4))
-	require.NoError(t, err)
+	foundNode, found := rt.GetNode(key.Key8(4))
+	require.True(t, found)
 	require.NotNil(t, foundNode)
 
 	require.True(t, key.Equal(foundNode.Key(), key.Key8(4)))
@@ -306,8 +306,8 @@ func TestIncludeMessageResponseInvalid(t *testing.T) {
 	require.IsType(t, &StateIncludeIdle{}, state)
 
 	// the routing table should not contain the node
-	foundNode, err := rt.Find(ctx, key.Key8(4))
-	require.NoError(t, err)
+	foundNode, found := rt.GetNode(key.Key8(4))
+	require.False(t, found)
 	require.Nil(t, foundNode)
 }
 
@@ -344,7 +344,7 @@ func TestIncludeMessageFailure(t *testing.T) {
 	require.IsType(t, &StateIncludeIdle{}, state)
 
 	// the routing table should not contain the node
-	foundNode, err := rt.Find(ctx, key.Key8(4))
-	require.NoError(t, err)
+	foundNode, found := rt.GetNode(key.Key8(4))
+	require.False(t, found)
 	require.Nil(t, foundNode)
 }
diff --git a/routing/simplert/table.go b/routing/simplert/table.go
index ac4d769..fc43433 100644
--- a/routing/simplert/table.go
+++ b/routing/simplert/table.go
@@ -1,16 +1,12 @@
 package simplert
 
 import (
-	"context"
 	"sort"
 
 	"github.com/plprobelab/go-kademlia/internal/kadtest"
 
 	"github.com/plprobelab/go-kademlia/kad"
 	"github.com/plprobelab/go-kademlia/key"
-	"github.com/plprobelab/go-kademlia/util"
-	"go.opentelemetry.io/otel/attribute"
-	"go.opentelemetry.io/otel/trace"
 )
 
 type peerInfo[K kad.Key[K], N kad.NodeID[K]] struct {
@@ -174,19 +170,15 @@ func (rt *SimpleRT[K, N]) RemoveKey(kadId K) bool {
 	return false
 }
 
-func (rt *SimpleRT[K, N]) Find(ctx context.Context, kadId K) (N, error) {
-	_, span := util.StartSpan(ctx, "routing.simple.find", trace.WithAttributes(
-		attribute.String("KadID", key.HexString(kadId)),
-	))
-	defer span.End()
-
+func (rt *SimpleRT[K, N]) GetNode(kadId K) (N, bool) {
 	bid, _ := rt.BucketIdForKey(kadId)
 	for _, p := range rt.buckets[bid] {
 		if key.Equal(kadId, p.kadId) {
-			return p.id, nil
+			return p.id, true
 		}
 	}
-	return *new(N), nil // TODO: can't return nil
+	var zero N
+	return zero, false
 }
 
 // TODO: not exactly working as expected
diff --git a/routing/simplert/table_test.go b/routing/simplert/table_test.go
index 0140e2a..9eb0f2e 100644
--- a/routing/simplert/table_test.go
+++ b/routing/simplert/table_test.go
@@ -1,7 +1,6 @@
 package simplert
 
 import (
-	"context"
 	"fmt"
 	"testing"
 
@@ -117,7 +116,7 @@ func TestAddPeer(t *testing.T) {
 	require.False(t, success)
 }
 
-func TestRemovePeer(t *testing.T) {
+func TestRemoveKey(t *testing.T) {
 	p := kt.NewID(key0) // irrelevant
 
 	rt := New[key.Key256](kt.NewID(key0), 2)
@@ -128,29 +127,27 @@ func TestRemovePeer(t *testing.T) {
 	require.True(t, success)
 }
 
-func TestFindPeer(t *testing.T) {
-	ctx := context.Background()
+func TestGetNode(t *testing.T) {
 	p := kt.NewID(key0)
 
 	rt := New[key.Key256](kt.NewID(key0), 2)
 	success := rt.addPeer(key1, p)
 	require.True(t, success)
 
-	peerid, err := rt.Find(ctx, key1)
-	require.NoError(t, err)
+	peerid, found := rt.GetNode(key1)
+	require.True(t, found)
 	require.Equal(t, p, peerid)
 
-	peerid, err = rt.Find(ctx, key2)
-	require.NoError(t, err)
-	require.Nil(t, peerid)
+	peerid, found = rt.GetNode(key2)
+	require.False(t, found)
+	require.Zero(t, peerid)
 
 	success = rt.RemoveKey(key1)
-	require.NoError(t, err)
 	require.True(t, success)
 
-	peerid, err = rt.Find(ctx, key1)
-	require.NoError(t, err)
-	require.Nil(t, peerid)
+	peerid, found = rt.GetNode(key1)
+	require.False(t, found)
+	require.Zero(t, peerid)
 }
 
 func TestNearestPeers(t *testing.T) {
diff --git a/routing/triert/table.go b/routing/triert/table.go
index a7379b6..15e92fc 100644
--- a/routing/triert/table.go
+++ b/routing/triert/table.go
@@ -1,7 +1,6 @@
 package triert
 
 import (
-	"context"
 	"fmt"
 
 	"github.com/plprobelab/go-kademlia/internal/kadtest"
@@ -80,13 +79,13 @@ func (rt *TrieRT[K, N]) NearestNodes(target K, n int) []N {
 	return nodes
 }
 
-func (rt *TrieRT[K, N]) Find(ctx context.Context, kk K) (kad.NodeID[K], error) {
+func (rt *TrieRT[K, N]) GetNode(kk K) (N, bool) {
 	found, node := trie.Find(rt.keys, kk)
-	if found {
-		return node, nil
+	if !found {
+		var zero N
+		return zero, false
 	}
-
-	return nil, nil
+	return node, true
 }
 
 // Size returns the number of peers contained in the table.
diff --git a/routing/triert/table_test.go b/routing/triert/table_test.go
index ae61392..2931e50 100644
--- a/routing/triert/table_test.go
+++ b/routing/triert/table_test.go
@@ -1,7 +1,6 @@
 package triert
 
 import (
-	"context"
 	"math/rand"
 	"testing"
 
@@ -110,7 +109,7 @@ func TestRemovePeer(t *testing.T) {
 	})
 }
 
-func TestFindPeer(t *testing.T) {
+func TestGetNode(t *testing.T) {
 	t.Run("known peer", func(t *testing.T) {
 		rt, err := New[key.Key32](node0, nil)
 		require.NoError(t, err)
@@ -118,17 +117,17 @@ func TestFindPeer(t *testing.T) {
 		require.True(t, success)
 
 		want := node1
-		got, err := rt.Find(context.Background(), key1)
-		require.NoError(t, err)
+		got, found := rt.GetNode(key1)
+		require.True(t, found)
 		require.Equal(t, want, got)
 	})
 
 	t.Run("unknown peer", func(t *testing.T) {
 		rt, err := New[key.Key32](node0, nil)
 		require.NoError(t, err)
-		got, err := rt.Find(context.Background(), key2)
-		require.NoError(t, err)
-		require.Nil(t, got)
+		got, found := rt.GetNode(key2)
+		require.False(t, found)
+		require.Zero(t, got)
 	})
 
 	t.Run("removed peer", func(t *testing.T) {
@@ -138,16 +137,16 @@ func TestFindPeer(t *testing.T) {
 		require.True(t, success)
 
 		want := node1
-		got, err := rt.Find(context.Background(), key1)
-		require.NoError(t, err)
+		got, found := rt.GetNode(key1)
+		require.True(t, found)
 		require.Equal(t, want, got)
 
 		success = rt.RemoveKey(key1)
 		require.True(t, success)
 
-		got, err = rt.Find(context.Background(), key1)
-		require.NoError(t, err)
-		require.Nil(t, got)
+		got, found = rt.GetNode(key1)
+		require.False(t, found)
+		require.Zero(t, got)
 	})
 }
 
@@ -287,7 +286,6 @@ func TestCplSize(t *testing.T) {
 }
 
 func TestKeyFilter(t *testing.T) {
-	ctx := context.Background()
 	cfg := DefaultConfig[key.Key32, node[key.Key32]]()
 	cfg.KeyFilter = func(rt *TrieRT[key.Key32, node[key.Key32]], kk key.Key32) bool {
 		return !key.Equal(kk, key2) // don't allow key2 to be added
@@ -300,9 +298,9 @@ func TestKeyFilter(t *testing.T) {
 	require.NoError(t, err)
 	require.False(t, success)
 
-	got, err := rt.Find(ctx, key2)
-	require.NoError(t, err)
-	require.Nil(t, got)
+	got, found := rt.GetNode(key2)
+	require.False(t, found)
+	require.Zero(t, got)
 
 	// can add other key
 	success = rt.AddNode(node1)
@@ -310,8 +308,8 @@ func TestKeyFilter(t *testing.T) {
 	require.True(t, success)
 
 	want := node1
-	got, err = rt.Find(ctx, key1)
-	require.NoError(t, err)
+	got, found = rt.GetNode(key1)
+	require.True(t, found)
 	require.Equal(t, want, got)
 }
 
@@ -382,7 +380,7 @@ func benchmarkFindPositive(n int) func(b *testing.B) {
 		b.ResetTimer()
 		b.ReportAllocs()
 		for i := 0; i < b.N; i++ {
-			rt.Find(context.Background(), keys[i%len(keys)])
+			rt.GetNode(keys[i%len(keys)])
 		}
 	}
 }
@@ -404,7 +402,7 @@ func benchmarkFindNegative(n int) func(b *testing.B) {
 		unknown := make([]key.Key32, n)
 		for i := 0; i < n; i++ {
 			kk := kadtest.RandomKey()
-			if found, _ := rt.Find(context.Background(), kk); found != nil {
+			if found, _ := rt.GetNode(kk); found != nil {
 				continue
 			}
 			unknown[i] = kk
@@ -413,7 +411,7 @@ func benchmarkFindNegative(n int) func(b *testing.B) {
 		b.ResetTimer()
 		b.ReportAllocs()
 		for i := 0; i < b.N; i++ {
-			rt.Find(context.Background(), unknown[i%len(unknown)])
+			rt.GetNode(unknown[i%len(unknown)])
 		}
 	}
 }
@@ -459,8 +457,8 @@ func benchmarkChurn(n int) func(b *testing.B) {
 		b.ReportAllocs()
 		for i := 0; i < b.N; i++ {
 			node := universe[i%len(universe)]
-			found, _ := rt.Find(context.Background(), node.Key())
-			if found == nil {
+			_, found := rt.GetNode(node.Key())
+			if !found {
 				// add new peer
 				rt.AddNode(universe[i%len(universe)])
 			} else {