Skip to content
This repository has been archived by the owner on Dec 12, 2024. It is now read-only.

Fix Redis pagination issue in readAllKeys function #576

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
17 changes: 11 additions & 6 deletions pkg/storage/redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,6 @@ func (b *RedisDB) ReadAllKeys(ctx context.Context, namespace string) ([]string,

// NOTE: When passing pageSize == -1, **all** items are returns. Exercise caution regarding memory limits. Always
// prefer to set the pageSize.
// TODO: Remove all calls that set pageSize to -1. https://github.com/TBD54566975/ssi-service/issues/525
func readAllKeys(ctx context.Context, namespace string, b *RedisDB, pageSize int, cursor uint64) ([]string, string, error) {

var allKeys []string
Expand All @@ -312,15 +311,21 @@ func readAllKeys(ctx context.Context, namespace string, b *RedisDB, pageSize int
if pageSize != -1 {
scanCount = min(RedisScanBatchSize, pageSize)
}
for pageSize == -1 || (len(allKeys) < pageSize) {
// Scan keys starting at cursor until the end or until we have enough keys
for {
keys, nextCursor, err = b.db.Scan(ctx, cursor, namespace+"*", int64(scanCount)).Result()
if err != nil {
return nil, "", errors.Wrap(err, "scan error")
}

allKeys = append(allKeys, keys...)

if nextCursor == 0 {
// Append keys one by one to ensure we don't exceed the page size
for _, key := range keys {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could this be simplified by using b.db.Scan(ctx, cursor, namespace+"*", int64(scanCount)).Iterator() above?

(as it happens, I also think there might be an implementation error right now, where only a subset of keys are returned)

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes @andresuribe87, it's possible, but ...
While the iterator provides a more convenient way to iterate over the keys returned by the Scan method, it does not provide a way to retrieve the current cursor value. As a result, a call to the Scan method is still necessary to retrieve the next cursor value. This allows us to keep the same function signature as the original implementation, which includes returning the next cursor value as a string.
I'm not sure if it's a good practice, calling the Scan method again after using the iterator to retrieve the next cursor value is redundant and inefficient. A more efficient approach would be to use only the Scan method without using an iterator. This avoids the need for an additional call to the Scan method to retrieve the next cursor value.

I see these options:

  1. Keep the function as it is using the Scan.
  2. Use Iterator and then call Scan to get the cursor.
  3. Analyse to see if changing the function's signature is a valid option.

If changing the function signature, we can support pagination without having a cursor by using an offset:

func readAllKeys(ctx context.Context, namespace string, b *RedisDB, pageSize int, pageNum int) ([]string, error) {
	var allKeys []string
	var err error
	offset := (pageNum - 1) * pageSize
	iter := b.db.Scan(ctx, 0, namespace+"*", 0).Iterator()
	for i := 0; iter.Next(ctx); i++ {
		if i >= offset && i < offset+pageSize {
			allKeys = append(allKeys, iter.Val())
		}
	}
	if err = iter.Err(); err != nil {
		return nil, errors.Wrap(err, "scan error")
	}
	return allKeys, nil
}

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it does not provide a way to retrieve the current cursor value.

Ah yes, excellent point.

Changing the function's signature seems like the best approach to me. I suggest keeping the cursor, and adding offsetFromCursor int. It's important to keep the cursor because it's propagated from the API.
You'll likely also need to change the way in which the token is parsed in

if pageToken != "" {
var err error
cursor, err = strconv.ParseUint(pageToken, 10, 64)
if err != nil {
return nil, "", errors.Wrap(err, "parsing page token")
}
}
, and how it's encoded in https://github.com/TBD54566975/ssi-service/blob/c8541a8597ab1cd8ad1d650d7d682f58e5119ad5/pkg/storage/redis.go#L332-L335C1

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since it’s important to keep the cursor because it’s propagated from the API, using an approach that doesn’t support cursor-based pagination may not be suitable. It’s not possible to use the Iterator method to implement pagination in the same way as the original readAllKeys. As I said previously, one possible solution would be to modify the readAllKeys function to use a different approach for pagination, such as using an offset and limit instead of a cursor. However, this would require removing the cursor.

I considered the possibility of manually calculating the cursor value inside the readAllKeys function while using the Iterator method. However, this is not possible because the cursor value is an internal implementation detail of the Redis SCAN command, and its value is determined by the Redis server based on the current state of the key space. The only way to retrieve the updated cursor value is to use the Scan method, which returns the next cursor value along with the keys.

Hence, using Iterator() and keeping the cursor value updated it's not a possibility.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hence, using Iterator() and keeping the cursor value updated it's not a possibility

I agree. I don't think I was clear with my suggestion, my apologies. What I'm proposing is to keep cursor as a parameter, as well as adding an additional parameter called offsetFromCursor. Scan would still be used, instead of Iterator. Then, the function can start populating the result to return by making a Scan call from the given cursor, and only adding elements if they're greater than or equal to the offsetFromCursor. Does that make sense?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, so we can have add a new type to the Redis struct:

type RedisDB struct {
	db               *goredislib.Client
	offsetFromCursor uint64
}

This will be used as the new parameter.

In readAllKeys we'll read until i >= int(b.offsetFromCursor).
We can also have a new method to update the offsetFromCursor:

func (b *RedisDB) SetOffSetFromCursor(offset uint64) {
	b.offsetFromCursor = offset
}

However this change will impact all services that use this package. The offsetFromCursor needs to be updated but I'm not sure where.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The parameter offsetFromCursor is something that changes with every request, so I would expect it to be a parameter to the readAllKeys function. Can you elaborate on the advantages of putting in in the RedisDB struct? Keep in mind that the RedisDB struct is long lived.

What I was imagining is something were the signature of the function becomes the following:

func readAllKeys(ctx context.Context, namespace string, b *RedisDB, pageSize int, cursor uint64, offsetFromCursor int) (allKeys []string, nextCursor uint64, nextOffsetFromCursor int, err error) {

The implementation should add up to pageSize elements to allKeys by:

  1. calling Scan with the given cursor value
  2. adding elements from the above result only after the offsetFromCursor index.

When there are enough elements returned from Scan to be able to fill up allKeys to the desired pageSize, then nextOffsetFromCursor should increment by pageSize. Otherwise, you should keep iterating over the results from the DB until the previous condition happens.

As far as passing in the offsetFromCursor parameter from the callers, below is an example of what I was thinking:

func (b *RedisDB) ReadPage(ctx context.Context, namespace string, pageToken string, pageSize int) (map[string][]byte, string, error) {
	token := new(PageToken)
	if pageToken != "" {
		var err error
		token, err = parseToken(pageToken)
		if err != nil {
			return nil, "", err
		}
	}

	keys, nextCursor, offsetFromCursor, err := readAllKeys(ctx, namespace, b, pageSize, token.Cursor, token.OffsetFromCursor)
	if err != nil {
		return nil, "", err
	}
	results, err := readAll(ctx, namespace, keys, b)
	if err != nil {
		return nil, "", err
	}
	nextPageToken := PageToken{
		Cursor:           nextCursor,
		OffsetFromCursor: offsetFromCursor,
	}
	encodedToken, err := encodeToken(nextPageToken)
	if err != nil {
		return nil, "", err
	}
	return results, encodedToken, nil
}

type PageToken struct {
	Cursor           uint64
	OffsetFromCursor int
}

func parseToken(pageToken string) (*PageToken, error) {
	pageTokenData, err := base64.RawURLEncoding.DecodeString(pageToken)
	if err != nil {
		return nil, errors.Wrap(err, "decoding page token")
	}

	var token PageToken
	if err := json.Unmarshal(pageTokenData, &token); err != nil {
		return nil, errors.Wrap(err, "unmarshalling page token data")
	}

	return &token, nil
}

func encodeToken(token PageToken) (string, error) {
	data, err := json.Marshal(token)
	if err != nil {
		return "", errors.Wrap(err, "marshalling page token")
	}
	return base64.RawURLEncoding.EncodeToString(data), nil
}

Does that all make sense?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, so something like this should do the job:

func readAllKeys(ctx context.Context, namespace string, b *RedisDB, pageSize int, cursor uint64, offsetFromCursor int) (allKeys []string, nextCursor uint64, nextOffsetFromCursor int, err error) {
	allKeys = []string{}

	var keys []string
	scanCount := RedisScanBatchSize
	if pageSize != -1 {
		scanCount = min(RedisScanBatchSize, pageSize)

	}

	for {
		keys, nextCursor, err = b.db.Scan(ctx, cursor, namespace+"*", int64(scanCount)).Result()
		if err != nil {
			return nil, 0, 0, errors.Wrap(err, "scan error")
		}

		for i := offsetFromCursor; i < len(keys); i++ {
			allKeys = append(allKeys, keys[i])
			if len(allKeys) >= pageSize {
				nextOffsetFromCursor = i + 1
				break
			}
		}

		if len(allKeys) >= pageSize || nextCursor == 0 {
			break
		}

		cursor = nextCursor
		offsetFromCursor = 0 // Reset offset when advancing to next cursor
	}

	// If there are no more keys left in the current cursor position for the next page, reset the offset and advance the cursor.
	if nextOffsetFromCursor >= len(keys) {
		nextOffsetFromCursor = 0
	}

	return allKeys, nextCursor, nextOffsetFromCursor, nil
}

@andresuribe87 what do you think?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think so, yes! This is awesome, thanks so much @radurobot

Just a note: I don't trust my eyes as much as I would trust tests :)

if pageSize != -1 && len(allKeys) >= pageSize {
break
}
allKeys = append(allKeys, key)
}
// If we have enough keys or we reached the end, break
if nextCursor == 0 || (pageSize != -1 && len(allKeys) >= pageSize) {
break
}

Expand Down