diff --git a/.travis.yml b/.travis.yml index a55583a703..8c0af291a3 100644 --- a/.travis.yml +++ b/.travis.yml @@ -16,7 +16,7 @@ jobs: os: linux arch: amd64 dist: bionic - go: 1.21.x + go: 1.22.x env: - docker services: @@ -33,7 +33,7 @@ jobs: os: linux arch: arm64 dist: bionic - go: 1.21.x + go: 1.22.x env: - docker services: @@ -51,7 +51,7 @@ jobs: os: linux dist: bionic sudo: required - go: 1.21.x + go: 1.22.x env: - azure-linux git: @@ -85,7 +85,7 @@ jobs: if: type = push os: osx osx_image: xcode14.2 - go: 1.21.x + go: 1.22.x env: - azure-osx git: @@ -101,7 +101,7 @@ jobs: os: linux arch: amd64 dist: bionic - go: 1.21.x + go: 1.22.x script: - travis_wait 30 go run build/ci.go test $TEST_PACKAGES @@ -110,14 +110,14 @@ jobs: os: linux arch: arm64 dist: bionic - go: 1.20.x + go: 1.21.x script: - travis_wait 30 go run build/ci.go test $TEST_PACKAGES - stage: build os: linux dist: bionic - go: 1.20.x + go: 1.21.x script: - travis_wait 30 go run build/ci.go test $TEST_PACKAGES @@ -126,7 +126,7 @@ jobs: if: type = cron || (type = push && tag ~= /^v[0-9]/) os: linux dist: bionic - go: 1.21.x + go: 1.22.x env: - ubuntu-ppa git: @@ -149,7 +149,7 @@ jobs: if: type = cron os: linux dist: bionic - go: 1.21.x + go: 1.22.x env: - azure-purge git: @@ -162,7 +162,7 @@ jobs: if: type = cron os: linux dist: bionic - go: 1.21.x + go: 1.22.x script: - travis_wait 30 go run build/ci.go test -race $TEST_PACKAGES diff --git a/Dockerfile b/Dockerfile index ed69a04789..ffd89905a7 100644 --- a/Dockerfile +++ b/Dockerfile @@ -4,7 +4,7 @@ ARG VERSION="" ARG BUILDNUM="" # Build Geth in a stock Go builder container -FROM golang:1.21-alpine as builder +FROM golang:1.22-alpine as builder RUN apk add --no-cache gcc musl-dev linux-headers git diff --git a/Dockerfile.alltools b/Dockerfile.alltools index c317da25fa..db256f5316 100644 --- a/Dockerfile.alltools +++ b/Dockerfile.alltools @@ -4,7 +4,7 @@ ARG VERSION="" ARG BUILDNUM="" # Build Geth in a stock Go builder container -FROM golang:1.21-alpine as builder +FROM golang:1.22-alpine as builder RUN apk add --no-cache gcc musl-dev linux-headers git diff --git a/accounts/scwallet/securechannel.go b/accounts/scwallet/securechannel.go index bbd8b22647..b3a7be8df0 100644 --- a/accounts/scwallet/securechannel.go +++ b/accounts/scwallet/securechannel.go @@ -20,7 +20,6 @@ import ( "bytes" "crypto/aes" "crypto/cipher" - "crypto/elliptic" "crypto/rand" "crypto/sha256" "crypto/sha512" @@ -72,11 +71,11 @@ func NewSecureChannelSession(card *pcsc.Card, keyData []byte) (*SecureChannelSes if err != nil { return nil, fmt.Errorf("could not unmarshal public key from card: %v", err) } - secret, _ := key.Curve.ScalarMult(cardPublic.X, cardPublic.Y, key.D.Bytes()) + secret, _ := crypto.S256().ScalarMult(cardPublic.X, cardPublic.Y, key.D.Bytes()) return &SecureChannelSession{ card: card, secret: secret.Bytes(), - publicKey: elliptic.Marshal(crypto.S256(), key.PublicKey.X, key.PublicKey.Y), + publicKey: crypto.FromECDSAPub(&key.PublicKey), }, nil } diff --git a/build/checksums.txt b/build/checksums.txt index 03a53946df..f92f739a2f 100644 --- a/build/checksums.txt +++ b/build/checksums.txt @@ -5,22 +5,22 @@ # https://github.com/ethereum/execution-spec-tests/releases/download/v2.1.0/ ca89c76851b0900bfcc3cbb9a26cbece1f3d7c64a3bed38723e914713290df6c fixtures_develop.tar.gz -# version:golang 1.21.6 +# version:golang 1.22.1 # https://go.dev/dl/ -124926a62e45f78daabbaedb9c011d97633186a33c238ffc1e25320c02046248 go1.21.6.src.tar.gz -31d6ecca09010ab351e51343a5af81d678902061fee871f912bdd5ef4d778850 go1.21.6.darwin-amd64.tar.gz -0ff541fb37c38e5e5c5bcecc8f4f43c5ffd5e3a6c33a5d3e4003ded66fcfb331 go1.21.6.darwin-arm64.tar.gz -a1d1a149b34bf0f53965a237682c6da1140acabb131bf0e597240e4a140b0e5e go1.21.6.freebsd-386.tar.gz -de59e1217e4398b1522eed8dddabab2fa1b97aecbdca3af08e34832b4f0e3f81 go1.21.6.freebsd-amd64.tar.gz -05d09041b5a1193c14e4b2db3f7fcc649b236c567f5eb93305c537851b72dd95 go1.21.6.linux-386.tar.gz -3f934f40ac360b9c01f616a9aa1796d227d8b0328bf64cb045c7b8c4ee9caea4 go1.21.6.linux-amd64.tar.gz -e2e8aa88e1b5170a0d495d7d9c766af2b2b6c6925a8f8956d834ad6b4cacbd9a go1.21.6.linux-arm64.tar.gz -6a8eda6cc6a799ff25e74ce0c13fdc1a76c0983a0bb07c789a2a3454bf6ec9b2 go1.21.6.linux-armv6l.tar.gz -e872b1e9a3f2f08fd4554615a32ca9123a4ba877ab6d19d36abc3424f86bc07f go1.21.6.linux-ppc64le.tar.gz -92894d0f732d3379bc414ffdd617eaadad47e1d72610e10d69a1156db03fc052 go1.21.6.linux-s390x.tar.gz -65b38857135cf45c80e1d267e0ce4f80fe149326c68835217da4f2da9b7943fe go1.21.6.windows-386.zip -27ac9dd6e66fb3fd0acfa6792ff053c86e7d2c055b022f4b5d53bfddec9e3301 go1.21.6.windows-amd64.zip -b93aff8f3c882c764c66a39b7a1483b0460e051e9992bf3435479129e5051bcd go1.21.6.windows-arm64.zip +79c9b91d7f109515a25fc3ecdaad125d67e6bdb54f6d4d98580f46799caea321 go1.22.1.src.tar.gz +3bc971772f4712fec0364f4bc3de06af22a00a12daab10b6f717fdcd13156cc0 go1.22.1.darwin-amd64.tar.gz +f6a9cec6b8a002fcc9c0ee24ec04d67f430a52abc3cfd613836986bcc00d8383 go1.22.1.darwin-arm64.tar.gz +99f81c10d5a3f8a886faf8fa86aaa2aaf929fbed54a972ae5eec3c5e0bdb961a go1.22.1.freebsd-386.tar.gz +51c614ddd92ee4a9913a14c39bf80508d9cfba08561f24d2f075fd00f3cfb067 go1.22.1.freebsd-amd64.tar.gz +8484df36d3d40139eaf0fe5e647b006435d826cc12f9ae72973bf7ec265e0ae4 go1.22.1.linux-386.tar.gz +aab8e15785c997ae20f9c88422ee35d962c4562212bb0f879d052a35c8307c7f go1.22.1.linux-amd64.tar.gz +e56685a245b6a0c592fc4a55f0b7803af5b3f827aaa29feab1f40e491acf35b8 go1.22.1.linux-arm64.tar.gz +8cb7a90e48c20daed39a6ac8b8a40760030ba5e93c12274c42191d868687c281 go1.22.1.linux-armv6l.tar.gz +ac775e19d93cc1668999b77cfe8c8964abfbc658718feccfe6e0eb87663cd668 go1.22.1.linux-ppc64le.tar.gz +7bb7dd8e10f95c9a4cc4f6bef44c816a6e7c9e03f56ac6af6efbb082b19b379f go1.22.1.linux-s390x.tar.gz +0c5ebb7eb39b7884ec99f92b425d4c03a96a72443562aafbf6e7d15c42a3108a go1.22.1.windows-386.zip +cf9c66a208a106402a527f5b956269ca506cfe535fc388e828d249ea88ed28ba go1.22.1.windows-amd64.zip +85b8511b298c9f4199ecae26afafcc3d46155bac934d43f2357b9224bcaa310f go1.22.1.windows-arm64.zip # version:golangci 1.55.2 # https://github.com/golangci/golangci-lint/releases/ diff --git a/cmd/devp2p/discv4cmd.go b/cmd/devp2p/discv4cmd.go index 45bcdcd367..3b5400ca3a 100644 --- a/cmd/devp2p/discv4cmd.go +++ b/cmd/devp2p/discv4cmd.go @@ -20,6 +20,7 @@ import ( "errors" "fmt" "net" + "net/http" "strconv" "strings" "time" @@ -28,9 +29,11 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/internal/flags" + "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/p2p/discover" "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/params" + "github.com/ethereum/go-ethereum/rpc" "github.com/urfave/cli/v2" ) @@ -45,6 +48,7 @@ var ( discv4ResolveJSONCommand, discv4CrawlCommand, discv4TestCommand, + discv4ListenCommand, }, } discv4PingCommand = &cli.Command{ @@ -75,6 +79,14 @@ var ( Flags: discoveryNodeFlags, ArgsUsage: "", } + discv4ListenCommand = &cli.Command{ + Name: "listen", + Usage: "Runs a discovery node", + Action: discv4Listen, + Flags: flags.Merge(discoveryNodeFlags, []cli.Flag{ + httpAddrFlag, + }), + } discv4CrawlCommand = &cli.Command{ Name: "crawl", Usage: "Updates a nodes.json file with random nodes found in the DHT", @@ -131,6 +143,10 @@ var ( Usage: "Enode of the remote node under test", EnvVars: []string{"REMOTE_ENODE"}, } + httpAddrFlag = &cli.StringFlag{ + Name: "rpc", + Usage: "HTTP server listening address", + } ) var discoveryNodeFlags = []cli.Flag{ @@ -154,6 +170,27 @@ func discv4Ping(ctx *cli.Context) error { return nil } +func discv4Listen(ctx *cli.Context) error { + disc, _ := startV4(ctx) + defer disc.Close() + + fmt.Println(disc.Self()) + + httpAddr := ctx.String(httpAddrFlag.Name) + if httpAddr == "" { + // Non-HTTP mode. + select {} + } + + api := &discv4API{disc} + log.Info("Starting RPC API server", "addr", httpAddr) + srv := rpc.NewServer() + srv.RegisterName("discv4", api) + http.DefaultServeMux.Handle("/", srv) + httpsrv := http.Server{Addr: httpAddr, Handler: http.DefaultServeMux} + return httpsrv.ListenAndServe() +} + func discv4RequestRecord(ctx *cli.Context) error { n := getNodeArg(ctx) disc, _ := startV4(ctx) @@ -362,3 +399,23 @@ func parseBootnodes(ctx *cli.Context) ([]*enode.Node, error) { } return nodes, nil } + +type discv4API struct { + host *discover.UDPv4 +} + +func (api *discv4API) LookupRandom(n int) (ns []*enode.Node) { + it := api.host.RandomNodes() + for len(ns) < n && it.Next() { + ns = append(ns, it.Node()) + } + return ns +} + +func (api *discv4API) Buckets() [][]discover.BucketNode { + return api.host.TableBuckets() +} + +func (api *discv4API) Self() *enode.Node { + return api.host.Self() +} diff --git a/cmd/devp2p/internal/v4test/framework.go b/cmd/devp2p/internal/v4test/framework.go index 9286594181..e8f4c021b8 100644 --- a/cmd/devp2p/internal/v4test/framework.go +++ b/cmd/devp2p/internal/v4test/framework.go @@ -110,7 +110,7 @@ func (te *testenv) localEndpoint(c net.PacketConn) v4wire.Endpoint { } func (te *testenv) remoteEndpoint() v4wire.Endpoint { - return v4wire.NewEndpoint(te.remoteAddr, 0) + return v4wire.NewEndpoint(te.remoteAddr.AddrPort(), 0) } func contains(ns []v4wire.Node, key v4wire.Pubkey) bool { diff --git a/cmd/era/main.go b/cmd/era/main.go index e27d8ccec6..c7f5de12bc 100644 --- a/cmd/era/main.go +++ b/cmd/era/main.go @@ -18,6 +18,7 @@ package main import ( "encoding/json" + "errors" "fmt" "math/big" "os" @@ -182,7 +183,7 @@ func open(ctx *cli.Context, epoch uint64) (*era.Era, error) { // that the accumulator matches the expected value. func verify(ctx *cli.Context) error { if ctx.Args().Len() != 1 { - return fmt.Errorf("missing accumulators file") + return errors.New("missing accumulators file") } roots, err := readHashes(ctx.Args().First()) @@ -203,7 +204,7 @@ func verify(ctx *cli.Context) error { } if len(entries) != len(roots) { - return fmt.Errorf("number of era1 files should match the number of accumulator hashes") + return errors.New("number of era1 files should match the number of accumulator hashes") } // Verify each epoch matches the expected root. @@ -308,7 +309,7 @@ func checkAccumulator(e *era.Era) error { func readHashes(f string) ([]common.Hash, error) { b, err := os.ReadFile(f) if err != nil { - return nil, fmt.Errorf("unable to open accumulators file") + return nil, errors.New("unable to open accumulators file") } s := strings.Split(string(b), "\n") // Remove empty last element, if present. diff --git a/cmd/geth/chaincmd.go b/cmd/geth/chaincmd.go index 2987bc5573..146f2a6b5c 100644 --- a/cmd/geth/chaincmd.go +++ b/cmd/geth/chaincmd.go @@ -446,7 +446,7 @@ func importHistory(ctx *cli.Context) error { return fmt.Errorf("no era1 files found in %s", dir) } if len(networks) > 1 { - return fmt.Errorf("multiple networks found, use a network flag to specify desired network") + return errors.New("multiple networks found, use a network flag to specify desired network") } network = networks[0] } diff --git a/cmd/utils/cmd.go b/cmd/utils/cmd.go index 6cbbeddfa3..8d00bfbacb 100644 --- a/cmd/utils/cmd.go +++ b/cmd/utils/cmd.go @@ -245,7 +245,7 @@ func readList(filename string) ([]string, error) { // starting from genesis. func ImportHistory(chain *core.BlockChain, db ethdb.Database, dir string, network string) error { if chain.CurrentSnapBlock().Number.BitLen() != 0 { - return fmt.Errorf("history import only supported when starting from genesis") + return errors.New("history import only supported when starting from genesis") } entries, err := era.ReadDir(dir, network) if err != nil { diff --git a/consensus/ethash/algorithm.go b/consensus/ethash/algorithm.go index 77af64dc70..4a3a397216 100644 --- a/consensus/ethash/algorithm.go +++ b/consensus/ethash/algorithm.go @@ -20,7 +20,6 @@ import ( "encoding/binary" "hash" "math/big" - "reflect" "runtime" "sync" "sync/atomic" @@ -176,12 +175,7 @@ func generateCache(dest []uint32, epoch uint64, epochLength uint64, seed []byte) logFn("Generated ethash verification cache", "epochLength", epochLength, "elapsed", common.PrettyDuration(elapsed)) }() // Convert our destination slice to a byte buffer - var cache []byte - cacheHdr := (*reflect.SliceHeader)(unsafe.Pointer(&cache)) - dstHdr := (*reflect.SliceHeader)(unsafe.Pointer(&dest)) - cacheHdr.Data = dstHdr.Data - cacheHdr.Len = dstHdr.Len * 4 - cacheHdr.Cap = dstHdr.Cap * 4 + cache := unsafe.Slice((*byte)(unsafe.Pointer(&dest[0])), len(dest)*4) // Calculate the number of theoretical rows (we'll store in one buffer nonetheless) size := uint64(len(cache)) @@ -310,12 +304,7 @@ func generateDataset(dest []uint32, epoch uint64, epochLength uint64, cache []ui swapped := !isLittleEndian() // Convert our destination slice to a byte buffer - var dataset []byte - datasetHdr := (*reflect.SliceHeader)(unsafe.Pointer(&dataset)) - destHdr := (*reflect.SliceHeader)(unsafe.Pointer(&dest)) - datasetHdr.Data = destHdr.Data - datasetHdr.Len = destHdr.Len * 4 - datasetHdr.Cap = destHdr.Cap * 4 + dataset := unsafe.Slice((*byte)(unsafe.Pointer(&dest[0])), len(dest)*4) // Generate the dataset on many goroutines since it takes a while threads := runtime.NumCPU() diff --git a/consensus/ethash/ethash.go b/consensus/ethash/ethash.go index 8258aaa225..978e6eeb2e 100644 --- a/consensus/ethash/ethash.go +++ b/consensus/ethash/ethash.go @@ -26,7 +26,6 @@ import ( "math/rand" "os" "path/filepath" - "reflect" "runtime" "strconv" "sync" @@ -144,11 +143,7 @@ func memoryMapFile(file *os.File, write bool) (mmap.MMap, []uint32, error) { return nil, nil, err } // The file is now memory-mapped. Create a []uint32 view of the file. - var view []uint32 - header := (*reflect.SliceHeader)(unsafe.Pointer(&view)) - header.Data = (*reflect.SliceHeader)(unsafe.Pointer(&mem)).Data - header.Cap = len(mem) / 4 - header.Len = header.Cap + view := unsafe.Slice((*uint32)(unsafe.Pointer(&mem[0])), len(mem)/4) return mem, view, nil } diff --git a/core/txpool/validation.go b/core/txpool/validation.go index f940224713..4dbe3946bc 100644 --- a/core/txpool/validation.go +++ b/core/txpool/validation.go @@ -18,6 +18,7 @@ package txpool import ( "crypto/sha256" + "errors" "fmt" "math/big" @@ -122,13 +123,13 @@ func ValidateTransaction(tx *types.Transaction, head *types.Header, signer types } sidecar := tx.BlobTxSidecar() if sidecar == nil { - return fmt.Errorf("missing sidecar in blob transaction") + return errors.New("missing sidecar in blob transaction") } // Ensure the number of items in the blob transaction and various side // data match up before doing any expensive validations hashes := tx.BlobHashes() if len(hashes) == 0 { - return fmt.Errorf("blobless blob transaction") + return errors.New("blobless blob transaction") } if len(hashes) > vars.MaxBlobGasPerBlock/vars.BlobTxBlobGasPerBlob { return fmt.Errorf("too many blobs in transaction: have %d, permitted %d", len(hashes), vars.MaxBlobGasPerBlock/vars.BlobTxBlobGasPerBlob) diff --git a/crypto/crypto.go b/crypto/crypto.go index 2492165d38..734feed5ca 100644 --- a/crypto/crypto.go +++ b/crypto/crypto.go @@ -51,6 +51,15 @@ var ( var errInvalidPubkey = errors.New("invalid secp256k1 public key") +// EllipticCurve contains curve operations. +type EllipticCurve interface { + elliptic.Curve + + // Point marshaling/unmarshaing. + Marshal(x, y *big.Int) []byte + Unmarshal(data []byte) (x, y *big.Int) +} + // KeccakState wraps sha3.state. In addition to the usual hash methods, it also supports // Read to get a variable amount of data from the hash state. Read is faster than Sum // because it doesn't copy the internal state, but also modifies the internal state. @@ -148,7 +157,7 @@ func toECDSA(d []byte, strict bool) (*ecdsa.PrivateKey, error) { return nil, errors.New("invalid private key, zero or negative") } - priv.PublicKey.X, priv.PublicKey.Y = priv.PublicKey.Curve.ScalarBaseMult(d) + priv.PublicKey.X, priv.PublicKey.Y = S256().ScalarBaseMult(d) if priv.PublicKey.X == nil { return nil, errors.New("invalid private key") } @@ -165,7 +174,7 @@ func FromECDSA(priv *ecdsa.PrivateKey) []byte { // UnmarshalPubkey converts bytes to a secp256k1 public key. func UnmarshalPubkey(pub []byte) (*ecdsa.PublicKey, error) { - x, y := elliptic.Unmarshal(S256(), pub) + x, y := S256().Unmarshal(pub) if x == nil { return nil, errInvalidPubkey } @@ -176,7 +185,7 @@ func FromECDSAPub(pub *ecdsa.PublicKey) []byte { if pub == nil || pub.X == nil || pub.Y == nil { return nil } - return elliptic.Marshal(S256(), pub.X, pub.Y) + return S256().Marshal(pub.X, pub.Y) } // HexToECDSA parses a secp256k1 private key. diff --git a/crypto/ecies/ecies.go b/crypto/ecies/ecies.go index 738bb8f584..1b6c9e97c1 100644 --- a/crypto/ecies/ecies.go +++ b/crypto/ecies/ecies.go @@ -40,6 +40,8 @@ import ( "hash" "io" "math/big" + + "github.com/ethereum/go-ethereum/crypto" ) var ( @@ -95,15 +97,15 @@ func ImportECDSA(prv *ecdsa.PrivateKey) *PrivateKey { // Generate an elliptic curve public / private keypair. If params is nil, // the recommended default parameters for the key will be chosen. func GenerateKey(rand io.Reader, curve elliptic.Curve, params *ECIESParams) (prv *PrivateKey, err error) { - pb, x, y, err := elliptic.GenerateKey(curve, rand) + sk, err := ecdsa.GenerateKey(curve, rand) if err != nil { return } prv = new(PrivateKey) - prv.PublicKey.X = x - prv.PublicKey.Y = y + prv.PublicKey.X = sk.X + prv.PublicKey.Y = sk.Y prv.PublicKey.Curve = curve - prv.D = new(big.Int).SetBytes(pb) + prv.D = new(big.Int).Set(sk.D) if params == nil { params = ParamsFromCurve(curve) } @@ -255,12 +257,15 @@ func Encrypt(rand io.Reader, pub *PublicKey, m, s1, s2 []byte) (ct []byte, err e d := messageTag(params.Hash, Km, em, s2) - Rb := elliptic.Marshal(pub.Curve, R.PublicKey.X, R.PublicKey.Y) - ct = make([]byte, len(Rb)+len(em)+len(d)) - copy(ct, Rb) - copy(ct[len(Rb):], em) - copy(ct[len(Rb)+len(em):], d) - return ct, nil + if curve, ok := pub.Curve.(crypto.EllipticCurve); ok { + Rb := curve.Marshal(R.PublicKey.X, R.PublicKey.Y) + ct = make([]byte, len(Rb)+len(em)+len(d)) + copy(ct, Rb) + copy(ct[len(Rb):], em) + copy(ct[len(Rb)+len(em):], d) + return ct, nil + } + return nil, ErrInvalidCurve } // Decrypt decrypts an ECIES ciphertext. @@ -297,21 +302,24 @@ func (prv *PrivateKey) Decrypt(c, s1, s2 []byte) (m []byte, err error) { R := new(PublicKey) R.Curve = prv.PublicKey.Curve - R.X, R.Y = elliptic.Unmarshal(R.Curve, c[:rLen]) - if R.X == nil { - return nil, ErrInvalidPublicKey - } - z, err := prv.GenerateShared(R, params.KeyLen, params.KeyLen) - if err != nil { - return nil, err - } - Ke, Km := deriveKeys(hash, z, s1, params.KeyLen) + if curve, ok := R.Curve.(crypto.EllipticCurve); ok { + R.X, R.Y = curve.Unmarshal(c[:rLen]) + if R.X == nil { + return nil, ErrInvalidPublicKey + } - d := messageTag(params.Hash, Km, c[mStart:mEnd], s2) - if subtle.ConstantTimeCompare(c[mEnd:], d) != 1 { - return nil, ErrInvalidMessage - } + z, err := prv.GenerateShared(R, params.KeyLen, params.KeyLen) + if err != nil { + return nil, err + } + Ke, Km := deriveKeys(hash, z, s1, params.KeyLen) - return symDecrypt(params, Ke, c[mStart:mEnd]) + d := messageTag(params.Hash, Km, c[mStart:mEnd], s2) + if subtle.ConstantTimeCompare(c[mEnd:], d) != 1 { + return nil, ErrInvalidMessage + } + return symDecrypt(params, Ke, c[mStart:mEnd]) + } + return nil, ErrInvalidCurve } diff --git a/crypto/secp256k1/secp256_test.go b/crypto/secp256k1/secp256_test.go index 74408d06d2..8bb870fa18 100644 --- a/crypto/secp256k1/secp256_test.go +++ b/crypto/secp256k1/secp256_test.go @@ -10,7 +10,6 @@ package secp256k1 import ( "bytes" "crypto/ecdsa" - "crypto/elliptic" "crypto/rand" "encoding/hex" "io" @@ -24,7 +23,7 @@ func generateKeyPair() (pubkey, privkey []byte) { if err != nil { panic(err) } - pubkey = elliptic.Marshal(S256(), key.X, key.Y) + pubkey = S256().Marshal(key.X, key.Y) privkey = make([]byte, 32) blob := key.D.Bytes() diff --git a/crypto/signature_cgo.go b/crypto/signature_cgo.go index 2339e52015..87289253c0 100644 --- a/crypto/signature_cgo.go +++ b/crypto/signature_cgo.go @@ -21,7 +21,6 @@ package crypto import ( "crypto/ecdsa" - "crypto/elliptic" "errors" "fmt" @@ -40,9 +39,7 @@ func SigToPub(hash, sig []byte) (*ecdsa.PublicKey, error) { if err != nil { return nil, err } - - x, y := elliptic.Unmarshal(S256(), s) - return &ecdsa.PublicKey{Curve: S256(), X: x, Y: y}, nil + return UnmarshalPubkey(s) } // Sign calculates an ECDSA signature. @@ -84,6 +81,6 @@ func CompressPubkey(pubkey *ecdsa.PublicKey) []byte { } // S256 returns an instance of the secp256k1 curve. -func S256() elliptic.Curve { +func S256() EllipticCurve { return secp256k1.S256() } diff --git a/crypto/signature_nocgo.go b/crypto/signature_nocgo.go index 6d628d758d..f70617019e 100644 --- a/crypto/signature_nocgo.go +++ b/crypto/signature_nocgo.go @@ -21,9 +21,9 @@ package crypto import ( "crypto/ecdsa" - "crypto/elliptic" "errors" "fmt" + "math/big" "github.com/btcsuite/btcd/btcec/v2" btc_ecdsa "github.com/btcsuite/btcd/btcec/v2/ecdsa" @@ -58,7 +58,13 @@ func SigToPub(hash, sig []byte) (*ecdsa.PublicKey, error) { if err != nil { return nil, err } - return pub.ToECDSA(), nil + // We need to explicitly set the curve here, because we're wrapping + // the original curve to add (un-)marshalling + return &ecdsa.PublicKey{ + Curve: S256(), + X: pub.X(), + Y: pub.Y(), + }, nil } // Sign calculates an ECDSA signature. @@ -73,7 +79,7 @@ func Sign(hash []byte, prv *ecdsa.PrivateKey) ([]byte, error) { if len(hash) != 32 { return nil, fmt.Errorf("hash is required to be exactly 32 bytes (%d)", len(hash)) } - if prv.Curve != btcec.S256() { + if prv.Curve != S256() { return nil, errors.New("private key curve is not secp256k1") } // ecdsa.PrivateKey -> btcec.PrivateKey @@ -128,7 +134,13 @@ func DecompressPubkey(pubkey []byte) (*ecdsa.PublicKey, error) { if err != nil { return nil, err } - return key.ToECDSA(), nil + // We need to explicitly set the curve here, because we're wrapping + // the original curve to add (un-)marshalling + return &ecdsa.PublicKey{ + Curve: S256(), + X: key.X(), + Y: key.Y(), + }, nil } // CompressPubkey encodes a public key to the 33-byte compressed format. The @@ -147,6 +159,38 @@ func CompressPubkey(pubkey *ecdsa.PublicKey) []byte { } // S256 returns an instance of the secp256k1 curve. -func S256() elliptic.Curve { - return btcec.S256() +func S256() EllipticCurve { + return btCurve{btcec.S256()} +} + +type btCurve struct { + *btcec.KoblitzCurve +} + +// Marshall converts a point given as (x, y) into a byte slice. +func (curve btCurve) Marshal(x, y *big.Int) []byte { + byteLen := (curve.Params().BitSize + 7) / 8 + + ret := make([]byte, 1+2*byteLen) + ret[0] = 4 // uncompressed point + + x.FillBytes(ret[1 : 1+byteLen]) + y.FillBytes(ret[1+byteLen : 1+2*byteLen]) + + return ret +} + +// Unmarshal converts a point, serialised by Marshal, into an x, y pair. On +// error, x = nil. +func (curve btCurve) Unmarshal(data []byte) (x, y *big.Int) { + byteLen := (curve.Params().BitSize + 7) / 8 + if len(data) != 1+2*byteLen { + return nil, nil + } + if data[0] != 4 { // uncompressed form + return nil, nil + } + x = new(big.Int).SetBytes(data[1 : 1+byteLen]) + y = new(big.Int).SetBytes(data[1+byteLen:]) + return } diff --git a/eth/peerset.go b/eth/peerset.go index 9e8ae1cc87..4cd8bf6b04 100644 --- a/eth/peerset.go +++ b/eth/peerset.go @@ -100,7 +100,7 @@ func (ps *peerSet) registerSnapExtension(peer *snap.Peer) error { return nil } -// waitExtensions blocks until all satellite protocols are connected and tracked +// waitSnapExtension blocks until all satellite protocols are connected and tracked // by the peerset. func (ps *peerSet) waitSnapExtension(peer *eth.Peer) (*snap.Peer, error) { // If the peer does not support a compatible `snap`, don't wait diff --git a/eth/protocols/eth/dispatcher.go b/eth/protocols/eth/dispatcher.go index ae98820cd6..146eec3f60 100644 --- a/eth/protocols/eth/dispatcher.go +++ b/eth/protocols/eth/dispatcher.go @@ -136,7 +136,7 @@ func (p *Peer) dispatchRequest(req *Request) error { } } -// dispatchRequest fulfils a pending request and delivers it to the requested +// dispatchResponse fulfils a pending request and delivers it to the requested // sink. func (p *Peer) dispatchResponse(res *Response, metadata func() interface{}) error { resOp := &response{ diff --git a/ethclient/ethclient_test.go b/ethclient/ethclient_test.go index 86a2ee92b9..eee767d559 100644 --- a/ethclient/ethclient_test.go +++ b/ethclient/ethclient_test.go @@ -867,7 +867,7 @@ func TestRPCDiscover(t *testing.T) { responseDocument, _ := json.MarshalIndent(r, "", " ") t.Logf(`Response Document: - + %s`, string(responseDocument)) t.Fatalf(`OVER (methods which do not appear in the current API, but exist in the hardcoded response document):): %v @@ -1169,6 +1169,7 @@ var allRPCMethods = []string{ "debug_dbAncient", "debug_dbAncients", "debug_dbGet", + "debug_discoveryV4Table", "debug_dumpBlock", "debug_freeOSMemory", "debug_gcStats", diff --git a/internal/era/accumulator.go b/internal/era/accumulator.go index 19e03973f1..2ece2755e1 100644 --- a/internal/era/accumulator.go +++ b/internal/era/accumulator.go @@ -17,6 +17,7 @@ package era import ( + "errors" "fmt" "math/big" @@ -28,7 +29,7 @@ import ( // accumulator of header records. func ComputeAccumulator(hashes []common.Hash, tds []*big.Int) (common.Hash, error) { if len(hashes) != len(tds) { - return common.Hash{}, fmt.Errorf("must have equal number hashes as td values") + return common.Hash{}, errors.New("must have equal number hashes as td values") } if len(hashes) > MaxEra1Size { return common.Hash{}, fmt.Errorf("too many records: have %d, max %d", len(hashes), MaxEra1Size) diff --git a/internal/era/builder.go b/internal/era/builder.go index 9217c049f3..75782a08c2 100644 --- a/internal/era/builder.go +++ b/internal/era/builder.go @@ -18,6 +18,7 @@ package era import ( "bytes" "encoding/binary" + "errors" "fmt" "io" "math/big" @@ -158,7 +159,7 @@ func (b *Builder) AddRLP(header, body, receipts []byte, number uint64, hash comm // corresponding e2store entries. func (b *Builder) Finalize() (common.Hash, error) { if b.startNum == nil { - return common.Hash{}, fmt.Errorf("finalize called on empty builder") + return common.Hash{}, errors.New("finalize called on empty builder") } // Compute accumulator root and write entry. root, err := ComputeAccumulator(b.hashes, b.tds) diff --git a/internal/era/e2store/e2store.go b/internal/era/e2store/e2store.go index d85b3e44e9..8e4d5dd24a 100644 --- a/internal/era/e2store/e2store.go +++ b/internal/era/e2store/e2store.go @@ -18,6 +18,7 @@ package e2store import ( "encoding/binary" + "errors" "fmt" "io" ) @@ -160,7 +161,7 @@ func (r *Reader) ReadMetadataAt(off int64) (typ uint16, length uint32, err error // Check reserved bytes of header. if b[6] != 0 || b[7] != 0 { - return 0, 0, fmt.Errorf("reserved bytes are non-zero") + return 0, 0, errors.New("reserved bytes are non-zero") } return typ, length, nil diff --git a/internal/era/e2store/e2store_test.go b/internal/era/e2store/e2store_test.go index febcffe4cf..b0803493c7 100644 --- a/internal/era/e2store/e2store_test.go +++ b/internal/era/e2store/e2store_test.go @@ -18,7 +18,7 @@ package e2store import ( "bytes" - "fmt" + "errors" "io" "testing" @@ -92,7 +92,7 @@ func TestDecode(t *testing.T) { }, { // basic invalid decoding have: "ffff000000000001", - err: fmt.Errorf("reserved bytes are non-zero"), + err: errors.New("reserved bytes are non-zero"), }, { // no more entries to read, returns EOF have: "", diff --git a/internal/era/era.go b/internal/era/era.go index a0e701b7e0..2099c2d575 100644 --- a/internal/era/era.go +++ b/internal/era/era.go @@ -18,6 +18,7 @@ package era import ( "encoding/binary" + "errors" "fmt" "io" "math/big" @@ -127,7 +128,7 @@ func (e *Era) Close() error { func (e *Era) GetBlockByNumber(num uint64) (*types.Block, error) { if e.m.start > num || e.m.start+e.m.count <= num { - return nil, fmt.Errorf("out-of-bounds") + return nil, errors.New("out-of-bounds") } off, err := e.readOffset(num) if err != nil { diff --git a/internal/era/iterator.go b/internal/era/iterator.go index e74a8154b1..d90e9586a4 100644 --- a/internal/era/iterator.go +++ b/internal/era/iterator.go @@ -17,6 +17,7 @@ package era import ( + "errors" "fmt" "io" "math/big" @@ -30,7 +31,7 @@ type Iterator struct { inner *RawIterator } -// NewRawIterator returns a new Iterator instance. Next must be immediately +// NewIterator returns a new Iterator instance. Next must be immediately // called on new iterators to load the first item. func NewIterator(e *Era) (*Iterator, error) { inner, err := NewRawIterator(e) @@ -61,7 +62,7 @@ func (it *Iterator) Error() error { // Block returns the block for the iterator's current position. func (it *Iterator) Block() (*types.Block, error) { if it.inner.Header == nil || it.inner.Body == nil { - return nil, fmt.Errorf("header and body must be non-nil") + return nil, errors.New("header and body must be non-nil") } var ( header types.Header diff --git a/internal/testlog/testlog.go b/internal/testlog/testlog.go index 037b7ee9c1..c6c656265d 100644 --- a/internal/testlog/testlog.go +++ b/internal/testlog/testlog.go @@ -55,7 +55,7 @@ func (h *bufHandler) Handle(_ context.Context, r slog.Record) error { } func (h *bufHandler) Enabled(_ context.Context, lvl slog.Level) bool { - return lvl <= h.level + return lvl >= h.level } func (h *bufHandler) WithAttrs(attrs []slog.Attr) slog.Handler { diff --git a/metrics/sample.go b/metrics/sample.go index 5398dd42d5..bb81e105cf 100644 --- a/metrics/sample.go +++ b/metrics/sample.go @@ -148,7 +148,7 @@ func (NilSample) Clear() {} func (NilSample) Snapshot() SampleSnapshot { return (*emptySnapshot)(nil) } func (NilSample) Update(v int64) {} -// SamplePercentiles returns an arbitrary percentile of the slice of int64. +// SamplePercentile returns an arbitrary percentile of the slice of int64. func SamplePercentile(values []int64, p float64) float64 { return CalculatePercentiles(values, []float64{p})[0] } diff --git a/miner/worker.go b/miner/worker.go index 68ec3bf379..6b2804da08 100644 --- a/miner/worker.go +++ b/miner/worker.go @@ -1086,7 +1086,7 @@ func (w *worker) prepareWork(genParams *generateParams) (*environment, error) { if genParams.parentHash != (common.Hash{}) { block := w.chain.GetBlockByHash(genParams.parentHash) if block == nil { - return nil, fmt.Errorf("missing parent") + return nil, errors.New("missing parent") } parent = block.Header() } diff --git a/node/api.go b/node/api.go index f81f394beb..e9e7374206 100644 --- a/node/api.go +++ b/node/api.go @@ -26,6 +26,7 @@ import ( "github.com/ethereum/go-ethereum/internal/debug" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/discover" "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/rpc" ) @@ -39,6 +40,9 @@ func (n *Node) apis() []rpc.API { }, { Namespace: "debug", Service: debug.Handler, + }, { + Namespace: "debug", + Service: &p2pDebugAPI{n}, }, { Namespace: "web3", Service: &web3API{n}, @@ -335,3 +339,16 @@ func (s *web3API) ClientVersion() string { func (s *web3API) Sha3(input hexutil.Bytes) hexutil.Bytes { return crypto.Keccak256(input) } + +// p2pDebugAPI provides access to p2p internals for debugging. +type p2pDebugAPI struct { + stack *Node +} + +func (s *p2pDebugAPI) DiscoveryV4Table() [][]discover.BucketNode { + disc := s.stack.server.DiscoveryV4() + if disc != nil { + return disc.TableBuckets() + } + return nil +} diff --git a/node/rpcstack.go b/node/rpcstack.go index d80d5271a7..253db0d564 100644 --- a/node/rpcstack.go +++ b/node/rpcstack.go @@ -19,6 +19,7 @@ package node import ( "compress/gzip" "context" + "errors" "fmt" "io" "net" @@ -299,7 +300,7 @@ func (h *httpServer) enableRPC(apis []rpc.API, config httpConfig) error { defer h.mu.Unlock() if h.rpcAllowed() { - return fmt.Errorf("JSON-RPC over HTTP is already enabled") + return errors.New("JSON-RPC over HTTP is already enabled") } // Create RPC server and handler. @@ -335,7 +336,7 @@ func (h *httpServer) enableWS(apis []rpc.API, config wsConfig) error { defer h.mu.Unlock() if h.wsAllowed() { - return fmt.Errorf("JSON-RPC over WebSocket is already enabled") + return errors.New("JSON-RPC over WebSocket is already enabled") } // Create RPC server and handler. srv := rpc.NewServer() diff --git a/p2p/dial.go b/p2p/dial.go index 5e4ab1d50d..08e1db2877 100644 --- a/p2p/dial.go +++ b/p2p/dial.go @@ -25,6 +25,7 @@ import ( mrand "math/rand" "net" "sync" + "sync/atomic" "time" "github.com/ethereum/go-ethereum/common/mclock" @@ -248,7 +249,7 @@ loop: } case task := <-d.doneCh: - id := task.dest.ID() + id := task.dest().ID() delete(d.dialing, id) d.updateStaticPool(id) d.doneSinceLastLog++ @@ -410,7 +411,7 @@ func (d *dialScheduler) startStaticDials(n int) (started int) { // updateStaticPool attempts to move the given static dial back into staticPool. func (d *dialScheduler) updateStaticPool(id enode.ID) { task, ok := d.static[id] - if ok && task.staticPoolIndex < 0 && d.checkDial(task.dest) == nil { + if ok && task.staticPoolIndex < 0 && d.checkDial(task.dest()) == nil { d.addToStaticPool(task) } } @@ -437,10 +438,11 @@ func (d *dialScheduler) removeFromStaticPool(idx int) { // startDial runs the given dial task in a separate goroutine. func (d *dialScheduler) startDial(task *dialTask) { - d.log.Trace("Starting p2p dial", "id", task.dest.ID(), "ip", task.dest.IP(), "flag", task.flags) - hkey := string(task.dest.ID().Bytes()) + node := task.dest() + d.log.Trace("Starting p2p dial", "id", node.ID(), "ip", node.IP(), "flag", task.flags) + hkey := string(node.ID().Bytes()) d.history.add(hkey, d.clock.Now().Add(dialHistoryExpiration)) - d.dialing[task.dest.ID()] = task + d.dialing[node.ID()] = task go func() { task.run(d) d.doneCh <- task @@ -451,39 +453,46 @@ func (d *dialScheduler) startDial(task *dialTask) { type dialTask struct { staticPoolIndex int flags connFlag + // These fields are private to the task and should not be // accessed by dialScheduler while the task is running. - dest *enode.Node + destPtr atomic.Pointer[enode.Node] lastResolved mclock.AbsTime resolveDelay time.Duration } func newDialTask(dest *enode.Node, flags connFlag) *dialTask { - return &dialTask{dest: dest, flags: flags, staticPoolIndex: -1} + t := &dialTask{flags: flags, staticPoolIndex: -1} + t.destPtr.Store(dest) + return t } type dialError struct { error } +func (t *dialTask) dest() *enode.Node { + return t.destPtr.Load() +} + func (t *dialTask) run(d *dialScheduler) { if t.needResolve() && !t.resolve(d) { return } - err := t.dial(d, t.dest) + err := t.dial(d, t.dest()) if err != nil { // For static nodes, resolve one more time if dialing fails. if _, ok := err.(*dialError); ok && t.flags&staticDialedConn != 0 { if t.resolve(d) { - t.dial(d, t.dest) + t.dial(d, t.dest()) } } } } func (t *dialTask) needResolve() bool { - return t.flags&staticDialedConn != 0 && t.dest.IP() == nil + return t.flags&staticDialedConn != 0 && t.dest().IP() == nil } // resolve attempts to find the current endpoint for the destination @@ -502,29 +511,31 @@ func (t *dialTask) resolve(d *dialScheduler) bool { if t.lastResolved > 0 && time.Duration(d.clock.Now()-t.lastResolved) < t.resolveDelay { return false } - resolved := d.resolver.Resolve(t.dest) + + node := t.dest() + resolved := d.resolver.Resolve(node) t.lastResolved = d.clock.Now() if resolved == nil { t.resolveDelay *= 2 if t.resolveDelay > maxResolveDelay { t.resolveDelay = maxResolveDelay } - d.log.Debug("Resolving node failed", "id", t.dest.ID(), "newdelay", t.resolveDelay) + d.log.Debug("Resolving node failed", "id", node.ID(), "newdelay", t.resolveDelay) return false } // The node was found. t.resolveDelay = initialResolveDelay - t.dest = resolved - d.log.Debug("Resolved node", "id", t.dest.ID(), "addr", &net.TCPAddr{IP: t.dest.IP(), Port: t.dest.TCP()}) + t.destPtr.Store(resolved) + d.log.Debug("Resolved node", "id", resolved.ID(), "addr", &net.TCPAddr{IP: resolved.IP(), Port: resolved.TCP()}) return true } // dial performs the actual connection attempt. func (t *dialTask) dial(d *dialScheduler, dest *enode.Node) error { dialMeter.Mark(1) - fd, err := d.dialer.Dial(d.ctx, t.dest) + fd, err := d.dialer.Dial(d.ctx, dest) if err != nil { - d.log.Trace("Dial error", "id", t.dest.ID(), "addr", nodeAddr(t.dest), "conn", t.flags, "err", cleanupDialErr(err)) + d.log.Trace("Dial error", "id", dest.ID(), "addr", nodeAddr(dest), "conn", t.flags, "err", cleanupDialErr(err)) dialConnectionError.Mark(1) return &dialError{err} } @@ -532,8 +543,9 @@ func (t *dialTask) dial(d *dialScheduler, dest *enode.Node) error { } func (t *dialTask) String() string { - id := t.dest.ID() - return fmt.Sprintf("%v %x %v:%d", t.flags, id[:8], t.dest.IP(), t.dest.TCP()) + node := t.dest() + id := node.ID() + return fmt.Sprintf("%v %x %v:%d", t.flags, id[:8], node.IP(), node.TCP()) } func cleanupDialErr(err error) error { diff --git a/p2p/discover/common.go b/p2p/discover/common.go index c9f0477def..0716f7472f 100644 --- a/p2p/discover/common.go +++ b/p2p/discover/common.go @@ -18,7 +18,12 @@ package discover import ( "crypto/ecdsa" + crand "crypto/rand" + "encoding/binary" + "math/rand" "net" + "net/netip" + "sync" "time" "github.com/ethereum/go-ethereum/common/mclock" @@ -30,8 +35,8 @@ import ( // UDPConn is a network connection on which discovery can operate. type UDPConn interface { - ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) - WriteToUDP(b []byte, addr *net.UDPAddr) (n int, err error) + ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) + WriteToUDPAddrPort(b []byte, addr netip.AddrPort) (n int, err error) Close() error LocalAddr() net.Addr } @@ -62,7 +67,7 @@ type Config struct { func (cfg Config) withDefaults() Config { // Node table configuration: if cfg.PingInterval == 0 { - cfg.PingInterval = 10 * time.Second + cfg.PingInterval = 3 * time.Second } if cfg.RefreshInterval == 0 { cfg.RefreshInterval = 30 * time.Minute @@ -90,12 +95,46 @@ func ListenUDP(c UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv4, error) { // channel if configured. type ReadPacket struct { Data []byte - Addr *net.UDPAddr + Addr netip.AddrPort } -func min(x, y int) int { - if x > y { - return y - } - return x +type randomSource interface { + Intn(int) int + Int63n(int64) int64 + Shuffle(int, func(int, int)) +} + +// reseedingRandom is a random number generator that tracks when it was last re-seeded. +type reseedingRandom struct { + mu sync.Mutex + cur *rand.Rand +} + +func (r *reseedingRandom) seed() { + var b [8]byte + crand.Read(b[:]) + seed := binary.BigEndian.Uint64(b[:]) + new := rand.New(rand.NewSource(int64(seed))) + + r.mu.Lock() + r.cur = new + r.mu.Unlock() +} + +func (r *reseedingRandom) Intn(n int) int { + r.mu.Lock() + defer r.mu.Unlock() + return r.cur.Intn(n) +} + +func (r *reseedingRandom) Int63n(n int64) int64 { + r.mu.Lock() + defer r.mu.Unlock() + return r.cur.Int63n(n) +} + +func (r *reseedingRandom) Shuffle(n int, swap func(i, j int)) { + r.mu.Lock() + defer r.mu.Unlock() + r.cur.Shuffle(n, swap) } diff --git a/p2p/discover/lookup.go b/p2p/discover/lookup.go index b8d97b44e1..09808b71e0 100644 --- a/p2p/discover/lookup.go +++ b/p2p/discover/lookup.go @@ -29,16 +29,16 @@ import ( // not need to be an actual node identifier. type lookup struct { tab *Table - queryfunc func(*node) ([]*node, error) - replyCh chan []*node + queryfunc queryFunc + replyCh chan []*enode.Node cancelCh <-chan struct{} asked, seen map[enode.ID]bool result nodesByDistance - replyBuffer []*node + replyBuffer []*enode.Node queries int } -type queryFunc func(*node) ([]*node, error) +type queryFunc func(*enode.Node) ([]*enode.Node, error) func newLookup(ctx context.Context, tab *Table, target enode.ID, q queryFunc) *lookup { it := &lookup{ @@ -47,7 +47,7 @@ func newLookup(ctx context.Context, tab *Table, target enode.ID, q queryFunc) *l asked: make(map[enode.ID]bool), seen: make(map[enode.ID]bool), result: nodesByDistance{target: target}, - replyCh: make(chan []*node, alpha), + replyCh: make(chan []*enode.Node, alpha), cancelCh: ctx.Done(), queries: -1, } @@ -61,7 +61,7 @@ func newLookup(ctx context.Context, tab *Table, target enode.ID, q queryFunc) *l func (it *lookup) run() []*enode.Node { for it.advance() { } - return unwrapNodes(it.result.entries) + return it.result.entries } // advance advances the lookup until any new nodes have been found. @@ -139,33 +139,14 @@ func (it *lookup) slowdown() { } } -func (it *lookup) query(n *node, reply chan<- []*node) { - fails := it.tab.db.FindFails(n.ID(), n.IP()) +func (it *lookup) query(n *enode.Node, reply chan<- []*enode.Node) { r, err := it.queryfunc(n) - if errors.Is(err, errClosed) { - // Avoid recording failures on shutdown. - reply <- nil - return - } else if len(r) == 0 { - fails++ - it.tab.db.UpdateFindFails(n.ID(), n.IP(), fails) - // Remove the node from the local table if it fails to return anything useful too - // many times, but only if there are enough other nodes in the bucket. - dropped := false - if fails >= maxFindnodeFailures && it.tab.bucketLen(n.ID()) >= bucketSize/2 { - dropped = true - it.tab.delete(n) + if !errors.Is(err, errClosed) { // avoid recording failures on shutdown. + success := len(r) > 0 + it.tab.trackRequest(n, success, r) + if err != nil { + it.tab.log.Trace("FINDNODE failed", "id", n.ID(), "err", err) } - it.tab.log.Trace("FINDNODE failed", "id", n.ID(), "failcount", fails, "dropped", dropped, "err", err) - } else if fails > 0 { - // Reset failure counter because it counts _consecutive_ failures. - it.tab.db.UpdateFindFails(n.ID(), n.IP(), 0) - } - - // Grab as many nodes as possible. Some of them might not be alive anymore, but we'll - // just remove those again during revalidation. - for _, n := range r { - it.tab.addSeenNode(n) } reply <- r } @@ -173,7 +154,7 @@ func (it *lookup) query(n *node, reply chan<- []*node) { // lookupIterator performs lookup operations and iterates over all seen nodes. // When a lookup finishes, a new one is created through nextLookup. type lookupIterator struct { - buffer []*node + buffer []*enode.Node nextLookup lookupFunc ctx context.Context cancel func() @@ -192,7 +173,7 @@ func (it *lookupIterator) Node() *enode.Node { if len(it.buffer) == 0 { return nil } - return unwrapNode(it.buffer[0]) + return it.buffer[0] } // Next moves to the next node. diff --git a/p2p/discover/metrics.go b/p2p/discover/metrics.go index 56aae24285..24d2bb1706 100644 --- a/p2p/discover/metrics.go +++ b/p2p/discover/metrics.go @@ -18,7 +18,7 @@ package discover import ( "fmt" - "net" + "net/netip" "github.com/ethereum/go-ethereum/metrics" ) @@ -58,16 +58,16 @@ func newMeteredConn(conn UDPConn) UDPConn { return &meteredUdpConn{UDPConn: conn} } -// ReadFromUDP delegates a network read to the underlying connection, bumping the udp ingress traffic meter along the way. -func (c *meteredUdpConn) ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) { - n, addr, err = c.UDPConn.ReadFromUDP(b) +// ReadFromUDPAddrPort delegates a network read to the underlying connection, bumping the udp ingress traffic meter along the way. +func (c *meteredUdpConn) ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) { + n, addr, err = c.UDPConn.ReadFromUDPAddrPort(b) ingressTrafficMeter.Mark(int64(n)) return n, addr, err } -// Write delegates a network write to the underlying connection, bumping the udp egress traffic meter along the way. -func (c *meteredUdpConn) WriteToUDP(b []byte, addr *net.UDPAddr) (n int, err error) { - n, err = c.UDPConn.WriteToUDP(b, addr) +// WriteToUDP delegates a network write to the underlying connection, bumping the udp egress traffic meter along the way. +func (c *meteredUdpConn) WriteToUDP(b []byte, addr netip.AddrPort) (n int, err error) { + n, err = c.UDPConn.WriteToUDPAddrPort(b, addr) egressTrafficMeter.Mark(int64(n)) return n, err } diff --git a/p2p/discover/node.go b/p2p/discover/node.go index 9ffe101ccf..042619221b 100644 --- a/p2p/discover/node.go +++ b/p2p/discover/node.go @@ -21,7 +21,8 @@ import ( "crypto/elliptic" "errors" "math/big" - "net" + "slices" + "sort" "time" "github.com/ethereum/go-ethereum/common/math" @@ -29,12 +30,22 @@ import ( "github.com/ethereum/go-ethereum/p2p/enode" ) -// node represents a host on the network. -// The fields of Node may not be modified. -type node struct { - enode.Node - addedAt time.Time // time when the node was added to the table - livenessChecks uint // how often liveness was checked +type BucketNode struct { + Node *enode.Node `json:"node"` + AddedToTable time.Time `json:"addedToTable"` + AddedToBucket time.Time `json:"addedToBucket"` + Checks int `json:"checks"` + Live bool `json:"live"` +} + +// tableNode is an entry in Table. +type tableNode struct { + *enode.Node + revalList *revalidationList + addedToTable time.Time // first time node was added to bucket or replacement list + addedToBucket time.Time // time it was added in the actual bucket + livenessChecks uint // how often liveness was checked + isValidatedLive bool // true if existence of node is considered validated right now } type encPubkey [64]byte @@ -64,34 +75,59 @@ func (e encPubkey) id() enode.ID { return enode.ID(crypto.Keccak256Hash(e[:])) } -func wrapNode(n *enode.Node) *node { - return &node{Node: *n} -} - -func wrapNodes(ns []*enode.Node) []*node { - result := make([]*node, len(ns)) +func unwrapNodes(ns []*tableNode) []*enode.Node { + result := make([]*enode.Node, len(ns)) for i, n := range ns { - result[i] = wrapNode(n) + result[i] = n.Node } return result } -func unwrapNode(n *node) *enode.Node { - return &n.Node +func (n *tableNode) String() string { + return n.Node.String() } -func unwrapNodes(ns []*node) []*enode.Node { - result := make([]*enode.Node, len(ns)) - for i, n := range ns { - result[i] = unwrapNode(n) +// nodesByDistance is a list of nodes, ordered by distance to target. +type nodesByDistance struct { + entries []*enode.Node + target enode.ID +} + +// push adds the given node to the list, keeping the total size below maxElems. +func (h *nodesByDistance) push(n *enode.Node, maxElems int) { + ix := sort.Search(len(h.entries), func(i int) bool { + return enode.DistCmp(h.target, h.entries[i].ID(), n.ID()) > 0 + }) + + end := len(h.entries) + if len(h.entries) < maxElems { + h.entries = append(h.entries, n) + } + if ix < end { + // Slide existing entries down to make room. + // This will overwrite the entry we just appended. + copy(h.entries[ix+1:], h.entries[ix:]) + h.entries[ix] = n } - return result } -func (n *node) addr() *net.UDPAddr { - return &net.UDPAddr{IP: n.IP(), Port: n.UDP()} +type nodeType interface { + ID() enode.ID } -func (n *node) String() string { - return n.Node.String() +// containsID reports whether ns contains a node with the given ID. +func containsID[N nodeType](ns []N, id enode.ID) bool { + for _, n := range ns { + if n.ID() == id { + return true + } + } + return false +} + +// deleteNode removes a node from the list. +func deleteNode[N nodeType](list []N, id enode.ID) []N { + return slices.DeleteFunc(list, func(n N) bool { + return n.ID() == id + }) } diff --git a/p2p/discover/table.go b/p2p/discover/table.go index 2b7a28708b..bd3c9b4143 100644 --- a/p2p/discover/table.go +++ b/p2p/discover/table.go @@ -24,16 +24,14 @@ package discover import ( "context" - crand "crypto/rand" - "encoding/binary" "fmt" - mrand "math/rand" "net" - "sort" + "slices" "sync" "time" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/common/mclock" "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/metrics" "github.com/ethereum/go-ethereum/p2p/enode" @@ -55,21 +53,21 @@ const ( bucketIPLimit, bucketSubnet = 2, 24 // at most 2 addresses from the same /24 tableIPLimit, tableSubnet = 10, 24 - copyNodesInterval = 30 * time.Second - seedMinTableTime = 5 * time.Minute - seedCount = 30 - seedMaxAge = 5 * 24 * time.Hour + seedMinTableTime = 5 * time.Minute + seedCount = 30 + seedMaxAge = 5 * 24 * time.Hour ) // Table is the 'node table', a Kademlia-like index of neighbor nodes. The table keeps // itself up-to-date by verifying the liveness of neighbors and requesting their node // records when announcements of a new record version are received. type Table struct { - mutex sync.Mutex // protects buckets, bucket content, nursery, rand - buckets [nBuckets]*bucket // index of known nodes by distance - nursery []*node // bootstrap nodes - rand *mrand.Rand // source of randomness, periodically reseeded - ips netutil.DistinctNetSet + mutex sync.Mutex // protects buckets, bucket content, nursery, rand + buckets [nBuckets]*bucket // index of known nodes by distance + nursery []*enode.Node // bootstrap nodes + rand reseedingRandom // source of randomness, periodically reseeded + ips netutil.DistinctNetSet + revalidation tableRevalidation db *enode.DB // database of known nodes net transport @@ -77,13 +75,17 @@ type Table struct { log log.Logger // loop channels - refreshReq chan chan struct{} - initDone chan struct{} - closeReq chan struct{} - closed chan struct{} + refreshReq chan chan struct{} + revalResponseCh chan revalidationResponse + addNodeCh chan addNodeOp + addNodeHandled chan bool + trackRequestCh chan trackRequestOp + initDone chan struct{} + closeReq chan struct{} + closed chan struct{} - nodeAddedHook func(*bucket, *node) - nodeRemovedHook func(*bucket, *node) + nodeAddedHook func(*bucket, *tableNode) + nodeRemovedHook func(*bucket, *tableNode) } // transport is implemented by the UDP transports. @@ -98,28 +100,40 @@ type transport interface { // bucket contains nodes, ordered by their last activity. the entry // that was most recently active is the first element in entries. type bucket struct { - entries []*node // live entries, sorted by time of last contact - replacements []*node // recently seen nodes to be used if revalidation fails + entries []*tableNode // live entries, sorted by time of last contact + replacements []*tableNode // recently seen nodes to be used if revalidation fails ips netutil.DistinctNetSet index int } +type addNodeOp struct { + node *enode.Node + isInbound bool + forceSetLive bool // for tests +} + +type trackRequestOp struct { + node *enode.Node + foundNodes []*enode.Node + success bool +} + func newTable(t transport, db *enode.DB, cfg Config) (*Table, error) { cfg = cfg.withDefaults() tab := &Table{ - net: t, - db: db, - cfg: cfg, - log: cfg.Log, - refreshReq: make(chan chan struct{}), - initDone: make(chan struct{}), - closeReq: make(chan struct{}), - closed: make(chan struct{}), - rand: mrand.New(mrand.NewSource(0)), - ips: netutil.DistinctNetSet{Subnet: tableSubnet, Limit: tableIPLimit}, - } - if err := tab.setFallbackNodes(cfg.Bootnodes); err != nil { - return nil, err + net: t, + db: db, + cfg: cfg, + log: cfg.Log, + refreshReq: make(chan chan struct{}), + revalResponseCh: make(chan revalidationResponse), + addNodeCh: make(chan addNodeOp), + addNodeHandled: make(chan bool), + trackRequestCh: make(chan trackRequestOp), + initDone: make(chan struct{}), + closeReq: make(chan struct{}), + closed: make(chan struct{}), + ips: netutil.DistinctNetSet{Subnet: tableSubnet, Limit: tableIPLimit}, } for i := range tab.buckets { tab.buckets[i] = &bucket{ @@ -127,41 +141,34 @@ func newTable(t transport, db *enode.DB, cfg Config) (*Table, error) { ips: netutil.DistinctNetSet{Subnet: bucketSubnet, Limit: bucketIPLimit}, } } - tab.seedRand() - tab.loadSeedNodes() - - return tab, nil -} + tab.rand.seed() + tab.revalidation.init(&cfg) -func newMeteredTable(t transport, db *enode.DB, cfg Config) (*Table, error) { - tab, err := newTable(t, db, cfg) - if err != nil { + // initial table content + if err := tab.setFallbackNodes(cfg.Bootnodes); err != nil { return nil, err } - if metrics.Enabled { - tab.nodeAddedHook = func(b *bucket, n *node) { - bucketsCounter[b.index].Inc(1) - } - tab.nodeRemovedHook = func(b *bucket, n *node) { - bucketsCounter[b.index].Dec(1) - } - } + tab.loadSeedNodes() + return tab, nil } // Nodes returns all nodes contained in the table. -func (tab *Table) Nodes() []*enode.Node { - if !tab.isInitDone() { - return nil - } - +func (tab *Table) Nodes() [][]BucketNode { tab.mutex.Lock() defer tab.mutex.Unlock() - var nodes []*enode.Node - for _, b := range &tab.buckets { - for _, n := range b.entries { - nodes = append(nodes, unwrapNode(n)) + nodes := make([][]BucketNode, len(tab.buckets)) + for i, b := range &tab.buckets { + nodes[i] = make([]BucketNode, len(b.entries)) + for j, n := range b.entries { + nodes[i][j] = BucketNode{ + Node: n.Node, + Checks: int(n.livenessChecks), + Live: n.isValidatedLive, + AddedToTable: n.addedToTable, + AddedToBucket: n.addedToBucket, + } } } return nodes @@ -171,15 +178,6 @@ func (tab *Table) self() *enode.Node { return tab.net.Self() } -func (tab *Table) seedRand() { - var b [8]byte - crand.Read(b[:]) - - tab.mutex.Lock() - tab.rand.Seed(int64(binary.BigEndian.Uint64(b[:]))) - tab.mutex.Unlock() -} - // getNode returns the node with the given ID or nil if it isn't in the table. func (tab *Table) getNode(id enode.ID) *enode.Node { tab.mutex.Lock() @@ -188,7 +186,7 @@ func (tab *Table) getNode(id enode.ID) *enode.Node { b := tab.bucket(id) for _, e := range b.entries { if e.ID() == id { - return unwrapNode(e) + return e.Node } } return nil @@ -204,7 +202,7 @@ func (tab *Table) close() { // are used to connect to the network if the table is empty and there // are no known nodes in the database. func (tab *Table) setFallbackNodes(nodes []*enode.Node) error { - nursery := make([]*node, 0, len(nodes)) + nursery := make([]*enode.Node, 0, len(nodes)) for _, n := range nodes { if err := n.ValidateComplete(); err != nil { return fmt.Errorf("bad bootstrap node %q: %v", n, err) @@ -213,7 +211,7 @@ func (tab *Table) setFallbackNodes(nodes []*enode.Node) error { tab.log.Error("Bootstrap node filtered by netrestrict", "id", n.ID(), "ip", n.IP()) continue } - nursery = append(nursery, wrapNode(n)) + nursery = append(nursery, n) } tab.nursery = nursery return nil @@ -239,52 +237,173 @@ func (tab *Table) refresh() <-chan struct{} { return done } -// loop schedules runs of doRefresh, doRevalidate and copyLiveNodes. +// findnodeByID returns the n nodes in the table that are closest to the given id. +// This is used by the FINDNODE/v4 handler. +// +// The preferLive parameter says whether the caller wants liveness-checked results. If +// preferLive is true and the table contains any verified nodes, the result will not +// contain unverified nodes. However, if there are no verified nodes at all, the result +// will contain unverified nodes. +func (tab *Table) findnodeByID(target enode.ID, nresults int, preferLive bool) *nodesByDistance { + tab.mutex.Lock() + defer tab.mutex.Unlock() + + // Scan all buckets. There might be a better way to do this, but there aren't that many + // buckets, so this solution should be fine. The worst-case complexity of this loop + // is O(tab.len() * nresults). + nodes := &nodesByDistance{target: target} + liveNodes := &nodesByDistance{target: target} + for _, b := range &tab.buckets { + for _, n := range b.entries { + nodes.push(n.Node, nresults) + if preferLive && n.isValidatedLive { + liveNodes.push(n.Node, nresults) + } + } + } + + if preferLive && len(liveNodes.entries) > 0 { + return liveNodes + } + return nodes +} + +// appendLiveNodes adds nodes at the given distance to the result slice. +// This is used by the FINDNODE/v5 handler. +func (tab *Table) appendLiveNodes(dist uint, result []*enode.Node) []*enode.Node { + if dist > 256 { + return result + } + if dist == 0 { + return append(result, tab.self()) + } + + tab.mutex.Lock() + for _, n := range tab.bucketAtDistance(int(dist)).entries { + if n.isValidatedLive { + result = append(result, n.Node) + } + } + tab.mutex.Unlock() + + // Shuffle result to avoid always returning same nodes in FINDNODE/v5. + tab.rand.Shuffle(len(result), func(i, j int) { + result[i], result[j] = result[j], result[i] + }) + return result +} + +// len returns the number of nodes in the table. +func (tab *Table) len() (n int) { + tab.mutex.Lock() + defer tab.mutex.Unlock() + + for _, b := range &tab.buckets { + n += len(b.entries) + } + return n +} + +// addFoundNode adds a node which may not be live. If the bucket has space available, +// adding the node succeeds immediately. Otherwise, the node is added to the replacements +// list. +// +// The caller must not hold tab.mutex. +func (tab *Table) addFoundNode(n *enode.Node, forceSetLive bool) bool { + op := addNodeOp{node: n, isInbound: false, forceSetLive: forceSetLive} + select { + case tab.addNodeCh <- op: + return <-tab.addNodeHandled + case <-tab.closeReq: + return false + } +} + +// addInboundNode adds a node from an inbound contact. If the bucket has no space, the +// node is added to the replacements list. +// +// There is an additional safety measure: if the table is still initializing the node is +// not added. This prevents an attack where the table could be filled by just sending ping +// repeatedly. +// +// The caller must not hold tab.mutex. +func (tab *Table) addInboundNode(n *enode.Node) bool { + op := addNodeOp{node: n, isInbound: true} + select { + case tab.addNodeCh <- op: + return <-tab.addNodeHandled + case <-tab.closeReq: + return false + } +} + +func (tab *Table) trackRequest(n *enode.Node, success bool, foundNodes []*enode.Node) { + op := trackRequestOp{n, foundNodes, success} + select { + case tab.trackRequestCh <- op: + case <-tab.closeReq: + } +} + +// loop is the main loop of Table. func (tab *Table) loop() { var ( - revalidate = time.NewTimer(tab.nextRevalidateTime()) - refresh = time.NewTimer(tab.nextRefreshTime()) - copyNodes = time.NewTicker(copyNodesInterval) - refreshDone = make(chan struct{}) // where doRefresh reports completion - revalidateDone chan struct{} // where doRevalidate reports completion - waiting = []chan struct{}{tab.initDone} // holds waiting callers while doRefresh runs + refresh = time.NewTimer(tab.nextRefreshTime()) + refreshDone = make(chan struct{}) // where doRefresh reports completion + waiting = []chan struct{}{tab.initDone} // holds waiting callers while doRefresh runs + revalTimer = mclock.NewAlarm(tab.cfg.Clock) + reseedRandTimer = time.NewTicker(10 * time.Minute) ) defer refresh.Stop() - defer revalidate.Stop() - defer copyNodes.Stop() + defer revalTimer.Stop() + defer reseedRandTimer.Stop() // Start initial refresh. go tab.doRefresh(refreshDone) loop: for { + nextTime := tab.revalidation.run(tab, tab.cfg.Clock.Now()) + revalTimer.Schedule(nextTime) + select { + case <-reseedRandTimer.C: + tab.rand.seed() + + case <-revalTimer.C(): + + case r := <-tab.revalResponseCh: + tab.revalidation.handleResponse(tab, r) + + case op := <-tab.addNodeCh: + tab.mutex.Lock() + ok := tab.handleAddNode(op) + tab.mutex.Unlock() + tab.addNodeHandled <- ok + + case op := <-tab.trackRequestCh: + tab.handleTrackRequest(op) + case <-refresh.C: - tab.seedRand() if refreshDone == nil { refreshDone = make(chan struct{}) go tab.doRefresh(refreshDone) } + case req := <-tab.refreshReq: waiting = append(waiting, req) if refreshDone == nil { refreshDone = make(chan struct{}) go tab.doRefresh(refreshDone) } + case <-refreshDone: for _, ch := range waiting { close(ch) } waiting, refreshDone = nil, nil refresh.Reset(tab.nextRefreshTime()) - case <-revalidate.C: - revalidateDone = make(chan struct{}) - go tab.doRevalidate(revalidateDone) - case <-revalidateDone: - revalidate.Reset(tab.nextRevalidateTime()) - revalidateDone = nil - case <-copyNodes.C: - go tab.copyLiveNodes() + case <-tab.closeReq: break loop } @@ -296,9 +415,6 @@ loop: for _, ch := range waiting { close(ch) } - if revalidateDone != nil { - <-revalidateDone - } close(tab.closed) } @@ -327,177 +443,24 @@ func (tab *Table) doRefresh(done chan struct{}) { } func (tab *Table) loadSeedNodes() { - seeds := wrapNodes(tab.db.QuerySeeds(seedCount, seedMaxAge)) + seeds := tab.db.QuerySeeds(seedCount, seedMaxAge) seeds = append(seeds, tab.nursery...) for i := range seeds { seed := seeds[i] if tab.log.Enabled(context.Background(), log.LevelTrace) { age := time.Since(tab.db.LastPongReceived(seed.ID(), seed.IP())) - tab.log.Trace("Found seed node in database", "id", seed.ID(), "addr", seed.addr(), "age", age) + addr, _ := seed.UDPEndpoint() + tab.log.Trace("Found seed node in database", "id", seed.ID(), "addr", addr, "age", age) } - tab.addSeenNode(seed) + tab.handleAddNode(addNodeOp{node: seed, isInbound: false}) } } -// doRevalidate checks that the last node in a random bucket is still live and replaces or -// deletes the node if it isn't. -func (tab *Table) doRevalidate(done chan<- struct{}) { - defer func() { done <- struct{}{} }() - - last, bi := tab.nodeToRevalidate() - if last == nil { - // No non-empty bucket found. - return - } - - // Ping the selected node and wait for a pong. - remoteSeq, err := tab.net.ping(unwrapNode(last)) - - // Also fetch record if the node replied and returned a higher sequence number. - if last.Seq() < remoteSeq { - n, err := tab.net.RequestENR(unwrapNode(last)) - if err != nil { - tab.log.Debug("ENR request failed", "id", last.ID(), "addr", last.addr(), "err", err) - } else { - last = &node{Node: *n, addedAt: last.addedAt, livenessChecks: last.livenessChecks} - } - } - - tab.mutex.Lock() - defer tab.mutex.Unlock() - b := tab.buckets[bi] - if err == nil { - // The node responded, move it to the front. - last.livenessChecks++ - tab.log.Debug("Revalidated node", "b", bi, "id", last.ID(), "checks", last.livenessChecks) - tab.bumpInBucket(b, last) - return - } - // No reply received, pick a replacement or delete the node if there aren't - // any replacements. - if r := tab.replace(b, last); r != nil { - tab.log.Debug("Replaced dead node", "b", bi, "id", last.ID(), "ip", last.IP(), "checks", last.livenessChecks, "r", r.ID(), "rip", r.IP()) - } else { - tab.log.Debug("Removed dead node", "b", bi, "id", last.ID(), "ip", last.IP(), "checks", last.livenessChecks) - } -} - -// nodeToRevalidate returns the last node in a random, non-empty bucket. -func (tab *Table) nodeToRevalidate() (n *node, bi int) { - tab.mutex.Lock() - defer tab.mutex.Unlock() - - for _, bi = range tab.rand.Perm(len(tab.buckets)) { - b := tab.buckets[bi] - if len(b.entries) > 0 { - last := b.entries[len(b.entries)-1] - return last, bi - } - } - return nil, 0 -} - -func (tab *Table) nextRevalidateTime() time.Duration { - tab.mutex.Lock() - defer tab.mutex.Unlock() - - return time.Duration(tab.rand.Int63n(int64(tab.cfg.PingInterval))) -} - func (tab *Table) nextRefreshTime() time.Duration { - tab.mutex.Lock() - defer tab.mutex.Unlock() - half := tab.cfg.RefreshInterval / 2 return half + time.Duration(tab.rand.Int63n(int64(half))) } -// copyLiveNodes adds nodes from the table to the database if they have been in the table -// longer than seedMinTableTime. -func (tab *Table) copyLiveNodes() { - tab.mutex.Lock() - defer tab.mutex.Unlock() - - now := time.Now() - for _, b := range &tab.buckets { - for _, n := range b.entries { - if n.livenessChecks > 0 && now.Sub(n.addedAt) >= seedMinTableTime { - tab.db.UpdateNode(unwrapNode(n)) - } - } - } -} - -// findnodeByID returns the n nodes in the table that are closest to the given id. -// This is used by the FINDNODE/v4 handler. -// -// The preferLive parameter says whether the caller wants liveness-checked results. If -// preferLive is true and the table contains any verified nodes, the result will not -// contain unverified nodes. However, if there are no verified nodes at all, the result -// will contain unverified nodes. -func (tab *Table) findnodeByID(target enode.ID, nresults int, preferLive bool) *nodesByDistance { - tab.mutex.Lock() - defer tab.mutex.Unlock() - - // Scan all buckets. There might be a better way to do this, but there aren't that many - // buckets, so this solution should be fine. The worst-case complexity of this loop - // is O(tab.len() * nresults). - nodes := &nodesByDistance{target: target} - liveNodes := &nodesByDistance{target: target} - for _, b := range &tab.buckets { - for _, n := range b.entries { - nodes.push(n, nresults) - if preferLive && n.livenessChecks > 0 { - liveNodes.push(n, nresults) - } - } - } - - if preferLive && len(liveNodes.entries) > 0 { - return liveNodes - } - return nodes -} - -// appendLiveNodes adds nodes at the given distance to the result slice. -func (tab *Table) appendLiveNodes(dist uint, result []*enode.Node) []*enode.Node { - if dist > 256 { - return result - } - if dist == 0 { - return append(result, tab.self()) - } - - tab.mutex.Lock() - defer tab.mutex.Unlock() - for _, n := range tab.bucketAtDistance(int(dist)).entries { - if n.livenessChecks >= 1 { - node := n.Node // avoid handing out pointer to struct field - result = append(result, &node) - } - } - return result -} - -// len returns the number of nodes in the table. -func (tab *Table) len() (n int) { - tab.mutex.Lock() - defer tab.mutex.Unlock() - - for _, b := range &tab.buckets { - n += len(b.entries) - } - return n -} - -// bucketLen returns the number of nodes in the bucket for the given ID. -func (tab *Table) bucketLen(id enode.ID) int { - tab.mutex.Lock() - defer tab.mutex.Unlock() - - return len(tab.bucket(id).entries) -} - // bucket returns the bucket for the given node ID hash. func (tab *Table) bucket(id enode.ID) *bucket { d := enode.LogDist(tab.self().ID(), id) @@ -511,95 +474,6 @@ func (tab *Table) bucketAtDistance(d int) *bucket { return tab.buckets[d-bucketMinDistance-1] } -// addSeenNode adds a node which may or may not be live to the end of a bucket. If the -// bucket has space available, adding the node succeeds immediately. Otherwise, the node is -// added to the replacements list. -// -// The caller must not hold tab.mutex. -func (tab *Table) addSeenNode(n *node) { - if n.ID() == tab.self().ID() { - return - } - - tab.mutex.Lock() - defer tab.mutex.Unlock() - b := tab.bucket(n.ID()) - if contains(b.entries, n.ID()) { - // Already in bucket, don't add. - return - } - if len(b.entries) >= bucketSize { - // Bucket full, maybe add as replacement. - tab.addReplacement(b, n) - return - } - if !tab.addIP(b, n.IP()) { - // Can't add: IP limit reached. - return - } - - // Add to end of bucket: - b.entries = append(b.entries, n) - b.replacements = deleteNode(b.replacements, n) - n.addedAt = time.Now() - - if tab.nodeAddedHook != nil { - tab.nodeAddedHook(b, n) - } -} - -// addVerifiedNode adds a node whose existence has been verified recently to the front of a -// bucket. If the node is already in the bucket, it is moved to the front. If the bucket -// has no space, the node is added to the replacements list. -// -// There is an additional safety measure: if the table is still initializing the node -// is not added. This prevents an attack where the table could be filled by just sending -// ping repeatedly. -// -// The caller must not hold tab.mutex. -func (tab *Table) addVerifiedNode(n *node) { - if !tab.isInitDone() { - return - } - if n.ID() == tab.self().ID() { - return - } - - tab.mutex.Lock() - defer tab.mutex.Unlock() - b := tab.bucket(n.ID()) - if tab.bumpInBucket(b, n) { - // Already in bucket, moved to front. - return - } - if len(b.entries) >= bucketSize { - // Bucket full, maybe add as replacement. - tab.addReplacement(b, n) - return - } - if !tab.addIP(b, n.IP()) { - // Can't add: IP limit reached. - return - } - - // Add to front of bucket. - b.entries, _ = pushNode(b.entries, n, bucketSize) - b.replacements = deleteNode(b.replacements, n) - n.addedAt = time.Now() - - if tab.nodeAddedHook != nil { - tab.nodeAddedHook(b, n) - } -} - -// delete removes an entry from the node table. It is used to evacuate dead nodes. -func (tab *Table) delete(node *node) { - tab.mutex.Lock() - defer tab.mutex.Unlock() - - tab.deleteInBucket(tab.bucket(node.ID()), node) -} - func (tab *Table) addIP(b *bucket, ip net.IP) bool { if len(ip) == 0 { return false // Nodes without IP cannot be added. @@ -627,128 +501,194 @@ func (tab *Table) removeIP(b *bucket, ip net.IP) { b.ips.Remove(ip) } -func (tab *Table) addReplacement(b *bucket, n *node) { - for _, e := range b.replacements { - if e.ID() == n.ID() { - return // already in list - } +// handleAddNode adds the node in the request to the table, if there is space. +// The caller must hold tab.mutex. +func (tab *Table) handleAddNode(req addNodeOp) bool { + if req.node.ID() == tab.self().ID() { + return false + } + // For nodes from inbound contact, there is an additional safety measure: if the table + // is still initializing the node is not added. + if req.isInbound && !tab.isInitDone() { + return false + } + + b := tab.bucket(req.node.ID()) + n, _ := tab.bumpInBucket(b, req.node, req.isInbound) + if n != nil { + // Already in bucket. + return false + } + if len(b.entries) >= bucketSize { + // Bucket full, maybe add as replacement. + tab.addReplacement(b, req.node) + return false + } + if !tab.addIP(b, req.node.IP()) { + // Can't add: IP limit reached. + return false + } + + // Add to bucket. + wn := &tableNode{Node: req.node} + if req.forceSetLive { + wn.livenessChecks = 1 + wn.isValidatedLive = true + } + b.entries = append(b.entries, wn) + b.replacements = deleteNode(b.replacements, wn.ID()) + tab.nodeAdded(b, wn) + return true +} + +// addReplacement adds n to the replacement cache of bucket b. +func (tab *Table) addReplacement(b *bucket, n *enode.Node) { + if containsID(b.replacements, n.ID()) { + // TODO: update ENR + return } if !tab.addIP(b, n.IP()) { return } - var removed *node - b.replacements, removed = pushNode(b.replacements, n, maxReplacements) + + wn := &tableNode{Node: n, addedToTable: time.Now()} + var removed *tableNode + b.replacements, removed = pushNode(b.replacements, wn, maxReplacements) if removed != nil { tab.removeIP(b, removed.IP()) } } -// replace removes n from the replacement list and replaces 'last' with it if it is the -// last entry in the bucket. If 'last' isn't the last entry, it has either been replaced -// with someone else or became active. -func (tab *Table) replace(b *bucket, last *node) *node { - if len(b.entries) == 0 || b.entries[len(b.entries)-1].ID() != last.ID() { - // Entry has moved, don't replace it. - return nil +func (tab *Table) nodeAdded(b *bucket, n *tableNode) { + if n.addedToTable == (time.Time{}) { + n.addedToTable = time.Now() } - // Still the last entry. - if len(b.replacements) == 0 { - tab.deleteInBucket(b, last) - return nil + n.addedToBucket = time.Now() + tab.revalidation.nodeAdded(tab, n) + if tab.nodeAddedHook != nil { + tab.nodeAddedHook(b, n) } - r := b.replacements[tab.rand.Intn(len(b.replacements))] - b.replacements = deleteNode(b.replacements, r) - b.entries[len(b.entries)-1] = r - tab.removeIP(b, last.IP()) - return r -} - -// bumpInBucket moves the given node to the front of the bucket entry list -// if it is contained in that list. -func (tab *Table) bumpInBucket(b *bucket, n *node) bool { - for i := range b.entries { - if b.entries[i].ID() == n.ID() { - if !n.IP().Equal(b.entries[i].IP()) { - // Endpoint has changed, ensure that the new IP fits into table limits. - tab.removeIP(b, b.entries[i].IP()) - if !tab.addIP(b, n.IP()) { - // It doesn't, put the previous one back. - tab.addIP(b, b.entries[i].IP()) - return false - } - } - // Move it to the front. - copy(b.entries[1:], b.entries[:i]) - b.entries[0] = n - return true - } + if metrics.Enabled { + bucketsCounter[b.index].Inc(1) } - return false } -func (tab *Table) deleteInBucket(b *bucket, n *node) { - // Check if the node is actually in the bucket so the removed hook - // isn't called multiple times for the same node. - if !contains(b.entries, n.ID()) { - return - } - b.entries = deleteNode(b.entries, n) - tab.removeIP(b, n.IP()) +func (tab *Table) nodeRemoved(b *bucket, n *tableNode) { + tab.revalidation.nodeRemoved(n) if tab.nodeRemovedHook != nil { tab.nodeRemovedHook(b, n) } + if metrics.Enabled { + bucketsCounter[b.index].Dec(1) + } } -func contains(ns []*node, id enode.ID) bool { - for _, n := range ns { - if n.ID() == id { - return true - } +// deleteInBucket removes node n from the table. +// If there are replacement nodes in the bucket, the node is replaced. +func (tab *Table) deleteInBucket(b *bucket, id enode.ID) *tableNode { + index := slices.IndexFunc(b.entries, func(e *tableNode) bool { return e.ID() == id }) + if index == -1 { + // Entry has been removed already. + return nil } - return false -} -// pushNode adds n to the front of list, keeping at most max items. -func pushNode(list []*node, n *node, max int) ([]*node, *node) { - if len(list) < max { - list = append(list, nil) + // Remove the node. + n := b.entries[index] + b.entries = slices.Delete(b.entries, index, index+1) + tab.removeIP(b, n.IP()) + tab.nodeRemoved(b, n) + + // Add replacement. + if len(b.replacements) == 0 { + tab.log.Debug("Removed dead node", "b", b.index, "id", n.ID(), "ip", n.IP()) + return nil } - removed := list[len(list)-1] - copy(list[1:], list) - list[0] = n - return list, removed + rindex := tab.rand.Intn(len(b.replacements)) + rep := b.replacements[rindex] + b.replacements = slices.Delete(b.replacements, rindex, rindex+1) + b.entries = append(b.entries, rep) + tab.nodeAdded(b, rep) + tab.log.Debug("Replaced dead node", "b", b.index, "id", n.ID(), "ip", n.IP(), "r", rep.ID(), "rip", rep.IP()) + return rep } -// deleteNode removes n from list. -func deleteNode(list []*node, n *node) []*node { - for i := range list { - if list[i].ID() == n.ID() { - return append(list[:i], list[i+1:]...) +// bumpInBucket updates a node record if it exists in the bucket. +// The second return value reports whether the node's endpoint (IP/port) was updated. +func (tab *Table) bumpInBucket(b *bucket, newRecord *enode.Node, isInbound bool) (n *tableNode, endpointChanged bool) { + i := slices.IndexFunc(b.entries, func(elem *tableNode) bool { + return elem.ID() == newRecord.ID() + }) + if i == -1 { + return nil, false // not in bucket + } + n = b.entries[i] + + // For inbound updates (from the node itself) we accept any change, even if it sets + // back the sequence number. For found nodes (!isInbound), seq has to advance. Note + // this check also ensures found discv4 nodes (which always have seq=0) can't be + // updated. + if newRecord.Seq() <= n.Seq() && !isInbound { + return n, false + } + + // Check endpoint update against IP limits. + ipchanged := newRecord.IPAddr() != n.IPAddr() + portchanged := newRecord.UDP() != n.UDP() + if ipchanged { + tab.removeIP(b, n.IP()) + if !tab.addIP(b, newRecord.IP()) { + // It doesn't fit with the limit, put the previous record back. + tab.addIP(b, n.IP()) + return n, false } } - return list -} -// nodesByDistance is a list of nodes, ordered by distance to target. -type nodesByDistance struct { - entries []*node - target enode.ID + // Apply update. + n.Node = newRecord + if ipchanged || portchanged { + // Ensure node is revalidated quickly for endpoint changes. + tab.revalidation.nodeEndpointChanged(tab, n) + return n, true + } + return n, false } -// push adds the given node to the list, keeping the total size below maxElems. -func (h *nodesByDistance) push(n *node, maxElems int) { - ix := sort.Search(len(h.entries), func(i int) bool { - return enode.DistCmp(h.target, h.entries[i].ID(), n.ID()) > 0 - }) +func (tab *Table) handleTrackRequest(op trackRequestOp) { + var fails int + if op.success { + // Reset failure counter because it counts _consecutive_ failures. + tab.db.UpdateFindFails(op.node.ID(), op.node.IP(), 0) + } else { + fails = tab.db.FindFails(op.node.ID(), op.node.IP()) + fails++ + tab.db.UpdateFindFails(op.node.ID(), op.node.IP(), fails) + } - end := len(h.entries) - if len(h.entries) < maxElems { - h.entries = append(h.entries, n) + tab.mutex.Lock() + defer tab.mutex.Unlock() + + b := tab.bucket(op.node.ID()) + // Remove the node from the local table if it fails to return anything useful too + // many times, but only if there are enough other nodes in the bucket. This latter + // condition specifically exists to make bootstrapping in smaller test networks more + // reliable. + if fails >= maxFindnodeFailures && len(b.entries) >= bucketSize/4 { + tab.deleteInBucket(b, op.node.ID()) } - if ix < end { - // Slide existing entries down to make room. - // This will overwrite the entry we just appended. - copy(h.entries[ix+1:], h.entries[ix:]) - h.entries[ix] = n + + // Add found nodes. + for _, n := range op.foundNodes { + tab.handleAddNode(addNodeOp{n, false, false}) } } + +// pushNode adds n to the front of list, keeping at most max items. +func pushNode(list []*tableNode, n *tableNode, max int) ([]*tableNode, *tableNode) { + if len(list) < max { + list = append(list, nil) + } + removed := list[len(list)-1] + copy(list[1:], list) + list[0] = n + return list, removed +} diff --git a/p2p/discover/table_reval.go b/p2p/discover/table_reval.go new file mode 100644 index 0000000000..f2ea8b34fa --- /dev/null +++ b/p2p/discover/table_reval.go @@ -0,0 +1,244 @@ +// Copyright 2024 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package discover + +import ( + "fmt" + "math" + "slices" + "time" + + "github.com/ethereum/go-ethereum/common/mclock" + "github.com/ethereum/go-ethereum/p2p/enode" +) + +const never = mclock.AbsTime(math.MaxInt64) + +const slowRevalidationFactor = 3 + +// tableRevalidation implements the node revalidation process. +// It tracks all nodes contained in Table, and schedules sending PING to them. +type tableRevalidation struct { + fast revalidationList + slow revalidationList + activeReq map[enode.ID]struct{} +} + +type revalidationResponse struct { + n *tableNode + newRecord *enode.Node + didRespond bool +} + +func (tr *tableRevalidation) init(cfg *Config) { + tr.activeReq = make(map[enode.ID]struct{}) + tr.fast.nextTime = never + tr.fast.interval = cfg.PingInterval + tr.fast.name = "fast" + tr.slow.nextTime = never + tr.slow.interval = cfg.PingInterval * slowRevalidationFactor + tr.slow.name = "slow" +} + +// nodeAdded is called when the table receives a new node. +func (tr *tableRevalidation) nodeAdded(tab *Table, n *tableNode) { + tr.fast.push(n, tab.cfg.Clock.Now(), &tab.rand) +} + +// nodeRemoved is called when a node was removed from the table. +func (tr *tableRevalidation) nodeRemoved(n *tableNode) { + if n.revalList == nil { + panic(fmt.Errorf("removed node %v has nil revalList", n.ID())) + } + n.revalList.remove(n) +} + +// nodeEndpointChanged is called when a change in IP or port is detected. +func (tr *tableRevalidation) nodeEndpointChanged(tab *Table, n *tableNode) { + n.isValidatedLive = false + tr.moveToList(&tr.fast, n, tab.cfg.Clock.Now(), &tab.rand) +} + +// run performs node revalidation. +// It returns the next time it should be invoked, which is used in the Table main loop +// to schedule a timer. However, run can be called at any time. +func (tr *tableRevalidation) run(tab *Table, now mclock.AbsTime) (nextTime mclock.AbsTime) { + if n := tr.fast.get(now, &tab.rand, tr.activeReq); n != nil { + tr.startRequest(tab, n) + tr.fast.schedule(now, &tab.rand) + } + if n := tr.slow.get(now, &tab.rand, tr.activeReq); n != nil { + tr.startRequest(tab, n) + tr.slow.schedule(now, &tab.rand) + } + + return min(tr.fast.nextTime, tr.slow.nextTime) +} + +// startRequest spawns a revalidation request for node n. +func (tr *tableRevalidation) startRequest(tab *Table, n *tableNode) { + if _, ok := tr.activeReq[n.ID()]; ok { + panic(fmt.Errorf("duplicate startRequest (node %v)", n.ID())) + } + tr.activeReq[n.ID()] = struct{}{} + resp := revalidationResponse{n: n} + + // Fetch the node while holding lock. + tab.mutex.Lock() + node := n.Node + tab.mutex.Unlock() + + go tab.doRevalidate(resp, node) +} + +func (tab *Table) doRevalidate(resp revalidationResponse, node *enode.Node) { + // Ping the selected node and wait for a pong response. + remoteSeq, err := tab.net.ping(node) + resp.didRespond = err == nil + + // Also fetch record if the node replied and returned a higher sequence number. + if remoteSeq > node.Seq() { + newrec, err := tab.net.RequestENR(node) + if err != nil { + tab.log.Debug("ENR request failed", "id", node.ID(), "err", err) + } else { + resp.newRecord = newrec + } + } + + select { + case tab.revalResponseCh <- resp: + case <-tab.closed: + } +} + +// handleResponse processes the result of a revalidation request. +func (tr *tableRevalidation) handleResponse(tab *Table, resp revalidationResponse) { + var ( + now = tab.cfg.Clock.Now() + n = resp.n + b = tab.bucket(n.ID()) + ) + delete(tr.activeReq, n.ID()) + + // If the node was removed from the table while getting checked, we need to stop + // processing here to avoid re-adding it. + if n.revalList == nil { + return + } + + // Store potential seeds in database. + // This is done via defer to avoid holding Table lock while writing to DB. + defer func() { + if n.isValidatedLive && n.livenessChecks > 5 { + tab.db.UpdateNode(resp.n.Node) + } + }() + + // Remaining logic needs access to Table internals. + tab.mutex.Lock() + defer tab.mutex.Unlock() + + if !resp.didRespond { + n.livenessChecks /= 3 + if n.livenessChecks <= 0 { + tab.deleteInBucket(b, n.ID()) + } else { + tab.log.Debug("Node revalidation failed", "b", b.index, "id", n.ID(), "checks", n.livenessChecks, "q", n.revalList.name) + tr.moveToList(&tr.fast, n, now, &tab.rand) + } + return + } + + // The node responded. + n.livenessChecks++ + n.isValidatedLive = true + tab.log.Debug("Node revalidated", "b", b.index, "id", n.ID(), "checks", n.livenessChecks, "q", n.revalList.name) + var endpointChanged bool + if resp.newRecord != nil { + _, endpointChanged = tab.bumpInBucket(b, resp.newRecord, false) + } + + // Node moves to slow list if it passed and hasn't changed. + if !endpointChanged { + tr.moveToList(&tr.slow, n, now, &tab.rand) + } +} + +// moveToList ensures n is in the 'dest' list. +func (tr *tableRevalidation) moveToList(dest *revalidationList, n *tableNode, now mclock.AbsTime, rand randomSource) { + if n.revalList == dest { + return + } + if n.revalList != nil { + n.revalList.remove(n) + } + dest.push(n, now, rand) +} + +// revalidationList holds a list nodes and the next revalidation time. +type revalidationList struct { + nodes []*tableNode + nextTime mclock.AbsTime + interval time.Duration + name string +} + +// get returns a random node from the queue. Nodes in the 'exclude' map are not returned. +func (list *revalidationList) get(now mclock.AbsTime, rand randomSource, exclude map[enode.ID]struct{}) *tableNode { + if now < list.nextTime || len(list.nodes) == 0 { + return nil + } + for i := 0; i < len(list.nodes)*3; i++ { + n := list.nodes[rand.Intn(len(list.nodes))] + _, excluded := exclude[n.ID()] + if !excluded { + return n + } + } + return nil +} + +func (list *revalidationList) schedule(now mclock.AbsTime, rand randomSource) { + list.nextTime = now.Add(time.Duration(rand.Int63n(int64(list.interval)))) +} + +func (list *revalidationList) push(n *tableNode, now mclock.AbsTime, rand randomSource) { + list.nodes = append(list.nodes, n) + if list.nextTime == never { + list.schedule(now, rand) + } + n.revalList = list +} + +func (list *revalidationList) remove(n *tableNode) { + i := slices.Index(list.nodes, n) + if i == -1 { + panic(fmt.Errorf("node %v not found in list", n.ID())) + } + list.nodes = slices.Delete(list.nodes, i, i+1) + if len(list.nodes) == 0 { + list.nextTime = never + } + n.revalList = nil +} + +func (list *revalidationList) contains(id enode.ID) bool { + return slices.ContainsFunc(list.nodes, func(n *tableNode) bool { + return n.ID() == id + }) +} diff --git a/p2p/discover/table_reval_test.go b/p2p/discover/table_reval_test.go new file mode 100644 index 0000000000..3605443934 --- /dev/null +++ b/p2p/discover/table_reval_test.go @@ -0,0 +1,119 @@ +// Copyright 2024 The go-ethereum Authors +// This file is part of the go-ethereum library. +// +// The go-ethereum library is free software: you can redistribute it and/or modify +// it under the terms of the GNU Lesser General Public License as published by +// the Free Software Foundation, either version 3 of the License, or +// (at your option) any later version. +// +// The go-ethereum library is distributed in the hope that it will be useful, +// but WITHOUT ANY WARRANTY; without even the implied warranty of +// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the +// GNU Lesser General Public License for more details. +// +// You should have received a copy of the GNU Lesser General Public License +// along with the go-ethereum library. If not, see . + +package discover + +import ( + "net" + "testing" + "time" + + "github.com/ethereum/go-ethereum/common/mclock" + "github.com/ethereum/go-ethereum/p2p/enode" + "github.com/ethereum/go-ethereum/p2p/enr" +) + +// This test checks that revalidation can handle a node disappearing while +// a request is active. +func TestRevalidation_nodeRemoved(t *testing.T) { + var ( + clock mclock.Simulated + transport = newPingRecorder() + tab, db = newInactiveTestTable(transport, Config{Clock: &clock}) + tr = &tab.revalidation + ) + defer db.Close() + + // Add a node to the table. + node := nodeAtDistance(tab.self().ID(), 255, net.IP{77, 88, 99, 1}) + tab.handleAddNode(addNodeOp{node: node}) + + // Start a revalidation request. Schedule once to get the next start time, + // then advance the clock to that point and schedule again to start. + next := tr.run(tab, clock.Now()) + clock.Run(time.Duration(next + 1)) + tr.run(tab, clock.Now()) + if len(tr.activeReq) != 1 { + t.Fatal("revalidation request did not start:", tr.activeReq) + } + + // Delete the node. + tab.deleteInBucket(tab.bucket(node.ID()), node.ID()) + + // Now finish the revalidation request. + var resp revalidationResponse + select { + case resp = <-tab.revalResponseCh: + case <-time.After(1 * time.Second): + t.Fatal("timed out waiting for revalidation") + } + tr.handleResponse(tab, resp) + + // Ensure the node was not re-added to the table. + if tab.getNode(node.ID()) != nil { + t.Fatal("node was re-added to Table") + } + if tr.fast.contains(node.ID()) || tr.slow.contains(node.ID()) { + t.Fatal("removed node contained in revalidation list") + } +} + +// This test checks that nodes with an updated endpoint remain in the fast revalidation list. +func TestRevalidation_endpointUpdate(t *testing.T) { + var ( + clock mclock.Simulated + transport = newPingRecorder() + tab, db = newInactiveTestTable(transport, Config{Clock: &clock}) + tr = &tab.revalidation + ) + defer db.Close() + + // Add node to table. + node := nodeAtDistance(tab.self().ID(), 255, net.IP{77, 88, 99, 1}) + tab.handleAddNode(addNodeOp{node: node}) + + // Update the record in transport, including endpoint update. + record := node.Record() + record.Set(enr.IP{100, 100, 100, 100}) + record.Set(enr.UDP(9999)) + nodev2 := enode.SignNull(record, node.ID()) + transport.updateRecord(nodev2) + + // Start a revalidation request. Schedule once to get the next start time, + // then advance the clock to that point and schedule again to start. + next := tr.run(tab, clock.Now()) + clock.Run(time.Duration(next + 1)) + tr.run(tab, clock.Now()) + if len(tr.activeReq) != 1 { + t.Fatal("revalidation request did not start:", tr.activeReq) + } + + // Now finish the revalidation request. + var resp revalidationResponse + select { + case resp = <-tab.revalResponseCh: + case <-time.After(1 * time.Second): + t.Fatal("timed out waiting for revalidation") + } + tr.handleResponse(tab, resp) + + if tr.fast.nodes[0].ID() != node.ID() { + t.Fatal("node not contained in fast revalidation list") + } + if tr.fast.nodes[0].isValidatedLive { + t.Fatal("node is marked live after endpoint change") + } +} diff --git a/p2p/discover/table_test.go b/p2p/discover/table_test.go index 3ba3422251..30e7d56f4a 100644 --- a/p2p/discover/table_test.go +++ b/p2p/discover/table_test.go @@ -20,14 +20,17 @@ import ( "crypto/ecdsa" "fmt" "math/rand" - "net" "reflect" + "slices" "testing" "testing/quick" "time" + "github.com/ethereum/go-ethereum/common/mclock" "github.com/ethereum/go-ethereum/crypto" + "github.com/ethereum/go-ethereum/internal/testlog" + "github.com/ethereum/go-ethereum/log" "github.com/ethereum/go-ethereum/p2p/enode" "github.com/ethereum/go-ethereum/p2p/enr" "github.com/ethereum/go-ethereum/p2p/netutil" @@ -49,106 +52,109 @@ func TestTable_pingReplace(t *testing.T) { } func testPingReplace(t *testing.T, newNodeIsResponding, lastInBucketIsResponding bool) { + simclock := new(mclock.Simulated) transport := newPingRecorder() - tab, db := newTestTable(transport) + tab, db := newTestTable(transport, Config{ + Clock: simclock, + Log: testlog.Logger(t, log.LevelTrace), + }) defer db.Close() defer tab.close() <-tab.initDone // Fill up the sender's bucket. - pingKey, _ := crypto.HexToECDSA("45a915e4d060149eb4365960e6a7a45f334393093061116b197e3240065ff2d8") - pingSender := wrapNode(enode.NewV4(&pingKey.PublicKey, net.IP{127, 0, 0, 1}, 99, 99)) - last := fillBucket(tab, pingSender) + replacementNodeKey, _ := crypto.HexToECDSA("45a915e4d060149eb4365960e6a7a45f334393093061116b197e3240065ff2d8") + replacementNode := enode.NewV4(&replacementNodeKey.PublicKey, net.IP{127, 0, 0, 1}, 99, 99) + last := fillBucket(tab, replacementNode.ID()) + tab.mutex.Lock() + nodeEvents := newNodeEventRecorder(128) + tab.nodeAddedHook = nodeEvents.nodeAdded + tab.nodeRemovedHook = nodeEvents.nodeRemoved + tab.mutex.Unlock() - // Add the sender as if it just pinged us. Revalidate should replace the last node in - // its bucket if it is unresponsive. Revalidate again to ensure that + // The revalidation process should replace + // this node in the bucket if it is unresponsive. transport.dead[last.ID()] = !lastInBucketIsResponding - transport.dead[pingSender.ID()] = !newNodeIsResponding - tab.addSeenNode(pingSender) - tab.doRevalidate(make(chan struct{}, 1)) - tab.doRevalidate(make(chan struct{}, 1)) - - if !transport.pinged[last.ID()] { - // Oldest node in bucket is pinged to see whether it is still alive. - t.Error("table did not ping last node in bucket") + transport.dead[replacementNode.ID()] = !newNodeIsResponding + + // Add replacement node to table. + tab.addFoundNode(replacementNode, false) + + t.Log("last:", last.ID()) + t.Log("replacement:", replacementNode.ID()) + + // Wait until the last node was pinged. + waitForRevalidationPing(t, transport, tab, last.ID()) + + if !lastInBucketIsResponding { + if !nodeEvents.waitNodeAbsent(last.ID(), 2*time.Second) { + t.Error("last node was not removed") + } + if !nodeEvents.waitNodePresent(replacementNode.ID(), 2*time.Second) { + t.Error("replacement node was not added") + } + + // If a replacement is expected, we also need to wait until the replacement node + // was pinged and added/removed. + waitForRevalidationPing(t, transport, tab, replacementNode.ID()) + if !newNodeIsResponding { + if !nodeEvents.waitNodeAbsent(replacementNode.ID(), 2*time.Second) { + t.Error("replacement node was not removed") + } + } } + // Check bucket content. tab.mutex.Lock() defer tab.mutex.Unlock() wantSize := bucketSize if !lastInBucketIsResponding && !newNodeIsResponding { wantSize-- } - if l := len(tab.bucket(pingSender.ID()).entries); l != wantSize { - t.Errorf("wrong bucket size after bond: got %d, want %d", l, wantSize) + bucket := tab.bucket(replacementNode.ID()) + if l := len(bucket.entries); l != wantSize { + t.Errorf("wrong bucket size after revalidation: got %d, want %d", l, wantSize) } - if found := contains(tab.bucket(pingSender.ID()).entries, last.ID()); found != lastInBucketIsResponding { - t.Errorf("last entry found: %t, want: %t", found, lastInBucketIsResponding) + if ok := containsID(bucket.entries, last.ID()); ok != lastInBucketIsResponding { + t.Errorf("revalidated node found: %t, want: %t", ok, lastInBucketIsResponding) } wantNewEntry := newNodeIsResponding && !lastInBucketIsResponding - if found := contains(tab.bucket(pingSender.ID()).entries, pingSender.ID()); found != wantNewEntry { - t.Errorf("new entry found: %t, want: %t", found, wantNewEntry) + if ok := containsID(bucket.entries, replacementNode.ID()); ok != wantNewEntry { + t.Errorf("replacement node found: %t, want: %t", ok, wantNewEntry) } } -func TestBucket_bumpNoDuplicates(t *testing.T) { - t.Parallel() - cfg := &quick.Config{ - MaxCount: 1000, - Rand: rand.New(rand.NewSource(time.Now().Unix())), - Values: func(args []reflect.Value, rand *rand.Rand) { - // generate a random list of nodes. this will be the content of the bucket. - n := rand.Intn(bucketSize-1) + 1 - nodes := make([]*node, n) - for i := range nodes { - nodes[i] = nodeAtDistance(enode.ID{}, 200, intIP(200)) - } - args[0] = reflect.ValueOf(nodes) - // generate random bump positions. - bumps := make([]int, rand.Intn(100)) - for i := range bumps { - bumps[i] = rand.Intn(len(nodes)) - } - args[1] = reflect.ValueOf(bumps) - }, - } - - prop := func(nodes []*node, bumps []int) (ok bool) { - tab, db := newTestTable(newPingRecorder()) - defer db.Close() - defer tab.close() +// waitForRevalidationPing waits until a PING message is sent to a node with the given id. +func waitForRevalidationPing(t *testing.T, transport *pingRecorder, tab *Table, id enode.ID) *enode.Node { + t.Helper() - b := &bucket{entries: make([]*node, len(nodes))} - copy(b.entries, nodes) - for i, pos := range bumps { - tab.bumpInBucket(b, b.entries[pos]) - if hasDuplicates(b.entries) { - t.Logf("bucket has duplicates after %d/%d bumps:", i+1, len(bumps)) - for _, n := range b.entries { - t.Logf(" %p", n) - } - return false - } + simclock := tab.cfg.Clock.(*mclock.Simulated) + maxAttempts := tab.len() * 8 + for i := 0; i < maxAttempts; i++ { + simclock.Run(tab.cfg.PingInterval * slowRevalidationFactor) + p := transport.waitPing(2 * time.Second) + if p == nil { + t.Fatal("Table did not send revalidation ping") + } + if id == (enode.ID{}) || p.ID() == id { + return p } - checkIPLimitInvariant(t, tab) - return true - } - if err := quick.Check(prop, cfg); err != nil { - t.Error(err) } + t.Fatalf("Table did not ping node %v (%d attempts)", id, maxAttempts) + return nil } // This checks that the table-wide IP limit is applied correctly. func TestTable_IPLimit(t *testing.T) { transport := newPingRecorder() - tab, db := newTestTable(transport) + tab, db := newTestTable(transport, Config{}) defer db.Close() defer tab.close() for i := 0; i < tableIPLimit+1; i++ { n := nodeAtDistance(tab.self().ID(), i, net.IP{172, 0, 1, byte(i)}) - tab.addSeenNode(n) + tab.addFoundNode(n, false) } if tab.len() > tableIPLimit { t.Errorf("too many nodes in table") @@ -159,14 +165,14 @@ func TestTable_IPLimit(t *testing.T) { // This checks that the per-bucket IP limit is applied correctly. func TestTable_BucketIPLimit(t *testing.T) { transport := newPingRecorder() - tab, db := newTestTable(transport) + tab, db := newTestTable(transport, Config{}) defer db.Close() defer tab.close() d := 3 for i := 0; i < bucketIPLimit+1; i++ { n := nodeAtDistance(tab.self().ID(), d, net.IP{172, 0, 1, byte(i)}) - tab.addSeenNode(n) + tab.addFoundNode(n, false) } if tab.len() > bucketIPLimit { t.Errorf("too many nodes in table") @@ -196,7 +202,7 @@ func TestTable_findnodeByID(t *testing.T) { test := func(test *closeTest) bool { // for any node table, Target and N transport := newPingRecorder() - tab, db := newTestTable(transport) + tab, db := newTestTable(transport, Config{}) defer db.Close() defer tab.close() fillTable(tab, test.All, true) @@ -227,7 +233,7 @@ func TestTable_findnodeByID(t *testing.T) { // check that the result nodes have minimum distance to target. for _, b := range tab.buckets { for _, n := range b.entries { - if contains(result, n.ID()) { + if containsID(result, n.ID()) { continue // don't run the check below for nodes in result } farthestResult := result[len(result)-1].ID() @@ -250,7 +256,7 @@ func TestTable_findnodeByID(t *testing.T) { type closeTest struct { Self enode.ID Target enode.ID - All []*node + All []*enode.Node N int } @@ -263,15 +269,14 @@ func (*closeTest) Generate(rand *rand.Rand, size int) reflect.Value { for _, id := range gen([]enode.ID{}, rand).([]enode.ID) { r := new(enr.Record) r.Set(enr.IP(genIP(rand))) - n := wrapNode(enode.SignNull(r, id)) - n.livenessChecks = 1 + n := enode.SignNull(r, id) t.All = append(t.All, n) } return reflect.ValueOf(t) } -func TestTable_addVerifiedNode(t *testing.T) { - tab, db := newTestTable(newPingRecorder()) +func TestTable_addInboundNode(t *testing.T) { + tab, db := newTestTable(newPingRecorder(), Config{}) <-tab.initDone defer db.Close() defer tab.close() @@ -279,31 +284,29 @@ func TestTable_addVerifiedNode(t *testing.T) { // Insert two nodes. n1 := nodeAtDistance(tab.self().ID(), 256, net.IP{88, 77, 66, 1}) n2 := nodeAtDistance(tab.self().ID(), 256, net.IP{88, 77, 66, 2}) - tab.addSeenNode(n1) - tab.addSeenNode(n2) + tab.addFoundNode(n1, false) + tab.addFoundNode(n2, false) + checkBucketContent(t, tab, []*enode.Node{n1, n2}) - // Verify bucket content: - bcontent := []*node{n1, n2} - if !reflect.DeepEqual(tab.bucket(n1.ID()).entries, bcontent) { - t.Fatalf("wrong bucket content: %v", tab.bucket(n1.ID()).entries) - } - - // Add a changed version of n2. + // Add a changed version of n2. The bucket should be updated. newrec := n2.Record() newrec.Set(enr.IP{99, 99, 99, 99}) - newn2 := wrapNode(enode.SignNull(newrec, n2.ID())) - tab.addVerifiedNode(newn2) - - // Check that bucket is updated correctly. - newBcontent := []*node{newn2, n1} - if !reflect.DeepEqual(tab.bucket(n1.ID()).entries, newBcontent) { - t.Fatalf("wrong bucket content after update: %v", tab.bucket(n1.ID()).entries) - } - checkIPLimitInvariant(t, tab) + n2v2 := enode.SignNull(newrec, n2.ID()) + tab.addInboundNode(n2v2) + checkBucketContent(t, tab, []*enode.Node{n1, n2v2}) + + // Try updating n2 without sequence number change. The update is accepted + // because it's inbound. + newrec = n2.Record() + newrec.Set(enr.IP{100, 100, 100, 100}) + newrec.SetSeq(n2.Seq()) + n2v3 := enode.SignNull(newrec, n2.ID()) + tab.addInboundNode(n2v3) + checkBucketContent(t, tab, []*enode.Node{n1, n2v3}) } -func TestTable_addSeenNode(t *testing.T) { - tab, db := newTestTable(newPingRecorder()) +func TestTable_addFoundNode(t *testing.T) { + tab, db := newTestTable(newPingRecorder(), Config{}) <-tab.initDone defer db.Close() defer tab.close() @@ -311,25 +314,86 @@ func TestTable_addSeenNode(t *testing.T) { // Insert two nodes. n1 := nodeAtDistance(tab.self().ID(), 256, net.IP{88, 77, 66, 1}) n2 := nodeAtDistance(tab.self().ID(), 256, net.IP{88, 77, 66, 2}) - tab.addSeenNode(n1) - tab.addSeenNode(n2) - - // Verify bucket content: - bcontent := []*node{n1, n2} - if !reflect.DeepEqual(tab.bucket(n1.ID()).entries, bcontent) { - t.Fatalf("wrong bucket content: %v", tab.bucket(n1.ID()).entries) - } + tab.addFoundNode(n1, false) + tab.addFoundNode(n2, false) + checkBucketContent(t, tab, []*enode.Node{n1, n2}) - // Add a changed version of n2. + // Add a changed version of n2. The bucket should be updated. newrec := n2.Record() newrec.Set(enr.IP{99, 99, 99, 99}) - newn2 := wrapNode(enode.SignNull(newrec, n2.ID())) - tab.addSeenNode(newn2) + n2v2 := enode.SignNull(newrec, n2.ID()) + tab.addFoundNode(n2v2, false) + checkBucketContent(t, tab, []*enode.Node{n1, n2v2}) + + // Try updating n2 without a sequence number change. + // The update should not be accepted. + newrec = n2.Record() + newrec.Set(enr.IP{100, 100, 100, 100}) + newrec.SetSeq(n2.Seq()) + n2v3 := enode.SignNull(newrec, n2.ID()) + tab.addFoundNode(n2v3, false) + checkBucketContent(t, tab, []*enode.Node{n1, n2v2}) +} + +// This test checks that discv4 nodes can update their own endpoint via PING. +func TestTable_addInboundNodeUpdateV4Accept(t *testing.T) { + tab, db := newTestTable(newPingRecorder(), Config{}) + <-tab.initDone + defer db.Close() + defer tab.close() + + // Add a v4 node. + key, _ := crypto.HexToECDSA("dd3757a8075e88d0f2b1431e7d3c5b1562e1c0aab9643707e8cbfcc8dae5cfe3") + n1 := enode.NewV4(&key.PublicKey, net.IP{88, 77, 66, 1}, 9000, 9000) + tab.addInboundNode(n1) + checkBucketContent(t, tab, []*enode.Node{n1}) + + // Add an updated version with changed IP. + // The update will be accepted because it is inbound. + n1v2 := enode.NewV4(&key.PublicKey, net.IP{99, 99, 99, 99}, 9000, 9000) + tab.addInboundNode(n1v2) + checkBucketContent(t, tab, []*enode.Node{n1v2}) +} + +// This test checks that discv4 node entries will NOT be updated when a +// changed record is found. +func TestTable_addFoundNodeV4UpdateReject(t *testing.T) { + tab, db := newTestTable(newPingRecorder(), Config{}) + <-tab.initDone + defer db.Close() + defer tab.close() - // Check that bucket content is unchanged. - if !reflect.DeepEqual(tab.bucket(n1.ID()).entries, bcontent) { - t.Fatalf("wrong bucket content after update: %v", tab.bucket(n1.ID()).entries) + // Add a v4 node. + key, _ := crypto.HexToECDSA("dd3757a8075e88d0f2b1431e7d3c5b1562e1c0aab9643707e8cbfcc8dae5cfe3") + n1 := enode.NewV4(&key.PublicKey, net.IP{88, 77, 66, 1}, 9000, 9000) + tab.addFoundNode(n1, false) + checkBucketContent(t, tab, []*enode.Node{n1}) + + // Add an updated version with changed IP. + // The update won't be accepted because it isn't inbound. + n1v2 := enode.NewV4(&key.PublicKey, net.IP{99, 99, 99, 99}, 9000, 9000) + tab.addFoundNode(n1v2, false) + checkBucketContent(t, tab, []*enode.Node{n1}) +} + +func checkBucketContent(t *testing.T, tab *Table, nodes []*enode.Node) { + t.Helper() + + b := tab.bucket(nodes[0].ID()) + if reflect.DeepEqual(unwrapNodes(b.entries), nodes) { + return } + t.Log("wrong bucket content. have nodes:") + for _, n := range b.entries { + t.Logf(" %v (seq=%v, ip=%v)", n.ID(), n.Seq(), n.IP()) + } + t.Log("want nodes:") + for _, n := range nodes { + t.Logf(" %v (seq=%v, ip=%v)", n.ID(), n.Seq(), n.IP()) + } + t.FailNow() + + // Also check IP limits. checkIPLimitInvariant(t, tab) } @@ -337,7 +401,10 @@ func TestTable_addSeenNode(t *testing.T) { // announces a new sequence number, the new record should be pulled. func TestTable_revalidateSyncRecord(t *testing.T) { transport := newPingRecorder() - tab, db := newTestTable(transport) + tab, db := newTestTable(transport, Config{ + Clock: new(mclock.Simulated), + Log: testlog.Logger(t, log.LevelTrace), + }) <-tab.initDone defer db.Close() defer tab.close() @@ -346,15 +413,19 @@ func TestTable_revalidateSyncRecord(t *testing.T) { var r enr.Record r.Set(enr.IP(net.IP{127, 0, 0, 1})) id := enode.ID{1} - n1 := wrapNode(enode.SignNull(&r, id)) - tab.addSeenNode(n1) + n1 := enode.SignNull(&r, id) + tab.addFoundNode(n1, false) // Update the node record. r.Set(enr.WithEntry("foo", "bar")) n2 := enode.SignNull(&r, id) transport.updateRecord(n2) - tab.doRevalidate(make(chan struct{}, 1)) + // Wait for revalidation. We wait for the node to be revalidated two times + // in order to synchronize with the update in the table. + waitForRevalidationPing(t, transport, tab, n2.ID()) + waitForRevalidationPing(t, transport, tab, n2.ID()) + intable := tab.getNode(id) if !reflect.DeepEqual(intable, n2) { t.Fatalf("table contains old record with seq %d, want seq %d", intable.Seq(), n2.Seq()) @@ -366,7 +437,7 @@ func TestNodesPush(t *testing.T) { n1 := nodeAtDistance(target, 255, intIP(1)) n2 := nodeAtDistance(target, 254, intIP(2)) n3 := nodeAtDistance(target, 253, intIP(3)) - perm := [][]*node{ + perm := [][]*enode.Node{ {n3, n2, n1}, {n3, n1, n2}, {n2, n3, n1}, @@ -381,7 +452,7 @@ func TestNodesPush(t *testing.T) { for _, n := range nodes { list.push(n, 3) } - if !slicesEqual(list.entries, perm[0], nodeIDEqual) { + if !slices.EqualFunc(list.entries, perm[0], nodeIDEqual) { t.Fatal("not equal") } } @@ -392,28 +463,16 @@ func TestNodesPush(t *testing.T) { for _, n := range nodes { list.push(n, 2) } - if !slicesEqual(list.entries, perm[0][:2], nodeIDEqual) { + if !slices.EqualFunc(list.entries, perm[0][:2], nodeIDEqual) { t.Fatal("not equal") } } } -func nodeIDEqual(n1, n2 *node) bool { +func nodeIDEqual[N nodeType](n1, n2 N) bool { return n1.ID() == n2.ID() } -func slicesEqual[T any](s1, s2 []T, check func(e1, e2 T) bool) bool { - if len(s1) != len(s2) { - return false - } - for i := range s1 { - if !check(s1[i], s2[i]) { - return false - } - } - return true -} - // gen wraps quick.Value so it's easier to use. // it generates a random value of the given value's type. func gen(typ interface{}, rand *rand.Rand) interface{} { diff --git a/p2p/discover/table_util_test.go b/p2p/discover/table_util_test.go index d6309dfd6c..34b831f5b1 100644 --- a/p2p/discover/table_util_test.go +++ b/p2p/discover/table_util_test.go @@ -25,6 +25,8 @@ import ( "math/rand" "net" "sync" + "sync/atomic" + "time" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/p2p/enode" @@ -40,27 +42,32 @@ func init() { nullNode = enode.SignNull(&r, enode.ID{}) } -func newTestTable(t transport) (*Table, *enode.DB) { - cfg := Config{} +func newTestTable(t transport, cfg Config) (*Table, *enode.DB) { + tab, db := newInactiveTestTable(t, cfg) + go tab.loop() + return tab, db +} + +// newInactiveTestTable creates a Table without running the main loop. +func newInactiveTestTable(t transport, cfg Config) (*Table, *enode.DB) { db, _ := enode.OpenDB("") tab, _ := newTable(t, db, cfg) - go tab.loop() return tab, db } // nodeAtDistance creates a node for which enode.LogDist(base, n.id) == ld. -func nodeAtDistance(base enode.ID, ld int, ip net.IP) *node { +func nodeAtDistance(base enode.ID, ld int, ip net.IP) *enode.Node { var r enr.Record r.Set(enr.IP(ip)) r.Set(enr.UDP(30303)) - return wrapNode(enode.SignNull(&r, idAtDistance(base, ld))) + return enode.SignNull(&r, idAtDistance(base, ld)) } // nodesAtDistance creates n nodes for which enode.LogDist(base, node.ID()) == ld. func nodesAtDistance(base enode.ID, ld int, n int) []*enode.Node { results := make([]*enode.Node, n) for i := range results { - results[i] = unwrapNode(nodeAtDistance(base, ld, intIP(i))) + results[i] = nodeAtDistance(base, ld, intIP(i)) } return results } @@ -98,31 +105,33 @@ func intIP(i int) net.IP { } // fillBucket inserts nodes into the given bucket until it is full. -func fillBucket(tab *Table, n *node) (last *node) { - ld := enode.LogDist(tab.self().ID(), n.ID()) - b := tab.bucket(n.ID()) +func fillBucket(tab *Table, id enode.ID) (last *tableNode) { + ld := enode.LogDist(tab.self().ID(), id) + b := tab.bucket(id) for len(b.entries) < bucketSize { - b.entries = append(b.entries, nodeAtDistance(tab.self().ID(), ld, intIP(ld))) + node := nodeAtDistance(tab.self().ID(), ld, intIP(ld)) + if !tab.addFoundNode(node, false) { + panic("node not added") + } } return b.entries[bucketSize-1] } // fillTable adds nodes the table to the end of their corresponding bucket // if the bucket is not full. The caller must not hold tab.mutex. -func fillTable(tab *Table, nodes []*node, setLive bool) { +func fillTable(tab *Table, nodes []*enode.Node, setLive bool) { for _, n := range nodes { - if setLive { - n.livenessChecks = 1 - } - tab.addSeenNode(n) + tab.addFoundNode(n, setLive) } } type pingRecorder struct { - mu sync.Mutex - dead, pinged map[enode.ID]bool - records map[enode.ID]*enode.Node - n *enode.Node + mu sync.Mutex + cond *sync.Cond + dead map[enode.ID]bool + records map[enode.ID]*enode.Node + pinged []*enode.Node + n *enode.Node } func newPingRecorder() *pingRecorder { @@ -130,12 +139,13 @@ func newPingRecorder() *pingRecorder { r.Set(enr.IP{0, 0, 0, 0}) n := enode.SignNull(&r, enode.ID{}) - return &pingRecorder{ + t := &pingRecorder{ dead: make(map[enode.ID]bool), - pinged: make(map[enode.ID]bool), records: make(map[enode.ID]*enode.Node), n: n, } + t.cond = sync.NewCond(&t.mu) + return t } // updateRecord updates a node record. Future calls to ping and @@ -151,12 +161,40 @@ func (t *pingRecorder) Self() *enode.Node { return nullNode } func (t *pingRecorder) lookupSelf() []*enode.Node { return nil } func (t *pingRecorder) lookupRandom() []*enode.Node { return nil } +func (t *pingRecorder) waitPing(timeout time.Duration) *enode.Node { + t.mu.Lock() + defer t.mu.Unlock() + + // Wake up the loop on timeout. + var timedout atomic.Bool + timer := time.AfterFunc(timeout, func() { + timedout.Store(true) + t.cond.Broadcast() + }) + defer timer.Stop() + + // Wait for a ping. + for { + if timedout.Load() { + return nil + } + if len(t.pinged) > 0 { + n := t.pinged[0] + t.pinged = append(t.pinged[:0], t.pinged[1:]...) + return n + } + t.cond.Wait() + } +} + // ping simulates a ping request. func (t *pingRecorder) ping(n *enode.Node) (seq uint64, err error) { t.mu.Lock() defer t.mu.Unlock() - t.pinged[n.ID()] = true + t.pinged = append(t.pinged, n) + t.cond.Broadcast() + if t.dead[n.ID()] { return 0, errTimeout } @@ -177,7 +215,7 @@ func (t *pingRecorder) RequestENR(n *enode.Node) (*enode.Node, error) { return t.records[n.ID()], nil } -func hasDuplicates(slice []*node) bool { +func hasDuplicates(slice []*enode.Node) bool { seen := make(map[enode.ID]bool, len(slice)) for i, e := range slice { if e == nil { @@ -219,14 +257,14 @@ func nodeEqual(n1 *enode.Node, n2 *enode.Node) bool { return n1.ID() == n2.ID() && n1.IP().Equal(n2.IP()) } -func sortByID(nodes []*enode.Node) { - slices.SortFunc(nodes, func(a, b *enode.Node) int { +func sortByID[N nodeType](nodes []N) { + slices.SortFunc(nodes, func(a, b N) int { return bytes.Compare(a.ID().Bytes(), b.ID().Bytes()) }) } -func sortedByDistanceTo(distbase enode.ID, slice []*node) bool { - return slices.IsSortedFunc(slice, func(a, b *node) int { +func sortedByDistanceTo(distbase enode.ID, slice []*enode.Node) bool { + return slices.IsSortedFunc(slice, func(a, b *enode.Node) int { return enode.DistCmp(distbase, a.ID(), b.ID()) }) } @@ -256,3 +294,57 @@ func hexEncPubkey(h string) (ret encPubkey) { copy(ret[:], b) return ret } + +type nodeEventRecorder struct { + evc chan recordedNodeEvent +} + +type recordedNodeEvent struct { + node *tableNode + added bool +} + +func newNodeEventRecorder(buffer int) *nodeEventRecorder { + return &nodeEventRecorder{ + evc: make(chan recordedNodeEvent, buffer), + } +} + +func (set *nodeEventRecorder) nodeAdded(b *bucket, n *tableNode) { + select { + case set.evc <- recordedNodeEvent{n, true}: + default: + panic("no space in event buffer") + } +} + +func (set *nodeEventRecorder) nodeRemoved(b *bucket, n *tableNode) { + select { + case set.evc <- recordedNodeEvent{n, false}: + default: + panic("no space in event buffer") + } +} + +func (set *nodeEventRecorder) waitNodePresent(id enode.ID, timeout time.Duration) bool { + return set.waitNodeEvent(id, timeout, true) +} + +func (set *nodeEventRecorder) waitNodeAbsent(id enode.ID, timeout time.Duration) bool { + return set.waitNodeEvent(id, timeout, false) +} + +func (set *nodeEventRecorder) waitNodeEvent(id enode.ID, timeout time.Duration, added bool) bool { + timer := time.NewTimer(timeout) + defer timer.Stop() + for { + select { + case ev := <-set.evc: + if ev.node.ID() == id && ev.added == added { + return true + } + case <-timer.C: + return false + } + } +} diff --git a/p2p/discover/v4_lookup_test.go b/p2p/discover/v4_lookup_test.go index 8867a5a8ac..ea75960566 100644 --- a/p2p/discover/v4_lookup_test.go +++ b/p2p/discover/v4_lookup_test.go @@ -19,7 +19,7 @@ package discover import ( "crypto/ecdsa" "fmt" - "net" + "net/netip" "testing" "github.com/ethereum/go-ethereum/crypto" @@ -40,7 +40,7 @@ func TestUDPv4_Lookup(t *testing.T) { } // Seed table with initial node. - fillTable(test.table, []*node{wrapNode(lookupTestnet.node(256, 0))}, true) + fillTable(test.table, []*enode.Node{lookupTestnet.node(256, 0)}, true) // Start the lookup. resultC := make(chan []*enode.Node, 1) @@ -70,9 +70,9 @@ func TestUDPv4_LookupIterator(t *testing.T) { defer test.close() // Seed table with initial nodes. - bootnodes := make([]*node, len(lookupTestnet.dists[256])) + bootnodes := make([]*enode.Node, len(lookupTestnet.dists[256])) for i := range lookupTestnet.dists[256] { - bootnodes[i] = wrapNode(lookupTestnet.node(256, i)) + bootnodes[i] = lookupTestnet.node(256, i) } fillTable(test.table, bootnodes, true) go serveTestnet(test, lookupTestnet) @@ -105,9 +105,9 @@ func TestUDPv4_LookupIteratorClose(t *testing.T) { defer test.close() // Seed table with initial nodes. - bootnodes := make([]*node, len(lookupTestnet.dists[256])) + bootnodes := make([]*enode.Node, len(lookupTestnet.dists[256])) for i := range lookupTestnet.dists[256] { - bootnodes[i] = wrapNode(lookupTestnet.node(256, i)) + bootnodes[i] = lookupTestnet.node(256, i) } fillTable(test.table, bootnodes, true) go serveTestnet(test, lookupTestnet) @@ -136,7 +136,7 @@ func TestUDPv4_LookupIteratorClose(t *testing.T) { func serveTestnet(test *udpTest, testnet *preminedTestnet) { for done := false; !done; { - done = test.waitPacketOut(func(p v4wire.Packet, to *net.UDPAddr, hash []byte) { + done = test.waitPacketOut(func(p v4wire.Packet, to netip.AddrPort, hash []byte) { n, key := testnet.nodeByAddr(to) switch p.(type) { case *v4wire.Ping: @@ -158,10 +158,10 @@ func checkLookupResults(t *testing.T, tn *preminedTestnet, results []*enode.Node for _, e := range results { t.Logf(" ld=%d, %x", enode.LogDist(tn.target.id(), e.ID()), e.ID().Bytes()) } - if hasDuplicates(wrapNodes(results)) { + if hasDuplicates(results) { t.Errorf("result set contains duplicate entries") } - if !sortedByDistanceTo(tn.target.id(), wrapNodes(results)) { + if !sortedByDistanceTo(tn.target.id(), results) { t.Errorf("result set not sorted by distance to target") } wantNodes := tn.closest(len(results)) @@ -264,9 +264,10 @@ func (tn *preminedTestnet) node(dist, index int) *enode.Node { return n } -func (tn *preminedTestnet) nodeByAddr(addr *net.UDPAddr) (*enode.Node, *ecdsa.PrivateKey) { - dist := int(addr.IP[1])<<8 + int(addr.IP[2]) - index := int(addr.IP[3]) +func (tn *preminedTestnet) nodeByAddr(addr netip.AddrPort) (*enode.Node, *ecdsa.PrivateKey) { + ip := addr.Addr().As4() + dist := int(ip[1])<<8 + int(ip[2]) + index := int(ip[3]) key := tn.dists[dist][index] return tn.node(dist, index), key } @@ -274,7 +275,7 @@ func (tn *preminedTestnet) nodeByAddr(addr *net.UDPAddr) (*enode.Node, *ecdsa.Pr func (tn *preminedTestnet) nodesAtDistance(dist int) []v4wire.Node { result := make([]v4wire.Node, len(tn.dists[dist])) for i := range result { - result[i] = nodeToRPC(wrapNode(tn.node(dist, i))) + result[i] = nodeToRPC(tn.node(dist, i)) } return result } diff --git a/p2p/discover/v4_udp.go b/p2p/discover/v4_udp.go index 988f16b01d..eb069c7305 100644 --- a/p2p/discover/v4_udp.go +++ b/p2p/discover/v4_udp.go @@ -26,6 +26,7 @@ import ( "fmt" "io" "net" + "net/netip" "sync" "time" @@ -45,6 +46,7 @@ var ( errClockWarp = errors.New("reply deadline too far in the future") errClosed = errors.New("socket closed") errLowPort = errors.New("low port") + errNoUDPEndpoint = errors.New("node has no UDP endpoint") ) const ( @@ -93,7 +95,7 @@ type UDPv4 struct { type replyMatcher struct { // these fields must match in the reply. from enode.ID - ip net.IP + ip netip.Addr ptype byte // time when the request must complete @@ -119,7 +121,7 @@ type replyMatchFunc func(v4wire.Packet) (matched bool, requestDone bool) // reply is a reply packet from a certain node. type reply struct { from enode.ID - ip net.IP + ip netip.Addr data v4wire.Packet // loop indicates whether there was // a matching request by sending on this channel. @@ -142,7 +144,7 @@ func ListenV4(c UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv4, error) { log: cfg.Log, } - tab, err := newMeteredTable(t, ln.Database(), cfg) + tab, err := newTable(t, ln.Database(), cfg) if err != nil { return nil, err } @@ -201,9 +203,12 @@ func (t *UDPv4) Resolve(n *enode.Node) *enode.Node { } func (t *UDPv4) ourEndpoint() v4wire.Endpoint { - n := t.Self() - a := &net.UDPAddr{IP: n.IP(), Port: n.UDP()} - return v4wire.NewEndpoint(a, uint16(n.TCP())) + node := t.Self() + addr, ok := node.UDPEndpoint() + if !ok { + return v4wire.Endpoint{} + } + return v4wire.NewEndpoint(addr, uint16(node.TCP())) } // Ping sends a ping message to the given node. @@ -214,7 +219,11 @@ func (t *UDPv4) Ping(n *enode.Node) error { // ping sends a ping message to the given node and waits for a reply. func (t *UDPv4) ping(n *enode.Node) (seq uint64, err error) { - rm := t.sendPing(n.ID(), &net.UDPAddr{IP: n.IP(), Port: n.UDP()}, nil) + addr, ok := n.UDPEndpoint() + if !ok { + return 0, errNoUDPEndpoint + } + rm := t.sendPing(n.ID(), addr, nil) if err = <-rm.errc; err == nil { seq = rm.reply.(*v4wire.Pong).ENRSeq } @@ -223,7 +232,7 @@ func (t *UDPv4) ping(n *enode.Node) (seq uint64, err error) { // sendPing sends a ping message to the given node and invokes the callback // when the reply arrives. -func (t *UDPv4) sendPing(toid enode.ID, toaddr *net.UDPAddr, callback func()) *replyMatcher { +func (t *UDPv4) sendPing(toid enode.ID, toaddr netip.AddrPort, callback func()) *replyMatcher { req := t.makePing(toaddr) packet, hash, err := v4wire.Encode(t.priv, req) if err != nil { @@ -233,7 +242,7 @@ func (t *UDPv4) sendPing(toid enode.ID, toaddr *net.UDPAddr, callback func()) *r } // Add a matcher for the reply to the pending reply queue. Pongs are matched if they // reference the ping we're about to send. - rm := t.pending(toid, toaddr.IP, v4wire.PongPacket, func(p v4wire.Packet) (matched bool, requestDone bool) { + rm := t.pending(toid, toaddr.Addr(), v4wire.PongPacket, func(p v4wire.Packet) (matched bool, requestDone bool) { matched = bytes.Equal(p.(*v4wire.Pong).ReplyTok, hash) if matched && callback != nil { callback() @@ -241,12 +250,13 @@ func (t *UDPv4) sendPing(toid enode.ID, toaddr *net.UDPAddr, callback func()) *r return matched, matched }) // Send the packet. - t.localNode.UDPContact(toaddr) + toUDPAddr := &net.UDPAddr{IP: toaddr.Addr().AsSlice()} + t.localNode.UDPContact(toUDPAddr) t.write(toaddr, toid, req.Name(), packet) return rm } -func (t *UDPv4) makePing(toaddr *net.UDPAddr) *v4wire.Ping { +func (t *UDPv4) makePing(toaddr netip.AddrPort) *v4wire.Ping { return &v4wire.Ping{ Version: 4, From: t.ourEndpoint(), @@ -290,35 +300,39 @@ func (t *UDPv4) newRandomLookup(ctx context.Context) *lookup { func (t *UDPv4) newLookup(ctx context.Context, targetKey encPubkey) *lookup { target := enode.ID(crypto.Keccak256Hash(targetKey[:])) ekey := v4wire.Pubkey(targetKey) - it := newLookup(ctx, t.tab, target, func(n *node) ([]*node, error) { - return t.findnode(n.ID(), n.addr(), ekey) + it := newLookup(ctx, t.tab, target, func(n *enode.Node) ([]*enode.Node, error) { + addr, ok := n.UDPEndpoint() + if !ok { + return nil, errNoUDPEndpoint + } + return t.findnode(n.ID(), addr, ekey) }) return it } // findnode sends a findnode request to the given node and waits until // the node has sent up to k neighbors. -func (t *UDPv4) findnode(toid enode.ID, toaddr *net.UDPAddr, target v4wire.Pubkey) ([]*node, error) { - t.ensureBond(toid, toaddr) +func (t *UDPv4) findnode(toid enode.ID, toAddrPort netip.AddrPort, target v4wire.Pubkey) ([]*enode.Node, error) { + t.ensureBond(toid, toAddrPort) // Add a matcher for 'neighbours' replies to the pending reply queue. The matcher is // active until enough nodes have been received. - nodes := make([]*node, 0, bucketSize) + nodes := make([]*enode.Node, 0, bucketSize) nreceived := 0 - rm := t.pending(toid, toaddr.IP, v4wire.NeighborsPacket, func(r v4wire.Packet) (matched bool, requestDone bool) { + rm := t.pending(toid, toAddrPort.Addr(), v4wire.NeighborsPacket, func(r v4wire.Packet) (matched bool, requestDone bool) { reply := r.(*v4wire.Neighbors) for _, rn := range reply.Nodes { nreceived++ - n, err := t.nodeFromRPC(toaddr, rn) + n, err := t.nodeFromRPC(toAddrPort, rn) if err != nil { - t.log.Trace("Invalid neighbor node received", "ip", rn.IP, "addr", toaddr, "err", err) + t.log.Trace("Invalid neighbor node received", "ip", rn.IP, "addr", toAddrPort, "err", err) continue } nodes = append(nodes, n) } return true, nreceived >= bucketSize }) - t.send(toaddr, toid, &v4wire.Findnode{ + t.send(toAddrPort, toid, &v4wire.Findnode{ Target: target, Expiration: uint64(time.Now().Add(expiration).Unix()), }) @@ -336,7 +350,7 @@ func (t *UDPv4) findnode(toid enode.ID, toaddr *net.UDPAddr, target v4wire.Pubke // RequestENR sends ENRRequest to the given node and waits for a response. func (t *UDPv4) RequestENR(n *enode.Node) (*enode.Node, error) { - addr := &net.UDPAddr{IP: n.IP(), Port: n.UDP()} + addr, _ := n.UDPEndpoint() t.ensureBond(n.ID(), addr) req := &v4wire.ENRRequest{ @@ -349,7 +363,7 @@ func (t *UDPv4) RequestENR(n *enode.Node) (*enode.Node, error) { // Add a matcher for the reply to the pending reply queue. Responses are matched if // they reference the request we're about to send. - rm := t.pending(n.ID(), addr.IP, v4wire.ENRResponsePacket, func(r v4wire.Packet) (matched bool, requestDone bool) { + rm := t.pending(n.ID(), addr.Addr(), v4wire.ENRResponsePacket, func(r v4wire.Packet) (matched bool, requestDone bool) { matched = bytes.Equal(r.(*v4wire.ENRResponse).ReplyTok, hash) return matched, matched }) @@ -364,20 +378,24 @@ func (t *UDPv4) RequestENR(n *enode.Node) (*enode.Node, error) { return nil, err } if respN.ID() != n.ID() { - return nil, fmt.Errorf("invalid ID in response record") + return nil, errors.New("invalid ID in response record") } if respN.Seq() < n.Seq() { return n, nil // response record is older } - if err := netutil.CheckRelayIP(addr.IP, respN.IP()); err != nil { + if err := netutil.CheckRelayIP(addr.Addr().AsSlice(), respN.IP()); err != nil { return nil, fmt.Errorf("invalid IP in response record: %v", err) } return respN, nil } +func (t *UDPv4) TableBuckets() [][]BucketNode { + return t.tab.Nodes() +} + // pending adds a reply matcher to the pending reply queue. // see the documentation of type replyMatcher for a detailed explanation. -func (t *UDPv4) pending(id enode.ID, ip net.IP, ptype byte, callback replyMatchFunc) *replyMatcher { +func (t *UDPv4) pending(id enode.ID, ip netip.Addr, ptype byte, callback replyMatchFunc) *replyMatcher { ch := make(chan error, 1) p := &replyMatcher{from: id, ip: ip, ptype: ptype, callback: callback, errc: ch} select { @@ -391,7 +409,7 @@ func (t *UDPv4) pending(id enode.ID, ip net.IP, ptype byte, callback replyMatchF // handleReply dispatches a reply packet, invoking reply matchers. It returns // whether any matcher considered the packet acceptable. -func (t *UDPv4) handleReply(from enode.ID, fromIP net.IP, req v4wire.Packet) bool { +func (t *UDPv4) handleReply(from enode.ID, fromIP netip.Addr, req v4wire.Packet) bool { matched := make(chan bool, 1) select { case t.gotreply <- reply{from, fromIP, req, matched}: @@ -457,7 +475,7 @@ func (t *UDPv4) loop() { var matched bool // whether any replyMatcher considered the reply acceptable. for el := plist.Front(); el != nil; el = el.Next() { p := el.Value.(*replyMatcher) - if p.from == r.from && p.ptype == r.data.Kind() && p.ip.Equal(r.ip) { + if p.from == r.from && p.ptype == r.data.Kind() && p.ip == r.ip { ok, requestDone := p.callback(r.data) matched = matched || ok p.reply = r.data @@ -496,7 +514,7 @@ func (t *UDPv4) loop() { } } -func (t *UDPv4) send(toaddr *net.UDPAddr, toid enode.ID, req v4wire.Packet) ([]byte, error) { +func (t *UDPv4) send(toaddr netip.AddrPort, toid enode.ID, req v4wire.Packet) ([]byte, error) { packet, hash, err := v4wire.Encode(t.priv, req) if err != nil { return hash, err @@ -504,8 +522,8 @@ func (t *UDPv4) send(toaddr *net.UDPAddr, toid enode.ID, req v4wire.Packet) ([]b return hash, t.write(toaddr, toid, req.Name(), packet) } -func (t *UDPv4) write(toaddr *net.UDPAddr, toid enode.ID, what string, packet []byte) error { - _, err := t.conn.WriteToUDP(packet, toaddr) +func (t *UDPv4) write(toaddr netip.AddrPort, toid enode.ID, what string, packet []byte) error { + _, err := t.conn.WriteToUDPAddrPort(packet, toaddr) t.log.Trace(">> "+what, "id", toid, "addr", toaddr, "err", err) return err } @@ -519,7 +537,7 @@ func (t *UDPv4) readLoop(unhandled chan<- ReadPacket) { buf := make([]byte, maxPacketSize) for { - nbytes, from, err := t.conn.ReadFromUDP(buf) + nbytes, from, err := t.conn.ReadFromUDPAddrPort(buf) if netutil.IsTemporaryError(err) { // Ignore temporary read errors. t.log.Debug("Temporary UDP read error", "err", err) @@ -540,7 +558,7 @@ func (t *UDPv4) readLoop(unhandled chan<- ReadPacket) { } } -func (t *UDPv4) handlePacket(from *net.UDPAddr, buf []byte) error { +func (t *UDPv4) handlePacket(from netip.AddrPort, buf []byte) error { rawpacket, fromKey, hash, err := v4wire.Decode(buf) if err != nil { t.log.Debug("Bad discv4 packet", "addr", from, "err", err) @@ -559,15 +577,16 @@ func (t *UDPv4) handlePacket(from *net.UDPAddr, buf []byte) error { } // checkBond checks if the given node has a recent enough endpoint proof. -func (t *UDPv4) checkBond(id enode.ID, ip net.IP) bool { - return time.Since(t.db.LastPongReceived(id, ip)) < bondExpiration +func (t *UDPv4) checkBond(id enode.ID, ip netip.AddrPort) bool { + return time.Since(t.db.LastPongReceived(id, ip.Addr().AsSlice())) < bondExpiration } // ensureBond solicits a ping from a node if we haven't seen a ping from it for a while. // This ensures there is a valid endpoint proof on the remote end. -func (t *UDPv4) ensureBond(toid enode.ID, toaddr *net.UDPAddr) { - tooOld := time.Since(t.db.LastPingReceived(toid, toaddr.IP)) > bondExpiration - if tooOld || t.db.FindFails(toid, toaddr.IP) > maxFindnodeFailures { +func (t *UDPv4) ensureBond(toid enode.ID, toaddr netip.AddrPort) { + ip := toaddr.Addr().AsSlice() + tooOld := time.Since(t.db.LastPingReceived(toid, ip)) > bondExpiration + if tooOld || t.db.FindFails(toid, ip) > maxFindnodeFailures { rm := t.sendPing(toid, toaddr, nil) <-rm.errc // Wait for them to ping back and process our pong. @@ -575,11 +594,11 @@ func (t *UDPv4) ensureBond(toid enode.ID, toaddr *net.UDPAddr) { } } -func (t *UDPv4) nodeFromRPC(sender *net.UDPAddr, rn v4wire.Node) (*node, error) { +func (t *UDPv4) nodeFromRPC(sender netip.AddrPort, rn v4wire.Node) (*enode.Node, error) { if rn.UDP <= 1024 { return nil, errLowPort } - if err := netutil.CheckRelayIP(sender.IP, rn.IP); err != nil { + if err := netutil.CheckRelayIP(sender.Addr().AsSlice(), rn.IP); err != nil { return nil, err } if t.netrestrict != nil && !t.netrestrict.Contains(rn.IP) { @@ -589,12 +608,12 @@ func (t *UDPv4) nodeFromRPC(sender *net.UDPAddr, rn v4wire.Node) (*node, error) if err != nil { return nil, err } - n := wrapNode(enode.NewV4(key, rn.IP, int(rn.TCP), int(rn.UDP))) + n := enode.NewV4(key, rn.IP, int(rn.TCP), int(rn.UDP)) err = n.ValidateComplete() return n, err } -func nodeToRPC(n *node) v4wire.Node { +func nodeToRPC(n *enode.Node) v4wire.Node { var key ecdsa.PublicKey var ekey v4wire.Pubkey if err := n.Load((*enode.Secp256k1)(&key)); err == nil { @@ -633,14 +652,14 @@ type packetHandlerV4 struct { senderKey *ecdsa.PublicKey // used for ping // preverify checks whether the packet is valid and should be handled at all. - preverify func(p *packetHandlerV4, from *net.UDPAddr, fromID enode.ID, fromKey v4wire.Pubkey) error + preverify func(p *packetHandlerV4, from netip.AddrPort, fromID enode.ID, fromKey v4wire.Pubkey) error // handle handles the packet. - handle func(req *packetHandlerV4, from *net.UDPAddr, fromID enode.ID, mac []byte) + handle func(req *packetHandlerV4, from netip.AddrPort, fromID enode.ID, mac []byte) } // PING/v4 -func (t *UDPv4) verifyPing(h *packetHandlerV4, from *net.UDPAddr, fromID enode.ID, fromKey v4wire.Pubkey) error { +func (t *UDPv4) verifyPing(h *packetHandlerV4, from netip.AddrPort, fromID enode.ID, fromKey v4wire.Pubkey) error { req := h.Packet.(*v4wire.Ping) if v4wire.Expired(req.Expiration) { @@ -654,7 +673,7 @@ func (t *UDPv4) verifyPing(h *packetHandlerV4, from *net.UDPAddr, fromID enode.I return nil } -func (t *UDPv4) handlePing(h *packetHandlerV4, from *net.UDPAddr, fromID enode.ID, mac []byte) { +func (t *UDPv4) handlePing(h *packetHandlerV4, from netip.AddrPort, fromID enode.ID, mac []byte) { req := h.Packet.(*v4wire.Ping) // Reply. @@ -666,45 +685,51 @@ func (t *UDPv4) handlePing(h *packetHandlerV4, from *net.UDPAddr, fromID enode.I }) // Ping back if our last pong on file is too far in the past. - n := wrapNode(enode.NewV4(h.senderKey, from.IP, int(req.From.TCP), from.Port)) - if time.Since(t.db.LastPongReceived(n.ID(), from.IP)) > bondExpiration { + fromIP := from.Addr().AsSlice() + n := enode.NewV4(h.senderKey, fromIP, int(req.From.TCP), int(from.Port())) + if time.Since(t.db.LastPongReceived(n.ID(), fromIP)) > bondExpiration { t.sendPing(fromID, from, func() { - t.tab.addVerifiedNode(n) + t.tab.addInboundNode(n) }) } else { - t.tab.addVerifiedNode(n) + t.tab.addInboundNode(n) } // Update node database and endpoint predictor. - t.db.UpdateLastPingReceived(n.ID(), from.IP, time.Now()) - t.localNode.UDPEndpointStatement(from, &net.UDPAddr{IP: req.To.IP, Port: int(req.To.UDP)}) + t.db.UpdateLastPingReceived(n.ID(), fromIP, time.Now()) + fromUDPAddr := &net.UDPAddr{IP: fromIP, Port: int(from.Port())} + toUDPAddr := &net.UDPAddr{IP: req.To.IP, Port: int(req.To.UDP)} + t.localNode.UDPEndpointStatement(fromUDPAddr, toUDPAddr) } // PONG/v4 -func (t *UDPv4) verifyPong(h *packetHandlerV4, from *net.UDPAddr, fromID enode.ID, fromKey v4wire.Pubkey) error { +func (t *UDPv4) verifyPong(h *packetHandlerV4, from netip.AddrPort, fromID enode.ID, fromKey v4wire.Pubkey) error { req := h.Packet.(*v4wire.Pong) if v4wire.Expired(req.Expiration) { return errExpired } - if !t.handleReply(fromID, from.IP, req) { + if !t.handleReply(fromID, from.Addr(), req) { return errUnsolicitedReply } - t.localNode.UDPEndpointStatement(from, &net.UDPAddr{IP: req.To.IP, Port: int(req.To.UDP)}) - t.db.UpdateLastPongReceived(fromID, from.IP, time.Now()) + fromIP := from.Addr().AsSlice() + fromUDPAddr := &net.UDPAddr{IP: fromIP, Port: int(from.Port())} + toUDPAddr := &net.UDPAddr{IP: req.To.IP, Port: int(req.To.UDP)} + t.localNode.UDPEndpointStatement(fromUDPAddr, toUDPAddr) + t.db.UpdateLastPongReceived(fromID, fromIP, time.Now()) return nil } // FINDNODE/v4 -func (t *UDPv4) verifyFindnode(h *packetHandlerV4, from *net.UDPAddr, fromID enode.ID, fromKey v4wire.Pubkey) error { +func (t *UDPv4) verifyFindnode(h *packetHandlerV4, from netip.AddrPort, fromID enode.ID, fromKey v4wire.Pubkey) error { req := h.Packet.(*v4wire.Findnode) if v4wire.Expired(req.Expiration) { return errExpired } - if !t.checkBond(fromID, from.IP) { + if !t.checkBond(fromID, from) { // No endpoint proof pong exists, we don't process the packet. This prevents an // attack vector where the discovery protocol could be used to amplify traffic in a // DDOS attack. A malicious actor would send a findnode request with the IP address @@ -716,7 +741,7 @@ func (t *UDPv4) verifyFindnode(h *packetHandlerV4, from *net.UDPAddr, fromID eno return nil } -func (t *UDPv4) handleFindnode(h *packetHandlerV4, from *net.UDPAddr, fromID enode.ID, mac []byte) { +func (t *UDPv4) handleFindnode(h *packetHandlerV4, from netip.AddrPort, fromID enode.ID, mac []byte) { req := h.Packet.(*v4wire.Findnode) // Determine closest nodes. @@ -728,7 +753,8 @@ func (t *UDPv4) handleFindnode(h *packetHandlerV4, from *net.UDPAddr, fromID eno p := v4wire.Neighbors{Expiration: uint64(time.Now().Add(expiration).Unix())} var sent bool for _, n := range closest { - if netutil.CheckRelayIP(from.IP, n.IP()) == nil { + fromIP := from.Addr().AsSlice() + if netutil.CheckRelayIP(fromIP, n.IP()) == nil { p.Nodes = append(p.Nodes, nodeToRPC(n)) } if len(p.Nodes) == v4wire.MaxNeighbors { @@ -744,13 +770,13 @@ func (t *UDPv4) handleFindnode(h *packetHandlerV4, from *net.UDPAddr, fromID eno // NEIGHBORS/v4 -func (t *UDPv4) verifyNeighbors(h *packetHandlerV4, from *net.UDPAddr, fromID enode.ID, fromKey v4wire.Pubkey) error { +func (t *UDPv4) verifyNeighbors(h *packetHandlerV4, from netip.AddrPort, fromID enode.ID, fromKey v4wire.Pubkey) error { req := h.Packet.(*v4wire.Neighbors) if v4wire.Expired(req.Expiration) { return errExpired } - if !t.handleReply(fromID, from.IP, h.Packet) { + if !t.handleReply(fromID, from.Addr(), h.Packet) { return errUnsolicitedReply } return nil @@ -758,19 +784,19 @@ func (t *UDPv4) verifyNeighbors(h *packetHandlerV4, from *net.UDPAddr, fromID en // ENRREQUEST/v4 -func (t *UDPv4) verifyENRRequest(h *packetHandlerV4, from *net.UDPAddr, fromID enode.ID, fromKey v4wire.Pubkey) error { +func (t *UDPv4) verifyENRRequest(h *packetHandlerV4, from netip.AddrPort, fromID enode.ID, fromKey v4wire.Pubkey) error { req := h.Packet.(*v4wire.ENRRequest) if v4wire.Expired(req.Expiration) { return errExpired } - if !t.checkBond(fromID, from.IP) { + if !t.checkBond(fromID, from) { return errUnknownNode } return nil } -func (t *UDPv4) handleENRRequest(h *packetHandlerV4, from *net.UDPAddr, fromID enode.ID, mac []byte) { +func (t *UDPv4) handleENRRequest(h *packetHandlerV4, from netip.AddrPort, fromID enode.ID, mac []byte) { t.send(from, fromID, &v4wire.ENRResponse{ ReplyTok: mac, Record: *t.localNode.Node().Record(), @@ -779,8 +805,8 @@ func (t *UDPv4) handleENRRequest(h *packetHandlerV4, from *net.UDPAddr, fromID e // ENRRESPONSE/v4 -func (t *UDPv4) verifyENRResponse(h *packetHandlerV4, from *net.UDPAddr, fromID enode.ID, fromKey v4wire.Pubkey) error { - if !t.handleReply(fromID, from.IP, h.Packet) { +func (t *UDPv4) verifyENRResponse(h *packetHandlerV4, from netip.AddrPort, fromID enode.ID, fromKey v4wire.Pubkey) error { + if !t.handleReply(fromID, from.Addr(), h.Packet) { return errUnsolicitedReply } return nil diff --git a/p2p/discover/v4_udp_test.go b/p2p/discover/v4_udp_test.go index 361e379626..c77347429c 100644 --- a/p2p/discover/v4_udp_test.go +++ b/p2p/discover/v4_udp_test.go @@ -26,6 +26,7 @@ import ( "io" "math/rand" "net" + "net/netip" "reflect" "sync" "testing" @@ -55,7 +56,7 @@ type udpTest struct { udp *UDPv4 sent [][]byte localkey, remotekey *ecdsa.PrivateKey - remoteaddr *net.UDPAddr + remoteaddr netip.AddrPort } func newUDPTest(t *testing.T) *udpTest { @@ -64,7 +65,7 @@ func newUDPTest(t *testing.T) *udpTest { pipe: newpipe(), localkey: newkey(), remotekey: newkey(), - remoteaddr: &net.UDPAddr{IP: net.IP{10, 0, 1, 99}, Port: 30303}, + remoteaddr: netip.MustParseAddrPort("10.0.1.99:30303"), } test.db, _ = enode.OpenDB("") @@ -92,7 +93,7 @@ func (test *udpTest) packetIn(wantError error, data v4wire.Packet) { } // handles a packet as if it had been sent to the transport by the key/endpoint. -func (test *udpTest) packetInFrom(wantError error, key *ecdsa.PrivateKey, addr *net.UDPAddr, data v4wire.Packet) { +func (test *udpTest) packetInFrom(wantError error, key *ecdsa.PrivateKey, addr netip.AddrPort, data v4wire.Packet) { test.t.Helper() enc, _, err := v4wire.Encode(key, data) @@ -106,7 +107,7 @@ func (test *udpTest) packetInFrom(wantError error, key *ecdsa.PrivateKey, addr * } // waits for a packet to be sent by the transport. -// validate should have type func(X, *net.UDPAddr, []byte), where X is a packet type. +// validate should have type func(X, netip.AddrPort, []byte), where X is a packet type. func (test *udpTest) waitPacketOut(validate interface{}) (closed bool) { test.t.Helper() @@ -128,7 +129,7 @@ func (test *udpTest) waitPacketOut(validate interface{}) (closed bool) { test.t.Errorf("sent packet type mismatch, got: %v, want: %v", reflect.TypeOf(p), exptype) return false } - fn.Call([]reflect.Value{reflect.ValueOf(p), reflect.ValueOf(&dgram.to), reflect.ValueOf(hash)}) + fn.Call([]reflect.Value{reflect.ValueOf(p), reflect.ValueOf(dgram.to), reflect.ValueOf(hash)}) return false } @@ -236,7 +237,7 @@ func TestUDPv4_findnodeTimeout(t *testing.T) { test := newUDPTest(t) defer test.close() - toaddr := &net.UDPAddr{IP: net.ParseIP("1.2.3.4"), Port: 2222} + toaddr := netip.AddrPortFrom(netip.MustParseAddr("1.2.3.4"), 2222) toid := enode.ID{1, 2, 3, 4} target := v4wire.Pubkey{4, 5, 6, 7} result, err := test.udp.findnode(toid, toaddr, target) @@ -261,26 +262,25 @@ func TestUDPv4_findnode(t *testing.T) { for i := 0; i < numCandidates; i++ { key := newkey() ip := net.IP{10, 13, 0, byte(i)} - n := wrapNode(enode.NewV4(&key.PublicKey, ip, 0, 2000)) + n := enode.NewV4(&key.PublicKey, ip, 0, 2000) // Ensure half of table content isn't verified live yet. if i > numCandidates/2 { - n.livenessChecks = 1 live[n.ID()] = true } + test.table.addFoundNode(n, live[n.ID()]) nodes.push(n, numCandidates) } - fillTable(test.table, nodes.entries, false) // ensure there's a bond with the test node, // findnode won't be accepted otherwise. remoteID := v4wire.EncodePubkey(&test.remotekey.PublicKey).ID() - test.table.db.UpdateLastPongReceived(remoteID, test.remoteaddr.IP, time.Now()) + test.table.db.UpdateLastPongReceived(remoteID, test.remoteaddr.Addr().AsSlice(), time.Now()) // check that closest neighbors are returned. expected := test.table.findnodeByID(testTarget.ID(), bucketSize, true) test.packetIn(nil, &v4wire.Findnode{Target: testTarget, Expiration: futureExp}) - waitNeighbors := func(want []*node) { - test.waitPacketOut(func(p *v4wire.Neighbors, to *net.UDPAddr, hash []byte) { + waitNeighbors := func(want []*enode.Node) { + test.waitPacketOut(func(p *v4wire.Neighbors, to netip.AddrPort, hash []byte) { if len(p.Nodes) != len(want) { t.Errorf("wrong number of results: got %d, want %d", len(p.Nodes), bucketSize) return @@ -309,10 +309,10 @@ func TestUDPv4_findnodeMultiReply(t *testing.T) { defer test.close() rid := enode.PubkeyToIDV4(&test.remotekey.PublicKey) - test.table.db.UpdateLastPingReceived(rid, test.remoteaddr.IP, time.Now()) + test.table.db.UpdateLastPingReceived(rid, test.remoteaddr.Addr().AsSlice(), time.Now()) // queue a pending findnode request - resultc, errc := make(chan []*node, 1), make(chan error, 1) + resultc, errc := make(chan []*enode.Node, 1), make(chan error, 1) go func() { rid := encodePubkey(&test.remotekey.PublicKey).id() ns, err := test.udp.findnode(rid, test.remoteaddr, testTarget) @@ -325,18 +325,18 @@ func TestUDPv4_findnodeMultiReply(t *testing.T) { // wait for the findnode to be sent. // after it is sent, the transport is waiting for a reply - test.waitPacketOut(func(p *v4wire.Findnode, to *net.UDPAddr, hash []byte) { + test.waitPacketOut(func(p *v4wire.Findnode, to netip.AddrPort, hash []byte) { if p.Target != testTarget { t.Errorf("wrong target: got %v, want %v", p.Target, testTarget) } }) // send the reply as two packets. - list := []*node{ - wrapNode(enode.MustParse("enode://ba85011c70bcc5c04d8607d3a0ed29aa6179c092cbdda10d5d32684fb33ed01bd94f588ca8f91ac48318087dcb02eaf36773a7a453f0eedd6742af668097b29c@10.0.1.16:30303?discport=30304")), - wrapNode(enode.MustParse("enode://81fa361d25f157cd421c60dcc28d8dac5ef6a89476633339c5df30287474520caca09627da18543d9079b5b288698b542d56167aa5c09111e55acdbbdf2ef799@10.0.1.16:30303")), - wrapNode(enode.MustParse("enode://9bffefd833d53fac8e652415f4973bee289e8b1a5c6c4cbe70abf817ce8a64cee11b823b66a987f51aaa9fba0d6a91b3e6bf0d5a5d1042de8e9eeea057b217f8@10.0.1.36:30301?discport=17")), - wrapNode(enode.MustParse("enode://1b5b4aa662d7cb44a7221bfba67302590b643028197a7d5214790f3bac7aaa4a3241be9e83c09cf1f6c69d007c634faae3dc1b1221793e8446c0b3a09de65960@10.0.1.16:30303")), + list := []*enode.Node{ + enode.MustParse("enode://ba85011c70bcc5c04d8607d3a0ed29aa6179c092cbdda10d5d32684fb33ed01bd94f588ca8f91ac48318087dcb02eaf36773a7a453f0eedd6742af668097b29c@10.0.1.16:30303?discport=30304"), + enode.MustParse("enode://81fa361d25f157cd421c60dcc28d8dac5ef6a89476633339c5df30287474520caca09627da18543d9079b5b288698b542d56167aa5c09111e55acdbbdf2ef799@10.0.1.16:30303"), + enode.MustParse("enode://9bffefd833d53fac8e652415f4973bee289e8b1a5c6c4cbe70abf817ce8a64cee11b823b66a987f51aaa9fba0d6a91b3e6bf0d5a5d1042de8e9eeea057b217f8@10.0.1.36:30301?discport=17"), + enode.MustParse("enode://1b5b4aa662d7cb44a7221bfba67302590b643028197a7d5214790f3bac7aaa4a3241be9e83c09cf1f6c69d007c634faae3dc1b1221793e8446c0b3a09de65960@10.0.1.16:30303"), } rpclist := make([]v4wire.Node, len(list)) for i := range list { @@ -368,8 +368,8 @@ func TestUDPv4_pingMatch(t *testing.T) { crand.Read(randToken) test.packetIn(nil, &v4wire.Ping{From: testRemote, To: testLocalAnnounced, Version: 4, Expiration: futureExp}) - test.waitPacketOut(func(*v4wire.Pong, *net.UDPAddr, []byte) {}) - test.waitPacketOut(func(*v4wire.Ping, *net.UDPAddr, []byte) {}) + test.waitPacketOut(func(*v4wire.Pong, netip.AddrPort, []byte) {}) + test.waitPacketOut(func(*v4wire.Ping, netip.AddrPort, []byte) {}) test.packetIn(errUnsolicitedReply, &v4wire.Pong{ReplyTok: randToken, To: testLocalAnnounced, Expiration: futureExp}) } @@ -379,10 +379,10 @@ func TestUDPv4_pingMatchIP(t *testing.T) { defer test.close() test.packetIn(nil, &v4wire.Ping{From: testRemote, To: testLocalAnnounced, Version: 4, Expiration: futureExp}) - test.waitPacketOut(func(*v4wire.Pong, *net.UDPAddr, []byte) {}) + test.waitPacketOut(func(*v4wire.Pong, netip.AddrPort, []byte) {}) - test.waitPacketOut(func(p *v4wire.Ping, to *net.UDPAddr, hash []byte) { - wrongAddr := &net.UDPAddr{IP: net.IP{33, 44, 1, 2}, Port: 30000} + test.waitPacketOut(func(p *v4wire.Ping, to netip.AddrPort, hash []byte) { + wrongAddr := netip.MustParseAddrPort("33.44.1.2:30000") test.packetInFrom(errUnsolicitedReply, test.remotekey, wrongAddr, &v4wire.Pong{ ReplyTok: hash, To: testLocalAnnounced, @@ -393,41 +393,36 @@ func TestUDPv4_pingMatchIP(t *testing.T) { func TestUDPv4_successfulPing(t *testing.T) { test := newUDPTest(t) - added := make(chan *node, 1) - test.table.nodeAddedHook = func(b *bucket, n *node) { added <- n } + added := make(chan *tableNode, 1) + test.table.nodeAddedHook = func(b *bucket, n *tableNode) { added <- n } defer test.close() // The remote side sends a ping packet to initiate the exchange. go test.packetIn(nil, &v4wire.Ping{From: testRemote, To: testLocalAnnounced, Version: 4, Expiration: futureExp}) // The ping is replied to. - test.waitPacketOut(func(p *v4wire.Pong, to *net.UDPAddr, hash []byte) { + test.waitPacketOut(func(p *v4wire.Pong, to netip.AddrPort, hash []byte) { pinghash := test.sent[0][:32] if !bytes.Equal(p.ReplyTok, pinghash) { t.Errorf("got pong.ReplyTok %x, want %x", p.ReplyTok, pinghash) } - wantTo := v4wire.Endpoint{ - // The mirrored UDP address is the UDP packet sender - IP: test.remoteaddr.IP, UDP: uint16(test.remoteaddr.Port), - // The mirrored TCP port is the one from the ping packet - TCP: testRemote.TCP, - } + // The mirrored UDP address is the UDP packet sender. + // The mirrored TCP port is the one from the ping packet. + wantTo := v4wire.NewEndpoint(test.remoteaddr, testRemote.TCP) if !reflect.DeepEqual(p.To, wantTo) { t.Errorf("got pong.To %v, want %v", p.To, wantTo) } }) // Remote is unknown, the table pings back. - test.waitPacketOut(func(p *v4wire.Ping, to *net.UDPAddr, hash []byte) { - if !reflect.DeepEqual(p.From, test.udp.ourEndpoint()) { + test.waitPacketOut(func(p *v4wire.Ping, to netip.AddrPort, hash []byte) { + wantFrom := test.udp.ourEndpoint() + wantFrom.IP = net.IP{} + if !reflect.DeepEqual(p.From, wantFrom) { t.Errorf("got ping.From %#v, want %#v", p.From, test.udp.ourEndpoint()) } - wantTo := v4wire.Endpoint{ - // The mirrored UDP address is the UDP packet sender. - IP: test.remoteaddr.IP, - UDP: uint16(test.remoteaddr.Port), - TCP: 0, - } + // The mirrored UDP address is the UDP packet sender. + wantTo := v4wire.NewEndpoint(test.remoteaddr, 0) if !reflect.DeepEqual(p.To, wantTo) { t.Errorf("got ping.To %v, want %v", p.To, wantTo) } @@ -442,11 +437,11 @@ func TestUDPv4_successfulPing(t *testing.T) { if n.ID() != rid { t.Errorf("node has wrong ID: got %v, want %v", n.ID(), rid) } - if !n.IP().Equal(test.remoteaddr.IP) { - t.Errorf("node has wrong IP: got %v, want: %v", n.IP(), test.remoteaddr.IP) + if !n.IP().Equal(test.remoteaddr.Addr().AsSlice()) { + t.Errorf("node has wrong IP: got %v, want: %v", n.IP(), test.remoteaddr.Addr()) } - if n.UDP() != test.remoteaddr.Port { - t.Errorf("node has wrong UDP port: got %v, want: %v", n.UDP(), test.remoteaddr.Port) + if n.UDP() != int(test.remoteaddr.Port()) { + t.Errorf("node has wrong UDP port: got %v, want: %v", n.UDP(), test.remoteaddr.Port()) } if n.TCP() != int(testRemote.TCP) { t.Errorf("node has wrong TCP port: got %v, want: %v", n.TCP(), testRemote.TCP) @@ -469,12 +464,12 @@ func TestUDPv4_EIP868(t *testing.T) { // Perform endpoint proof and check for sequence number in packet tail. test.packetIn(nil, &v4wire.Ping{Expiration: futureExp}) - test.waitPacketOut(func(p *v4wire.Pong, addr *net.UDPAddr, hash []byte) { + test.waitPacketOut(func(p *v4wire.Pong, addr netip.AddrPort, hash []byte) { if p.ENRSeq != wantNode.Seq() { t.Errorf("wrong sequence number in pong: %d, want %d", p.ENRSeq, wantNode.Seq()) } }) - test.waitPacketOut(func(p *v4wire.Ping, addr *net.UDPAddr, hash []byte) { + test.waitPacketOut(func(p *v4wire.Ping, addr netip.AddrPort, hash []byte) { if p.ENRSeq != wantNode.Seq() { t.Errorf("wrong sequence number in ping: %d, want %d", p.ENRSeq, wantNode.Seq()) } @@ -483,7 +478,7 @@ func TestUDPv4_EIP868(t *testing.T) { // Request should work now. test.packetIn(nil, &v4wire.ENRRequest{Expiration: futureExp}) - test.waitPacketOut(func(p *v4wire.ENRResponse, addr *net.UDPAddr, hash []byte) { + test.waitPacketOut(func(p *v4wire.ENRResponse, addr netip.AddrPort, hash []byte) { n, err := enode.New(enode.ValidSchemes, &p.Record) if err != nil { t.Fatalf("invalid record: %v", err) @@ -584,7 +579,7 @@ type dgramPipe struct { } type dgram struct { - to net.UDPAddr + to netip.AddrPort data []byte } @@ -597,8 +592,8 @@ func newpipe() *dgramPipe { } } -// WriteToUDP queues a datagram. -func (c *dgramPipe) WriteToUDP(b []byte, to *net.UDPAddr) (n int, err error) { +// WriteToUDPAddrPort queues a datagram. +func (c *dgramPipe) WriteToUDPAddrPort(b []byte, to netip.AddrPort) (n int, err error) { msg := make([]byte, len(b)) copy(msg, b) c.mu.Lock() @@ -606,15 +601,15 @@ func (c *dgramPipe) WriteToUDP(b []byte, to *net.UDPAddr) (n int, err error) { if c.closed { return 0, errors.New("closed") } - c.queue = append(c.queue, dgram{*to, b}) + c.queue = append(c.queue, dgram{to, b}) c.cond.Signal() return len(b), nil } -// ReadFromUDP just hangs until the pipe is closed. -func (c *dgramPipe) ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) { +// ReadFromUDPAddrPort just hangs until the pipe is closed. +func (c *dgramPipe) ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) { <-c.closing - return 0, nil, io.EOF + return 0, netip.AddrPort{}, io.EOF } func (c *dgramPipe) Close() error { diff --git a/p2p/discover/v4wire/v4wire.go b/p2p/discover/v4wire/v4wire.go index 9c59359fb2..958cca324d 100644 --- a/p2p/discover/v4wire/v4wire.go +++ b/p2p/discover/v4wire/v4wire.go @@ -25,6 +25,7 @@ import ( "fmt" "math/big" "net" + "net/netip" "time" "github.com/ethereum/go-ethereum/common/math" @@ -150,14 +151,15 @@ type Endpoint struct { } // NewEndpoint creates an endpoint. -func NewEndpoint(addr *net.UDPAddr, tcpPort uint16) Endpoint { - ip := net.IP{} - if ip4 := addr.IP.To4(); ip4 != nil { - ip = ip4 - } else if ip6 := addr.IP.To16(); ip6 != nil { - ip = ip6 +func NewEndpoint(addr netip.AddrPort, tcpPort uint16) Endpoint { + var ip net.IP + if addr.Addr().Is4() || addr.Addr().Is4In6() { + ip4 := addr.Addr().As4() + ip = ip4[:] + } else { + ip = addr.Addr().AsSlice() } - return Endpoint{IP: ip, UDP: uint16(addr.Port), TCP: tcpPort} + return Endpoint{IP: ip, UDP: addr.Port(), TCP: tcpPort} } type Packet interface { diff --git a/p2p/discover/v5_talk.go b/p2p/discover/v5_talk.go index c1f6787940..2246b47141 100644 --- a/p2p/discover/v5_talk.go +++ b/p2p/discover/v5_talk.go @@ -18,6 +18,7 @@ package discover import ( "net" + "net/netip" "sync" "time" @@ -70,7 +71,7 @@ func (t *talkSystem) register(protocol string, handler TalkRequestHandler) { } // handleRequest handles a talk request. -func (t *talkSystem) handleRequest(id enode.ID, addr *net.UDPAddr, req *v5wire.TalkRequest) { +func (t *talkSystem) handleRequest(id enode.ID, addr netip.AddrPort, req *v5wire.TalkRequest) { t.mutex.Lock() handler, ok := t.handlers[req.Protocol] t.mutex.Unlock() @@ -88,7 +89,8 @@ func (t *talkSystem) handleRequest(id enode.ID, addr *net.UDPAddr, req *v5wire.T case <-t.slots: go func() { defer func() { t.slots <- struct{}{} }() - respMessage := handler(id, addr, req.Message) + udpAddr := &net.UDPAddr{IP: addr.Addr().AsSlice(), Port: int(addr.Port())} + respMessage := handler(id, udpAddr, req.Message) resp := &v5wire.TalkResponse{ReqID: req.ReqID, Message: respMessage} t.transport.sendFromAnotherThread(id, addr, resp) }() diff --git a/p2p/discover/v5_udp.go b/p2p/discover/v5_udp.go index 8b3e33d37c..b816a9c17a 100644 --- a/p2p/discover/v5_udp.go +++ b/p2p/discover/v5_udp.go @@ -25,6 +25,7 @@ import ( "fmt" "io" "net" + "net/netip" "sync" "time" @@ -100,14 +101,14 @@ type UDPv5 struct { type sendRequest struct { destID enode.ID - destAddr *net.UDPAddr + destAddr netip.AddrPort msg v5wire.Packet } // callV5 represents a remote procedure call against another node. type callV5 struct { id enode.ID - addr *net.UDPAddr + addr netip.AddrPort node *enode.Node // This is required to perform handshakes. packet v5wire.Packet @@ -174,7 +175,7 @@ func newUDPv5(conn UDPConn, ln *enode.LocalNode, cfg Config) (*UDPv5, error) { cancelCloseCtx: cancelCloseCtx, } t.talk = newTalkSystem(t) - tab, err := newMeteredTable(t, t.db, cfg) + tab, err := newTable(t, t.db, cfg) if err != nil { return nil, err } @@ -232,7 +233,7 @@ func (t *UDPv5) AllNodes() []*enode.Node { for _, b := range &t.tab.buckets { for _, n := range b.entries { - nodes = append(nodes, unwrapNode(n)) + nodes = append(nodes, n.Node) } } return nodes @@ -265,7 +266,7 @@ func (t *UDPv5) TalkRequest(n *enode.Node, protocol string, request []byte) ([]b } // TalkRequestToID sends a talk request to a node and waits for a response. -func (t *UDPv5) TalkRequestToID(id enode.ID, addr *net.UDPAddr, protocol string, request []byte) ([]byte, error) { +func (t *UDPv5) TalkRequestToID(id enode.ID, addr netip.AddrPort, protocol string, request []byte) ([]byte, error) { req := &v5wire.TalkRequest{Protocol: protocol, Message: request} resp := t.callToID(id, addr, v5wire.TalkResponseMsg, req) defer t.callDone(resp) @@ -313,26 +314,26 @@ func (t *UDPv5) newRandomLookup(ctx context.Context) *lookup { } func (t *UDPv5) newLookup(ctx context.Context, target enode.ID) *lookup { - return newLookup(ctx, t.tab, target, func(n *node) ([]*node, error) { + return newLookup(ctx, t.tab, target, func(n *enode.Node) ([]*enode.Node, error) { return t.lookupWorker(n, target) }) } // lookupWorker performs FINDNODE calls against a single node during lookup. -func (t *UDPv5) lookupWorker(destNode *node, target enode.ID) ([]*node, error) { +func (t *UDPv5) lookupWorker(destNode *enode.Node, target enode.ID) ([]*enode.Node, error) { var ( dists = lookupDistances(target, destNode.ID()) nodes = nodesByDistance{target: target} err error ) var r []*enode.Node - r, err = t.findnode(unwrapNode(destNode), dists) + r, err = t.findnode(destNode, dists) if errors.Is(err, errClosed) { return nil, err } for _, n := range r { if n.ID() != t.Self().ID() { - nodes.push(wrapNode(n), findnodeResultLimit) + nodes.push(n, findnodeResultLimit) } } return nodes.entries, err @@ -426,7 +427,7 @@ func (t *UDPv5) verifyResponseNode(c *callV5, r *enr.Record, distances []uint, s if err != nil { return nil, err } - if err := netutil.CheckRelayIP(c.addr.IP, node.IP()); err != nil { + if err := netutil.CheckRelayIP(c.addr.Addr().AsSlice(), node.IP()); err != nil { return nil, err } if t.netrestrict != nil && !t.netrestrict.Contains(node.IP()) { @@ -442,7 +443,7 @@ func (t *UDPv5) verifyResponseNode(c *callV5, r *enr.Record, distances []uint, s } } if _, ok := seen[node.ID()]; ok { - return nil, fmt.Errorf("duplicate record") + return nil, errors.New("duplicate record") } seen[node.ID()] = struct{}{} return node, nil @@ -460,14 +461,14 @@ func containsUint(x uint, xs []uint) bool { // callToNode sends the given call and sets up a handler for response packets (of message // type responseType). Responses are dispatched to the call's response channel. func (t *UDPv5) callToNode(n *enode.Node, responseType byte, req v5wire.Packet) *callV5 { - addr := &net.UDPAddr{IP: n.IP(), Port: n.UDP()} + addr, _ := n.UDPEndpoint() c := &callV5{id: n.ID(), addr: addr, node: n} t.initCall(c, responseType, req) return c } // callToID is like callToNode, but for cases where the node record is not available. -func (t *UDPv5) callToID(id enode.ID, addr *net.UDPAddr, responseType byte, req v5wire.Packet) *callV5 { +func (t *UDPv5) callToID(id enode.ID, addr netip.AddrPort, responseType byte, req v5wire.Packet) *callV5 { c := &callV5{id: id, addr: addr} t.initCall(c, responseType, req) return c @@ -627,12 +628,12 @@ func (t *UDPv5) sendCall(c *callV5) { // sendResponse sends a response packet to the given node. // This doesn't trigger a handshake even if no keys are available. -func (t *UDPv5) sendResponse(toID enode.ID, toAddr *net.UDPAddr, packet v5wire.Packet) error { +func (t *UDPv5) sendResponse(toID enode.ID, toAddr netip.AddrPort, packet v5wire.Packet) error { _, err := t.send(toID, toAddr, packet, nil) return err } -func (t *UDPv5) sendFromAnotherThread(toID enode.ID, toAddr *net.UDPAddr, packet v5wire.Packet) { +func (t *UDPv5) sendFromAnotherThread(toID enode.ID, toAddr netip.AddrPort, packet v5wire.Packet) { select { case t.sendCh <- sendRequest{toID, toAddr, packet}: case <-t.closeCtx.Done(): @@ -640,7 +641,7 @@ func (t *UDPv5) sendFromAnotherThread(toID enode.ID, toAddr *net.UDPAddr, packet } // send sends a packet to the given node. -func (t *UDPv5) send(toID enode.ID, toAddr *net.UDPAddr, packet v5wire.Packet, c *v5wire.Whoareyou) (v5wire.Nonce, error) { +func (t *UDPv5) send(toID enode.ID, toAddr netip.AddrPort, packet v5wire.Packet, c *v5wire.Whoareyou) (v5wire.Nonce, error) { addr := toAddr.String() t.logcontext = append(t.logcontext[:0], "id", toID, "addr", addr) t.logcontext = packet.AppendLogInfo(t.logcontext) @@ -652,7 +653,7 @@ func (t *UDPv5) send(toID enode.ID, toAddr *net.UDPAddr, packet v5wire.Packet, c return nonce, err } - _, err = t.conn.WriteToUDP(enc, toAddr) + _, err = t.conn.WriteToUDPAddrPort(enc, toAddr) t.log.Trace(">> "+packet.Name(), t.logcontext...) return nonce, err } @@ -663,7 +664,7 @@ func (t *UDPv5) readLoop() { buf := make([]byte, maxPacketSize) for range t.readNextCh { - nbytes, from, err := t.conn.ReadFromUDP(buf) + nbytes, from, err := t.conn.ReadFromUDPAddrPort(buf) if netutil.IsTemporaryError(err) { // Ignore temporary read errors. t.log.Debug("Temporary UDP read error", "err", err) @@ -680,7 +681,7 @@ func (t *UDPv5) readLoop() { } // dispatchReadPacket sends a packet into the dispatch loop. -func (t *UDPv5) dispatchReadPacket(from *net.UDPAddr, content []byte) bool { +func (t *UDPv5) dispatchReadPacket(from netip.AddrPort, content []byte) bool { select { case t.packetInCh <- ReadPacket{content, from}: return true @@ -690,7 +691,7 @@ func (t *UDPv5) dispatchReadPacket(from *net.UDPAddr, content []byte) bool { } // handlePacket decodes and processes an incoming packet from the network. -func (t *UDPv5) handlePacket(rawpacket []byte, fromAddr *net.UDPAddr) error { +func (t *UDPv5) handlePacket(rawpacket []byte, fromAddr netip.AddrPort) error { addr := fromAddr.String() fromID, fromNode, packet, err := t.codec.Decode(rawpacket, addr) if err != nil { @@ -707,7 +708,7 @@ func (t *UDPv5) handlePacket(rawpacket []byte, fromAddr *net.UDPAddr) error { } if fromNode != nil { // Handshake succeeded, add to table. - t.tab.addSeenNode(wrapNode(fromNode)) + t.tab.addInboundNode(fromNode) } if packet.Kind() != v5wire.WhoareyouPacket { // WHOAREYOU logged separately to report errors. @@ -720,13 +721,13 @@ func (t *UDPv5) handlePacket(rawpacket []byte, fromAddr *net.UDPAddr) error { } // handleCallResponse dispatches a response packet to the call waiting for it. -func (t *UDPv5) handleCallResponse(fromID enode.ID, fromAddr *net.UDPAddr, p v5wire.Packet) bool { +func (t *UDPv5) handleCallResponse(fromID enode.ID, fromAddr netip.AddrPort, p v5wire.Packet) bool { ac := t.activeCallByNode[fromID] if ac == nil || !bytes.Equal(p.RequestID(), ac.reqid) { t.log.Debug(fmt.Sprintf("Unsolicited/late %s response", p.Name()), "id", fromID, "addr", fromAddr) return false } - if !fromAddr.IP.Equal(ac.addr.IP) || fromAddr.Port != ac.addr.Port { + if fromAddr != ac.addr { t.log.Debug(fmt.Sprintf("%s from wrong endpoint", p.Name()), "id", fromID, "addr", fromAddr) return false } @@ -751,7 +752,7 @@ func (t *UDPv5) getNode(id enode.ID) *enode.Node { } // handle processes incoming packets according to their message type. -func (t *UDPv5) handle(p v5wire.Packet, fromID enode.ID, fromAddr *net.UDPAddr) { +func (t *UDPv5) handle(p v5wire.Packet, fromID enode.ID, fromAddr netip.AddrPort) { switch p := p.(type) { case *v5wire.Unknown: t.handleUnknown(p, fromID, fromAddr) @@ -761,7 +762,9 @@ func (t *UDPv5) handle(p v5wire.Packet, fromID enode.ID, fromAddr *net.UDPAddr) t.handlePing(p, fromID, fromAddr) case *v5wire.Pong: if t.handleCallResponse(fromID, fromAddr, p) { - t.localNode.UDPEndpointStatement(fromAddr, &net.UDPAddr{IP: p.ToIP, Port: int(p.ToPort)}) + fromUDPAddr := &net.UDPAddr{IP: fromAddr.Addr().AsSlice(), Port: int(fromAddr.Port())} + toUDPAddr := &net.UDPAddr{IP: p.ToIP, Port: int(p.ToPort)} + t.localNode.UDPEndpointStatement(fromUDPAddr, toUDPAddr) } case *v5wire.Findnode: t.handleFindnode(p, fromID, fromAddr) @@ -775,7 +778,7 @@ func (t *UDPv5) handle(p v5wire.Packet, fromID enode.ID, fromAddr *net.UDPAddr) } // handleUnknown initiates a handshake by responding with WHOAREYOU. -func (t *UDPv5) handleUnknown(p *v5wire.Unknown, fromID enode.ID, fromAddr *net.UDPAddr) { +func (t *UDPv5) handleUnknown(p *v5wire.Unknown, fromID enode.ID, fromAddr netip.AddrPort) { challenge := &v5wire.Whoareyou{Nonce: p.Nonce} crand.Read(challenge.IDNonce[:]) if n := t.getNode(fromID); n != nil { @@ -791,7 +794,7 @@ var ( ) // handleWhoareyou resends the active call as a handshake packet. -func (t *UDPv5) handleWhoareyou(p *v5wire.Whoareyou, fromID enode.ID, fromAddr *net.UDPAddr) { +func (t *UDPv5) handleWhoareyou(p *v5wire.Whoareyou, fromID enode.ID, fromAddr netip.AddrPort) { c, err := t.matchWithCall(fromID, p.Nonce) if err != nil { t.log.Debug("Invalid "+p.Name(), "addr", fromAddr, "err", err) @@ -825,32 +828,35 @@ func (t *UDPv5) matchWithCall(fromID enode.ID, nonce v5wire.Nonce) (*callV5, err } // handlePing sends a PONG response. -func (t *UDPv5) handlePing(p *v5wire.Ping, fromID enode.ID, fromAddr *net.UDPAddr) { - remoteIP := fromAddr.IP - // Handle IPv4 mapped IPv6 addresses in the - // event the local node is binded to an - // ipv6 interface. - if remoteIP.To4() != nil { - remoteIP = remoteIP.To4() +func (t *UDPv5) handlePing(p *v5wire.Ping, fromID enode.ID, fromAddr netip.AddrPort) { + var remoteIP net.IP + // Handle IPv4 mapped IPv6 addresses in the event the local node is binded + // to an ipv6 interface. + if fromAddr.Addr().Is4() || fromAddr.Addr().Is4In6() { + ip4 := fromAddr.Addr().As4() + remoteIP = ip4[:] + } else { + remoteIP = fromAddr.Addr().AsSlice() } t.sendResponse(fromID, fromAddr, &v5wire.Pong{ ReqID: p.ReqID, ToIP: remoteIP, - ToPort: uint16(fromAddr.Port), + ToPort: fromAddr.Port(), ENRSeq: t.localNode.Node().Seq(), }) } // handleFindnode returns nodes to the requester. -func (t *UDPv5) handleFindnode(p *v5wire.Findnode, fromID enode.ID, fromAddr *net.UDPAddr) { - nodes := t.collectTableNodes(fromAddr.IP, p.Distances, findnodeResultLimit) +func (t *UDPv5) handleFindnode(p *v5wire.Findnode, fromID enode.ID, fromAddr netip.AddrPort) { + nodes := t.collectTableNodes(fromAddr.Addr(), p.Distances, findnodeResultLimit) for _, resp := range packNodes(p.ReqID, nodes) { t.sendResponse(fromID, fromAddr, resp) } } // collectTableNodes creates a FINDNODE result set for the given distances. -func (t *UDPv5) collectTableNodes(rip net.IP, distances []uint, limit int) []*enode.Node { +func (t *UDPv5) collectTableNodes(rip netip.Addr, distances []uint, limit int) []*enode.Node { + ripSlice := rip.AsSlice() var bn []*enode.Node var nodes []*enode.Node var processed = make(map[uint]struct{}) @@ -865,7 +871,7 @@ func (t *UDPv5) collectTableNodes(rip net.IP, distances []uint, limit int) []*en for _, n := range t.tab.appendLiveNodes(dist, bn[:0]) { // Apply some pre-checks to avoid sending invalid nodes. // Note liveness is checked by appendLiveNodes. - if netutil.CheckRelayIP(rip, n.IP()) != nil { + if netutil.CheckRelayIP(ripSlice, n.IP()) != nil { continue } nodes = append(nodes, n) diff --git a/p2p/discover/v5_udp_test.go b/p2p/discover/v5_udp_test.go index eaa969ea8b..eddacb1960 100644 --- a/p2p/discover/v5_udp_test.go +++ b/p2p/discover/v5_udp_test.go @@ -23,6 +23,7 @@ import ( "fmt" "math/rand" "net" + "net/netip" "reflect" "testing" "time" @@ -103,7 +104,7 @@ func TestUDPv5_pingHandling(t *testing.T) { defer test.close() test.packetIn(&v5wire.Ping{ReqID: []byte("foo")}) - test.waitPacketOut(func(p *v5wire.Pong, addr *net.UDPAddr, _ v5wire.Nonce) { + test.waitPacketOut(func(p *v5wire.Pong, addr netip.AddrPort, _ v5wire.Nonce) { if !bytes.Equal(p.ReqID, []byte("foo")) { t.Error("wrong request ID in response:", p.ReqID) } @@ -135,16 +136,16 @@ func TestUDPv5_unknownPacket(t *testing.T) { // Unknown packet from unknown node. test.packetIn(&v5wire.Unknown{Nonce: nonce}) - test.waitPacketOut(func(p *v5wire.Whoareyou, addr *net.UDPAddr, _ v5wire.Nonce) { + test.waitPacketOut(func(p *v5wire.Whoareyou, addr netip.AddrPort, _ v5wire.Nonce) { check(p, 0) }) // Make node known. n := test.getNode(test.remotekey, test.remoteaddr).Node() - test.table.addSeenNode(wrapNode(n)) + test.table.addFoundNode(n, false) test.packetIn(&v5wire.Unknown{Nonce: nonce}) - test.waitPacketOut(func(p *v5wire.Whoareyou, addr *net.UDPAddr, _ v5wire.Nonce) { + test.waitPacketOut(func(p *v5wire.Whoareyou, addr netip.AddrPort, _ v5wire.Nonce) { check(p, n.Seq()) }) } @@ -159,9 +160,9 @@ func TestUDPv5_findnodeHandling(t *testing.T) { nodes253 := nodesAtDistance(test.table.self().ID(), 253, 16) nodes249 := nodesAtDistance(test.table.self().ID(), 249, 4) nodes248 := nodesAtDistance(test.table.self().ID(), 248, 10) - fillTable(test.table, wrapNodes(nodes253), true) - fillTable(test.table, wrapNodes(nodes249), true) - fillTable(test.table, wrapNodes(nodes248), true) + fillTable(test.table, nodes253, true) + fillTable(test.table, nodes249, true) + fillTable(test.table, nodes248, true) // Requesting with distance zero should return the node's own record. test.packetIn(&v5wire.Findnode{ReqID: []byte{0}, Distances: []uint{0}}) @@ -199,7 +200,7 @@ func (test *udpV5Test) expectNodes(wantReqID []byte, wantTotal uint8, wantNodes } for { - test.waitPacketOut(func(p *v5wire.Nodes, addr *net.UDPAddr, _ v5wire.Nonce) { + test.waitPacketOut(func(p *v5wire.Nodes, addr netip.AddrPort, _ v5wire.Nonce) { if !bytes.Equal(p.ReqID, wantReqID) { test.t.Fatalf("wrong request ID %v in response, want %v", p.ReqID, wantReqID) } @@ -238,7 +239,7 @@ func TestUDPv5_pingCall(t *testing.T) { _, err := test.udp.ping(remote) done <- err }() - test.waitPacketOut(func(p *v5wire.Ping, addr *net.UDPAddr, _ v5wire.Nonce) {}) + test.waitPacketOut(func(p *v5wire.Ping, addr netip.AddrPort, _ v5wire.Nonce) {}) if err := <-done; err != errTimeout { t.Fatalf("want errTimeout, got %q", err) } @@ -248,7 +249,7 @@ func TestUDPv5_pingCall(t *testing.T) { _, err := test.udp.ping(remote) done <- err }() - test.waitPacketOut(func(p *v5wire.Ping, addr *net.UDPAddr, _ v5wire.Nonce) { + test.waitPacketOut(func(p *v5wire.Ping, addr netip.AddrPort, _ v5wire.Nonce) { test.packetInFrom(test.remotekey, test.remoteaddr, &v5wire.Pong{ReqID: p.ReqID}) }) if err := <-done; err != nil { @@ -260,8 +261,8 @@ func TestUDPv5_pingCall(t *testing.T) { _, err := test.udp.ping(remote) done <- err }() - test.waitPacketOut(func(p *v5wire.Ping, addr *net.UDPAddr, _ v5wire.Nonce) { - wrongAddr := &net.UDPAddr{IP: net.IP{33, 44, 55, 22}, Port: 10101} + test.waitPacketOut(func(p *v5wire.Ping, addr netip.AddrPort, _ v5wire.Nonce) { + wrongAddr := netip.MustParseAddrPort("33.44.55.22:10101") test.packetInFrom(test.remotekey, wrongAddr, &v5wire.Pong{ReqID: p.ReqID}) }) if err := <-done; err != errTimeout { @@ -291,7 +292,7 @@ func TestUDPv5_findnodeCall(t *testing.T) { }() // Serve the responses: - test.waitPacketOut(func(p *v5wire.Findnode, addr *net.UDPAddr, _ v5wire.Nonce) { + test.waitPacketOut(func(p *v5wire.Findnode, addr netip.AddrPort, _ v5wire.Nonce) { if !reflect.DeepEqual(p.Distances, distances) { t.Fatalf("wrong distances in request: %v", p.Distances) } @@ -337,15 +338,15 @@ func TestUDPv5_callResend(t *testing.T) { }() // Ping answered by WHOAREYOU. - test.waitPacketOut(func(p *v5wire.Ping, addr *net.UDPAddr, nonce v5wire.Nonce) { + test.waitPacketOut(func(p *v5wire.Ping, addr netip.AddrPort, nonce v5wire.Nonce) { test.packetIn(&v5wire.Whoareyou{Nonce: nonce}) }) // Ping should be re-sent. - test.waitPacketOut(func(p *v5wire.Ping, addr *net.UDPAddr, _ v5wire.Nonce) { + test.waitPacketOut(func(p *v5wire.Ping, addr netip.AddrPort, _ v5wire.Nonce) { test.packetIn(&v5wire.Pong{ReqID: p.ReqID}) }) // Answer the other ping. - test.waitPacketOut(func(p *v5wire.Ping, addr *net.UDPAddr, _ v5wire.Nonce) { + test.waitPacketOut(func(p *v5wire.Ping, addr netip.AddrPort, _ v5wire.Nonce) { test.packetIn(&v5wire.Pong{ReqID: p.ReqID}) }) if err := <-done; err != nil { @@ -370,11 +371,11 @@ func TestUDPv5_multipleHandshakeRounds(t *testing.T) { }() // Ping answered by WHOAREYOU. - test.waitPacketOut(func(p *v5wire.Ping, addr *net.UDPAddr, nonce v5wire.Nonce) { + test.waitPacketOut(func(p *v5wire.Ping, addr netip.AddrPort, nonce v5wire.Nonce) { test.packetIn(&v5wire.Whoareyou{Nonce: nonce}) }) // Ping answered by WHOAREYOU again. - test.waitPacketOut(func(p *v5wire.Ping, addr *net.UDPAddr, nonce v5wire.Nonce) { + test.waitPacketOut(func(p *v5wire.Ping, addr netip.AddrPort, nonce v5wire.Nonce) { test.packetIn(&v5wire.Whoareyou{Nonce: nonce}) }) if err := <-done; err != errTimeout { @@ -401,7 +402,7 @@ func TestUDPv5_callTimeoutReset(t *testing.T) { }() // Serve two responses, slowly. - test.waitPacketOut(func(p *v5wire.Findnode, addr *net.UDPAddr, _ v5wire.Nonce) { + test.waitPacketOut(func(p *v5wire.Findnode, addr netip.AddrPort, _ v5wire.Nonce) { time.Sleep(respTimeout - 50*time.Millisecond) test.packetIn(&v5wire.Nodes{ ReqID: p.ReqID, @@ -439,7 +440,7 @@ func TestUDPv5_talkHandling(t *testing.T) { Protocol: "test", Message: []byte("test request"), }) - test.waitPacketOut(func(p *v5wire.TalkResponse, addr *net.UDPAddr, _ v5wire.Nonce) { + test.waitPacketOut(func(p *v5wire.TalkResponse, addr netip.AddrPort, _ v5wire.Nonce) { if !bytes.Equal(p.ReqID, []byte("foo")) { t.Error("wrong request ID in response:", p.ReqID) } @@ -458,7 +459,7 @@ func TestUDPv5_talkHandling(t *testing.T) { Protocol: "wrong", Message: []byte("test request"), }) - test.waitPacketOut(func(p *v5wire.TalkResponse, addr *net.UDPAddr, _ v5wire.Nonce) { + test.waitPacketOut(func(p *v5wire.TalkResponse, addr netip.AddrPort, _ v5wire.Nonce) { if !bytes.Equal(p.ReqID, []byte("2")) { t.Error("wrong request ID in response:", p.ReqID) } @@ -485,7 +486,7 @@ func TestUDPv5_talkRequest(t *testing.T) { _, err := test.udp.TalkRequest(remote, "test", []byte("test request")) done <- err }() - test.waitPacketOut(func(p *v5wire.TalkRequest, addr *net.UDPAddr, _ v5wire.Nonce) {}) + test.waitPacketOut(func(p *v5wire.TalkRequest, addr netip.AddrPort, _ v5wire.Nonce) {}) if err := <-done; err != errTimeout { t.Fatalf("want errTimeout, got %q", err) } @@ -495,7 +496,7 @@ func TestUDPv5_talkRequest(t *testing.T) { _, err := test.udp.TalkRequest(remote, "test", []byte("test request")) done <- err }() - test.waitPacketOut(func(p *v5wire.TalkRequest, addr *net.UDPAddr, _ v5wire.Nonce) { + test.waitPacketOut(func(p *v5wire.TalkRequest, addr netip.AddrPort, _ v5wire.Nonce) { if p.Protocol != "test" { t.Errorf("wrong protocol ID in talk request: %q", p.Protocol) } @@ -516,7 +517,7 @@ func TestUDPv5_talkRequest(t *testing.T) { _, err := test.udp.TalkRequestToID(remote.ID(), test.remoteaddr, "test", []byte("test request 2")) done <- err }() - test.waitPacketOut(func(p *v5wire.TalkRequest, addr *net.UDPAddr, _ v5wire.Nonce) { + test.waitPacketOut(func(p *v5wire.TalkRequest, addr netip.AddrPort, _ v5wire.Nonce) { if p.Protocol != "test" { t.Errorf("wrong protocol ID in talk request: %q", p.Protocol) } @@ -583,13 +584,14 @@ func TestUDPv5_lookup(t *testing.T) { for d, nn := range lookupTestnet.dists { for i, key := range nn { n := lookupTestnet.node(d, i) - test.getNode(key, &net.UDPAddr{IP: n.IP(), Port: n.UDP()}) + addr, _ := n.UDPEndpoint() + test.getNode(key, addr) } } // Seed table with initial node. initialNode := lookupTestnet.node(256, 0) - fillTable(test.table, []*node{wrapNode(initialNode)}, true) + fillTable(test.table, []*enode.Node{initialNode}, true) // Start the lookup. resultC := make(chan []*enode.Node, 1) @@ -601,7 +603,7 @@ func TestUDPv5_lookup(t *testing.T) { // Answer lookup packets. asked := make(map[enode.ID]bool) for done := false; !done; { - done = test.waitPacketOut(func(p v5wire.Packet, to *net.UDPAddr, _ v5wire.Nonce) { + done = test.waitPacketOut(func(p v5wire.Packet, to netip.AddrPort, _ v5wire.Nonce) { recipient, key := lookupTestnet.nodeByAddr(to) switch p := p.(type) { case *v5wire.Ping: @@ -652,11 +654,8 @@ func TestUDPv5_PingWithIPV4MappedAddress(t *testing.T) { test := newUDPV5Test(t) defer test.close() - rawIP := net.IPv4(0xFF, 0x12, 0x33, 0xE5) - test.remoteaddr = &net.UDPAddr{ - IP: rawIP.To16(), - Port: 0, - } + rawIP := netip.AddrFrom4([4]byte{0xFF, 0x12, 0x33, 0xE5}) + test.remoteaddr = netip.AddrPortFrom(netip.AddrFrom16(rawIP.As16()), 0) remote := test.getNode(test.remotekey, test.remoteaddr).Node() done := make(chan struct{}, 1) @@ -665,14 +664,14 @@ func TestUDPv5_PingWithIPV4MappedAddress(t *testing.T) { test.udp.handlePing(&v5wire.Ping{ENRSeq: 1}, remote.ID(), test.remoteaddr) done <- struct{}{} }() - test.waitPacketOut(func(p *v5wire.Pong, addr *net.UDPAddr, _ v5wire.Nonce) { + test.waitPacketOut(func(p *v5wire.Pong, addr netip.AddrPort, _ v5wire.Nonce) { if len(p.ToIP) == net.IPv6len { t.Error("Received untruncated ip address") } if len(p.ToIP) != net.IPv4len { t.Errorf("Received ip address with incorrect length: %d", len(p.ToIP)) } - if !p.ToIP.Equal(rawIP) { + if !p.ToIP.Equal(rawIP.AsSlice()) { t.Errorf("Received incorrect ip address: wanted %s but received %s", rawIP.String(), p.ToIP.String()) } }) @@ -688,9 +687,9 @@ type udpV5Test struct { db *enode.DB udp *UDPv5 localkey, remotekey *ecdsa.PrivateKey - remoteaddr *net.UDPAddr + remoteaddr netip.AddrPort nodesByID map[enode.ID]*enode.LocalNode - nodesByIP map[string]*enode.LocalNode + nodesByIP map[netip.Addr]*enode.LocalNode } // testCodec is the packet encoding used by protocol tests. This codec does not perform encryption. @@ -750,9 +749,9 @@ func newUDPV5Test(t *testing.T) *udpV5Test { pipe: newpipe(), localkey: newkey(), remotekey: newkey(), - remoteaddr: &net.UDPAddr{IP: net.IP{10, 0, 1, 99}, Port: 30303}, + remoteaddr: netip.MustParseAddrPort("10.0.1.99:30303"), nodesByID: make(map[enode.ID]*enode.LocalNode), - nodesByIP: make(map[string]*enode.LocalNode), + nodesByIP: make(map[netip.Addr]*enode.LocalNode), } test.db, _ = enode.OpenDB("") ln := enode.NewLocalNode(test.db, test.localkey) @@ -777,8 +776,8 @@ func (test *udpV5Test) packetIn(packet v5wire.Packet) { test.packetInFrom(test.remotekey, test.remoteaddr, packet) } -// handles a packet as if it had been sent to the transport by the key/endpoint. -func (test *udpV5Test) packetInFrom(key *ecdsa.PrivateKey, addr *net.UDPAddr, packet v5wire.Packet) { +// packetInFrom handles a packet as if it had been sent to the transport by the key/endpoint. +func (test *udpV5Test) packetInFrom(key *ecdsa.PrivateKey, addr netip.AddrPort, packet v5wire.Packet) { test.t.Helper() ln := test.getNode(key, addr) @@ -793,22 +792,22 @@ func (test *udpV5Test) packetInFrom(key *ecdsa.PrivateKey, addr *net.UDPAddr, pa } // getNode ensures the test knows about a node at the given endpoint. -func (test *udpV5Test) getNode(key *ecdsa.PrivateKey, addr *net.UDPAddr) *enode.LocalNode { +func (test *udpV5Test) getNode(key *ecdsa.PrivateKey, addr netip.AddrPort) *enode.LocalNode { id := encodePubkey(&key.PublicKey).id() ln := test.nodesByID[id] if ln == nil { db, _ := enode.OpenDB("") ln = enode.NewLocalNode(db, key) - ln.SetStaticIP(addr.IP) - ln.Set(enr.UDP(addr.Port)) + ln.SetStaticIP(addr.Addr().AsSlice()) + ln.Set(enr.UDP(addr.Port())) test.nodesByID[id] = ln } - test.nodesByIP[string(addr.IP)] = ln + test.nodesByIP[addr.Addr()] = ln return ln } // waitPacketOut waits for the next output packet and handles it using the given 'validate' -// function. The function must be of type func (X, *net.UDPAddr, v5wire.Nonce) where X is +// function. The function must be of type func (X, netip.AddrPort, v5wire.Nonce) where X is // assignable to packetV5. func (test *udpV5Test) waitPacketOut(validate interface{}) (closed bool) { test.t.Helper() @@ -824,7 +823,7 @@ func (test *udpV5Test) waitPacketOut(validate interface{}) (closed bool) { test.t.Fatalf("timed out waiting for %v", exptype) return false } - ln := test.nodesByIP[string(dgram.to.IP)] + ln := test.nodesByIP[dgram.to.Addr()] if ln == nil { test.t.Fatalf("attempt to send to non-existing node %v", &dgram.to) return false @@ -839,7 +838,7 @@ func (test *udpV5Test) waitPacketOut(validate interface{}) (closed bool) { test.t.Errorf("sent packet type mismatch, got: %v, want: %v", reflect.TypeOf(p), exptype) return false } - fn.Call([]reflect.Value{reflect.ValueOf(p), reflect.ValueOf(&dgram.to), reflect.ValueOf(frame.AuthTag)}) + fn.Call([]reflect.Value{reflect.ValueOf(p), reflect.ValueOf(dgram.to), reflect.ValueOf(frame.AuthTag)}) return false } diff --git a/p2p/discover/v5wire/encoding.go b/p2p/discover/v5wire/encoding.go index 5108910620..904a3ddec6 100644 --- a/p2p/discover/v5wire/encoding.go +++ b/p2p/discover/v5wire/encoding.go @@ -367,11 +367,11 @@ func (c *Codec) makeHandshakeAuth(toID enode.ID, addr string, challenge *Whoarey // key is part of the ID nonce signature. var remotePubkey = new(ecdsa.PublicKey) if err := challenge.Node.Load((*enode.Secp256k1)(remotePubkey)); err != nil { - return nil, nil, fmt.Errorf("can't find secp256k1 key for recipient") + return nil, nil, errors.New("can't find secp256k1 key for recipient") } ephkey, err := c.sc.ephemeralKeyGen() if err != nil { - return nil, nil, fmt.Errorf("can't generate ephemeral key") + return nil, nil, errors.New("can't generate ephemeral key") } ephpubkey := EncodePubkey(&ephkey.PublicKey) auth.pubkey = ephpubkey[:] @@ -395,7 +395,7 @@ func (c *Codec) makeHandshakeAuth(toID enode.ID, addr string, challenge *Whoarey // Create session keys. sec := deriveKeys(sha256.New, ephkey, remotePubkey, c.localnode.ID(), challenge.Node.ID(), cdata) if sec == nil { - return nil, nil, fmt.Errorf("key derivation failed") + return nil, nil, errors.New("key derivation failed") } return auth, sec, err } diff --git a/p2p/discover/v5wire/encoding_test.go b/p2p/discover/v5wire/encoding_test.go index a5387311a5..27966f2afc 100644 --- a/p2p/discover/v5wire/encoding_test.go +++ b/p2p/discover/v5wire/encoding_test.go @@ -30,6 +30,7 @@ import ( "testing" "github.com/davecgh/go-spew/spew" + "github.com/ethereum/go-ethereum/common/hexutil" "github.com/ethereum/go-ethereum/common/mclock" "github.com/ethereum/go-ethereum/crypto" @@ -283,9 +284,38 @@ func TestDecodeErrorsV5(t *testing.T) { b = make([]byte, 63) net.nodeA.expectDecodeErr(t, errInvalidHeader, b) - // TODO some more tests would be nice :) - // - check invalid authdata sizes - // - check invalid handshake data sizes + t.Run("invalid-handshake-datasize", func(t *testing.T) { + requiredNumber := 108 + + testDataFile := filepath.Join("testdata", "v5.1-ping-handshake"+".txt") + enc := hexFile(testDataFile) + //delete some byte from handshake to make it invalid + enc = enc[:len(enc)-requiredNumber] + net.nodeB.expectDecodeErr(t, errMsgTooShort, enc) + }) + + t.Run("invalid-auth-datasize", func(t *testing.T) { + testPacket := []byte{} + testDataFiles := []string{"v5.1-whoareyou", "v5.1-ping-handshake"} + for counter, name := range testDataFiles { + file := filepath.Join("testdata", name+".txt") + enc := hexFile(file) + if counter == 0 { + //make whoareyou header + testPacket = enc[:sizeofStaticPacketData-1] + testPacket = append(testPacket, 255) + } + if counter == 1 { + //append invalid auth size + testPacket = append(testPacket, enc[sizeofStaticPacketData:]...) + } + } + + wantErr := "invalid auth size" + if _, err := net.nodeB.decode(testPacket); strings.HasSuffix(err.Error(), wantErr) { + t.Fatal(fmt.Errorf("(%s) got err %q, want %q", net.nodeB.ln.ID().TerminalString(), err, wantErr)) + } + }) } // This test checks that all test vectors can be decoded. diff --git a/p2p/dnsdisc/client.go b/p2p/dnsdisc/client.go index 8f1c221b80..4f14d860e1 100644 --- a/p2p/dnsdisc/client.go +++ b/p2p/dnsdisc/client.go @@ -191,7 +191,7 @@ func (c *Client) resolveEntry(ctx context.Context, domain, hash string) (entry, func (c *Client) doResolveEntry(ctx context.Context, domain, hash string) (entry, error) { wantHash, err := b32format.DecodeString(hash) if err != nil { - return nil, fmt.Errorf("invalid base32 hash") + return nil, errors.New("invalid base32 hash") } name := hash + "." + domain txts, err := c.cfg.Resolver.LookupTXT(ctx, hash+"."+domain) diff --git a/p2p/dnsdisc/client_test.go b/p2p/dnsdisc/client_test.go index abc35ddbd3..01912e1eab 100644 --- a/p2p/dnsdisc/client_test.go +++ b/p2p/dnsdisc/client_test.go @@ -20,6 +20,7 @@ import ( "context" "crypto/ecdsa" "errors" + "maps" "reflect" "testing" "time" @@ -214,7 +215,7 @@ func TestIteratorNodeUpdates(t *testing.T) { // Ensure RandomNode returns the new nodes after the tree is updated. updateSomeNodes(keys, nodes) tree2, _ := makeTestTree("n", nodes, nil) - resolver.clear() + clear(resolver) resolver.add(tree2.ToTXT("n")) t.Log("tree updated") @@ -255,7 +256,7 @@ func TestIteratorRootRecheckOnFail(t *testing.T) { // Ensure RandomNode returns the new nodes after the tree is updated. updateSomeNodes(keys, nodes) tree2, _ := makeTestTree("n", nodes, nil) - resolver.clear() + clear(resolver) resolver.add(tree2.ToTXT("n")) t.Log("tree updated") @@ -446,16 +447,8 @@ func newMapResolver(maps ...map[string]string) mapResolver { return mr } -func (mr mapResolver) clear() { - for k := range mr { - delete(mr, k) - } -} - func (mr mapResolver) add(m map[string]string) { - for k, v := range m { - mr[k] = v - } + maps.Copy(mr, m) } func (mr mapResolver) LookupTXT(ctx context.Context, name string) ([]string, error) { diff --git a/p2p/dnsdisc/tree.go b/p2p/dnsdisc/tree.go index 7d9703a345..dfac4fb372 100644 --- a/p2p/dnsdisc/tree.go +++ b/p2p/dnsdisc/tree.go @@ -21,6 +21,7 @@ import ( "crypto/ecdsa" "encoding/base32" "encoding/base64" + "errors" "fmt" "io" "strings" @@ -341,7 +342,7 @@ func parseLinkEntry(e string) (entry, error) { func parseLink(e string) (*linkEntry, error) { if !strings.HasPrefix(e, linkPrefix) { - return nil, fmt.Errorf("wrong/missing scheme 'enrtree' in URL") + return nil, errors.New("wrong/missing scheme 'enrtree' in URL") } e = e[len(linkPrefix):] diff --git a/p2p/enode/idscheme.go b/p2p/enode/idscheme.go index fd5d868b76..db7841c047 100644 --- a/p2p/enode/idscheme.go +++ b/p2p/enode/idscheme.go @@ -18,7 +18,7 @@ package enode import ( "crypto/ecdsa" - "fmt" + "errors" "io" "github.com/ethereum/go-ethereum/common/math" @@ -67,7 +67,7 @@ func (V4ID) Verify(r *enr.Record, sig []byte) error { if err := r.Load(&entry); err != nil { return err } else if len(entry) != 33 { - return fmt.Errorf("invalid public key") + return errors.New("invalid public key") } h := sha3.NewLegacyKeccak256() @@ -157,5 +157,5 @@ func SignNull(r *enr.Record, id ID) *Node { if err := r.SetSig(NullID{}, []byte{}); err != nil { panic(err) } - return &Node{r: *r, id: id} + return newNodeWithID(r, id) } diff --git a/p2p/enode/node.go b/p2p/enode/node.go index d7a1a9a156..cb4ac8d172 100644 --- a/p2p/enode/node.go +++ b/p2p/enode/node.go @@ -24,6 +24,7 @@ import ( "fmt" "math/bits" "net" + "net/netip" "strings" "github.com/ethereum/go-ethereum/p2p/enr" @@ -36,6 +37,10 @@ var errMissingPrefix = errors.New("missing 'enr:' prefix for base64-encoded reco type Node struct { r enr.Record id ID + // endpoint information + ip netip.Addr + udp uint16 + tcp uint16 } // New wraps a node record. The record must be valid according to the given @@ -44,11 +49,76 @@ func New(validSchemes enr.IdentityScheme, r *enr.Record) (*Node, error) { if err := r.VerifySignature(validSchemes); err != nil { return nil, err } - node := &Node{r: *r} - if n := copy(node.id[:], validSchemes.NodeAddr(&node.r)); n != len(ID{}) { - return nil, fmt.Errorf("invalid node ID length %d, need %d", n, len(ID{})) + var id ID + if n := copy(id[:], validSchemes.NodeAddr(r)); n != len(id) { + return nil, fmt.Errorf("invalid node ID length %d, need %d", n, len(id)) + } + return newNodeWithID(r, id), nil +} + +func newNodeWithID(r *enr.Record, id ID) *Node { + n := &Node{r: *r, id: id} + // Set the preferred endpoint. + // Here we decide between IPv4 and IPv6, choosing the 'most global' address. + var ip4 netip.Addr + var ip6 netip.Addr + n.Load((*enr.IPv4Addr)(&ip4)) + n.Load((*enr.IPv6Addr)(&ip6)) + valid4 := validIP(ip4) + valid6 := validIP(ip6) + switch { + case valid4 && valid6: + if localityScore(ip4) >= localityScore(ip6) { + n.setIP4(ip4) + } else { + n.setIP6(ip6) + } + case valid4: + n.setIP4(ip4) + case valid6: + n.setIP6(ip6) + } + return n +} + +// validIP reports whether 'ip' is a valid node endpoint IP address. +func validIP(ip netip.Addr) bool { + return ip.IsValid() && !ip.IsMulticast() +} + +func localityScore(ip netip.Addr) int { + switch { + case ip.IsUnspecified(): + return 0 + case ip.IsLoopback(): + return 1 + case ip.IsLinkLocalUnicast(): + return 2 + case ip.IsPrivate(): + return 3 + default: + return 4 + } +} + +func (n *Node) setIP4(ip netip.Addr) { + n.ip = ip + n.Load((*enr.UDP)(&n.udp)) + n.Load((*enr.TCP)(&n.tcp)) +} + +func (n *Node) setIP6(ip netip.Addr) { + if ip.Is4In6() { + n.setIP4(ip) + return + } + n.ip = ip + if err := n.Load((*enr.UDP6)(&n.udp)); err != nil { + n.Load((*enr.UDP)(&n.udp)) + } + if err := n.Load((*enr.TCP6)(&n.tcp)); err != nil { + n.Load((*enr.TCP)(&n.tcp)) } - return node, nil } // MustParse parses a node record or enode:// URL. It panics if the input is invalid. @@ -89,43 +159,45 @@ func (n *Node) Seq() uint64 { return n.r.Seq() } -// Incomplete returns true for nodes with no IP address. -func (n *Node) Incomplete() bool { - return n.IP() == nil -} - // Load retrieves an entry from the underlying record. func (n *Node) Load(k enr.Entry) error { return n.r.Load(k) } -// IP returns the IP address of the node. This prefers IPv4 addresses. +// IP returns the IP address of the node. func (n *Node) IP() net.IP { - var ( - ip4 enr.IPv4 - ip6 enr.IPv6 - ) - if n.Load(&ip4) == nil { - return net.IP(ip4) - } - if n.Load(&ip6) == nil { - return net.IP(ip6) - } - return nil + return net.IP(n.ip.AsSlice()) +} + +// IPAddr returns the IP address of the node. +func (n *Node) IPAddr() netip.Addr { + return n.ip } // UDP returns the UDP port of the node. func (n *Node) UDP() int { - var port enr.UDP - n.Load(&port) - return int(port) + return int(n.udp) } // TCP returns the TCP port of the node. func (n *Node) TCP() int { - var port enr.TCP - n.Load(&port) - return int(port) + return int(n.tcp) +} + +// UDPEndpoint returns the announced UDP endpoint. +func (n *Node) UDPEndpoint() (netip.AddrPort, bool) { + if !n.ip.IsValid() || n.ip.IsUnspecified() || n.udp == 0 { + return netip.AddrPort{}, false + } + return netip.AddrPortFrom(n.ip, n.udp), true +} + +// TCPEndpoint returns the announced TCP endpoint. +func (n *Node) TCPEndpoint() (netip.AddrPort, bool) { + if !n.ip.IsValid() || n.ip.IsUnspecified() || n.tcp == 0 { + return netip.AddrPort{}, false + } + return netip.AddrPortFrom(n.ip, n.tcp), true } // Pubkey returns the secp256k1 public key of the node, if present. @@ -147,16 +219,15 @@ func (n *Node) Record() *enr.Record { // ValidateComplete checks whether n has a valid IP and UDP port. // Deprecated: don't use this method. func (n *Node) ValidateComplete() error { - if n.Incomplete() { + if !n.ip.IsValid() { return errors.New("missing IP address") } - if n.UDP() == 0 { - return errors.New("missing UDP port") - } - ip := n.IP() - if ip.IsMulticast() || ip.IsUnspecified() { + if n.ip.IsMulticast() || n.ip.IsUnspecified() { return errors.New("invalid IP (multicast/unspecified)") } + if n.udp == 0 { + return errors.New("missing UDP port") + } // Validate the node key (on curve, etc.). var key Secp256k1 return n.Load(&key) diff --git a/p2p/enode/node_test.go b/p2p/enode/node_test.go index d15859c477..56e196e82e 100644 --- a/p2p/enode/node_test.go +++ b/p2p/enode/node_test.go @@ -21,6 +21,7 @@ import ( "encoding/hex" "fmt" "math/big" + "net/netip" "testing" "testing/quick" @@ -64,6 +65,167 @@ func TestPythonInterop(t *testing.T) { } } +func TestNodeEndpoints(t *testing.T) { + id := HexID("00000000000000806ad9b61fa5ae014307ebdc964253adcd9f2c0a392aa11abc") + type endpointTest struct { + name string + node *Node + wantIP netip.Addr + wantUDP int + wantTCP int + } + tests := []endpointTest{ + { + name: "no-addr", + node: func() *Node { + var r enr.Record + return SignNull(&r, id) + }(), + }, + { + name: "udp-only", + node: func() *Node { + var r enr.Record + r.Set(enr.UDP(9000)) + return SignNull(&r, id) + }(), + }, + { + name: "tcp-only", + node: func() *Node { + var r enr.Record + r.Set(enr.TCP(9000)) + return SignNull(&r, id) + }(), + }, + { + name: "ipv4-only-loopback", + node: func() *Node { + var r enr.Record + r.Set(enr.IPv4Addr(netip.MustParseAddr("127.0.0.1"))) + return SignNull(&r, id) + }(), + wantIP: netip.MustParseAddr("127.0.0.1"), + }, + { + name: "ipv4-only-unspecified", + node: func() *Node { + var r enr.Record + r.Set(enr.IPv4Addr(netip.MustParseAddr("0.0.0.0"))) + return SignNull(&r, id) + }(), + wantIP: netip.MustParseAddr("0.0.0.0"), + }, + { + name: "ipv4-only", + node: func() *Node { + var r enr.Record + r.Set(enr.IPv4Addr(netip.MustParseAddr("99.22.33.1"))) + return SignNull(&r, id) + }(), + wantIP: netip.MustParseAddr("99.22.33.1"), + }, + { + name: "ipv6-only", + node: func() *Node { + var r enr.Record + r.Set(enr.IPv6Addr(netip.MustParseAddr("2001::ff00:0042:8329"))) + return SignNull(&r, id) + }(), + wantIP: netip.MustParseAddr("2001::ff00:0042:8329"), + }, + { + name: "ipv4-loopback-and-ipv6-global", + node: func() *Node { + var r enr.Record + r.Set(enr.IPv4Addr(netip.MustParseAddr("127.0.0.1"))) + r.Set(enr.UDP(30304)) + r.Set(enr.IPv6Addr(netip.MustParseAddr("2001::ff00:0042:8329"))) + r.Set(enr.UDP6(30306)) + return SignNull(&r, id) + }(), + wantIP: netip.MustParseAddr("2001::ff00:0042:8329"), + wantUDP: 30306, + }, + { + name: "ipv4-unspecified-and-ipv6-loopback", + node: func() *Node { + var r enr.Record + r.Set(enr.IPv4Addr(netip.MustParseAddr("0.0.0.0"))) + r.Set(enr.IPv6Addr(netip.MustParseAddr("::1"))) + return SignNull(&r, id) + }(), + wantIP: netip.MustParseAddr("::1"), + }, + { + name: "ipv4-private-and-ipv6-global", + node: func() *Node { + var r enr.Record + r.Set(enr.IPv4Addr(netip.MustParseAddr("192.168.2.2"))) + r.Set(enr.UDP(30304)) + r.Set(enr.IPv6Addr(netip.MustParseAddr("2001::ff00:0042:8329"))) + r.Set(enr.UDP6(30306)) + return SignNull(&r, id) + }(), + wantIP: netip.MustParseAddr("2001::ff00:0042:8329"), + wantUDP: 30306, + }, + { + name: "ipv4-local-and-ipv6-global", + node: func() *Node { + var r enr.Record + r.Set(enr.IPv4Addr(netip.MustParseAddr("169.254.2.6"))) + r.Set(enr.UDP(30304)) + r.Set(enr.IPv6Addr(netip.MustParseAddr("2001::ff00:0042:8329"))) + r.Set(enr.UDP6(30306)) + return SignNull(&r, id) + }(), + wantIP: netip.MustParseAddr("2001::ff00:0042:8329"), + wantUDP: 30306, + }, + { + name: "ipv4-private-and-ipv6-private", + node: func() *Node { + var r enr.Record + r.Set(enr.IPv4Addr(netip.MustParseAddr("192.168.2.2"))) + r.Set(enr.UDP(30304)) + r.Set(enr.IPv6Addr(netip.MustParseAddr("fd00::abcd:1"))) + r.Set(enr.UDP6(30306)) + return SignNull(&r, id) + }(), + wantIP: netip.MustParseAddr("192.168.2.2"), + wantUDP: 30304, + }, + { + name: "ipv4-private-and-ipv6-link-local", + node: func() *Node { + var r enr.Record + r.Set(enr.IPv4Addr(netip.MustParseAddr("192.168.2.2"))) + r.Set(enr.UDP(30304)) + r.Set(enr.IPv6Addr(netip.MustParseAddr("fe80::1"))) + r.Set(enr.UDP6(30306)) + return SignNull(&r, id) + }(), + wantIP: netip.MustParseAddr("192.168.2.2"), + wantUDP: 30304, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.wantIP != test.node.IPAddr() { + t.Errorf("node has wrong IP %v, want %v", test.node.IPAddr(), test.wantIP) + } + if test.wantUDP != test.node.UDP() { + t.Errorf("node has wrong UDP port %d, want %d", test.node.UDP(), test.wantUDP) + } + if test.wantTCP != test.node.TCP() { + t.Errorf("node has wrong TCP port %d, want %d", test.node.TCP(), test.wantTCP) + } + }) + } +} + func TestHexID(t *testing.T) { ref := ID{0, 0, 0, 0, 0, 0, 0, 128, 106, 217, 182, 31, 165, 174, 1, 67, 7, 235, 220, 150, 66, 83, 173, 205, 159, 44, 10, 57, 42, 161, 26, 188} id1 := HexID("0x00000000000000806ad9b61fa5ae014307ebdc964253adcd9f2c0a392aa11abc") diff --git a/p2p/enode/nodedb.go b/p2p/enode/nodedb.go index 7e7fb69b29..654d71d47b 100644 --- a/p2p/enode/nodedb.go +++ b/p2p/enode/nodedb.go @@ -26,6 +26,7 @@ import ( "sync" "time" + "github.com/ethereum/go-ethereum/p2p/enr" "github.com/ethereum/go-ethereum/rlp" "github.com/syndtr/goleveldb/leveldb" "github.com/syndtr/goleveldb/leveldb/errors" @@ -84,7 +85,7 @@ func OpenDB(path string) (*DB, error) { return newPersistentDB(path) } -// newMemoryNodeDB creates a new in-memory node database without a persistent backend. +// newMemoryDB creates a new in-memory node database without a persistent backend. func newMemoryDB() (*DB, error) { db, err := leveldb.Open(storage.NewMemStorage(), nil) if err != nil { @@ -93,7 +94,7 @@ func newMemoryDB() (*DB, error) { return &DB{lvl: db, quit: make(chan struct{})}, nil } -// newPersistentNodeDB creates/opens a leveldb backed persistent node database, +// newPersistentDB creates/opens a leveldb backed persistent node database, // also flushing its contents in case of a version mismatch. func newPersistentDB(path string) (*DB, error) { opts := &opt.Options{OpenFilesCacheCapacity: 5} @@ -242,13 +243,14 @@ func (db *DB) Node(id ID) *Node { } func mustDecodeNode(id, data []byte) *Node { - node := new(Node) - if err := rlp.DecodeBytes(data, &node.r); err != nil { + var r enr.Record + if err := rlp.DecodeBytes(data, &r); err != nil { panic(fmt.Errorf("p2p/enode: can't decode node %x in DB: %v", id, err)) } - // Restore node id cache. - copy(node.id[:], id) - return node + if len(id) != len(ID{}) { + panic(fmt.Errorf("invalid id length %d", len(id))) + } + return newNodeWithID(&r, ID(id)) } // UpdateNode inserts - potentially overwriting - a node into the peer database. diff --git a/p2p/enode/urlv4.go b/p2p/enode/urlv4.go index 0272eee987..a55dfa6632 100644 --- a/p2p/enode/urlv4.go +++ b/p2p/enode/urlv4.go @@ -181,7 +181,7 @@ func (n *Node) URLv4() string { nodeid = fmt.Sprintf("%s.%x", scheme, n.id[:]) } u := url.URL{Scheme: "enode"} - if n.Incomplete() { + if !n.ip.IsValid() { u.Host = nodeid } else { addr := net.TCPAddr{IP: n.IP(), Port: n.TCP()} diff --git a/p2p/enr/entries.go b/p2p/enr/entries.go index 9945a436c9..917e1becba 100644 --- a/p2p/enr/entries.go +++ b/p2p/enr/entries.go @@ -21,6 +21,7 @@ import ( "fmt" "io" "net" + "net/netip" "github.com/ethereum/go-ethereum/rlp" ) @@ -167,6 +168,60 @@ func (v *IPv6) DecodeRLP(s *rlp.Stream) error { return nil } +// IPv4Addr is the "ip" key, which holds the IP address of the node. +type IPv4Addr netip.Addr + +func (v IPv4Addr) ENRKey() string { return "ip" } + +// EncodeRLP implements rlp.Encoder. +func (v IPv4Addr) EncodeRLP(w io.Writer) error { + addr := netip.Addr(v) + if !addr.Is4() { + return fmt.Errorf("address is not IPv4") + } + enc := rlp.NewEncoderBuffer(w) + bytes := addr.As4() + enc.WriteBytes(bytes[:]) + return enc.Flush() +} + +// DecodeRLP implements rlp.Decoder. +func (v *IPv4Addr) DecodeRLP(s *rlp.Stream) error { + var bytes [4]byte + if err := s.ReadBytes(bytes[:]); err != nil { + return err + } + *v = IPv4Addr(netip.AddrFrom4(bytes)) + return nil +} + +// IPv6Addr is the "ip6" key, which holds the IP address of the node. +type IPv6Addr netip.Addr + +func (v IPv6Addr) ENRKey() string { return "ip6" } + +// EncodeRLP implements rlp.Encoder. +func (v IPv6Addr) EncodeRLP(w io.Writer) error { + addr := netip.Addr(v) + if !addr.Is6() { + return fmt.Errorf("address is not IPv6") + } + enc := rlp.NewEncoderBuffer(w) + bytes := addr.As16() + enc.WriteBytes(bytes[:]) + return enc.Flush() +} + +// DecodeRLP implements rlp.Decoder. +func (v *IPv6Addr) DecodeRLP(s *rlp.Stream) error { + var bytes [16]byte + if err := s.ReadBytes(bytes[:]); err != nil { + return err + } + *v = IPv6Addr(netip.AddrFrom16(bytes)) + return nil +} + // KeyError is an error related to a key. type KeyError struct { Key string diff --git a/p2p/metrics.go b/p2p/metrics.go index a6e36b91a8..a2ae213b70 100644 --- a/p2p/metrics.go +++ b/p2p/metrics.go @@ -37,7 +37,9 @@ const ( ) var ( - activePeerGauge metrics.Gauge = metrics.NilGauge{} + activePeerGauge metrics.Gauge = metrics.NilGauge{} + activeInboundPeerGauge metrics.Gauge = metrics.NilGauge{} + activeOutboundPeerGauge metrics.Gauge = metrics.NilGauge{} ingressTrafficMeter = metrics.NewRegisteredMeter("p2p/ingress", nil) egressTrafficMeter = metrics.NewRegisteredMeter("p2p/egress", nil) @@ -65,6 +67,8 @@ func init() { } activePeerGauge = metrics.NewRegisteredGauge("p2p/peers", nil) + activeInboundPeerGauge = metrics.NewRegisteredGauge("p2p/peers/inbound", nil) + activeOutboundPeerGauge = metrics.NewRegisteredGauge("p2p/peers/outbound", nil) serveMeter = metrics.NewRegisteredMeter("p2p/serves", nil) serveSuccessMeter = metrics.NewRegisteredMeter("p2p/serves/success", nil) dialMeter = metrics.NewRegisteredMeter("p2p/dials", nil) diff --git a/p2p/nat/natpmp.go b/p2p/nat/natpmp.go index 97601c99dc..ea2d897829 100644 --- a/p2p/nat/natpmp.go +++ b/p2p/nat/natpmp.go @@ -17,6 +17,7 @@ package nat import ( + "errors" "fmt" "net" "strings" @@ -46,7 +47,7 @@ func (n *pmp) ExternalIP() (net.IP, error) { func (n *pmp) AddMapping(protocol string, extport, intport int, name string, lifetime time.Duration) (uint16, error) { if lifetime <= 0 { - return 0, fmt.Errorf("lifetime must not be <= 0") + return 0, errors.New("lifetime must not be <= 0") } // Note order of port arguments is switched between our // AddMapping and the client's AddPortMapping. diff --git a/p2p/nodestate/nodestate.go b/p2p/nodestate/nodestate.go deleted file mode 100644 index 1e1757559c..0000000000 --- a/p2p/nodestate/nodestate.go +++ /dev/null @@ -1,1023 +0,0 @@ -// Copyright 2020 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see . - -package nodestate - -import ( - "errors" - "reflect" - "sync" - "time" - "unsafe" - - "github.com/ethereum/go-ethereum/common/mclock" - "github.com/ethereum/go-ethereum/ethdb" - "github.com/ethereum/go-ethereum/log" - "github.com/ethereum/go-ethereum/metrics" - "github.com/ethereum/go-ethereum/p2p/enode" - "github.com/ethereum/go-ethereum/p2p/enr" - "github.com/ethereum/go-ethereum/rlp" -) - -var ( - ErrInvalidField = errors.New("invalid field type") - ErrClosed = errors.New("already closed") -) - -type ( - // NodeStateMachine implements a network node-related event subscription system. - // It can assign binary state flags and fields of arbitrary type to each node and allows - // subscriptions to flag/field changes which can also modify further flags and fields, - // potentially triggering further subscriptions. An operation includes an initial change - // and all resulting subsequent changes and always ends in a consistent global state. - // It is initiated by a "top level" SetState/SetField call that blocks (also blocking other - // top-level functions) until the operation is finished. Callbacks making further changes - // should use the non-blocking SetStateSub/SetFieldSub functions. The tree of events - // resulting from the initial changes is traversed in a breadth-first order, ensuring for - // each subscription callback that all other callbacks caused by the same change triggering - // the current callback are processed before anything is triggered by the changes made in the - // current callback. In practice this logic ensures that all subscriptions "see" events in - // the logical order, callbacks are never called concurrently and "back and forth" effects - // are also possible. The state machine design should ensure that infinite event cycles - // cannot happen. - // The caller can also add timeouts assigned to a certain node and a subset of state flags. - // If the timeout elapses, the flags are reset. If all relevant flags are reset then the timer - // is dropped. State flags with no timeout are persisted in the database if the flag - // descriptor enables saving. If a node has no state flags set at any moment then it is discarded. - // Note: in order to avoid mutex deadlocks the callbacks should never lock a mutex that - // might be locked when the top level SetState/SetField functions are called. If a function - // potentially performs state/field changes then it is recommended to mention this fact in the - // function description, along with whether it should run inside an operation callback. - NodeStateMachine struct { - started, closed bool - lock sync.Mutex - clock mclock.Clock - db ethdb.KeyValueStore - dbNodeKey []byte - nodes map[enode.ID]*nodeInfo - offlineCallbackList []offlineCallback - opFlag bool // an operation has started - opWait *sync.Cond // signaled when the operation ends - opPending []func() // pending callback list of the current operation - - // Registered state flags or fields. Modifications are allowed - // only when the node state machine has not been started. - setup *Setup - fields []*fieldInfo - saveFlags bitMask - - // Installed callbacks. Modifications are allowed only when the - // node state machine has not been started. - stateSubs []stateSub - - // Testing hooks, only for testing purposes. - saveNodeHook func(*nodeInfo) - } - - // Flags represents a set of flags from a certain setup - Flags struct { - mask bitMask - setup *Setup - } - - // Field represents a field from a certain setup - Field struct { - index int - setup *Setup - } - - // flagDefinition describes a node state flag. Each registered instance is automatically - // mapped to a bit of the 64 bit node states. - // If persistent is true then the node is saved when state machine is shutdown. - flagDefinition struct { - name string - persistent bool - } - - // fieldDefinition describes an optional node field of the given type. The contents - // of the field are only retained for each node as long as at least one of the - // state flags is set. - fieldDefinition struct { - name string - ftype reflect.Type - encode func(interface{}) ([]byte, error) - decode func([]byte) (interface{}, error) - } - - // Setup contains the list of flags and fields used by the application - Setup struct { - Version uint - flags []flagDefinition - fields []fieldDefinition - } - - // bitMask describes a node state or state mask. It represents a subset - // of node flags with each bit assigned to a flag index (LSB represents flag 0). - bitMask uint64 - - // StateCallback is a subscription callback which is called when one of the - // state flags that is included in the subscription state mask is changed. - // Note: oldState and newState are also masked with the subscription mask so only - // the relevant bits are included. - StateCallback func(n *enode.Node, oldState, newState Flags) - - // FieldCallback is a subscription callback which is called when the value of - // a specific field is changed. - FieldCallback func(n *enode.Node, state Flags, oldValue, newValue interface{}) - - // nodeInfo contains node state, fields and state timeouts - nodeInfo struct { - node *enode.Node - state bitMask - timeouts []*nodeStateTimeout - fields []interface{} - fieldCount int - db, dirty bool - } - - nodeInfoEnc struct { - Enr enr.Record - Version uint - State bitMask - Fields [][]byte - } - - stateSub struct { - mask bitMask - callback StateCallback - } - - nodeStateTimeout struct { - mask bitMask - timer mclock.Timer - } - - fieldInfo struct { - fieldDefinition - subs []FieldCallback - } - - offlineCallback struct { - node *nodeInfo - state bitMask - fields []interface{} - } -) - -// offlineState is a special state that is assumed to be set before a node is loaded from -// the database and after it is shut down. -const offlineState = bitMask(1) - -// NewFlag creates a new node state flag -func (s *Setup) NewFlag(name string) Flags { - if s.flags == nil { - s.flags = []flagDefinition{{name: "offline"}} - } - f := Flags{mask: bitMask(1) << uint(len(s.flags)), setup: s} - s.flags = append(s.flags, flagDefinition{name: name}) - return f -} - -// NewPersistentFlag creates a new persistent node state flag -func (s *Setup) NewPersistentFlag(name string) Flags { - if s.flags == nil { - s.flags = []flagDefinition{{name: "offline"}} - } - f := Flags{mask: bitMask(1) << uint(len(s.flags)), setup: s} - s.flags = append(s.flags, flagDefinition{name: name, persistent: true}) - return f -} - -// OfflineFlag returns the system-defined offline flag belonging to the given setup -func (s *Setup) OfflineFlag() Flags { - return Flags{mask: offlineState, setup: s} -} - -// NewField creates a new node state field -func (s *Setup) NewField(name string, ftype reflect.Type) Field { - f := Field{index: len(s.fields), setup: s} - s.fields = append(s.fields, fieldDefinition{ - name: name, - ftype: ftype, - }) - return f -} - -// NewPersistentField creates a new persistent node field -func (s *Setup) NewPersistentField(name string, ftype reflect.Type, encode func(interface{}) ([]byte, error), decode func([]byte) (interface{}, error)) Field { - f := Field{index: len(s.fields), setup: s} - s.fields = append(s.fields, fieldDefinition{ - name: name, - ftype: ftype, - encode: encode, - decode: decode, - }) - return f -} - -// flagOp implements binary flag operations and also checks whether the operands belong to the same setup -func flagOp(a, b Flags, trueIfA, trueIfB, trueIfBoth bool) Flags { - if a.setup == nil { - if a.mask != 0 { - panic("Node state flags have no setup reference") - } - a.setup = b.setup - } - if b.setup == nil { - if b.mask != 0 { - panic("Node state flags have no setup reference") - } - b.setup = a.setup - } - if a.setup != b.setup { - panic("Node state flags belong to a different setup") - } - res := Flags{setup: a.setup} - if trueIfA { - res.mask |= a.mask & ^b.mask - } - if trueIfB { - res.mask |= b.mask & ^a.mask - } - if trueIfBoth { - res.mask |= a.mask & b.mask - } - return res -} - -// And returns the set of flags present in both a and b -func (a Flags) And(b Flags) Flags { return flagOp(a, b, false, false, true) } - -// AndNot returns the set of flags present in a but not in b -func (a Flags) AndNot(b Flags) Flags { return flagOp(a, b, true, false, false) } - -// Or returns the set of flags present in either a or b -func (a Flags) Or(b Flags) Flags { return flagOp(a, b, true, true, true) } - -// Xor returns the set of flags present in either a or b but not both -func (a Flags) Xor(b Flags) Flags { return flagOp(a, b, true, true, false) } - -// HasAll returns true if b is a subset of a -func (a Flags) HasAll(b Flags) bool { return flagOp(a, b, false, true, false).mask == 0 } - -// HasNone returns true if a and b have no shared flags -func (a Flags) HasNone(b Flags) bool { return flagOp(a, b, false, false, true).mask == 0 } - -// Equals returns true if a and b have the same flags set -func (a Flags) Equals(b Flags) bool { return flagOp(a, b, true, true, false).mask == 0 } - -// IsEmpty returns true if a has no flags set -func (a Flags) IsEmpty() bool { return a.mask == 0 } - -// MergeFlags merges multiple sets of state flags -func MergeFlags(list ...Flags) Flags { - if len(list) == 0 { - return Flags{} - } - res := list[0] - for i := 1; i < len(list); i++ { - res = res.Or(list[i]) - } - return res -} - -// String returns a list of the names of the flags specified in the bit mask -func (f Flags) String() string { - if f.mask == 0 { - return "[]" - } - s := "[" - comma := false - for index, flag := range f.setup.flags { - if f.mask&(bitMask(1)< 8*int(unsafe.Sizeof(bitMask(0))) { - panic("Too many node state flags") - } - ns := &NodeStateMachine{ - db: db, - dbNodeKey: dbKey, - clock: clock, - setup: setup, - nodes: make(map[enode.ID]*nodeInfo), - fields: make([]*fieldInfo, len(setup.fields)), - } - ns.opWait = sync.NewCond(&ns.lock) - stateNameMap := make(map[string]int, len(setup.flags)) - for index, flag := range setup.flags { - if _, ok := stateNameMap[flag.name]; ok { - panic("Node state flag name collision: " + flag.name) - } - stateNameMap[flag.name] = index - if flag.persistent { - ns.saveFlags |= bitMask(1) << uint(index) - } - } - fieldNameMap := make(map[string]int, len(setup.fields)) - for index, field := range setup.fields { - if _, ok := fieldNameMap[field.name]; ok { - panic("Node field name collision: " + field.name) - } - ns.fields[index] = &fieldInfo{fieldDefinition: field} - fieldNameMap[field.name] = index - } - return ns -} - -// stateMask checks whether the set of flags belongs to the same setup and returns its internal bit mask -func (ns *NodeStateMachine) stateMask(flags Flags) bitMask { - if flags.setup != ns.setup && flags.mask != 0 { - panic("Node state flags belong to a different setup") - } - return flags.mask -} - -// fieldIndex checks whether the field belongs to the same setup and returns its internal index -func (ns *NodeStateMachine) fieldIndex(field Field) int { - if field.setup != ns.setup { - panic("Node field belongs to a different setup") - } - return field.index -} - -// SubscribeState adds a node state subscription. The callback is called while the state -// machine mutex is not held and it is allowed to make further state updates using the -// non-blocking SetStateSub/SetFieldSub functions. All callbacks of an operation are running -// from the thread/goroutine of the initial caller and parallel operations are not permitted. -// Therefore the callback is never called concurrently. It is the responsibility of the -// implemented state logic to avoid deadlocks and to reach a stable state in a finite amount -// of steps. -// State subscriptions should be installed before loading the node database or making the -// first state update. -func (ns *NodeStateMachine) SubscribeState(flags Flags, callback StateCallback) { - ns.lock.Lock() - defer ns.lock.Unlock() - - if ns.started { - panic("state machine already started") - } - ns.stateSubs = append(ns.stateSubs, stateSub{ns.stateMask(flags), callback}) -} - -// SubscribeField adds a node field subscription. Same rules apply as for SubscribeState. -func (ns *NodeStateMachine) SubscribeField(field Field, callback FieldCallback) { - ns.lock.Lock() - defer ns.lock.Unlock() - - if ns.started { - panic("state machine already started") - } - f := ns.fields[ns.fieldIndex(field)] - f.subs = append(f.subs, callback) -} - -// newNode creates a new nodeInfo -func (ns *NodeStateMachine) newNode(n *enode.Node) *nodeInfo { - return &nodeInfo{node: n, fields: make([]interface{}, len(ns.fields))} -} - -// checkStarted checks whether the state machine has already been started and panics otherwise. -func (ns *NodeStateMachine) checkStarted() { - if !ns.started { - panic("state machine not started yet") - } -} - -// Start starts the state machine, enabling state and field operations and disabling -// further subscriptions. -func (ns *NodeStateMachine) Start() { - ns.lock.Lock() - if ns.started { - panic("state machine already started") - } - ns.started = true - if ns.db != nil { - ns.loadFromDb() - } - - ns.opStart() - ns.offlineCallbacks(true) - ns.opFinish() - ns.lock.Unlock() -} - -// Stop stops the state machine and saves its state if a database was supplied -func (ns *NodeStateMachine) Stop() { - ns.lock.Lock() - defer ns.lock.Unlock() - - ns.checkStarted() - if !ns.opStart() { - panic("already closed") - } - for _, node := range ns.nodes { - fields := make([]interface{}, len(node.fields)) - copy(fields, node.fields) - ns.offlineCallbackList = append(ns.offlineCallbackList, offlineCallback{node, node.state, fields}) - } - if ns.db != nil { - ns.saveToDb() - } - ns.offlineCallbacks(false) - ns.closed = true - ns.opFinish() -} - -// loadFromDb loads persisted node states from the database -func (ns *NodeStateMachine) loadFromDb() { - it := ns.db.NewIterator(ns.dbNodeKey, nil) - for it.Next() { - var id enode.ID - if len(it.Key()) != len(ns.dbNodeKey)+len(id) { - log.Error("Node state db entry with invalid length", "found", len(it.Key()), "expected", len(ns.dbNodeKey)+len(id)) - continue - } - copy(id[:], it.Key()[len(ns.dbNodeKey):]) - ns.decodeNode(id, it.Value()) - } -} - -type dummyIdentity enode.ID - -func (id dummyIdentity) Verify(r *enr.Record, sig []byte) error { return nil } -func (id dummyIdentity) NodeAddr(r *enr.Record) []byte { return id[:] } - -// decodeNode decodes a node database entry and adds it to the node set if successful -func (ns *NodeStateMachine) decodeNode(id enode.ID, data []byte) { - var enc nodeInfoEnc - if err := rlp.DecodeBytes(data, &enc); err != nil { - log.Error("Failed to decode node info", "id", id, "error", err) - return - } - n, _ := enode.New(dummyIdentity(id), &enc.Enr) - node := ns.newNode(n) - node.db = true - - if enc.Version != ns.setup.Version { - log.Debug("Removing stored node with unknown version", "current", ns.setup.Version, "stored", enc.Version) - ns.deleteNode(id) - return - } - if len(enc.Fields) > len(ns.setup.fields) { - log.Error("Invalid node field count", "id", id, "stored", len(enc.Fields)) - return - } - // Resolve persisted node fields - for i, encField := range enc.Fields { - if len(encField) == 0 { - continue - } - if decode := ns.fields[i].decode; decode != nil { - if field, err := decode(encField); err == nil { - node.fields[i] = field - node.fieldCount++ - } else { - log.Error("Failed to decode node field", "id", id, "field name", ns.fields[i].name, "error", err) - return - } - } else { - log.Error("Cannot decode node field", "id", id, "field name", ns.fields[i].name) - return - } - } - // It's a compatible node record, add it to set. - ns.nodes[id] = node - node.state = enc.State - fields := make([]interface{}, len(node.fields)) - copy(fields, node.fields) - ns.offlineCallbackList = append(ns.offlineCallbackList, offlineCallback{node, node.state, fields}) - log.Debug("Loaded node state", "id", id, "state", Flags{mask: enc.State, setup: ns.setup}) -} - -// saveNode saves the given node info to the database -func (ns *NodeStateMachine) saveNode(id enode.ID, node *nodeInfo) error { - if ns.db == nil { - return nil - } - - storedState := node.state & ns.saveFlags - for _, t := range node.timeouts { - storedState &= ^t.mask - } - enc := nodeInfoEnc{ - Enr: *node.node.Record(), - Version: ns.setup.Version, - State: storedState, - Fields: make([][]byte, len(ns.fields)), - } - log.Debug("Saved node state", "id", id, "state", Flags{mask: enc.State, setup: ns.setup}) - lastIndex := -1 - for i, f := range node.fields { - if f == nil { - continue - } - encode := ns.fields[i].encode - if encode == nil { - continue - } - blob, err := encode(f) - if err != nil { - return err - } - enc.Fields[i] = blob - lastIndex = i - } - if storedState == 0 && lastIndex == -1 { - if node.db { - node.db = false - ns.deleteNode(id) - } - node.dirty = false - return nil - } - enc.Fields = enc.Fields[:lastIndex+1] - data, err := rlp.EncodeToBytes(&enc) - if err != nil { - return err - } - if err := ns.db.Put(append(ns.dbNodeKey, id[:]...), data); err != nil { - return err - } - node.dirty, node.db = false, true - - if ns.saveNodeHook != nil { - ns.saveNodeHook(node) - } - return nil -} - -// deleteNode removes a node info from the database -func (ns *NodeStateMachine) deleteNode(id enode.ID) { - ns.db.Delete(append(ns.dbNodeKey, id[:]...)) -} - -// saveToDb saves the persistent flags and fields of all nodes that have been changed -func (ns *NodeStateMachine) saveToDb() { - for id, node := range ns.nodes { - if node.dirty { - err := ns.saveNode(id, node) - if err != nil { - log.Error("Failed to save node", "id", id, "error", err) - } - } - } -} - -// updateEnode updates the enode entry belonging to the given node if it already exists -func (ns *NodeStateMachine) updateEnode(n *enode.Node) (enode.ID, *nodeInfo) { - id := n.ID() - node := ns.nodes[id] - if node != nil && n.Seq() > node.node.Seq() { - node.node = n - node.dirty = true - } - return id, node -} - -// Persist saves the persistent state and fields of the given node immediately -func (ns *NodeStateMachine) Persist(n *enode.Node) error { - ns.lock.Lock() - defer ns.lock.Unlock() - - ns.checkStarted() - if id, node := ns.updateEnode(n); node != nil && node.dirty { - err := ns.saveNode(id, node) - if err != nil { - log.Error("Failed to save node", "id", id, "error", err) - } - return err - } - return nil -} - -// SetState updates the given node state flags and blocks until the operation is finished. -// If a flag with a timeout is set again, the operation removes or replaces the existing timeout. -func (ns *NodeStateMachine) SetState(n *enode.Node, setFlags, resetFlags Flags, timeout time.Duration) error { - ns.lock.Lock() - defer ns.lock.Unlock() - - if !ns.opStart() { - return ErrClosed - } - ns.setState(n, setFlags, resetFlags, timeout) - ns.opFinish() - return nil -} - -// SetStateSub updates the given node state flags without blocking (should be called -// from a subscription/operation callback). -func (ns *NodeStateMachine) SetStateSub(n *enode.Node, setFlags, resetFlags Flags, timeout time.Duration) { - ns.lock.Lock() - defer ns.lock.Unlock() - - ns.opCheck() - ns.setState(n, setFlags, resetFlags, timeout) -} - -func (ns *NodeStateMachine) setState(n *enode.Node, setFlags, resetFlags Flags, timeout time.Duration) { - ns.checkStarted() - set, reset := ns.stateMask(setFlags), ns.stateMask(resetFlags) - id, node := ns.updateEnode(n) - if node == nil { - if set == 0 { - return - } - node = ns.newNode(n) - ns.nodes[id] = node - } - oldState := node.state - newState := (node.state & (^reset)) | set - changed := oldState ^ newState - node.state = newState - - // Remove the timeout callbacks for all reset and set flags, - // even they are not existent(it's noop). - ns.removeTimeouts(node, set|reset) - - // Register the timeout callback if required - if timeout != 0 && set != 0 { - ns.addTimeout(n, set, timeout) - } - if newState == oldState { - return - } - if newState == 0 && node.fieldCount == 0 { - delete(ns.nodes, id) - if node.db { - ns.deleteNode(id) - } - } else { - if changed&ns.saveFlags != 0 { - node.dirty = true - } - } - callback := func() { - for _, sub := range ns.stateSubs { - if changed&sub.mask != 0 { - sub.callback(n, Flags{mask: oldState & sub.mask, setup: ns.setup}, Flags{mask: newState & sub.mask, setup: ns.setup}) - } - } - } - ns.opPending = append(ns.opPending, callback) -} - -// opCheck checks whether an operation is active -func (ns *NodeStateMachine) opCheck() { - if !ns.opFlag { - panic("Operation has not started") - } -} - -// opStart waits until other operations are finished and starts a new one -func (ns *NodeStateMachine) opStart() bool { - for ns.opFlag { - ns.opWait.Wait() - } - if ns.closed { - return false - } - ns.opFlag = true - return true -} - -// opFinish finishes the current operation by running all pending callbacks. -// Callbacks resulting from a state/field change performed in a previous callback are always -// put at the end of the pending list and therefore processed after all callbacks resulting -// from the previous state/field change. -func (ns *NodeStateMachine) opFinish() { - for len(ns.opPending) != 0 { - list := ns.opPending - ns.lock.Unlock() - for _, cb := range list { - cb() - } - ns.lock.Lock() - ns.opPending = ns.opPending[len(list):] - } - ns.opPending = nil - ns.opFlag = false - ns.opWait.Broadcast() -} - -// Operation calls the given function as an operation callback. This allows the caller -// to start an operation with multiple initial changes. The same rules apply as for -// subscription callbacks. -func (ns *NodeStateMachine) Operation(fn func()) error { - ns.lock.Lock() - started := ns.opStart() - ns.lock.Unlock() - if !started { - return ErrClosed - } - fn() - ns.lock.Lock() - ns.opFinish() - ns.lock.Unlock() - return nil -} - -// offlineCallbacks calls state update callbacks at startup or shutdown -func (ns *NodeStateMachine) offlineCallbacks(start bool) { - for _, cb := range ns.offlineCallbackList { - cb := cb - callback := func() { - for _, sub := range ns.stateSubs { - offState := offlineState & sub.mask - onState := cb.state & sub.mask - if offState == onState { - continue - } - if start { - sub.callback(cb.node.node, Flags{mask: offState, setup: ns.setup}, Flags{mask: onState, setup: ns.setup}) - } else { - sub.callback(cb.node.node, Flags{mask: onState, setup: ns.setup}, Flags{mask: offState, setup: ns.setup}) - } - } - for i, f := range cb.fields { - if f == nil || ns.fields[i].subs == nil { - continue - } - for _, fsub := range ns.fields[i].subs { - if start { - fsub(cb.node.node, Flags{mask: offlineState, setup: ns.setup}, nil, f) - } else { - fsub(cb.node.node, Flags{mask: offlineState, setup: ns.setup}, f, nil) - } - } - } - } - ns.opPending = append(ns.opPending, callback) - } - ns.offlineCallbackList = nil -} - -// AddTimeout adds a node state timeout associated to the given state flag(s). -// After the specified time interval, the relevant states will be reset. -func (ns *NodeStateMachine) AddTimeout(n *enode.Node, flags Flags, timeout time.Duration) error { - ns.lock.Lock() - defer ns.lock.Unlock() - - ns.checkStarted() - if ns.closed { - return ErrClosed - } - ns.addTimeout(n, ns.stateMask(flags), timeout) - return nil -} - -// addTimeout adds a node state timeout associated to the given state flag(s). -func (ns *NodeStateMachine) addTimeout(n *enode.Node, mask bitMask, timeout time.Duration) { - _, node := ns.updateEnode(n) - if node == nil { - return - } - mask &= node.state - if mask == 0 { - return - } - ns.removeTimeouts(node, mask) - t := &nodeStateTimeout{mask: mask} - t.timer = ns.clock.AfterFunc(timeout, func() { - ns.lock.Lock() - defer ns.lock.Unlock() - - if !ns.opStart() { - return - } - ns.setState(n, Flags{}, Flags{mask: t.mask, setup: ns.setup}, 0) - ns.opFinish() - }) - node.timeouts = append(node.timeouts, t) - if mask&ns.saveFlags != 0 { - node.dirty = true - } -} - -// removeTimeout removes node state timeouts associated to the given state flag(s). -// If a timeout was associated to multiple flags which are not all included in the -// specified remove mask then only the included flags are de-associated and the timer -// stays active. -func (ns *NodeStateMachine) removeTimeouts(node *nodeInfo, mask bitMask) { - for i := 0; i < len(node.timeouts); i++ { - t := node.timeouts[i] - match := t.mask & mask - if match == 0 { - continue - } - t.mask -= match - if t.mask != 0 { - continue - } - t.timer.Stop() - node.timeouts[i] = node.timeouts[len(node.timeouts)-1] - node.timeouts = node.timeouts[:len(node.timeouts)-1] - i-- - if match&ns.saveFlags != 0 { - node.dirty = true - } - } -} - -// GetField retrieves the given field of the given node. Note that when used in a -// subscription callback the result can be out of sync with the state change represented -// by the callback parameters so extra safety checks might be necessary. -func (ns *NodeStateMachine) GetField(n *enode.Node, field Field) interface{} { - ns.lock.Lock() - defer ns.lock.Unlock() - - ns.checkStarted() - if ns.closed { - return nil - } - if _, node := ns.updateEnode(n); node != nil { - return node.fields[ns.fieldIndex(field)] - } - return nil -} - -// GetState retrieves the current state of the given node. Note that when used in a -// subscription callback the result can be out of sync with the state change represented -// by the callback parameters so extra safety checks might be necessary. -func (ns *NodeStateMachine) GetState(n *enode.Node) Flags { - ns.lock.Lock() - defer ns.lock.Unlock() - - ns.checkStarted() - if ns.closed { - return Flags{} - } - if _, node := ns.updateEnode(n); node != nil { - return Flags{mask: node.state, setup: ns.setup} - } - return Flags{} -} - -// SetField sets the given field of the given node and blocks until the operation is finished -func (ns *NodeStateMachine) SetField(n *enode.Node, field Field, value interface{}) error { - ns.lock.Lock() - defer ns.lock.Unlock() - - if !ns.opStart() { - return ErrClosed - } - err := ns.setField(n, field, value) - ns.opFinish() - return err -} - -// SetFieldSub sets the given field of the given node without blocking (should be called -// from a subscription/operation callback). -func (ns *NodeStateMachine) SetFieldSub(n *enode.Node, field Field, value interface{}) error { - ns.lock.Lock() - defer ns.lock.Unlock() - - ns.opCheck() - return ns.setField(n, field, value) -} - -func (ns *NodeStateMachine) setField(n *enode.Node, field Field, value interface{}) error { - ns.checkStarted() - id, node := ns.updateEnode(n) - if node == nil { - if value == nil { - return nil - } - node = ns.newNode(n) - ns.nodes[id] = node - } - fieldIndex := ns.fieldIndex(field) - f := ns.fields[fieldIndex] - if value != nil && reflect.TypeOf(value) != f.ftype { - log.Error("Invalid field type", "type", reflect.TypeOf(value), "required", f.ftype) - return ErrInvalidField - } - oldValue := node.fields[fieldIndex] - if value == oldValue { - return nil - } - if oldValue != nil { - node.fieldCount-- - } - if value != nil { - node.fieldCount++ - } - node.fields[fieldIndex] = value - if node.state == 0 && node.fieldCount == 0 { - delete(ns.nodes, id) - if node.db { - ns.deleteNode(id) - } - } else { - if f.encode != nil { - node.dirty = true - } - } - state := node.state - callback := func() { - for _, cb := range f.subs { - cb(n, Flags{mask: state, setup: ns.setup}, oldValue, value) - } - } - ns.opPending = append(ns.opPending, callback) - return nil -} - -// ForEach calls the callback for each node having all of the required and none of the -// disabled flags set. -// Note that this callback is not an operation callback but ForEach can be called from an -// Operation callback or Operation can also be called from a ForEach callback if necessary. -func (ns *NodeStateMachine) ForEach(requireFlags, disableFlags Flags, cb func(n *enode.Node, state Flags)) { - ns.lock.Lock() - ns.checkStarted() - type callback struct { - node *enode.Node - state bitMask - } - require, disable := ns.stateMask(requireFlags), ns.stateMask(disableFlags) - var callbacks []callback - for _, node := range ns.nodes { - if node.state&require == require && node.state&disable == 0 { - callbacks = append(callbacks, callback{node.node, node.state & (require | disable)}) - } - } - ns.lock.Unlock() - for _, c := range callbacks { - cb(c.node, Flags{mask: c.state, setup: ns.setup}) - } -} - -// GetNode returns the enode currently associated with the given ID -func (ns *NodeStateMachine) GetNode(id enode.ID) *enode.Node { - ns.lock.Lock() - defer ns.lock.Unlock() - - ns.checkStarted() - if node := ns.nodes[id]; node != nil { - return node.node - } - return nil -} - -// AddLogMetrics adds logging and/or metrics for nodes entering, exiting and currently -// being in a given set specified by required and disabled state flags -func (ns *NodeStateMachine) AddLogMetrics(requireFlags, disableFlags Flags, name string, inMeter, outMeter metrics.Meter, gauge metrics.Gauge) { - var count int64 - ns.SubscribeState(requireFlags.Or(disableFlags), func(n *enode.Node, oldState, newState Flags) { - oldMatch := oldState.HasAll(requireFlags) && oldState.HasNone(disableFlags) - newMatch := newState.HasAll(requireFlags) && newState.HasNone(disableFlags) - if newMatch == oldMatch { - return - } - - if newMatch { - count++ - if name != "" { - log.Debug("Node entered", "set", name, "id", n.ID(), "count", count) - } - if inMeter != nil { - inMeter.Mark(1) - } - } else { - count-- - if name != "" { - log.Debug("Node left", "set", name, "id", n.ID(), "count", count) - } - if outMeter != nil { - outMeter.Mark(1) - } - } - if gauge != nil { - gauge.Update(count) - } - }) -} diff --git a/p2p/nodestate/nodestate_test.go b/p2p/nodestate/nodestate_test.go deleted file mode 100644 index d06ad755e2..0000000000 --- a/p2p/nodestate/nodestate_test.go +++ /dev/null @@ -1,407 +0,0 @@ -// Copyright 2020 The go-ethereum Authors -// This file is part of the go-ethereum library. -// -// The go-ethereum library is free software: you can redistribute it and/or modify -// it under the terms of the GNU Lesser General Public License as published by -// the Free Software Foundation, either version 3 of the License, or -// (at your option) any later version. -// -// The go-ethereum library is distributed in the hope that it will be useful, -// but WITHOUT ANY WARRANTY; without even the implied warranty of -// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -// GNU Lesser General Public License for more details. -// -// You should have received a copy of the GNU Lesser General Public License -// along with the go-ethereum library. If not, see . - -package nodestate - -import ( - "errors" - "fmt" - "reflect" - "testing" - "time" - - "github.com/ethereum/go-ethereum/common/mclock" - "github.com/ethereum/go-ethereum/core/rawdb" - "github.com/ethereum/go-ethereum/p2p/enode" - "github.com/ethereum/go-ethereum/p2p/enr" - "github.com/ethereum/go-ethereum/rlp" -) - -func testSetup(flagPersist []bool, fieldType []reflect.Type) (*Setup, []Flags, []Field) { - setup := &Setup{} - flags := make([]Flags, len(flagPersist)) - for i, persist := range flagPersist { - if persist { - flags[i] = setup.NewPersistentFlag(fmt.Sprintf("flag-%d", i)) - } else { - flags[i] = setup.NewFlag(fmt.Sprintf("flag-%d", i)) - } - } - fields := make([]Field, len(fieldType)) - for i, ftype := range fieldType { - switch ftype { - case reflect.TypeOf(uint64(0)): - fields[i] = setup.NewPersistentField(fmt.Sprintf("field-%d", i), ftype, uint64FieldEnc, uint64FieldDec) - case reflect.TypeOf(""): - fields[i] = setup.NewPersistentField(fmt.Sprintf("field-%d", i), ftype, stringFieldEnc, stringFieldDec) - default: - fields[i] = setup.NewField(fmt.Sprintf("field-%d", i), ftype) - } - } - return setup, flags, fields -} - -func testNode(b byte) *enode.Node { - r := &enr.Record{} - r.SetSig(dummyIdentity{b}, []byte{42}) - n, _ := enode.New(dummyIdentity{b}, r) - return n -} - -func TestCallback(t *testing.T) { - mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{} - - s, flags, _ := testSetup([]bool{false, false, false}, nil) - ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) - - set0 := make(chan struct{}, 1) - set1 := make(chan struct{}, 1) - set2 := make(chan struct{}, 1) - ns.SubscribeState(flags[0], func(n *enode.Node, oldState, newState Flags) { set0 <- struct{}{} }) - ns.SubscribeState(flags[1], func(n *enode.Node, oldState, newState Flags) { set1 <- struct{}{} }) - ns.SubscribeState(flags[2], func(n *enode.Node, oldState, newState Flags) { set2 <- struct{}{} }) - - ns.Start() - - ns.SetState(testNode(1), flags[0], Flags{}, 0) - ns.SetState(testNode(1), flags[1], Flags{}, time.Second) - ns.SetState(testNode(1), flags[2], Flags{}, 2*time.Second) - - for i := 0; i < 3; i++ { - select { - case <-set0: - case <-set1: - case <-set2: - case <-time.After(time.Second): - t.Fatalf("failed to invoke callback") - } - } -} - -func TestPersistentFlags(t *testing.T) { - mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{} - - s, flags, _ := testSetup([]bool{true, true, true, false}, nil) - ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) - - saveNode := make(chan *nodeInfo, 5) - ns.saveNodeHook = func(node *nodeInfo) { - saveNode <- node - } - - ns.Start() - - ns.SetState(testNode(1), flags[0], Flags{}, time.Second) // state with timeout should not be saved - ns.SetState(testNode(2), flags[1], Flags{}, 0) - ns.SetState(testNode(3), flags[2], Flags{}, 0) - ns.SetState(testNode(4), flags[3], Flags{}, 0) - ns.SetState(testNode(5), flags[0], Flags{}, 0) - ns.Persist(testNode(5)) - select { - case <-saveNode: - case <-time.After(time.Second): - t.Fatalf("Timeout") - } - ns.Stop() - - for i := 0; i < 2; i++ { - select { - case <-saveNode: - case <-time.After(time.Second): - t.Fatalf("Timeout") - } - } - select { - case <-saveNode: - t.Fatalf("Unexpected saveNode") - case <-time.After(time.Millisecond * 100): - } -} - -func TestSetField(t *testing.T) { - mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{} - - s, flags, fields := testSetup([]bool{true}, []reflect.Type{reflect.TypeOf("")}) - ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) - - saveNode := make(chan *nodeInfo, 1) - ns.saveNodeHook = func(node *nodeInfo) { - saveNode <- node - } - - ns.Start() - - // Set field before setting state - ns.SetField(testNode(1), fields[0], "hello world") - field := ns.GetField(testNode(1), fields[0]) - if field == nil { - t.Fatalf("Field should be set before setting states") - } - ns.SetField(testNode(1), fields[0], nil) - field = ns.GetField(testNode(1), fields[0]) - if field != nil { - t.Fatalf("Field should be unset") - } - // Set field after setting state - ns.SetState(testNode(1), flags[0], Flags{}, 0) - ns.SetField(testNode(1), fields[0], "hello world") - field = ns.GetField(testNode(1), fields[0]) - if field == nil { - t.Fatalf("Field should be set after setting states") - } - if err := ns.SetField(testNode(1), fields[0], 123); err == nil { - t.Fatalf("Invalid field should be rejected") - } - // Dirty node should be written back - ns.Stop() - select { - case <-saveNode: - case <-time.After(time.Second): - t.Fatalf("Timeout") - } -} - -func TestSetState(t *testing.T) { - mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{} - - s, flags, _ := testSetup([]bool{false, false, false}, nil) - ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) - - type change struct{ old, new Flags } - set := make(chan change, 1) - ns.SubscribeState(flags[0].Or(flags[1]), func(n *enode.Node, oldState, newState Flags) { - set <- change{ - old: oldState, - new: newState, - } - }) - - ns.Start() - - check := func(expectOld, expectNew Flags, expectChange bool) { - if expectChange { - select { - case c := <-set: - if !c.old.Equals(expectOld) { - t.Fatalf("Old state mismatch") - } - if !c.new.Equals(expectNew) { - t.Fatalf("New state mismatch") - } - case <-time.After(time.Second): - } - return - } - select { - case <-set: - t.Fatalf("Unexpected change") - case <-time.After(time.Millisecond * 100): - return - } - } - ns.SetState(testNode(1), flags[0], Flags{}, 0) - check(Flags{}, flags[0], true) - - ns.SetState(testNode(1), flags[1], Flags{}, 0) - check(flags[0], flags[0].Or(flags[1]), true) - - ns.SetState(testNode(1), flags[2], Flags{}, 0) - check(Flags{}, Flags{}, false) - - ns.SetState(testNode(1), Flags{}, flags[0], 0) - check(flags[0].Or(flags[1]), flags[1], true) - - ns.SetState(testNode(1), Flags{}, flags[1], 0) - check(flags[1], Flags{}, true) - - ns.SetState(testNode(1), Flags{}, flags[2], 0) - check(Flags{}, Flags{}, false) - - ns.SetState(testNode(1), flags[0].Or(flags[1]), Flags{}, time.Second) - check(Flags{}, flags[0].Or(flags[1]), true) - clock.Run(time.Second) - check(flags[0].Or(flags[1]), Flags{}, true) -} - -func uint64FieldEnc(field interface{}) ([]byte, error) { - if u, ok := field.(uint64); ok { - enc, err := rlp.EncodeToBytes(&u) - return enc, err - } - return nil, errors.New("invalid field type") -} - -func uint64FieldDec(enc []byte) (interface{}, error) { - var u uint64 - err := rlp.DecodeBytes(enc, &u) - return u, err -} - -func stringFieldEnc(field interface{}) ([]byte, error) { - if s, ok := field.(string); ok { - return []byte(s), nil - } - return nil, errors.New("invalid field type") -} - -func stringFieldDec(enc []byte) (interface{}, error) { - return string(enc), nil -} - -func TestPersistentFields(t *testing.T) { - mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{} - - s, flags, fields := testSetup([]bool{true}, []reflect.Type{reflect.TypeOf(uint64(0)), reflect.TypeOf("")}) - ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) - - ns.Start() - ns.SetState(testNode(1), flags[0], Flags{}, 0) - ns.SetField(testNode(1), fields[0], uint64(100)) - ns.SetField(testNode(1), fields[1], "hello world") - ns.Stop() - - ns2 := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) - - ns2.Start() - field0 := ns2.GetField(testNode(1), fields[0]) - if !reflect.DeepEqual(field0, uint64(100)) { - t.Fatalf("Field changed") - } - field1 := ns2.GetField(testNode(1), fields[1]) - if !reflect.DeepEqual(field1, "hello world") { - t.Fatalf("Field changed") - } - - s.Version++ - ns3 := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) - ns3.Start() - if ns3.GetField(testNode(1), fields[0]) != nil { - t.Fatalf("Old field version should have been discarded") - } -} - -func TestFieldSub(t *testing.T) { - mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{} - - s, flags, fields := testSetup([]bool{true}, []reflect.Type{reflect.TypeOf(uint64(0))}) - ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) - - var ( - lastState Flags - lastOldValue, lastNewValue interface{} - ) - ns.SubscribeField(fields[0], func(n *enode.Node, state Flags, oldValue, newValue interface{}) { - lastState, lastOldValue, lastNewValue = state, oldValue, newValue - }) - check := func(state Flags, oldValue, newValue interface{}) { - if !lastState.Equals(state) || lastOldValue != oldValue || lastNewValue != newValue { - t.Fatalf("Incorrect field sub callback (expected [%v %v %v], got [%v %v %v])", state, oldValue, newValue, lastState, lastOldValue, lastNewValue) - } - } - ns.Start() - ns.SetState(testNode(1), flags[0], Flags{}, 0) - ns.SetField(testNode(1), fields[0], uint64(100)) - check(flags[0], nil, uint64(100)) - ns.Stop() - check(s.OfflineFlag(), uint64(100), nil) - - ns2 := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) - ns2.SubscribeField(fields[0], func(n *enode.Node, state Flags, oldValue, newValue interface{}) { - lastState, lastOldValue, lastNewValue = state, oldValue, newValue - }) - ns2.Start() - check(s.OfflineFlag(), nil, uint64(100)) - ns2.SetState(testNode(1), Flags{}, flags[0], 0) - ns2.SetField(testNode(1), fields[0], nil) - check(Flags{}, uint64(100), nil) - ns2.Stop() -} - -func TestDuplicatedFlags(t *testing.T) { - mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{} - - s, flags, _ := testSetup([]bool{true}, nil) - ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) - - type change struct{ old, new Flags } - set := make(chan change, 1) - ns.SubscribeState(flags[0], func(n *enode.Node, oldState, newState Flags) { - set <- change{oldState, newState} - }) - - ns.Start() - defer ns.Stop() - - check := func(expectOld, expectNew Flags, expectChange bool) { - if expectChange { - select { - case c := <-set: - if !c.old.Equals(expectOld) { - t.Fatalf("Old state mismatch") - } - if !c.new.Equals(expectNew) { - t.Fatalf("New state mismatch") - } - case <-time.After(time.Second): - } - return - } - select { - case <-set: - t.Fatalf("Unexpected change") - case <-time.After(time.Millisecond * 100): - return - } - } - ns.SetState(testNode(1), flags[0], Flags{}, time.Second) - check(Flags{}, flags[0], true) - ns.SetState(testNode(1), flags[0], Flags{}, 2*time.Second) // extend the timeout to 2s - check(Flags{}, flags[0], false) - - clock.Run(2 * time.Second) - check(flags[0], Flags{}, true) -} - -func TestCallbackOrder(t *testing.T) { - mdb, clock := rawdb.NewMemoryDatabase(), &mclock.Simulated{} - - s, flags, _ := testSetup([]bool{false, false, false, false}, nil) - ns := NewNodeStateMachine(mdb, []byte("-ns"), clock, s) - - ns.SubscribeState(flags[0], func(n *enode.Node, oldState, newState Flags) { - if newState.Equals(flags[0]) { - ns.SetStateSub(n, flags[1], Flags{}, 0) - ns.SetStateSub(n, flags[2], Flags{}, 0) - } - }) - ns.SubscribeState(flags[1], func(n *enode.Node, oldState, newState Flags) { - if newState.Equals(flags[1]) { - ns.SetStateSub(n, flags[3], Flags{}, 0) - } - }) - lastState := Flags{} - ns.SubscribeState(MergeFlags(flags[1], flags[2], flags[3]), func(n *enode.Node, oldState, newState Flags) { - if !oldState.Equals(lastState) { - t.Fatalf("Wrong callback order") - } - lastState = newState - }) - - ns.Start() - defer ns.Stop() - - ns.SetState(testNode(1), flags[0], Flags{}, 0) -} diff --git a/p2p/rlpx/rlpx.go b/p2p/rlpx/rlpx.go index 8bd6f64b9b..a338490e62 100644 --- a/p2p/rlpx/rlpx.go +++ b/p2p/rlpx/rlpx.go @@ -22,7 +22,6 @@ import ( "crypto/aes" "crypto/cipher" "crypto/ecdsa" - "crypto/elliptic" "crypto/hmac" "crypto/rand" "encoding/binary" @@ -664,7 +663,10 @@ func exportPubkey(pub *ecies.PublicKey) []byte { if pub == nil { panic("nil pubkey") } - return elliptic.Marshal(pub.Curve, pub.X, pub.Y)[1:] + if curve, ok := pub.Curve.(crypto.EllipticCurve); ok { + return curve.Marshal(pub.X, pub.Y)[1:] + } + return []byte{} } func xor(one, other []byte) (xor []byte) { diff --git a/p2p/server.go b/p2p/server.go index 975a3bb916..c5d5e23c44 100644 --- a/p2p/server.go +++ b/p2p/server.go @@ -24,6 +24,7 @@ import ( "errors" "fmt" "net" + "net/netip" "sync" "sync/atomic" "time" @@ -190,8 +191,8 @@ type Server struct { nodedb *enode.DB localnode *enode.LocalNode - ntab *discover.UDPv4 - DiscV5 *discover.UDPv5 + discv4 *discover.UDPv4 + discv5 *discover.UDPv5 discmix *enode.FairMix dialsched *dialScheduler @@ -400,6 +401,16 @@ func (srv *Server) Self() *enode.Node { return ln.Node() } +// DiscoveryV4 returns the discovery v4 instance, if configured. +func (srv *Server) DiscoveryV4() *discover.UDPv4 { + return srv.discv4 +} + +// DiscoveryV5 returns the discovery v5 instance, if configured. +func (srv *Server) DiscoveryV5() *discover.UDPv5 { + return srv.discv5 +} + // Stop terminates the server and all active peer connections. // It blocks until all active connections have been closed. func (srv *Server) Stop() { @@ -425,11 +436,11 @@ type sharedUDPConn struct { unhandled chan discover.ReadPacket } -// ReadFromUDP implements discover.UDPConn -func (s *sharedUDPConn) ReadFromUDP(b []byte) (n int, addr *net.UDPAddr, err error) { +// ReadFromUDPAddrPort implements discover.UDPConn +func (s *sharedUDPConn) ReadFromUDPAddrPort(b []byte) (n int, addr netip.AddrPort, err error) { packet, ok := <-s.unhandled if !ok { - return 0, nil, errors.New("connection was closed") + return 0, netip.AddrPort{}, errors.New("connection was closed") } l := len(packet.Data) if l > len(b) { @@ -547,13 +558,13 @@ func (srv *Server) setupDiscovery() error { ) // If both versions of discovery are running, setup a shared // connection, so v5 can read unhandled messages from v4. - if srv.DiscoveryV4 && srv.DiscoveryV5 { + if srv.Config.DiscoveryV4 && srv.Config.DiscoveryV5 { unhandled = make(chan discover.ReadPacket, 100) sconn = &sharedUDPConn{conn, unhandled} } // Start discovery services. - if srv.DiscoveryV4 { + if srv.Config.DiscoveryV4 { cfg := discover.Config{ PrivateKey: srv.PrivateKey, NetRestrict: srv.NetRestrict, @@ -565,17 +576,17 @@ func (srv *Server) setupDiscovery() error { if err != nil { return err } - srv.ntab = ntab + srv.discv4 = ntab srv.discmix.AddSource(ntab.RandomNodes()) } - if srv.DiscoveryV5 { + if srv.Config.DiscoveryV5 { cfg := discover.Config{ PrivateKey: srv.PrivateKey, NetRestrict: srv.NetRestrict, Bootnodes: srv.BootstrapNodesV5, Log: srv.log, } - srv.DiscV5, err = discover.ListenV5(sconn, srv.localnode, cfg) + srv.discv5, err = discover.ListenV5(sconn, srv.localnode, cfg) if err != nil { return err } @@ -602,8 +613,8 @@ func (srv *Server) setupDialScheduler() { dialer: srv.Dialer, clock: srv.clock, } - if srv.ntab != nil { - config.resolver = srv.ntab + if srv.discv4 != nil { + config.resolver = srv.discv4 } if config.dialer == nil { config.dialer = tcpDialer{&net.Dialer{Timeout: defaultDialTimeout}} @@ -771,8 +782,10 @@ running: if p.Inbound() { inboundCount++ serveSuccessMeter.Mark(1) + activeInboundPeerGauge.Inc(1) } else { dialSuccessMeter.Mark(1) + activeOutboundPeerGauge.Inc(1) } activePeerGauge.Inc(1) } @@ -786,6 +799,9 @@ running: srv.dialsched.peerRemoved(pd.rw) if pd.Inbound() { inboundCount-- + activeInboundPeerGauge.Dec(1) + } else { + activeOutboundPeerGauge.Dec(1) } activePeerGauge.Dec(1) } @@ -794,11 +810,11 @@ running: srv.log.Trace("P2P networking is spinning down") // Terminate discovery. If there is a running lookup it will terminate soon. - if srv.ntab != nil { - srv.ntab.Close() + if srv.discv4 != nil { + srv.discv4.Close() } - if srv.DiscV5 != nil { - srv.DiscV5.Close() + if srv.discv5 != nil { + srv.discv5.Close() } // Disconnect all peers. for _, p := range peers { @@ -937,7 +953,7 @@ func (srv *Server) SetupConn(fd net.Conn, flags connFlag, dialDest *enode.Node) c.transport = srv.newTransport(fd, dialDest.Pubkey()) } - err := srv.setupConn(c, flags, dialDest) + err := srv.setupConn(c, dialDest) if err != nil { if !c.is(inboundConn) { markDialError(err) @@ -947,7 +963,7 @@ func (srv *Server) SetupConn(fd net.Conn, flags connFlag, dialDest *enode.Node) return err } -func (srv *Server) setupConn(c *conn, flags connFlag, dialDest *enode.Node) error { +func (srv *Server) setupConn(c *conn, dialDest *enode.Node) error { // Prevent leftover pending conns from entering the handshake. srv.lock.Lock() running := srv.running diff --git a/p2p/simulations/adapters/exec.go b/p2p/simulations/adapters/exec.go index 63cc4936c1..17e0f75d5a 100644 --- a/p2p/simulations/adapters/exec.go +++ b/p2p/simulations/adapters/exec.go @@ -460,7 +460,7 @@ func startExecNodeStack() (*node.Node, error) { // decode the config confEnv := os.Getenv(envNodeConfig) if confEnv == "" { - return nil, fmt.Errorf("missing " + envNodeConfig) + return nil, errors.New("missing " + envNodeConfig) } var conf execNodeConfig if err := json.Unmarshal([]byte(confEnv), &conf); err != nil { diff --git a/p2p/simulations/adapters/inproc_test.go b/p2p/simulations/adapters/inproc_test.go index 2a61508fe1..d0539ca867 100644 --- a/p2p/simulations/adapters/inproc_test.go +++ b/p2p/simulations/adapters/inproc_test.go @@ -78,7 +78,7 @@ func TestTCPPipeBidirections(t *testing.T) { } if !bytes.Equal(expected, out) { - t.Fatalf("expected %#v, got %#v", out, expected) + t.Fatalf("expected %#v, got %#v", expected, out) } else { msg := []byte(fmt.Sprintf("pong %02d", i)) if _, err := c2.Write(msg); err != nil { @@ -94,7 +94,7 @@ func TestTCPPipeBidirections(t *testing.T) { t.Fatal(err) } if !bytes.Equal(expected, out) { - t.Fatalf("expected %#v, got %#v", out, expected) + t.Fatalf("expected %#v, got %#v", expected, out) } } } diff --git a/p2p/simulations/adapters/types.go b/p2p/simulations/adapters/types.go index fb8463d221..94b06b5903 100644 --- a/p2p/simulations/adapters/types.go +++ b/p2p/simulations/adapters/types.go @@ -42,7 +42,6 @@ import ( // // - SimNode, an in-memory node in the same process // - ExecNode, a child process node -// - DockerNode, a node running in a Docker container type Node interface { // Addr returns the node's address (e.g. an Enode URL) Addr() []byte diff --git a/p2p/simulations/examples/ping-pong.go b/p2p/simulations/examples/ping-pong.go index 70b35ad777..b0b8f22fdb 100644 --- a/p2p/simulations/examples/ping-pong.go +++ b/p2p/simulations/examples/ping-pong.go @@ -33,7 +33,7 @@ import ( "github.com/ethereum/go-ethereum/p2p/simulations/adapters" ) -var adapterType = flag.String("adapter", "sim", `node adapter to use (one of "sim", "exec" or "docker")`) +var adapterType = flag.String("adapter", "sim", `node adapter to use (one of "sim" or "exec")`) // main() starts a simulation network which contains nodes running a simple // ping-pong protocol diff --git a/rlp/unsafe.go b/rlp/unsafe.go index 2152ba35fc..10868caaf2 100644 --- a/rlp/unsafe.go +++ b/rlp/unsafe.go @@ -26,10 +26,5 @@ import ( // byteArrayBytes returns a slice of the byte array v. func byteArrayBytes(v reflect.Value, length int) []byte { - var s []byte - hdr := (*reflect.SliceHeader)(unsafe.Pointer(&s)) - hdr.Data = v.UnsafeAddr() - hdr.Cap = length - hdr.Len = length - return s + return unsafe.Slice((*byte)(unsafe.Pointer(v.UnsafeAddr())), length) } diff --git a/rpc/handler.go b/rpc/handler.go index f44e4d7b01..792581cbc0 100644 --- a/rpc/handler.go +++ b/rpc/handler.go @@ -324,7 +324,7 @@ func (h *handler) addRequestOp(op *requestOp) { } } -// removeRequestOps stops waiting for the given request IDs. +// removeRequestOp stops waiting for the given request IDs. func (h *handler) removeRequestOp(op *requestOp) { for _, id := range op.ids { delete(h.respWait, string(id)) diff --git a/signer/core/apitypes/types.go b/signer/core/apitypes/types.go index 6bfcd2a727..e28f059106 100644 --- a/signer/core/apitypes/types.go +++ b/signer/core/apitypes/types.go @@ -708,7 +708,7 @@ func formatPrimitiveValue(encType string, encValue interface{}) (string, error) func (t Types) validate() error { for typeKey, typeArr := range t { if len(typeKey) == 0 { - return fmt.Errorf("empty type key") + return errors.New("empty type key") } for i, typeObj := range typeArr { if len(typeObj.Type) == 0 { diff --git a/signer/core/signed_data.go b/signer/core/signed_data.go index c6ae7b1274..f8b3c9d86d 100644 --- a/signer/core/signed_data.go +++ b/signer/core/signed_data.go @@ -260,7 +260,7 @@ func fromHex(data any) ([]byte, error) { return nil, fmt.Errorf("wrong type %T", data) } -// typeDataRequest tries to convert the data into a SignDataRequest. +// typedDataRequest tries to convert the data into a SignDataRequest. func typedDataRequest(data any) (*SignDataRequest, error) { var typedData apitypes.TypedData if td, ok := data.(apitypes.TypedData); ok { diff --git a/trie/trie_test.go b/trie/trie_test.go index c141c52078..7f126eba64 100644 --- a/trie/trie_test.go +++ b/trie/trie_test.go @@ -556,7 +556,7 @@ func runRandTest(rt randTest) error { checktr.MustUpdate(it.Key, it.Value) } if tr.Hash() != checktr.Hash() { - rt[i].err = fmt.Errorf("hash mismatch in opItercheckhash") + rt[i].err = errors.New("hash mismatch in opItercheckhash") } case opNodeDiff: var ( @@ -594,19 +594,19 @@ func runRandTest(rt randTest) error { } } if len(insertExp) != len(tr.tracer.inserts) { - rt[i].err = fmt.Errorf("insert set mismatch") + rt[i].err = errors.New("insert set mismatch") } if len(deleteExp) != len(tr.tracer.deletes) { - rt[i].err = fmt.Errorf("delete set mismatch") + rt[i].err = errors.New("delete set mismatch") } for insert := range tr.tracer.inserts { if _, present := insertExp[insert]; !present { - rt[i].err = fmt.Errorf("missing inserted node") + rt[i].err = errors.New("missing inserted node") } } for del := range tr.tracer.deletes { if _, present := deleteExp[del]; !present { - rt[i].err = fmt.Errorf("missing deleted node") + rt[i].err = errors.New("missing deleted node") } } } diff --git a/triedb/pathdb/history.go b/triedb/pathdb/history.go index 6e3f3faaed..051e122bec 100644 --- a/triedb/pathdb/history.go +++ b/triedb/pathdb/history.go @@ -215,7 +215,7 @@ func (m *meta) encode() []byte { // decode unpacks the meta object from byte stream. func (m *meta) decode(blob []byte) error { if len(blob) < 1 { - return fmt.Errorf("no version tag") + return errors.New("no version tag") } switch blob[0] { case stateHistoryVersion: