diff --git a/impl/README.md b/impl/README.md index 99e699f8..b9018507 100644 --- a/impl/README.md +++ b/impl/README.md @@ -62,4 +62,4 @@ docker run \ ### Postgres To use a postgres database as the storage backend, set configuration option `storage_uri` to a `postgres://` URI with -the database connection string. The schema will be created or updated as needed while the program starts. +the database connection string. The schema will be created or updated as needed while the program starts. \ No newline at end of file diff --git a/impl/cmd/main.go b/impl/cmd/main.go index f143354e..d15265ec 100644 --- a/impl/cmd/main.go +++ b/impl/cmd/main.go @@ -75,8 +75,7 @@ func run() error { } // create a channel of buffer size 1 to handle shutdown. - // buffer's size is 1 in order to ignore any additional ctrl+c - // spamming. + // buffer's size is 1 in order to ignore any additional ctrl+c spamming. shutdown := make(chan os.Signal, 1) signal.Notify(shutdown, os.Interrupt, syscall.SIGTERM) diff --git a/impl/config/config.go b/impl/config/config.go index d04dc785..13c2e8ba 100644 --- a/impl/config/config.go +++ b/impl/config/config.go @@ -58,9 +58,10 @@ type DHTServiceConfig struct { } type PKARRServiceConfig struct { - RepublishCRON string `toml:"republish_cron"` - CacheTTLSeconds int `toml:"cache_ttl_seconds"` - CacheSizeLimitMB int `toml:"cache_size_limit_mb"` + RepublishCRON string `toml:"republish_cron"` + CacheTTLSeconds int `toml:"cache_ttl_seconds"` + CacheSizeLimitMB int `toml:"cache_size_limit_mb"` + PutTimeoutSeconds int `toml:"put_timeout_seconds"` } type LogConfig struct { @@ -81,9 +82,10 @@ func GetDefaultConfig() Config { BootstrapPeers: GetDefaultBootstrapPeers(), }, PkarrConfig: PKARRServiceConfig{ - RepublishCRON: "0 */2 * * *", - CacheTTLSeconds: 600, - CacheSizeLimitMB: 500, + RepublishCRON: "0 */2 * * *", + CacheTTLSeconds: 600, + CacheSizeLimitMB: 500, + PutTimeoutSeconds: 5, }, Log: LogConfig{ Level: logrus.InfoLevel.String(), diff --git a/impl/config/config.toml b/impl/config/config.toml index a0c87c63..e71deb8d 100644 --- a/impl/config/config.toml +++ b/impl/config/config.toml @@ -13,4 +13,5 @@ bootstrap_peers = ["router.magnets.im:6881", "router.bittorrent.com:6881", "dht. [pkarr] republish_cron = "0 */2 * * *" # every 2 hours cache_ttl_seconds = 600 # 10 minutes -cache_size_limit_mb = 500 # 512 MB +cache_size_limit_mb = 1000 # 1000 MB +put_timeout_seconds = 5 # 5 seconds before puts time out \ No newline at end of file diff --git a/impl/integrationtest/main.go b/impl/integrationtest/main.go index 954fa3f5..138c393f 100644 --- a/impl/integrationtest/main.go +++ b/impl/integrationtest/main.go @@ -47,7 +47,7 @@ func run(server string) { continue } - if err := get(ctx, server, suffix); err != nil { + if err = get(ctx, server, suffix); err != nil { logrus.WithError(err).Error("error making GET request") continue } @@ -71,7 +71,7 @@ func put(ctx context.Context, server string) (string, error) { return "", err } - if err := doRequest(ctx, req); err != nil { + if err = doRequest(ctx, req); err != nil { return "", err } @@ -84,7 +84,7 @@ func get(ctx context.Context, server string, suffix string) error { return err } - if err := doRequest(ctx, req); err != nil { + if err = doRequest(ctx, req); err != nil { return err } diff --git a/impl/pkg/dht/dht.go b/impl/pkg/dht/dht.go index 0f0b580a..b3a68932 100644 --- a/impl/pkg/dht/dht.go +++ b/impl/pkg/dht/dht.go @@ -4,6 +4,7 @@ import ( "context" "net" "testing" + "time" errutil "github.com/TBD54566975/ssi-sdk/util" "github.com/anacrolix/dht/v2" @@ -13,6 +14,7 @@ import ( "github.com/anacrolix/torrent/types/infohash" "github.com/sirupsen/logrus" "github.com/stretchr/testify/require" + "golang.org/x/time/rate" dhtint "github.com/TBD54566975/did-dht-method/internal/dht" "github.com/TBD54566975/did-dht-method/internal/util" @@ -29,6 +31,9 @@ func NewDHT(bootstrapPeers []string) (*DHT, error) { logrus.WithField("bootstrap_peers", len(bootstrapPeers)).Info("initializing DHT") c := dht.NewDefaultServerConfig() + // change default expire to 24 hours + c.Exp = time.Hour * 24 + c.NoSecurity = false conn, err := net.ListenPacket("udp", "0.0.0.0:6881") if err != nil { return nil, errutil.LoggingErrorMsg(err, "failed to listen on udp port 6881") @@ -37,6 +42,8 @@ func NewDHT(bootstrapPeers []string) (*DHT, error) { c.Logger = log.NewLogger().WithFilterLevel(log.Debug) c.Logger.SetHandlers(logHandler{}) c.StartingNodes = func() ([]dht.Addr, error) { return dht.ResolveHostPorts(bootstrapPeers) } + // set up rate limiter - 100 requests per second, 500 requests burst + c.SendLimiter = rate.NewLimiter(100, 500) s, err := dht.NewServer(c) if err != nil { return nil, errutil.LoggingErrorMsg(err, "failed to create dht server") @@ -84,14 +91,17 @@ func (d *DHT) Put(ctx context.Context, request bep44.Put) (string, error) { logrus.Warn("no nodes available in the DHT for publishing") } + key := util.Z32Encode(request.K[:]) t, err := getput.Put(ctx, request.Target(), d.Server, nil, func(int64) bep44.Put { return request }) if err != nil { if t == nil { - return "", errutil.LoggingNewErrorf("failed to put key into dht: %v", err) + return "", errutil.LoggingNewErrorf("failed to put key[%s] into dht: %v", key, err) } - return "", errutil.LoggingNewErrorf("failed to put key into dht, tried %d nodes, got %d responses", t.NumAddrsTried, t.NumResponses) + return "", errutil.LoggingNewErrorf("failed to put key[%s] into dht, tried %d nodes, got %d responses", key, t.NumAddrsTried, t.NumResponses) + } else { + logrus.WithField("key", key).Debug("successfully put key into dht") } return util.Z32Encode(request.K[:]), nil } diff --git a/impl/pkg/pkarr/record.go b/impl/pkg/pkarr/record.go index 0f250f9d..862624c4 100644 --- a/impl/pkg/pkarr/record.go +++ b/impl/pkg/pkarr/record.go @@ -1,6 +1,8 @@ package pkarr import ( + "bytes" + "crypto/sha256" "encoding/base64" "errors" "fmt" @@ -8,9 +10,21 @@ import ( "github.com/TBD54566975/ssi-sdk/util" "github.com/anacrolix/dht/v2/bep44" "github.com/anacrolix/torrent/bencode" + "github.com/goccy/go-json" "github.com/tv42/zbase32" ) +type Response struct { + V []byte `validate:"required"` + Seq int64 `validate:"required"` + Sig [64]byte `validate:"required"` +} + +// Equals returns true if the response is equal to the other response +func (r Response) Equals(other Response) bool { + return r.Seq == other.Seq && bytes.Equal(r.V, other.V) && r.Sig == other.Sig +} + type Record struct { Value []byte `json:"v" validate:"required"` Key [32]byte `json:"k" validate:"required"` @@ -18,12 +32,7 @@ type Record struct { SequenceNumber int64 `json:"seq" validate:"required"` } -type Response struct { - V []byte `validate:"required"` - Seq int64 `validate:"required"` - Sig [64]byte `validate:"required"` -} - +// NewRecord returns a new Record with the given key, value, signature, and sequence number func NewRecord(k []byte, v []byte, sig []byte, seq int64) (*Record, error) { record := Record{SequenceNumber: seq} @@ -67,6 +76,7 @@ func (r Record) IsValid() error { return nil } +// Response returns the record as a Response func (r Record) Response() Response { return Response{ V: r.Value, @@ -75,6 +85,7 @@ func (r Record) Response() Response { } } +// BEP44 returns the record as a BEP44 Put message func (r Record) BEP44() bep44.Put { return bep44.Put{ V: r.Value, @@ -84,11 +95,27 @@ func (r Record) BEP44() bep44.Put { } } +// String returns a string representation of the record func (r Record) String() string { e := base64.RawURLEncoding return fmt.Sprintf("pkarr.Record{K=%s V=%s Sig=%s Seq=%d}", zbase32.EncodeToString(r.Key[:]), e.EncodeToString(r.Value), e.EncodeToString(r.Signature[:]), r.SequenceNumber) } +// ID returns the base32 encoded key as a string +func (r Record) ID() string { + return zbase32.EncodeToString(r.Key[:]) +} + +// Hash returns the SHA256 hash of the record as a string +func (r Record) Hash() (string, error) { + recordBytes, err := json.Marshal(r) + if err != nil { + return "", err + } + return string(sha256.New().Sum(recordBytes)), nil +} + +// RecordFromBEP44 returns a Record from a BEP44 Put message func RecordFromBEP44(putMsg *bep44.Put) Record { return Record{ Key: *putMsg.K, diff --git a/impl/pkg/server/server.go b/impl/pkg/server/server.go index ed7645d4..1c02727b 100644 --- a/impl/pkg/server/server.go +++ b/impl/pkg/server/server.go @@ -1,6 +1,7 @@ package server import ( + "context" "fmt" "net/http" "os" @@ -44,6 +45,13 @@ func NewServer(cfg *config.Config, shutdown chan os.Signal, d *dht.DHT) (*Server return nil, util.LoggingErrorMsg(err, "failed to instantiate storage") } + recordCnt, err := db.RecordCount(context.Background()) + if err != nil { + logrus.WithError(err).Error("failed to get record count") + } else { + logrus.WithField("record_count", recordCnt).Info("storage instantiated with record count") + } + pkarrService, err := service.NewPkarrService(cfg, db, d) if err != nil { return nil, util.LoggingErrorMsg(err, "could not instantiate pkarr service") diff --git a/impl/pkg/service/pkarr.go b/impl/pkg/service/pkarr.go index a560c31f..6be23bf1 100644 --- a/impl/pkg/service/pkarr.go +++ b/impl/pkg/service/pkarr.go @@ -72,16 +72,23 @@ func (s *PkarrService) PublishPkarr(ctx context.Context, id string, record pkarr return err } + // check if the message is already in the cache + if got, err := s.cache.Get(id); err == nil { + var resp pkarr.Response + if err = json.Unmarshal(got, &resp); err == nil && record.Response().Equals(resp) { + logrus.WithField("record_id", id).Debug("resolved pkarr record from cache with matching response") + return nil + } + } + // write to db and cache if err := s.db.WriteRecord(ctx, record); err != nil { return err } - recordBytes, err := json.Marshal(record.Response()) if err != nil { return err } - if err = s.cache.Set(id, recordBytes); err != nil { return err } @@ -90,11 +97,11 @@ func (s *PkarrService) PublishPkarr(ctx context.Context, id string, record pkarr // TODO(gabe): consider a background process to monitor failures go func() { // Create a new context with a timeout so that the parent context does not cancel the put - putCtx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + putCtx, cancel := context.WithTimeout(context.Background(), time.Duration(s.cfg.PkarrConfig.PutTimeoutSeconds)*time.Second) defer cancel() if _, err = s.dht.Put(putCtx, record.BEP44()); err != nil { - logrus.WithError(err).Error("error from dht.Put") + logrus.WithError(err).Errorf("error from dht.Put for record: %s", id) } }() @@ -109,12 +116,11 @@ func (s *PkarrService) GetPkarr(ctx context.Context, id string) (*pkarr.Response // first do a cache lookup if got, err := s.cache.Get(id); err == nil { var resp pkarr.Response - err = json.Unmarshal(got, &resp) - if err == nil { + if err = json.Unmarshal(got, &resp); err == nil { logrus.WithField("record_id", id).Debug("resolved pkarr record from cache") return &resp, nil } - logrus.WithError(err).WithField("record", id).Warn("failed to unmarshal pkarr record from cache, falling back to dht") + logrus.WithError(err).WithField("record", id).Warn("failed to get pkarr record from cache, falling back to dht") } // next do a dht lookup @@ -182,9 +188,15 @@ func (s *PkarrService) republish() { ctx, span := telemetry.GetTracer().Start(context.Background(), "PkarrService.republish") defer span.End() + recordCnt, err := s.db.RecordCount(ctx) + if err != nil { + logrus.WithError(err).Error("failed to get record count") + } else { + logrus.WithField("record_count", recordCnt).Info("republishing records") + } + var nextPageToken []byte var allRecords []pkarr.Record - var err error errCnt := 0 successCnt := 0 for { @@ -199,12 +211,13 @@ func (s *PkarrService) republish() { return } - logrus.WithField("record_count", len(allRecords)).Info("Republishing record") + logrus.WithField("record_count", len(allRecords)).Info("republishing records in batch") for _, record := range allRecords { - logrus.Infof("Republishing record: %s", zbase32.EncodeToString(record.Key[:])) + recordID := zbase32.EncodeToString(record.Key[:]) + logrus.Debugf("republishing record: %s", recordID) if _, err = s.dht.Put(ctx, record.BEP44()); err != nil { - logrus.WithError(err).Error("failed to republish record") + logrus.WithError(err).Errorf("failed to republish record: %s", recordID) errCnt++ continue } diff --git a/impl/pkg/service/pkarr_test.go b/impl/pkg/service/pkarr_test.go index c6a29cee..f034d52d 100644 --- a/impl/pkg/service/pkarr_test.go +++ b/impl/pkg/service/pkarr_test.go @@ -18,8 +18,8 @@ import ( "github.com/TBD54566975/did-dht-method/pkg/storage" ) -func TestPKARRService(t *testing.T) { - svc := newPKARRService(t, "a") +func TestPkarrService(t *testing.T) { + svc := newPkarrService(t, "a") t.Run("test put bad record", func(t *testing.T) { err := svc.PublishPkarr(context.Background(), "", pkarr.Record{}) @@ -128,7 +128,7 @@ func TestPKARRService(t *testing.T) { } func TestDHT(t *testing.T) { - svc1 := newPKARRService(t, "b") + svc1 := newPkarrService(t, "b") // create and publish a record to service1 sk, doc, err := did.GenerateDIDDHT(did.CreateDIDDHTOpts{}) @@ -155,7 +155,7 @@ func TestDHT(t *testing.T) { assert.Equal(t, putMsg.Seq, got.Seq) // create service2 with service1 as a bootstrap peer - svc2 := newPKARRService(t, "c", anacrolixdht.NewAddr(svc1.dht.Addr())) + svc2 := newPkarrService(t, "c", anacrolixdht.NewAddr(svc1.dht.Addr())) // get the record via service2 gotFrom2, err := svc2.GetPkarr(context.Background(), suffix) @@ -188,7 +188,7 @@ func TestNoConfig(t *testing.T) { assert.Nil(t, svc) } -func newPKARRService(t *testing.T, id string, bootstrapPeers ...anacrolixdht.Addr) PkarrService { +func newPkarrService(t *testing.T, id string, bootstrapPeers ...anacrolixdht.Addr) PkarrService { defaultConfig := config.GetDefaultConfig() db, err := storage.NewStorage(fmt.Sprintf("bolt://diddht-test-%s.db", id)) diff --git a/impl/pkg/storage/db/bolt/bolt.go b/impl/pkg/storage/db/bolt/bolt.go index 506f9360..54d5ec10 100644 --- a/impl/pkg/storage/db/bolt/bolt.go +++ b/impl/pkg/storage/db/bolt/bolt.go @@ -17,7 +17,7 @@ const ( pkarrNamespace = "pkarr" ) -type BoltDB struct { +type Bolt struct { db *bolt.DB } @@ -26,7 +26,7 @@ type boltRecord struct { } // NewBolt creates a BoltDB-based implementation of storage.Storage -func NewBolt(path string) (*BoltDB, error) { +func NewBolt(path string) (*Bolt, error) { if path == "" { return nil, errors.New("path is required") } @@ -35,12 +35,12 @@ func NewBolt(path string) (*BoltDB, error) { return nil, err } - return &BoltDB{db: db}, nil + return &Bolt{db: db}, nil } // WriteRecord writes the given record to the storage // TODO: don't overwrite existing records, store unique seq numbers -func (s *BoltDB) WriteRecord(ctx context.Context, record pkarr.Record) error { +func (b *Bolt) WriteRecord(ctx context.Context, record pkarr.Record) error { ctx, span := telemetry.GetTracer().Start(ctx, "bolt.WriteRecord") defer span.End() @@ -50,15 +50,15 @@ func (s *BoltDB) WriteRecord(ctx context.Context, record pkarr.Record) error { return err } - return s.write(ctx, pkarrNamespace, encoded.K, recordBytes) + return b.write(ctx, pkarrNamespace, encoded.K, recordBytes) } // ReadRecord reads the record with the given id from the storage -func (s *BoltDB) ReadRecord(ctx context.Context, id []byte) (*pkarr.Record, error) { +func (b *Bolt) ReadRecord(ctx context.Context, id []byte) (*pkarr.Record, error) { ctx, span := telemetry.GetTracer().Start(ctx, "bolt.ReadRecord") defer span.End() - recordBytes, err := s.read(ctx, pkarrNamespace, encoding.EncodeToString(id)) + recordBytes, err := b.read(ctx, pkarrNamespace, encoding.EncodeToString(id)) if err != nil { return nil, err } @@ -80,11 +80,11 @@ func (s *BoltDB) ReadRecord(ctx context.Context, id []byte) (*pkarr.Record, erro } // ListRecords lists all records in the storage -func (s *BoltDB) ListRecords(ctx context.Context, nextPageToken []byte, pagesize int) ([]pkarr.Record, []byte, error) { +func (b *Bolt) ListRecords(ctx context.Context, nextPageToken []byte, pagesize int) ([]pkarr.Record, []byte, error) { ctx, span := telemetry.GetTracer().Start(ctx, "bolt.ListRecords") defer span.End() - boltRecords, err := s.readSeveral(ctx, pkarrNamespace, nextPageToken, pagesize) + boltRecords, err := b.readSeveral(ctx, pkarrNamespace, nextPageToken, pagesize) if err != nil { return nil, nil, err } @@ -113,15 +113,15 @@ func (s *BoltDB) ListRecords(ctx context.Context, nextPageToken []byte, pagesize return records, nextPageToken, nil } -func (s *BoltDB) Close() error { - return s.db.Close() +func (b *Bolt) Close() error { + return b.db.Close() } -func (s *BoltDB) write(ctx context.Context, namespace string, key string, value []byte) error { +func (b *Bolt) write(ctx context.Context, namespace string, key string, value []byte) error { _, span := telemetry.GetTracer().Start(ctx, "bolt.write") defer span.End() - return s.db.Update(func(tx *bolt.Tx) error { + return b.db.Update(func(tx *bolt.Tx) error { bucket, err := tx.CreateBucketIfNotExists([]byte(namespace)) if err != nil { return err @@ -133,12 +133,12 @@ func (s *BoltDB) write(ctx context.Context, namespace string, key string, value }) } -func (s *BoltDB) read(ctx context.Context, namespace, key string) ([]byte, error) { +func (b *Bolt) read(ctx context.Context, namespace, key string) ([]byte, error) { _, span := telemetry.GetTracer().Start(ctx, "bolt.read") defer span.End() var result []byte - err := s.db.View(func(tx *bolt.Tx) error { + err := b.db.View(func(tx *bolt.Tx) error { bucket := tx.Bucket([]byte(namespace)) if bucket == nil { logrus.WithField("namespace", namespace).Info("namespace does not exist") @@ -150,9 +150,9 @@ func (s *BoltDB) read(ctx context.Context, namespace, key string) ([]byte, error return result, err } -func (s *BoltDB) readAll(namespace string) (map[string][]byte, error) { +func (b *Bolt) readAll(namespace string) (map[string][]byte, error) { result := make(map[string][]byte) - err := s.db.View(func(tx *bolt.Tx) error { + err := b.db.View(func(tx *bolt.Tx) error { bucket := tx.Bucket([]byte(namespace)) if bucket == nil { logrus.WithField("namespace", namespace).Warn("namespace does not exist") @@ -167,12 +167,12 @@ func (s *BoltDB) readAll(namespace string) (map[string][]byte, error) { return result, err } -func (s *BoltDB) readSeveral(ctx context.Context, namespace string, after []byte, count int) ([]boltRecord, error) { +func (b *Bolt) readSeveral(ctx context.Context, namespace string, after []byte, count int) ([]boltRecord, error) { _, span := telemetry.GetTracer().Start(ctx, "bolt.readSeveral") defer span.End() var result []boltRecord - err := s.db.View(func(tx *bolt.Tx) error { + err := b.db.View(func(tx *bolt.Tx) error { bucket := tx.Bucket([]byte(namespace)) if bucket == nil { logrus.WithField("namespace", namespace).Warn("namespace does not exist") @@ -200,3 +200,21 @@ func (s *BoltDB) readSeveral(ctx context.Context, namespace string, after []byte }) return result, err } + +// RecordCount returns the number of records in the storage for the pkarr namespace +func (b *Bolt) RecordCount(ctx context.Context) (int, error) { + _, span := telemetry.GetTracer().Start(ctx, "bolt.RecordCount") + defer span.End() + + var count int + err := b.db.View(func(tx *bolt.Tx) error { + bucket := tx.Bucket([]byte(pkarrNamespace)) + if bucket == nil { + logrus.WithField("namespace", pkarrNamespace).Warn("namespace does not exist") + return nil + } + count = bucket.Stats().KeyN + return nil + }) + return count, err +} diff --git a/impl/pkg/storage/db/bolt/bolt_test.go b/impl/pkg/storage/db/bolt/bolt_test.go index 74ae425e..b4fbd289 100644 --- a/impl/pkg/storage/db/bolt/bolt_test.go +++ b/impl/pkg/storage/db/bolt/bolt_test.go @@ -94,7 +94,7 @@ func TestBoltDB_PrefixAndKeys(t *testing.T) { assert.NoError(t, err) } -func getTestDB(t *testing.T) *BoltDB { +func getTestDB(t *testing.T) *Bolt { path := "test.db" db, err := NewBolt(path) assert.NoError(t, err) @@ -109,6 +109,10 @@ func getTestDB(t *testing.T) *BoltDB { func TestReadWrite(t *testing.T) { db := getTestDB(t) + ctx := context.Background() + + beforeCnt, err := db.RecordCount(ctx) + require.NoError(t, err) // create a did doc as a packet to store sk, doc, err := did.GenerateDIDDHT(did.CreateDIDDHTOpts{}) @@ -125,7 +129,6 @@ func TestReadWrite(t *testing.T) { r := pkarr.RecordFromBEP44(putMsg) - ctx := context.Background() err = db.WriteRecord(ctx, r) require.NoError(t, err) @@ -136,6 +139,10 @@ func TestReadWrite(t *testing.T) { assert.Equal(t, r.Value, r2.Value) assert.Equal(t, r.Signature, r2.Signature) assert.Equal(t, r.SequenceNumber, r2.SequenceNumber) + + afterCnt, err := db.RecordCount(ctx) + assert.NoError(t, err) + assert.Equal(t, beforeCnt+1, afterCnt) } func TestDBPagination(t *testing.T) { @@ -144,6 +151,9 @@ func TestDBPagination(t *testing.T) { ctx := context.Background() + beforeCnt, err := db.RecordCount(ctx) + require.NoError(t, err) + preTestRecords, _, err := db.ListRecords(ctx, nil, 10) require.NoError(t, err) @@ -198,6 +208,10 @@ func TestDBPagination(t *testing.T) { assert.NoError(t, err) assert.Nil(t, nextPageToken) assert.Len(t, page, 1+len(preTestRecords)) + + afterCnt, err := db.RecordCount(ctx) + assert.NoError(t, err) + assert.Equal(t, beforeCnt+11, afterCnt) } func TestNewBolt(t *testing.T) { diff --git a/impl/pkg/storage/db/bolt/pkarr.go b/impl/pkg/storage/db/bolt/pkarr.go index ee973296..97f30040 100644 --- a/impl/pkg/storage/db/bolt/pkarr.go +++ b/impl/pkg/storage/db/bolt/pkarr.go @@ -4,11 +4,14 @@ import ( "encoding/base64" "fmt" + "github.com/TBD54566975/ssi-sdk/util" + "github.com/TBD54566975/did-dht-method/pkg/pkarr" - "github.com/sirupsen/logrus" ) -var encoding = base64.RawURLEncoding +var ( + encoding = base64.RawURLEncoding +) type base64PkarrRecord struct { // Up to an 1000 byte base64URL encoded string @@ -48,8 +51,7 @@ func (b base64PkarrRecord) Decode() (*pkarr.Record, error) { record, err := pkarr.NewRecord(k, v, sig, b.Seq) if err != nil { // TODO: do something useful if this happens - logrus.WithError(err).Warn("error loading record from database, skipping") - return nil, err + return nil, util.LoggingErrorMsg(err, "error loading record from database, skipping") } return record, nil } diff --git a/impl/pkg/storage/db/postgres/db.go b/impl/pkg/storage/db/postgres/db.go index 2c187adf..0c307ea8 100644 --- a/impl/pkg/storage/db/postgres/db.go +++ b/impl/pkg/storage/db/postgres/db.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.25.0 +// sqlc v1.26.0 package postgres diff --git a/impl/pkg/storage/db/postgres/models.go b/impl/pkg/storage/db/postgres/models.go index b409172e..03958671 100644 --- a/impl/pkg/storage/db/postgres/models.go +++ b/impl/pkg/storage/db/postgres/models.go @@ -1,11 +1,9 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.25.0 +// sqlc v1.26.0 package postgres -import () - type PkarrRecord struct { ID int32 Key []byte diff --git a/impl/pkg/storage/db/postgres/postgres.go b/impl/pkg/storage/db/postgres/postgres.go index 6d415c40..7d3872a5 100644 --- a/impl/pkg/storage/db/postgres/postgres.go +++ b/impl/pkg/storage/db/postgres/postgres.go @@ -6,9 +6,9 @@ import ( "embed" "fmt" - pgx "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5" _ "github.com/jackc/pgx/v5/stdlib" - goose "github.com/pressly/goose/v3" + "github.com/pressly/goose/v3" "github.com/sirupsen/logrus" "github.com/TBD54566975/did-dht-method/pkg/pkarr" @@ -38,11 +38,11 @@ func (p Postgres) migrate() error { defer db.Close() goose.SetBaseFS(migrations) - if err := goose.SetDialect("postgres"); err != nil { + if err = goose.SetDialect("postgres"); err != nil { return err } - if err := goose.Up(db, "migrations"); err != nil { + if err = goose.Up(db, "migrations"); err != nil { return err } @@ -159,3 +159,21 @@ func (p Postgres) Close() error { func (row PkarrRecord) Record() (*pkarr.Record, error) { return pkarr.NewRecord(row.Key, row.Value, row.Sig, row.Seq) } + +func (p Postgres) RecordCount(ctx context.Context) (int, error) { + ctx, span := telemetry.GetTracer().Start(ctx, "postgres.RecordCount") + defer span.End() + + queries, db, err := p.connect(ctx) + if err != nil { + return 0, err + } + defer db.Close(ctx) + + count, err := queries.RecordCount(ctx) + if err != nil { + return 0, err + } + + return int(count), nil +} diff --git a/impl/pkg/storage/db/postgres/postgres_test.go b/impl/pkg/storage/db/postgres/postgres_test.go index c3d4f042..9fc88ba9 100644 --- a/impl/pkg/storage/db/postgres/postgres_test.go +++ b/impl/pkg/storage/db/postgres/postgres_test.go @@ -36,6 +36,10 @@ func getTestDB(t *testing.T) storage.Storage { func TestReadWrite(t *testing.T) { db := getTestDB(t) + ctx := context.Background() + + beforeCnt, err := db.RecordCount(ctx) + require.NoError(t, err) // create a did doc as a packet to store sk, doc, err := did.GenerateDIDDHT(did.CreateDIDDHTOpts{}) @@ -52,7 +56,6 @@ func TestReadWrite(t *testing.T) { r := pkarr.RecordFromBEP44(putMsg) - ctx := context.Background() err = db.WriteRecord(ctx, r) require.NoError(t, err) @@ -63,6 +66,10 @@ func TestReadWrite(t *testing.T) { assert.Equal(t, r.Value, r2.Value) assert.Equal(t, r.Signature, r2.Signature) assert.Equal(t, r.SequenceNumber, r2.SequenceNumber) + + afterCnt, err := db.RecordCount(ctx) + require.NoError(t, err) + assert.Equal(t, beforeCnt+1, afterCnt) } func TestDBPagination(t *testing.T) { @@ -71,6 +78,9 @@ func TestDBPagination(t *testing.T) { ctx := context.Background() + beforeCnt, err := db.RecordCount(ctx) + require.NoError(t, err) + preTestRecords, _, err := db.ListRecords(ctx, nil, 10) require.NoError(t, err) @@ -125,4 +135,8 @@ func TestDBPagination(t *testing.T) { assert.NoError(t, err) assert.Nil(t, nextPageToken) assert.Len(t, page, 1+len(preTestRecords)) + + afterCnt, err := db.RecordCount(ctx) + require.NoError(t, err) + assert.Equal(t, beforeCnt+11, afterCnt) } diff --git a/impl/pkg/storage/db/postgres/queries.sql.go b/impl/pkg/storage/db/postgres/queries.sql.go index 4ddee1e3..3687e39a 100644 --- a/impl/pkg/storage/db/postgres/queries.sql.go +++ b/impl/pkg/storage/db/postgres/queries.sql.go @@ -1,6 +1,6 @@ // Code generated by sqlc. DO NOT EDIT. // versions: -// sqlc v1.25.0 +// sqlc v1.26.0 // source: queries.sql package postgres @@ -91,6 +91,17 @@ func (q *Queries) ReadRecord(ctx context.Context, key []byte) (PkarrRecord, erro return i, err } +const recordCount = `-- name: RecordCount :one +SELECT count(*) AS exact_count FROM pkarr_records +` + +func (q *Queries) RecordCount(ctx context.Context) (int64, error) { + row := q.db.QueryRow(ctx, recordCount) + var exact_count int64 + err := row.Scan(&exact_count) + return exact_count, err +} + const writeRecord = `-- name: WriteRecord :exec INSERT INTO pkarr_records(key, value, sig, seq) VALUES($1, $2, $3, $4) ` diff --git a/impl/pkg/storage/db/postgres/queries/queries.sql b/impl/pkg/storage/db/postgres/queries/queries.sql index 7f7002da..8ae96ea3 100644 --- a/impl/pkg/storage/db/postgres/queries/queries.sql +++ b/impl/pkg/storage/db/postgres/queries/queries.sql @@ -8,4 +8,7 @@ SELECT * FROM pkarr_records WHERE key = $1 LIMIT 1; SELECT * FROM pkarr_records WHERE id > (SELECT id FROM pkarr_records WHERE pkarr_records.key = $1) ORDER BY id ASC LIMIT $2; -- name: ListRecordsFirstPage :many -SELECT * FROM pkarr_records ORDER BY id ASC LIMIT $1; \ No newline at end of file +SELECT * FROM pkarr_records ORDER BY id ASC LIMIT $1; + +-- name: RecordCount :one +SELECT count(*) AS exact_count FROM pkarr_records; diff --git a/impl/pkg/storage/storage.go b/impl/pkg/storage/storage.go index 469b21c2..c8e460f0 100644 --- a/impl/pkg/storage/storage.go +++ b/impl/pkg/storage/storage.go @@ -16,7 +16,8 @@ import ( type Storage interface { WriteRecord(ctx context.Context, record pkarr.Record) error ReadRecord(ctx context.Context, id []byte) (*pkarr.Record, error) - ListRecords(ctx context.Context, nextPageToken []byte, pagesize int) (records []pkarr.Record, nextPage []byte, err error) + ListRecords(ctx context.Context, nextPageToken []byte, pageSize int) (records []pkarr.Record, nextPage []byte, err error) + RecordCount(ctx context.Context) (int, error) Close() error } diff --git a/impl/pkg/storage/storage_test.go b/impl/pkg/storage/storage_test.go index 274c4d13..81f82d19 100644 --- a/impl/pkg/storage/storage_test.go +++ b/impl/pkg/storage/storage_test.go @@ -33,7 +33,7 @@ func TestNewStoragePostgres(t *testing.T) { func TestNewStorageBolt(t *testing.T) { db, err := storage.NewStorage("bolt:///tmp/bolt.db") require.NoError(t, err) - assert.IsType(t, &bolt.BoltDB{}, db) + assert.IsType(t, &bolt.Bolt{}, db) } func TestNewStorageUnsupported(t *testing.T) {