From 74cedc74ad2804f9eb88a70ad2565893cb260edf Mon Sep 17 00:00:00 2001 From: kinggo Date: Fri, 1 Sep 2023 17:21:53 +0800 Subject: [PATCH 01/20] optimize(hz): use sprig to add template function (#784) Co-authored-by: fgy --- cmd/hz/generator/template_funcs.go | 19 ++++++++++++------ cmd/hz/go.mod | 1 + cmd/hz/go.sum | 32 ++++++++++++++++++++++++++++++ licenses/LICENSE-sprig.txt | 19 ++++++++++++++++++ 4 files changed, 65 insertions(+), 6 deletions(-) create mode 100644 licenses/LICENSE-sprig.txt diff --git a/cmd/hz/generator/template_funcs.go b/cmd/hz/generator/template_funcs.go index 556ffb1e8..2d33d72ea 100644 --- a/cmd/hz/generator/template_funcs.go +++ b/cmd/hz/generator/template_funcs.go @@ -20,15 +20,22 @@ import ( "strings" "text/template" + "github.com/Masterminds/sprig/v3" "github.com/cloudwego/hertz/cmd/hz/util" ) -var funcMap = template.FuncMap{ - "GetUniqueHandlerOutDir": getUniqueHandlerOutDir, - "ToSnakeCase": util.ToSnakeCase, - "Split": strings.Split, - "Trim": strings.Trim, -} +var funcMap = func() template.FuncMap { + m := template.FuncMap{ + "GetUniqueHandlerOutDir": getUniqueHandlerOutDir, + "ToSnakeCase": util.ToSnakeCase, + "Split": strings.Split, + "Trim": strings.Trim, + } + for key, f := range sprig.TxtFuncMap() { + m[key] = f + } + return m +}() // getUniqueHandlerOutDir uses to get unique "api.handler_path" func getUniqueHandlerOutDir(methods []*HttpMethod) (ret []string) { diff --git a/cmd/hz/go.mod b/cmd/hz/go.mod index b65eeeb65..712c5a28d 100644 --- a/cmd/hz/go.mod +++ b/cmd/hz/go.mod @@ -3,6 +3,7 @@ module github.com/cloudwego/hertz/cmd/hz go 1.16 require ( + github.com/Masterminds/sprig/v3 v3.2.3 github.com/cloudwego/thriftgo v0.1.7 github.com/hashicorp/go-version v1.5.0 github.com/jhump/protoreflect v1.12.0 diff --git a/cmd/hz/go.sum b/cmd/hz/go.sum index 23aa6038b..de5c00d51 100644 --- a/cmd/hz/go.sum +++ b/cmd/hz/go.sum @@ -1,6 +1,12 @@ cloud.google.com/go v0.26.0/go.mod h1:aQUYkXzVsufM+DwF1aE+0xfcU+56JwCaLick0ClmMTw= github.com/BurntSushi/toml v0.3.1/go.mod h1:xHWCNGjB5oqiDr8zfno3MHue2Ht5sIBksp03qcyfWMU= github.com/BurntSushi/toml v1.1.0/go.mod h1:CxXYINrC8qIiEnFrOxCa7Jy5BFHlXnUU2pbicEuybxQ= +github.com/Masterminds/goutils v1.1.1 h1:5nUrii3FMTL5diU80unEVvNevw1nH4+ZV4DSLVJLSYI= +github.com/Masterminds/goutils v1.1.1/go.mod h1:8cTjp+g8YejhMuvIA5y2vz3BpJxksy863GQaJW2MFNU= +github.com/Masterminds/semver/v3 v3.2.0 h1:3MEsd0SM6jqZojhjLWWeBY+Kcjy9i6MQAeY7YgDP83g= +github.com/Masterminds/semver/v3 v3.2.0/go.mod h1:qvl/7zhW3nngYb5+80sSMF+FG2BjYrf8m9wsX0PNOMQ= +github.com/Masterminds/sprig/v3 v3.2.3 h1:eL2fZNezLomi0uOLqjQoN6BfsDD+fyLtgbJMAj9n6YA= +github.com/Masterminds/sprig/v3 v3.2.3/go.mod h1:rXcFaZ2zZbLRJv/xSysmlgIM1u11eBaRMhvYXJNkGuM= github.com/apache/thrift v0.13.0 h1:5hryIiq9gtn+MiLVn0wP37kb/uTeRZgN08WoCsAhIhI= github.com/apache/thrift v0.13.0/go.mod h1:cp2SuWMxlEZw2r+iP2GNCdIi4C1qmUzdZFSVb+bacwQ= github.com/census-instrumentation/opencensus-proto v0.2.1/go.mod h1:f6KPmirojxKA12rnyqOA5BBL4O983OfeGPqjHWSTneU= @@ -11,6 +17,8 @@ github.com/cncf/udpa/go v0.0.0-20201120205902-5459f2c99403/go.mod h1:WmhPx2Nbnht github.com/cpuguy83/go-md2man/v2 v2.0.2 h1:p1EgwI/C7NhT0JmVkwCD2ZBK8j4aeHQX2pMHHBfMQ6w= github.com/cpuguy83/go-md2man/v2 v2.0.2/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/envoyproxy/go-control-plane v0.9.0/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.1-0.20191026205805-5f8ba28d4473/go.mod h1:YTl/9mNaCwkRvm6d1a2C3ymFceY/DCBVvsKhRF0iEA4= github.com/envoyproxy/go-control-plane v0.9.9-0.20210217033140-668b12f5399d/go.mod h1:cXg6YxExXjJnVBQHBLXeUAgxn2UodCpnH306RInaBQk= @@ -35,9 +43,15 @@ github.com/google/go-cmp v0.4.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/ github.com/google/go-cmp v0.5.0/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/uuid v1.1.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/google/uuid v1.1.2 h1:EVhdT+1Kseyi1/pUmXKaFxYsDNy9RQYkMWRH68J/W7Y= github.com/google/uuid v1.1.2/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/hashicorp/go-version v1.5.0 h1:O293SZ2Eg+AAYijkVK3jR786Am1bhDEh2GHT0tIVE5E= github.com/hashicorp/go-version v1.5.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA= +github.com/huandu/xstrings v1.3.3 h1:/Gcsuc1x8JVbJ9/rlye4xZnVAbEkGauT8lbebqcQws4= +github.com/huandu/xstrings v1.3.3/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE= +github.com/imdario/mergo v0.3.11 h1:3tnifQM4i+fbajXKBHXWEH+KvNHqojZ778UH75j3bGA= +github.com/imdario/mergo v0.3.11/go.mod h1:jmQim1M+e3UYxmgPu/WyfjB3N3VflVyUjjjwH0dnCYA= github.com/jhump/gopoet v0.0.0-20190322174617-17282ff210b3/go.mod h1:me9yfT6IJSlOL3FCfrg+L6yzUEZ+5jW6WHt4Sk+UPUI= github.com/jhump/gopoet v0.1.0/go.mod h1:me9yfT6IJSlOL3FCfrg+L6yzUEZ+5jW6WHt4Sk+UPUI= github.com/jhump/goprotoc v0.5.0/go.mod h1:VrbvcYrQOrTi3i0Vf+m+oqQWk9l72mjkJCYo7UvLHRQ= @@ -49,11 +63,22 @@ github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORN github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/mitchellh/copystructure v1.0.0 h1:Laisrj+bAB6b/yJwB5Bt3ITZhGJdqmxquMKeZ+mmkFQ= +github.com/mitchellh/copystructure v1.0.0/go.mod h1:SNtv71yrdKgLRyLFxmLdkAbkKEFWgYaq1OVrnRcwhnw= +github.com/mitchellh/reflectwalk v1.0.0 h1:9D+8oIskB4VJBN5SFlmc27fSlIBZaov1Wpk/IfikLNY= +github.com/mitchellh/reflectwalk v1.0.0/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA= github.com/russross/blackfriday/v2 v2.1.0 h1:JIOH55/0cWyOuilr9/qlrm0BSXldqnqwMsf35Ld67mk= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/shopspring/decimal v1.2.0 h1:abSATXmQEYyShuxI4/vyW3tV1MrKAJzCZ/0zLUXYbsQ= +github.com/shopspring/decimal v1.2.0/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= +github.com/spf13/cast v1.3.1 h1:nFm6S0SMdyzrzcmThSipiEubIDy8WEXKNZ0UOgiRpng= +github.com/spf13/cast v1.3.1/go.mod h1:Qx5cxh0v+4UWYiBimWS+eyWzqEqokIECu5etghLkUJE= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= +github.com/stretchr/testify v1.5.1 h1:nOGnQDM7FYENwehXlg/kFVnos3rEvtKTjRvOWSzb6H4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/urfave/cli/v2 v2.23.0 h1:pkly7gKIeYv3olPAeNajNpLjeJrmTPYCoZWaV+2VfvE= github.com/urfave/cli/v2 v2.23.0/go.mod h1:1CNUng3PtjQMtRzJO4FMXBQvkGtuYRxxiR9xMa7jMwI= @@ -63,6 +88,8 @@ github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5t golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.3.0 h1:a06MkbcxBrEFc0w0QIZWXrH/9cCX6KJyWbBOIwAn+7A= +golang.org/x/crypto v0.3.0/go.mod h1:hebNnKkNXi2UzZN1eVRvBB7co0a+JxK6XbPiWVs/3J4= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/lint v0.0.0-20181026193005-c67002cb31c3/go.mod h1:UVdnD1Gm6xHRNCYTkRU2/jEulfH38KcIWyp/GAMgvoE= golang.org/x/lint v0.0.0-20190227174305-5b3e6a55c961/go.mod h1:wehouNa3lNwaWXcvxsM5YxQ5yQlVC4a0KAMCusXpPoU= @@ -78,6 +105,7 @@ golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLL golang.org/x/net v0.0.0-20200625001655-4c5254603344/go.mod h1:/O7V0waA8r7cgGh81Ro3o1hOxt32SMVPicZroKQ2sZA= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.2.0/go.mod h1:KqCZLdyyvdV855qA2rE3GC2aiw5xGR5TEjj8smXukLY= golang.org/x/net v0.3.0 h1:VWL6FNY2bEEmsGVKabSlHu5Irp34xmMRoqb/9lF9lxk= golang.org/x/net v0.3.0/go.mod h1:MBQ8lrhLObU/6UmLb4fmbmk5OcyYmqtbGd/9yIeKjEE= golang.org/x/oauth2 v0.0.0-20180821212333-d2e6202438be/go.mod h1:N/0e6XlmueqKjAGxoOufVs8QHGRruUQn6yWY3a++T0U= @@ -95,15 +123,18 @@ golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7w golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.2.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.3.0 h1:w8ZOecv6NaNa/zC8944JTU3vz4u6Lagfk4RPQxv92NQ= golang.org/x/sys v0.3.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.2.0/go.mod h1:TVmDHMZPmdnySmBfhjOoOdhjzdE1h4u1VwSiw2l1Nuc= golang.org/x/term v0.3.0/go.mod h1:q750SLmJuPmVoN1blW3UFBPREJfb1KmY3vwxfr+nFDA= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/text v0.5.0 h1:OLmvp0KP+FVG99Ct/qFiL/Fhk4zp4QQnZ7b2U+5piUM= golang.org/x/text v0.5.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= @@ -148,6 +179,7 @@ gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogR gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v2 v2.3.0/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/licenses/LICENSE-sprig.txt b/licenses/LICENSE-sprig.txt new file mode 100644 index 000000000..9e6cf575c --- /dev/null +++ b/licenses/LICENSE-sprig.txt @@ -0,0 +1,19 @@ +Copyright (C) 2013-2020 Masterminds + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. \ No newline at end of file From 45f7b422ff680e5280fe1e11a4d644ff50a62c8a Mon Sep 17 00:00:00 2001 From: Wenju Gao Date: Wed, 6 Sep 2023 20:56:13 +0800 Subject: [PATCH 02/20] fix(client): conn leak in stream mode when resp is no content(204) (#936) --- pkg/protocol/http1/client.go | 12 ++++++---- pkg/protocol/http1/client_test.go | 40 +++++++++++++++++++++++++++++++ pkg/protocol/request.go | 1 - pkg/protocol/response.go | 11 ++++++++- 4 files changed, 58 insertions(+), 6 deletions(-) diff --git a/pkg/protocol/http1/client.go b/pkg/protocol/http1/client.go index c2b7cf5a4..09fbc10cd 100644 --- a/pkg/protocol/http1/client.go +++ b/pkg/protocol/http1/client.go @@ -72,9 +72,12 @@ import ( respI "github.com/cloudwego/hertz/pkg/protocol/http1/resp" ) -var errConnectionClosed = errs.NewPublic("the server closed connection before returning the first response byte. " + - "Make sure the server returns 'Connection: close' response header before closing the connection") -var errTimeout = errs.New(errs.ErrTimeout, errs.ErrorTypePublic, "host client") +var ( + errConnectionClosed = errs.NewPublic("the server closed connection before returning the first response byte. " + + "Make sure the server returns 'Connection: close' response header before closing the connection") + + errTimeout = errs.New(errs.ErrTimeout, errs.ErrorTypePublic, "host client") +) // HostClient balances http requests among hosts listed in Addr. // @@ -699,7 +702,8 @@ func (c *HostClient) doNonNilReqResp(req *protocol.Request, resp *protocol.Respo shouldCloseConn = resetConnection || req.ConnectionClose() || resp.ConnectionClose() - if c.ResponseBodyStream { + // In stream mode, we still can close/release the connection immediately if there is no content on the wire. + if c.ResponseBodyStream && resp.BodyStream() != protocol.NoResponseBody { return false, err } diff --git a/pkg/protocol/http1/client_test.go b/pkg/protocol/http1/client_test.go index 7bc389bf9..7c45d1406 100644 --- a/pkg/protocol/http1/client_test.go +++ b/pkg/protocol/http1/client_test.go @@ -536,3 +536,43 @@ func TestConnNotRetry(t *testing.T) { assert.True(t, logbuf.String() == "") protocol.ReleaseResponse(resp) } + +type countCloseConn struct { + network.Conn + isClose bool +} + +func (c *countCloseConn) Close() error { + c.isClose = true + return nil +} + +func newCountCloseConn(s string) *countCloseConn { + return &countCloseConn{ + Conn: mock.NewConn(s), + } +} + +func TestStreamNoContent(t *testing.T) { + conn := newCountCloseConn("HTTP/1.1 204 Foo Bar\r\nContent-Type: aab\r\nTrailer: Foo\r\nContent-Encoding: deflate\r\nTransfer-Encoding: chunked\r\n\r\n0\r\nFoo: bar\r\n\r\nHTTP/1.2") + + c := &HostClient{ + ClientOptions: &ClientOptions{ + Dialer: newSlowConnDialer(func(network, addr string) (network.Conn, error) { + return conn, nil + }), + }, + Addr: "foobar", + } + + c.ResponseBodyStream = true + + req := protocol.AcquireRequest() + req.SetRequestURI("http://foobar/baz") + req.Header.SetConnectionClose(true) + resp := protocol.AcquireResponse() + + c.Do(context.Background(), req, resp) + + assert.True(t, conn.isClose) +} diff --git a/pkg/protocol/request.go b/pkg/protocol/request.go index ef70f92ca..d04731867 100644 --- a/pkg/protocol/request.go +++ b/pkg/protocol/request.go @@ -76,7 +76,6 @@ var ( // NoBody is an io.ReadCloser with no bytes. Read always returns EOF // and Close always returns nil. It can be used in an outgoing client // request to explicitly signal that a request has zero bytes. -// An alternative, however, is to simply set Request.Body to nil. var NoBody = noBody{} type noBody struct{} diff --git a/pkg/protocol/response.go b/pkg/protocol/response.go index 016af9234..8beb38597 100644 --- a/pkg/protocol/response.go +++ b/pkg/protocol/response.go @@ -54,7 +54,13 @@ import ( "github.com/cloudwego/hertz/pkg/network" ) -var responsePool sync.Pool +var ( + responsePool sync.Pool + // NoResponseBody is an io.ReadCloser with no bytes. Read always returns EOF + // and Close always returns nil. It can be used in an ingoing client + // response to explicitly signal that a response has zero bytes. + NoResponseBody = noBody{} +) // Response represents HTTP response. // @@ -334,6 +340,9 @@ func (resp *Response) SetBody(body []byte) { } func (resp *Response) BodyStream() io.Reader { + if resp.bodyStream == nil { + resp.bodyStream = NoResponseBody + } return resp.bodyStream } From 03376c92b0cc8aaa3016aedfb8c2954fb5ac34c5 Mon Sep 17 00:00:00 2001 From: Lorain <87760338+justlorain@users.noreply.github.com> Date: Thu, 7 Sep 2023 15:05:02 +0800 Subject: [PATCH 03/20] fix(hz): shebang position (#928) --- cmd/hz/generator/layout_tpl.go | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/cmd/hz/generator/layout_tpl.go b/cmd/hz/generator/layout_tpl.go index 6f2b4c351..68a6d636b 100644 --- a/cmd/hz/generator/layout_tpl.go +++ b/cmd/hz/generator/layout_tpl.go @@ -201,24 +201,20 @@ func GeneratedRegister(r *server.Hertz){ }, { Path: "build.sh", - Body: ` -#!/bin/bash + Body: `#!/bin/bash RUN_NAME={{.ServiceName}} mkdir -p output/bin cp script/* output 2>/dev/null chmod +x output/bootstrap.sh -go build -o output/bin/${RUN_NAME} -`, +go build -o output/bin/${RUN_NAME}`, }, { Path: defaultScriptDir + sp + "bootstrap.sh", - Body: ` -#!/bin/bash + Body: `#!/bin/bash CURDIR=$(cd $(dirname $0); pwd) BinaryName={{.ServiceName}} echo "$CURDIR/bin/${BinaryName}" -exec $CURDIR/bin/${BinaryName} -`, +exec $CURDIR/bin/${BinaryName}`, }, }, } From 8b1eaba889cf6e4aa8e287290e667052a6ed81e0 Mon Sep 17 00:00:00 2001 From: chaoranz758 <110596971+chaoranz758@users.noreply.github.com> Date: Thu, 14 Sep 2023 20:38:19 +0800 Subject: [PATCH 04/20] refactor(hz): fix hz client custom template function (#941) --- cmd/hz/app/app.go | 2 -- cmd/hz/generator/package.go | 3 +++ 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/cmd/hz/app/app.go b/cmd/hz/app/app.go index 269e5bad9..a6c44c011 100644 --- a/cmd/hz/app/app.go +++ b/cmd/hz/app/app.go @@ -321,8 +321,6 @@ func Init() *cli.App { &snakeNameFlag, &rmTagFlag, &excludeFilesFlag, - &customLayout, - &customLayoutData, &customPackage, &protoPluginsFlag, &thriftPluginsFlag, diff --git a/cmd/hz/generator/package.go b/cmd/hz/generator/package.go index 9eb00adf2..01a6884a1 100644 --- a/cmd/hz/generator/package.go +++ b/cmd/hz/generator/package.go @@ -167,6 +167,9 @@ func (pkgGen *HttpPackageGenerator) Generate(pkg *HttpPackage) error { if err := pkgGen.genClient(pkg, clientDir); err != nil { return err } + if err := pkgGen.genCustomizedFile(pkg); err != nil { + return err + } return nil } From d21a04a9b8cb76825a1994aa6ec3c37f72f9208d Mon Sep 17 00:00:00 2001 From: Joway Date: Mon, 18 Sep 2023 17:44:19 +0800 Subject: [PATCH 05/20] chore: netpoll pre release v0.4.2 (#877) --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 78039a7ec..b5eeb5251 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/bytedance/gopkg v0.0.0-20220413063733-65bf48ffb3a7 github.com/bytedance/mockey v1.2.1 github.com/bytedance/sonic v1.8.1 - github.com/cloudwego/netpoll v0.3.2 + github.com/cloudwego/netpoll v0.4.2-0.20230807055039-52fd5fb7b00f github.com/fsnotify/fsnotify v1.5.4 github.com/tidwall/gjson v1.13.0 // indirect golang.org/x/sync v0.0.0-20210220032951-036812b2e83c diff --git a/go.sum b/go.sum index fe32d2b48..59e21cf1c 100644 --- a/go.sum +++ b/go.sum @@ -10,8 +10,8 @@ github.com/bytedance/sonic v1.8.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZX github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= -github.com/cloudwego/netpoll v0.3.2 h1:/998ICrNMVBo4mlul4j7qcIeY7QnEfuCCPPwck9S3X4= -github.com/cloudwego/netpoll v0.3.2/go.mod h1:xVefXptcyheopwNDZjDPcfU6kIjZXZ4nY550k1yH9eQ= +github.com/cloudwego/netpoll v0.4.2-0.20230807055039-52fd5fb7b00f h1:8iWPKjHdXl4tjcSxUJTavnhRL5JPupYvxbtsAlm2Igw= +github.com/cloudwego/netpoll v0.4.2-0.20230807055039-52fd5fb7b00f/go.mod h1:xVefXptcyheopwNDZjDPcfU6kIjZXZ4nY550k1yH9eQ= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= From c52e1c9be6cd12d6bbbbf5107c612f8eeb6b388b Mon Sep 17 00:00:00 2001 From: chaoranz758 <110596971+chaoranz758@users.noreply.github.com> Date: Mon, 18 Sep 2023 20:34:08 +0800 Subject: [PATCH 06/20] feat: add Del func for requestHeader (#950) --- pkg/protocol/header.go | 6 ++++++ pkg/protocol/header_test.go | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/pkg/protocol/header.go b/pkg/protocol/header.go index 3ac88016b..f7533ef2f 100644 --- a/pkg/protocol/header.go +++ b/pkg/protocol/header.go @@ -1064,6 +1064,12 @@ func (h *RequestHeader) DelBytes(key []byte) { h.del(h.bufKV.key) } +// Del deletes header with the given key. +func (h *RequestHeader) Del(key string) { + k := getHeaderKeyBytes(&h.bufKV, key, h.disableNormalizing) + h.del(k) +} + func (h *RequestHeader) SetArgBytes(key, value []byte, noValue bool) { h.h = setArgBytes(h.h, key, value, noValue) } diff --git a/pkg/protocol/header_test.go b/pkg/protocol/header_test.go index 1fb021100..8d768fd78 100644 --- a/pkg/protocol/header_test.go +++ b/pkg/protocol/header_test.go @@ -209,6 +209,7 @@ func TestRequestHeaderDel(t *testing.T) { var h RequestHeader h.Set("Foo-Bar", "baz") h.Set("aaa", "bbb") + h.Set("ccc", "ddd") h.Set(consts.HeaderConnection, "keep-alive") h.Set(consts.HeaderContentType, "aaa") h.Set(consts.HeaderServer, "aaabbb") @@ -226,11 +227,16 @@ func TestRequestHeaderDel(t *testing.T) { h.del([]byte("Host")) h.del([]byte(consts.HeaderTrailer)) h.DelCookie("foo") + h.Del("ccc") hv := h.Peek("aaa") if string(hv) != "bbb" { t.Fatalf("unexpected header value: %q. Expecting %q", hv, "bbb") } + hv = h.Peek("ccc") + if string(hv) != "" { + t.Fatalf("unexpected header value: %q. Expecting %q", hv, "") + } hv = h.Peek("Foo-Bar") if len(hv) > 0 { t.Fatalf("non-zero header value: %q", hv) From e92838ee66d61ad87c92e256f7b44b83d7cafbda Mon Sep 17 00:00:00 2001 From: Jiun Lee <1808644906@qq.com> Date: Mon, 18 Sep 2023 22:59:07 +0800 Subject: [PATCH 07/20] fix: fix error printing of escaped characters in accesslog (#819) --- pkg/common/hlog/default.go | 7 ++++++- pkg/common/hlog/default_test.go | 10 ++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/pkg/common/hlog/default.go b/pkg/common/hlog/default.go index 6bc624a4a..bdb2a802d 100644 --- a/pkg/common/hlog/default.go +++ b/pkg/common/hlog/default.go @@ -149,10 +149,15 @@ func (ll *defaultLogger) logf(lv Level, format *string, v ...interface{}) { } msg := lv.toString() if format != nil { - msg += fmt.Sprintf(*format, v...) + if len(v) > 0 { + msg += fmt.Sprintf(*format, v...) + } else { + msg += *format + } } else { msg += fmt.Sprint(v...) } + ll.stdlog.Output(ll.depth, msg) if lv == LevelFatal { os.Exit(1) diff --git a/pkg/common/hlog/default_test.go b/pkg/common/hlog/default_test.go index ec8310ca2..36f063f9b 100644 --- a/pkg/common/hlog/default_test.go +++ b/pkg/common/hlog/default_test.go @@ -107,6 +107,16 @@ func TestCtxLogger(t *testing.T) { "[Error] work failed\n", string(w.b)) } +func TestFormatLoggerWithEscapedCharacters(t *testing.T) { + initTestLogger() + + var w byteSliceWriter + SetOutput(&w) + + Infof("http://localhost:8080/ping?f=http://localhost:3000/hello?c=%E5%A4%A7hi%E5%93%A6%E5%95%8A%E8%AF%B4%E5%BE%97%E5%A5%BD") + assert.DeepEqual(t, "[Info] http://localhost:8080/ping?f=http://localhost:3000/hello?c=%E5%A4%A7hi%E5%93%A6%E5%95%8A%E8%AF%B4%E5%BE%97%E5%A5%BD\n", string(w.b)) +} + func TestSetLevel(t *testing.T) { setLogger := &defaultLogger{ stdlog: log.New(os.Stderr, "", log.LstdFlags|log.Lshortfile|log.Lmicroseconds), From 025f404d8c23bec5b9a73d8ccc89c6ea3b8f96ea Mon Sep 17 00:00:00 2001 From: chaoranz758 <110596971+chaoranz758@users.noreply.github.com> Date: Tue, 19 Sep 2023 11:16:21 +0800 Subject: [PATCH 08/20] feat(hz): hz tool don't gen sh in windows (#942) --- cmd/hz/generator/layout.go | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/cmd/hz/generator/layout.go b/cmd/hz/generator/layout.go index d007fa82e..ea189f5a8 100644 --- a/cmd/hz/generator/layout.go +++ b/cmd/hz/generator/layout.go @@ -124,11 +124,22 @@ func (lg *LayoutGenerator) GenerateByService(service Layout) error { if !service.NeedGoMod { gomodFile := "go.mod" - if _, exist := lg.tpls["go.mod"]; exist { + if _, exist := lg.tpls[gomodFile]; exist { delete(lg.tpls, gomodFile) } } + if util.IsWindows() { + buildSh := "build.sh" + bootstrapSh := defaultScriptDir + sp + "bootstrap.sh" + if _, exist := lg.tpls[buildSh]; exist { + delete(lg.tpls, buildSh) + } + if _, exist := lg.tpls[bootstrapSh]; exist { + delete(lg.tpls, bootstrapSh) + } + } + sd, err := serviceToLayoutData(service) if err != nil { return err From a2d383fd041176b5ca5208175b61a57cacd9563c Mon Sep 17 00:00:00 2001 From: Xian <117006076+liuxianloveqiqi@users.noreply.github.com> Date: Tue, 19 Sep 2023 17:04:52 +0800 Subject: [PATCH 09/20] test: Increase pkg/common/utils coverage to more than 90% (#948) --- pkg/common/utils/ioutil_test.go | 64 +++++++++++++++++++++++++++++++- pkg/common/utils/network_test.go | 11 ++++++ pkg/common/utils/path_test.go | 5 +++ 3 files changed, 79 insertions(+), 1 deletion(-) diff --git a/pkg/common/utils/ioutil_test.go b/pkg/common/utils/ioutil_test.go index c47718014..e9a573b47 100644 --- a/pkg/common/utils/ioutil_test.go +++ b/pkg/common/utils/ioutil_test.go @@ -106,6 +106,23 @@ func TestIoutilCopyBuffer(t *testing.T) { assert.DeepEqual(t, written, srcLen) assert.DeepEqual(t, err, nil) assert.DeepEqual(t, []byte(str), writeBuffer.Bytes()) + + // Test when no data is readable + writeBuffer.Reset() + emptySrc := bytes.NewBufferString("") + written, err = CopyBuffer(dst, emptySrc, buf) + assert.DeepEqual(t, written, int64(0)) + assert.Nil(t, err) + assert.DeepEqual(t, []byte(""), writeBuffer.Bytes()) + + // Test a LimitedReader + writeBuffer.Reset() + limit := int64(5) + limitedSrc := io.LimitedReader{R: bytes.NewBufferString(str), N: limit} + written, err = CopyBuffer(dst, &limitedSrc, buf) + assert.DeepEqual(t, written, limit) + assert.Nil(t, err) + assert.DeepEqual(t, []byte(str[:limit]), writeBuffer.Bytes()) } func TestIoutilCopyBufferWithIoWriter(t *testing.T) { @@ -198,7 +215,7 @@ func TestIoutilCopyBufferWithNilBufferAndIoLimitedReader(t *testing.T) { func TestIoutilCopyZeroAlloc(t *testing.T) { var writeBuffer bytes.Buffer - str := string("hertz is very good!!!") + str := "hertz is very good!!!" src := bytes.NewBufferString(str) dst := network.NewWriter(&writeBuffer) srcLen := int64(src.Len()) @@ -207,4 +224,49 @@ func TestIoutilCopyZeroAlloc(t *testing.T) { assert.DeepEqual(t, written, srcLen) assert.DeepEqual(t, err, nil) assert.DeepEqual(t, []byte(str), writeBuffer.Bytes()) + + // Test when no data is readable + writeBuffer.Reset() + emptySrc := bytes.NewBufferString("") + written, err = CopyZeroAlloc(dst, emptySrc) + assert.DeepEqual(t, written, int64(0)) + assert.Nil(t, err) + assert.DeepEqual(t, []byte(""), writeBuffer.Bytes()) +} + +func TestIoutilCopyBufferWithEmptyBuffer(t *testing.T) { + var writeBuffer bytes.Buffer + str := "hertz is very good!!!" + src := bytes.NewBufferString(str) + dst := network.NewWriter(&writeBuffer) + // Use a non-empty buffer of length 0 + emptyBuf := make([]byte, 0) + func() { + defer func() { + if r := recover(); r != nil { + assert.DeepEqual(t, "empty buffer in io.CopyBuffer", r) + } + }() + + written, err := CopyBuffer(dst, src, emptyBuf) + assert.Nil(t, err) + assert.DeepEqual(t, written, int64(len(str))) + assert.DeepEqual(t, []byte(str), writeBuffer.Bytes()) + }() +} + +func TestIoutilCopyBufferWithLimitedReader(t *testing.T) { + var writeBuffer bytes.Buffer + str := "hertz is very good!!!" + src := bytes.NewBufferString(str) + limit := int64(5) + limitedSrc := io.LimitedReader{R: src, N: limit} + dst := network.NewWriter(&writeBuffer) + var buf []byte + + // Test LimitedReader status + written, err := CopyBuffer(dst, &limitedSrc, buf) + assert.Nil(t, err) + assert.DeepEqual(t, written, limit) + assert.DeepEqual(t, []byte(str[:limit]), writeBuffer.Bytes()) } diff --git a/pkg/common/utils/network_test.go b/pkg/common/utils/network_test.go index 232e35a88..78be77dbc 100644 --- a/pkg/common/utils/network_test.go +++ b/pkg/common/utils/network_test.go @@ -41,3 +41,14 @@ func TestTLSRecordHeaderLooksLikeHTTP(t *testing.T) { assert.DeepEqual(t, expectedResult, TLSRecordHeaderLooksLikeHTTP(value)) } } + +func TestLocalIP(t *testing.T) { + // Mock the localIP variable for testing purposes. + localIP = "192.168.0.1" + + // Ensure that LocalIP() returns the expected local IP. + expectedIP := "192.168.0.1" + if got := LocalIP(); got != expectedIP { + assert.DeepEqual(t, got, expectedIP) + } +} diff --git a/pkg/common/utils/path_test.go b/pkg/common/utils/path_test.go index 6ebb8430c..b98a8c747 100644 --- a/pkg/common/utils/path_test.go +++ b/pkg/common/utils/path_test.go @@ -82,6 +82,11 @@ func TestPathCleanPath(t *testing.T) { expectedMultiSlashPath := "/Foo/Bar/go/src/github.com/cloudwego" cleanMultiSlashPath := CleanPath(multiSlashPath) assert.DeepEqual(t, expectedMultiSlashPath, cleanMultiSlashPath) + + inputPath := "/Foo/Bar/go/src/github.com/cloudwego/hertz/pkg/common/utils/path_test.go/." + expectedPath := "/Foo/Bar/go/src/github.com/cloudwego/hertz/pkg/common/utils/path_test.go/" + cleanedPath := CleanPath(inputPath) + assert.DeepEqual(t, expectedPath, cleanedPath) } // The Function AddMissingPort can only add the missed port, don't consider the other error case. From fecbe811144c305bf415521758d6bab9cd33e64c Mon Sep 17 00:00:00 2001 From: kinggo Date: Tue, 19 Sep 2023 17:05:10 +0800 Subject: [PATCH 10/20] optimize: adjusting the judgement logic of isTrustedProxy (#934) --- pkg/app/context.go | 54 +++++++++++++++++++------- pkg/app/context_test.go | 86 ++++++++++++++++++++++++----------------- 2 files changed, 90 insertions(+), 50 deletions(-) diff --git a/pkg/app/context.go b/pkg/app/context.go index 559b84d03..9a6e45440 100644 --- a/pkg/app/context.go +++ b/pkg/app/context.go @@ -80,48 +80,72 @@ type ClientIP func(ctx *RequestContext) string type ClientIPOptions struct { RemoteIPHeaders []string - TrustedProxies map[string]bool + TrustedCIDRs []*net.IPNet } -var defaultClientIPOptions = ClientIPOptions{ - RemoteIPHeaders: []string{"X-Real-IP", "X-Forwarded-For"}, - TrustedProxies: map[string]bool{ - "0.0.0.0": true, +var defaultTrustedCIDRs = []*net.IPNet{ + { // 0.0.0.0/0 (IPv4) + IP: net.IP{0x0, 0x0, 0x0, 0x0}, + Mask: net.IPMask{0x0, 0x0, 0x0, 0x0}, + }, + { // ::/0 (IPv6) + IP: net.IP{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, + Mask: net.IPMask{0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0}, }, } +var defaultClientIPOptions = ClientIPOptions{ + RemoteIPHeaders: []string{"X-Forwarded-For", "X-Real-IP"}, + TrustedCIDRs: defaultTrustedCIDRs, +} + // ClientIPWithOption used to generate custom ClientIP function and set by engine.SetClientIPFunc func ClientIPWithOption(opts ClientIPOptions) ClientIP { return func(ctx *RequestContext) string { RemoteIPHeaders := opts.RemoteIPHeaders - TrustedProxies := opts.TrustedProxies + TrustedCIDRs := opts.TrustedCIDRs - remoteIP, _, err := net.SplitHostPort(strings.TrimSpace(ctx.RemoteAddr().String())) + remoteIPStr, _, err := net.SplitHostPort(strings.TrimSpace(ctx.RemoteAddr().String())) if err != nil { return "" } - trusted := isTrustedProxy(TrustedProxies, remoteIP) + + remoteIP := net.ParseIP(remoteIPStr) + if remoteIP == nil { + return "" + } + + trusted := isTrustedProxy(TrustedCIDRs, remoteIP) if trusted { for _, headerName := range RemoteIPHeaders { - ip, valid := validateHeader(TrustedProxies, ctx.Request.Header.Get(headerName)) + ip, valid := validateHeader(TrustedCIDRs, ctx.Request.Header.Get(headerName)) if valid { return ip } } } - return remoteIP + return remoteIPStr } } -// isTrustedProxy will check whether the IP address is included in the trusted list according to TrustedProxies -func isTrustedProxy(trustedProxies map[string]bool, remoteIP string) bool { - return trustedProxies[remoteIP] +// isTrustedProxy will check whether the IP address is included in the trusted list according to trustedCIDRs +func isTrustedProxy(trustedCIDRs []*net.IPNet, remoteIP net.IP) bool { + if trustedCIDRs == nil { + return false + } + + for _, cidr := range trustedCIDRs { + if cidr.Contains(remoteIP) { + return true + } + } + return false } // validateHeader will parse X-Real-IP and X-Forwarded-For header and return the Initial client IP address or an untrusted IP address -func validateHeader(trustedProxies map[string]bool, header string) (clientIP string, valid bool) { +func validateHeader(trustedCIDRs []*net.IPNet, header string) (clientIP string, valid bool) { if header == "" { return "", false } @@ -135,7 +159,7 @@ func validateHeader(trustedProxies map[string]bool, header string) (clientIP str // X-Forwarded-For is appended by proxy // Check IPs in reverse order and stop when find untrusted proxy - if (i == 0) || (!isTrustedProxy(trustedProxies, ipStr)) { + if (i == 0) || (!isTrustedProxy(trustedCIDRs, ip)) { return ipStr, true } } diff --git a/pkg/app/context_test.go b/pkg/app/context_test.go index 7211e3910..88cd26338 100644 --- a/pkg/app/context_test.go +++ b/pkg/app/context_test.go @@ -24,6 +24,7 @@ import ( "fmt" "html/template" "io/ioutil" + "net" "os" "reflect" "strings" @@ -802,51 +803,66 @@ func TestContextContentType(t *testing.T) { assert.DeepEqual(t, consts.MIMEApplicationJSONUTF8, bytesconv.B2s(c.ContentType())) } -func TestClientIp(t *testing.T) { +type MockIpConn struct { + *mock.Conn + RemoteIp string + Port int +} + +func (c *MockIpConn) RemoteAddr() net.Addr { + return &net.UDPAddr{ + IP: net.ParseIP(c.RemoteIp), + Port: c.Port, + } +} + +func newContextClientIPTest() *RequestContext { c := NewContext(0) - c.conn = mock.NewConn("") - // 0.0.0.0 simulates a trusted proxy server - c.Request.Header.Set("X-Forwarded-For", " 126.0.0.2, 0.0.0.0 ") - val := c.ClientIP() - if val != "126.0.0.2" { - t.Fatalf("unexpected %v. Expecting %v", val, "126.0.0.2") - } - // no proxy server - c = NewContext(0) - c.conn = mock.NewConn("") - c.Request.Header.Set("X-Real-Ip", "126.0.0.1") - val = c.ClientIP() - if val != "126.0.0.1" { - t.Fatalf("unexpected %v. Expecting %v", val, "126.0.0.1") - } - // custom RemoteIPHeaders and TrustedProxies + c.conn = &MockIpConn{ + Conn: mock.NewConn(""), + RemoteIp: "127.0.0.1", + Port: 8080, + } + c.Request.Header.Set("X-Real-IP", " 10.10.10.10 ") + c.Request.Header.Set("X-Forwarded-For", " 20.20.20.20, 30.30.30.30") + return c +} + +func TestClientIp(t *testing.T) { + c := newContextClientIPTest() + // default X-Forwarded-For and X-Real-IP behaviour + assert.DeepEqual(t, "20.20.20.20", c.ClientIP()) + + c.Request.Header.DelBytes([]byte("X-Forwarded-For")) + assert.DeepEqual(t, "10.10.10.10", c.ClientIP()) + + c.Request.Header.Set("X-Forwarded-For", "30.30.30.30 ") + assert.DeepEqual(t, "30.30.30.30", c.ClientIP()) + + // No trusted CIDRS + c = newContextClientIPTest() opts := ClientIPOptions{ RemoteIPHeaders: []string{"X-Forwarded-For", "X-Real-IP"}, - TrustedProxies: map[string]bool{ - "0.0.0.0": true, - }, + TrustedCIDRs: nil, } - c = NewContext(0) c.SetClientIPFunc(ClientIPWithOption(opts)) - c.conn = mock.NewConn("") - c.Request.Header.Set("X-Forwarded-For", " 126.0.0.2, 0.0.0.0 ") - val = c.ClientIP() - if val != "126.0.0.2" { - t.Fatalf("unexpected %v. Expecting %v", val, "126.0.0.2") - } - // no trusted proxy server + assert.DeepEqual(t, "127.0.0.1", c.ClientIP()) + + _, cidr, _ := net.ParseCIDR("30.30.30.30/32") opts = ClientIPOptions{ RemoteIPHeaders: []string{"X-Forwarded-For", "X-Real-IP"}, - TrustedProxies: nil, + TrustedCIDRs: []*net.IPNet{cidr}, } - c = NewContext(0) c.SetClientIPFunc(ClientIPWithOption(opts)) - c.conn = mock.NewConn("") - c.Request.Header.Set("X-Forwarded-For", " 126.0.0.2, 0.0.0.0 ") - val = c.ClientIP() - if val != "0.0.0.0" { - t.Fatalf("unexpected %v. Expecting %v", val, "0.0.0.0") + assert.DeepEqual(t, "127.0.0.1", c.ClientIP()) + + _, cidr, _ = net.ParseCIDR("127.0.0.1/32") + opts = ClientIPOptions{ + RemoteIPHeaders: []string{"X-Forwarded-For", "X-Real-IP"}, + TrustedCIDRs: []*net.IPNet{cidr}, } + c.SetClientIPFunc(ClientIPWithOption(opts)) + assert.DeepEqual(t, "30.30.30.30", c.ClientIP()) } func TestSetClientIPFunc(t *testing.T) { From a91d4dd2b6954430ff859af863a24c2cd4003a97 Mon Sep 17 00:00:00 2001 From: GuangyuFan <97507466+FGYFFFF@users.noreply.github.com> Date: Thu, 21 Sep 2023 14:35:00 +0800 Subject: [PATCH 11/20] fix(hz): client default tag error (#944) --- cmd/hz/protobuf/ast.go | 2 +- cmd/hz/thrift/ast.go | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/cmd/hz/protobuf/ast.go b/cmd/hz/protobuf/ast.go index 264dc4e30..d607a8394 100644 --- a/cmd/hz/protobuf/ast.go +++ b/cmd/hz/protobuf/ast.go @@ -354,7 +354,7 @@ func parseAnnotationToClient(clientMethod *generator.ClientMethod, gen *protogen clientMethod.FormFileCode += fmt.Sprintf("%q: req.Get%s(),\n", val, f.GoName) } if !hasAnnotation && strings.EqualFold(clientMethod.HTTPMethod, "get") { - clientMethod.QueryParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", checkSnakeName(f.GoName), f.GoName) + clientMethod.QueryParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", checkSnakeName(string(f.Desc.Name())), f.GoName) } } clientMethod.BodyParamsCode = meta.SetBodyParam diff --git a/cmd/hz/thrift/ast.go b/cmd/hz/thrift/ast.go index f634d9605..83722358f 100644 --- a/cmd/hz/thrift/ast.go +++ b/cmd/hz/thrift/ast.go @@ -312,7 +312,7 @@ func parseAnnotationToClient(clientMethod *generator.ClientMethod, p *parser.Typ clientMethod.FormFileCode += fmt.Sprintf("%q: req.Get%s(),\n", fileName, field.GoName().String()) } if !hasAnnotation && strings.EqualFold(clientMethod.HTTPMethod, "get") { - clientMethod.QueryParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", checkSnakeName(field.GoName().String()), field.GoName().String()) + clientMethod.QueryParamsCode += fmt.Sprintf("%q: req.Get%s(),\n", checkSnakeName(field.GetName()), field.GoName().String()) } } clientMethod.BodyParamsCode = meta.SetBodyParam From b62c1c6630f3b55af07c7f94773ef74a7b227737 Mon Sep 17 00:00:00 2001 From: GuangyuFan <97507466+FGYFFFF@users.noreply.github.com> Date: Thu, 21 Sep 2023 15:44:05 +0800 Subject: [PATCH 12/20] feat: add a public method to update engine status (#889) --- pkg/route/engine.go | 13 +++++++++++-- pkg/route/engine_test.go | 8 ++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/pkg/route/engine.go b/pkg/route/engine.go index 10cffac7a..480e09258 100644 --- a/pkg/route/engine.go +++ b/pkg/route/engine.go @@ -341,8 +341,8 @@ func (engine *Engine) Run() (err error) { return err } - if !atomic.CompareAndSwapUint32(&engine.status, statusInitialized, statusRunning) { - return errAlreadyRunning + if err = engine.MarkAsRunning(); err != nil { + return err } defer atomic.StoreUint32(&engine.status, statusClosed) @@ -1022,3 +1022,12 @@ func versionToALNP(v uint32) string { } return "" } + +// MarkAsRunning will mark the status of the hertz engine as "running". +// Warning: do not call this method by yourself, unless you know what you are doing. +func (engine *Engine) MarkAsRunning() (err error) { + if !atomic.CompareAndSwapUint32(&engine.status, statusInitialized, statusRunning) { + return errAlreadyRunning + } + return nil +} diff --git a/pkg/route/engine_test.go b/pkg/route/engine_test.go index fe528b457..350da8177 100644 --- a/pkg/route/engine_test.go +++ b/pkg/route/engine_test.go @@ -502,6 +502,14 @@ func TestRenderHtmlOfFilesWithAutoRender(t *testing.T) { assert.DeepEqual(t, "text/html; charset=utf-8", rr.Header().Get("Content-Type")) } +func TestSetEngineRun(t *testing.T) { + e := NewEngine(config.NewOptions(nil)) + e.Init() + assert.True(t, !e.IsRunning()) + e.MarkAsRunning() + assert.True(t, e.IsRunning()) +} + type mockConn struct{} func (m *mockConn) SetWriteTimeout(t time.Duration) error { From 59196eb9da0c257a18345091093467a073311efc Mon Sep 17 00:00:00 2001 From: Wenju Gao Date: Thu, 21 Sep 2023 16:12:52 +0800 Subject: [PATCH 13/20] feat: add config hook for host client (#938) --- pkg/app/client/client.go | 10 ++++++++ pkg/app/client/client_test.go | 38 ++++++++++++++++++++++++++++++ pkg/app/client/option.go | 7 ++++++ pkg/common/config/client_option.go | 4 ++++ 4 files changed, 59 insertions(+) diff --git a/pkg/app/client/client.go b/pkg/app/client/client.go index 33f9fb539..614b5b5c7 100644 --- a/pkg/app/client/client.go +++ b/pkg/app/client/client.go @@ -513,6 +513,16 @@ func (c *Client) do(ctx context.Context, req *protocol.Request, resp *protocol.R ProxyURI: proxyURI, IsTLS: isTLS, }) + + // re-configure hook + if c.options.HostClientConfigHook != nil { + err = c.options.HostClientConfigHook(hc) + if err != nil { + c.mLock.Unlock() + return err + } + } + m[h] = hc if len(m) == 1 { startCleaner = true diff --git a/pkg/app/client/client_test.go b/pkg/app/client/client_test.go index e7dcbf506..d231fac54 100644 --- a/pkg/app/client/client_test.go +++ b/pkg/app/client/client_test.go @@ -2151,6 +2151,44 @@ func TestClientRetry(t *testing.T) { } } +func TestClientHostClientConfigHookError(t *testing.T) { + client, _ := NewClient(WithHostClientConfigHook(func(hc interface{}) error { + hct, ok := hc.(*http1.HostClient) + assert.True(t, ok) + assert.DeepEqual(t, "foo.bar:80", hct.Addr) + return errors.New("hook return") + })) + + req := protocol.AcquireRequest() + req.SetMethod(consts.MethodGet) + req.SetRequestURI("http://foo.bar/") + resp := protocol.AcquireResponse() + err := client.do(nil, req, resp) + assert.DeepEqual(t, "hook return", err.Error()) +} + +func TestClientHostClientConfigHook(t *testing.T) { + client, _ := NewClient(WithHostClientConfigHook(func(hc interface{}) error { + hct, ok := hc.(*http1.HostClient) + assert.True(t, ok) + assert.DeepEqual(t, "foo.bar:80", hct.Addr) + hct.Addr = "FOO.BAR:443" + return nil + })) + + req := protocol.AcquireRequest() + req.SetMethod(consts.MethodGet) + req.SetRequestURI("http://foo.bar/") + resp := protocol.AcquireResponse() + client.do(context.Background(), req, resp) + client.mLock.Lock() + hc := client.m["foo.bar"] + client.mLock.Unlock() + hcr, ok := hc.(*http1.HostClient) + assert.True(t, ok) + assert.DeepEqual(t, "FOO.BAR:443", hcr.Addr) +} + func TestClientDialerName(t *testing.T) { client, _ := NewClient() dName, err := client.GetDialerName() diff --git a/pkg/app/client/option.go b/pkg/app/client/option.go index 3e03c4431..c65cc76c6 100644 --- a/pkg/app/client/option.go +++ b/pkg/app/client/option.go @@ -99,6 +99,13 @@ func WithResponseBodyStream(b bool) config.ClientOption { }} } +// WithHostClientConfigHook is used to set the function hook for re-configure the host client. +func WithHostClientConfigHook(h func(hc interface{}) error) config.ClientOption { + return config.ClientOption{F: func(o *config.ClientOptions) { + o.HostClientConfigHook = h + }} +} + // WithDisableHeaderNamesNormalizing is used to set whether disable header names normalizing. func WithDisableHeaderNamesNormalizing(disable bool) config.ClientOption { return config.ClientOption{F: func(o *config.ClientOptions) { diff --git a/pkg/common/config/client_option.go b/pkg/common/config/client_option.go index 45133838a..8abc15b0d 100644 --- a/pkg/common/config/client_option.go +++ b/pkg/common/config/client_option.go @@ -128,6 +128,10 @@ type ClientOptions struct { // StateObserve execution interval ObservationInterval time.Duration + + // Callback hook for re-configuring host client + // If an error is returned, the request will be terminated. + HostClientConfigHook func(hc interface{}) error } func NewClientOptions(opts []ClientOption) *ClientOptions { From 02112cc7697a89bb3ee15ca6ed829b4f456be4da Mon Sep 17 00:00:00 2001 From: Wenju Gao Date: Thu, 21 Sep 2023 20:32:50 +0800 Subject: [PATCH 14/20] feat: add disable normalizing (#940) Co-authored-by: kinggo --- pkg/app/client/client_test.go | 2 +- pkg/app/server/hertz_test.go | 36 ++++++++++++++++++++++++++++++++ pkg/app/server/option.go | 7 +++++++ pkg/app/server/option_test.go | 3 +++ pkg/common/config/option.go | 19 +++++++++++++++++ pkg/common/config/option_test.go | 1 + pkg/protocol/header.go | 2 ++ pkg/protocol/http1/server.go | 35 ++++++++++++++++++------------- pkg/route/engine.go | 29 ++++++++++++------------- 9 files changed, 105 insertions(+), 29 deletions(-) diff --git a/pkg/app/client/client_test.go b/pkg/app/client/client_test.go index d231fac54..5f3e8c62f 100644 --- a/pkg/app/client/client_test.go +++ b/pkg/app/client/client_test.go @@ -2163,7 +2163,7 @@ func TestClientHostClientConfigHookError(t *testing.T) { req.SetMethod(consts.MethodGet) req.SetRequestURI("http://foo.bar/") resp := protocol.AcquireResponse() - err := client.do(nil, req, resp) + err := client.do(context.TODO(), req, resp) assert.DeepEqual(t, "hook return", err.Error()) } diff --git a/pkg/app/server/hertz_test.go b/pkg/app/server/hertz_test.go index 3ce91a83f..a5cf7d350 100644 --- a/pkg/app/server/hertz_test.go +++ b/pkg/app/server/hertz_test.go @@ -784,3 +784,39 @@ func TestSilentMode(t *testing.T) { t.Fatalf("unexpected error in log: %s", b.String()) } } + +func TestHertzDisableHeaderNamesNormalizing(t *testing.T) { + h := New( + WithHostPorts("localhost:9212"), + WithDisableHeaderNamesNormalizing(true), + ) + headerName := "CASE-senSITive-HEAder-NAME" + headerValue := "foobar-baz" + succeed := false + h.GET("/test", func(c context.Context, ctx *app.RequestContext) { + ctx.VisitAllHeaders(func(key, value []byte) { + if string(key) == headerName && string(value) == headerValue { + succeed = true + return + } + }) + if !succeed { + t.Fatalf("DisableHeaderNamesNormalizing failed") + } else { + ctx.Header(headerName, headerValue) + } + }) + + go h.Spin() + time.Sleep(100 * time.Millisecond) + + cli, _ := c.NewClient(c.WithDisableHeaderNamesNormalizing(true)) + + r := protocol.NewRequest("GET", "http://localhost:9212/test", nil) + r.Header.DisableNormalizing() + r.Header.Set(headerName, headerValue) + res := protocol.AcquireResponse() + err := cli.Do(context.Background(), r, res) + assert.Nil(t, err) + assert.DeepEqual(t, headerValue, res.Header.Get(headerName)) +} diff --git a/pkg/app/server/option.go b/pkg/app/server/option.go index bf4482904..c9e3735be 100644 --- a/pkg/app/server/option.go +++ b/pkg/app/server/option.go @@ -346,3 +346,10 @@ func WithOnConnect(fn func(ctx context.Context, conn network.Conn) context.Conte o.OnConnect = fn }} } + +// WithDisableHeaderNamesNormalizing is used to set whether disable header names normalizing. +func WithDisableHeaderNamesNormalizing(disable bool) config.Option { + return config.Option{F: func(o *config.Options) { + o.DisableHeaderNamesNormalizing = disable + }} +} diff --git a/pkg/app/server/option_test.go b/pkg/app/server/option_test.go index aef554c14..f5d7f7b32 100644 --- a/pkg/app/server/option_test.go +++ b/pkg/app/server/option_test.go @@ -75,6 +75,7 @@ func TestOptions(t *testing.T) { WithAutoReloadRender(true, 5*time.Second), WithListenConfig(cfg), WithAltTransport(transporter), + WithDisableHeaderNamesNormalizing(true), }) assert.DeepEqual(t, opt.ReadTimeout, time.Second) assert.DeepEqual(t, opt.WriteTimeout, time.Second) @@ -107,6 +108,7 @@ func TestOptions(t *testing.T) { assert.DeepEqual(t, opt.AutoReloadInterval, 5*time.Second) assert.DeepEqual(t, opt.ListenConfig, cfg) assert.Assert(t, reflect.TypeOf(opt.AltTransporterNewer) == reflect.TypeOf(transporter)) + assert.DeepEqual(t, opt.DisableHeaderNamesNormalizing, true) } func TestDefaultOptions(t *testing.T) { @@ -139,6 +141,7 @@ func TestDefaultOptions(t *testing.T) { assert.Assert(t, opt.RegistryInfo == nil) assert.DeepEqual(t, opt.AutoReloadRender, false) assert.DeepEqual(t, opt.AutoReloadInterval, time.Duration(0)) + assert.DeepEqual(t, opt.DisableHeaderNamesNormalizing, false) } type mockTransporter struct{} diff --git a/pkg/common/config/option.go b/pkg/common/config/option.go index 03c84a02a..9ef7ddf42 100644 --- a/pkg/common/config/option.go +++ b/pkg/common/config/option.go @@ -97,6 +97,22 @@ type Options struct { // The HTML template will reload according to files' changing event // otherwise it will reload after AutoReloadInterval. AutoReloadInterval time.Duration + + // Header names are passed as-is without normalization + // if this option is set. + // + // Disabled header names' normalization may be useful only for proxying + // responses to other clients expecting case-sensitive header names. + // + // By default, request and response header names are normalized, i.e. + // The first letter and the first letters following dashes + // are uppercased, while all the other letters are lowercased. + // Examples: + // + // * HOST -> Host + // * content-type -> Content-Type + // * cONTENT-lenGTH -> Content-Length + DisableHeaderNamesNormalizing bool } func (o *Options) Apply(opts []Option) { @@ -225,6 +241,9 @@ func NewOptions(opts []Option) *Options { TraceLevel: new(interface{}), Registry: registry.NoopRegistry, + + // Disabled header names' normalization, default false + DisableHeaderNamesNormalizing: false, } options.Apply(opts) return options diff --git a/pkg/common/config/option_test.go b/pkg/common/config/option_test.go index 49ea661f3..39d92d736 100644 --- a/pkg/common/config/option_test.go +++ b/pkg/common/config/option_test.go @@ -53,6 +53,7 @@ func TestDefaultOptions(t *testing.T) { assert.DeepEqual(t, []interface{}{}, options.Tracers) assert.DeepEqual(t, new(interface{}), options.TraceLevel) assert.DeepEqual(t, registry.NoopRegistry, options.Registry) + assert.DeepEqual(t, false, options.DisableHeaderNamesNormalizing) } // TestApplyCustomOptions test apply options with custom values after init diff --git a/pkg/protocol/header.go b/pkg/protocol/header.go index f7533ef2f..14df744ef 100644 --- a/pkg/protocol/header.go +++ b/pkg/protocol/header.go @@ -1497,6 +1497,7 @@ func (h *RequestHeader) UserAgent() []byte { // Disable header names' normalization only if you know what are you doing. func (h *RequestHeader) DisableNormalizing() { h.disableNormalizing = true + h.Trailer().DisableNormalizing() } func (h *RequestHeader) IsDisableNormalizing() bool { @@ -1697,6 +1698,7 @@ func (h *RequestHeader) SetMethodBytes(method []byte) { // Disable header names' normalization only if you know what are you doing. func (h *ResponseHeader) DisableNormalizing() { h.disableNormalizing = true + h.Trailer().DisableNormalizing() } // setSpecialHeader handles special headers and return true when a header is processed. diff --git a/pkg/protocol/http1/server.go b/pkg/protocol/http1/server.go index 77cc74e18..66cb01e85 100644 --- a/pkg/protocol/http1/server.go +++ b/pkg/protocol/http1/server.go @@ -54,20 +54,21 @@ var ( ) type Option struct { - StreamRequestBody bool - GetOnly bool - DisablePreParseMultipartForm bool - DisableKeepalive bool - NoDefaultServerHeader bool - MaxRequestBodySize int - IdleTimeout time.Duration - ReadTimeout time.Duration - ServerName []byte - TLS *tls.Config - HTMLRender render.HTMLRender - EnableTrace bool - ContinueHandler func(header *protocol.RequestHeader) bool - HijackConnHandle func(c network.Conn, h app.HijackHandler) + StreamRequestBody bool + GetOnly bool + DisablePreParseMultipartForm bool + DisableKeepalive bool + NoDefaultServerHeader bool + DisableHeaderNamesNormalizing bool + MaxRequestBodySize int + IdleTimeout time.Duration + ReadTimeout time.Duration + ServerName []byte + TLS *tls.Config + HTMLRender render.HTMLRender + EnableTrace bool + ContinueHandler func(header *protocol.RequestHeader) bool + HijackConnHandle func(c network.Conn, h app.HijackHandler) } type Server struct { @@ -179,6 +180,12 @@ func (s Server) Serve(c context.Context, conn network.Conn) (err error) { internalStats.Record(ti, stats.ReadHeaderFinish, err) }) } + + if s.DisableHeaderNamesNormalizing { + ctx.Request.Header.DisableNormalizing() + ctx.Response.Header.DisableNormalizing() + } + // Read Headers if err = req.ReadHeader(&ctx.Request.Header, zr); err == nil { if s.EnableTrace { diff --git a/pkg/route/engine.go b/pkg/route/engine.go index 480e09258..bd8fbc1a9 100644 --- a/pkg/route/engine.go +++ b/pkg/route/engine.go @@ -990,20 +990,21 @@ func iterate(method string, routes RoutesInfo, root *node) RoutesInfo { // for built-in http1 impl only. func newHttp1OptionFromEngine(engine *Engine) *http1.Option { opt := &http1.Option{ - StreamRequestBody: engine.options.StreamRequestBody, - GetOnly: engine.options.GetOnly, - DisablePreParseMultipartForm: engine.options.DisablePreParseMultipartForm, - DisableKeepalive: engine.options.DisableKeepalive, - NoDefaultServerHeader: engine.options.NoDefaultServerHeader, - MaxRequestBodySize: engine.options.MaxRequestBodySize, - IdleTimeout: engine.options.IdleTimeout, - ReadTimeout: engine.options.ReadTimeout, - ServerName: engine.GetServerName(), - ContinueHandler: engine.ContinueHandler, - TLS: engine.options.TLS, - HTMLRender: engine.htmlRender, - EnableTrace: engine.IsTraceEnable(), - HijackConnHandle: engine.HijackConnHandle, + StreamRequestBody: engine.options.StreamRequestBody, + GetOnly: engine.options.GetOnly, + DisablePreParseMultipartForm: engine.options.DisablePreParseMultipartForm, + DisableKeepalive: engine.options.DisableKeepalive, + NoDefaultServerHeader: engine.options.NoDefaultServerHeader, + MaxRequestBodySize: engine.options.MaxRequestBodySize, + IdleTimeout: engine.options.IdleTimeout, + ReadTimeout: engine.options.ReadTimeout, + ServerName: engine.GetServerName(), + ContinueHandler: engine.ContinueHandler, + TLS: engine.options.TLS, + HTMLRender: engine.htmlRender, + EnableTrace: engine.IsTraceEnable(), + HijackConnHandle: engine.HijackConnHandle, + DisableHeaderNamesNormalizing: engine.options.DisableHeaderNamesNormalizing, } // Idle timeout of standard network must not be zero. Set it to -1 seconds if it is zero. // Due to the different triggering ways of the network library, see the actual use of this value for the detailed reasons. From 8904972b8bd419b23c43a9c3fa28ba611f759225 Mon Sep 17 00:00:00 2001 From: GuangyuFan <97507466+FGYFFFF@users.noreply.github.com> Date: Fri, 22 Sep 2023 14:49:12 +0800 Subject: [PATCH 15/20] refactor: binding (#541) --- .github/workflows/pr-check.yml | 2 +- go.mod | 2 +- go.sum | 4 +- pkg/app/context.go | 93 +- pkg/app/context_test.go | 118 ++ pkg/app/server/binding/binder.go | 58 + pkg/app/server/binding/binder_test.go | 1479 +++++++++++++++++ pkg/app/server/binding/binding.go | 122 -- pkg/app/server/binding/binding_test.go | 450 ----- pkg/app/server/binding/config.go | 170 ++ pkg/app/server/binding/default.go | 410 +++++ .../internal/decoder/base_type_decoder.go | 181 ++ .../decoder/customized_type_decoder.go | 141 ++ .../binding/internal/decoder/decoder.go | 191 +++ .../server/binding/internal/decoder/getter.go | 134 ++ .../internal/decoder/gjson_required.go | 49 + .../internal/decoder/map_type_decoder.go | 165 ++ .../decoder/multipart_file_decoder.go | 165 ++ .../binding/internal/decoder/reflect.go | 113 ++ .../binding/internal/decoder/slice_getter.go | 143 ++ .../internal/decoder/slice_type_decoder.go | 250 +++ .../internal/decoder/sonic_required.go | 62 + .../internal/decoder/struct_type_decoder.go | 142 ++ .../server/binding/internal/decoder/tag.go | 164 ++ .../binding/internal/decoder/text_decoder.go | 169 ++ pkg/app/server/binding/reflect.go | 73 + .../server/binding/reflect_internal_test.go | 90 + pkg/app/server/binding/reflect_test.go | 87 + pkg/app/server/binding/request.go | 138 -- pkg/app/server/binding/request_test.go | 235 --- pkg/app/server/binding/tagexpr_bind_test.go | 1281 ++++++++++++++ pkg/app/server/binding/testdata/hello.pb.go | 157 ++ pkg/app/server/binding/testdata/hello.proto | 24 + pkg/app/server/binding/validator.go | 46 + pkg/app/server/binding/validator_test.go | 35 + pkg/app/server/hertz_test.go | 177 +- pkg/app/server/option.go | 22 + pkg/common/config/option.go | 3 + pkg/common/config/option_test.go | 3 + pkg/common/utils/utils.go | 9 + pkg/common/utils/utils_test.go | 6 + pkg/protocol/consts/headers.go | 2 + pkg/route/engine.go | 43 + pkg/route/engine_test.go | 181 ++ 44 files changed, 6633 insertions(+), 956 deletions(-) create mode 100644 pkg/app/server/binding/binder.go create mode 100644 pkg/app/server/binding/binder_test.go delete mode 100644 pkg/app/server/binding/binding.go delete mode 100644 pkg/app/server/binding/binding_test.go create mode 100644 pkg/app/server/binding/config.go create mode 100644 pkg/app/server/binding/default.go create mode 100644 pkg/app/server/binding/internal/decoder/base_type_decoder.go create mode 100644 pkg/app/server/binding/internal/decoder/customized_type_decoder.go create mode 100644 pkg/app/server/binding/internal/decoder/decoder.go create mode 100644 pkg/app/server/binding/internal/decoder/getter.go create mode 100644 pkg/app/server/binding/internal/decoder/gjson_required.go create mode 100644 pkg/app/server/binding/internal/decoder/map_type_decoder.go create mode 100644 pkg/app/server/binding/internal/decoder/multipart_file_decoder.go create mode 100644 pkg/app/server/binding/internal/decoder/reflect.go create mode 100644 pkg/app/server/binding/internal/decoder/slice_getter.go create mode 100644 pkg/app/server/binding/internal/decoder/slice_type_decoder.go create mode 100644 pkg/app/server/binding/internal/decoder/sonic_required.go create mode 100644 pkg/app/server/binding/internal/decoder/struct_type_decoder.go create mode 100644 pkg/app/server/binding/internal/decoder/tag.go create mode 100644 pkg/app/server/binding/internal/decoder/text_decoder.go create mode 100644 pkg/app/server/binding/reflect.go create mode 100644 pkg/app/server/binding/reflect_internal_test.go create mode 100644 pkg/app/server/binding/reflect_test.go delete mode 100644 pkg/app/server/binding/request.go delete mode 100644 pkg/app/server/binding/request_test.go create mode 100644 pkg/app/server/binding/tagexpr_bind_test.go create mode 100644 pkg/app/server/binding/testdata/hello.pb.go create mode 100644 pkg/app/server/binding/testdata/hello.proto create mode 100644 pkg/app/server/binding/validator.go create mode 100644 pkg/app/server/binding/validator_test.go diff --git a/.github/workflows/pr-check.yml b/.github/workflows/pr-check.yml index 683bd08c3..dba5baaf3 100644 --- a/.github/workflows/pr-check.yml +++ b/.github/workflows/pr-check.yml @@ -34,4 +34,4 @@ jobs: # Exit with 1 when it find at least one finding. fail_on_error: true # Set staticcheck flags - staticcheck_flags: -checks=inherit,-SA1029 + staticcheck_flags: -checks=inherit,-SA1029,-SA5008 diff --git a/go.mod b/go.mod index b5eeb5251..25f20f8a7 100644 --- a/go.mod +++ b/go.mod @@ -9,7 +9,7 @@ require ( github.com/bytedance/sonic v1.8.1 github.com/cloudwego/netpoll v0.4.2-0.20230807055039-52fd5fb7b00f github.com/fsnotify/fsnotify v1.5.4 - github.com/tidwall/gjson v1.13.0 // indirect + github.com/tidwall/gjson v1.14.4 golang.org/x/sync v0.0.0-20210220032951-036812b2e83c golang.org/x/sys v0.0.0-20220412211240-33da011f77ad google.golang.org/protobuf v1.27.1 diff --git a/go.sum b/go.sum index 59e21cf1c..f86006603 100644 --- a/go.sum +++ b/go.sum @@ -52,8 +52,8 @@ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/tidwall/gjson v1.9.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= -github.com/tidwall/gjson v1.13.0 h1:3TFY9yxOQShrvmjdM76K+jc66zJeT6D3/VFFYCGQf7M= -github.com/tidwall/gjson v1.13.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM= +github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= diff --git a/pkg/app/context.go b/pkg/app/context.go index 9a6e45440..c607ce73b 100644 --- a/pkg/app/context.go +++ b/pkg/app/context.go @@ -233,6 +233,9 @@ type RequestContext struct { // clientIPFunc get form value by use custom function. formValueFunc FormValueFunc + + binder binding.Binder + validator binding.StructValidator } // Flush is the shortcut for ctx.Response.GetHijackWriter().Flush(). @@ -252,6 +255,14 @@ func (ctx *RequestContext) SetFormValueFunc(f FormValueFunc) { ctx.formValueFunc = f } +func (ctx *RequestContext) SetBinder(binder binding.Binder) { + ctx.binder = binder +} + +func (ctx *RequestContext) SetValidator(validator binding.StructValidator) { + ctx.validator = validator +} + func (ctx *RequestContext) GetTraceInfo() traceinfo.TraceInfo { return ctx.traceInfo } @@ -732,6 +743,10 @@ func (ctx *RequestContext) Copy() *RequestContext { paramCopy := make([]param.Param, len(cp.Params)) copy(paramCopy, cp.Params) cp.Params = paramCopy + cp.clientIPFunc = ctx.clientIPFunc + cp.formValueFunc = ctx.formValueFunc + cp.binder = ctx.binder + cp.validator = ctx.validator return cp } @@ -1302,22 +1317,94 @@ func bodyAllowedForStatus(status int) bool { return true } +func (ctx *RequestContext) getBinder() binding.Binder { + if ctx.binder != nil { + return ctx.binder + } + return binding.DefaultBinder() +} + +func (ctx *RequestContext) getValidator() binding.StructValidator { + if ctx.validator != nil { + return ctx.validator + } + return binding.DefaultValidator() +} + // BindAndValidate binds data from *RequestContext to obj and validates them if needed. // NOTE: obj should be a pointer. func (ctx *RequestContext) BindAndValidate(obj interface{}) error { - return binding.BindAndValidate(&ctx.Request, obj, ctx.Params) + return ctx.getBinder().BindAndValidate(&ctx.Request, obj, ctx.Params) } // Bind binds data from *RequestContext to obj. // NOTE: obj should be a pointer. func (ctx *RequestContext) Bind(obj interface{}) error { - return binding.Bind(&ctx.Request, obj, ctx.Params) + return ctx.getBinder().Bind(&ctx.Request, obj, ctx.Params) } // Validate validates obj with "vd" tag // NOTE: obj should be a pointer. func (ctx *RequestContext) Validate(obj interface{}) error { - return binding.Validate(obj) + return ctx.getValidator().ValidateStruct(obj) +} + +// BindQuery binds query parameters from *RequestContext to obj with 'query' tag. It will only use 'query' tag for binding. +// NOTE: obj should be a pointer. +func (ctx *RequestContext) BindQuery(obj interface{}) error { + return ctx.getBinder().BindQuery(&ctx.Request, obj) +} + +// BindHeader binds header parameters from *RequestContext to obj with 'header' tag. It will only use 'header' tag for binding. +// NOTE: obj should be a pointer. +func (ctx *RequestContext) BindHeader(obj interface{}) error { + return ctx.getBinder().BindHeader(&ctx.Request, obj) +} + +// BindPath binds router parameters from *RequestContext to obj with 'path' tag. It will only use 'path' tag for binding. +// NOTE: obj should be a pointer. +func (ctx *RequestContext) BindPath(obj interface{}) error { + return ctx.getBinder().BindPath(&ctx.Request, obj, ctx.Params) +} + +// BindForm binds form parameters from *RequestContext to obj with 'form' tag. It will only use 'form' tag for binding. +// NOTE: obj should be a pointer. +func (ctx *RequestContext) BindForm(obj interface{}) error { + if len(ctx.Request.Body()) == 0 { + return fmt.Errorf("missing form body") + } + return ctx.getBinder().BindForm(&ctx.Request, obj) +} + +// BindJSON binds JSON body from *RequestContext. +// NOTE: obj should be a pointer. +func (ctx *RequestContext) BindJSON(obj interface{}) error { + return ctx.getBinder().BindJSON(&ctx.Request, obj) +} + +// BindProtobuf binds protobuf body from *RequestContext. +// NOTE: obj should be a pointer. +func (ctx *RequestContext) BindProtobuf(obj interface{}) error { + return ctx.getBinder().BindProtobuf(&ctx.Request, obj) +} + +// BindByContentType will select the binding type on the ContentType automatically. +// NOTE: obj should be a pointer. +func (ctx *RequestContext) BindByContentType(obj interface{}) error { + if ctx.Request.Header.IsGet() { + return ctx.BindQuery(obj) + } + ct := utils.FilterContentType(bytesconv.B2s(ctx.Request.Header.ContentType())) + switch ct { + case consts.MIMEApplicationJSON: + return ctx.BindJSON(obj) + case consts.MIMEPROTOBUF: + return ctx.BindProtobuf(obj) + case consts.MIMEApplicationHTMLForm, consts.MIMEMultipartPOSTForm: + return ctx.BindForm(obj) + default: + return fmt.Errorf("unsupported bind content-type for '%s'", ct) + } } // VisitAllQueryArgs calls f for each existing query arg. diff --git a/pkg/app/context_test.go b/pkg/app/context_test.go index 88cd26338..22e7f8608 100644 --- a/pkg/app/context_test.go +++ b/pkg/app/context_test.go @@ -33,6 +33,7 @@ import ( "github.com/cloudwego/hertz/internal/bytesconv" "github.com/cloudwego/hertz/internal/bytestr" + "github.com/cloudwego/hertz/pkg/app/server/binding" "github.com/cloudwego/hertz/pkg/app/server/render" errs "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/common/test/assert" @@ -873,6 +874,35 @@ func TestSetClientIPFunc(t *testing.T) { assert.DeepEqual(t, reflect.ValueOf(fn).Pointer(), reflect.ValueOf(defaultClientIP).Pointer()) } +type mockValidator struct{} + +func (m *mockValidator) ValidateStruct(interface{}) error { + return fmt.Errorf("test mock") +} + +func (m *mockValidator) Engine() interface{} { + return nil +} + +func TestSetValidator(t *testing.T) { + m := &mockValidator{} + c := NewContext(0) + c.SetValidator(m) + c.SetBinder(binding.NewDefaultBinder(&binding.BindConfig{ValidateTag: "vt"})) + type User struct { + Age int `vt:"$>=0&&$<=130"` + } + + user := &User{ + Age: 135, + } + err := c.Validate(user) + if err == nil { + t.Fatalf("expected an error, but got nil") + } + assert.DeepEqual(t, "test mock", err.Error()) +} + func TestGetQuery(t *testing.T) { c := NewContext(0) c.Request.SetRequestURI("http://aaa.com?a=1&b=") @@ -1457,6 +1487,94 @@ func TestBindAndValidate(t *testing.T) { } } +func TestBindForm(t *testing.T) { + type Test struct { + A string + B int + } + + c := &RequestContext{} + c.Request.SetRequestURI("/foo/bar?a=123&b=11") + c.Request.SetBody([]byte("A=123&B=11")) + c.Request.Header.SetContentTypeBytes([]byte("application/x-www-form-urlencoded")) + + var req Test + err := c.BindForm(&req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + assert.DeepEqual(t, "123", req.A) + assert.DeepEqual(t, 11, req.B) + + c.Request.SetBody([]byte("")) + err = c.BindForm(&req) + if err == nil { + t.Fatalf("expected error, but get nil") + } +} + +type mockBinder struct{} + +func (m *mockBinder) Name() string { + return "test binder" +} + +func (m *mockBinder) Bind(request *protocol.Request, i interface{}, params param.Params) error { + return nil +} + +func (m *mockBinder) BindAndValidate(request *protocol.Request, i interface{}, params param.Params) error { + return fmt.Errorf("test binder") +} + +func (m *mockBinder) BindQuery(request *protocol.Request, i interface{}) error { + return nil +} + +func (m *mockBinder) BindHeader(request *protocol.Request, i interface{}) error { + return nil +} + +func (m *mockBinder) BindPath(request *protocol.Request, i interface{}, params param.Params) error { + return nil +} + +func (m *mockBinder) BindForm(request *protocol.Request, i interface{}) error { + return nil +} + +func (m *mockBinder) BindJSON(request *protocol.Request, i interface{}) error { + return nil +} + +func (m *mockBinder) BindProtobuf(request *protocol.Request, i interface{}) error { + return nil +} + +func TestSetBinder(t *testing.T) { + c := NewContext(0) + c.SetBinder(&mockBinder{}) + type T struct{} + req := T{} + err := c.Bind(&req) + assert.Nil(t, err) + err = c.BindAndValidate(&req) + assert.NotNil(t, err) + assert.DeepEqual(t, "test binder", err.Error()) + err = c.BindProtobuf(&req) + assert.Nil(t, err) + err = c.BindJSON(&req) + assert.Nil(t, err) + err = c.BindForm(&req) + assert.NotNil(t, err) + err = c.BindPath(&req) + assert.Nil(t, err) + err = c.BindQuery(&req) + assert.Nil(t, err) + err = c.BindHeader(&req) + assert.Nil(t, err) +} + func TestRequestContext_SetCookie(t *testing.T) { c := NewContext(0) c.SetCookie("user", "hertz", 1, "/", "localhost", protocol.CookieSameSiteLaxMode, true, true) diff --git a/pkg/app/server/binding/binder.go b/pkg/app/server/binding/binder.go new file mode 100644 index 000000000..f97b80dbd --- /dev/null +++ b/pkg/app/server/binding/binder.go @@ -0,0 +1,58 @@ +/* + * Copyright 2023 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * MIT License + * + * Copyright (c) 2019-present Fenny and Contributors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * This file may have been modified by CloudWeGo authors. All CloudWeGo + * Modifications are Copyright 2023 CloudWeGo Authors + */ + +package binding + +import ( + "github.com/cloudwego/hertz/pkg/protocol" + "github.com/cloudwego/hertz/pkg/route/param" +) + +type Binder interface { + Name() string + Bind(*protocol.Request, interface{}, param.Params) error + BindAndValidate(*protocol.Request, interface{}, param.Params) error + BindQuery(*protocol.Request, interface{}) error + BindHeader(*protocol.Request, interface{}) error + BindPath(*protocol.Request, interface{}, param.Params) error + BindForm(*protocol.Request, interface{}) error + BindJSON(*protocol.Request, interface{}) error + BindProtobuf(*protocol.Request, interface{}) error +} diff --git a/pkg/app/server/binding/binder_test.go b/pkg/app/server/binding/binder_test.go new file mode 100644 index 000000000..d106ed7ad --- /dev/null +++ b/pkg/app/server/binding/binder_test.go @@ -0,0 +1,1479 @@ +/* + * Copyright 2023 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * MIT License + * + * Copyright (c) 2019-present Fenny and Contributors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * This file may have been modified by CloudWeGo authors. All CloudWeGo + * Modifications are Copyright 2023 CloudWeGo Authors + */ + +package binding + +import ( + "encoding/json" + "fmt" + "mime/multipart" + "net/url" + "reflect" + "testing" + + "github.com/cloudwego/hertz/pkg/app/server/binding/testdata" + "github.com/cloudwego/hertz/pkg/common/test/assert" + "github.com/cloudwego/hertz/pkg/protocol" + "github.com/cloudwego/hertz/pkg/protocol/consts" + req2 "github.com/cloudwego/hertz/pkg/protocol/http1/req" + "github.com/cloudwego/hertz/pkg/route/param" + "google.golang.org/protobuf/proto" +) + +type mockRequest struct { + Req *protocol.Request +} + +func newMockRequest() *mockRequest { + return &mockRequest{ + Req: &protocol.Request{}, + } +} + +func (m *mockRequest) SetRequestURI(uri string) *mockRequest { + m.Req.SetRequestURI(uri) + return m +} + +func (m *mockRequest) SetFile(param, fileName string) *mockRequest { + m.Req.SetFile(param, fileName) + return m +} + +func (m *mockRequest) SetHeader(key, value string) *mockRequest { + m.Req.Header.Set(key, value) + return m +} + +func (m *mockRequest) SetHeaders(key, value string) *mockRequest { + m.Req.Header.Set(key, value) + return m +} + +func (m *mockRequest) SetPostArg(key, value string) *mockRequest { + m.Req.PostArgs().Add(key, value) + return m +} + +func (m *mockRequest) SetUrlEncodeContentType() *mockRequest { + m.Req.Header.SetContentTypeBytes([]byte("application/x-www-form-urlencoded")) + return m +} + +func (m *mockRequest) SetJSONContentType() *mockRequest { + m.Req.Header.SetContentTypeBytes([]byte(consts.MIMEApplicationJSON)) + return m +} + +func (m *mockRequest) SetProtobufContentType() *mockRequest { + m.Req.Header.SetContentTypeBytes([]byte(consts.MIMEPROTOBUF)) + return m +} + +func (m *mockRequest) SetBody(data []byte) *mockRequest { + m.Req.SetBody(data) + m.Req.Header.SetContentLength(len(data)) + return m +} + +func TestBind_BaseType(t *testing.T) { + type Req struct { + Version int `path:"v"` + ID int `query:"id"` + Header string `header:"H"` + Form string `form:"f"` + } + + req := newMockRequest(). + SetRequestURI("http://foobar.com?id=12"). + SetHeaders("H", "header"). + SetPostArg("f", "form"). + SetUrlEncodeContentType() + var params param.Params + params = append(params, param.Param{ + Key: "v", + Value: "1", + }) + + var result Req + + err := DefaultBinder().Bind(req.Req, &result, params) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, 1, result.Version) + assert.DeepEqual(t, 12, result.ID) + assert.DeepEqual(t, "header", result.Header) + assert.DeepEqual(t, "form", result.Form) +} + +func TestBind_SliceType(t *testing.T) { + type Req struct { + ID *[]int `query:"id"` + Str [3]string `query:"str"` + Byte []byte `query:"b"` + } + IDs := []int{11, 12, 13} + Strs := [3]string{"qwe", "asd", "zxc"} + Bytes := []byte("123") + + req := newMockRequest(). + SetRequestURI(fmt.Sprintf("http://foobar.com?id=%d&id=%d&id=%d&str=%s&str=%s&str=%s&b=%d&b=%d&b=%d", IDs[0], IDs[1], IDs[2], Strs[0], Strs[1], Strs[2], Bytes[0], Bytes[1], Bytes[2])) + + var result Req + + err := DefaultBinder().Bind(req.Req, &result, nil) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, 3, len(*result.ID)) + for idx, val := range IDs { + assert.DeepEqual(t, val, (*result.ID)[idx]) + } + assert.DeepEqual(t, 3, len(result.Str)) + for idx, val := range Strs { + assert.DeepEqual(t, val, result.Str[idx]) + } + assert.DeepEqual(t, 3, len(result.Byte)) + for idx, val := range Bytes { + assert.DeepEqual(t, val, result.Byte[idx]) + } +} + +func TestBind_StructType(t *testing.T) { + type FFF struct { + F1 string `query:"F1"` + } + + type TTT struct { + T1 string `query:"F1"` + T2 FFF + } + + type Foo struct { + F1 string `query:"F1"` + F2 string `header:"f2"` + F3 TTT + } + + type Bar struct { + B1 string `query:"B1"` + B2 Foo `query:"B2"` + } + + var result Bar + + req := newMockRequest().SetRequestURI("http://foobar.com?F1=f1&B1=b1").SetHeader("f2", "f2") + + err := DefaultBinder().Bind(req.Req, &result, nil) + if err != nil { + t.Error(err) + } + + assert.DeepEqual(t, "b1", result.B1) + assert.DeepEqual(t, "f1", result.B2.F1) + assert.DeepEqual(t, "f2", result.B2.F2) + assert.DeepEqual(t, "f1", result.B2.F3.T1) + assert.DeepEqual(t, "f1", result.B2.F3.T2.F1) +} + +func TestBind_PointerType(t *testing.T) { + type TT struct { + T1 string `query:"F1"` + } + + type Foo struct { + F1 *TT `query:"F1"` + F2 *******************string `query:"F1"` + } + + type Bar struct { + B1 ***string `query:"B1"` + B2 ****Foo `query:"B2"` + B3 []*string `query:"B3"` + B4 [2]*int `query:"B4"` + } + + result := Bar{} + + F1 := "f1" + B1 := "b1" + B2 := "b2" + B3s := []string{"b31", "b32"} + B4s := [2]int{0, 1} + + req := newMockRequest().SetRequestURI(fmt.Sprintf("http://foobar.com?F1=%s&B1=%s&B2=%s&B3=%s&B3=%s&B4=%d&B4=%d", F1, B1, B2, B3s[0], B3s[1], B4s[0], B4s[1])). + SetHeader("f2", "f2") + + err := DefaultBinder().Bind(req.Req, &result, nil) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, B1, ***result.B1) + assert.DeepEqual(t, F1, (*(****result.B2).F1).T1) + assert.DeepEqual(t, F1, *******************(****result.B2).F2) + assert.DeepEqual(t, len(B3s), len(result.B3)) + for idx, val := range B3s { + assert.DeepEqual(t, val, *result.B3[idx]) + } + assert.DeepEqual(t, len(B4s), len(result.B4)) + for idx, val := range B4s { + assert.DeepEqual(t, val, *result.B4[idx]) + } +} + +func TestBind_NestedStruct(t *testing.T) { + type Foo struct { + F1 string `query:"F1"` + } + + type Bar struct { + Foo + Nested struct { + N1 string `query:"F1"` + } + } + + result := Bar{} + + req := newMockRequest().SetRequestURI("http://foobar.com?F1=qwe") + err := DefaultBinder().Bind(req.Req, &result, nil) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, "qwe", result.Foo.F1) + assert.DeepEqual(t, "qwe", result.Nested.N1) +} + +func TestBind_SliceStruct(t *testing.T) { + type Foo struct { + F1 string `json:"f1"` + } + + type Bar struct { + B1 []Foo `query:"F1"` + } + + result := Bar{} + B1s := []string{"1", "2", "3"} + + req := newMockRequest().SetRequestURI(fmt.Sprintf("http://foobar.com?F1={\"f1\":\"%s\"}&F1={\"f1\":\"%s\"}&F1={\"f1\":\"%s\"}", B1s[0], B1s[1], B1s[2])) + err := DefaultBinder().Bind(req.Req, &result, nil) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, len(result.B1), len(B1s)) + for idx, val := range B1s { + assert.DeepEqual(t, B1s[idx], val) + } +} + +func TestBind_MapType(t *testing.T) { + var result map[string]string + req := newMockRequest(). + SetJSONContentType(). + SetBody([]byte(`{"j1":"j1", "j2":"j2"}`)) + err := DefaultBinder().Bind(req.Req, &result, nil) + if err != nil { + t.Fatal(err) + } + assert.DeepEqual(t, 2, len(result)) + assert.DeepEqual(t, "j1", result["j1"]) + assert.DeepEqual(t, "j2", result["j2"]) +} + +func TestBind_MapFieldType(t *testing.T) { + type Foo struct { + F1 ***map[string]string `query:"f1" json:"f1"` + } + + req := newMockRequest(). + SetRequestURI("http://foobar.com?f1={\"f1\":\"f1\"}"). + SetJSONContentType(). + SetBody([]byte(`{"j1":"j1", "j2":"j2"}`)) + result := Foo{} + err := DefaultBinder().Bind(req.Req, &result, nil) + if err != nil { + t.Fatal(err) + } + assert.DeepEqual(t, 1, len(***result.F1)) + assert.DeepEqual(t, "f1", (***result.F1)["f1"]) + + type Foo2 struct { + F1 map[string]string `query:"f1" json:"f1"` + } + result2 := Foo2{} + err = DefaultBinder().Bind(req.Req, &result2, nil) + if err != nil { + t.Fatal(err) + } + assert.DeepEqual(t, 1, len(result2.F1)) + assert.DeepEqual(t, "f1", result2.F1["f1"]) + req = newMockRequest(). + SetRequestURI("http://foobar.com?f1={\"f1\":\"f1\"") + result2 = Foo2{} + err = DefaultBinder().Bind(req.Req, &result2, nil) + if err == nil { + t.Error(err) + } +} + +func TestBind_UnexportedField(t *testing.T) { + var s struct { + A int `query:"a"` + b int `query:"b"` + } + req := newMockRequest(). + SetRequestURI("http://foobar.com?a=1&b=2") + err := DefaultBinder().Bind(req.Req, &s, nil) + if err != nil { + t.Fatal(err) + } + assert.DeepEqual(t, 1, s.A) + assert.DeepEqual(t, 0, s.b) +} + +func TestBind_NoTagField(t *testing.T) { + var s struct { + A string + B string + C string + } + req := newMockRequest(). + SetRequestURI("http://foobar.com?B=b1&C=c1"). + SetHeader("A", "a2") + + var params param.Params + params = append(params, param.Param{ + Key: "B", + Value: "b2", + }) + + err := DefaultBinder().Bind(req.Req, &s, params) + if err != nil { + t.Fatal(err) + } + assert.DeepEqual(t, "a2", s.A) + assert.DeepEqual(t, "b2", s.B) + assert.DeepEqual(t, "c1", s.C) +} + +func TestBind_ZeroValueBind(t *testing.T) { + var s struct { + A int `query:"a"` + B float64 `query:"b"` + } + req := newMockRequest(). + SetRequestURI("http://foobar.com?a=&b") + + bindConfig := &BindConfig{} + bindConfig.LooseZeroMode = true + binder := NewDefaultBinder(bindConfig) + err := binder.Bind(req.Req, &s, nil) + if err != nil { + t.Fatal(err) + } + assert.DeepEqual(t, 0, s.A) + assert.DeepEqual(t, float64(0), s.B) +} + +func TestBind_DefaultValueBind(t *testing.T) { + var s struct { + A int `default:"15"` + B float64 `query:"b" default:"17"` + C []int `default:"15"` + D []string `default:"qwe"` + } + req := newMockRequest(). + SetRequestURI("http://foobar.com") + + err := DefaultBinder().Bind(req.Req, &s, nil) + if err != nil { + t.Fatal(err) + } + assert.DeepEqual(t, 15, s.A) + assert.DeepEqual(t, float64(17), s.B) + assert.DeepEqual(t, 15, s.C[0]) + assert.DeepEqual(t, "qwe", s.D[0]) + + var d struct { + D [2]string `default:"qwe"` + } + + err = DefaultBinder().Bind(req.Req, &d, nil) + if err == nil { + t.Fatal("expected err") + } +} + +func TestBind_RequiredBind(t *testing.T) { + var s struct { + A int `query:"a,required"` + } + req := newMockRequest(). + SetRequestURI("http://foobar.com"). + SetHeader("A", "1") + + err := DefaultBinder().Bind(req.Req, &s, nil) + if err == nil { + t.Fatal("expected error") + } + + var d struct { + A int `query:"a,required" header:"A"` + } + err = DefaultBinder().Bind(req.Req, &d, nil) + if err != nil { + t.Fatal(err) + } + assert.DeepEqual(t, 1, d.A) +} + +func TestBind_TypedefType(t *testing.T) { + type Foo string + type Bar *int + type T struct { + T1 string `query:"a"` + } + type TT T + + var s struct { + A Foo `query:"a"` + B Bar `query:"b"` + T1 TT + } + req := newMockRequest(). + SetRequestURI("http://foobar.com?a=1&b=2") + err := DefaultBinder().Bind(req.Req, &s, nil) + if err != nil { + t.Fatal(err) + } + assert.DeepEqual(t, Foo("1"), s.A) + assert.DeepEqual(t, 2, *s.B) + assert.DeepEqual(t, "1", s.T1.T1) +} + +type EnumType int64 + +const ( + EnumType_TWEET EnumType = 0 + EnumType_RETWEET EnumType = 2 +) + +func (p EnumType) String() string { + switch p { + case EnumType_TWEET: + return "TWEET" + case EnumType_RETWEET: + return "RETWEET" + } + return "" +} + +func TestBind_EnumBind(t *testing.T) { + var s struct { + A EnumType `query:"a"` + B EnumType `query:"b"` + } + req := newMockRequest(). + SetRequestURI("http://foobar.com?a=0&b=2") + err := DefaultBinder().Bind(req.Req, &s, nil) + if err != nil { + t.Fatal(err) + } +} + +type CustomizedDecode struct { + A string +} + +func TestBind_CustomizedTypeDecode(t *testing.T) { + type Foo struct { + F ***CustomizedDecode + } + + bindConfig := &BindConfig{} + err := bindConfig.RegTypeUnmarshal(reflect.TypeOf(CustomizedDecode{}), func(req *protocol.Request, params param.Params, text string) (reflect.Value, error) { + q1 := req.URI().QueryArgs().Peek("a") + if len(q1) == 0 { + return reflect.Value{}, fmt.Errorf("can be nil") + } + val := CustomizedDecode{ + A: string(q1), + } + return reflect.ValueOf(val), nil + }) + if err != nil { + t.Fatal(err) + } + binder := NewDefaultBinder(bindConfig) + + req := newMockRequest(). + SetRequestURI("http://foobar.com?a=1&b=2") + result := Foo{} + err = binder.Bind(req.Req, &result, nil) + if err != nil { + t.Fatal(err) + } + assert.DeepEqual(t, "1", (***result.F).A) + + type Bar struct { + B *Foo + } + + result2 := Bar{} + err = binder.Bind(req.Req, &result2, nil) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, "1", (***(*result2.B).F).A) +} + +func TestBind_CustomizedTypeDecodeForPanic(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("expect a panic, but get nil") + } + }() + + bindConfig := &BindConfig{} + bindConfig.MustRegTypeUnmarshal(reflect.TypeOf(string("")), func(req *protocol.Request, params param.Params, text string) (reflect.Value, error) { + return reflect.Value{}, nil + }) +} + +func TestBind_JSON(t *testing.T) { + type Req struct { + J1 string `json:"j1"` + J2 int `json:"j2" query:"j2"` // 1. json unmarshal 2. query binding cover + J3 []byte `json:"j3"` + J4 [2]string `json:"j4"` + } + J3s := []byte("12") + J4s := [2]string{"qwe", "asd"} + + req := newMockRequest(). + SetRequestURI("http://foobar.com?j2=13"). + SetJSONContentType(). + SetBody([]byte(fmt.Sprintf(`{"j1":"j1", "j2":12, "j3":[%d, %d], "j4":["%s", "%s"]}`, J3s[0], J3s[1], J4s[0], J4s[1]))) + var result Req + err := DefaultBinder().Bind(req.Req, &result, nil) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, "j1", result.J1) + assert.DeepEqual(t, 13, result.J2) + for idx, val := range J3s { + assert.DeepEqual(t, val, result.J3[idx]) + } + for idx, val := range J4s { + assert.DeepEqual(t, val, result.J4[idx]) + } +} + +func TestBind_ResetJSONUnmarshal(t *testing.T) { + bindConfig := &BindConfig{} + bindConfig.UseStdJSONUnmarshaler() + binder := NewDefaultBinder(bindConfig) + type Req struct { + J1 string `json:"j1"` + J2 int `json:"j2"` + J3 []byte `json:"j3"` + J4 [2]string `json:"j4"` + } + J3s := []byte("12") + J4s := [2]string{"qwe", "asd"} + + req := newMockRequest(). + SetJSONContentType(). + SetBody([]byte(fmt.Sprintf(`{"j1":"j1", "j2":12, "j3":[%d, %d], "j4":["%s", "%s"]}`, J3s[0], J3s[1], J4s[0], J4s[1]))) + var result Req + err := binder.Bind(req.Req, &result, nil) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, "j1", result.J1) + assert.DeepEqual(t, 12, result.J2) + for idx, val := range J3s { + assert.DeepEqual(t, val, result.J3[idx]) + } + for idx, val := range J4s { + assert.DeepEqual(t, val, result.J4[idx]) + } +} + +func TestBind_FileBind(t *testing.T) { + type Nest struct { + N multipart.FileHeader `file_name:"d"` + } + + var s struct { + A *multipart.FileHeader `file_name:"a"` + B *multipart.FileHeader `form:"b"` + C multipart.FileHeader + D **Nest `file_name:"d"` + } + fileName := "binder_test.go" + req := newMockRequest(). + SetRequestURI("http://foobar.com"). + SetFile("a", fileName). + SetFile("b", fileName). + SetFile("C", fileName). + SetFile("d", fileName) + // to parse multipart files + req2 := req2.GetHTTP1Request(req.Req) + _ = req2.String() + err := DefaultBinder().Bind(req.Req, &s, nil) + if err != nil { + t.Fatal(err) + } + assert.DeepEqual(t, fileName, s.A.Filename) + assert.DeepEqual(t, fileName, s.B.Filename) + assert.DeepEqual(t, fileName, s.C.Filename) + assert.DeepEqual(t, fileName, (**s.D).N.Filename) +} + +func TestBind_FileSliceBind(t *testing.T) { + type Nest struct { + N *[]*multipart.FileHeader `form:"b"` + } + var s struct { + A []multipart.FileHeader `form:"a"` + B [3]multipart.FileHeader `form:"b"` + C []*multipart.FileHeader `form:"b"` + D Nest + } + fileName := "binder_test.go" + req := newMockRequest(). + SetRequestURI("http://foobar.com"). + SetFile("a", fileName). + SetFile("a", fileName). + SetFile("a", fileName). + SetFile("b", fileName). + SetFile("b", fileName). + SetFile("b", fileName) + // to parse multipart files + req2 := req2.GetHTTP1Request(req.Req) + _ = req2.String() + err := DefaultBinder().Bind(req.Req, &s, nil) + if err != nil { + t.Fatal(err) + } + assert.DeepEqual(t, 3, len(s.A)) + for _, file := range s.A { + assert.DeepEqual(t, fileName, file.Filename) + } + assert.DeepEqual(t, 3, len(s.B)) + for _, file := range s.B { + assert.DeepEqual(t, fileName, file.Filename) + } + assert.DeepEqual(t, 3, len(s.C)) + for _, file := range s.C { + assert.DeepEqual(t, fileName, file.Filename) + } + assert.DeepEqual(t, 3, len(*s.D.N)) + for _, file := range *s.D.N { + assert.DeepEqual(t, fileName, file.Filename) + } +} + +func TestBind_AnonymousField(t *testing.T) { + type nest struct { + n1 string `query:"n1"` // bind default value + N2 ***string `query:"n2"` // bind n2 value + string `query:"n3"` // bind default value + } + + var s struct { + s1 int `query:"s1"` // bind default value + int `query:"s2"` // bind default value + nest + } + req := newMockRequest(). + SetRequestURI("http://foobar.com?s1=1&s2=2&n1=1&n2=2&n3=3") + err := DefaultBinder().Bind(req.Req, &s, nil) + if err != nil { + t.Fatal(err) + } + assert.DeepEqual(t, 0, s.s1) + assert.DeepEqual(t, 0, s.int) + assert.DeepEqual(t, "", s.nest.n1) + assert.DeepEqual(t, "2", ***s.nest.N2) + assert.DeepEqual(t, "", s.nest.string) +} + +func TestBind_IgnoreField(t *testing.T) { + type Req struct { + Version int `path:"-"` + ID int `query:"-"` + Header string `header:"-"` + Form string `form:"-"` + } + + req := newMockRequest(). + SetRequestURI("http://foobar.com?ID=12"). + SetHeaders("Header", "header"). + SetPostArg("Form", "form"). + SetUrlEncodeContentType() + var params param.Params + params = append(params, param.Param{ + Key: "Version", + Value: "1", + }) + + var result Req + + err := DefaultBinder().Bind(req.Req, &result, params) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, 0, result.Version) + assert.DeepEqual(t, 0, result.ID) + assert.DeepEqual(t, "", result.Header) + assert.DeepEqual(t, "", result.Form) +} + +func TestBind_DefaultTag(t *testing.T) { + type Req struct { + Version int + ID int + Header string + Form string + } + type Req2 struct { + Version int + ID int + Header string + Form string + } + req := newMockRequest(). + SetRequestURI("http://foobar.com?ID=12"). + SetHeaders("Header", "header"). + SetPostArg("Form", "form"). + SetUrlEncodeContentType() + var params param.Params + params = append(params, param.Param{ + Key: "Version", + Value: "1", + }) + var result Req + err := DefaultBinder().Bind(req.Req, &result, params) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, 1, result.Version) + assert.DeepEqual(t, 12, result.ID) + assert.DeepEqual(t, "header", result.Header) + assert.DeepEqual(t, "form", result.Form) + + bindConfig := &BindConfig{} + bindConfig.DisableDefaultTag = true + binder := NewDefaultBinder(bindConfig) + result2 := Req2{} + err = binder.Bind(req.Req, &result2, params) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, 0, result2.Version) + assert.DeepEqual(t, 0, result2.ID) + assert.DeepEqual(t, "", result2.Header) + assert.DeepEqual(t, "", result2.Form) +} + +func TestBind_StructFieldResolve(t *testing.T) { + type Nested struct { + A int `query:"a" json:"a"` + B int `query:"b" json:"b"` + } + type Req struct { + N Nested `query:"n"` + } + + req := newMockRequest(). + SetRequestURI("http://foobar.com?n={\"a\":1,\"b\":2}"). + SetHeaders("Header", "header"). + SetPostArg("Form", "form"). + SetUrlEncodeContentType() + var result Req + bindConfig := &BindConfig{} + bindConfig.DisableStructFieldResolve = false + binder := NewDefaultBinder(bindConfig) + err := binder.Bind(req.Req, &result, nil) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, 1, result.N.A) + assert.DeepEqual(t, 2, result.N.B) + + req = newMockRequest(). + SetRequestURI("http://foobar.com?n={\"a\":1,\"b\":2}&a=11&b=22"). + SetHeaders("Header", "header"). + SetPostArg("Form", "form"). + SetUrlEncodeContentType() + err = DefaultBinder().Bind(req.Req, &result, nil) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, 11, result.N.A) + assert.DeepEqual(t, 22, result.N.B) +} + +func TestBind_JSONRequiredField(t *testing.T) { + type Nested2 struct { + C int `json:"c,required"` + D int `json:"dd,required"` + } + type Nested struct { + A int `json:"a,required"` + B int `json:"b,required"` + N2 Nested2 `json:"n2"` + } + type Req struct { + N Nested `json:"n,required"` + } + bodyBytes := []byte(`{ + "n": { + "a": 1, + "b": 2, + "n2": { + "dd": 4 + } + } +}`) + req := newMockRequest(). + SetRequestURI("http://foobar.com?j2=13"). + SetJSONContentType(). + SetBody(bodyBytes) + var result Req + err := DefaultBinder().Bind(req.Req, &result, nil) + if err == nil { + t.Errorf("expected an error, but get nil") + } + assert.DeepEqual(t, 1, result.N.A) + assert.DeepEqual(t, 2, result.N.B) + assert.DeepEqual(t, 0, result.N.N2.C) + assert.DeepEqual(t, 4, result.N.N2.D) + + bodyBytes = []byte(`{ + "n": { + "a": 1, + "b": 2 + } +}`) + req = newMockRequest(). + SetRequestURI("http://foobar.com?j2=13"). + SetJSONContentType(). + SetBody(bodyBytes) + var result2 Req + err = DefaultBinder().Bind(req.Req, &result2, nil) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, 1, result2.N.A) + assert.DeepEqual(t, 2, result2.N.B) + assert.DeepEqual(t, 0, result2.N.N2.C) + assert.DeepEqual(t, 0, result2.N.N2.D) +} + +func TestValidate_MultipleValidate(t *testing.T) { + type Test1 struct { + A int `query:"a" vd:"$>10"` + } + req := newMockRequest(). + SetRequestURI("http://foobar.com?a=9") + var result Test1 + err := DefaultBinder().BindAndValidate(req.Req, &result, nil) + if err == nil { + t.Fatalf("expected an error, but get nil") + } +} + +func TestBind_BindQuery(t *testing.T) { + type Req struct { + Q1 int `query:"q1"` + Q2 int + Q3 string + Q4 string + Q5 []int + } + + req := newMockRequest(). + SetRequestURI("http://foobar.com?q1=1&Q2=2&Q3=3&Q4=4&Q5=51&Q5=52") + + var result Req + + err := DefaultBinder().BindQuery(req.Req, &result) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, 1, result.Q1) + assert.DeepEqual(t, 2, result.Q2) + assert.DeepEqual(t, "3", result.Q3) + assert.DeepEqual(t, "4", result.Q4) + assert.DeepEqual(t, 51, result.Q5[0]) + assert.DeepEqual(t, 52, result.Q5[1]) +} + +func TestBind_LooseMode(t *testing.T) { + bindConfig := &BindConfig{} + bindConfig.LooseZeroMode = false + binder := NewDefaultBinder(bindConfig) + type Req struct { + ID int `query:"id"` + } + + req := newMockRequest(). + SetRequestURI("http://foobar.com?id=") + + var result Req + + err := binder.Bind(req.Req, &result, nil) + if err == nil { + t.Fatal("expected err") + } + assert.DeepEqual(t, 0, result.ID) + + bindConfig.LooseZeroMode = true + binder = NewDefaultBinder(bindConfig) + var result2 Req + + err = binder.Bind(req.Req, &result2, nil) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, 0, result.ID) +} + +func TestBind_NonStruct(t *testing.T) { + req := newMockRequest(). + SetRequestURI("http://foobar.com?id=1&id=2") + var id interface{} + err := DefaultBinder().Bind(req.Req, &id, nil) + if err != nil { + t.Error(err) + } + + err = DefaultBinder().BindAndValidate(req.Req, &id, nil) + if err != nil { + t.Error(err) + } +} + +func TestBind_BindTag(t *testing.T) { + type Req struct { + Query string + Header string + Path string + Form string + } + req := newMockRequest(). + SetRequestURI("http://foobar.com?Query=query"). + SetHeader("Header", "header"). + SetPostArg("Form", "form") + var params param.Params + params = append(params, param.Param{ + Key: "Path", + Value: "path", + }) + result := Req{} + + // test query tag + err := DefaultBinder().BindQuery(req.Req, &result) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, "query", result.Query) + + // test header tag + result = Req{} + err = DefaultBinder().BindHeader(req.Req, &result) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, "header", result.Header) + + // test form tag + result = Req{} + err = DefaultBinder().BindForm(req.Req, &result) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, "form", result.Form) + + // test path tag + result = Req{} + err = DefaultBinder().BindPath(req.Req, &result, params) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, "path", result.Path) + + // test json tag + req = newMockRequest(). + SetRequestURI("http://foobar.com"). + SetJSONContentType(). + SetBody([]byte("{\n \"Query\": \"query\",\n \"Path\": \"path\",\n \"Header\": \"header\",\n \"Form\": \"form\"\n}")) + result = Req{} + err = DefaultBinder().BindJSON(req.Req, &result) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, "form", result.Form) + assert.DeepEqual(t, "query", result.Query) + assert.DeepEqual(t, "header", result.Header) + assert.DeepEqual(t, "path", result.Path) +} + +func TestBind_BindAndValidate(t *testing.T) { + type Req struct { + ID int `query:"id" vd:"$>10"` + } + req := newMockRequest(). + SetRequestURI("http://foobar.com?id=12") + + // test bindAndValidate + var result Req + err := BindAndValidate(req.Req, &result, nil) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, 12, result.ID) + + // test bind + result = Req{} + err = Bind(req.Req, &result, nil) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, 12, result.ID) + + // test validate + req = newMockRequest(). + SetRequestURI("http://foobar.com?id=9") + result = Req{} + err = Bind(req.Req, &result, nil) + if err != nil { + t.Error(err) + } + err = Validate(result) + if err == nil { + t.Errorf("expect an error, but get nil") + } + assert.DeepEqual(t, 9, result.ID) +} + +func TestBind_FastPath(t *testing.T) { + type Req struct { + ID int `query:"id" vd:"$>10"` + } + req := newMockRequest(). + SetRequestURI("http://foobar.com?id=12") + + // test bindAndValidate + var result Req + err := BindAndValidate(req.Req, &result, nil) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, 12, result.ID) + // execute multiple times, test cache + for i := 0; i < 10; i++ { + result = Req{} + err := BindAndValidate(req.Req, &result, nil) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, 12, result.ID) + } +} + +func TestBind_NonPointer(t *testing.T) { + type Req struct { + ID int `query:"id" vd:"$>10"` + } + req := newMockRequest(). + SetRequestURI("http://foobar.com?id=12") + + // test bindAndValidate + var result Req + err := BindAndValidate(req.Req, result, nil) + if err == nil { + t.Error("expect an error, but get nil") + } + + err = Bind(req.Req, result, nil) + if err == nil { + t.Error("expect an error, but get nil") + } +} + +func TestBind_PreBind(t *testing.T) { + type Req struct { + Query string + Header string + Path string + Form string + } + // test json tag + req := newMockRequest(). + SetRequestURI("http://foobar.com"). + SetJSONContentType(). + SetBody([]byte("\n \"Query\": \"query\",\n \"Path\": \"path\",\n \"Header\": \"header\",\n \"Form\": \"form\"\n}")) + result := Req{} + err := DefaultBinder().Bind(req.Req, &result, nil) + if err == nil { + t.Error("expect an error, but get nil") + } + err = DefaultBinder().BindAndValidate(req.Req, &result, nil) + if err == nil { + t.Error("expect an error, but get nil") + } +} + +func TestBind_BindProtobuf(t *testing.T) { + data := testdata.HertzReq{Name: "hertz"} + body, err := proto.Marshal(&data) + if err != nil { + t.Fatal(err) + } + req := newMockRequest(). + SetRequestURI("http://foobar.com"). + SetProtobufContentType(). + SetBody(body) + + result := testdata.HertzReq{} + err = DefaultBinder().BindAndValidate(req.Req, &result, nil) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, "hertz", result.Name) + + result = testdata.HertzReq{} + err = DefaultBinder().BindProtobuf(req.Req, &result) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, "hertz", result.Name) +} + +func TestBind_PointerStruct(t *testing.T) { + bindConfig := &BindConfig{} + bindConfig.DisableStructFieldResolve = false + binder := NewDefaultBinder(bindConfig) + type Foo struct { + F1 string `query:"F1"` + } + type Bar struct { + B1 **Foo `query:"B1,required"` + } + query := make(url.Values) + query.Add("B1", "{\n \"F1\": \"111\"\n}") + + var result Bar + req := newMockRequest(). + SetRequestURI(fmt.Sprintf("http://foobar.com?%s", query.Encode())) + + err := binder.Bind(req.Req, &result, nil) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, "111", (**result.B1).F1) + + result = Bar{} + req = newMockRequest(). + SetRequestURI(fmt.Sprintf("http://foobar.com?%s&F1=222", query.Encode())) + err = binder.Bind(req.Req, &result, nil) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, "222", (**result.B1).F1) +} + +func TestBind_StructRequired(t *testing.T) { + bindConfig := &BindConfig{} + bindConfig.DisableStructFieldResolve = false + binder := NewDefaultBinder(bindConfig) + type Foo struct { + F1 string `query:"F1"` + } + type Bar struct { + B1 **Foo `query:"B1,required"` + } + + var result Bar + req := newMockRequest(). + SetRequestURI("http://foobar.com") + + err := binder.Bind(req.Req, &result, nil) + if err == nil { + t.Error("expect an error, but get nil") + } + + type Bar2 struct { + B1 **Foo `query:"B1"` + } + var result2 Bar2 + req = newMockRequest(). + SetRequestURI("http://foobar.com") + + err = binder.Bind(req.Req, &result2, nil) + if err != nil { + t.Error(err) + } +} + +func TestBind_StructErrorToWarn(t *testing.T) { + bindConfig := &BindConfig{} + bindConfig.DisableStructFieldResolve = false + binder := NewDefaultBinder(bindConfig) + type Foo struct { + F1 string `query:"F1"` + } + type Bar struct { + B1 **Foo `query:"B1,required"` + } + + var result Bar + req := newMockRequest(). + SetRequestURI("http://foobar.com?B1=111&F1=222") + + err := binder.Bind(req.Req, &result, nil) + // transfer 'unmarsahl err' to 'warn' + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, "222", (**result.B1).F1) + + type Bar2 struct { + B1 Foo `query:"B1,required"` + } + var result2 Bar2 + err = binder.Bind(req.Req, &result2, nil) + // transfer 'unmarsahl err' to 'warn' + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, "222", result2.B1.F1) +} + +func TestBind_DisallowUnknownFieldsConfig(t *testing.T) { + bindConfig := &BindConfig{} + bindConfig.EnableDecoderDisallowUnknownFields = true + binder := NewDefaultBinder(bindConfig) + type FooStructUseNumber struct { + Foo interface{} `json:"foo"` + } + req := newMockRequest(). + SetRequestURI("http://foobar.com"). + SetJSONContentType(). + SetBody([]byte(`{"foo": 123,"bar": "456"}`)) + var result FooStructUseNumber + + err := binder.BindJSON(req.Req, &result) + if err == nil { + t.Errorf("expected an error, but get nil") + } +} + +func TestBind_UseNumberConfig(t *testing.T) { + bindConfig := &BindConfig{} + bindConfig.EnableDecoderUseNumber = true + binder := NewDefaultBinder(bindConfig) + type FooStructUseNumber struct { + Foo interface{} `json:"foo"` + } + req := newMockRequest(). + SetRequestURI("http://foobar.com"). + SetJSONContentType(). + SetBody([]byte(`{"foo": 123}`)) + var result FooStructUseNumber + + err := binder.BindJSON(req.Req, &result) + if err != nil { + t.Error(err) + } + v, err := result.Foo.(json.Number).Int64() + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, int64(123), v) +} + +func TestBind_InterfaceType(t *testing.T) { + type Bar struct { + B1 interface{} `query:"B1"` + } + + var result Bar + query := make(url.Values) + query.Add("B1", `{"B1":"111"}`) + req := newMockRequest(). + SetRequestURI(fmt.Sprintf("http://foobar.com?%s", query.Encode())) + err := DefaultBinder().Bind(req.Req, &result, nil) + if err != nil { + t.Error(err) + } + + type Bar2 struct { + B2 *interface{} `query:"B1"` + } + + var result2 Bar2 + err = DefaultBinder().Bind(req.Req, &result2, nil) + if err != nil { + t.Error(err) + } +} + +func Test_BindHeaderNormalize(t *testing.T) { + type Req struct { + Header string `header:"h"` + } + + req := newMockRequest(). + SetRequestURI("http://foobar.com"). + SetHeaders("h", "header") + var result Req + + err := DefaultBinder().Bind(req.Req, &result, nil) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, "header", result.Header) + req = newMockRequest(). + SetRequestURI("http://foobar.com"). + SetHeaders("H", "header") + err = DefaultBinder().Bind(req.Req, &result, nil) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, "header", result.Header) + + type Req2 struct { + Header string `header:"H"` + } + + req2 := newMockRequest(). + SetRequestURI("http://foobar.com"). + SetHeaders("h", "header") + var result2 Req2 + + err2 := DefaultBinder().Bind(req2.Req, &result2, nil) + if err != nil { + t.Error(err2) + } + assert.DeepEqual(t, "header", result2.Header) + req2 = newMockRequest(). + SetRequestURI("http://foobar.com"). + SetHeaders("H", "header") + err2 = DefaultBinder().Bind(req2.Req, &result2, nil) + if err2 != nil { + t.Error(err2) + } + assert.DeepEqual(t, "header", result2.Header) + + type Req3 struct { + Header string `header:"h"` + } + + // without normalize, the header key & tag key need to be consistent + req3 := newMockRequest(). + SetRequestURI("http://foobar.com") + req3.Req.Header.DisableNormalizing() + req3.SetHeaders("h", "header") + var result3 Req3 + err3 := DefaultBinder().Bind(req3.Req, &result3, nil) + if err3 != nil { + t.Error(err3) + } + assert.DeepEqual(t, "header", result3.Header) + req3 = newMockRequest(). + SetRequestURI("http://foobar.com") + req3.Req.Header.DisableNormalizing() + req3.SetHeaders("H", "header") + result3 = Req3{} + err3 = DefaultBinder().Bind(req3.Req, &result3, nil) + if err3 != nil { + t.Error(err3) + } + assert.DeepEqual(t, "", result3.Header) +} + +func Benchmark_Binding(b *testing.B) { + type Req struct { + Version string `path:"v"` + ID int `query:"id"` + Header string `header:"h"` + Form string `form:"f"` + } + + req := newMockRequest(). + SetRequestURI("http://foobar.com?id=12"). + SetHeaders("H", "header"). + SetPostArg("f", "form"). + SetUrlEncodeContentType() + + var params param.Params + params = append(params, param.Param{ + Key: "v", + Value: "1", + }) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + var result Req + err := DefaultBinder().Bind(req.Req, &result, params) + if err != nil { + b.Error(err) + } + if result.ID != 12 { + b.Error("Id failed") + } + if result.Form != "form" { + b.Error("form failed") + } + if result.Header != "header" { + b.Error("header failed") + } + if result.Version != "1" { + b.Error("path failed") + } + } +} diff --git a/pkg/app/server/binding/binding.go b/pkg/app/server/binding/binding.go deleted file mode 100644 index fa4af9d97..000000000 --- a/pkg/app/server/binding/binding.go +++ /dev/null @@ -1,122 +0,0 @@ -/* - * Copyright 2022 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package binding - -import ( - "encoding/json" - "reflect" - - "github.com/bytedance/go-tagexpr/v2/binding" - "github.com/bytedance/go-tagexpr/v2/binding/gjson" - "github.com/bytedance/go-tagexpr/v2/validator" - hjson "github.com/cloudwego/hertz/pkg/common/json" - "github.com/cloudwego/hertz/pkg/protocol" - "github.com/cloudwego/hertz/pkg/route/param" -) - -func init() { - binding.ResetJSONUnmarshaler(hjson.Unmarshal) -} - -var defaultBinder = binding.Default() - -// BindAndValidate binds data from *protocol.Request to obj and validates them if needed. -// NOTE: -// -// obj should be a pointer. -func BindAndValidate(req *protocol.Request, obj interface{}, pathParams param.Params) error { - return defaultBinder.IBindAndValidate(obj, wrapRequest(req), pathParams) -} - -// Bind binds data from *protocol.Request to obj. -// NOTE: -// -// obj should be a pointer. -func Bind(req *protocol.Request, obj interface{}, pathParams param.Params) error { - return defaultBinder.IBind(obj, wrapRequest(req), pathParams) -} - -// Validate validates obj with "vd" tag -// NOTE: -// -// obj should be a pointer. -// Validate should be called after Bind. -func Validate(obj interface{}) error { - return defaultBinder.Validate(obj) -} - -// SetLooseZeroMode if set to true, -// the empty string request parameter is bound to the zero value of parameter. -// NOTE: -// -// The default is false. -// Suitable for these parameter types: query/header/cookie/form . -func SetLooseZeroMode(enable bool) { - defaultBinder.SetLooseZeroMode(enable) -} - -// SetErrorFactory customizes the factory of validation error. -// NOTE: -// -// If errFactory==nil, the default is used. -// SetErrorFactory will remain in effect once it has been called. -func SetErrorFactory(bindErrFactory, validatingErrFactory func(failField, msg string) error) { - defaultBinder.SetErrorFactory(bindErrFactory, validatingErrFactory) -} - -// MustRegTypeUnmarshal registers unmarshal function of type. -// NOTE: -// -// It will panic if exist error. -// MustRegTypeUnmarshal will remain in effect once it has been called. -func MustRegTypeUnmarshal(t reflect.Type, fn func(v string, emptyAsZero bool) (reflect.Value, error)) { - binding.MustRegTypeUnmarshal(t, fn) -} - -// MustRegValidateFunc registers validator function expression. -// NOTE: -// -// If force=true, allow to cover the existed same funcName. -// MustRegValidateFunc will remain in effect once it has been called. -func MustRegValidateFunc(funcName string, fn func(args ...interface{}) error, force ...bool) { - validator.RegFunc(funcName, fn, force...) -} - -// UseStdJSONUnmarshaler uses encoding/json as json library -// NOTE: -// -// The current version uses encoding/json by default. -// UseStdJSONUnmarshaler will remain in effect once it has been called. -func UseStdJSONUnmarshaler() { - binding.ResetJSONUnmarshaler(json.Unmarshal) -} - -// UseGJSONUnmarshaler uses github.com/bytedance/go-tagexpr/v2/binding/gjson as json library -// NOTE: -// -// UseGJSONUnmarshaler will remain in effect once it has been called. -func UseGJSONUnmarshaler() { - gjson.UseJSONUnmarshaler() -} - -// UseThirdPartyJSONUnmarshaler uses third-party json library for binding -// NOTE: -// -// UseThirdPartyJSONUnmarshaler will remain in effect once it has been called. -func UseThirdPartyJSONUnmarshaler(unmarshaler func(data []byte, v interface{}) error) { - binding.ResetJSONUnmarshaler(unmarshaler) -} diff --git a/pkg/app/server/binding/binding_test.go b/pkg/app/server/binding/binding_test.go deleted file mode 100644 index a9050f940..000000000 --- a/pkg/app/server/binding/binding_test.go +++ /dev/null @@ -1,450 +0,0 @@ -/* - * Copyright 2022 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package binding - -import ( - "bytes" - "fmt" - "mime/multipart" - "reflect" - "testing" - - "github.com/cloudwego/hertz/pkg/common/test/assert" - "github.com/cloudwego/hertz/pkg/protocol" - "github.com/cloudwego/hertz/pkg/route/param" -) - -func TestBindAndValidate(t *testing.T) { - type TestBind struct { - A string `query:"a"` - B []string `query:"b"` - C string `query:"c"` - D string `header:"d"` - E string `path:"e"` - F string `form:"f"` - G multipart.FileHeader `form:"g"` - H string `cookie:"h"` - } - - s := `------WebKitFormBoundaryJwfATyF8tmxSJnLg -Content-Disposition: form-data; name="f" - -fff -------WebKitFormBoundaryJwfATyF8tmxSJnLg -Content-Disposition: form-data; name="g"; filename="TODO" -Content-Type: application/octet-stream - -- SessionClient with referer and cookies support. -- Client with requests' pipelining support. -- ProxyHandler similar to FSHandler. -- WebSockets. See https://tools.ietf.org/html/rfc6455 . -- HTTP/2.0. See https://tools.ietf.org/html/rfc7540 . - -------WebKitFormBoundaryJwfATyF8tmxSJnLg-- -tailfoobar` - - mr := bytes.NewBufferString(s) - r := protocol.NewRequest("POST", "/foo", mr) - r.SetRequestURI("/foo/bar?a=aaa&b=b1&b=b2&c&i=19") - r.SetHeader("d", "ddd") - r.Header.SetContentLength(len(s)) - r.Header.SetContentTypeBytes([]byte("multipart/form-data; boundary=----WebKitFormBoundaryJwfATyF8tmxSJnLg")) - - r.SetCookie("h", "hhh") - - para := param.Params{ - {Key: "e", Value: "eee"}, - } - - // test BindAndValidate() - SetLooseZeroMode(true) - var req TestBind - err := BindAndValidate(r, &req, para) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - assert.DeepEqual(t, "aaa", req.A) - assert.DeepEqual(t, 2, len(req.B)) - assert.DeepEqual(t, "", req.C) - assert.DeepEqual(t, "ddd", req.D) - assert.DeepEqual(t, "eee", req.E) - assert.DeepEqual(t, "fff", req.F) - assert.DeepEqual(t, "TODO", req.G.Filename) - assert.DeepEqual(t, "hhh", req.H) - - // test Bind() - req = TestBind{} - err = Bind(r, &req, para) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - assert.DeepEqual(t, "aaa", req.A) - assert.DeepEqual(t, 2, len(req.B)) - assert.DeepEqual(t, "", req.C) - assert.DeepEqual(t, "ddd", req.D) - assert.DeepEqual(t, "eee", req.E) - assert.DeepEqual(t, "fff", req.F) - assert.DeepEqual(t, "TODO", req.G.Filename) - assert.DeepEqual(t, "hhh", req.H) - - type TestValidate struct { - I int `query:"i" vd:"$>20"` - } - - // test BindAndValidate() - var bindReq TestValidate - err = BindAndValidate(r, &bindReq, para) - if err == nil { - t.Fatalf("unexpected nil, expected an error") - } - - // test Validate() - bindReq = TestValidate{} - err = Bind(r, &bindReq, para) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - assert.DeepEqual(t, 19, bindReq.I) - err = Validate(&bindReq) - if err == nil { - t.Fatalf("unexpected nil, expected an error") - } -} - -func TestJsonBind(t *testing.T) { - type Test struct { - A string `json:"a"` - B []string `json:"b"` - C string `json:"c"` - D int `json:"d,string"` - } - - data := `{"a":"aaa", "b":["b1","b2"], "c":"ccc", "d":"100"}` - mr := bytes.NewBufferString(data) - r := protocol.NewRequest("POST", "/foo", mr) - r.Header.Set("Content-Type", "application/json; charset=utf-8") - r.SetHeader("d", "ddd") - r.Header.SetContentLength(len(data)) - - var req Test - err := BindAndValidate(r, &req, nil) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - assert.DeepEqual(t, "aaa", req.A) - assert.DeepEqual(t, 2, len(req.B)) - assert.DeepEqual(t, "ccc", req.C) - // NOTE: The default does not support string to go int conversion in json. - // You can add "string" tags or use other json unmarshal libraries that support this feature - assert.DeepEqual(t, 100, req.D) - - req = Test{} - UseGJSONUnmarshaler() - err = BindAndValidate(r, &req, nil) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - assert.DeepEqual(t, "aaa", req.A) - assert.DeepEqual(t, 2, len(req.B)) - assert.DeepEqual(t, "ccc", req.C) - // NOTE: The default does not support string to go int conversion in json. - // You can add "string" tags or use other json unmarshal libraries that support this feature - assert.DeepEqual(t, 100, req.D) - - req = Test{} - UseStdJSONUnmarshaler() - err = BindAndValidate(r, &req, nil) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - assert.DeepEqual(t, "aaa", req.A) - assert.DeepEqual(t, 2, len(req.B)) - assert.DeepEqual(t, "ccc", req.C) - // NOTE: The default does not support string to go int conversion in json. - // You can add "string" tags or use other json unmarshal libraries that support this feature - assert.DeepEqual(t, 100, req.D) -} - -// TestQueryParamInconsistency tests the Inconsistency for GetQuery(), the other unit test for GetFunc() in request.go are similar to it -func TestQueryParamInconsistency(t *testing.T) { - type QueryPara struct { - Para1 string `query:"para1"` - Para2 *string `query:"para2"` - } - - r := protocol.NewRequest("GET", "/foo", nil) - r.SetRequestURI("/foo/bar?para1=hertz¶2=binding") - - var req QueryPara - err := BindAndValidate(r, &req, nil) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - beforePara1 := deepCopyString(req.Para1) - beforePara2 := deepCopyString(*req.Para2) - r.URI().QueryArgs().Set("para1", "test") - r.URI().QueryArgs().Set("para2", "test") - afterPara1 := req.Para1 - afterPara2 := *req.Para2 - assert.DeepEqual(t, beforePara1, afterPara1) - assert.DeepEqual(t, beforePara2, afterPara2) -} - -func deepCopyString(str string) string { - tmp := make([]byte, len(str)) - copy(tmp, str) - c := string(tmp) - - return c -} - -func TestBindingFile(t *testing.T) { - type FileParas struct { - F *multipart.FileHeader `form:"F1"` - F1 multipart.FileHeader - Fs []multipart.FileHeader `form:"F1"` - Fs1 []*multipart.FileHeader `form:"F1"` - F2 *multipart.FileHeader `form:"F2"` - } - - s := `------WebKitFormBoundaryJwfATyF8tmxSJnLg -Content-Disposition: form-data; name="f" - -fff -------WebKitFormBoundaryJwfATyF8tmxSJnLg -Content-Disposition: form-data; name="F1"; filename="TODO1" -Content-Type: application/octet-stream - -- SessionClient with referer and cookies support. -- Client with requests' pipelining support. -- ProxyHandler similar to FSHandler. -- WebSockets. See https://tools.ietf.org/html/rfc6455 . -- HTTP/2.0. See https://tools.ietf.org/html/rfc7540 . -------WebKitFormBoundaryJwfATyF8tmxSJnLg -Content-Disposition: form-data; name="F1"; filename="TODO2" -Content-Type: application/octet-stream - -- SessionClient with referer and cookies support. -- Client with requests' pipelining support. -- ProxyHandler similar to FSHandler. -- WebSockets. See https://tools.ietf.org/html/rfc6455 . -- HTTP/2.0. See https://tools.ietf.org/html/rfc7540 . -------WebKitFormBoundaryJwfATyF8tmxSJnLg -Content-Disposition: form-data; name="F2"; filename="TODO3" -Content-Type: application/octet-stream - -- SessionClient with referer and cookies support. -- Client with requests' pipelining support. -- ProxyHandler similar to FSHandler. -- WebSockets. See https://tools.ietf.org/html/rfc6455 . -- HTTP/2.0. See https://tools.ietf.org/html/rfc7540 . - -------WebKitFormBoundaryJwfATyF8tmxSJnLg-- -tailfoobar` - - mr := bytes.NewBufferString(s) - r := protocol.NewRequest("POST", "/foo", mr) - r.SetRequestURI("/foo/bar?a=aaa&b=b1&b=b2&c&i=19") - r.SetHeader("d", "ddd") - r.Header.SetContentLength(len(s)) - r.Header.SetContentTypeBytes([]byte("multipart/form-data; boundary=----WebKitFormBoundaryJwfATyF8tmxSJnLg")) - - var req FileParas - err := BindAndValidate(r, &req, nil) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - assert.DeepEqual(t, "TODO1", req.F.Filename) - assert.DeepEqual(t, "TODO1", req.F1.Filename) - assert.DeepEqual(t, 2, len(req.Fs)) - assert.DeepEqual(t, 2, len(req.Fs1)) - assert.DeepEqual(t, "TODO3", req.F2.Filename) -} - -type BindError struct { - ErrType, FailField, Msg string -} - -// Error implements error interface. -func (e *BindError) Error() string { - if e.Msg != "" { - return e.ErrType + ": expr_path=" + e.FailField + ", cause=" + e.Msg - } - return e.ErrType + ": expr_path=" + e.FailField + ", cause=invalid" -} - -type ValidateError struct { - ErrType, FailField, Msg string -} - -// Error implements error interface. -func (e *ValidateError) Error() string { - if e.Msg != "" { - return e.ErrType + ": expr_path=" + e.FailField + ", cause=" + e.Msg - } - return e.ErrType + ": expr_path=" + e.FailField + ", cause=invalid" -} - -func TestSetErrorFactory(t *testing.T) { - type TestBind struct { - A string `query:"a,required"` - } - - r := protocol.NewRequest("GET", "/foo", nil) - r.SetRequestURI("/foo/bar?b=20") - - CustomBindErrFunc := func(failField, msg string) error { - err := BindError{ - ErrType: "bindErr", - FailField: "[bindFailField]: " + failField, - Msg: "[bindErrMsg]: " + msg, - } - - return &err - } - - CustomValidateErrFunc := func(failField, msg string) error { - err := ValidateError{ - ErrType: "validateErr", - FailField: "[validateFailField]: " + failField, - Msg: "[validateErrMsg]: " + msg, - } - - return &err - } - - SetErrorFactory(CustomBindErrFunc, CustomValidateErrFunc) - - var req TestBind - err := Bind(r, &req, nil) - if err == nil { - t.Fatalf("unexpected nil, expected an error") - } - assert.DeepEqual(t, "bindErr: expr_path=[bindFailField]: A, cause=[bindErrMsg]: missing required parameter", err.Error()) - - type TestValidate struct { - B int `query:"b" vd:"$>100"` - } - - var reqValidate TestValidate - err = Bind(r, &reqValidate, nil) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - err = Validate(&reqValidate) - if err == nil { - t.Fatalf("unexpected nil, expected an error") - } - assert.DeepEqual(t, "validateErr: expr_path=[validateFailField]: B, cause=[validateErrMsg]: ", err.Error()) -} - -func TestMustRegTypeUnmarshal(t *testing.T) { - type Nested struct { - B string - C string - } - - type TestBind struct { - A Nested `query:"a,required"` - } - - r := protocol.NewRequest("GET", "/foo", nil) - r.SetRequestURI("/foo/bar?a=hertzbinding") - - MustRegTypeUnmarshal(reflect.TypeOf(Nested{}), func(v string, emptyAsZero bool) (reflect.Value, error) { - if v == "" && emptyAsZero { - return reflect.ValueOf(Nested{}), nil - } - val := Nested{ - B: v[:5], - C: v[5:], - } - return reflect.ValueOf(val), nil - }) - - var req TestBind - err := Bind(r, &req, nil) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - assert.DeepEqual(t, "hertz", req.A.B) - assert.DeepEqual(t, "binding", req.A.C) -} - -func TestMustRegValidateFunc(t *testing.T) { - type TestValidate struct { - A string `query:"a" vd:"test($)"` - } - - r := protocol.NewRequest("GET", "/foo", nil) - r.SetRequestURI("/foo/bar?a=123") - - MustRegValidateFunc("test", func(args ...interface{}) error { - if len(args) != 1 { - return fmt.Errorf("the args must be one") - } - s, _ := args[0].(string) - if s == "123" { - return fmt.Errorf("the args can not be 123") - } - return nil - }) - - var req TestValidate - err := Bind(r, &req, nil) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - err = Validate(&req) - if err == nil { - t.Fatalf("unexpected nil, expected an error") - } -} - -func TestQueryAlias(t *testing.T) { - type MyInt int - type MyString string - type MyIntSlice []int - type MyStringSlice []string - type Test struct { - A []MyInt `query:"a"` - B MyIntSlice `query:"b"` - C MyString `query:"c"` - D MyStringSlice `query:"d"` - } - - r := protocol.NewRequest("GET", "/foo", nil) - r.SetRequestURI("/foo/bar?a=1&a=2&b=2&b=3&c=string1&d=string2&d=string3") - - var req Test - err := Bind(r, &req, nil) - if err != nil { - t.Fatalf("unexpected error: %v", err) - return - } - assert.DeepEqual(t, 2, len(req.A)) - assert.DeepEqual(t, 1, int(req.A[0])) - assert.DeepEqual(t, 2, int(req.A[1])) - assert.DeepEqual(t, 2, len(req.B)) - assert.DeepEqual(t, 2, req.B[0]) - assert.DeepEqual(t, 3, req.B[1]) - assert.DeepEqual(t, "string1", string(req.C)) - assert.DeepEqual(t, 2, len(req.D)) - assert.DeepEqual(t, "string2", req.D[0]) - assert.DeepEqual(t, "string3", req.D[1]) -} diff --git a/pkg/app/server/binding/config.go b/pkg/app/server/binding/config.go new file mode 100644 index 000000000..c122c54c6 --- /dev/null +++ b/pkg/app/server/binding/config.go @@ -0,0 +1,170 @@ +/* + * Copyright 2023 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package binding + +import ( + stdJson "encoding/json" + "fmt" + "reflect" + "time" + + "github.com/bytedance/go-tagexpr/v2/validator" + inDecoder "github.com/cloudwego/hertz/pkg/app/server/binding/internal/decoder" + hJson "github.com/cloudwego/hertz/pkg/common/json" + "github.com/cloudwego/hertz/pkg/protocol" + "github.com/cloudwego/hertz/pkg/route/param" +) + +// BindConfig contains options for default bind behavior. +type BindConfig struct { + // LooseZeroMode if set to true, + // the empty string request parameter is bound to the zero value of parameter. + // NOTE: + // The default is false. + // Suitable for these parameter types: query/header/cookie/form . + LooseZeroMode bool + // DisableDefaultTag is used to add default tags to a field when it has no tag + // If is false, the field with no tag will be added default tags, for more automated binding. But there may be additional overhead. + // NOTE: + // The default is false. + DisableDefaultTag bool + // DisableStructFieldResolve is used to generate a separate decoder for a struct. + // If is false, the 'struct' field will get a single inDecoder.structTypeFieldTextDecoder, and use json.Unmarshal for decode it. + // It usually used to add json string to query parameter. + // NOTE: + // The default is false. + DisableStructFieldResolve bool + // EnableDecoderUseNumber is used to call the UseNumber method on the JSON + // Decoder instance. UseNumber causes the Decoder to unmarshal a number into an + // interface{} as a Number instead of as a float64. + // NOTE: + // The default is false. + // It is used for BindJSON(). + EnableDecoderUseNumber bool + // EnableDecoderDisallowUnknownFields is used to call the DisallowUnknownFields method + // on the JSON Decoder instance. DisallowUnknownFields causes the Decoder to + // return an error when the destination is a struct and the input contains object + // keys which do not match any non-ignored, exported fields in the destination. + // NOTE: + // The default is false. + // It is used for BindJSON(). + EnableDecoderDisallowUnknownFields bool + // ValidateTag is used to determine if a filed needs to be validated. + // NOTE: + // The default is "vd". + ValidateTag string + // TypeUnmarshalFuncs registers customized type unmarshaler. + // NOTE: + // time.Time is registered by default + TypeUnmarshalFuncs map[reflect.Type]inDecoder.CustomizeDecodeFunc + // Validator is used to validate for BindAndValidate() + Validator StructValidator +} + +func NewBindConfig() *BindConfig { + return &BindConfig{ + LooseZeroMode: false, + DisableDefaultTag: false, + DisableStructFieldResolve: false, + EnableDecoderUseNumber: false, + EnableDecoderDisallowUnknownFields: false, + ValidateTag: "vd", + TypeUnmarshalFuncs: make(map[reflect.Type]inDecoder.CustomizeDecodeFunc), + Validator: defaultValidate, + } +} + +// RegTypeUnmarshal registers customized type unmarshaler. +func (config *BindConfig) RegTypeUnmarshal(t reflect.Type, fn inDecoder.CustomizeDecodeFunc) error { + // check + switch t.Kind() { + case reflect.String, reflect.Bool, + reflect.Float32, reflect.Float64, + reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8, + reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8: + return fmt.Errorf("registration type cannot be a basic type") + case reflect.Ptr: + return fmt.Errorf("registration type cannot be a pointer type") + } + if config.TypeUnmarshalFuncs == nil { + config.TypeUnmarshalFuncs = make(map[reflect.Type]inDecoder.CustomizeDecodeFunc) + } + config.TypeUnmarshalFuncs[t] = fn + return nil +} + +// MustRegTypeUnmarshal registers customized type unmarshaler. It will panic if exist error. +func (config *BindConfig) MustRegTypeUnmarshal(t reflect.Type, fn func(req *protocol.Request, params param.Params, text string) (reflect.Value, error)) { + err := config.RegTypeUnmarshal(t, fn) + if err != nil { + panic(err) + } +} + +func (config *BindConfig) initTypeUnmarshal() { + config.MustRegTypeUnmarshal(reflect.TypeOf(time.Time{}), func(req *protocol.Request, params param.Params, text string) (reflect.Value, error) { + if text == "" { + return reflect.ValueOf(time.Time{}), nil + } + t, err := time.Parse(time.RFC3339, text) + if err != nil { + return reflect.Value{}, err + } + return reflect.ValueOf(t), nil + }) +} + +// UseThirdPartyJSONUnmarshaler uses third-party json library for binding +// NOTE: +// +// UseThirdPartyJSONUnmarshaler will remain in effect once it has been called. +func (config *BindConfig) UseThirdPartyJSONUnmarshaler(fn func(data []byte, v interface{}) error) { + hJson.Unmarshal = fn +} + +// UseStdJSONUnmarshaler uses encoding/json as json library +// NOTE: +// +// The current version uses encoding/json by default. +// UseStdJSONUnmarshaler will remain in effect once it has been called. +func (config *BindConfig) UseStdJSONUnmarshaler() { + config.UseThirdPartyJSONUnmarshaler(stdJson.Unmarshal) +} + +type ValidateConfig struct{} + +func NewValidateConfig() *ValidateConfig { + return &ValidateConfig{} +} + +// MustRegValidateFunc registers validator function expression. +// NOTE: +// +// If force=true, allow to cover the existed same funcName. +// MustRegValidateFunc will remain in effect once it has been called. +func (config *ValidateConfig) MustRegValidateFunc(funcName string, fn func(args ...interface{}) error, force ...bool) { + validator.MustRegFunc(funcName, fn, force...) +} + +// SetValidatorErrorFactory customizes the factory of validation error. +func (config *ValidateConfig) SetValidatorErrorFactory(validatingErrFactory func(failField, msg string) error) { + if val, ok := DefaultValidator().(*defaultValidator); ok { + val.validate.SetErrorFactory(validatingErrFactory) + } else { + panic("customized validator can not use 'SetValidatorErrorFactory'") + } +} diff --git a/pkg/app/server/binding/default.go b/pkg/app/server/binding/default.go new file mode 100644 index 000000000..28bbc5311 --- /dev/null +++ b/pkg/app/server/binding/default.go @@ -0,0 +1,410 @@ +/* + * Copyright 2023 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * The MIT License + * + * Copyright (c) 2019-present Fenny and Contributors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * Copyright (c) 2014 Manuel Martínez-Almeida + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + * This file may have been modified by CloudWeGo authors. All CloudWeGo + * Modifications are Copyright 2023 CloudWeGo Authors + */ + +package binding + +import ( + "bytes" + stdJson "encoding/json" + "fmt" + "io" + "net/url" + "reflect" + "sync" + + "github.com/bytedance/go-tagexpr/v2/validator" + "github.com/cloudwego/hertz/internal/bytesconv" + inDecoder "github.com/cloudwego/hertz/pkg/app/server/binding/internal/decoder" + hJson "github.com/cloudwego/hertz/pkg/common/json" + "github.com/cloudwego/hertz/pkg/common/utils" + "github.com/cloudwego/hertz/pkg/protocol" + "github.com/cloudwego/hertz/pkg/protocol/consts" + "github.com/cloudwego/hertz/pkg/route/param" + "google.golang.org/protobuf/proto" +) + +const ( + queryTag = "query" + headerTag = "header" + formTag = "form" + pathTag = "path" +) + +type decoderInfo struct { + decoder inDecoder.Decoder + needValidate bool +} + +var defaultBind = NewDefaultBinder(nil) + +func DefaultBinder() Binder { + return defaultBind +} + +type defaultBinder struct { + config *BindConfig + decoderCache sync.Map + queryDecoderCache sync.Map + formDecoderCache sync.Map + headerDecoderCache sync.Map + pathDecoderCache sync.Map +} + +func NewDefaultBinder(config *BindConfig) Binder { + if config == nil { + bindConfig := NewBindConfig() + bindConfig.initTypeUnmarshal() + return &defaultBinder{ + config: bindConfig, + } + } + config.initTypeUnmarshal() + if config.Validator == nil { + config.Validator = DefaultValidator() + } + return &defaultBinder{ + config: config, + } +} + +// BindAndValidate binds data from *protocol.Request to obj and validates them if needed. +// NOTE: +// +// obj should be a pointer. +func BindAndValidate(req *protocol.Request, obj interface{}, pathParams param.Params) error { + return DefaultBinder().BindAndValidate(req, obj, pathParams) +} + +// Bind binds data from *protocol.Request to obj. +// NOTE: +// +// obj should be a pointer. +func Bind(req *protocol.Request, obj interface{}, pathParams param.Params) error { + return DefaultBinder().Bind(req, obj, pathParams) +} + +// Validate validates obj with "vd" tag +// NOTE: +// +// obj should be a pointer. +// Validate should be called after Bind. +func Validate(obj interface{}) error { + return DefaultValidator().ValidateStruct(obj) +} + +func (b *defaultBinder) tagCache(tag string) *sync.Map { + switch tag { + case queryTag: + return &b.queryDecoderCache + case headerTag: + return &b.headerDecoderCache + case formTag: + return &b.formDecoderCache + case pathTag: + return &b.pathDecoderCache + default: + return &b.decoderCache + } +} + +func (b *defaultBinder) bindTag(req *protocol.Request, v interface{}, params param.Params, tag string) error { + rv, typeID := valueAndTypeID(v) + if err := checkPointer(rv); err != nil { + return err + } + rt := dereferPointer(rv) + if rt.Kind() != reflect.Struct { + return b.bindNonStruct(req, v) + } + + err := b.preBindBody(req, v) + if err != nil { + return fmt.Errorf("bind body failed, err=%v", err) + } + cache := b.tagCache(tag) + cached, ok := cache.Load(typeID) + if ok { + // cached fieldDecoder, fast path + decoder := cached.(decoderInfo) + return decoder.decoder(req, params, rv.Elem()) + } + + decodeConfig := &inDecoder.DecodeConfig{ + LooseZeroMode: b.config.LooseZeroMode, + DisableDefaultTag: b.config.DisableDefaultTag, + DisableStructFieldResolve: b.config.DisableStructFieldResolve, + EnableDecoderUseNumber: b.config.EnableDecoderUseNumber, + EnableDecoderDisallowUnknownFields: b.config.EnableDecoderDisallowUnknownFields, + ValidateTag: b.config.ValidateTag, + TypeUnmarshalFuncs: b.config.TypeUnmarshalFuncs, + } + decoder, needValidate, err := inDecoder.GetReqDecoder(rv.Type(), tag, decodeConfig) + if err != nil { + return err + } + + cache.Store(typeID, decoderInfo{decoder: decoder, needValidate: needValidate}) + return decoder(req, params, rv.Elem()) +} + +func (b *defaultBinder) bindTagWithValidate(req *protocol.Request, v interface{}, params param.Params, tag string) error { + rv, typeID := valueAndTypeID(v) + if err := checkPointer(rv); err != nil { + return err + } + rt := dereferPointer(rv) + if rt.Kind() != reflect.Struct { + return b.bindNonStruct(req, v) + } + + err := b.preBindBody(req, v) + if err != nil { + return fmt.Errorf("bind body failed, err=%v", err) + } + cache := b.tagCache(tag) + cached, ok := cache.Load(typeID) + if ok { + // cached fieldDecoder, fast path + decoder := cached.(decoderInfo) + err = decoder.decoder(req, params, rv.Elem()) + if err != nil { + return err + } + if decoder.needValidate { + err = b.config.Validator.ValidateStruct(rv.Elem()) + } + return err + } + decodeConfig := &inDecoder.DecodeConfig{ + LooseZeroMode: b.config.LooseZeroMode, + DisableDefaultTag: b.config.DisableDefaultTag, + DisableStructFieldResolve: b.config.DisableStructFieldResolve, + EnableDecoderUseNumber: b.config.EnableDecoderUseNumber, + EnableDecoderDisallowUnknownFields: b.config.EnableDecoderDisallowUnknownFields, + ValidateTag: b.config.ValidateTag, + TypeUnmarshalFuncs: b.config.TypeUnmarshalFuncs, + } + decoder, needValidate, err := inDecoder.GetReqDecoder(rv.Type(), tag, decodeConfig) + if err != nil { + return err + } + + cache.Store(typeID, decoderInfo{decoder: decoder, needValidate: needValidate}) + err = decoder(req, params, rv.Elem()) + if err != nil { + return err + } + if needValidate { + err = b.config.Validator.ValidateStruct(rv.Elem()) + } + return err +} + +func (b *defaultBinder) BindQuery(req *protocol.Request, v interface{}) error { + return b.bindTag(req, v, nil, queryTag) +} + +func (b *defaultBinder) BindHeader(req *protocol.Request, v interface{}) error { + return b.bindTag(req, v, nil, headerTag) +} + +func (b *defaultBinder) BindPath(req *protocol.Request, v interface{}, params param.Params) error { + return b.bindTag(req, v, params, pathTag) +} + +func (b *defaultBinder) BindForm(req *protocol.Request, v interface{}) error { + return b.bindTag(req, v, nil, formTag) +} + +func (b *defaultBinder) BindJSON(req *protocol.Request, v interface{}) error { + return b.decodeJSON(bytes.NewReader(req.Body()), v) +} + +func (b *defaultBinder) decodeJSON(r io.Reader, obj interface{}) error { + decoder := hJson.NewDecoder(r) + if b.config.EnableDecoderUseNumber { + decoder.UseNumber() + } + if b.config.EnableDecoderDisallowUnknownFields { + decoder.DisallowUnknownFields() + } + + return decoder.Decode(obj) +} + +func (b *defaultBinder) BindProtobuf(req *protocol.Request, v interface{}) error { + msg, ok := v.(proto.Message) + if !ok { + return fmt.Errorf("%s does not implement 'proto.Message'", v) + } + return proto.Unmarshal(req.Body(), msg) +} + +func (b *defaultBinder) Name() string { + return "hertz" +} + +func (b *defaultBinder) BindAndValidate(req *protocol.Request, v interface{}, params param.Params) error { + return b.bindTagWithValidate(req, v, params, "") +} + +func (b *defaultBinder) Bind(req *protocol.Request, v interface{}, params param.Params) error { + return b.bindTag(req, v, params, "") +} + +// best effort binding +func (b *defaultBinder) preBindBody(req *protocol.Request, v interface{}) error { + if req.Header.ContentLength() <= 0 { + return nil + } + ct := bytesconv.B2s(req.Header.ContentType()) + switch utils.FilterContentType(ct) { + case consts.MIMEApplicationJSON: + return hJson.Unmarshal(req.Body(), v) + case consts.MIMEPROTOBUF: + msg, ok := v.(proto.Message) + if !ok { + return fmt.Errorf("%s can not implement 'proto.Message'", v) + } + return proto.Unmarshal(req.Body(), msg) + default: + return nil + } +} + +func (b *defaultBinder) bindNonStruct(req *protocol.Request, v interface{}) (err error) { + ct := bytesconv.B2s(req.Header.ContentType()) + switch utils.FilterContentType(ct) { + case consts.MIMEApplicationJSON: + err = hJson.Unmarshal(req.Body(), v) + case consts.MIMEPROTOBUF: + msg, ok := v.(proto.Message) + if !ok { + return fmt.Errorf("%s can not implement 'proto.Message'", v) + } + err = proto.Unmarshal(req.Body(), msg) + case consts.MIMEMultipartPOSTForm: + form := make(url.Values) + mf, err1 := req.MultipartForm() + if err1 == nil && mf.Value != nil { + for k, v := range mf.Value { + for _, vv := range v { + form.Add(k, vv) + } + } + } + b, _ := stdJson.Marshal(form) + err = hJson.Unmarshal(b, v) + case consts.MIMEApplicationHTMLForm: + form := make(url.Values) + req.PostArgs().VisitAll(func(formKey, value []byte) { + form.Add(string(formKey), string(value)) + }) + b, _ := stdJson.Marshal(form) + err = hJson.Unmarshal(b, v) + default: + // using query to decode + query := make(url.Values) + req.URI().QueryArgs().VisitAll(func(queryKey, value []byte) { + query.Add(string(queryKey), string(value)) + }) + b, _ := stdJson.Marshal(query) + err = hJson.Unmarshal(b, v) + } + return +} + +var _ StructValidator = (*defaultValidator)(nil) + +type defaultValidator struct { + once sync.Once + validate *validator.Validator +} + +func NewDefaultValidator(config *ValidateConfig) StructValidator { + return &defaultValidator{} +} + +// ValidateStruct receives any kind of type, but only performed struct or pointer to struct type. +func (v *defaultValidator) ValidateStruct(obj interface{}) error { + if obj == nil { + return nil + } + v.lazyinit() + return v.validate.Validate(obj) +} + +func (v *defaultValidator) lazyinit() { + v.once.Do(func() { + v.validate = validator.Default() + }) +} + +// Engine returns the underlying validator +func (v *defaultValidator) Engine() interface{} { + v.lazyinit() + return v.validate +} + +var defaultValidate = NewDefaultValidator(nil) + +func DefaultValidator() StructValidator { + return defaultValidate +} diff --git a/pkg/app/server/binding/internal/decoder/base_type_decoder.go b/pkg/app/server/binding/internal/decoder/base_type_decoder.go new file mode 100644 index 000000000..ece04f737 --- /dev/null +++ b/pkg/app/server/binding/internal/decoder/base_type_decoder.go @@ -0,0 +1,181 @@ +/* + * Copyright 2023 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * MIT License + * + * Copyright (c) 2019-present Fenny and Contributors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * This file may have been modified by CloudWeGo authors. All CloudWeGo + * Modifications are Copyright 2023 CloudWeGo Authors + */ + +package decoder + +import ( + "fmt" + "reflect" + + "github.com/cloudwego/hertz/pkg/protocol" + "github.com/cloudwego/hertz/pkg/route/param" +) + +type fieldInfo struct { + index int + parentIndex []int + fieldName string + tagInfos []TagInfo + fieldType reflect.Type + config *DecodeConfig +} + +type baseTypeFieldTextDecoder struct { + fieldInfo + decoder TextDecoder +} + +func (d *baseTypeFieldTextDecoder) Decode(req *protocol.Request, params param.Params, reqValue reflect.Value) error { + var err error + var text string + var exist bool + var defaultValue string + for _, tagInfo := range d.tagInfos { + if tagInfo.Skip || tagInfo.Key == jsonTag || tagInfo.Key == fileNameTag { + defaultValue = tagInfo.Default + if tagInfo.Key == jsonTag { + found := checkRequireJSON(req, tagInfo) + if found { + err = nil + } else { + err = fmt.Errorf("'%s' field is a 'required' parameter, but the request body does not have this parameter '%s'", d.fieldName, tagInfo.JSONName) + } + } + continue + } + text, exist = tagInfo.Getter(req, params, tagInfo.Value) + defaultValue = tagInfo.Default + if exist { + err = nil + break + } + if tagInfo.Required { + err = fmt.Errorf("'%s' field is a 'required' parameter, but the request does not have this parameter", d.fieldName) + } + } + if err != nil { + return err + } + if len(text) == 0 && len(defaultValue) != 0 { + text = defaultValue + } + if !exist && len(text) == 0 { + return nil + } + + // get the non-nil value for the parent field + reqValue = GetFieldValue(reqValue, d.parentIndex) + field := reqValue.Field(d.index) + if field.Kind() == reflect.Ptr { + t := field.Type() + var ptrDepth int + for t.Kind() == reflect.Ptr { + t = t.Elem() + ptrDepth++ + } + var vv reflect.Value + vv, err := stringToValue(t, text, req, params, d.config) + if err != nil { + return err + } + field.Set(ReferenceValue(vv, ptrDepth)) + return nil + } + + // Non-pointer elems + err = d.decoder.UnmarshalString(text, field, d.config.LooseZeroMode) + if err != nil { + return fmt.Errorf("unable to decode '%s' as %s: %w", text, d.fieldType.Name(), err) + } + + return nil +} + +func getBaseTypeTextDecoder(field reflect.StructField, index int, tagInfos []TagInfo, parentIdx []int, config *DecodeConfig) ([]fieldDecoder, error) { + for idx, tagInfo := range tagInfos { + switch tagInfo.Key { + case pathTag: + tagInfos[idx].SliceGetter = pathSlice + tagInfos[idx].Getter = path + case formTag: + tagInfos[idx].SliceGetter = postFormSlice + tagInfos[idx].Getter = postForm + case queryTag: + tagInfos[idx].SliceGetter = querySlice + tagInfos[idx].Getter = query + case cookieTag: + tagInfos[idx].SliceGetter = cookieSlice + tagInfos[idx].Getter = cookie + case headerTag: + tagInfos[idx].SliceGetter = headerSlice + tagInfos[idx].Getter = header + case jsonTag: + // do nothing + case rawBodyTag: + tagInfos[idx].SliceGetter = rawBodySlice + tagInfos[idx].Getter = rawBody + case fileNameTag: + // do nothing + default: + } + } + + fieldType := field.Type + for field.Type.Kind() == reflect.Ptr { + fieldType = field.Type.Elem() + } + + textDecoder, err := SelectTextDecoder(fieldType) + if err != nil { + return nil, err + } + + return []fieldDecoder{&baseTypeFieldTextDecoder{ + fieldInfo: fieldInfo{ + index: index, + parentIndex: parentIdx, + fieldName: field.Name, + tagInfos: tagInfos, + fieldType: fieldType, + config: config, + }, + decoder: textDecoder, + }}, nil +} diff --git a/pkg/app/server/binding/internal/decoder/customized_type_decoder.go b/pkg/app/server/binding/internal/decoder/customized_type_decoder.go new file mode 100644 index 000000000..8bf0f0121 --- /dev/null +++ b/pkg/app/server/binding/internal/decoder/customized_type_decoder.go @@ -0,0 +1,141 @@ +/* + * Copyright 2023 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * MIT License + * + * Copyright (c) 2019-present Fenny and Contributors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * This file may have been modified by CloudWeGo authors. All CloudWeGo + * Modifications are Copyright 2023 CloudWeGo Authors + */ + +package decoder + +import ( + "reflect" + + "github.com/cloudwego/hertz/pkg/protocol" + "github.com/cloudwego/hertz/pkg/route/param" +) + +type CustomizeDecodeFunc func(req *protocol.Request, params param.Params, text string) (reflect.Value, error) + +type customizedFieldTextDecoder struct { + fieldInfo + decodeFunc CustomizeDecodeFunc +} + +func (d *customizedFieldTextDecoder) Decode(req *protocol.Request, params param.Params, reqValue reflect.Value) error { + var text string + var exist bool + var defaultValue string + for _, tagInfo := range d.tagInfos { + if tagInfo.Skip || tagInfo.Key == jsonTag || tagInfo.Key == fileNameTag { + defaultValue = tagInfo.Default + continue + } + text, exist = tagInfo.Getter(req, params, tagInfo.Value) + defaultValue = tagInfo.Default + if exist { + break + } + } + if len(text) == 0 && len(defaultValue) != 0 { + text = defaultValue + } + + v, err := d.decodeFunc(req, params, text) + if err != nil { + return err + } + + reqValue = GetFieldValue(reqValue, d.parentIndex) + field := reqValue.Field(d.index) + if field.Kind() == reflect.Ptr { + t := field.Type() + var ptrDepth int + for t.Kind() == reflect.Ptr { + t = t.Elem() + ptrDepth++ + } + field.Set(ReferenceValue(v, ptrDepth)) + return nil + } + + field.Set(v) + return nil +} + +func getCustomizedFieldDecoder(field reflect.StructField, index int, tagInfos []TagInfo, parentIdx []int, decodeFunc CustomizeDecodeFunc, config *DecodeConfig) ([]fieldDecoder, error) { + for idx, tagInfo := range tagInfos { + switch tagInfo.Key { + case pathTag: + tagInfos[idx].SliceGetter = pathSlice + tagInfos[idx].Getter = path + case formTag: + tagInfos[idx].SliceGetter = postFormSlice + tagInfos[idx].Getter = postForm + case queryTag: + tagInfos[idx].SliceGetter = querySlice + tagInfos[idx].Getter = query + case cookieTag: + tagInfos[idx].SliceGetter = cookieSlice + tagInfos[idx].Getter = cookie + case headerTag: + tagInfos[idx].SliceGetter = headerSlice + tagInfos[idx].Getter = header + case jsonTag: + // do nothing + case rawBodyTag: + tagInfos[idx].SliceGetter = rawBodySlice + tagInfos[idx].Getter = rawBody + case fileNameTag: + // do nothing + default: + } + } + fieldType := field.Type + for field.Type.Kind() == reflect.Ptr { + fieldType = field.Type.Elem() + } + return []fieldDecoder{&customizedFieldTextDecoder{ + fieldInfo: fieldInfo{ + index: index, + parentIndex: parentIdx, + fieldName: field.Name, + tagInfos: tagInfos, + fieldType: fieldType, + config: config, + }, + decodeFunc: decodeFunc, + }}, nil +} diff --git a/pkg/app/server/binding/internal/decoder/decoder.go b/pkg/app/server/binding/internal/decoder/decoder.go new file mode 100644 index 000000000..0bd13442a --- /dev/null +++ b/pkg/app/server/binding/internal/decoder/decoder.go @@ -0,0 +1,191 @@ +/* + * Copyright 2023 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * MIT License + * + * Copyright (c) 2019-present Fenny and Contributors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * This file may have been modified by CloudWeGo authors. All CloudWeGo + * Modifications are Copyright 2023 CloudWeGo Authors + */ + +package decoder + +import ( + "fmt" + "mime/multipart" + "reflect" + + "github.com/cloudwego/hertz/pkg/protocol" + "github.com/cloudwego/hertz/pkg/route/param" +) + +type fieldDecoder interface { + Decode(req *protocol.Request, params param.Params, reqValue reflect.Value) error +} + +type Decoder func(req *protocol.Request, params param.Params, rv reflect.Value) error + +type DecodeConfig struct { + LooseZeroMode bool + DisableDefaultTag bool + DisableStructFieldResolve bool + EnableDecoderUseNumber bool + EnableDecoderDisallowUnknownFields bool + ValidateTag string + TypeUnmarshalFuncs map[reflect.Type]CustomizeDecodeFunc +} + +func GetReqDecoder(rt reflect.Type, byTag string, config *DecodeConfig) (Decoder, bool, error) { + var decoders []fieldDecoder + var needValidate bool + + el := rt.Elem() + if el.Kind() != reflect.Struct { + return nil, false, fmt.Errorf("unsupported \"%s\" type binding", rt.String()) + } + + for i := 0; i < el.NumField(); i++ { + if el.Field(i).PkgPath != "" && !el.Field(i).Anonymous { + // ignore unexported field + continue + } + + dec, needValidate2, err := getFieldDecoder(el.Field(i), i, []int{}, "", byTag, config) + if err != nil { + return nil, false, err + } + needValidate = needValidate || needValidate2 + + if dec != nil { + decoders = append(decoders, dec...) + } + } + + return func(req *protocol.Request, params param.Params, rv reflect.Value) error { + for _, decoder := range decoders { + err := decoder.Decode(req, params, rv) + if err != nil { + return err + } + } + + return nil + }, needValidate, nil +} + +func getFieldDecoder(field reflect.StructField, index int, parentIdx []int, parentJSONName string, byTag string, config *DecodeConfig) ([]fieldDecoder, bool, error) { + for field.Type.Kind() == reflect.Ptr { + field.Type = field.Type.Elem() + } + // skip anonymous definitions, like: + // type A struct { + // string + // } + if field.Type.Kind() != reflect.Struct && field.Anonymous { + return nil, false, nil + } + + // JSONName is like 'a.b.c' for 'required validate' + fieldTagInfos, newParentJSONName, needValidate := lookupFieldTags(field, parentJSONName, config) + if len(fieldTagInfos) == 0 && !config.DisableDefaultTag { + fieldTagInfos = getDefaultFieldTags(field) + } + if len(byTag) != 0 { + fieldTagInfos = getFieldTagInfoByTag(field, byTag) + } + + // customized type decoder has the highest priority + if customizedFunc, exist := config.TypeUnmarshalFuncs[field.Type]; exist { + dec, err := getCustomizedFieldDecoder(field, index, fieldTagInfos, parentIdx, customizedFunc, config) + return dec, needValidate, err + } + + // slice/array field decoder + if field.Type.Kind() == reflect.Slice || field.Type.Kind() == reflect.Array { + dec, err := getSliceFieldDecoder(field, index, fieldTagInfos, parentIdx, config) + return dec, needValidate, err + } + + // map filed decoder + if field.Type.Kind() == reflect.Map { + dec, err := getMapTypeTextDecoder(field, index, fieldTagInfos, parentIdx, config) + return dec, needValidate, err + } + + // struct field will be resolved recursively + if field.Type.Kind() == reflect.Struct { + var decoders []fieldDecoder + el := field.Type + // todo: more built-in common struct binding, ex. time... + switch el { + case reflect.TypeOf(multipart.FileHeader{}): // file binding + dec, err := getMultipartFileDecoder(field, index, fieldTagInfos, parentIdx, config) + return dec, needValidate, err + } + if !config.DisableStructFieldResolve { // decode struct type separately + structFieldDecoder, err := getStructTypeFieldDecoder(field, index, fieldTagInfos, parentIdx, config) + if err != nil { + return nil, needValidate, err + } + if structFieldDecoder != nil { + decoders = append(decoders, structFieldDecoder...) + } + } + + for i := 0; i < el.NumField(); i++ { + if el.Field(i).PkgPath != "" && !el.Field(i).Anonymous { + // ignore unexported field + continue + } + var idxes []int + if len(parentIdx) > 0 { + idxes = append(idxes, parentIdx...) + } + idxes = append(idxes, index) + dec, needValidate2, err := getFieldDecoder(el.Field(i), i, idxes, newParentJSONName, byTag, config) + needValidate = needValidate || needValidate2 + if err != nil { + return nil, false, err + } + if dec != nil { + decoders = append(decoders, dec...) + } + } + + return decoders, needValidate, nil + } + + // base type decoder + dec, err := getBaseTypeTextDecoder(field, index, fieldTagInfos, parentIdx, config) + return dec, needValidate, err +} diff --git a/pkg/app/server/binding/internal/decoder/getter.go b/pkg/app/server/binding/internal/decoder/getter.go new file mode 100644 index 000000000..81f8202c8 --- /dev/null +++ b/pkg/app/server/binding/internal/decoder/getter.go @@ -0,0 +1,134 @@ +/* + * Copyright 2023 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * MIT License + * + * Copyright (c) 2019-present Fenny and Contributors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * This file may have been modified by CloudWeGo authors. All CloudWeGo + * Modifications are Copyright 2023 CloudWeGo Authors + */ + +package decoder + +import ( + "github.com/cloudwego/hertz/pkg/protocol" + "github.com/cloudwego/hertz/pkg/route/param" +) + +type getter func(req *protocol.Request, params param.Params, key string, defaultValue ...string) (ret string, exist bool) + +func path(req *protocol.Request, params param.Params, key string, defaultValue ...string) (ret string, exist bool) { + if params != nil { + ret, exist = params.Get(key) + } + + if len(ret) == 0 && len(defaultValue) != 0 { + ret = defaultValue[0] + } + return ret, exist +} + +func postForm(req *protocol.Request, params param.Params, key string, defaultValue ...string) (ret string, exist bool) { + if ret, exist = req.PostArgs().PeekExists(key); exist { + return + } + + mf, err := req.MultipartForm() + if err == nil && mf.Value != nil { + for k, v := range mf.Value { + if k == key && len(v) > 0 { + ret = v[0] + } + } + } + + if len(ret) != 0 { + return ret, true + } + if ret, exist = req.URI().QueryArgs().PeekExists(key); exist { + return + } + + if len(ret) == 0 && len(defaultValue) != 0 { + ret = defaultValue[0] + } + + return ret, false +} + +func query(req *protocol.Request, params param.Params, key string, defaultValue ...string) (ret string, exist bool) { + if ret, exist = req.URI().QueryArgs().PeekExists(key); exist { + return + } + + if len(ret) == 0 && len(defaultValue) != 0 { + ret = defaultValue[0] + } + + return +} + +func cookie(req *protocol.Request, params param.Params, key string, defaultValue ...string) (ret string, exist bool) { + if val := req.Header.Cookie(key); val != nil { + ret = string(val) + return ret, true + } + + if len(ret) == 0 && len(defaultValue) != 0 { + ret = defaultValue[0] + } + + return ret, false +} + +func header(req *protocol.Request, params param.Params, key string, defaultValue ...string) (ret string, exist bool) { + if val := req.Header.Peek(key); val != nil { + ret = string(val) + return ret, true + } + + if len(ret) == 0 && len(defaultValue) != 0 { + ret = defaultValue[0] + } + + return ret, false +} + +func rawBody(req *protocol.Request, params param.Params, key string, defaultValue ...string) (ret string, exist bool) { + exist = false + if req.Header.ContentLength() > 0 { + ret = string(req.Body()) + exist = true + } + return +} diff --git a/pkg/app/server/binding/internal/decoder/gjson_required.go b/pkg/app/server/binding/internal/decoder/gjson_required.go new file mode 100644 index 000000000..88697e0f3 --- /dev/null +++ b/pkg/app/server/binding/internal/decoder/gjson_required.go @@ -0,0 +1,49 @@ +// Copyright 2023 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +//go:build gjson || !(amd64 && (linux || windows || darwin)) +// +build gjson !amd64 !linux,!windows,!darwin + +package decoder + +import ( + "strings" + + "github.com/cloudwego/hertz/internal/bytesconv" + "github.com/cloudwego/hertz/pkg/common/utils" + "github.com/cloudwego/hertz/pkg/protocol" + "github.com/cloudwego/hertz/pkg/protocol/consts" + "github.com/tidwall/gjson" +) + +func checkRequireJSON(req *protocol.Request, tagInfo TagInfo) bool { + if !tagInfo.Required { + return true + } + ct := bytesconv.B2s(req.Header.ContentType()) + if utils.FilterContentType(ct) != consts.MIMEApplicationJSON { + return false + } + result := gjson.GetBytes(req.Body(), tagInfo.JSONName) + if !result.Exists() { + idx := strings.LastIndex(tagInfo.JSONName, ".") + // There should be a superior if it is empty, it will report 'true' for required + if idx > 0 && !gjson.GetBytes(req.Body(), tagInfo.JSONName[:idx]).Exists() { + return true + } + return false + } + return true +} diff --git a/pkg/app/server/binding/internal/decoder/map_type_decoder.go b/pkg/app/server/binding/internal/decoder/map_type_decoder.go new file mode 100644 index 000000000..31fe85a1b --- /dev/null +++ b/pkg/app/server/binding/internal/decoder/map_type_decoder.go @@ -0,0 +1,165 @@ +/* + * Copyright 2023 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * MIT License + * + * Copyright (c) 2019-present Fenny and Contributors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * This file may have been modified by CloudWeGo authors. All CloudWeGo + * Modifications are Copyright 2023 CloudWeGo Authors + */ + +package decoder + +import ( + "fmt" + "reflect" + + "github.com/cloudwego/hertz/internal/bytesconv" + hJson "github.com/cloudwego/hertz/pkg/common/json" + "github.com/cloudwego/hertz/pkg/protocol" + "github.com/cloudwego/hertz/pkg/route/param" +) + +type mapTypeFieldTextDecoder struct { + fieldInfo +} + +func (d *mapTypeFieldTextDecoder) Decode(req *protocol.Request, params param.Params, reqValue reflect.Value) error { + var err error + var text string + var exist bool + var defaultValue string + for _, tagInfo := range d.tagInfos { + if tagInfo.Skip || tagInfo.Key == jsonTag || tagInfo.Key == fileNameTag { + defaultValue = tagInfo.Default + if tagInfo.Key == jsonTag { + found := checkRequireJSON(req, tagInfo) + if found { + err = nil + } else { + err = fmt.Errorf("'%s' field is a 'required' parameter, but the request does not have this parameter", d.fieldName) + } + } + continue + } + text, exist = tagInfo.Getter(req, params, tagInfo.Value) + defaultValue = tagInfo.Default + if exist { + err = nil + break + } + if tagInfo.Required { + err = fmt.Errorf("'%s' field is a 'required' parameter, but the request does not have this parameter", d.fieldName) + } + } + if err != nil { + return err + } + if len(text) == 0 && len(defaultValue) != 0 { + text = defaultValue + } + if !exist && len(text) == 0 { + return nil + } + + reqValue = GetFieldValue(reqValue, d.parentIndex) + field := reqValue.Field(d.index) + if field.Kind() == reflect.Ptr { + t := field.Type() + var ptrDepth int + for t.Kind() == reflect.Ptr { + t = t.Elem() + ptrDepth++ + } + var vv reflect.Value + vv, err := stringToValue(t, text, req, params, d.config) + if err != nil { + return fmt.Errorf("unable to decode '%s' as %s: %w", text, d.fieldType.Name(), err) + } + field.Set(ReferenceValue(vv, ptrDepth)) + return nil + } + + err = hJson.Unmarshal(bytesconv.S2b(text), field.Addr().Interface()) + if err != nil { + return fmt.Errorf("unable to decode '%s' as %s: %w", text, d.fieldType.Name(), err) + } + + return nil +} + +func getMapTypeTextDecoder(field reflect.StructField, index int, tagInfos []TagInfo, parentIdx []int, config *DecodeConfig) ([]fieldDecoder, error) { + for idx, tagInfo := range tagInfos { + switch tagInfo.Key { + case pathTag: + tagInfos[idx].SliceGetter = pathSlice + tagInfos[idx].Getter = path + case formTag: + tagInfos[idx].SliceGetter = postFormSlice + tagInfos[idx].Getter = postForm + case queryTag: + tagInfos[idx].SliceGetter = querySlice + tagInfos[idx].Getter = query + case cookieTag: + tagInfos[idx].SliceGetter = cookieSlice + tagInfos[idx].Getter = cookie + case headerTag: + tagInfos[idx].SliceGetter = headerSlice + tagInfos[idx].Getter = header + case jsonTag: + // do nothing + case rawBodyTag: + tagInfos[idx].SliceGetter = rawBodySlice + tagInfos[idx].Getter = rawBody + case fileNameTag: + // do nothing + default: + } + } + + fieldType := field.Type + for field.Type.Kind() == reflect.Ptr { + fieldType = field.Type.Elem() + } + + return []fieldDecoder{&mapTypeFieldTextDecoder{ + fieldInfo: fieldInfo{ + index: index, + parentIndex: parentIdx, + fieldName: field.Name, + tagInfos: tagInfos, + fieldType: fieldType, + config: config, + }, + }}, nil +} diff --git a/pkg/app/server/binding/internal/decoder/multipart_file_decoder.go b/pkg/app/server/binding/internal/decoder/multipart_file_decoder.go new file mode 100644 index 000000000..ae32dfea5 --- /dev/null +++ b/pkg/app/server/binding/internal/decoder/multipart_file_decoder.go @@ -0,0 +1,165 @@ +/* + * Copyright 2023 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package decoder + +import ( + "fmt" + "reflect" + + "github.com/cloudwego/hertz/pkg/protocol" + "github.com/cloudwego/hertz/pkg/route/param" +) + +type fileTypeDecoder struct { + fieldInfo + isRepeated bool +} + +func (d *fileTypeDecoder) Decode(req *protocol.Request, params param.Params, reqValue reflect.Value) error { + fieldValue := GetFieldValue(reqValue, d.parentIndex) + field := fieldValue.Field(d.index) + + if d.isRepeated { + return d.fileSliceDecode(req, params, reqValue) + } + var fileName string + // file_name > form > fieldName + for _, tagInfo := range d.tagInfos { + if tagInfo.Key == fileNameTag { + fileName = tagInfo.Value + break + } + if tagInfo.Key == formTag { + fileName = tagInfo.Value + } + } + if len(fileName) == 0 { + fileName = d.fieldName + } + file, err := req.FormFile(fileName) + if err != nil { + return fmt.Errorf("can not get file '%s', err: %v", fileName, err) + } + if field.Kind() == reflect.Ptr { + t := field.Type() + var ptrDepth int + for t.Kind() == reflect.Ptr { + t = t.Elem() + ptrDepth++ + } + v := reflect.New(t).Elem() + v.Set(reflect.ValueOf(*file)) + field.Set(ReferenceValue(v, ptrDepth)) + return nil + } + + // Non-pointer elems + field.Set(reflect.ValueOf(*file)) + + return nil +} + +func (d *fileTypeDecoder) fileSliceDecode(req *protocol.Request, params param.Params, reqValue reflect.Value) error { + fieldValue := GetFieldValue(reqValue, d.parentIndex) + field := fieldValue.Field(d.index) + // 如果没值,需要为其建一个值 + if field.Kind() == reflect.Ptr { + if field.IsNil() { + nonNilVal, ptrDepth := GetNonNilReferenceValue(field) + field.Set(ReferenceValue(nonNilVal, ptrDepth)) + } + } + var parentPtrDepth int + for field.Kind() == reflect.Ptr { + field = field.Elem() + parentPtrDepth++ + } + + var fileName string + // file_name > form > fieldName + for _, tagInfo := range d.tagInfos { + if tagInfo.Key == fileNameTag { + fileName = tagInfo.Value + break + } + if tagInfo.Key == formTag { + fileName = tagInfo.Value + } + } + if len(fileName) == 0 { + fileName = d.fieldName + } + multipartForm, err := req.MultipartForm() + if err != nil { + return fmt.Errorf("can not get multipartForm info, err: %v", err) + } + files, exist := multipartForm.File[fileName] + if !exist { + return fmt.Errorf("the file '%s' is not existed", fileName) + } + + if field.Kind() == reflect.Array { + if len(files) != field.Len() { + return fmt.Errorf("the numbers(%d) of file '%s' does not match the length(%d) of %s", len(files), fileName, field.Len(), field.Type().String()) + } + } else { + // slice need creating enough capacity + field = reflect.MakeSlice(field.Type(), len(files), len(files)) + } + + // handle multiple pointer + var ptrDepth int + t := d.fieldType.Elem() + elemKind := t.Kind() + for elemKind == reflect.Ptr { + t = t.Elem() + elemKind = t.Kind() + ptrDepth++ + } + + for idx, file := range files { + v := reflect.New(t).Elem() + v.Set(reflect.ValueOf(*file)) + field.Index(idx).Set(ReferenceValue(v, ptrDepth)) + } + fieldValue.Field(d.index).Set(ReferenceValue(field, parentPtrDepth)) + + return nil +} + +func getMultipartFileDecoder(field reflect.StructField, index int, tagInfos []TagInfo, parentIdx []int, config *DecodeConfig) ([]fieldDecoder, error) { + fieldType := field.Type + for field.Type.Kind() == reflect.Ptr { + fieldType = field.Type.Elem() + } + isRepeated := false + if fieldType.Kind() == reflect.Array || fieldType.Kind() == reflect.Slice { + isRepeated = true + } + + return []fieldDecoder{&fileTypeDecoder{ + fieldInfo: fieldInfo{ + index: index, + parentIndex: parentIdx, + fieldName: field.Name, + tagInfos: tagInfos, + fieldType: fieldType, + config: config, + }, + isRepeated: isRepeated, + }}, nil +} diff --git a/pkg/app/server/binding/internal/decoder/reflect.go b/pkg/app/server/binding/internal/decoder/reflect.go new file mode 100644 index 000000000..8d9b115e5 --- /dev/null +++ b/pkg/app/server/binding/internal/decoder/reflect.go @@ -0,0 +1,113 @@ +/* + * Copyright 2023 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * MIT License + * + * Copyright (c) 2019-present Fenny and Contributors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * This file may have been modified by CloudWeGo authors. All CloudWeGo + * Modifications are Copyright 2023 CloudWeGo Authors + */ + +package decoder + +import ( + "reflect" +) + +// ReferenceValue convert T to *T, the ptrDepth is the count of '*'. +func ReferenceValue(v reflect.Value, ptrDepth int) reflect.Value { + switch { + case ptrDepth > 0: + for ; ptrDepth > 0; ptrDepth-- { + vv := reflect.New(v.Type()) + vv.Elem().Set(v) + v = vv + } + case ptrDepth < 0: + for ; ptrDepth < 0 && v.Kind() == reflect.Ptr; ptrDepth++ { + v = v.Elem() + } + } + return v +} + +func GetNonNilReferenceValue(v reflect.Value) (reflect.Value, int) { + var ptrDepth int + t := v.Type() + elemKind := t.Kind() + for elemKind == reflect.Ptr { + t = t.Elem() + elemKind = t.Kind() + ptrDepth++ + } + val := reflect.New(t).Elem() + return val, ptrDepth +} + +func GetFieldValue(reqValue reflect.Value, parentIndex []int) reflect.Value { + // reqValue -> (***bar)(nil) need new a default + if reqValue.Kind() == reflect.Ptr && reqValue.IsNil() { + nonNilVal, ptrDepth := GetNonNilReferenceValue(reqValue) + reqValue = ReferenceValue(nonNilVal, ptrDepth) + } + for _, idx := range parentIndex { + if reqValue.Kind() == reflect.Ptr && reqValue.IsNil() { + nonNilVal, ptrDepth := GetNonNilReferenceValue(reqValue) + reqValue.Set(ReferenceValue(nonNilVal, ptrDepth)) + } + for reqValue.Kind() == reflect.Ptr { + reqValue = reqValue.Elem() + } + reqValue = reqValue.Field(idx) + } + + // It is possible that the parent struct is also a pointer, + // so need to create a non-nil reflect.Value for it at runtime. + for reqValue.Kind() == reflect.Ptr { + if reqValue.IsNil() { + nonNilVal, ptrDepth := GetNonNilReferenceValue(reqValue) + reqValue.Set(ReferenceValue(nonNilVal, ptrDepth)) + } + reqValue = reqValue.Elem() + } + + return reqValue +} + +func getElemType(t reflect.Type) reflect.Type { + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + + return t +} diff --git a/pkg/app/server/binding/internal/decoder/slice_getter.go b/pkg/app/server/binding/internal/decoder/slice_getter.go new file mode 100644 index 000000000..27d2b4174 --- /dev/null +++ b/pkg/app/server/binding/internal/decoder/slice_getter.go @@ -0,0 +1,143 @@ +/* + * Copyright 2023 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * MIT License + * + * Copyright (c) 2019-present Fenny and Contributors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * This file may have been modified by CloudWeGo authors. All CloudWeGo + * Modifications are Copyright 2023 CloudWeGo Authors + */ + +package decoder + +import ( + "github.com/cloudwego/hertz/internal/bytesconv" + "github.com/cloudwego/hertz/pkg/protocol" + "github.com/cloudwego/hertz/pkg/route/param" +) + +type sliceGetter func(req *protocol.Request, params param.Params, key string, defaultValue ...string) (ret []string) + +func pathSlice(req *protocol.Request, params param.Params, key string, defaultValue ...string) (ret []string) { + var value string + if params != nil { + value, _ = params.Get(key) + } + + if len(value) == 0 && len(defaultValue) != 0 { + value = defaultValue[0] + } + if len(value) != 0 { + ret = append(ret, value) + } + + return +} + +func postFormSlice(req *protocol.Request, params param.Params, key string, defaultValue ...string) (ret []string) { + req.PostArgs().VisitAll(func(formKey, value []byte) { + if bytesconv.B2s(formKey) == key { + ret = append(ret, string(value)) + } + }) + if len(ret) > 0 { + return + } + + mf, err := req.MultipartForm() + if err == nil && mf.Value != nil { + for k, v := range mf.Value { + if k == key && len(v) > 0 { + ret = append(ret, v...) + } + } + } + if len(ret) > 0 { + return + } + + if len(ret) == 0 && len(defaultValue) != 0 { + ret = append(ret, defaultValue...) + } + + return +} + +func querySlice(req *protocol.Request, params param.Params, key string, defaultValue ...string) (ret []string) { + req.URI().QueryArgs().VisitAll(func(queryKey, value []byte) { + if key == bytesconv.B2s(queryKey) { + ret = append(ret, string(value)) + } + }) + + if len(ret) == 0 && len(defaultValue) != 0 { + ret = append(ret, defaultValue...) + } + + return +} + +func cookieSlice(req *protocol.Request, params param.Params, key string, defaultValue ...string) (ret []string) { + req.Header.VisitAllCookie(func(cookieKey, value []byte) { + if bytesconv.B2s(cookieKey) == key { + ret = append(ret, string(value)) + } + }) + + if len(ret) == 0 && len(defaultValue) != 0 { + ret = append(ret, defaultValue...) + } + + return +} + +func headerSlice(req *protocol.Request, params param.Params, key string, defaultValue ...string) (ret []string) { + req.Header.VisitAll(func(headerKey, value []byte) { + if bytesconv.B2s(headerKey) == key { + ret = append(ret, string(value)) + } + }) + + if len(ret) == 0 && len(defaultValue) != 0 { + ret = append(ret, defaultValue...) + } + + return +} + +func rawBodySlice(req *protocol.Request, params param.Params, key string, defaultValue ...string) (ret []string) { + if req.Header.ContentLength() > 0 { + ret = append(ret, string(req.Body())) + } + return +} diff --git a/pkg/app/server/binding/internal/decoder/slice_type_decoder.go b/pkg/app/server/binding/internal/decoder/slice_type_decoder.go new file mode 100644 index 000000000..fc5c9814f --- /dev/null +++ b/pkg/app/server/binding/internal/decoder/slice_type_decoder.go @@ -0,0 +1,250 @@ +/* + * Copyright 2023 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * MIT License + * + * Copyright (c) 2019-present Fenny and Contributors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * This file may have been modified by CloudWeGo authors. All CloudWeGo + * Modifications are Copyright 2023 CloudWeGo Authors + */ + +package decoder + +import ( + "fmt" + "mime/multipart" + "reflect" + + "github.com/cloudwego/hertz/internal/bytesconv" + hJson "github.com/cloudwego/hertz/pkg/common/json" + "github.com/cloudwego/hertz/pkg/protocol" + "github.com/cloudwego/hertz/pkg/route/param" +) + +type sliceTypeFieldTextDecoder struct { + fieldInfo + isArray bool +} + +func (d *sliceTypeFieldTextDecoder) Decode(req *protocol.Request, params param.Params, reqValue reflect.Value) error { + var err error + var texts []string + var defaultValue string + var bindRawBody bool + for _, tagInfo := range d.tagInfos { + if tagInfo.Skip || tagInfo.Key == jsonTag || tagInfo.Key == fileNameTag { + defaultValue = tagInfo.Default + if tagInfo.Key == jsonTag { + found := checkRequireJSON(req, tagInfo) + if found { + err = nil + } else { + err = fmt.Errorf("'%s' field is a 'required' parameter, but the request does not have this parameter", d.fieldName) + } + } + continue + } + if tagInfo.Key == rawBodyTag { + bindRawBody = true + } + texts = tagInfo.SliceGetter(req, params, tagInfo.Value) + defaultValue = tagInfo.Default + if len(texts) != 0 { + err = nil + break + } + if tagInfo.Required { + err = fmt.Errorf("'%s' field is a 'required' parameter, but the request does not have this parameter", d.fieldName) + } + } + if err != nil { + return err + } + if len(texts) == 0 && len(defaultValue) != 0 { + texts = append(texts, defaultValue) + } + if len(texts) == 0 { + return nil + } + + reqValue = GetFieldValue(reqValue, d.parentIndex) + field := reqValue.Field(d.index) + // **[]**int + if field.Kind() == reflect.Ptr { + if field.IsNil() { + nonNilVal, ptrDepth := GetNonNilReferenceValue(field) + field.Set(ReferenceValue(nonNilVal, ptrDepth)) + } + } + var parentPtrDepth int + for field.Kind() == reflect.Ptr { + field = field.Elem() + parentPtrDepth++ + } + + if d.isArray { + if len(texts) != field.Len() { + return fmt.Errorf("%q is not valid value for %s", texts, field.Type().String()) + } + } else { + // slice need creating enough capacity + field = reflect.MakeSlice(field.Type(), len(texts), len(texts)) + } + // raw_body && []byte binding + if bindRawBody && field.Type().Elem().Kind() == reflect.Uint8 { + reqValue.Field(d.index).Set(reflect.ValueOf(req.Body())) + return nil + } + + // handle internal multiple pointer, []**int + var ptrDepth int + t := d.fieldType.Elem() // d.fieldType is non-pointer type for the field + elemKind := t.Kind() + for elemKind == reflect.Ptr { + t = t.Elem() + elemKind = t.Kind() + ptrDepth++ + } + + for idx, text := range texts { + var vv reflect.Value + vv, err = stringToValue(t, text, req, params, d.config) + if err != nil { + break + } + field.Index(idx).Set(ReferenceValue(vv, ptrDepth)) + } + if err != nil { + if !reqValue.Field(d.index).CanAddr() { + return err + } + // text[0] can be a complete json content for []Type. + err = hJson.Unmarshal(bytesconv.S2b(texts[0]), reqValue.Field(d.index).Addr().Interface()) + if err != nil { + return fmt.Errorf("using '%s' to unmarshal field '%s: %s' failed, %v", texts[0], d.fieldName, d.fieldType.String(), err) + } + } else { + reqValue.Field(d.index).Set(ReferenceValue(field, parentPtrDepth)) + } + + return nil +} + +func getSliceFieldDecoder(field reflect.StructField, index int, tagInfos []TagInfo, parentIdx []int, config *DecodeConfig) ([]fieldDecoder, error) { + if !(field.Type.Kind() == reflect.Slice || field.Type.Kind() == reflect.Array) { + return nil, fmt.Errorf("unexpected type %s, expected slice or array", field.Type.String()) + } + isArray := false + if field.Type.Kind() == reflect.Array { + isArray = true + } + for idx, tagInfo := range tagInfos { + switch tagInfo.Key { + case pathTag: + tagInfos[idx].SliceGetter = pathSlice + tagInfos[idx].Getter = path + case formTag: + tagInfos[idx].SliceGetter = postFormSlice + tagInfos[idx].Getter = postForm + case queryTag: + tagInfos[idx].SliceGetter = querySlice + tagInfos[idx].Getter = query + case cookieTag: + tagInfos[idx].SliceGetter = cookieSlice + tagInfos[idx].Getter = cookie + case headerTag: + tagInfos[idx].SliceGetter = headerSlice + tagInfos[idx].Getter = header + case jsonTag: + // do nothing + case rawBodyTag: + tagInfos[idx].SliceGetter = rawBodySlice + tagInfos[idx].Getter = rawBody + case fileNameTag: + // do nothing + default: + } + } + + fieldType := field.Type + for field.Type.Kind() == reflect.Ptr { + fieldType = field.Type.Elem() + } + // fieldType.Elem() is the type for array/slice elem + t := getElemType(fieldType.Elem()) + if t == reflect.TypeOf(multipart.FileHeader{}) { + return getMultipartFileDecoder(field, index, tagInfos, parentIdx, config) + } + + return []fieldDecoder{&sliceTypeFieldTextDecoder{ + fieldInfo: fieldInfo{ + index: index, + parentIndex: parentIdx, + fieldName: field.Name, + tagInfos: tagInfos, + fieldType: fieldType, + config: config, + }, + isArray: isArray, + }}, nil +} + +func stringToValue(elemType reflect.Type, text string, req *protocol.Request, params param.Params, config *DecodeConfig) (v reflect.Value, err error) { + v = reflect.New(elemType).Elem() + if customizedFunc, exist := config.TypeUnmarshalFuncs[elemType]; exist { + val, err := customizedFunc(req, params, text) + if err != nil { + return reflect.Value{}, err + } + return val, nil + } + switch elemType.Kind() { + case reflect.Struct: + err = hJson.Unmarshal(bytesconv.S2b(text), v.Addr().Interface()) + case reflect.Map: + err = hJson.Unmarshal(bytesconv.S2b(text), v.Addr().Interface()) + case reflect.Array, reflect.Slice: + // do nothing + default: + decoder, err := SelectTextDecoder(elemType) + if err != nil { + return reflect.Value{}, fmt.Errorf("unsupported type %s for slice/array", elemType.String()) + } + err = decoder.UnmarshalString(text, v, config.LooseZeroMode) + if err != nil { + return reflect.Value{}, fmt.Errorf("unable to decode '%s' as %s: %w", text, elemType.String(), err) + } + } + + return v, err +} diff --git a/pkg/app/server/binding/internal/decoder/sonic_required.go b/pkg/app/server/binding/internal/decoder/sonic_required.go new file mode 100644 index 000000000..2aae0c3a4 --- /dev/null +++ b/pkg/app/server/binding/internal/decoder/sonic_required.go @@ -0,0 +1,62 @@ +// Copyright 2023 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +//go:build (linux || windows || darwin) && amd64 && !gjson +// +build linux windows darwin +// +build amd64 +// +build !gjson + +package decoder + +import ( + "strings" + + "github.com/bytedance/sonic" + "github.com/cloudwego/hertz/internal/bytesconv" + "github.com/cloudwego/hertz/pkg/common/utils" + "github.com/cloudwego/hertz/pkg/protocol" + "github.com/cloudwego/hertz/pkg/protocol/consts" +) + +func checkRequireJSON(req *protocol.Request, tagInfo TagInfo) bool { + if !tagInfo.Required { + return true + } + ct := bytesconv.B2s(req.Header.ContentType()) + if utils.FilterContentType(ct) != consts.MIMEApplicationJSON { + return false + } + node, _ := sonic.Get(req.Body(), stringSliceForInterface(tagInfo.JSONName)...) + if !node.Exists() { + idx := strings.LastIndex(tagInfo.JSONName, ".") + if idx > 0 { + // There should be a superior if it is empty, it will report 'true' for required + node, _ := sonic.Get(req.Body(), stringSliceForInterface(tagInfo.JSONName[:idx])...) + if !node.Exists() { + return true + } + } + return false + } + return true +} + +func stringSliceForInterface(s string) (ret []interface{}) { + x := strings.Split(s, ".") + for _, val := range x { + ret = append(ret, val) + } + return +} diff --git a/pkg/app/server/binding/internal/decoder/struct_type_decoder.go b/pkg/app/server/binding/internal/decoder/struct_type_decoder.go new file mode 100644 index 000000000..3030f2ac6 --- /dev/null +++ b/pkg/app/server/binding/internal/decoder/struct_type_decoder.go @@ -0,0 +1,142 @@ +/* + * Copyright 2023 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package decoder + +import ( + "fmt" + "reflect" + + "github.com/cloudwego/hertz/internal/bytesconv" + "github.com/cloudwego/hertz/pkg/common/hlog" + hjson "github.com/cloudwego/hertz/pkg/common/json" + "github.com/cloudwego/hertz/pkg/protocol" + "github.com/cloudwego/hertz/pkg/route/param" +) + +type structTypeFieldTextDecoder struct { + fieldInfo +} + +func (d *structTypeFieldTextDecoder) Decode(req *protocol.Request, params param.Params, reqValue reflect.Value) error { + var err error + var text string + var exist bool + var defaultValue string + for _, tagInfo := range d.tagInfos { + if tagInfo.Skip || tagInfo.Key == jsonTag || tagInfo.Key == fileNameTag { + defaultValue = tagInfo.Default + if tagInfo.Key == jsonTag { + found := checkRequireJSON(req, tagInfo) + if found { + err = nil + } else { + err = fmt.Errorf("'%s' field is a 'required' parameter, but the request does not have this parameter", d.fieldName) + } + } + continue + } + text, exist = tagInfo.Getter(req, params, tagInfo.Value) + defaultValue = tagInfo.Default + if exist { + err = nil + break + } + if tagInfo.Required { + err = fmt.Errorf("'%s' field is a 'required' parameter, but the request does not have this parameter", d.fieldName) + } + } + if err != nil { + return err + } + if len(text) == 0 && len(defaultValue) != 0 { + text = defaultValue + } + if !exist && len(text) == 0 { + return nil + } + reqValue = GetFieldValue(reqValue, d.parentIndex) + field := reqValue.Field(d.index) + if field.Kind() == reflect.Ptr { + t := field.Type() + var ptrDepth int + for t.Kind() == reflect.Ptr { + t = t.Elem() + ptrDepth++ + } + var vv reflect.Value + vv, err := stringToValue(t, text, req, params, d.config) + if err != nil { + hlog.Infof("unable to decode '%s' as %s: %v, but it may not affect correctness, so skip it", text, d.fieldType.Name(), err) + return nil + } + field.Set(ReferenceValue(vv, ptrDepth)) + return nil + } + + err = hjson.Unmarshal(bytesconv.S2b(text), field.Addr().Interface()) + if err != nil { + hlog.Infof("unable to decode '%s' as %s: %v, but it may not affect correctness, so skip it", text, d.fieldType.Name(), err) + } + + return nil +} + +func getStructTypeFieldDecoder(field reflect.StructField, index int, tagInfos []TagInfo, parentIdx []int, config *DecodeConfig) ([]fieldDecoder, error) { + for idx, tagInfo := range tagInfos { + switch tagInfo.Key { + case pathTag: + tagInfos[idx].SliceGetter = pathSlice + tagInfos[idx].Getter = path + case formTag: + tagInfos[idx].SliceGetter = postFormSlice + tagInfos[idx].Getter = postForm + case queryTag: + tagInfos[idx].SliceGetter = querySlice + tagInfos[idx].Getter = query + case cookieTag: + tagInfos[idx].SliceGetter = cookieSlice + tagInfos[idx].Getter = cookie + case headerTag: + tagInfos[idx].SliceGetter = headerSlice + tagInfos[idx].Getter = header + case jsonTag: + // do nothing + case rawBodyTag: + tagInfos[idx].SliceGetter = rawBodySlice + tagInfos[idx].Getter = rawBody + case fileNameTag: + // do nothing + default: + } + } + + fieldType := field.Type + for field.Type.Kind() == reflect.Ptr { + fieldType = field.Type.Elem() + } + + return []fieldDecoder{&structTypeFieldTextDecoder{ + fieldInfo: fieldInfo{ + index: index, + parentIndex: parentIdx, + fieldName: field.Name, + tagInfos: tagInfos, + fieldType: fieldType, + config: config, + }, + }}, nil +} diff --git a/pkg/app/server/binding/internal/decoder/tag.go b/pkg/app/server/binding/internal/decoder/tag.go new file mode 100644 index 000000000..6df09aaa3 --- /dev/null +++ b/pkg/app/server/binding/internal/decoder/tag.go @@ -0,0 +1,164 @@ +/* + * Copyright 2023 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package decoder + +import ( + "reflect" + "strings" +) + +const ( + pathTag = "path" + formTag = "form" + queryTag = "query" + cookieTag = "cookie" + headerTag = "header" + jsonTag = "json" + rawBodyTag = "raw_body" + fileNameTag = "file_name" +) + +const ( + defaultTag = "default" +) + +const ( + requiredTagOpt = "required" +) + +type TagInfo struct { + Key string + Value string + JSONName string + Required bool + Skip bool + Default string + Options []string + Getter getter + SliceGetter sliceGetter +} + +func head(str, sep string) (head, tail string) { + idx := strings.Index(str, sep) + if idx < 0 { + return str, "" + } + return str[:idx], str[idx+len(sep):] +} + +func lookupFieldTags(field reflect.StructField, parentJSONName string, config *DecodeConfig) ([]TagInfo, string, bool) { + var ret []string + var needValidate bool + if _, ok := field.Tag.Lookup(config.ValidateTag); ok { + needValidate = true + } + tags := []string{pathTag, formTag, queryTag, cookieTag, headerTag, jsonTag, rawBodyTag, fileNameTag} + for _, tag := range tags { + if _, ok := field.Tag.Lookup(tag); ok { + ret = append(ret, tag) + } + } + + defaultVal := "" + if val, ok := field.Tag.Lookup(defaultTag); ok { + defaultVal = val + } + + var tagInfos []TagInfo + var newParentJSONName string + for _, tag := range ret { + tagContent := field.Tag.Get(tag) + tagValue, opts := head(tagContent, ",") + if len(tagValue) == 0 { + tagValue = field.Name + } + skip := false + jsonName := "" + if tag == jsonTag { + jsonName = parentJSONName + "." + tagValue + } + if tagValue == "-" { + skip = true + if tag == jsonTag { + jsonName = parentJSONName + "." + field.Name + } + } + if jsonName != "" { + jsonName = strings.TrimPrefix(jsonName, ".") + newParentJSONName = jsonName + } + var options []string + var opt string + var required bool + for len(opts) > 0 { + opt, opts = head(opts, ",") + options = append(options, opt) + if opt == requiredTagOpt { + required = true + } + } + tagInfos = append(tagInfos, TagInfo{Key: tag, Value: tagValue, JSONName: jsonName, Options: options, Required: required, Default: defaultVal, Skip: skip}) + } + if len(newParentJSONName) == 0 { + newParentJSONName = strings.TrimPrefix(parentJSONName+"."+field.Name, ".") + } + + return tagInfos, newParentJSONName, needValidate +} + +func getDefaultFieldTags(field reflect.StructField) (tagInfos []TagInfo) { + defaultVal := "" + if val, ok := field.Tag.Lookup(defaultTag); ok { + defaultVal = val + } + + tags := []string{pathTag, formTag, queryTag, cookieTag, headerTag, jsonTag, fileNameTag} + for _, tag := range tags { + tagInfos = append(tagInfos, TagInfo{Key: tag, Value: field.Name, Default: defaultVal}) + } + + return +} + +func getFieldTagInfoByTag(field reflect.StructField, tag string) []TagInfo { + var tagInfos []TagInfo + if content, ok := field.Tag.Lookup(tag); ok { + tagValue, opts := head(content, ",") + if len(tagValue) == 0 { + tagValue = field.Name + } + skip := false + if tagValue == "-" { + skip = true + } + var options []string + var opt string + var required bool + for len(opts) > 0 { + opt, opts = head(opts, ",") + options = append(options, opt) + if opt == requiredTagOpt { + required = true + } + } + tagInfos = append(tagInfos, TagInfo{Key: tag, Value: tagValue, Options: options, Required: required, Skip: skip}) + } else { + tagInfos = append(tagInfos, TagInfo{Key: tag, Value: field.Name}) + } + + return tagInfos +} diff --git a/pkg/app/server/binding/internal/decoder/text_decoder.go b/pkg/app/server/binding/internal/decoder/text_decoder.go new file mode 100644 index 000000000..8b53c2bf5 --- /dev/null +++ b/pkg/app/server/binding/internal/decoder/text_decoder.go @@ -0,0 +1,169 @@ +/* + * Copyright 2023 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * MIT License + * + * Copyright (c) 2019-present Fenny and Contributors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * This file may have been modified by CloudWeGo authors. All CloudWeGo + * Modifications are Copyright 2023 CloudWeGo Authors + */ + +package decoder + +import ( + "fmt" + "reflect" + "strconv" + + "github.com/cloudwego/hertz/internal/bytesconv" + hJson "github.com/cloudwego/hertz/pkg/common/json" +) + +type TextDecoder interface { + UnmarshalString(s string, fieldValue reflect.Value, looseZeroMode bool) error +} + +func SelectTextDecoder(rt reflect.Type) (TextDecoder, error) { + switch rt.Kind() { + case reflect.Bool: + return &boolDecoder{}, nil + case reflect.Uint8: + return &uintDecoder{bitSize: 8}, nil + case reflect.Uint16: + return &uintDecoder{bitSize: 16}, nil + case reflect.Uint32: + return &uintDecoder{bitSize: 32}, nil + case reflect.Uint64: + return &uintDecoder{bitSize: 64}, nil + case reflect.Uint: + return &uintDecoder{}, nil + case reflect.Int8: + return &intDecoder{bitSize: 8}, nil + case reflect.Int16: + return &intDecoder{bitSize: 16}, nil + case reflect.Int32: + return &intDecoder{bitSize: 32}, nil + case reflect.Int64: + return &intDecoder{bitSize: 64}, nil + case reflect.Int: + return &intDecoder{}, nil + case reflect.String: + return &stringDecoder{}, nil + case reflect.Float32: + return &floatDecoder{bitSize: 32}, nil + case reflect.Float64: + return &floatDecoder{bitSize: 64}, nil + case reflect.Interface: + return &interfaceDecoder{}, nil + } + + return nil, fmt.Errorf("unsupported type " + rt.String()) +} + +type boolDecoder struct{} + +func (d *boolDecoder) UnmarshalString(s string, fieldValue reflect.Value, looseZeroMode bool) error { + if s == "" && looseZeroMode { + s = "false" + } + v, err := strconv.ParseBool(s) + if err != nil { + return err + } + fieldValue.SetBool(v) + return nil +} + +type floatDecoder struct { + bitSize int +} + +func (d *floatDecoder) UnmarshalString(s string, fieldValue reflect.Value, looseZeroMode bool) error { + if s == "" && looseZeroMode { + s = "0.0" + } + v, err := strconv.ParseFloat(s, d.bitSize) + if err != nil { + return err + } + fieldValue.SetFloat(v) + return nil +} + +type intDecoder struct { + bitSize int +} + +func (d *intDecoder) UnmarshalString(s string, fieldValue reflect.Value, looseZeroMode bool) error { + if s == "" && looseZeroMode { + s = "0" + } + v, err := strconv.ParseInt(s, 10, d.bitSize) + if err != nil { + return err + } + fieldValue.SetInt(v) + return nil +} + +type stringDecoder struct{} + +func (d *stringDecoder) UnmarshalString(s string, fieldValue reflect.Value, looseZeroMode bool) error { + fieldValue.SetString(s) + return nil +} + +type uintDecoder struct { + bitSize int +} + +func (d *uintDecoder) UnmarshalString(s string, fieldValue reflect.Value, looseZeroMode bool) error { + if s == "" && looseZeroMode { + s = "0" + } + v, err := strconv.ParseUint(s, 10, d.bitSize) + if err != nil { + return err + } + fieldValue.SetUint(v) + return nil +} + +type interfaceDecoder struct{} + +func (d *interfaceDecoder) UnmarshalString(s string, fieldValue reflect.Value, looseZeroMode bool) error { + if s == "" && looseZeroMode { + s = "0" + } + return hJson.Unmarshal(bytesconv.S2b(s), fieldValue.Addr().Interface()) +} diff --git a/pkg/app/server/binding/reflect.go b/pkg/app/server/binding/reflect.go new file mode 100644 index 000000000..502de11d2 --- /dev/null +++ b/pkg/app/server/binding/reflect.go @@ -0,0 +1,73 @@ +/* + * Copyright 2023 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * MIT License + * + * Copyright (c) 2019-present Fenny and Contributors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * This file may have been modified by CloudWeGo authors. All CloudWeGo + * Modifications are Copyright 2023 CloudWeGo Authors + */ + +package binding + +import ( + "fmt" + "reflect" + "unsafe" +) + +func valueAndTypeID(v interface{}) (reflect.Value, uintptr) { + header := (*emptyInterface)(unsafe.Pointer(&v)) + rv := reflect.ValueOf(v) + return rv, header.typeID +} + +type emptyInterface struct { + typeID uintptr + dataPtr unsafe.Pointer +} + +func checkPointer(rv reflect.Value) error { + if rv.Kind() != reflect.Ptr || rv.IsNil() { + return fmt.Errorf("receiver must be a non-nil pointer") + } + return nil +} + +func dereferPointer(rv reflect.Value) reflect.Type { + rt := rv.Type() + for rt.Kind() == reflect.Ptr { + rt = rt.Elem() + } + return rt +} diff --git a/pkg/app/server/binding/reflect_internal_test.go b/pkg/app/server/binding/reflect_internal_test.go new file mode 100644 index 000000000..65dc68fc8 --- /dev/null +++ b/pkg/app/server/binding/reflect_internal_test.go @@ -0,0 +1,90 @@ +// Copyright 2023 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +package binding + +import ( + "reflect" + "testing" + + "github.com/cloudwego/hertz/pkg/app/server/binding/internal/decoder" + "github.com/cloudwego/hertz/pkg/common/test/assert" +) + +type foo2 struct { + F1 string +} + +type fooq struct { + F1 **string +} + +func Test_ReferenceValue(t *testing.T) { + foo1 := foo2{F1: "f1"} + foo1Val := reflect.ValueOf(foo1) + foo1PointerVal := decoder.ReferenceValue(foo1Val, 5) + assert.DeepEqual(t, "f1", foo1.F1) + assert.DeepEqual(t, "f1", foo1Val.Field(0).Interface().(string)) + if foo1PointerVal.Kind() != reflect.Ptr { + t.Errorf("expect a pointer, but get nil") + } + assert.DeepEqual(t, "*****binding.foo2", foo1PointerVal.Type().String()) + + deFoo1PointerVal := decoder.ReferenceValue(foo1PointerVal, -5) + if deFoo1PointerVal.Kind() == reflect.Ptr { + t.Errorf("expect a non-pointer, but get a pointer") + } + assert.DeepEqual(t, "f1", deFoo1PointerVal.Field(0).Interface().(string)) +} + +func Test_GetNonNilReferenceValue(t *testing.T) { + foo1 := (****foo)(nil) + foo1Val := reflect.ValueOf(foo1) + foo1ValNonNil, ptrDepth := decoder.GetNonNilReferenceValue(foo1Val) + if !foo1ValNonNil.IsValid() { + t.Errorf("expect a valid value, but get nil") + } + if !foo1ValNonNil.CanSet() { + t.Errorf("expect can set value, but not") + } + + foo1ReferPointer := decoder.ReferenceValue(foo1ValNonNil, ptrDepth) + if foo1ReferPointer.Kind() != reflect.Ptr { + t.Errorf("expect a pointer, but get nil") + } +} + +func Test_GetFieldValue(t *testing.T) { + type bar struct { + B1 **fooq + } + bar1 := (***bar)(nil) + parentIdx := []int{0} + idx := 0 + + bar1Val := reflect.ValueOf(bar1) + parentFieldVal := decoder.GetFieldValue(bar1Val, parentIdx) + if parentFieldVal.Kind() == reflect.Ptr { + t.Errorf("expect a non-pointer, but get a pointer") + } + if !parentFieldVal.CanSet() { + t.Errorf("expect can set value, but not") + } + fooFieldVal := parentFieldVal.Field(idx) + assert.DeepEqual(t, "**string", fooFieldVal.Type().String()) + if !fooFieldVal.CanSet() { + t.Errorf("expect can set value, but not") + } +} diff --git a/pkg/app/server/binding/reflect_test.go b/pkg/app/server/binding/reflect_test.go new file mode 100644 index 000000000..036eb7f35 --- /dev/null +++ b/pkg/app/server/binding/reflect_test.go @@ -0,0 +1,87 @@ +/* + * Copyright 2023 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package binding + +import ( + "reflect" + "testing" + + "github.com/cloudwego/hertz/pkg/common/test/assert" +) + +type foo struct { + f1 string +} + +func TestReflect_TypeID(t *testing.T) { + _, intType := valueAndTypeID(int(1)) + _, uintType := valueAndTypeID(uint(1)) + _, shouldBeIntType := valueAndTypeID(int(1)) + assert.DeepEqual(t, intType, shouldBeIntType) + assert.NotEqual(t, intType, uintType) + + foo1 := foo{f1: "1"} + foo2 := foo{f1: "2"} + _, foo1Type := valueAndTypeID(foo1) + _, foo2Type := valueAndTypeID(foo2) + _, foo2PointerType := valueAndTypeID(&foo2) + _, foo1PointerType := valueAndTypeID(&foo1) + assert.DeepEqual(t, foo1Type, foo2Type) + assert.NotEqual(t, foo1Type, foo2PointerType) + assert.DeepEqual(t, foo1PointerType, foo2PointerType) +} + +func TestReflect_CheckPointer(t *testing.T) { + foo1 := foo{} + foo1Val := reflect.ValueOf(foo1) + err := checkPointer(foo1Val) + if err == nil { + t.Errorf("expect an err, but get nil") + } + + foo2 := &foo{} + foo2Val := reflect.ValueOf(foo2) + err = checkPointer(foo2Val) + if err != nil { + t.Error(err) + } + + foo3 := (*foo)(nil) + foo3Val := reflect.ValueOf(foo3) + err = checkPointer(foo3Val) + if err == nil { + t.Errorf("expect an err, but get nil") + } +} + +func TestReflect_DereferPointer(t *testing.T) { + var foo1 ****foo + foo1Val := reflect.ValueOf(foo1) + rt := dereferPointer(foo1Val) + if rt.Kind() == reflect.Ptr { + t.Errorf("expect non-pointer type, but get pointer") + } + assert.DeepEqual(t, "foo", rt.Name()) + + var foo2 foo + foo2Val := reflect.ValueOf(foo2) + rt2 := dereferPointer(foo2Val) + if rt2.Kind() == reflect.Ptr { + t.Errorf("expect non-pointer type, but get pointer") + } + assert.DeepEqual(t, "foo", rt2.Name()) +} diff --git a/pkg/app/server/binding/request.go b/pkg/app/server/binding/request.go deleted file mode 100644 index e4d70ba0d..000000000 --- a/pkg/app/server/binding/request.go +++ /dev/null @@ -1,138 +0,0 @@ -/* - * Copyright 2022 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package binding - -import ( - "mime/multipart" - "net/http" - "net/url" - - "github.com/bytedance/go-tagexpr/v2/binding" - "github.com/cloudwego/hertz/internal/bytesconv" - "github.com/cloudwego/hertz/pkg/protocol" -) - -func wrapRequest(req *protocol.Request) binding.Request { - r := &bindRequest{ - req: req, - } - return r -} - -type bindRequest struct { - req *protocol.Request -} - -func (r *bindRequest) GetQuery() url.Values { - queryMap := make(url.Values) - r.req.URI().QueryArgs().VisitAll(func(key, value []byte) { - keyStr := string(key) - values := queryMap[keyStr] - values = append(values, string(value)) - queryMap[keyStr] = values - }) - - return queryMap -} - -func (r *bindRequest) GetPostForm() (url.Values, error) { - postMap := make(url.Values) - r.req.PostArgs().VisitAll(func(key, value []byte) { - keyStr := string(key) - values := postMap[keyStr] - values = append(values, string(value)) - postMap[keyStr] = values - }) - mf, err := r.req.MultipartForm() - if err == nil { - for k, v := range mf.Value { - if len(v) > 0 { - postMap[k] = v - } - } - } - - return postMap, nil -} - -func (r *bindRequest) GetForm() (url.Values, error) { - formMap := make(url.Values) - r.req.URI().QueryArgs().VisitAll(func(key, value []byte) { - keyStr := string(key) - values := formMap[keyStr] - values = append(values, string(value)) - formMap[keyStr] = values - }) - r.req.PostArgs().VisitAll(func(key, value []byte) { - keyStr := string(key) - values := formMap[keyStr] - values = append(values, string(value)) - formMap[keyStr] = values - }) - - return formMap, nil -} - -func (r *bindRequest) GetCookies() []*http.Cookie { - var cookies []*http.Cookie - r.req.Header.VisitAllCookie(func(key, value []byte) { - cookies = append(cookies, &http.Cookie{ - Name: string(key), - Value: string(value), - }) - }) - - return cookies -} - -func (r *bindRequest) GetHeader() http.Header { - header := make(http.Header) - r.req.Header.VisitAll(func(key, value []byte) { - keyStr := string(key) - values := header[keyStr] - values = append(values, string(value)) - header[keyStr] = values - }) - - return header -} - -func (r *bindRequest) GetMethod() string { - return bytesconv.B2s(r.req.Method()) -} - -func (r *bindRequest) GetContentType() string { - return bytesconv.B2s(r.req.Header.ContentType()) -} - -func (r *bindRequest) GetBody() ([]byte, error) { - return r.req.Body(), nil -} - -func (r *bindRequest) GetFileHeaders() (map[string][]*multipart.FileHeader, error) { - files := make(map[string][]*multipart.FileHeader) - mf, err := r.req.MultipartForm() - if err == nil { - for k, v := range mf.File { - if len(v) > 0 { - files[k] = v - } - } - } - - return files, nil -} diff --git a/pkg/app/server/binding/request_test.go b/pkg/app/server/binding/request_test.go deleted file mode 100644 index b3bb70523..000000000 --- a/pkg/app/server/binding/request_test.go +++ /dev/null @@ -1,235 +0,0 @@ -/* - * Copyright 2022 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package binding - -import ( - "bytes" - "fmt" - "testing" - - "github.com/cloudwego/hertz/pkg/common/test/assert" - "github.com/cloudwego/hertz/pkg/protocol" - "github.com/cloudwego/hertz/pkg/protocol/consts" -) - -func TestGetQuery(t *testing.T) { - r := protocol.NewRequest("GET", "/foo", nil) - r.SetRequestURI("/foo/bar?para1=hertz¶2=query1¶2=query2¶3=1¶3=2") - - bindReq := bindRequest{ - req: r, - } - - values := bindReq.GetQuery() - - assert.DeepEqual(t, []string{"hertz"}, values["para1"]) - assert.DeepEqual(t, []string{"query1", "query2"}, values["para2"]) - assert.DeepEqual(t, []string{"1", "2"}, values["para3"]) -} - -func TestGetPostForm(t *testing.T) { - data := "a=aaa&b=b1&b=b2&c=ccc&d=100" - mr := bytes.NewBufferString(data) - - r := protocol.NewRequest("POST", "/foo", mr) - r.Header.SetContentTypeBytes([]byte(consts.MIMEApplicationHTMLForm)) - r.Header.SetContentLength(len(data)) - - bindReq := bindRequest{ - req: r, - } - - values, err := bindReq.GetPostForm() - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - assert.DeepEqual(t, []string{"aaa"}, values["a"]) - assert.DeepEqual(t, []string{"b1", "b2"}, values["b"]) - assert.DeepEqual(t, []string{"ccc"}, values["c"]) - assert.DeepEqual(t, []string{"100"}, values["d"]) -} - -func TestGetForm(t *testing.T) { - data := "a=aaa&b=b1&b=b2&c=ccc&d=100" - mr := bytes.NewBufferString(data) - - r := protocol.NewRequest("POST", "/foo", mr) - r.SetRequestURI("/foo/bar?para1=hertz¶2=query1¶2=query2¶3=1¶3=2") - r.Header.SetContentTypeBytes([]byte(consts.MIMEApplicationHTMLForm)) - r.Header.SetContentLength(len(data)) - - bindReq := bindRequest{ - req: r, - } - - values, err := bindReq.GetForm() - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - assert.DeepEqual(t, []string{"aaa"}, values["a"]) - assert.DeepEqual(t, []string{"b1", "b2"}, values["b"]) - assert.DeepEqual(t, []string{"ccc"}, values["c"]) - assert.DeepEqual(t, []string{"100"}, values["d"]) - assert.DeepEqual(t, []string{"hertz"}, values["para1"]) - assert.DeepEqual(t, []string{"query1", "query2"}, values["para2"]) - assert.DeepEqual(t, []string{"1", "2"}, values["para3"]) -} - -func TestGetCookies(t *testing.T) { - r := protocol.NewRequest("POST", "/foo", nil) - r.SetCookie("cookie1", "cookies1") - r.SetCookie("cookie2", "cookies2") - - bindReq := bindRequest{ - req: r, - } - - values := bindReq.GetCookies() - - assert.DeepEqual(t, "cookies1", values[0].Value) - assert.DeepEqual(t, "cookies2", values[1].Value) -} - -func TestGetHeader(t *testing.T) { - headers := map[string]string{ - "Header1": "headers1", - "Header2": "headers2", - } - - r := protocol.NewRequest("GET", "/foo", nil) - r.SetHeaders(headers) - r.SetHeader("Header3", "headers3") - - bindReq := bindRequest{ - req: r, - } - - values := bindReq.GetHeader() - - assert.DeepEqual(t, []string{"headers1"}, values["Header1"]) - assert.DeepEqual(t, []string{"headers2"}, values["Header2"]) - assert.DeepEqual(t, []string{"headers3"}, values["Header3"]) -} - -func TestGetMethod(t *testing.T) { - r := protocol.NewRequest("GET", "/foo", nil) - - bindReq := bindRequest{ - req: r, - } - - values := bindReq.GetMethod() - - assert.DeepEqual(t, "GET", values) -} - -func TestGetContentType(t *testing.T) { - data := "a=aaa&b=b1&b=b2&c=ccc&d=100" - mr := bytes.NewBufferString(data) - - r := protocol.NewRequest("POST", "/foo", mr) - r.Header.SetContentTypeBytes([]byte(consts.MIMEApplicationHTMLForm)) - r.Header.SetContentLength(len(data)) - - bindReq := bindRequest{ - req: r, - } - - values := bindReq.GetContentType() - - assert.DeepEqual(t, consts.MIMEApplicationHTMLForm, values) -} - -func TestGetBody(t *testing.T) { - data := "a=aaa&b=b1&b=b2&c=ccc&d=100" - mr := bytes.NewBufferString(data) - - r := protocol.NewRequest("POST", "/foo", mr) - r.Header.SetContentTypeBytes([]byte(consts.MIMEApplicationHTMLForm)) - r.Header.SetContentLength(len(data)) - - bindReq := bindRequest{ - req: r, - } - - values, err := bindReq.GetBody() - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - assert.DeepEqual(t, []byte("a=aaa&b=b1&b=b2&c=ccc&d=100"), values) -} - -func TestGetFileHeaders(t *testing.T) { - s := `------WebKitFormBoundaryJwfATyF8tmxSJnLg -Content-Disposition: form-data; name="f" - -fff -------WebKitFormBoundaryJwfATyF8tmxSJnLg -Content-Disposition: form-data; name="F1"; filename="TODO1" -Content-Type: application/octet-stream - -- SessionClient with referer and cookies support. -- Client with requests' pipelining support. -- ProxyHandler similar to FSHandler. -- WebSockets. See https://tools.ietf.org/html/rfc6455 . -- HTTP/2.0. See https://tools.ietf.org/html/rfc7540 . -------WebKitFormBoundaryJwfATyF8tmxSJnLg -Content-Disposition: form-data; name="F1"; filename="TODO2" -Content-Type: application/octet-stream - -- SessionClient with referer and cookies support. -- Client with requests' pipelining support. -- ProxyHandler similar to FSHandler. -- WebSockets. See https://tools.ietf.org/html/rfc6455 . -- HTTP/2.0. See https://tools.ietf.org/html/rfc7540 . -------WebKitFormBoundaryJwfATyF8tmxSJnLg -Content-Disposition: form-data; name="F2"; filename="TODO3" -Content-Type: application/octet-stream - -- SessionClient with referer and cookies support. -- Client with requests' pipelining support. -- ProxyHandler similar to FSHandler. -- WebSockets. See https://tools.ietf.org/html/rfc6455 . -- HTTP/2.0. See https://tools.ietf.org/html/rfc7540 . - -------WebKitFormBoundaryJwfATyF8tmxSJnLg-- -tailfoobar` - - mr := bytes.NewBufferString(s) - - r := protocol.NewRequest("POST", "/foo", mr) - r.Header.SetContentTypeBytes([]byte("multipart/form-data; boundary=----WebKitFormBoundaryJwfATyF8tmxSJnLg")) - r.Header.SetContentLength(len(s)) - - bindReq := bindRequest{ - req: r, - } - - values, err := bindReq.GetFileHeaders() - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - fmt.Printf("%v\n", values) - - assert.DeepEqual(t, "TODO1", values["F1"][0].Filename) - assert.DeepEqual(t, "TODO2", values["F1"][1].Filename) - assert.DeepEqual(t, "TODO3", values["F2"][0].Filename) -} diff --git a/pkg/app/server/binding/tagexpr_bind_test.go b/pkg/app/server/binding/tagexpr_bind_test.go new file mode 100644 index 000000000..82221745c --- /dev/null +++ b/pkg/app/server/binding/tagexpr_bind_test.go @@ -0,0 +1,1281 @@ +/* + * Copyright 2023 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * MIT License + * + * Copyright 2019 Bytedance Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * This file may have been modified by CloudWeGo authors. All CloudWeGo + * Modifications are Copyright 2023 CloudWeGo Authors + */ + +package binding + +import ( + "bytes" + "encoding/json" + "io" + "io/ioutil" + "mime/multipart" + "net/http" + "net/url" + "strings" + "testing" + "time" + + "github.com/cloudwego/hertz/pkg/common/test/assert" + "github.com/cloudwego/hertz/pkg/protocol/consts" + "github.com/cloudwego/hertz/pkg/route/param" + "google.golang.org/protobuf/proto" +) + +func TestRawBody(t *testing.T) { + type Recv struct { + S []byte `raw_body:""` + F **string `raw_body:""` + } + bodyBytes := []byte("raw_body.............") + req := newRequest("", nil, nil, bytes.NewReader(bodyBytes)) + recv := new(Recv) + err := DefaultBinder().Bind(req.Req, recv, nil) + if err != nil { + if err != nil { + t.Error(err) + } + } + + assert.DeepEqual(t, bodyBytes, recv.S) + assert.DeepEqual(t, string(bodyBytes), **recv.F) +} + +func TestQueryString(t *testing.T) { + type metric string + type count int32 + + type Recv struct { + X **struct { + A []string `query:"a"` + B string `query:"b"` + C *[]string `query:"c,required"` + D *string `query:"d"` + E *[]***int `query:"e"` + F metric `query:"f"` + G []count `query:"g"` + } + Y string `query:"y,required"` + Z *string `query:"z"` + } + req := newRequest("http://localhost:8080/?a=a1&a=a2&b=b1&c=c1&c=c2&d=d1&d=d&f=qps&g=1002&g=1003&e=&e=2&y=y1", nil, nil, nil) + recv := new(Recv) + bindConfig := &BindConfig{} + bindConfig.LooseZeroMode = true + binder := NewDefaultBinder(bindConfig) + err := binder.Bind(req.Req, recv, nil) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, 0, ***(*(**recv.X).E)[0]) + assert.DeepEqual(t, 2, ***(*(**recv.X).E)[1]) + assert.DeepEqual(t, []string{"a1", "a2"}, (**recv.X).A) + assert.DeepEqual(t, "b1", (**recv.X).B) + assert.DeepEqual(t, []string{"c1", "c2"}, *(**recv.X).C) + assert.DeepEqual(t, "d1", *(**recv.X).D) + assert.DeepEqual(t, metric("qps"), (**recv.X).F) + assert.DeepEqual(t, []count{1002, 1003}, (**recv.X).G) + assert.DeepEqual(t, "y1", recv.Y) + assert.DeepEqual(t, (*string)(nil), recv.Z) +} + +func TestGetBody(t *testing.T) { + type Recv struct { + X **struct { + E string `json:"e,required" query:"e,required"` + } + } + req := newRequest("http://localhost:8080/", nil, nil, nil) + recv := new(Recv) + err := DefaultBinder().Bind(req.Req, recv, nil) + if err == nil { + t.Fatalf("expected an error, but get nil") + } + assert.DeepEqual(t, err.Error(), "'E' field is a 'required' parameter, but the request body does not have this parameter 'X.e'") +} + +func TestQueryNum(t *testing.T) { + type Recv struct { + X **struct { + A []int `query:"a"` + B int32 `query:"b"` + C *[]uint16 `query:"c,required"` + D *float32 `query:"d"` + } + Y bool `query:"y,required"` + Z *int64 `query:"z"` + } + req := newRequest("http://localhost:8080/?a=11&a=12&b=21&c=31&c=32&d=41&d=42&y=true", nil, nil, nil) + recv := new(Recv) + err := DefaultBinder().Bind(req.Req, recv, nil) + if err != nil { + if err != nil { + t.Error(err) + } + } + assert.DeepEqual(t, []int{11, 12}, (**recv.X).A) + assert.DeepEqual(t, int32(21), (**recv.X).B) + assert.DeepEqual(t, &[]uint16{31, 32}, (**recv.X).C) + assert.DeepEqual(t, float32(41), *(**recv.X).D) + assert.DeepEqual(t, true, recv.Y) + assert.DeepEqual(t, (*int64)(nil), recv.Z) +} + +func TestHeaderString(t *testing.T) { + type Recv struct { + X **struct { + A []string `header:"X-A"` + B string `header:"X-B"` + C *[]string `header:"X-C,required"` + D *string `header:"X-D"` + } + Y string `header:"X-Y,required"` + Z *string `header:"X-Z"` + } + header := make(http.Header) + header.Add("X-A", "a1") + header.Add("X-A", "a2") + header.Add("X-B", "b1") + header.Add("X-C", "c1") + header.Add("X-C", "c2") + header.Add("X-D", "d1") + header.Add("X-D", "d2") + header.Add("X-Y", "y1") + req := newRequest("", header, nil, nil) + recv := new(Recv) + err := DefaultBinder().Bind(req.Req, recv, nil) + if err != nil { + if err != nil { + t.Error(err) + } + } + assert.DeepEqual(t, []string{"a1", "a2"}, (**recv.X).A) + assert.DeepEqual(t, "b1", (**recv.X).B) + assert.DeepEqual(t, []string{"c1", "c2"}, *(**recv.X).C) + assert.DeepEqual(t, "d1", *(**recv.X).D) + assert.DeepEqual(t, "y1", recv.Y) + assert.DeepEqual(t, (*string)(nil), recv.Z) +} + +func TestHeaderNum(t *testing.T) { + type Recv struct { + X **struct { + A []int `header:"X-A"` + B int32 `header:"X-B"` + C *[]uint16 `header:"X-C,required"` + D *float32 `header:"X-D"` + } + Y bool `header:"X-Y,required"` + Z *int64 `header:"X-Z"` + } + header := make(http.Header) + header.Add("X-A", "11") + header.Add("X-A", "12") + header.Add("X-B", "21") + header.Add("X-C", "31") + header.Add("X-C", "32") + header.Add("X-D", "41") + header.Add("X-D", "42") + header.Add("X-Y", "true") + req := newRequest("", header, nil, nil) + recv := new(Recv) + + err := DefaultBinder().Bind(req.Req, recv, nil) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, []int{11, 12}, (**recv.X).A) + assert.DeepEqual(t, int32(21), (**recv.X).B) + assert.DeepEqual(t, &[]uint16{31, 32}, (**recv.X).C) + assert.DeepEqual(t, float32(41), *(**recv.X).D) + assert.DeepEqual(t, true, recv.Y) + assert.DeepEqual(t, (*int64)(nil), recv.Z) +} + +func TestCookieString(t *testing.T) { + type Recv struct { + X **struct { + A []string `cookie:"a"` + B string `cookie:"b"` + C *[]string `cookie:"c,required"` + D *string `cookie:"d"` + } + Y string `cookie:"y,required"` + Z *string `cookie:"z"` + } + cookies := []*http.Cookie{ + {Name: "a", Value: "a1"}, + {Name: "a", Value: "a2"}, + {Name: "b", Value: "b1"}, + {Name: "c", Value: "c1"}, + {Name: "c", Value: "c2"}, + {Name: "d", Value: "d1"}, + {Name: "d", Value: "d2"}, + {Name: "y", Value: "y1"}, + } + req := newRequest("", nil, cookies, nil) + recv := new(Recv) + + err := DefaultBinder().Bind(req.Req, recv, nil) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, []string{"a2"}, (**recv.X).A) + assert.DeepEqual(t, "b1", (**recv.X).B) + assert.DeepEqual(t, []string{"c2"}, *(**recv.X).C) + assert.DeepEqual(t, "d2", *(**recv.X).D) + assert.DeepEqual(t, "y1", recv.Y) + assert.DeepEqual(t, (*string)(nil), recv.Z) +} + +func TestCookieNum(t *testing.T) { + type Recv struct { + X **struct { + A []int `cookie:"a"` + B int32 `cookie:"b"` + C *[]uint16 `cookie:"c,required"` + D *float32 `cookie:"d"` + } + Y bool `cookie:"y,required"` + Z *int64 `cookie:"z"` + } + cookies := []*http.Cookie{ + {Name: "a", Value: "11"}, + {Name: "b", Value: "21"}, + {Name: "c", Value: "31"}, + {Name: "d", Value: "41"}, + {Name: "y", Value: "t"}, + } + req := newRequest("", nil, cookies, nil) + recv := new(Recv) + + err := DefaultBinder().Bind(req.Req, recv, nil) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, []int{11}, (**recv.X).A) + assert.DeepEqual(t, int32(21), (**recv.X).B) + assert.DeepEqual(t, &[]uint16{31}, (**recv.X).C) + assert.DeepEqual(t, float32(41), *(**recv.X).D) + assert.DeepEqual(t, true, recv.Y) + assert.DeepEqual(t, (*int64)(nil), recv.Z) +} + +func TestFormString(t *testing.T) { + type Recv struct { + X **struct { + A []string `form:"a"` + B string `form:"b"` + C *[]string `form:"c,required"` + D *string `form:"d"` + } + Y string `form:"y,required"` + Z *string `form:"z"` + F *multipart.FileHeader `form:"F1"` + F1 multipart.FileHeader + Fs []multipart.FileHeader `form:"F1"` + Fs1 []*multipart.FileHeader `form:"F1"` + } + values := make(url.Values) + values.Add("a", "a1") + values.Add("a", "a2") + values.Add("b", "b1") + values.Add("c", "c1") + values.Add("c", "c2") + values.Add("d", "d1") + values.Add("d", "d2") + values.Add("y", "y1") + for i, f := range []files{{ + "F1": []file{ + newFile("txt", strings.NewReader("0123")), + }, + }} { + contentType, bodyReader := newFormBody2(values, f) + header := make(http.Header) + header.Set("Content-Type", contentType) + req := newRequest("", header, nil, bodyReader) + recv := new(Recv) + err := DefaultBinder().Bind(req.Req, recv, nil) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, []string{"a1", "a2"}, (**recv.X).A) + assert.DeepEqual(t, "b1", (**recv.X).B) + assert.DeepEqual(t, []string{"c1", "c2"}, *(**recv.X).C) + assert.DeepEqual(t, "d1", *(**recv.X).D) + assert.DeepEqual(t, "y1", recv.Y) + assert.DeepEqual(t, (*string)(nil), recv.Z) + t.Logf("[%d] F: %#v", i, recv.F) + t.Logf("[%d] F1: %#v", i, recv.F1) + t.Logf("[%d] Fs: %#v", i, recv.Fs) + t.Logf("[%d] Fs1: %#v", i, recv.Fs1) + if len(recv.Fs1) > 0 { + t.Logf("[%d] Fs1[0]: %#v", i, recv.Fs1[0]) + } + } +} + +func TestFormNum(t *testing.T) { + type Recv struct { + X **struct { + A []int `form:"a"` + B int32 `form:"b"` + C *[]uint16 `form:"c,required"` + D *float32 `form:"d"` + } + Y bool `form:"y,required"` + Z *int64 `form:"z"` + } + values := make(url.Values) + values.Add("a", "11") + values.Add("a", "12") + values.Add("b", "-21") + values.Add("c", "31") + values.Add("c", "32") + values.Add("d", "41") + values.Add("d", "42") + values.Add("y", "1") + for _, f := range []files{nil, { + "f1": []file{ + newFile("txt", strings.NewReader("f11 text.")), + }, + }} { + contentType, bodyReader := newFormBody2(values, f) + header := make(http.Header) + header.Set("Content-Type", contentType) + req := newRequest("", header, nil, bodyReader) + recv := new(Recv) + + err := DefaultBinder().Bind(req.Req, recv, nil) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, []int{11, 12}, (**recv.X).A) + assert.DeepEqual(t, int32(-21), (**recv.X).B) + assert.DeepEqual(t, &[]uint16{31, 32}, (**recv.X).C) + assert.DeepEqual(t, float32(41), *(**recv.X).D) + assert.DeepEqual(t, true, recv.Y) + assert.DeepEqual(t, (*int64)(nil), recv.Z) + } +} + +func TestJSON(t *testing.T) { + type metric string + type count int32 + type ZS struct { + Z *int64 + } + type Recv struct { + X **struct { + A []string `json:"a"` + B int32 `json:""` + C *[]uint16 `json:",required"` + D *float32 `json:"d"` + E metric `json:"e"` + F count `json:"f"` + M map[string]string `json:"m"` + } + Y string `json:"y,required"` + ZS + } + + bodyReader := strings.NewReader(`{ + "X": { + "a": ["a1","a2"], + "B": 21, + "C": [31,32], + "d": 41, + "e": "qps", + "f": 100, + "m": {"a":"x"} + }, + "Z": 6 + }`) + + header := make(http.Header) + header.Set("Content-Type", consts.MIMEApplicationJSON) + req := newRequest("", header, nil, bodyReader) + recv := new(Recv) + + err := DefaultBinder().Bind(req.Req, recv, nil) + if err == nil { + t.Error("expected an error, but get nil") + } + assert.DeepEqual(t, err.Error(), "'Y' field is a 'required' parameter, but the request body does not have this parameter 'y'") + assert.DeepEqual(t, []string{"a1", "a2"}, (**recv.X).A) + assert.DeepEqual(t, int32(21), (**recv.X).B) + assert.DeepEqual(t, &[]uint16{31, 32}, (**recv.X).C) + assert.DeepEqual(t, float32(41), *(**recv.X).D) + assert.DeepEqual(t, metric("qps"), (**recv.X).E) + assert.DeepEqual(t, count(100), (**recv.X).F) + assert.DeepEqual(t, map[string]string{"a": "x"}, (**recv.X).M) + assert.DeepEqual(t, "", recv.Y) + assert.DeepEqual(t, (int64)(6), *recv.Z) +} + +func TestNonstruct(t *testing.T) { + bodyReader := strings.NewReader(`{ + "X": { + "a": ["a1","a2"], + "B": 21, + "C": [31,32], + "d": 41, + "e": "qps", + "f": 100 + }, + "Z": 6 + }`) + + header := make(http.Header) + header.Set("Content-Type", "application/json") + req := newRequest("", header, nil, bodyReader) + var recv interface{} + err := DefaultBinder().Bind(req.Req, &recv, nil) + if err != nil { + t.Error(err) + } + b, err := json.Marshal(recv) + if err != nil { + t.Error(err) + } + t.Logf("%s", b) + + bodyReader = strings.NewReader("b=334ddddd&token=yoMba34uspjVQEbhflgTRe2ceeDFUK32&type=url_verification") + header.Set("Content-Type", "application/x-www-form-urlencoded; charset=utf-8") + req = newRequest("", header, nil, bodyReader) + recv = nil + err = DefaultBinder().Bind(req.Req, &recv, nil) + if err != nil { + t.Error(err) + } + b, err = json.Marshal(recv) + if err != nil { + t.Error(err) + } + t.Logf("%s", b) +} + +func TestPath(t *testing.T) { + type Recv struct { + X **struct { + A []string `path:"a"` + B int32 `path:"b"` + C *[]uint16 `path:"c,required"` + D *float32 `path:"d"` + } + Y string `path:"y,required"` + Z *int64 + } + + req := newRequest("", nil, nil, nil) + recv := new(Recv) + + params := param.Params{ + { + Key: "a", + Value: "a1", + }, + { + Key: "b", + Value: "-21", + }, + { + Key: "c", + Value: "31", + }, + { + Key: "d", + Value: "41", + }, + { + Key: "y", + Value: "y1", + }, + { + Key: "name", + Value: "henrylee2cn", + }, + } + + err := DefaultBinder().Bind(req.Req, recv, params) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, []string{"a1"}, (**recv.X).A) + assert.DeepEqual(t, int32(-21), (**recv.X).B) + assert.DeepEqual(t, &[]uint16{31}, (**recv.X).C) + assert.DeepEqual(t, float32(41), *(**recv.X).D) + assert.DeepEqual(t, "y1", recv.Y) + assert.DeepEqual(t, (*int64)(nil), recv.Z) +} + +// FIXME: 复杂类型的默认值,暂时先不做,低优 +func TestDefault(t *testing.T) { + //type S struct { + // SS string `json:"ss"` + //} + type Recv struct { + X **struct { + A []string `path:"a" json:"a"` + B int32 `path:"b" default:"32"` + C bool `json:"c" default:"true"` + D *float32 `default:"123.4"` + // E *[]string `default:"['a','b','c','d,e,f']"` + // F map[string]string `default:"{'a':'\"\\'1','\"b':'c','c':'2'}"` + // G map[string]int64 `default:"{'a':1,'b':2,'c':3}"` + // H map[string]float64 `default:"{'a':0.1,'b':1.2,'c':2.3}"` + // I map[string]float64 `default:"{'\"a\"':0.1,'b':1.2,'c':2.3}"` + Empty string `default:""` + Null string `default:""` + CommaSpace string `default:",a:c "` + Dash string `default:"-"` + // InvalidInt int `default:"abc"` + // InvalidMap map[string]string `default:"abc"` + } + Y string `json:"y" default:"y1"` + Z int64 + W string `json:"w"` + // V []int64 `json:"u" default:"[1,2,3]"` + // U []float32 `json:"u" default:"[1.1,2,3]"` + T *string `json:"t" default:"t1"` + // S S `default:"{'ss':'test'}"` + // O *S `default:"{'ss':'test2'}"` + // Complex map[string][]map[string][]int64 `default:"{'a':[{'aa':[1,2,3], 'bb':[4,5]}],'b':[{}]}"` + } + + bodyReader := strings.NewReader(`{ + "X": { + "a": ["a1","a2"] + }, + "Z": 6 + }`) + + // var nilMap map[string]string + header := make(http.Header) + header.Set("Content-Type", consts.MIMEApplicationJSON) + req := newRequest("", header, nil, bodyReader) + recv := new(Recv) + + param2 := param.Params{ + { + Key: "e", + Value: "123", + }, + } + + err := DefaultBinder().Bind(req.Req, recv, param2) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, []string{"a1", "a2"}, (**recv.X).A) + assert.DeepEqual(t, int32(32), (**recv.X).B) + assert.DeepEqual(t, true, (**recv.X).C) + assert.DeepEqual(t, float32(123.4), *(**recv.X).D) + // assert.DeepEqual(t, []string{"a", "b", "c", "d,e,f"}, *(**recv.X).E) + // assert.DeepEqual(t, map[string]string{"a": "\"'1", "\"b": "c", "c": "2"}, (**recv.X).F) + // assert.DeepEqual(t, map[string]int64{"a": 1, "b": 2, "c": 3}, (**recv.X).G) + // assert.DeepEqual(t, map[string]float64{"a": 0.1, "b": 1.2, "c": 2.3}, (**recv.X).H) + // assert.DeepEqual(t, map[string]float64{"\"a\"": 0.1, "b": 1.2, "c": 2.3}, (**recv.X).I) + assert.DeepEqual(t, "", (**recv.X).Empty) + assert.DeepEqual(t, "", (**recv.X).Null) + assert.DeepEqual(t, ",a:c ", (**recv.X).CommaSpace) + assert.DeepEqual(t, "-", (**recv.X).Dash) + // assert.DeepEqual(t, 0, (**recv.X).InvalidInt) + // assert.DeepEqual(t, nilMap, (**recv.X).InvalidMap) + assert.DeepEqual(t, "y1", recv.Y) + assert.DeepEqual(t, "t1", *recv.T) + assert.DeepEqual(t, int64(6), recv.Z) + // assert.DeepEqual(t, []int64{1, 2, 3}, recv.V) + // assert.DeepEqual(t, []float32{1.1, 2, 3}, recv.U) + // assert.DeepEqual(t, S{SS: "test"}, recv.S) + // assert.DeepEqual(t, &S{SS: "test2"}, recv.O) + // assert.DeepEqual(t, map[string][]map[string][]int64{"a": {{"aa": {1, 2, 3}, "bb": []int64{4, 5}}}, "b": {map[string][]int64{}}}, recv.Complex) +} + +func TestAuto(t *testing.T) { + type Recv struct { + A string + B string + C string + D string `query:"D,required" form:"D,required"` + E string `cookie:"e" json:"e"` + } + query := make(url.Values) + query.Add("A", "a") + query.Add("B", "b") + query.Add("C", "c") + query.Add("D", "d-from-query") + contentType, bodyReader, err := newJSONBody(map[string]string{"e": "e-from-jsonbody"}) + if err != nil { + t.Error(err) + } + header := make(http.Header) + header.Set("Content-Type", contentType) + req := newRequest("http://localhost/?"+query.Encode(), header, []*http.Cookie{ + {Name: "e", Value: "e-from-cookie"}, + }, bodyReader) + recv := new(Recv) + + err = DefaultBinder().Bind(req.Req, recv, nil) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, "a", recv.A) + assert.DeepEqual(t, "b", recv.B) + assert.DeepEqual(t, "c", recv.C) + assert.DeepEqual(t, "d-from-query", recv.D) + assert.DeepEqual(t, "e-from-cookie", recv.E) + + query = make(url.Values) + query.Add("D", "d-from-query") + form := make(url.Values) + form.Add("B", "b") + form.Add("C", "c") + form.Add("D", "d-from-form") + contentType, bodyReader = newFormBody2(form, nil) + header = make(http.Header) + header.Set("Content-Type", contentType) + req = newRequest("http://localhost/?"+query.Encode(), header, nil, bodyReader) + recv = new(Recv) + err = DefaultBinder().Bind(req.Req, recv, nil) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, "", recv.A) + assert.DeepEqual(t, "b", recv.B) + assert.DeepEqual(t, "c", recv.C) + assert.DeepEqual(t, "d-from-form", recv.D) +} + +func TestTypeUnmarshal(t *testing.T) { + type Recv struct { + A time.Time `form:"t1"` + B *time.Time `query:"t2"` + C []time.Time `query:"t2"` + } + query := make(url.Values) + query.Add("t2", "2019-09-04T14:05:24+08:00") + query.Add("t2", "2019-09-04T18:05:24+08:00") + form := make(url.Values) + form.Add("t1", "2019-09-03T18:05:24+08:00") + contentType, bodyReader := newFormBody2(form, nil) + header := make(http.Header) + header.Set("Content-Type", contentType) + req := newRequest("http://localhost/?"+query.Encode(), header, nil, bodyReader) + recv := new(Recv) + + err := DefaultBinder().Bind(req.Req, recv, nil) + if err != nil { + t.Error(err) + } + t1, err := time.Parse(time.RFC3339, "2019-09-03T18:05:24+08:00") + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, t1, recv.A) + t21, err := time.Parse(time.RFC3339, "2019-09-04T14:05:24+08:00") + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, t21, *recv.B) + t22, err := time.Parse(time.RFC3339, "2019-09-04T18:05:24+08:00") + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, []time.Time{t21, t22}, recv.C) + t.Logf("%v", recv) +} + +// test: required +func TestOption(t *testing.T) { + type Recv struct { + X *struct { + C int `json:"c,required"` + D int `json:"d"` + } `json:"X"` + Y string `json:"y"` + } + header := make(http.Header) + header.Set("Content-Type", consts.MIMEApplicationJSON) + + bodyReader := strings.NewReader(`{ + "X": { + "c": 21, + "d": 41 + }, + "y": "y1" + }`) + req := newRequest("", header, nil, bodyReader) + recv := new(Recv) + + err := DefaultBinder().Bind(req.Req, recv, nil) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, 21, recv.X.C) + assert.DeepEqual(t, 41, recv.X.D) + assert.DeepEqual(t, "y1", recv.Y) + + bodyReader = strings.NewReader(`{ + "X": { + }, + "y": "y1" + }`) + req = newRequest("", header, nil, bodyReader) + recv = new(Recv) + err = DefaultBinder().Bind(req.Req, recv, nil) + assert.DeepEqual(t, err.Error(), "'C' field is a 'required' parameter, but the request body does not have this parameter 'X.c'") + assert.DeepEqual(t, 0, recv.X.C) + assert.DeepEqual(t, 0, recv.X.D) + assert.DeepEqual(t, "y1", recv.Y) + + bodyReader = strings.NewReader(`{ + "y": "y1" + }`) + req = newRequest("", header, nil, bodyReader) + recv = new(Recv) + err = DefaultBinder().Bind(req.Req, recv, nil) + if err != nil { + t.Error(err) + } + assert.True(t, recv.X == nil) + assert.DeepEqual(t, "y1", recv.Y) + + type Recv2 struct { + X *struct { + C int `json:"c,required"` + D int `json:"d"` + } `json:"X,required"` + Y string `json:"y"` + } + bodyReader = strings.NewReader(`{ + "y": "y1" + }`) + req = newRequest("", header, nil, bodyReader) + recv2 := new(Recv2) + bindConfig := &BindConfig{} + bindConfig.DisableStructFieldResolve = false + binder := NewDefaultBinder(bindConfig) + err = binder.Bind(req.Req, recv2, nil) + assert.DeepEqual(t, err.Error(), "'X' field is a 'required' parameter, but the request does not have this parameter") + assert.True(t, recv2.X == nil) + assert.DeepEqual(t, "y1", recv2.Y) +} + +func newRequest(u string, header http.Header, cookies []*http.Cookie, bodyReader io.Reader) *mockRequest { + if header == nil { + header = make(http.Header) + } + method := "GET" + var body []byte + if bodyReader != nil { + body, _ = ioutil.ReadAll(bodyReader) + method = "POST" + } + if u == "" { + u = "http://localhost" + } + req := newMockRequest() + req.SetRequestURI(u) + for k, v := range header { + for _, val := range v { + req.Req.Header.Add(k, val) + } + } + if len(body) != 0 { + req.SetBody(body) + req.Req.Header.SetContentLength(len(body)) + } + req.Req.SetMethod(method) + for _, c := range cookies { + req.Req.Header.SetCookie(c.Name, c.Value) + } + return req +} + +func TestQueryStringIssue(t *testing.T) { + type Timestamp struct { + time.Time + } + type Recv struct { + Name *string `query:"name"` + T *Timestamp `query:"t"` + } + req := newRequest("http://localhost:8080/?name=test", nil, nil, nil) + recv := new(Recv) + + err := DefaultBinder().Bind(req.Req, recv, nil) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, "test", *recv.Name) + // DIFF: the type with customized decoder must be a non-nil value + // assert.DeepEqual(t, (*Timestamp)(nil), recv.T) +} + +func TestQueryTypes(t *testing.T) { + type metric string + type count int32 + type metrics []string + type filter struct { + Col1 string + } + + type Recv struct { + A metric + B count + C *count + D metrics `query:"D,required" form:"D,required"` + E metric `cookie:"e" json:"e"` + F filter `json:"f"` + } + query := make(url.Values) + query.Add("A", "qps") + query.Add("B", "123") + query.Add("C", "321") + query.Add("D", "dau") + query.Add("D", "dnu") + contentType, bodyReader, err := newJSONBody( + map[string]interface{}{ + "e": "e-from-jsonbody", + "f": filter{Col1: "abc"}, + }, + ) + if err != nil { + t.Error(err) + } + header := make(http.Header) + header.Set("Content-Type", contentType) + req := newRequest("http://localhost/?"+query.Encode(), header, []*http.Cookie{ + {Name: "e", Value: "e-from-cookie"}, + }, bodyReader) + recv := new(Recv) + + err = DefaultBinder().Bind(req.Req, recv, nil) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, metric("qps"), recv.A) + assert.DeepEqual(t, count(123), recv.B) + assert.DeepEqual(t, count(321), *recv.C) + assert.DeepEqual(t, metrics{"dau", "dnu"}, recv.D) + assert.DeepEqual(t, metric("e-from-cookie"), recv.E) + assert.DeepEqual(t, filter{Col1: "abc"}, recv.F) +} + +func TestNoTagIssue(t *testing.T) { + type x int + type T struct { + x + x2 x + a int + B int + } + req := newRequest("http://localhost:8080/?x=11&x2=12&a=1&B=2", nil, nil, nil) + recv := new(T) + + err := DefaultBinder().Bind(req.Req, recv, nil) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, x(0), recv.x) + assert.DeepEqual(t, x(0), recv.x2) + assert.DeepEqual(t, 0, recv.a) + assert.DeepEqual(t, 2, recv.B) +} + +func TestRegTypeUnmarshal(t *testing.T) { + type Q struct { + A int + B string + } + type T struct { + Q Q `query:"q"` + Qs []*Q `query:"qs"` + Qs2 ***[]*Q `query:"qs"` + } + values := url.Values{} + b, err := json.Marshal(Q{A: 2, B: "y"}) + if err != nil { + t.Error(err) + } + values.Add("q", string(b)) + bs, _ := json.Marshal([]Q{{A: 1, B: "x"}, {A: 2, B: "y"}}) + values.Add("qs", string(bs)) + req := newRequest("http://localhost:8080/?"+values.Encode(), nil, nil, nil) + recv := new(T) + + bindConfig := &BindConfig{} + bindConfig.DisableStructFieldResolve = false + binder := NewDefaultBinder(bindConfig) + err = binder.Bind(req.Req, recv, nil) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, 2, recv.Q.A) + assert.DeepEqual(t, "y", recv.Q.B) + assert.DeepEqual(t, 1, recv.Qs[0].A) + assert.DeepEqual(t, "x", recv.Qs[0].B) + assert.DeepEqual(t, 2, recv.Qs[1].A) + assert.DeepEqual(t, "y", recv.Qs[1].B) + assert.DeepEqual(t, 1, (***recv.Qs2)[0].A) + assert.DeepEqual(t, "x", (***recv.Qs2)[0].B) + assert.DeepEqual(t, 2, (***recv.Qs2)[1].A) + assert.DeepEqual(t, "y", (***recv.Qs2)[1].B) +} + +func TestPathnameBUG(t *testing.T) { + type Currency struct { + CurrencyName *string `form:"currency_name,required" json:"currency_name,required" protobuf:"bytes,1,req,name=currency_name,json=currencyName" query:"currency_name,required"` + CurrencySymbol *string `form:"currency_symbol,required" json:"currency_symbol,required" protobuf:"bytes,2,req,name=currency_symbol,json=currencySymbol" query:"currency_symbol,required"` + SymbolPosition *int32 `form:"symbol_position,required" json:"symbol_position,required" protobuf:"varint,3,req,name=symbol_position,json=symbolPosition" query:"symbol_position,required"` + DecimalPlaces *int32 `form:"decimal_places,required" json:"decimal_places,required" protobuf:"varint,4,req,name=decimal_places,json=decimalPlaces" query:"decimal_places,required"` // 56x56 + DecimalSymbol *string `form:"decimal_symbol,required" json:"decimal_symbol,required" protobuf:"bytes,5,req,name=decimal_symbol,json=decimalSymbol" query:"decimal_symbol,required"` + Separator *string `form:"separator,required" json:"separator,required" protobuf:"bytes,6,req,name=separator" query:"separator,required"` + SeparatorIndex *string `form:"separator_index,required" json:"separator_index,required" protobuf:"bytes,7,req,name=separator_index,json=separatorIndex" query:"separator_index,required"` + Between *string `form:"between,required" json:"between,required" protobuf:"bytes,8,req,name=between" query:"between,required"` + MinPrice *string `form:"min_price" json:"min_price,omitempty" protobuf:"bytes,9,opt,name=min_price,json=minPrice" query:"min_price"` + MaxPrice *string `form:"max_price" json:"max_price,omitempty" protobuf:"bytes,10,opt,name=max_price,json=maxPrice" query:"max_price"` + } + + type CurrencyData struct { + Amount *string `form:"amount,required" json:"amount,required" protobuf:"bytes,1,req,name=amount" query:"amount,required"` + Currency *Currency `form:"currency,required" json:"currency,required" protobuf:"bytes,2,req,name=currency" query:"currency,required"` + } + + type ExchangeCurrencyRequest struct { + PromotionRegion *string `form:"promotion_region,required" json:"promotion_region,required" protobuf:"bytes,1,req,name=promotion_region,json=promotionRegion" query:"promotion_region,required"` + Currency *CurrencyData `form:"currency,required" json:"currency,required" protobuf:"bytes,2,req,name=currency" query:"currency,required"` + Version *int32 `json:"version,omitempty" path:"version" protobuf:"varint,100,opt,name=version"` + } + + z := new(ExchangeCurrencyRequest) + z.Currency = new(CurrencyData) + z.Currency.Currency = new(Currency) + z.PromotionRegion = proto.String("?") + z.Version = proto.Int32(-32) + z.Currency.Amount = proto.String("?") + z.Currency.Currency.CurrencyName = proto.String("?") + z.Currency.Currency.CurrencySymbol = proto.String("?") + z.Currency.Currency.SymbolPosition = proto.Int32(-32) + z.Currency.Currency.DecimalPlaces = proto.Int32(-32) + z.Currency.Currency.DecimalSymbol = proto.String("?") + z.Currency.Currency.Separator = proto.String("?") + z.Currency.Currency.Between = proto.String("?") + z.Currency.Currency.MinPrice = proto.String("?") + z.Currency.Currency.MaxPrice = proto.String("?") + + b, err := json.MarshalIndent(z, "", " ") + if err != nil { + t.Error(err) + } + header := make(http.Header) + header.Set("Content-Type", "application/json;charset=utf-8") + req := newRequest("http://localhost", header, nil, bytes.NewReader(b)) + recv := new(ExchangeCurrencyRequest) + + err = DefaultBinder().Bind(req.Req, recv, nil) + if err != nil { + t.Error(err) + } +} + +// test: required +func TestPathnameBUG2(t *testing.T) { + type CurrencyData struct { + Amount *string `form:"amount,required" json:"amount,required" protobuf:"bytes,1,req,name=amount" query:"amount,required"` + Name *string `form:"name,required" json:"name,required" protobuf:"bytes,2,req,name=name" query:"name,required"` + Symbol *string `form:"symbol" json:"symbol,omitempty" protobuf:"bytes,3,opt,name=symbol" query:"symbol"` + } + type TimeRange struct { + StartTime *int64 `form:"start_time,required" json:"start_time,required" protobuf:"varint,1,req,name=start_time,json=startTime" query:"start_time,required"` + EndTime *int64 `form:"end_time,required" json:"end_time,required" protobuf:"varint,2,req,name=end_time,json=endTime" query:"end_time,required"` + } + type CreateFreeShippingRequest struct { + PromotionName *string `form:"promotion_name,required" json:"promotion_name,required" protobuf:"bytes,1,req,name=promotion_name,json=promotionName" query:"promotion_name,required"` + PromotionRegion *string `form:"promotion_region,required" json:"promotion_region,required" protobuf:"bytes,2,req,name=promotion_region,json=promotionRegion" query:"promotion_region,required"` + TimeRange *TimeRange `form:"time_range,required" json:"time_range,required" protobuf:"bytes,3,req,name=time_range,json=timeRange" query:"time_range,required"` + PromotionBudget *CurrencyData `form:"promotion_budget,required" json:"promotion_budget,required" protobuf:"bytes,4,req,name=promotion_budget,json=promotionBudget" query:"promotion_budget,required"` + Loaded_SellerIds []string `form:"loaded_Seller_ids" json:"loaded_Seller_ids,omitempty" protobuf:"bytes,5,rep,name=loaded_Seller_ids,json=loadedSellerIds" query:"loaded_Seller_ids"` + Version *int32 `json:"version,omitempty" path:"version" protobuf:"varint,100,opt,name=version"` + } + + // z := &CreateFreeShippingRequest{} + // v := ameda.InitSampleValue(reflect.TypeOf(z), 10).Interface().(*CreateFreeShippingRequest) + // b, err := json.MarshalIndent(v, "", " ") + // t.Log(string(b)) + b := []byte(`{ + "promotion_name": "mu", + "promotion_region": "ID", + "time_range": { + "start_time": 1616420139, + "end_time": 1616520139 + }, + "promotion_budget": { + "amount":"10000000", + "name":"USD", + "symbol":"$" + }, + "loaded_Seller_ids": [ + "7493989780026655762","11111","111212121" + ] +}`) + v := new(CreateFreeShippingRequest) + err := json.Unmarshal(b, v) + if err != nil { + t.Error(err) + } + + header := make(http.Header) + header.Set("Content-Type", "application/json;charset=utf-8") + req := newRequest("http://localhost", header, nil, bytes.NewReader(b)) + recv := new(CreateFreeShippingRequest) + + err = DefaultBinder().Bind(req.Req, recv, nil) + if err != nil { + t.Error(err) + } + + assert.DeepEqual(t, v, recv) +} + +func TestRequiredBUG(t *testing.T) { + type Currency struct { + // currencyName *string `form:"currency_name,required" json:"currency_name,required" protobuf:"bytes,1,req,name=currency_name,json=currencyName" query:"currency_name,required"` + CurrencySymbol *string `form:"currency_symbol,required" json:"currency_symbol,required" protobuf:"bytes,2,req,name=currency_symbol,json=currencySymbol" query:"currency_symbol,required"` + } + + type CurrencyData struct { + Amount *string `form:"amount,required" json:"amount,required" protobuf:"bytes,1,req,name=amount" query:"amount,required"` + Slice []*Currency `form:"slice,required" json:"slice,required" protobuf:"bytes,2,req,name=slice" query:"slice,required"` + Map map[string]*Currency `form:"map,required" json:"map,required" protobuf:"bytes,2,req,name=map" query:"map,required"` + } + + type ExchangeCurrencyRequest struct { + PromotionRegion *string `form:"promotion_region,required" json:"promotion_region,required" protobuf:"bytes,1,req,name=promotion_region,json=promotionRegion" query:"promotion_region,required"` + Currency *CurrencyData `form:"currency,required" json:"currency,required" protobuf:"bytes,2,req,name=currency" query:"currency,required"` + } + + z := &ExchangeCurrencyRequest{} + b := []byte(`{ + "promotion_region": "?", + "currency": { + "amount": "?", + "slice": [ + { + "currency_symbol": "?" + } + ], + "map": { + "?": { + "currency_name": "?" + } + } + } + }`) + json.Unmarshal(b, z) + header := make(http.Header) + header.Set("Content-Type", "application/json;charset=utf-8") + req := newRequest("http://localhost", header, nil, bytes.NewReader(b)) + recv := new(ExchangeCurrencyRequest) + + err := DefaultBinder().Bind(req.Req, recv, nil) + // no need for validate + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, z, recv) +} + +func TestIssue25(t *testing.T) { + type Recv struct { + A string + } + header := make(http.Header) + header.Set("A", "from header") + cookies := []*http.Cookie{ + {Name: "A", Value: "from cookie"}, + } + req := newRequest("/1", header, cookies, nil) + recv := new(Recv) + + err := DefaultBinder().Bind(req.Req, recv, nil) + if err != nil { + t.Error(err) + } + // assert.DeepEqual(t, "from cookie", recv.A) + + header2 := make(http.Header) + header2.Set("A", "from header") + cookies2 := []*http.Cookie{} + req2 := newRequest("/2", header2, cookies2, nil) + recv2 := new(Recv) + err2 := DefaultBinder().Bind(req2.Req, recv2, nil) + if err2 != nil { + t.Error(err2) + } + assert.DeepEqual(t, "from header", recv2.A) +} + +func TestIssue26(t *testing.T) { + type Recv struct { + Type string `json:"type,required" vd:"($=='update_target_threshold' && (TargetThreshold)$!='-1') || ($=='update_status' && (Status)$!='-1')"` + RuleName string `json:"rule_name,required" vd:"regexp('^rule[0-9]+$')"` + TargetThreshold string `json:"target_threshold" vd:"regexp('^-?[0-9]+(\\.[0-9]+)?$')"` + Status string `json:"status" vd:"$=='0' || $=='1'"` + Operator string `json:"operator,required" vd:"len($)>0"` + } + + b := []byte(`{ + "status": "1", + "adv": "11520", + "target_deep_external_action": "39", + "package": "test.bytedance.com", + "previous_target_threshold": "0.6", + "deep_external_action": "675", + "rule_name": "rule2", + "deep_bid_type": "54", + "modify_time": "2021-08-24:14:35:20", + "aid": "111", + "operator": "yanghaoze", + "external_action": "76", + "target_threshold": "0.1", + "type": "update_status" +}`) + + recv := new(Recv) + err := json.Unmarshal(b, recv) + if err != nil { + t.Error(err) + } + + header := make(http.Header) + header.Set("Content-Type", consts.MIMEApplicationJSON) + header.Set("A", "from header") + cookies := []*http.Cookie{ + {Name: "A", Value: "from cookie"}, + } + + req := newRequest("/1", header, cookies, bytes.NewReader(b)) + + recv2 := new(Recv) + err = DefaultBinder().Bind(req.Req, recv2, nil) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, recv, recv2) +} + +// FIXME: after 'json unmarshal', the default value will change it +//func TestDefault2(t *testing.T) { +// type Recv struct { +// X **struct { +// Dash string `default:"xxxx"` +// } +// } +// bodyReader := strings.NewReader(`{ +// "X": { +// "Dash": "hello Dash" +// } +// }`) +// header := make(http.Header) +// header.Set("Content-Type", consts.MIMEApplicationJSON) +// req := newRequest("", header, nil, bodyReader) +// recv := new(Recv) +// +// err := DefaultBinder().Bind(req.Req, nil, recv) +// if err != nil { +// t.Error(err) +// } +// assert.DeepEqual(t, "hello Dash", (**recv.X).Dash) +//} + +type ( + files map[string][]file + file interface { + Name() string + Read(p []byte) (n int, err error) + } +) + +func newFormBody2(values url.Values, files files) (contentType string, bodyReader io.Reader) { + if len(files) == 0 { + return "application/x-www-form-urlencoded", strings.NewReader(values.Encode()) + } + pr, pw := io.Pipe() + bodyWriter := multipart.NewWriter(pw) + var fileWriter io.Writer + buf := make([]byte, 32*1024) + go func() { + for fieldName, postfiles := range files { + for _, file := range postfiles { + fileWriter, _ = bodyWriter.CreateFormFile(fieldName, file.Name()) + io.CopyBuffer(fileWriter, file, buf) + } + } + for k, v := range values { + for _, vv := range v { + bodyWriter.WriteField(k, vv) + } + } + bodyWriter.Close() + pw.Close() + }() + return bodyWriter.FormDataContentType(), pr +} + +func newFile(name string, bodyReader io.Reader) file { + return &fileReader{name, bodyReader} +} + +// fileReader file name and bytes. +type fileReader struct { + name string + bodyReader io.Reader +} + +func (f *fileReader) Name() string { + return f.name +} + +func (f *fileReader) Read(p []byte) (int, error) { + return f.bodyReader.Read(p) +} + +func newJSONBody(v interface{}) (contentType string, bodyReader io.Reader, err error) { + b, err := json.Marshal(v) + if err != nil { + return + } + return "application/json;charset=utf-8", bytes.NewReader(b), nil +} diff --git a/pkg/app/server/binding/testdata/hello.pb.go b/pkg/app/server/binding/testdata/hello.pb.go new file mode 100644 index 000000000..8b4bce477 --- /dev/null +++ b/pkg/app/server/binding/testdata/hello.pb.go @@ -0,0 +1,157 @@ +/* + * Copyright 2023 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.30.0 +// protoc v3.21.12 +// source: hello.proto + +package testdata + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type HertzReq struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Name string `protobuf:"bytes,1,opt,name=Name,proto3" json:"Name,omitempty"` +} + +func (x *HertzReq) Reset() { + *x = HertzReq{} + if protoimpl.UnsafeEnabled { + mi := &file_hello_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *HertzReq) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*HertzReq) ProtoMessage() {} + +func (x *HertzReq) ProtoReflect() protoreflect.Message { + mi := &file_hello_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use HertzReq.ProtoReflect.Descriptor instead. +func (*HertzReq) Descriptor() ([]byte, []int) { + return file_hello_proto_rawDescGZIP(), []int{0} +} + +func (x *HertzReq) GetName() string { + if x != nil { + return x.Name + } + return "" +} + +var File_hello_proto protoreflect.FileDescriptor + +var file_hello_proto_rawDesc = []byte{ + 0x0a, 0x0b, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x05, 0x68, + 0x65, 0x72, 0x74, 0x7a, 0x22, 0x1e, 0x0a, 0x08, 0x48, 0x65, 0x72, 0x74, 0x7a, 0x52, 0x65, 0x71, + 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, + 0x4e, 0x61, 0x6d, 0x65, 0x42, 0x0d, 0x5a, 0x0b, 0x68, 0x65, 0x72, 0x74, 0x7a, 0x2f, 0x68, 0x65, + 0x6c, 0x6c, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_hello_proto_rawDescOnce sync.Once + file_hello_proto_rawDescData = file_hello_proto_rawDesc +) + +func file_hello_proto_rawDescGZIP() []byte { + file_hello_proto_rawDescOnce.Do(func() { + file_hello_proto_rawDescData = protoimpl.X.CompressGZIP(file_hello_proto_rawDescData) + }) + return file_hello_proto_rawDescData +} + +var file_hello_proto_msgTypes = make([]protoimpl.MessageInfo, 1) +var file_hello_proto_goTypes = []interface{}{ + (*HertzReq)(nil), // 0: hertz.HertzReq +} +var file_hello_proto_depIdxs = []int32{ + 0, // [0:0] is the sub-list for method output_type + 0, // [0:0] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_hello_proto_init() } +func file_hello_proto_init() { + if File_hello_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_hello_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*HertzReq); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_hello_proto_rawDesc, + NumEnums: 0, + NumMessages: 1, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_hello_proto_goTypes, + DependencyIndexes: file_hello_proto_depIdxs, + MessageInfos: file_hello_proto_msgTypes, + }.Build() + File_hello_proto = out.File + file_hello_proto_rawDesc = nil + file_hello_proto_goTypes = nil + file_hello_proto_depIdxs = nil +} diff --git a/pkg/app/server/binding/testdata/hello.proto b/pkg/app/server/binding/testdata/hello.proto new file mode 100644 index 000000000..e880c3fec --- /dev/null +++ b/pkg/app/server/binding/testdata/hello.proto @@ -0,0 +1,24 @@ +/* + * Copyright 2023 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +syntax = "proto3"; +package hertz; +option go_package = "hertz/hello"; + +message HertzReq { + string Name = 1; +} + diff --git a/pkg/app/server/binding/validator.go b/pkg/app/server/binding/validator.go new file mode 100644 index 000000000..0939b7aef --- /dev/null +++ b/pkg/app/server/binding/validator.go @@ -0,0 +1,46 @@ +/* + * Copyright 2023 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * The MIT License (MIT) + * + * Copyright (c) 2014 Manuel Martínez-Almeida + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + * This file may have been modified by CloudWeGo authors. All CloudWeGo + * Modifications are Copyright 2023 CloudWeGo Authors + */ + +package binding + +type StructValidator interface { + ValidateStruct(interface{}) error + Engine() interface{} +} diff --git a/pkg/app/server/binding/validator_test.go b/pkg/app/server/binding/validator_test.go new file mode 100644 index 000000000..2f85716b5 --- /dev/null +++ b/pkg/app/server/binding/validator_test.go @@ -0,0 +1,35 @@ +/* + * Copyright 2023 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package binding + +import ( + "testing" +) + +func Test_ValidateStruct(t *testing.T) { + type User struct { + Age int `vd:"$>=0&&$<=130"` + } + + user := &User{ + Age: 135, + } + err := DefaultValidator().ValidateStruct(user) + if err == nil { + t.Fatalf("expected an error, but got nil") + } +} diff --git a/pkg/app/server/hertz_test.go b/pkg/app/server/hertz_test.go index a5cf7d350..c70ff7340 100644 --- a/pkg/app/server/hertz_test.go +++ b/pkg/app/server/hertz_test.go @@ -34,6 +34,7 @@ import ( "github.com/cloudwego/hertz/pkg/app" c "github.com/cloudwego/hertz/pkg/app/client" + "github.com/cloudwego/hertz/pkg/app/server/binding" "github.com/cloudwego/hertz/pkg/app/server/registry" "github.com/cloudwego/hertz/pkg/common/config" errs "github.com/cloudwego/hertz/pkg/common/errors" @@ -47,6 +48,7 @@ import ( "github.com/cloudwego/hertz/pkg/protocol/consts" "github.com/cloudwego/hertz/pkg/protocol/http1/req" "github.com/cloudwego/hertz/pkg/protocol/http1/resp" + "github.com/cloudwego/hertz/pkg/route/param" ) func TestHertz_Run(t *testing.T) { @@ -693,7 +695,7 @@ type CloseWithoutResetBuffer interface { func TestOnprepare(t *testing.T) { h1 := New( - WithHostPorts("localhost:9229"), + WithHostPorts("localhost:9333"), WithOnConnect(func(ctx context.Context, conn network.Conn) context.Context { b, err := conn.Peek(3) assert.Nil(t, err) @@ -711,7 +713,7 @@ func TestOnprepare(t *testing.T) { go h1.Spin() time.Sleep(time.Second) - _, _, err := c.Get(context.Background(), nil, "http://127.0.0.1:9229/ping") + _, _, err := c.Get(context.Background(), nil, "http://127.0.0.1:9333/ping") assert.DeepEqual(t, "the server closed connection before returning the first response byte. Make sure the server returns 'Connection: close' response header before closing the connection", err.Error()) h2 := New( @@ -719,13 +721,13 @@ func TestOnprepare(t *testing.T) { conn.Close() return context.Background() }), - WithHostPorts("localhost:9230")) + WithHostPorts("localhost:9331")) h2.GET("/ping", func(ctx context.Context, c *app.RequestContext) { c.JSON(consts.StatusOK, utils.H{"ping": "pong"}) }) go h2.Spin() time.Sleep(time.Second) - _, _, err = c.Get(context.Background(), nil, "http://127.0.0.1:9230/ping") + _, _, err = c.Get(context.Background(), nil, "http://127.0.0.1:9331/ping") if err == nil { t.Fatalf("err should not be nil") } @@ -820,3 +822,170 @@ func TestHertzDisableHeaderNamesNormalizing(t *testing.T) { assert.Nil(t, err) assert.DeepEqual(t, headerValue, res.Header.Get(headerName)) } + +func TestBindConfig(t *testing.T) { + type Req struct { + A int `query:"a"` + } + bindConfig := binding.NewBindConfig() + bindConfig.LooseZeroMode = true + h := New( + WithHostPorts("localhost:9332"), + WithBindConfig(bindConfig)) + h.GET("/bind", func(c context.Context, ctx *app.RequestContext) { + var req Req + err := ctx.BindAndValidate(&req) + if err != nil { + t.Fatal("unexpected error") + } + }) + + go h.Spin() + time.Sleep(100 * time.Millisecond) + hc := http.Client{Timeout: time.Second} + _, err := hc.Get("http://127.0.0.1:9332/bind?a=") + assert.Nil(t, err) + + bindConfig = binding.NewBindConfig() + bindConfig.LooseZeroMode = false + h2 := New( + WithHostPorts("localhost:9448"), + WithBindConfig(bindConfig)) + h2.GET("/bind", func(c context.Context, ctx *app.RequestContext) { + var req Req + err := ctx.BindAndValidate(&req) + if err == nil { + t.Fatal("expect an error") + } + }) + + go h2.Spin() + time.Sleep(100 * time.Millisecond) + + _, err = hc.Get("http://127.0.0.1:9448/bind?a=") + assert.Nil(t, err) + time.Sleep(100 * time.Millisecond) +} + +type mockBinder struct{} + +func (m *mockBinder) Name() string { + return "test binder" +} + +func (m *mockBinder) Bind(request *protocol.Request, i interface{}, params param.Params) error { + return nil +} + +func (m *mockBinder) BindAndValidate(request *protocol.Request, i interface{}, params param.Params) error { + return fmt.Errorf("test binder") +} + +func (m *mockBinder) BindQuery(request *protocol.Request, i interface{}) error { + return nil +} + +func (m *mockBinder) BindHeader(request *protocol.Request, i interface{}) error { + return nil +} + +func (m *mockBinder) BindPath(request *protocol.Request, i interface{}, params param.Params) error { + return nil +} + +func (m *mockBinder) BindForm(request *protocol.Request, i interface{}) error { + return nil +} + +func (m *mockBinder) BindJSON(request *protocol.Request, i interface{}) error { + return nil +} + +func (m *mockBinder) BindProtobuf(request *protocol.Request, i interface{}) error { + return nil +} + +func TestCustomBinder(t *testing.T) { + type Req struct { + A int `query:"a"` + } + h := New( + WithHostPorts("localhost:9334"), + WithCustomBinder(&mockBinder{})) + h.GET("/bind", func(c context.Context, ctx *app.RequestContext) { + var req Req + err := ctx.BindAndValidate(&req) + if err == nil { + t.Fatal("expect an error") + } + assert.DeepEqual(t, "test binder", err.Error()) + }) + + go h.Spin() + time.Sleep(100 * time.Millisecond) + hc := http.Client{Timeout: time.Second} + _, err := hc.Get("http://127.0.0.1:9334/bind?a=") + assert.Nil(t, err) + time.Sleep(100 * time.Millisecond) +} + +func TestValidateConfig(t *testing.T) { + type Req struct { + A int `query:"a" vd:"f($)"` + } + validateConfig := &binding.ValidateConfig{} + validateConfig.MustRegValidateFunc("f", func(args ...interface{}) error { + return fmt.Errorf("test validator") + }) + h := New( + WithHostPorts("localhost:9229")) + h.GET("/bind", func(c context.Context, ctx *app.RequestContext) { + var req Req + err := ctx.BindAndValidate(&req) + if err == nil { + t.Fatal("expect an error") + } + assert.DeepEqual(t, "test validator", err.Error()) + }) + + go h.Spin() + time.Sleep(100 * time.Millisecond) + hc := http.Client{Timeout: time.Second} + _, err := hc.Get("http://127.0.0.1:9229/bind?a=2") + assert.Nil(t, err) + time.Sleep(100 * time.Millisecond) +} + +type mockValidator struct{} + +func (m *mockValidator) ValidateStruct(interface{}) error { + return fmt.Errorf("test mock validator") +} + +func (m *mockValidator) Engine() interface{} { + return nil +} + +func TestCustomValidator(t *testing.T) { + type Req struct { + A int `query:"a" vd:"f($)"` + } + h := New( + WithHostPorts("localhost:9555"), + WithCustomValidator(&mockValidator{})) + h.GET("/bind", func(c context.Context, ctx *app.RequestContext) { + var req Req + err := ctx.BindAndValidate(&req) + if err == nil { + t.Fatal("expect an error") + } + assert.DeepEqual(t, "test mock validator", err.Error()) + }) + + go h.Spin() + time.Sleep(100 * time.Millisecond) + hc := http.Client{Timeout: time.Second} + _, err := hc.Get("http://127.0.0.1:9555/bind?a=2") + assert.Nil(t, err) + time.Sleep(100 * time.Millisecond) +} diff --git a/pkg/app/server/option.go b/pkg/app/server/option.go index c9e3735be..d94a7b0cc 100644 --- a/pkg/app/server/option.go +++ b/pkg/app/server/option.go @@ -23,6 +23,7 @@ import ( "strings" "time" + "github.com/cloudwego/hertz/pkg/app/server/binding" "github.com/cloudwego/hertz/pkg/app/server/registry" "github.com/cloudwego/hertz/pkg/common/config" "github.com/cloudwego/hertz/pkg/common/tracer" @@ -347,6 +348,27 @@ func WithOnConnect(fn func(ctx context.Context, conn network.Conn) context.Conte }} } +// WithBindConfig sets bind config. +func WithBindConfig(bc *binding.BindConfig) config.Option { + return config.Option{F: func(o *config.Options) { + o.BindConfig = bc + }} +} + +// WithCustomBinder sets customized Binder. +func WithCustomBinder(b binding.Binder) config.Option { + return config.Option{F: func(o *config.Options) { + o.CustomBinder = b + }} +} + +// WithCustomValidator sets customized Binder. +func WithCustomValidator(b binding.StructValidator) config.Option { + return config.Option{F: func(o *config.Options) { + o.CustomValidator = b + }} +} + // WithDisableHeaderNamesNormalizing is used to set whether disable header names normalizing. func WithDisableHeaderNamesNormalizing(disable bool) config.Option { return config.Option{F: func(o *config.Options) { diff --git a/pkg/common/config/option.go b/pkg/common/config/option.go index 9ef7ddf42..d8e6de2d0 100644 --- a/pkg/common/config/option.go +++ b/pkg/common/config/option.go @@ -72,6 +72,9 @@ type Options struct { Tracers []interface{} TraceLevel interface{} ListenConfig *net.ListenConfig + BindConfig interface{} + CustomBinder interface{} + CustomValidator interface{} // TransporterNewer is the function to create a transporter. TransporterNewer func(opt *Options) network.Transporter diff --git a/pkg/common/config/option_test.go b/pkg/common/config/option_test.go index 39d92d736..6ee0fee95 100644 --- a/pkg/common/config/option_test.go +++ b/pkg/common/config/option_test.go @@ -53,6 +53,9 @@ func TestDefaultOptions(t *testing.T) { assert.DeepEqual(t, []interface{}{}, options.Tracers) assert.DeepEqual(t, new(interface{}), options.TraceLevel) assert.DeepEqual(t, registry.NoopRegistry, options.Registry) + assert.Nil(t, options.BindConfig) + assert.Nil(t, options.CustomBinder) + assert.Nil(t, options.CustomValidator) assert.DeepEqual(t, false, options.DisableHeaderNamesNormalizing) } diff --git a/pkg/common/utils/utils.go b/pkg/common/utils/utils.go index f002f2964..68778a468 100644 --- a/pkg/common/utils/utils.go +++ b/pkg/common/utils/utils.go @@ -117,3 +117,12 @@ func NextLine(b []byte) ([]byte, []byte, error) { } return b[:n], b[nNext+1:], nil } + +func FilterContentType(content string) string { + for i, char := range content { + if char == ' ' || char == ';' { + return content[:i] + } + } + return content +} diff --git a/pkg/common/utils/utils_test.go b/pkg/common/utils/utils_test.go index 231ad930f..92873b51d 100644 --- a/pkg/common/utils/utils_test.go +++ b/pkg/common/utils/utils_test.go @@ -136,3 +136,9 @@ func TestUtilsNextLine(t *testing.T) { _, _, sErr = NextLine(singleHeaderStr) assert.DeepEqual(t, errNeedMore, sErr) } + +func TestFilterContentType(t *testing.T) { + contentType := "text/plain; charset=utf-8" + contentType = FilterContentType(contentType) + assert.DeepEqual(t, "text/plain", contentType) +} diff --git a/pkg/protocol/consts/headers.go b/pkg/protocol/consts/headers.go index 3c2b82e7e..e4b7b316b 100644 --- a/pkg/protocol/consts/headers.go +++ b/pkg/protocol/consts/headers.go @@ -96,6 +96,7 @@ const ( MIMETextHtml = "text/html" MIMETextCss = "text/css" MIMETextJavascript = "text/javascript" + MIMEMultipartPOSTForm = "multipart/form-data" // MIME application MIMEApplicationOctetStream = "application/octet-stream" @@ -121,6 +122,7 @@ const ( MIMEApplicationOpenXMLWord = "application/vnd.openxmlformats-officedocument.wordprocessingml.document" MIMEApplicationOpenXMLExcel = "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" MIMEApplicationOpenXMLPPT = "application/vnd.openxmlformats-officedocument.presentationml.presentation" + MIMEPROTOBUF = "application/x-protobuf" // MIME image MIMEImageJPEG = "image/jpeg" diff --git a/pkg/route/engine.go b/pkg/route/engine.go index bd8fbc1a9..ff8cf5a7e 100644 --- a/pkg/route/engine.go +++ b/pkg/route/engine.go @@ -59,6 +59,7 @@ import ( "github.com/cloudwego/hertz/internal/nocopy" internalStats "github.com/cloudwego/hertz/internal/stats" "github.com/cloudwego/hertz/pkg/app" + "github.com/cloudwego/hertz/pkg/app/server/binding" "github.com/cloudwego/hertz/pkg/app/server/render" "github.com/cloudwego/hertz/pkg/common/config" errs "github.com/cloudwego/hertz/pkg/common/errors" @@ -192,6 +193,10 @@ type Engine struct { // Custom Functions clientIPFunc app.ClientIP formValueFunc app.FormValueFunc + + // Custom Binder and Validator + binder binding.Binder + validator binding.StructValidator } func (engine *Engine) IsTraceEnable() bool { @@ -551,6 +556,41 @@ func (engine *Engine) ServeStream(ctx context.Context, conn network.StreamConn) return errs.ErrNotSupportProtocol } +func (engine *Engine) initBinderAndValidator(opt *config.Options) { + // init validator + engine.validator = binding.DefaultValidator() + if opt.CustomValidator != nil { + customValidator, ok := opt.CustomValidator.(binding.StructValidator) + if !ok { + panic("customized validator can not implement binding.StructValidator") + } + engine.validator = customValidator + } + + if opt.CustomBinder != nil { + customBinder, ok := opt.CustomBinder.(binding.Binder) + if !ok { + panic("customized binder can not implement binding.Binder") + } + engine.binder = customBinder + return + } + // Init binder. Due to the existence of the "BindAndValidate" interface, the Validator needs to be injected here. + defaultBindConfig := binding.NewBindConfig() + defaultBindConfig.Validator = engine.validator + engine.binder = binding.NewDefaultBinder(defaultBindConfig) + if opt.BindConfig != nil { + bConf, ok := opt.BindConfig.(*binding.BindConfig) + if !ok { + panic("bind config error") + } + if bConf.Validator == nil { + bConf.Validator = engine.validator + } + engine.binder = binding.NewDefaultBinder(bConf) + } +} + func NewEngine(opt *config.Options) *Engine { engine := &Engine{ trees: make(MethodTrees, 0, 9), @@ -566,6 +606,7 @@ func NewEngine(opt *config.Options) *Engine { enableTrace: true, options: opt, } + engine.initBinderAndValidator(opt) if opt.TransporterNewer != nil { engine.transport = opt.TransporterNewer(opt) } @@ -665,6 +706,8 @@ func (engine *Engine) recv(ctx *app.RequestContext) { // ServeHTTP makes the router implement the Handler interface. func (engine *Engine) ServeHTTP(c context.Context, ctx *app.RequestContext) { + ctx.SetBinder(engine.binder) + ctx.SetValidator(engine.validator) if engine.PanicHandler != nil { defer engine.recv(ctx) } diff --git a/pkg/route/engine_test.go b/pkg/route/engine_test.go index 350da8177..65f4fb16a 100644 --- a/pkg/route/engine_test.go +++ b/pkg/route/engine_test.go @@ -54,13 +54,16 @@ import ( "time" "github.com/cloudwego/hertz/pkg/app" + "github.com/cloudwego/hertz/pkg/app/server/binding" "github.com/cloudwego/hertz/pkg/common/config" errs "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/common/test/mock" "github.com/cloudwego/hertz/pkg/network" "github.com/cloudwego/hertz/pkg/network/standard" + "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/protocol/consts" + "github.com/cloudwego/hertz/pkg/route/param" ) func TestNew_Engine(t *testing.T) { @@ -623,3 +626,181 @@ func (f *fakeTransporter) ListenAndServe(onData network.OnData) error { // TODO implement me panic("implement me") } + +type mockBinder struct{} + +func (m *mockBinder) Name() string { + return "test binder" +} + +func (m *mockBinder) Bind(request *protocol.Request, i interface{}, params param.Params) error { + return nil +} + +func (m *mockBinder) BindAndValidate(request *protocol.Request, i interface{}, params param.Params) error { + return nil +} + +func (m *mockBinder) BindQuery(request *protocol.Request, i interface{}) error { + return nil +} + +func (m *mockBinder) BindHeader(request *protocol.Request, i interface{}) error { + return nil +} + +func (m *mockBinder) BindPath(request *protocol.Request, i interface{}, params param.Params) error { + return nil +} + +func (m *mockBinder) BindForm(request *protocol.Request, i interface{}) error { + return nil +} + +func (m *mockBinder) BindJSON(request *protocol.Request, i interface{}) error { + return nil +} + +func (m *mockBinder) BindProtobuf(request *protocol.Request, i interface{}) error { + return nil +} + +type mockValidator struct{} + +func (m *mockValidator) ValidateStruct(interface{}) error { + return fmt.Errorf("test mock") +} + +func (m *mockValidator) Engine() interface{} { + return nil +} + +type mockNonValidator struct{} + +func (m *mockNonValidator) ValidateStruct(interface{}) error { + return fmt.Errorf("test mock") +} + +func TestInitBinderAndValidator(t *testing.T) { + defer func() { + if r := recover(); r != nil { + t.Errorf("unexpected panic, %v", r) + } + }() + opt := config.NewOptions([]config.Option{}) + bindConfig := binding.NewBindConfig() + bindConfig.LooseZeroMode = true + opt.BindConfig = bindConfig + binder := &mockBinder{} + opt.CustomBinder = binder + validator := &mockValidator{} + opt.CustomValidator = validator + NewEngine(opt) +} + +func TestInitBinderAndValidatorForPanic(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("expect a panic, but get nil") + } + }() + opt := config.NewOptions([]config.Option{}) + bindConfig := binding.NewBindConfig() + bindConfig.LooseZeroMode = true + opt.BindConfig = bindConfig + binder := &mockBinder{} + opt.CustomBinder = binder + nonValidator := &mockNonValidator{} + opt.CustomValidator = nonValidator + NewEngine(opt) +} + +func TestBindConfig(t *testing.T) { + type Req struct { + A int `query:"a"` + } + opt := config.NewOptions([]config.Option{}) + bindConfig := binding.NewBindConfig() + bindConfig.LooseZeroMode = false + opt.BindConfig = bindConfig + e := NewEngine(opt) + e.GET("/bind", func(c context.Context, ctx *app.RequestContext) { + var req Req + err := ctx.BindAndValidate(&req) + if err == nil { + t.Fatal("expect an error") + } + }) + performRequest(e, "GET", "/bind?a=") + + bindConfig = binding.NewBindConfig() + bindConfig.LooseZeroMode = true + opt.BindConfig = bindConfig + e = NewEngine(opt) + e.GET("/bind", func(c context.Context, ctx *app.RequestContext) { + var req Req + err := ctx.BindAndValidate(&req) + if err != nil { + t.Fatal("unexpected error") + } + assert.DeepEqual(t, 0, req.A) + }) + performRequest(e, "GET", "/bind?a=") +} + +func TestCustomBinder(t *testing.T) { + type Req struct { + A int `query:"a"` + } + opt := config.NewOptions([]config.Option{}) + opt.CustomBinder = &mockBinder{} + e := NewEngine(opt) + e.GET("/bind", func(c context.Context, ctx *app.RequestContext) { + var req Req + err := ctx.BindAndValidate(&req) + if err != nil { + t.Fatal("unexpected error") + } + assert.NotEqual(t, 2, req.A) + }) + performRequest(e, "GET", "/bind?a=2") +} + +func TestValidateConfig(t *testing.T) { + type Req struct { + A int `query:"a" vd:"f($)"` + } + opt := config.NewOptions([]config.Option{}) + validateConfig := &binding.ValidateConfig{} + validateConfig.MustRegValidateFunc("f", func(args ...interface{}) error { + return fmt.Errorf("test error") + }) + e := NewEngine(opt) + e.GET("/validate", func(c context.Context, ctx *app.RequestContext) { + var req Req + err := ctx.BindAndValidate(&req) + assert.NotNil(t, err) + assert.DeepEqual(t, "test error", err.Error()) + }) + performRequest(e, "GET", "/validate?a=2") +} + +func TestCustomValidator(t *testing.T) { + type Req struct { + A int `query:"a" vd:"d($)"` + } + opt := config.NewOptions([]config.Option{}) + validateConfig := &binding.ValidateConfig{} + validateConfig.MustRegValidateFunc("d", func(args ...interface{}) error { + return fmt.Errorf("test error") + }) + opt.CustomValidator = &mockValidator{} + e := NewEngine(opt) + e.GET("/validate", func(c context.Context, ctx *app.RequestContext) { + var req Req + err := ctx.BindAndValidate(&req) + assert.NotNil(t, err) + assert.DeepEqual(t, "test mock", err.Error()) + }) + performRequest(e, "GET", "/validate?a=2") +} From a924461975fba97e4283d937c28ca80a323a28fa Mon Sep 17 00:00:00 2001 From: cqqqq777 <115757192+cqqqq777@users.noreply.github.com> Date: Fri, 22 Sep 2023 01:49:23 -0500 Subject: [PATCH 16/20] docs: amend signature (#954) --- pkg/app/context.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pkg/app/context.go b/pkg/app/context.go index c607ce73b..46aeee7b3 100644 --- a/pkg/app/context.go +++ b/pkg/app/context.go @@ -535,8 +535,8 @@ func (ctx *RequestContext) String(code int, format string, values ...interface{} // FullPath returns a matched route full path. For not found routes // returns an empty string. // -// router.GET("/user/:id", func(c *hertz.RequestContext) { -// c.FullPath() == "/user/:id" // true +// router.GET("/user/:id", func(c context.Context, ctx *app.RequestContext) { +// ctx.FullPath() == "/user/:id" // true // }) func (ctx *RequestContext) FullPath() string { return ctx.fullPath @@ -989,9 +989,9 @@ func (ctx *RequestContext) GetStringMapStringSlice(key string) (smss map[string] // Param returns the value of the URL param. // It is a shortcut for c.Params.ByName(key) // -// router.GET("/user/:id", func(c *hertz.RequestContext) { +// router.GET("/user/:id", func(c context.Context, ctx *app.RequestContext) { // // a GET request to /user/john -// id := c.Param("id") // id == "john" +// id := ctx.Param("id") // id == "john" // }) func (ctx *RequestContext) Param(key string) string { return ctx.Params.ByName(key) From 8bc14077cf3f382c4d3b6b47543cb72c24610903 Mon Sep 17 00:00:00 2001 From: GuangyuFan <97507466+FGYFFFF@users.noreply.github.com> Date: Tue, 26 Sep 2023 15:44:46 +0800 Subject: [PATCH 17/20] refactor: validate config (#955) --- pkg/app/context_test.go | 6 +- pkg/app/server/binding/binder_test.go | 55 +++++++++++++++ pkg/app/server/binding/config.go | 29 ++++---- pkg/app/server/binding/default.go | 85 +++++++++++++++++------- pkg/app/server/binding/validator.go | 1 + pkg/app/server/binding/validator_test.go | 29 ++++++++ pkg/app/server/hertz_test.go | 79 +++++++++++++++++++++- pkg/app/server/option.go | 7 ++ pkg/common/config/option.go | 1 + pkg/common/config/option_test.go | 1 + pkg/route/engine.go | 14 +++- pkg/route/engine_test.go | 52 ++++++++++++++- 12 files changed, 314 insertions(+), 45 deletions(-) diff --git a/pkg/app/context_test.go b/pkg/app/context_test.go index 22e7f8608..c065d482c 100644 --- a/pkg/app/context_test.go +++ b/pkg/app/context_test.go @@ -884,11 +884,15 @@ func (m *mockValidator) Engine() interface{} { return nil } +func (m *mockValidator) ValidateTag() string { + return "vt" +} + func TestSetValidator(t *testing.T) { m := &mockValidator{} c := NewContext(0) c.SetValidator(m) - c.SetBinder(binding.NewDefaultBinder(&binding.BindConfig{ValidateTag: "vt"})) + c.SetBinder(binding.NewDefaultBinder(&binding.BindConfig{Validator: m})) type User struct { Age int `vt:"$>=0&&$<=130"` } diff --git a/pkg/app/server/binding/binder_test.go b/pkg/app/server/binding/binder_test.go index d106ed7ad..c971776e3 100644 --- a/pkg/app/server/binding/binder_test.go +++ b/pkg/app/server/binding/binder_test.go @@ -1436,6 +1436,61 @@ func Test_BindHeaderNormalize(t *testing.T) { assert.DeepEqual(t, "", result3.Header) } +type ValidateError struct { + ErrType, FailField, Msg string +} + +// Error implements error interface. +func (e *ValidateError) Error() string { + if e.Msg != "" { + return e.ErrType + ": expr_path=" + e.FailField + ", cause=" + e.Msg + } + return e.ErrType + ": expr_path=" + e.FailField + ", cause=invalid" +} + +func Test_ValidatorErrorFactory(t *testing.T) { + type TestBind struct { + A string `query:"a,required"` + } + + r := protocol.NewRequest("GET", "/foo", nil) + r.SetRequestURI("/foo/bar?b=20") + CustomValidateErrFunc := func(failField, msg string) error { + err := ValidateError{ + ErrType: "validateErr", + FailField: "[validateFailField]: " + failField, + Msg: "[validateErrMsg]: " + msg, + } + + return &err + } + + validateConfig := NewValidateConfig() + validateConfig.SetValidatorErrorFactory(CustomValidateErrFunc) + validator := NewValidator(validateConfig) + + var req TestBind + err := Bind(r, &req, nil) + if err == nil { + t.Fatalf("unexpected nil, expected an error") + } + + type TestValidate struct { + B int `query:"b" vd:"$>100"` + } + + var reqValidate TestValidate + err = Bind(r, &reqValidate, nil) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + err = validator.ValidateStruct(&reqValidate) + if err == nil { + t.Fatalf("unexpected nil, expected an error") + } + assert.DeepEqual(t, "validateErr: expr_path=[validateFailField]: B, cause=[validateErrMsg]: ", err.Error()) +} + func Benchmark_Binding(b *testing.B) { type Req struct { Version string `path:"v"` diff --git a/pkg/app/server/binding/config.go b/pkg/app/server/binding/config.go index c122c54c6..81cf30e56 100644 --- a/pkg/app/server/binding/config.go +++ b/pkg/app/server/binding/config.go @@ -22,7 +22,7 @@ import ( "reflect" "time" - "github.com/bytedance/go-tagexpr/v2/validator" + exprValidator "github.com/bytedance/go-tagexpr/v2/validator" inDecoder "github.com/cloudwego/hertz/pkg/app/server/binding/internal/decoder" hJson "github.com/cloudwego/hertz/pkg/common/json" "github.com/cloudwego/hertz/pkg/protocol" @@ -63,10 +63,6 @@ type BindConfig struct { // The default is false. // It is used for BindJSON(). EnableDecoderDisallowUnknownFields bool - // ValidateTag is used to determine if a filed needs to be validated. - // NOTE: - // The default is "vd". - ValidateTag string // TypeUnmarshalFuncs registers customized type unmarshaler. // NOTE: // time.Time is registered by default @@ -82,7 +78,6 @@ func NewBindConfig() *BindConfig { DisableStructFieldResolve: false, EnableDecoderUseNumber: false, EnableDecoderDisallowUnknownFields: false, - ValidateTag: "vd", TypeUnmarshalFuncs: make(map[reflect.Type]inDecoder.CustomizeDecodeFunc), Validator: defaultValidate, } @@ -145,7 +140,12 @@ func (config *BindConfig) UseStdJSONUnmarshaler() { config.UseThirdPartyJSONUnmarshaler(stdJson.Unmarshal) } -type ValidateConfig struct{} +type ValidateErrFactory func(fieldSelector, msg string) error + +type ValidateConfig struct { + ValidateTag string + ErrFactory ValidateErrFactory +} func NewValidateConfig() *ValidateConfig { return &ValidateConfig{} @@ -157,14 +157,15 @@ func NewValidateConfig() *ValidateConfig { // If force=true, allow to cover the existed same funcName. // MustRegValidateFunc will remain in effect once it has been called. func (config *ValidateConfig) MustRegValidateFunc(funcName string, fn func(args ...interface{}) error, force ...bool) { - validator.MustRegFunc(funcName, fn, force...) + exprValidator.MustRegFunc(funcName, fn, force...) } // SetValidatorErrorFactory customizes the factory of validation error. -func (config *ValidateConfig) SetValidatorErrorFactory(validatingErrFactory func(failField, msg string) error) { - if val, ok := DefaultValidator().(*defaultValidator); ok { - val.validate.SetErrorFactory(validatingErrFactory) - } else { - panic("customized validator can not use 'SetValidatorErrorFactory'") - } +func (config *ValidateConfig) SetValidatorErrorFactory(errFactory ValidateErrFactory) { + config.ErrFactory = errFactory +} + +// SetValidatorTag customizes the factory of validation error. +func (config *ValidateConfig) SetValidatorTag(tag string) { + config.ValidateTag = tag } diff --git a/pkg/app/server/binding/default.go b/pkg/app/server/binding/default.go index 28bbc5311..0634f26cf 100644 --- a/pkg/app/server/binding/default.go +++ b/pkg/app/server/binding/default.go @@ -69,7 +69,7 @@ import ( "reflect" "sync" - "github.com/bytedance/go-tagexpr/v2/validator" + exprValidator "github.com/bytedance/go-tagexpr/v2/validator" "github.com/cloudwego/hertz/internal/bytesconv" inDecoder "github.com/cloudwego/hertz/pkg/app/server/binding/internal/decoder" hJson "github.com/cloudwego/hertz/pkg/common/json" @@ -81,10 +81,11 @@ import ( ) const ( - queryTag = "query" - headerTag = "header" - formTag = "form" - pathTag = "path" + queryTag = "query" + headerTag = "header" + formTag = "form" + pathTag = "path" + defaultValidateTag = "vd" ) type decoderInfo struct { @@ -185,14 +186,17 @@ func (b *defaultBinder) bindTag(req *protocol.Request, v interface{}, params par decoder := cached.(decoderInfo) return decoder.decoder(req, params, rv.Elem()) } - + validateTag := defaultValidateTag + if len(b.config.Validator.ValidateTag()) != 0 { + validateTag = b.config.Validator.ValidateTag() + } decodeConfig := &inDecoder.DecodeConfig{ LooseZeroMode: b.config.LooseZeroMode, DisableDefaultTag: b.config.DisableDefaultTag, DisableStructFieldResolve: b.config.DisableStructFieldResolve, EnableDecoderUseNumber: b.config.EnableDecoderUseNumber, EnableDecoderDisallowUnknownFields: b.config.EnableDecoderDisallowUnknownFields, - ValidateTag: b.config.ValidateTag, + ValidateTag: validateTag, TypeUnmarshalFuncs: b.config.TypeUnmarshalFuncs, } decoder, needValidate, err := inDecoder.GetReqDecoder(rv.Type(), tag, decodeConfig) @@ -232,13 +236,17 @@ func (b *defaultBinder) bindTagWithValidate(req *protocol.Request, v interface{} } return err } + validateTag := defaultValidateTag + if len(b.config.Validator.ValidateTag()) != 0 { + validateTag = b.config.Validator.ValidateTag() + } decodeConfig := &inDecoder.DecodeConfig{ LooseZeroMode: b.config.LooseZeroMode, DisableDefaultTag: b.config.DisableDefaultTag, DisableStructFieldResolve: b.config.DisableStructFieldResolve, EnableDecoderUseNumber: b.config.EnableDecoderUseNumber, EnableDecoderDisallowUnknownFields: b.config.EnableDecoderDisallowUnknownFields, - ValidateTag: b.config.ValidateTag, + ValidateTag: validateTag, TypeUnmarshalFuncs: b.config.TypeUnmarshalFuncs, } decoder, needValidate, err := inDecoder.GetReqDecoder(rv.Type(), tag, decodeConfig) @@ -371,39 +379,66 @@ func (b *defaultBinder) bindNonStruct(req *protocol.Request, v interface{}) (err return } -var _ StructValidator = (*defaultValidator)(nil) +var _ StructValidator = (*validator)(nil) + +type validator struct { + validateTag string + validate *exprValidator.Validator +} + +func NewValidator(config *ValidateConfig) StructValidator { + validateTag := defaultValidateTag + if config != nil && len(config.ValidateTag) != 0 { + validateTag = config.ValidateTag + } + vd := exprValidator.New(validateTag).SetErrorFactory(defaultValidateErrorFactory) + if config != nil && config.ErrFactory != nil { + vd.SetErrorFactory(config.ErrFactory) + } + return &validator{ + validateTag: validateTag, + validate: vd, + } +} + +// Error validate error +type validateError struct { + FailPath, Msg string +} -type defaultValidator struct { - once sync.Once - validate *validator.Validator +// Error implements error interface. +func (e *validateError) Error() string { + if e.Msg != "" { + return e.Msg + } + return "invalid parameter: " + e.FailPath } -func NewDefaultValidator(config *ValidateConfig) StructValidator { - return &defaultValidator{} +func defaultValidateErrorFactory(failPath, msg string) error { + return &validateError{ + FailPath: failPath, + Msg: msg, + } } // ValidateStruct receives any kind of type, but only performed struct or pointer to struct type. -func (v *defaultValidator) ValidateStruct(obj interface{}) error { +func (v *validator) ValidateStruct(obj interface{}) error { if obj == nil { return nil } - v.lazyinit() return v.validate.Validate(obj) } -func (v *defaultValidator) lazyinit() { - v.once.Do(func() { - v.validate = validator.Default() - }) -} - // Engine returns the underlying validator -func (v *defaultValidator) Engine() interface{} { - v.lazyinit() +func (v *validator) Engine() interface{} { return v.validate } -var defaultValidate = NewDefaultValidator(nil) +func (v *validator) ValidateTag() string { + return v.validateTag +} + +var defaultValidate = NewValidator(NewValidateConfig()) func DefaultValidator() StructValidator { return defaultValidate diff --git a/pkg/app/server/binding/validator.go b/pkg/app/server/binding/validator.go index 0939b7aef..14d618364 100644 --- a/pkg/app/server/binding/validator.go +++ b/pkg/app/server/binding/validator.go @@ -43,4 +43,5 @@ package binding type StructValidator interface { ValidateStruct(interface{}) error Engine() interface{} + ValidateTag() string } diff --git a/pkg/app/server/binding/validator_test.go b/pkg/app/server/binding/validator_test.go index 2f85716b5..5564282ef 100644 --- a/pkg/app/server/binding/validator_test.go +++ b/pkg/app/server/binding/validator_test.go @@ -33,3 +33,32 @@ func Test_ValidateStruct(t *testing.T) { t.Fatalf("expected an error, but got nil") } } + +func Test_ValidateTag(t *testing.T) { + type User struct { + Age int `query:"age" vt:"$>=0&&$<=130"` + } + + user := &User{ + Age: 135, + } + validateConfig := NewValidateConfig() + validateConfig.ValidateTag = "vt" + vd := NewValidator(validateConfig) + err := vd.ValidateStruct(user) + if err == nil { + t.Fatalf("expected an error, but got nil") + } + + bindConfig := NewBindConfig() + bindConfig.Validator = vd + binder := NewDefaultBinder(bindConfig) + user = &User{} + req := newMockRequest(). + SetRequestURI("http://foobar.com?age=135"). + SetHeaders("h", "header") + err = binder.BindAndValidate(req.Req, user, nil) + if err == nil { + t.Fatalf("expected an error, but got nil") + } +} diff --git a/pkg/app/server/hertz_test.go b/pkg/app/server/hertz_test.go index c70ff7340..5baa3ecf5 100644 --- a/pkg/app/server/hertz_test.go +++ b/pkg/app/server/hertz_test.go @@ -929,7 +929,7 @@ func TestCustomBinder(t *testing.T) { time.Sleep(100 * time.Millisecond) } -func TestValidateConfig(t *testing.T) { +func TestValidateConfigRegValidateFunc(t *testing.T) { type Req struct { A int `query:"a" vd:"f($)"` } @@ -966,6 +966,10 @@ func (m *mockValidator) Engine() interface{} { return nil } +func (m *mockValidator) ValidateTag() string { + return "vd" +} + func TestCustomValidator(t *testing.T) { type Req struct { A int `query:"a" vd:"f($)"` @@ -989,3 +993,76 @@ func TestCustomValidator(t *testing.T) { assert.Nil(t, err) time.Sleep(100 * time.Millisecond) } + +type ValidateError struct { + ErrType, FailField, Msg string +} + +// Error implements error interface. +func (e *ValidateError) Error() string { + if e.Msg != "" { + return e.ErrType + ": expr_path=" + e.FailField + ", cause=" + e.Msg + } + return e.ErrType + ": expr_path=" + e.FailField + ", cause=invalid" +} + +func TestValidateConfigSetSetErrorFactory(t *testing.T) { + type TestValidate struct { + B int `query:"b" vd:"$>100"` + } + CustomValidateErrFunc := func(failField, msg string) error { + err := ValidateError{ + ErrType: "validateErr", + FailField: "[validateFailField]: " + failField, + Msg: "[validateErrMsg]: " + msg, + } + + return &err + } + validateConfig := binding.NewValidateConfig() + validateConfig.SetValidatorErrorFactory(CustomValidateErrFunc) + h := New( + WithHostPorts("localhost:9666"), + WithValidateConfig(validateConfig)) + h.GET("/bind", func(c context.Context, ctx *app.RequestContext) { + var req TestValidate + err := ctx.BindAndValidate(&req) + if err == nil { + t.Fatal("expect an error") + } + assert.DeepEqual(t, "validateErr: expr_path=[validateFailField]: B, cause=[validateErrMsg]: ", err.Error()) + }) + + go h.Spin() + time.Sleep(100 * time.Millisecond) + hc := http.Client{Timeout: time.Second} + _, err := hc.Get("http://127.0.0.1:9666/bind?b=1") + assert.Nil(t, err) + time.Sleep(100 * time.Millisecond) +} + +func TestValidateConfigAndBindConfig(t *testing.T) { + type Req struct { + A int `query:"a" vt:"$>=0&&$<=130"` + } + validateConfig := binding.NewValidateConfig() + validateConfig.ValidateTag = "vt" + h := New( + WithHostPorts("localhost:9876"), + WithValidateConfig(validateConfig)) + h.GET("/bind", func(c context.Context, ctx *app.RequestContext) { + var req Req + err := ctx.BindAndValidate(&req) + if err == nil { + t.Fatal("expect an error") + } + t.Log(err) + }) + + go h.Spin() + time.Sleep(100 * time.Millisecond) + hc := http.Client{Timeout: time.Second} + _, err := hc.Get("http://127.0.0.1:9876/bind?a=135") + assert.Nil(t, err) + time.Sleep(100 * time.Millisecond) +} diff --git a/pkg/app/server/option.go b/pkg/app/server/option.go index d94a7b0cc..fcf380485 100644 --- a/pkg/app/server/option.go +++ b/pkg/app/server/option.go @@ -355,6 +355,13 @@ func WithBindConfig(bc *binding.BindConfig) config.Option { }} } +// WithValidateConfig sets validate config. +func WithValidateConfig(vc *binding.ValidateConfig) config.Option { + return config.Option{F: func(o *config.Options) { + o.ValidateConfig = vc + }} +} + // WithCustomBinder sets customized Binder. func WithCustomBinder(b binding.Binder) config.Option { return config.Option{F: func(o *config.Options) { diff --git a/pkg/common/config/option.go b/pkg/common/config/option.go index d8e6de2d0..048fb366f 100644 --- a/pkg/common/config/option.go +++ b/pkg/common/config/option.go @@ -73,6 +73,7 @@ type Options struct { TraceLevel interface{} ListenConfig *net.ListenConfig BindConfig interface{} + ValidateConfig interface{} CustomBinder interface{} CustomValidator interface{} diff --git a/pkg/common/config/option_test.go b/pkg/common/config/option_test.go index 6ee0fee95..67fcab796 100644 --- a/pkg/common/config/option_test.go +++ b/pkg/common/config/option_test.go @@ -54,6 +54,7 @@ func TestDefaultOptions(t *testing.T) { assert.DeepEqual(t, new(interface{}), options.TraceLevel) assert.DeepEqual(t, registry.NoopRegistry, options.Registry) assert.Nil(t, options.BindConfig) + assert.Nil(t, options.ValidateConfig) assert.Nil(t, options.CustomBinder) assert.Nil(t, options.CustomValidator) assert.DeepEqual(t, false, options.DisableHeaderNamesNormalizing) diff --git a/pkg/route/engine.go b/pkg/route/engine.go index ff8cf5a7e..168c79b48 100644 --- a/pkg/route/engine.go +++ b/pkg/route/engine.go @@ -558,13 +558,21 @@ func (engine *Engine) ServeStream(ctx context.Context, conn network.StreamConn) func (engine *Engine) initBinderAndValidator(opt *config.Options) { // init validator - engine.validator = binding.DefaultValidator() if opt.CustomValidator != nil { customValidator, ok := opt.CustomValidator.(binding.StructValidator) if !ok { - panic("customized validator can not implement binding.StructValidator") + panic("customized validator does not implement binding.StructValidator") } engine.validator = customValidator + } else { + engine.validator = binding.NewValidator(binding.NewValidateConfig()) + if opt.ValidateConfig != nil { + vConf, ok := opt.ValidateConfig.(*binding.ValidateConfig) + if !ok { + panic("opt.ValidateConfig is not the '*binding.ValidateConfig' type") + } + engine.validator = binding.NewValidator(vConf) + } } if opt.CustomBinder != nil { @@ -582,7 +590,7 @@ func (engine *Engine) initBinderAndValidator(opt *config.Options) { if opt.BindConfig != nil { bConf, ok := opt.BindConfig.(*binding.BindConfig) if !ok { - panic("bind config error") + panic("opt.BindConfig is not the '*binding.BindConfig' type") } if bConf.Validator == nil { bConf.Validator = engine.validator diff --git a/pkg/route/engine_test.go b/pkg/route/engine_test.go index 65f4fb16a..37a154bc2 100644 --- a/pkg/route/engine_test.go +++ b/pkg/route/engine_test.go @@ -675,6 +675,10 @@ func (m *mockValidator) Engine() interface{} { return nil } +func (m *mockValidator) ValidateTag() string { + return "vd" +} + type mockNonValidator struct{} func (m *mockNonValidator) ValidateStruct(interface{}) error { @@ -696,6 +700,10 @@ func TestInitBinderAndValidator(t *testing.T) { validator := &mockValidator{} opt.CustomValidator = validator NewEngine(opt) + validateConfig := binding.NewValidateConfig() + opt.ValidateConfig = validateConfig + opt.CustomValidator = nil + NewEngine(opt) } func TestInitBinderAndValidatorForPanic(t *testing.T) { @@ -748,6 +756,48 @@ func TestBindConfig(t *testing.T) { performRequest(e, "GET", "/bind?a=") } +type ValidateError struct { + ErrType, FailField, Msg string +} + +// Error implements error interface. +func (e *ValidateError) Error() string { + if e.Msg != "" { + return e.ErrType + ": expr_path=" + e.FailField + ", cause=" + e.Msg + } + return e.ErrType + ": expr_path=" + e.FailField + ", cause=invalid" +} + +func TestValidateConfigSetErrorFactory(t *testing.T) { + type TestValidate struct { + B int `query:"b" vd:"$>100"` + } + opt := config.NewOptions([]config.Option{}) + CustomValidateErrFunc := func(failField, msg string) error { + err := ValidateError{ + ErrType: "validateErr", + FailField: "[validateFailField]: " + failField, + Msg: "[validateErrMsg]: " + msg, + } + + return &err + } + + validateConfig := binding.NewValidateConfig() + validateConfig.SetValidatorErrorFactory(CustomValidateErrFunc) + opt.ValidateConfig = validateConfig + e := NewEngine(opt) + e.GET("/bind", func(c context.Context, ctx *app.RequestContext) { + var req TestValidate + err := ctx.BindAndValidate(&req) + if err == nil { + t.Fatal("expect an error") + } + assert.DeepEqual(t, "validateErr: expr_path=[validateFailField]: B, cause=[validateErrMsg]: ", err.Error()) + }) + performRequest(e, "GET", "/bind?b=1") +} + func TestCustomBinder(t *testing.T) { type Req struct { A int `query:"a"` @@ -766,7 +816,7 @@ func TestCustomBinder(t *testing.T) { performRequest(e, "GET", "/bind?a=2") } -func TestValidateConfig(t *testing.T) { +func TestValidateRegValidateFunc(t *testing.T) { type Req struct { A int `query:"a" vd:"f($)"` } From 3d88c963bf39c4a6c8f197ff61ef67d9fa8e5276 Mon Sep 17 00:00:00 2001 From: GuangyuFan <97507466+FGYFFFF@users.noreply.github.com> Date: Tue, 26 Sep 2023 16:47:19 +0800 Subject: [PATCH 18/20] chore(hz): release v070 (#959) --- cmd/hz/meta/const.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cmd/hz/meta/const.go b/cmd/hz/meta/const.go index abda7c71c..8a36ba2f6 100644 --- a/cmd/hz/meta/const.go +++ b/cmd/hz/meta/const.go @@ -19,7 +19,7 @@ package meta import "runtime" // Version hz version -const Version = "v0.6.7" +const Version = "v0.7.0" const DefaultServiceName = "hertz_service" From b36830f891a7c720e61cbb755f89f864c0f8fb73 Mon Sep 17 00:00:00 2001 From: GuangyuFan <97507466+FGYFFFF@users.noreply.github.com> Date: Tue, 26 Sep 2023 19:07:08 +0800 Subject: [PATCH 19/20] chore: upgrade netpoll to v0.5.0 (#961) --- go.mod | 2 +- go.sum | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.mod b/go.mod index 25f20f8a7..3619b833d 100644 --- a/go.mod +++ b/go.mod @@ -7,7 +7,7 @@ require ( github.com/bytedance/gopkg v0.0.0-20220413063733-65bf48ffb3a7 github.com/bytedance/mockey v1.2.1 github.com/bytedance/sonic v1.8.1 - github.com/cloudwego/netpoll v0.4.2-0.20230807055039-52fd5fb7b00f + github.com/cloudwego/netpoll v0.5.0 github.com/fsnotify/fsnotify v1.5.4 github.com/tidwall/gjson v1.14.4 golang.org/x/sync v0.0.0-20210220032951-036812b2e83c diff --git a/go.sum b/go.sum index f86006603..3cf3c2f96 100644 --- a/go.sum +++ b/go.sum @@ -10,8 +10,8 @@ github.com/bytedance/sonic v1.8.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZX github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= -github.com/cloudwego/netpoll v0.4.2-0.20230807055039-52fd5fb7b00f h1:8iWPKjHdXl4tjcSxUJTavnhRL5JPupYvxbtsAlm2Igw= -github.com/cloudwego/netpoll v0.4.2-0.20230807055039-52fd5fb7b00f/go.mod h1:xVefXptcyheopwNDZjDPcfU6kIjZXZ4nY550k1yH9eQ= +github.com/cloudwego/netpoll v0.5.0 h1:oRrOp58cPCvK2QbMozZNDESvrxQaEHW2dCimmwH1lcU= +github.com/cloudwego/netpoll v0.5.0/go.mod h1:xVefXptcyheopwNDZjDPcfU6kIjZXZ4nY550k1yH9eQ= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= From 40a3d0e7d1f0952883895ba201ea31a452c38465 Mon Sep 17 00:00:00 2001 From: alice <90381261+alice-yyds@users.noreply.github.com> Date: Tue, 26 Sep 2023 19:07:41 +0800 Subject: [PATCH 20/20] chore: update version v0.7.0 --- version.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/version.go b/version.go index 279360628..2db9ab991 100644 --- a/version.go +++ b/version.go @@ -19,5 +19,5 @@ package hertz // Name and Version info of this framework, used for statistics and debug const ( Name = "Hertz" - Version = "v0.6.8" + Version = "v0.7.0" )