From c9c088f31dc50438281732a042c549bff98efc50 Mon Sep 17 00:00:00 2001 From: Adin Schmahmann Date: Thu, 28 May 2020 18:48:14 -0400 Subject: [PATCH 01/15] refactor: extract messaging components from IpfsDHT into its own struct. create a new struct that manages sending DHT messages that can be used independently from the DHT. --- dht.go | 87 +------------------------- dht_net.go | 72 +++++++++++++++------- dht_test.go | 8 +-- go.sum | 6 ++ lookup.go | 4 +- messages.go | 140 ++++++++++++++++++++++++++++++++++++++++++ records.go | 3 +- routing.go | 47 +++----------- subscriber_notifee.go | 17 +---- 9 files changed, 216 insertions(+), 168 deletions(-) create mode 100644 messages.go diff --git a/dht.go b/dht.go index df5a43148..80c012841 100644 --- a/dht.go +++ b/dht.go @@ -1,7 +1,6 @@ package dht import ( - "bytes" "context" "errors" "fmt" @@ -33,7 +32,6 @@ import ( goprocessctx "github.com/jbenet/goprocess/context" "github.com/multiformats/go-base32" ma "github.com/multiformats/go-multiaddr" - "github.com/multiformats/go-multihash" "go.opencensus.io/tag" "go.uber.org/zap" ) @@ -97,8 +95,7 @@ type IpfsDHT struct { ctx context.Context proc goprocess.Process - strmap map[peer.ID]*messageSender - smlk sync.Mutex + protoMessenger *ProtocolMessenger plk sync.Mutex @@ -190,6 +187,7 @@ func New(ctx context.Context, h host.Host, options ...Option) (*IpfsDHT, error) dht.disableFixLowPeers = cfg.disableFixLowPeers dht.Validator = cfg.validator + dht.protoMessenger = NewProtocolMessenger(dht.host, dht.protocols, dht.Validator) dht.testAddressUpdateProcessing = cfg.testAddressUpdateProcessing @@ -276,7 +274,6 @@ func makeDHT(ctx context.Context, h host.Host, cfg config) (*IpfsDHT, error) { selfKey: kb.ConvertPeerID(h.ID()), peerstore: h.Peerstore(), host: h, - strmap: make(map[peer.ID]*messageSender), birth: time.Now(), protocols: protocols, protocolsStrs: protocol.ConvertToStrings(protocols), @@ -530,67 +527,8 @@ func (dht *IpfsDHT) persistRTPeersInPeerStore() { } } -// putValueToPeer stores the given key/value pair at the peer 'p' -func (dht *IpfsDHT) putValueToPeer(ctx context.Context, p peer.ID, rec *recpb.Record) error { - pmes := pb.NewMessage(pb.Message_PUT_VALUE, rec.Key, 0) - pmes.Record = rec - rpmes, err := dht.sendRequest(ctx, p, pmes) - if err != nil { - logger.Debugw("failed to put value to peer", "to", p, "key", loggableRecordKeyBytes(rec.Key), "error", err) - return err - } - - if !bytes.Equal(rpmes.GetRecord().Value, pmes.GetRecord().Value) { - logger.Infow("value not put correctly", "put-message", pmes, "get-message", rpmes) - return errors.New("value not put correctly") - } - - return nil -} - var errInvalidRecord = errors.New("received invalid record") -// getValueOrPeers queries a particular peer p for the value for -// key. It returns either the value or a list of closer peers. -// NOTE: It will update the dht's peerstore with any new addresses -// it finds for the given peer. -func (dht *IpfsDHT) getValueOrPeers(ctx context.Context, p peer.ID, key string) (*recpb.Record, []*peer.AddrInfo, error) { - pmes, err := dht.getValueSingle(ctx, p, key) - if err != nil { - return nil, nil, err - } - - // Perhaps we were given closer peers - peers := pb.PBPeersToPeerInfos(pmes.GetCloserPeers()) - - if rec := pmes.GetRecord(); rec != nil { - // Success! We were given the value - logger.Debug("got value") - - // make sure record is valid. - err = dht.Validator.Validate(string(rec.GetKey()), rec.GetValue()) - if err != nil { - logger.Debug("received invalid record (discarded)") - // return a sentinal to signify an invalid record was received - err = errInvalidRecord - rec = new(recpb.Record) - } - return rec, peers, err - } - - if len(peers) > 0 { - return nil, peers, nil - } - - return nil, nil, routing.ErrNotFound -} - -// getValueSingle simply performs the get value RPC with the given parameters -func (dht *IpfsDHT) getValueSingle(ctx context.Context, p peer.ID, key string) (*pb.Message, error) { - pmes := pb.NewMessage(pb.Message_GET_VALUE, []byte(key), 0) - return dht.sendRequest(ctx, p, pmes) -} - // getLocal attempts to retrieve the value from the datastore func (dht *IpfsDHT) getLocal(key string) (*recpb.Record, error) { logger.Debugw("finding value in datastore", "key", loggableRecordKeyString(key)) @@ -719,17 +657,6 @@ func (dht *IpfsDHT) FindLocal(id peer.ID) peer.AddrInfo { } } -// findPeerSingle asks peer 'p' if they know where the peer with id 'id' is -func (dht *IpfsDHT) findPeerSingle(ctx context.Context, p peer.ID, id peer.ID) (*pb.Message, error) { - pmes := pb.NewMessage(pb.Message_FIND_NODE, []byte(id), 0) - return dht.sendRequest(ctx, p, pmes) -} - -func (dht *IpfsDHT) findProvidersSingle(ctx context.Context, p peer.ID, key multihash.Multihash) (*pb.Message, error) { - pmes := pb.NewMessage(pb.Message_GET_PROVIDERS, key, 0) - return dht.sendRequest(ctx, p, pmes) -} - // nearestPeersToQuery returns the routing tables closest peers. func (dht *IpfsDHT) nearestPeersToQuery(pmes *pb.Message, count int) []peer.ID { closer := dht.routingTable.NearestPeers(kb.ConvertKey(string(pmes.GetKey())), count) @@ -870,15 +797,7 @@ func (dht *IpfsDHT) Host() host.Host { // Ping sends a ping message to the passed peer and waits for a response. func (dht *IpfsDHT) Ping(ctx context.Context, p peer.ID) error { - req := pb.NewMessage(pb.Message_PING, nil, 0) - resp, err := dht.sendRequest(ctx, p, req) - if err != nil { - return fmt.Errorf("sending request: %w", err) - } - if resp.Type != pb.Message_PING { - return fmt.Errorf("got unexpected response type: %v", resp.Type) - } - return nil + return dht.protoMessenger.Ping(ctx, p) } // newContextWithLocalTags returns a new context.Context with the InstanceID and diff --git a/dht_net.go b/dht_net.go index 879be778f..3bf679e55 100644 --- a/dht_net.go +++ b/dht_net.go @@ -8,8 +8,10 @@ import ( "sync" "time" + "github.com/libp2p/go-libp2p-core/host" "github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/peer" + "github.com/libp2p/go-libp2p-core/protocol" "github.com/libp2p/go-msgio/protoio" "github.com/libp2p/go-libp2p-kad-dht/metrics" @@ -208,12 +210,38 @@ func (dht *IpfsDHT) handleNewMessage(s network.Stream) bool { } } +type messageManager struct { + host host.Host // the network services we need + strmap map[peer.ID]*messageSender + smlk sync.Mutex + protocols []protocol.ID +} + +func (m *messageManager) streamDisconnect(ctx context.Context, p peer.ID) { + m.smlk.Lock() + defer m.smlk.Unlock() + ms, ok := m.strmap[p] + if !ok { + return + } + delete(m.strmap, p) + + // Do this asynchronously as ms.lk can block for a while. + go func() { + if err := ms.lk.Lock(ctx); err != nil { + return + } + defer ms.lk.Unlock() + ms.invalidate() + }() +} + // sendRequest sends out a request, but also makes sure to // measure the RTT for latency measurements. -func (dht *IpfsDHT) sendRequest(ctx context.Context, p peer.ID, pmes *pb.Message) (*pb.Message, error) { +func (m *messageManager) sendRequest(ctx context.Context, p peer.ID, pmes *pb.Message) (*pb.Message, error) { ctx, _ = tag.New(ctx, metrics.UpsertMessageType(pmes)) - ms, err := dht.messageSenderForPeer(ctx, p) + ms, err := m.messageSenderForPeer(ctx, p) if err != nil { stats.Record(ctx, metrics.SentRequests.M(1), @@ -240,15 +268,15 @@ func (dht *IpfsDHT) sendRequest(ctx context.Context, p peer.ID, pmes *pb.Message metrics.SentBytes.M(int64(pmes.Size())), metrics.OutboundRequestLatency.M(float64(time.Since(start))/float64(time.Millisecond)), ) - dht.peerstore.RecordLatency(p, time.Since(start)) + m.host.Peerstore().RecordLatency(p, time.Since(start)) return rpmes, nil } // sendMessage sends out a message -func (dht *IpfsDHT) sendMessage(ctx context.Context, p peer.ID, pmes *pb.Message) error { +func (m *messageManager) sendMessage(ctx context.Context, p peer.ID, pmes *pb.Message) error { ctx, _ = tag.New(ctx, metrics.UpsertMessageType(pmes)) - ms, err := dht.messageSenderForPeer(ctx, p) + ms, err := m.messageSenderForPeer(ctx, p) if err != nil { stats.Record(ctx, metrics.SentMessages.M(1), @@ -274,22 +302,22 @@ func (dht *IpfsDHT) sendMessage(ctx context.Context, p peer.ID, pmes *pb.Message return nil } -func (dht *IpfsDHT) messageSenderForPeer(ctx context.Context, p peer.ID) (*messageSender, error) { - dht.smlk.Lock() - ms, ok := dht.strmap[p] +func (m *messageManager) messageSenderForPeer(ctx context.Context, p peer.ID) (*messageSender, error) { + m.smlk.Lock() + ms, ok := m.strmap[p] if ok { - dht.smlk.Unlock() + m.smlk.Unlock() return ms, nil } - ms = &messageSender{p: p, dht: dht, lk: newCtxMutex()} - dht.strmap[p] = ms - dht.smlk.Unlock() + ms = &messageSender{p: p, m: m, lk: newCtxMutex()} + m.strmap[p] = ms + m.smlk.Unlock() if err := ms.prepOrInvalidate(ctx); err != nil { - dht.smlk.Lock() - defer dht.smlk.Unlock() + m.smlk.Lock() + defer m.smlk.Unlock() - if msCur, ok := dht.strmap[p]; ok { + if msCur, ok := m.strmap[p]; ok { // Changed. Use the new one, old one is invalid and // not in the map so we can just throw it away. if ms != msCur { @@ -297,7 +325,7 @@ func (dht *IpfsDHT) messageSenderForPeer(ctx context.Context, p peer.ID) (*messa } // Not changed, remove the now invalid stream from the // map. - delete(dht.strmap, p) + delete(m.strmap, p) } // Invalid but not in map. Must have been removed by a disconnect. return nil, err @@ -307,11 +335,11 @@ func (dht *IpfsDHT) messageSenderForPeer(ctx context.Context, p peer.ID) (*messa } type messageSender struct { - s network.Stream - r msgio.ReadCloser - lk ctxMutex - p peer.ID - dht *IpfsDHT + s network.Stream + r msgio.ReadCloser + lk ctxMutex + p peer.ID + m *messageManager invalid bool singleMes int @@ -352,7 +380,7 @@ func (ms *messageSender) prep(ctx context.Context) error { // We only want to speak to peers using our primary protocols. We do not want to query any peer that only speaks // one of the secondary "server" protocols that we happen to support (e.g. older nodes that we can respond to for // backwards compatibility reasons). - nstr, err := ms.dht.host.NewStream(ctx, ms.p, ms.dht.protocols...) + nstr, err := ms.m.host.NewStream(ctx, ms.p, ms.m.protocols...) if err != nil { return err } diff --git a/dht_test.go b/dht_test.go index 85dcc22f6..53026e1f1 100644 --- a/dht_test.go +++ b/dht_test.go @@ -570,14 +570,14 @@ func TestInvalidMessageSenderTracking(t *testing.T) { defer dht.Close() foo := peer.ID("asdasd") - _, err := dht.messageSenderForPeer(ctx, foo) + _, err := dht.protoMessenger.m.messageSenderForPeer(ctx, foo) if err == nil { t.Fatal("that shouldnt have succeeded") } - dht.smlk.Lock() - mscnt := len(dht.strmap) - dht.smlk.Unlock() + dht.protoMessenger.m.smlk.Lock() + mscnt := len(dht.protoMessenger.m.strmap) + dht.protoMessenger.m.smlk.Unlock() if mscnt > 0 { t.Fatal("should have no message senders in map") diff --git a/go.sum b/go.sum index da973767b..83fe301f2 100644 --- a/go.sum +++ b/go.sum @@ -62,6 +62,7 @@ github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4er github.com/golang/mock v1.1.1/go.mod h1:oTYuIxOrZwtPieC+H1uAHpcLFnEyAGVDL/k47Jfbm0A= github.com/golang/protobuf v1.2.0/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.0/go.mod h1:Qd/q+1AKNOZr9uGQzbzCmRO6sUih6GTPZv6a1/R87v0= +github.com/golang/protobuf v1.3.1 h1:YF8+flBXS5eO826T4nzqPrxfhQThhXl0YzfuUPu4SBg= github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.3.3/go.mod h1:vzj43D7+SQXF/4pzW/hwtAqwc6iTitCiVSaWz5lYuqw= @@ -73,6 +74,7 @@ github.com/golang/protobuf v1.4.0 h1:oOuy+ugB+P/kBdUnG5QaMXSIyJ1q38wWSojYCb3z5VQ github.com/golang/protobuf v1.4.0/go.mod h1:jodUvKwWbYaEsadDk5Fwe5c77LiNKVO9IDvqG2KuDX0= github.com/golang/snappy v0.0.0-20180518054509-2e65f85255db/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/go-cmp v0.2.0/go.mod h1:oXzfMopK8JAjlY9xF4vHSVASa0yLyX7SntLO5aqRK0M= +github.com/google/go-cmp v0.3.0 h1:crn/baboCvb5fXaQ0IJ1SGTsTVrWpDsCWC8EGETZijY= github.com/google/go-cmp v0.3.0/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.3.1/go.mod h1:8QqcDgzrUqlUb/G2PQTWiueGozuR1884gddMywk6iLU= github.com/google/go-cmp v0.4.0 h1:xsAVV57WRhGj6kEIi8ReJzQlHHqcBYCElAvkovg3B/4= @@ -100,6 +102,7 @@ github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ github.com/hashicorp/golang-lru v0.5.4 h1:YDjusn29QI/Das2iO9M0BHnIbxPeyuCHsjMW+lJfyTc= github.com/hashicorp/golang-lru v0.5.4/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= github.com/hashicorp/hcl v1.0.0/go.mod h1:E5yfLk+7swimpb2L/Alb/PJmXilQ/rhwaUYs4T20WEQ= +github.com/hpcloud/tail v1.0.0 h1:nfCOvKYfkgYP8hkirhJocXT2+zOD8yUNjXaWfTlyFKI= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= github.com/huin/goupnp v1.0.0 h1:wg75sLpL6DZqwHQN6E1Cfk6mtfzS45z8OV+ic+DtHRo= github.com/huin/goupnp v1.0.0/go.mod h1:n9v9KO1tAxYH82qOn+UTIFQDmx5n1Zxd/ClZDMX7Bnc= @@ -488,6 +491,7 @@ github.com/nxadm/tail v1.4.4/go.mod h1:kenIhsEOeOJmVchQTgglprH7qJGnHDVpk1VPCcaMI github.com/onsi/ginkgo v1.6.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.7.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= github.com/onsi/ginkgo v1.8.0/go.mod h1:lLunBs/Ym6LB5Z9jYTR76FiuTmxDTDusOGeTQH+WWjE= +github.com/onsi/ginkgo v1.12.0 h1:Iw5WCbBcaAAd0fpRb1c9r5YCylv4XDoCSigm1zLevwU= github.com/onsi/ginkgo v1.12.0/go.mod h1:oUhWkIvk5aDxtKvDDuw8gItl8pKl42LzjC9KZE0HfGg= github.com/onsi/ginkgo v1.12.1 h1:mFwc4LvZ0xpSvDZ3E+k8Yte0hLOMxXUlP+yXtJqkYfQ= github.com/onsi/ginkgo v1.12.1/go.mod h1:zj2OWP4+oCPe1qIXoGWkgMRwljMUYCdkwsT2108oapk= @@ -652,6 +656,7 @@ golang.org/x/tools v0.0.0-20191108193012-7d206e10da11/go.mod h1:b+2E5dAYhXwXZwtn golang.org/x/tools v0.0.0-20191216052735-49a3e744a425 h1:VvQyQJN0tSuecqgcIxMWnnfG5kSmgy9KZR9sW3W5QeA= golang.org/x/tools v0.0.0-20191216052735-49a3e744a425/go.mod h1:TB2adYChydJhpapKDTa4BR/hXlZSLoq2Wpct/0txZ28= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898 h1:/atklqdjdhuosWIl6AIbOeHJjicWYPqR9bpxqxYG2pA= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -682,6 +687,7 @@ gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= +gopkg.in/fsnotify.v1 v1.4.7 h1:xOHLXZwVvI9hhs+cLKq5+I5onOuwQLhQwiu63xxlHs4= gopkg.in/fsnotify.v1 v1.4.7/go.mod h1:Tz8NjZHkW78fSQdbUxIjBTcgA1z1m8ZHf0WmKUhAMys= gopkg.in/src-d/go-cli.v0 v0.0.0-20181105080154-d492247bbc0d/go.mod h1:z+K8VcOYVYcSwSjGebuDL6176A1XskgbtNl64NSg+n8= gopkg.in/src-d/go-log.v1 v1.0.1/go.mod h1:GN34hKP0g305ysm2/hctJ0Y8nWP3zxXXJ8GFabTyABE= diff --git a/lookup.go b/lookup.go index 0168601a5..dff8bb244 100644 --- a/lookup.go +++ b/lookup.go @@ -8,7 +8,6 @@ import ( "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/routing" - pb "github.com/libp2p/go-libp2p-kad-dht/pb" kb "github.com/libp2p/go-libp2p-kbucket" ) @@ -30,12 +29,11 @@ func (dht *IpfsDHT) GetClosestPeers(ctx context.Context, key string) (<-chan pee ID: p, }) - pmes, err := dht.findPeerSingle(ctx, p, peer.ID(key)) + peers, err := dht.protoMessenger.GetClosestPeers(ctx, p, peer.ID(key)) if err != nil { logger.Debugf("error getting closer peers: %s", err) return nil, err } - peers := pb.PBPeersToPeerInfos(pmes.GetCloserPeers()) // For DHT query command routing.PublishQueryEvent(ctx, &routing.QueryEvent{ diff --git a/messages.go b/messages.go new file mode 100644 index 000000000..e9c25415e --- /dev/null +++ b/messages.go @@ -0,0 +1,140 @@ +package dht + +import ( + "bytes" + "context" + "errors" + "fmt" + + "github.com/libp2p/go-libp2p-core/host" + "github.com/libp2p/go-libp2p-core/peer" + "github.com/libp2p/go-libp2p-core/protocol" + "github.com/libp2p/go-libp2p-core/routing" + + record "github.com/libp2p/go-libp2p-record" + recpb "github.com/libp2p/go-libp2p-record/pb" + "github.com/multiformats/go-multihash" + + pb "github.com/libp2p/go-libp2p-kad-dht/pb" +) + +type ProtocolMessenger struct { + m *messageManager + validator record.Validator +} + +func NewProtocolMessenger(host host.Host, protocols []protocol.ID, validator record.Validator) *ProtocolMessenger { + return &ProtocolMessenger{ + m: &messageManager{ + host: host, + strmap: make(map[peer.ID]*messageSender), + protocols: protocols, + }, + validator: validator, + } +} + +// putValueToPeer stores the given key/value pair at the peer 'p' +func (pm *ProtocolMessenger) PutValue(ctx context.Context, p peer.ID, rec *recpb.Record) error { + pmes := pb.NewMessage(pb.Message_PUT_VALUE, rec.Key, 0) + pmes.Record = rec + rpmes, err := pm.m.sendRequest(ctx, p, pmes) + if err != nil { + logger.Debugw("failed to put value to peer", "to", p, "key", loggableRecordKeyBytes(rec.Key), "error", err) + return err + } + + if !bytes.Equal(rpmes.GetRecord().Value, pmes.GetRecord().Value) { + logger.Infow("value not put correctly", "put-message", pmes, "get-message", rpmes) + return errors.New("value not put correctly") + } + + return nil +} + +// GetValue queries a particular peer p for the value for +// key. It returns the value and a list of closer peers. +func (pm *ProtocolMessenger) GetValue(ctx context.Context, p peer.ID, key string) (*recpb.Record, []*peer.AddrInfo, error) { + pmes := pb.NewMessage(pb.Message_GET_VALUE, []byte(key), 0) + respMsg, err := pm.m.sendRequest(ctx, p, pmes) + if err != nil { + return nil, nil, err + } + + // Perhaps we were given closer peers + peers := pb.PBPeersToPeerInfos(respMsg.GetCloserPeers()) + + if rec := respMsg.GetRecord(); rec != nil { + // Success! We were given the value + logger.Debug("got value") + + // make sure record is valid. + err = pm.validator.Validate(string(rec.GetKey()), rec.GetValue()) + if err != nil { + logger.Debug("received invalid record (discarded)") + // return a sentinal to signify an invalid record was received + err = errInvalidRecord + rec = new(recpb.Record) + } + return rec, peers, err + } + + if len(peers) > 0 { + return nil, peers, nil + } + + return nil, nil, routing.ErrNotFound +} + +// findPeerSingle asks peer 'p' if they know where the peer with id 'id' is +func (pm *ProtocolMessenger) GetClosestPeers(ctx context.Context, p peer.ID, id peer.ID) ([]*peer.AddrInfo, error) { + pmes := pb.NewMessage(pb.Message_FIND_NODE, []byte(id), 0) + respMsg, err := pm.m.sendRequest(ctx, p, pmes) + if err != nil { + return nil, err + } + peers := pb.PBPeersToPeerInfos(respMsg.GetCloserPeers()) + return peers, nil +} + +func (pm *ProtocolMessenger) PutProvider(ctx context.Context, p peer.ID, key multihash.Multihash, host host.Host) error { + pi := peer.AddrInfo{ + ID: host.ID(), + Addrs: host.Addrs(), + } + + // // only share WAN-friendly addresses ?? + // pi.Addrs = addrutil.WANShareableAddrs(pi.Addrs) + if len(pi.Addrs) < 1 { + return fmt.Errorf("no known addresses for self, cannot put provider") + } + + pmes := pb.NewMessage(pb.Message_ADD_PROVIDER, key, 0) + pmes.ProviderPeers = pb.RawPeerInfosToPBPeers([]peer.AddrInfo{pi}) + + return pm.m.sendMessage(ctx, p, pmes) +} + +func (pm *ProtocolMessenger) GetProviders(ctx context.Context, p peer.ID, key multihash.Multihash) ([]*peer.AddrInfo, []*peer.AddrInfo, error) { + pmes := pb.NewMessage(pb.Message_GET_PROVIDERS, key, 0) + respMsg, err := pm.m.sendRequest(ctx, p, pmes) + if err != nil { + return nil, nil, err + } + provs := pb.PBPeersToPeerInfos(respMsg.GetProviderPeers()) + closerPeers := pb.PBPeersToPeerInfos(respMsg.GetCloserPeers()) + return provs, closerPeers, nil +} + +// Ping sends a ping message to the passed peer and waits for a response. +func (pm *ProtocolMessenger) Ping(ctx context.Context, p peer.ID) error { + req := pb.NewMessage(pb.Message_PING, nil, 0) + resp, err := pm.m.sendRequest(ctx, p, req) + if err != nil { + return fmt.Errorf("sending request: %w", err) + } + if resp.Type != pb.Message_PING { + return fmt.Errorf("got unexpected response type: %v", resp.Type) + } + return nil +} diff --git a/records.go b/records.go index adb28ce7d..bba505080 100644 --- a/records.go +++ b/records.go @@ -98,13 +98,12 @@ func (dht *IpfsDHT) getPublicKeyFromNode(ctx context.Context, p peer.ID) (ci.Pub // Get the key from the node itself pkkey := routing.KeyForPublicKey(p) - pmes, err := dht.getValueSingle(ctx, p, pkkey) + record, _, err := dht.protoMessenger.GetValue(ctx, p, pkkey) if err != nil { return nil, err } // node doesn't have key :( - record := pmes.GetRecord() if record == nil { return nil, fmt.Errorf("node %v not responding with its public key", p) } diff --git a/routing.go b/routing.go index 4d6ae990c..885dfe2e2 100644 --- a/routing.go +++ b/routing.go @@ -14,7 +14,6 @@ import ( "github.com/ipfs/go-cid" u "github.com/ipfs/go-ipfs-util" - pb "github.com/libp2p/go-libp2p-kad-dht/pb" "github.com/libp2p/go-libp2p-kad-dht/qpeerset" kb "github.com/libp2p/go-libp2p-kbucket" record "github.com/libp2p/go-libp2p-record" @@ -81,7 +80,7 @@ func (dht *IpfsDHT) PutValue(ctx context.Context, key string, value []byte, opts ID: p, }) - err := dht.putValueToPeer(ctx, p, rec) + err := dht.protoMessenger.PutValue(ctx, p, rec) if err != nil { logger.Debugf("failed putting value to peer: %s", err) } @@ -281,7 +280,7 @@ func (dht *IpfsDHT) updatePeerValues(ctx context.Context, key string, val []byte } ctx, cancel := context.WithTimeout(ctx, time.Second*30) defer cancel() - err := dht.putValueToPeer(ctx, p, fixupRec) + err := dht.protoMessenger.PutValue(ctx, p, fixupRec) if err != nil { logger.Debug("Error correcting DHT entry: ", err) } @@ -316,7 +315,7 @@ func (dht *IpfsDHT) getValues(ctx context.Context, key string, stopQuery chan st ID: p, }) - rec, peers, err := dht.getValueOrPeers(ctx, p, key) + rec, peers, err := dht.protoMessenger.GetValue(ctx, p, key) switch err { case routing.ErrNotFound: // in this case, they responded with nothing, @@ -444,18 +443,13 @@ func (dht *IpfsDHT) Provide(ctx context.Context, key cid.Cid, brdcst bool) (err return err } - mes, err := dht.makeProvRecord(keyMH) - if err != nil { - return err - } - wg := sync.WaitGroup{} for p := range peers { wg.Add(1) go func(p peer.ID) { defer wg.Done() logger.Debugf("putProvider(%s, %s)", loggableProviderRecordBytes(keyMH), p) - err := dht.sendMessage(ctx, p, mes) + err := dht.protoMessenger.PutProvider(ctx, p, keyMH, dht.host) if err != nil { logger.Debug(err) } @@ -467,22 +461,6 @@ func (dht *IpfsDHT) Provide(ctx context.Context, key cid.Cid, brdcst bool) (err } return ctx.Err() } -func (dht *IpfsDHT) makeProvRecord(key []byte) (*pb.Message, error) { - pi := peer.AddrInfo{ - ID: dht.self, - Addrs: dht.host.Addrs(), - } - - // // only share WAN-friendly addresses ?? - // pi.Addrs = addrutil.WANShareableAddrs(pi.Addrs) - if len(pi.Addrs) < 1 { - return nil, fmt.Errorf("no known addresses for self, cannot put provider") - } - - pmes := pb.NewMessage(pb.Message_ADD_PROVIDER, key, 0) - pmes.ProviderPeers = pb.RawPeerInfosToPBPeers([]peer.AddrInfo{pi}) - return pmes, nil -} // FindProviders searches until the context expires. func (dht *IpfsDHT) FindProviders(ctx context.Context, c cid.Cid) ([]peer.AddrInfo, error) { @@ -562,14 +540,12 @@ func (dht *IpfsDHT) findProvidersAsyncRoutine(ctx context.Context, key multihash ID: p, }) - pmes, err := dht.findProvidersSingle(ctx, p, key) + provs, closest, err := dht.protoMessenger.GetProviders(ctx, p, key) if err != nil { return nil, err } - logger.Debugf("%d provider entries", len(pmes.GetProviderPeers())) - provs := pb.PBPeersToPeerInfos(pmes.GetProviderPeers()) - logger.Debugf("%d provider entries decoded", len(provs)) + logger.Debugf("%d provider entries", len(provs)) // Add unique providers from request, up to 'count' for _, prov := range provs { @@ -591,17 +567,15 @@ func (dht *IpfsDHT) findProvidersAsyncRoutine(ctx context.Context, key multihash } // Give closer peers back to the query to be queried - closer := pmes.GetCloserPeers() - peers := pb.PBPeersToPeerInfos(closer) - logger.Debugf("got closer peers: %d %s", len(peers), peers) + logger.Debugf("got closer peers: %d %s", len(closest), closest) routing.PublishQueryEvent(ctx, &routing.QueryEvent{ Type: routing.PeerResponse, ID: p, - Responses: peers, + Responses: closest, }) - return peers, nil + return closest, nil }, func() bool { return !findAll && ps.Size() >= count @@ -634,12 +608,11 @@ func (dht *IpfsDHT) FindPeer(ctx context.Context, id peer.ID) (_ peer.AddrInfo, ID: p, }) - pmes, err := dht.findPeerSingle(ctx, p, id) + peers, err := dht.protoMessenger.GetClosestPeers(ctx, p, id) if err != nil { logger.Debugf("error getting closer peers: %s", err) return nil, err } - peers := pb.PBPeersToPeerInfos(pmes.GetCloserPeers()) // For DHT query command routing.PublishQueryEvent(ctx, &routing.QueryEvent{ diff --git a/subscriber_notifee.go b/subscriber_notifee.go index 8211d25de..d8be15fca 100644 --- a/subscriber_notifee.go +++ b/subscriber_notifee.go @@ -173,22 +173,7 @@ func (nn *subscriberNotifee) Disconnected(n network.Network, v network.Conn) { return } - dht.smlk.Lock() - defer dht.smlk.Unlock() - ms, ok := dht.strmap[p] - if !ok { - return - } - delete(dht.strmap, p) - - // Do this asynchronously as ms.lk can block for a while. - go func() { - if err := ms.lk.Lock(dht.Context()); err != nil { - return - } - defer ms.lk.Unlock() - ms.invalidate() - }() + dht.protoMessenger.m.streamDisconnect(dht.Context(), p) } func (nn *subscriberNotifee) Connected(network.Network, network.Conn) {} From a3c7c3dc1ee6d2d8f3124cebd4b52c949f8cfe99 Mon Sep 17 00:00:00 2001 From: Adin Schmahmann Date: Tue, 6 Oct 2020 18:21:11 -0400 Subject: [PATCH 02/15] docs: cleaned up comments --- messages.go | 27 ++++++++++++++++++++------- 1 file changed, 20 insertions(+), 7 deletions(-) diff --git a/messages.go b/messages.go index e9c25415e..8f6db182d 100644 --- a/messages.go +++ b/messages.go @@ -18,11 +18,19 @@ import ( pb "github.com/libp2p/go-libp2p-kad-dht/pb" ) +// ProtocolMessenger can be used for sending DHT messages to peers and processing their responses. +// This decouples the wire protocol format from both the DHT protocol implementation and from the implementation of the +// routing.Routing interface. +// +// TODO: This is still strongly coupled with the existing implementation of what happens when a peer actually sends a +// message on the wire (e.g. reusing streams, reusing connections, metrics tracking, etc.). type ProtocolMessenger struct { m *messageManager validator record.Validator } +// NewProtocolMessenger creates a new ProtocolMessenger that is used for sending DHT messages to peers and processing +// their responses. func NewProtocolMessenger(host host.Host, protocols []protocol.ID, validator record.Validator) *ProtocolMessenger { return &ProtocolMessenger{ m: &messageManager{ @@ -34,7 +42,7 @@ func NewProtocolMessenger(host host.Host, protocols []protocol.ID, validator rec } } -// putValueToPeer stores the given key/value pair at the peer 'p' +// PutValue asks a peer to store the given key/value pair. func (pm *ProtocolMessenger) PutValue(ctx context.Context, p peer.ID, rec *recpb.Record) error { pmes := pb.NewMessage(pb.Message_PUT_VALUE, rec.Key, 0) pmes.Record = rec @@ -52,8 +60,8 @@ func (pm *ProtocolMessenger) PutValue(ctx context.Context, p peer.ID, rec *recpb return nil } -// GetValue queries a particular peer p for the value for -// key. It returns the value and a list of closer peers. +// GetValue asks a peer for the value corresponding to the given key. Also returns the K closest peers to the key +// as described in GetClosestPeers. func (pm *ProtocolMessenger) GetValue(ctx context.Context, p peer.ID, key string) (*recpb.Record, []*peer.AddrInfo, error) { pmes := pb.NewMessage(pb.Message_GET_VALUE, []byte(key), 0) respMsg, err := pm.m.sendRequest(ctx, p, pmes) @@ -72,7 +80,7 @@ func (pm *ProtocolMessenger) GetValue(ctx context.Context, p peer.ID, key string err = pm.validator.Validate(string(rec.GetKey()), rec.GetValue()) if err != nil { logger.Debug("received invalid record (discarded)") - // return a sentinal to signify an invalid record was received + // return a sentinel to signify an invalid record was received err = errInvalidRecord rec = new(recpb.Record) } @@ -86,7 +94,9 @@ func (pm *ProtocolMessenger) GetValue(ctx context.Context, p peer.ID, key string return nil, nil, routing.ErrNotFound } -// findPeerSingle asks peer 'p' if they know where the peer with id 'id' is +// GetClosestPeers asks a peer to return the K (a DHT-wide parameter) DHT server peers closest in XOR space to the id +// Note: If the peer happens to know another peer whose peerID exactly matches the given id it will return that peer +// even if that peer is not a DHT server node. func (pm *ProtocolMessenger) GetClosestPeers(ctx context.Context, p peer.ID, id peer.ID) ([]*peer.AddrInfo, error) { pmes := pb.NewMessage(pb.Message_FIND_NODE, []byte(id), 0) respMsg, err := pm.m.sendRequest(ctx, p, pmes) @@ -97,14 +107,15 @@ func (pm *ProtocolMessenger) GetClosestPeers(ctx context.Context, p peer.ID, id return peers, nil } +// PutProvider asks a peer to store that we are a provider for the given key. func (pm *ProtocolMessenger) PutProvider(ctx context.Context, p peer.ID, key multihash.Multihash, host host.Host) error { pi := peer.AddrInfo{ ID: host.ID(), Addrs: host.Addrs(), } - // // only share WAN-friendly addresses ?? - // pi.Addrs = addrutil.WANShareableAddrs(pi.Addrs) + // TODO: We may want to limit the type of addresses in our provider records + // For example, in a WAN-only DHT prohibit sharing non-WAN addresses (e.g. 192.168.0.100) if len(pi.Addrs) < 1 { return fmt.Errorf("no known addresses for self, cannot put provider") } @@ -115,6 +126,8 @@ func (pm *ProtocolMessenger) PutProvider(ctx context.Context, p peer.ID, key mul return pm.m.sendMessage(ctx, p, pmes) } +// GetProviders asks a peer for the providers it knows of for a given key. Also returns the K closest peers to the key +// as described in GetClosestPeers. func (pm *ProtocolMessenger) GetProviders(ctx context.Context, p peer.ID, key multihash.Multihash) ([]*peer.AddrInfo, []*peer.AddrInfo, error) { pmes := pb.NewMessage(pb.Message_GET_PROVIDERS, key, 0) respMsg, err := pm.m.sendRequest(ctx, p, pmes) From 94d08a2b668ae3d6288c123c160d031c3baa9f06 Mon Sep 17 00:00:00 2001 From: Adin Schmahmann Date: Tue, 6 Oct 2020 19:09:37 -0400 Subject: [PATCH 03/15] refactor: decouple ProtocolMessenger from MessageSender implementation --- dht.go | 11 +++++++++- dht_net.go | 10 ++++----- dht_test.go | 8 +++---- messages.go | 51 +++++++++++++++++++++++++++++-------------- subscriber_notifee.go | 2 +- 5 files changed, 55 insertions(+), 27 deletions(-) diff --git a/dht.go b/dht.go index 80c012841..67848600d 100644 --- a/dht.go +++ b/dht.go @@ -96,6 +96,7 @@ type IpfsDHT struct { proc goprocess.Process protoMessenger *ProtocolMessenger + messageMgr *messageManager plk sync.Mutex @@ -187,7 +188,15 @@ func New(ctx context.Context, h host.Host, options ...Option) (*IpfsDHT, error) dht.disableFixLowPeers = cfg.disableFixLowPeers dht.Validator = cfg.validator - dht.protoMessenger = NewProtocolMessenger(dht.host, dht.protocols, dht.Validator) + dht.messageMgr = &messageManager{ + host: h, + strmap: make(map[peer.ID]*messageSender), + protocols: dht.protocols, + } + dht.protoMessenger, err = NewProtocolMessenger(dht.messageMgr, WithValidator(dht.Validator)) + if err != nil { + return nil, err + } dht.testAddressUpdateProcessing = cfg.testAddressUpdateProcessing diff --git a/dht_net.go b/dht_net.go index 3bf679e55..35c52bb06 100644 --- a/dht_net.go +++ b/dht_net.go @@ -212,8 +212,8 @@ func (dht *IpfsDHT) handleNewMessage(s network.Stream) bool { type messageManager struct { host host.Host // the network services we need - strmap map[peer.ID]*messageSender smlk sync.Mutex + strmap map[peer.ID]*messageSender protocols []protocol.ID } @@ -236,9 +236,9 @@ func (m *messageManager) streamDisconnect(ctx context.Context, p peer.ID) { }() } -// sendRequest sends out a request, but also makes sure to +// SendRequest sends out a request, but also makes sure to // measure the RTT for latency measurements. -func (m *messageManager) sendRequest(ctx context.Context, p peer.ID, pmes *pb.Message) (*pb.Message, error) { +func (m *messageManager) SendRequest(ctx context.Context, p peer.ID, pmes *pb.Message) (*pb.Message, error) { ctx, _ = tag.New(ctx, metrics.UpsertMessageType(pmes)) ms, err := m.messageSenderForPeer(ctx, p) @@ -272,8 +272,8 @@ func (m *messageManager) sendRequest(ctx context.Context, p peer.ID, pmes *pb.Me return rpmes, nil } -// sendMessage sends out a message -func (m *messageManager) sendMessage(ctx context.Context, p peer.ID, pmes *pb.Message) error { +// SendMessage sends out a message +func (m *messageManager) SendMessage(ctx context.Context, p peer.ID, pmes *pb.Message) error { ctx, _ = tag.New(ctx, metrics.UpsertMessageType(pmes)) ms, err := m.messageSenderForPeer(ctx, p) diff --git a/dht_test.go b/dht_test.go index 53026e1f1..37f5142f7 100644 --- a/dht_test.go +++ b/dht_test.go @@ -570,14 +570,14 @@ func TestInvalidMessageSenderTracking(t *testing.T) { defer dht.Close() foo := peer.ID("asdasd") - _, err := dht.protoMessenger.m.messageSenderForPeer(ctx, foo) + _, err := dht.messageMgr.messageSenderForPeer(ctx, foo) if err == nil { t.Fatal("that shouldnt have succeeded") } - dht.protoMessenger.m.smlk.Lock() - mscnt := len(dht.protoMessenger.m.strmap) - dht.protoMessenger.m.smlk.Unlock() + dht.messageMgr.smlk.Lock() + mscnt := len(dht.messageMgr.strmap) + dht.messageMgr.smlk.Unlock() if mscnt > 0 { t.Fatal("should have no message senders in map") diff --git a/messages.go b/messages.go index 8f6db182d..a0c3f4523 100644 --- a/messages.go +++ b/messages.go @@ -8,7 +8,6 @@ import ( "github.com/libp2p/go-libp2p-core/host" "github.com/libp2p/go-libp2p-core/peer" - "github.com/libp2p/go-libp2p-core/protocol" "github.com/libp2p/go-libp2p-core/routing" record "github.com/libp2p/go-libp2p-record" @@ -25,28 +24,48 @@ import ( // TODO: This is still strongly coupled with the existing implementation of what happens when a peer actually sends a // message on the wire (e.g. reusing streams, reusing connections, metrics tracking, etc.). type ProtocolMessenger struct { - m *messageManager + m MessageSender validator record.Validator } +type ProtocolMessengerOption func(*ProtocolMessenger) error + +func WithValidator(validator record.Validator) ProtocolMessengerOption { + return func(messenger *ProtocolMessenger) error { + messenger.validator = validator + return nil + } +} + // NewProtocolMessenger creates a new ProtocolMessenger that is used for sending DHT messages to peers and processing // their responses. -func NewProtocolMessenger(host host.Host, protocols []protocol.ID, validator record.Validator) *ProtocolMessenger { - return &ProtocolMessenger{ - m: &messageManager{ - host: host, - strmap: make(map[peer.ID]*messageSender), - protocols: protocols, - }, - validator: validator, +func NewProtocolMessenger(msgSender MessageSender, opts ...ProtocolMessengerOption) (*ProtocolMessenger, error) { + pm := &ProtocolMessenger{ + m: msgSender, + } + + for _, o := range opts { + if err := o(pm); err != nil { + return nil, err + } } + + return pm, nil +} + +// MessageSender handles sending wire protocol messages to a given peer +type MessageSender interface { + // SendRequest sends a peer a message and waits for its response + SendRequest(ctx context.Context, p peer.ID, pmes *pb.Message) (*pb.Message, error) + // SendMessage sends a peer a message without waiting on a response + SendMessage(ctx context.Context, p peer.ID, pmes *pb.Message) error } // PutValue asks a peer to store the given key/value pair. func (pm *ProtocolMessenger) PutValue(ctx context.Context, p peer.ID, rec *recpb.Record) error { pmes := pb.NewMessage(pb.Message_PUT_VALUE, rec.Key, 0) pmes.Record = rec - rpmes, err := pm.m.sendRequest(ctx, p, pmes) + rpmes, err := pm.m.SendRequest(ctx, p, pmes) if err != nil { logger.Debugw("failed to put value to peer", "to", p, "key", loggableRecordKeyBytes(rec.Key), "error", err) return err @@ -64,7 +83,7 @@ func (pm *ProtocolMessenger) PutValue(ctx context.Context, p peer.ID, rec *recpb // as described in GetClosestPeers. func (pm *ProtocolMessenger) GetValue(ctx context.Context, p peer.ID, key string) (*recpb.Record, []*peer.AddrInfo, error) { pmes := pb.NewMessage(pb.Message_GET_VALUE, []byte(key), 0) - respMsg, err := pm.m.sendRequest(ctx, p, pmes) + respMsg, err := pm.m.SendRequest(ctx, p, pmes) if err != nil { return nil, nil, err } @@ -99,7 +118,7 @@ func (pm *ProtocolMessenger) GetValue(ctx context.Context, p peer.ID, key string // even if that peer is not a DHT server node. func (pm *ProtocolMessenger) GetClosestPeers(ctx context.Context, p peer.ID, id peer.ID) ([]*peer.AddrInfo, error) { pmes := pb.NewMessage(pb.Message_FIND_NODE, []byte(id), 0) - respMsg, err := pm.m.sendRequest(ctx, p, pmes) + respMsg, err := pm.m.SendRequest(ctx, p, pmes) if err != nil { return nil, err } @@ -123,14 +142,14 @@ func (pm *ProtocolMessenger) PutProvider(ctx context.Context, p peer.ID, key mul pmes := pb.NewMessage(pb.Message_ADD_PROVIDER, key, 0) pmes.ProviderPeers = pb.RawPeerInfosToPBPeers([]peer.AddrInfo{pi}) - return pm.m.sendMessage(ctx, p, pmes) + return pm.m.SendMessage(ctx, p, pmes) } // GetProviders asks a peer for the providers it knows of for a given key. Also returns the K closest peers to the key // as described in GetClosestPeers. func (pm *ProtocolMessenger) GetProviders(ctx context.Context, p peer.ID, key multihash.Multihash) ([]*peer.AddrInfo, []*peer.AddrInfo, error) { pmes := pb.NewMessage(pb.Message_GET_PROVIDERS, key, 0) - respMsg, err := pm.m.sendRequest(ctx, p, pmes) + respMsg, err := pm.m.SendRequest(ctx, p, pmes) if err != nil { return nil, nil, err } @@ -142,7 +161,7 @@ func (pm *ProtocolMessenger) GetProviders(ctx context.Context, p peer.ID, key mu // Ping sends a ping message to the passed peer and waits for a response. func (pm *ProtocolMessenger) Ping(ctx context.Context, p peer.ID) error { req := pb.NewMessage(pb.Message_PING, nil, 0) - resp, err := pm.m.sendRequest(ctx, p, req) + resp, err := pm.m.SendRequest(ctx, p, req) if err != nil { return fmt.Errorf("sending request: %w", err) } diff --git a/subscriber_notifee.go b/subscriber_notifee.go index d8be15fca..a9fa8849e 100644 --- a/subscriber_notifee.go +++ b/subscriber_notifee.go @@ -173,7 +173,7 @@ func (nn *subscriberNotifee) Disconnected(n network.Network, v network.Conn) { return } - dht.protoMessenger.m.streamDisconnect(dht.Context(), p) + dht.messageMgr.streamDisconnect(dht.Context(), p) } func (nn *subscriberNotifee) Connected(network.Network, network.Conn) {} From 7b5446ab2f481b931dba89e14628f48856a99485 Mon Sep 17 00:00:00 2001 From: Adin Schmahmann Date: Tue, 6 Oct 2020 21:44:39 -0400 Subject: [PATCH 04/15] refactor: temporarily export loggable bytes array aliases before moving them into a new package --- dht.go | 8 ++++---- handlers.go | 8 ++++---- logging.go | 22 +++++++++++----------- messages.go | 2 +- routing.go | 14 +++++++------- 5 files changed, 27 insertions(+), 27 deletions(-) diff --git a/dht.go b/dht.go index 67848600d..c5b2dc219 100644 --- a/dht.go +++ b/dht.go @@ -540,17 +540,17 @@ var errInvalidRecord = errors.New("received invalid record") // getLocal attempts to retrieve the value from the datastore func (dht *IpfsDHT) getLocal(key string) (*recpb.Record, error) { - logger.Debugw("finding value in datastore", "key", loggableRecordKeyString(key)) + logger.Debugw("finding value in datastore", "key", LoggableRecordKeyString(key)) rec, err := dht.getRecordFromDatastore(mkDsKey(key)) if err != nil { - logger.Warnw("get local failed", "key", loggableRecordKeyString(key), "error", err) + logger.Warnw("get local failed", "key", LoggableRecordKeyString(key), "error", err) return nil, err } // Double check the key. Can't hurt. if rec != nil && string(rec.GetKey()) != key { - logger.Errorw("BUG: found a DHT record that didn't match it's key", "expected", loggableRecordKeyString(key), "got", rec.GetKey()) + logger.Errorw("BUG: found a DHT record that didn't match it's key", "expected", LoggableRecordKeyString(key), "got", rec.GetKey()) return nil, nil } @@ -561,7 +561,7 @@ func (dht *IpfsDHT) getLocal(key string) (*recpb.Record, error) { func (dht *IpfsDHT) putLocal(key string, rec *recpb.Record) error { data, err := proto.Marshal(rec) if err != nil { - logger.Warnw("failed to put marshal record for local put", "error", err, "key", loggableRecordKeyString(key)) + logger.Warnw("failed to put marshal record for local put", "error", err, "key", LoggableRecordKeyString(key)) return err } diff --git a/handlers.go b/handlers.go index 99bea1942..0aab5713a 100644 --- a/handlers.go +++ b/handlers.go @@ -167,7 +167,7 @@ func (dht *IpfsDHT) handlePutValue(ctx context.Context, p peer.ID, pmes *pb.Mess // Make sure the record is valid (not expired, valid signature etc) if err = dht.Validator.Validate(string(rec.GetKey()), rec.GetValue()); err != nil { - logger.Infow("bad dht record in PUT", "from", p, "key", loggableRecordKeyBytes(rec.GetKey()), "error", err) + logger.Infow("bad dht record in PUT", "from", p, "key", LoggableRecordKeyBytes(rec.GetKey()), "error", err) return nil, err } @@ -196,11 +196,11 @@ func (dht *IpfsDHT) handlePutValue(ctx context.Context, p peer.ID, pmes *pb.Mess recs := [][]byte{rec.GetValue(), existing.GetValue()} i, err := dht.Validator.Select(string(rec.GetKey()), recs) if err != nil { - logger.Warnw("dht record passed validation but failed select", "from", p, "key", loggableRecordKeyBytes(rec.GetKey()), "error", err) + logger.Warnw("dht record passed validation but failed select", "from", p, "key", LoggableRecordKeyBytes(rec.GetKey()), "error", err) return nil, err } if i != 0 { - logger.Infow("DHT record in PUT older than existing record (ignoring)", "peer", p, "key", loggableRecordKeyBytes(rec.GetKey())) + logger.Infow("DHT record in PUT older than existing record (ignoring)", "peer", p, "key", LoggableRecordKeyBytes(rec.GetKey())) return nil, errors.New("old record") } } @@ -344,7 +344,7 @@ func (dht *IpfsDHT) handleAddProvider(ctx context.Context, p peer.ID, pmes *pb.M return nil, fmt.Errorf("handleAddProvider key is empty") } - logger.Debugf("adding provider", "from", p, "key", loggableProviderRecordBytes(key)) + logger.Debugf("adding provider", "from", p, "key", LoggableProviderRecordBytes(key)) // add provider should use the address given in the message pinfos := pb.PBPeersToPeerInfos(pmes.GetProviderPeers()) diff --git a/logging.go b/logging.go index ffc337e3c..db1165c02 100644 --- a/logging.go +++ b/logging.go @@ -20,14 +20,14 @@ func multibaseB32Encode(k []byte) string { func tryFormatLoggableRecordKey(k string) (string, error) { if len(k) == 0 { - return "", fmt.Errorf("loggableRecordKey is empty") + return "", fmt.Errorf("LoggableRecordKey is empty") } var proto, cstr string if k[0] == '/' { // it's a path (probably) protoEnd := strings.IndexByte(k[1:], '/') if protoEnd < 0 { - return "", fmt.Errorf("loggableRecordKey starts with '/' but is not a path: %s", multibaseB32Encode([]byte(k))) + return "", fmt.Errorf("LoggableRecordKey starts with '/' but is not a path: %s", multibaseB32Encode([]byte(k))) } proto = k[1 : protoEnd+1] cstr = k[protoEnd+2:] @@ -36,12 +36,12 @@ func tryFormatLoggableRecordKey(k string) (string, error) { return fmt.Sprintf("/%s/%s", proto, encStr), nil } - return "", fmt.Errorf("loggableRecordKey is not a path: %s", multibaseB32Encode([]byte(cstr))) + return "", fmt.Errorf("LoggableRecordKey is not a path: %s", multibaseB32Encode([]byte(cstr))) } -type loggableRecordKeyString string +type LoggableRecordKeyString string -func (lk loggableRecordKeyString) String() string { +func (lk LoggableRecordKeyString) String() string { k := string(lk) newKey, err := tryFormatLoggableRecordKey(k) if err == nil { @@ -50,9 +50,9 @@ func (lk loggableRecordKeyString) String() string { return err.Error() } -type loggableRecordKeyBytes []byte +type LoggableRecordKeyBytes []byte -func (lk loggableRecordKeyBytes) String() string { +func (lk LoggableRecordKeyBytes) String() string { k := string(lk) newKey, err := tryFormatLoggableRecordKey(k) if err == nil { @@ -61,9 +61,9 @@ func (lk loggableRecordKeyBytes) String() string { return err.Error() } -type loggableProviderRecordBytes []byte +type LoggableProviderRecordBytes []byte -func (lk loggableProviderRecordBytes) String() string { +func (lk LoggableProviderRecordBytes) String() string { newKey, err := tryFormatLoggableProviderKey(lk) if err == nil { return newKey @@ -73,7 +73,7 @@ func (lk loggableProviderRecordBytes) String() string { func tryFormatLoggableProviderKey(k []byte) (string, error) { if len(k) == 0 { - return "", fmt.Errorf("loggableProviderKey is empty") + return "", fmt.Errorf("LoggableProviderKey is empty") } encodedKey := multibaseB32Encode(k) @@ -88,5 +88,5 @@ func tryFormatLoggableProviderKey(k []byte) (string, error) { return encodedKey, nil } - return "", fmt.Errorf("loggableProviderKey is not a Multihash or CID: %s", encodedKey) + return "", fmt.Errorf("LoggableProviderKey is not a Multihash or CID: %s", encodedKey) } diff --git a/messages.go b/messages.go index a0c3f4523..7b7deeae1 100644 --- a/messages.go +++ b/messages.go @@ -67,7 +67,7 @@ func (pm *ProtocolMessenger) PutValue(ctx context.Context, p peer.ID, rec *recpb pmes.Record = rec rpmes, err := pm.m.SendRequest(ctx, p, pmes) if err != nil { - logger.Debugw("failed to put value to peer", "to", p, "key", loggableRecordKeyBytes(rec.Key), "error", err) + logger.Debugw("failed to put value to peer", "to", p, "key", LoggableRecordKeyBytes(rec.Key), "error", err) return err } diff --git a/routing.go b/routing.go index 885dfe2e2..3813b4f9b 100644 --- a/routing.go +++ b/routing.go @@ -31,7 +31,7 @@ func (dht *IpfsDHT) PutValue(ctx context.Context, key string, value []byte, opts return routing.ErrNotSupported } - logger.Debugw("putting value", "key", loggableRecordKeyString(key)) + logger.Debugw("putting value", "key", LoggableRecordKeyString(key)) // don't even allow local users to put bad values. if err := dht.Validator.Validate(key, value); err != nil { @@ -127,7 +127,7 @@ func (dht *IpfsDHT) GetValue(ctx context.Context, key string, opts ...routing.Op if best == nil { return nil, routing.ErrNotFound } - logger.Debugf("GetValue %v %x", loggableRecordKeyString(key), best) + logger.Debugf("GetValue %v %x", LoggableRecordKeyString(key), best) return best, nil } @@ -246,7 +246,7 @@ loop: } sel, err := dht.Validator.Select(key, [][]byte{best, v.Val}) if err != nil { - logger.Warnw("failed to select best value", "key", loggableRecordKeyString(key), "error", err) + logger.Warnw("failed to select best value", "key", LoggableRecordKeyString(key), "error", err) continue } if sel != 1 { @@ -292,7 +292,7 @@ func (dht *IpfsDHT) getValues(ctx context.Context, key string, stopQuery chan st valCh := make(chan RecvdVal, 1) lookupResCh := make(chan *lookupWithFollowupResult, 1) - logger.Debugw("finding value", "key", loggableRecordKeyString(key)) + logger.Debugw("finding value", "key", LoggableRecordKeyString(key)) if rec, err := dht.getLocal(key); rec != nil && err == nil { select { @@ -398,7 +398,7 @@ func (dht *IpfsDHT) Provide(ctx context.Context, key cid.Cid, brdcst bool) (err return fmt.Errorf("invalid cid: undefined") } keyMH := key.Hash() - logger.Debugw("providing", "cid", key, "mh", loggableProviderRecordBytes(keyMH)) + logger.Debugw("providing", "cid", key, "mh", LoggableProviderRecordBytes(keyMH)) // add self locally dht.ProviderManager.AddProvider(ctx, keyMH, dht.self) @@ -448,7 +448,7 @@ func (dht *IpfsDHT) Provide(ctx context.Context, key cid.Cid, brdcst bool) (err wg.Add(1) go func(p peer.ID) { defer wg.Done() - logger.Debugf("putProvider(%s, %s)", loggableProviderRecordBytes(keyMH), p) + logger.Debugf("putProvider(%s, %s)", LoggableProviderRecordBytes(keyMH), p) err := dht.protoMessenger.PutProvider(ctx, p, keyMH, dht.host) if err != nil { logger.Debug(err) @@ -497,7 +497,7 @@ func (dht *IpfsDHT) FindProvidersAsync(ctx context.Context, key cid.Cid, count i keyMH := key.Hash() - logger.Debugw("finding providers", "cid", key, "mh", loggableProviderRecordBytes(keyMH)) + logger.Debugw("finding providers", "cid", key, "mh", LoggableProviderRecordBytes(keyMH)) go dht.findProvidersAsyncRoutine(ctx, keyMH, count, peerOut) return peerOut } From 7a80399a424ff3683a2d28fab56d64f8df6b479c Mon Sep 17 00:00:00 2001 From: Adin Schmahmann Date: Tue, 6 Oct 2020 21:52:39 -0400 Subject: [PATCH 05/15] refactor: move logging helpers to internal package --- dht.go | 9 +++++---- handlers.go | 9 +++++---- logging.go => internal/logging.go | 2 +- logging_test.go => internal/logging_test.go | 2 +- messages.go | 3 ++- routing.go | 15 ++++++++------- 6 files changed, 22 insertions(+), 18 deletions(-) rename logging.go => internal/logging.go (99%) rename logging_test.go => internal/logging_test.go (99%) diff --git a/dht.go b/dht.go index c5b2dc219..b8fde1332 100644 --- a/dht.go +++ b/dht.go @@ -16,6 +16,7 @@ import ( "github.com/libp2p/go-libp2p-core/protocol" "github.com/libp2p/go-libp2p-core/routing" + "github.com/libp2p/go-libp2p-kad-dht/internal" "github.com/libp2p/go-libp2p-kad-dht/metrics" pb "github.com/libp2p/go-libp2p-kad-dht/pb" "github.com/libp2p/go-libp2p-kad-dht/providers" @@ -540,17 +541,17 @@ var errInvalidRecord = errors.New("received invalid record") // getLocal attempts to retrieve the value from the datastore func (dht *IpfsDHT) getLocal(key string) (*recpb.Record, error) { - logger.Debugw("finding value in datastore", "key", LoggableRecordKeyString(key)) + logger.Debugw("finding value in datastore", "key", internal.LoggableRecordKeyString(key)) rec, err := dht.getRecordFromDatastore(mkDsKey(key)) if err != nil { - logger.Warnw("get local failed", "key", LoggableRecordKeyString(key), "error", err) + logger.Warnw("get local failed", "key", internal.LoggableRecordKeyString(key), "error", err) return nil, err } // Double check the key. Can't hurt. if rec != nil && string(rec.GetKey()) != key { - logger.Errorw("BUG: found a DHT record that didn't match it's key", "expected", LoggableRecordKeyString(key), "got", rec.GetKey()) + logger.Errorw("BUG: found a DHT record that didn't match it's key", "expected", internal.LoggableRecordKeyString(key), "got", rec.GetKey()) return nil, nil } @@ -561,7 +562,7 @@ func (dht *IpfsDHT) getLocal(key string) (*recpb.Record, error) { func (dht *IpfsDHT) putLocal(key string, rec *recpb.Record) error { data, err := proto.Marshal(rec) if err != nil { - logger.Warnw("failed to put marshal record for local put", "error", err, "key", LoggableRecordKeyString(key)) + logger.Warnw("failed to put marshal record for local put", "error", err, "key", internal.LoggableRecordKeyString(key)) return err } diff --git a/handlers.go b/handlers.go index 0aab5713a..5160232c0 100644 --- a/handlers.go +++ b/handlers.go @@ -14,6 +14,7 @@ import ( "github.com/gogo/protobuf/proto" ds "github.com/ipfs/go-datastore" u "github.com/ipfs/go-ipfs-util" + "github.com/libp2p/go-libp2p-kad-dht/internal" pb "github.com/libp2p/go-libp2p-kad-dht/pb" recpb "github.com/libp2p/go-libp2p-record/pb" "github.com/multiformats/go-base32" @@ -167,7 +168,7 @@ func (dht *IpfsDHT) handlePutValue(ctx context.Context, p peer.ID, pmes *pb.Mess // Make sure the record is valid (not expired, valid signature etc) if err = dht.Validator.Validate(string(rec.GetKey()), rec.GetValue()); err != nil { - logger.Infow("bad dht record in PUT", "from", p, "key", LoggableRecordKeyBytes(rec.GetKey()), "error", err) + logger.Infow("bad dht record in PUT", "from", p, "key", internal.LoggableRecordKeyBytes(rec.GetKey()), "error", err) return nil, err } @@ -196,11 +197,11 @@ func (dht *IpfsDHT) handlePutValue(ctx context.Context, p peer.ID, pmes *pb.Mess recs := [][]byte{rec.GetValue(), existing.GetValue()} i, err := dht.Validator.Select(string(rec.GetKey()), recs) if err != nil { - logger.Warnw("dht record passed validation but failed select", "from", p, "key", LoggableRecordKeyBytes(rec.GetKey()), "error", err) + logger.Warnw("dht record passed validation but failed select", "from", p, "key", internal.LoggableRecordKeyBytes(rec.GetKey()), "error", err) return nil, err } if i != 0 { - logger.Infow("DHT record in PUT older than existing record (ignoring)", "peer", p, "key", LoggableRecordKeyBytes(rec.GetKey())) + logger.Infow("DHT record in PUT older than existing record (ignoring)", "peer", p, "key", internal.LoggableRecordKeyBytes(rec.GetKey())) return nil, errors.New("old record") } } @@ -344,7 +345,7 @@ func (dht *IpfsDHT) handleAddProvider(ctx context.Context, p peer.ID, pmes *pb.M return nil, fmt.Errorf("handleAddProvider key is empty") } - logger.Debugf("adding provider", "from", p, "key", LoggableProviderRecordBytes(key)) + logger.Debugf("adding provider", "from", p, "key", internal.LoggableProviderRecordBytes(key)) // add provider should use the address given in the message pinfos := pb.PBPeersToPeerInfos(pmes.GetProviderPeers()) diff --git a/logging.go b/internal/logging.go similarity index 99% rename from logging.go rename to internal/logging.go index db1165c02..981f728cd 100644 --- a/logging.go +++ b/internal/logging.go @@ -1,4 +1,4 @@ -package dht +package internal import ( "fmt" diff --git a/logging_test.go b/internal/logging_test.go similarity index 99% rename from logging_test.go rename to internal/logging_test.go index 64e4f87ba..25bd033e8 100644 --- a/logging_test.go +++ b/internal/logging_test.go @@ -1,4 +1,4 @@ -package dht +package internal import ( "testing" diff --git a/messages.go b/messages.go index 7b7deeae1..219619942 100644 --- a/messages.go +++ b/messages.go @@ -14,6 +14,7 @@ import ( recpb "github.com/libp2p/go-libp2p-record/pb" "github.com/multiformats/go-multihash" + "github.com/libp2p/go-libp2p-kad-dht/internal" pb "github.com/libp2p/go-libp2p-kad-dht/pb" ) @@ -67,7 +68,7 @@ func (pm *ProtocolMessenger) PutValue(ctx context.Context, p peer.ID, rec *recpb pmes.Record = rec rpmes, err := pm.m.SendRequest(ctx, p, pmes) if err != nil { - logger.Debugw("failed to put value to peer", "to", p, "key", LoggableRecordKeyBytes(rec.Key), "error", err) + logger.Debugw("failed to put value to peer", "to", p, "key", internal.LoggableRecordKeyBytes(rec.Key), "error", err) return err } diff --git a/routing.go b/routing.go index 3813b4f9b..7812d7cf5 100644 --- a/routing.go +++ b/routing.go @@ -14,6 +14,7 @@ import ( "github.com/ipfs/go-cid" u "github.com/ipfs/go-ipfs-util" + "github.com/libp2p/go-libp2p-kad-dht/internal" "github.com/libp2p/go-libp2p-kad-dht/qpeerset" kb "github.com/libp2p/go-libp2p-kbucket" record "github.com/libp2p/go-libp2p-record" @@ -31,7 +32,7 @@ func (dht *IpfsDHT) PutValue(ctx context.Context, key string, value []byte, opts return routing.ErrNotSupported } - logger.Debugw("putting value", "key", LoggableRecordKeyString(key)) + logger.Debugw("putting value", "key", internal.LoggableRecordKeyString(key)) // don't even allow local users to put bad values. if err := dht.Validator.Validate(key, value); err != nil { @@ -127,7 +128,7 @@ func (dht *IpfsDHT) GetValue(ctx context.Context, key string, opts ...routing.Op if best == nil { return nil, routing.ErrNotFound } - logger.Debugf("GetValue %v %x", LoggableRecordKeyString(key), best) + logger.Debugf("GetValue %v %x", internal.LoggableRecordKeyString(key), best) return best, nil } @@ -246,7 +247,7 @@ loop: } sel, err := dht.Validator.Select(key, [][]byte{best, v.Val}) if err != nil { - logger.Warnw("failed to select best value", "key", LoggableRecordKeyString(key), "error", err) + logger.Warnw("failed to select best value", "key", internal.LoggableRecordKeyString(key), "error", err) continue } if sel != 1 { @@ -292,7 +293,7 @@ func (dht *IpfsDHT) getValues(ctx context.Context, key string, stopQuery chan st valCh := make(chan RecvdVal, 1) lookupResCh := make(chan *lookupWithFollowupResult, 1) - logger.Debugw("finding value", "key", LoggableRecordKeyString(key)) + logger.Debugw("finding value", "key", internal.LoggableRecordKeyString(key)) if rec, err := dht.getLocal(key); rec != nil && err == nil { select { @@ -398,7 +399,7 @@ func (dht *IpfsDHT) Provide(ctx context.Context, key cid.Cid, brdcst bool) (err return fmt.Errorf("invalid cid: undefined") } keyMH := key.Hash() - logger.Debugw("providing", "cid", key, "mh", LoggableProviderRecordBytes(keyMH)) + logger.Debugw("providing", "cid", key, "mh", internal.LoggableProviderRecordBytes(keyMH)) // add self locally dht.ProviderManager.AddProvider(ctx, keyMH, dht.self) @@ -448,7 +449,7 @@ func (dht *IpfsDHT) Provide(ctx context.Context, key cid.Cid, brdcst bool) (err wg.Add(1) go func(p peer.ID) { defer wg.Done() - logger.Debugf("putProvider(%s, %s)", LoggableProviderRecordBytes(keyMH), p) + logger.Debugf("putProvider(%s, %s)", internal.LoggableProviderRecordBytes(keyMH), p) err := dht.protoMessenger.PutProvider(ctx, p, keyMH, dht.host) if err != nil { logger.Debug(err) @@ -497,7 +498,7 @@ func (dht *IpfsDHT) FindProvidersAsync(ctx context.Context, key cid.Cid, count i keyMH := key.Hash() - logger.Debugw("finding providers", "cid", key, "mh", LoggableProviderRecordBytes(keyMH)) + logger.Debugw("finding providers", "cid", key, "mh", internal.LoggableProviderRecordBytes(keyMH)) go dht.findProvidersAsyncRoutine(ctx, keyMH, count, peerOut) return peerOut } From 52b4eb20b085f9126c54b0efc8c88214d7868d75 Mon Sep 17 00:00:00 2001 From: Adin Schmahmann Date: Tue, 6 Oct 2020 22:02:19 -0400 Subject: [PATCH 06/15] refactor: move invalid record error into internal package --- dht.go | 3 --- internal/errors.go | 5 +++++ messages.go | 2 +- routing.go | 2 +- 4 files changed, 7 insertions(+), 5 deletions(-) create mode 100644 internal/errors.go diff --git a/dht.go b/dht.go index b8fde1332..220cf4e80 100644 --- a/dht.go +++ b/dht.go @@ -2,7 +2,6 @@ package dht import ( "context" - "errors" "fmt" "math" "math/rand" @@ -537,8 +536,6 @@ func (dht *IpfsDHT) persistRTPeersInPeerStore() { } } -var errInvalidRecord = errors.New("received invalid record") - // getLocal attempts to retrieve the value from the datastore func (dht *IpfsDHT) getLocal(key string) (*recpb.Record, error) { logger.Debugw("finding value in datastore", "key", internal.LoggableRecordKeyString(key)) diff --git a/internal/errors.go b/internal/errors.go new file mode 100644 index 000000000..3c32a83dc --- /dev/null +++ b/internal/errors.go @@ -0,0 +1,5 @@ +package internal + +import "errors" + +var ErrInvalidRecord = errors.New("received invalid record") diff --git a/messages.go b/messages.go index 219619942..699e67525 100644 --- a/messages.go +++ b/messages.go @@ -101,7 +101,7 @@ func (pm *ProtocolMessenger) GetValue(ctx context.Context, p peer.ID, key string if err != nil { logger.Debug("received invalid record (discarded)") // return a sentinel to signify an invalid record was received - err = errInvalidRecord + err = internal.ErrInvalidRecord rec = new(recpb.Record) } return rec, peers, err diff --git a/routing.go b/routing.go index 7812d7cf5..d14e3845f 100644 --- a/routing.go +++ b/routing.go @@ -329,7 +329,7 @@ func (dht *IpfsDHT) getValues(ctx context.Context, key string, stopQuery chan st return nil, err default: return nil, err - case nil, errInvalidRecord: + case nil, internal.ErrInvalidRecord: // in either of these cases, we want to keep going } From f73b905eba722e79624c58a5b305f0a78c3ddf0e Mon Sep 17 00:00:00 2001 From: Adin Schmahmann Date: Tue, 6 Oct 2020 22:06:18 -0400 Subject: [PATCH 07/15] refactor: move ProtocolMessenger into new package --- dht.go | 5 +++-- messages.go => wire/messages.go | 5 ++++- 2 files changed, 7 insertions(+), 3 deletions(-) rename messages.go => wire/messages.go (98%) diff --git a/dht.go b/dht.go index 220cf4e80..7eb5dff00 100644 --- a/dht.go +++ b/dht.go @@ -20,6 +20,7 @@ import ( pb "github.com/libp2p/go-libp2p-kad-dht/pb" "github.com/libp2p/go-libp2p-kad-dht/providers" "github.com/libp2p/go-libp2p-kad-dht/rtrefresh" + "github.com/libp2p/go-libp2p-kad-dht/wire" kb "github.com/libp2p/go-libp2p-kbucket" "github.com/libp2p/go-libp2p-kbucket/peerdiversity" record "github.com/libp2p/go-libp2p-record" @@ -95,7 +96,7 @@ type IpfsDHT struct { ctx context.Context proc goprocess.Process - protoMessenger *ProtocolMessenger + protoMessenger *wire.ProtocolMessenger messageMgr *messageManager plk sync.Mutex @@ -193,7 +194,7 @@ func New(ctx context.Context, h host.Host, options ...Option) (*IpfsDHT, error) strmap: make(map[peer.ID]*messageSender), protocols: dht.protocols, } - dht.protoMessenger, err = NewProtocolMessenger(dht.messageMgr, WithValidator(dht.Validator)) + dht.protoMessenger, err = wire.NewProtocolMessenger(dht.messageMgr, wire.WithValidator(dht.Validator)) if err != nil { return nil, err } diff --git a/messages.go b/wire/messages.go similarity index 98% rename from messages.go rename to wire/messages.go index 699e67525..0bb13dba7 100644 --- a/messages.go +++ b/wire/messages.go @@ -1,4 +1,4 @@ -package dht +package wire import ( "bytes" @@ -10,6 +10,7 @@ import ( "github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/routing" + logging "github.com/ipfs/go-log" record "github.com/libp2p/go-libp2p-record" recpb "github.com/libp2p/go-libp2p-record/pb" "github.com/multiformats/go-multihash" @@ -18,6 +19,8 @@ import ( pb "github.com/libp2p/go-libp2p-kad-dht/pb" ) +var logger = logging.Logger("dht") + // ProtocolMessenger can be used for sending DHT messages to peers and processing their responses. // This decouples the wire protocol format from both the DHT protocol implementation and from the implementation of the // routing.Routing interface. From 786c3c9efabc7104cb55d9cac93710817e263859 Mon Sep 17 00:00:00 2001 From: Adin Schmahmann Date: Tue, 6 Oct 2020 23:27:52 -0400 Subject: [PATCH 08/15] refactor: return nil instead of an empty record when processing an invalid record --- wire/messages.go | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/wire/messages.go b/wire/messages.go index 0bb13dba7..d17080063 100644 --- a/wire/messages.go +++ b/wire/messages.go @@ -104,8 +104,7 @@ func (pm *ProtocolMessenger) GetValue(ctx context.Context, p peer.ID, key string if err != nil { logger.Debug("received invalid record (discarded)") // return a sentinel to signify an invalid record was received - err = internal.ErrInvalidRecord - rec = new(recpb.Record) + return nil, peers, internal.ErrInvalidRecord } return rec, peers, err } From 4b3a573a87da99d93ac83de7863b22afa8d0d3c2 Mon Sep 17 00:00:00 2001 From: Adin Schmahmann Date: Mon, 12 Oct 2020 14:15:26 -0400 Subject: [PATCH 09/15] refactor: move message manager to its own file --- dht_net.go | 308 ------------------------------------------ message_manager.go | 324 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 324 insertions(+), 308 deletions(-) create mode 100644 message_manager.go diff --git a/dht_net.go b/dht_net.go index 35c52bb06..278216625 100644 --- a/dht_net.go +++ b/dht_net.go @@ -2,16 +2,12 @@ package dht import ( "bufio" - "context" "fmt" "io" "sync" "time" - "github.com/libp2p/go-libp2p-core/host" "github.com/libp2p/go-libp2p-core/network" - "github.com/libp2p/go-libp2p-core/peer" - "github.com/libp2p/go-libp2p-core/protocol" "github.com/libp2p/go-msgio/protoio" "github.com/libp2p/go-libp2p-kad-dht/metrics" @@ -209,307 +205,3 @@ func (dht *IpfsDHT) handleNewMessage(s network.Stream) bool { stats.Record(ctx, metrics.InboundRequestLatency.M(latencyMillis)) } } - -type messageManager struct { - host host.Host // the network services we need - smlk sync.Mutex - strmap map[peer.ID]*messageSender - protocols []protocol.ID -} - -func (m *messageManager) streamDisconnect(ctx context.Context, p peer.ID) { - m.smlk.Lock() - defer m.smlk.Unlock() - ms, ok := m.strmap[p] - if !ok { - return - } - delete(m.strmap, p) - - // Do this asynchronously as ms.lk can block for a while. - go func() { - if err := ms.lk.Lock(ctx); err != nil { - return - } - defer ms.lk.Unlock() - ms.invalidate() - }() -} - -// SendRequest sends out a request, but also makes sure to -// measure the RTT for latency measurements. -func (m *messageManager) SendRequest(ctx context.Context, p peer.ID, pmes *pb.Message) (*pb.Message, error) { - ctx, _ = tag.New(ctx, metrics.UpsertMessageType(pmes)) - - ms, err := m.messageSenderForPeer(ctx, p) - if err != nil { - stats.Record(ctx, - metrics.SentRequests.M(1), - metrics.SentRequestErrors.M(1), - ) - logger.Debugw("request failed to open message sender", "error", err, "to", p) - return nil, err - } - - start := time.Now() - - rpmes, err := ms.SendRequest(ctx, pmes) - if err != nil { - stats.Record(ctx, - metrics.SentRequests.M(1), - metrics.SentRequestErrors.M(1), - ) - logger.Debugw("request failed", "error", err, "to", p) - return nil, err - } - - stats.Record(ctx, - metrics.SentRequests.M(1), - metrics.SentBytes.M(int64(pmes.Size())), - metrics.OutboundRequestLatency.M(float64(time.Since(start))/float64(time.Millisecond)), - ) - m.host.Peerstore().RecordLatency(p, time.Since(start)) - return rpmes, nil -} - -// SendMessage sends out a message -func (m *messageManager) SendMessage(ctx context.Context, p peer.ID, pmes *pb.Message) error { - ctx, _ = tag.New(ctx, metrics.UpsertMessageType(pmes)) - - ms, err := m.messageSenderForPeer(ctx, p) - if err != nil { - stats.Record(ctx, - metrics.SentMessages.M(1), - metrics.SentMessageErrors.M(1), - ) - logger.Debugw("message failed to open message sender", "error", err, "to", p) - return err - } - - if err := ms.SendMessage(ctx, pmes); err != nil { - stats.Record(ctx, - metrics.SentMessages.M(1), - metrics.SentMessageErrors.M(1), - ) - logger.Debugw("message failed", "error", err, "to", p) - return err - } - - stats.Record(ctx, - metrics.SentMessages.M(1), - metrics.SentBytes.M(int64(pmes.Size())), - ) - return nil -} - -func (m *messageManager) messageSenderForPeer(ctx context.Context, p peer.ID) (*messageSender, error) { - m.smlk.Lock() - ms, ok := m.strmap[p] - if ok { - m.smlk.Unlock() - return ms, nil - } - ms = &messageSender{p: p, m: m, lk: newCtxMutex()} - m.strmap[p] = ms - m.smlk.Unlock() - - if err := ms.prepOrInvalidate(ctx); err != nil { - m.smlk.Lock() - defer m.smlk.Unlock() - - if msCur, ok := m.strmap[p]; ok { - // Changed. Use the new one, old one is invalid and - // not in the map so we can just throw it away. - if ms != msCur { - return msCur, nil - } - // Not changed, remove the now invalid stream from the - // map. - delete(m.strmap, p) - } - // Invalid but not in map. Must have been removed by a disconnect. - return nil, err - } - // All ready to go. - return ms, nil -} - -type messageSender struct { - s network.Stream - r msgio.ReadCloser - lk ctxMutex - p peer.ID - m *messageManager - - invalid bool - singleMes int -} - -// invalidate is called before this messageSender is removed from the strmap. -// It prevents the messageSender from being reused/reinitialized and then -// forgotten (leaving the stream open). -func (ms *messageSender) invalidate() { - ms.invalid = true - if ms.s != nil { - _ = ms.s.Reset() - ms.s = nil - } -} - -func (ms *messageSender) prepOrInvalidate(ctx context.Context) error { - if err := ms.lk.Lock(ctx); err != nil { - return err - } - defer ms.lk.Unlock() - - if err := ms.prep(ctx); err != nil { - ms.invalidate() - return err - } - return nil -} - -func (ms *messageSender) prep(ctx context.Context) error { - if ms.invalid { - return fmt.Errorf("message sender has been invalidated") - } - if ms.s != nil { - return nil - } - - // We only want to speak to peers using our primary protocols. We do not want to query any peer that only speaks - // one of the secondary "server" protocols that we happen to support (e.g. older nodes that we can respond to for - // backwards compatibility reasons). - nstr, err := ms.m.host.NewStream(ctx, ms.p, ms.m.protocols...) - if err != nil { - return err - } - - ms.r = msgio.NewVarintReaderSize(nstr, network.MessageSizeMax) - ms.s = nstr - - return nil -} - -// streamReuseTries is the number of times we will try to reuse a stream to a -// given peer before giving up and reverting to the old one-message-per-stream -// behaviour. -const streamReuseTries = 3 - -func (ms *messageSender) SendMessage(ctx context.Context, pmes *pb.Message) error { - if err := ms.lk.Lock(ctx); err != nil { - return err - } - defer ms.lk.Unlock() - - retry := false - for { - if err := ms.prep(ctx); err != nil { - return err - } - - if err := ms.writeMsg(pmes); err != nil { - _ = ms.s.Reset() - ms.s = nil - - if retry { - logger.Debugw("error writing message", "error", err) - return err - } - logger.Debugw("error writing message", "error", err, "retrying", true) - retry = true - continue - } - - var err error - if ms.singleMes > streamReuseTries { - err = ms.s.Close() - ms.s = nil - } else if retry { - ms.singleMes++ - } - - return err - } -} - -func (ms *messageSender) SendRequest(ctx context.Context, pmes *pb.Message) (*pb.Message, error) { - if err := ms.lk.Lock(ctx); err != nil { - return nil, err - } - defer ms.lk.Unlock() - - retry := false - for { - if err := ms.prep(ctx); err != nil { - return nil, err - } - - if err := ms.writeMsg(pmes); err != nil { - _ = ms.s.Reset() - ms.s = nil - - if retry { - logger.Debugw("error writing message", "error", err) - return nil, err - } - logger.Debugw("error writing message", "error", err, "retrying", true) - retry = true - continue - } - - mes := new(pb.Message) - if err := ms.ctxReadMsg(ctx, mes); err != nil { - _ = ms.s.Reset() - ms.s = nil - - if retry { - logger.Debugw("error reading message", "error", err) - return nil, err - } - logger.Debugw("error reading message", "error", err, "retrying", true) - retry = true - continue - } - - var err error - if ms.singleMes > streamReuseTries { - err = ms.s.Close() - ms.s = nil - } else if retry { - ms.singleMes++ - } - - return mes, err - } -} - -func (ms *messageSender) writeMsg(pmes *pb.Message) error { - return writeMsg(ms.s, pmes) -} - -func (ms *messageSender) ctxReadMsg(ctx context.Context, mes *pb.Message) error { - errc := make(chan error, 1) - go func(r msgio.ReadCloser) { - defer close(errc) - bytes, err := r.ReadMsg() - defer r.ReleaseMsg(bytes) - if err != nil { - errc <- err - return - } - errc <- mes.Unmarshal(bytes) - }(ms.r) - - t := time.NewTimer(dhtReadMessageTimeout) - defer t.Stop() - - select { - case err := <-errc: - return err - case <-ctx.Done(): - return ctx.Err() - case <-t.C: - return ErrReadTimeout - } -} diff --git a/message_manager.go b/message_manager.go new file mode 100644 index 000000000..74a099862 --- /dev/null +++ b/message_manager.go @@ -0,0 +1,324 @@ +package dht + +import ( + "context" + "fmt" + "sync" + "time" + + "github.com/libp2p/go-libp2p-core/host" + "github.com/libp2p/go-libp2p-core/network" + "github.com/libp2p/go-libp2p-core/peer" + "github.com/libp2p/go-libp2p-core/protocol" + + "github.com/libp2p/go-libp2p-kad-dht/metrics" + pb "github.com/libp2p/go-libp2p-kad-dht/pb" + + "github.com/libp2p/go-msgio" + "go.opencensus.io/stats" + "go.opencensus.io/tag" +) + +type messageManager struct { + host host.Host // the network services we need + smlk sync.Mutex + strmap map[peer.ID]*messageSender + protocols []protocol.ID +} + +func (m *messageManager) streamDisconnect(ctx context.Context, p peer.ID) { + m.smlk.Lock() + defer m.smlk.Unlock() + ms, ok := m.strmap[p] + if !ok { + return + } + delete(m.strmap, p) + + // Do this asynchronously as ms.lk can block for a while. + go func() { + if err := ms.lk.Lock(ctx); err != nil { + return + } + defer ms.lk.Unlock() + ms.invalidate() + }() +} + +// SendRequest sends out a request, but also makes sure to +// measure the RTT for latency measurements. +func (m *messageManager) SendRequest(ctx context.Context, p peer.ID, pmes *pb.Message) (*pb.Message, error) { + ctx, _ = tag.New(ctx, metrics.UpsertMessageType(pmes)) + + ms, err := m.messageSenderForPeer(ctx, p) + if err != nil { + stats.Record(ctx, + metrics.SentRequests.M(1), + metrics.SentRequestErrors.M(1), + ) + logger.Debugw("request failed to open message sender", "error", err, "to", p) + return nil, err + } + + start := time.Now() + + rpmes, err := ms.SendRequest(ctx, pmes) + if err != nil { + stats.Record(ctx, + metrics.SentRequests.M(1), + metrics.SentRequestErrors.M(1), + ) + logger.Debugw("request failed", "error", err, "to", p) + return nil, err + } + + stats.Record(ctx, + metrics.SentRequests.M(1), + metrics.SentBytes.M(int64(pmes.Size())), + metrics.OutboundRequestLatency.M(float64(time.Since(start))/float64(time.Millisecond)), + ) + m.host.Peerstore().RecordLatency(p, time.Since(start)) + return rpmes, nil +} + +// SendMessage sends out a message +func (m *messageManager) SendMessage(ctx context.Context, p peer.ID, pmes *pb.Message) error { + ctx, _ = tag.New(ctx, metrics.UpsertMessageType(pmes)) + + ms, err := m.messageSenderForPeer(ctx, p) + if err != nil { + stats.Record(ctx, + metrics.SentMessages.M(1), + metrics.SentMessageErrors.M(1), + ) + logger.Debugw("message failed to open message sender", "error", err, "to", p) + return err + } + + if err := ms.SendMessage(ctx, pmes); err != nil { + stats.Record(ctx, + metrics.SentMessages.M(1), + metrics.SentMessageErrors.M(1), + ) + logger.Debugw("message failed", "error", err, "to", p) + return err + } + + stats.Record(ctx, + metrics.SentMessages.M(1), + metrics.SentBytes.M(int64(pmes.Size())), + ) + return nil +} + +func (m *messageManager) messageSenderForPeer(ctx context.Context, p peer.ID) (*messageSender, error) { + m.smlk.Lock() + ms, ok := m.strmap[p] + if ok { + m.smlk.Unlock() + return ms, nil + } + ms = &messageSender{p: p, m: m, lk: newCtxMutex()} + m.strmap[p] = ms + m.smlk.Unlock() + + if err := ms.prepOrInvalidate(ctx); err != nil { + m.smlk.Lock() + defer m.smlk.Unlock() + + if msCur, ok := m.strmap[p]; ok { + // Changed. Use the new one, old one is invalid and + // not in the map so we can just throw it away. + if ms != msCur { + return msCur, nil + } + // Not changed, remove the now invalid stream from the + // map. + delete(m.strmap, p) + } + // Invalid but not in map. Must have been removed by a disconnect. + return nil, err + } + // All ready to go. + return ms, nil +} + +type messageSender struct { + s network.Stream + r msgio.ReadCloser + lk ctxMutex + p peer.ID + m *messageManager + + invalid bool + singleMes int +} + +// invalidate is called before this messageSender is removed from the strmap. +// It prevents the messageSender from being reused/reinitialized and then +// forgotten (leaving the stream open). +func (ms *messageSender) invalidate() { + ms.invalid = true + if ms.s != nil { + _ = ms.s.Reset() + ms.s = nil + } +} + +func (ms *messageSender) prepOrInvalidate(ctx context.Context) error { + if err := ms.lk.Lock(ctx); err != nil { + return err + } + defer ms.lk.Unlock() + + if err := ms.prep(ctx); err != nil { + ms.invalidate() + return err + } + return nil +} + +func (ms *messageSender) prep(ctx context.Context) error { + if ms.invalid { + return fmt.Errorf("message sender has been invalidated") + } + if ms.s != nil { + return nil + } + + // We only want to speak to peers using our primary protocols. We do not want to query any peer that only speaks + // one of the secondary "server" protocols that we happen to support (e.g. older nodes that we can respond to for + // backwards compatibility reasons). + nstr, err := ms.m.host.NewStream(ctx, ms.p, ms.m.protocols...) + if err != nil { + return err + } + + ms.r = msgio.NewVarintReaderSize(nstr, network.MessageSizeMax) + ms.s = nstr + + return nil +} + +// streamReuseTries is the number of times we will try to reuse a stream to a +// given peer before giving up and reverting to the old one-message-per-stream +// behaviour. +const streamReuseTries = 3 + +func (ms *messageSender) SendMessage(ctx context.Context, pmes *pb.Message) error { + if err := ms.lk.Lock(ctx); err != nil { + return err + } + defer ms.lk.Unlock() + + retry := false + for { + if err := ms.prep(ctx); err != nil { + return err + } + + if err := ms.writeMsg(pmes); err != nil { + _ = ms.s.Reset() + ms.s = nil + + if retry { + logger.Debugw("error writing message", "error", err) + return err + } + logger.Debugw("error writing message", "error", err, "retrying", true) + retry = true + continue + } + + var err error + if ms.singleMes > streamReuseTries { + err = ms.s.Close() + ms.s = nil + } else if retry { + ms.singleMes++ + } + + return err + } +} + +func (ms *messageSender) SendRequest(ctx context.Context, pmes *pb.Message) (*pb.Message, error) { + if err := ms.lk.Lock(ctx); err != nil { + return nil, err + } + defer ms.lk.Unlock() + + retry := false + for { + if err := ms.prep(ctx); err != nil { + return nil, err + } + + if err := ms.writeMsg(pmes); err != nil { + _ = ms.s.Reset() + ms.s = nil + + if retry { + logger.Debugw("error writing message", "error", err) + return nil, err + } + logger.Debugw("error writing message", "error", err, "retrying", true) + retry = true + continue + } + + mes := new(pb.Message) + if err := ms.ctxReadMsg(ctx, mes); err != nil { + _ = ms.s.Reset() + ms.s = nil + + if retry { + logger.Debugw("error reading message", "error", err) + return nil, err + } + logger.Debugw("error reading message", "error", err, "retrying", true) + retry = true + continue + } + + var err error + if ms.singleMes > streamReuseTries { + err = ms.s.Close() + ms.s = nil + } else if retry { + ms.singleMes++ + } + + return mes, err + } +} + +func (ms *messageSender) writeMsg(pmes *pb.Message) error { + return writeMsg(ms.s, pmes) +} + +func (ms *messageSender) ctxReadMsg(ctx context.Context, mes *pb.Message) error { + errc := make(chan error, 1) + go func(r msgio.ReadCloser) { + defer close(errc) + bytes, err := r.ReadMsg() + defer r.ReleaseMsg(bytes) + if err != nil { + errc <- err + return + } + errc <- mes.Unmarshal(bytes) + }(ms.r) + + t := time.NewTimer(dhtReadMessageTimeout) + defer t.Stop() + + select { + case err := <-errc: + return err + case <-ctx.Done(): + return ctx.Err() + case <-t.C: + return ErrReadTimeout + } +} From 99772569173eea1ecaf2e6dd0b6a69b5f04f8fae Mon Sep 17 00:00:00 2001 From: Adin Schmahmann Date: Mon, 12 Oct 2020 14:26:15 -0400 Subject: [PATCH 10/15] refactor: cleanup protocol messenger comments --- wire/messages.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/wire/messages.go b/wire/messages.go index d17080063..16e930c31 100644 --- a/wire/messages.go +++ b/wire/messages.go @@ -25,8 +25,8 @@ var logger = logging.Logger("dht") // This decouples the wire protocol format from both the DHT protocol implementation and from the implementation of the // routing.Routing interface. // -// TODO: This is still strongly coupled with the existing implementation of what happens when a peer actually sends a -// message on the wire (e.g. reusing streams, reusing connections, metrics tracking, etc.). +// Note: the ProtocolMessenger's MessageSender still needs to deal with some wire protocol details such as using +// varint-delineated protobufs type ProtocolMessenger struct { m MessageSender validator record.Validator From 138cb80ad5553b71e34660a34ce263019720b054 Mon Sep 17 00:00:00 2001 From: Adin Schmahmann Date: Mon, 12 Oct 2020 14:29:52 -0400 Subject: [PATCH 11/15] refactor: move protocol messenger to pb package --- dht.go | 5 ++-- wire/messages.go => pb/protocol_messenger.go | 31 ++++++++++---------- 2 files changed, 17 insertions(+), 19 deletions(-) rename wire/messages.go => pb/protocol_messenger.go (85%) diff --git a/dht.go b/dht.go index 7eb5dff00..cd391a76d 100644 --- a/dht.go +++ b/dht.go @@ -20,7 +20,6 @@ import ( pb "github.com/libp2p/go-libp2p-kad-dht/pb" "github.com/libp2p/go-libp2p-kad-dht/providers" "github.com/libp2p/go-libp2p-kad-dht/rtrefresh" - "github.com/libp2p/go-libp2p-kad-dht/wire" kb "github.com/libp2p/go-libp2p-kbucket" "github.com/libp2p/go-libp2p-kbucket/peerdiversity" record "github.com/libp2p/go-libp2p-record" @@ -96,7 +95,7 @@ type IpfsDHT struct { ctx context.Context proc goprocess.Process - protoMessenger *wire.ProtocolMessenger + protoMessenger *pb.ProtocolMessenger messageMgr *messageManager plk sync.Mutex @@ -194,7 +193,7 @@ func New(ctx context.Context, h host.Host, options ...Option) (*IpfsDHT, error) strmap: make(map[peer.ID]*messageSender), protocols: dht.protocols, } - dht.protoMessenger, err = wire.NewProtocolMessenger(dht.messageMgr, wire.WithValidator(dht.Validator)) + dht.protoMessenger, err = pb.NewProtocolMessenger(dht.messageMgr, pb.WithValidator(dht.Validator)) if err != nil { return nil, err } diff --git a/wire/messages.go b/pb/protocol_messenger.go similarity index 85% rename from wire/messages.go rename to pb/protocol_messenger.go index 16e930c31..97eb26477 100644 --- a/wire/messages.go +++ b/pb/protocol_messenger.go @@ -1,4 +1,4 @@ -package wire +package dht_pb import ( "bytes" @@ -16,7 +16,6 @@ import ( "github.com/multiformats/go-multihash" "github.com/libp2p/go-libp2p-kad-dht/internal" - pb "github.com/libp2p/go-libp2p-kad-dht/pb" ) var logger = logging.Logger("dht") @@ -60,14 +59,14 @@ func NewProtocolMessenger(msgSender MessageSender, opts ...ProtocolMessengerOpti // MessageSender handles sending wire protocol messages to a given peer type MessageSender interface { // SendRequest sends a peer a message and waits for its response - SendRequest(ctx context.Context, p peer.ID, pmes *pb.Message) (*pb.Message, error) + SendRequest(ctx context.Context, p peer.ID, pmes *Message) (*Message, error) // SendMessage sends a peer a message without waiting on a response - SendMessage(ctx context.Context, p peer.ID, pmes *pb.Message) error + SendMessage(ctx context.Context, p peer.ID, pmes *Message) error } // PutValue asks a peer to store the given key/value pair. func (pm *ProtocolMessenger) PutValue(ctx context.Context, p peer.ID, rec *recpb.Record) error { - pmes := pb.NewMessage(pb.Message_PUT_VALUE, rec.Key, 0) + pmes := NewMessage(Message_PUT_VALUE, rec.Key, 0) pmes.Record = rec rpmes, err := pm.m.SendRequest(ctx, p, pmes) if err != nil { @@ -86,14 +85,14 @@ func (pm *ProtocolMessenger) PutValue(ctx context.Context, p peer.ID, rec *recpb // GetValue asks a peer for the value corresponding to the given key. Also returns the K closest peers to the key // as described in GetClosestPeers. func (pm *ProtocolMessenger) GetValue(ctx context.Context, p peer.ID, key string) (*recpb.Record, []*peer.AddrInfo, error) { - pmes := pb.NewMessage(pb.Message_GET_VALUE, []byte(key), 0) + pmes := NewMessage(Message_GET_VALUE, []byte(key), 0) respMsg, err := pm.m.SendRequest(ctx, p, pmes) if err != nil { return nil, nil, err } // Perhaps we were given closer peers - peers := pb.PBPeersToPeerInfos(respMsg.GetCloserPeers()) + peers := PBPeersToPeerInfos(respMsg.GetCloserPeers()) if rec := respMsg.GetRecord(); rec != nil { // Success! We were given the value @@ -120,12 +119,12 @@ func (pm *ProtocolMessenger) GetValue(ctx context.Context, p peer.ID, key string // Note: If the peer happens to know another peer whose peerID exactly matches the given id it will return that peer // even if that peer is not a DHT server node. func (pm *ProtocolMessenger) GetClosestPeers(ctx context.Context, p peer.ID, id peer.ID) ([]*peer.AddrInfo, error) { - pmes := pb.NewMessage(pb.Message_FIND_NODE, []byte(id), 0) + pmes := NewMessage(Message_FIND_NODE, []byte(id), 0) respMsg, err := pm.m.SendRequest(ctx, p, pmes) if err != nil { return nil, err } - peers := pb.PBPeersToPeerInfos(respMsg.GetCloserPeers()) + peers := PBPeersToPeerInfos(respMsg.GetCloserPeers()) return peers, nil } @@ -142,8 +141,8 @@ func (pm *ProtocolMessenger) PutProvider(ctx context.Context, p peer.ID, key mul return fmt.Errorf("no known addresses for self, cannot put provider") } - pmes := pb.NewMessage(pb.Message_ADD_PROVIDER, key, 0) - pmes.ProviderPeers = pb.RawPeerInfosToPBPeers([]peer.AddrInfo{pi}) + pmes := NewMessage(Message_ADD_PROVIDER, key, 0) + pmes.ProviderPeers = RawPeerInfosToPBPeers([]peer.AddrInfo{pi}) return pm.m.SendMessage(ctx, p, pmes) } @@ -151,24 +150,24 @@ func (pm *ProtocolMessenger) PutProvider(ctx context.Context, p peer.ID, key mul // GetProviders asks a peer for the providers it knows of for a given key. Also returns the K closest peers to the key // as described in GetClosestPeers. func (pm *ProtocolMessenger) GetProviders(ctx context.Context, p peer.ID, key multihash.Multihash) ([]*peer.AddrInfo, []*peer.AddrInfo, error) { - pmes := pb.NewMessage(pb.Message_GET_PROVIDERS, key, 0) + pmes := NewMessage(Message_GET_PROVIDERS, key, 0) respMsg, err := pm.m.SendRequest(ctx, p, pmes) if err != nil { return nil, nil, err } - provs := pb.PBPeersToPeerInfos(respMsg.GetProviderPeers()) - closerPeers := pb.PBPeersToPeerInfos(respMsg.GetCloserPeers()) + provs := PBPeersToPeerInfos(respMsg.GetProviderPeers()) + closerPeers := PBPeersToPeerInfos(respMsg.GetCloserPeers()) return provs, closerPeers, nil } // Ping sends a ping message to the passed peer and waits for a response. func (pm *ProtocolMessenger) Ping(ctx context.Context, p peer.ID) error { - req := pb.NewMessage(pb.Message_PING, nil, 0) + req := NewMessage(Message_PING, nil, 0) resp, err := pm.m.SendRequest(ctx, p, req) if err != nil { return fmt.Errorf("sending request: %w", err) } - if resp.Type != pb.Message_PING { + if resp.Type != Message_PING { return fmt.Errorf("got unexpected response type: %v", resp.Type) } return nil From 2c313a81a57be98e237a9abf5a68ad232104b994 Mon Sep 17 00:00:00 2001 From: Adin Schmahmann Date: Mon, 4 Jan 2021 13:57:01 -0500 Subject: [PATCH 12/15] refactor: rename messageSender to peerMessageSender and added comments to clarify what peerMessageSender and messageManager do --- dht.go | 2 +- message_manager.go | 29 ++++++++++++++++------------- 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/dht.go b/dht.go index cd391a76d..b75be2c5e 100644 --- a/dht.go +++ b/dht.go @@ -190,7 +190,7 @@ func New(ctx context.Context, h host.Host, options ...Option) (*IpfsDHT, error) dht.Validator = cfg.validator dht.messageMgr = &messageManager{ host: h, - strmap: make(map[peer.ID]*messageSender), + strmap: make(map[peer.ID]*peerMessageSender), protocols: dht.protocols, } dht.protoMessenger, err = pb.NewProtocolMessenger(dht.messageMgr, pb.WithValidator(dht.Validator)) diff --git a/message_manager.go b/message_manager.go index 74a099862..edc96ee80 100644 --- a/message_manager.go +++ b/message_manager.go @@ -19,10 +19,12 @@ import ( "go.opencensus.io/tag" ) +// messageManager is responsible for sending requests and messages to peers efficiently, including reuse of streams. +// It also tracks metrics for sent requests and messages. type messageManager struct { host host.Host // the network services we need smlk sync.Mutex - strmap map[peer.ID]*messageSender + strmap map[peer.ID]*peerMessageSender protocols []protocol.ID } @@ -111,14 +113,14 @@ func (m *messageManager) SendMessage(ctx context.Context, p peer.ID, pmes *pb.Me return nil } -func (m *messageManager) messageSenderForPeer(ctx context.Context, p peer.ID) (*messageSender, error) { +func (m *messageManager) messageSenderForPeer(ctx context.Context, p peer.ID) (*peerMessageSender, error) { m.smlk.Lock() ms, ok := m.strmap[p] if ok { m.smlk.Unlock() return ms, nil } - ms = &messageSender{p: p, m: m, lk: newCtxMutex()} + ms = &peerMessageSender{p: p, m: m, lk: newCtxMutex()} m.strmap[p] = ms m.smlk.Unlock() @@ -143,7 +145,8 @@ func (m *messageManager) messageSenderForPeer(ctx context.Context, p peer.ID) (* return ms, nil } -type messageSender struct { +// peerMessageSender is responsible for sending requests and messages to a particular peer +type peerMessageSender struct { s network.Stream r msgio.ReadCloser lk ctxMutex @@ -154,10 +157,10 @@ type messageSender struct { singleMes int } -// invalidate is called before this messageSender is removed from the strmap. -// It prevents the messageSender from being reused/reinitialized and then +// invalidate is called before this peerMessageSender is removed from the strmap. +// It prevents the peerMessageSender from being reused/reinitialized and then // forgotten (leaving the stream open). -func (ms *messageSender) invalidate() { +func (ms *peerMessageSender) invalidate() { ms.invalid = true if ms.s != nil { _ = ms.s.Reset() @@ -165,7 +168,7 @@ func (ms *messageSender) invalidate() { } } -func (ms *messageSender) prepOrInvalidate(ctx context.Context) error { +func (ms *peerMessageSender) prepOrInvalidate(ctx context.Context) error { if err := ms.lk.Lock(ctx); err != nil { return err } @@ -178,7 +181,7 @@ func (ms *messageSender) prepOrInvalidate(ctx context.Context) error { return nil } -func (ms *messageSender) prep(ctx context.Context) error { +func (ms *peerMessageSender) prep(ctx context.Context) error { if ms.invalid { return fmt.Errorf("message sender has been invalidated") } @@ -205,7 +208,7 @@ func (ms *messageSender) prep(ctx context.Context) error { // behaviour. const streamReuseTries = 3 -func (ms *messageSender) SendMessage(ctx context.Context, pmes *pb.Message) error { +func (ms *peerMessageSender) SendMessage(ctx context.Context, pmes *pb.Message) error { if err := ms.lk.Lock(ctx); err != nil { return err } @@ -242,7 +245,7 @@ func (ms *messageSender) SendMessage(ctx context.Context, pmes *pb.Message) erro } } -func (ms *messageSender) SendRequest(ctx context.Context, pmes *pb.Message) (*pb.Message, error) { +func (ms *peerMessageSender) SendRequest(ctx context.Context, pmes *pb.Message) (*pb.Message, error) { if err := ms.lk.Lock(ctx); err != nil { return nil, err } @@ -293,11 +296,11 @@ func (ms *messageSender) SendRequest(ctx context.Context, pmes *pb.Message) (*pb } } -func (ms *messageSender) writeMsg(pmes *pb.Message) error { +func (ms *peerMessageSender) writeMsg(pmes *pb.Message) error { return writeMsg(ms.s, pmes) } -func (ms *messageSender) ctxReadMsg(ctx context.Context, mes *pb.Message) error { +func (ms *peerMessageSender) ctxReadMsg(ctx context.Context, mes *pb.Message) error { errc := make(chan error, 1) go func(r msgio.ReadCloser) { defer close(errc) From 9ebf3068576275018b646c05ba57e363d3c8b406 Mon Sep 17 00:00:00 2001 From: Adin Schmahmann Date: Mon, 4 Jan 2021 13:58:30 -0500 Subject: [PATCH 13/15] refactor: factor out reused string --- pb/protocol_messenger.go | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pb/protocol_messenger.go b/pb/protocol_messenger.go index 97eb26477..7524f59b9 100644 --- a/pb/protocol_messenger.go +++ b/pb/protocol_messenger.go @@ -75,8 +75,9 @@ func (pm *ProtocolMessenger) PutValue(ctx context.Context, p peer.ID, rec *recpb } if !bytes.Equal(rpmes.GetRecord().Value, pmes.GetRecord().Value) { - logger.Infow("value not put correctly", "put-message", pmes, "get-message", rpmes) - return errors.New("value not put correctly") + const errStr = "value not put correctly" + logger.Infow(errStr, "put-message", pmes, "get-message", rpmes) + return errors.New(errStr) } return nil From 6fca41f4917b13cb13cac6cedb7548462a7e927f Mon Sep 17 00:00:00 2001 From: Adin Schmahmann Date: Mon, 4 Jan 2021 14:33:16 -0500 Subject: [PATCH 14/15] comments: added comments clarifying the contract of getLocal --- dht.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/dht.go b/dht.go index b75be2c5e..24c9f4569 100644 --- a/dht.go +++ b/dht.go @@ -536,7 +536,10 @@ func (dht *IpfsDHT) persistRTPeersInPeerStore() { } } -// getLocal attempts to retrieve the value from the datastore +// getLocal attempts to retrieve the value from the datastore. +// +// returns nil, nil when either nothing is found or the value found doesn't properly validate. +// returns nil, some_error when there's a *datastore* error (i.e., something goes very wrong) func (dht *IpfsDHT) getLocal(key string) (*recpb.Record, error) { logger.Debugw("finding value in datastore", "key", internal.LoggableRecordKeyString(key)) From ebd2d69f90de6572232010a9e6bd25fdb2c5e8ee Mon Sep 17 00:00:00 2001 From: Adin Schmahmann Date: Mon, 4 Jan 2021 14:37:17 -0500 Subject: [PATCH 15/15] rename messageManager to messageSenderImpl and renamed dht.messageMgr to dht.msgSender --- dht.go | 6 +++--- dht_test.go | 8 ++++---- message_manager.go | 14 +++++++------- subscriber_notifee.go | 2 +- 4 files changed, 15 insertions(+), 15 deletions(-) diff --git a/dht.go b/dht.go index 24c9f4569..7c89ab56a 100644 --- a/dht.go +++ b/dht.go @@ -96,7 +96,7 @@ type IpfsDHT struct { proc goprocess.Process protoMessenger *pb.ProtocolMessenger - messageMgr *messageManager + msgSender *messageSenderImpl plk sync.Mutex @@ -188,12 +188,12 @@ func New(ctx context.Context, h host.Host, options ...Option) (*IpfsDHT, error) dht.disableFixLowPeers = cfg.disableFixLowPeers dht.Validator = cfg.validator - dht.messageMgr = &messageManager{ + dht.msgSender = &messageSenderImpl{ host: h, strmap: make(map[peer.ID]*peerMessageSender), protocols: dht.protocols, } - dht.protoMessenger, err = pb.NewProtocolMessenger(dht.messageMgr, pb.WithValidator(dht.Validator)) + dht.protoMessenger, err = pb.NewProtocolMessenger(dht.msgSender, pb.WithValidator(dht.Validator)) if err != nil { return nil, err } diff --git a/dht_test.go b/dht_test.go index 37f5142f7..74835c8e2 100644 --- a/dht_test.go +++ b/dht_test.go @@ -570,14 +570,14 @@ func TestInvalidMessageSenderTracking(t *testing.T) { defer dht.Close() foo := peer.ID("asdasd") - _, err := dht.messageMgr.messageSenderForPeer(ctx, foo) + _, err := dht.msgSender.messageSenderForPeer(ctx, foo) if err == nil { t.Fatal("that shouldnt have succeeded") } - dht.messageMgr.smlk.Lock() - mscnt := len(dht.messageMgr.strmap) - dht.messageMgr.smlk.Unlock() + dht.msgSender.smlk.Lock() + mscnt := len(dht.msgSender.strmap) + dht.msgSender.smlk.Unlock() if mscnt > 0 { t.Fatal("should have no message senders in map") diff --git a/message_manager.go b/message_manager.go index edc96ee80..8cc3e22e3 100644 --- a/message_manager.go +++ b/message_manager.go @@ -19,16 +19,16 @@ import ( "go.opencensus.io/tag" ) -// messageManager is responsible for sending requests and messages to peers efficiently, including reuse of streams. +// messageSenderImpl is responsible for sending requests and messages to peers efficiently, including reuse of streams. // It also tracks metrics for sent requests and messages. -type messageManager struct { +type messageSenderImpl struct { host host.Host // the network services we need smlk sync.Mutex strmap map[peer.ID]*peerMessageSender protocols []protocol.ID } -func (m *messageManager) streamDisconnect(ctx context.Context, p peer.ID) { +func (m *messageSenderImpl) streamDisconnect(ctx context.Context, p peer.ID) { m.smlk.Lock() defer m.smlk.Unlock() ms, ok := m.strmap[p] @@ -49,7 +49,7 @@ func (m *messageManager) streamDisconnect(ctx context.Context, p peer.ID) { // SendRequest sends out a request, but also makes sure to // measure the RTT for latency measurements. -func (m *messageManager) SendRequest(ctx context.Context, p peer.ID, pmes *pb.Message) (*pb.Message, error) { +func (m *messageSenderImpl) SendRequest(ctx context.Context, p peer.ID, pmes *pb.Message) (*pb.Message, error) { ctx, _ = tag.New(ctx, metrics.UpsertMessageType(pmes)) ms, err := m.messageSenderForPeer(ctx, p) @@ -84,7 +84,7 @@ func (m *messageManager) SendRequest(ctx context.Context, p peer.ID, pmes *pb.Me } // SendMessage sends out a message -func (m *messageManager) SendMessage(ctx context.Context, p peer.ID, pmes *pb.Message) error { +func (m *messageSenderImpl) SendMessage(ctx context.Context, p peer.ID, pmes *pb.Message) error { ctx, _ = tag.New(ctx, metrics.UpsertMessageType(pmes)) ms, err := m.messageSenderForPeer(ctx, p) @@ -113,7 +113,7 @@ func (m *messageManager) SendMessage(ctx context.Context, p peer.ID, pmes *pb.Me return nil } -func (m *messageManager) messageSenderForPeer(ctx context.Context, p peer.ID) (*peerMessageSender, error) { +func (m *messageSenderImpl) messageSenderForPeer(ctx context.Context, p peer.ID) (*peerMessageSender, error) { m.smlk.Lock() ms, ok := m.strmap[p] if ok { @@ -151,7 +151,7 @@ type peerMessageSender struct { r msgio.ReadCloser lk ctxMutex p peer.ID - m *messageManager + m *messageSenderImpl invalid bool singleMes int diff --git a/subscriber_notifee.go b/subscriber_notifee.go index a9fa8849e..7cc9018f7 100644 --- a/subscriber_notifee.go +++ b/subscriber_notifee.go @@ -173,7 +173,7 @@ func (nn *subscriberNotifee) Disconnected(n network.Network, v network.Conn) { return } - dht.messageMgr.streamDisconnect(dht.Context(), p) + dht.msgSender.streamDisconnect(dht.Context(), p) } func (nn *subscriberNotifee) Connected(network.Network, network.Conn) {}