diff --git a/impl/cmd/main.go b/impl/cmd/main.go index 2c341752..83b6a9c2 100644 --- a/impl/cmd/main.go +++ b/impl/cmd/main.go @@ -109,8 +109,8 @@ func configureLogger(level string) { if level != "" { logLevel, err := logrus.ParseLevel(level) if err != nil { - logrus.WithError(err).WithField("level", level).Error("could not parse log level, setting to info") - logrus.SetLevel(logrus.InfoLevel) + logrus.WithError(err).WithField("level", level).Error("could not parse log level, setting to debug") + logrus.SetLevel(logrus.DebugLevel) } else { logrus.SetLevel(logLevel) } diff --git a/impl/config/config.go b/impl/config/config.go index a95ec23e..e17d5d5e 100644 --- a/impl/config/config.go +++ b/impl/config/config.go @@ -89,7 +89,7 @@ func GetDefaultConfig() Config { CacheSizeLimitMB: 500, }, Log: LogConfig{ - Level: logrus.InfoLevel.String(), + Level: logrus.DebugLevel.String(), }, } } diff --git a/impl/integrationtest/main.go b/impl/integrationtest/main.go index bad1df58..5765c773 100644 --- a/impl/integrationtest/main.go +++ b/impl/integrationtest/main.go @@ -22,7 +22,7 @@ var ( ) func main() { - logrus.SetLevel(logrus.InfoLevel) + logrus.SetLevel(logrus.DebugLevel) if len(os.Args) < 2 { logrus.Fatal("must specify 1 argument (server URL)") } diff --git a/impl/internal/dht/getput.go b/impl/internal/dht/getput.go index 52cbe9c1..ac62ab2c 100644 --- a/impl/internal/dht/getput.go +++ b/impl/internal/dht/getput.go @@ -6,6 +6,7 @@ import ( "errors" "math" "sync" + "time" k_nearest_nodes "github.com/anacrolix/dht/v2/k-nearest-nodes" "github.com/anacrolix/torrent/bencode" @@ -37,7 +38,10 @@ func startGetTraversal( Alpha: 15, Target: target, DoQuery: func(ctx context.Context, addr krpc.NodeAddr) traversal.QueryResult { - res := s.Get(ctx, dht.NewAddr(addr.UDP()), target, seq, dht.QueryRateLimiting{}) + queryCtx, cancel := context.WithTimeout(ctx, 8*time.Second) + defer cancel() + + res := s.Get(queryCtx, dht.NewAddr(addr.UDP()), target, seq, dht.QueryRateLimiting{}) err := res.ToError() if err != nil && !errors.Is(err, context.Canceled) && !errors.Is(err, dht.TransactionTimeout) { logrus.WithContext(ctx).WithError(err).Debugf("error querying %v", addr) @@ -52,7 +56,7 @@ func startGetTraversal( Sig: r.Sig, Mutable: false, }: - case <-ctx.Done(): + case <-queryCtx.Done(): } } else if sha1.Sum(append(r.K[:], salt...)) == target && bep44.Verify(r.K[:], salt, *r.Seq, bv, r.Sig[:]) { select { @@ -62,15 +66,13 @@ func startGetTraversal( Sig: r.Sig, Mutable: true, }: - case <-ctx.Done(): + case <-queryCtx.Done(): } } else if rv != nil { logrus.WithContext(ctx).Debugf("get response item hash didn't match target: %q", rv) } } tqr := res.TraversalQueryResult(addr) - // Filter replies from nodes that don't have a string token. This doesn't look prettier - // with generics. "The token value should be a short binary string." ¯\_(ツ)_/¯ (BEP 5). tqr.ClosestData, _ = tqr.ClosestData.(string) if tqr.ClosestData == nil { tqr.ResponseFrom = nil @@ -80,7 +82,7 @@ func startGetTraversal( NodeFilter: s.TraversalNodeFilter, }) - // list for context cancellation or stalled traversal + // Listen for context cancellation or stalled traversal go func() { select { case <-ctx.Done(): diff --git a/impl/pkg/dht/logging.go b/impl/pkg/dht/logging.go index f54fb3c1..47e3af2a 100644 --- a/impl/pkg/dht/logging.go +++ b/impl/pkg/dht/logging.go @@ -1,12 +1,13 @@ package dht import ( + "strings" + "github.com/anacrolix/log" "github.com/sirupsen/logrus" ) func init() { - logrus.SetFormatter(&logrus.JSONFormatter{}) log.Default.Handlers = []log.Handler{logrusHandler{}} } @@ -16,7 +17,7 @@ type logrusHandler struct{} // It intentionally downgrades the log level to reduce verbosity. func (logrusHandler) Handle(record log.Record) { entry := logrus.WithFields(logrus.Fields{"names": record.Names}) - msg := record.Msg.String() + msg := strings.Replace(record.Msg.String(), "\n", "", -1) switch record.Level { case log.Debug: diff --git a/impl/pkg/server/pkarr.go b/impl/pkg/server/pkarr.go index fc61c500..d78293cf 100644 --- a/impl/pkg/server/pkarr.go +++ b/impl/pkg/server/pkarr.go @@ -43,7 +43,18 @@ func (r *PkarrRouter) GetRecord(c *gin.Context) { return } - resp, err := r.service.GetPkarr(c.Request.Context(), *id) + // make sure the key is valid + key, err := util.Z32Decode(*id) + if err != nil { + LoggingRespondErrWithMsg(c, err, "invalid record id", http.StatusInternalServerError) + return + } + if len(key) != ed25519.PublicKeySize { + LoggingRespondErrMsg(c, "invalid z32 encoded ed25519 public key", http.StatusBadRequest) + return + } + + resp, err := r.service.GetPkarr(c, *id) if err != nil { LoggingRespondErrWithMsg(c, err, "failed to get pkarr record", http.StatusInternalServerError) return @@ -82,7 +93,7 @@ func (r *PkarrRouter) PutRecord(c *gin.Context) { } key, err := util.Z32Decode(*id) if err != nil { - LoggingRespondErrWithMsg(c, err, "failed to read id", http.StatusInternalServerError) + LoggingRespondErrWithMsg(c, err, "invalid record id", http.StatusInternalServerError) return } if len(key) != ed25519.PublicKeySize { @@ -114,7 +125,7 @@ func (r *PkarrRouter) PutRecord(c *gin.Context) { return } - if err = r.service.PublishPkarr(c.Request.Context(), *id, *request); err != nil { + if err = r.service.PublishPkarr(c, *id, *request); err != nil { LoggingRespondErrWithMsg(c, err, "failed to publish pkarr record", http.StatusInternalServerError) return } diff --git a/impl/pkg/server/pkarr_test.go b/impl/pkg/server/pkarr_test.go index f635600b..f87bd5c6 100644 --- a/impl/pkg/server/pkarr_test.go +++ b/impl/pkg/server/pkarr_test.go @@ -140,7 +140,7 @@ func TestPkarrRouter(t *testing.T) { t.Run("test get not found", func(t *testing.T) { w := httptest.NewRecorder() - suffix := "aaa" + suffix := "uqaj3fcr9db6jg6o9pjs53iuftyj45r46aubogfaceqjbo6pp9sy" req := httptest.NewRequest(http.MethodGet, fmt.Sprintf("%s/%s", testServerURL, suffix), nil) c := newRequestContextWithParams(w, req, map[string]string{IDParam: suffix}) pkarrRouter.GetRecord(c) diff --git a/impl/pkg/server/server.go b/impl/pkg/server/server.go index 66bb7138..6da80e70 100644 --- a/impl/pkg/server/server.go +++ b/impl/pkg/server/server.go @@ -70,9 +70,10 @@ func NewServer(cfg *config.Config, shutdown chan os.Signal, d *dht.DHT) (*Server Server: &http.Server{ Addr: fmt.Sprintf("%s:%d", cfg.ServerConfig.APIHost, cfg.ServerConfig.APIPort), Handler: handler, - ReadTimeout: time.Second * 10, + ReadTimeout: time.Second * 15, ReadHeaderTimeout: time.Second * 10, WriteTimeout: time.Second * 10, + MaxHeaderBytes: 1 << 20, }, cfg: cfg, svc: pkarrService, diff --git a/impl/pkg/service/pkarr.go b/impl/pkg/service/pkarr.go index 781384c8..018cc11c 100644 --- a/impl/pkg/service/pkarr.go +++ b/impl/pkg/service/pkarr.go @@ -28,11 +28,12 @@ const recordSizeLimit = 1000 // PkarrService is the Pkarr service responsible for managing the Pkarr DHT and reading/writing records type PkarrService struct { - cfg *config.Config - db storage.Storage - dht *dht.DHT - cache *bigcache.BigCache - scheduler *dhtint.Scheduler + cfg *config.Config + db storage.Storage + dht *dht.DHT + cache *bigcache.BigCache + badGetCache *bigcache.BigCache + scheduler *dhtint.Scheduler } // NewPkarrService returns a new instance of the Pkarr service @@ -41,7 +42,7 @@ func NewPkarrService(cfg *config.Config, db storage.Storage, d *dht.DHT) (*Pkarr return nil, ssiutil.LoggingNewError("config is required") } - // create and start cache and scheduler + // create and start get cache cacheTTL := time.Duration(cfg.PkarrConfig.CacheTTLSeconds) * time.Second cacheConfig := bigcache.DefaultConfig(cacheTTL) cacheConfig.MaxEntrySize = recordSizeLimit @@ -51,13 +52,24 @@ func NewPkarrService(cfg *config.Config, db storage.Storage, d *dht.DHT) (*Pkarr if err != nil { return nil, ssiutil.LoggingErrorMsg(err, "failed to instantiate cache") } + + // create a new cache for bad gets to prevent spamming the DHT + cacheConfig.LifeWindow = 120 * time.Second + cacheConfig.CleanWindow = 60 * time.Second + badGetCache, err := bigcache.New(context.Background(), cacheConfig) + if err != nil { + return nil, ssiutil.LoggingErrorMsg(err, "failed to instantiate badGetCache") + } + + // start scheduler for republishing scheduler := dhtint.NewScheduler() svc := PkarrService{ - cfg: cfg, - db: db, - dht: d, - cache: cache, - scheduler: &scheduler, + cfg: cfg, + db: db, + dht: d, + cache: cache, + badGetCache: badGetCache, + scheduler: &scheduler, } if err = scheduler.Schedule(cfg.PkarrConfig.RepublishCRON, svc.republish); err != nil { return nil, ssiutil.LoggingErrorMsg(err, "failed to start republisher") @@ -70,6 +82,11 @@ func (s *PkarrService) PublishPkarr(ctx context.Context, id string, record pkarr ctx, span := telemetry.GetTracer().Start(ctx, "PkarrService.PublishPkarr") defer span.End() + // make sure the key is valid + if _, err := util.Z32Decode(id); err != nil { + return ssiutil.LoggingCtxErrorMsgf(ctx, err, "failed to decode z-base-32 encoded ID: %s", id) + } + if err := record.IsValid(); err != nil { return err } @@ -115,6 +132,16 @@ func (s *PkarrService) GetPkarr(ctx context.Context, id string) (*pkarr.Response ctx, span := telemetry.GetTracer().Start(ctx, "PkarrService.GetPkarr") defer span.End() + // make sure the key is valid + if _, err := util.Z32Decode(id); err != nil { + return nil, ssiutil.LoggingCtxErrorMsgf(ctx, err, "failed to decode z-base-32 encoded ID: %s", id) + } + + // if the key is in the badGetCache, return an error + if _, err := s.badGetCache.Get(id); err == nil { + return nil, ssiutil.LoggingCtxErrorMsgf(ctx, err, "key [%s] looked up too frequently, please wait a bit before trying again", id) + } + // first do a cache lookup if got, err := s.cache.Get(id); err == nil { var resp pkarr.Response @@ -138,7 +165,13 @@ func (s *PkarrService) GetPkarr(ctx context.Context, id string) (*pkarr.Response record, err := s.db.ReadRecord(ctx, rawID) if err != nil || record == nil { - logrus.WithContext(ctx).WithError(err).WithField("record", id).Error("failed to resolve pkarr record from storage") + logrus.WithContext(ctx).WithError(err).WithField("record", id).Error("failed to resolve pkarr record from storage; adding to badGetCache") + + // add the key to the badGetCache to prevent spamming the DHT + if err = s.badGetCache.Set(id, []byte{0}); err != nil { + logrus.WithContext(ctx).WithError(err).WithField("record", id).Error("failed to set key in badGetCache") + } + return nil, err } @@ -193,67 +226,93 @@ func (s *PkarrService) republish() { recordCnt, err := s.db.RecordCount(ctx) if err != nil { logrus.WithContext(ctx).WithError(err).Error("failed to get record count before republishing") + return } else { logrus.WithContext(ctx).WithField("record_count", recordCnt).Info("republishing records") } var nextPageToken []byte var recordsBatch []pkarr.Record - var seenRecords, batchCnt, successCnt, errCnt int32 = 0, 0, 0, 0 + var seenRecords, batchCnt, successCnt, errCnt int32 = 0, 1, 0, 0 + for { recordsBatch, nextPageToken, err = s.db.ListRecords(ctx, nextPageToken, 1000) if err != nil { logrus.WithContext(ctx).WithError(err).Error("failed to list record(s) for republishing") return } - seenRecords += int32(len(recordsBatch)) - if len(recordsBatch) == 0 { + batchSize := len(recordsBatch) + seenRecords += int32(batchSize) + if batchSize == 0 { logrus.WithContext(ctx).Info("no records to republish") return } logrus.WithContext(ctx).WithFields(logrus.Fields{ - "record_count": len(recordsBatch), + "record_count": batchSize, "batch_number": batchCnt, "total_seen": seenRecords, - }).Infof("republishing next batch of records") + }).Infof("republishing batch [%d] of [%d] records", batchCnt, batchSize) batchCnt++ var wg sync.WaitGroup - wg.Add(len(recordsBatch)) + wg.Add(batchSize) + var batchErrCnt, batchSuccessCnt int32 = 0, 0 for _, record := range recordsBatch { - go func(record pkarr.Record) { + go func(ctx context.Context, record pkarr.Record) { defer wg.Done() recordID := zbase32.EncodeToString(record.Key[:]) logrus.WithContext(ctx).Debugf("republishing record: %s", recordID) - // Create a new context with a timeout of 10 seconds - putCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + putCtx, cancel := context.WithTimeout(ctx, 10*time.Second) defer cancel() - if _, err = s.dht.Put(putCtx, record.BEP44()); err != nil { - logrus.WithContext(ctx).WithError(err).Errorf("failed to republish record: %s", recordID) - atomic.AddInt32(&errCnt, 1) + if _, putErr := s.dht.Put(putCtx, record.BEP44()); putErr != nil { + logrus.WithContext(putCtx).WithError(putErr).Errorf("failed to republish record: %s", recordID) + atomic.AddInt32(&batchErrCnt, 1) } else { - atomic.AddInt32(&successCnt, 1) + atomic.AddInt32(&batchSuccessCnt, 1) } - }(record) + }(ctx, record) } + // Wait for all goroutines in this batch to finish before moving on to the next batch wg.Wait() + // Update the success and error counts + atomic.AddInt32(&successCnt, batchSuccessCnt) + atomic.AddInt32(&errCnt, batchErrCnt) + + successRate := float64(batchSuccessCnt) / float64(batchSize) + + logrus.WithContext(ctx).WithFields(logrus.Fields{ + "batch_number": batchCnt, + "success": successCnt, + "errors": errCnt, + }).Infof("batch [%d] completed with a [%02f] percent success rate", batchCnt, successRate) + + if successRate < 0.8 { + logrus.WithContext(ctx).WithFields(logrus.Fields{ + "batch_number": batchCnt, + "success": successCnt, + "errors": errCnt, + }).Errorf("batch [%d] failed to meet success rate threshold; exiting republishing early", batchCnt) + break + } + if nextPageToken == nil { break } } + successRate := float64(successCnt) / float64(seenRecords) logrus.WithContext(ctx).WithFields(logrus.Fields{ "success": seenRecords - errCnt, "errors": errCnt, "total": seenRecords, - }).Infof("republishing complete with [%d] batches", batchCnt) + }).Infof("republishing complete with [%d] batches of [%d] total records with an [%02f] percent success rate", batchCnt, seenRecords, successRate*100) } // Close closes the Pkarr service gracefully @@ -269,6 +328,11 @@ func (s *PkarrService) Close() { logrus.WithError(err).Error("failed to close cache") } } + if s.badGetCache != nil { + if err := s.badGetCache.Close(); err != nil { + logrus.WithError(err).Error("failed to close badGetCache") + } + } if err := s.db.Close(); err != nil { logrus.WithError(err).Error("failed to close db") } diff --git a/impl/pkg/service/pkarr_test.go b/impl/pkg/service/pkarr_test.go index 3a08e1d8..1275abb8 100644 --- a/impl/pkg/service/pkarr_test.go +++ b/impl/pkg/service/pkarr_test.go @@ -34,7 +34,7 @@ func TestPkarrService(t *testing.T) { t.Run("test get record with invalid ID", func(t *testing.T) { got, err := svc.GetPkarr(context.Background(), "---") - assert.EqualError(t, err, "illegal z-base-32 data at input byte 0") + assert.ErrorContains(t, err, "illegal z-base-32 data at input byte 0") assert.Nil(t, got) }) @@ -125,6 +125,17 @@ func TestPkarrService(t *testing.T) { assert.Equal(t, putMsg.Seq, got.Seq) }) + t.Run("test get record with invalid ID", func(t *testing.T) { + got, err := svc.GetPkarr(context.Background(), "uqaj3fcr9db6jg6o9pjs53iuftyj45r46aubogfaceqjbo6pp9sy") + assert.NoError(t, err) + assert.Empty(t, got) + + // try it again to make sure the cache is working + got, err = svc.GetPkarr(context.Background(), "uqaj3fcr9db6jg6o9pjs53iuftyj45r46aubogfaceqjbo6pp9sy") + assert.ErrorContains(t, err, "looked up too frequently, please wait a bit before trying again") + assert.Empty(t, got) + }) + t.Cleanup(func() { svc.Close() }) }