Skip to content

Commit

Permalink
Add broadcast state machine for storing records in the DHT (#930)
Browse files Browse the repository at this point in the history
Co-authored-by: Ian Davis <[email protected]>
  • Loading branch information
dennis-tra and iand authored Sep 22, 2023
1 parent 2da54ab commit 74ffa67
Show file tree
Hide file tree
Showing 37 changed files with 1,992 additions and 251 deletions.
9 changes: 0 additions & 9 deletions v2/dht.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package dht

import (
"context"
"crypto/sha256"
"fmt"
"io"
"sync"
Expand All @@ -13,7 +12,6 @@ import (
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/plprobelab/go-kademlia/key"
"golang.org/x/exp/slog"

"github.com/libp2p/go-libp2p-kad-dht/v2/internal/coord"
Expand Down Expand Up @@ -339,13 +337,6 @@ func (d *DHT) AddAddresses(ctx context.Context, ais []peer.AddrInfo, ttl time.Du
return d.kad.AddNodes(ctx, ids)
}

// newSHA256Key returns a [kadt.KadKey] that conforms to the [kad.Key] interface by
// SHA256 hashing the given bytes and wrapping them in a [kadt.KadKey].
func newSHA256Key(data []byte) kadt.Key {
h := sha256.Sum256(data)
return key.NewKey256(h[:])
}

// typedBackend returns the backend at the given namespace. It is casted to the
// provided type. If the namespace doesn't exist or the type cast failed, this
// function returns an error. Can't be a method on [DHT] because of the generic
Expand Down
4 changes: 2 additions & 2 deletions v2/dht_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/libp2p/go-libp2p-kad-dht/v2/internal/coord"
"github.com/libp2p/go-libp2p-kad-dht/v2/internal/coord/coordt"
"github.com/libp2p/go-libp2p-kad-dht/v2/internal/kadtest"
"github.com/libp2p/go-libp2p-kad-dht/v2/kadt"
)
Expand Down Expand Up @@ -82,7 +82,7 @@ func TestAddAddresses(t *testing.T) {

// local routing table should not contain the node
_, err := local.kad.GetNode(ctx, kadt.PeerID(remote.host.ID()))
require.ErrorIs(t, err, coord.ErrNodeNotFound)
require.ErrorIs(t, err, coordt.ErrNodeNotFound)

remoteAddrInfo := peer.AddrInfo{
ID: remote.host.ID(),
Expand Down
4 changes: 2 additions & 2 deletions v2/go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ require (
github.com/libp2p/go-msgio v0.3.0
github.com/multiformats/go-base32 v0.1.0
github.com/multiformats/go-multiaddr v0.11.0
github.com/multiformats/go-multihash v0.2.3 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/plprobelab/go-kademlia v0.0.0-20230913171354-443ec1f56080
github.com/prometheus/client_golang v1.16.0 // indirect
github.com/stretchr/testify v1.8.4
Expand Down Expand Up @@ -84,14 +86,12 @@ require (
github.com/multiformats/go-multiaddr-fmt v0.1.0 // indirect
github.com/multiformats/go-multibase v0.2.0 // indirect
github.com/multiformats/go-multicodec v0.9.0 // indirect
github.com/multiformats/go-multihash v0.2.3 // indirect
github.com/multiformats/go-multistream v0.4.1 // indirect
github.com/multiformats/go-varint v0.0.7 // indirect
github.com/nxadm/tail v1.4.8 // indirect
github.com/onsi/ginkgo/v2 v2.11.0 // indirect
github.com/opencontainers/runtime-spec v1.1.0 // indirect
github.com/pbnjay/memory v0.0.0-20210728143218-7b4eea64cf58 // indirect
github.com/pkg/errors v0.9.1 // indirect
github.com/pmezard/go-difflib v1.0.0 // indirect
github.com/polydawn/refmt v0.89.0 // indirect
github.com/prometheus/client_model v0.4.0 // indirect
Expand Down
4 changes: 2 additions & 2 deletions v2/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ func (d *DHT) handleGetValue(ctx context.Context, remote peer.ID, req *pb.Messag
resp := &pb.Message{
Type: pb.Message_GET_VALUE,
Key: req.GetKey(),
CloserPeers: d.closerPeers(ctx, remote, newSHA256Key(req.GetKey())),
CloserPeers: d.closerPeers(ctx, remote, kadt.NewKey(req.GetKey())),
}

ns, path, err := record.SplitKey(k) // get namespace (prefix of the key)
Expand Down Expand Up @@ -226,7 +226,7 @@ func (d *DHT) handleGetProviders(ctx context.Context, remote peer.ID, req *pb.Me
resp := &pb.Message{
Type: pb.Message_GET_PROVIDERS,
Key: k,
CloserPeers: d.closerPeers(ctx, remote, newSHA256Key(k)),
CloserPeers: d.closerPeers(ctx, remote, kadt.NewKey(k)),
ProviderPeers: pbProviders,
}

Expand Down
4 changes: 0 additions & 4 deletions v2/internal/coord/behaviour.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,6 @@ type Behaviour[I BehaviourEvent, O BehaviourEvent] interface {
Perform(ctx context.Context) (O, bool)
}

type SM[E any, S any] interface {
Advance(context.Context, E) S
}

type WorkQueueFunc[E BehaviourEvent] func(context.Context, E) bool

// WorkQueue is buffered queue of work to be performed.
Expand Down
230 changes: 230 additions & 0 deletions v2/internal/coord/brdcst.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,230 @@
package coord

import (
"context"
"sync"

"go.opentelemetry.io/otel/trace"
"golang.org/x/exp/slog"

"github.com/libp2p/go-libp2p-kad-dht/v2/internal/coord/brdcst"
"github.com/libp2p/go-libp2p-kad-dht/v2/internal/coord/coordt"
"github.com/libp2p/go-libp2p-kad-dht/v2/kadt"
"github.com/libp2p/go-libp2p-kad-dht/v2/pb"
"github.com/libp2p/go-libp2p-kad-dht/v2/tele"
)

type PooledBroadcastBehaviour struct {
pool coordt.StateMachine[brdcst.PoolEvent, brdcst.PoolState]
waiters map[coordt.QueryID]NotifyCloser[BehaviourEvent]

pendingMu sync.Mutex
pending []BehaviourEvent
ready chan struct{}

logger *slog.Logger
tracer trace.Tracer
}

var _ Behaviour[BehaviourEvent, BehaviourEvent] = (*PooledBroadcastBehaviour)(nil)

func NewPooledBroadcastBehaviour(brdcstPool *brdcst.Pool[kadt.Key, kadt.PeerID, *pb.Message], logger *slog.Logger, tracer trace.Tracer) *PooledBroadcastBehaviour {
b := &PooledBroadcastBehaviour{
pool: brdcstPool,
waiters: make(map[coordt.QueryID]NotifyCloser[BehaviourEvent]),
ready: make(chan struct{}, 1),
logger: logger.With("behaviour", "pooledBroadcast"),
tracer: tracer,
}
return b
}

func (b *PooledBroadcastBehaviour) Ready() <-chan struct{} {
return b.ready
}

func (b *PooledBroadcastBehaviour) Notify(ctx context.Context, ev BehaviourEvent) {
ctx, span := b.tracer.Start(ctx, "PooledBroadcastBehaviour.Notify")
defer span.End()

b.pendingMu.Lock()
defer b.pendingMu.Unlock()

var cmd brdcst.PoolEvent
switch ev := ev.(type) {
case *EventStartBroadcast:
cmd = &brdcst.EventPoolStartBroadcast[kadt.Key, kadt.PeerID, *pb.Message]{
QueryID: ev.QueryID,
Target: ev.Target,
Message: ev.Message,
Seed: ev.Seed,
Config: ev.Config,
}
if ev.Notify != nil {
b.waiters[ev.QueryID] = ev.Notify
}

case *EventGetCloserNodesSuccess:
for _, info := range ev.CloserNodes {
b.pending = append(b.pending, &EventAddNode{
NodeID: info,
})
}

waiter, ok := b.waiters[ev.QueryID]
if ok {
waiter.Notify(ctx, &EventQueryProgressed{
NodeID: ev.To,
QueryID: ev.QueryID,
})
}

cmd = &brdcst.EventPoolGetCloserNodesSuccess[kadt.Key, kadt.PeerID]{
NodeID: ev.To,
QueryID: ev.QueryID,
Target: ev.Target,
CloserNodes: ev.CloserNodes,
}

case *EventGetCloserNodesFailure:
// queue an event that will notify the routing behaviour of a failed node
b.pending = append(b.pending, &EventNotifyNonConnectivity{
ev.To,
})

cmd = &brdcst.EventPoolGetCloserNodesFailure[kadt.Key, kadt.PeerID]{
NodeID: ev.To,
QueryID: ev.QueryID,
Target: ev.Target,
Error: ev.Err,
}

case *EventSendMessageSuccess:
for _, info := range ev.CloserNodes {
b.pending = append(b.pending, &EventAddNode{
NodeID: info,
})
}
waiter, ok := b.waiters[ev.QueryID]
if ok {
waiter.Notify(ctx, &EventQueryProgressed{
NodeID: ev.To,
QueryID: ev.QueryID,
Response: ev.Response,
})
}
// TODO: How do we know it's a StoreRecord response?
cmd = &brdcst.EventPoolStoreRecordSuccess[kadt.Key, kadt.PeerID, *pb.Message]{
QueryID: ev.QueryID,
NodeID: ev.To,
Request: ev.Request,
Response: ev.Response,
}

case *EventSendMessageFailure:
// queue an event that will notify the routing behaviour of a failed node
b.pending = append(b.pending, &EventNotifyNonConnectivity{
ev.To,
})

// TODO: How do we know it's a StoreRecord response?
cmd = &brdcst.EventPoolStoreRecordFailure[kadt.Key, kadt.PeerID, *pb.Message]{
NodeID: ev.To,
QueryID: ev.QueryID,
Request: ev.Request,
Error: ev.Err,
}

case *EventStopQuery:
cmd = &brdcst.EventPoolStopBroadcast{
QueryID: ev.QueryID,
}
}

// attempt to advance the broadcast pool
ev, ok := b.advancePool(ctx, cmd)
if ok {
b.pending = append(b.pending, ev)
}
if len(b.pending) > 0 {
select {
case b.ready <- struct{}{}:
default:
}
}
}

func (b *PooledBroadcastBehaviour) Perform(ctx context.Context) (BehaviourEvent, bool) {
ctx, span := b.tracer.Start(ctx, "PooledBroadcastBehaviour.Perform")
defer span.End()

// No inbound work can be done until Perform is complete
b.pendingMu.Lock()
defer b.pendingMu.Unlock()

for {
// drain queued events first.
if len(b.pending) > 0 {
var ev BehaviourEvent
ev, b.pending = b.pending[0], b.pending[1:]

if len(b.pending) > 0 {
select {
case b.ready <- struct{}{}:
default:
}
}
return ev, true
}

ev, ok := b.advancePool(ctx, &brdcst.EventPoolPoll{})
if ok {
return ev, true
}

// finally check if any pending events were accumulated in the meantime
if len(b.pending) == 0 {
return nil, false
}
}
}

func (b *PooledBroadcastBehaviour) advancePool(ctx context.Context, ev brdcst.PoolEvent) (out BehaviourEvent, term bool) {
ctx, span := b.tracer.Start(ctx, "PooledBroadcastBehaviour.advancePool", trace.WithAttributes(tele.AttrInEvent(ev)))
defer func() {
span.SetAttributes(tele.AttrOutEvent(out))
span.End()
}()

pstate := b.pool.Advance(ctx, ev)
switch st := pstate.(type) {
case *brdcst.StatePoolIdle:
// nothing to do
case *brdcst.StatePoolFindCloser[kadt.Key, kadt.PeerID]:
return &EventOutboundGetCloserNodes{
QueryID: st.QueryID,
To: st.NodeID,
Target: st.Target,
Notify: b,
}, true
case *brdcst.StatePoolStoreRecord[kadt.Key, kadt.PeerID, *pb.Message]:
return &EventOutboundSendMessage{
QueryID: st.QueryID,
To: st.NodeID,
Message: st.Message,
Notify: b,
}, true
case *brdcst.StatePoolBroadcastFinished[kadt.Key, kadt.PeerID]:
waiter, ok := b.waiters[st.QueryID]
if ok {
waiter.Notify(ctx, &EventBroadcastFinished{
QueryID: st.QueryID,
Contacted: st.Contacted,
Errors: st.Errors,
})
waiter.Close()
}
}

return nil, false
}
Loading

0 comments on commit 74ffa67

Please sign in to comment.