diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 4692e228d2d8..f53fe073d4e6 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -29,8 +29,17 @@ jobs: - name: Lint run: | go vet -stdmethods=false $(go list ./...) - go install mvdan.cc/gofumpt@latest - test -z "$(gofumpt -l -extra .)" || echo "Please run 'gofumpt -l -w -extra .'" + + if ! test -z "$(gofmt -l .)"; then + echo "Please run 'gofmt -l -w .'" + exit 1 + fi + + go mod tidy + if ! test -z "$(git status --porcelain)"; then + echo "Please run 'go mod tidy'" + exit 1 + fi - name: Test run: go test -race -coverprofile=coverage.txt -covermode=atomic ./... diff --git a/core/bloom/bloom.go b/core/bloom/bloom.go index bc08df25ca0f..8646b0174b54 100644 --- a/core/bloom/bloom.go +++ b/core/bloom/bloom.go @@ -11,25 +11,27 @@ import ( const ( // for detailed error rate table, see http://pages.cs.wisc.edu/~cao/papers/summary-cache/node8.html // maps as k in the error rate table - maps = 14 - setScript = ` + maps = 14 +) + +var ( + // ErrTooLargeOffset indicates the offset is too large in bitset. + ErrTooLargeOffset = errors.New("too large offset") + setScript = redis.NewScript(` for _, offset in ipairs(ARGV) do redis.call("setbit", KEYS[1], offset, 1) end -` - testScript = ` +`) + testScript = redis.NewScript(` for _, offset in ipairs(ARGV) do if tonumber(redis.call("getbit", KEYS[1], offset)) == 0 then return false end end return true -` +`) ) -// ErrTooLargeOffset indicates the offset is too large in bitset. -var ErrTooLargeOffset = errors.New("too large offset") - type ( // A Filter is a bloom filter. Filter struct { @@ -117,7 +119,7 @@ func (r *redisBitSet) check(offsets []uint) (bool, error) { return false, err } - resp, err := r.store.Eval(testScript, []string{r.key}, args) + resp, err := r.store.ScriptRun(testScript, []string{r.key}, args) if err == redis.Nil { return false, nil } else if err != nil { @@ -147,7 +149,7 @@ func (r *redisBitSet) set(offsets []uint) error { return err } - _, err = r.store.Eval(setScript, []string{r.key}, args) + _, err = r.store.ScriptRun(setScript, []string{r.key}, args) if err == redis.Nil { return nil } diff --git a/core/bloom/bloom_test.go b/core/bloom/bloom_test.go index 35bf15c85e68..6362e39005be 100644 --- a/core/bloom/bloom_test.go +++ b/core/bloom/bloom_test.go @@ -4,13 +4,12 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/core/stores/redis/redistest" ) func TestRedisBitSet_New_Set_Test(t *testing.T) { - store, clean, err := redistest.CreateRedis() - assert.Nil(t, err) - defer clean() + store := redistest.CreateRedis(t) bitSet := newRedisBitSet(store, "test_key", 1024) isSetBefore, err := bitSet.check([]uint{0}) @@ -42,9 +41,7 @@ func TestRedisBitSet_New_Set_Test(t *testing.T) { } func TestRedisBitSet_Add(t *testing.T) { - store, clean, err := redistest.CreateRedis() - assert.Nil(t, err) - defer clean() + store := redistest.CreateRedis(t) filter := New(store, "test_key", 64) assert.Nil(t, filter.Add([]byte("hello"))) @@ -53,3 +50,49 @@ func TestRedisBitSet_Add(t *testing.T) { assert.Nil(t, err) assert.True(t, ok) } + +func TestFilter_Exists(t *testing.T) { + store, clean := redistest.CreateRedisWithClean(t) + + rbs := New(store, "test", 64) + _, err := rbs.Exists([]byte{0, 1, 2}) + assert.NoError(t, err) + + clean() + rbs = New(store, "test", 64) + _, err = rbs.Exists([]byte{0, 1, 2}) + assert.Error(t, err) +} + +func TestRedisBitSet_check(t *testing.T) { + store, clean := redistest.CreateRedisWithClean(t) + + rbs := newRedisBitSet(store, "test", 0) + assert.Error(t, rbs.set([]uint{0, 1, 2})) + _, err := rbs.check([]uint{0, 1, 2}) + assert.Error(t, err) + + rbs = newRedisBitSet(store, "test", 64) + _, err = rbs.check([]uint{0, 1, 2}) + assert.NoError(t, err) + + clean() + rbs = newRedisBitSet(store, "test", 64) + _, err = rbs.check([]uint{0, 1, 2}) + assert.Error(t, err) +} + +func TestRedisBitSet_set(t *testing.T) { + logx.Disable() + store, clean := redistest.CreateRedisWithClean(t) + + rbs := newRedisBitSet(store, "test", 0) + assert.Error(t, rbs.set([]uint{0, 1, 2})) + + rbs = newRedisBitSet(store, "test", 64) + assert.NoError(t, rbs.set([]uint{0, 1, 2})) + + clean() + rbs = newRedisBitSet(store, "test", 64) + assert.Error(t, rbs.set([]uint{0, 1, 2})) +} diff --git a/core/limit/periodlimit.go b/core/limit/periodlimit.go index 5250878a25fc..332fe3c95c97 100644 --- a/core/limit/periodlimit.go +++ b/core/limit/periodlimit.go @@ -9,21 +9,6 @@ import ( "github.com/zeromicro/go-zero/core/stores/redis" ) -// to be compatible with aliyun redis, we cannot use `local key = KEYS[1]` to reuse the key -const periodScript = `local limit = tonumber(ARGV[1]) -local window = tonumber(ARGV[2]) -local current = redis.call("INCRBY", KEYS[1], 1) -if current == 1 then - redis.call("expire", KEYS[1], window) -end -if current < limit then - return 1 -elseif current == limit then - return 2 -else - return 0 -end` - const ( // Unknown means not initialized state. Unknown = iota @@ -39,8 +24,25 @@ const ( internalHitQuota = 2 ) -// ErrUnknownCode is an error that represents unknown status code. -var ErrUnknownCode = errors.New("unknown status code") +var ( + // ErrUnknownCode is an error that represents unknown status code. + ErrUnknownCode = errors.New("unknown status code") + + // to be compatible with aliyun redis, we cannot use `local key = KEYS[1]` to reuse the key + periodScript = redis.NewScript(`local limit = tonumber(ARGV[1]) +local window = tonumber(ARGV[2]) +local current = redis.call("INCRBY", KEYS[1], 1) +if current == 1 then + redis.call("expire", KEYS[1], window) +end +if current < limit then + return 1 +elseif current == limit then + return 2 +else + return 0 +end`) +) type ( // PeriodOption defines the method to customize a PeriodLimit. @@ -80,7 +82,7 @@ func (h *PeriodLimit) Take(key string) (int, error) { // TakeCtx requests a permit with context, it returns the permit state. func (h *PeriodLimit) TakeCtx(ctx context.Context, key string) (int, error) { - resp, err := h.limitStore.EvalCtx(ctx, periodScript, []string{h.keyPrefix + key}, []string{ + resp, err := h.limitStore.ScriptRunCtx(ctx, periodScript, []string{h.keyPrefix + key}, []string{ strconv.Itoa(h.quota), strconv.Itoa(h.calcExpireSeconds()), }) diff --git a/core/limit/periodlimit_test.go b/core/limit/periodlimit_test.go index c23c14db4dcb..9ff35ca3e1e2 100644 --- a/core/limit/periodlimit_test.go +++ b/core/limit/periodlimit_test.go @@ -33,9 +33,7 @@ func TestPeriodLimit_RedisUnavailable(t *testing.T) { } func testPeriodLimit(t *testing.T, opts ...PeriodOption) { - store, clean, err := redistest.CreateRedis() - assert.Nil(t, err) - defer clean() + store := redistest.CreateRedis(t) const ( seconds = 1 diff --git a/core/limit/tokenlimit.go b/core/limit/tokenlimit.go index aeb52f9ade36..fe576dd9d0fa 100644 --- a/core/limit/tokenlimit.go +++ b/core/limit/tokenlimit.go @@ -15,10 +15,15 @@ import ( ) const ( - // to be compatible with aliyun redis, we cannot use `local key = KEYS[1]` to reuse the key - // KEYS[1] as tokens_key - // KEYS[2] as timestamp_key - script = `local rate = tonumber(ARGV[1]) + tokenFormat = "{%s}.tokens" + timestampFormat = "{%s}.ts" + pingInterval = time.Millisecond * 100 +) + +// to be compatible with aliyun redis, we cannot use `local key = KEYS[1]` to reuse the key +// KEYS[1] as tokens_key +// KEYS[2] as timestamp_key +var script = redis.NewScript(`local rate = tonumber(ARGV[1]) local capacity = tonumber(ARGV[2]) local now = tonumber(ARGV[3]) local requested = tonumber(ARGV[4]) @@ -45,11 +50,7 @@ end redis.call("setex", KEYS[1], ttl, new_tokens) redis.call("setex", KEYS[2], ttl, now) -return allowed` - tokenFormat = "{%s}.tokens" - timestampFormat = "{%s}.ts" - pingInterval = time.Millisecond * 100 -) +return allowed`) // A TokenLimiter controls how frequently events are allowed to happen with in one second. type TokenLimiter struct { @@ -110,7 +111,7 @@ func (lim *TokenLimiter) reserveN(ctx context.Context, now time.Time, n int) boo return lim.rescueLimiter.AllowN(now, n) } - resp, err := lim.store.EvalCtx(ctx, + resp, err := lim.store.ScriptRunCtx(ctx, script, []string{ lim.tokenKey, diff --git a/core/limit/tokenlimit_test.go b/core/limit/tokenlimit_test.go index 65107d01a386..22829a822fa7 100644 --- a/core/limit/tokenlimit_test.go +++ b/core/limit/tokenlimit_test.go @@ -70,9 +70,7 @@ func TestTokenLimit_Rescue(t *testing.T) { } func TestTokenLimit_Take(t *testing.T) { - store, clean, err := redistest.CreateRedis() - assert.Nil(t, err) - defer clean() + store := redistest.CreateRedis(t) const ( total = 100 @@ -92,9 +90,7 @@ func TestTokenLimit_Take(t *testing.T) { } func TestTokenLimit_TakeBurst(t *testing.T) { - store, clean, err := redistest.CreateRedis() - assert.Nil(t, err) - defer clean() + store := redistest.CreateRedis(t) const ( total = 100 diff --git a/core/stores/cache/cache_test.go b/core/stores/cache/cache_test.go index d46e20cd2486..fef7f716532c 100644 --- a/core/stores/cache/cache_test.go +++ b/core/stores/cache/cache_test.go @@ -112,12 +112,8 @@ func (mc *mockedNode) TakeWithExpireCtx(ctx context.Context, val any, key string func TestCache_SetDel(t *testing.T) { t.Run("test set del", func(t *testing.T) { const total = 1000 - r1, clean1, err := redistest.CreateRedis() - assert.Nil(t, err) - defer clean1() - r2, clean2, err := redistest.CreateRedis() - assert.Nil(t, err) - defer clean2() + r1 := redistest.CreateRedis(t) + r2 := redistest.CreateRedis(t) conf := ClusterConf{ { RedisConf: redis.RedisConf{ @@ -193,9 +189,7 @@ func TestCache_SetDel(t *testing.T) { func TestCache_OneNode(t *testing.T) { const total = 1000 - r, clean, err := redistest.CreateRedis() - assert.Nil(t, err) - defer clean() + r := redistest.CreateRedis(t) conf := ClusterConf{ { RedisConf: redis.RedisConf{ diff --git a/core/stores/cache/cachenode_test.go b/core/stores/cache/cachenode_test.go index 8b2fe92212e8..0a2c358c6607 100644 --- a/core/stores/cache/cachenode_test.go +++ b/core/stores/cache/cachenode_test.go @@ -34,10 +34,8 @@ func init() { func TestCacheNode_DelCache(t *testing.T) { t.Run("del cache", func(t *testing.T) { - store, clean, err := redistest.CreateRedis() - assert.Nil(t, err) + store := redistest.CreateRedis(t) store.Type = redis.ClusterType - defer clean() cn := cacheNode{ rds: store, @@ -84,9 +82,7 @@ func TestCacheNode_DelCache(t *testing.T) { } func TestCacheNode_DelCacheWithErrors(t *testing.T) { - store, clean, err := redistest.CreateRedis() - assert.Nil(t, err) - defer clean() + store := redistest.CreateRedis(t) store.Type = redis.ClusterType cn := cacheNode{ @@ -122,9 +118,7 @@ func TestCacheNode_InvalidCache(t *testing.T) { } func TestCacheNode_SetWithExpire(t *testing.T) { - store, clean, err := redistest.CreateRedis() - assert.Nil(t, err) - defer clean() + store := redistest.CreateRedis(t) cn := cacheNode{ rds: store, @@ -139,14 +133,12 @@ func TestCacheNode_SetWithExpire(t *testing.T) { } func TestCacheNode_Take(t *testing.T) { - store, clean, err := redistest.CreateRedis() - assert.Nil(t, err) - defer clean() + store := redistest.CreateRedis(t) cn := NewNode(store, syncx.NewSingleFlight(), NewStat("any"), errTestNotFound, WithExpiry(time.Second), WithNotFoundExpiry(time.Second)) var str string - err = cn.Take(&str, "any", func(v any) error { + err := cn.Take(&str, "any", func(v any) error { *v.(*string) = "value" return nil }) @@ -174,9 +166,7 @@ func TestCacheNode_TakeBadRedis(t *testing.T) { } func TestCacheNode_TakeNotFound(t *testing.T) { - store, clean, err := redistest.CreateRedis() - assert.Nil(t, err) - defer clean() + store := redistest.CreateRedis(t) cn := cacheNode{ rds: store, @@ -188,7 +178,7 @@ func TestCacheNode_TakeNotFound(t *testing.T) { errNotFound: errTestNotFound, } var str string - err = cn.Take(&str, "any", func(v any) error { + err := cn.Take(&str, "any", func(v any) error { return errTestNotFound }) assert.True(t, cn.IsNotFound(err)) @@ -213,9 +203,7 @@ func TestCacheNode_TakeNotFound(t *testing.T) { } func TestCacheNode_TakeNotFoundButChangedByOthers(t *testing.T) { - store, clean, err := redistest.CreateRedis() - assert.NoError(t, err) - defer clean() + store := redistest.CreateRedis(t) cn := cacheNode{ rds: store, @@ -228,7 +216,7 @@ func TestCacheNode_TakeNotFoundButChangedByOthers(t *testing.T) { } var str string - err = cn.Take(&str, "any", func(v any) error { + err := cn.Take(&str, "any", func(v any) error { store.Set("any", "foo") return errTestNotFound }) @@ -242,9 +230,7 @@ func TestCacheNode_TakeNotFoundButChangedByOthers(t *testing.T) { } func TestCacheNode_TakeWithExpire(t *testing.T) { - store, clean, err := redistest.CreateRedis() - assert.Nil(t, err) - defer clean() + store := redistest.CreateRedis(t) cn := cacheNode{ rds: store, @@ -256,7 +242,7 @@ func TestCacheNode_TakeWithExpire(t *testing.T) { errNotFound: errors.New("any"), } var str string - err = cn.TakeWithExpire(&str, "any", func(v any, expire time.Duration) error { + err := cn.TakeWithExpire(&str, "any", func(v any, expire time.Duration) error { *v.(*string) = "value" return nil }) @@ -269,9 +255,7 @@ func TestCacheNode_TakeWithExpire(t *testing.T) { } func TestCacheNode_String(t *testing.T) { - store, clean, err := redistest.CreateRedis() - assert.Nil(t, err) - defer clean() + store := redistest.CreateRedis(t) cn := cacheNode{ rds: store, @@ -286,9 +270,7 @@ func TestCacheNode_String(t *testing.T) { } func TestCacheValueWithBigInt(t *testing.T) { - store, clean, err := redistest.CreateRedis() - assert.Nil(t, err) - defer clean() + store := redistest.CreateRedis(t) cn := cacheNode{ rds: store, diff --git a/core/stores/redis/redis.go b/core/stores/redis/redis.go index dde4be82f123..75b8b327cb08 100644 --- a/core/stores/redis/redis.go +++ b/core/stores/redis/redis.go @@ -88,6 +88,8 @@ type ( FloatCmd = red.FloatCmd // StringCmd is an alias of redis.StringCmd. StringCmd = red.StringCmd + // Script is an alias of redis.Script. + Script = red.Script ) // New returns a Redis with given options. @@ -146,6 +148,11 @@ func newRedis(addr string, opts ...Option) *Redis { return r } +// NewScript returns a new Script instance. +func NewScript(script string) *Script { + return red.NewScript(script) +} + // BitCount is redis bitcount command implementation. func (s *Redis) BitCount(key string, start, end int64) (int64, error) { return s.BitCountCtx(context.Background(), key, start, end) @@ -1646,6 +1653,25 @@ func (s *Redis) ScriptLoadCtx(ctx context.Context, script string) (string, error return conn.ScriptLoad(ctx, script).Result() } +// ScriptRun is the implementation of *redis.Script run command. +func (s *Redis) ScriptRun(script *Script, keys []string, args ...any) (any, error) { + return s.ScriptRunCtx(context.Background(), script, keys, args...) +} + +// ScriptRunCtx is the implementation of *redis.Script run command. +func (s *Redis) ScriptRunCtx(ctx context.Context, script *Script, keys []string, args ...any) (val any, err error) { + err = s.brk.DoWithAcceptable(func() error { + conn, err := getRedis(s) + if err != nil { + return err + } + + val, err = script.Run(ctx, conn, keys, args...).Result() + return err + }, acceptable) + return +} + // Set is the implementation of redis set command. func (s *Redis) Set(key, value string) error { return s.SetCtx(context.Background(), key, value) diff --git a/core/stores/redis/redis_test.go b/core/stores/redis/redis_test.go index b83463e158cf..5fc00388d800 100644 --- a/core/stores/redis/redis_test.go +++ b/core/stores/redis/redis_test.go @@ -241,6 +241,24 @@ func TestRedis_Eval(t *testing.T) { }) } +func TestRedis_ScriptRun(t *testing.T) { + runOnRedis(t, func(client *Redis) { + sc := NewScript(`redis.call("EXISTS", KEYS[1])`) + sc2 := NewScript(`return redis.call("EXISTS", KEYS[1])`) + _, err := New(client.Addr, badType()).ScriptRun(sc, []string{"notexist"}) + assert.NotNil(t, err) + _, err = client.ScriptRun(sc, []string{"notexist"}) + assert.Equal(t, Nil, err) + err = client.Set("key1", "value1") + assert.Nil(t, err) + _, err = client.ScriptRun(sc, []string{"key1"}) + assert.Equal(t, Nil, err) + val, err := client.ScriptRun(sc2, []string{"key1"}) + assert.Nil(t, err) + assert.Equal(t, int64(1), val) + }) +} + func TestRedis_GeoHash(t *testing.T) { runOnRedis(t, func(client *Redis) { _, err := client.GeoHash("parent", "child1", "child2") diff --git a/core/stores/redis/redislock.go b/core/stores/redis/redislock.go index de66e97e6a33..c740ef3689bd 100644 --- a/core/stores/redis/redislock.go +++ b/core/stores/redis/redislock.go @@ -17,17 +17,20 @@ const ( randomLen = 16 tolerance = 500 // milliseconds millisPerSecond = 1000 - lockCommand = `if redis.call("GET", KEYS[1]) == ARGV[1] then +) + +var ( + lockScript = NewScript(`if redis.call("GET", KEYS[1]) == ARGV[1] then redis.call("SET", KEYS[1], ARGV[1], "PX", ARGV[2]) return "OK" else return redis.call("SET", KEYS[1], ARGV[1], "NX", "PX", ARGV[2]) -end` - delCommand = `if redis.call("GET", KEYS[1]) == ARGV[1] then +end`) + delScript = NewScript(`if redis.call("GET", KEYS[1]) == ARGV[1] then return redis.call("DEL", KEYS[1]) else return 0 -end` +end`) ) // A RedisLock is a redis lock. @@ -59,7 +62,7 @@ func (rl *RedisLock) Acquire() (bool, error) { // AcquireCtx acquires the lock with the given ctx. func (rl *RedisLock) AcquireCtx(ctx context.Context) (bool, error) { seconds := atomic.LoadUint32(&rl.seconds) - resp, err := rl.store.EvalCtx(ctx, lockCommand, []string{rl.key}, []string{ + resp, err := rl.store.ScriptRunCtx(ctx, lockScript, []string{rl.key}, []string{ rl.id, strconv.Itoa(int(seconds)*millisPerSecond + tolerance), }) if err == red.Nil { @@ -87,7 +90,7 @@ func (rl *RedisLock) Release() (bool, error) { // ReleaseCtx releases the lock with the given ctx. func (rl *RedisLock) ReleaseCtx(ctx context.Context) (bool, error) { - resp, err := rl.store.EvalCtx(ctx, delCommand, []string{rl.key}, []string{rl.id}) + resp, err := rl.store.ScriptRunCtx(ctx, delScript, []string{rl.key}, []string{rl.id}) if err != nil { return false, err } diff --git a/core/stores/redis/redistest/redistest.go b/core/stores/redis/redistest/redistest.go index 82a1128c0a40..6c82c22e15a3 100644 --- a/core/stores/redis/redistest/redistest.go +++ b/core/stores/redis/redistest/redistest.go @@ -1,31 +1,20 @@ package redistest import ( - "time" + "testing" "github.com/alicebob/miniredis/v2" - "github.com/zeromicro/go-zero/core/lang" "github.com/zeromicro/go-zero/core/stores/redis" ) // CreateRedis returns an in process redis.Redis. -func CreateRedis() (r *redis.Redis, clean func(), err error) { - mr, err := miniredis.Run() - if err != nil { - return nil, nil, err - } - - return redis.New(mr.Addr()), func() { - ch := make(chan lang.PlaceholderType) - - go func() { - mr.Close() - close(ch) - }() +func CreateRedis(t *testing.T) *redis.Redis { + r, _ := CreateRedisWithClean(t) + return r +} - select { - case <-ch: - case <-time.After(time.Second): - } - }, nil +// CreateRedisWithClean returns an in process redis.Redis and a clean function. +func CreateRedisWithClean(t *testing.T) (r *redis.Redis, clean func()) { + mr := miniredis.RunT(t) + return redis.New(mr.Addr()), mr.Close } diff --git a/rest/handler/timeouthandler.go b/rest/handler/timeouthandler.go index 41fe38f29d42..40340a2efe00 100644 --- a/rest/handler/timeouthandler.go +++ b/rest/handler/timeouthandler.go @@ -183,7 +183,10 @@ func (tw *timeoutWriter) writeHeaderLocked(code int) { func (tw *timeoutWriter) WriteHeader(code int) { tw.mu.Lock() defer tw.mu.Unlock() - tw.writeHeaderLocked(code) + + if !tw.wroteHeader { + tw.writeHeaderLocked(code) + } } func checkWriteHeaderCode(code int) { diff --git a/zrpc/internal/auth/auth_test.go b/zrpc/internal/auth/auth_test.go index 20d6f8ef3f2e..c52fc871ae4e 100644 --- a/zrpc/internal/auth/auth_test.go +++ b/zrpc/internal/auth/auth_test.go @@ -43,9 +43,7 @@ func TestAuthenticator(t *testing.T) { }, } - store, clean, err := redistest.CreateRedis() - assert.Nil(t, err) - defer clean() + store := redistest.CreateRedis(t) for _, test := range tests { t.Run(test.name, func(t *testing.T) { diff --git a/zrpc/internal/serverinterceptors/authinterceptor_test.go b/zrpc/internal/serverinterceptors/authinterceptor_test.go index 587c3d7f7c9a..12d83a096a4b 100644 --- a/zrpc/internal/serverinterceptors/authinterceptor_test.go +++ b/zrpc/internal/serverinterceptors/authinterceptor_test.go @@ -45,9 +45,7 @@ func TestStreamAuthorizeInterceptor(t *testing.T) { }, } - store, clean, err := redistest.CreateRedis() - assert.Nil(t, err) - defer clean() + store := redistest.CreateRedis(t) for _, test := range tests { t.Run(test.name, func(t *testing.T) { @@ -111,9 +109,7 @@ func TestUnaryAuthorizeInterceptor(t *testing.T) { }, } - store, clean, err := redistest.CreateRedis() - assert.Nil(t, err) - defer clean() + store := redistest.CreateRedis(t) for _, test := range tests { t.Run(test.name, func(t *testing.T) {