From 3477b0309b81fa038bc3d0f4cafd0d65ea0eb223 Mon Sep 17 00:00:00 2001 From: Kyle Xiao Date: Mon, 21 Oct 2024 18:00:40 +0800 Subject: [PATCH 1/6] feat: response support hijack for upgrade conn (#1214) --- pkg/protocol/http1/client.go | 58 ++++++++++++++++++++++++++---- pkg/protocol/http1/client_test.go | 59 +++++++++++++++++++++++++++++++ pkg/protocol/response.go | 20 +++++++++++ pkg/protocol/response_test.go | 19 ++++++++++ 4 files changed, 150 insertions(+), 6 deletions(-) diff --git a/pkg/protocol/http1/client.go b/pkg/protocol/http1/client.go index f7b04ba25..149029db0 100644 --- a/pkg/protocol/http1/client.go +++ b/pkg/protocol/http1/client.go @@ -48,6 +48,7 @@ import ( "errors" "io" "net" + "runtime" "strings" "sync" "sync/atomic" @@ -689,22 +690,27 @@ func (c *HostClient) doNonNilReqResp(req *protocol.Request, resp *protocol.Respo return nil }) } + zr.Release() //nolint:errcheck if err != nil { - zr.Release() //nolint:errcheck c.closeConn(cc) // Don't retry in case of ErrBodyTooLarge since we will just get the same again. retry := !errors.Is(err, errs.ErrBodyTooLarge) return retry, err } - - zr.Release() //nolint:errcheck - shouldCloseConn = resetConnection || req.ConnectionClose() || resp.ConnectionClose() + if resp.Header.StatusCode() == consts.StatusSwitchingProtocols && + bytes.Equal(resp.Header.Peek(consts.HeaderConnection), bytestr.StrUpgrade) { + // can not reuse connection in this case, it's no longer http1 protocol. + // set BodyStream for (*Response).Hijack + resp.SetBodyStream(newUpgradeConn(c, cc), -1) + return false, nil + } + // In stream mode, we still can close/release the connection immediately if there is no content on the wire. if c.ResponseBodyStream && resp.BodyStream() != protocol.NoResponseBody { - return false, err + return false, nil } if shouldCloseConn { @@ -712,8 +718,48 @@ func (c *HostClient) doNonNilReqResp(req *protocol.Request, resp *protocol.Respo } else { c.releaseConn(cc) } + return false, nil +} + +var poolUpgradeConn = sync.Pool{ + New: func() interface{} { + return &upgradeConn{} + }, +} + +type upgradeConn struct { + c *HostClient + cc *clientConn +} + +func newUpgradeConn(c *HostClient, cc *clientConn) *upgradeConn { + p := poolUpgradeConn.Get().(*upgradeConn) + p.c = c + p.cc = cc + runtime.SetFinalizer(p, (*upgradeConn).gc) + return p +} - return false, err +// Read implements io.Reader +func (p *upgradeConn) Read(b []byte) (int, error) { return p.cc.c.Read(b) } + +// Hijack returns underlying network.Conn. This method is called by (*Response).Hijack +func (p *upgradeConn) Hijack() (network.Conn, error) { return p.cc.c, nil } + +// gc closes conn and reuse upgradeConn. +// +// It MUST be called only by go runtime to avoid concurenccy issue. +// For the 1st GC, it closes conn, and put upgradeConn back to pool +// For the 2nd GC, it will be recycled if it's still in pool +func (p *upgradeConn) gc() error { + if p.c != nil { + runtime.SetFinalizer(p, nil) + p.c.closeConn(p.cc) + p.c = nil + p.cc = nil + poolUpgradeConn.Put(p) + } + return nil } func (c *HostClient) Close() error { diff --git a/pkg/protocol/http1/client_test.go b/pkg/protocol/http1/client_test.go index 8c0869dde..2ecb764a6 100644 --- a/pkg/protocol/http1/client_test.go +++ b/pkg/protocol/http1/client_test.go @@ -49,6 +49,7 @@ import ( "fmt" "io/ioutil" "net" + "net/http" "strings" "sync" "sync/atomic" @@ -311,6 +312,64 @@ func TestDoNonNilReqResp1(t *testing.T) { assert.NotNil(t, err) } +func TestConnUpgrade(t *testing.T) { + ln, _ := net.Listen("tcp", "localhost:0") + defer ln.Close() + svr := http.Server{} + svr.Handler = http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + hj, ok := w.(http.Hijacker) + if !ok { + http.Error(w, "webserver doesn't support hijacking", http.StatusInternalServerError) + return + } + conn, rw, err := hj.Hijack() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + defer conn.Close() + _, err = rw.WriteString("HTTP/1.1 101 Switching Protocols\nConnection: Upgrade\n\n") + assert.Nil(t, err) + assert.Nil(t, rw.Flush()) + b := make([]byte, 100) + for { // echo with "echo:" prefix + n, err := rw.Read(b) + if err != nil { + return + } + _, err = rw.Write([]byte("echo:" + string(b[:n]))) + if err != nil { + return + } + _ = rw.Flush() + } + }) + go svr.Serve(ln) + + c := &HostClient{ + Addr: ln.Addr().String(), + ClientOptions: &ClientOptions{}, + } + req := protocol.AcquireRequest() + req.SetRequestURI("http://" + ln.Addr().String() + "/") + resp := protocol.AcquireResponse() + retry, err := c.doNonNilReqResp(req, resp) + assert.False(t, retry) + assert.Nil(t, err) + assert.DeepEqual(t, resp.StatusCode(), 101) + + s := resp.BodyStream() + assert.NotNil(t, s) + conn, err := resp.Hijack() + assert.Nil(t, err) + + b := make([]byte, 100) + _, _ = conn.Write(append(b[:0], "hello"...)) + n, err := s.Read(b) // same as conn.Read + assert.Nil(t, err) + assert.DeepEqual(t, string(b[:n]), "echo:hello") +} + func TestWriteTimeoutPriority(t *testing.T) { c := &HostClient{ ClientOptions: &ClientOptions{ diff --git a/pkg/protocol/response.go b/pkg/protocol/response.go index 8beb38597..ccace1399 100644 --- a/pkg/protocol/response.go +++ b/pkg/protocol/response.go @@ -42,6 +42,7 @@ package protocol import ( + "errors" "io" "net" "sync" @@ -346,6 +347,25 @@ func (resp *Response) BodyStream() io.Reader { return resp.bodyStream } +// Hijack returns the underlying network.Conn if available. +// +// It's only available when StatusCode() == 101 and "Connection: Upgrade", +// coz Hertz will NOT reuse connection in this case, +// then make it optional for users to implement their own protocols. +// +// The most common scenario is used with github.com/hertz-contrib/websocket +func (resp *Response) Hijack() (network.Conn, error) { + if resp.bodyStream != nil { + h, ok := resp.bodyStream.(interface { + Hijack() (network.Conn, error) + }) + if ok { + return h.Hijack() + } + } + return nil, errors.New("not available") +} + // AppendBody appends p to response body. // // It is safe re-using p after the function returns. diff --git a/pkg/protocol/response_test.go b/pkg/protocol/response_test.go index 20a18ffce..6b09c2e19 100644 --- a/pkg/protocol/response_test.go +++ b/pkg/protocol/response_test.go @@ -43,6 +43,7 @@ package protocol import ( "bytes" + "errors" "fmt" "math" "reflect" @@ -52,6 +53,7 @@ import ( "github.com/cloudwego/hertz/pkg/common/compress" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/common/test/mock" + "github.com/cloudwego/hertz/pkg/network" "github.com/cloudwego/hertz/pkg/protocol/consts" ) @@ -294,3 +296,20 @@ func TestResponse_HijackWriter(t *testing.T) { resp.GetHijackWriter().Finalize() assert.True(t, isFinal) } + +type HijackerFunc func() (network.Conn, error) + +func (h HijackerFunc) Read(_ []byte) (int, error) { return 0, errors.New("not implemented") } +func (h HijackerFunc) Hijack() (network.Conn, error) { return h() } + +func TestResponse_Hijack(t *testing.T) { + resp := AcquireResponse() + defer ReleaseResponse(resp) + + _, err := resp.Hijack() + assert.NotNil(t, err) + + resp.SetBodyStream(HijackerFunc(func() (network.Conn, error) { return nil, nil }), -1) + _, err = resp.Hijack() + assert.Nil(t, err) +} From 3697f895e90fcec9182b74fabc78b00fb38a04c7 Mon Sep 17 00:00:00 2001 From: Kyle Xiao Date: Thu, 24 Oct 2024 14:28:31 +0800 Subject: [PATCH 2/6] test(app): speed up client & server tests (#1216) --- .github/workflows/tests.yml | 25 +- .gitignore | 5 + go.mod | 2 +- go.sum | 2 + pkg/app/client/client.go | 38 +- pkg/app/client/client_test.go | 396 +++++++----------- pkg/app/client/loadbalance/lbcache_test.go | 39 +- .../client/loadbalance/weight_random_test.go | 15 +- pkg/app/server/hertz_test.go | 165 ++++++-- pkg/app/server/hertz_unix_test.go | 29 +- 10 files changed, 393 insertions(+), 323 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 54ce8b96e..000f1d6c1 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -34,26 +34,41 @@ jobs: cache: false # don't use cache for self-hosted runners - name: Unit Test - run: go test -race -covermode=atomic -coverprofile=coverage.txt ./... - - - name: Codecov - run: bash <(curl -s https://codecov.io/bash) + run: go test -race ./... ut-windows: strategy: matrix: version: ["1.20", "1.21", "1.22", "1.23"] runs-on: windows-latest + env: # Fixes https://github.com/actions/setup-go/issues/240 + GOMODCACHE: 'D:\go\pkg\mod' + GOCACHE: 'D:\go\go-build' steps: - uses: actions/checkout@v4 - name: Set up Go uses: actions/setup-go@v5 with: go-version: ${{ matrix.version }} + + - name: Unit Test + run: go test -race ./... + + code-cov: + runs-on: [self-hosted, X64] + steps: + - uses: actions/checkout@v4 + - name: Set up Go + uses: actions/setup-go@v5 + with: + go-version: stable cache: false # don't use cache for self-hosted runners - name: Unit Test - run: go test -race -covermode=atomic ./... + run: go test -covermode=atomic -coverprofile=coverage.txt ./... + + - name: Codecov + run: bash <(curl -s https://codecov.io/bash) hz-test-unix: runs-on: [ self-hosted, X64 ] diff --git a/.gitignore b/.gitignore index 338957806..d4c580190 100644 --- a/.gitignore +++ b/.gitignore @@ -3,3 +3,8 @@ pkg/app/fs.go.hertz.gz coverage.txt coverage.out + +# test benchmark tmp output +cpu.out +mem.out +*.test diff --git a/go.mod b/go.mod index cc537142f..65cd910d6 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/bytedance/gopkg v0.1.0 github.com/bytedance/mockey v1.2.12 github.com/bytedance/sonic v1.12.0 - github.com/cloudwego/netpoll v0.6.2 + github.com/cloudwego/netpoll v0.6.4 github.com/fsnotify/fsnotify v1.5.4 github.com/tidwall/gjson v1.14.4 golang.org/x/sync v0.0.0-20210220032951-036812b2e83c diff --git a/go.sum b/go.sum index 446891e45..98999c738 100644 --- a/go.sum +++ b/go.sum @@ -16,6 +16,8 @@ github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= github.com/cloudwego/netpoll v0.6.2 h1:+KdILv5ATJU+222wNNXpHapYaBeRvvL8qhJyhcxRxrQ= github.com/cloudwego/netpoll v0.6.2/go.mod h1:kaqvfZ70qd4T2WtIIpCOi5Cxyob8viEpzLhCrTrz3HM= +github.com/cloudwego/netpoll v0.6.4 h1:z/dA4sOTUQof6zZIO4QNnLBXsDFFFEos9OOGloR6kno= +github.com/cloudwego/netpoll v0.6.4/go.mod h1:BtM+GjKTdwKoC8IOzD08/+8eEn2gYoiNLipFca6BVXQ= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/pkg/app/client/client.go b/pkg/app/client/client.go index 614b5b5c7..e75009671 100644 --- a/pkg/app/client/client.go +++ b/pkg/app/client/client.go @@ -551,33 +551,29 @@ func (c *Client) CloseIdleConnections() { } func (c *Client) mCleaner() { - mustStop := false - for { time.Sleep(10 * time.Second) - c.mLock.Lock() - for k, v := range c.m { - shouldRemove := v.ShouldRemove() - - if shouldRemove { - delete(c.m, k) - if f, ok := v.(io.Closer); ok { - err := f.Close() - if err != nil { - hlog.Warnf("clean hostclient error, addr: %s, err: %s", k, err.Error()) - } - } - } - } - if len(c.m) == 0 { - mustStop = true + if c.mClean() { + break } - c.mLock.Unlock() + } +} - if mustStop { - break +func (c *Client) mClean() bool { + c.mLock.Lock() + defer c.mLock.Unlock() + for k, v := range c.m { + if v.ShouldRemove() { + delete(c.m, k) + if f, ok := v.(io.Closer); ok { + err := f.Close() + if err != nil { + hlog.Warnf("clean hostclient error, addr: %s, err: %s", k, err.Error()) + } + } } } + return len(c.m) == 0 } func (c *Client) SetClientFactory(cf suite.ClientFactory) { diff --git a/pkg/app/client/client_test.go b/pkg/app/client/client_test.go index ed5959bea..c9332a2ec 100644 --- a/pkg/app/client/client_test.go +++ b/pkg/app/client/client_test.go @@ -57,6 +57,7 @@ import ( "path/filepath" "reflect" "regexp" + "strconv" "strings" "sync" "sync/atomic" @@ -82,9 +83,58 @@ import ( var errTooManyRedirects = errors.New("too many redirects detected when doing the request") +func assertNil(err error) { + if err != nil { + panic(err) + } +} + +var unixsockPath string + +func TestMain(m *testing.M) { + dir, err := os.MkdirTemp("", "tests-*") + assertNil(err) + unixsockPath = dir + defer os.RemoveAll(dir) + + m.Run() +} + +var nextUnixSockID = int32(10000) + +func nextUnixSock() string { + n := atomic.AddInt32(&nextUnixSockID, 1) + return filepath.Join(unixsockPath, strconv.Itoa(int(n))+".sock") +} + +func waitEngineRunning(e *route.Engine) { + for i := 0; i < 100; i++ { + if e.IsRunning() { + break + } + time.Sleep(10 * time.Millisecond) + } + opts := e.GetOptions() + network, addr := opts.Network, opts.Addr + if network == "" { + network = "tcp" + } + for i := 0; i < 100; i++ { + conn, err := net.Dial(network, addr) + if err != nil { + time.Sleep(10 * time.Millisecond) + continue + } + conn.Close() + return + } + + panic("not running") +} + func TestCloseIdleConnections(t *testing.T) { opt := config.NewOptions([]config.Option{}) - opt.Addr = "unix-test-10000" + opt.Addr = nextUnixSock() opt.Network = "unix" engine := route.NewEngine(opt) @@ -92,12 +142,7 @@ func TestCloseIdleConnections(t *testing.T) { defer func() { engine.Close() }() - for { - time.Sleep(1 * time.Second) - if engine.IsRunning() { - break - } - } + waitEngineRunning(engine) c, _ := NewClient(WithDialer(newMockDialerWithCustomFunc(opt.Network, opt.Addr, 1*time.Second, nil))) @@ -124,13 +169,21 @@ func TestCloseIdleConnections(t *testing.T) { if conns := connsLen(); conns > 0 { t.Errorf("expected 0 conns got %d", conns) } + + c.mClean() + + func() { + c.mLock.Lock() + defer c.mLock.Unlock() + if len(c.m) != 0 { + t.Errorf("expected 0 conns got %d", len(c.m)) + } + }() } func TestClientInvalidURI(t *testing.T) { - t.Parallel() - opt := config.NewOptions([]config.Option{}) - opt.Addr = "unix-test-10001" + opt.Addr = nextUnixSock() opt.Network = "unix" requests := int64(0) engine := route.NewEngine(opt) @@ -141,12 +194,7 @@ func TestClientInvalidURI(t *testing.T) { defer func() { engine.Close() }() - for { - time.Sleep(1 * time.Second) - if engine.IsRunning() { - break - } - } + waitEngineRunning(engine) c, _ := NewClient(WithDialer(newMockDialerWithCustomFunc(opt.Network, opt.Addr, 1*time.Second, nil))) req, res := protocol.AcquireRequest(), protocol.AcquireResponse() @@ -166,10 +214,8 @@ func TestClientInvalidURI(t *testing.T) { } func TestClientGetWithBody(t *testing.T) { - t.Parallel() - opt := config.NewOptions([]config.Option{}) - opt.Addr = "unix-test-10002" + opt.Addr = nextUnixSock() opt.Network = "unix" engine := route.NewEngine(opt) engine.GET("/", func(c context.Context, ctx *app.RequestContext) { @@ -180,12 +226,7 @@ func TestClientGetWithBody(t *testing.T) { defer func() { engine.Close() }() - for { - time.Sleep(1 * time.Second) - if engine.IsRunning() { - break - } - } + waitEngineRunning(engine) c, _ := NewClient(WithDialer(newMockDialerWithCustomFunc(opt.Network, opt.Addr, 1*time.Second, nil))) req, res := protocol.AcquireRequest(), protocol.AcquireResponse() @@ -206,10 +247,8 @@ func TestClientGetWithBody(t *testing.T) { } func TestClientPostBodyStream(t *testing.T) { - t.Parallel() - opt := config.NewOptions([]config.Option{}) - opt.Addr = "unix-test-10102" + opt.Addr = nextUnixSock() opt.Network = "unix" engine := route.NewEngine(opt) engine.POST("/", func(c context.Context, ctx *app.RequestContext) { @@ -220,12 +259,7 @@ func TestClientPostBodyStream(t *testing.T) { defer func() { engine.Close() }() - for { - time.Sleep(1 * time.Second) - if engine.IsRunning() { - break - } - } + waitEngineRunning(engine) cStream, _ := NewClient(WithDialer(newMockDialerWithCustomFunc(opt.Network, opt.Addr, 1*time.Second, nil)), WithResponseBodyStream(true)) args := &protocol.Args{} @@ -244,8 +278,6 @@ func TestClientPostBodyStream(t *testing.T) { } func TestClientURLAuth(t *testing.T) { - t.Parallel() - cases := map[string]string{ "foo:bar@": "Basic Zm9vOmJhcg==", "foo:@": "Basic Zm9vOg==", @@ -256,7 +288,7 @@ func TestClientURLAuth(t *testing.T) { ch := make(chan string, 1) opt := config.NewOptions([]config.Option{}) - opt.Addr = "unix-test-10003" + opt.Addr = nextUnixSock() opt.Network = "unix" engine := route.NewEngine(opt) engine.GET("/foo/bar", func(c context.Context, ctx *app.RequestContext) { @@ -266,12 +298,7 @@ func TestClientURLAuth(t *testing.T) { defer func() { engine.Close() }() - for { - time.Sleep(1 * time.Second) - if engine.IsRunning() { - break - } - } + waitEngineRunning(engine) c, _ := NewClient(WithDialer(newMockDialerWithCustomFunc(opt.Network, opt.Addr, 1*time.Second, nil))) for up, expected := range cases { @@ -293,7 +320,7 @@ func TestClientURLAuth(t *testing.T) { func TestClientNilResp(t *testing.T) { opt := config.NewOptions([]config.Option{}) - opt.Addr = "unix-test-10004" + opt.Addr = nextUnixSock() opt.Network = "unix" engine := route.NewEngine(opt) @@ -303,12 +330,7 @@ func TestClientNilResp(t *testing.T) { defer func() { engine.Close() }() - for { - time.Sleep(1 * time.Second) - if engine.IsRunning() { - break - } - } + waitEngineRunning(engine) c, _ := NewClient(WithDialer(newMockDialerWithCustomFunc(opt.Network, opt.Addr, 1*time.Second, nil))) @@ -324,7 +346,6 @@ func TestClientNilResp(t *testing.T) { } func TestClientParseConn(t *testing.T) { - t.Parallel() opt := config.NewOptions([]config.Option{}) opt.Addr = "127.0.0.1:10005" engine := route.NewEngine(opt) @@ -334,12 +355,7 @@ func TestClientParseConn(t *testing.T) { defer func() { engine.Close() }() - for { - time.Sleep(1 * time.Second) - if engine.IsRunning() { - break - } - } + waitEngineRunning(engine) c, _ := NewClient(WithDialer(newMockDialerWithCustomFunc(opt.Network, opt.Addr, 1*time.Second, nil))) req, res := protocol.AcquireRequest(), protocol.AcquireResponse() @@ -365,9 +381,8 @@ func TestClientParseConn(t *testing.T) { } func TestClientPostArgs(t *testing.T) { - t.Parallel() opt := config.NewOptions([]config.Option{}) - opt.Addr = "unix-test-10006" + opt.Addr = nextUnixSock() opt.Network = "unix" engine := route.NewEngine(opt) engine.POST("/", func(c context.Context, ctx *app.RequestContext) { @@ -381,12 +396,8 @@ func TestClientPostArgs(t *testing.T) { defer func() { engine.Close() }() - for { - time.Sleep(1 * time.Second) - if engine.IsRunning() { - break - } - } + waitEngineRunning(engine) + c, _ := NewClient(WithDialer(newMockDialerWithCustomFunc(opt.Network, opt.Addr, 1*time.Second, nil))) req, res := protocol.AcquireRequest(), protocol.AcquireResponse() defer func() { @@ -408,10 +419,8 @@ func TestClientPostArgs(t *testing.T) { } func TestClientHeaderCase(t *testing.T) { - t.Parallel() - opt := config.NewOptions([]config.Option{}) - opt.Addr = "unix-test-10007" + opt.Addr = nextUnixSock() opt.Network = "unix" engine := route.NewEngine(opt) engine.GET("/", func(c context.Context, ctx *app.RequestContext) { @@ -428,7 +437,7 @@ func TestClientHeaderCase(t *testing.T) { defer func() { engine.Close() }() - time.Sleep(time.Second) + waitEngineRunning(engine) c, _ := NewClient(WithDialer(newMockDialerWithCustomFunc(opt.Network, opt.Addr, time.Second, nil)), WithDisableHeaderNamesNormalizing(true)) code, body, err := c.Get(context.Background(), nil, "http://example.com") @@ -442,32 +451,32 @@ func TestClientHeaderCase(t *testing.T) { } func TestClientReadTimeout(t *testing.T) { - if testing.Short() { - t.Skip("skipping test in short mode") - } - opt := config.NewOptions([]config.Option{}) - opt.Addr = "localhost:10024" + opt.Addr = nextUnixSock() + opt.Network = "unix" engine := route.NewEngine(opt) + readtimeout := 50 * time.Millisecond + sleeptime := 75 * time.Millisecond // must > readtimeout + engine.GET("/normal", func(c context.Context, ctx *app.RequestContext) { ctx.String(201, "ok") }) engine.GET("/timeout", func(c context.Context, ctx *app.RequestContext) { - time.Sleep(time.Second * 60) + time.Sleep(sleeptime) ctx.String(202, "timeout ok") }) go engine.Run() defer func() { engine.Close() }() - time.Sleep(time.Second * 1) + waitEngineRunning(engine) c := &http1.HostClient{ ClientOptions: &http1.ClientOptions{ - ReadTimeout: time.Second * 4, + ReadTimeout: readtimeout, RetryConfig: &retry.Config{MaxAttemptTimes: 1}, - Dialer: standard.NewDialer(), + Dialer: newMockDialerWithCustomFunc(opt.Network, opt.Addr, readtimeout, nil), }, Addr: opt.Addr, } @@ -475,7 +484,7 @@ func TestClientReadTimeout(t *testing.T) { req := protocol.AcquireRequest() res := protocol.AcquireResponse() - req.SetRequestURI("http://" + opt.Addr + "/normal") + req.SetRequestURI("http://example.com/normal") req.Header.SetMethod(consts.MethodGet) // Setting Connection: Close will make the connection be returned to the pool. @@ -485,48 +494,36 @@ func TestClientReadTimeout(t *testing.T) { t.Fatal(err) } - protocol.ReleaseRequest(req) - protocol.ReleaseResponse(res) - - done := make(chan struct{}) - go func() { - req := protocol.AcquireRequest() - res := protocol.AcquireResponse() - - req.SetRequestURI("http://" + opt.Addr + "/timeout") - req.Header.SetMethod(consts.MethodGet) - req.SetConnectionClose() + req.Reset() + req.SetRequestURI("http://example.com/timeout") + req.Header.SetMethod(consts.MethodGet) + req.SetConnectionClose() + res.Reset() - if err := c.Do(context.Background(), req, res); !errors.Is(err, errs.ErrTimeout) { - if err == nil { - t.Errorf("expected ErrTimeout got nil, req url: %s, read resp body: %s, status: %d", string(req.URI().FullURI()), string(res.Body()), res.StatusCode()) - } else { - if !strings.Contains(err.Error(), "timeout") { - t.Errorf("expected ErrTimeout got %#v", err) - } + t0 := time.Now() + err := c.Do(context.Background(), req, res) + t1 := time.Now() + if !errors.Is(err, errs.ErrTimeout) { + if err == nil { + t.Errorf("expected ErrTimeout got nil, req url: %s, read resp body: %s, status: %d", string(req.URI().FullURI()), string(res.Body()), res.StatusCode()) + } else { + if !strings.Contains(err.Error(), "timeout") { + t.Errorf("expected ErrTimeout got %#v", err) } } - - protocol.ReleaseRequest(req) - protocol.ReleaseResponse(res) - close(done) - }() - - select { - case <-done: - // It is abnormal when waiting time exceeds the value of readTimeout times the number of retries. - // Give it extra 2 seconds just to be sure. - case <-time.After(c.ReadTimeout*time.Duration(c.RetryConfig.MaxAttemptTimes) + time.Second*2): - t.Fatal("Client.ReadTimeout didn't work") + } + protocol.ReleaseRequest(req) + protocol.ReleaseResponse(res) + if d := t1.Sub(t0) - readtimeout; d > readtimeout/2 { + t.Errorf("timeout more than expected: %v", d) + } else { + t.Log("latency", d) } } func TestClientDefaultUserAgent(t *testing.T) { - t.Parallel() - opt := config.NewOptions([]config.Option{}) - - opt.Addr = "unix-test-10009" + opt.Addr = nextUnixSock() opt.Network = "unix" engine := route.NewEngine(opt) @@ -537,12 +534,7 @@ func TestClientDefaultUserAgent(t *testing.T) { defer func() { engine.Close() }() - for { - time.Sleep(1 * time.Second) - if engine.IsRunning() { - break - } - } + waitEngineRunning(engine) c, _ := NewClient(WithDialer(newMockDialerWithCustomFunc(opt.Network, opt.Addr, 1*time.Second, nil))) req := protocol.AcquireRequest() @@ -561,11 +553,8 @@ func TestClientDefaultUserAgent(t *testing.T) { } func TestClientSetUserAgent(t *testing.T) { - t.Parallel() - opt := config.NewOptions([]config.Option{}) - - opt.Addr = "unix-test-10010" + opt.Addr = nextUnixSock() opt.Network = "unix" engine := route.NewEngine(opt) @@ -576,12 +565,7 @@ func TestClientSetUserAgent(t *testing.T) { defer func() { engine.Close() }() - for { - time.Sleep(1 * time.Second) - if engine.IsRunning() { - break - } - } + waitEngineRunning(engine) userAgent := "I'm not hertz" c, _ := NewClient(WithDialer(newMockDialerWithCustomFunc(opt.Network, opt.Addr, time.Second, nil)), WithName(userAgent)) @@ -601,7 +585,7 @@ func TestClientSetUserAgent(t *testing.T) { func TestClientNoUserAgent(t *testing.T) { opt := config.NewOptions([]config.Option{}) - opt.Addr = "unix-test-10011" + opt.Addr = nextUnixSock() opt.Network = "unix" engine := route.NewEngine(opt) @@ -612,12 +596,8 @@ func TestClientNoUserAgent(t *testing.T) { defer func() { engine.Close() }() - for { - time.Sleep(1 * time.Second) - if engine.IsRunning() { - break - } - } + waitEngineRunning(engine) + c, _ := NewClient(WithDialer(newMockDialerWithCustomFunc(opt.Network, opt.Addr, time.Second, nil)), WithDialTimeout(1*time.Second), WithNoDefaultUserAgentHeader(true)) req := protocol.AcquireRequest() @@ -635,8 +615,6 @@ func TestClientNoUserAgent(t *testing.T) { } func TestClientDoWithCustomHeaders(t *testing.T) { - t.Parallel() - ch := make(chan error) uri := "/foo/bar/baz?a=b&cd=12" headers := map[string]string{ @@ -647,8 +625,7 @@ func TestClientDoWithCustomHeaders(t *testing.T) { } body := "request body" opt := config.NewOptions([]config.Option{}) - - opt.Addr = "unix-test-10012" + opt.Addr = nextUnixSock() opt.Network = "unix" engine := route.NewEngine(opt) @@ -698,12 +675,7 @@ func TestClientDoWithCustomHeaders(t *testing.T) { defer func() { engine.Close() }() - for { - time.Sleep(1 * time.Second) - if engine.IsRunning() { - break - } - } + waitEngineRunning(engine) // make sure that the client sends all the request headers and body. c, _ := NewClient(WithDialer(newMockDialerWithCustomFunc(opt.Network, opt.Addr, 1*time.Second, nil))) @@ -731,10 +703,8 @@ func TestClientDoWithCustomHeaders(t *testing.T) { } func TestClientDoTimeoutDisablePathNormalizing(t *testing.T) { - t.Parallel() opt := config.NewOptions([]config.Option{}) - - opt.Addr = "unix-test-10013" + opt.Addr = nextUnixSock() opt.Network = "unix" engine := route.NewEngine(opt) @@ -748,12 +718,8 @@ func TestClientDoTimeoutDisablePathNormalizing(t *testing.T) { defer func() { engine.Close() }() - for { - time.Sleep(1 * time.Second) - if engine.IsRunning() { - break - } - } + waitEngineRunning(engine) + c, _ := NewClient(WithDialer(newMockDialerWithCustomFunc(opt.Network, opt.Addr, time.Second, nil)), WithDisablePathNormalizing(true)) urlWithEncodedPath := "http://example.com/encoded/Y%2BY%2FY%3D/stuff" @@ -773,14 +739,11 @@ func TestClientDoTimeoutDisablePathNormalizing(t *testing.T) { } func TestHostClientPendingRequests(t *testing.T) { - t.Parallel() - const concurrency = 10 doneCh := make(chan struct{}) readyCh := make(chan struct{}, concurrency) opt := config.NewOptions([]config.Option{}) - - opt.Addr = "unix-test-10014" + opt.Addr = nextUnixSock() opt.Network = "unix" engine := route.NewEngine(opt) @@ -792,7 +755,7 @@ func TestHostClientPendingRequests(t *testing.T) { defer func() { engine.Close() }() - time.Sleep(time.Second) + waitEngineRunning(engine) c := &http1.HostClient{ ClientOptions: &http1.ClientOptions{ @@ -867,8 +830,7 @@ func TestHostClientMaxConnsWithDeadline(t *testing.T) { wg sync.WaitGroup ) opt := config.NewOptions([]config.Option{}) - - opt.Addr = "unix-test-10015" + opt.Addr = nextUnixSock() opt.Network = "unix" engine := route.NewEngine(opt) @@ -883,12 +845,7 @@ func TestHostClientMaxConnsWithDeadline(t *testing.T) { defer func() { engine.Close() }() - for { - time.Sleep(1 * time.Second) - if engine.IsRunning() { - break - } - } + waitEngineRunning(engine) c := &http1.HostClient{ ClientOptions: &http1.ClientOptions{ @@ -912,7 +869,7 @@ func TestHostClientMaxConnsWithDeadline(t *testing.T) { for { if err := c.DoDeadline(context.Background(), req, resp, time.Now().Add(timeout)); err != nil { if err.Error() == errs.ErrNoFreeConns.Error() { - time.Sleep(time.Millisecond * 500) + time.Sleep(10 * time.Millisecond) continue } t.Errorf("unexpected error: %s", err) @@ -938,12 +895,9 @@ func TestHostClientMaxConnsWithDeadline(t *testing.T) { } func TestHostClientMaxConnDuration(t *testing.T) { - t.Parallel() - connectionCloseCount := uint32(0) opt := config.NewOptions([]config.Option{}) - - opt.Addr = "unix-test-10016" + opt.Addr = nextUnixSock() opt.Network = "unix" engine := route.NewEngine(opt) @@ -957,12 +911,7 @@ func TestHostClientMaxConnDuration(t *testing.T) { defer func() { engine.Close() }() - for { - time.Sleep(1 * time.Second) - if engine.IsRunning() { - break - } - } + waitEngineRunning(engine) c := &http1.HostClient{ ClientOptions: &http1.ClientOptions{ @@ -992,10 +941,8 @@ func TestHostClientMaxConnDuration(t *testing.T) { } func TestHostClientMultipleAddrs(t *testing.T) { - t.Parallel() opt := config.NewOptions([]config.Option{}) - - opt.Addr = "unix-test-10017" + opt.Addr = nextUnixSock() opt.Network = "unix" engine := route.NewEngine(opt) @@ -1007,12 +954,7 @@ func TestHostClientMultipleAddrs(t *testing.T) { defer func() { engine.Close() }() - for { - time.Sleep(1 * time.Second) - if engine.IsRunning() { - break - } - } + waitEngineRunning(engine) dialsCount := make(map[string]int) c := &http1.HostClient{ @@ -1048,10 +990,8 @@ func TestHostClientMultipleAddrs(t *testing.T) { } func TestClientFollowRedirects(t *testing.T) { - t.Parallel() opt := config.NewOptions([]config.Option{}) - - opt.Addr = "unix-test-10018" + opt.Addr = nextUnixSock() opt.Network = "unix" engine := route.NewEngine(opt) @@ -1079,7 +1019,7 @@ func TestClientFollowRedirects(t *testing.T) { defer func() { engine.Close() }() - time.Sleep(time.Second * 2) + waitEngineRunning(engine) c := &http1.HostClient{ ClientOptions: &http1.ClientOptions{ @@ -1157,8 +1097,7 @@ func TestHostClientMaxConnWaitTimeoutSuccess(t *testing.T) { wg sync.WaitGroup ) opt := config.NewOptions([]config.Option{}) - - opt.Addr = "unix-test-10019" + opt.Addr = nextUnixSock() opt.Network = "unix" engine := route.NewEngine(opt) @@ -1173,12 +1112,7 @@ func TestHostClientMaxConnWaitTimeoutSuccess(t *testing.T) { defer func() { engine.Close() }() - for { - time.Sleep(1 * time.Second) - if engine.IsRunning() { - break - } - } + waitEngineRunning(engine) c := &http1.HostClient{ ClientOptions: &http1.ClientOptions{ @@ -1231,8 +1165,7 @@ func TestHostClientMaxConnWaitTimeoutError(t *testing.T) { wg sync.WaitGroup ) opt := config.NewOptions([]config.Option{}) - - opt.Addr = "unix-test-10020" + opt.Addr = nextUnixSock() opt.Network = "unix" engine := route.NewEngine(opt) @@ -1247,12 +1180,7 @@ func TestHostClientMaxConnWaitTimeoutError(t *testing.T) { defer func() { engine.Close() }() - for { - time.Sleep(1 * time.Second) - if engine.IsRunning() { - break - } - } + waitEngineRunning(engine) c := &http1.HostClient{ ClientOptions: &http1.ClientOptions{ @@ -1317,7 +1245,7 @@ func TestNewClient(t *testing.T) { defer func() { engine.Close() }() - time.Sleep(1 * time.Second) + waitEngineRunning(engine) client, err := NewClient(WithDialTimeout(2 * time.Second)) if err != nil { @@ -1345,7 +1273,7 @@ func TestUseShortConnection(t *testing.T) { defer func() { engine.Close() }() - time.Sleep(1 * time.Second) + waitEngineRunning(engine) c, _ := NewClient(WithKeepAlive(false)) var wg sync.WaitGroup @@ -1392,8 +1320,8 @@ func TestPostWithFormData(t *testing.T) { defer func() { engine.Close() }() + waitEngineRunning(engine) - time.Sleep(1 * time.Second) client, _ := NewClient() req := protocol.AcquireRequest() rsp := protocol.AcquireResponse() @@ -1446,8 +1374,8 @@ func TestPostWithMultipartField(t *testing.T) { defer func() { engine.Close() }() + waitEngineRunning(engine) - time.Sleep(1 * time.Second) client, _ := NewClient() req := protocol.AcquireRequest() rsp := protocol.AcquireResponse() @@ -1492,8 +1420,8 @@ func TestSetFiles(t *testing.T) { defer func() { engine.Close() }() + waitEngineRunning(engine) - time.Sleep(1 * time.Second) client, _ := NewClient() req := protocol.AcquireRequest() rsp := protocol.AcquireResponse() @@ -1543,8 +1471,8 @@ func TestSetMultipartFields(t *testing.T) { defer func() { engine.Close() }() + waitEngineRunning(engine) - time.Sleep(1 * time.Second) client, _ := NewClient(WithDialTimeout(50 * time.Millisecond)) req := protocol.AcquireRequest() rsp := protocol.AcquireResponse() @@ -1598,7 +1526,7 @@ func TestClientReadResponseBodyStream(t *testing.T) { defer func() { engine.Close() }() - time.Sleep(1 * time.Second) + waitEngineRunning(engine) client, _ := NewClient(WithResponseBodyStream(true)) req, resp := protocol.AcquireRequest(), protocol.AcquireResponse() @@ -1651,7 +1579,8 @@ func TestWithBasicAuth(t *testing.T) { defer func() { engine.Close() }() - time.Sleep(1 * time.Second) + waitEngineRunning(engine) + client, _ := NewClient() req := protocol.AcquireRequest() rsp := protocol.AcquireResponse() @@ -2022,7 +1951,7 @@ func TestClientReadResponseBodyStreamWithDoubleRequest(t *testing.T) { defer func() { engine.Close() }() - time.Sleep(1 * time.Second) + waitEngineRunning(engine) client, _ := NewClient(WithResponseBodyStream(true)) req, resp := protocol.AcquireRequest(), protocol.AcquireResponse() @@ -2095,7 +2024,7 @@ func TestClientReadResponseBodyStreamWithConnectionClose(t *testing.T) { defer func() { engine.Close() }() - time.Sleep(1 * time.Second) + waitEngineRunning(engine) client, _ := NewClient(WithResponseBodyStream(true)) @@ -2150,7 +2079,6 @@ func (m *mockDialer) DialConnection(network, address string, timeout time.Durati } func TestClientRetry(t *testing.T) { - t.Parallel() client, err := NewClient( // Default dial function performs different in different os. So unit the performance of dial function. WithDialFunc(func(addr string) (network.Conn, error) { @@ -2368,14 +2296,11 @@ func TestClientDialerName(t *testing.T) { } func TestClientDoWithDialFunc(t *testing.T) { - t.Parallel() - ch := make(chan error, 1) uri := "/foo/bar/baz" body := "request body" opt := config.NewOptions([]config.Option{}) - - opt.Addr = "unix-test-10021" + opt.Addr = nextUnixSock() opt.Network = "unix" engine := route.NewEngine(opt) @@ -2405,7 +2330,7 @@ func TestClientDoWithDialFunc(t *testing.T) { defer func() { engine.Close() }() - time.Sleep(1 * time.Second) + waitEngineRunning(engine) c, _ := NewClient(WithDialFunc(func(addr string) (network.Conn, error) { return dialer.DialConnection(opt.Network, opt.Addr, time.Second, nil) @@ -2442,11 +2367,13 @@ func TestClientState(t *testing.T) { defer func() { engine.Close() }() + waitEngineRunning(engine) - time.Sleep(1 * time.Second) - + var wg sync.WaitGroup + wg.Add(2) state := int32(0) client, _ := NewClient( + WithMaxIdleConnDuration(75*time.Millisecond), WithConnStateObserve(func(hcs config.HostClientState) { switch atomic.LoadInt32(&state) { case int32(0): @@ -2454,19 +2381,18 @@ func TestClientState(t *testing.T) { assert.DeepEqual(t, 1, hcs.ConnPoolState().PoolConnNum) assert.DeepEqual(t, "127.0.0.1:10037", hcs.ConnPoolState().Addr) atomic.StoreInt32(&state, int32(1)) + wg.Done() case int32(1): assert.DeepEqual(t, 0, hcs.ConnPoolState().TotalConnNum) assert.DeepEqual(t, 0, hcs.ConnPoolState().PoolConnNum) assert.DeepEqual(t, "127.0.0.1:10037", hcs.ConnPoolState().Addr) atomic.StoreInt32(&state, int32(2)) - return - case int32(2): - t.Fatal("It shouldn't go to here") + wg.Done() } - }, time.Second*9)) - + }, 50*time.Millisecond)) client.Get(context.Background(), nil, "http://127.0.0.1:10037") - time.Sleep(time.Second * 22) + wg.Wait() + assert.DeepEqual(t, int32(2), atomic.LoadInt32(&state)) } func TestClientRetryErr(t *testing.T) { @@ -2486,7 +2412,8 @@ func TestClientRetryErr(t *testing.T) { defer func() { engine.Close() }() - time.Sleep(1 * time.Second) + waitEngineRunning(engine) + c, _ := NewClient(WithRetryConfig(retry.WithMaxAttemptTimes(3))) _, _, err := c.Get(context.Background(), nil, "http://127.0.0.1:10136/ping") assert.Nil(t, err) @@ -2511,7 +2438,8 @@ func TestClientRetryErr(t *testing.T) { defer func() { engine.Close() }() - time.Sleep(1 * time.Second) + waitEngineRunning(engine) + c, _ := NewClient(WithRetryConfig(retry.WithMaxAttemptTimes(3))) c.SetRetryIfFunc(func(req *protocol.Request, resp *protocol.Response, err error) bool { return resp.StatusCode() == 502 diff --git a/pkg/app/client/loadbalance/lbcache_test.go b/pkg/app/client/loadbalance/lbcache_test.go index c1e314f59..4e3fe6f62 100644 --- a/pkg/app/client/loadbalance/lbcache_test.go +++ b/pkg/app/client/loadbalance/lbcache_test.go @@ -19,6 +19,7 @@ package loadbalance import ( "context" "fmt" + "strconv" "sync/atomic" "testing" "time" @@ -114,9 +115,11 @@ func TestBalancerRefresh(t *testing.T) { }, NameFunc: func() string { return t.Name() }, } + opts := DefaultLbOpts + opts.RefreshInterval = 30 * time.Millisecond blf := NewBalancerFactory(Config{ Balancer: NewWeightedBalancer(), - LbOpts: DefaultLbOpts, + LbOpts: opts, Resolver: r, }) req := &protocol.Request{} @@ -128,12 +131,44 @@ func TestBalancerRefresh(t *testing.T) { addr, err = blf.GetInstance(context.Background(), req) assert.Assert(t, err == nil, err) assert.Assert(t, addr.Address().String() == "127.0.0.1:8888") - time.Sleep(6 * time.Second) + time.Sleep(2 * opts.RefreshInterval) addr, err = blf.GetInstance(context.Background(), req) assert.Assert(t, err == nil, err) assert.Assert(t, addr.Address().String() == "127.0.0.1:8889") } +func TestBalancerExpires(t *testing.T) { + n := int32(1000) + r := &discovery.SynthesizedResolver{ + TargetFunc: func(ctx context.Context, target *discovery.TargetInfo) string { + return target.Host + }, + ResolveFunc: func(ctx context.Context, key string) (discovery.Result, error) { + ins := discovery.NewInstance("tcp", "127.0.0.1:"+strconv.Itoa(int(atomic.AddInt32(&n, 1))), 10, nil) + return discovery.Result{CacheKey: "svc1", Instances: []discovery.Instance{ins}}, nil + }, + NameFunc: func() string { return t.Name() }, + } + opts := DefaultLbOpts + opts.ExpireInterval = 30 * time.Millisecond + blf := NewBalancerFactory(Config{ + Balancer: NewWeightedBalancer(), + LbOpts: opts, + Resolver: r, + }) + req := &protocol.Request{} + req.SetHost("svc1") + addr1, err := blf.GetInstance(context.Background(), req) + assert.Assert(t, err == nil, err) + addr2, err := blf.GetInstance(context.Background(), req) + assert.Assert(t, err == nil, err) + assert.Assert(t, addr1.Address().String() == addr2.Address().String()) + time.Sleep(3 * opts.ExpireInterval) + addr3, err := blf.GetInstance(context.Background(), req) + assert.Assert(t, err == nil, err) + assert.Assert(t, addr3.Address().String() != addr2.Address().String()) +} + func TestCacheKey(t *testing.T) { uniqueKey := cacheKey("hello", "world", Options{RefreshInterval: 15 * time.Second, ExpireInterval: 5 * time.Minute}) assert.Assert(t, uniqueKey == "hello|world|{15s 5m0s}") diff --git a/pkg/app/client/loadbalance/weight_random_test.go b/pkg/app/client/loadbalance/weight_random_test.go index 1cab4cb0f..720d06250 100644 --- a/pkg/app/client/loadbalance/weight_random_test.go +++ b/pkg/app/client/loadbalance/weight_random_test.go @@ -55,12 +55,11 @@ func TestWeightedBalancer(t *testing.T) { // multi instances, weightSum > 0 insList = []discovery.Instance{ - discovery.NewInstance("tcp", "127.0.0.1:8881", 10, nil), - discovery.NewInstance("tcp", "127.0.0.1:8882", 20, nil), - discovery.NewInstance("tcp", "127.0.0.1:8883", 50, nil), - discovery.NewInstance("tcp", "127.0.0.1:8884", 100, nil), - discovery.NewInstance("tcp", "127.0.0.1:8885", 200, nil), - discovery.NewInstance("tcp", "127.0.0.1:8886", 500, nil), + discovery.NewInstance("tcp", "127.0.0.1:8881", 100, nil), + discovery.NewInstance("tcp", "127.0.0.1:8882", 200, nil), + discovery.NewInstance("tcp", "127.0.0.1:8883", 300, nil), + discovery.NewInstance("tcp", "127.0.0.1:8884", 400, nil), + discovery.NewInstance("tcp", "127.0.0.1:8885", 500, nil), } var weightSum int @@ -69,7 +68,7 @@ func TestWeightedBalancer(t *testing.T) { weightSum += weight } - n := 10000000 + n := 1000000 pickedStat := map[int]int{} e = discovery.Result{ Instances: insList, @@ -91,7 +90,7 @@ func TestWeightedBalancer(t *testing.T) { expect := float64(weight) / float64(weightSum) * float64(n) actual := float64(pickedStat[weight]) delta := math.Abs(expect - actual) - assert.DeepEqual(t, true, delta/expect < 0.01) + assert.DeepEqual(t, true, delta/expect < 0.05) } // have instances that weight < 0 diff --git a/pkg/app/server/hertz_test.go b/pkg/app/server/hertz_test.go index 09e77dd40..b66043d29 100644 --- a/pkg/app/server/hertz_test.go +++ b/pkg/app/server/hertz_test.go @@ -51,6 +51,35 @@ import ( "github.com/cloudwego/hertz/pkg/route/param" ) +type routeEngine interface { + IsRunning() bool + GetOptions() *config.Options +} + +func waitEngineRunning(e routeEngine) { + for i := 0; i < 100; i++ { + if e.IsRunning() { + break + } + time.Sleep(10 * time.Millisecond) + } + opts := e.GetOptions() + network, addr := opts.Network, opts.Addr + if network == "" { + network = "tcp" + } + for i := 0; i < 100; i++ { + conn, err := net.Dial(network, addr) + if err != nil { + time.Sleep(10 * time.Millisecond) + continue + } + conn.Close() + return + } + panic("not running") +} + func TestHertz_Run(t *testing.T) { hertz := Default(WithHostPorts("127.0.0.1:6666")) hertz.GET("/test", func(c context.Context, ctx *app.RequestContext) { @@ -67,7 +96,7 @@ func TestHertz_Run(t *testing.T) { assert.Assert(t, len(hertz.Handlers) == 1) go hertz.Spin() - time.Sleep(100 * time.Millisecond) + waitEngineRunning(hertz) hertz.Close() resp, err := http.Get("http://127.0.0.1:6666/test") @@ -77,9 +106,12 @@ func TestHertz_Run(t *testing.T) { } func TestHertz_GracefulShutdown(t *testing.T) { + handling := make(chan struct{}) + closing := make(chan struct{}) engine := New(WithHostPorts("127.0.0.1:6667")) engine.GET("/test", func(c context.Context, ctx *app.RequestContext) { - time.Sleep(time.Second * 2) + close(handling) + <-closing path := ctx.Request.URI().PathOriginal() ctx.SetBodyString(string(path)) }) @@ -95,12 +127,11 @@ func TestHertz_GracefulShutdown(t *testing.T) { atomic.StoreUint32(&testint2, 2) }) engine.Engine.OnShutdown = append(engine.OnShutdown, func(ctx context.Context) { - time.Sleep(2 * time.Second) atomic.StoreUint32(&testint3, 3) }) go engine.Spin() - time.Sleep(time.Millisecond) + waitEngineRunning(engine) hc := http.Client{Timeout: time.Second} var err error @@ -108,7 +139,7 @@ func TestHertz_GracefulShutdown(t *testing.T) { ch := make(chan struct{}) ch2 := make(chan struct{}) go func() { - ticker := time.NewTicker(time.Millisecond * 100) + ticker := time.NewTicker(10 * time.Millisecond) defer ticker.Stop() for range ticker.C { t.Logf("[%v]begin listening\n", time.Now()) @@ -127,14 +158,16 @@ func TestHertz_GracefulShutdown(t *testing.T) { ch <- struct{}{} }() - time.Sleep(time.Second * 1) + <-handling + start := time.Now() - ctx, cancel := context.WithTimeout(context.Background(), time.Second*3) + ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond) t.Logf("[%v]begin shutdown\n", start) engine.Shutdown(ctx) end := time.Now() t.Logf("[%v]end shutdown\n", end) + close(closing) <-ch assert.Nil(t, err) assert.NotNil(t, resp) @@ -161,7 +194,8 @@ func TestLoadHTMLGlob(t *testing.T) { defer func() { engine.Close() }() - time.Sleep(1 * time.Second) + waitEngineRunning(engine) + resp, _ := http.Get("http://127.0.0.1:8893/index") assert.DeepEqual(t, consts.StatusOK, resp.StatusCode) b := make([]byte, 100) @@ -188,7 +222,8 @@ func TestLoadHTMLFiles(t *testing.T) { defer func() { engine.Close() }() - time.Sleep(1 * time.Second) + waitEngineRunning(engine) + resp, _ := http.Get("http://127.0.0.1:8891/raw") assert.DeepEqual(t, consts.StatusOK, resp.StatusCode) b := make([]byte, 100) @@ -233,7 +268,8 @@ func TestServer_Run(t *testing.T) { ctx.Redirect(consts.StatusMovedPermanently, []byte("http://127.0.0.1:8899/test")) }) go hertz.Run() - time.Sleep(1 * time.Second) + waitEngineRunning(hertz) + resp, err := http.Get("http://127.0.0.1:8899/test") assert.Nil(t, err) assert.DeepEqual(t, consts.StatusOK, resp.StatusCode) @@ -269,7 +305,7 @@ func TestNotAbsolutePath(t *testing.T) { defer func() { engine.Close() }() - time.Sleep(1 * time.Second) + waitEngineRunning(engine) s := "POST ?a=b HTTP/1.1\r\nHost: a.b.c\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343" zr := mock.NewZeroCopyReader(s) @@ -311,7 +347,7 @@ func TestNotAbsolutePathWithRawPath(t *testing.T) { defer func() { engine.Close() }() - time.Sleep(1 * time.Second) + waitEngineRunning(engine) s := "POST ?a=b HTTP/1.1\r\nHost: a.b.c\r\nContent-Length: 5\r\nContent-Type: foo/bar\r\n\r\nabcdef4343" zr := mock.NewZeroCopyReader(s) @@ -389,7 +425,8 @@ func TestWithBasePath(t *testing.T) { defer func() { engine.Close() }() - time.Sleep(1 * time.Second) + waitEngineRunning(engine) + var r http.Request r.ParseForm() r.Form.Add("xxxxxx", "xxx") @@ -407,7 +444,8 @@ func TestNotEnoughBodySize(t *testing.T) { defer func() { engine.Close() }() - time.Sleep(1 * time.Second) + waitEngineRunning(engine) + var r http.Request r.ParseForm() r.Form.Add("xxxxxx", "xxx") @@ -427,7 +465,8 @@ func TestEnoughBodySize(t *testing.T) { defer func() { engine.Close() }() - time.Sleep(1 * time.Second) + waitEngineRunning(engine) + var r http.Request r.ParseForm() r.Form.Add("xxxxxx", "xxx") @@ -549,18 +588,19 @@ func TestParamInconsist(t *testing.T) { } }) go h.Run() - time.Sleep(time.Millisecond * 50) + waitEngineRunning(h) + client, _ := c.NewClient() wg := sync.WaitGroup{} tr := func() { defer wg.Done() - for i := 0; i < 5000; i++ { + for i := 0; i < 500; i++ { client.Get(context.Background(), nil, "http://localhost:10091/test1") } } ti := func() { defer wg.Done() - for i := 0; i < 5000; i++ { + for i := 0; i < 500; i++ { client.Get(context.Background(), nil, "http://localhost:10091/test2") } } @@ -580,7 +620,8 @@ func TestDuplicateReleaseBodyStream(t *testing.T) { c.Response.SetBodyStream(stream, -1) }) go h.Spin() - time.Sleep(time.Second) + waitEngineRunning(h) + client, _ := c.NewClient(c.WithMaxConnsPerHost(1000000), c.WithDialTimeout(time.Minute)) bodyBytes := make([]byte, 102388) index := 0 @@ -616,6 +657,8 @@ func TestDuplicateReleaseBodyStream(t *testing.T) { } func TestServiceRegisterFailed(t *testing.T) { + t.Parallel() // slow test, make it parallel + mockRegErr := errors.New("mock register error") var rCount int32 var drCount int32 @@ -634,39 +677,50 @@ func TestServiceRegisterFailed(t *testing.T) { opts = append(opts, WithHostPorts("127.0.0.1:9222")) srv := New(opts...) srv.Spin() - time.Sleep(2 * time.Second) assert.Assert(t, atomic.LoadInt32(&rCount) == 1) } func TestServiceDeregisterFailed(t *testing.T) { + t.Parallel() // slow test, make it parallel + mockDeregErr := errors.New("mock deregister error") + + var wg sync.WaitGroup + wg.Add(2) // RegisterFunc && DeregisterFunc var rCount int32 var drCount int32 mockRegistry := MockRegistry{ RegisterFunc: func(info *registry.Info) error { + defer wg.Done() atomic.AddInt32(&rCount, 1) return nil }, DeregisterFunc: func(info *registry.Info) error { + defer wg.Done() atomic.AddInt32(&drCount, 1) return mockDeregErr }, } + var opts []config.Option opts = append(opts, WithRegistry(mockRegistry, nil)) opts = append(opts, WithHostPorts("127.0.0.1:9223")) srv := New(opts...) go srv.Spin() - time.Sleep(1 * time.Second) + waitEngineRunning(srv) + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) defer cancel() _ = srv.Shutdown(ctx) - time.Sleep(1 * time.Second) + + wg.Wait() assert.Assert(t, atomic.LoadInt32(&rCount) == 1) assert.Assert(t, atomic.LoadInt32(&drCount) == 1) } func TestServiceRegistryInfo(t *testing.T) { + t.Parallel() // slow test, make it parallel + registryInfo := ®istry.Info{ Weight: 100, Tags: map[string]string{"aa": "bb"}, @@ -678,15 +732,20 @@ func TestServiceRegistryInfo(t *testing.T) { assert.Assert(t, len(info.Tags) == len(registryInfo.Tags), info.Tags) assert.Assert(t, info.Tags["aa"] == registryInfo.Tags["aa"], info.Tags) } + + var wg sync.WaitGroup + wg.Add(2) // RegisterFunc && DeregisterFunc var rCount int32 var drCount int32 mockRegistry := MockRegistry{ RegisterFunc: func(info *registry.Info) error { + defer wg.Done() checkInfo(info) atomic.AddInt32(&rCount, 1) return nil }, DeregisterFunc: func(info *registry.Info) error { + defer wg.Done() checkInfo(info) atomic.AddInt32(&drCount, 1) return nil @@ -697,28 +756,36 @@ func TestServiceRegistryInfo(t *testing.T) { opts = append(opts, WithHostPorts("127.0.0.1:9225")) srv := New(opts...) go srv.Spin() - time.Sleep(2 * time.Second) + waitEngineRunning(srv) + ctx, cancel := context.WithTimeout(context.Background(), 0) defer cancel() _ = srv.Shutdown(ctx) - time.Sleep(2 * time.Second) + wg.Wait() assert.Assert(t, atomic.LoadInt32(&rCount) == 1) assert.Assert(t, atomic.LoadInt32(&drCount) == 1) } func TestServiceRegistryNoInitInfo(t *testing.T) { + t.Parallel() // slow test, make it parallel + checkInfo := func(info *registry.Info) { assert.Assert(t, info == nil) } + + var wg sync.WaitGroup + wg.Add(2) // RegisterFunc && DeregisterFunc var rCount int32 var drCount int32 mockRegistry := MockRegistry{ RegisterFunc: func(info *registry.Info) error { + defer wg.Done() checkInfo(info) atomic.AddInt32(&rCount, 1) return nil }, DeregisterFunc: func(info *registry.Info) error { + defer wg.Done() checkInfo(info) atomic.AddInt32(&drCount, 1) return nil @@ -729,11 +796,12 @@ func TestServiceRegistryNoInitInfo(t *testing.T) { opts = append(opts, WithHostPorts("127.0.0.1:9227")) srv := New(opts...) go srv.Spin() - time.Sleep(2 * time.Second) + waitEngineRunning(srv) + ctx, cancel := context.WithTimeout(context.Background(), 0) defer cancel() _ = srv.Shutdown(ctx) - time.Sleep(2 * time.Second) + wg.Wait() assert.Assert(t, atomic.LoadInt32(&rCount) == 1) assert.Assert(t, atomic.LoadInt32(&drCount) == 1) } @@ -758,7 +826,8 @@ func TestReuseCtx(t *testing.T) { }) go h.Spin() - time.Sleep(time.Second) + waitEngineRunning(h) + for i := 0; i < 1000; i++ { _, _, err := c.Get(context.Background(), nil, "http://127.0.0.1:9228/ping") assert.Nil(t, err) @@ -770,9 +839,15 @@ type CloseWithoutResetBuffer interface { } func TestOnprepare(t *testing.T) { + n := int32(0) h1 := New( WithHostPorts("localhost:9333"), WithOnConnect(func(ctx context.Context, conn network.Conn) context.Context { + if atomic.AddInt32(&n, 1) == 1 { + // the 1st connection is from waitEngineRunning + conn.Close() + return ctx + } b, err := conn.Peek(3) assert.Nil(t, err) assert.DeepEqual(t, string(b), "GET") @@ -788,7 +863,8 @@ func TestOnprepare(t *testing.T) { }) go h1.Spin() - time.Sleep(time.Second) + waitEngineRunning(h1) + _, _, err := c.Get(context.Background(), nil, "http://127.0.0.1:9333/ping") assert.DeepEqual(t, "the server closed connection before returning the first response byte. Make sure the server returns 'Connection: close' response header before closing the connection", err.Error()) @@ -802,7 +878,8 @@ func TestOnprepare(t *testing.T) { c.JSON(consts.StatusOK, utils.H{"ping": "pong"}) }) go h2.Spin() - time.Sleep(time.Second) + waitEngineRunning(h2) + _, _, err = c.Get(context.Background(), nil, "http://127.0.0.1:9331/ping") if err == nil { t.Fatalf("err should not be nil") @@ -819,7 +896,8 @@ func TestOnprepare(t *testing.T) { c.JSON(consts.StatusOK, utils.H{"ping": "pong"}) }) go h3.Spin() - time.Sleep(time.Second) + waitEngineRunning(h3) + c.Get(context.Background(), nil, "http://127.0.0.1:9231/ping") } @@ -851,7 +929,7 @@ func TestSilentMode(t *testing.T) { ctx.Write([]byte("hello, world")) }) go h.Spin() - time.Sleep(time.Second) + waitEngineRunning(h) d := standard.NewDialer() conn, _ := d.DialConnection("tcp", "127.0.0.1:9232", 0, nil) @@ -886,7 +964,7 @@ func TestHertzDisableHeaderNamesNormalizing(t *testing.T) { }) go h.Spin() - time.Sleep(100 * time.Millisecond) + waitEngineRunning(h) cli, _ := c.NewClient(c.WithDisableHeaderNamesNormalizing(true)) @@ -917,7 +995,8 @@ func TestBindConfig(t *testing.T) { }) go h.Spin() - time.Sleep(100 * time.Millisecond) + waitEngineRunning(h) + hc := http.Client{Timeout: time.Second} _, err := hc.Get("http://127.0.0.1:9332/bind?a=") assert.Nil(t, err) @@ -936,7 +1015,7 @@ func TestBindConfig(t *testing.T) { }) go h2.Spin() - time.Sleep(100 * time.Millisecond) + waitEngineRunning(h2) _, err = hc.Get("http://127.0.0.1:9448/bind?a=") assert.Nil(t, err) @@ -998,7 +1077,8 @@ func TestCustomBinder(t *testing.T) { }) go h.Spin() - time.Sleep(100 * time.Millisecond) + waitEngineRunning(h) + hc := http.Client{Timeout: time.Second} _, err := hc.Get("http://127.0.0.1:9334/bind?a=") assert.Nil(t, err) @@ -1025,7 +1105,8 @@ func TestValidateConfigRegValidateFunc(t *testing.T) { }) go h.Spin() - time.Sleep(100 * time.Millisecond) + waitEngineRunning(h) + hc := http.Client{Timeout: time.Second} _, err := hc.Get("http://127.0.0.1:9229/bind?a=2") assert.Nil(t, err) @@ -1110,7 +1191,8 @@ func TestValidateConfigSetSetErrorFactory(t *testing.T) { }) go h.Spin() - time.Sleep(100 * time.Millisecond) + waitEngineRunning(h) + hc := http.Client{Timeout: time.Second} _, err := hc.Get("http://127.0.0.1:9666/bind?b=1") assert.Nil(t, err) @@ -1136,7 +1218,8 @@ func TestValidateConfigAndBindConfig(t *testing.T) { }) go h.Spin() - time.Sleep(100 * time.Millisecond) + waitEngineRunning(h) + hc := http.Client{Timeout: time.Second} _, err := hc.Get("http://127.0.0.1:9876/bind?a=135") assert.Nil(t, err) @@ -1150,7 +1233,8 @@ func TestWithDisableDefaultDate(t *testing.T) { ) h.GET("/", func(_ context.Context, c *app.RequestContext) {}) go h.Spin() - time.Sleep(100 * time.Millisecond) + waitEngineRunning(h) + hc := http.Client{Timeout: time.Second} r, _ := hc.Get("http://127.0.0.1:8321") //nolint:errcheck assert.DeepEqual(t, "", r.Header.Get("Date")) @@ -1163,7 +1247,8 @@ func TestWithDisableDefaultContentType(t *testing.T) { ) h.GET("/", func(_ context.Context, c *app.RequestContext) {}) go h.Spin() - time.Sleep(100 * time.Millisecond) + waitEngineRunning(h) + hc := http.Client{Timeout: time.Second} r, _ := hc.Get("http://127.0.0.1:8324") //nolint:errcheck assert.DeepEqual(t, "", r.Header.Get("Content-Type")) diff --git a/pkg/app/server/hertz_unix_test.go b/pkg/app/server/hertz_unix_test.go index b1f7d700c..7e1d8d18a 100644 --- a/pkg/app/server/hertz_unix_test.go +++ b/pkg/app/server/hertz_unix_test.go @@ -67,7 +67,10 @@ func TestReusePorts(t *testing.T) { go hb.Run() go hc.Run() go hd.Run() - time.Sleep(time.Second) + waitEngineRunning(ha) + waitEngineRunning(hb) + waitEngineRunning(hc) + waitEngineRunning(hd) client, _ := c.NewClient() for i := 0; i < 1000; i++ { @@ -81,7 +84,7 @@ func TestReusePorts(t *testing.T) { func TestHertz_Spin(t *testing.T) { engine := New(WithHostPorts("127.0.0.1:6668")) engine.GET("/test", func(c context.Context, ctx *app.RequestContext) { - time.Sleep(time.Second * 2) + time.Sleep(40 * time.Millisecond) path := ctx.Request.URI().PathOriginal() ctx.SetBodyString(string(path)) }) @@ -93,7 +96,7 @@ func TestHertz_Spin(t *testing.T) { }) go engine.Spin() - time.Sleep(time.Millisecond) + waitEngineRunning(engine) hc := http.Client{Timeout: time.Second} var err error @@ -101,7 +104,7 @@ func TestHertz_Spin(t *testing.T) { ch := make(chan struct{}) ch2 := make(chan struct{}) go func() { - ticker := time.NewTicker(time.Millisecond * 100) + ticker := time.NewTicker(10 * time.Millisecond) defer ticker.Stop() for range ticker.C { _, err := hc.Get("http://127.0.0.1:6668/test2") @@ -120,7 +123,7 @@ func TestHertz_Spin(t *testing.T) { ch <- struct{}{} }() - time.Sleep(time.Second * 1) + time.Sleep(20 * time.Millisecond) pid := strconv.Itoa(os.Getpid()) cmd := exec.Command("kill", "-SIGHUP", pid) t.Logf("[%v]begin SIGHUP\n", time.Now()) @@ -131,9 +134,9 @@ func TestHertz_Spin(t *testing.T) { <-ch assert.Nil(t, err) assert.NotNil(t, resp) - assert.DeepEqual(t, uint32(1), atomic.LoadUint32(&testint)) <-ch2 + assert.DeepEqual(t, uint32(1), atomic.LoadUint32(&testint)) } func TestWithSenseClientDisconnection(t *testing.T) { @@ -150,15 +153,16 @@ func TestWithSenseClientDisconnection(t *testing.T) { } }) go h.Spin() - time.Sleep(time.Second) + waitEngineRunning(h) + con, err := net.Dial("tcp", "127.0.0.1:6631") assert.Nil(t, err) _, err = con.Write([]byte("GET /ping HTTP/1.1\r\nHost: aa\r\n\r\n")) assert.Nil(t, err) - time.Sleep(time.Second) + time.Sleep(20 * time.Millisecond) assert.DeepEqual(t, atomic.LoadInt32(&closeFlag), int32(0)) assert.Nil(t, con.Close()) - time.Sleep(time.Second) + time.Sleep(20 * time.Millisecond) assert.DeepEqual(t, atomic.LoadInt32(&closeFlag), int32(1)) } @@ -178,14 +182,15 @@ func TestWithSenseClientDisconnectionAndWithOnConnect(t *testing.T) { } }) go h.Spin() - time.Sleep(time.Second) + waitEngineRunning(h) + con, err := net.Dial("tcp", "127.0.0.1:6632") assert.Nil(t, err) _, err = con.Write([]byte("GET /ping HTTP/1.1\r\nHost: aa\r\n\r\n")) assert.Nil(t, err) - time.Sleep(time.Second) + time.Sleep(20 * time.Millisecond) assert.DeepEqual(t, atomic.LoadInt32(&closeFlag), int32(0)) assert.Nil(t, con.Close()) - time.Sleep(time.Second) + time.Sleep(20 * time.Millisecond) assert.DeepEqual(t, atomic.LoadInt32(&closeFlag), int32(1)) } From 74815dd773b3a2a4c2b71bf5065b039741756f6b Mon Sep 17 00:00:00 2001 From: Kyle Xiao Date: Tue, 19 Nov 2024 20:57:23 +0800 Subject: [PATCH 3/6] fix(http1): use bytes.EqualFold for header value check (#1232) --- pkg/protocol/http1/client.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/protocol/http1/client.go b/pkg/protocol/http1/client.go index 149029db0..758d1649d 100644 --- a/pkg/protocol/http1/client.go +++ b/pkg/protocol/http1/client.go @@ -701,7 +701,7 @@ func (c *HostClient) doNonNilReqResp(req *protocol.Request, resp *protocol.Respo shouldCloseConn = resetConnection || req.ConnectionClose() || resp.ConnectionClose() if resp.Header.StatusCode() == consts.StatusSwitchingProtocols && - bytes.Equal(resp.Header.Peek(consts.HeaderConnection), bytestr.StrUpgrade) { + bytes.EqualFold(resp.Header.Peek(consts.HeaderConnection), bytestr.StrUpgrade) { // can not reuse connection in this case, it's no longer http1 protocol. // set BodyStream for (*Response).Hijack resp.SetBodyStream(newUpgradeConn(c, cc), -1) From 7d63572c3d16aabe42bf3e1e078f03c90fcd2c72 Mon Sep 17 00:00:00 2001 From: Kyle Xiao Date: Fri, 22 Nov 2024 15:17:22 +0800 Subject: [PATCH 4/6] refactor(binding): use internal tagexpr (#1234) --- go.mod | 5 +- go.sum | 14 - internal/tagexpr/LICENSE | 201 +++ internal/tagexpr/README.md | 3 + internal/tagexpr/example_test.go | 80 ++ internal/tagexpr/expr.go | 297 +++++ internal/tagexpr/expr_test.go | 248 ++++ internal/tagexpr/handler.go | 148 +++ internal/tagexpr/selector.go | 112 ++ internal/tagexpr/selector_test.go | 30 + internal/tagexpr/spec_func.go | 342 +++++ internal/tagexpr/spec_func_test.go | 104 ++ internal/tagexpr/spec_operand.go | 363 ++++++ internal/tagexpr/spec_operator.go | 290 +++++ internal/tagexpr/spec_range.go | 164 +++ internal/tagexpr/spec_range_test.go | 54 + internal/tagexpr/spec_selector.go | 109 ++ internal/tagexpr/spec_test.go | 162 +++ internal/tagexpr/tagexpr.go | 1225 ++++++++++++++++++ internal/tagexpr/tagexpr_test.go | 855 ++++++++++++ internal/tagexpr/tagparser.go | 190 +++ internal/tagexpr/tagparser_test.go | 93 ++ internal/tagexpr/utils.go | 101 ++ internal/tagexpr/validator/README.md | 204 +++ internal/tagexpr/validator/default.go | 42 + internal/tagexpr/validator/example_test.go | 122 ++ internal/tagexpr/validator/func.go | 116 ++ internal/tagexpr/validator/validator.go | 163 +++ internal/tagexpr/validator/validator_test.go | 354 +++++ pkg/app/server/binding/config.go | 2 +- pkg/app/server/binding/default.go | 2 +- 31 files changed, 6175 insertions(+), 20 deletions(-) create mode 100644 internal/tagexpr/LICENSE create mode 100644 internal/tagexpr/README.md create mode 100644 internal/tagexpr/example_test.go create mode 100644 internal/tagexpr/expr.go create mode 100644 internal/tagexpr/expr_test.go create mode 100644 internal/tagexpr/handler.go create mode 100644 internal/tagexpr/selector.go create mode 100644 internal/tagexpr/selector_test.go create mode 100644 internal/tagexpr/spec_func.go create mode 100644 internal/tagexpr/spec_func_test.go create mode 100644 internal/tagexpr/spec_operand.go create mode 100644 internal/tagexpr/spec_operator.go create mode 100644 internal/tagexpr/spec_range.go create mode 100644 internal/tagexpr/spec_range_test.go create mode 100644 internal/tagexpr/spec_selector.go create mode 100644 internal/tagexpr/spec_test.go create mode 100644 internal/tagexpr/tagexpr.go create mode 100644 internal/tagexpr/tagexpr_test.go create mode 100644 internal/tagexpr/tagparser.go create mode 100644 internal/tagexpr/tagparser_test.go create mode 100644 internal/tagexpr/utils.go create mode 100644 internal/tagexpr/validator/README.md create mode 100644 internal/tagexpr/validator/default.go create mode 100644 internal/tagexpr/validator/example_test.go create mode 100644 internal/tagexpr/validator/func.go create mode 100644 internal/tagexpr/validator/validator.go create mode 100644 internal/tagexpr/validator/validator_test.go diff --git a/go.mod b/go.mod index 65cd910d6..d2950316f 100644 --- a/go.mod +++ b/go.mod @@ -3,12 +3,12 @@ module github.com/cloudwego/hertz go 1.17 require ( - github.com/bytedance/go-tagexpr/v2 v2.9.2 github.com/bytedance/gopkg v0.1.0 github.com/bytedance/mockey v1.2.12 github.com/bytedance/sonic v1.12.0 github.com/cloudwego/netpoll v0.6.4 github.com/fsnotify/fsnotify v1.5.4 + github.com/nyaruka/phonenumbers v1.0.55 github.com/tidwall/gjson v1.14.4 golang.org/x/sync v0.0.0-20210220032951-036812b2e83c golang.org/x/sys v0.24.0 @@ -21,11 +21,8 @@ require ( github.com/cloudwego/iasm v0.2.0 // indirect github.com/golang/protobuf v1.5.0 // indirect github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 // indirect - github.com/henrylee2cn/ameda v1.4.10 // indirect - github.com/henrylee2cn/goutil v0.0.0-20210127050712-89660552f6f8 // indirect github.com/jtolds/gls v4.20.0+incompatible // indirect github.com/klauspost/cpuid/v2 v2.0.9 // indirect - github.com/nyaruka/phonenumbers v1.0.55 // indirect github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d // indirect github.com/smartystreets/goconvey v1.6.4 // indirect github.com/tidwall/match v1.1.1 // indirect diff --git a/go.sum b/go.sum index 98999c738..7e920d459 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,3 @@ -github.com/bytedance/go-tagexpr/v2 v2.9.2 h1:QySJaAIQgOEDQBLS3x9BxOWrnhqu5sQ+f6HaZIxD39I= -github.com/bytedance/go-tagexpr/v2 v2.9.2/go.mod h1:5qsx05dYOiUXOUgnQ7w3Oz8BYs2qtM/bJokdLb79wRM= -github.com/bytedance/gopkg v0.0.0-20240507064146-197ded923ae3/go.mod h1:FtQG3YbQG9L/91pbKSw787yBQPutC+457AvDW77fgUQ= github.com/bytedance/gopkg v0.1.0 h1:aAxB7mm1qms4Wz4sp8e1AtKDOeFLtdqvGiUe7aonRJs= github.com/bytedance/gopkg v0.1.0/go.mod h1:FtQG3YbQG9L/91pbKSw787yBQPutC+457AvDW77fgUQ= github.com/bytedance/mockey v1.2.12 h1:aeszOmGw8CPX8CRx1DZ/Glzb1yXvhjDh6jdFBNZjsU4= @@ -14,8 +11,6 @@ github.com/cloudwego/base64x v0.1.4 h1:jwCgWpFanWmN8xoIUHa2rtzmkd5J2plF/dnLS6Xd/ github.com/cloudwego/base64x v0.1.4/go.mod h1:0zlkT4Wn5C6NdauXdJRhSKRlJvmclQ1hhJgA0rcu/8w= github.com/cloudwego/iasm v0.2.0 h1:1KNIy1I1H9hNNFEEH3DVnI4UujN+1zjpuk6gwHLTssg= github.com/cloudwego/iasm v0.2.0/go.mod h1:8rXZaNYT2n95jn+zTI1sDr+IgcD2GVs0nlbbQPiEFhY= -github.com/cloudwego/netpoll v0.6.2 h1:+KdILv5ATJU+222wNNXpHapYaBeRvvL8qhJyhcxRxrQ= -github.com/cloudwego/netpoll v0.6.2/go.mod h1:kaqvfZ70qd4T2WtIIpCOi5Cxyob8viEpzLhCrTrz3HM= github.com/cloudwego/netpoll v0.6.4 h1:z/dA4sOTUQof6zZIO4QNnLBXsDFFFEos9OOGloR6kno= github.com/cloudwego/netpoll v0.6.4/go.mod h1:BtM+GjKTdwKoC8IOzD08/+8eEn2gYoiNLipFca6BVXQ= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= @@ -30,11 +25,6 @@ github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 h1:EGx4pi6eqNxGaHF6qqu48+N2wcFQ5qg5FXgOdqsJ5d8= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= -github.com/henrylee2cn/ameda v1.4.8/go.mod h1:liZulR8DgHxdK+MEwvZIylGnmcjzQ6N6f2PlWe7nEO4= -github.com/henrylee2cn/ameda v1.4.10 h1:JdvI2Ekq7tapdPsuhrc4CaFiqw6QXFvZIULWJgQyCAk= -github.com/henrylee2cn/ameda v1.4.10/go.mod h1:liZulR8DgHxdK+MEwvZIylGnmcjzQ6N6f2PlWe7nEO4= -github.com/henrylee2cn/goutil v0.0.0-20210127050712-89660552f6f8 h1:yE9ULgp02BhYIrO6sdV/FPe0xQM6fNHkVQW2IAymfM0= -github.com/henrylee2cn/goutil v0.0.0-20210127050712-89660552f6f8/go.mod h1:Nhe/DM3671a5udlv2AdV2ni/MZzgfv2qrPL5nIi3EGQ= github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4= @@ -51,14 +41,11 @@ github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9 github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= -github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/tidwall/gjson v1.9.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM= github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= @@ -92,7 +79,6 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0 google.golang.org/protobuf v1.27.1 h1:SnqbnDw1V7RiZcXPx5MEeqPv2s79L9i7BJUlG/+RurQ= google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/internal/tagexpr/LICENSE b/internal/tagexpr/LICENSE new file mode 100644 index 000000000..5d7fd6bfa --- /dev/null +++ b/internal/tagexpr/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2019 Bytedance Inc. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/internal/tagexpr/README.md b/internal/tagexpr/README.md new file mode 100644 index 000000000..b248ac3b4 --- /dev/null +++ b/internal/tagexpr/README.md @@ -0,0 +1,3 @@ +# go-tagexpr + +originally from https://github.com/bytedance/go-tagexpr diff --git a/internal/tagexpr/example_test.go b/internal/tagexpr/example_test.go new file mode 100644 index 000000000..fab5e87a4 --- /dev/null +++ b/internal/tagexpr/example_test.go @@ -0,0 +1,80 @@ +// Copyright 2019 Bytedance Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tagexpr_test + +import ( + "fmt" + + "github.com/cloudwego/hertz/internal/tagexpr" +) + +func Example() { + type T struct { + A int `tagexpr:"$<0||$>=100"` + B string `tagexpr:"len($)>1 && regexp('^\\w*$')"` + C bool `tagexpr:"expr1:(f.g)$>0 && $; expr2:'C must be true when T.f.g>0'"` + d []string `tagexpr:"@:len($)>0 && $[0]=='D'; msg:sprintf('invalid d: %v',$)"` + e map[string]int `tagexpr:"len($)==$['len']"` + e2 map[string]*int `tagexpr:"len($)==$['len']"` + f struct { + g int `tagexpr:"$"` + } + h int `tagexpr:"$>minVal"` + } + + vm := tagexpr.New("tagexpr") + t := &T{ + A: 107, + B: "abc", + C: true, + d: []string{"x", "y"}, + e: map[string]int{"len": 1}, + e2: map[string]*int{"len": new(int)}, + f: struct { + g int `tagexpr:"$"` + }{1}, + h: 10, + } + + tagExpr, err := vm.Run(t) + if err != nil { + panic(err) + } + + fmt.Println(tagExpr.Eval("A")) + fmt.Println(tagExpr.Eval("B")) + fmt.Println(tagExpr.Eval("C@expr1")) + fmt.Println(tagExpr.Eval("C@expr2")) + if !tagExpr.Eval("d").(bool) { + fmt.Println(tagExpr.Eval("d@msg")) + } + fmt.Println(tagExpr.Eval("e")) + fmt.Println(tagExpr.Eval("e2")) + fmt.Println(tagExpr.Eval("f.g")) + fmt.Println(tagExpr.EvalWithEnv("h", map[string]interface{}{"minVal": 9})) + fmt.Println(tagExpr.EvalWithEnv("h", map[string]interface{}{"minVal": 11})) + + // Output: + // true + // true + // true + // C must be true when T.f.g>0 + // invalid d: [x y] + // true + // false + // 1 + // true + // false +} diff --git a/internal/tagexpr/expr.go b/internal/tagexpr/expr.go new file mode 100644 index 000000000..776a6795d --- /dev/null +++ b/internal/tagexpr/expr.go @@ -0,0 +1,297 @@ +// Copyright 2019 Bytedance Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tagexpr + +import ( + "context" + "fmt" +) + +type variableKeyType string + +const variableKey variableKeyType = "__ENV_KEY__" + +// Expr expression +type Expr struct { + expr ExprNode +} + +// parseExpr parses the expression. +func parseExpr(expr string) (*Expr, error) { + e := newGroupExprNode() + p := &Expr{ + expr: e, + } + s := expr + err := p.parseExprNode(&s, e) + if err != nil { + return nil, err + } + sortPriority(e) + return p, nil +} + +func (p *Expr) parseExprNode(expr *string, e ExprNode) error { + trimLeftSpace(expr) + if *expr == "" { + return nil + } + operand := p.readSelectorExprNode(expr) + if operand == nil { + operand = p.readRangeKvExprNode(expr) + if operand == nil { + var subExprNode *string + operand, subExprNode = readGroupExprNode(expr) + if operand != nil { + err := p.parseExprNode(subExprNode, operand) + if err != nil { + return err + } + } else { + operand = p.parseOperand(expr) + } + } + } + if operand == nil { + return fmt.Errorf("syntax error: %q", *expr) + } + trimLeftSpace(expr) + operator := p.parseOperator(expr) + if operator == nil { + e.SetRightOperand(operand) + operand.SetParent(e) + return nil + } + if _, ok := e.(*groupExprNode); ok { + operator.SetLeftOperand(operand) + operand.SetParent(operator) + e.SetRightOperand(operator) + operator.SetParent(e) + } else { + operator.SetParent(e.Parent()) + operator.Parent().SetRightOperand(operator) + operator.SetLeftOperand(e) + e.SetParent(operator) + e.SetRightOperand(operand) + operand.SetParent(e) + } + return p.parseExprNode(expr, operator) +} + +func (p *Expr) parseOperand(expr *string) (e ExprNode) { + for _, fn := range funcList { + if e = fn(p, expr); e != nil { + return e + } + } + if e = readStringExprNode(expr); e != nil { + return e + } + if e = readDigitalExprNode(expr); e != nil { + return e + } + if e = readBoolExprNode(expr); e != nil { + return e + } + if e = readNilExprNode(expr); e != nil { + return e + } + if e = readVariableExprNode(expr); e != nil { + return e + } + return nil +} + +func (*Expr) parseOperator(expr *string) (e ExprNode) { + s := *expr + if len(s) < 2 { + return nil + } + defer func() { + if e != nil && *expr == s { + *expr = (*expr)[2:] + } + }() + a := s[:2] + switch a { + // case "<<": + // case ">>": + // case "&^": + case "||": + return newOrExprNode() + case "&&": + return newAndExprNode() + case "==": + return newEqualExprNode() + case ">=": + return newGreaterEqualExprNode() + case "<=": + return newLessEqualExprNode() + case "!=": + return newNotEqualExprNode() + } + defer func() { + if e != nil { + *expr = (*expr)[1:] + } + }() + switch a[0] { + // case '&': + // case '|': + // case '^': + case '+': + return newAdditionExprNode() + case '-': + return newSubtractionExprNode() + case '*': + return newMultiplicationExprNode() + case '/': + return newDivisionExprNode() + case '%': + return newRemainderExprNode() + case '<': + return newLessExprNode() + case '>': + return newGreaterExprNode() + } + return nil +} + +// run calculates the value of expression. +func (p *Expr) run(field string, tagExpr *TagExpr) interface{} { + return p.expr.Run(context.Background(), field, tagExpr) +} + +func (p *Expr) runWithEnv(field string, tagExpr *TagExpr, env map[string]interface{}) interface{} { + ctx := context.WithValue(context.Background(), variableKey, env) + return p.expr.Run(ctx, field, tagExpr) +} + +/** + * Priority: + * () ! bool float64 string nil + * * / % + * + - + * < <= > >= + * == != + * && + * || +**/ + +func sortPriority(e ExprNode) { + for subSortPriority(e.RightOperand(), false) { + } +} + +func subSortPriority(e ExprNode, isLeft bool) bool { + if e == nil { + return false + } + leftChanged := subSortPriority(e.LeftOperand(), true) + rightChanged := subSortPriority(e.RightOperand(), false) + if getPriority(e) > getPriority(e.LeftOperand()) { + leftOperandToParent(e, isLeft) + return true + } + return leftChanged || rightChanged +} + +func leftOperandToParent(e ExprNode, isLeft bool) { + le := e.LeftOperand() + if le == nil { + return + } + p := e.Parent() + le.SetParent(p) + if p != nil { + if isLeft { + p.SetLeftOperand(le) + } else { + p.SetRightOperand(le) + } + } + e.SetParent(le) + e.SetLeftOperand(le.RightOperand()) + le.RightOperand().SetParent(e) + le.SetRightOperand(e) +} + +func getPriority(e ExprNode) (i int) { + // defer func() { + // printf("expr:%T %d\n", e, i) + // }() + switch e.(type) { + default: // () ! bool float64 string nil + return 7 + case *multiplicationExprNode, *divisionExprNode, *remainderExprNode: // * / % + return 6 + case *additionExprNode, *subtractionExprNode: // + - + return 5 + case *lessExprNode, *lessEqualExprNode, *greaterExprNode, *greaterEqualExprNode: // < <= > >= + return 4 + case *equalExprNode, *notEqualExprNode: // == != + return 3 + case *andExprNode: // && + return 2 + case *orExprNode: // || + return 1 + } +} + +// ExprNode expression interface +type ExprNode interface { + SetParent(ExprNode) + Parent() ExprNode + LeftOperand() ExprNode + RightOperand() ExprNode + SetLeftOperand(ExprNode) + SetRightOperand(ExprNode) + String() string + Run(context.Context, string, *TagExpr) interface{} +} + +// var _ ExprNode = new(exprBackground) + +type exprBackground struct { + parent ExprNode + leftOperand ExprNode + rightOperand ExprNode +} + +func (eb *exprBackground) SetParent(e ExprNode) { + eb.parent = e +} + +func (eb *exprBackground) Parent() ExprNode { + return eb.parent +} + +func (eb *exprBackground) LeftOperand() ExprNode { + return eb.leftOperand +} + +func (eb *exprBackground) RightOperand() ExprNode { + return eb.rightOperand +} + +func (eb *exprBackground) SetLeftOperand(left ExprNode) { + eb.leftOperand = left +} + +func (eb *exprBackground) SetRightOperand(right ExprNode) { + eb.rightOperand = right +} + +func (*exprBackground) Run(context.Context, string, *TagExpr) interface{} { return nil } diff --git a/internal/tagexpr/expr_test.go b/internal/tagexpr/expr_test.go new file mode 100644 index 000000000..bec82411c --- /dev/null +++ b/internal/tagexpr/expr_test.go @@ -0,0 +1,248 @@ +// Copyright 2019 Bytedance Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tagexpr + +import ( + "math" + "reflect" + "testing" +) + +func TestExpr(t *testing.T) { + cases := []struct { + expr string + val interface{} + }{ + // Simple string + {expr: "'a'", val: "a"}, + {expr: "('a')", val: "a"}, + // Simple digital + {expr: " 10 ", val: 10.0}, + {expr: "(10)", val: 10.0}, + // Simple bool + {expr: "true", val: true}, + {expr: "!true", val: false}, + {expr: "!!true", val: true}, + {expr: "false", val: false}, + {expr: "!false", val: true}, + {expr: "!!false", val: false}, + {expr: "(false)", val: false}, + {expr: "(!false)", val: true}, + {expr: "(!!false)", val: false}, + {expr: "!!(!false)", val: true}, + {expr: "!(!false)", val: false}, + // Join string + {expr: "'true '+('a')", val: "true a"}, + {expr: "'a'+('b'+'c')+'d'", val: "abcd"}, + // Arithmetic operator + {expr: "1+7+2", val: 10.0}, + {expr: "1+(7)+(2)", val: 10.0}, + {expr: "1.1+ 2", val: 3.1}, + {expr: "-1.1+4", val: 2.9}, + {expr: "10-7-2", val: 1.0}, + {expr: "20/2", val: 10.0}, + {expr: "1/0", val: math.NaN()}, + {expr: "20%2", val: 0.0}, + {expr: "6 % 5", val: 1.0}, + {expr: "20%7 %5", val: 1.0}, + {expr: "1*2+7+2.2", val: 11.2}, + {expr: "-20/2+1+2", val: -7.0}, + {expr: "20/2+1-2-1", val: 8.0}, + {expr: "30/(2+1)/5-2-1", val: -1.0}, + {expr: "100/(( 2+8)*5 )-(1 +1- 0)", val: 0.0}, + {expr: "(2*3)+(4*2)", val: 14.0}, + {expr: "1+(2*(3+4))", val: 15.0}, + {expr: "20%(7%5)", val: 0.0}, + // Relational operator + {expr: "50 == 5", val: false}, + {expr: "'50'==50", val: true}, + {expr: "'50'=='50'", val: true}, + {expr: "'50' =='5' == true", val: false}, + {expr: "50== 50 == false", val: false}, + {expr: "50== 50 == true ==true==true", val: true}, + {expr: "50 != 5", val: true}, + {expr: "'50'!=50", val: false}, + {expr: "'50'!= '50'", val: false}, + {expr: "'50' !='5' != true", val: false}, + {expr: "50!= 50 == false", val: true}, + {expr: "50== 50 != true ==true!=true", val: true}, + {expr: "50 > 5", val: true}, + {expr: "50.1 > 50.1", val: false}, + {expr: "3.2 > 2.1", val: true}, + {expr: "'3.2' > '2.1'", val: true}, + {expr: "'13.2'>'2.1'", val: false}, + {expr: "3.2 >= 2.1", val: true}, + {expr: "2.1 >= 2.1", val: true}, + {expr: "2.05 >= 2.1", val: false}, + {expr: "'2.05'>='2.1'", val: false}, + {expr: "'12.05'>='2.1'", val: false}, + {expr: "50 < 5", val: false}, + {expr: "50.1 < 50.1", val: false}, + {expr: "3 <12.11", val: true}, + {expr: "3.2 < 2.1", val: false}, + {expr: "'3.2' < '2.1'", val: false}, + {expr: "'13.2' < '2.1'", val: true}, + {expr: "3.2 <= 2.1", val: false}, + {expr: "2.1 <= 2.1", val: true}, + {expr: "2.05 <= 2.1", val: true}, + {expr: "'2.05'<='2.1'", val: true}, + {expr: "'12.05'<='2.1'", val: true}, + // Logical operator + {expr: "!('13.2' < '2.1')", val: false}, + {expr: "(3.2 <= 2.1) &&true", val: false}, + {expr: "true&&(2.1<=2.1)", val: true}, + {expr: "(2.05<=2.1)&&false", val: false}, + {expr: "true&&!true&&false", val: false}, + {expr: "true&&true&&true", val: true}, + {expr: "true&&true&&false", val: false}, + {expr: "false&&true&&true", val: false}, + {expr: "true && false && true", val: false}, + {expr: "true||false", val: true}, + {expr: "false ||true", val: true}, + {expr: "true&&true || false", val: true}, + {expr: "true&&false || false", val: false}, + {expr: "true && false || true ", val: true}, + } + for _, c := range cases { + t.Log(c.expr) + vm, err := parseExpr(c.expr) + if err != nil { + t.Fatal(err) + } + val := vm.run("", nil) + if !reflect.DeepEqual(val, c.val) { + if f, ok := c.val.(float64); ok && math.IsNaN(f) && math.IsNaN(val.(float64)) { + continue + } + t.Fatalf("expr: %q, got: %v, expect: %v", c.expr, val, c.val) + } + } +} + +func TestExprWithEnv(t *testing.T) { + cases := []struct { + expr string + val interface{} + }{ + // env: a = 10, b = "string value", + {expr: "a", val: 10.0}, + {expr: "b", val: "string value"}, + {expr: "a>10", val: false}, + {expr: "a<11", val: true}, + {expr: "a+1", val: 11.0}, + {expr: "a==10", val: true}, + } + + for _, c := range cases { + t.Log(c.expr) + vm, err := parseExpr(c.expr) + if err != nil { + t.Fatal(err) + } + val := vm.runWithEnv("", nil, map[string]interface{}{"a": 10, "b": "string value"}) + if !reflect.DeepEqual(val, c.val) { + if f, ok := c.val.(float64); ok && math.IsNaN(f) && math.IsNaN(val.(float64)) { + continue + } + t.Fatalf("expr: %q, got: %v, expect: %v", c.expr, val, c.val) + } + } +} + +func TestPriority(t *testing.T) { + cases := []struct { + expr string + val interface{} + }{ + {expr: "false||true&&8==8", val: true}, + {expr: "1+2>5-4", val: true}, + {expr: "1+2*4/2", val: 5.0}, + {expr: "(true||false)&&false||false", val: false}, + {expr: "true||false&&false||false", val: true}, + {expr: "true||1<0&&'a'!='a'||0!=0", val: true}, + } + for _, c := range cases { + t.Log(c.expr) + vm, err := parseExpr(c.expr) + if err != nil { + t.Fatal(err) + } + val := vm.run("", nil) + if !reflect.DeepEqual(val, c.val) { + if f, ok := c.val.(float64); ok && math.IsNaN(f) && math.IsNaN(val.(float64)) { + continue + } + t.Fatalf("expr: %q, got: %v, expect: %v", c.expr, val, c.val) + } + } +} + +func TestBuiltInFunc(t *testing.T) { + cases := []struct { + expr string + val interface{} + }{ + {expr: "len('abc')", val: 3.0}, + {expr: "len('abc')+2*2/len('cd')", val: 5.0}, + {expr: "len(0)", val: 0.0}, + + {expr: "regexp('a\\d','a0')", val: true}, + {expr: "regexp('^a\\d$','a0')", val: true}, + {expr: "regexp('a\\d','a')", val: false}, + {expr: "regexp('^a\\d$','a')", val: false}, + + {expr: "sprintf('test string: %s','a')", val: "test string: a"}, + {expr: "sprintf('test string: %s','a'+'b')", val: "test string: ab"}, + {expr: "sprintf('test string: %s,%v','a',1)", val: "test string: a,1"}, + {expr: "sprintf('')+'a'", val: "a"}, + {expr: "sprintf('%v',10+2*2)", val: "14"}, + } + for _, c := range cases { + t.Log(c.expr) + vm, err := parseExpr(c.expr) + if err != nil { + t.Fatal(err) + } + val := vm.run("", nil) + if !reflect.DeepEqual(val, c.val) { + if f, ok := c.val.(float64); ok && math.IsNaN(f) && math.IsNaN(val.(float64)) { + continue + } + t.Fatalf("expr: %q, got: %v, expect: %v", c.expr, val, c.val) + } + } +} + +func TestSyntaxIncorrect(t *testing.T) { + cases := []struct { + incorrectExpr string + }{ + {incorrectExpr: "1 + + 'a'"}, + {incorrectExpr: "regexp()"}, + {incorrectExpr: "regexp('^'+'a','a')"}, + {incorrectExpr: "regexp('^a','a','b')"}, + {incorrectExpr: "sprintf()"}, + {incorrectExpr: "sprintf(0)"}, + {incorrectExpr: "sprintf('a'+'b')"}, + } + for _, c := range cases { + _, err := parseExpr(c.incorrectExpr) + if err == nil { + t.Fatalf("expect syntax incorrect: %s", c.incorrectExpr) + } else { + t.Log(err) + } + } +} diff --git a/internal/tagexpr/handler.go b/internal/tagexpr/handler.go new file mode 100644 index 000000000..a64d6825e --- /dev/null +++ b/internal/tagexpr/handler.go @@ -0,0 +1,148 @@ +// Copyright 2019 Bytedance Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tagexpr + +import "reflect" + +// FieldHandler field handler +type FieldHandler struct { + selector string + field *fieldVM + expr *TagExpr +} + +func newFieldHandler(expr *TagExpr, fieldSelector string, field *fieldVM) *FieldHandler { + return &FieldHandler{ + selector: fieldSelector, + field: field, + expr: expr, + } +} + +// StringSelector returns the field selector of string type. +func (f *FieldHandler) StringSelector() string { + return f.selector +} + +// FieldSelector returns the field selector of FieldSelector type. +func (f *FieldHandler) FieldSelector() FieldSelector { + return FieldSelector(f.selector) +} + +// Value returns the field value. +// NOTE: +// +// If initZero==true, initialize nil pointer to zero value +func (f *FieldHandler) Value(initZero bool) reflect.Value { + return f.field.reflectValueGetter(f.expr.ptr, initZero) +} + +// EvalFuncs returns the tag expression eval functions. +func (f *FieldHandler) EvalFuncs() map[ExprSelector]func() interface{} { + targetTagExpr, _ := f.expr.checkout(f.selector) + evals := make(map[ExprSelector]func() interface{}, len(f.field.exprs)) + for k, v := range f.field.exprs { + expr := v + exprSelector := ExprSelector(k) + evals[exprSelector] = func() interface{} { + return expr.run(exprSelector.Name(), targetTagExpr) + } + } + return evals +} + +// StructField returns the field StructField object. +func (f *FieldHandler) StructField() reflect.StructField { + return f.field.structField +} + +// ExprHandler expr handler +type ExprHandler struct { + base string + path string + selector string + expr *TagExpr + targetExpr *TagExpr +} + +func newExprHandler(te, tte *TagExpr, base, es string) *ExprHandler { + return &ExprHandler{ + base: base, + selector: es, + expr: te, + targetExpr: tte, + } +} + +// TagExpr returns the *TagExpr. +func (e *ExprHandler) TagExpr() *TagExpr { + return e.expr +} + +// StringSelector returns the expression selector of string type. +func (e *ExprHandler) StringSelector() string { + return e.selector +} + +// ExprSelector returns the expression selector of ExprSelector type. +func (e *ExprHandler) ExprSelector() ExprSelector { + return ExprSelector(e.selector) +} + +// Path returns the path description of the expression. +func (e *ExprHandler) Path() string { + if e.path == "" { + if e.targetExpr.path == "" { + e.path = e.selector + } else { + e.path = e.targetExpr.path + FieldSeparator + e.selector + } + } + return e.path +} + +// Eval evaluate the value of the struct tag expression. +// NOTE: +// +// result types: float64, string, bool, nil +func (e *ExprHandler) Eval() interface{} { + return e.expr.s.exprs[e.selector].run(e.base, e.targetExpr) +} + +// EvalFloat evaluates the value of the struct tag expression. +// NOTE: +// +// If the expression value type is not float64, return 0. +func (e *ExprHandler) EvalFloat() float64 { + r, _ := e.Eval().(float64) + return r +} + +// EvalString evaluates the value of the struct tag expression. +// NOTE: +// +// If the expression value type is not string, return "". +func (e *ExprHandler) EvalString() string { + r, _ := e.Eval().(string) + return r +} + +// EvalBool evaluates the value of the struct tag expression. +// NOTE: +// +// If the expression value is not 0, '' or nil, return true. +func (e *ExprHandler) EvalBool() bool { + return FakeBool(e.Eval()) +} diff --git a/internal/tagexpr/selector.go b/internal/tagexpr/selector.go new file mode 100644 index 000000000..7d9e737b2 --- /dev/null +++ b/internal/tagexpr/selector.go @@ -0,0 +1,112 @@ +// Copyright 2019 Bytedance Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tagexpr + +import ( + "strings" +) + +const ( + // FieldSeparator in the expression selector, + // the separator between field names + FieldSeparator = "." + // ExprNameSeparator in the expression selector, + // the separator of the field name and expression name + ExprNameSeparator = "@" + // DefaultExprName the default name of single model expression + DefaultExprName = ExprNameSeparator +) + +// FieldSelector expression selector +type FieldSelector string + +// Name returns the current field name. +func (f FieldSelector) Name() string { + s := string(f) + idx := strings.LastIndex(s, FieldSeparator) + if idx == -1 { + return s + } + return s[idx+1:] +} + +// Split returns the path segments and the current field name. +func (f FieldSelector) Split() (paths []string, name string) { + s := string(f) + a := strings.Split(s, FieldSeparator) + idx := len(a) - 1 + if idx > 0 { + return a[:idx], a[idx] + } + return nil, s +} + +// Parent returns the parent FieldSelector. +func (f FieldSelector) Parent() (string, bool) { + s := string(f) + i := strings.LastIndex(s, FieldSeparator) + if i < 0 { + return "", false + } + return s[:i], true +} + +// String returns string type value. +func (f FieldSelector) String() string { + return string(f) +} + +// ExprSelector expression selector +type ExprSelector string + +// Name returns the name of the expression. +func (e ExprSelector) Name() string { + s := string(e) + atIdx := strings.LastIndex(s, ExprNameSeparator) + if atIdx == -1 { + return DefaultExprName + } + return s[atIdx+1:] +} + +// Field returns the field selector it belongs to. +func (e ExprSelector) Field() string { + s := string(e) + idx := strings.LastIndex(s, ExprNameSeparator) + if idx != -1 { + s = s[:idx] + } + return s +} + +// ParentField returns the parent field selector it belongs to. +func (e ExprSelector) ParentField() (string, bool) { + return FieldSelector(e.Field()).Parent() +} + +// Split returns the field selector and the expression name. +func (e ExprSelector) Split() (field FieldSelector, name string) { + s := string(e) + atIdx := strings.LastIndex(s, ExprNameSeparator) + if atIdx == -1 { + return FieldSelector(s), DefaultExprName + } + return FieldSelector(s[:atIdx]), s[atIdx+1:] +} + +// String returns string type value. +func (e ExprSelector) String() string { + return string(e) +} diff --git a/internal/tagexpr/selector_test.go b/internal/tagexpr/selector_test.go new file mode 100644 index 000000000..bfb5dd4f7 --- /dev/null +++ b/internal/tagexpr/selector_test.go @@ -0,0 +1,30 @@ +// Copyright 2019 Bytedance Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tagexpr + +import ( + "testing" +) + +func TestExprSelector(t *testing.T) { + es := ExprSelector("F1.Index") + field, ok := es.ParentField() + if !ok { + t.Fatal("not ok") + } + if "F1" != field { + t.Fatal(field) + } +} diff --git a/internal/tagexpr/spec_func.go b/internal/tagexpr/spec_func.go new file mode 100644 index 000000000..8e6291da8 --- /dev/null +++ b/internal/tagexpr/spec_func.go @@ -0,0 +1,342 @@ +// Copyright 2019 Bytedance Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tagexpr + +import ( + "context" + "fmt" + "reflect" + "regexp" + "strings" +) + +// --------------------------- Custom function --------------------------- + +var funcList = map[string]func(p *Expr, expr *string) ExprNode{} + +// MustRegFunc registers function expression. +// NOTE: +// +// example: len($), regexp("\\d") or regexp("\\d",$); +// If @force=true, allow to cover the existed same @funcName; +// The go number types always are float64; +// The go string types always are string; +// Panic if there is an error. +func MustRegFunc(funcName string, fn func(...interface{}) interface{}, force ...bool) { + err := RegFunc(funcName, fn, force...) + if err != nil { + panic(err) + } +} + +// RegFunc registers function expression. +// NOTE: +// +// example: len($), regexp("\\d") or regexp("\\d",$); +// If @force=true, allow to cover the existed same @funcName; +// The go number types always are float64; +// The go string types always are string. +func RegFunc(funcName string, fn func(...interface{}) interface{}, force ...bool) error { + if len(force) == 0 || !force[0] { + _, ok := funcList[funcName] + if ok { + return fmt.Errorf("duplicate registration expression function: %s", funcName) + } + } + funcList[funcName] = newFunc(funcName, fn) + return nil +} + +func (p *Expr) parseFuncSign(funcName string, expr *string) (boolOpposite *bool, signOpposite *bool, args []ExprNode, found bool) { + prefix := funcName + "(" + length := len(funcName) + last, boolOpposite, signOpposite := getBoolAndSignOpposite(expr) + if !strings.HasPrefix(last, prefix) { + return + } + *expr = last[length:] + lastStr := *expr + subExprNode := readPairedSymbol(expr, '(', ')') + if subExprNode == nil { + return + } + *subExprNode = "," + *subExprNode + for { + if strings.HasPrefix(*subExprNode, ",") { + *subExprNode = (*subExprNode)[1:] + operand := newGroupExprNode() + err := p.parseExprNode(trimLeftSpace(subExprNode), operand) + if err != nil { + *expr = lastStr + return + } + sortPriority(operand) + args = append(args, operand) + } else { + *expr = lastStr + return + } + trimLeftSpace(subExprNode) + if len(*subExprNode) == 0 { + found = true + return + } + } +} + +func newFunc(funcName string, fn func(...interface{}) interface{}) func(*Expr, *string) ExprNode { + return func(p *Expr, expr *string) ExprNode { + boolOpposite, signOpposite, args, found := p.parseFuncSign(funcName, expr) + if !found { + return nil + } + return &funcExprNode{ + fn: fn, + boolOpposite: boolOpposite, + signOpposite: signOpposite, + args: args, + } + } +} + +type funcExprNode struct { + exprBackground + args []ExprNode + fn func(...interface{}) interface{} + boolOpposite *bool + signOpposite *bool +} + +func (f *funcExprNode) String() string { + return "func()" +} + +func (f *funcExprNode) Run(ctx context.Context, currField string, tagExpr *TagExpr) interface{} { + var args []interface{} + if n := len(f.args); n > 0 { + args = make([]interface{}, n) + for k, v := range f.args { + args[k] = v.Run(ctx, currField, tagExpr) + } + } + return realValue(f.fn(args...), f.boolOpposite, f.signOpposite) +} + +// --------------------------- Built-in function --------------------------- +func init() { + funcList["regexp"] = readRegexpFuncExprNode + funcList["sprintf"] = readSprintfFuncExprNode + funcList["range"] = readRangeFuncExprNode + // len: Built-in function len, the length of struct field X + MustRegFunc("len", func(args ...interface{}) (n interface{}) { + if len(args) != 1 { + return 0 + } + v := args[0] + switch e := v.(type) { + case string: + return float64(len(e)) + case float64, bool, nil: + return 0 + } + defer func() { + if recover() != nil { + n = 0 + } + }() + return float64(reflect.ValueOf(v).Len()) + }, true) + // mblen: get the length of string field X (character number) + MustRegFunc("mblen", func(args ...interface{}) (n interface{}) { + if len(args) != 1 { + return 0 + } + v := args[0] + switch e := v.(type) { + case string: + return float64(len([]rune(e))) + case float64, bool, nil: + return 0 + } + defer func() { + if recover() != nil { + n = 0 + } + }() + return float64(reflect.ValueOf(v).Len()) + }, true) + + // in: Check if the first parameter is one of the enumerated parameters + MustRegFunc("in", func(args ...interface{}) interface{} { + switch len(args) { + case 0: + return true + case 1: + return false + default: + elem := args[0] + set := args[1:] + for _, e := range set { + if elem == e { + return true + } + } + return false + } + }, true) +} + +type regexpFuncExprNode struct { + exprBackground + re *regexp.Regexp + boolOpposite bool +} + +func (re *regexpFuncExprNode) String() string { + return "regexp()" +} + +func readRegexpFuncExprNode(p *Expr, expr *string) ExprNode { + last, boolOpposite, _ := getBoolAndSignOpposite(expr) + if !strings.HasPrefix(last, "regexp(") { + return nil + } + *expr = last[6:] + lastStr := *expr + subExprNode := readPairedSymbol(expr, '(', ')') + if subExprNode == nil { + return nil + } + s := readPairedSymbol(trimLeftSpace(subExprNode), '\'', '\'') + if s == nil { + *expr = lastStr + return nil + } + rege, err := regexp.Compile(*s) + if err != nil { + *expr = lastStr + return nil + } + operand := newGroupExprNode() + trimLeftSpace(subExprNode) + if strings.HasPrefix(*subExprNode, ",") { + *subExprNode = (*subExprNode)[1:] + err = p.parseExprNode(trimLeftSpace(subExprNode), operand) + if err != nil { + *expr = lastStr + return nil + } + } else { + currFieldVal := "$" + p.parseExprNode(&currFieldVal, operand) + } + trimLeftSpace(subExprNode) + if *subExprNode != "" { + *expr = lastStr + return nil + } + e := ®expFuncExprNode{ + re: rege, + } + if boolOpposite != nil { + e.boolOpposite = *boolOpposite + } + e.SetRightOperand(operand) + return e +} + +func (re *regexpFuncExprNode) Run(ctx context.Context, currField string, tagExpr *TagExpr) interface{} { + param := re.rightOperand.Run(ctx, currField, tagExpr) + switch v := param.(type) { + case string: + bol := re.re.MatchString(v) + if re.boolOpposite { + return !bol + } + return bol + case float64, bool: + return false + } + v := reflect.ValueOf(param) + if v.Kind() == reflect.String { + bol := re.re.MatchString(v.String()) + if re.boolOpposite { + return !bol + } + return bol + } + return false +} + +type sprintfFuncExprNode struct { + exprBackground + format string + args []ExprNode +} + +func (se *sprintfFuncExprNode) String() string { + return "sprintf()" +} + +func readSprintfFuncExprNode(p *Expr, expr *string) ExprNode { + if !strings.HasPrefix(*expr, "sprintf(") { + return nil + } + *expr = (*expr)[7:] + lastStr := *expr + subExprNode := readPairedSymbol(expr, '(', ')') + if subExprNode == nil { + return nil + } + format := readPairedSymbol(trimLeftSpace(subExprNode), '\'', '\'') + if format == nil { + *expr = lastStr + return nil + } + e := &sprintfFuncExprNode{ + format: *format, + } + for { + trimLeftSpace(subExprNode) + if len(*subExprNode) == 0 { + return e + } + if strings.HasPrefix(*subExprNode, ",") { + *subExprNode = (*subExprNode)[1:] + operand := newGroupExprNode() + err := p.parseExprNode(trimLeftSpace(subExprNode), operand) + if err != nil { + *expr = lastStr + return nil + } + sortPriority(operand) + e.args = append(e.args, operand) + } else { + *expr = lastStr + return nil + } + } +} + +func (se *sprintfFuncExprNode) Run(ctx context.Context, currField string, tagExpr *TagExpr) interface{} { + var args []interface{} + if n := len(se.args); n > 0 { + args = make([]interface{}, n) + for i, e := range se.args { + args[i] = e.Run(ctx, currField, tagExpr) + } + } + return fmt.Sprintf(se.format, args...) +} diff --git a/internal/tagexpr/spec_func_test.go b/internal/tagexpr/spec_func_test.go new file mode 100644 index 000000000..3d66a49bc --- /dev/null +++ b/internal/tagexpr/spec_func_test.go @@ -0,0 +1,104 @@ +// Copyright 2019 Bytedance Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tagexpr_test + +import ( + "reflect" + "regexp" + "testing" + + "github.com/cloudwego/hertz/internal/tagexpr" +) + +func TestFunc(t *testing.T) { + emailRegexp := regexp.MustCompile( + "^([A-Za-z0-9_\\-\\.\u4e00-\u9fa5])+\\@([A-Za-z0-9_\\-\\.])+\\.([A-Za-z]{2,8})$", + ) + tagexpr.RegFunc("email", func(args ...interface{}) interface{} { + if len(args) == 0 { + return false + } + s, ok := args[0].(string) + if !ok { + return false + } + t.Log(s) + return emailRegexp.MatchString(s) + }) + + vm := tagexpr.New("te") + + type T struct { + Email string `te:"email($)"` + } + cases := []struct { + email string + expect bool + }{ + {"", false}, + {"henrylee2cn@gmail.com", true}, + } + + obj := new(T) + for _, c := range cases { + obj.Email = c.email + te := vm.MustRun(obj) + got := te.EvalBool("Email") + if got != c.expect { + t.Fatalf("email: %s, expect: %v, but got: %v", c.email, c.expect, got) + } + } + + // test len + type R struct { + Str string `vd:"mblen($)<6"` + } + lenCases := []struct { + str string + expect bool + }{ + {"123", true}, + {"一二三四五六七", false}, + {"一二三四五", true}, + } + + lenObj := new(R) + vm = tagexpr.New("vd") + for _, lenCase := range lenCases { + lenObj.Str = lenCase.str + te := vm.MustRun(lenObj) + got := te.EvalBool("Str") + if got != lenCase.expect { + t.Fatalf("string: %v, expect: %v, but got: %v", lenCase.str, lenCase.expect, got) + } + } +} + +func TestRangeIn(t *testing.T) { + vm := tagexpr.New("te") + type S struct { + F []string `te:"range($, in(#v, '', 'ttp', 'euttp'))"` + } + a := []string{"ttp", "", "euttp"} + r := vm.MustRun(S{ + F: a, + // F: b, + }) + expect := []interface{}{true, true, true} + actual := r.Eval("F") + if !reflect.DeepEqual(expect, actual) { + t.Fatal("not equal", expect, actual) + } +} diff --git a/internal/tagexpr/spec_operand.go b/internal/tagexpr/spec_operand.go new file mode 100644 index 000000000..b14b05515 --- /dev/null +++ b/internal/tagexpr/spec_operand.go @@ -0,0 +1,363 @@ +// Copyright 2019 Bytedance Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tagexpr + +import ( + "context" + "fmt" + "reflect" + "regexp" + "strconv" + "strings" +) + +// --------------------------- Operand --------------------------- + +type groupExprNode struct { + exprBackground + boolOpposite *bool + signOpposite *bool +} + +func newGroupExprNode() ExprNode { return &groupExprNode{} } + +func readGroupExprNode(expr *string) (grp ExprNode, subExprNode *string) { + last, boolOpposite, signOpposite := getBoolAndSignOpposite(expr) + sptr := readPairedSymbol(&last, '(', ')') + if sptr == nil { + return nil, nil + } + *expr = last + e := &groupExprNode{boolOpposite: boolOpposite, signOpposite: signOpposite} + return e, sptr +} + +func (ge *groupExprNode) String() string { + return "()" +} + +func (ge *groupExprNode) Run(ctx context.Context, currField string, tagExpr *TagExpr) interface{} { + if ge.rightOperand == nil { + return nil + } + return realValue(ge.rightOperand.Run(ctx, currField, tagExpr), ge.boolOpposite, ge.signOpposite) +} + +type boolExprNode struct { + exprBackground + val bool +} + +func (be *boolExprNode) String() string { + return fmt.Sprintf("%v", be.val) +} + +var boolRegexp = regexp.MustCompile(`^!*(true|false)([\)\],\|&!= \t]{1}|$)`) + +func readBoolExprNode(expr *string) ExprNode { + s := boolRegexp.FindString(*expr) + if s == "" { + return nil + } + last := s[len(s)-1] + if last != 'e' { + s = s[:len(s)-1] + } + *expr = (*expr)[len(s):] + e := &boolExprNode{} + if strings.Contains(s, "t") { + e.val = (len(s)-4)&1 == 0 + } else { + e.val = (len(s)-5)&1 == 1 + } + return e +} + +func (be *boolExprNode) Run(ctx context.Context, currField string, tagExpr *TagExpr) interface{} { + return be.val +} + +type stringExprNode struct { + exprBackground + val interface{} +} + +func (se *stringExprNode) String() string { + return fmt.Sprintf("%v", se.val) +} + +func readStringExprNode(expr *string) ExprNode { + last, boolOpposite, _ := getBoolAndSignOpposite(expr) + sptr := readPairedSymbol(&last, '\'', '\'') + if sptr == nil { + return nil + } + *expr = last + e := &stringExprNode{val: realValue(*sptr, boolOpposite, nil)} + return e +} + +func (se *stringExprNode) Run(ctx context.Context, currField string, tagExpr *TagExpr) interface{} { + return se.val +} + +type digitalExprNode struct { + exprBackground + val interface{} +} + +func (de *digitalExprNode) String() string { + return fmt.Sprintf("%v", de.val) +} + +var digitalRegexp = regexp.MustCompile(`^[\+\-]?\d+(\.\d+)?([\)\],\+\-\*\/%><\|&!=\^ \t\\]|$)`) + +func readDigitalExprNode(expr *string) ExprNode { + last, boolOpposite := getOpposite(expr, "!") + s := digitalRegexp.FindString(last) + if s == "" { + return nil + } + if r := s[len(s)-1]; r < '0' || r > '9' { + s = s[:len(s)-1] + } + *expr = last[len(s):] + f64, _ := strconv.ParseFloat(s, 64) + return &digitalExprNode{val: realValue(f64, boolOpposite, nil)} +} + +func (de *digitalExprNode) Run(ctx context.Context, currField string, tagExpr *TagExpr) interface{} { + return de.val +} + +type nilExprNode struct { + exprBackground + val interface{} +} + +func (ne *nilExprNode) String() string { + return "" +} + +var nilRegexp = regexp.MustCompile(`^nil([\)\],\|&!= \t]{1}|$)`) + +func readNilExprNode(expr *string) ExprNode { + last, boolOpposite := getOpposite(expr, "!") + s := nilRegexp.FindString(last) + if s == "" { + return nil + } + *expr = last[3:] + return &nilExprNode{val: realValue(nil, boolOpposite, nil)} +} + +func (ne *nilExprNode) Run(ctx context.Context, currField string, tagExpr *TagExpr) interface{} { + return ne.val +} + +type variableExprNode struct { + exprBackground + boolOpposite *bool + val string +} + +func (ve *variableExprNode) String() string { + return fmt.Sprintf("%v", ve.val) +} + +func (ve *variableExprNode) Run(ctx context.Context, variableName string, _ *TagExpr) interface{} { + envObj := ctx.Value(variableKey) + if envObj == nil { + return nil + } + + env := envObj.(map[string]interface{}) + if len(env) == 0 { + return nil + } + + if value, ok := env[ve.val]; ok && value != nil { + return realValue(value, ve.boolOpposite, nil) + } else { + return nil + } +} + +var variableRegex = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*`) + +func readVariableExprNode(expr *string) ExprNode { + last, boolOpposite := getOpposite(expr, "!") + variable := variableRegex.FindString(last) + if variable == "" { + return nil + } + + *expr = (*expr)[len(*expr)-len(last)+len(variable):] + + return &variableExprNode{ + val: variable, + boolOpposite: boolOpposite, + } +} + +func getBoolAndSignOpposite(expr *string) (last string, boolOpposite *bool, signOpposite *bool) { + last, boolOpposite = getOpposite(expr, "!") + last = strings.TrimLeft(last, "+") + last, signOpposite = getOpposite(&last, "-") + last = strings.TrimLeft(last, "+") + return +} + +func getOpposite(expr *string, cutset string) (string, *bool) { + last := strings.TrimLeft(*expr, cutset) + n := len(*expr) - len(last) + if n == 0 { + return last, nil + } + bol := n&1 == 1 + return last, &bol +} + +func toString(i interface{}, enforce bool) (string, bool) { + switch vv := i.(type) { + case string: + return vv, true + case nil: + return "", false + default: + rv := dereferenceValue(reflect.ValueOf(i)) + if rv.Kind() == reflect.String { + return rv.String(), true + } + if enforce { + if rv.IsValid() && rv.CanInterface() { + return fmt.Sprint(rv.Interface()), true + } else { + return fmt.Sprint(i), true + } + } + } + return "", false +} + +func toFloat64(i interface{}, tryParse bool) (float64, bool) { + var v float64 + ok := true + switch t := i.(type) { + case float64: + v = t + case float32: + v = float64(t) + case int: + v = float64(t) + case int8: + v = float64(t) + case int16: + v = float64(t) + case int32: + v = float64(t) + case int64: + v = float64(t) + case uint: + v = float64(t) + case uint8: + v = float64(t) + case uint16: + v = float64(t) + case uint32: + v = float64(t) + case uint64: + v = float64(t) + case nil: + ok = false + default: + rv := dereferenceValue(reflect.ValueOf(t)) + switch rv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v = float64(rv.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v = float64(rv.Uint()) + case reflect.Float32, reflect.Float64: + v = rv.Float() + default: + if tryParse { + if s, ok := toString(i, false); ok { + var err error + v, err = strconv.ParseFloat(s, 64) + return v, err == nil + } + } + ok = false + } + } + return v, ok +} + +func realValue(v interface{}, boolOpposite *bool, signOpposite *bool) interface{} { + if boolOpposite != nil { + bol := FakeBool(v) + if *boolOpposite { + return !bol + } + return bol + } + switch t := v.(type) { + case float64, string: + case float32: + v = float64(t) + case int: + v = float64(t) + case int8: + v = float64(t) + case int16: + v = float64(t) + case int32: + v = float64(t) + case int64: + v = float64(t) + case uint: + v = float64(t) + case uint8: + v = float64(t) + case uint16: + v = float64(t) + case uint32: + v = float64(t) + case uint64: + v = float64(t) + case []interface{}: + for k, v := range t { + t[k] = realValue(v, boolOpposite, signOpposite) + } + default: + rv := dereferenceValue(reflect.ValueOf(v)) + switch rv.Kind() { + case reflect.String: + v = rv.String() + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + v = float64(rv.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + v = float64(rv.Uint()) + case reflect.Float32, reflect.Float64: + v = rv.Float() + } + } + if signOpposite != nil && *signOpposite { + if f, ok := v.(float64); ok { + v = -f + } + } + return v +} diff --git a/internal/tagexpr/spec_operator.go b/internal/tagexpr/spec_operator.go new file mode 100644 index 000000000..72c4fc6c0 --- /dev/null +++ b/internal/tagexpr/spec_operator.go @@ -0,0 +1,290 @@ +// Copyright 2019 Bytedance Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tagexpr + +import ( + "context" + "math" +) + +// --------------------------- Operator --------------------------- + +type additionExprNode struct{ exprBackground } + +func (ae *additionExprNode) String() string { + return "+" +} + +func newAdditionExprNode() ExprNode { return &additionExprNode{} } + +func (ae *additionExprNode) Run(ctx context.Context, currField string, tagExpr *TagExpr) interface{} { + // positive number or Addition + v0 := ae.leftOperand.Run(ctx, currField, tagExpr) + v1 := ae.rightOperand.Run(ctx, currField, tagExpr) + if s0, ok := toFloat64(v0, false); ok { + s1, _ := toFloat64(v1, true) + return s0 + s1 + } + if s0, ok := toString(v0, false); ok { + s1, _ := toString(v1, true) + return s0 + s1 + } + return v0 +} + +type multiplicationExprNode struct{ exprBackground } + +func (ae *multiplicationExprNode) String() string { + return "*" +} + +func newMultiplicationExprNode() ExprNode { return &multiplicationExprNode{} } + +func (ae *multiplicationExprNode) Run(ctx context.Context, currField string, tagExpr *TagExpr) interface{} { + v0, _ := toFloat64(ae.leftOperand.Run(ctx, currField, tagExpr), true) + v1, _ := toFloat64(ae.rightOperand.Run(ctx, currField, tagExpr), true) + return v0 * v1 +} + +type divisionExprNode struct{ exprBackground } + +func (de *divisionExprNode) String() string { + return "/" +} + +func newDivisionExprNode() ExprNode { return &divisionExprNode{} } + +func (de *divisionExprNode) Run(ctx context.Context, currField string, tagExpr *TagExpr) interface{} { + v1, _ := toFloat64(de.rightOperand.Run(ctx, currField, tagExpr), true) + if v1 == 0 { + return math.NaN() + } + v0, _ := toFloat64(de.leftOperand.Run(ctx, currField, tagExpr), true) + return v0 / v1 +} + +type subtractionExprNode struct{ exprBackground } + +func (de *subtractionExprNode) String() string { + return "-" +} + +func newSubtractionExprNode() ExprNode { return &subtractionExprNode{} } + +func (de *subtractionExprNode) Run(ctx context.Context, currField string, tagExpr *TagExpr) interface{} { + v0, _ := toFloat64(de.leftOperand.Run(ctx, currField, tagExpr), true) + v1, _ := toFloat64(de.rightOperand.Run(ctx, currField, tagExpr), true) + return v0 - v1 +} + +type remainderExprNode struct{ exprBackground } + +func (re *remainderExprNode) String() string { + return "%" +} + +func newRemainderExprNode() ExprNode { return &remainderExprNode{} } + +func (re *remainderExprNode) Run(ctx context.Context, currField string, tagExpr *TagExpr) interface{} { + v1, _ := toFloat64(re.rightOperand.Run(ctx, currField, tagExpr), true) + if v1 == 0 { + return math.NaN() + } + v0, _ := toFloat64(re.leftOperand.Run(ctx, currField, tagExpr), true) + return float64(int64(v0) % int64(v1)) +} + +type equalExprNode struct{ exprBackground } + +func (ee *equalExprNode) String() string { + return "==" +} + +func newEqualExprNode() ExprNode { return &equalExprNode{} } + +func (ee *equalExprNode) Run(ctx context.Context, currField string, tagExpr *TagExpr) interface{} { + v0 := ee.leftOperand.Run(ctx, currField, tagExpr) + v1 := ee.rightOperand.Run(ctx, currField, tagExpr) + if v0 == v1 { + return true + } + if s0, ok := toFloat64(v0, false); ok { + if s1, ok := toFloat64(v1, true); ok { + return s0 == s1 + } + } + if s0, ok := toString(v0, false); ok { + if s1, ok := toString(v1, true); ok { + return s0 == s1 + } + return false + } + switch r := v0.(type) { + case bool: + r1, ok := v1.(bool) + if ok { + return r == r1 + } + case nil: + return v1 == nil + } + return false +} + +type notEqualExprNode struct{ equalExprNode } + +func (ne *notEqualExprNode) String() string { + return "!=" +} + +func newNotEqualExprNode() ExprNode { return ¬EqualExprNode{} } + +func (ne *notEqualExprNode) Run(ctx context.Context, currField string, tagExpr *TagExpr) interface{} { + return !ne.equalExprNode.Run(ctx, currField, tagExpr).(bool) +} + +type greaterExprNode struct{ exprBackground } + +func (ge *greaterExprNode) String() string { + return ">" +} + +func newGreaterExprNode() ExprNode { return &greaterExprNode{} } + +func (ge *greaterExprNode) Run(ctx context.Context, currField string, tagExpr *TagExpr) interface{} { + v0 := ge.leftOperand.Run(ctx, currField, tagExpr) + v1 := ge.rightOperand.Run(ctx, currField, tagExpr) + if s0, ok := toFloat64(v0, false); ok { + if s1, ok := toFloat64(v1, true); ok { + return s0 > s1 + } + } + if s0, ok := toString(v0, false); ok { + if s1, ok := toString(v1, true); ok { + return s0 > s1 + } + return false + } + return false +} + +type greaterEqualExprNode struct{ exprBackground } + +func (ge *greaterEqualExprNode) String() string { + return ">=" +} + +func newGreaterEqualExprNode() ExprNode { return &greaterEqualExprNode{} } + +func (ge *greaterEqualExprNode) Run(ctx context.Context, currField string, tagExpr *TagExpr) interface{} { + v0 := ge.leftOperand.Run(ctx, currField, tagExpr) + v1 := ge.rightOperand.Run(ctx, currField, tagExpr) + if s0, ok := toFloat64(v0, false); ok { + if s1, ok := toFloat64(v1, true); ok { + return s0 >= s1 + } + } + if s0, ok := toString(v0, false); ok { + if s1, ok := toString(v1, true); ok { + return s0 >= s1 + } + return false + } + return false +} + +type lessExprNode struct{ exprBackground } + +func (le *lessExprNode) String() string { + return "<" +} + +func newLessExprNode() ExprNode { return &lessExprNode{} } + +func (le *lessExprNode) Run(ctx context.Context, currField string, tagExpr *TagExpr) interface{} { + v0 := le.leftOperand.Run(ctx, currField, tagExpr) + v1 := le.rightOperand.Run(ctx, currField, tagExpr) + if s0, ok := toFloat64(v0, false); ok { + if s1, ok := toFloat64(v1, true); ok { + return s0 < s1 + } + } + if s0, ok := toString(v0, false); ok { + if s1, ok := toString(v1, true); ok { + return s0 < s1 + } + return false + } + return false +} + +type lessEqualExprNode struct{ exprBackground } + +func (le *lessEqualExprNode) String() string { + return "<=" +} + +func newLessEqualExprNode() ExprNode { return &lessEqualExprNode{} } + +func (le *lessEqualExprNode) Run(ctx context.Context, currField string, tagExpr *TagExpr) interface{} { + v0 := le.leftOperand.Run(ctx, currField, tagExpr) + v1 := le.rightOperand.Run(ctx, currField, tagExpr) + if s0, ok := toFloat64(v0, false); ok { + if s1, ok := toFloat64(v1, true); ok { + return s0 <= s1 + } + } + if s0, ok := toString(v0, false); ok { + if s1, ok := toString(v1, true); ok { + return s0 <= s1 + } + return false + } + return false +} + +type andExprNode struct{ exprBackground } + +func (ae *andExprNode) String() string { + return "&&" +} + +func newAndExprNode() ExprNode { return &andExprNode{} } + +func (ae *andExprNode) Run(ctx context.Context, currField string, tagExpr *TagExpr) interface{} { + for _, e := range [2]ExprNode{ae.leftOperand, ae.rightOperand} { + if !FakeBool(e.Run(ctx, currField, tagExpr)) { + return false + } + } + return true +} + +type orExprNode struct{ exprBackground } + +func (oe *orExprNode) String() string { + return "||" +} + +func newOrExprNode() ExprNode { return &orExprNode{} } + +func (oe *orExprNode) Run(ctx context.Context, currField string, tagExpr *TagExpr) interface{} { + for _, e := range [2]ExprNode{oe.leftOperand, oe.rightOperand} { + if FakeBool(e.Run(ctx, currField, tagExpr)) { + return true + } + } + return false +} diff --git a/internal/tagexpr/spec_range.go b/internal/tagexpr/spec_range.go new file mode 100644 index 000000000..dcb88722a --- /dev/null +++ b/internal/tagexpr/spec_range.go @@ -0,0 +1,164 @@ +// Copyright 2019 Bytedance Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tagexpr + +import ( + "context" + "reflect" + "regexp" +) + +type rangeCtxKey string + +const ( + rangeKey rangeCtxKey = "#k" + rangeValue rangeCtxKey = "#v" + rangeLen rangeCtxKey = "##" +) + +type rangeKvExprNode struct { + exprBackground + ctxKey rangeCtxKey + boolOpposite *bool + signOpposite *bool +} + +func (re *rangeKvExprNode) String() string { + return string(re.ctxKey) +} + +func (p *Expr) readRangeKvExprNode(expr *string) ExprNode { + name, boolOpposite, signOpposite, found := findRangeKv(expr) + if !found { + return nil + } + operand := &rangeKvExprNode{ + ctxKey: rangeCtxKey(name), + boolOpposite: boolOpposite, + signOpposite: signOpposite, + } + // fmt.Printf("operand: %#v\n", operand) + return operand +} + +var rangeKvRegexp = regexp.MustCompile(`^([\!\+\-]*)(#[kv#])([\)\[\],\+\-\*\/%><\|&!=\^ \t\\]|$)`) + +func findRangeKv(expr *string) (name string, boolOpposite, signOpposite *bool, found bool) { + raw := *expr + a := rangeKvRegexp.FindAllStringSubmatch(raw, -1) + if len(a) != 1 { + return + } + r := a[0] + name = r[2] + *expr = (*expr)[len(a[0][0])-len(r[3]):] + prefix := r[1] + if len(prefix) == 0 { + found = true + return + } + _, boolOpposite, signOpposite = getBoolAndSignOpposite(&prefix) + found = true + return +} + +func (re *rangeKvExprNode) Run(ctx context.Context, _ string, _ *TagExpr) interface{} { + var v interface{} + switch val := ctx.Value(re.ctxKey).(type) { + case reflect.Value: + if !val.IsValid() || !val.CanInterface() { + return nil + } + v = val.Interface() + default: + v = val + } + return realValue(v, re.boolOpposite, re.signOpposite) +} + +type rangeFuncExprNode struct { + exprBackground + object ExprNode + elemExprNode ExprNode + boolOpposite *bool + signOpposite *bool +} + +func (e *rangeFuncExprNode) String() string { + return "range()" +} + +// range($, gt($v,10)) +// range($, $v>10) +func readRangeFuncExprNode(p *Expr, expr *string) ExprNode { + boolOpposite, signOpposite, args, found := p.parseFuncSign("range", expr) + if !found { + return nil + } + if len(args) != 2 { + return nil + } + return &rangeFuncExprNode{ + boolOpposite: boolOpposite, + signOpposite: signOpposite, + object: args[0], + elemExprNode: args[1], + } +} + +func (e *rangeFuncExprNode) Run(ctx context.Context, currField string, tagExpr *TagExpr) interface{} { + var r []interface{} + obj := e.object.Run(ctx, currField, tagExpr) + // fmt.Printf("%v\n", obj) + objval := reflect.ValueOf(obj) + switch objval.Kind() { + case reflect.Array, reflect.Slice: + count := objval.Len() + r = make([]interface{}, count) + ctx = context.WithValue(ctx, rangeLen, count) + for i := 0; i < count; i++ { + // fmt.Printf("%#v, (%v)\n", e.elemExprNode, objval.Index(i)) + r[i] = realValue(e.elemExprNode.Run( + context.WithValue( + context.WithValue( + ctx, + rangeKey, i, + ), + rangeValue, objval.Index(i), + ), + currField, tagExpr, + ), e.boolOpposite, e.signOpposite) + } + case reflect.Map: + keys := objval.MapKeys() + count := len(keys) + r = make([]interface{}, count) + ctx = context.WithValue(ctx, rangeLen, count) + for i, key := range keys { + r[i] = realValue(e.elemExprNode.Run( + context.WithValue( + context.WithValue( + ctx, + rangeKey, key, + ), + rangeValue, objval.MapIndex(key), + ), + currField, tagExpr, + ), e.boolOpposite, e.signOpposite) + } + default: + } + return r +} diff --git a/internal/tagexpr/spec_range_test.go b/internal/tagexpr/spec_range_test.go new file mode 100644 index 000000000..bc4f73a36 --- /dev/null +++ b/internal/tagexpr/spec_range_test.go @@ -0,0 +1,54 @@ +// Copyright 2019 Bytedance Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tagexpr_test + +import ( + "reflect" + "testing" + + "github.com/cloudwego/hertz/internal/tagexpr" +) + +func TestIssue12(t *testing.T) { + vm := tagexpr.New("te") + type I int + type S struct { + F []I `te:"range($, '>'+sprintf('%v:%v', #k, #v+2+len($)))"` + Fs [][]I `te:"range($, range(#v, '>'+sprintf('%v:%v', #k, #v+2+##)))"` + M map[string]I `te:"range($, '>'+sprintf('%s:%v', #k, #v+2+##))"` + MFs []map[string][]I `te:"range($, range(#v, range(#v, '>'+sprintf('%v:%v', #k, #v+2+##))))"` + MFs2 []map[string][]I `te:"range($, range(#v, range(#v, '>'+sprintf('%v:%v', #k, #v+2+##))))"` + } + a := []I{2, 3} + r := vm.MustRun(S{ + F: a, + Fs: [][]I{a}, + M: map[string]I{"m0": 2, "m1": 3}, + MFs: []map[string][]I{{"m": a}}, + MFs2: []map[string][]I{}, + }) + assertEqual(t, []interface{}{">0:6", ">1:7"}, r.Eval("F")) + assertEqual(t, []interface{}{[]interface{}{">0:6", ">1:7"}}, r.Eval("Fs")) + assertEqual(t, []interface{}{[]interface{}{[]interface{}{">0:6", ">1:7"}}}, r.Eval("MFs")) + assertEqual(t, []interface{}{}, r.Eval("MFs2")) + assertEqual(t, true, r.EvalBool("MFs2")) + + // result may not stable for map + got := r.Eval("M") + if !reflect.DeepEqual([]interface{}{">m0:6", ">m1:7"}, got) && + !reflect.DeepEqual([]interface{}{">m1:7", ">m0:6"}, got) { + t.Fatal(got) + } +} diff --git a/internal/tagexpr/spec_selector.go b/internal/tagexpr/spec_selector.go new file mode 100644 index 000000000..7e00990b1 --- /dev/null +++ b/internal/tagexpr/spec_selector.go @@ -0,0 +1,109 @@ +// Copyright 2019 Bytedance Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tagexpr + +import ( + "context" + "fmt" + "regexp" + "strings" +) + +type selectorExprNode struct { + exprBackground + field, name string + subExprs []ExprNode + boolOpposite *bool + signOpposite *bool +} + +func (se *selectorExprNode) String() string { + return fmt.Sprintf("(%s)%s", se.field, se.name) +} + +func (p *Expr) readSelectorExprNode(expr *string) ExprNode { + field, name, subSelector, boolOpposite, signOpposite, found := findSelector(expr) + if !found { + return nil + } + operand := &selectorExprNode{ + field: field, + name: name, + boolOpposite: boolOpposite, + signOpposite: signOpposite, + } + operand.subExprs = make([]ExprNode, 0, len(subSelector)) + for _, s := range subSelector { + grp := newGroupExprNode() + err := p.parseExprNode(&s, grp) + if err != nil { + return nil + } + sortPriority(grp) + operand.subExprs = append(operand.subExprs, grp) + } + return operand +} + +var selectorRegexp = regexp.MustCompile(`^([\!\+\-]*)(\([ \t]*[A-Za-z_]+[A-Za-z0-9_\.]*[ \t]*\))?(\$)([\)\[\],\+\-\*\/%><\|&!=\^ \t\\]|$)`) + +func findSelector(expr *string) (field string, name string, subSelector []string, boolOpposite, signOpposite *bool, found bool) { + raw := *expr + a := selectorRegexp.FindAllStringSubmatch(raw, -1) + if len(a) != 1 { + return + } + r := a[0] + if s0 := r[2]; len(s0) > 0 { + field = strings.TrimSpace(s0[1 : len(s0)-1]) + } + name = r[3] + *expr = (*expr)[len(a[0][0])-len(r[4]):] + for { + sub := readPairedSymbol(expr, '[', ']') + if sub == nil { + break + } + if *sub == "" || (*sub)[0] == '[' { + *expr = raw + return "", "", nil, nil, nil, false + } + subSelector = append(subSelector, strings.TrimSpace(*sub)) + } + prefix := r[1] + if len(prefix) == 0 { + found = true + return + } + _, boolOpposite, signOpposite = getBoolAndSignOpposite(&prefix) + found = true + return +} + +func (se *selectorExprNode) Run(ctx context.Context, currField string, tagExpr *TagExpr) interface{} { + var subFields []interface{} + if n := len(se.subExprs); n > 0 { + subFields = make([]interface{}, n) + for i, e := range se.subExprs { + subFields[i] = e.Run(ctx, currField, tagExpr) + } + } + field := se.field + if field == "" { + field = currField + } + v := tagExpr.getValue(field, subFields) + return realValue(v, se.boolOpposite, se.signOpposite) +} diff --git a/internal/tagexpr/spec_test.go b/internal/tagexpr/spec_test.go new file mode 100644 index 000000000..07f93d1ec --- /dev/null +++ b/internal/tagexpr/spec_test.go @@ -0,0 +1,162 @@ +// Copyright 2019 Bytedance Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tagexpr + +import ( + "context" + "reflect" + "testing" +) + +func TestReadPairedSymbol(t *testing.T) { + cases := []struct { + left, right rune + expr, val, lastExprNode string + }{ + {left: '\'', right: '\'', expr: "'true '+'a'", val: "true ", lastExprNode: "+'a'"}, + {left: '(', right: ')', expr: "((0+1)/(2-1)*9)%2", val: "(0+1)/(2-1)*9", lastExprNode: "%2"}, + {left: '(', right: ')', expr: `(\)\(\))`, val: `)()`}, + {left: '\'', right: '\'', expr: `'\\'`, val: `\\`}, + {left: '\'', right: '\'', expr: `'\'\''`, val: `''`}, + } + for _, c := range cases { + t.Log(c.expr) + expr := c.expr + got := readPairedSymbol(&expr, c.left, c.right) + if got == nil { + t.Fatalf("expr: %q, got: %v, %q, want: %q, %q", c.expr, got, expr, c.val, c.lastExprNode) + } else if *got != c.val || expr != c.lastExprNode { + t.Fatalf("expr: %q, got: %q, %q, want: %q, %q", c.expr, *got, expr, c.val, c.lastExprNode) + } + } +} + +func TestReadBoolExprNode(t *testing.T) { + cases := []struct { + expr string + val bool + lastExprNode string + }{ + {expr: "false", val: false, lastExprNode: ""}, + {expr: "true", val: true, lastExprNode: ""}, + {expr: "true ", val: true, lastExprNode: " "}, + {expr: "!true&", val: false, lastExprNode: "&"}, + {expr: "!false|", val: true, lastExprNode: "|"}, + {expr: "!!!!false =", val: !!!!false, lastExprNode: " ="}, + } + for _, c := range cases { + t.Log(c.expr) + expr := c.expr + e := readBoolExprNode(&expr) + got := e.Run(context.TODO(), "", nil).(bool) + if got != c.val || expr != c.lastExprNode { + t.Fatalf("expr: %s, got: %v, %s, want: %v, %s", c.expr, got, expr, c.val, c.lastExprNode) + } + } +} + +func TestReadDigitalExprNode(t *testing.T) { + cases := []struct { + expr string + val float64 + lastExprNode string + }{ + {expr: "0.1 +1", val: 0.1, lastExprNode: " +1"}, + {expr: "-1\\1", val: -1, lastExprNode: "\\1"}, + {expr: "1a", val: 0, lastExprNode: ""}, + {expr: "1", val: 1, lastExprNode: ""}, + {expr: "1.1", val: 1.1, lastExprNode: ""}, + {expr: "1.1/", val: 1.1, lastExprNode: "/"}, + } + for _, c := range cases { + expr := c.expr + e := readDigitalExprNode(&expr) + if c.expr == "1a" { + if e != nil { + t.Fatalf("expr: %s, got:%v, want:%v", c.expr, e.Run(context.TODO(), "", nil), nil) + } + continue + } + got := e.Run(context.TODO(), "", nil).(float64) + if got != c.val || expr != c.lastExprNode { + t.Fatalf("expr: %s, got: %f, %s, want: %f, %s", c.expr, got, expr, c.val, c.lastExprNode) + } + } +} + +func TestFindSelector(t *testing.T) { + cases := []struct { + expr string + field string + name string + subSelector []string + boolOpposite bool + signOpposite bool + found bool + last string + }{ + {expr: "$", name: "$", found: true}, + {expr: "!!$", name: "$", found: true}, + {expr: "!$", name: "$", boolOpposite: true, found: true}, + {expr: "+$", name: "$", found: true}, + {expr: "--$", name: "$", found: true}, + {expr: "-$", name: "$", signOpposite: true, found: true}, + {expr: "---$", name: "$", signOpposite: true, found: true}, + {expr: "()$", last: "()$"}, + {expr: "(0)$", last: "(0)$"}, + {expr: "(A)$", field: "A", name: "$", found: true}, + {expr: "+(A)$", field: "A", name: "$", found: true}, + {expr: "++(A)$", field: "A", name: "$", found: true}, + {expr: "!(A)$", field: "A", name: "$", boolOpposite: true, found: true}, + {expr: "-(A)$", field: "A", name: "$", signOpposite: true, found: true}, + {expr: "(A0)$", field: "A0", name: "$", found: true}, + {expr: "!!(A0)$", field: "A0", name: "$", found: true}, + {expr: "--(A0)$", field: "A0", name: "$", found: true}, + {expr: "(A0)$(A1)$", last: "(A0)$(A1)$"}, + {expr: "(A0)$ $(A1)$", field: "A0", name: "$", found: true, last: " $(A1)$"}, + {expr: "$a", last: "$a"}, + {expr: "$[1]['a']", name: "$", subSelector: []string{"1", "'a'"}, found: true}, + {expr: "$[1][]", last: "$[1][]"}, + {expr: "$[[]]", last: "$[[]]"}, + {expr: "$[[[]]]", last: "$[[[]]]"}, + {expr: "$[(A)$[1]]", name: "$", subSelector: []string{"(A)$[1]"}, found: true}, + {expr: "$>0&&$<10", name: "$", found: true, last: ">0&&$<10"}, + } + for _, c := range cases { + last := c.expr + field, name, subSelector, boolOpposite, signOpposite, found := findSelector(&last) + if found != c.found { + t.Fatalf("%q found: got: %v, want: %v", c.expr, found, c.found) + } + if c.boolOpposite && (boolOpposite == nil || !*boolOpposite) { + t.Fatalf("%q boolOpposite: got: %v, want: %v", c.expr, boolOpposite, c.boolOpposite) + } + if c.signOpposite && (signOpposite == nil || !*signOpposite) { + t.Fatalf("%q signOpposite: got: %v, want: %v", c.expr, signOpposite, c.signOpposite) + } + if field != c.field { + t.Fatalf("%q field: got: %q, want: %q", c.expr, field, c.field) + } + if name != c.name { + t.Fatalf("%q name: got: %q, want: %q", c.expr, name, c.name) + } + if !reflect.DeepEqual(subSelector, c.subSelector) { + t.Fatalf("%q subSelector: got: %v, want: %v", c.expr, subSelector, c.subSelector) + } + if last != c.last { + t.Fatalf("%q last: got: %q, want: %q", c.expr, last, c.last) + } + } +} diff --git a/internal/tagexpr/tagexpr.go b/internal/tagexpr/tagexpr.go new file mode 100644 index 000000000..9b14576e3 --- /dev/null +++ b/internal/tagexpr/tagexpr.go @@ -0,0 +1,1225 @@ +// Copyright 2019 Bytedance Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package tagexpr is an interesting go struct tag expression syntax for field validation, etc. +package tagexpr + +import ( + "errors" + "fmt" + "reflect" + "strconv" + "strings" + "sync" + "unsafe" +) + +// Internally unified data types +type ( + Number = float64 + Null = interface{} + Boolean = bool + String = string +) + +// VM struct tag expression interpreter +type VM struct { + tagName string + structJar map[uintptr]*structVM + rw sync.RWMutex +} + +// structVM tag expression set of struct +type structVM struct { + vm *VM + name string + fields map[string]*fieldVM + fieldSelectorList []string + fieldsWithIndirectStructVM []*fieldVM + exprs map[string]*Expr + exprSelectorList []string + ifaceTagExprGetters []func(unsafe.Pointer, string, func(*TagExpr, error) error) error + err error +} + +// fieldVM tag expression set of struct field +type fieldVM struct { + structField reflect.StructField + ptrDeep int + getPtr func(unsafe.Pointer) unsafe.Pointer + elemType reflect.Type + elemKind reflect.Kind + valueGetter func(unsafe.Pointer) interface{} + reflectValueGetter func(unsafe.Pointer, bool) reflect.Value + exprs map[string]*Expr + origin *structVM + mapKeyStructVM *structVM + mapOrSliceElemStructVM *structVM + mapOrSliceIfaceKinds [2]bool // [value, key/index] + fieldSelector string + tagOp string +} + +// New creates a tag expression interpreter that uses tagName as the tag name. +// NOTE: +// +// If no tagName is specified, no tag expression will be interpreted, +// but still can operate the various fields. +func New(tagName ...string) *VM { + if len(tagName) == 0 { + tagName = append(tagName, "") + } + return &VM{ + tagName: tagName[0], + structJar: make(map[uintptr]*structVM, 256), + } +} + +// MustRun is similar to Run, but panic when error. +func (vm *VM) MustRun(structOrStructPtrOrReflectValue interface{}) *TagExpr { + te, err := vm.Run(structOrStructPtrOrReflectValue) + if err != nil { + panic(err) + } + return te +} + +var ( + unsupportedNil = errors.New("unsupported data: nil") + unsupportedCannotAddr = errors.New("unsupported data: can not addr") +) + +// Run returns the tag expression handler of the @structPtrOrReflectValue. +// NOTE: +// +// If the structure type has not been warmed up, +// it will be slower when it is first called. +// +// Disable new -d=checkptr behaviour for Go 1.14 +// +//go:nocheckptr +func (vm *VM) Run(structPtrOrReflectValue interface{}) (*TagExpr, error) { + var v reflect.Value + switch t := structPtrOrReflectValue.(type) { + case reflect.Value: + v = dereferenceValue(t) + default: + v = dereferenceValue(reflect.ValueOf(t)) + } + if err := checkStructMapAddr(v); err != nil { + return nil, err + } + + ptr := rvPtr(v) + if ptr == nil { + return nil, unsupportedNil + } + + tid := rvType(v) + var err error + vm.rw.RLock() + s, ok := vm.structJar[tid] + vm.rw.RUnlock() + if !ok { + vm.rw.Lock() + s, ok = vm.structJar[tid] + if !ok { + s, err = vm.registerStructLocked(v.Type()) + if err != nil { + vm.rw.Unlock() + return nil, err + } + } + vm.rw.Unlock() + } + if s.err != nil { + return nil, s.err + } + return s.newTagExpr(ptr, ""), nil +} + +// RunAny returns the tag expression handler for the @v. +// NOTE: +// +// The @v can be structured data such as struct, map, slice, array, interface, reflcet.Value, etc. +// If the structure type has not been warmed up, +// it will be slower when it is first called. +func (vm *VM) RunAny(v interface{}, fn func(*TagExpr, error) error) error { + vv, isReflectValue := v.(reflect.Value) + if !isReflectValue { + vv = reflect.ValueOf(v) + } + return vm.subRunAll(false, "", vv, fn) +} + +// check type: struct{F map[T1]T2} +func checkStructMapAddr(v reflect.Value) error { + if !v.IsValid() || v.CanAddr() || v.NumField() != 1 || v.Field(0).Kind() != reflect.Map { + return nil + } + return unsupportedCannotAddr +} + +func (vm *VM) subRunAll(omitNil bool, tePath string, value reflect.Value, fn func(*TagExpr, error) error) error { + rv := dereferenceInterfaceValue(value) + if !rv.IsValid() { + return nil + } + rt := dereferenceType(rv.Type()) + rv = dereferenceValue(rv) + switch rt.Kind() { + case reflect.Struct: + if len(tePath) == 0 { + if err := checkStructMapAddr(rv); err != nil { + return err + } + } + ptr := rvPtr(rv) + if ptr == nil { + if omitNil { + return nil + } + return fn(nil, unsupportedNil) + } + return fn(vm.subRun(tePath, rt, rvType(rv), ptr)) + + case reflect.Slice, reflect.Array: + count := rv.Len() + if count == 0 { + return nil + } + switch dereferenceType(rv.Type().Elem()).Kind() { + case reflect.Struct, reflect.Interface, reflect.Slice, reflect.Array, reflect.Map: + for i := count - 1; i >= 0; i-- { + err := vm.subRunAll(omitNil, tePath+"["+strconv.Itoa(i)+"]", rv.Index(i), fn) + if err != nil { + return err + } + } + default: + return nil + } + + case reflect.Map: + if rv.Len() == 0 { + return nil + } + var canKey, canValue bool + rt := rv.Type() + switch dereferenceType(rt.Key()).Kind() { + case reflect.Struct, reflect.Interface, reflect.Slice, reflect.Array, reflect.Map: + canKey = true + } + switch dereferenceType(rt.Elem()).Kind() { + case reflect.Struct, reflect.Interface, reflect.Slice, reflect.Array, reflect.Map: + canValue = true + } + if !canKey && !canValue { + return nil + } + for _, key := range rv.MapKeys() { + if canKey { + err := vm.subRunAll(omitNil, tePath+"{k}", key, fn) + if err != nil { + return err + } + } + if canValue { + err := vm.subRunAll(omitNil, tePath+"{v for k="+key.String()+"}", rv.MapIndex(key), fn) + if err != nil { + return err + } + } + } + } + return nil +} + +func (vm *VM) subRun(path string, t reflect.Type, tid uintptr, ptr unsafe.Pointer) (*TagExpr, error) { + var err error + vm.rw.RLock() + s, ok := vm.structJar[tid] + vm.rw.RUnlock() + if !ok { + vm.rw.Lock() + s, ok = vm.structJar[tid] + if !ok { + s, err = vm.registerStructLocked(t) + if err != nil { + vm.rw.Unlock() + return nil, err + } + } + vm.rw.Unlock() + } + if s.err != nil { + return nil, s.err + } + return s.newTagExpr(ptr, path), nil +} + +func (vm *VM) registerStructLocked(structType reflect.Type) (*structVM, error) { + structType, err := vm.getStructType(structType) + if err != nil { + return nil, err + } + tid := rtType(structType) + s, had := vm.structJar[tid] + if had { + return s, s.err + } + s = vm.newStructVM() + s.name = structType.String() + vm.structJar[tid] = s + numField := structType.NumField() + var structField reflect.StructField + var sub *structVM + for i := 0; i < numField; i++ { + structField = structType.Field(i) + field, ok, err := s.newFieldVM(structField) + if err != nil { + s.err = err + return nil, err + } + // skip omitted tag + if !ok { + continue + } + switch field.elemKind { + default: + field.setUnsupportedGetter() + switch field.elemKind { + case reflect.Struct: + sub, err = vm.registerStructLocked(field.structField.Type) + if err != nil { + s.err = err + return nil, err + } + s.mergeSubStructVM(field, sub) + case reflect.Interface: + s.setIfaceTagExprGetter(field) + } + case reflect.Float32, reflect.Float64, + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + field.setFloatGetter() + case reflect.String: + field.setStringGetter() + case reflect.Bool: + field.setBoolGetter() + case reflect.Array, reflect.Slice, reflect.Map: + err = vm.registerIndirectStructLocked(field) + if err != nil { + s.err = err + return nil, err + } + } + } + return s, nil +} + +func (vm *VM) registerIndirectStructLocked(field *fieldVM) error { + field.setLengthGetter() + if field.tagOp == tagOmit { + return nil + } + a := make([]reflect.Type, 1, 2) + a[0] = derefType(field.elemType.Elem()) + if field.elemKind == reflect.Map { + a = append(a, derefType(field.elemType.Key())) + } + for i, t := range a { + kind := t.Kind() + switch kind { + case reflect.Interface: + field.mapOrSliceIfaceKinds[i] = true + field.origin.fieldsWithIndirectStructVM = appendDistinct(field.origin.fieldsWithIndirectStructVM, field) + case reflect.Slice, reflect.Array, reflect.Map: + tt := t.Elem() + checkMap := kind == reflect.Map + F2: + for { + switch tt.Kind() { + case reflect.Slice, reflect.Array, reflect.Map, reflect.Ptr: + tt = tt.Elem() + case reflect.Struct: + _, err := vm.registerStructLocked(tt) + if err != nil { + return err + } + field.mapOrSliceIfaceKinds[i] = true + field.origin.fieldsWithIndirectStructVM = appendDistinct(field.origin.fieldsWithIndirectStructVM, field) + break F2 + default: + break F2 + } + } + if checkMap { + tt = t.Key() + checkMap = false + goto F2 + } + case reflect.Struct: + s, err := vm.registerStructLocked(t) + if err != nil { + return err + } + if len(s.exprSelectorList) > 0 || + len(s.ifaceTagExprGetters) > 0 || + len(s.fieldsWithIndirectStructVM) > 0 { + if i == 0 { + field.mapOrSliceElemStructVM = s + } else { + field.mapKeyStructVM = s + } + field.origin.fieldsWithIndirectStructVM = appendDistinct(field.origin.fieldsWithIndirectStructVM, field) + } + } + } + return nil +} + +func appendDistinct(a []*fieldVM, i *fieldVM) []*fieldVM { + has := false + for _, e := range a { + if e == i { + has = true + break + } + } + if !has { + return append(a, i) + } + return a +} + +func (vm *VM) newStructVM() *structVM { + return &structVM{ + vm: vm, + fields: make(map[string]*fieldVM, 32), + fieldSelectorList: make([]string, 0, 32), + fieldsWithIndirectStructVM: make([]*fieldVM, 0, 32), + exprs: make(map[string]*Expr, 64), + exprSelectorList: make([]string, 0, 64), + } +} + +func (s *structVM) newFieldVM(structField reflect.StructField) (*fieldVM, bool, error) { + tag := structField.Tag.Get(s.vm.tagName) + if tag == tagOmit { + return nil, false, nil + } + f := &fieldVM{ + structField: structField, + exprs: make(map[string]*Expr, 8), + origin: s, + fieldSelector: structField.Name, + } + err := f.parseExprs(tag) + if err != nil { + return nil, false, err + } + s.fields[f.fieldSelector] = f + s.fieldSelectorList = append(s.fieldSelectorList, f.fieldSelector) + + t := structField.Type + var ptrDeep int + for t.Kind() == reflect.Ptr { + t = t.Elem() + ptrDeep++ + } + f.ptrDeep = ptrDeep + + offset := structField.Offset + f.getPtr = func(ptr unsafe.Pointer) unsafe.Pointer { + if ptr == nil { + return nil + } + return unsafe.Pointer(uintptr(ptr) + offset) + } + + f.elemType = t + f.elemKind = t.Kind() + f.reflectValueGetter = func(ptr unsafe.Pointer, initZero bool) reflect.Value { + v := f.packRawFrom(ptr) + if initZero { + f.ensureInit(v) + } + return v + } + + return f, true, nil +} + +func (f *fieldVM) ensureInit(v reflect.Value) { + if safeIsNil(v) && v.CanSet() { + newField := reflect.New(f.elemType).Elem() + for i := 0; i < f.ptrDeep; i++ { + if newField.CanAddr() { + newField = newField.Addr() + } else { + newField2 := reflect.New(newField.Type()) + newField2.Elem().Set(newField) + newField = newField2 + } + } + v.Set(newField) + } +} + +func (s *structVM) mergeSubStructVM(field *fieldVM, sub *structVM) { + field.origin = sub + fieldsWithIndirectStructVM := make(map[*fieldVM]struct{}, len(sub.fieldsWithIndirectStructVM)) + for _, subField := range sub.fieldsWithIndirectStructVM { + fieldsWithIndirectStructVM[subField] = struct{}{} + } + for _, k := range sub.fieldSelectorList { + v := sub.fields[k] + f := s.newChildField(field, v, true) + if _, ok := fieldsWithIndirectStructVM[v]; ok { + s.fieldsWithIndirectStructVM = append(s.fieldsWithIndirectStructVM, f) + // TODO: maybe needed? + // delete(fieldsWithIndirectStructVM, v) + } + } + // TODO: maybe needed? + // for v := range fieldsWithIndirectStructVM { + // f := s.newChildField(field, v, false) + // s.fieldsWithIndirectStructVM = append(s.fieldsWithIndirectStructVM, f) + // } + + for _, _subFn := range sub.ifaceTagExprGetters { + subFn := _subFn + s.ifaceTagExprGetters = append(s.ifaceTagExprGetters, func(ptr unsafe.Pointer, pathPrefix string, fn func(*TagExpr, error) error) error { + ptr = field.getElemPtr(ptr) + if ptr == nil { + return nil + } + var path string + if pathPrefix == "" { + path = field.fieldSelector + } else { + path = pathPrefix + FieldSeparator + field.fieldSelector + } + return subFn(ptr, path, fn) + }) + } +} + +func (s *structVM) newChildField(parent *fieldVM, child *fieldVM, toBind bool) *fieldVM { + f := &fieldVM{ + structField: child.structField, + exprs: make(map[string]*Expr, len(child.exprs)), + ptrDeep: child.ptrDeep, + elemType: child.elemType, + elemKind: child.elemKind, + origin: child.origin, + mapKeyStructVM: child.mapKeyStructVM, + mapOrSliceElemStructVM: child.mapOrSliceElemStructVM, + mapOrSliceIfaceKinds: child.mapOrSliceIfaceKinds, + fieldSelector: parent.fieldSelector + FieldSeparator + child.fieldSelector, + } + if parent.tagOp != tagOmit { + f.tagOp = child.tagOp + } else { + f.tagOp = parent.tagOp + } + f.getPtr = func(ptr unsafe.Pointer) unsafe.Pointer { + ptr = parent.getElemPtr(ptr) + if ptr == nil { + return nil + } + return child.getPtr(ptr) + } + if child.valueGetter != nil { + if parent.ptrDeep == 0 { + f.valueGetter = func(ptr unsafe.Pointer) interface{} { + return child.valueGetter(parent.getPtr(ptr)) + } + f.reflectValueGetter = func(ptr unsafe.Pointer, initZero bool) reflect.Value { + return child.reflectValueGetter(parent.getPtr(ptr), initZero) + } + } else { + f.valueGetter = func(ptr unsafe.Pointer) interface{} { + newField := reflect.NewAt(parent.structField.Type, parent.getPtr(ptr)) + for i := 0; i < parent.ptrDeep; i++ { + newField = newField.Elem() + } + if newField.IsNil() { + return nil + } + return child.valueGetter(unsafe.Pointer(newField.Pointer())) + } + f.reflectValueGetter = func(ptr unsafe.Pointer, initZero bool) reflect.Value { + newField := reflect.NewAt(parent.structField.Type, parent.getPtr(ptr)) + if initZero { + parent.ensureInit(newField.Elem()) + } + for i := 0; i < parent.ptrDeep; i++ { + newField = newField.Elem() + } + if (newField == reflect.Value{}) || (!initZero && newField.IsNil()) { + return reflect.Value{} + } + return child.reflectValueGetter(unsafe.Pointer(newField.Pointer()), initZero) + } + } + } + + if toBind { + s.fields[f.fieldSelector] = f + s.fieldSelectorList = append(s.fieldSelectorList, f.fieldSelector) + if parent.tagOp != tagOmit { + for k, v := range child.exprs { + selector := parent.fieldSelector + FieldSeparator + k + f.exprs[selector] = v + s.exprs[selector] = v + s.exprSelectorList = append(s.exprSelectorList, selector) + } + } + } + return f +} + +func (f *fieldVM) getElemPtr(ptr unsafe.Pointer) unsafe.Pointer { + ptr = f.getPtr(ptr) + for i := f.ptrDeep; ptr != nil && i > 0; i-- { + ptr = ptrElem(ptr) + } + return ptr +} + +func (f *fieldVM) packRawFrom(ptr unsafe.Pointer) reflect.Value { + return reflect.NewAt(f.structField.Type, f.getPtr(ptr)).Elem() +} + +func (f *fieldVM) packElemFrom(ptr unsafe.Pointer) reflect.Value { + return reflect.NewAt(f.elemType, f.getElemPtr(ptr)).Elem() +} + +func (s *structVM) setIfaceTagExprGetter(f *fieldVM) { + if f.tagOp == tagOmit { + return + } + s.ifaceTagExprGetters = append(s.ifaceTagExprGetters, func(ptr unsafe.Pointer, pathPrefix string, fn func(*TagExpr, error) error) error { + v := f.packElemFrom(ptr) + if !v.IsValid() || v.IsNil() { + return nil + } + var path string + if pathPrefix == "" { + path = f.fieldSelector + } else { + path = pathPrefix + FieldSeparator + f.fieldSelector + } + return s.vm.subRunAll(f.tagOp == tagOmitNil, path, v, fn) + }) +} + +func (f *fieldVM) setFloatGetter() { + if f.ptrDeep == 0 { + f.valueGetter = func(ptr unsafe.Pointer) interface{} { + ptr = f.getPtr(ptr) + if ptr == nil { + return nil + } + return getFloat64(f.elemKind, ptr) + } + } else { + f.valueGetter = func(ptr unsafe.Pointer) interface{} { + v := f.packElemFrom(ptr) + if v.CanAddr() { + return getFloat64(f.elemKind, unsafe.Pointer(v.UnsafeAddr())) + } + return nil + } + } +} + +func (f *fieldVM) setBoolGetter() { + if f.ptrDeep == 0 { + f.valueGetter = func(ptr unsafe.Pointer) interface{} { + ptr = f.getPtr(ptr) + if ptr == nil { + return nil + } + return *(*bool)(ptr) + } + } else { + f.valueGetter = func(ptr unsafe.Pointer) interface{} { + v := f.packElemFrom(ptr) + if v.IsValid() { + return v.Bool() + } + return nil + } + } +} + +func (f *fieldVM) setStringGetter() { + if f.ptrDeep == 0 { + f.valueGetter = func(ptr unsafe.Pointer) interface{} { + ptr = f.getPtr(ptr) + if ptr == nil { + return nil + } + return *(*string)(ptr) + } + } else { + f.valueGetter = func(ptr unsafe.Pointer) interface{} { + v := f.packElemFrom(ptr) + if v.IsValid() { + return v.String() + } + return nil + } + } +} + +func (f *fieldVM) setLengthGetter() { + f.valueGetter = func(ptr unsafe.Pointer) interface{} { + v := f.packElemFrom(ptr) + if v.IsValid() { + return v.Interface() + } + return nil + } +} + +func (f *fieldVM) setUnsupportedGetter() { + f.valueGetter = func(ptr unsafe.Pointer) interface{} { + raw := f.packRawFrom(ptr) + if safeIsNil(raw) { + return nil + } + v := raw + for i := 0; i < f.ptrDeep; i++ { + v = v.Elem() + } + for v.Kind() == reflect.Interface { + v = v.Elem() + } + return anyValueGetter(raw, v) + } +} + +func (vm *VM) getStructType(t reflect.Type) (reflect.Type, error) { + structType := t + for structType.Kind() == reflect.Ptr { + structType = structType.Elem() + } + if structType.Kind() != reflect.Struct { + return nil, fmt.Errorf("unsupported type: %s", t.String()) + } + return structType, nil +} + +func (s *structVM) newTagExpr(ptr unsafe.Pointer, path string) *TagExpr { + te := &TagExpr{ + s: s, + ptr: ptr, + sub: make(map[string]*TagExpr, 8), + path: strings.TrimPrefix(path, "."), + } + return te +} + +// TagExpr struct tag expression evaluator +type TagExpr struct { + s *structVM + ptr unsafe.Pointer + sub map[string]*TagExpr + path string +} + +// EvalFloat evaluates the value of the struct tag expression by the selector expression. +// NOTE: +// +// If the expression value type is not float64, return 0. +func (t *TagExpr) EvalFloat(exprSelector string) float64 { + r, _ := t.Eval(exprSelector).(float64) + return r +} + +// EvalString evaluates the value of the struct tag expression by the selector expression. +// NOTE: +// +// If the expression value type is not string, return "". +func (t *TagExpr) EvalString(exprSelector string) string { + r, _ := t.Eval(exprSelector).(string) + return r +} + +// EvalBool evaluates the value of the struct tag expression by the selector expression. +// NOTE: +// +// If the expression value is not 0, '' or nil, return true. +func (t *TagExpr) EvalBool(exprSelector string) bool { + return FakeBool(t.Eval(exprSelector)) +} + +// FakeBool fakes any type as a boolean. +func FakeBool(v interface{}) bool { + switch r := v.(type) { + case float64: + return r != 0 + case float32: + return r != 0 + case int: + return r != 0 + case int8: + return r != 0 + case int16: + return r != 0 + case int32: + return r != 0 + case int64: + return r != 0 + case uint: + return r != 0 + case uint8: + return r != 0 + case uint16: + return r != 0 + case uint32: + return r != 0 + case uint64: + return r != 0 + case string: + return r != "" + case bool: + return r + case nil, error: + return false + case []interface{}: + bol := true + for _, v := range r { + bol = bol && FakeBool(v) + } + return bol + default: + vv := dereferenceValue(reflect.ValueOf(v)) + if vv.IsValid() || vv.IsZero() { + return false + } + return true + } +} + +// Field returns the field handler specified by the selector. +func (t *TagExpr) Field(fieldSelector string) (fh *FieldHandler, found bool) { + f, ok := t.s.fields[fieldSelector] + if !ok { + return nil, false + } + return newFieldHandler(t, fieldSelector, f), true +} + +// RangeFields loop through each field. +// When fn returns false, interrupt traversal and return false. +func (t *TagExpr) RangeFields(fn func(*FieldHandler) bool) bool { + if list := t.s.fieldSelectorList; len(list) > 0 { + fields := t.s.fields + for _, fieldSelector := range list { + if !fn(newFieldHandler(t, fieldSelector, fields[fieldSelector])) { + return false + } + } + } + return true +} + +// Eval evaluates the value of the struct tag expression by the selector expression. +// NOTE: +// +// format: fieldName, fieldName.exprName, fieldName1.fieldName2.exprName1 +// result types: float64, string, bool, nil +func (t *TagExpr) Eval(exprSelector string) interface{} { + expr, ok := t.s.exprs[exprSelector] + if !ok { + // Compatible with single mode or the expression with the name @ + if strings.HasSuffix(exprSelector, ExprNameSeparator) { + exprSelector = exprSelector[:len(exprSelector)-1] + if strings.HasSuffix(exprSelector, ExprNameSeparator) { + exprSelector = exprSelector[:len(exprSelector)-1] + } + expr, ok = t.s.exprs[exprSelector] + } + if !ok { + return nil + } + } + dir, base := splitFieldSelector(exprSelector) + targetTagExpr, err := t.checkout(dir) + if err != nil { + return nil + } + return expr.run(base, targetTagExpr) +} + +// EvalWithEnv evaluates the value with the given env +// NOTE: +// +// format: fieldName, fieldName.exprName, fieldName1.fieldName2.exprName1 +// result types: float64, string, bool, nil +func (t *TagExpr) EvalWithEnv(exprSelector string, env map[string]interface{}) interface{} { + expr, ok := t.s.exprs[exprSelector] + if !ok { + // Compatible with single mode or the expression with the name @ + if strings.HasSuffix(exprSelector, ExprNameSeparator) { + exprSelector = exprSelector[:len(exprSelector)-1] + if strings.HasSuffix(exprSelector, ExprNameSeparator) { + exprSelector = exprSelector[:len(exprSelector)-1] + } + expr, ok = t.s.exprs[exprSelector] + } + if !ok { + return nil + } + } + dir, base := splitFieldSelector(exprSelector) + targetTagExpr, err := t.checkout(dir) + if err != nil { + return nil + } + return expr.runWithEnv(base, targetTagExpr, env) +} + +// Range loop through each tag expression. +// When fn returns false, interrupt traversal and return false. +// NOTE: +// +// eval result types: float64, string, bool, nil +func (t *TagExpr) Range(fn func(*ExprHandler) error) error { + var err error + if list := t.s.exprSelectorList; len(list) > 0 { + for _, es := range list { + dir, base := splitFieldSelector(es) + targetTagExpr, err := t.checkout(dir) + if err != nil { + continue + } + err = fn(newExprHandler(t, targetTagExpr, base, es)) + if err != nil { + return err + } + } + } + + ptr := t.ptr + + if list := t.s.fieldsWithIndirectStructVM; len(list) > 0 { + for _, f := range list { + v := f.packElemFrom(ptr) + if !v.IsValid() { + continue + } + omitNil := f.tagOp == tagOmitNil + mapKeyStructVM := f.mapKeyStructVM + mapOrSliceElemStructVM := f.mapOrSliceElemStructVM + valueIface := f.mapOrSliceIfaceKinds[0] + keyIface := f.mapOrSliceIfaceKinds[1] + + if f.elemKind == reflect.Map && + (mapOrSliceElemStructVM != nil || mapKeyStructVM != nil || valueIface || keyIface) { + keyPath := f.fieldSelector + "{k}" + for _, key := range v.MapKeys() { + if mapKeyStructVM != nil { + p := rvPtr(derefValue(key)) + if omitNil && p == nil { + continue + } + err = mapKeyStructVM.newTagExpr(p, keyPath).Range(fn) + if err != nil { + return err + } + } else if keyIface { + err = t.subRange(omitNil, keyPath, key, fn) + if err != nil { + return err + } + } + if mapOrSliceElemStructVM != nil { + p := rvPtr(derefValue(v.MapIndex(key))) + if omitNil && p == nil { + continue + } + err = mapOrSliceElemStructVM.newTagExpr(p, f.fieldSelector+"{v for k="+key.String()+"}").Range(fn) + if err != nil { + return err + } + } else if valueIface { + err = t.subRange(omitNil, f.fieldSelector+"{v for k="+key.String()+"}", v.MapIndex(key), fn) + if err != nil { + return err + } + } + } + + } else if mapOrSliceElemStructVM != nil || valueIface { + // slice or array + for i := v.Len() - 1; i >= 0; i-- { + if mapOrSliceElemStructVM != nil { + p := rvPtr(derefValue(v.Index(i))) + if omitNil && p == nil { + continue + } + err = mapOrSliceElemStructVM.newTagExpr(p, f.fieldSelector+"["+strconv.Itoa(i)+"]").Range(fn) + if err != nil { + return err + } + } else if valueIface { + err = t.subRange(omitNil, f.fieldSelector+"["+strconv.Itoa(i)+"]", v.Index(i), fn) + if err != nil { + return err + } + } + } + } + } + } + + if list := t.s.ifaceTagExprGetters; len(list) > 0 { + for _, getter := range list { + err = getter(ptr, "", func(te *TagExpr, err error) error { + if err != nil { + return err + } + return te.Range(fn) + }) + if err != nil { + return err + } + } + } + return nil +} + +func (t *TagExpr) subRange(omitNil bool, path string, value reflect.Value, fn func(*ExprHandler) error) error { + return t.s.vm.subRunAll(omitNil, path, value, func(te *TagExpr, err error) error { + if err != nil { + return err + } + return te.Range(fn) + }) +} + +var ( + errFieldSelector = errors.New("field selector does not exist") + errOmitNil = errors.New("omit nil") +) + +func (t *TagExpr) checkout(fs string) (*TagExpr, error) { + if fs == "" { + return t, nil + } + subTagExpr, ok := t.sub[fs] + if ok { + if subTagExpr == nil { + return nil, errOmitNil + } + return subTagExpr, nil + } + f, ok := t.s.fields[fs] + if !ok { + return nil, errFieldSelector + } + ptr := f.getElemPtr(t.ptr) + if f.tagOp == tagOmitNil && ptr == nil { + t.sub[fs] = nil + return nil, errOmitNil + } + subTagExpr = f.origin.newTagExpr(ptr, t.path) + t.sub[fs] = subTagExpr + return subTagExpr, nil +} + +func (t *TagExpr) getValue(fieldSelector string, subFields []interface{}) (v interface{}) { + f := t.s.fields[fieldSelector] + if f == nil { + return nil + } + if f.valueGetter == nil { + return nil + } + v = f.valueGetter(t.ptr) + if v == nil { + return nil + } + if len(subFields) == 0 { + return v + } + vv := reflect.ValueOf(v) + var kind reflect.Kind + for i, k := range subFields { + kind = vv.Kind() + for kind == reflect.Ptr || kind == reflect.Interface { + vv = vv.Elem() + kind = vv.Kind() + } + switch kind { + case reflect.Slice, reflect.Array, reflect.String: + if float, ok := k.(float64); ok { + idx := int(float) + if idx >= vv.Len() { + return nil + } + vv = vv.Index(idx) + } else { + return nil + } + case reflect.Map: + k := safeConvert(reflect.ValueOf(k), vv.Type().Key()) + if !k.IsValid() { + return nil + } + vv = vv.MapIndex(k) + case reflect.Struct: + if float, ok := k.(float64); ok { + idx := int(float) + if idx < 0 || idx >= vv.NumField() { + return nil + } + vv = vv.Field(idx) + } else if str, ok := k.(string); ok { + vv = vv.FieldByName(str) + } else { + return nil + } + default: + if i < len(subFields)-1 { + return nil + } + } + if !vv.IsValid() { + return nil + } + } + raw := vv + for vv.Kind() == reflect.Ptr || vv.Kind() == reflect.Interface { + vv = vv.Elem() + } + return anyValueGetter(raw, vv) +} + +func safeConvert(v reflect.Value, t reflect.Type) reflect.Value { + defer func() { recover() }() + return v.Convert(t) +} + +func splitFieldSelector(selector string) (dir, base string) { + idx := strings.LastIndex(selector, ExprNameSeparator) + if idx != -1 { + selector = selector[:idx] + } + idx = strings.LastIndex(selector, FieldSeparator) + if idx != -1 { + return selector[:idx], selector[idx+1:] + } + return "", selector +} + +func getFloat64(kind reflect.Kind, p unsafe.Pointer) interface{} { + switch kind { + case reflect.Float32: + return float64(*(*float32)(p)) + case reflect.Float64: + return *(*float64)(p) + case reflect.Int: + return float64(*(*int)(p)) + case reflect.Int8: + return float64(*(*int8)(p)) + case reflect.Int16: + return float64(*(*int16)(p)) + case reflect.Int32: + return float64(*(*int32)(p)) + case reflect.Int64: + return float64(*(*int64)(p)) + case reflect.Uint: + return float64(*(*uint)(p)) + case reflect.Uint8: + return float64(*(*uint8)(p)) + case reflect.Uint16: + return float64(*(*uint16)(p)) + case reflect.Uint32: + return float64(*(*uint32)(p)) + case reflect.Uint64: + return float64(*(*uint64)(p)) + case reflect.Uintptr: + return float64(*(*uintptr)(p)) + } + return nil +} + +func anyValueGetter(raw, elem reflect.Value) interface{} { + if !elem.IsValid() || !raw.IsValid() { + return nil + } + kind := elem.Kind() + switch kind { + case reflect.Float32, reflect.Float64, + reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, + reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + if elem.CanAddr() { + return getFloat64(kind, unsafe.Pointer(elem.UnsafeAddr())) + } + switch kind { + case reflect.Float32, reflect.Float64: + return elem.Float() + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return float64(elem.Int()) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return float64(elem.Uint()) + } + case reflect.String: + return elem.String() + case reflect.Bool: + return elem.Bool() + } + if raw.CanInterface() { + return raw.Interface() + } + return nil +} + +func safeIsNil(v reflect.Value) bool { + if !v.IsValid() { + return true + } + switch v.Kind() { + case reflect.Chan, reflect.Func, reflect.Map, reflect.Ptr, + reflect.UnsafePointer, reflect.Interface, reflect.Slice: + return v.IsNil() + } + return false +} + +//go:nocheckptr +func ptrElem(ptr unsafe.Pointer) unsafe.Pointer { + return unsafe.Pointer(*(*uintptr)(ptr)) +} + +func derefType(t reflect.Type) reflect.Type { + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + return t +} + +func derefValue(v reflect.Value) reflect.Value { + for v.Kind() == reflect.Ptr { + v = v.Elem() + } + return v +} diff --git a/internal/tagexpr/tagexpr_test.go b/internal/tagexpr/tagexpr_test.go new file mode 100644 index 000000000..d571c4106 --- /dev/null +++ b/internal/tagexpr/tagexpr_test.go @@ -0,0 +1,855 @@ +// Copyright 2019 Bytedance Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tagexpr_test + +import ( + "fmt" + "reflect" + "strconv" + "testing" + "time" + + "github.com/cloudwego/hertz/internal/tagexpr" +) + +func assertEqual(t *testing.T, v1, v2 interface{}, msgs ...interface{}) { + t.Helper() + if reflect.DeepEqual(v1, v2) { + return + } + t.Fatal(fmt.Sprintf("not equal %v %v", v1, v2) + "\n" + fmt.Sprint(msgs...)) +} + +func BenchmarkTagExpr(b *testing.B) { + type T struct { + a int `bench:"$%3"` + } + vm := tagexpr.New("bench") + vm.MustRun(new(T)) // warm up + b.ReportAllocs() + b.ResetTimer() + t := &T{10} + for i := 0; i < b.N; i++ { + tagExpr, err := vm.Run(t) + if err != nil { + b.FailNow() + } + if tagExpr.EvalFloat("a") != 1 { + b.FailNow() + } + } +} + +func BenchmarkReflect(b *testing.B) { + type T struct { + a int `remainder:"3"` + } + b.ReportAllocs() + b.ResetTimer() + t := &T{1} + for i := 0; i < b.N; i++ { + v := reflect.ValueOf(t).Elem() + ft, ok := v.Type().FieldByName("a") + if !ok { + b.FailNow() + } + x, err := strconv.ParseInt(ft.Tag.Get("remainder"), 10, 64) + if err != nil { + b.FailNow() + } + fv := v.FieldByName("a") + if fv.Int()%x != 1 { + b.FailNow() + } + } +} + +func Test(t *testing.T) { + g := &struct { + _ int + h string `tagexpr:"$"` + s []string + m map[string][]string + }{ + h: "haha", + s: []string{"1"}, + m: map[string][]string{"0": {"2"}}, + } + d := "ddd" + e := new(int) + *e = 3 + type iface interface{} + cases := []struct { + tagName string + structure interface{} + tests map[string]interface{} + }{ + { + tagName: "tagexpr", + structure: &struct { + A int `tagexpr:"$>0&&$<10&&!''&&!!!0&&!nil&&$"` + A2 int `tagexpr:"@:$>0&&$<10"` + b string `tagexpr:"is:$=='test';msg:sprintf('expect: test, but got: %s',$)"` + c float32 `tagexpr:"(A)$+$"` + d *string `tagexpr:"$"` + e **int `tagexpr:"$"` + f *[3]int `tagexpr:"x:len($)"` + g string `tagexpr:"x:!regexp('xxx',$);y:regexp('g\\d{3}$')"` + h []string `tagexpr:"x:$[1];y:$[10]"` + i map[string]int `tagexpr:"x:$['a'];y:$[0];z:$==nil"` + i2 *map[string]int `tagexpr:"x:$['a'];y:$[0];z:$"` + j, j2 iface `tagexpr:"@:$==1;y:$"` + k *iface `tagexpr:"$==nil"` + m *struct{ i int } `tagexpr:"@:$;x:$['a']['x']"` + }{ + A: 5.0, + A2: 5.0, + b: "x", + c: 1, + d: &d, + e: &e, + f: new([3]int), + g: "g123", + h: []string{"", "hehe"}, + i: map[string]int{"a": 7}, + j2: iface(1), + m: &struct{ i int }{1}, + }, + tests: map[string]interface{}{ + "A": true, + "A2": true, + "b@is": false, + "b@msg": "expect: test, but got: x", + "c": 6.0, + "d": d, + "e": float64(*e), + "f@x": float64(3), + "g@x": true, + "g@y": true, + "h@x": "hehe", + "h@y": nil, + "i@x": 7.0, + "i@y": nil, + "i@z": false, + "i2@x": nil, + "i2@y": nil, + "i2@z": nil, + "j": false, + "j@y": nil, + "j2": true, + "j2@y": 1.0, + "k": true, + "m": &struct{ i int }{1}, + "m@x": nil, + }, + }, + { + tagName: "tagexpr", + structure: &struct { + A int `tagexpr:"$>0&&$<10"` + b string `tagexpr:"is:$=='test';msg:sprintf('expect: test, but got: %s',$)"` + c struct { + _ int + d bool `tagexpr:"$"` + } + e *struct { + _ int + f bool `tagexpr:"$"` + } + g **struct { + _ int + h string `tagexpr:"$"` + s []string + m map[string][]string + } `tagexpr:"$['h']"` + i string `tagexpr:"(g.s)$[0]+(g.m)$['0'][0]==$"` + j bool `tagexpr:"!$"` + k int `tagexpr:"!$"` + m *int `tagexpr:"$==nil"` + n *bool `tagexpr:"$==nil"` + p *string `tagexpr:"$"` + }{ + A: 5, + b: "x", + c: struct { + _ int + d bool `tagexpr:"$"` + }{d: true}, + e: &struct { + _ int + f bool `tagexpr:"$"` + }{f: true}, + g: &g, + i: "12", + }, + tests: map[string]interface{}{ + "A": true, + "b@is": false, + "b@msg": "expect: test, but got: x", + "c.d": true, + "e.f": true, + "g": "haha", + "g.h": "haha", + "i": true, + "j": true, + "k": true, + "m": true, + "n": true, + "p": nil, + }, + }, + { + tagName: "p", + structure: &struct { + q *struct { + x int + } `p:"(q.x)$"` + }{}, + tests: map[string]interface{}{ + "q": nil, + }, + }, + } + for i, c := range cases { + vm := tagexpr.New(c.tagName) + // vm.WarmUp(c.structure) + tagExpr, err := vm.Run(c.structure) + if err != nil { + t.Fatal(err) + } + for selector, value := range c.tests { + val := tagExpr.Eval(selector) + if !reflect.DeepEqual(val, value) { + t.Fatalf("Eval Serial: %d, selector: %q, got: %v, expect: %v", i, selector, val, value) + } + } + tagExpr.Range(func(eh *tagexpr.ExprHandler) error { + es := eh.ExprSelector() + t.Logf("Range selector: %s, field: %q exprName: %q", es, es.Field(), es.Name()) + value := c.tests[es.String()] + val := eh.Eval() + if !reflect.DeepEqual(val, value) { + t.Fatalf("Range NO: %d, selector: %q, got: %v, expect: %v", i, es, val, value) + } + return nil + }) + } +} + +func TestFieldNotInit(t *testing.T) { + g := &struct { + _ int + h string + s []string + m map[string][]string + }{ + h: "haha", + s: []string{"1"}, + m: map[string][]string{"0": {"2"}}, + } + structure := &struct { + A int + b string + c struct { + _ int + d *bool `expr:"test:nil"` + } + e *struct { + _ int + f bool + } + g **struct { + _ int + h string + s []string + m map[string][]string + } + i string + j bool + k int + m *int + n *bool + p *string + }{ + A: 5, + b: "x", + e: &struct { + _ int + f bool + }{f: true}, + g: &g, + i: "12", + } + vm := tagexpr.New("expr") + e, err := vm.Run(structure) + if err != nil { + t.Fatal(err) + } + cases := []struct { + fieldSelector string + value interface{} + }{ + {"A", structure.A}, + {"b", structure.b}, + {"c", structure.c}, + {"c._", 0}, + {"c.d", structure.c.d}, + {"e", structure.e}, + {"e._", 0}, + {"e.f", structure.e.f}, + {"g", structure.g}, + {"g._", 0}, + {"g.h", (*structure.g).h}, + {"g.s", (*structure.g).s}, + {"g.m", (*structure.g).m}, + {"i", structure.i}, + {"j", structure.j}, + {"k", structure.k}, + {"m", structure.m}, + {"n", structure.n}, + {"p", structure.p}, + } + for _, c := range cases { + fh, _ := e.Field(c.fieldSelector) + val := fh.Value(false).Interface() + assertEqual(t, c.value, val, c.fieldSelector) + } + var i int + e.RangeFields(func(fh *tagexpr.FieldHandler) bool { + val := fh.Value(false).Interface() + if fh.StringSelector() == "c.d" { + if fh.EvalFuncs()["c.d@test"] == nil { + t.Fatal("nil") + } + } + assertEqual(t, cases[i].value, val, fh.StringSelector()) + i++ + return true + }) + var wall uint64 = 1024 + unix := time.Unix(1549186325, int64(wall)) + e, err = vm.Run(&unix) + if err != nil { + t.Fatal(err) + } + fh, _ := e.Field("wall") + val := fh.Value(false).Interface() + if !reflect.DeepEqual(val, wall) { + t.Fatalf("Time.wall: got: %v(%[1]T), expect: %v(%[2]T)", val, wall) + } +} + +func TestFieldInitZero(t *testing.T) { + g := &struct { + _ int + h string + s []string + m map[string][]string + }{ + h: "haha", + s: []string{"1"}, + m: map[string][]string{"0": {"2"}}, + } + + structure := &struct { + A int + b string + c struct { + _ int + d *bool + } + e *struct { + _ int + f bool + } + g **struct { + _ int + h string + s []string + m map[string][]string + } + g2 ****struct { + _ int + h string + s []string + m map[string][]string + } + i string + j bool + k int + m *int + n *bool + p *string + }{ + A: 5, + b: "x", + e: &struct { + _ int + f bool + }{f: true}, + g: &g, + i: "12", + } + + vm := tagexpr.New("") + e, err := vm.Run(structure) + if err != nil { + t.Fatal(err) + } + + cases := []struct { + fieldSelector string + value interface{} + }{ + {"A", structure.A}, + {"b", structure.b}, + {"c", struct { + _ int + d *bool + }{}}, + {"c._", 0}, + {"c.d", new(bool)}, + {"e", structure.e}, + {"e._", 0}, + {"e.f", structure.e.f}, + {"g", structure.g}, + {"g._", 0}, + {"g.h", (*structure.g).h}, + {"g.s", (*structure.g).s}, + {"g.m", (*structure.g).m}, + {"g2.m", (map[string][]string)(nil)}, + {"i", structure.i}, + {"j", structure.j}, + {"k", structure.k}, + {"m", new(int)}, + {"n", new(bool)}, + {"p", new(string)}, + } + for _, c := range cases { + fh, _ := e.Field(c.fieldSelector) + val := fh.Value(true).Interface() + assertEqual(t, c.value, val, c.fieldSelector) + } +} + +func TestOperator(t *testing.T) { + type Tmp1 struct { + A string `tagexpr:$=="1"||$=="2"||$="3"` //nolint:govet + B []int `tagexpr:len($)>=10&&$[0]<10` //nolint:govet + C interface{} + } + + type Tmp2 struct { + A *Tmp1 + B interface{} + } + + type Target struct { + A int `tagexpr:"-$+$<=10"` + B int `tagexpr:"+$-$<=10"` + C int `tagexpr:"-$+(M)$*(N)$/$%(D.B)$[2]+$==1"` + D *Tmp1 `tagexpr:"(D.A)$!=nil"` + E string `tagexpr:"((D.A)$=='1'&&len($)>1)||((D.A)$=='2'&&len($)>2)||((D.A)$=='3'&&len($)>3)"` + F map[string]int `tagexpr:"x:len($);y:$['a']>10&&$['b']>1"` + G *map[string]int `tagexpr:"x:$['a']+(F)$['a']>20"` + H []string `tagexpr:"len($)>=1&&len($)<10&&$[0]=='123'&&$[1]!='456'"` + I interface{} `tagexpr:"$!=nil"` + K *string `tagexpr:"len((D.A)$)+len($)<10&&len((D.A)$+$)<10"` + L **string `tagexpr:"false"` + M float64 `tagexpr:"$/2>10&&$%2==0"` + N *float64 `tagexpr:"($+$*$-$/$+1)/$==$+1"` + O *[3]float64 `tagexpr:"$[0]>10&&$[0]<20||$[0]>20&&$[0]<30"` + P *Tmp2 `tagexpr:"x:$!=nil;y:len((P.A.A)$)<=1&&(P.A.B)$[0]==1;z:$['A']['C']==nil;w:$['A']['B'][0]==1;r:$[0][1][2]==3;s1:$[2]==nil;s2:$[0][3]==nil;s3:(ZZ)$;s4:(P.B)$!=nil"` + Q *Tmp2 `tagexpr:"s1:$['A']['B']!=nil;s2:(Q.A)$['B']!=nil;s3:$['A']['C']==nil;s4:(Q.A)$['C']==nil;s5:(Q.A)$['B'][0]==1;s6:$['X']['Z']==nil"` + } + + k := "123456" + n := float64(-12.5) + o := [3]float64{15, 9, 9} + cases := []struct { + tagName string + structure interface{} + tests map[string]interface{} + }{ + { + tagName: "tagexpr", + structure: &Target{ + A: 5, + B: 10, + C: -10, + D: &Tmp1{A: "3", B: []int{1, 2, 3}}, + E: "1234", + F: map[string]int{"a": 11, "b": 9}, + G: &map[string]int{"a": 11}, + H: []string{"123", "45"}, + I: struct{}{}, + K: &k, + L: nil, + M: float64(30), + N: &n, + O: &o, + P: &Tmp2{A: &Tmp1{A: "3", B: []int{1, 2, 3}}, B: struct{}{}}, + Q: &Tmp2{A: &Tmp1{A: "3", B: []int{1, 2, 3}}, B: struct{}{}}, + }, + tests: map[string]interface{}{ + "A": true, + "B": true, + "C": true, + "D": true, + "E": true, + "F@x": float64(2), + "F@y": true, + "G@x": true, + "H": true, + "I": true, + "K": true, + "L": false, + "M": true, + "N": true, + "O": true, + + "P@x": true, + "P@y": true, + "P@z": true, + "P@w": true, + "P@r": true, + "P@s1": true, + "P@s2": true, + "P@s3": nil, + "P@s4": true, + + "Q@s1": true, + "Q@s2": true, + "Q@s3": true, + "Q@s4": true, + "Q@s5": true, + "Q@s6": true, + }, + }, + } + + for i, c := range cases { + vm := tagexpr.New(c.tagName) + // vm.WarmUp(c.structure) + tagExpr, err := vm.Run(c.structure) + if err != nil { + t.Fatal(err) + } + for selector, value := range c.tests { + val := tagExpr.Eval(selector) + if !reflect.DeepEqual(val, value) { + t.Fatalf("Eval NO: %d, selector: %q, got: %v, expect: %v", i, selector, val, value) + } + } + tagExpr.Range(func(eh *tagexpr.ExprHandler) error { + es := eh.ExprSelector() + t.Logf("Range selector: %s, field: %q exprName: %q", es, es.Field(), es.Name()) + value := c.tests[es.String()] + val := eh.Eval() + if !reflect.DeepEqual(val, value) { + t.Fatalf("Range NO: %d, selector: %q, got: %v, expect: %v", i, es, val, value) + } + return nil + }) + } +} + +func TestStruct(t *testing.T) { + type A struct { + B struct { + C struct { + D struct { + X string `vd:"$"` + } + } `vd:"@:$['D']['X']"` + C2 string `vd:"@:(C)$['D']['X']"` + C3 string `vd:"@:(C.D.X)$"` + } + } + a := new(A) + a.B.C.D.X = "xxx" + vm := tagexpr.New("vd") + expr := vm.MustRun(a) + assertEqual(t, "xxx", expr.EvalString("B.C2")) + assertEqual(t, "xxx", expr.EvalString("B.C3")) + assertEqual(t, "xxx", expr.EvalString("B.C")) + assertEqual(t, "xxx", expr.EvalString("B.C.D.X")) + expr.Range(func(eh *tagexpr.ExprHandler) error { + es := eh.ExprSelector() + t.Logf("Range selector: %s, field: %q exprName: %q", es, es.Field(), es.Name()) + if eh.Eval().(string) != "xxx" { + t.FailNow() + } + return nil + }) +} + +func TestStruct2(t *testing.T) { + type IframeBlock struct { + XBlock struct { + BlockType string `vd:"$"` + } + Props struct { + Data struct { + DataType string `vd:"$"` + } + } + } + b := new(IframeBlock) + b.XBlock.BlockType = "BlockType" + b.Props.Data.DataType = "DataType" + vm := tagexpr.New("vd") + expr := vm.MustRun(b) + if expr.EvalString("XBlock.BlockType") != "BlockType" { + t.Fatal(expr.EvalString("XBlock.BlockType")) + } + if expr.EvalString("Props.Data.DataType") != "DataType" { + t.Fatal(expr.EvalString("Props.Data.DataType")) + } +} + +func TestStruct3(t *testing.T) { + type Data struct { + DataType string `vd:"$"` + } + type Prop struct { + PropType string `vd:"$"` + DD []*Data `vd:"$"` + DD2 []*Data `vd:"$"` + DataMap map[int]Data `vd:"$"` + DataMap2 map[int]Data `vd:"$"` + } + type IframeBlock struct { + XBlock struct { + BlockType string `vd:"$"` + } + Props []Prop `vd:"$"` + Props1 [2]Prop `vd:"$"` + Props2 []Prop `vd:"$"` + PropMap map[int]*Prop `vd:"$"` + PropMap2 map[int]*Prop `vd:"$"` + } + + b := new(IframeBlock) + b.XBlock.BlockType = "BlockType" + p1 := Prop{ + PropType: "p1", + DD: []*Data{ + {"p1s1"}, + {"p1s2"}, + nil, + }, + DataMap: map[int]Data{ + 1: {"p1m1"}, + 2: {"p1m2"}, + 0: {}, + }, + } + b.Props = []Prop{p1} + p2 := &Prop{ + PropType: "p2", + DD: []*Data{ + {"p2s1"}, + {"p2s2"}, + nil, + }, + DataMap: map[int]Data{ + 1: {"p2m1"}, + 2: {"p2m2"}, + 0: {}, + }, + } + b.Props1 = [2]Prop{p1, {}} + b.PropMap = map[int]*Prop{ + 9: p2, + } + + vm := tagexpr.New("vd") + expr := vm.MustRun(b) + if expr.EvalString("XBlock.BlockType") != "BlockType" { + t.Fatal(expr.EvalString("XBlock.BlockType")) + } + err := expr.Range(func(eh *tagexpr.ExprHandler) error { + es := eh.ExprSelector() + t.Logf("Range selector: %s, field: %q exprName: %q, eval: %v", eh.Path(), es.Field(), es.Name(), eh.Eval()) + return nil + }) + if err != nil { + t.Fatal(err) + } +} + +func TestNilField(t *testing.T) { + type P struct { + X **struct { + A *[]uint16 `tagexpr:"$"` + } `tagexpr:"$"` + Y **struct{} `tagexpr:"$"` + } + vm := tagexpr.New("tagexpr") + te := vm.MustRun(&P{}) + te.Range(func(eh *tagexpr.ExprHandler) error { + r := eh.Eval() + if r != nil { + t.Fatal(eh.Path()) + } + return nil + }) + + type G struct { + // Nil1 *int `tagexpr:"nil!=$"` + Nil2 *int `tagexpr:"$!=nil"` + } + g := &G{ + // Nil1: new(int), + Nil2: new(int), + } + vm.MustRun(g).Range(func(eh *tagexpr.ExprHandler) error { + r, ok := eh.Eval().(bool) + if !ok || !r { + t.Fatal(eh.Path()) + } + return nil + }) +} + +func TestDeepNested(t *testing.T) { + type testInner struct { + Address string `tagexpr:"name:$"` + } + type struct1 struct { + I *testInner + A []*testInner + X interface{} + } + type struct2 struct { + S *struct1 + } + type Data struct { + S1 *struct2 + S2 *struct2 + } + data := &Data{ + S1: &struct2{ + S: &struct1{ + I: &testInner{Address: "I:address"}, + A: []*testInner{{Address: "A:address"}}, + X: []*testInner{{Address: "X:address"}}, + }, + }, + S2: &struct2{ + S: &struct1{ + A: []*testInner{nil}, + }, + }, + } + expectKey := [...]interface{}{"S1.S.I.Address@name", "S2.S.I.Address@name", "S1.S.A[0].Address@name", "S2.S.A[0].Address@name", "S1.S.X[0].Address@name"} + expectValue := [...]interface{}{"I:address", nil, "A:address", nil, "X:address"} + var i int + vm := tagexpr.New("tagexpr") + vm.MustRun(data).Range(func(eh *tagexpr.ExprHandler) error { + assertEqual(t, expectKey[i], eh.Path()) + assertEqual(t, expectValue[i], eh.Eval()) + i++ + t.Log(eh.Path(), eh.ExprSelector(), eh.Eval()) + return nil + }) + assertEqual(t, 5, i) +} + +func TestIssue3(t *testing.T) { + type C struct { + Id string + Index int32 `vd:"$"` + P *int `vd:"$!=nil"` + } + type A struct { + F1 *C + F2 *C + } + a := &A{ + F1: &C{ + Id: "test", + Index: 1, + P: new(int), + }, + } + vm := tagexpr.New("vd") + err := vm.MustRun(a).Range(func(eh *tagexpr.ExprHandler) error { + switch eh.Path() { + case "F1.Index": + assertEqual(t, float64(1), eh.Eval(), eh.Path()) + case "F2.Index": + assertEqual(t, nil, eh.Eval(), eh.Path()) + case "F1.P": + assertEqual(t, true, eh.Eval(), eh.Path()) + case "F2.P": + assertEqual(t, false, eh.Eval(), eh.Path()) + } + return nil + }) + if err != nil { + t.Fatal(err) + } +} + +func TestIssue4(t *testing.T) { + type T struct { + A *string `te:"len($)+mblen($)"` + B *string `te:"len($)+mblen($)"` + C *string `te:"len($)+mblen($)"` + } + c := "c" + v := &T{ + B: new(string), + C: &c, + } + vm := tagexpr.New("te") + err := vm.MustRun(v).Range(func(eh *tagexpr.ExprHandler) error { + t.Logf("eval:%v, path:%s", eh.EvalFloat(), eh.Path()) + return nil + }) + if err != nil { + t.Fatal(err) + } +} + +func TestIssue5(t *testing.T) { + type A struct { + F1 int `vd:"true && $ <= 24*60*60"` // 1500 ok + F2 int `vd:"$%60 == 0 && $ <= (24*60*60)"` // 1500 ok + F3 int `vd:"$ <= 24*60*60"` // 1500 ok + } + a := &A{ + F1: 1500, + F2: 1500, + F3: 1500, + } + vm := tagexpr.New("vd") + err := vm.MustRun(a).Range(func(eh *tagexpr.ExprHandler) error { + switch eh.Path() { + case "F1": + assertEqual(t, true, eh.Eval(), eh.Path()) + case "F2": + assertEqual(t, true, eh.Eval(), eh.Path()) + case "F3": + assertEqual(t, true, eh.Eval(), eh.Path()) + } + return nil + }) + if err != nil { + t.Fatal(err) + } +} diff --git a/internal/tagexpr/tagparser.go b/internal/tagexpr/tagparser.go new file mode 100644 index 000000000..fd9547845 --- /dev/null +++ b/internal/tagexpr/tagparser.go @@ -0,0 +1,190 @@ +// Copyright 2019 Bytedance Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tagexpr + +import ( + "fmt" + "strings" + "unicode" +) + +const ( + tagOmit = "-" + tagOmitNil = "?" +) + +func (f *fieldVM) parseExprs(tag string) error { + switch tag { + case tagOmit, tagOmitNil: + f.tagOp = tag + return nil + } + + kvs, err := parseTag(tag) + if err != nil { + return err + } + exprSelectorPrefix := f.structField.Name + + for exprSelector, exprString := range kvs { + expr, err := parseExpr(exprString) + if err != nil { + return err + } + if exprSelector == ExprNameSeparator { + exprSelector = exprSelectorPrefix + } else { + exprSelector = exprSelectorPrefix + ExprNameSeparator + exprSelector + } + f.exprs[exprSelector] = expr + f.origin.exprs[exprSelector] = expr + f.origin.exprSelectorList = append(f.origin.exprSelectorList, exprSelector) + } + return nil +} + +func parseTag(tag string) (map[string]string, error) { + s := tag + ptr := &s + kvs := make(map[string]string) + for { + one, err := readOneExpr(ptr) + if err != nil { + return nil, err + } + if one == "" { + return kvs, nil + } + key, val := splitExpr(one) + if val == "" { + return nil, fmt.Errorf("syntax error: %q expression string can not be empty", tag) + } + if _, ok := kvs[key]; ok { + return nil, fmt.Errorf("syntax error: %q duplicate expression name %q", tag, key) + } + kvs[key] = val + } +} + +func splitExpr(one string) (key, val string) { + one = strings.TrimSpace(one) + if one == "" { + return DefaultExprName, "" + } + var rs []rune + for _, r := range one { + if r == '@' || + r == '_' || + (r >= '0' && r <= '9') || + (r >= 'A' && r <= 'Z') || + (r >= 'a' && r <= 'z') { + rs = append(rs, r) + } else { + break + } + } + key = string(rs) + val = strings.TrimSpace(one[len(key):]) + if val == "" || val[0] != ':' { + return DefaultExprName, one + } + val = val[1:] + if key == "" { + key = DefaultExprName + } + return key, val +} + +func readOneExpr(tag *string) (string, error) { + s := *(trimRightSpace(trimLeftSpace(tag))) + s = strings.TrimLeft(s, ";") + if s == "" { + return "", nil + } + if s[len(s)-1] != ';' { + s += ";" + } + a := strings.SplitAfter(strings.Replace(s, "\\'", "##", -1), ";") + idx := -1 + var patch int + for _, v := range a { + idx += len(v) + count := strings.Count(v, "'") + if (count+patch)%2 == 0 { + *tag = s[idx+1:] + return s[:idx], nil + } + if count > 0 { + patch++ + } + } + return "", fmt.Errorf("syntax error: %q unclosed single quote \"'\"", s) +} + +func trimLeftSpace(p *string) *string { + *p = strings.TrimLeftFunc(*p, unicode.IsSpace) + return p +} + +func trimRightSpace(p *string) *string { + *p = strings.TrimRightFunc(*p, unicode.IsSpace) + return p +} + +func readPairedSymbol(p *string, left, right rune) *string { + s := *p + if len(s) == 0 || rune(s[0]) != left { + return nil + } + s = s[1:] + last1 := left + var last2 rune + var leftLevel, rightLevel int + escapeIndexes := make(map[int]bool) + var realEqual, escapeEqual bool + for i, r := range s { + if realEqual, escapeEqual = equalRune(right, r, last1, last2); realEqual { + if leftLevel == rightLevel { + *p = s[i+1:] + sub := make([]rune, 0, i) + for k, v := range s[:i] { + if !escapeIndexes[k] { + sub = append(sub, v) + } + } + s = string(sub) + return &s + } + rightLevel++ + } else if escapeEqual { + escapeIndexes[i-1] = true + } else if realEqual, escapeEqual = equalRune(left, r, last1, last2); realEqual { + leftLevel++ + } else if escapeEqual { + escapeIndexes[i-1] = true + } + last2 = last1 + last1 = r + } + return nil +} + +func equalRune(a, b, last1, last2 rune) (real, escape bool) { + if a == b { + real = last1 != '\\' || last2 == '\\' + escape = last1 == '\\' && last2 != '\\' + } + return +} diff --git a/internal/tagexpr/tagparser_test.go b/internal/tagexpr/tagparser_test.go new file mode 100644 index 000000000..5f36fb80e --- /dev/null +++ b/internal/tagexpr/tagparser_test.go @@ -0,0 +1,93 @@ +// Copyright 2019 Bytedance Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tagexpr + +import ( + "reflect" + "testing" +) + +func TestTagparser(t *testing.T) { + cases := []struct { + tag reflect.StructTag + expect map[string]string + fail bool + }{ + { + tag: `tagexpr:"$>0"`, + expect: map[string]string{ + "@": "$>0", + }, + }, { + tag: `tagexpr:"$>0;'xxx'"`, + fail: true, + }, { + tag: `tagexpr:"$>0;b:sprintf('%[1]T; %[1]v',(X)$)"`, + expect: map[string]string{ + "@": `$>0`, + "b": `sprintf('%[1]T; %[1]v',(X)$)`, + }, + }, { + tag: `tagexpr:"a:$=='0;1;';b:sprintf('%[1]T; %[1]v',(X)$)"`, + expect: map[string]string{ + "a": `$=='0;1;'`, + "b": `sprintf('%[1]T; %[1]v',(X)$)`, + }, + }, { + tag: `tagexpr:"a:1;;b:2"`, + expect: map[string]string{ + "a": `1`, + "b": `2`, + }, + }, { + tag: `tagexpr:";a:1;;b:2;;;"`, + expect: map[string]string{ + "a": `1`, + "b": `2`, + }, + }, { + tag: `tagexpr:";a:'123\\'';;b:'1\\'23';c:'1\\'2\\'3';;"`, + expect: map[string]string{ + "a": `'123\''`, + "b": `'1\'23'`, + "c": `'1\'2\'3'`, + }, + }, { + tag: `tagexpr:"email($)"`, + expect: map[string]string{ + "@": `email($)`, + }, + }, { + tag: `tagexpr:"false"`, + expect: map[string]string{ + "@": `false`, + }, + }, + } + + for _, c := range cases { + r, e := parseTag(c.tag.Get("tagexpr")) + if e != nil == c.fail { + if !reflect.DeepEqual(c.expect, r) { + t.Fatal(c.expect, r, c.tag) + } + } else { + t.Fatalf("tag:%s kvs:%v, err:%v", c.tag, r, e) + } + if e != nil { + t.Logf("tag:%q, errMsg:%v", c.tag, e) + } + } +} diff --git a/internal/tagexpr/utils.go b/internal/tagexpr/utils.go new file mode 100644 index 000000000..11b3ca622 --- /dev/null +++ b/internal/tagexpr/utils.go @@ -0,0 +1,101 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package tagexpr + +import ( + "reflect" + "unsafe" +) + +func init() { + testhack() +} + +func dereferenceValue(v reflect.Value) reflect.Value { + for v.Kind() == reflect.Ptr || v.Kind() == reflect.Interface { + v = v.Elem() + } + return v +} + +func dereferenceType(t reflect.Type) reflect.Type { + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + return t +} + +func dereferenceInterfaceValue(v reflect.Value) reflect.Value { + for v.Kind() == reflect.Interface { + v = v.Elem() + } + return v +} + +type rvtype struct { // reflect.Value + abiType uintptr + ptr unsafe.Pointer // data pointer +} + +func rvPtr(rv reflect.Value) unsafe.Pointer { + return (*rvtype)(unsafe.Pointer(&rv)).ptr +} + +func rvType(rv reflect.Value) uintptr { + return (*rvtype)(unsafe.Pointer(&rv)).abiType +} + +func rtType(rt reflect.Type) uintptr { + type iface struct { + tab uintptr + data uintptr + } + return (*iface)(unsafe.Pointer(&rt)).data +} + +// quick test make sure the hack above works +func testhack() { + type T1 struct { + a int + } + type T2 struct { + a int + } + p0 := &T1{1} + p1 := &T1{2} + p2 := &T2{3} + + if rvPtr(reflect.ValueOf(p0)) != unsafe.Pointer(p0) || + rvPtr(reflect.ValueOf(p0).Elem()) != unsafe.Pointer(p0) || + rvPtr(reflect.ValueOf(p0)) == rvPtr(reflect.ValueOf(p1)) { + panic("rvPtr() compatibility issue found") + } + + if rvType(reflect.ValueOf(p0)) != rvType(reflect.ValueOf(p1)) || + rvType(reflect.ValueOf(p0)) == rvType(reflect.ValueOf(p2)) || + rvType(reflect.ValueOf(p0).Elem()) != rvType(reflect.ValueOf(p1).Elem()) || + rvType(reflect.ValueOf(p0).Elem()) == rvType(reflect.ValueOf(p2).Elem()) { + panic("rvType() compatibility issue found") + } + + if rtType(reflect.TypeOf(p0)) != rtType(reflect.TypeOf(p1)) || + rtType(reflect.TypeOf(p0)) == rtType(reflect.TypeOf(p2)) || + rtType(reflect.TypeOf(p0).Elem()) != rtType(reflect.TypeOf(p1).Elem()) || + rtType(reflect.TypeOf(p0).Elem()) == rtType(reflect.TypeOf(p2).Elem()) { + panic("rtType() compatibility issue found") + } +} diff --git a/internal/tagexpr/validator/README.md b/internal/tagexpr/validator/README.md new file mode 100644 index 000000000..b3321a671 --- /dev/null +++ b/internal/tagexpr/validator/README.md @@ -0,0 +1,204 @@ +# validator [![GoDoc](https://img.shields.io/badge/godoc-reference-blue.svg?style=flat-square)](http://godoc.org/github.com/bytedance/go-tagexpr/v2/validator) + +A powerful validator that supports struct tag expression. + +## Feature + +- Support for a variety of common operator +- Support for accessing arrays, slices, members of the dictionary +- Support access to any field in the current structure +- Support access to nested fields, non-exported fields, etc. +- Support registers validator function expression +- Built-in len, sprintf, regexp, email, phone functions +- Support simple mode, or specify error message mode +- Use offset pointers to directly take values, better performance +- Required go version ≥1.9 + +## Example + +```go +package validator_test + +import ( + "fmt" + + vd "github.com/bytedance/go-tagexpr/v2/validator" +) + +func Example() { + type InfoRequest struct { + Name string `vd:"($!='Alice'||(Age)$==18) && regexp('\\w')"` + Age int `vd:"$>0"` + Email string `vd:"email($)"` + Phone1 string `vd:"phone($)"` + OtherPhones []string `vd:"range($, phone(#v,'CN'))"` + *InfoRequest `vd:"?"` + Info1 *InfoRequest `vd:"?"` + Info2 *InfoRequest `vd:"-"` + } + info := &InfoRequest{ + Name: "Alice", + Age: 18, + Email: "henrylee2cn@gmail.com", + Phone1: "+8618812345678", + OtherPhones: []string{"18812345679", "18812345680"}, + } + fmt.Println(vd.Validate(info)) + + type A struct { + A int `vd:"$<0||$>=100"` + Info interface{} + } + info.Email = "xxx" + a := &A{A: 107, Info: info} + fmt.Println(vd.Validate(a)) + type B struct { + B string `vd:"len($)>1 && regexp('^\\w*$')"` + } + b := &B{"abc"} + fmt.Println(vd.Validate(b) == nil) + + type C struct { + C bool `vd:"@:(S.A)$>0 && !$; msg:'C must be false when S.A>0'"` + S *A + } + c := &C{C: true, S: a} + fmt.Println(vd.Validate(c)) + + type D struct { + d []string `vd:"@:len($)>0 && $[0]=='D'; msg:sprintf('invalid d: %v',$)"` + } + d := &D{d: []string{"x", "y"}} + fmt.Println(vd.Validate(d)) + + type E struct { + e map[string]int `vd:"len($)==$['len']"` + } + e := &E{map[string]int{"len": 2}} + fmt.Println(vd.Validate(e)) + + // Customizes the factory of validation error. + vd.SetErrorFactory(func(failPath, msg string) error { + return fmt.Errorf(`{"succ":false, "error":"validation failed: %s"}`, failPath) + }) + + type F struct { + f struct { + g int `vd:"$%3==0"` + } + } + f := &F{} + f.f.g = 10 + fmt.Println(vd.Validate(f)) + + fmt.Println(vd.Validate(map[string]*F{"a": f})) + fmt.Println(vd.Validate(map[string]map[string]*F{"a": {"b": f}})) + fmt.Println(vd.Validate([]map[string]*F{{"a": f}})) + fmt.Println(vd.Validate(struct { + A []map[string]*F + }{A: []map[string]*F{{"x": f}}})) + fmt.Println(vd.Validate(map[*F]int{f: 1})) + fmt.Println(vd.Validate([][1]*F{{f}})) + fmt.Println(vd.Validate((*F)(nil))) + fmt.Println(vd.Validate(map[string]*F{})) + fmt.Println(vd.Validate(map[string]map[string]*F{})) + fmt.Println(vd.Validate([]map[string]*F{})) + fmt.Println(vd.Validate([]*F{})) + + // Output: + // + // email format is incorrect + // true + // C must be false when S.A>0 + // invalid d: [x y] + // invalid parameter: e + // {"succ":false, "error":"validation failed: f.g"} + // {"succ":false, "error":"validation failed: {v for k=a}.f.g"} + // {"succ":false, "error":"validation failed: {v for k=a}{v for k=b}.f.g"} + // {"succ":false, "error":"validation failed: [0]{v for k=a}.f.g"} + // {"succ":false, "error":"validation failed: A[0]{v for k=x}.f.g"} + // {"succ":false, "error":"validation failed: {k}.f.g"} + // {"succ":false, "error":"validation failed: [0][0].f.g"} + // unsupported data: nil + // + // + // + // +} +``` + +## Syntax + +Struct tag syntax spec: + +``` +type T struct { + // Simple model + Field1 T1 `tagName:"expression"` + // Specify error message mode + Field2 T2 `tagName:"@:expression; msg:expression2"` + // Omit it + Field3 T3 `tagName:"-"` + // Omit it when it is nil + Field4 T4 `tagName:"?"` + ... +} +``` + +|Operator or Operand|Explain| +|-----|---------| +|`true` `false`|boolean| +|`0` `0.0`|float64 "0"| +|`''`|String| +|`\\'`| Escape `'` delims in string| +|`\"`| Escape `"` delims in string| +|`nil`|nil, undefined| +|`!`|not| +|`+`|Digital addition or string splicing| +|`-`|Digital subtraction or negative| +|`*`|Digital multiplication| +|`/`|Digital division| +|`%`|division remainder, as: `float64(int64(a)%int64(b))`| +|`==`|`eq`| +|`!=`|`ne`| +|`>`|`gt`| +|`>=`|`ge`| +|`<`|`lt`| +|`<=`|`le`| +|`&&`|Logic `and`| +|`\|\|`|Logic `or`| +|`()`|Expression group| +|`(X)$`|Struct field value named X| +|`(X.Y)$`|Struct field value named X.Y| +|`$`|Shorthand for `(X)$`, omit `(X)` to indicate current struct field value| +|`(X)$['A']`|Map value with key A or struct A sub-field in the struct field X| +|`(X)$[0]`|The 0th element or sub-field of the struct field X(type: map, slice, array, struct)| +|`len((X)$)`|Built-in function `len`, the length of struct field X| +|`mblen((X)$)`|the length of string field X (character number)| +|`regexp('^\\w*$', (X)$)`|Regular match the struct field X, return boolean| +|`regexp('^\\w*$')`|Regular match the current struct field, return boolean| +|`sprintf('X value: %v', (X)$)`|`fmt.Sprintf`, format the value of struct field X| +|`range(KvExpr, forEachExpr)`|Iterate over an array, slice, or dictionary
- `#k` is the element key var
- `#v` is the element value var
- `##` is the number of elements
- e.g. [example](../spec_range_test.go)| +|`in((X)$, enum_1, ...enum_n)`|Check if the first parameter is one of the enumerated parameters| +|`email((X)$)`|Regular match the struct field X, return true if it is email| +|`phone((X)$,<'defaultRegion'>)`|Regular match the struct field X, return true if it is phone| + + + + + +Operator priority(high -> low): + +* `()` `!` `bool` `float64` `string` `nil` +* `*` `/` `%` +* `+` `-` +* `<` `<=` `>` `>=` +* `==` `!=` +* `&&` +* `||` diff --git a/internal/tagexpr/validator/default.go b/internal/tagexpr/validator/default.go new file mode 100644 index 000000000..667b5f5cf --- /dev/null +++ b/internal/tagexpr/validator/default.go @@ -0,0 +1,42 @@ +// Copyright 2019 Bytedance Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package validator + +var defaultValidator = New("vd").SetErrorFactory(defaultErrorFactory) + +// Default returns the default validator. +// NOTE: +// +// The tag name is 'vd' +func Default() *Validator { + return defaultValidator +} + +// Validate uses the default validator to validate whether the fields of value is valid. +// NOTE: +// +// The tag name is 'vd' +// If checkAll=true, validate all the error. +func Validate(value interface{}, checkAll ...bool) error { + return defaultValidator.Validate(value, checkAll...) +} + +// SetErrorFactory customizes the factory of validation error for the default validator. +// NOTE: +// +// The tag name is 'vd' +func SetErrorFactory(errFactory func(fieldSelector, msg string) error) { + defaultValidator.SetErrorFactory(errFactory) +} diff --git a/internal/tagexpr/validator/example_test.go b/internal/tagexpr/validator/example_test.go new file mode 100644 index 000000000..0c1788787 --- /dev/null +++ b/internal/tagexpr/validator/example_test.go @@ -0,0 +1,122 @@ +// Copyright 2019 Bytedance Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package validator_test + +import ( + "fmt" + + vd "github.com/cloudwego/hertz/internal/tagexpr/validator" +) + +func Example() { + type InfoRequest struct { + Name string `vd:"($!='Alice'||(Age)$==18) && regexp('\\w')"` + Age int `vd:"$>0"` + Email string `vd:"email($)"` + Phone1 string `vd:"phone($)"` + OtherPhones []string `vd:"range($, phone(#v,'CN'))"` + *InfoRequest `vd:"?"` + Info1 *InfoRequest `vd:"?"` + Info2 *InfoRequest `vd:"-"` + } + info := &InfoRequest{ + Name: "Alice", + Age: 18, + Email: "henrylee2cn@gmail.com", + Phone1: "+8618812345678", + OtherPhones: []string{"18812345679", "18812345680"}, + } + fmt.Println(vd.Validate(info)) + + type A struct { + A int `vd:"$<0||$>=100"` + Info interface{} + } + info.Email = "xxx" + a := &A{A: 107, Info: info} + fmt.Println(vd.Validate(a)) + type B struct { + B string `vd:"len($)>1 && regexp('^\\w*$')"` + } + b := &B{"abc"} + fmt.Println(vd.Validate(b) == nil) + + type C struct { + C bool `vd:"@:(S.A)$>0 && !$; msg:'C must be false when S.A>0'"` + S *A + } + c := &C{C: true, S: a} + fmt.Println(vd.Validate(c)) + + type D struct { + d []string `vd:"@:len($)>0 && $[0]=='D'; msg:sprintf('invalid d: %v',$)"` + } + d := &D{d: []string{"x", "y"}} + fmt.Println(vd.Validate(d)) + + type E struct { + e map[string]int `vd:"len($)==$['len']"` + } + e := &E{map[string]int{"len": 2}} + fmt.Println(vd.Validate(e)) + + // Customizes the factory of validation error. + vd.SetErrorFactory(func(failPath, msg string) error { + return fmt.Errorf(`{"succ":false, "error":"validation failed: %s"}`, failPath) + }) + + type F struct { + f struct { + g int `vd:"$%3==0"` + } + } + f := &F{} + f.f.g = 10 + fmt.Println(vd.Validate(f)) + + fmt.Println(vd.Validate(map[string]*F{"a": f})) + fmt.Println(vd.Validate(map[string]map[string]*F{"a": {"b": f}})) + fmt.Println(vd.Validate([]map[string]*F{{"a": f}})) + fmt.Println(vd.Validate(struct { + A []map[string]*F + }{A: []map[string]*F{{"x": f}}})) + fmt.Println(vd.Validate(map[*F]int{f: 1})) + fmt.Println(vd.Validate([][1]*F{{f}})) + fmt.Println(vd.Validate((*F)(nil))) + fmt.Println(vd.Validate(map[string]*F{})) + fmt.Println(vd.Validate(map[string]map[string]*F{})) + fmt.Println(vd.Validate([]map[string]*F{})) + fmt.Println(vd.Validate([]*F{})) + + // Output: + // + // email format is incorrect + // true + // C must be false when S.A>0 + // invalid d: [x y] + // invalid parameter: e + // {"succ":false, "error":"validation failed: f.g"} + // {"succ":false, "error":"validation failed: {v for k=a}.f.g"} + // {"succ":false, "error":"validation failed: {v for k=a}{v for k=b}.f.g"} + // {"succ":false, "error":"validation failed: [0]{v for k=a}.f.g"} + // {"succ":false, "error":"validation failed: A[0]{v for k=x}.f.g"} + // {"succ":false, "error":"validation failed: {k}.f.g"} + // {"succ":false, "error":"validation failed: [0][0].f.g"} + // unsupported data: nil + // + // + // + // +} diff --git a/internal/tagexpr/validator/func.go b/internal/tagexpr/validator/func.go new file mode 100644 index 000000000..17800cc65 --- /dev/null +++ b/internal/tagexpr/validator/func.go @@ -0,0 +1,116 @@ +// Copyright 2019 Bytedance Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package validator + +import ( + "errors" + "regexp" + + "github.com/nyaruka/phonenumbers" + + "github.com/cloudwego/hertz/internal/tagexpr" +) + +// ErrInvalidWithoutMsg verification error without error message. +var ErrInvalidWithoutMsg = errors.New("") + +// MustRegFunc registers validator function expression. +// NOTE: +// +// panic if exist error; +// example: phone($) or phone($,'CN'); +// If @force=true, allow to cover the existed same @funcName; +// The go number types always are float64; +// The go string types always are string. +func MustRegFunc(funcName string, fn func(args ...interface{}) error, force ...bool) { + err := RegFunc(funcName, fn, force...) + if err != nil { + panic(err) + } +} + +// RegFunc registers validator function expression. +// NOTE: +// +// example: phone($) or phone($,'CN'); +// If @force=true, allow to cover the existed same @funcName; +// The go number types always are float64; +// The go string types always are string. +func RegFunc(funcName string, fn func(args ...interface{}) error, force ...bool) error { + return tagexpr.RegFunc(funcName, func(args ...interface{}) interface{} { + err := fn(args...) + if err == nil { + // nil defaults to false, so returns true + return true + } + return err + }, force...) +} + +func init() { + pattern := "^([A-Za-z0-9_\\-\\.\u4e00-\u9fa5])+\\@([A-Za-z0-9_\\-\\.])+\\.([A-Za-z]{2,8})$" + emailRegexp := regexp.MustCompile(pattern) + MustRegFunc("email", func(args ...interface{}) error { + if len(args) != 1 { + return errors.New("number of parameters of email function is not one") + } + s, ok := args[0].(string) + if !ok { + return errors.New("parameter of email function is not string type") + } + matched := emailRegexp.MatchString(s) + if !matched { + // return ErrInvalidWithoutMsg + return errors.New("email format is incorrect") + } + return nil + }, true) +} + +func init() { + // phone: defaultRegion is 'CN' + MustRegFunc("phone", func(args ...interface{}) error { + var numberToParse, defaultRegion string + var ok bool + switch len(args) { + default: + return errors.New("the number of parameters of phone function is not one or two") + case 2: + defaultRegion, ok = args[1].(string) + if !ok { + return errors.New("the 2nd parameter of phone function is not string type") + } + fallthrough + case 1: + numberToParse, ok = args[0].(string) + if !ok { + return errors.New("the 1st parameter of phone function is not string type") + } + } + if defaultRegion == "" { + defaultRegion = "CN" + } + num, err := phonenumbers.Parse(numberToParse, defaultRegion) + if err != nil { + return err + } + matched := phonenumbers.IsValidNumber(num) + if !matched { + // return ErrInvalidWithoutMsg + return errors.New("phone format is incorrect") + } + return nil + }, true) +} diff --git a/internal/tagexpr/validator/validator.go b/internal/tagexpr/validator/validator.go new file mode 100644 index 000000000..2c57010b5 --- /dev/null +++ b/internal/tagexpr/validator/validator.go @@ -0,0 +1,163 @@ +// Package validator is a powerful validator that supports struct tag expression. +// +// Copyright 2019 Bytedance Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +package validator + +import ( + "errors" + "io" + "reflect" + "strings" + _ "unsafe" + + "github.com/cloudwego/hertz/internal/tagexpr" +) + +const ( + // MatchExprName the name of the expression used for validation + MatchExprName = tagexpr.DefaultExprName + // ErrMsgExprName the name of the expression used to specify the message + // returned when validation failed + ErrMsgExprName = "msg" +) + +// Validator struct fields validator +type Validator struct { + vm *tagexpr.VM + errFactory func(failPath, msg string) error +} + +// New creates a struct fields validator. +func New(tagName string) *Validator { + v := &Validator{ + vm: tagexpr.New(tagName), + errFactory: defaultErrorFactory, + } + return v +} + +// VM returns the struct tag expression interpreter. +func (v *Validator) VM() *tagexpr.VM { + return v.vm +} + +// Validate validates whether the fields of value is valid. +// NOTE: +// +// If checkAll=true, validate all the error. +func (v *Validator) Validate(value interface{}, checkAll ...bool) error { + var all bool + if len(checkAll) > 0 { + all = checkAll[0] + } + errs := make([]error, 0, 8) + err := v.vm.RunAny(value, func(te *tagexpr.TagExpr, err error) error { + if err != nil { + errs = append(errs, err) + if all { + return nil + } + return io.EOF + } + nilParentFields := make(map[string]bool, 16) + err = te.Range(func(eh *tagexpr.ExprHandler) error { + if strings.Contains(eh.StringSelector(), tagexpr.ExprNameSeparator) { + return nil + } + r := eh.Eval() + if r == nil { + return nil + } + rerr, ok := r.(error) + if !ok && tagexpr.FakeBool(r) { + return nil + } + // Ignore this error if the value of the parent is nil + if pfs, ok := eh.ExprSelector().ParentField(); ok { + if nilParentFields[pfs] { + return nil + } + if fh, ok := eh.TagExpr().Field(pfs); ok { + v := fh.Value(false) + if !v.IsValid() || (v.Kind() == reflect.Ptr && v.IsNil()) { + nilParentFields[pfs] = true + return nil + } + } + } + msg := eh.TagExpr().EvalString(eh.StringSelector() + tagexpr.ExprNameSeparator + ErrMsgExprName) + if msg == "" && rerr != nil { + msg = rerr.Error() + } + errs = append(errs, v.errFactory(eh.Path(), msg)) + if all { + return nil + } + return io.EOF + }) + if err != nil && !all { + return err + } + return nil + }) + if err != io.EOF && err != nil { + return err + } + switch len(errs) { + case 0: + return nil + case 1: + return errs[0] + default: + var errStr string + for _, e := range errs { + errStr += e.Error() + "\t" + } + return errors.New(errStr[:len(errStr)-1]) + } +} + +// SetErrorFactory customizes the factory of validation error. +// NOTE: +// +// If errFactory==nil, the default is used +func (v *Validator) SetErrorFactory(errFactory func(failPath, msg string) error) *Validator { + if errFactory == nil { + errFactory = defaultErrorFactory + } + v.errFactory = errFactory + return v +} + +// Error validate error +type Error struct { + FailPath, Msg string +} + +// Error implements error interface. +func (e *Error) Error() string { + if e.Msg != "" { + return e.Msg + } + return "invalid parameter: " + e.FailPath +} + +//go:nosplit +func defaultErrorFactory(failPath, msg string) error { + return &Error{ + FailPath: failPath, + Msg: msg, + } +} diff --git a/internal/tagexpr/validator/validator_test.go b/internal/tagexpr/validator/validator_test.go new file mode 100644 index 000000000..5cc2d7fb1 --- /dev/null +++ b/internal/tagexpr/validator/validator_test.go @@ -0,0 +1,354 @@ +// Copyright 2019 Bytedance Inc. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package validator_test + +import ( + "encoding/json" + "errors" + "testing" + + vd "github.com/cloudwego/hertz/internal/tagexpr/validator" +) + +func assertEqualError(t *testing.T, err error, s string) { + t.Helper() + if err.Error() != s { + t.Fatal("not equal", err, s) + } +} + +func assertNoError(t *testing.T, err error) { + t.Helper() + if err != nil { + t.Fatal(err) + } +} + +func TestNil(t *testing.T) { + type F struct { + F struct { + G int `vd:"$%3==1"` + } + } + assertEqualError(t, vd.Validate((*F)(nil)), "unsupported data: nil") +} + +func TestAll(t *testing.T) { + type T struct { + A string `vd:"email($)"` + F struct { + G int `vd:"$%3==1"` + } + } + assertEqualError(t, vd.Validate(new(T), true), "email format is incorrect\tinvalid parameter: F.G") +} + +func TestIssue1(t *testing.T) { + type MailBox struct { + Address *string `vd:"email($)"` + Name *string + } + type EmailMsg struct { + Recipients []*MailBox + RecipientsCc []*MailBox + RecipientsBcc []*MailBox + Subject *string + Content *string + AttachmentIDList []string + ReplyTo *string + Params map[string]string + FromEmailAddress *string + FromEmailName *string + } + type EmailTaskInfo struct { + Msg *EmailMsg + StartTimeMS *int64 + LogTag *string + } + type BatchCreateEmailTaskRequest struct { + InfoList []*EmailTaskInfo + } + invalid := "invalid email" + req := &BatchCreateEmailTaskRequest{ + InfoList: []*EmailTaskInfo{ + { + Msg: &EmailMsg{ + Recipients: []*MailBox{ + { + Address: &invalid, + }, + }, + }, + }, + }, + } + assertEqualError(t, vd.Validate(req, false), "email format is incorrect") +} + +func TestIssue2(t *testing.T) { + type a struct { + m map[string]interface{} + } + A := &a{ + m: map[string]interface{}{ + "1": 1, + "2": nil, + }, + } + v := vd.New("vd") + assertNoError(t, v.Validate(A)) +} + +func TestIssue3(t *testing.T) { + type C struct { + Id string + Index int32 `vd:"$==1"` + } + type A struct { + F1 *C + F2 *C + } + a := &A{ + F1: &C{ + Id: "test", + Index: 1, + }, + } + v := vd.New("vd") + assertNoError(t, v.Validate(a)) +} + +func TestIssue4(t *testing.T) { + type C struct { + Index *int32 `vd:"@:$!=nil;msg:'index is nil'"` + Index2 *int32 `vd:"$!=nil"` + Index3 *int32 `vd:"$!=nil"` + } + type A struct { + F1 *C + F2 map[string]*C + F3 []*C + } + v := vd.New("vd") + + a := &A{} + assertNoError(t, v.Validate(a)) + + a = &A{F1: new(C)} + assertEqualError(t, v.Validate(a), "index is nil") + + a = &A{F2: map[string]*C{"x": {Index: new(int32)}}} + assertEqualError(t, v.Validate(a), "invalid parameter: F2{v for k=x}.Index2") + + a = &A{F3: []*C{{Index: new(int32)}}} + assertEqualError(t, v.Validate(a), "invalid parameter: F3[0].Index2") + + type B struct { + F1 *C `vd:"$!=nil"` + F2 *C + } + b := &B{} + assertEqualError(t, v.Validate(b), "invalid parameter: F1") + + type D struct { + F1 *C + F2 *C + } + + type E struct { + D []*D + } + b.F1 = new(C) + e := &E{D: []*D{nil}} + assertNoError(t, v.Validate(e)) +} + +func TestIssue5(t *testing.T) { + type SubSheet struct{} + type CopySheet struct { + Source *SubSheet `json:"source" vd:"$!=nil"` + Destination *SubSheet `json:"destination" vd:"$!=nil"` + } + type UpdateSheetsRequest struct { + CopySheet *CopySheet `json:"copySheet"` + } + type BatchUpdateSheetRequestArg struct { + Requests []*UpdateSheetsRequest `json:"requests"` + } + b := `{"requests": [{}]}` + var data BatchUpdateSheetRequestArg + err := json.Unmarshal([]byte(b), &data) + assertNoError(t, err) + if len(data.Requests) != 1 { + t.Fatal(len(data.Requests)) + } + if data.Requests[0].CopySheet != nil { + t.Fatal(data.Requests[0].CopySheet) + } + v := vd.New("vd") + assertNoError(t, v.Validate(&data)) +} + +func TestIn(t *testing.T) { + type S string + type I int16 + type T struct { + X *int `vd:"$==nil || len($)>0"` + A S `vd:"in($,'a','b','c')"` + B I `vd:"in($,1,2.0,3)"` + } + v := vd.New("vd") + data := &T{} + err := v.Validate(data) + assertEqualError(t, err, "invalid parameter: A") + data.A = "b" + err = v.Validate(data) + assertEqualError(t, err, "invalid parameter: B") + data.B = 2 + err = v.Validate(data) + assertNoError(t, err) + + type T2 struct { + C string `vd:"in($)"` + } + data2 := &T2{} + err = v.Validate(data2) + assertEqualError(t, err, "invalid parameter: C") + + type T3 struct { + C string `vd:"in($,1)"` + } + data3 := &T3{} + err = v.Validate(data3) + assertEqualError(t, err, "invalid parameter: C") +} + +type ( + Issue23A struct { + B *Issue23B + V int64 `vd:"$==0"` + } + Issue23B struct { + A *Issue23A + V int64 `vd:"$==0"` + } +) + +func TestIssue23(t *testing.T) { + data := &Issue23B{A: &Issue23A{B: new(Issue23B)}} + err := vd.Validate(data, true) + assertNoError(t, err) +} + +func TestIssue24(t *testing.T) { + type SubmitDoctorImportItem struct { + Name string `form:"name,required" json:"name,required" query:"name,required"` + Avatar *string `form:"avatar,omitempty" json:"avatar,omitempty" query:"avatar,omitempty"` + Idcard string `form:"idcard,required" json:"idcard,required" query:"idcard,required" vd:"len($)==18"` + IdcardPics []string `form:"idcard_pics,omitempty" json:"idcard_pics,omitempty" query:"idcard_pics,omitempty"` + Hosp string `form:"hosp,required" json:"hosp,required" query:"hosp,required"` + HospDept string `form:"hosp_dept,required" json:"hosp_dept,required" query:"hosp_dept,required"` + HospProv *string `form:"hosp_prov,omitempty" json:"hosp_prov,omitempty" query:"hosp_prov,omitempty"` + HospCity *string `form:"hosp_city,omitempty" json:"hosp_city,omitempty" query:"hosp_city,omitempty"` + HospCounty *string `form:"hosp_county,omitempty" json:"hosp_county,omitempty" query:"hosp_county,omitempty"` + ProTit string `form:"pro_tit,required" json:"pro_tit,required" query:"pro_tit,required"` + ThTit *string `form:"th_tit,omitempty" json:"th_tit,omitempty" query:"th_tit,omitempty"` + ServDepts *string `form:"serv_depts,omitempty" json:"serv_depts,omitempty" query:"serv_depts,omitempty"` + TitCerts []string `form:"tit_certs,omitempty" json:"tit_certs,omitempty" query:"tit_certs,omitempty"` + ThTitCerts []string `form:"th_tit_certs,omitempty" json:"th_tit_certs,omitempty" query:"th_tit_certs,omitempty"` + PracCerts []string `form:"prac_certs,omitempty" json:"prac_certs,omitempty" query:"prac_certs,omitempty"` + QualCerts []string `form:"qual_certs,omitempty" json:"qual_certs,omitempty" query:"qual_certs,omitempty"` + PracCertNo string `form:"prac_cert_no,required" json:"prac_cert_no,required" query:"prac_cert_no,required" vd:"len($)==15"` + Goodat *string `form:"goodat,omitempty" json:"goodat,omitempty" query:"goodat,omitempty"` + Intro *string `form:"intro,omitempty" json:"intro,omitempty" query:"intro,omitempty"` + Linkman string `form:"linkman,required" json:"linkman,required" query:"linkman,required" vd:"email($)"` + Phone string `form:"phone,required" json:"phone,required" query:"phone,required" vd:"phone($,'CN')"` + } + + type SubmitDoctorImportRequest struct { + SubmitDoctorImport []*SubmitDoctorImportItem `form:"submit_doctor_import,required" json:"submit_doctor_import,required"` + } + data := &SubmitDoctorImportRequest{SubmitDoctorImport: []*SubmitDoctorImportItem{{}}} + err := vd.Validate(data, true) + assertEqualError(t, err, "invalid parameter: SubmitDoctorImport[0].Idcard\tinvalid parameter: SubmitDoctorImport[0].PracCertNo\temail format is incorrect\tthe phone number supplied is not a number") +} + +func TestStructSliceMap(t *testing.T) { + type F struct { + f struct { + g int `vd:"$%3==0"` + } + } + f := &F{} + f.f.g = 10 + type S struct { + A map[string]*F + B []map[string]*F + C map[string][]map[string]F + // _ int + } + s := S{ + A: map[string]*F{"x": f}, + B: []map[string]*F{{"y": f}}, + C: map[string][]map[string]F{"z": {{"zz": *f}}}, + } + err := vd.Validate(s, true) + assertEqualError(t, err, "invalid parameter: A{v for k=x}.f.g\tinvalid parameter: B[0]{v for k=y}.f.g\tinvalid parameter: C{v for k=z}[0]{v for k=zz}.f.g") +} + +func TestIssue30(t *testing.T) { + type TStruct struct { + TOk string `vd:"gt($,'0') && gt($, '1')" json:"t_ok"` + // TFail string `vd:"gt($,'0')" json:"t_fail"` + } + vd.RegFunc("gt", func(args ...interface{}) error { + return errors.New("force error") + }) + assertEqualError(t, vd.Validate(&TStruct{TOk: "1"}), "invalid parameter: TOk") + // assertNoError(t, vd.Validate(&TStruct{TOk: "1", TFail: "1"})) +} + +func TestIssue31(t *testing.T) { + type TStruct struct { + A []int32 `vd:"$ == nil || ($ != nil && range($, in(#v, 1, 2, 3))"` + } + assertEqualError(t, vd.Validate(&TStruct{A: []int32{1}}), "syntax error: \"($ != nil && range($, in(#v, 1, 2, 3))\"") + assertEqualError(t, vd.Validate(&TStruct{A: []int32{1}}), "syntax error: \"($ != nil && range($, in(#v, 1, 2, 3))\"") + assertEqualError(t, vd.Validate(&TStruct{A: []int32{1}}), "syntax error: \"($ != nil && range($, in(#v, 1, 2, 3))\"") +} + +func TestRegexp(t *testing.T) { + type TStruct struct { + A string `vd:"regexp('(\\d+\\.){3}\\d+')"` + } + assertNoError(t, vd.Validate(&TStruct{A: "0.0.0.0"})) + assertEqualError(t, vd.Validate(&TStruct{A: "0...0"}), "invalid parameter: A") + assertEqualError(t, vd.Validate(&TStruct{A: "abc1"}), "invalid parameter: A") + assertEqualError(t, vd.Validate(&TStruct{A: "0?0?0?0"}), "invalid parameter: A") +} + +func TestRangeIn(t *testing.T) { + type S struct { + F []string `vd:"range($, in(#v, '', 'ttp', 'euttp'))"` + } + err := vd.Validate(S{ + F: []string{"ttp", "", "euttp"}, + }) + assertNoError(t, err) + err = vd.Validate(S{ + F: []string{"ttp", "?", "euttp"}, + }) + assertEqualError(t, err, "invalid parameter: F") +} diff --git a/pkg/app/server/binding/config.go b/pkg/app/server/binding/config.go index 81cf30e56..4ee349e93 100644 --- a/pkg/app/server/binding/config.go +++ b/pkg/app/server/binding/config.go @@ -22,7 +22,7 @@ import ( "reflect" "time" - exprValidator "github.com/bytedance/go-tagexpr/v2/validator" + exprValidator "github.com/cloudwego/hertz/internal/tagexpr/validator" inDecoder "github.com/cloudwego/hertz/pkg/app/server/binding/internal/decoder" hJson "github.com/cloudwego/hertz/pkg/common/json" "github.com/cloudwego/hertz/pkg/protocol" diff --git a/pkg/app/server/binding/default.go b/pkg/app/server/binding/default.go index 7e09ac9bb..42f55ee87 100644 --- a/pkg/app/server/binding/default.go +++ b/pkg/app/server/binding/default.go @@ -70,8 +70,8 @@ import ( "strings" "sync" - exprValidator "github.com/bytedance/go-tagexpr/v2/validator" "github.com/cloudwego/hertz/internal/bytesconv" + exprValidator "github.com/cloudwego/hertz/internal/tagexpr/validator" inDecoder "github.com/cloudwego/hertz/pkg/app/server/binding/internal/decoder" hJson "github.com/cloudwego/hertz/pkg/common/json" "github.com/cloudwego/hertz/pkg/common/utils" From e644b1cd522dc00d408318444aac1d4bcec676a8 Mon Sep 17 00:00:00 2001 From: Wenju Gao Date: Wed, 11 Dec 2024 16:18:59 +0800 Subject: [PATCH 5/6] fix(client): use the latest host in Location header when redirect (#1246) --- go.mod | 5 ++ go.sum | 2 + pkg/protocol/client/client.go | 3 ++ pkg/protocol/client/client_test.go | 73 ++++++++++++++++++++++++++++++ 4 files changed, 83 insertions(+) create mode 100644 pkg/protocol/client/client_test.go diff --git a/go.mod b/go.mod index d2950316f..31972f936 100644 --- a/go.mod +++ b/go.mod @@ -9,6 +9,7 @@ require ( github.com/cloudwego/netpoll v0.6.4 github.com/fsnotify/fsnotify v1.5.4 github.com/nyaruka/phonenumbers v1.0.55 + github.com/stretchr/testify v1.8.1 github.com/tidwall/gjson v1.14.4 golang.org/x/sync v0.0.0-20210220032951-036812b2e83c golang.org/x/sys v0.24.0 @@ -19,14 +20,18 @@ require ( github.com/bytedance/sonic/loader v0.2.0 // indirect github.com/cloudwego/base64x v0.1.4 // indirect github.com/cloudwego/iasm v0.2.0 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/golang/protobuf v1.5.0 // indirect github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 // indirect github.com/jtolds/gls v4.20.0+incompatible // indirect github.com/klauspost/cpuid/v2 v2.0.9 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d // indirect github.com/smartystreets/goconvey v1.6.4 // indirect + github.com/stretchr/objx v0.5.0 // indirect github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.0 // indirect github.com/twitchyliquid64/golang-asm v0.15.1 // indirect golang.org/x/arch v0.0.0-20210923205945-b76863e36670 // indirect + gopkg.in/yaml.v3 v3.0.1 // indirect ) diff --git a/go.sum b/go.sum index 7e920d459..fe72a680c 100644 --- a/go.sum +++ b/go.sum @@ -40,6 +40,7 @@ github.com/smartystreets/goconvey v1.6.4 h1:fv0U8FUIMPNf1L9lnHLvLhgicrIVChEkdzIK github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= +github.com/stretchr/objx v0.5.0 h1:1zr/of2m5FGMsad5YfcqgdqdWrIhu+EBEJRhR1U7z/c= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= @@ -78,6 +79,7 @@ golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8T google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= google.golang.org/protobuf v1.27.1 h1:SnqbnDw1V7RiZcXPx5MEeqPv2s79L9i7BJUlG/+RurQ= google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/pkg/protocol/client/client.go b/pkg/protocol/client/client.go index 777f55cdd..a5b2c88e8 100644 --- a/pkg/protocol/client/client.go +++ b/pkg/protocol/client/client.go @@ -242,6 +242,9 @@ func DoRequestFollowRedirects(ctx context.Context, req *protocol.Request, resp * break } url = getRedirectURL(url, location) + + // Remove the former host header. + req.Header.Del(consts.HeaderHost) } return statusCode, body, err diff --git a/pkg/protocol/client/client_test.go b/pkg/protocol/client/client_test.go new file mode 100644 index 000000000..49e7e7df6 --- /dev/null +++ b/pkg/protocol/client/client_test.go @@ -0,0 +1,73 @@ +/* + * Copyright 2024 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package client + +import ( + "context" + "errors" + "testing" + + "github.com/cloudwego/hertz/internal/bytestr" + "github.com/cloudwego/hertz/pkg/protocol" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" +) + +var firstTime = true + +type MockDoer struct { + mock.Mock +} + +func (m *MockDoer) Do(ctx context.Context, req *protocol.Request, resp *protocol.Response) error { + + // this is the real logic in (c *HostClient) doNonNilReqResp method + if len(req.Header.Host()) == 0 { + req.Header.SetHostBytes(req.URI().Host()) + } + + if firstTime { + // req.Header.Host() is the real host writing to the wire + if string(req.Header.Host()) != "example.com" { + return errors.New("host not match") + } + // this is the real logic in (c *HostClient) doNonNilReqResp method + if len(req.Header.Host()) == 0 { + req.Header.SetHostBytes(req.URI().Host()) + } + resp.Header.SetCanonical(bytestr.StrLocation, []byte("https://a.b.c/foo")) + resp.SetStatusCode(301) + firstTime = false + return nil + } + + if string(req.Header.Host()) != "a.b.c" { + resp.SetStatusCode(400) + return errors.New("host not match") + } + + resp.SetStatusCode(200) + + return nil +} + +func TestDoRequestFollowRedirects(t *testing.T) { + mockDoer := new(MockDoer) + mockDoer.On("Do", mock.Anything, mock.Anything, mock.Anything).Return(nil) + statusCode, _, err := DoRequestFollowRedirects(context.Background(), &protocol.Request{}, &protocol.Response{}, "https://example.com", defaultMaxRedirectsCount, mockDoer) + assert.NoError(t, err) + assert.Equal(t, 200, statusCode) +} From 67455c5f64780458ad23efd22cb17bd9832a963b Mon Sep 17 00:00:00 2001 From: alice <90381261+alice-yyds@users.noreply.github.com> Date: Thu, 12 Dec 2024 20:04:37 +0800 Subject: [PATCH 6/6] chore: update version v0.9.4 --- version.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/version.go b/version.go index 1f5fabefb..027817a09 100644 --- a/version.go +++ b/version.go @@ -19,5 +19,5 @@ package hertz // Name and Version info of this framework, used for statistics and debug const ( Name = "Hertz" - Version = "v0.9.3" + Version = "v0.9.4" )