Skip to content
This repository has been archived by the owner on Dec 12, 2024. It is now read-only.

Fix Redis pagination issue in readAllKeys function #576

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
108 changes: 76 additions & 32 deletions pkg/storage/redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,12 @@ package storage

import (
"context"
"encoding/base64"
"fmt"
"strconv"
"time"

"github.com/cenkalti/backoff/v4"
"github.com/goccy/go-json"
"github.com/pkg/errors"
"github.com/redis/go-redis/extra/redisotel/v9"
goredislib "github.com/redis/go-redis/v9"
Expand All @@ -31,24 +32,59 @@ type RedisDB struct {
}

func (b *RedisDB) ReadPage(ctx context.Context, namespace string, pageToken string, pageSize int) (map[string][]byte, string, error) {
cursor := uint64(0)
token := new(PageToken)
if pageToken != "" {
var err error
cursor, err = strconv.ParseUint(pageToken, 10, 64)
token, err = parseToken(pageToken)
if err != nil {
return nil, "", errors.Wrap(err, "parsing page token")
return nil, "", err
}
}

keys, nextCursor, err := readAllKeys(ctx, namespace, b, pageSize, cursor)
keys, nextCursor, offsetFromCursor, err := readAllKeys(ctx, namespace, b, pageSize, token.Cursor, token.OffsetFromCursor)
if err != nil {
return nil, "", err
}
results, err := readAll(ctx, namespace, keys, b)
if err != nil {
return nil, "", err
}
return results, nextCursor, nil
nextPageToken := PageToken{
Cursor: nextCursor,
OffsetFromCursor: offsetFromCursor,
}
encodedToken, err := encodeToken(nextPageToken)
if err != nil {
return nil, "", err
}
return results, encodedToken, nil
}

type PageToken struct {
Cursor uint64
OffsetFromCursor int
}

func parseToken(pageToken string) (*PageToken, error) {
pageTokenData, err := base64.RawURLEncoding.DecodeString(pageToken)
if err != nil {
return nil, errors.Wrap(err, "decoding page token")
}

var token PageToken
if err := json.Unmarshal(pageTokenData, &token); err != nil {
return nil, errors.Wrap(err, "unmarshalling page token data")
}

return &token, nil
}

func encodeToken(token PageToken) (string, error) {
data, err := json.Marshal(token)
if err != nil {
return "", errors.Wrap(err, "marshalling page token")
}
return base64.RawURLEncoding.EncodeToString(data), nil
}

var _ ServiceStorage = (*RedisDB)(nil)
Expand Down Expand Up @@ -234,7 +270,7 @@ func (b *RedisDB) Read(ctx context.Context, namespace, key string) ([]byte, erro
func (b *RedisDB) ReadPrefix(ctx context.Context, namespace, prefix string) (map[string][]byte, error) {
namespacePrefix := getRedisKey(namespace, prefix)

keys, _, err := readAllKeys(ctx, namespacePrefix, b, -1, 0)
keys, _, _, err := readAllKeys(ctx, namespacePrefix, b, -1, 0, 0)
if err != nil {
return nil, errors.Wrap(err, "read all keys")
}
Expand All @@ -243,7 +279,7 @@ func (b *RedisDB) ReadPrefix(ctx context.Context, namespace, prefix string) (map
}

func (b *RedisDB) ReadAll(ctx context.Context, namespace string) (map[string][]byte, error) {
keys, _, err := readAllKeys(ctx, namespace, b, -1, 0)
keys, _, _, err := readAllKeys(ctx, namespace, b, -1, 0, 0)
if err != nil {
return nil, errors.Wrap(err, "read all keys")
}
Expand Down Expand Up @@ -280,7 +316,7 @@ func readAll(ctx context.Context, namespace string, keys []string, b *RedisDB) (
}

func (b *RedisDB) ReadAllKeys(ctx context.Context, namespace string) ([]string, error) {
keys, _, err := readAllKeys(ctx, namespace, b, -1, 0)
keys, _, _, err := readAllKeys(ctx, namespace, b, -1, 0, 0)
if err != nil {
return nil, err
}
Expand All @@ -300,38 +336,46 @@ func (b *RedisDB) ReadAllKeys(ctx context.Context, namespace string) ([]string,

// NOTE: When passing pageSize == -1, **all** items are returns. Exercise caution regarding memory limits. Always
// prefer to set the pageSize.
// TODO: Remove all calls that set pageSize to -1. https://github.com/TBD54566975/ssi-service/issues/525
func readAllKeys(ctx context.Context, namespace string, b *RedisDB, pageSize int, cursor uint64) ([]string, string, error) {

var allKeys []string
func readAllKeys(ctx context.Context, namespace string, b *RedisDB, pageSize int, cursor uint64, offset int) ([]string, uint64, int, error) {

var nextCursor uint64
var err error
var keys []string
scanCount := RedisScanBatchSize
if pageSize != -1 {
scanCount = min(RedisScanBatchSize, pageSize)
}
for pageSize == -1 || (len(allKeys) < pageSize) {
keys, nextCursor, err = b.db.Scan(ctx, cursor, namespace+"*", int64(scanCount)).Result()
var scannedKeys []string
var err error
nextCursor := cursor

for {
scannedKeys, nextCursor, err = b.db.Scan(ctx, nextCursor, namespace+"*", int64(RedisScanBatchSize)).Result()
if err != nil {
return nil, "", errors.Wrap(err, "scan error")
return keys, 0, 0, err
}

allKeys = append(allKeys, keys...)

if nextCursor == 0 {
if len(scannedKeys) == 0 {
break
}

cursor = nextCursor
}
// Apply offset
if offset > 0 {
if offset >= len(scannedKeys) {
// Offset past end of results
offset -= len(scannedKeys)
continue
}

scannedKeys = scannedKeys[offset:]
offset = 0
}

// Append scanned keys
keys = append(keys, scannedKeys...)

var nextCursorToReturn string
if nextCursor != 0 {
nextCursorToReturn = strconv.FormatUint(nextCursor, 10)
// Break if we have enough keys
if len(keys) >= pageSize && pageSize != -1 {
keys = keys[:pageSize]
break
}
}
return allKeys, nextCursorToReturn, nil

return keys, nextCursor, offset, nil
}

func min(l int, r int) int {
Expand Down Expand Up @@ -360,7 +404,7 @@ func (b *RedisDB) Delete(ctx context.Context, namespace, key string) error {
}

func (b *RedisDB) DeleteNamespace(ctx context.Context, namespace string) error {
keys, _, err := readAllKeys(ctx, namespace, b, -1, 0)
keys, _, _, err := readAllKeys(ctx, namespace, b, -1, 0, 0)
if err != nil {
return errors.Wrap(err, "read all keys")
}
Expand Down