diff --git a/cmd/backfill-index/main.go b/cmd/backfill-index/main.go index 6169b2a57..6abc137e7 100644 --- a/cmd/backfill-index/main.go +++ b/cmd/backfill-index/main.go @@ -38,6 +38,7 @@ import ( "log" "os" "os/signal" + "strconv" "strings" "syscall" "time" @@ -90,6 +91,10 @@ const ( providerMySQL ) +const ( + redisLastProcessedIndexKey = "last_processed_index" +) + type indexClient interface { idempotentAddToIndex(ctx context.Context, key, value string) error } @@ -127,6 +132,7 @@ var ( redisPassword = flag.String("redis-password", "", "Password for Redis authentication") redisEnableTLS = flag.Bool("redis-enable-tls", false, "Enable TLS for Redis client") redisInsecureSkipVerify = flag.Bool("redis-insecure-skip-verify", false, "Whether to skip TLS verification for Redis client or not") + enableRedisIndexResume = flag.Bool("enable-redis-index-resume", false, "Enable resuming from the last processed index stored in Redis. When enabled, the '--start' flag becomes optional for the Redis provider.") mysqlDSN = flag.String("mysql-dsn", "", "MySQL Data Source Name") startIndex = flag.Int("start", -1, "First index to backfill") endIndex = flag.Int("end", -1, "Last index to backfill") @@ -158,7 +164,7 @@ func main() { if *mysqlDSN != "" { provider = providerMySQL } - if *redisHostname != "" || *redisPort != "" || *redisPassword != "" { + if *redisHostname != "" || *redisPort != "" || *redisPassword != "" || *enableRedisIndexResume { provider = providerRedis } if provider == providerUnset { @@ -172,7 +178,10 @@ func main() { log.Fatal("Redis port must be set") } } - if *startIndex == -1 { + if *enableRedisIndexResume && *startIndex != -1 { + log.Fatal("--enable-redis-index-resume and --start cannot be set simultaneously") + } + if *startIndex == -1 && !*enableRedisIndexResume { log.Fatal("start must be set to >=0") } if *endIndex == -1 { @@ -276,6 +285,24 @@ func populate(indexClient indexClient, rekorClient *rekorclient.Rekor) (err erro } }() + var lastFilled int + if *enableRedisIndexResume && !*dryRun { + redisClient, ok := indexClient.(*redisClient) + if !ok { + return fmt.Errorf("enableRedisIndexResume is only supported with Redis backend") + } + lastFilled, err = redisClient.getLastFilledIndex(ctx) + if err != nil { + return fmt.Errorf("failed to retrieve last filled index: %v", err) + } + if lastFilled == -1 { + log.Printf("%s not found, starting from index 0", redisLastProcessedIndexKey) + *startIndex = 0 + } else { + *startIndex = lastFilled + 1 // Start from the next index + } + } + for i := *startIndex; i <= *endIndex; i++ { index := i // capture loop variable for closure group.Go(func() error { @@ -345,6 +372,18 @@ func populate(indexClient indexClient, rekorClient *rekorclient.Rekor) (err erro return nil }) } + + if *enableRedisIndexResume && !*dryRun { + redisClient, ok := indexClient.(*redisClient) + if !ok { + return fmt.Errorf("--enable-redis-index-resume is only supported with Redis backend") + } + if err := redisClient.setLastFilledIndex(ctx, *endIndex); err != nil { + return fmt.Errorf("failed to set last filled index: %v", err) + } + fmt.Printf("Last filled index updated to %d\n", *endIndex) + } + err = group.Wait() if err != nil { return fmt.Errorf("error running backfill: %v", err) @@ -393,3 +432,22 @@ func (c *mysqlClient) idempotentAddToIndex(ctx context.Context, key, value strin _, err := c.client.NamedExecContext(ctx, mysqlWriteStmt, map[string]any{"key": key, "uuid": value}) return err } + +func (c *redisClient) getLastFilledIndex(ctx context.Context) (int, error) { + val, err := c.client.Get(ctx, redisLastProcessedIndexKey).Result() + if err != nil { + if err == redis.Nil { + return -1, nil // No index has been filled yet + } + return 0, fmt.Errorf("failed to get last filled index from Redis: %w", err) + } + index, err := strconv.Atoi(val) + if err != nil { + return 0, fmt.Errorf("invalid last filled index value in Redis: %w", err) + } + return index, nil +} + +func (c *redisClient) setLastFilledIndex(ctx context.Context, index int) error { + return c.client.Set(ctx, redisLastProcessedIndexKey, index, 0).Err() +}