diff --git a/internal/ratelimit/redis_slide_window.go b/internal/ratelimit/redis_slide_window.go index 2e06ae9..3e81890 100644 --- a/internal/ratelimit/redis_slide_window.go +++ b/internal/ratelimit/redis_slide_window.go @@ -17,8 +17,10 @@ package ratelimit import ( "context" _ "embed" + "fmt" "time" + "github.com/google/uuid" "github.com/redis/go-redis/v9" ) @@ -38,6 +40,10 @@ type RedisSlidingWindowLimiter struct { } func (r *RedisSlidingWindowLimiter) Limit(ctx context.Context, key string) (bool, error) { + uid, err := uuid.NewUUID() + if err != nil { + return false, fmt.Errorf("generate uuid failed: %w", err) + } return r.Cmd.Eval(ctx, luaSlideWindow, []string{key}, - r.Interval.Milliseconds(), r.Rate, time.Now().UnixMilli()).Bool() + r.Interval.Milliseconds(), r.Rate, time.Now().UnixMilli(), uid.String()).Bool() } diff --git a/internal/ratelimit/redis_slide_window_test.go b/internal/ratelimit/redis_slide_window_test.go index 6d303a8..f73eb3b 100644 --- a/internal/ratelimit/redis_slide_window_test.go +++ b/internal/ratelimit/redis_slide_window_test.go @@ -82,3 +82,33 @@ func initRedis() redis.Cmdable { }) return redisClient } + +func TestRedisSlidingWindowLimiter(t *testing.T) { + r := &RedisSlidingWindowLimiter{ + Cmd: initRedis(), + Interval: time.Second, + Rate: 1200, + } + var ( + total = 1500 // 总请求数 + succCount int // 成功请求数 + limitCount int // 被限流的请求数 + ) + start := time.Now() + for i := 0; i < total; i++ { + limit, err := r.Limit(context.Background(), "test") + if err != nil { + t.Fatalf("limit error: %v", err) + return + } + if limit { + limitCount++ + continue + } + succCount++ + } + end := time.Now() + t.Logf("开始时间: %v", start.Format(time.StampMilli)) + t.Logf("结束时间: %v", end.Format(time.StampMilli)) + t.Logf("total: %d, succ: %d, limited: %d", total, succCount, limitCount) +} diff --git a/internal/ratelimit/slide_window.lua b/internal/ratelimit/slide_window.lua index 522f1be..a933872 100644 --- a/internal/ratelimit/slide_window.lua +++ b/internal/ratelimit/slide_window.lua @@ -5,6 +5,9 @@ local window = tonumber(ARGV[1]) -- 阈值 local threshold = tonumber(ARGV[2]) local now = tonumber(ARGV[3]) +-- 唯一ID, 用于解决同一时间内多个请求只统计一次的问题 +-- SEE: issue #27 +local uid = ARGV[4] -- 窗口的起始时间 local min = now - window @@ -16,7 +19,7 @@ if cnt >= threshold then return "true" else -- 把 score 和 member 都设置成 now - redis.call('ZADD', key, now, now) + redis.call('ZADD', key, now, now .. uid) redis.call('PEXPIRE', key, window) return "false" end \ No newline at end of file