diff --git a/core/logx/writer.go b/core/logx/writer.go index f83dd995a3dc..cb0c121f534e 100644 --- a/core/logx/writer.go +++ b/core/logx/writer.go @@ -5,12 +5,12 @@ import ( "fmt" "io" "log" - "os" "path" "strings" "sync" "sync/atomic" + fatihcolor "github.com/fatih/color" "github.com/zeromicro/go-zero/core/color" ) @@ -76,8 +76,8 @@ func (w *atomicWriter) Swap(v Writer) Writer { } func newConsoleWriter() Writer { - outLog := newLogWriter(log.New(os.Stdout, "", flags)) - errLog := newLogWriter(log.New(os.Stderr, "", flags)) + outLog := newLogWriter(log.New(fatihcolor.Output, "", flags)) + errLog := newLogWriter(log.New(fatihcolor.Error, "", flags)) return &concreteWriter{ infoLog: outLog, errorLog: errLog, diff --git a/core/stat/internal/cpu_linux.go b/core/stat/internal/cpu_linux.go index f9c1729ae2d2..5c8c6521d742 100644 --- a/core/stat/internal/cpu_linux.go +++ b/core/stat/internal/cpu_linux.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "strings" + "sync" "time" "github.com/zeromicro/go-zero/core/iox" @@ -20,10 +21,11 @@ var ( preTotal uint64 quota float64 cores uint64 + initOnce sync.Once ) // if /proc not present, ignore the cpu calculation, like wsl linux -func init() { +func initialize() { cpus, err := cpuSets() if err != nil { logx.Error(err) @@ -69,10 +71,13 @@ func init() { // RefreshCpu refreshes cpu usage and returns. func RefreshCpu() uint64 { + initOnce.Do(initialize) + total, err := totalCpuUsage() if err != nil { return 0 } + system, err := systemCpuUsage() if err != nil { return 0 diff --git a/core/stores/kv/store.go b/core/stores/kv/store.go index f7ee6b42c8e6..fcaf07c1e978 100644 --- a/core/stores/kv/store.go +++ b/core/stores/kv/store.go @@ -110,7 +110,9 @@ type ( Ttl(key string) (int, error) TtlCtx(ctx context.Context, key string) (int, error) Zadd(key string, score int64, value string) (bool, error) + ZaddFloat(key string, score float64, value string) (bool, error) ZaddCtx(ctx context.Context, key string, score int64, value string) (bool, error) + ZaddFloatCtx(ctx context.Context, key string, score float64, value string) (bool, error) Zadds(key string, ps ...redis.Pair) (int64, error) ZaddsCtx(ctx context.Context, key string, ps ...redis.Pair) (int64, error) Zcard(key string) (int, error) @@ -787,13 +789,21 @@ func (cs clusterStore) Zadd(key string, score int64, value string) (bool, error) return cs.ZaddCtx(context.Background(), key, score, value) } +func (cs clusterStore) ZaddFloat(key string, score float64, value string) (bool, error) { + return cs.ZaddFloatCtx(context.Background(), key, score, value) +} + func (cs clusterStore) ZaddCtx(ctx context.Context, key string, score int64, value string) (bool, error) { + return cs.ZaddFloatCtx(ctx, key, float64(score), value) +} + +func (cs clusterStore) ZaddFloatCtx(ctx context.Context, key string, score float64, value string) (bool, error) { node, err := cs.getRedis(key) if err != nil { return false, err } - return node.ZaddCtx(ctx, key, score, value) + return node.ZaddFloatCtx(ctx, key, score, value) } func (cs clusterStore) Zadds(key string, ps ...redis.Pair) (int64, error) { diff --git a/core/stores/kv/store_test.go b/core/stores/kv/store_test.go index 1803d0159dd1..96d0e29bc465 100644 --- a/core/stores/kv/store_test.go +++ b/core/stores/kv/store_test.go @@ -22,7 +22,7 @@ func TestRedis_Decr(t *testing.T) { _, err := store.Decr("a") assert.NotNil(t, err) - runOnCluster(t, func(client Store) { + runOnCluster(func(client Store) { val, err := client.Decr("a") assert.Nil(t, err) assert.Equal(t, int64(-1), val) @@ -37,7 +37,7 @@ func TestRedis_DecrBy(t *testing.T) { _, err := store.Incrby("a", 2) assert.NotNil(t, err) - runOnCluster(t, func(client Store) { + runOnCluster(func(client Store) { val, err := client.Decrby("a", 2) assert.Nil(t, err) assert.Equal(t, int64(-2), val) @@ -52,7 +52,7 @@ func TestRedis_Exists(t *testing.T) { _, err := store.Exists("foo") assert.NotNil(t, err) - runOnCluster(t, func(client Store) { + runOnCluster(func(client Store) { ok, err := client.Exists("a") assert.Nil(t, err) assert.False(t, ok) @@ -68,7 +68,7 @@ func TestRedis_Eval(t *testing.T) { _, err := store.Eval(`redis.call("EXISTS", KEYS[1])`, "key1") assert.NotNil(t, err) - runOnCluster(t, func(client Store) { + runOnCluster(func(client Store) { _, err := client.Eval(`redis.call("EXISTS", KEYS[1])`, "notexist") assert.Equal(t, redis.Nil, err) err = client.Set("key1", "value1") @@ -88,7 +88,7 @@ func TestRedis_Hgetall(t *testing.T) { _, err = store.Hgetall("a") assert.NotNil(t, err) - runOnCluster(t, func(client Store) { + runOnCluster(func(client Store) { assert.Nil(t, client.Hset("a", "aa", "aaa")) assert.Nil(t, client.Hset("a", "bb", "bbb")) vals, err := client.Hgetall("a") @@ -105,7 +105,7 @@ func TestRedis_Hvals(t *testing.T) { _, err := store.Hvals("a") assert.NotNil(t, err) - runOnCluster(t, func(client Store) { + runOnCluster(func(client Store) { assert.Nil(t, client.Hset("a", "aa", "aaa")) assert.Nil(t, client.Hset("a", "bb", "bbb")) vals, err := client.Hvals("a") @@ -119,7 +119,7 @@ func TestRedis_Hsetnx(t *testing.T) { _, err := store.Hsetnx("a", "dd", "ddd") assert.NotNil(t, err) - runOnCluster(t, func(client Store) { + runOnCluster(func(client Store) { assert.Nil(t, client.Hset("a", "aa", "aaa")) assert.Nil(t, client.Hset("a", "bb", "bbb")) ok, err := client.Hsetnx("a", "bb", "ccc") @@ -141,7 +141,7 @@ func TestRedis_HdelHlen(t *testing.T) { _, err = store.Hlen("a") assert.NotNil(t, err) - runOnCluster(t, func(client Store) { + runOnCluster(func(client Store) { assert.Nil(t, client.Hset("a", "aa", "aaa")) assert.Nil(t, client.Hset("a", "bb", "bbb")) num, err := client.Hlen("a") @@ -161,7 +161,7 @@ func TestRedis_HIncrBy(t *testing.T) { _, err := store.Hincrby("key", "field", 3) assert.NotNil(t, err) - runOnCluster(t, func(client Store) { + runOnCluster(func(client Store) { val, err := client.Hincrby("key", "field", 2) assert.Nil(t, err) assert.Equal(t, 2, val) @@ -176,7 +176,7 @@ func TestRedis_Hkeys(t *testing.T) { _, err := store.Hkeys("a") assert.NotNil(t, err) - runOnCluster(t, func(client Store) { + runOnCluster(func(client Store) { assert.Nil(t, client.Hset("a", "aa", "aaa")) assert.Nil(t, client.Hset("a", "bb", "bbb")) vals, err := client.Hkeys("a") @@ -190,7 +190,7 @@ func TestRedis_Hmget(t *testing.T) { _, err := store.Hmget("a", "aa", "bb") assert.NotNil(t, err) - runOnCluster(t, func(client Store) { + runOnCluster(func(client Store) { assert.Nil(t, client.Hset("a", "aa", "aaa")) assert.Nil(t, client.Hset("a", "bb", "bbb")) vals, err := client.Hmget("a", "aa", "bb") @@ -209,7 +209,7 @@ func TestRedis_Hmset(t *testing.T) { }) assert.NotNil(t, err) - runOnCluster(t, func(client Store) { + runOnCluster(func(client Store) { assert.Nil(t, client.Hmset("a", map[string]string{ "aa": "aaa", "bb": "bbb", @@ -225,7 +225,7 @@ func TestRedis_Incr(t *testing.T) { _, err := store.Incr("a") assert.NotNil(t, err) - runOnCluster(t, func(client Store) { + runOnCluster(func(client Store) { val, err := client.Incr("a") assert.Nil(t, err) assert.Equal(t, int64(1), val) @@ -240,7 +240,7 @@ func TestRedis_IncrBy(t *testing.T) { _, err := store.Incrby("a", 2) assert.NotNil(t, err) - runOnCluster(t, func(client Store) { + runOnCluster(func(client Store) { val, err := client.Incrby("a", 2) assert.Nil(t, err) assert.Equal(t, int64(2), val) @@ -267,7 +267,7 @@ func TestRedis_List(t *testing.T) { _, err = store.Lindex("key", 0) assert.NotNil(t, err) - runOnCluster(t, func(client Store) { + runOnCluster(func(client Store) { val, err := client.Lpush("key", "value1", "value2") assert.Nil(t, err) assert.Equal(t, 2, val) @@ -316,7 +316,7 @@ func TestRedis_Persist(t *testing.T) { err = store.Expireat("key", time.Now().Unix()+5) assert.NotNil(t, err) - runOnCluster(t, func(client Store) { + runOnCluster(func(client Store) { ok, err := client.Persist("key") assert.Nil(t, err) assert.False(t, ok) @@ -348,7 +348,7 @@ func TestRedis_Sscan(t *testing.T) { _, err = store.Del(key) assert.NotNil(t, err) - runOnCluster(t, func(client Store) { + runOnCluster(func(client Store) { var list []string for i := 0; i < 1550; i++ { list = append(list, stringx.Randn(i)) @@ -390,7 +390,7 @@ func TestRedis_Set(t *testing.T) { _, err = store.Spop("key") assert.NotNil(t, err) - runOnCluster(t, func(client Store) { + runOnCluster(func(client Store) { num, err := client.Sadd("key", 1, 2, 3, 4) assert.Nil(t, err) assert.Equal(t, 4, num) @@ -434,7 +434,7 @@ func TestRedis_SetGetDel(t *testing.T) { _, err = store.Del("hello") assert.NotNil(t, err) - runOnCluster(t, func(client Store) { + runOnCluster(func(client Store) { err := client.Set("hello", "world") assert.Nil(t, err) val, err := client.Get("hello") @@ -457,7 +457,7 @@ func TestRedis_SetExNx(t *testing.T) { _, err = store.SetnxEx("newhello", "newworld", 5) assert.NotNil(t, err) - runOnCluster(t, func(client Store) { + runOnCluster(func(client Store) { err := client.Setex("hello", "world", 5) assert.Nil(t, err) ok, err := client.Setnx("hello", "newworld") @@ -495,7 +495,7 @@ func TestRedis_Getset(t *testing.T) { _, err := store.GetSet("hello", "world") assert.NotNil(t, err) - runOnCluster(t, func(client Store) { + runOnCluster(func(client Store) { val, err := client.GetSet("hello", "world") assert.Nil(t, err) assert.Equal(t, "", val) @@ -524,7 +524,7 @@ func TestRedis_SetGetDelHashField(t *testing.T) { _, err = store.Hdel("key", "field") assert.NotNil(t, err) - runOnCluster(t, func(client Store) { + runOnCluster(func(client Store) { err := client.Hset("key", "field", "value") assert.Nil(t, err) val, err := client.Hget("key", "field") @@ -587,8 +587,8 @@ func TestRedis_SortedSet(t *testing.T) { }) assert.NotNil(t, err) - runOnCluster(t, func(client Store) { - ok, err := client.Zadd("key", 1, "value1") + runOnCluster(func(client Store) { + ok, err := client.ZaddFloat("key", 1, "value1") assert.Nil(t, err) assert.True(t, ok) ok, err = client.Zadd("key", 2, "value1") @@ -724,7 +724,7 @@ func TestRedis_HyperLogLog(t *testing.T) { _, err = store.Pfcount("key") assert.NotNil(t, err) - runOnCluster(t, func(cluster Store) { + runOnCluster(func(cluster Store) { ok, err := cluster.Pfadd("key", "value") assert.Nil(t, err) assert.True(t, ok) @@ -734,7 +734,7 @@ func TestRedis_HyperLogLog(t *testing.T) { }) } -func runOnCluster(t *testing.T, fn func(cluster Store)) { +func runOnCluster(fn func(cluster Store)) { s1.FlushAll() s2.FlushAll() diff --git a/core/stores/redis/redis.go b/core/stores/redis/redis.go index b31eb874a21f..c4d9c40132eb 100644 --- a/core/stores/redis/redis.go +++ b/core/stores/redis/redis.go @@ -1836,8 +1836,19 @@ func (s *Redis) Zadd(key string, score int64, value string) (bool, error) { return s.ZaddCtx(context.Background(), key, score, value) } +// ZaddFloat is the implementation of redis zadd command. +func (s *Redis) ZaddFloat(key string, score float64, value string) (bool, error) { + return s.ZaddFloatCtx(context.Background(), key, score, value) +} + // ZaddCtx is the implementation of redis zadd command. func (s *Redis) ZaddCtx(ctx context.Context, key string, score int64, value string) ( + val bool, err error) { + return s.ZaddFloatCtx(ctx, key, float64(score), value) +} + +// ZaddFloatCtx is the implementation of redis zadd command. +func (s *Redis) ZaddFloatCtx(ctx context.Context, key string, score float64, value string) ( val bool, err error) { err = s.brk.DoWithAcceptable(func() error { conn, err := getRedis(s) @@ -1846,7 +1857,7 @@ func (s *Redis) ZaddCtx(ctx context.Context, key string, score int64, value stri } v, err := conn.ZAdd(ctx, key, &red.Z{ - Score: float64(score), + Score: score, Member: value, }).Result() if err != nil { diff --git a/core/stores/redis/redis_test.go b/core/stores/redis/redis_test.go index fdf90e71a4da..4f960909f7c6 100644 --- a/core/stores/redis/redis_test.go +++ b/core/stores/redis/redis_test.go @@ -810,7 +810,7 @@ func TestRedis_SetGetDelHashField(t *testing.T) { func TestRedis_SortedSet(t *testing.T) { runOnRedis(t, func(client *Redis) { - ok, err := client.Zadd("key", 1, "value1") + ok, err := client.ZaddFloat("key", 1, "value1") assert.Nil(t, err) assert.True(t, ok) ok, err = client.Zadd("key", 2, "value1") @@ -988,8 +988,8 @@ func TestRedis_SortedSet(t *testing.T) { assert.Equal(t, 0, len(pairs)) _, err = New(client.Addr, badType()).Zrevrank("key", "value") assert.NotNil(t, err) - client.Zadd("second", 2, "aa") - client.Zadd("third", 3, "bbb") + _, _ = client.Zadd("second", 2, "aa") + _, _ = client.Zadd("third", 3, "bbb") val, err = client.Zunionstore("union", &ZStore{ Keys: []string{"second", "third"}, Weights: []float64{1, 2}, @@ -1176,7 +1176,7 @@ func runOnRedis(t *testing.T, fn func(client *Redis)) { } if client != nil { - client.Close() + _ = client.Close() } }() fn(New(s.Addr())) @@ -1198,7 +1198,7 @@ func runOnRedisTLS(t *testing.T, fn func(client *Redis)) { t.Error(err) } if client != nil { - client.Close() + _ = client.Close() } }() fn(New(s.Addr(), WithTLS())) @@ -1214,6 +1214,6 @@ type mockedNode struct { RedisNode } -func (n mockedNode) BLPop(ctx context.Context, timeout time.Duration, keys ...string) *red.StringSliceCmd { +func (n mockedNode) BLPop(_ context.Context, _ time.Duration, _ ...string) *red.StringSliceCmd { return red.NewStringSliceCmd(context.Background(), "foo", "bar") } diff --git a/core/syncx/resourcemanager.go b/core/syncx/resourcemanager.go index e863ba0a3190..f556f924a64e 100644 --- a/core/syncx/resourcemanager.go +++ b/core/syncx/resourcemanager.go @@ -57,8 +57,8 @@ func (manager *ResourceManager) GetResource(key string, create func() (io.Closer } manager.lock.Lock() + defer manager.lock.Unlock() manager.resources[key] = resource - manager.lock.Unlock() return resource, nil }) diff --git a/core/syncx/resourcemanager_test.go b/core/syncx/resourcemanager_test.go index a62b7892c0a5..725b8d13ee77 100644 --- a/core/syncx/resourcemanager_test.go +++ b/core/syncx/resourcemanager_test.go @@ -74,6 +74,12 @@ func TestResourceManager_UseAfterClose(t *testing.T) { return nil, errors.New("fail") }) assert.NotNil(t, err) + + assert.Panics(t, func() { + _, err = manager.GetResource("key", func() (io.Closer, error) { + return &dummyResource{age: 123}, nil + }) + }) } } diff --git a/go.mod b/go.mod index 9fb63a4bb7d8..8c743d3cb0d5 100644 --- a/go.mod +++ b/go.mod @@ -35,7 +35,7 @@ require ( go.uber.org/goleak v1.1.12 golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a golang.org/x/time v0.0.0-20220411224347-583f2d630306 - google.golang.org/grpc v1.48.0 + google.golang.org/grpc v1.49.0 google.golang.org/protobuf v1.28.1 gopkg.in/cheggaaa/pb.v1 v1.0.28 gopkg.in/h2non/gock.v1 v1.1.2 diff --git a/go.sum b/go.sum index a7ae8517ea85..62aeb74981ca 100644 --- a/go.sum +++ b/go.sum @@ -971,8 +971,9 @@ google.golang.org/grpc v1.36.0/go.mod h1:qjiiYl8FncCW8feJPdyg3v6XW24KsRHe+dy9BAG google.golang.org/grpc v1.38.0/go.mod h1:NREThFqKR1f3iQ6oBuvc5LadQuXVGo9rkm5ZGrQdJfM= google.golang.org/grpc v1.41.0/go.mod h1:U3l9uK9J0sini8mHphKoXyaqDA/8VyGnDee1zzIUK6k= google.golang.org/grpc v1.46.2/go.mod h1:vN9eftEi1UMyUsIF80+uQXhHjbXYbm0uXoFCACuMGWk= -google.golang.org/grpc v1.48.0 h1:rQOsyJ/8+ufEDJd/Gdsz7HG220Mh9HAhFHRGnIjda0w= google.golang.org/grpc v1.48.0/go.mod h1:vN9eftEi1UMyUsIF80+uQXhHjbXYbm0uXoFCACuMGWk= +google.golang.org/grpc v1.49.0 h1:WTLtQzmQori5FUH25Pq4WT22oCsv8USpQ+F6rqtsmxw= +google.golang.org/grpc v1.49.0/go.mod h1:ZgQEeidpAuNRZ8iRrlBKXZQP1ghovWIVhdJRyCDK+GI= google.golang.org/protobuf v0.0.0-20200109180630-ec00e32a8dfd/go.mod h1:DFci5gLYBciE7Vtevhsrf46CRTquxDuWsQurQQe4oz8= google.golang.org/protobuf v0.0.0-20200221191635-4d8936d0db64/go.mod h1:kwYJMbMJ01Woi6D6+Kah6886xMZcty6N08ah7+eCXa0= google.golang.org/protobuf v0.0.0-20200228230310-ab0ca4ff8a60/go.mod h1:cfTl7dwQJ+fmap5saPgwCLgHXTUD7jkjRqWcaiX5VyM= diff --git a/readme-cn.md b/readme-cn.md index b77547d58c51..06f87af1690d 100644 --- a/readme-cn.md +++ b/readme-cn.md @@ -289,6 +289,7 @@ go-zero 已被许多公司用于生产部署,接入场景如在线教育、电 >74. 驭势科技 >75. 叮当跳动 >76. Keep +>77. simba innovation 如果贵公司也已使用 go-zero,欢迎在 [登记地址](https://github.com/zeromicro/go-zero/issues/602) 登记,仅仅为了推广,不做其它用途。 diff --git a/readme.md b/readme.md index fba75a92994c..4673efcda226 100644 --- a/readme.md +++ b/readme.md @@ -259,6 +259,8 @@ go-zero enlisted in the [CNCF Cloud Native Landscape](https://landscape.cncf.io/ If you like or are using this project to learn or start your solution, please give it a star. Thanks! +[![Star History Chart](https://api.star-history.com/svg?repos=zeromicro/go-zero&type=Date)](#go-zero) + ## Buy me a coffee Buy Me A Coffee diff --git a/rest/server.go b/rest/server.go index 0d0bad13586e..4c7a30ffaedf 100644 --- a/rest/server.go +++ b/rest/server.go @@ -126,7 +126,7 @@ func WithChain(chn chain.Chain) RunOption { func WithCors(origin ...string) RunOption { return func(server *Server) { server.router.SetNotAllowedHandler(cors.NotAllowedHandler(nil, origin...)) - server.Use(cors.Middleware(nil, origin...)) + server.router = newCorsRouter(server.router, nil, origin...) } } @@ -136,7 +136,7 @@ func WithCustomCors(middlewareFn func(header http.Header), notAllowedFn func(htt origin ...string) RunOption { return func(server *Server) { server.router.SetNotAllowedHandler(cors.NotAllowedHandler(notAllowedFn, origin...)) - server.Use(cors.Middleware(middlewareFn, origin...)) + server.router = newCorsRouter(server.router, middlewareFn, origin...) } } @@ -291,3 +291,19 @@ func validateSecret(secret string) { panic("secret's length can't be less than 8") } } + +type corsRouter struct { + httpx.Router + middleware Middleware +} + +func newCorsRouter(router httpx.Router, headerFn func(http.Header), origins ...string) httpx.Router { + return &corsRouter{ + Router: router, + middleware: cors.Middleware(headerFn, origins...), + } +} + +func (c *corsRouter) ServeHTTP(w http.ResponseWriter, r *http.Request) { + c.middleware(c.Router.ServeHTTP)(w, r) +} diff --git a/rest/server_test.go b/rest/server_test.go index bc99f1259e89..f356ff4a61e9 100644 --- a/rest/server_test.go +++ b/rest/server_test.go @@ -18,6 +18,7 @@ import ( "github.com/zeromicro/go-zero/core/logx" "github.com/zeromicro/go-zero/rest/chain" "github.com/zeromicro/go-zero/rest/httpx" + "github.com/zeromicro/go-zero/rest/internal/cors" "github.com/zeromicro/go-zero/rest/router" ) @@ -515,3 +516,23 @@ func TestServer_WithChain(t *testing.T) { rt.ServeHTTP(httptest.NewRecorder(), req) assert.Equal(t, int32(5), atomic.LoadInt32(&called)) } + +func TestServer_WithCors(t *testing.T) { + var called int32 + middleware := func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + atomic.AddInt32(&called, 1) + next.ServeHTTP(w, r) + }) + } + r := router.NewRouter() + assert.Nil(t, r.Handle(http.MethodOptions, "/", middleware(http.NotFoundHandler()))) + + cr := &corsRouter{ + Router: r, + middleware: cors.Middleware(nil, "*"), + } + req := httptest.NewRequest(http.MethodOptions, "/", nil) + cr.ServeHTTP(httptest.NewRecorder(), req) + assert.Equal(t, int32(0), atomic.LoadInt32(&called)) +} diff --git a/zrpc/config.go b/zrpc/config.go index 2a9595a16cb7..36dd00744ba8 100644 --- a/zrpc/config.go +++ b/zrpc/config.go @@ -19,6 +19,8 @@ type ( // setting 0 means no timeout Timeout int64 `json:",default=2000"` CpuThreshold int64 `json:",default=900,range=[0:1000]"` + // grpc health check switch + Health bool `json:",default=true"` } // A RpcClientConf is a rpc client config. diff --git a/zrpc/internal/rpcserver.go b/zrpc/internal/rpcserver.go index 98ef126fa345..3640877622f6 100644 --- a/zrpc/internal/rpcserver.go +++ b/zrpc/internal/rpcserver.go @@ -16,6 +16,7 @@ type ( rpcServerOptions struct { metrics *stat.Metrics + health bool } rpcServer struct { @@ -74,13 +75,17 @@ func (s *rpcServer) Start(register RegisterFn) error { register(server) // register the health check service - grpc_health_v1.RegisterHealthServer(server, s.health) - s.health.Resume() + if s.health != nil { + grpc_health_v1.RegisterHealthServer(server, s.health) + s.health.Resume() + } // we need to make sure all others are wrapped up, // so we do graceful stop at shutdown phase instead of wrap up phase waitForCalled := proc.AddWrapUpListener(func() { - s.health.Shutdown() + if s.health != nil { + s.health.Shutdown() + } server.GracefulStop() }) defer waitForCalled() @@ -94,3 +99,10 @@ func WithMetrics(metrics *stat.Metrics) ServerOption { options.metrics = metrics } } + +// WithRpcHealth returns a func that sets rpc health switch to a Server. +func WithRpcHealth(health bool) ServerOption { + return func(options *rpcServerOptions) { + options.health = health + } +} diff --git a/zrpc/internal/server.go b/zrpc/internal/server.go index 5a5d56550968..ad405b3458a5 100644 --- a/zrpc/internal/server.go +++ b/zrpc/internal/server.go @@ -35,9 +35,13 @@ type ( ) func newBaseRpcServer(address string, rpcServerOpts *rpcServerOptions) *baseRpcServer { + var h *health.Server + if rpcServerOpts.health { + h = health.NewServer() + } return &baseRpcServer{ address: address, - health: health.NewServer(), + health: h, metrics: rpcServerOpts.metrics, options: []grpc.ServerOption{grpc.KeepaliveParams(keepalive.ServerParameters{ MaxConnectionIdle: defaultConnectionIdleDuration, diff --git a/zrpc/server.go b/zrpc/server.go index c73861844b8d..6e6b1cfc0a43 100644 --- a/zrpc/server.go +++ b/zrpc/server.go @@ -40,6 +40,7 @@ func NewServer(c RpcServerConf, register internal.RegisterFn) (*RpcServer, error metrics := stat.NewMetrics(c.ListenOn) serverOptions := []internal.ServerOption{ internal.WithMetrics(metrics), + internal.WithRpcHealth(c.Health), } if c.HasEtcd() {