diff --git a/pkg/storage/redis.go b/pkg/storage/redis.go index 9834a13bb..8799edd33 100644 --- a/pkg/storage/redis.go +++ b/pkg/storage/redis.go @@ -2,11 +2,12 @@ package storage import ( "context" + "encoding/base64" "fmt" - "strconv" "time" "github.com/cenkalti/backoff/v4" + "github.com/goccy/go-json" "github.com/pkg/errors" "github.com/redis/go-redis/extra/redisotel/v9" goredislib "github.com/redis/go-redis/v9" @@ -31,16 +32,16 @@ type RedisDB struct { } func (b *RedisDB) ReadPage(ctx context.Context, namespace string, pageToken string, pageSize int) (map[string][]byte, string, error) { - cursor := uint64(0) + token := new(PageToken) if pageToken != "" { var err error - cursor, err = strconv.ParseUint(pageToken, 10, 64) + token, err = parseToken(pageToken) if err != nil { - return nil, "", errors.Wrap(err, "parsing page token") + return nil, "", err } } - keys, nextCursor, err := readAllKeys(ctx, namespace, b, pageSize, cursor) + keys, nextCursor, offsetFromCursor, err := readAllKeys(ctx, namespace, b, pageSize, token.Cursor, token.OffsetFromCursor) if err != nil { return nil, "", err } @@ -48,7 +49,42 @@ func (b *RedisDB) ReadPage(ctx context.Context, namespace string, pageToken stri if err != nil { return nil, "", err } - return results, nextCursor, nil + nextPageToken := PageToken{ + Cursor: nextCursor, + OffsetFromCursor: offsetFromCursor, + } + encodedToken, err := encodeToken(nextPageToken) + if err != nil { + return nil, "", err + } + return results, encodedToken, nil +} + +type PageToken struct { + Cursor uint64 + OffsetFromCursor int +} + +func parseToken(pageToken string) (*PageToken, error) { + pageTokenData, err := base64.RawURLEncoding.DecodeString(pageToken) + if err != nil { + return nil, errors.Wrap(err, "decoding page token") + } + + var token PageToken + if err := json.Unmarshal(pageTokenData, &token); err != nil { + return nil, errors.Wrap(err, "unmarshalling page token data") + } + + return &token, nil +} + +func encodeToken(token PageToken) (string, error) { + data, err := json.Marshal(token) + if err != nil { + return "", errors.Wrap(err, "marshalling page token") + } + return base64.RawURLEncoding.EncodeToString(data), nil } var _ ServiceStorage = (*RedisDB)(nil) @@ -234,7 +270,7 @@ func (b *RedisDB) Read(ctx context.Context, namespace, key string) ([]byte, erro func (b *RedisDB) ReadPrefix(ctx context.Context, namespace, prefix string) (map[string][]byte, error) { namespacePrefix := getRedisKey(namespace, prefix) - keys, _, err := readAllKeys(ctx, namespacePrefix, b, -1, 0) + keys, _, _, err := readAllKeys(ctx, namespacePrefix, b, -1, 0, 0) if err != nil { return nil, errors.Wrap(err, "read all keys") } @@ -243,7 +279,7 @@ func (b *RedisDB) ReadPrefix(ctx context.Context, namespace, prefix string) (map } func (b *RedisDB) ReadAll(ctx context.Context, namespace string) (map[string][]byte, error) { - keys, _, err := readAllKeys(ctx, namespace, b, -1, 0) + keys, _, _, err := readAllKeys(ctx, namespace, b, -1, 0, 0) if err != nil { return nil, errors.Wrap(err, "read all keys") } @@ -280,7 +316,7 @@ func readAll(ctx context.Context, namespace string, keys []string, b *RedisDB) ( } func (b *RedisDB) ReadAllKeys(ctx context.Context, namespace string) ([]string, error) { - keys, _, err := readAllKeys(ctx, namespace, b, -1, 0) + keys, _, _, err := readAllKeys(ctx, namespace, b, -1, 0, 0) if err != nil { return nil, err } @@ -300,38 +336,46 @@ func (b *RedisDB) ReadAllKeys(ctx context.Context, namespace string) ([]string, // NOTE: When passing pageSize == -1, **all** items are returns. Exercise caution regarding memory limits. Always // prefer to set the pageSize. -// TODO: Remove all calls that set pageSize to -1. https://github.com/TBD54566975/ssi-service/issues/525 -func readAllKeys(ctx context.Context, namespace string, b *RedisDB, pageSize int, cursor uint64) ([]string, string, error) { - - var allKeys []string +func readAllKeys(ctx context.Context, namespace string, b *RedisDB, pageSize int, cursor uint64, offset int) ([]string, uint64, int, error) { - var nextCursor uint64 - var err error var keys []string - scanCount := RedisScanBatchSize - if pageSize != -1 { - scanCount = min(RedisScanBatchSize, pageSize) - } - for pageSize == -1 || (len(allKeys) < pageSize) { - keys, nextCursor, err = b.db.Scan(ctx, cursor, namespace+"*", int64(scanCount)).Result() + var scannedKeys []string + var err error + nextCursor := cursor + + for { + scannedKeys, nextCursor, err = b.db.Scan(ctx, nextCursor, namespace+"*", int64(RedisScanBatchSize)).Result() if err != nil { - return nil, "", errors.Wrap(err, "scan error") + return keys, 0, 0, err } - allKeys = append(allKeys, keys...) - - if nextCursor == 0 { + if len(scannedKeys) == 0 { break } - cursor = nextCursor - } + // Apply offset + if offset > 0 { + if offset >= len(scannedKeys) { + // Offset past end of results + offset -= len(scannedKeys) + continue + } + + scannedKeys = scannedKeys[offset:] + offset = 0 + } + + // Append scanned keys + keys = append(keys, scannedKeys...) - var nextCursorToReturn string - if nextCursor != 0 { - nextCursorToReturn = strconv.FormatUint(nextCursor, 10) + // Break if we have enough keys + if len(keys) >= pageSize && pageSize != -1 { + keys = keys[:pageSize] + break + } } - return allKeys, nextCursorToReturn, nil + + return keys, nextCursor, offset, nil } func min(l int, r int) int { @@ -360,7 +404,7 @@ func (b *RedisDB) Delete(ctx context.Context, namespace, key string) error { } func (b *RedisDB) DeleteNamespace(ctx context.Context, namespace string) error { - keys, _, err := readAllKeys(ctx, namespace, b, -1, 0) + keys, _, _, err := readAllKeys(ctx, namespace, b, -1, 0, 0) if err != nil { return errors.Wrap(err, "read all keys") }