From 15c89132fb6540efe74fc5b8dab32ea16f57f22b Mon Sep 17 00:00:00 2001 From: fgy Date: Mon, 9 Jan 2023 14:34:07 +0800 Subject: [PATCH 01/91] refactor: binding --- .../server/binding_v2/base_type_decoder.go | 107 ++++++++++ pkg/app/server/binding_v2/binder.go | 80 ++++++++ pkg/app/server/binding_v2/binder_test.go | 193 ++++++++++++++++++ .../binding_v2/customized_type_decoder.go | 25 +++ pkg/app/server/binding_v2/decoder.go | 108 ++++++++++ pkg/app/server/binding_v2/getter.go | 95 +++++++++ pkg/app/server/binding_v2/reflect.go | 35 ++++ .../server/binding_v2/slice_type_decoder.go | 141 +++++++++++++ pkg/app/server/binding_v2/tag.go | 66 ++++++ .../server/binding_v2/text_decoder/bool.go | 17 ++ .../server/binding_v2/text_decoder/float.go | 19 ++ pkg/app/server/binding_v2/text_decoder/int.go | 19 ++ .../server/binding_v2/text_decoder/string.go | 11 + .../binding_v2/text_decoder/text_decoder.go | 52 +++++ .../server/binding_v2/text_decoder/unit.go | 19 ++ 15 files changed, 987 insertions(+) create mode 100644 pkg/app/server/binding_v2/base_type_decoder.go create mode 100644 pkg/app/server/binding_v2/binder.go create mode 100644 pkg/app/server/binding_v2/binder_test.go create mode 100644 pkg/app/server/binding_v2/customized_type_decoder.go create mode 100644 pkg/app/server/binding_v2/decoder.go create mode 100644 pkg/app/server/binding_v2/getter.go create mode 100644 pkg/app/server/binding_v2/reflect.go create mode 100644 pkg/app/server/binding_v2/slice_type_decoder.go create mode 100644 pkg/app/server/binding_v2/tag.go create mode 100644 pkg/app/server/binding_v2/text_decoder/bool.go create mode 100644 pkg/app/server/binding_v2/text_decoder/float.go create mode 100644 pkg/app/server/binding_v2/text_decoder/int.go create mode 100644 pkg/app/server/binding_v2/text_decoder/string.go create mode 100644 pkg/app/server/binding_v2/text_decoder/text_decoder.go create mode 100644 pkg/app/server/binding_v2/text_decoder/unit.go diff --git a/pkg/app/server/binding_v2/base_type_decoder.go b/pkg/app/server/binding_v2/base_type_decoder.go new file mode 100644 index 000000000..b73577fca --- /dev/null +++ b/pkg/app/server/binding_v2/base_type_decoder.go @@ -0,0 +1,107 @@ +package binding_v2 + +import ( + "fmt" + "reflect" + + "github.com/cloudwego/hertz/pkg/app/server/binding_v2/text_decoder" + "github.com/cloudwego/hertz/pkg/protocol" +) + +type baseTypeFieldTextDecoder struct { + index int + fieldName string + tagInfos []TagInfo // query,param,header,respHeader ... + fieldType reflect.Type + decoder text_decoder.TextDecoder +} + +func (d *baseTypeFieldTextDecoder) Decode(req *protocol.Request, params PathParams, reqValue reflect.Value) error { + var text string + var defaultValue string + // 最大努力交付,对齐 hertz 现有设计 + for _, tagInfo := range d.tagInfos { + if tagInfo.Key == jsonTag { + continue + } + ret := tagInfo.Getter(req, params, tagInfo.Value) + defaultValue = tagInfo.Default + if len(ret) != 0 { + // 非数组/切片类型,只取第一个值作为只 + text = ret[0] + break + } + } + if len(text) == 0 && len(defaultValue) != 0 { + text = defaultValue + } + if text == "" { + return nil + } + + var err error + + // Pointer support for struct elems + field := reqValue.Field(d.index) + if field.Kind() == reflect.Ptr { + elem := reflect.New(d.fieldType) + err = d.decoder.UnmarshalString(text, elem.Elem()) + if err != nil { + return fmt.Errorf("unable to decode '%s' as %s: %w", text, d.fieldType.Name(), err) + } + + field.Set(elem) + + return nil + } + + // Non-pointer elems + err = d.decoder.UnmarshalString(text, field) + if err != nil { + return fmt.Errorf("unable to decode '%s' as %s: %w", text, d.fieldType.Name(), err) + } + + return nil +} + +func getBaseTypeTextDecoder(field reflect.StructField, index int, tagInfos []TagInfo) ([]decoder, error) { + for idx, tagInfo := range tagInfos { + switch tagInfo.Key { + case pathTag: + tagInfos[idx].Getter = PathParam + case formTag: + tagInfos[idx].Getter = Form + case queryTag: + tagInfos[idx].Getter = Query + case cookieTag: + tagInfos[idx].Getter = Cookie + case headerTag: + tagInfos[idx].Getter = Header + case jsonTag: + // do nothing + case rawBodyTag: + tagInfo.Getter = RawBody + default: + } + } + + fieldType := field.Type + if field.Type.Kind() == reflect.Ptr { + fieldType = field.Type.Elem() + } + + textDecoder, err := text_decoder.SelectTextDecoder(fieldType) + if err != nil { + return nil, err + } + + fieldDecoder := &baseTypeFieldTextDecoder{ + index: index, + fieldName: field.Name, + tagInfos: tagInfos, + decoder: textDecoder, + fieldType: fieldType, + } + + return []decoder{fieldDecoder}, nil +} diff --git a/pkg/app/server/binding_v2/binder.go b/pkg/app/server/binding_v2/binder.go new file mode 100644 index 000000000..f116d0cad --- /dev/null +++ b/pkg/app/server/binding_v2/binder.go @@ -0,0 +1,80 @@ +package binding_v2 + +import ( + "encoding/json" + "fmt" + "reflect" + "sync" + + "github.com/cloudwego/hertz/pkg/protocol" + "google.golang.org/protobuf/proto" +) + +// PathParams parameter acquisition interface on the URL path +type PathParams interface { + Get(name string) (string, bool) +} + +type Binder interface { + Name() string + Bind(*protocol.Request, PathParams, interface{}) error +} + +type Bind struct { + decoderCache sync.Map +} + +func (b *Bind) Name() string { + return "hertz" +} + +func (b *Bind) Bind(req *protocol.Request, params PathParams, v interface{}) error { + // todo: 先做 body unmarshal, 先尝试做 body 绑定,然后再尝试绑定其他内容 + err := b.PreBindBody(req, v) + if err != nil { + return err + } + rv, typeID := valueAndTypeID(v) + if rv.Kind() != reflect.Pointer || rv.IsNil() { + return fmt.Errorf("receiver must be a non-nil pointer") + } + cached, ok := b.decoderCache.Load(typeID) + if ok { + // cached decoder, fast path + decoder := cached.(Decoder) + return decoder(req, params, rv.Elem()) + } + + decoder, err := getReqDecoder(rv.Type()) + if err != nil { + return err + } + + b.decoderCache.Store(typeID, decoder) + return decoder(req, params, rv.Elem()) +} + +var ( + jsonContentTypeBytes = "application/json; charset=utf-8" + protobufContentType = "application/x-protobuf" +) + +// best effort binding +func (b *Bind) PreBindBody(req *protocol.Request, v interface{}) error { + if req.Header.ContentLength() <= 0 { + return nil + } + switch string(req.Header.ContentType()) { + case jsonContentTypeBytes: + // todo: 对齐gin, 添加 "EnableDecoderUseNumber"/"EnableDecoderDisallowUnknownFields" 接口 + return json.Unmarshal(req.Body(), v) + case protobufContentType: + msg, ok := v.(proto.Message) + if !ok { + return fmt.Errorf("%s can not implement 'proto.Message'", v) + } + return proto.Unmarshal(req.Body(), msg) + default: + return nil + } +} diff --git a/pkg/app/server/binding_v2/binder_test.go b/pkg/app/server/binding_v2/binder_test.go new file mode 100644 index 000000000..a073fc94a --- /dev/null +++ b/pkg/app/server/binding_v2/binder_test.go @@ -0,0 +1,193 @@ +package binding_v2 + +import ( + "fmt" + "testing" + + "github.com/cloudwego/hertz/pkg/app/server/binding" + "github.com/cloudwego/hertz/pkg/common/test/assert" + "github.com/cloudwego/hertz/pkg/protocol" + "github.com/cloudwego/hertz/pkg/route/param" +) + +func TestBind_BaseType(t *testing.T) { + bind := Bind{} + type Req struct { + Version int `path:"v"` + ID int `query:"id"` + Header string `header:"H"` + Form string `form:"f"` + } + + req := &protocol.Request{} + req.SetRequestURI("http://foobar.com?id=12") + req.Header.Set("H", "header") // disableNormalizing + req.PostArgs().Add("f", "form") + req.Header.SetContentTypeBytes([]byte("application/x-www-form-urlencoded")) + var params param.Params + params = append(params, param.Param{ + Key: "v", + Value: "1", + }) + + var result Req + + err := bind.Bind(req, params, &result) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, 1, result.Version) + assert.DeepEqual(t, 12, result.ID) + assert.DeepEqual(t, "header", result.Header) + assert.DeepEqual(t,"form", result.Form) +} + +func TestBind_SliceType(t *testing.T) { + bind := Bind{} + type Req struct { + ID []int `query:"id"` + Str [3]string `query:"str"` + Byte []byte `query:"b"` + } + IDs := []int{11, 12, 13} + Strs := [3]string{"qwe", "asd", "zxc"} + Bytes := []byte("123") + + req := &protocol.Request{} + req.SetRequestURI(fmt.Sprintf("http://foobar.com?id=%d&id=%d&id=%d&str=%s&str=%s&str=%s&b=%d&b=%d&b=%d", IDs[0], IDs[1], IDs[2], Strs[0], Strs[1], Strs[2], Bytes[0], Bytes[1], Bytes[2])) + + var result Req + + err := bind.Bind(req, nil, &result) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, 3, len(result.ID)) + for idx, val := range IDs { + assert.DeepEqual(t, val, result.ID[idx]) + } + assert.DeepEqual(t, 3, len(result.Str)) + for idx, val := range Strs { + assert.DeepEqual(t, val, result.Str[idx]) + } + assert.DeepEqual(t, 3, len(result.Byte)) + for idx, val := range Bytes { + assert.DeepEqual(t, val, result.Byte[idx]) + } +} + +func TestBind_JSON(t *testing.T) { + bind := Bind{} + type Req struct { + J1 string `json:"j1"` + J2 int `json:"j2" query:"j2"` // 1. json unmarshal 2. query binding cover + // todo: map + J3 []byte `json:"j3"` + J4 [2]string `json:"j4"` + } + J3s := []byte("12") + J4s := [2]string{"qwe", "asd"} + + req := &protocol.Request{} + req.SetRequestURI("http://foobar.com?j2=13") + req.Header.SetContentTypeBytes([]byte(jsonContentTypeBytes)) + data := []byte(fmt.Sprintf(`{"j1":"j1", "j2":12, "j3":[%d, %d], "j4":["%s", "%s"]}`, J3s[0], J3s[1], J4s[0], J4s[1])) + req.SetBody(data) + req.Header.SetContentLength(len(data)) + var result Req + err := bind.Bind(req, nil, &result) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, "j1", result.J1) + assert.DeepEqual(t, 13, result.J2) + for idx, val := range J3s { + assert.DeepEqual(t, val, result.J3[idx]) + } + for idx, val := range J4s { + assert.DeepEqual(t, val, result.J4[idx]) + } +} + +func Benchmark_V2(b *testing.B) { + bind := Bind{} + type Req struct { + Version string `path:"v"` + ID int `query:"id"` + Header string `header:"h"` + Form string `form:"f"` + } + + req := &protocol.Request{} + req.SetRequestURI("http://foobar.com?id=12") + req.Header.Set("h", "header") + req.PostArgs().Add("f", "form") + req.Header.SetContentTypeBytes([]byte("application/x-www-form-urlencoded")) + var params param.Params + params = append(params, param.Param{ + Key: "v", + Value: "1", + }) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + var result Req + err := bind.Bind(req, params, &result) + if err != nil { + b.Error(err) + } + if result.ID != 12 { + b.Error("Id failed") + } + if result.Form != "form" { + b.Error("form failed") + } + if result.Header != "header" { + b.Error("header failed") + } + if result.Version != "1" { + b.Error("path failed") + } + } +} + +func Benchmark_V1(b *testing.B) { + type Req struct { + Version string `path:"v"` + ID int `query:"id"` + Header string `header:"h"` + Form string `form:"f"` + } + + req := &protocol.Request{} + req.SetRequestURI("http://foobar.com?id=12") + req.Header.Set("h", "header") + req.PostArgs().Add("f", "form") + req.Header.SetContentTypeBytes([]byte("application/x-www-form-urlencoded")) + var params param.Params + params = append(params, param.Param{ + Key: "v", + Value: "1", + }) + + b.ResetTimer() + for i := 0; i < b.N; i++ { + var result Req + err := binding.Bind(req, &result, params) + if err != nil { + b.Error(err) + } + if result.ID != 12 { + b.Error("Id failed") + } + if result.Form != "form" { + b.Error("form failed") + } + if result.Header != "header" { + b.Error("header failed") + } + if result.Version != "1" { + b.Error("path failed") + } + } +} diff --git a/pkg/app/server/binding_v2/customized_type_decoder.go b/pkg/app/server/binding_v2/customized_type_decoder.go new file mode 100644 index 000000000..97a2b660b --- /dev/null +++ b/pkg/app/server/binding_v2/customized_type_decoder.go @@ -0,0 +1,25 @@ +package binding_v2 + +import ( + "reflect" + + "github.com/cloudwego/hertz/pkg/protocol" +) + +type customizedFieldTextDecoder struct { + index int + fieldName string + fieldType reflect.Type +} + +func (d *customizedFieldTextDecoder) Decode(req *protocol.Request, params PathParams, reqValue reflect.Value) error { + v := reflect.New(d.fieldType) + decoder := v.Interface().(FieldCustomizedDecoder) + + if err := decoder.CustomizedFieldDecode(req, params); err != nil { + return err + } + + reqValue.Field(d.index).Set(v.Elem()) + return nil +} diff --git a/pkg/app/server/binding_v2/decoder.go b/pkg/app/server/binding_v2/decoder.go new file mode 100644 index 000000000..a9b1795b4 --- /dev/null +++ b/pkg/app/server/binding_v2/decoder.go @@ -0,0 +1,108 @@ +package binding_v2 + +import ( + "fmt" + "reflect" + + "github.com/cloudwego/hertz/pkg/protocol" +) + +type decoder interface { + Decode(req *protocol.Request, params PathParams, reqValue reflect.Value) error +} + +type FieldCustomizedDecoder interface { + CustomizedFieldDecode(req *protocol.Request, params PathParams) error +} + +type Decoder func(req *protocol.Request, params PathParams, rv reflect.Value) error + +var fieldDecoderType = reflect.TypeOf((*FieldCustomizedDecoder)(nil)).Elem() + +func getReqDecoder(rt reflect.Type) (Decoder, error) { + var decoders []decoder + + el := rt.Elem() + if el.Kind() != reflect.Struct { + // todo: 增加对map的支持 + return nil, fmt.Errorf("unsupport non-struct type binding") + } + + for i := 0; i < el.NumField(); i++ { + if !el.Field(i).IsExported() { + // ignore unexported field + continue + } + + dec, err := getFieldDecoder(el.Field(i), i) + if err != nil { + return nil, err + } + + if dec != nil { + decoders = append(decoders, dec...) + } + } + + return func(req *protocol.Request, params PathParams, rv reflect.Value) error { + for _, decoder := range decoders { + err := decoder.Decode(req, params, rv) + if err != nil { + return err + } + } + + return nil + }, nil +} + +func getFieldDecoder(field reflect.StructField, index int) ([]decoder, error) { + if reflect.PtrTo(field.Type).Implements(fieldDecoderType) { + return []decoder{&customizedFieldTextDecoder{index: index, fieldName: field.Name, fieldType: field.Type}}, nil + } + + fieldTagInfos := lookupFieldTags(field) + if len(fieldTagInfos) == 0 { + // todo: 如果没定义尝试给其赋值所有 tag + return nil, nil + } + + // todo: 用户自定义text信息解析 + //if reflect.PtrTo(field.Type).Implements(textUnmarshalerType) { + // return compileTextBasedDecoder(field, index, tagScope, tagContent) + //} + + // todo: reflect Map + if field.Type.Kind() == reflect.Slice || field.Type.Kind() == reflect.Array { + return getSliceFieldDecoder(field, index, fieldTagInfos) + } + + // Nested binding support + if field.Type.Kind() == reflect.Ptr { + field.Type = field.Type.Elem() + } + // 递归每一个 struct + if field.Type.Kind() == reflect.Struct { + var decoders []decoder + el := field.Type + + for i := 0; i < el.NumField(); i++ { + if !el.Field(i).IsExported() { + // ignore unexported field + continue + } + dec, err := getFieldDecoder(el.Field(i), i) + if err != nil { + return nil, err + } + + if dec != nil { + decoders = append(decoders, dec...) + } + } + + return decoders, nil + } + + return getBaseTypeTextDecoder(field, index, fieldTagInfos) +} diff --git a/pkg/app/server/binding_v2/getter.go b/pkg/app/server/binding_v2/getter.go new file mode 100644 index 000000000..75ace6117 --- /dev/null +++ b/pkg/app/server/binding_v2/getter.go @@ -0,0 +1,95 @@ +package binding_v2 + +import ( + "github.com/cloudwego/hertz/internal/bytesconv" + "github.com/cloudwego/hertz/pkg/protocol" +) + +// todo: 优化,对于非数组类型的解析,要不要再提供一个不返回 []string 的 + +type getter func(req *protocol.Request, params PathParams, key string, defaultValue ...string) (ret []string) + +// todo string 强转优化 +func PathParam(req *protocol.Request, params PathParams, key string, defaultValue ...string) (ret []string) { + value, _ := params.Get(key) + + if len(value) == 0 && len(defaultValue) != 0 { + value = defaultValue[0] + } + ret = append(ret, value) + return +} + +// todo 区分postform和multipartform +func Form(req *protocol.Request, params PathParams, key string, defaultValue ...string) (ret []string) { + req.URI().QueryArgs().VisitAll(func(queryKey, value []byte) { + if bytesconv.B2s(queryKey) == key { + ret = append(ret, string(value)) + } + }) + req.PostArgs().VisitAll(func(formKey, value []byte) { + if bytesconv.B2s(formKey) == key { + ret = append(ret, string(value)) + } + }) + + if len(ret) == 0 && len(defaultValue) != 0 { + ret = append(ret, defaultValue...) + } + + return +} + +func Query(req *protocol.Request, params PathParams, key string, defaultValue ...string) (ret []string) { + req.URI().QueryArgs().VisitAll(func(queryKey, value []byte) { + if bytesconv.B2s(queryKey) == key { + ret = append(ret, string(value)) + } + }) + if len(ret) == 0 && len(defaultValue) != 0 { + ret = append(ret, defaultValue...) + } + + return +} + +// todo: cookie +func Cookie(req *protocol.Request, params PathParams, key string, defaultValue ...string) (ret []string) { + req.Header.VisitAllCookie(func(cookieKey, value []byte) { + if bytesconv.B2s(cookieKey) == key { + ret = append(ret, string(value)) + } + }) + + if len(ret) == 0 && len(defaultValue) != 0 { + ret = append(ret, defaultValue...) + } + + return +} + +func Header(req *protocol.Request, params PathParams, key string, defaultValue ...string) (ret []string) { + req.Header.VisitAll(func(headerKey, value []byte) { + if bytesconv.B2s(headerKey) == key { + ret = append(ret, string(value)) + } + }) + + if len(ret) == 0 && len(defaultValue) != 0 { + ret = append(ret, defaultValue...) + } + + return +} + +func Json(req *protocol.Request, params PathParams, key string, defaultValue ...string) (ret []string) { + // do nothing + return +} + +func RawBody(req *protocol.Request, params PathParams, key string, defaultValue ...string) (ret []string) { + if req.Header.ContentLength() > 0 { + ret = append(ret, string(req.Body())) + } + return +} diff --git a/pkg/app/server/binding_v2/reflect.go b/pkg/app/server/binding_v2/reflect.go new file mode 100644 index 000000000..0e7fae7fb --- /dev/null +++ b/pkg/app/server/binding_v2/reflect.go @@ -0,0 +1,35 @@ +package binding_v2 + +import ( + "reflect" + "unsafe" +) + +func valueAndTypeID(v interface{}) (reflect.Value, uintptr) { + header := (*emptyInterface)(unsafe.Pointer(&v)) + + rv := reflect.ValueOf(v) + return rv, header.typeID +} + +type emptyInterface struct { + typeID uintptr + dataPtr unsafe.Pointer +} + +// ReferenceValue convert T to *T, the ptrDepth is the count of '*'. +func ReferenceValue(v reflect.Value, ptrDepth int) reflect.Value { + switch { + case ptrDepth > 0: + for ; ptrDepth > 0; ptrDepth-- { + vv := reflect.New(v.Type()) + vv.Elem().Set(v) + v = vv + } + case ptrDepth < 0: + for ; ptrDepth < 0 && v.Kind() == reflect.Ptr; ptrDepth++ { + v = v.Elem() + } + } + return v +} diff --git a/pkg/app/server/binding_v2/slice_type_decoder.go b/pkg/app/server/binding_v2/slice_type_decoder.go new file mode 100644 index 000000000..420bc6cba --- /dev/null +++ b/pkg/app/server/binding_v2/slice_type_decoder.go @@ -0,0 +1,141 @@ +package binding_v2 + +import ( + "encoding/json" + "fmt" + "reflect" + + "github.com/cloudwego/hertz/internal/bytesconv" + "github.com/cloudwego/hertz/pkg/app/server/binding_v2/text_decoder" + "github.com/cloudwego/hertz/pkg/protocol" +) + +type sliceTypeFieldTextDecoder struct { + index int + fieldName string + isArray bool + tagInfos []TagInfo // query,param,header,respHeader ... + fieldType reflect.Type +} + +func (d *sliceTypeFieldTextDecoder) Decode(req *protocol.Request, params PathParams, reqValue reflect.Value) error { + var texts []string + for _, tagInfo := range d.tagInfos { + if tagInfo.Key == jsonTag { + continue + } + texts = tagInfo.Getter(req, params, tagInfo.Value) + // todo: 数组默认值 + // defaultValue = tagInfo.Default + if len(texts) != 0 { + break + } + } + if len(texts) == 0 { + return nil + } + + field := reqValue.Field(d.index) + + if d.isArray { + if len(texts) != field.Len() { + return fmt.Errorf("%q is not valid value for %s", texts, field.Type().String()) + } + } else { + // slice need creating enough capacity + field = reflect.MakeSlice(field.Type(), len(texts), len(texts)) + } + + // handle multiple pointer + var ptrDepth int + t := d.fieldType.Elem() + elemKind := t.Kind() + for elemKind == reflect.Ptr { + t = t.Elem() + elemKind = t.Kind() + ptrDepth++ + } + + for idx, text := range texts { + var vv reflect.Value + vv, err := stringToValue(t, text) + if err != nil { + return err + } + field.Index(idx).Set(ReferenceValue(vv, ptrDepth)) + } + reqValue.Field(d.index).Set(field) + + return nil +} + +// 数组/切片类型的decoder, +// 对于map和struct类型的数组元素直接使用unmarshal,不做嵌套处理 +func getSliceFieldDecoder(field reflect.StructField, index int, tagInfos []TagInfo) ([]decoder, error) { + if !(field.Type.Kind() == reflect.Slice || field.Type.Kind() == reflect.Array) { + return nil, fmt.Errorf("unexpected type %s, expected slice or array", field.Type.String()) + } + isArray := false + if field.Type.Kind() == reflect.Array { + isArray = true + } + for idx, tagInfo := range tagInfos { + switch tagInfo.Key { + case pathTag: + tagInfos[idx].Getter = PathParam + case formTag: + tagInfos[idx].Getter = Form + case queryTag: + tagInfos[idx].Getter = Query + case cookieTag: + tagInfos[idx].Getter = Cookie + case headerTag: + tagInfos[idx].Getter = Header + case jsonTag: + // do nothing + case rawBodyTag: + tagInfo.Getter = RawBody + default: + } + } + + fieldType := field.Type + if field.Type.Kind() == reflect.Ptr { + fieldType = field.Type.Elem() + } + + fieldDecoder := &sliceTypeFieldTextDecoder{ + index: index, + fieldName: field.Name, + tagInfos: tagInfos, + fieldType: fieldType, + isArray: isArray, + } + + return []decoder{fieldDecoder}, nil +} + +func stringToValue(elemType reflect.Type, text string) (v reflect.Value, err error) { + v = reflect.New(elemType).Elem() + // todo:自定义类型解析 + + switch elemType.Kind() { + case reflect.Struct: + err = json.Unmarshal(bytesconv.S2b(text), v.Addr().Interface()) + case reflect.Map: + err = json.Unmarshal(bytesconv.S2b(text), v.Addr().Interface()) + case reflect.Array, reflect.Slice: + // do nothing + default: + decoder, err := text_decoder.SelectTextDecoder(elemType) + if err != nil { + return reflect.Value{}, fmt.Errorf("unsupport type %s for slice/array", elemType.String()) + } + err = decoder.UnmarshalString(text, v) + if err != nil { + return reflect.Value{}, fmt.Errorf("unable to decode '%s' as %s: %w", text, elemType.String(), err) + } + } + + return v, nil +} diff --git a/pkg/app/server/binding_v2/tag.go b/pkg/app/server/binding_v2/tag.go new file mode 100644 index 000000000..72dcc0af3 --- /dev/null +++ b/pkg/app/server/binding_v2/tag.go @@ -0,0 +1,66 @@ +package binding_v2 + +import ( + "reflect" + "strings" +) + +const ( + pathTag = "path" + formTag = "form" + queryTag = "query" + cookieTag = "cookie" + headerTag = "header" + jsonTag = "json" + rawBodyTag = "raw_body" +) + +const ( + requiredTagOpt = "required" +) + +type TagInfo struct { + Key string + Value string + Required bool + Default string + Options []string + Getter getter +} + +func head(str, sep string) (head, tail string) { + idx := strings.Index(str, sep) + if idx < 0 { + return str, "" + } + return str[:idx], str[idx+len(sep):] +} + +func lookupFieldTags(field reflect.StructField) []TagInfo { + var ret []string + tags := []string{pathTag, formTag, queryTag, cookieTag, headerTag, jsonTag, rawBodyTag} + for _, tag := range tags { + if _, ok := field.Tag.Lookup(tag); ok { + ret = append(ret, tag) + } + } + var tagInfos []TagInfo + + for _, tag := range ret { + tagContent := field.Tag.Get(tag) + tagValue, opts := head(tagContent, ",") + var options []string + var opt string + var required bool + for len(opts) > 0 { + opt, opts = head(opts, ",") + options = append(options, opt) + if opt == requiredTagOpt { + required = true + } + } + tagInfos = append(tagInfos, TagInfo{Key: tag, Value: tagValue, Options: options, Required: required}) + } + + return tagInfos +} diff --git a/pkg/app/server/binding_v2/text_decoder/bool.go b/pkg/app/server/binding_v2/text_decoder/bool.go new file mode 100644 index 000000000..b669f5f9a --- /dev/null +++ b/pkg/app/server/binding_v2/text_decoder/bool.go @@ -0,0 +1,17 @@ +package text_decoder + +import ( + "reflect" + "strconv" +) + +type boolDecoder struct{} + +func (d *boolDecoder) UnmarshalString(s string, fieldValue reflect.Value) error { + v, err := strconv.ParseBool(s) + if err != nil { + return err + } + fieldValue.SetBool(v) + return nil +} diff --git a/pkg/app/server/binding_v2/text_decoder/float.go b/pkg/app/server/binding_v2/text_decoder/float.go new file mode 100644 index 000000000..395526153 --- /dev/null +++ b/pkg/app/server/binding_v2/text_decoder/float.go @@ -0,0 +1,19 @@ +package text_decoder + +import ( + "reflect" + "strconv" +) + +type floatDecoder struct { + bitSize int +} + +func (d *floatDecoder) UnmarshalString(s string, fieldValue reflect.Value) error { + v, err := strconv.ParseFloat(s, d.bitSize) + if err != nil { + return err + } + fieldValue.SetFloat(v) + return nil +} diff --git a/pkg/app/server/binding_v2/text_decoder/int.go b/pkg/app/server/binding_v2/text_decoder/int.go new file mode 100644 index 000000000..13a26e644 --- /dev/null +++ b/pkg/app/server/binding_v2/text_decoder/int.go @@ -0,0 +1,19 @@ +package text_decoder + +import ( + "reflect" + "strconv" +) + +type intDecoder struct { + bitSize int +} + +func (d *intDecoder) UnmarshalString(s string, fieldValue reflect.Value) error { + v, err := strconv.ParseInt(s, 10, d.bitSize) + if err != nil { + return err + } + fieldValue.SetInt(v) + return nil +} diff --git a/pkg/app/server/binding_v2/text_decoder/string.go b/pkg/app/server/binding_v2/text_decoder/string.go new file mode 100644 index 000000000..6290a4c31 --- /dev/null +++ b/pkg/app/server/binding_v2/text_decoder/string.go @@ -0,0 +1,11 @@ +package text_decoder + +import "reflect" + +type stringDecoder struct{} + +func (d *stringDecoder) UnmarshalString(s string, fieldValue reflect.Value) error { + // todo: 优化一下 + fieldValue.SetString(s) + return nil +} diff --git a/pkg/app/server/binding_v2/text_decoder/text_decoder.go b/pkg/app/server/binding_v2/text_decoder/text_decoder.go new file mode 100644 index 000000000..934b9a9c0 --- /dev/null +++ b/pkg/app/server/binding_v2/text_decoder/text_decoder.go @@ -0,0 +1,52 @@ +package text_decoder + +import ( + "fmt" + "reflect" +) + +type TextDecoder interface { + UnmarshalString(s string, fieldValue reflect.Value) error +} + +// var textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() + +func SelectTextDecoder(rt reflect.Type) (TextDecoder, error) { + // todo: encoding.TextUnmarshaler + //if reflect.PtrTo(rt).Implements(textUnmarshalerType) { + // return &textUnmarshalEncoder{fieldType: rt}, nil + //} + + switch rt.Kind() { + case reflect.Bool: + return &boolDecoder{}, nil + case reflect.Uint8: + return &uintDecoder{bitSize: 8}, nil + case reflect.Uint16: + return &uintDecoder{bitSize: 16}, nil + case reflect.Uint32: + return &uintDecoder{bitSize: 32}, nil + case reflect.Uint64: + return &uintDecoder{bitSize: 64}, nil + case reflect.Uint: + return &uintDecoder{}, nil + case reflect.Int8: + return &intDecoder{bitSize: 8}, nil + case reflect.Int16: + return &intDecoder{bitSize: 16}, nil + case reflect.Int32: + return &intDecoder{bitSize: 32}, nil + case reflect.Int64: + return &intDecoder{bitSize: 64}, nil + case reflect.Int: + return &intDecoder{}, nil + case reflect.String: + return &stringDecoder{}, nil + case reflect.Float32: + return &floatDecoder{bitSize: 32}, nil + case reflect.Float64: + return &floatDecoder{bitSize: 64}, nil + } + + return nil, fmt.Errorf("unsupported type " + rt.String()) +} diff --git a/pkg/app/server/binding_v2/text_decoder/unit.go b/pkg/app/server/binding_v2/text_decoder/unit.go new file mode 100644 index 000000000..cb766964a --- /dev/null +++ b/pkg/app/server/binding_v2/text_decoder/unit.go @@ -0,0 +1,19 @@ +package text_decoder + +import ( + "reflect" + "strconv" +) + +type uintDecoder struct { + bitSize int +} + +func (d *uintDecoder) UnmarshalString(s string, fieldValue reflect.Value) error { + v, err := strconv.ParseUint(s, 10, d.bitSize) + if err != nil { + return err + } + fieldValue.SetUint(v) + return nil +} From d08feef82309d35f4d0b2a4ff88b437235999a42 Mon Sep 17 00:00:00 2001 From: fgy Date: Thu, 12 Jan 2023 10:48:17 +0800 Subject: [PATCH 02/91] feat: add struct binding --- .../server/binding_v2/base_type_decoder.go | 71 +++- pkg/app/server/binding_v2/binder.go | 3 + pkg/app/server/binding_v2/binder_test.go | 356 ++++++++++++++++-- .../binding_v2/customized_type_decoder.go | 48 ++- pkg/app/server/binding_v2/decoder.go | 51 ++- pkg/app/server/binding_v2/map_type_decoder.go | 131 +++++++ pkg/app/server/binding_v2/reflect.go | 13 + .../server/binding_v2/slice_type_decoder.go | 54 ++- 8 files changed, 641 insertions(+), 86 deletions(-) create mode 100644 pkg/app/server/binding_v2/map_type_decoder.go diff --git a/pkg/app/server/binding_v2/base_type_decoder.go b/pkg/app/server/binding_v2/base_type_decoder.go index b73577fca..a32e34764 100644 --- a/pkg/app/server/binding_v2/base_type_decoder.go +++ b/pkg/app/server/binding_v2/base_type_decoder.go @@ -5,15 +5,17 @@ import ( "reflect" "github.com/cloudwego/hertz/pkg/app/server/binding_v2/text_decoder" + "github.com/cloudwego/hertz/pkg/common/utils" "github.com/cloudwego/hertz/pkg/protocol" ) type baseTypeFieldTextDecoder struct { - index int - fieldName string - tagInfos []TagInfo // query,param,header,respHeader ... - fieldType reflect.Type - decoder text_decoder.TextDecoder + index int + parentIndex []int + fieldName string + tagInfos []TagInfo // query,param,header,respHeader ... + fieldType reflect.Type + decoder text_decoder.TextDecoder } func (d *baseTypeFieldTextDecoder) Decode(req *protocol.Request, params PathParams, reqValue reflect.Value) error { @@ -24,6 +26,11 @@ func (d *baseTypeFieldTextDecoder) Decode(req *protocol.Request, params PathPara if tagInfo.Key == jsonTag { continue } + if tagInfo.Key == headerTag { + tmp := []byte(tagInfo.Value) + utils.NormalizeHeaderKey(tmp, req.Header.IsDisableNormalizing()) + tagInfo.Value = string(tmp) + } ret := tagInfo.Getter(req, params, tagInfo.Value) defaultValue = tagInfo.Default if len(ret) != 0 { @@ -40,18 +47,43 @@ func (d *baseTypeFieldTextDecoder) Decode(req *protocol.Request, params PathPara } var err error + // 找到该 field 的父 struct 的 reflect.Value + for _, idx := range d.parentIndex { + if reqValue.Kind() == reflect.Ptr && reqValue.IsNil() { + nonNilVal, ptrDepth := GetNonNilReferenceValue(reqValue) + reqValue.Set(ReferenceValue(nonNilVal, ptrDepth)) + } + for reqValue.Kind() == reflect.Ptr { + reqValue = reqValue.Elem() + } + reqValue = reqValue.Field(idx) + } + + // 父 struct 有可能也是一个指针,所以需要再处理一次才能得到最终的父Value(非nil的reflect.Value) + for reqValue.Kind() == reflect.Ptr { + if reqValue.IsNil() { + nonNilVal, ptrDepth := GetNonNilReferenceValue(reqValue) + reqValue.Set(ReferenceValue(nonNilVal, ptrDepth)) + } + reqValue = reqValue.Elem() + } - // Pointer support for struct elems + // 根据最终的 Struct,获取对应 field 的 reflect.Value field := reqValue.Field(d.index) if field.Kind() == reflect.Ptr { - elem := reflect.New(d.fieldType) - err = d.decoder.UnmarshalString(text, elem.Elem()) + // 如果是指针则新建一个reflect.Value,然后赋值给指针 + t := field.Type() + var ptrDepth int + for t.Kind() == reflect.Ptr { + t = t.Elem() + ptrDepth++ + } + var vv reflect.Value + vv, err := stringToValue(t, text) if err != nil { - return fmt.Errorf("unable to decode '%s' as %s: %w", text, d.fieldType.Name(), err) + return err } - - field.Set(elem) - + field.Set(ReferenceValue(vv, ptrDepth)) return nil } @@ -64,7 +96,7 @@ func (d *baseTypeFieldTextDecoder) Decode(req *protocol.Request, params PathPara return nil } -func getBaseTypeTextDecoder(field reflect.StructField, index int, tagInfos []TagInfo) ([]decoder, error) { +func getBaseTypeTextDecoder(field reflect.StructField, index int, tagInfos []TagInfo, parentIdx []int) ([]decoder, error) { for idx, tagInfo := range tagInfos { switch tagInfo.Key { case pathTag: @@ -86,7 +118,7 @@ func getBaseTypeTextDecoder(field reflect.StructField, index int, tagInfos []Tag } fieldType := field.Type - if field.Type.Kind() == reflect.Ptr { + for field.Type.Kind() == reflect.Ptr { fieldType = field.Type.Elem() } @@ -96,11 +128,12 @@ func getBaseTypeTextDecoder(field reflect.StructField, index int, tagInfos []Tag } fieldDecoder := &baseTypeFieldTextDecoder{ - index: index, - fieldName: field.Name, - tagInfos: tagInfos, - decoder: textDecoder, - fieldType: fieldType, + index: index, + parentIndex: parentIdx, + fieldName: field.Name, + tagInfos: tagInfos, + decoder: textDecoder, + fieldType: fieldType, } return []decoder{fieldDecoder}, nil diff --git a/pkg/app/server/binding_v2/binder.go b/pkg/app/server/binding_v2/binder.go index f116d0cad..cdde45838 100644 --- a/pkg/app/server/binding_v2/binder.go +++ b/pkg/app/server/binding_v2/binder.go @@ -38,6 +38,9 @@ func (b *Bind) Bind(req *protocol.Request, params PathParams, v interface{}) err if rv.Kind() != reflect.Pointer || rv.IsNil() { return fmt.Errorf("receiver must be a non-nil pointer") } + if rv.Kind() == reflect.Map { + return nil + } cached, ok := b.decoderCache.Load(typeID) if ok { // cached decoder, fast path diff --git a/pkg/app/server/binding_v2/binder_test.go b/pkg/app/server/binding_v2/binder_test.go index a073fc94a..20d4aece3 100644 --- a/pkg/app/server/binding_v2/binder_test.go +++ b/pkg/app/server/binding_v2/binder_test.go @@ -10,20 +10,66 @@ import ( "github.com/cloudwego/hertz/pkg/route/param" ) +type mockRequest struct { + Req *protocol.Request +} + +func newMockRequest() *mockRequest { + return &mockRequest{ + Req: &protocol.Request{}, + } +} + +func (m *mockRequest) SetRequestURI(uri string) *mockRequest { + m.Req.SetRequestURI(uri) + return m +} + +func (m *mockRequest) SetHeader(key, value string) *mockRequest { + m.Req.Header.Set(key, value) + return m +} + +func (m *mockRequest) SetHeaders(key, value string) *mockRequest { + m.Req.Header.Set(key, value) + return m +} + +func (m *mockRequest) SetPostArg(key, value string) *mockRequest { + m.Req.PostArgs().Add(key, value) + return m +} + +func (m *mockRequest) SetUrlEncodeContentType() *mockRequest { + m.Req.Header.SetContentTypeBytes([]byte("application/x-www-form-urlencoded")) + return m +} + +func (m *mockRequest) SetJSONContentType() *mockRequest { + m.Req.Header.SetContentTypeBytes([]byte(jsonContentTypeBytes)) + return m +} + +func (m *mockRequest) SetBody(data []byte) *mockRequest { + m.Req.SetBody(data) + m.Req.Header.SetContentLength(len(data)) + return m +} + func TestBind_BaseType(t *testing.T) { bind := Bind{} type Req struct { - Version int `path:"v"` + Version int `path:"v"` ID int `query:"id"` Header string `header:"H"` Form string `form:"f"` } - req := &protocol.Request{} - req.SetRequestURI("http://foobar.com?id=12") - req.Header.Set("H", "header") // disableNormalizing - req.PostArgs().Add("f", "form") - req.Header.SetContentTypeBytes([]byte("application/x-www-form-urlencoded")) + req := newMockRequest(). + SetRequestURI("http://foobar.com?id=12"). + SetHeaders("H", "header"). + SetPostArg("f", "form"). + SetUrlEncodeContentType() var params param.Params params = append(params, param.Param{ Key: "v", @@ -32,14 +78,14 @@ func TestBind_BaseType(t *testing.T) { var result Req - err := bind.Bind(req, params, &result) + err := bind.Bind(req.Req, params, &result) if err != nil { t.Error(err) } assert.DeepEqual(t, 1, result.Version) assert.DeepEqual(t, 12, result.ID) assert.DeepEqual(t, "header", result.Header) - assert.DeepEqual(t,"form", result.Form) + assert.DeepEqual(t, "form", result.Form) } func TestBind_SliceType(t *testing.T) { @@ -53,12 +99,12 @@ func TestBind_SliceType(t *testing.T) { Strs := [3]string{"qwe", "asd", "zxc"} Bytes := []byte("123") - req := &protocol.Request{} - req.SetRequestURI(fmt.Sprintf("http://foobar.com?id=%d&id=%d&id=%d&str=%s&str=%s&str=%s&b=%d&b=%d&b=%d", IDs[0], IDs[1], IDs[2], Strs[0], Strs[1], Strs[2], Bytes[0], Bytes[1], Bytes[2])) + req := newMockRequest(). + SetRequestURI(fmt.Sprintf("http://foobar.com?id=%d&id=%d&id=%d&str=%s&str=%s&str=%s&b=%d&b=%d&b=%d", IDs[0], IDs[1], IDs[2], Strs[0], Strs[1], Strs[2], Bytes[0], Bytes[1], Bytes[2])) var result Req - err := bind.Bind(req, nil, &result) + err := bind.Bind(req.Req, nil, &result) if err != nil { t.Error(err) } @@ -76,6 +122,257 @@ func TestBind_SliceType(t *testing.T) { } } +func TestBind_StructType(t *testing.T) { + type FFF struct { + F1 string `query:"F1"` + } + + type TTT struct { + T1 string `query:"F1"` + T2 FFF + } + + type Foo struct { + F1 string `query:"F1"` + F2 string `header:"f2"` + F3 TTT + } + + type Bar struct { + B1 string `query:"B1"` + B2 Foo `query:"B2"` + } + + bind := Bind{} + + var result Bar + + req := newMockRequest().SetRequestURI("http://foobar.com?F1=f1&B1=b1").SetHeader("f2", "f2") + + err := bind.Bind(req.Req, nil, &result) + if err != nil { + t.Error(err) + } + + assert.DeepEqual(t, "b1", result.B1) + assert.DeepEqual(t, "f1", result.B2.F1) + assert.DeepEqual(t, "f2", result.B2.F2) + assert.DeepEqual(t, "f1", result.B2.F3.T1) + assert.DeepEqual(t, "f1", result.B2.F3.T2.F1) +} + +func TestBind_PointerType(t *testing.T) { + type TT struct { + T1 string `query:"F1"` + } + + type Foo struct { + F1 *TT `query:"F1"` + F2 *******************string `query:"F1"` + } + + type Bar struct { + B1 ***string `query:"B1"` + B2 ****Foo `query:"B2"` + B3 []*string `query:"B3"` + B4 [2]*int `query:"B4"` + } + + bind := Bind{} + + result := Bar{} + + F1 := "f1" + B1 := "b1" + B2 := "b2" + B3s := []string{"b31", "b32"} + B4s := [2]int{0, 1} + + req := newMockRequest().SetRequestURI(fmt.Sprintf("http://foobar.com?F1=%s&B1=%s&B2=%s&B3=%s&B3=%s&B4=%d&B4=%d", F1, B1, B2, B3s[0], B3s[1], B4s[0], B4s[1])). + SetHeader("f2", "f2") + + err := bind.Bind(req.Req, nil, &result) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, B1, ***result.B1) + assert.DeepEqual(t, F1, (*(****result.B2).F1).T1) + assert.DeepEqual(t, F1, *******************(****result.B2).F2) + assert.DeepEqual(t, len(B3s), len(result.B3)) + for idx, val := range B3s { + assert.DeepEqual(t, val, *result.B3[idx]) + } + assert.DeepEqual(t, len(B4s), len(result.B4)) + for idx, val := range B4s { + assert.DeepEqual(t, val, *result.B4[idx]) + } +} + +func TestBind_NestedStruct(t *testing.T) { + type Foo struct { + F1 string `query:"F1"` + } + + type Bar struct { + Foo + Nested struct { + N1 string `query:"F1"` + } + } + + bind := Bind{} + + result := Bar{} + + req := newMockRequest().SetRequestURI("http://foobar.com?F1=qwe") + err := bind.Bind(req.Req, nil, &result) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, "qwe", result.Foo.F1) + assert.DeepEqual(t, "qwe", result.Nested.N1) +} + +func TestBind_SliceStruct(t *testing.T) { + type Foo struct { + F1 string `json:"f1"` + } + + type Bar struct { + B1 []Foo `query:"F1"` + } + + bind := Bind{} + + result := Bar{} + B1s := []string{"1", "2", "3"} + + req := newMockRequest().SetRequestURI(fmt.Sprintf("http://foobar.com?F1={\"f1\":\"%s\"}&F1={\"f1\":\"%s\"}&F1={\"f1\":\"%s\"}", B1s[0], B1s[1], B1s[2])) + err := bind.Bind(req.Req, nil, &result) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, len(result.B1), len(B1s)) + for idx, val := range B1s { + assert.DeepEqual(t, B1s[idx], val) + } +} + +func TestBind_MapType(t *testing.T) { + var result map[string]string + bind := Bind{} + req := newMockRequest(). + SetJSONContentType(). + SetBody([]byte(`{"j1":"j1", "j2":"j2"}`)) + err := bind.Bind(req.Req, nil, &result) + if err != nil { + t.Fatal(err) + } + assert.DeepEqual(t, 2, len(result)) + assert.DeepEqual(t, "j1", result["j1"]) + assert.DeepEqual(t, "j2", result["j2"]) +} + +func TestBind_MapFieldType(t *testing.T) { + type Foo struct { + F1 ***map[string]string `query:"f1" json:"f1"` + } + + bind := Bind{} + req := newMockRequest(). + SetRequestURI("http://foobar.com?f1={\"f1\":\"f1\"}"). + SetJSONContentType(). + SetBody([]byte(`{"j1":"j1", "j2":"j2"}`)) + result := Foo{} + err := bind.Bind(req.Req, nil, &result) + if err != nil { + t.Fatal(err) + } + assert.DeepEqual(t, 1, len(***result.F1)) + assert.DeepEqual(t, "f1", (***result.F1)["f1"]) +} + +func TestBind_UnexportedField(t *testing.T) { + var s struct { + A int `query:"a"` + b int `query:"b"` + } + bind := Bind{} + req := newMockRequest(). + SetRequestURI("http://foobar.com?a=1&b=2}") + err := bind.Bind(req.Req, nil, &s) + if err != nil { + t.Fatal(err) + } + assert.DeepEqual(t, 1, s.A) + assert.DeepEqual(t, 0, s.b) +} + +func TestBind_TypedefType(t *testing.T) { + type Foo string + type Bar *int + type T struct { + T1 string `query:"a"` + } + type TT T + + var s struct { + A Foo `query:"a"` + B Bar `query:"b"` + T1 TT + } + bind := Bind{} + req := newMockRequest(). + SetRequestURI("http://foobar.com?a=1&b=2") + err := bind.Bind(req.Req, nil, &s) + if err != nil { + t.Fatal(err) + } + assert.DeepEqual(t, Foo("1"), s.A) + assert.DeepEqual(t, 2, *s.B) + assert.DeepEqual(t, "1", s.T1.T1) +} + +type CustomizedDecode struct { + A string +} + +func (c *CustomizedDecode) CustomizedFieldDecode(req *protocol.Request, params PathParams) error { + q1 := req.URI().QueryArgs().Peek("a") + if len(q1) == 0 { + return fmt.Errorf("can be nil") + } + + c.A = string(q1) + return nil +} + +func TestBind_CustomizedTypeDecode(t *testing.T) { + type Foo struct { + F ***CustomizedDecode + } + bind := Bind{} + req := newMockRequest(). + SetRequestURI("http://foobar.com?a=1&b=2") + result := Foo{} + err := bind.Bind(req.Req, nil, &result) + if err != nil { + t.Fatal(err) + } + assert.DeepEqual(t, "1", (***result.F).A) + + type Bar struct { + B *Foo + } + + result2 := Bar{} + err = bind.Bind(req.Req, nil, &result2) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, "1", (***(*result2.B).F).A) +} + func TestBind_JSON(t *testing.T) { bind := Bind{} type Req struct { @@ -88,14 +385,12 @@ func TestBind_JSON(t *testing.T) { J3s := []byte("12") J4s := [2]string{"qwe", "asd"} - req := &protocol.Request{} - req.SetRequestURI("http://foobar.com?j2=13") - req.Header.SetContentTypeBytes([]byte(jsonContentTypeBytes)) - data := []byte(fmt.Sprintf(`{"j1":"j1", "j2":12, "j3":[%d, %d], "j4":["%s", "%s"]}`, J3s[0], J3s[1], J4s[0], J4s[1])) - req.SetBody(data) - req.Header.SetContentLength(len(data)) + req := newMockRequest(). + SetRequestURI("http://foobar.com?j2=13"). + SetJSONContentType(). + SetBody([]byte(fmt.Sprintf(`{"j1":"j1", "j2":12, "j3":[%d, %d], "j4":["%s", "%s"]}`, J3s[0], J3s[1], J4s[0], J4s[1]))) var result Req - err := bind.Bind(req, nil, &result) + err := bind.Bind(req.Req, nil, &result) if err != nil { t.Error(err) } @@ -118,11 +413,12 @@ func Benchmark_V2(b *testing.B) { Form string `form:"f"` } - req := &protocol.Request{} - req.SetRequestURI("http://foobar.com?id=12") - req.Header.Set("h", "header") - req.PostArgs().Add("f", "form") - req.Header.SetContentTypeBytes([]byte("application/x-www-form-urlencoded")) + req := newMockRequest(). + SetRequestURI("http://foobar.com?id=12"). + SetHeaders("H", "header"). + SetPostArg("f", "form"). + SetUrlEncodeContentType() + var params param.Params params = append(params, param.Param{ Key: "v", @@ -132,7 +428,7 @@ func Benchmark_V2(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { var result Req - err := bind.Bind(req, params, &result) + err := bind.Bind(req.Req, params, &result) if err != nil { b.Error(err) } @@ -159,11 +455,11 @@ func Benchmark_V1(b *testing.B) { Form string `form:"f"` } - req := &protocol.Request{} - req.SetRequestURI("http://foobar.com?id=12") - req.Header.Set("h", "header") - req.PostArgs().Add("f", "form") - req.Header.SetContentTypeBytes([]byte("application/x-www-form-urlencoded")) + req := newMockRequest(). + SetRequestURI("http://foobar.com?id=12"). + SetHeaders("h", "header"). + SetPostArg("f", "form"). + SetUrlEncodeContentType() var params param.Params params = append(params, param.Param{ Key: "v", @@ -173,7 +469,7 @@ func Benchmark_V1(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { var result Req - err := binding.Bind(req, &result, params) + err := binding.Bind(req.Req, &result, params) if err != nil { b.Error(err) } diff --git a/pkg/app/server/binding_v2/customized_type_decoder.go b/pkg/app/server/binding_v2/customized_type_decoder.go index 97a2b660b..aa61512d0 100644 --- a/pkg/app/server/binding_v2/customized_type_decoder.go +++ b/pkg/app/server/binding_v2/customized_type_decoder.go @@ -7,19 +7,55 @@ import ( ) type customizedFieldTextDecoder struct { - index int - fieldName string - fieldType reflect.Type + index int + parentIndex []int + fieldName string + fieldType reflect.Type } func (d *customizedFieldTextDecoder) Decode(req *protocol.Request, params PathParams, reqValue reflect.Value) error { + var err error v := reflect.New(d.fieldType) - decoder := v.Interface().(FieldCustomizedDecoder) + decoder := v.Interface().(CustomizedFieldDecoder) - if err := decoder.CustomizedFieldDecode(req, params); err != nil { + if err = decoder.CustomizedFieldDecode(req, params); err != nil { return err } - reqValue.Field(d.index).Set(v.Elem()) + // 找到该 field 的父 struct 的 reflect.Value + for _, idx := range d.parentIndex { + if reqValue.Kind() == reflect.Ptr && reqValue.IsNil() { + nonNilVal, ptrDepth := GetNonNilReferenceValue(reqValue) + reqValue.Set(ReferenceValue(nonNilVal, ptrDepth)) + } + for reqValue.Kind() == reflect.Ptr { + reqValue = reqValue.Elem() + } + reqValue = reqValue.Field(idx) + } + + // 父 struct 有可能也是一个指针,所以需要再处理一次才能得到最终的父Value(非nil的reflect.Value) + for reqValue.Kind() == reflect.Ptr { + if reqValue.IsNil() { + nonNilVal, ptrDepth := GetNonNilReferenceValue(reqValue) + reqValue.Set(ReferenceValue(nonNilVal, ptrDepth)) + } + reqValue = reqValue.Elem() + } + + field := reqValue.Field(d.index) + if field.Kind() == reflect.Ptr { + // 如果是指针则新建一个reflect.Value,然后赋值给指针 + t := field.Type() + var ptrDepth int + for t.Kind() == reflect.Ptr { + t = t.Elem() + ptrDepth++ + } + field.Set(ReferenceValue(v.Elem(), ptrDepth)) + return nil + } + + field.Set(v) return nil } diff --git a/pkg/app/server/binding_v2/decoder.go b/pkg/app/server/binding_v2/decoder.go index a9b1795b4..614ec0fa6 100644 --- a/pkg/app/server/binding_v2/decoder.go +++ b/pkg/app/server/binding_v2/decoder.go @@ -11,21 +11,20 @@ type decoder interface { Decode(req *protocol.Request, params PathParams, reqValue reflect.Value) error } -type FieldCustomizedDecoder interface { +type CustomizedFieldDecoder interface { CustomizedFieldDecode(req *protocol.Request, params PathParams) error } type Decoder func(req *protocol.Request, params PathParams, rv reflect.Value) error -var fieldDecoderType = reflect.TypeOf((*FieldCustomizedDecoder)(nil)).Elem() +var fieldDecoderType = reflect.TypeOf((*CustomizedFieldDecoder)(nil)).Elem() func getReqDecoder(rt reflect.Type) (Decoder, error) { var decoders []decoder el := rt.Elem() if el.Kind() != reflect.Struct { - // todo: 增加对map的支持 - return nil, fmt.Errorf("unsupport non-struct type binding") + return nil, fmt.Errorf("unsupport \"%s\" type binding", el.String()) } for i := 0; i < el.NumField(); i++ { @@ -34,7 +33,7 @@ func getReqDecoder(rt reflect.Type) (Decoder, error) { continue } - dec, err := getFieldDecoder(el.Field(i), i) + dec, err := getFieldDecoder(el.Field(i), i, []int{}) if err != nil { return nil, err } @@ -56,31 +55,39 @@ func getReqDecoder(rt reflect.Type) (Decoder, error) { }, nil } -func getFieldDecoder(field reflect.StructField, index int) ([]decoder, error) { +func getFieldDecoder(field reflect.StructField, index int, parentIdx []int) ([]decoder, error) { + // 去掉每一个filed的指针,使其指向最终内容 + for field.Type.Kind() == reflect.Ptr { + field.Type = field.Type.Elem() + } if reflect.PtrTo(field.Type).Implements(fieldDecoderType) { - return []decoder{&customizedFieldTextDecoder{index: index, fieldName: field.Name, fieldType: field.Type}}, nil + return []decoder{&customizedFieldTextDecoder{ + index: index, + parentIndex: parentIdx, + fieldName: field.Name, + fieldType: field.Type}}, nil } fieldTagInfos := lookupFieldTags(field) - if len(fieldTagInfos) == 0 { - // todo: 如果没定义尝试给其赋值所有 tag - return nil, nil - } + // todo: 没有 tag 也不直接返回 + //if len(fieldTagInfos) == 0 { + // // todo: 如果没定义尝试给其赋值所有 tag + // return nil, nil + //} // todo: 用户自定义text信息解析 //if reflect.PtrTo(field.Type).Implements(textUnmarshalerType) { // return compileTextBasedDecoder(field, index, tagScope, tagContent) //} - - // todo: reflect Map if field.Type.Kind() == reflect.Slice || field.Type.Kind() == reflect.Array { - return getSliceFieldDecoder(field, index, fieldTagInfos) + return getSliceFieldDecoder(field, index, fieldTagInfos, parentIdx) } - // Nested binding support - if field.Type.Kind() == reflect.Ptr { - field.Type = field.Type.Elem() + // todo: reflect Map + if field.Type.Kind() == reflect.Map { + return getMapTypeTextDecoder(field, index, fieldTagInfos, parentIdx) } + // 递归每一个 struct if field.Type.Kind() == reflect.Struct { var decoders []decoder @@ -91,7 +98,13 @@ func getFieldDecoder(field reflect.StructField, index int) ([]decoder, error) { // ignore unexported field continue } - dec, err := getFieldDecoder(el.Field(i), i) + // todo: 优化一下? idxes := append(parentIdx, index) + var idxes []int + if len(parentIdx) > 0 { + idxes = append(idxes, parentIdx...) + } + idxes = append(idxes, index) + dec, err := getFieldDecoder(el.Field(i), i, idxes) if err != nil { return nil, err } @@ -104,5 +117,5 @@ func getFieldDecoder(field reflect.StructField, index int) ([]decoder, error) { return decoders, nil } - return getBaseTypeTextDecoder(field, index, fieldTagInfos) + return getBaseTypeTextDecoder(field, index, fieldTagInfos, parentIdx) } diff --git a/pkg/app/server/binding_v2/map_type_decoder.go b/pkg/app/server/binding_v2/map_type_decoder.go new file mode 100644 index 000000000..914eb7291 --- /dev/null +++ b/pkg/app/server/binding_v2/map_type_decoder.go @@ -0,0 +1,131 @@ +package binding_v2 + +import ( + "encoding/json" + "fmt" + "reflect" + + "github.com/cloudwego/hertz/internal/bytesconv" + "github.com/cloudwego/hertz/pkg/common/utils" + + "github.com/cloudwego/hertz/pkg/protocol" +) + +type mapTypeFieldTextDecoder struct { + index int + parentIndex []int + fieldName string + tagInfos []TagInfo // query,param,header,respHeader ... + fieldType reflect.Type +} + +func (d *mapTypeFieldTextDecoder) Decode(req *protocol.Request, params PathParams, reqValue reflect.Value) error { + var text string + var defaultValue string + // 最大努力交付,对齐 hertz 现有设计 + for _, tagInfo := range d.tagInfos { + if tagInfo.Key == jsonTag { + continue + } + if tagInfo.Key == headerTag { + tmp := []byte(tagInfo.Value) + utils.NormalizeHeaderKey(tmp, req.Header.IsDisableNormalizing()) + tagInfo.Value = string(tmp) + } + ret := tagInfo.Getter(req, params, tagInfo.Value) + defaultValue = tagInfo.Default + if len(ret) != 0 { + // 非数组/切片类型,只取第一个值作为只 + text = ret[0] + break + } + } + if len(text) == 0 && len(defaultValue) != 0 { + text = defaultValue + } + if text == "" { + return nil + } + + // todo 多重指针 + for _, idx := range d.parentIndex { + if reqValue.Kind() == reflect.Ptr && reqValue.IsNil() { + nonNilVal, ptrDepth := GetNonNilReferenceValue(reqValue) + reqValue.Set(ReferenceValue(nonNilVal, ptrDepth)) + } + for reqValue.Kind() == reflect.Ptr { + reqValue = reqValue.Elem() + } + reqValue = reqValue.Field(idx) + } + + // 父 struct 有可能也是一个指针,所以需要再处理一次才能得到最终的父Value(非nil的reflect.Value) + for reqValue.Kind() == reflect.Ptr { + if reqValue.IsNil() { + nonNilVal, ptrDepth := GetNonNilReferenceValue(reqValue) + reqValue.Set(ReferenceValue(nonNilVal, ptrDepth)) + } + reqValue = reqValue.Elem() + } + + field := reqValue.Field(d.index) + if field.Kind() == reflect.Ptr { + // 如果是指针则新建一个reflect.Value,然后赋值给指针 + t := field.Type() + var ptrDepth int + for t.Kind() == reflect.Ptr { + t = t.Elem() + ptrDepth++ + } + var vv reflect.Value + vv, err := stringToValue(t, text) + if err != nil { + return fmt.Errorf("unable to decode '%s' as %s: %w", text, d.fieldType.Name(), err) + } + field.Set(ReferenceValue(vv, ptrDepth)) + return nil + } + + err := json.Unmarshal(bytesconv.S2b(text), field.Addr().Interface()) + if err != nil { + return fmt.Errorf("unable to decode '%s' as %s: %w", text, d.fieldType.Name(), err) + } + + return nil +} + +func getMapTypeTextDecoder(field reflect.StructField, index int, tagInfos []TagInfo, parentIdx []int) ([]decoder, error) { + for idx, tagInfo := range tagInfos { + switch tagInfo.Key { + case pathTag: + tagInfos[idx].Getter = PathParam + case formTag: + tagInfos[idx].Getter = Form + case queryTag: + tagInfos[idx].Getter = Query + case cookieTag: + tagInfos[idx].Getter = Cookie + case headerTag: + tagInfos[idx].Getter = Header + case jsonTag: + // do nothing + case rawBodyTag: + tagInfo.Getter = RawBody + default: + } + } + + fieldType := field.Type + for field.Type.Kind() == reflect.Ptr { + fieldType = field.Type.Elem() + } + fieldDecoder := &mapTypeFieldTextDecoder{ + index: index, + parentIndex: parentIdx, + fieldName: field.Name, + tagInfos: tagInfos, + fieldType: fieldType, + } + + return []decoder{fieldDecoder}, nil +} diff --git a/pkg/app/server/binding_v2/reflect.go b/pkg/app/server/binding_v2/reflect.go index 0e7fae7fb..aa8d91aee 100644 --- a/pkg/app/server/binding_v2/reflect.go +++ b/pkg/app/server/binding_v2/reflect.go @@ -33,3 +33,16 @@ func ReferenceValue(v reflect.Value, ptrDepth int) reflect.Value { } return v } + +func GetNonNilReferenceValue(v reflect.Value) (reflect.Value, int) { + var ptrDepth int + t := v.Type() + elemKind := t.Kind() + for elemKind == reflect.Ptr { + t = t.Elem() + elemKind = t.Kind() + ptrDepth++ + } + val := reflect.New(t).Elem() + return val, ptrDepth +} diff --git a/pkg/app/server/binding_v2/slice_type_decoder.go b/pkg/app/server/binding_v2/slice_type_decoder.go index 420bc6cba..55805650c 100644 --- a/pkg/app/server/binding_v2/slice_type_decoder.go +++ b/pkg/app/server/binding_v2/slice_type_decoder.go @@ -5,17 +5,20 @@ import ( "fmt" "reflect" + "github.com/cloudwego/hertz/pkg/common/utils" + "github.com/cloudwego/hertz/internal/bytesconv" "github.com/cloudwego/hertz/pkg/app/server/binding_v2/text_decoder" "github.com/cloudwego/hertz/pkg/protocol" ) type sliceTypeFieldTextDecoder struct { - index int - fieldName string - isArray bool - tagInfos []TagInfo // query,param,header,respHeader ... - fieldType reflect.Type + index int + parentIndex []int + fieldName string + isArray bool + tagInfos []TagInfo // query,param,header,respHeader ... + fieldType reflect.Type } func (d *sliceTypeFieldTextDecoder) Decode(req *protocol.Request, params PathParams, reqValue reflect.Value) error { @@ -24,6 +27,11 @@ func (d *sliceTypeFieldTextDecoder) Decode(req *protocol.Request, params PathPar if tagInfo.Key == jsonTag { continue } + if tagInfo.Key == headerTag { + tmp := []byte(tagInfo.Value) + utils.NormalizeHeaderKey(tmp, req.Header.IsDisableNormalizing()) + tagInfo.Value = string(tmp) + } texts = tagInfo.Getter(req, params, tagInfo.Value) // todo: 数组默认值 // defaultValue = tagInfo.Default @@ -35,6 +43,27 @@ func (d *sliceTypeFieldTextDecoder) Decode(req *protocol.Request, params PathPar return nil } + // todo 多重指针 + for _, idx := range d.parentIndex { + if reqValue.Kind() == reflect.Ptr && reqValue.IsNil() { + nonNilVal, ptrDepth := GetNonNilReferenceValue(reqValue) + reqValue.Set(ReferenceValue(nonNilVal, ptrDepth)) + } + for reqValue.Kind() == reflect.Ptr { + reqValue = reqValue.Elem() + } + reqValue = reqValue.Field(idx) + } + + // 父 struct 有可能也是一个指针,所以需要再处理一次才能得到最终的父Value(非nil的reflect.Value) + for reqValue.Kind() == reflect.Ptr { + if reqValue.IsNil() { + nonNilVal, ptrDepth := GetNonNilReferenceValue(reqValue) + reqValue.Set(ReferenceValue(nonNilVal, ptrDepth)) + } + reqValue = reqValue.Elem() + } + field := reqValue.Field(d.index) if d.isArray { @@ -71,7 +100,7 @@ func (d *sliceTypeFieldTextDecoder) Decode(req *protocol.Request, params PathPar // 数组/切片类型的decoder, // 对于map和struct类型的数组元素直接使用unmarshal,不做嵌套处理 -func getSliceFieldDecoder(field reflect.StructField, index int, tagInfos []TagInfo) ([]decoder, error) { +func getSliceFieldDecoder(field reflect.StructField, index int, tagInfos []TagInfo, parentIdx []int) ([]decoder, error) { if !(field.Type.Kind() == reflect.Slice || field.Type.Kind() == reflect.Array) { return nil, fmt.Errorf("unexpected type %s, expected slice or array", field.Type.String()) } @@ -100,16 +129,17 @@ func getSliceFieldDecoder(field reflect.StructField, index int, tagInfos []TagIn } fieldType := field.Type - if field.Type.Kind() == reflect.Ptr { + for field.Type.Kind() == reflect.Ptr { fieldType = field.Type.Elem() } fieldDecoder := &sliceTypeFieldTextDecoder{ - index: index, - fieldName: field.Name, - tagInfos: tagInfos, - fieldType: fieldType, - isArray: isArray, + index: index, + parentIndex: parentIdx, + fieldName: field.Name, + tagInfos: tagInfos, + fieldType: fieldType, + isArray: isArray, } return []decoder{fieldDecoder}, nil From 65df9bc028b96d53b845e35d3a40d7ebfe5c49c1 Mon Sep 17 00:00:00 2001 From: fgy Date: Tue, 17 Jan 2023 15:22:03 +0800 Subject: [PATCH 03/91] refactor: refactor code style --- .../server/binding_v2/base_type_decoder.go | 49 +++++++----------- pkg/app/server/binding_v2/binder.go | 2 +- .../binding_v2/customized_type_decoder.go | 27 +--------- pkg/app/server/binding_v2/decoder.go | 15 +++--- pkg/app/server/binding_v2/map_type_decoder.go | 46 ++++------------- pkg/app/server/binding_v2/reflect.go | 24 +++++++++ .../server/binding_v2/slice_type_decoder.go | 51 +++++-------------- pkg/common/utils/utils.go | 6 +++ 8 files changed, 81 insertions(+), 139 deletions(-) diff --git a/pkg/app/server/binding_v2/base_type_decoder.go b/pkg/app/server/binding_v2/base_type_decoder.go index a32e34764..75aa2d42c 100644 --- a/pkg/app/server/binding_v2/base_type_decoder.go +++ b/pkg/app/server/binding_v2/base_type_decoder.go @@ -9,13 +9,17 @@ import ( "github.com/cloudwego/hertz/pkg/protocol" ) -type baseTypeFieldTextDecoder struct { +type fieldInfo struct { index int parentIndex []int fieldName string tagInfos []TagInfo // query,param,header,respHeader ... fieldType reflect.Type - decoder text_decoder.TextDecoder +} + +type baseTypeFieldTextDecoder struct { + fieldInfo + decoder text_decoder.TextDecoder } func (d *baseTypeFieldTextDecoder) Decode(req *protocol.Request, params PathParams, reqValue reflect.Value) error { @@ -27,9 +31,7 @@ func (d *baseTypeFieldTextDecoder) Decode(req *protocol.Request, params PathPara continue } if tagInfo.Key == headerTag { - tmp := []byte(tagInfo.Value) - utils.NormalizeHeaderKey(tmp, req.Header.IsDisableNormalizing()) - tagInfo.Value = string(tmp) + tagInfo.Value = utils.GetNormalizeHeaderKey(tagInfo.Value, req.Header.IsDisableNormalizing()) } ret := tagInfo.Getter(req, params, tagInfo.Value) defaultValue = tagInfo.Default @@ -47,27 +49,8 @@ func (d *baseTypeFieldTextDecoder) Decode(req *protocol.Request, params PathPara } var err error - // 找到该 field 的父 struct 的 reflect.Value - for _, idx := range d.parentIndex { - if reqValue.Kind() == reflect.Ptr && reqValue.IsNil() { - nonNilVal, ptrDepth := GetNonNilReferenceValue(reqValue) - reqValue.Set(ReferenceValue(nonNilVal, ptrDepth)) - } - for reqValue.Kind() == reflect.Ptr { - reqValue = reqValue.Elem() - } - reqValue = reqValue.Field(idx) - } - - // 父 struct 有可能也是一个指针,所以需要再处理一次才能得到最终的父Value(非nil的reflect.Value) - for reqValue.Kind() == reflect.Ptr { - if reqValue.IsNil() { - nonNilVal, ptrDepth := GetNonNilReferenceValue(reqValue) - reqValue.Set(ReferenceValue(nonNilVal, ptrDepth)) - } - reqValue = reqValue.Elem() - } - + // 得到该field的非nil值 + reqValue = GetFieldValue(reqValue, d.parentIndex) // 根据最终的 Struct,获取对应 field 的 reflect.Value field := reqValue.Field(d.index) if field.Kind() == reflect.Ptr { @@ -128,12 +111,14 @@ func getBaseTypeTextDecoder(field reflect.StructField, index int, tagInfos []Tag } fieldDecoder := &baseTypeFieldTextDecoder{ - index: index, - parentIndex: parentIdx, - fieldName: field.Name, - tagInfos: tagInfos, - decoder: textDecoder, - fieldType: fieldType, + fieldInfo: fieldInfo{ + index: index, + parentIndex: parentIdx, + fieldName: field.Name, + tagInfos: tagInfos, + fieldType: fieldType, + }, + decoder: textDecoder, } return []decoder{fieldDecoder}, nil diff --git a/pkg/app/server/binding_v2/binder.go b/pkg/app/server/binding_v2/binder.go index cdde45838..16757879b 100644 --- a/pkg/app/server/binding_v2/binder.go +++ b/pkg/app/server/binding_v2/binder.go @@ -38,7 +38,7 @@ func (b *Bind) Bind(req *protocol.Request, params PathParams, v interface{}) err if rv.Kind() != reflect.Pointer || rv.IsNil() { return fmt.Errorf("receiver must be a non-nil pointer") } - if rv.Kind() == reflect.Map { + if rv.Elem().Kind() == reflect.Map { return nil } cached, ok := b.decoderCache.Load(typeID) diff --git a/pkg/app/server/binding_v2/customized_type_decoder.go b/pkg/app/server/binding_v2/customized_type_decoder.go index aa61512d0..302ea61e4 100644 --- a/pkg/app/server/binding_v2/customized_type_decoder.go +++ b/pkg/app/server/binding_v2/customized_type_decoder.go @@ -7,10 +7,7 @@ import ( ) type customizedFieldTextDecoder struct { - index int - parentIndex []int - fieldName string - fieldType reflect.Type + fieldInfo } func (d *customizedFieldTextDecoder) Decode(req *protocol.Request, params PathParams, reqValue reflect.Value) error { @@ -22,27 +19,7 @@ func (d *customizedFieldTextDecoder) Decode(req *protocol.Request, params PathPa return err } - // 找到该 field 的父 struct 的 reflect.Value - for _, idx := range d.parentIndex { - if reqValue.Kind() == reflect.Ptr && reqValue.IsNil() { - nonNilVal, ptrDepth := GetNonNilReferenceValue(reqValue) - reqValue.Set(ReferenceValue(nonNilVal, ptrDepth)) - } - for reqValue.Kind() == reflect.Ptr { - reqValue = reqValue.Elem() - } - reqValue = reqValue.Field(idx) - } - - // 父 struct 有可能也是一个指针,所以需要再处理一次才能得到最终的父Value(非nil的reflect.Value) - for reqValue.Kind() == reflect.Ptr { - if reqValue.IsNil() { - nonNilVal, ptrDepth := GetNonNilReferenceValue(reqValue) - reqValue.Set(ReferenceValue(nonNilVal, ptrDepth)) - } - reqValue = reqValue.Elem() - } - + reqValue = GetFieldValue(reqValue, d.parentIndex) field := reqValue.Field(d.index) if field.Kind() == reflect.Ptr { // 如果是指针则新建一个reflect.Value,然后赋值给指针 diff --git a/pkg/app/server/binding_v2/decoder.go b/pkg/app/server/binding_v2/decoder.go index 614ec0fa6..3a5f42f85 100644 --- a/pkg/app/server/binding_v2/decoder.go +++ b/pkg/app/server/binding_v2/decoder.go @@ -62,10 +62,13 @@ func getFieldDecoder(field reflect.StructField, index int, parentIdx []int) ([]d } if reflect.PtrTo(field.Type).Implements(fieldDecoderType) { return []decoder{&customizedFieldTextDecoder{ - index: index, - parentIndex: parentIdx, - fieldName: field.Name, - fieldType: field.Type}}, nil + fieldInfo: fieldInfo{ + index: index, + parentIndex: parentIdx, + fieldName: field.Name, + fieldType: field.Type, + }, + }}, nil } fieldTagInfos := lookupFieldTags(field) @@ -75,10 +78,6 @@ func getFieldDecoder(field reflect.StructField, index int, parentIdx []int) ([]d // return nil, nil //} - // todo: 用户自定义text信息解析 - //if reflect.PtrTo(field.Type).Implements(textUnmarshalerType) { - // return compileTextBasedDecoder(field, index, tagScope, tagContent) - //} if field.Type.Kind() == reflect.Slice || field.Type.Kind() == reflect.Array { return getSliceFieldDecoder(field, index, fieldTagInfos, parentIdx) } diff --git a/pkg/app/server/binding_v2/map_type_decoder.go b/pkg/app/server/binding_v2/map_type_decoder.go index 914eb7291..39cad45b3 100644 --- a/pkg/app/server/binding_v2/map_type_decoder.go +++ b/pkg/app/server/binding_v2/map_type_decoder.go @@ -12,11 +12,7 @@ import ( ) type mapTypeFieldTextDecoder struct { - index int - parentIndex []int - fieldName string - tagInfos []TagInfo // query,param,header,respHeader ... - fieldType reflect.Type + fieldInfo } func (d *mapTypeFieldTextDecoder) Decode(req *protocol.Request, params PathParams, reqValue reflect.Value) error { @@ -28,14 +24,12 @@ func (d *mapTypeFieldTextDecoder) Decode(req *protocol.Request, params PathParam continue } if tagInfo.Key == headerTag { - tmp := []byte(tagInfo.Value) - utils.NormalizeHeaderKey(tmp, req.Header.IsDisableNormalizing()) - tagInfo.Value = string(tmp) + tagInfo.Value = utils.GetNormalizeHeaderKey(tagInfo.Value, req.Header.IsDisableNormalizing()) } ret := tagInfo.Getter(req, params, tagInfo.Value) defaultValue = tagInfo.Default if len(ret) != 0 { - // 非数组/切片类型,只取第一个值作为只 + // 非数组/切片类型,只取第一个值作为值 text = ret[0] break } @@ -47,27 +41,7 @@ func (d *mapTypeFieldTextDecoder) Decode(req *protocol.Request, params PathParam return nil } - // todo 多重指针 - for _, idx := range d.parentIndex { - if reqValue.Kind() == reflect.Ptr && reqValue.IsNil() { - nonNilVal, ptrDepth := GetNonNilReferenceValue(reqValue) - reqValue.Set(ReferenceValue(nonNilVal, ptrDepth)) - } - for reqValue.Kind() == reflect.Ptr { - reqValue = reqValue.Elem() - } - reqValue = reqValue.Field(idx) - } - - // 父 struct 有可能也是一个指针,所以需要再处理一次才能得到最终的父Value(非nil的reflect.Value) - for reqValue.Kind() == reflect.Ptr { - if reqValue.IsNil() { - nonNilVal, ptrDepth := GetNonNilReferenceValue(reqValue) - reqValue.Set(ReferenceValue(nonNilVal, ptrDepth)) - } - reqValue = reqValue.Elem() - } - + reqValue = GetFieldValue(reqValue, d.parentIndex) field := reqValue.Field(d.index) if field.Kind() == reflect.Ptr { // 如果是指针则新建一个reflect.Value,然后赋值给指针 @@ -120,11 +94,13 @@ func getMapTypeTextDecoder(field reflect.StructField, index int, tagInfos []TagI fieldType = field.Type.Elem() } fieldDecoder := &mapTypeFieldTextDecoder{ - index: index, - parentIndex: parentIdx, - fieldName: field.Name, - tagInfos: tagInfos, - fieldType: fieldType, + fieldInfo: fieldInfo{ + index: index, + parentIndex: parentIdx, + fieldName: field.Name, + tagInfos: tagInfos, + fieldType: fieldType, + }, } return []decoder{fieldDecoder}, nil diff --git a/pkg/app/server/binding_v2/reflect.go b/pkg/app/server/binding_v2/reflect.go index aa8d91aee..79034aba8 100644 --- a/pkg/app/server/binding_v2/reflect.go +++ b/pkg/app/server/binding_v2/reflect.go @@ -46,3 +46,27 @@ func GetNonNilReferenceValue(v reflect.Value) (reflect.Value, int) { val := reflect.New(t).Elem() return val, ptrDepth } + +func GetFieldValue(reqValue reflect.Value, parentIndex []int) reflect.Value { + for _, idx := range parentIndex { + if reqValue.Kind() == reflect.Ptr && reqValue.IsNil() { + nonNilVal, ptrDepth := GetNonNilReferenceValue(reqValue) + reqValue.Set(ReferenceValue(nonNilVal, ptrDepth)) + } + for reqValue.Kind() == reflect.Ptr { + reqValue = reqValue.Elem() + } + reqValue = reqValue.Field(idx) + } + + // 父 struct 有可能也是一个指针,所以需要再处理一次才能得到最终的父Value(非nil的reflect.Value) + for reqValue.Kind() == reflect.Ptr { + if reqValue.IsNil() { + nonNilVal, ptrDepth := GetNonNilReferenceValue(reqValue) + reqValue.Set(ReferenceValue(nonNilVal, ptrDepth)) + } + reqValue = reqValue.Elem() + } + + return reqValue +} diff --git a/pkg/app/server/binding_v2/slice_type_decoder.go b/pkg/app/server/binding_v2/slice_type_decoder.go index 55805650c..21939fe82 100644 --- a/pkg/app/server/binding_v2/slice_type_decoder.go +++ b/pkg/app/server/binding_v2/slice_type_decoder.go @@ -5,20 +5,15 @@ import ( "fmt" "reflect" - "github.com/cloudwego/hertz/pkg/common/utils" - "github.com/cloudwego/hertz/internal/bytesconv" "github.com/cloudwego/hertz/pkg/app/server/binding_v2/text_decoder" + "github.com/cloudwego/hertz/pkg/common/utils" "github.com/cloudwego/hertz/pkg/protocol" ) type sliceTypeFieldTextDecoder struct { - index int - parentIndex []int - fieldName string - isArray bool - tagInfos []TagInfo // query,param,header,respHeader ... - fieldType reflect.Type + fieldInfo + isArray bool } func (d *sliceTypeFieldTextDecoder) Decode(req *protocol.Request, params PathParams, reqValue reflect.Value) error { @@ -28,9 +23,7 @@ func (d *sliceTypeFieldTextDecoder) Decode(req *protocol.Request, params PathPar continue } if tagInfo.Key == headerTag { - tmp := []byte(tagInfo.Value) - utils.NormalizeHeaderKey(tmp, req.Header.IsDisableNormalizing()) - tagInfo.Value = string(tmp) + tagInfo.Value = utils.GetNormalizeHeaderKey(tagInfo.Value, req.Header.IsDisableNormalizing()) } texts = tagInfo.Getter(req, params, tagInfo.Value) // todo: 数组默认值 @@ -43,27 +36,7 @@ func (d *sliceTypeFieldTextDecoder) Decode(req *protocol.Request, params PathPar return nil } - // todo 多重指针 - for _, idx := range d.parentIndex { - if reqValue.Kind() == reflect.Ptr && reqValue.IsNil() { - nonNilVal, ptrDepth := GetNonNilReferenceValue(reqValue) - reqValue.Set(ReferenceValue(nonNilVal, ptrDepth)) - } - for reqValue.Kind() == reflect.Ptr { - reqValue = reqValue.Elem() - } - reqValue = reqValue.Field(idx) - } - - // 父 struct 有可能也是一个指针,所以需要再处理一次才能得到最终的父Value(非nil的reflect.Value) - for reqValue.Kind() == reflect.Ptr { - if reqValue.IsNil() { - nonNilVal, ptrDepth := GetNonNilReferenceValue(reqValue) - reqValue.Set(ReferenceValue(nonNilVal, ptrDepth)) - } - reqValue = reqValue.Elem() - } - + reqValue = GetFieldValue(reqValue, d.parentIndex) field := reqValue.Field(d.index) if d.isArray { @@ -134,12 +107,14 @@ func getSliceFieldDecoder(field reflect.StructField, index int, tagInfos []TagIn } fieldDecoder := &sliceTypeFieldTextDecoder{ - index: index, - parentIndex: parentIdx, - fieldName: field.Name, - tagInfos: tagInfos, - fieldType: fieldType, - isArray: isArray, + fieldInfo: fieldInfo{ + index: index, + parentIndex: parentIdx, + fieldName: field.Name, + tagInfos: tagInfos, + fieldType: fieldType, + }, + isArray: isArray, } return []decoder{fieldDecoder}, nil diff --git a/pkg/common/utils/utils.go b/pkg/common/utils/utils.go index f002f2964..b90132bb5 100644 --- a/pkg/common/utils/utils.go +++ b/pkg/common/utils/utils.go @@ -82,6 +82,12 @@ func CaseInsensitiveCompare(a, b []byte) bool { return true } +func GetNormalizeHeaderKey(key string, disableNormalizing bool) string { + keyBytes := []byte(key) + NormalizeHeaderKey(keyBytes, disableNormalizing) + return string(keyBytes) +} + func NormalizeHeaderKey(b []byte, disableNormalizing bool) { if disableNormalizing { return From 1f4da2d025706556b480750c68f897daaa8fd4ad Mon Sep 17 00:00:00 2001 From: fgy Date: Wed, 18 Jan 2023 20:13:05 +0800 Subject: [PATCH 04/91] feat: add default tag --- pkg/app/server/binding_v2/binder_test.go | 28 +++++++++++++++++++++++- pkg/app/server/binding_v2/decoder.go | 7 +++--- pkg/app/server/binding_v2/getter.go | 10 +++++++-- pkg/app/server/binding_v2/tag.go | 9 ++++++++ 4 files changed, 47 insertions(+), 7 deletions(-) diff --git a/pkg/app/server/binding_v2/binder_test.go b/pkg/app/server/binding_v2/binder_test.go index 20d4aece3..e39dc8020 100644 --- a/pkg/app/server/binding_v2/binder_test.go +++ b/pkg/app/server/binding_v2/binder_test.go @@ -299,7 +299,7 @@ func TestBind_UnexportedField(t *testing.T) { } bind := Bind{} req := newMockRequest(). - SetRequestURI("http://foobar.com?a=1&b=2}") + SetRequestURI("http://foobar.com?a=1&b=2") err := bind.Bind(req.Req, nil, &s) if err != nil { t.Fatal(err) @@ -308,6 +308,32 @@ func TestBind_UnexportedField(t *testing.T) { assert.DeepEqual(t, 0, s.b) } +func TestBind_NoTagField(t *testing.T) { + var s struct { + A string + B string + C string + } + bind := Bind{} + req := newMockRequest(). + SetRequestURI("http://foobar.com?B=b1&C=c1"). + SetHeader("A", "a2") + + var params param.Params + params = append(params, param.Param{ + Key: "B", + Value: "b2", + }) + + err := bind.Bind(req.Req, params, &s) + if err != nil { + t.Fatal(err) + } + assert.DeepEqual(t, "a2", s.A) + assert.DeepEqual(t, "b2", s.B) + assert.DeepEqual(t, "c1", s.C) +} + func TestBind_TypedefType(t *testing.T) { type Foo string type Bar *int diff --git a/pkg/app/server/binding_v2/decoder.go b/pkg/app/server/binding_v2/decoder.go index 3a5f42f85..8fc907ea3 100644 --- a/pkg/app/server/binding_v2/decoder.go +++ b/pkg/app/server/binding_v2/decoder.go @@ -73,10 +73,9 @@ func getFieldDecoder(field reflect.StructField, index int, parentIdx []int) ([]d fieldTagInfos := lookupFieldTags(field) // todo: 没有 tag 也不直接返回 - //if len(fieldTagInfos) == 0 { - // // todo: 如果没定义尝试给其赋值所有 tag - // return nil, nil - //} + if len(fieldTagInfos) == 0 { + fieldTagInfos = getDefaultFieldTags(field) + } if field.Type.Kind() == reflect.Slice || field.Type.Kind() == reflect.Array { return getSliceFieldDecoder(field, index, fieldTagInfos, parentIdx) diff --git a/pkg/app/server/binding_v2/getter.go b/pkg/app/server/binding_v2/getter.go index 75ace6117..895d1aa61 100644 --- a/pkg/app/server/binding_v2/getter.go +++ b/pkg/app/server/binding_v2/getter.go @@ -11,12 +11,18 @@ type getter func(req *protocol.Request, params PathParams, key string, defaultVa // todo string 强转优化 func PathParam(req *protocol.Request, params PathParams, key string, defaultValue ...string) (ret []string) { - value, _ := params.Get(key) + var value string + if params != nil { + value, _ = params.Get(key) + } if len(value) == 0 && len(defaultValue) != 0 { value = defaultValue[0] } - ret = append(ret, value) + if len(value) != 0 { + ret = append(ret, value) + } + return } diff --git a/pkg/app/server/binding_v2/tag.go b/pkg/app/server/binding_v2/tag.go index 72dcc0af3..f3d91f567 100644 --- a/pkg/app/server/binding_v2/tag.go +++ b/pkg/app/server/binding_v2/tag.go @@ -64,3 +64,12 @@ func lookupFieldTags(field reflect.StructField) []TagInfo { return tagInfos } + +func getDefaultFieldTags(field reflect.StructField) (tagInfos []TagInfo) { + tags := []string{pathTag, formTag, queryTag, cookieTag, headerTag, jsonTag} + for _, tag := range tags { + tagInfos = append(tagInfos, TagInfo{Key: tag, Value: field.Name}) + } + + return +} From 61da4bd17864775816785eb4a162a47b9b04402a Mon Sep 17 00:00:00 2001 From: fgy Date: Mon, 30 Jan 2023 15:49:23 +0800 Subject: [PATCH 05/91] feat: add default value --- pkg/app/server/binding_v2/binder_test.go | 47 +++++++++++++++++++ pkg/app/server/binding_v2/map_type_decoder.go | 1 - .../server/binding_v2/slice_type_decoder.go | 6 ++- pkg/app/server/binding_v2/tag.go | 20 ++++++-- 4 files changed, 69 insertions(+), 5 deletions(-) diff --git a/pkg/app/server/binding_v2/binder_test.go b/pkg/app/server/binding_v2/binder_test.go index e39dc8020..07b633c31 100644 --- a/pkg/app/server/binding_v2/binder_test.go +++ b/pkg/app/server/binding_v2/binder_test.go @@ -334,6 +334,53 @@ func TestBind_NoTagField(t *testing.T) { assert.DeepEqual(t, "c1", s.C) } +func TestBind_ZeroValueBind(t *testing.T) { + var s struct { + A int `query:"a"` + B float64 `query:"b"` + } + bind := Bind{} + req := newMockRequest(). + SetRequestURI("http://foobar.com?a=&b") + + err := bind.Bind(req.Req, nil, &s) + if err != nil { + t.Fatal(err) + } + assert.DeepEqual(t, 0, s.A) + assert.DeepEqual(t, float64(0), s.B) +} + +func TestBind_DefaultValueBind(t *testing.T) { + var s struct { + A int `default:"15"` + B float64 `query:"b" default:"17"` + C []int `default:"15"` + D []string `default:"qwe"` + } + bind := Bind{} + req := newMockRequest(). + SetRequestURI("http://foobar.com") + + err := bind.Bind(req.Req, nil, &s) + if err != nil { + t.Fatal(err) + } + assert.DeepEqual(t, 15, s.A) + assert.DeepEqual(t, float64(17), s.B) + assert.DeepEqual(t, 15, s.C[0]) + assert.DeepEqual(t, "qwe", s.D[0]) + + var d struct { + D [2]string `default:"qwe"` + } + + err = bind.Bind(req.Req, nil, &d) + if err == nil { + t.Fatal("expected err") + } +} + func TestBind_TypedefType(t *testing.T) { type Foo string type Bar *int diff --git a/pkg/app/server/binding_v2/map_type_decoder.go b/pkg/app/server/binding_v2/map_type_decoder.go index 39cad45b3..b4b3e2484 100644 --- a/pkg/app/server/binding_v2/map_type_decoder.go +++ b/pkg/app/server/binding_v2/map_type_decoder.go @@ -7,7 +7,6 @@ import ( "github.com/cloudwego/hertz/internal/bytesconv" "github.com/cloudwego/hertz/pkg/common/utils" - "github.com/cloudwego/hertz/pkg/protocol" ) diff --git a/pkg/app/server/binding_v2/slice_type_decoder.go b/pkg/app/server/binding_v2/slice_type_decoder.go index 21939fe82..a091115c8 100644 --- a/pkg/app/server/binding_v2/slice_type_decoder.go +++ b/pkg/app/server/binding_v2/slice_type_decoder.go @@ -18,6 +18,7 @@ type sliceTypeFieldTextDecoder struct { func (d *sliceTypeFieldTextDecoder) Decode(req *protocol.Request, params PathParams, reqValue reflect.Value) error { var texts []string + var defaultValue string for _, tagInfo := range d.tagInfos { if tagInfo.Key == jsonTag { continue @@ -27,11 +28,14 @@ func (d *sliceTypeFieldTextDecoder) Decode(req *protocol.Request, params PathPar } texts = tagInfo.Getter(req, params, tagInfo.Value) // todo: 数组默认值 - // defaultValue = tagInfo.Default + defaultValue = tagInfo.Default if len(texts) != 0 { break } } + if len(texts) == 0 && len(defaultValue) != 0 { + texts = append(texts, defaultValue) + } if len(texts) == 0 { return nil } diff --git a/pkg/app/server/binding_v2/tag.go b/pkg/app/server/binding_v2/tag.go index f3d91f567..90e625abd 100644 --- a/pkg/app/server/binding_v2/tag.go +++ b/pkg/app/server/binding_v2/tag.go @@ -15,6 +15,10 @@ const ( rawBodyTag = "raw_body" ) +const ( + defaultTag = "default" +) + const ( requiredTagOpt = "required" ) @@ -44,8 +48,13 @@ func lookupFieldTags(field reflect.StructField) []TagInfo { ret = append(ret, tag) } } - var tagInfos []TagInfo + defaultVal := "" + if val, ok := field.Tag.Lookup(defaultTag); ok { + defaultVal = val + } + + var tagInfos []TagInfo for _, tag := range ret { tagContent := field.Tag.Get(tag) tagValue, opts := head(tagContent, ",") @@ -59,16 +68,21 @@ func lookupFieldTags(field reflect.StructField) []TagInfo { required = true } } - tagInfos = append(tagInfos, TagInfo{Key: tag, Value: tagValue, Options: options, Required: required}) + tagInfos = append(tagInfos, TagInfo{Key: tag, Value: tagValue, Options: options, Required: required, Default: defaultVal}) } return tagInfos } func getDefaultFieldTags(field reflect.StructField) (tagInfos []TagInfo) { + defaultVal := "" + if val, ok := field.Tag.Lookup(defaultTag); ok { + defaultVal = val + } + tags := []string{pathTag, formTag, queryTag, cookieTag, headerTag, jsonTag} for _, tag := range tags { - tagInfos = append(tagInfos, TagInfo{Key: tag, Value: field.Name}) + tagInfos = append(tagInfos, TagInfo{Key: tag, Value: field.Name, Default: defaultVal}) } return From f287cdd99c6ca95ac63f59f3c2741efe3270e4f1 Mon Sep 17 00:00:00 2001 From: fgy Date: Mon, 30 Jan 2023 17:21:53 +0800 Subject: [PATCH 06/91] feat: add required validate --- .../server/binding_v2/base_type_decoder.go | 9 ++++++- pkg/app/server/binding_v2/binder_test.go | 24 +++++++++++++++++++ 2 files changed, 32 insertions(+), 1 deletion(-) diff --git a/pkg/app/server/binding_v2/base_type_decoder.go b/pkg/app/server/binding_v2/base_type_decoder.go index 75aa2d42c..57e7aa958 100644 --- a/pkg/app/server/binding_v2/base_type_decoder.go +++ b/pkg/app/server/binding_v2/base_type_decoder.go @@ -23,6 +23,7 @@ type baseTypeFieldTextDecoder struct { } func (d *baseTypeFieldTextDecoder) Decode(req *protocol.Request, params PathParams, reqValue reflect.Value) error { + var err error var text string var defaultValue string // 最大努力交付,对齐 hertz 现有设计 @@ -38,8 +39,15 @@ func (d *baseTypeFieldTextDecoder) Decode(req *protocol.Request, params PathPara if len(ret) != 0 { // 非数组/切片类型,只取第一个值作为只 text = ret[0] + err = nil break } + if tagInfo.Required { + err = fmt.Errorf("'%s' field is a 'required' parameter, but the request does not have this parameter", d.fieldName) + } + } + if err != nil { + return err } if len(text) == 0 && len(defaultValue) != 0 { text = defaultValue @@ -48,7 +56,6 @@ func (d *baseTypeFieldTextDecoder) Decode(req *protocol.Request, params PathPara return nil } - var err error // 得到该field的非nil值 reqValue = GetFieldValue(reqValue, d.parentIndex) // 根据最终的 Struct,获取对应 field 的 reflect.Value diff --git a/pkg/app/server/binding_v2/binder_test.go b/pkg/app/server/binding_v2/binder_test.go index 07b633c31..81caad931 100644 --- a/pkg/app/server/binding_v2/binder_test.go +++ b/pkg/app/server/binding_v2/binder_test.go @@ -381,6 +381,30 @@ func TestBind_DefaultValueBind(t *testing.T) { } } +func TestBind_RequiredBind(t *testing.T) { + var s struct { + A int `query:"a,required"` + } + bind := Bind{} + req := newMockRequest(). + SetRequestURI("http://foobar.com"). + SetHeader("A", "1") + + err := bind.Bind(req.Req, nil, &s) + if err == nil { + t.Fatal("expected error") + } + + var d struct { + A int `query:"a,required" header:"A"` + } + err = bind.Bind(req.Req, nil, &d) + if err != nil { + t.Fatal(err) + } + assert.DeepEqual(t, 1, d.A) +} + func TestBind_TypedefType(t *testing.T) { type Foo string type Bar *int From 260db67e68166b61db37e5f0cb1654cd475afa5a Mon Sep 17 00:00:00 2001 From: fgy Date: Mon, 30 Jan 2023 20:06:29 +0800 Subject: [PATCH 07/91] feat: config json unmarshaler --- pkg/app/server/binding_v2/binder.go | 6 ++-- pkg/app/server/binding_v2/binder_test.go | 30 +++++++++++++++++++ pkg/app/server/binding_v2/json.go | 25 ++++++++++++++++ pkg/app/server/binding_v2/map_type_decoder.go | 3 +- .../server/binding_v2/slice_type_decoder.go | 5 ++-- 5 files changed, 61 insertions(+), 8 deletions(-) create mode 100644 pkg/app/server/binding_v2/json.go diff --git a/pkg/app/server/binding_v2/binder.go b/pkg/app/server/binding_v2/binder.go index 16757879b..dbdd4a0f8 100644 --- a/pkg/app/server/binding_v2/binder.go +++ b/pkg/app/server/binding_v2/binder.go @@ -1,11 +1,11 @@ package binding_v2 import ( - "encoding/json" "fmt" "reflect" "sync" + "github.com/cloudwego/hertz/internal/bytesconv" "github.com/cloudwego/hertz/pkg/protocol" "google.golang.org/protobuf/proto" ) @@ -67,10 +67,10 @@ func (b *Bind) PreBindBody(req *protocol.Request, v interface{}) error { if req.Header.ContentLength() <= 0 { return nil } - switch string(req.Header.ContentType()) { + switch bytesconv.B2s(req.Header.ContentType()) { case jsonContentTypeBytes: // todo: 对齐gin, 添加 "EnableDecoderUseNumber"/"EnableDecoderDisallowUnknownFields" 接口 - return json.Unmarshal(req.Body(), v) + return jsonUnmarshalFunc(req.Body(), v) case protobufContentType: msg, ok := v.(proto.Message) if !ok { diff --git a/pkg/app/server/binding_v2/binder_test.go b/pkg/app/server/binding_v2/binder_test.go index 81caad931..1c9df21d2 100644 --- a/pkg/app/server/binding_v2/binder_test.go +++ b/pkg/app/server/binding_v2/binder_test.go @@ -501,6 +501,36 @@ func TestBind_JSON(t *testing.T) { } } +func TestBind_ResetJSONUnmarshal(t *testing.T) { + ResetStdJSONUnmarshaler() + bind := Bind{} + type Req struct { + J1 string `json:"j1"` + J2 int `json:"j2"` + J3 []byte `json:"j3"` + J4 [2]string `json:"j4"` + } + J3s := []byte("12") + J4s := [2]string{"qwe", "asd"} + + req := newMockRequest(). + SetJSONContentType(). + SetBody([]byte(fmt.Sprintf(`{"j1":"j1", "j2":12, "j3":[%d, %d], "j4":["%s", "%s"]}`, J3s[0], J3s[1], J4s[0], J4s[1]))) + var result Req + err := bind.Bind(req.Req, nil, &result) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, "j1", result.J1) + assert.DeepEqual(t, 12, result.J2) + for idx, val := range J3s { + assert.DeepEqual(t, val, result.J3[idx]) + } + for idx, val := range J4s { + assert.DeepEqual(t, val, result.J4[idx]) + } +} + func Benchmark_V2(b *testing.B) { bind := Bind{} type Req struct { diff --git a/pkg/app/server/binding_v2/json.go b/pkg/app/server/binding_v2/json.go new file mode 100644 index 000000000..049cc7456 --- /dev/null +++ b/pkg/app/server/binding_v2/json.go @@ -0,0 +1,25 @@ +package binding_v2 + +import ( + "encoding/json" + + hjson "github.com/cloudwego/hertz/pkg/common/json" +) + +// JSONUnmarshaler is the interface implemented by types +// that can unmarshal a JSON description of themselves. +type JSONUnmarshaler func(data []byte, v interface{}) error + +var jsonUnmarshalFunc JSONUnmarshaler + +func init() { + ResetJSONUnmarshaler(hjson.Unmarshal) +} + +func ResetJSONUnmarshaler(fn JSONUnmarshaler) { + jsonUnmarshalFunc = fn +} + +func ResetStdJSONUnmarshaler() { + ResetJSONUnmarshaler(json.Unmarshal) +} diff --git a/pkg/app/server/binding_v2/map_type_decoder.go b/pkg/app/server/binding_v2/map_type_decoder.go index b4b3e2484..bfdc993c1 100644 --- a/pkg/app/server/binding_v2/map_type_decoder.go +++ b/pkg/app/server/binding_v2/map_type_decoder.go @@ -1,7 +1,6 @@ package binding_v2 import ( - "encoding/json" "fmt" "reflect" @@ -59,7 +58,7 @@ func (d *mapTypeFieldTextDecoder) Decode(req *protocol.Request, params PathParam return nil } - err := json.Unmarshal(bytesconv.S2b(text), field.Addr().Interface()) + err := jsonUnmarshalFunc(bytesconv.S2b(text), field.Addr().Interface()) if err != nil { return fmt.Errorf("unable to decode '%s' as %s: %w", text, d.fieldType.Name(), err) } diff --git a/pkg/app/server/binding_v2/slice_type_decoder.go b/pkg/app/server/binding_v2/slice_type_decoder.go index a091115c8..74ac884c2 100644 --- a/pkg/app/server/binding_v2/slice_type_decoder.go +++ b/pkg/app/server/binding_v2/slice_type_decoder.go @@ -1,7 +1,6 @@ package binding_v2 import ( - "encoding/json" "fmt" "reflect" @@ -130,9 +129,9 @@ func stringToValue(elemType reflect.Type, text string) (v reflect.Value, err err switch elemType.Kind() { case reflect.Struct: - err = json.Unmarshal(bytesconv.S2b(text), v.Addr().Interface()) + err = jsonUnmarshalFunc(bytesconv.S2b(text), v.Addr().Interface()) case reflect.Map: - err = json.Unmarshal(bytesconv.S2b(text), v.Addr().Interface()) + err = jsonUnmarshalFunc(bytesconv.S2b(text), v.Addr().Interface()) case reflect.Array, reflect.Slice: // do nothing default: From f20613e5fa36ec833e54d11379b5e1255cd920d4 Mon Sep 17 00:00:00 2001 From: fgy Date: Tue, 31 Jan 2023 18:53:04 +0800 Subject: [PATCH 08/91] feat: replace validator --- go.mod | 5 +- go.sum | 58 ++++++++++- pkg/app/context.go | 14 ++- pkg/app/server/binding_v2/binder.go | 68 +------------ pkg/app/server/binding_v2/binder_test.go | 95 ++++++++++--------- pkg/app/server/binding_v2/default_binder.go | 73 ++++++++++++++ .../server/binding_v2/default_validator.go | 54 +++++++++++ pkg/app/server/binding_v2/validator.go | 24 +++++ pkg/app/server/binding_v2/validator_test.go | 31 ++++++ 9 files changed, 305 insertions(+), 117 deletions(-) create mode 100644 pkg/app/server/binding_v2/default_binder.go create mode 100644 pkg/app/server/binding_v2/default_validator.go create mode 100644 pkg/app/server/binding_v2/validator.go create mode 100644 pkg/app/server/binding_v2/validator_test.go diff --git a/go.mod b/go.mod index b5eeb5251..865de6144 100644 --- a/go.mod +++ b/go.mod @@ -9,8 +9,9 @@ require ( github.com/bytedance/sonic v1.8.1 github.com/cloudwego/netpoll v0.4.2-0.20230807055039-52fd5fb7b00f github.com/fsnotify/fsnotify v1.5.4 + github.com/go-playground/validator/v10 v10.11.2 github.com/tidwall/gjson v1.13.0 // indirect - golang.org/x/sync v0.0.0-20210220032951-036812b2e83c - golang.org/x/sys v0.0.0-20220412211240-33da011f77ad + golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 + golang.org/x/sys v0.4.0 google.golang.org/protobuf v1.27.1 ) diff --git a/go.sum b/go.sum index 59e21cf1c..057d986d7 100644 --- a/go.sum +++ b/go.sum @@ -17,6 +17,14 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/fsnotify/fsnotify v1.5.4 h1:jRbGcIw6P2Meqdwuo0H1p6JVLbL5DHKAKlYndzMwVZI= github.com/fsnotify/fsnotify v1.5.4/go.mod h1:OVB6XrOHzAwXMpEM7uPOzcehqUV2UqJxmVXmkdnm1bU= +github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= +github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= +github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= +github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= +github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= +github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= +github.com/go-playground/validator/v10 v10.11.2 h1:q3SHpufmypg+erIExEKUmsgmhDTyhcJ38oeKGACXohU= +github.com/go-playground/validator/v10 v10.11.2/go.mod h1:NieE624vt4SCTJtD87arVLvdmjPAeV8BQlHtMnw9D7s= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/protobuf v1.5.0 h1:LUVKkCeviFUMKqHa4tXIIij/lbhnMbP7Fn5wKdKkRh4= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= @@ -33,19 +41,32 @@ github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7 github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= +github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= +github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= +github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= +github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= +github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= +github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= +github.com/leodido/go-urn v1.2.1 h1:BqpAaACuzVSgi/VLzGZIobT2z4v53pjosyNd9Yv6n/w= +github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY= github.com/nyaruka/phonenumbers v1.0.55 h1:bj0nTO88Y68KeUQ/n3Lo2KgK7lM1hF7L9NFuwcCl3yg= github.com/nyaruka/phonenumbers v1.0.55/go.mod h1:sDaTZ/KPX5f8qyV9qN+hIm+4ZBARJrupC6LuhshJq1U= +github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d h1:zE9ykElWQ6/NYmHa3jpm/yHnI4xSofP+UP6SpjHcSeM= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= github.com/smartystreets/goconvey v1.6.4 h1:fv0U8FUIMPNf1L9lnHLvLhgicrIVChEkdzIKYqbNC9s= github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= +github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= +github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= @@ -60,17 +81,49 @@ github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/arch v0.0.0-20201008161808-52c3e6f60cff/go.mod h1:flIaEI6LNU6xOCD5PaJvn9wGP0agmIOqjrtsKGRguv4= golang.org/x/arch v0.0.0-20210923205945-b76863e36670 h1:18EFjUmQOcUvxNYSkA6jO9VAiXCnxFY6NyDX0bHDmkU= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= +golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= +golang.org/x/crypto v0.5.0 h1:U/0M97KRkSFvyD/3FSmdP5W5swImpNgle/EHFhOsQPE= +golang.org/x/crypto v0.5.0/go.mod h1:NK/OQwhpMQP3MwtdjgLlYHnH9ebylxKWv3e0fK+mkQU= +golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= +golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= +golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= +golang.org/x/net v0.5.0/go.mod h1:DivGGAXEgPSlEBzxGzZI+ZLohi+xUj054jfeKui00ws= +golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 h1:uVc8UZUe6tr40fFVnUP5Oj+veunVezqYl9z7DYw9xzw= +golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20220110181412-a018aaa089fe/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220412211240-33da011f77ad h1:ntjMns5wyP/fN65tdBD4g8J5w8n015+iIIs9rtjXkY0= golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18= +golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.4.0/go.mod h1:9P2UbLfCdcvo3p/nzKvsmas4TnlujnuoV9hGgYzW1lQ= +golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= +golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= +golang.org/x/text v0.6.0 h1:3XmdazWV+ubf7QgHSTWeykHOci5oeekaGJBLkrkaw4k= +golang.org/x/text v0.6.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= +golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= +golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= @@ -79,6 +132,9 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0 google.golang.org/protobuf v1.27.1 h1:SnqbnDw1V7RiZcXPx5MEeqPv2s79L9i7BJUlG/+RurQ= google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= +gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= diff --git a/pkg/app/context.go b/pkg/app/context.go index 9a6e45440..99521f37b 100644 --- a/pkg/app/context.go +++ b/pkg/app/context.go @@ -54,9 +54,10 @@ import ( "sync" "time" + "github.com/cloudwego/hertz/pkg/app/server/binding_v2" + "github.com/cloudwego/hertz/internal/bytesconv" "github.com/cloudwego/hertz/internal/bytestr" - "github.com/cloudwego/hertz/pkg/app/server/binding" "github.com/cloudwego/hertz/pkg/app/server/render" "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/common/tracer/traceinfo" @@ -1305,19 +1306,24 @@ func bodyAllowedForStatus(status int) bool { // BindAndValidate binds data from *RequestContext to obj and validates them if needed. // NOTE: obj should be a pointer. func (ctx *RequestContext) BindAndValidate(obj interface{}) error { - return binding.BindAndValidate(&ctx.Request, obj, ctx.Params) + err := binding_v2.DefaultBinder.Bind(&ctx.Request, ctx.Params, obj) + if err != nil { + return err + } + err = binding_v2.DefaultValidator.ValidateStruct(obj) + return err } // Bind binds data from *RequestContext to obj. // NOTE: obj should be a pointer. func (ctx *RequestContext) Bind(obj interface{}) error { - return binding.Bind(&ctx.Request, obj, ctx.Params) + return binding_v2.DefaultBinder.Bind(&ctx.Request, ctx.Params, obj) } // Validate validates obj with "vd" tag // NOTE: obj should be a pointer. func (ctx *RequestContext) Validate(obj interface{}) error { - return binding.Validate(obj) + return binding_v2.DefaultValidator.ValidateStruct(obj) } // VisitAllQueryArgs calls f for each existing query arg. diff --git a/pkg/app/server/binding_v2/binder.go b/pkg/app/server/binding_v2/binder.go index dbdd4a0f8..314e1d76e 100644 --- a/pkg/app/server/binding_v2/binder.go +++ b/pkg/app/server/binding_v2/binder.go @@ -1,13 +1,7 @@ package binding_v2 import ( - "fmt" - "reflect" - "sync" - - "github.com/cloudwego/hertz/internal/bytesconv" "github.com/cloudwego/hertz/pkg/protocol" - "google.golang.org/protobuf/proto" ) // PathParams parameter acquisition interface on the URL path @@ -20,64 +14,4 @@ type Binder interface { Bind(*protocol.Request, PathParams, interface{}) error } -type Bind struct { - decoderCache sync.Map -} - -func (b *Bind) Name() string { - return "hertz" -} - -func (b *Bind) Bind(req *protocol.Request, params PathParams, v interface{}) error { - // todo: 先做 body unmarshal, 先尝试做 body 绑定,然后再尝试绑定其他内容 - err := b.PreBindBody(req, v) - if err != nil { - return err - } - rv, typeID := valueAndTypeID(v) - if rv.Kind() != reflect.Pointer || rv.IsNil() { - return fmt.Errorf("receiver must be a non-nil pointer") - } - if rv.Elem().Kind() == reflect.Map { - return nil - } - cached, ok := b.decoderCache.Load(typeID) - if ok { - // cached decoder, fast path - decoder := cached.(Decoder) - return decoder(req, params, rv.Elem()) - } - - decoder, err := getReqDecoder(rv.Type()) - if err != nil { - return err - } - - b.decoderCache.Store(typeID, decoder) - return decoder(req, params, rv.Elem()) -} - -var ( - jsonContentTypeBytes = "application/json; charset=utf-8" - protobufContentType = "application/x-protobuf" -) - -// best effort binding -func (b *Bind) PreBindBody(req *protocol.Request, v interface{}) error { - if req.Header.ContentLength() <= 0 { - return nil - } - switch bytesconv.B2s(req.Header.ContentType()) { - case jsonContentTypeBytes: - // todo: 对齐gin, 添加 "EnableDecoderUseNumber"/"EnableDecoderDisallowUnknownFields" 接口 - return jsonUnmarshalFunc(req.Body(), v) - case protobufContentType: - msg, ok := v.(proto.Message) - if !ok { - return fmt.Errorf("%s can not implement 'proto.Message'", v) - } - return proto.Unmarshal(req.Body(), msg) - default: - return nil - } -} +var DefaultBinder Binder = &Bind{} diff --git a/pkg/app/server/binding_v2/binder_test.go b/pkg/app/server/binding_v2/binder_test.go index 1c9df21d2..11315ca1c 100644 --- a/pkg/app/server/binding_v2/binder_test.go +++ b/pkg/app/server/binding_v2/binder_test.go @@ -57,7 +57,6 @@ func (m *mockRequest) SetBody(data []byte) *mockRequest { } func TestBind_BaseType(t *testing.T) { - bind := Bind{} type Req struct { Version int `path:"v"` ID int `query:"id"` @@ -78,7 +77,7 @@ func TestBind_BaseType(t *testing.T) { var result Req - err := bind.Bind(req.Req, params, &result) + err := DefaultBinder.Bind(req.Req, params, &result) if err != nil { t.Error(err) } @@ -89,7 +88,6 @@ func TestBind_BaseType(t *testing.T) { } func TestBind_SliceType(t *testing.T) { - bind := Bind{} type Req struct { ID []int `query:"id"` Str [3]string `query:"str"` @@ -104,7 +102,7 @@ func TestBind_SliceType(t *testing.T) { var result Req - err := bind.Bind(req.Req, nil, &result) + err := DefaultBinder.Bind(req.Req, nil, &result) if err != nil { t.Error(err) } @@ -143,13 +141,11 @@ func TestBind_StructType(t *testing.T) { B2 Foo `query:"B2"` } - bind := Bind{} - var result Bar req := newMockRequest().SetRequestURI("http://foobar.com?F1=f1&B1=b1").SetHeader("f2", "f2") - err := bind.Bind(req.Req, nil, &result) + err := DefaultBinder.Bind(req.Req, nil, &result) if err != nil { t.Error(err) } @@ -178,8 +174,6 @@ func TestBind_PointerType(t *testing.T) { B4 [2]*int `query:"B4"` } - bind := Bind{} - result := Bar{} F1 := "f1" @@ -191,7 +185,7 @@ func TestBind_PointerType(t *testing.T) { req := newMockRequest().SetRequestURI(fmt.Sprintf("http://foobar.com?F1=%s&B1=%s&B2=%s&B3=%s&B3=%s&B4=%d&B4=%d", F1, B1, B2, B3s[0], B3s[1], B4s[0], B4s[1])). SetHeader("f2", "f2") - err := bind.Bind(req.Req, nil, &result) + err := DefaultBinder.Bind(req.Req, nil, &result) if err != nil { t.Error(err) } @@ -220,12 +214,10 @@ func TestBind_NestedStruct(t *testing.T) { } } - bind := Bind{} - result := Bar{} req := newMockRequest().SetRequestURI("http://foobar.com?F1=qwe") - err := bind.Bind(req.Req, nil, &result) + err := DefaultBinder.Bind(req.Req, nil, &result) if err != nil { t.Error(err) } @@ -242,13 +234,11 @@ func TestBind_SliceStruct(t *testing.T) { B1 []Foo `query:"F1"` } - bind := Bind{} - result := Bar{} B1s := []string{"1", "2", "3"} req := newMockRequest().SetRequestURI(fmt.Sprintf("http://foobar.com?F1={\"f1\":\"%s\"}&F1={\"f1\":\"%s\"}&F1={\"f1\":\"%s\"}", B1s[0], B1s[1], B1s[2])) - err := bind.Bind(req.Req, nil, &result) + err := DefaultBinder.Bind(req.Req, nil, &result) if err != nil { t.Error(err) } @@ -260,11 +250,10 @@ func TestBind_SliceStruct(t *testing.T) { func TestBind_MapType(t *testing.T) { var result map[string]string - bind := Bind{} req := newMockRequest(). SetJSONContentType(). SetBody([]byte(`{"j1":"j1", "j2":"j2"}`)) - err := bind.Bind(req.Req, nil, &result) + err := DefaultBinder.Bind(req.Req, nil, &result) if err != nil { t.Fatal(err) } @@ -278,13 +267,12 @@ func TestBind_MapFieldType(t *testing.T) { F1 ***map[string]string `query:"f1" json:"f1"` } - bind := Bind{} req := newMockRequest(). SetRequestURI("http://foobar.com?f1={\"f1\":\"f1\"}"). SetJSONContentType(). SetBody([]byte(`{"j1":"j1", "j2":"j2"}`)) result := Foo{} - err := bind.Bind(req.Req, nil, &result) + err := DefaultBinder.Bind(req.Req, nil, &result) if err != nil { t.Fatal(err) } @@ -297,10 +285,9 @@ func TestBind_UnexportedField(t *testing.T) { A int `query:"a"` b int `query:"b"` } - bind := Bind{} req := newMockRequest(). SetRequestURI("http://foobar.com?a=1&b=2") - err := bind.Bind(req.Req, nil, &s) + err := DefaultBinder.Bind(req.Req, nil, &s) if err != nil { t.Fatal(err) } @@ -314,7 +301,6 @@ func TestBind_NoTagField(t *testing.T) { B string C string } - bind := Bind{} req := newMockRequest(). SetRequestURI("http://foobar.com?B=b1&C=c1"). SetHeader("A", "a2") @@ -325,7 +311,7 @@ func TestBind_NoTagField(t *testing.T) { Value: "b2", }) - err := bind.Bind(req.Req, params, &s) + err := DefaultBinder.Bind(req.Req, params, &s) if err != nil { t.Fatal(err) } @@ -339,11 +325,10 @@ func TestBind_ZeroValueBind(t *testing.T) { A int `query:"a"` B float64 `query:"b"` } - bind := Bind{} req := newMockRequest(). SetRequestURI("http://foobar.com?a=&b") - err := bind.Bind(req.Req, nil, &s) + err := DefaultBinder.Bind(req.Req, nil, &s) if err != nil { t.Fatal(err) } @@ -358,11 +343,10 @@ func TestBind_DefaultValueBind(t *testing.T) { C []int `default:"15"` D []string `default:"qwe"` } - bind := Bind{} req := newMockRequest(). SetRequestURI("http://foobar.com") - err := bind.Bind(req.Req, nil, &s) + err := DefaultBinder.Bind(req.Req, nil, &s) if err != nil { t.Fatal(err) } @@ -375,7 +359,7 @@ func TestBind_DefaultValueBind(t *testing.T) { D [2]string `default:"qwe"` } - err = bind.Bind(req.Req, nil, &d) + err = DefaultBinder.Bind(req.Req, nil, &d) if err == nil { t.Fatal("expected err") } @@ -385,12 +369,11 @@ func TestBind_RequiredBind(t *testing.T) { var s struct { A int `query:"a,required"` } - bind := Bind{} req := newMockRequest(). SetRequestURI("http://foobar.com"). SetHeader("A", "1") - err := bind.Bind(req.Req, nil, &s) + err := DefaultBinder.Bind(req.Req, nil, &s) if err == nil { t.Fatal("expected error") } @@ -398,7 +381,7 @@ func TestBind_RequiredBind(t *testing.T) { var d struct { A int `query:"a,required" header:"A"` } - err = bind.Bind(req.Req, nil, &d) + err = DefaultBinder.Bind(req.Req, nil, &d) if err != nil { t.Fatal(err) } @@ -418,10 +401,9 @@ func TestBind_TypedefType(t *testing.T) { B Bar `query:"b"` T1 TT } - bind := Bind{} req := newMockRequest(). SetRequestURI("http://foobar.com?a=1&b=2") - err := bind.Bind(req.Req, nil, &s) + err := DefaultBinder.Bind(req.Req, nil, &s) if err != nil { t.Fatal(err) } @@ -430,6 +412,37 @@ func TestBind_TypedefType(t *testing.T) { assert.DeepEqual(t, "1", s.T1.T1) } +// 枚举类型BaseType +type EnumType int64 + +const ( + EnumType_TWEET EnumType = 0 + EnumType_RETWEET EnumType = 2 +) + +func (p EnumType) String() string { + switch p { + case EnumType_TWEET: + return "TWEET" + case EnumType_RETWEET: + return "RETWEET" + } + return "" +} + +func TestBind_EnumBind(t *testing.T) { + var s struct { + A EnumType `query:"a"` + B EnumType `query:"b"` + } + req := newMockRequest(). + SetRequestURI("http://foobar.com?a=0&b=2") + err := DefaultBinder.Bind(req.Req, nil, &s) + if err != nil { + t.Fatal(err) + } +} + type CustomizedDecode struct { A string } @@ -448,11 +461,10 @@ func TestBind_CustomizedTypeDecode(t *testing.T) { type Foo struct { F ***CustomizedDecode } - bind := Bind{} req := newMockRequest(). SetRequestURI("http://foobar.com?a=1&b=2") result := Foo{} - err := bind.Bind(req.Req, nil, &result) + err := DefaultBinder.Bind(req.Req, nil, &result) if err != nil { t.Fatal(err) } @@ -463,7 +475,7 @@ func TestBind_CustomizedTypeDecode(t *testing.T) { } result2 := Bar{} - err = bind.Bind(req.Req, nil, &result2) + err = DefaultBinder.Bind(req.Req, nil, &result2) if err != nil { t.Error(err) } @@ -471,7 +483,6 @@ func TestBind_CustomizedTypeDecode(t *testing.T) { } func TestBind_JSON(t *testing.T) { - bind := Bind{} type Req struct { J1 string `json:"j1"` J2 int `json:"j2" query:"j2"` // 1. json unmarshal 2. query binding cover @@ -487,7 +498,7 @@ func TestBind_JSON(t *testing.T) { SetJSONContentType(). SetBody([]byte(fmt.Sprintf(`{"j1":"j1", "j2":12, "j3":[%d, %d], "j4":["%s", "%s"]}`, J3s[0], J3s[1], J4s[0], J4s[1]))) var result Req - err := bind.Bind(req.Req, nil, &result) + err := DefaultBinder.Bind(req.Req, nil, &result) if err != nil { t.Error(err) } @@ -503,7 +514,6 @@ func TestBind_JSON(t *testing.T) { func TestBind_ResetJSONUnmarshal(t *testing.T) { ResetStdJSONUnmarshaler() - bind := Bind{} type Req struct { J1 string `json:"j1"` J2 int `json:"j2"` @@ -517,7 +527,7 @@ func TestBind_ResetJSONUnmarshal(t *testing.T) { SetJSONContentType(). SetBody([]byte(fmt.Sprintf(`{"j1":"j1", "j2":12, "j3":[%d, %d], "j4":["%s", "%s"]}`, J3s[0], J3s[1], J4s[0], J4s[1]))) var result Req - err := bind.Bind(req.Req, nil, &result) + err := DefaultBinder.Bind(req.Req, nil, &result) if err != nil { t.Error(err) } @@ -532,7 +542,6 @@ func TestBind_ResetJSONUnmarshal(t *testing.T) { } func Benchmark_V2(b *testing.B) { - bind := Bind{} type Req struct { Version string `path:"v"` ID int `query:"id"` @@ -555,7 +564,7 @@ func Benchmark_V2(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { var result Req - err := bind.Bind(req.Req, params, &result) + err := DefaultBinder.Bind(req.Req, params, &result) if err != nil { b.Error(err) } diff --git a/pkg/app/server/binding_v2/default_binder.go b/pkg/app/server/binding_v2/default_binder.go new file mode 100644 index 000000000..9710249d9 --- /dev/null +++ b/pkg/app/server/binding_v2/default_binder.go @@ -0,0 +1,73 @@ +package binding_v2 + +import ( + "fmt" + "reflect" + "sync" + + "github.com/cloudwego/hertz/internal/bytesconv" + "github.com/cloudwego/hertz/pkg/protocol" + "google.golang.org/protobuf/proto" +) + +type Bind struct { + decoderCache sync.Map +} + +func (b *Bind) Name() string { + return "hertz" +} + +func (b *Bind) Bind(req *protocol.Request, params PathParams, v interface{}) error { + // todo: 先做 body unmarshal, 先尝试做 body 绑定,然后再尝试绑定其他内容 + err := b.preBindBody(req, v) + if err != nil { + return err + } + rv, typeID := valueAndTypeID(v) + if rv.Kind() != reflect.Pointer || rv.IsNil() { + return fmt.Errorf("receiver must be a non-nil pointer") + } + if rv.Elem().Kind() == reflect.Map { + return nil + } + cached, ok := b.decoderCache.Load(typeID) + if ok { + // cached decoder, fast path + decoder := cached.(Decoder) + return decoder(req, params, rv.Elem()) + } + + decoder, err := getReqDecoder(rv.Type()) + if err != nil { + return err + } + + b.decoderCache.Store(typeID, decoder) + return decoder(req, params, rv.Elem()) +} + +var ( + jsonContentTypeBytes = "application/json; charset=utf-8" + protobufContentType = "application/x-protobuf" +) + +// best effort binding +func (b *Bind) preBindBody(req *protocol.Request, v interface{}) error { + if req.Header.ContentLength() <= 0 { + return nil + } + switch bytesconv.B2s(req.Header.ContentType()) { + case jsonContentTypeBytes: + // todo: 对齐gin, 添加 "EnableDecoderUseNumber"/"EnableDecoderDisallowUnknownFields" 接口 + return jsonUnmarshalFunc(req.Body(), v) + case protobufContentType: + msg, ok := v.(proto.Message) + if !ok { + return fmt.Errorf("%s can not implement 'proto.Message'", v) + } + return proto.Unmarshal(req.Body(), msg) + default: + return nil + } +} diff --git a/pkg/app/server/binding_v2/default_validator.go b/pkg/app/server/binding_v2/default_validator.go new file mode 100644 index 000000000..988b05fc3 --- /dev/null +++ b/pkg/app/server/binding_v2/default_validator.go @@ -0,0 +1,54 @@ +package binding_v2 + +import ( + "reflect" + "sync" + + "github.com/go-playground/validator/v10" +) + +var _ StructValidator = (*defaultValidator)(nil) + +type defaultValidator struct { + once sync.Once + validate *validator.Validate +} + +// ValidateStruct receives any kind of type, but only performed struct or pointer to struct type. +func (v *defaultValidator) ValidateStruct(obj interface{}) error { + if obj == nil { + return nil + } + + value := reflect.ValueOf(obj) + switch value.Kind() { + case reflect.Ptr: + return v.ValidateStruct(value.Elem().Interface()) + case reflect.Struct: + return v.validateStruct(obj) + default: + return nil + } +} + +// validateStruct receives struct type +func (v *defaultValidator) validateStruct(obj interface{}) error { + v.lazyinit() + return v.validate.Struct(obj) +} + +func (v *defaultValidator) lazyinit() { + v.once.Do(func() { + v.validate = validator.New() + v.validate.SetTagName("validate") + }) +} + +// Engine returns the underlying validator engine which powers the default +// Validator instance. This is useful if you want to register custom validations +// or struct level validations. See validator GoDoc for more info - +// https://pkg.go.dev/github.com/go-playground/validator/v10 +func (v *defaultValidator) Engine() interface{} { + v.lazyinit() + return v.validate +} diff --git a/pkg/app/server/binding_v2/validator.go b/pkg/app/server/binding_v2/validator.go new file mode 100644 index 000000000..1e418b70a --- /dev/null +++ b/pkg/app/server/binding_v2/validator.go @@ -0,0 +1,24 @@ +package binding_v2 + +// StructValidator is the minimal interface which needs to be implemented in +// order for it to be used as the validator engine for ensuring the correctness +// of the request. Hertz provides a default implementation for this using +// https://github.com/go-playground/validator/tree/v10.6.1. +type StructValidator interface { + // ValidateStruct can receive any kind of type and it should never panic, even if the configuration is not right. + // If the received type is a slice|array, the validation should be performed travel on every element. + // If the received type is not a struct or slice|array, any validation should be skipped and nil must be returned. + // If the received type is a struct or pointer to a struct, the validation should be performed. + // If the struct is not valid or the validation itself fails, a descriptive error should be returned. + // Otherwise nil must be returned. + ValidateStruct(interface{}) error + + // Engine returns the underlying validator engine which powers the + // StructValidator implementation. + Engine() interface{} +} + +// DefaultValidator is the default validator which implements the StructValidator +// interface. It uses https://github.com/go-playground/validator/tree/v10.6.1 +// under the hood. +var DefaultValidator StructValidator = &defaultValidator{} diff --git a/pkg/app/server/binding_v2/validator_test.go b/pkg/app/server/binding_v2/validator_test.go new file mode 100644 index 000000000..2ec125810 --- /dev/null +++ b/pkg/app/server/binding_v2/validator_test.go @@ -0,0 +1,31 @@ +package binding_v2 + +import ( + "fmt" +) + +func ExampleValidateStruct() { + type User struct { + FirstName string `validate:"required"` + LastName string `validate:"required"` + Age uint8 `validate:"gte=0,lte=130"` + Email string `validate:"required,email"` + FavouriteColor string `validate:"iscolor"` + } + + user := &User{ + FirstName: "Hertz", + Age: 135, + Email: "hertz", + FavouriteColor: "sad", + } + err := DefaultValidator.ValidateStruct(user) + if err != nil { + fmt.Println(err) + } + // Output: + //Key: 'User.LastName' Error:Field validation for 'LastName' failed on the 'required' tag + //Key: 'User.Age' Error:Field validation for 'Age' failed on the 'lte' tag + //Key: 'User.Email' Error:Field validation for 'Email' failed on the 'email' tag + //Key: 'User.FavouriteColor' Error:Field validation for 'FavouriteColor' failed on the 'iscolor' tag +} From f965917a7fe51fe86600c451262365049f88235f Mon Sep 17 00:00:00 2001 From: fgy Date: Wed, 1 Feb 2023 18:23:23 +0800 Subject: [PATCH 09/91] feat: add license --- go.mod | 2 - pkg/app/context.go | 11 +- pkg/app/context_test.go | 2 +- .../base_type_decoder.go | 50 +- pkg/app/server/binding/binder.go | 57 +++ .../{binding_v2 => binding}/binder_test.go | 92 ++-- pkg/app/server/binding/binding.go | 122 ----- pkg/app/server/binding/binding_test.go | 450 ------------------ .../server/binding/customized_type_decoder.go | 77 +++ .../server/{binding_v2 => binding}/decoder.go | 49 +- pkg/app/server/binding/default_binder.go | 112 +++++ pkg/app/server/binding/default_validator.go | 94 ++++ .../server/{binding_v2 => binding}/getter.go | 66 ++- pkg/app/server/binding/json.go | 65 +++ .../map_type_decoder.go | 45 +- pkg/app/server/binding/reflect.go | 113 +++++ pkg/app/server/binding/request.go | 138 ------ .../slice_type_decoder.go | 54 ++- pkg/app/server/{binding_v2 => binding}/tag.go | 18 +- pkg/app/server/binding/text_decoder/bool.go | 57 +++ pkg/app/server/binding/text_decoder/float.go | 59 +++ pkg/app/server/binding/text_decoder/int.go | 59 +++ pkg/app/server/binding/text_decoder/string.go | 50 ++ .../binding/text_decoder/text_decoder.go | 92 ++++ pkg/app/server/binding/text_decoder/unit.go | 59 +++ pkg/app/server/binding/validator.go | 64 +++ pkg/app/server/binding/validator_test.go | 47 ++ pkg/app/server/binding_v2/binder.go | 17 - .../binding_v2/customized_type_decoder.go | 38 -- pkg/app/server/binding_v2/default_binder.go | 73 --- .../server/binding_v2/default_validator.go | 54 --- pkg/app/server/binding_v2/json.go | 25 - pkg/app/server/binding_v2/reflect.go | 72 --- .../server/binding_v2/text_decoder/bool.go | 17 - .../server/binding_v2/text_decoder/float.go | 19 - pkg/app/server/binding_v2/text_decoder/int.go | 19 - .../server/binding_v2/text_decoder/string.go | 11 - .../binding_v2/text_decoder/text_decoder.go | 52 -- .../server/binding_v2/text_decoder/unit.go | 19 - pkg/app/server/binding_v2/validator.go | 24 - pkg/app/server/binding_v2/validator_test.go | 31 -- 41 files changed, 1305 insertions(+), 1270 deletions(-) rename pkg/app/server/{binding_v2 => binding}/base_type_decoder.go (58%) create mode 100644 pkg/app/server/binding/binder.go rename pkg/app/server/{binding_v2 => binding}/binder_test.go (86%) delete mode 100644 pkg/app/server/binding/binding.go create mode 100644 pkg/app/server/binding/customized_type_decoder.go rename pkg/app/server/{binding_v2 => binding}/decoder.go (56%) create mode 100644 pkg/app/server/binding/default_binder.go create mode 100644 pkg/app/server/binding/default_validator.go rename pkg/app/server/{binding_v2 => binding}/getter.go (52%) create mode 100644 pkg/app/server/binding/json.go rename pkg/app/server/{binding_v2 => binding}/map_type_decoder.go (55%) create mode 100644 pkg/app/server/binding/reflect.go delete mode 100644 pkg/app/server/binding/request.go rename pkg/app/server/{binding_v2 => binding}/slice_type_decoder.go (61%) rename pkg/app/server/{binding_v2 => binding}/tag.go (74%) create mode 100644 pkg/app/server/binding/text_decoder/bool.go create mode 100644 pkg/app/server/binding/text_decoder/float.go create mode 100644 pkg/app/server/binding/text_decoder/int.go create mode 100644 pkg/app/server/binding/text_decoder/string.go create mode 100644 pkg/app/server/binding/text_decoder/text_decoder.go create mode 100644 pkg/app/server/binding/text_decoder/unit.go create mode 100644 pkg/app/server/binding/validator.go create mode 100644 pkg/app/server/binding/validator_test.go delete mode 100644 pkg/app/server/binding_v2/binder.go delete mode 100644 pkg/app/server/binding_v2/customized_type_decoder.go delete mode 100644 pkg/app/server/binding_v2/default_binder.go delete mode 100644 pkg/app/server/binding_v2/default_validator.go delete mode 100644 pkg/app/server/binding_v2/json.go delete mode 100644 pkg/app/server/binding_v2/reflect.go delete mode 100644 pkg/app/server/binding_v2/text_decoder/bool.go delete mode 100644 pkg/app/server/binding_v2/text_decoder/float.go delete mode 100644 pkg/app/server/binding_v2/text_decoder/int.go delete mode 100644 pkg/app/server/binding_v2/text_decoder/string.go delete mode 100644 pkg/app/server/binding_v2/text_decoder/text_decoder.go delete mode 100644 pkg/app/server/binding_v2/text_decoder/unit.go delete mode 100644 pkg/app/server/binding_v2/validator.go delete mode 100644 pkg/app/server/binding_v2/validator_test.go diff --git a/go.mod b/go.mod index 865de6144..8f0892db9 100644 --- a/go.mod +++ b/go.mod @@ -3,14 +3,12 @@ module github.com/cloudwego/hertz go 1.16 require ( - github.com/bytedance/go-tagexpr/v2 v2.9.2 github.com/bytedance/gopkg v0.0.0-20220413063733-65bf48ffb3a7 github.com/bytedance/mockey v1.2.1 github.com/bytedance/sonic v1.8.1 github.com/cloudwego/netpoll v0.4.2-0.20230807055039-52fd5fb7b00f github.com/fsnotify/fsnotify v1.5.4 github.com/go-playground/validator/v10 v10.11.2 - github.com/tidwall/gjson v1.13.0 // indirect golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 golang.org/x/sys v0.4.0 google.golang.org/protobuf v1.27.1 diff --git a/pkg/app/context.go b/pkg/app/context.go index 99521f37b..92ad94217 100644 --- a/pkg/app/context.go +++ b/pkg/app/context.go @@ -54,10 +54,9 @@ import ( "sync" "time" - "github.com/cloudwego/hertz/pkg/app/server/binding_v2" - "github.com/cloudwego/hertz/internal/bytesconv" "github.com/cloudwego/hertz/internal/bytestr" + "github.com/cloudwego/hertz/pkg/app/server/binding" "github.com/cloudwego/hertz/pkg/app/server/render" "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/common/tracer/traceinfo" @@ -1306,24 +1305,24 @@ func bodyAllowedForStatus(status int) bool { // BindAndValidate binds data from *RequestContext to obj and validates them if needed. // NOTE: obj should be a pointer. func (ctx *RequestContext) BindAndValidate(obj interface{}) error { - err := binding_v2.DefaultBinder.Bind(&ctx.Request, ctx.Params, obj) + err := binding.DefaultBinder.Bind(&ctx.Request, ctx.Params, obj) if err != nil { return err } - err = binding_v2.DefaultValidator.ValidateStruct(obj) + err = binding.DefaultValidator.ValidateStruct(obj) return err } // Bind binds data from *RequestContext to obj. // NOTE: obj should be a pointer. func (ctx *RequestContext) Bind(obj interface{}) error { - return binding_v2.DefaultBinder.Bind(&ctx.Request, ctx.Params, obj) + return binding.DefaultBinder.Bind(&ctx.Request, ctx.Params, obj) } // Validate validates obj with "vd" tag // NOTE: obj should be a pointer. func (ctx *RequestContext) Validate(obj interface{}) error { - return binding_v2.DefaultValidator.ValidateStruct(obj) + return binding.DefaultValidator.ValidateStruct(obj) } // VisitAllQueryArgs calls f for each existing query arg. diff --git a/pkg/app/context_test.go b/pkg/app/context_test.go index 88cd26338..2567a822d 100644 --- a/pkg/app/context_test.go +++ b/pkg/app/context_test.go @@ -1419,7 +1419,7 @@ func TestRequestContext_GetResponse(t *testing.T) { func TestBindAndValidate(t *testing.T) { type Test struct { A string `query:"a"` - B int `query:"b" vd:"$>10"` + B int `query:"b" validate:"gt=10"` } c := &RequestContext{} diff --git a/pkg/app/server/binding_v2/base_type_decoder.go b/pkg/app/server/binding/base_type_decoder.go similarity index 58% rename from pkg/app/server/binding_v2/base_type_decoder.go rename to pkg/app/server/binding/base_type_decoder.go index 57e7aa958..69fdb98bd 100644 --- a/pkg/app/server/binding_v2/base_type_decoder.go +++ b/pkg/app/server/binding/base_type_decoder.go @@ -1,10 +1,50 @@ -package binding_v2 +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * MIT License + * + * Copyright (c) 2019-present Fenny and Contributors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * This file may have been modified by CloudWeGo authors. All CloudWeGo + * Modifications are Copyright 2022 CloudWeGo Authors + */ + +package binding import ( "fmt" "reflect" - "github.com/cloudwego/hertz/pkg/app/server/binding_v2/text_decoder" + "github.com/cloudwego/hertz/pkg/app/server/binding/text_decoder" "github.com/cloudwego/hertz/pkg/common/utils" "github.com/cloudwego/hertz/pkg/protocol" ) @@ -26,7 +66,6 @@ func (d *baseTypeFieldTextDecoder) Decode(req *protocol.Request, params PathPara var err error var text string var defaultValue string - // 最大努力交付,对齐 hertz 现有设计 for _, tagInfo := range d.tagInfos { if tagInfo.Key == jsonTag { continue @@ -37,7 +76,6 @@ func (d *baseTypeFieldTextDecoder) Decode(req *protocol.Request, params PathPara ret := tagInfo.Getter(req, params, tagInfo.Value) defaultValue = tagInfo.Default if len(ret) != 0 { - // 非数组/切片类型,只取第一个值作为只 text = ret[0] err = nil break @@ -56,12 +94,10 @@ func (d *baseTypeFieldTextDecoder) Decode(req *protocol.Request, params PathPara return nil } - // 得到该field的非nil值 + // get the non-nil value for the field reqValue = GetFieldValue(reqValue, d.parentIndex) - // 根据最终的 Struct,获取对应 field 的 reflect.Value field := reqValue.Field(d.index) if field.Kind() == reflect.Ptr { - // 如果是指针则新建一个reflect.Value,然后赋值给指针 t := field.Type() var ptrDepth int for t.Kind() == reflect.Ptr { diff --git a/pkg/app/server/binding/binder.go b/pkg/app/server/binding/binder.go new file mode 100644 index 000000000..d6fbda809 --- /dev/null +++ b/pkg/app/server/binding/binder.go @@ -0,0 +1,57 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * MIT License + * + * Copyright (c) 2019-present Fenny and Contributors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * This file may have been modified by CloudWeGo authors. All CloudWeGo + * Modifications are Copyright 2022 CloudWeGo Authors + */ + +package binding + +import ( + "github.com/cloudwego/hertz/pkg/protocol" +) + +// PathParams parameter acquisition interface on the URL path +type PathParams interface { + Get(name string) (string, bool) +} + +type Binder interface { + Name() string + Bind(*protocol.Request, PathParams, interface{}) error +} + +var DefaultBinder Binder = &Bind{} diff --git a/pkg/app/server/binding_v2/binder_test.go b/pkg/app/server/binding/binder_test.go similarity index 86% rename from pkg/app/server/binding_v2/binder_test.go rename to pkg/app/server/binding/binder_test.go index 11315ca1c..d228cfd57 100644 --- a/pkg/app/server/binding_v2/binder_test.go +++ b/pkg/app/server/binding/binder_test.go @@ -1,10 +1,49 @@ -package binding_v2 +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * MIT License + * + * Copyright (c) 2019-present Fenny and Contributors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * This file may have been modified by CloudWeGo authors. All CloudWeGo + * Modifications are Copyright 2022 CloudWeGo Authors + */ + +package binding import ( "fmt" "testing" - "github.com/cloudwego/hertz/pkg/app/server/binding" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/route/param" @@ -412,7 +451,6 @@ func TestBind_TypedefType(t *testing.T) { assert.DeepEqual(t, "1", s.T1.T1) } -// 枚举类型BaseType type EnumType int64 const ( @@ -484,9 +522,8 @@ func TestBind_CustomizedTypeDecode(t *testing.T) { func TestBind_JSON(t *testing.T) { type Req struct { - J1 string `json:"j1"` - J2 int `json:"j2" query:"j2"` // 1. json unmarshal 2. query binding cover - // todo: map + J1 string `json:"j1"` + J2 int `json:"j2" query:"j2"` // 1. json unmarshal 2. query binding cover J3 []byte `json:"j3"` J4 [2]string `json:"j4"` } @@ -541,7 +578,7 @@ func TestBind_ResetJSONUnmarshal(t *testing.T) { } } -func Benchmark_V2(b *testing.B) { +func Benchmark_Binding(b *testing.B) { type Req struct { Version string `path:"v"` ID int `query:"id"` @@ -582,44 +619,3 @@ func Benchmark_V2(b *testing.B) { } } } - -func Benchmark_V1(b *testing.B) { - type Req struct { - Version string `path:"v"` - ID int `query:"id"` - Header string `header:"h"` - Form string `form:"f"` - } - - req := newMockRequest(). - SetRequestURI("http://foobar.com?id=12"). - SetHeaders("h", "header"). - SetPostArg("f", "form"). - SetUrlEncodeContentType() - var params param.Params - params = append(params, param.Param{ - Key: "v", - Value: "1", - }) - - b.ResetTimer() - for i := 0; i < b.N; i++ { - var result Req - err := binding.Bind(req.Req, &result, params) - if err != nil { - b.Error(err) - } - if result.ID != 12 { - b.Error("Id failed") - } - if result.Form != "form" { - b.Error("form failed") - } - if result.Header != "header" { - b.Error("header failed") - } - if result.Version != "1" { - b.Error("path failed") - } - } -} diff --git a/pkg/app/server/binding/binding.go b/pkg/app/server/binding/binding.go deleted file mode 100644 index fa4af9d97..000000000 --- a/pkg/app/server/binding/binding.go +++ /dev/null @@ -1,122 +0,0 @@ -/* - * Copyright 2022 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package binding - -import ( - "encoding/json" - "reflect" - - "github.com/bytedance/go-tagexpr/v2/binding" - "github.com/bytedance/go-tagexpr/v2/binding/gjson" - "github.com/bytedance/go-tagexpr/v2/validator" - hjson "github.com/cloudwego/hertz/pkg/common/json" - "github.com/cloudwego/hertz/pkg/protocol" - "github.com/cloudwego/hertz/pkg/route/param" -) - -func init() { - binding.ResetJSONUnmarshaler(hjson.Unmarshal) -} - -var defaultBinder = binding.Default() - -// BindAndValidate binds data from *protocol.Request to obj and validates them if needed. -// NOTE: -// -// obj should be a pointer. -func BindAndValidate(req *protocol.Request, obj interface{}, pathParams param.Params) error { - return defaultBinder.IBindAndValidate(obj, wrapRequest(req), pathParams) -} - -// Bind binds data from *protocol.Request to obj. -// NOTE: -// -// obj should be a pointer. -func Bind(req *protocol.Request, obj interface{}, pathParams param.Params) error { - return defaultBinder.IBind(obj, wrapRequest(req), pathParams) -} - -// Validate validates obj with "vd" tag -// NOTE: -// -// obj should be a pointer. -// Validate should be called after Bind. -func Validate(obj interface{}) error { - return defaultBinder.Validate(obj) -} - -// SetLooseZeroMode if set to true, -// the empty string request parameter is bound to the zero value of parameter. -// NOTE: -// -// The default is false. -// Suitable for these parameter types: query/header/cookie/form . -func SetLooseZeroMode(enable bool) { - defaultBinder.SetLooseZeroMode(enable) -} - -// SetErrorFactory customizes the factory of validation error. -// NOTE: -// -// If errFactory==nil, the default is used. -// SetErrorFactory will remain in effect once it has been called. -func SetErrorFactory(bindErrFactory, validatingErrFactory func(failField, msg string) error) { - defaultBinder.SetErrorFactory(bindErrFactory, validatingErrFactory) -} - -// MustRegTypeUnmarshal registers unmarshal function of type. -// NOTE: -// -// It will panic if exist error. -// MustRegTypeUnmarshal will remain in effect once it has been called. -func MustRegTypeUnmarshal(t reflect.Type, fn func(v string, emptyAsZero bool) (reflect.Value, error)) { - binding.MustRegTypeUnmarshal(t, fn) -} - -// MustRegValidateFunc registers validator function expression. -// NOTE: -// -// If force=true, allow to cover the existed same funcName. -// MustRegValidateFunc will remain in effect once it has been called. -func MustRegValidateFunc(funcName string, fn func(args ...interface{}) error, force ...bool) { - validator.RegFunc(funcName, fn, force...) -} - -// UseStdJSONUnmarshaler uses encoding/json as json library -// NOTE: -// -// The current version uses encoding/json by default. -// UseStdJSONUnmarshaler will remain in effect once it has been called. -func UseStdJSONUnmarshaler() { - binding.ResetJSONUnmarshaler(json.Unmarshal) -} - -// UseGJSONUnmarshaler uses github.com/bytedance/go-tagexpr/v2/binding/gjson as json library -// NOTE: -// -// UseGJSONUnmarshaler will remain in effect once it has been called. -func UseGJSONUnmarshaler() { - gjson.UseJSONUnmarshaler() -} - -// UseThirdPartyJSONUnmarshaler uses third-party json library for binding -// NOTE: -// -// UseThirdPartyJSONUnmarshaler will remain in effect once it has been called. -func UseThirdPartyJSONUnmarshaler(unmarshaler func(data []byte, v interface{}) error) { - binding.ResetJSONUnmarshaler(unmarshaler) -} diff --git a/pkg/app/server/binding/binding_test.go b/pkg/app/server/binding/binding_test.go index a9050f940..e69de29bb 100644 --- a/pkg/app/server/binding/binding_test.go +++ b/pkg/app/server/binding/binding_test.go @@ -1,450 +0,0 @@ -/* - * Copyright 2022 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package binding - -import ( - "bytes" - "fmt" - "mime/multipart" - "reflect" - "testing" - - "github.com/cloudwego/hertz/pkg/common/test/assert" - "github.com/cloudwego/hertz/pkg/protocol" - "github.com/cloudwego/hertz/pkg/route/param" -) - -func TestBindAndValidate(t *testing.T) { - type TestBind struct { - A string `query:"a"` - B []string `query:"b"` - C string `query:"c"` - D string `header:"d"` - E string `path:"e"` - F string `form:"f"` - G multipart.FileHeader `form:"g"` - H string `cookie:"h"` - } - - s := `------WebKitFormBoundaryJwfATyF8tmxSJnLg -Content-Disposition: form-data; name="f" - -fff -------WebKitFormBoundaryJwfATyF8tmxSJnLg -Content-Disposition: form-data; name="g"; filename="TODO" -Content-Type: application/octet-stream - -- SessionClient with referer and cookies support. -- Client with requests' pipelining support. -- ProxyHandler similar to FSHandler. -- WebSockets. See https://tools.ietf.org/html/rfc6455 . -- HTTP/2.0. See https://tools.ietf.org/html/rfc7540 . - -------WebKitFormBoundaryJwfATyF8tmxSJnLg-- -tailfoobar` - - mr := bytes.NewBufferString(s) - r := protocol.NewRequest("POST", "/foo", mr) - r.SetRequestURI("/foo/bar?a=aaa&b=b1&b=b2&c&i=19") - r.SetHeader("d", "ddd") - r.Header.SetContentLength(len(s)) - r.Header.SetContentTypeBytes([]byte("multipart/form-data; boundary=----WebKitFormBoundaryJwfATyF8tmxSJnLg")) - - r.SetCookie("h", "hhh") - - para := param.Params{ - {Key: "e", Value: "eee"}, - } - - // test BindAndValidate() - SetLooseZeroMode(true) - var req TestBind - err := BindAndValidate(r, &req, para) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - assert.DeepEqual(t, "aaa", req.A) - assert.DeepEqual(t, 2, len(req.B)) - assert.DeepEqual(t, "", req.C) - assert.DeepEqual(t, "ddd", req.D) - assert.DeepEqual(t, "eee", req.E) - assert.DeepEqual(t, "fff", req.F) - assert.DeepEqual(t, "TODO", req.G.Filename) - assert.DeepEqual(t, "hhh", req.H) - - // test Bind() - req = TestBind{} - err = Bind(r, &req, para) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - assert.DeepEqual(t, "aaa", req.A) - assert.DeepEqual(t, 2, len(req.B)) - assert.DeepEqual(t, "", req.C) - assert.DeepEqual(t, "ddd", req.D) - assert.DeepEqual(t, "eee", req.E) - assert.DeepEqual(t, "fff", req.F) - assert.DeepEqual(t, "TODO", req.G.Filename) - assert.DeepEqual(t, "hhh", req.H) - - type TestValidate struct { - I int `query:"i" vd:"$>20"` - } - - // test BindAndValidate() - var bindReq TestValidate - err = BindAndValidate(r, &bindReq, para) - if err == nil { - t.Fatalf("unexpected nil, expected an error") - } - - // test Validate() - bindReq = TestValidate{} - err = Bind(r, &bindReq, para) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - assert.DeepEqual(t, 19, bindReq.I) - err = Validate(&bindReq) - if err == nil { - t.Fatalf("unexpected nil, expected an error") - } -} - -func TestJsonBind(t *testing.T) { - type Test struct { - A string `json:"a"` - B []string `json:"b"` - C string `json:"c"` - D int `json:"d,string"` - } - - data := `{"a":"aaa", "b":["b1","b2"], "c":"ccc", "d":"100"}` - mr := bytes.NewBufferString(data) - r := protocol.NewRequest("POST", "/foo", mr) - r.Header.Set("Content-Type", "application/json; charset=utf-8") - r.SetHeader("d", "ddd") - r.Header.SetContentLength(len(data)) - - var req Test - err := BindAndValidate(r, &req, nil) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - assert.DeepEqual(t, "aaa", req.A) - assert.DeepEqual(t, 2, len(req.B)) - assert.DeepEqual(t, "ccc", req.C) - // NOTE: The default does not support string to go int conversion in json. - // You can add "string" tags or use other json unmarshal libraries that support this feature - assert.DeepEqual(t, 100, req.D) - - req = Test{} - UseGJSONUnmarshaler() - err = BindAndValidate(r, &req, nil) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - assert.DeepEqual(t, "aaa", req.A) - assert.DeepEqual(t, 2, len(req.B)) - assert.DeepEqual(t, "ccc", req.C) - // NOTE: The default does not support string to go int conversion in json. - // You can add "string" tags or use other json unmarshal libraries that support this feature - assert.DeepEqual(t, 100, req.D) - - req = Test{} - UseStdJSONUnmarshaler() - err = BindAndValidate(r, &req, nil) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - assert.DeepEqual(t, "aaa", req.A) - assert.DeepEqual(t, 2, len(req.B)) - assert.DeepEqual(t, "ccc", req.C) - // NOTE: The default does not support string to go int conversion in json. - // You can add "string" tags or use other json unmarshal libraries that support this feature - assert.DeepEqual(t, 100, req.D) -} - -// TestQueryParamInconsistency tests the Inconsistency for GetQuery(), the other unit test for GetFunc() in request.go are similar to it -func TestQueryParamInconsistency(t *testing.T) { - type QueryPara struct { - Para1 string `query:"para1"` - Para2 *string `query:"para2"` - } - - r := protocol.NewRequest("GET", "/foo", nil) - r.SetRequestURI("/foo/bar?para1=hertz¶2=binding") - - var req QueryPara - err := BindAndValidate(r, &req, nil) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - beforePara1 := deepCopyString(req.Para1) - beforePara2 := deepCopyString(*req.Para2) - r.URI().QueryArgs().Set("para1", "test") - r.URI().QueryArgs().Set("para2", "test") - afterPara1 := req.Para1 - afterPara2 := *req.Para2 - assert.DeepEqual(t, beforePara1, afterPara1) - assert.DeepEqual(t, beforePara2, afterPara2) -} - -func deepCopyString(str string) string { - tmp := make([]byte, len(str)) - copy(tmp, str) - c := string(tmp) - - return c -} - -func TestBindingFile(t *testing.T) { - type FileParas struct { - F *multipart.FileHeader `form:"F1"` - F1 multipart.FileHeader - Fs []multipart.FileHeader `form:"F1"` - Fs1 []*multipart.FileHeader `form:"F1"` - F2 *multipart.FileHeader `form:"F2"` - } - - s := `------WebKitFormBoundaryJwfATyF8tmxSJnLg -Content-Disposition: form-data; name="f" - -fff -------WebKitFormBoundaryJwfATyF8tmxSJnLg -Content-Disposition: form-data; name="F1"; filename="TODO1" -Content-Type: application/octet-stream - -- SessionClient with referer and cookies support. -- Client with requests' pipelining support. -- ProxyHandler similar to FSHandler. -- WebSockets. See https://tools.ietf.org/html/rfc6455 . -- HTTP/2.0. See https://tools.ietf.org/html/rfc7540 . -------WebKitFormBoundaryJwfATyF8tmxSJnLg -Content-Disposition: form-data; name="F1"; filename="TODO2" -Content-Type: application/octet-stream - -- SessionClient with referer and cookies support. -- Client with requests' pipelining support. -- ProxyHandler similar to FSHandler. -- WebSockets. See https://tools.ietf.org/html/rfc6455 . -- HTTP/2.0. See https://tools.ietf.org/html/rfc7540 . -------WebKitFormBoundaryJwfATyF8tmxSJnLg -Content-Disposition: form-data; name="F2"; filename="TODO3" -Content-Type: application/octet-stream - -- SessionClient with referer and cookies support. -- Client with requests' pipelining support. -- ProxyHandler similar to FSHandler. -- WebSockets. See https://tools.ietf.org/html/rfc6455 . -- HTTP/2.0. See https://tools.ietf.org/html/rfc7540 . - -------WebKitFormBoundaryJwfATyF8tmxSJnLg-- -tailfoobar` - - mr := bytes.NewBufferString(s) - r := protocol.NewRequest("POST", "/foo", mr) - r.SetRequestURI("/foo/bar?a=aaa&b=b1&b=b2&c&i=19") - r.SetHeader("d", "ddd") - r.Header.SetContentLength(len(s)) - r.Header.SetContentTypeBytes([]byte("multipart/form-data; boundary=----WebKitFormBoundaryJwfATyF8tmxSJnLg")) - - var req FileParas - err := BindAndValidate(r, &req, nil) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - assert.DeepEqual(t, "TODO1", req.F.Filename) - assert.DeepEqual(t, "TODO1", req.F1.Filename) - assert.DeepEqual(t, 2, len(req.Fs)) - assert.DeepEqual(t, 2, len(req.Fs1)) - assert.DeepEqual(t, "TODO3", req.F2.Filename) -} - -type BindError struct { - ErrType, FailField, Msg string -} - -// Error implements error interface. -func (e *BindError) Error() string { - if e.Msg != "" { - return e.ErrType + ": expr_path=" + e.FailField + ", cause=" + e.Msg - } - return e.ErrType + ": expr_path=" + e.FailField + ", cause=invalid" -} - -type ValidateError struct { - ErrType, FailField, Msg string -} - -// Error implements error interface. -func (e *ValidateError) Error() string { - if e.Msg != "" { - return e.ErrType + ": expr_path=" + e.FailField + ", cause=" + e.Msg - } - return e.ErrType + ": expr_path=" + e.FailField + ", cause=invalid" -} - -func TestSetErrorFactory(t *testing.T) { - type TestBind struct { - A string `query:"a,required"` - } - - r := protocol.NewRequest("GET", "/foo", nil) - r.SetRequestURI("/foo/bar?b=20") - - CustomBindErrFunc := func(failField, msg string) error { - err := BindError{ - ErrType: "bindErr", - FailField: "[bindFailField]: " + failField, - Msg: "[bindErrMsg]: " + msg, - } - - return &err - } - - CustomValidateErrFunc := func(failField, msg string) error { - err := ValidateError{ - ErrType: "validateErr", - FailField: "[validateFailField]: " + failField, - Msg: "[validateErrMsg]: " + msg, - } - - return &err - } - - SetErrorFactory(CustomBindErrFunc, CustomValidateErrFunc) - - var req TestBind - err := Bind(r, &req, nil) - if err == nil { - t.Fatalf("unexpected nil, expected an error") - } - assert.DeepEqual(t, "bindErr: expr_path=[bindFailField]: A, cause=[bindErrMsg]: missing required parameter", err.Error()) - - type TestValidate struct { - B int `query:"b" vd:"$>100"` - } - - var reqValidate TestValidate - err = Bind(r, &reqValidate, nil) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - err = Validate(&reqValidate) - if err == nil { - t.Fatalf("unexpected nil, expected an error") - } - assert.DeepEqual(t, "validateErr: expr_path=[validateFailField]: B, cause=[validateErrMsg]: ", err.Error()) -} - -func TestMustRegTypeUnmarshal(t *testing.T) { - type Nested struct { - B string - C string - } - - type TestBind struct { - A Nested `query:"a,required"` - } - - r := protocol.NewRequest("GET", "/foo", nil) - r.SetRequestURI("/foo/bar?a=hertzbinding") - - MustRegTypeUnmarshal(reflect.TypeOf(Nested{}), func(v string, emptyAsZero bool) (reflect.Value, error) { - if v == "" && emptyAsZero { - return reflect.ValueOf(Nested{}), nil - } - val := Nested{ - B: v[:5], - C: v[5:], - } - return reflect.ValueOf(val), nil - }) - - var req TestBind - err := Bind(r, &req, nil) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - assert.DeepEqual(t, "hertz", req.A.B) - assert.DeepEqual(t, "binding", req.A.C) -} - -func TestMustRegValidateFunc(t *testing.T) { - type TestValidate struct { - A string `query:"a" vd:"test($)"` - } - - r := protocol.NewRequest("GET", "/foo", nil) - r.SetRequestURI("/foo/bar?a=123") - - MustRegValidateFunc("test", func(args ...interface{}) error { - if len(args) != 1 { - return fmt.Errorf("the args must be one") - } - s, _ := args[0].(string) - if s == "123" { - return fmt.Errorf("the args can not be 123") - } - return nil - }) - - var req TestValidate - err := Bind(r, &req, nil) - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - err = Validate(&req) - if err == nil { - t.Fatalf("unexpected nil, expected an error") - } -} - -func TestQueryAlias(t *testing.T) { - type MyInt int - type MyString string - type MyIntSlice []int - type MyStringSlice []string - type Test struct { - A []MyInt `query:"a"` - B MyIntSlice `query:"b"` - C MyString `query:"c"` - D MyStringSlice `query:"d"` - } - - r := protocol.NewRequest("GET", "/foo", nil) - r.SetRequestURI("/foo/bar?a=1&a=2&b=2&b=3&c=string1&d=string2&d=string3") - - var req Test - err := Bind(r, &req, nil) - if err != nil { - t.Fatalf("unexpected error: %v", err) - return - } - assert.DeepEqual(t, 2, len(req.A)) - assert.DeepEqual(t, 1, int(req.A[0])) - assert.DeepEqual(t, 2, int(req.A[1])) - assert.DeepEqual(t, 2, len(req.B)) - assert.DeepEqual(t, 2, req.B[0]) - assert.DeepEqual(t, 3, req.B[1]) - assert.DeepEqual(t, "string1", string(req.C)) - assert.DeepEqual(t, 2, len(req.D)) - assert.DeepEqual(t, "string2", req.D[0]) - assert.DeepEqual(t, "string3", req.D[1]) -} diff --git a/pkg/app/server/binding/customized_type_decoder.go b/pkg/app/server/binding/customized_type_decoder.go new file mode 100644 index 000000000..87379c644 --- /dev/null +++ b/pkg/app/server/binding/customized_type_decoder.go @@ -0,0 +1,77 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * MIT License + * + * Copyright (c) 2019-present Fenny and Contributors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * This file may have been modified by CloudWeGo authors. All CloudWeGo + * Modifications are Copyright 2022 CloudWeGo Authors + */ + +package binding + +import ( + "reflect" + + "github.com/cloudwego/hertz/pkg/protocol" +) + +type customizedFieldTextDecoder struct { + fieldInfo +} + +func (d *customizedFieldTextDecoder) Decode(req *protocol.Request, params PathParams, reqValue reflect.Value) error { + var err error + v := reflect.New(d.fieldType) + decoder := v.Interface().(CustomizedFieldDecoder) + + if err = decoder.CustomizedFieldDecode(req, params); err != nil { + return err + } + + reqValue = GetFieldValue(reqValue, d.parentIndex) + field := reqValue.Field(d.index) + if field.Kind() == reflect.Ptr { + t := field.Type() + var ptrDepth int + for t.Kind() == reflect.Ptr { + t = t.Elem() + ptrDepth++ + } + field.Set(ReferenceValue(v.Elem(), ptrDepth)) + return nil + } + + field.Set(v) + return nil +} diff --git a/pkg/app/server/binding_v2/decoder.go b/pkg/app/server/binding/decoder.go similarity index 56% rename from pkg/app/server/binding_v2/decoder.go rename to pkg/app/server/binding/decoder.go index 8fc907ea3..f8bb12246 100644 --- a/pkg/app/server/binding_v2/decoder.go +++ b/pkg/app/server/binding/decoder.go @@ -1,4 +1,44 @@ -package binding_v2 +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * MIT License + * + * Copyright (c) 2019-present Fenny and Contributors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * This file may have been modified by CloudWeGo authors. All CloudWeGo + * Modifications are Copyright 2022 CloudWeGo Authors + */ + +package binding import ( "fmt" @@ -24,7 +64,7 @@ func getReqDecoder(rt reflect.Type) (Decoder, error) { el := rt.Elem() if el.Kind() != reflect.Struct { - return nil, fmt.Errorf("unsupport \"%s\" type binding", el.String()) + return nil, fmt.Errorf("unsupported \"%s\" type binding", el.String()) } for i := 0; i < el.NumField(); i++ { @@ -56,7 +96,6 @@ func getReqDecoder(rt reflect.Type) (Decoder, error) { } func getFieldDecoder(field reflect.StructField, index int, parentIdx []int) ([]decoder, error) { - // 去掉每一个filed的指针,使其指向最终内容 for field.Type.Kind() == reflect.Ptr { field.Type = field.Type.Elem() } @@ -72,7 +111,6 @@ func getFieldDecoder(field reflect.StructField, index int, parentIdx []int) ([]d } fieldTagInfos := lookupFieldTags(field) - // todo: 没有 tag 也不直接返回 if len(fieldTagInfos) == 0 { fieldTagInfos = getDefaultFieldTags(field) } @@ -81,12 +119,10 @@ func getFieldDecoder(field reflect.StructField, index int, parentIdx []int) ([]d return getSliceFieldDecoder(field, index, fieldTagInfos, parentIdx) } - // todo: reflect Map if field.Type.Kind() == reflect.Map { return getMapTypeTextDecoder(field, index, fieldTagInfos, parentIdx) } - // 递归每一个 struct if field.Type.Kind() == reflect.Struct { var decoders []decoder el := field.Type @@ -96,7 +132,6 @@ func getFieldDecoder(field reflect.StructField, index int, parentIdx []int) ([]d // ignore unexported field continue } - // todo: 优化一下? idxes := append(parentIdx, index) var idxes []int if len(parentIdx) > 0 { idxes = append(idxes, parentIdx...) diff --git a/pkg/app/server/binding/default_binder.go b/pkg/app/server/binding/default_binder.go new file mode 100644 index 000000000..18b5de06e --- /dev/null +++ b/pkg/app/server/binding/default_binder.go @@ -0,0 +1,112 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * MIT License + * + * Copyright (c) 2019-present Fenny and Contributors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * This file may have been modified by CloudWeGo authors. All CloudWeGo + * Modifications are Copyright 2022 CloudWeGo Authors + */ + +package binding + +import ( + "fmt" + "reflect" + "sync" + + "github.com/cloudwego/hertz/internal/bytesconv" + "github.com/cloudwego/hertz/pkg/protocol" + "google.golang.org/protobuf/proto" +) + +type Bind struct { + decoderCache sync.Map +} + +func (b *Bind) Name() string { + return "hertz" +} + +func (b *Bind) Bind(req *protocol.Request, params PathParams, v interface{}) error { + err := b.preBindBody(req, v) + if err != nil { + return err + } + rv, typeID := valueAndTypeID(v) + if rv.Kind() != reflect.Pointer || rv.IsNil() { + return fmt.Errorf("receiver must be a non-nil pointer") + } + if rv.Elem().Kind() == reflect.Map { + return nil + } + cached, ok := b.decoderCache.Load(typeID) + if ok { + // cached decoder, fast path + decoder := cached.(Decoder) + return decoder(req, params, rv.Elem()) + } + + decoder, err := getReqDecoder(rv.Type()) + if err != nil { + return err + } + + b.decoderCache.Store(typeID, decoder) + return decoder(req, params, rv.Elem()) +} + +var ( + jsonContentTypeBytes = "application/json; charset=utf-8" + protobufContentType = "application/x-protobuf" +) + +// best effort binding +func (b *Bind) preBindBody(req *protocol.Request, v interface{}) error { + if req.Header.ContentLength() <= 0 { + return nil + } + switch bytesconv.B2s(req.Header.ContentType()) { + case jsonContentTypeBytes: + // todo: Aligning the gin, add "EnableDecoderUseNumber"/"EnableDecoderDisallowUnknownFields" interface + return jsonUnmarshalFunc(req.Body(), v) + case protobufContentType: + msg, ok := v.(proto.Message) + if !ok { + return fmt.Errorf("%s can not implement 'proto.Message'", v) + } + return proto.Unmarshal(req.Body(), msg) + default: + return nil + } +} diff --git a/pkg/app/server/binding/default_validator.go b/pkg/app/server/binding/default_validator.go new file mode 100644 index 000000000..0ff3cd3d6 --- /dev/null +++ b/pkg/app/server/binding/default_validator.go @@ -0,0 +1,94 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * The MIT License (MIT) + * + * Copyright (c) 2014 Manuel Martínez-Almeida + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + * This file may have been modified by CloudWeGo authors. All CloudWeGo + * Modifications are Copyright 2022 CloudWeGo Authors + */ + +package binding + +import ( + "reflect" + "sync" + + "github.com/go-playground/validator/v10" +) + +var _ StructValidator = (*defaultValidator)(nil) + +type defaultValidator struct { + once sync.Once + validate *validator.Validate +} + +// ValidateStruct receives any kind of type, but only performed struct or pointer to struct type. +func (v *defaultValidator) ValidateStruct(obj interface{}) error { + if obj == nil { + return nil + } + + value := reflect.ValueOf(obj) + switch value.Kind() { + case reflect.Ptr: + return v.ValidateStruct(value.Elem().Interface()) + case reflect.Struct: + return v.validateStruct(obj) + default: + return nil + } +} + +// validateStruct receives struct type +func (v *defaultValidator) validateStruct(obj interface{}) error { + v.lazyinit() + return v.validate.Struct(obj) +} + +func (v *defaultValidator) lazyinit() { + v.once.Do(func() { + v.validate = validator.New() + v.validate.SetTagName("validate") + }) +} + +// Engine returns the underlying validator engine which powers the default +// Validator instance. This is useful if you want to register custom validations +// or struct level validations. See validator GoDoc for more info - +// https://pkg.go.dev/github.com/go-playground/validator/v10 +func (v *defaultValidator) Engine() interface{} { + v.lazyinit() + return v.validate +} diff --git a/pkg/app/server/binding_v2/getter.go b/pkg/app/server/binding/getter.go similarity index 52% rename from pkg/app/server/binding_v2/getter.go rename to pkg/app/server/binding/getter.go index 895d1aa61..014aae323 100644 --- a/pkg/app/server/binding_v2/getter.go +++ b/pkg/app/server/binding/getter.go @@ -1,15 +1,52 @@ -package binding_v2 +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * MIT License + * + * Copyright (c) 2019-present Fenny and Contributors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * This file may have been modified by CloudWeGo authors. All CloudWeGo + * Modifications are Copyright 2022 CloudWeGo Authors + */ + +package binding import ( "github.com/cloudwego/hertz/internal/bytesconv" "github.com/cloudwego/hertz/pkg/protocol" ) -// todo: 优化,对于非数组类型的解析,要不要再提供一个不返回 []string 的 - type getter func(req *protocol.Request, params PathParams, key string, defaultValue ...string) (ret []string) -// todo string 强转优化 func PathParam(req *protocol.Request, params PathParams, key string, defaultValue ...string) (ret []string) { var value string if params != nil { @@ -26,18 +63,37 @@ func PathParam(req *protocol.Request, params PathParams, key string, defaultValu return } -// todo 区分postform和multipartform +// todo: Optimize 'postform' and 'multipart-form' func Form(req *protocol.Request, params PathParams, key string, defaultValue ...string) (ret []string) { req.URI().QueryArgs().VisitAll(func(queryKey, value []byte) { if bytesconv.B2s(queryKey) == key { ret = append(ret, string(value)) } }) + if len(ret) > 0 { + return + } + req.PostArgs().VisitAll(func(formKey, value []byte) { if bytesconv.B2s(formKey) == key { ret = append(ret, string(value)) } }) + if len(ret) > 0 { + return + } + + mf, err := req.MultipartForm() + if err == nil && mf.Value != nil { + for k, v := range mf.Value { + if k == key { + ret = append(ret, v[0]) + } + } + } + if len(ret) > 0 { + return + } if len(ret) == 0 && len(defaultValue) != 0 { ret = append(ret, defaultValue...) diff --git a/pkg/app/server/binding/json.go b/pkg/app/server/binding/json.go new file mode 100644 index 000000000..24407fb58 --- /dev/null +++ b/pkg/app/server/binding/json.go @@ -0,0 +1,65 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * MIT License + * + * Copyright (c) 2019-present Fenny and Contributors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * This file may have been modified by CloudWeGo authors. All CloudWeGo + * Modifications are Copyright 2022 CloudWeGo Authors + */ + +package binding + +import ( + "encoding/json" + + hjson "github.com/cloudwego/hertz/pkg/common/json" +) + +// JSONUnmarshaler is the interface implemented by types +// that can unmarshal a JSON description of themselves. +type JSONUnmarshaler func(data []byte, v interface{}) error + +var jsonUnmarshalFunc JSONUnmarshaler + +func init() { + ResetJSONUnmarshaler(hjson.Unmarshal) +} + +func ResetJSONUnmarshaler(fn JSONUnmarshaler) { + jsonUnmarshalFunc = fn +} + +func ResetStdJSONUnmarshaler() { + ResetJSONUnmarshaler(json.Unmarshal) +} diff --git a/pkg/app/server/binding_v2/map_type_decoder.go b/pkg/app/server/binding/map_type_decoder.go similarity index 55% rename from pkg/app/server/binding_v2/map_type_decoder.go rename to pkg/app/server/binding/map_type_decoder.go index bfdc993c1..2b9966f58 100644 --- a/pkg/app/server/binding_v2/map_type_decoder.go +++ b/pkg/app/server/binding/map_type_decoder.go @@ -1,4 +1,44 @@ -package binding_v2 +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * MIT License + * + * Copyright (c) 2019-present Fenny and Contributors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * This file may have been modified by CloudWeGo authors. All CloudWeGo + * Modifications are Copyright 2022 CloudWeGo Authors + */ + +package binding import ( "fmt" @@ -16,7 +56,6 @@ type mapTypeFieldTextDecoder struct { func (d *mapTypeFieldTextDecoder) Decode(req *protocol.Request, params PathParams, reqValue reflect.Value) error { var text string var defaultValue string - // 最大努力交付,对齐 hertz 现有设计 for _, tagInfo := range d.tagInfos { if tagInfo.Key == jsonTag { continue @@ -27,7 +66,6 @@ func (d *mapTypeFieldTextDecoder) Decode(req *protocol.Request, params PathParam ret := tagInfo.Getter(req, params, tagInfo.Value) defaultValue = tagInfo.Default if len(ret) != 0 { - // 非数组/切片类型,只取第一个值作为值 text = ret[0] break } @@ -42,7 +80,6 @@ func (d *mapTypeFieldTextDecoder) Decode(req *protocol.Request, params PathParam reqValue = GetFieldValue(reqValue, d.parentIndex) field := reqValue.Field(d.index) if field.Kind() == reflect.Ptr { - // 如果是指针则新建一个reflect.Value,然后赋值给指针 t := field.Type() var ptrDepth int for t.Kind() == reflect.Ptr { diff --git a/pkg/app/server/binding/reflect.go b/pkg/app/server/binding/reflect.go new file mode 100644 index 000000000..3471dbceb --- /dev/null +++ b/pkg/app/server/binding/reflect.go @@ -0,0 +1,113 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * MIT License + * + * Copyright (c) 2019-present Fenny and Contributors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * This file may have been modified by CloudWeGo authors. All CloudWeGo + * Modifications are Copyright 2022 CloudWeGo Authors + */ + +package binding + +import ( + "reflect" + "unsafe" +) + +func valueAndTypeID(v interface{}) (reflect.Value, uintptr) { + header := (*emptyInterface)(unsafe.Pointer(&v)) + + rv := reflect.ValueOf(v) + return rv, header.typeID +} + +type emptyInterface struct { + typeID uintptr + dataPtr unsafe.Pointer +} + +// ReferenceValue convert T to *T, the ptrDepth is the count of '*'. +func ReferenceValue(v reflect.Value, ptrDepth int) reflect.Value { + switch { + case ptrDepth > 0: + for ; ptrDepth > 0; ptrDepth-- { + vv := reflect.New(v.Type()) + vv.Elem().Set(v) + v = vv + } + case ptrDepth < 0: + for ; ptrDepth < 0 && v.Kind() == reflect.Ptr; ptrDepth++ { + v = v.Elem() + } + } + return v +} + +func GetNonNilReferenceValue(v reflect.Value) (reflect.Value, int) { + var ptrDepth int + t := v.Type() + elemKind := t.Kind() + for elemKind == reflect.Ptr { + t = t.Elem() + elemKind = t.Kind() + ptrDepth++ + } + val := reflect.New(t).Elem() + return val, ptrDepth +} + +func GetFieldValue(reqValue reflect.Value, parentIndex []int) reflect.Value { + for _, idx := range parentIndex { + if reqValue.Kind() == reflect.Ptr && reqValue.IsNil() { + nonNilVal, ptrDepth := GetNonNilReferenceValue(reqValue) + reqValue.Set(ReferenceValue(nonNilVal, ptrDepth)) + } + for reqValue.Kind() == reflect.Ptr { + reqValue = reqValue.Elem() + } + reqValue = reqValue.Field(idx) + } + + // It is possible that the parent struct is also a pointer, + // so need to create a non-nil reflect.Value for it at runtime. + for reqValue.Kind() == reflect.Ptr { + if reqValue.IsNil() { + nonNilVal, ptrDepth := GetNonNilReferenceValue(reqValue) + reqValue.Set(ReferenceValue(nonNilVal, ptrDepth)) + } + reqValue = reqValue.Elem() + } + + return reqValue +} diff --git a/pkg/app/server/binding/request.go b/pkg/app/server/binding/request.go deleted file mode 100644 index e4d70ba0d..000000000 --- a/pkg/app/server/binding/request.go +++ /dev/null @@ -1,138 +0,0 @@ -/* - * Copyright 2022 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package binding - -import ( - "mime/multipart" - "net/http" - "net/url" - - "github.com/bytedance/go-tagexpr/v2/binding" - "github.com/cloudwego/hertz/internal/bytesconv" - "github.com/cloudwego/hertz/pkg/protocol" -) - -func wrapRequest(req *protocol.Request) binding.Request { - r := &bindRequest{ - req: req, - } - return r -} - -type bindRequest struct { - req *protocol.Request -} - -func (r *bindRequest) GetQuery() url.Values { - queryMap := make(url.Values) - r.req.URI().QueryArgs().VisitAll(func(key, value []byte) { - keyStr := string(key) - values := queryMap[keyStr] - values = append(values, string(value)) - queryMap[keyStr] = values - }) - - return queryMap -} - -func (r *bindRequest) GetPostForm() (url.Values, error) { - postMap := make(url.Values) - r.req.PostArgs().VisitAll(func(key, value []byte) { - keyStr := string(key) - values := postMap[keyStr] - values = append(values, string(value)) - postMap[keyStr] = values - }) - mf, err := r.req.MultipartForm() - if err == nil { - for k, v := range mf.Value { - if len(v) > 0 { - postMap[k] = v - } - } - } - - return postMap, nil -} - -func (r *bindRequest) GetForm() (url.Values, error) { - formMap := make(url.Values) - r.req.URI().QueryArgs().VisitAll(func(key, value []byte) { - keyStr := string(key) - values := formMap[keyStr] - values = append(values, string(value)) - formMap[keyStr] = values - }) - r.req.PostArgs().VisitAll(func(key, value []byte) { - keyStr := string(key) - values := formMap[keyStr] - values = append(values, string(value)) - formMap[keyStr] = values - }) - - return formMap, nil -} - -func (r *bindRequest) GetCookies() []*http.Cookie { - var cookies []*http.Cookie - r.req.Header.VisitAllCookie(func(key, value []byte) { - cookies = append(cookies, &http.Cookie{ - Name: string(key), - Value: string(value), - }) - }) - - return cookies -} - -func (r *bindRequest) GetHeader() http.Header { - header := make(http.Header) - r.req.Header.VisitAll(func(key, value []byte) { - keyStr := string(key) - values := header[keyStr] - values = append(values, string(value)) - header[keyStr] = values - }) - - return header -} - -func (r *bindRequest) GetMethod() string { - return bytesconv.B2s(r.req.Method()) -} - -func (r *bindRequest) GetContentType() string { - return bytesconv.B2s(r.req.Header.ContentType()) -} - -func (r *bindRequest) GetBody() ([]byte, error) { - return r.req.Body(), nil -} - -func (r *bindRequest) GetFileHeaders() (map[string][]*multipart.FileHeader, error) { - files := make(map[string][]*multipart.FileHeader) - mf, err := r.req.MultipartForm() - if err == nil { - for k, v := range mf.File { - if len(v) > 0 { - files[k] = v - } - } - } - - return files, nil -} diff --git a/pkg/app/server/binding_v2/slice_type_decoder.go b/pkg/app/server/binding/slice_type_decoder.go similarity index 61% rename from pkg/app/server/binding_v2/slice_type_decoder.go rename to pkg/app/server/binding/slice_type_decoder.go index 74ac884c2..c2ced705d 100644 --- a/pkg/app/server/binding_v2/slice_type_decoder.go +++ b/pkg/app/server/binding/slice_type_decoder.go @@ -1,11 +1,51 @@ -package binding_v2 +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * MIT License + * + * Copyright (c) 2019-present Fenny and Contributors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * This file may have been modified by CloudWeGo authors. All CloudWeGo + * Modifications are Copyright 2022 CloudWeGo Authors + */ + +package binding import ( "fmt" "reflect" "github.com/cloudwego/hertz/internal/bytesconv" - "github.com/cloudwego/hertz/pkg/app/server/binding_v2/text_decoder" + "github.com/cloudwego/hertz/pkg/app/server/binding/text_decoder" "github.com/cloudwego/hertz/pkg/common/utils" "github.com/cloudwego/hertz/pkg/protocol" ) @@ -26,7 +66,7 @@ func (d *sliceTypeFieldTextDecoder) Decode(req *protocol.Request, params PathPar tagInfo.Value = utils.GetNormalizeHeaderKey(tagInfo.Value, req.Header.IsDisableNormalizing()) } texts = tagInfo.Getter(req, params, tagInfo.Value) - // todo: 数组默认值 + // todo: array/slice default value defaultValue = tagInfo.Default if len(texts) != 0 { break @@ -74,8 +114,6 @@ func (d *sliceTypeFieldTextDecoder) Decode(req *protocol.Request, params PathPar return nil } -// 数组/切片类型的decoder, -// 对于map和struct类型的数组元素直接使用unmarshal,不做嵌套处理 func getSliceFieldDecoder(field reflect.StructField, index int, tagInfos []TagInfo, parentIdx []int) ([]decoder, error) { if !(field.Type.Kind() == reflect.Slice || field.Type.Kind() == reflect.Array) { return nil, fmt.Errorf("unexpected type %s, expected slice or array", field.Type.String()) @@ -125,7 +163,7 @@ func getSliceFieldDecoder(field reflect.StructField, index int, tagInfos []TagIn func stringToValue(elemType reflect.Type, text string) (v reflect.Value, err error) { v = reflect.New(elemType).Elem() - // todo:自定义类型解析 + // todo: customized type binding switch elemType.Kind() { case reflect.Struct: @@ -137,7 +175,7 @@ func stringToValue(elemType reflect.Type, text string) (v reflect.Value, err err default: decoder, err := text_decoder.SelectTextDecoder(elemType) if err != nil { - return reflect.Value{}, fmt.Errorf("unsupport type %s for slice/array", elemType.String()) + return reflect.Value{}, fmt.Errorf("unsupported type %s for slice/array", elemType.String()) } err = decoder.UnmarshalString(text, v) if err != nil { @@ -145,5 +183,5 @@ func stringToValue(elemType reflect.Type, text string) (v reflect.Value, err err } } - return v, nil + return v, err } diff --git a/pkg/app/server/binding_v2/tag.go b/pkg/app/server/binding/tag.go similarity index 74% rename from pkg/app/server/binding_v2/tag.go rename to pkg/app/server/binding/tag.go index 90e625abd..6d5d88b0c 100644 --- a/pkg/app/server/binding_v2/tag.go +++ b/pkg/app/server/binding/tag.go @@ -1,4 +1,20 @@ -package binding_v2 +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package binding import ( "reflect" diff --git a/pkg/app/server/binding/text_decoder/bool.go b/pkg/app/server/binding/text_decoder/bool.go new file mode 100644 index 000000000..5ae167296 --- /dev/null +++ b/pkg/app/server/binding/text_decoder/bool.go @@ -0,0 +1,57 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * MIT License + * + * Copyright (c) 2019-present Fenny and Contributors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * This file may have been modified by CloudWeGo authors. All CloudWeGo + * Modifications are Copyright 2022 CloudWeGo Authors + */ + +package text_decoder + +import ( + "reflect" + "strconv" +) + +type boolDecoder struct{} + +func (d *boolDecoder) UnmarshalString(s string, fieldValue reflect.Value) error { + v, err := strconv.ParseBool(s) + if err != nil { + return err + } + fieldValue.SetBool(v) + return nil +} diff --git a/pkg/app/server/binding/text_decoder/float.go b/pkg/app/server/binding/text_decoder/float.go new file mode 100644 index 000000000..f44a1c76d --- /dev/null +++ b/pkg/app/server/binding/text_decoder/float.go @@ -0,0 +1,59 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * MIT License + * + * Copyright (c) 2019-present Fenny and Contributors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * This file may have been modified by CloudWeGo authors. All CloudWeGo + * Modifications are Copyright 2022 CloudWeGo Authors + */ + +package text_decoder + +import ( + "reflect" + "strconv" +) + +type floatDecoder struct { + bitSize int +} + +func (d *floatDecoder) UnmarshalString(s string, fieldValue reflect.Value) error { + v, err := strconv.ParseFloat(s, d.bitSize) + if err != nil { + return err + } + fieldValue.SetFloat(v) + return nil +} diff --git a/pkg/app/server/binding/text_decoder/int.go b/pkg/app/server/binding/text_decoder/int.go new file mode 100644 index 000000000..1594e2016 --- /dev/null +++ b/pkg/app/server/binding/text_decoder/int.go @@ -0,0 +1,59 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * MIT License + * + * Copyright (c) 2019-present Fenny and Contributors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * This file may have been modified by CloudWeGo authors. All CloudWeGo + * Modifications are Copyright 2022 CloudWeGo Authors + */ + +package text_decoder + +import ( + "reflect" + "strconv" +) + +type intDecoder struct { + bitSize int +} + +func (d *intDecoder) UnmarshalString(s string, fieldValue reflect.Value) error { + v, err := strconv.ParseInt(s, 10, d.bitSize) + if err != nil { + return err + } + fieldValue.SetInt(v) + return nil +} diff --git a/pkg/app/server/binding/text_decoder/string.go b/pkg/app/server/binding/text_decoder/string.go new file mode 100644 index 000000000..46917469f --- /dev/null +++ b/pkg/app/server/binding/text_decoder/string.go @@ -0,0 +1,50 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * MIT License + * + * Copyright (c) 2019-present Fenny and Contributors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * This file may have been modified by CloudWeGo authors. All CloudWeGo + * Modifications are Copyright 2022 CloudWeGo Authors + */ + +package text_decoder + +import "reflect" + +type stringDecoder struct{} + +func (d *stringDecoder) UnmarshalString(s string, fieldValue reflect.Value) error { + fieldValue.SetString(s) + return nil +} diff --git a/pkg/app/server/binding/text_decoder/text_decoder.go b/pkg/app/server/binding/text_decoder/text_decoder.go new file mode 100644 index 000000000..08659aede --- /dev/null +++ b/pkg/app/server/binding/text_decoder/text_decoder.go @@ -0,0 +1,92 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * MIT License + * + * Copyright (c) 2019-present Fenny and Contributors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * This file may have been modified by CloudWeGo authors. All CloudWeGo + * Modifications are Copyright 2022 CloudWeGo Authors + */ + +package text_decoder + +import ( + "fmt" + "reflect" +) + +type TextDecoder interface { + UnmarshalString(s string, fieldValue reflect.Value) error +} + +// var textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() + +func SelectTextDecoder(rt reflect.Type) (TextDecoder, error) { + // todo: encoding.TextUnmarshaler + //if reflect.PtrTo(rt).Implements(textUnmarshalerType) { + // return &textUnmarshalEncoder{fieldType: rt}, nil + //} + + switch rt.Kind() { + case reflect.Bool: + return &boolDecoder{}, nil + case reflect.Uint8: + return &uintDecoder{bitSize: 8}, nil + case reflect.Uint16: + return &uintDecoder{bitSize: 16}, nil + case reflect.Uint32: + return &uintDecoder{bitSize: 32}, nil + case reflect.Uint64: + return &uintDecoder{bitSize: 64}, nil + case reflect.Uint: + return &uintDecoder{}, nil + case reflect.Int8: + return &intDecoder{bitSize: 8}, nil + case reflect.Int16: + return &intDecoder{bitSize: 16}, nil + case reflect.Int32: + return &intDecoder{bitSize: 32}, nil + case reflect.Int64: + return &intDecoder{bitSize: 64}, nil + case reflect.Int: + return &intDecoder{}, nil + case reflect.String: + return &stringDecoder{}, nil + case reflect.Float32: + return &floatDecoder{bitSize: 32}, nil + case reflect.Float64: + return &floatDecoder{bitSize: 64}, nil + } + + return nil, fmt.Errorf("unsupported type " + rt.String()) +} diff --git a/pkg/app/server/binding/text_decoder/unit.go b/pkg/app/server/binding/text_decoder/unit.go new file mode 100644 index 000000000..1c3703b1c --- /dev/null +++ b/pkg/app/server/binding/text_decoder/unit.go @@ -0,0 +1,59 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * MIT License + * + * Copyright (c) 2019-present Fenny and Contributors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * This file may have been modified by CloudWeGo authors. All CloudWeGo + * Modifications are Copyright 2022 CloudWeGo Authors + */ + +package text_decoder + +import ( + "reflect" + "strconv" +) + +type uintDecoder struct { + bitSize int +} + +func (d *uintDecoder) UnmarshalString(s string, fieldValue reflect.Value) error { + v, err := strconv.ParseUint(s, 10, d.bitSize) + if err != nil { + return err + } + fieldValue.SetUint(v) + return nil +} diff --git a/pkg/app/server/binding/validator.go b/pkg/app/server/binding/validator.go new file mode 100644 index 000000000..3332752a8 --- /dev/null +++ b/pkg/app/server/binding/validator.go @@ -0,0 +1,64 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * The MIT License (MIT) + * + * Copyright (c) 2014 Manuel Martínez-Almeida + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * + * This file may have been modified by CloudWeGo authors. All CloudWeGo + * Modifications are Copyright 2022 CloudWeGo Authors + */ + +package binding + +// StructValidator is the minimal interface which needs to be implemented in +// order for it to be used as the validator engine for ensuring the correctness +// of the request. Hertz provides a default implementation for this using +// https://github.com/go-playground/validator/tree/v10.6.1. +type StructValidator interface { + // ValidateStruct can receive any kind of type and it should never panic, even if the configuration is not right. + // If the received type is a slice|array, the validation should be performed travel on every element. + // If the received type is not a struct or slice|array, any validation should be skipped and nil must be returned. + // If the received type is a struct or pointer to a struct, the validation should be performed. + // If the struct is not valid or the validation itself fails, a descriptive error should be returned. + // Otherwise nil must be returned. + ValidateStruct(interface{}) error + + // Engine returns the underlying validator engine which powers the + // StructValidator implementation. + Engine() interface{} +} + +// DefaultValidator is the default validator which implements the StructValidator +// interface. It uses https://github.com/go-playground/validator/tree/v10.6.1 +// under the hood. +var DefaultValidator StructValidator = &defaultValidator{} diff --git a/pkg/app/server/binding/validator_test.go b/pkg/app/server/binding/validator_test.go new file mode 100644 index 000000000..3c7b5a292 --- /dev/null +++ b/pkg/app/server/binding/validator_test.go @@ -0,0 +1,47 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package binding + +import ( + "fmt" +) + +func ExampleDefaultValidator_ValidateStruct() { + type User struct { + FirstName string `validate:"required"` + LastName string `validate:"required"` + Age uint8 `validate:"gte=0,lte=130"` + Email string `validate:"required,email"` + FavouriteColor string `validate:"iscolor"` + } + + user := &User{ + FirstName: "Hertz", + Age: 135, + Email: "hertz", + FavouriteColor: "sad", + } + err := DefaultValidator.ValidateStruct(user) + if err != nil { + fmt.Println(err) + } + // Output: + // Key: 'User.LastName' Error:Field validation for 'LastName' failed on the 'required' tag + // Key: 'User.Age' Error:Field validation for 'Age' failed on the 'lte' tag + // Key: 'User.Email' Error:Field validation for 'Email' failed on the 'email' tag + // Key: 'User.FavouriteColor' Error:Field validation for 'FavouriteColor' failed on the 'iscolor' tag +} diff --git a/pkg/app/server/binding_v2/binder.go b/pkg/app/server/binding_v2/binder.go deleted file mode 100644 index 314e1d76e..000000000 --- a/pkg/app/server/binding_v2/binder.go +++ /dev/null @@ -1,17 +0,0 @@ -package binding_v2 - -import ( - "github.com/cloudwego/hertz/pkg/protocol" -) - -// PathParams parameter acquisition interface on the URL path -type PathParams interface { - Get(name string) (string, bool) -} - -type Binder interface { - Name() string - Bind(*protocol.Request, PathParams, interface{}) error -} - -var DefaultBinder Binder = &Bind{} diff --git a/pkg/app/server/binding_v2/customized_type_decoder.go b/pkg/app/server/binding_v2/customized_type_decoder.go deleted file mode 100644 index 302ea61e4..000000000 --- a/pkg/app/server/binding_v2/customized_type_decoder.go +++ /dev/null @@ -1,38 +0,0 @@ -package binding_v2 - -import ( - "reflect" - - "github.com/cloudwego/hertz/pkg/protocol" -) - -type customizedFieldTextDecoder struct { - fieldInfo -} - -func (d *customizedFieldTextDecoder) Decode(req *protocol.Request, params PathParams, reqValue reflect.Value) error { - var err error - v := reflect.New(d.fieldType) - decoder := v.Interface().(CustomizedFieldDecoder) - - if err = decoder.CustomizedFieldDecode(req, params); err != nil { - return err - } - - reqValue = GetFieldValue(reqValue, d.parentIndex) - field := reqValue.Field(d.index) - if field.Kind() == reflect.Ptr { - // 如果是指针则新建一个reflect.Value,然后赋值给指针 - t := field.Type() - var ptrDepth int - for t.Kind() == reflect.Ptr { - t = t.Elem() - ptrDepth++ - } - field.Set(ReferenceValue(v.Elem(), ptrDepth)) - return nil - } - - field.Set(v) - return nil -} diff --git a/pkg/app/server/binding_v2/default_binder.go b/pkg/app/server/binding_v2/default_binder.go deleted file mode 100644 index 9710249d9..000000000 --- a/pkg/app/server/binding_v2/default_binder.go +++ /dev/null @@ -1,73 +0,0 @@ -package binding_v2 - -import ( - "fmt" - "reflect" - "sync" - - "github.com/cloudwego/hertz/internal/bytesconv" - "github.com/cloudwego/hertz/pkg/protocol" - "google.golang.org/protobuf/proto" -) - -type Bind struct { - decoderCache sync.Map -} - -func (b *Bind) Name() string { - return "hertz" -} - -func (b *Bind) Bind(req *protocol.Request, params PathParams, v interface{}) error { - // todo: 先做 body unmarshal, 先尝试做 body 绑定,然后再尝试绑定其他内容 - err := b.preBindBody(req, v) - if err != nil { - return err - } - rv, typeID := valueAndTypeID(v) - if rv.Kind() != reflect.Pointer || rv.IsNil() { - return fmt.Errorf("receiver must be a non-nil pointer") - } - if rv.Elem().Kind() == reflect.Map { - return nil - } - cached, ok := b.decoderCache.Load(typeID) - if ok { - // cached decoder, fast path - decoder := cached.(Decoder) - return decoder(req, params, rv.Elem()) - } - - decoder, err := getReqDecoder(rv.Type()) - if err != nil { - return err - } - - b.decoderCache.Store(typeID, decoder) - return decoder(req, params, rv.Elem()) -} - -var ( - jsonContentTypeBytes = "application/json; charset=utf-8" - protobufContentType = "application/x-protobuf" -) - -// best effort binding -func (b *Bind) preBindBody(req *protocol.Request, v interface{}) error { - if req.Header.ContentLength() <= 0 { - return nil - } - switch bytesconv.B2s(req.Header.ContentType()) { - case jsonContentTypeBytes: - // todo: 对齐gin, 添加 "EnableDecoderUseNumber"/"EnableDecoderDisallowUnknownFields" 接口 - return jsonUnmarshalFunc(req.Body(), v) - case protobufContentType: - msg, ok := v.(proto.Message) - if !ok { - return fmt.Errorf("%s can not implement 'proto.Message'", v) - } - return proto.Unmarshal(req.Body(), msg) - default: - return nil - } -} diff --git a/pkg/app/server/binding_v2/default_validator.go b/pkg/app/server/binding_v2/default_validator.go deleted file mode 100644 index 988b05fc3..000000000 --- a/pkg/app/server/binding_v2/default_validator.go +++ /dev/null @@ -1,54 +0,0 @@ -package binding_v2 - -import ( - "reflect" - "sync" - - "github.com/go-playground/validator/v10" -) - -var _ StructValidator = (*defaultValidator)(nil) - -type defaultValidator struct { - once sync.Once - validate *validator.Validate -} - -// ValidateStruct receives any kind of type, but only performed struct or pointer to struct type. -func (v *defaultValidator) ValidateStruct(obj interface{}) error { - if obj == nil { - return nil - } - - value := reflect.ValueOf(obj) - switch value.Kind() { - case reflect.Ptr: - return v.ValidateStruct(value.Elem().Interface()) - case reflect.Struct: - return v.validateStruct(obj) - default: - return nil - } -} - -// validateStruct receives struct type -func (v *defaultValidator) validateStruct(obj interface{}) error { - v.lazyinit() - return v.validate.Struct(obj) -} - -func (v *defaultValidator) lazyinit() { - v.once.Do(func() { - v.validate = validator.New() - v.validate.SetTagName("validate") - }) -} - -// Engine returns the underlying validator engine which powers the default -// Validator instance. This is useful if you want to register custom validations -// or struct level validations. See validator GoDoc for more info - -// https://pkg.go.dev/github.com/go-playground/validator/v10 -func (v *defaultValidator) Engine() interface{} { - v.lazyinit() - return v.validate -} diff --git a/pkg/app/server/binding_v2/json.go b/pkg/app/server/binding_v2/json.go deleted file mode 100644 index 049cc7456..000000000 --- a/pkg/app/server/binding_v2/json.go +++ /dev/null @@ -1,25 +0,0 @@ -package binding_v2 - -import ( - "encoding/json" - - hjson "github.com/cloudwego/hertz/pkg/common/json" -) - -// JSONUnmarshaler is the interface implemented by types -// that can unmarshal a JSON description of themselves. -type JSONUnmarshaler func(data []byte, v interface{}) error - -var jsonUnmarshalFunc JSONUnmarshaler - -func init() { - ResetJSONUnmarshaler(hjson.Unmarshal) -} - -func ResetJSONUnmarshaler(fn JSONUnmarshaler) { - jsonUnmarshalFunc = fn -} - -func ResetStdJSONUnmarshaler() { - ResetJSONUnmarshaler(json.Unmarshal) -} diff --git a/pkg/app/server/binding_v2/reflect.go b/pkg/app/server/binding_v2/reflect.go deleted file mode 100644 index 79034aba8..000000000 --- a/pkg/app/server/binding_v2/reflect.go +++ /dev/null @@ -1,72 +0,0 @@ -package binding_v2 - -import ( - "reflect" - "unsafe" -) - -func valueAndTypeID(v interface{}) (reflect.Value, uintptr) { - header := (*emptyInterface)(unsafe.Pointer(&v)) - - rv := reflect.ValueOf(v) - return rv, header.typeID -} - -type emptyInterface struct { - typeID uintptr - dataPtr unsafe.Pointer -} - -// ReferenceValue convert T to *T, the ptrDepth is the count of '*'. -func ReferenceValue(v reflect.Value, ptrDepth int) reflect.Value { - switch { - case ptrDepth > 0: - for ; ptrDepth > 0; ptrDepth-- { - vv := reflect.New(v.Type()) - vv.Elem().Set(v) - v = vv - } - case ptrDepth < 0: - for ; ptrDepth < 0 && v.Kind() == reflect.Ptr; ptrDepth++ { - v = v.Elem() - } - } - return v -} - -func GetNonNilReferenceValue(v reflect.Value) (reflect.Value, int) { - var ptrDepth int - t := v.Type() - elemKind := t.Kind() - for elemKind == reflect.Ptr { - t = t.Elem() - elemKind = t.Kind() - ptrDepth++ - } - val := reflect.New(t).Elem() - return val, ptrDepth -} - -func GetFieldValue(reqValue reflect.Value, parentIndex []int) reflect.Value { - for _, idx := range parentIndex { - if reqValue.Kind() == reflect.Ptr && reqValue.IsNil() { - nonNilVal, ptrDepth := GetNonNilReferenceValue(reqValue) - reqValue.Set(ReferenceValue(nonNilVal, ptrDepth)) - } - for reqValue.Kind() == reflect.Ptr { - reqValue = reqValue.Elem() - } - reqValue = reqValue.Field(idx) - } - - // 父 struct 有可能也是一个指针,所以需要再处理一次才能得到最终的父Value(非nil的reflect.Value) - for reqValue.Kind() == reflect.Ptr { - if reqValue.IsNil() { - nonNilVal, ptrDepth := GetNonNilReferenceValue(reqValue) - reqValue.Set(ReferenceValue(nonNilVal, ptrDepth)) - } - reqValue = reqValue.Elem() - } - - return reqValue -} diff --git a/pkg/app/server/binding_v2/text_decoder/bool.go b/pkg/app/server/binding_v2/text_decoder/bool.go deleted file mode 100644 index b669f5f9a..000000000 --- a/pkg/app/server/binding_v2/text_decoder/bool.go +++ /dev/null @@ -1,17 +0,0 @@ -package text_decoder - -import ( - "reflect" - "strconv" -) - -type boolDecoder struct{} - -func (d *boolDecoder) UnmarshalString(s string, fieldValue reflect.Value) error { - v, err := strconv.ParseBool(s) - if err != nil { - return err - } - fieldValue.SetBool(v) - return nil -} diff --git a/pkg/app/server/binding_v2/text_decoder/float.go b/pkg/app/server/binding_v2/text_decoder/float.go deleted file mode 100644 index 395526153..000000000 --- a/pkg/app/server/binding_v2/text_decoder/float.go +++ /dev/null @@ -1,19 +0,0 @@ -package text_decoder - -import ( - "reflect" - "strconv" -) - -type floatDecoder struct { - bitSize int -} - -func (d *floatDecoder) UnmarshalString(s string, fieldValue reflect.Value) error { - v, err := strconv.ParseFloat(s, d.bitSize) - if err != nil { - return err - } - fieldValue.SetFloat(v) - return nil -} diff --git a/pkg/app/server/binding_v2/text_decoder/int.go b/pkg/app/server/binding_v2/text_decoder/int.go deleted file mode 100644 index 13a26e644..000000000 --- a/pkg/app/server/binding_v2/text_decoder/int.go +++ /dev/null @@ -1,19 +0,0 @@ -package text_decoder - -import ( - "reflect" - "strconv" -) - -type intDecoder struct { - bitSize int -} - -func (d *intDecoder) UnmarshalString(s string, fieldValue reflect.Value) error { - v, err := strconv.ParseInt(s, 10, d.bitSize) - if err != nil { - return err - } - fieldValue.SetInt(v) - return nil -} diff --git a/pkg/app/server/binding_v2/text_decoder/string.go b/pkg/app/server/binding_v2/text_decoder/string.go deleted file mode 100644 index 6290a4c31..000000000 --- a/pkg/app/server/binding_v2/text_decoder/string.go +++ /dev/null @@ -1,11 +0,0 @@ -package text_decoder - -import "reflect" - -type stringDecoder struct{} - -func (d *stringDecoder) UnmarshalString(s string, fieldValue reflect.Value) error { - // todo: 优化一下 - fieldValue.SetString(s) - return nil -} diff --git a/pkg/app/server/binding_v2/text_decoder/text_decoder.go b/pkg/app/server/binding_v2/text_decoder/text_decoder.go deleted file mode 100644 index 934b9a9c0..000000000 --- a/pkg/app/server/binding_v2/text_decoder/text_decoder.go +++ /dev/null @@ -1,52 +0,0 @@ -package text_decoder - -import ( - "fmt" - "reflect" -) - -type TextDecoder interface { - UnmarshalString(s string, fieldValue reflect.Value) error -} - -// var textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() - -func SelectTextDecoder(rt reflect.Type) (TextDecoder, error) { - // todo: encoding.TextUnmarshaler - //if reflect.PtrTo(rt).Implements(textUnmarshalerType) { - // return &textUnmarshalEncoder{fieldType: rt}, nil - //} - - switch rt.Kind() { - case reflect.Bool: - return &boolDecoder{}, nil - case reflect.Uint8: - return &uintDecoder{bitSize: 8}, nil - case reflect.Uint16: - return &uintDecoder{bitSize: 16}, nil - case reflect.Uint32: - return &uintDecoder{bitSize: 32}, nil - case reflect.Uint64: - return &uintDecoder{bitSize: 64}, nil - case reflect.Uint: - return &uintDecoder{}, nil - case reflect.Int8: - return &intDecoder{bitSize: 8}, nil - case reflect.Int16: - return &intDecoder{bitSize: 16}, nil - case reflect.Int32: - return &intDecoder{bitSize: 32}, nil - case reflect.Int64: - return &intDecoder{bitSize: 64}, nil - case reflect.Int: - return &intDecoder{}, nil - case reflect.String: - return &stringDecoder{}, nil - case reflect.Float32: - return &floatDecoder{bitSize: 32}, nil - case reflect.Float64: - return &floatDecoder{bitSize: 64}, nil - } - - return nil, fmt.Errorf("unsupported type " + rt.String()) -} diff --git a/pkg/app/server/binding_v2/text_decoder/unit.go b/pkg/app/server/binding_v2/text_decoder/unit.go deleted file mode 100644 index cb766964a..000000000 --- a/pkg/app/server/binding_v2/text_decoder/unit.go +++ /dev/null @@ -1,19 +0,0 @@ -package text_decoder - -import ( - "reflect" - "strconv" -) - -type uintDecoder struct { - bitSize int -} - -func (d *uintDecoder) UnmarshalString(s string, fieldValue reflect.Value) error { - v, err := strconv.ParseUint(s, 10, d.bitSize) - if err != nil { - return err - } - fieldValue.SetUint(v) - return nil -} diff --git a/pkg/app/server/binding_v2/validator.go b/pkg/app/server/binding_v2/validator.go deleted file mode 100644 index 1e418b70a..000000000 --- a/pkg/app/server/binding_v2/validator.go +++ /dev/null @@ -1,24 +0,0 @@ -package binding_v2 - -// StructValidator is the minimal interface which needs to be implemented in -// order for it to be used as the validator engine for ensuring the correctness -// of the request. Hertz provides a default implementation for this using -// https://github.com/go-playground/validator/tree/v10.6.1. -type StructValidator interface { - // ValidateStruct can receive any kind of type and it should never panic, even if the configuration is not right. - // If the received type is a slice|array, the validation should be performed travel on every element. - // If the received type is not a struct or slice|array, any validation should be skipped and nil must be returned. - // If the received type is a struct or pointer to a struct, the validation should be performed. - // If the struct is not valid or the validation itself fails, a descriptive error should be returned. - // Otherwise nil must be returned. - ValidateStruct(interface{}) error - - // Engine returns the underlying validator engine which powers the - // StructValidator implementation. - Engine() interface{} -} - -// DefaultValidator is the default validator which implements the StructValidator -// interface. It uses https://github.com/go-playground/validator/tree/v10.6.1 -// under the hood. -var DefaultValidator StructValidator = &defaultValidator{} diff --git a/pkg/app/server/binding_v2/validator_test.go b/pkg/app/server/binding_v2/validator_test.go deleted file mode 100644 index 2ec125810..000000000 --- a/pkg/app/server/binding_v2/validator_test.go +++ /dev/null @@ -1,31 +0,0 @@ -package binding_v2 - -import ( - "fmt" -) - -func ExampleValidateStruct() { - type User struct { - FirstName string `validate:"required"` - LastName string `validate:"required"` - Age uint8 `validate:"gte=0,lte=130"` - Email string `validate:"required,email"` - FavouriteColor string `validate:"iscolor"` - } - - user := &User{ - FirstName: "Hertz", - Age: 135, - Email: "hertz", - FavouriteColor: "sad", - } - err := DefaultValidator.ValidateStruct(user) - if err != nil { - fmt.Println(err) - } - // Output: - //Key: 'User.LastName' Error:Field validation for 'LastName' failed on the 'required' tag - //Key: 'User.Age' Error:Field validation for 'Age' failed on the 'lte' tag - //Key: 'User.Email' Error:Field validation for 'Email' failed on the 'email' tag - //Key: 'User.FavouriteColor' Error:Field validation for 'FavouriteColor' failed on the 'iscolor' tag -} From 9c0dd8438c5ef36f9d564cc70642e8067aa702fc Mon Sep 17 00:00:00 2001 From: fgy Date: Thu, 9 Feb 2023 19:49:21 +0800 Subject: [PATCH 10/91] feat: add file bind --- pkg/app/server/binding/base_type_decoder.go | 4 +- pkg/app/server/binding/binder_test.go | 86 +++++++++- pkg/app/server/binding/decoder.go | 6 + pkg/app/server/binding/getter.go | 5 + pkg/app/server/binding/map_type_decoder.go | 4 +- .../server/binding/multipart_file_decoder.go | 148 ++++++++++++++++++ pkg/app/server/binding/reflect.go | 8 + pkg/app/server/binding/slice_type_decoder.go | 22 ++- pkg/app/server/binding/tag.go | 19 +-- 9 files changed, 286 insertions(+), 16 deletions(-) create mode 100644 pkg/app/server/binding/multipart_file_decoder.go diff --git a/pkg/app/server/binding/base_type_decoder.go b/pkg/app/server/binding/base_type_decoder.go index 69fdb98bd..07c155372 100644 --- a/pkg/app/server/binding/base_type_decoder.go +++ b/pkg/app/server/binding/base_type_decoder.go @@ -67,7 +67,7 @@ func (d *baseTypeFieldTextDecoder) Decode(req *protocol.Request, params PathPara var text string var defaultValue string for _, tagInfo := range d.tagInfos { - if tagInfo.Key == jsonTag { + if tagInfo.Key == jsonTag || tagInfo.Key == fileNameTag { continue } if tagInfo.Key == headerTag { @@ -139,6 +139,8 @@ func getBaseTypeTextDecoder(field reflect.StructField, index int, tagInfos []Tag // do nothing case rawBodyTag: tagInfo.Getter = RawBody + case fileNameTag: + // do nothing default: } } diff --git a/pkg/app/server/binding/binder_test.go b/pkg/app/server/binding/binder_test.go index d228cfd57..bd3084ae4 100644 --- a/pkg/app/server/binding/binder_test.go +++ b/pkg/app/server/binding/binder_test.go @@ -42,10 +42,12 @@ package binding import ( "fmt" + "mime/multipart" "testing" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/protocol" + req2 "github.com/cloudwego/hertz/pkg/protocol/http1/req" "github.com/cloudwego/hertz/pkg/route/param" ) @@ -64,6 +66,11 @@ func (m *mockRequest) SetRequestURI(uri string) *mockRequest { return m } +func (m *mockRequest) SetFile(param, fileName string) *mockRequest { + m.Req.SetFile(param, fileName) + return m +} + func (m *mockRequest) SetHeader(key, value string) *mockRequest { m.Req.Header.Set(key, value) return m @@ -128,7 +135,7 @@ func TestBind_BaseType(t *testing.T) { func TestBind_SliceType(t *testing.T) { type Req struct { - ID []int `query:"id"` + ID *[]int `query:"id"` Str [3]string `query:"str"` Byte []byte `query:"b"` } @@ -145,9 +152,9 @@ func TestBind_SliceType(t *testing.T) { if err != nil { t.Error(err) } - assert.DeepEqual(t, 3, len(result.ID)) + assert.DeepEqual(t, 3, len(*result.ID)) for idx, val := range IDs { - assert.DeepEqual(t, val, result.ID[idx]) + assert.DeepEqual(t, val, (*result.ID)[idx]) } assert.DeepEqual(t, 3, len(result.Str)) for idx, val := range Strs { @@ -578,6 +585,79 @@ func TestBind_ResetJSONUnmarshal(t *testing.T) { } } +func TestBind_FileBind(t *testing.T) { + type Nest struct { + N multipart.FileHeader `file_name:"d"` + } + + var s struct { + A *multipart.FileHeader `file_name:"a"` + B *multipart.FileHeader `form:"b"` + C multipart.FileHeader + D **Nest `file_name:"d"` + } + fileName := "binder_test.go" + req := newMockRequest(). + SetRequestURI("http://foobar.com"). + SetFile("a", fileName). + SetFile("b", fileName). + SetFile("C", fileName). + SetFile("d", fileName) + req2 := req2.GetHTTP1Request(req.Req) + req2.String() + err := DefaultBinder.Bind(req.Req, nil, &s) + if err != nil { + t.Fatal(err) + } + assert.DeepEqual(t, fileName, s.A.Filename) + assert.DeepEqual(t, fileName, s.B.Filename) + assert.DeepEqual(t, fileName, s.C.Filename) + assert.DeepEqual(t, fileName, (**s.D).N.Filename) +} + +func TestBind_FileSliceBind(t *testing.T) { + type Nest struct { + N *[]*multipart.FileHeader `form:"b"` + } + var s struct { + A []multipart.FileHeader `form:"a"` + B [3]multipart.FileHeader `form:"b"` + C []*multipart.FileHeader `form:"b"` + D Nest + } + fileName := "binder_test.go" + req := newMockRequest(). + SetRequestURI("http://foobar.com"). + SetFile("a", fileName). + SetFile("a", fileName). + SetFile("a", fileName). + SetFile("b", fileName). + SetFile("b", fileName). + SetFile("b", fileName) + req2 := req2.GetHTTP1Request(req.Req) + req2.String() + err := DefaultBinder.Bind(req.Req, nil, &s) + if err != nil { + t.Fatal(err) + } + assert.DeepEqual(t, 3, len(s.A)) + for _, file := range s.A { + assert.DeepEqual(t, fileName, file.Filename) + } + assert.DeepEqual(t, 3, len(s.B)) + for _, file := range s.B { + assert.DeepEqual(t, fileName, file.Filename) + } + assert.DeepEqual(t, 3, len(s.C)) + for _, file := range s.C { + assert.DeepEqual(t, fileName, file.Filename) + } + assert.DeepEqual(t, 3, len(*s.D.N)) + for _, file := range *s.D.N { + assert.DeepEqual(t, fileName, file.Filename) + } +} + func Benchmark_Binding(b *testing.B) { type Req struct { Version string `path:"v"` diff --git a/pkg/app/server/binding/decoder.go b/pkg/app/server/binding/decoder.go index f8bb12246..0cc3695a3 100644 --- a/pkg/app/server/binding/decoder.go +++ b/pkg/app/server/binding/decoder.go @@ -42,6 +42,7 @@ package binding import ( "fmt" + "mime/multipart" "reflect" "github.com/cloudwego/hertz/pkg/protocol" @@ -126,6 +127,11 @@ func getFieldDecoder(field reflect.StructField, index int, parentIdx []int) ([]d if field.Type.Kind() == reflect.Struct { var decoders []decoder el := field.Type + // todo: built-in bindings for some common structs, code need to be optimized + switch el { + case reflect.TypeOf(multipart.FileHeader{}): + return getMultipartFileDecoder(field, index, fieldTagInfos, parentIdx) + } for i := 0; i < el.NumField(); i++ { if !el.Field(i).IsExported() { diff --git a/pkg/app/server/binding/getter.go b/pkg/app/server/binding/getter.go index 014aae323..878cd5e5c 100644 --- a/pkg/app/server/binding/getter.go +++ b/pkg/app/server/binding/getter.go @@ -155,3 +155,8 @@ func RawBody(req *protocol.Request, params PathParams, key string, defaultValue } return } + +func FileName(req *protocol.Request, params PathParams, key string, defaultValue ...string) (ret []string) { + // do nothing + return +} diff --git a/pkg/app/server/binding/map_type_decoder.go b/pkg/app/server/binding/map_type_decoder.go index 2b9966f58..7af3bdbac 100644 --- a/pkg/app/server/binding/map_type_decoder.go +++ b/pkg/app/server/binding/map_type_decoder.go @@ -57,7 +57,7 @@ func (d *mapTypeFieldTextDecoder) Decode(req *protocol.Request, params PathParam var text string var defaultValue string for _, tagInfo := range d.tagInfos { - if tagInfo.Key == jsonTag { + if tagInfo.Key == jsonTag || tagInfo.Key == fileNameTag { continue } if tagInfo.Key == headerTag { @@ -120,6 +120,8 @@ func getMapTypeTextDecoder(field reflect.StructField, index int, tagInfos []TagI // do nothing case rawBodyTag: tagInfo.Getter = RawBody + case fileNameTag: + // do nothing default: } } diff --git a/pkg/app/server/binding/multipart_file_decoder.go b/pkg/app/server/binding/multipart_file_decoder.go new file mode 100644 index 000000000..4700f5564 --- /dev/null +++ b/pkg/app/server/binding/multipart_file_decoder.go @@ -0,0 +1,148 @@ +package binding + +import ( + "fmt" + "github.com/cloudwego/hertz/pkg/protocol" + "reflect" +) + +type fileTypeDecoder struct { + fieldInfo + isRepeated bool +} + +func (d *fileTypeDecoder) Decode(req *protocol.Request, params PathParams, reqValue reflect.Value) error { + fieldValue := GetFieldValue(reqValue, d.parentIndex) + field := fieldValue.Field(d.index) + + if d.isRepeated { + return d.fileSliceDecode(req, params, reqValue) + } + var fileName string + // file_name > form > fieldName + for _, tagInfo := range d.tagInfos { + if tagInfo.Key == fileNameTag { + fileName = tagInfo.Value + break + } + if tagInfo.Key == formTag { + fileName = tagInfo.Value + } + } + if len(fileName) == 0 { + fileName = d.fieldName + } + file, err := req.FormFile(fileName) + if err != nil { + return fmt.Errorf("can not get file '%s', err: %v", fileName, err) + } + if field.Kind() == reflect.Ptr { + t := field.Type() + var ptrDepth int + for t.Kind() == reflect.Ptr { + t = t.Elem() + ptrDepth++ + } + v := reflect.New(t).Elem() + v.Set(reflect.ValueOf(*file)) + field.Set(ReferenceValue(v, ptrDepth)) + return nil + } + + // Non-pointer elems + field.Set(reflect.ValueOf(*file)) + + return nil +} + +func (d *fileTypeDecoder) fileSliceDecode(req *protocol.Request, params PathParams, reqValue reflect.Value) error { + fieldValue := GetFieldValue(reqValue, d.parentIndex) + field := fieldValue.Field(d.index) + // 如果没值,需要为其建一个值 + if field.Kind() == reflect.Ptr { + if field.IsNil() { + nonNilVal, ptrDepth := GetNonNilReferenceValue(field) + field.Set(ReferenceValue(nonNilVal, ptrDepth)) + } + } + var parentPtrDepth int + for field.Kind() == reflect.Ptr { + field = field.Elem() + parentPtrDepth++ + } + + var fileName string + // file_name > form > fieldName + for _, tagInfo := range d.tagInfos { + if tagInfo.Key == fileNameTag { + fileName = tagInfo.Value + break + } + if tagInfo.Key == formTag { + fileName = tagInfo.Value + } + } + if len(fileName) == 0 { + fileName = d.fieldName + } + multipartForm, err := req.MultipartForm() + if err != nil { + return fmt.Errorf("can not get multipartForm info, err: %v", err) + } + files, exist := multipartForm.File[fileName] + if !exist { + return fmt.Errorf("the file '%s' is not existed", fileName) + } + + if field.Kind() == reflect.Array { + if len(files) != field.Len() { + return fmt.Errorf("the numbers(%d) of file '%s' does not match the length(%d) of %s", len(files), fileName, field.Len(), field.Type().String()) + } + } else { + // slice need creating enough capacity + field = reflect.MakeSlice(field.Type(), len(files), len(files)) + } + + // handle multiple pointer + var ptrDepth int + t := d.fieldType.Elem() + elemKind := t.Kind() + for elemKind == reflect.Ptr { + t = t.Elem() + elemKind = t.Kind() + ptrDepth++ + } + + for idx, file := range files { + v := reflect.New(t).Elem() + v.Set(reflect.ValueOf(*file)) + field.Index(idx).Set(ReferenceValue(v, ptrDepth)) + } + fieldValue.Field(d.index).Set(ReferenceValue(field, parentPtrDepth)) + + return nil +} + +func getMultipartFileDecoder(field reflect.StructField, index int, tagInfos []TagInfo, parentIdx []int) ([]decoder, error) { + fieldType := field.Type + for field.Type.Kind() == reflect.Ptr { + fieldType = field.Type.Elem() + } + isRepeated := false + if fieldType.Kind() == reflect.Array || fieldType.Kind() == reflect.Slice { + isRepeated = true + } + + fieldDecoder := &fileTypeDecoder{ + fieldInfo: fieldInfo{ + index: index, + parentIndex: parentIdx, + fieldName: field.Name, + tagInfos: tagInfos, + fieldType: fieldType, + }, + isRepeated: isRepeated, + } + + return []decoder{fieldDecoder}, nil +} diff --git a/pkg/app/server/binding/reflect.go b/pkg/app/server/binding/reflect.go index 3471dbceb..7b8933442 100644 --- a/pkg/app/server/binding/reflect.go +++ b/pkg/app/server/binding/reflect.go @@ -111,3 +111,11 @@ func GetFieldValue(reqValue reflect.Value, parentIndex []int) reflect.Value { return reqValue } + +func getElemType(t reflect.Type) reflect.Type { + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + + return t +} diff --git a/pkg/app/server/binding/slice_type_decoder.go b/pkg/app/server/binding/slice_type_decoder.go index c2ced705d..f167db96f 100644 --- a/pkg/app/server/binding/slice_type_decoder.go +++ b/pkg/app/server/binding/slice_type_decoder.go @@ -42,6 +42,7 @@ package binding import ( "fmt" + "mime/multipart" "reflect" "github.com/cloudwego/hertz/internal/bytesconv" @@ -59,7 +60,7 @@ func (d *sliceTypeFieldTextDecoder) Decode(req *protocol.Request, params PathPar var texts []string var defaultValue string for _, tagInfo := range d.tagInfos { - if tagInfo.Key == jsonTag { + if tagInfo.Key == jsonTag || tagInfo.Key == fileNameTag { continue } if tagInfo.Key == headerTag { @@ -81,6 +82,17 @@ func (d *sliceTypeFieldTextDecoder) Decode(req *protocol.Request, params PathPar reqValue = GetFieldValue(reqValue, d.parentIndex) field := reqValue.Field(d.index) + if field.Kind() == reflect.Ptr { + if field.IsNil() { + nonNilVal, ptrDepth := GetNonNilReferenceValue(field) + field.Set(ReferenceValue(nonNilVal, ptrDepth)) + } + } + var parentPtrDepth int + for field.Kind() == reflect.Ptr { + field = field.Elem() + parentPtrDepth++ + } if d.isArray { if len(texts) != field.Len() { @@ -109,7 +121,7 @@ func (d *sliceTypeFieldTextDecoder) Decode(req *protocol.Request, params PathPar } field.Index(idx).Set(ReferenceValue(vv, ptrDepth)) } - reqValue.Field(d.index).Set(field) + reqValue.Field(d.index).Set(ReferenceValue(field, parentPtrDepth)) return nil } @@ -138,6 +150,8 @@ func getSliceFieldDecoder(field reflect.StructField, index int, tagInfos []TagIn // do nothing case rawBodyTag: tagInfo.Getter = RawBody + case fileNameTag: + // do nothing default: } } @@ -146,6 +160,10 @@ func getSliceFieldDecoder(field reflect.StructField, index int, tagInfos []TagIn for field.Type.Kind() == reflect.Ptr { fieldType = field.Type.Elem() } + t := getElemType(fieldType.Elem()) + if t == reflect.TypeOf(multipart.FileHeader{}) { + return getMultipartFileDecoder(field, index, tagInfos, parentIdx) + } fieldDecoder := &sliceTypeFieldTextDecoder{ fieldInfo: fieldInfo{ diff --git a/pkg/app/server/binding/tag.go b/pkg/app/server/binding/tag.go index 6d5d88b0c..f44778d97 100644 --- a/pkg/app/server/binding/tag.go +++ b/pkg/app/server/binding/tag.go @@ -22,13 +22,14 @@ import ( ) const ( - pathTag = "path" - formTag = "form" - queryTag = "query" - cookieTag = "cookie" - headerTag = "header" - jsonTag = "json" - rawBodyTag = "raw_body" + pathTag = "path" + formTag = "form" + queryTag = "query" + cookieTag = "cookie" + headerTag = "header" + jsonTag = "json" + rawBodyTag = "raw_body" + fileNameTag = "file_name" ) const ( @@ -58,7 +59,7 @@ func head(str, sep string) (head, tail string) { func lookupFieldTags(field reflect.StructField) []TagInfo { var ret []string - tags := []string{pathTag, formTag, queryTag, cookieTag, headerTag, jsonTag, rawBodyTag} + tags := []string{pathTag, formTag, queryTag, cookieTag, headerTag, jsonTag, rawBodyTag, fileNameTag} for _, tag := range tags { if _, ok := field.Tag.Lookup(tag); ok { ret = append(ret, tag) @@ -96,7 +97,7 @@ func getDefaultFieldTags(field reflect.StructField) (tagInfos []TagInfo) { defaultVal = val } - tags := []string{pathTag, formTag, queryTag, cookieTag, headerTag, jsonTag} + tags := []string{pathTag, formTag, queryTag, cookieTag, headerTag, jsonTag, fileNameTag} for _, tag := range tags { tagInfos = append(tagInfos, TagInfo{Key: tag, Value: field.Name, Default: defaultVal}) } From 4cbf86391fe561f6ec1f7a143fcaff548c65d9ae Mon Sep 17 00:00:00 2001 From: fgy Date: Mon, 3 Apr 2023 11:31:05 +0800 Subject: [PATCH 11/91] style: go lint --- pkg/app/server/binding/binder_test.go | 6 ++++-- .../server/binding/multipart_file_decoder.go | 19 ++++++++++++++++++- 2 files changed, 22 insertions(+), 3 deletions(-) diff --git a/pkg/app/server/binding/binder_test.go b/pkg/app/server/binding/binder_test.go index bd3084ae4..46a90f6b1 100644 --- a/pkg/app/server/binding/binder_test.go +++ b/pkg/app/server/binding/binder_test.go @@ -603,8 +603,9 @@ func TestBind_FileBind(t *testing.T) { SetFile("b", fileName). SetFile("C", fileName). SetFile("d", fileName) + // to parse multipart files req2 := req2.GetHTTP1Request(req.Req) - req2.String() + _ = req2.String() err := DefaultBinder.Bind(req.Req, nil, &s) if err != nil { t.Fatal(err) @@ -634,8 +635,9 @@ func TestBind_FileSliceBind(t *testing.T) { SetFile("b", fileName). SetFile("b", fileName). SetFile("b", fileName) + // to parse multipart files req2 := req2.GetHTTP1Request(req.Req) - req2.String() + _ = req2.String() err := DefaultBinder.Bind(req.Req, nil, &s) if err != nil { t.Fatal(err) diff --git a/pkg/app/server/binding/multipart_file_decoder.go b/pkg/app/server/binding/multipart_file_decoder.go index 4700f5564..cf4a4c888 100644 --- a/pkg/app/server/binding/multipart_file_decoder.go +++ b/pkg/app/server/binding/multipart_file_decoder.go @@ -1,9 +1,26 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package binding import ( "fmt" - "github.com/cloudwego/hertz/pkg/protocol" "reflect" + + "github.com/cloudwego/hertz/pkg/protocol" ) type fileTypeDecoder struct { From da1ccce5418ad7a2b92306f2ad1f9f0a3449249b Mon Sep 17 00:00:00 2001 From: fgy Date: Mon, 3 Apr 2023 11:36:44 +0800 Subject: [PATCH 12/91] feat: go mod tidy --- go.sum | 39 --------------------------------------- 1 file changed, 39 deletions(-) diff --git a/go.sum b/go.sum index 057d986d7..665310aca 100644 --- a/go.sum +++ b/go.sum @@ -1,9 +1,5 @@ -github.com/bytedance/go-tagexpr/v2 v2.9.2 h1:QySJaAIQgOEDQBLS3x9BxOWrnhqu5sQ+f6HaZIxD39I= -github.com/bytedance/go-tagexpr/v2 v2.9.2/go.mod h1:5qsx05dYOiUXOUgnQ7w3Oz8BYs2qtM/bJokdLb79wRM= github.com/bytedance/gopkg v0.0.0-20220413063733-65bf48ffb3a7 h1:PtwsQyQJGxf8iaPptPNaduEIu9BnrNms+pcRdHAxZaM= github.com/bytedance/gopkg v0.0.0-20220413063733-65bf48ffb3a7/go.mod h1:2ZlV9BaUH4+NXIBF0aMdKKAnHTzqH+iMU4KUjAbL23Q= -github.com/bytedance/mockey v1.2.1 h1:g84ngI88hz1DR4wZTL3yOuqlEcq67MretBfQUdXwrmw= -github.com/bytedance/mockey v1.2.1/go.mod h1:+Jm/fzWZAuhEDrPXVjDf/jLM2BlLXJkwk94zf2JZ3X4= github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= github.com/bytedance/sonic v1.8.1 h1:NqAHCaGaTzro0xMmnTCLUyRlbEP6r8MCA1cJUrH3Pu4= github.com/bytedance/sonic v1.8.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= @@ -25,20 +21,9 @@ github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJn github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= github.com/go-playground/validator/v10 v10.11.2 h1:q3SHpufmypg+erIExEKUmsgmhDTyhcJ38oeKGACXohU= github.com/go-playground/validator/v10 v10.11.2/go.mod h1:NieE624vt4SCTJtD87arVLvdmjPAeV8BQlHtMnw9D7s= -github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= -github.com/golang/protobuf v1.5.0 h1:LUVKkCeviFUMKqHa4tXIIij/lbhnMbP7Fn5wKdKkRh4= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= -github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 h1:EGx4pi6eqNxGaHF6qqu48+N2wcFQ5qg5FXgOdqsJ5d8= -github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= -github.com/henrylee2cn/ameda v1.4.8/go.mod h1:liZulR8DgHxdK+MEwvZIylGnmcjzQ6N6f2PlWe7nEO4= -github.com/henrylee2cn/ameda v1.4.10 h1:JdvI2Ekq7tapdPsuhrc4CaFiqw6QXFvZIULWJgQyCAk= -github.com/henrylee2cn/ameda v1.4.10/go.mod h1:liZulR8DgHxdK+MEwvZIylGnmcjzQ6N6f2PlWe7nEO4= -github.com/henrylee2cn/goutil v0.0.0-20210127050712-89660552f6f8 h1:yE9ULgp02BhYIrO6sdV/FPe0xQM6fNHkVQW2IAymfM0= -github.com/henrylee2cn/goutil v0.0.0-20210127050712-89660552f6f8/go.mod h1:Nhe/DM3671a5udlv2AdV2ni/MZzgfv2qrPL5nIi3EGQ= -github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= -github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= @@ -49,40 +34,23 @@ github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= github.com/leodido/go-urn v1.2.1 h1:BqpAaACuzVSgi/VLzGZIobT2z4v53pjosyNd9Yv6n/w= github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY= -github.com/nyaruka/phonenumbers v1.0.55 h1:bj0nTO88Y68KeUQ/n3Lo2KgK7lM1hF7L9NFuwcCl3yg= -github.com/nyaruka/phonenumbers v1.0.55/go.mod h1:sDaTZ/KPX5f8qyV9qN+hIm+4ZBARJrupC6LuhshJq1U= github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d h1:zE9ykElWQ6/NYmHa3jpm/yHnI4xSofP+UP6SpjHcSeM= -github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= -github.com/smartystreets/goconvey v1.6.4 h1:fv0U8FUIMPNf1L9lnHLvLhgicrIVChEkdzIKYqbNC9s= -github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= -github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= -github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= -github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= -github.com/tidwall/gjson v1.9.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= -github.com/tidwall/gjson v1.13.0 h1:3TFY9yxOQShrvmjdM76K+jc66zJeT6D3/VFFYCGQf7M= -github.com/tidwall/gjson v1.13.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= -github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= -github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= -github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= -github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= -golang.org/x/arch v0.0.0-20201008161808-52c3e6f60cff/go.mod h1:flIaEI6LNU6xOCD5PaJvn9wGP0agmIOqjrtsKGRguv4= golang.org/x/arch v0.0.0-20210923205945-b76863e36670 h1:18EFjUmQOcUvxNYSkA6jO9VAiXCnxFY6NyDX0bHDmkU= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= @@ -95,16 +63,12 @@ golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.5.0/go.mod h1:DivGGAXEgPSlEBzxGzZI+ZLohi+xUj054jfeKui00ws= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 h1:uVc8UZUe6tr40fFVnUP5Oj+veunVezqYl9z7DYw9xzw= golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20220110181412-a018aaa089fe/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= @@ -124,8 +88,6 @@ golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGm golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= -golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= @@ -135,7 +97,6 @@ gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= -gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= From 4d00e4f84b4f15f568641cefe80a7294e3d4f961 Mon Sep 17 00:00:00 2001 From: fgy Date: Tue, 9 May 2023 20:36:38 +0800 Subject: [PATCH 13/91] optimize: performance optimize --- go.sum | 13 ++ licenses/LICENSE-validator.txt | 21 ++++ pkg/app/context.go | 28 ++--- pkg/app/server/binding/base_type_decoder.go | 22 ++-- pkg/app/server/binding/binder.go | 8 +- pkg/app/server/binding/binder_test.go | 2 +- .../server/binding/customized_type_decoder.go | 6 +- pkg/app/server/binding/decoder.go | 28 +++-- .../binding/{default_binder.go => default.go} | 77 +++++++++++- pkg/app/server/binding/default_validator.go | 94 -------------- pkg/app/server/binding/getter.go | 119 +++++++++++------- pkg/app/server/binding/map_type_decoder.go | 17 ++- .../server/binding/multipart_file_decoder.go | 10 +- pkg/app/server/binding/slice_type_decoder.go | 22 ++-- 14 files changed, 257 insertions(+), 210 deletions(-) create mode 100644 licenses/LICENSE-validator.txt rename pkg/app/server/binding/{default_binder.go => default.go} (55%) delete mode 100644 pkg/app/server/binding/default_validator.go diff --git a/go.sum b/go.sum index 665310aca..e21bb766c 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,7 @@ github.com/bytedance/gopkg v0.0.0-20220413063733-65bf48ffb3a7 h1:PtwsQyQJGxf8iaPptPNaduEIu9BnrNms+pcRdHAxZaM= github.com/bytedance/gopkg v0.0.0-20220413063733-65bf48ffb3a7/go.mod h1:2ZlV9BaUH4+NXIBF0aMdKKAnHTzqH+iMU4KUjAbL23Q= +github.com/bytedance/mockey v1.2.1 h1:g84ngI88hz1DR4wZTL3yOuqlEcq67MretBfQUdXwrmw= +github.com/bytedance/mockey v1.2.1/go.mod h1:+Jm/fzWZAuhEDrPXVjDf/jLM2BlLXJkwk94zf2JZ3X4= github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= github.com/bytedance/sonic v1.8.1 h1:NqAHCaGaTzro0xMmnTCLUyRlbEP6r8MCA1cJUrH3Pu4= github.com/bytedance/sonic v1.8.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZXU064P/U= @@ -24,6 +26,10 @@ github.com/go-playground/validator/v10 v10.11.2/go.mod h1:NieE624vt4SCTJtD87arVL github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 h1:EGx4pi6eqNxGaHF6qqu48+N2wcFQ5qg5FXgOdqsJ5d8= +github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= +github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= +github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= @@ -39,6 +45,10 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= +github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d h1:zE9ykElWQ6/NYmHa3jpm/yHnI4xSofP+UP6SpjHcSeM= +github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= +github.com/smartystreets/goconvey v1.6.4 h1:fv0U8FUIMPNf1L9lnHLvLhgicrIVChEkdzIKYqbNC9s= +github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9QV7WQ/tjFTllLA= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= @@ -51,6 +61,7 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= +golang.org/x/arch v0.0.0-20201008161808-52c3e6f60cff/go.mod h1:flIaEI6LNU6xOCD5PaJvn9wGP0agmIOqjrtsKGRguv4= golang.org/x/arch v0.0.0-20210923205945-b76863e36670 h1:18EFjUmQOcUvxNYSkA6jO9VAiXCnxFY6NyDX0bHDmkU= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= @@ -58,6 +69,7 @@ golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5y golang.org/x/crypto v0.5.0 h1:U/0M97KRkSFvyD/3FSmdP5W5swImpNgle/EHFhOsQPE= golang.org/x/crypto v0.5.0/go.mod h1:NK/OQwhpMQP3MwtdjgLlYHnH9ebylxKWv3e0fK+mkQU= golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= @@ -85,6 +97,7 @@ golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= golang.org/x/text v0.6.0 h1:3XmdazWV+ubf7QgHSTWeykHOci5oeekaGJBLkrkaw4k= golang.org/x/text v0.6.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= +golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= diff --git a/licenses/LICENSE-validator.txt b/licenses/LICENSE-validator.txt new file mode 100644 index 000000000..ab4304b3c --- /dev/null +++ b/licenses/LICENSE-validator.txt @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2015 Dean Karn + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/pkg/app/context.go b/pkg/app/context.go index 92ad94217..c60febe24 100644 --- a/pkg/app/context.go +++ b/pkg/app/context.go @@ -1130,24 +1130,24 @@ func (ctx *RequestContext) Cookie(key string) []byte { return ctx.Request.Header.Cookie(key) } -// SetCookie adds a Set-Cookie header to the Response's headers. +// SetCookie adds a Set-cookie header to the Response's headers. // // Parameter introduce: -// name and value is used to set cookie's name and value, eg. Set-Cookie: name=value -// maxAge is use to set cookie's expiry date, eg. Set-Cookie: name=value; max-age=1 -// path and domain is used to set the scope of a cookie, eg. Set-Cookie: name=value;domain=localhost; path=/; -// secure and httpOnly is used to sent cookies securely; eg. Set-Cookie: name=value;HttpOnly; secure; -// sameSite let servers specify whether/when cookies are sent with cross-site requests; eg. Set-Cookie: name=value;HttpOnly; secure; SameSite=Lax; +// name and value is used to set cookie's name and value, eg. Set-cookie: name=value +// maxAge is use to set cookie's expiry date, eg. Set-cookie: name=value; max-age=1 +// path and domain is used to set the scope of a cookie, eg. Set-cookie: name=value;domain=localhost; path=/; +// secure and httpOnly is used to sent cookies securely; eg. Set-cookie: name=value;HttpOnly; secure; +// sameSite let servers specify whether/when cookies are sent with cross-site requests; eg. Set-cookie: name=value;HttpOnly; secure; SameSite=Lax; // // For example: // 1. ctx.SetCookie("user", "hertz", 1, "/", "localhost",protocol.CookieSameSiteLaxMode, true, true) -// add response header ---> Set-Cookie: user=hertz; max-age=1; domain=localhost; path=/; HttpOnly; secure; SameSite=Lax; +// add response header ---> Set-cookie: user=hertz; max-age=1; domain=localhost; path=/; HttpOnly; secure; SameSite=Lax; // 2. ctx.SetCookie("user", "hertz", 10, "/", "localhost",protocol.CookieSameSiteLaxMode, false, false) -// add response header ---> Set-Cookie: user=hertz; max-age=10; domain=localhost; path=/; SameSite=Lax; +// add response header ---> Set-cookie: user=hertz; max-age=10; domain=localhost; path=/; SameSite=Lax; // 3. ctx.SetCookie("", "hertz", 10, "/", "localhost",protocol.CookieSameSiteLaxMode, false, false) -// add response header ---> Set-Cookie: hertz; max-age=10; domain=localhost; path=/; SameSite=Lax; +// add response header ---> Set-cookie: hertz; max-age=10; domain=localhost; path=/; SameSite=Lax; // 4. ctx.SetCookie("user", "", 10, "/", "localhost",protocol.CookieSameSiteLaxMode, false, false) -// add response header ---> Set-Cookie: user=; max-age=10; domain=localhost; path=/; SameSite=Lax; +// add response header ---> Set-cookie: user=; max-age=10; domain=localhost; path=/; SameSite=Lax; func (ctx *RequestContext) SetCookie(name, value string, maxAge int, path, domain string, sameSite protocol.CookieSameSite, secure, httpOnly bool) { if path == "" { path = "/" @@ -1223,10 +1223,10 @@ func (ctx *RequestContext) PostArgs() *protocol.Args { // For example: // // GET /path?id=1234&name=Manu&value= -// c.Query("id") == "1234" -// c.Query("name") == "Manu" -// c.Query("value") == "" -// c.Query("wtf") == "" +// c.query("id") == "1234" +// c.query("name") == "Manu" +// c.query("value") == "" +// c.query("wtf") == "" func (ctx *RequestContext) Query(key string) string { value, _ := ctx.GetQuery(key) return value diff --git a/pkg/app/server/binding/base_type_decoder.go b/pkg/app/server/binding/base_type_decoder.go index 07c155372..483a0a0cf 100644 --- a/pkg/app/server/binding/base_type_decoder.go +++ b/pkg/app/server/binding/base_type_decoder.go @@ -46,15 +46,14 @@ import ( "github.com/cloudwego/hertz/pkg/app/server/binding/text_decoder" "github.com/cloudwego/hertz/pkg/common/utils" - "github.com/cloudwego/hertz/pkg/protocol" ) type fieldInfo struct { index int parentIndex []int fieldName string - tagInfos []TagInfo // query,param,header,respHeader ... - fieldType reflect.Type + tagInfos []TagInfo // query,param,header,respHeader ... + fieldType reflect.Type // can not be pointer type } type baseTypeFieldTextDecoder struct { @@ -62,7 +61,7 @@ type baseTypeFieldTextDecoder struct { decoder text_decoder.TextDecoder } -func (d *baseTypeFieldTextDecoder) Decode(req *protocol.Request, params PathParams, reqValue reflect.Value) error { +func (d *baseTypeFieldTextDecoder) Decode(req *bindRequest, params PathParam, reqValue reflect.Value) error { var err error var text string var defaultValue string @@ -71,7 +70,7 @@ func (d *baseTypeFieldTextDecoder) Decode(req *protocol.Request, params PathPara continue } if tagInfo.Key == headerTag { - tagInfo.Value = utils.GetNormalizeHeaderKey(tagInfo.Value, req.Header.IsDisableNormalizing()) + tagInfo.Value = utils.GetNormalizeHeaderKey(tagInfo.Value, req.Req.Header.IsDisableNormalizing()) } ret := tagInfo.Getter(req, params, tagInfo.Value) defaultValue = tagInfo.Default @@ -90,6 +89,7 @@ func (d *baseTypeFieldTextDecoder) Decode(req *protocol.Request, params PathPara if len(text) == 0 && len(defaultValue) != 0 { text = defaultValue } + //todo: check a=?b=?c= 这种情况 loosemode if text == "" { return nil } @@ -126,19 +126,19 @@ func getBaseTypeTextDecoder(field reflect.StructField, index int, tagInfos []Tag for idx, tagInfo := range tagInfos { switch tagInfo.Key { case pathTag: - tagInfos[idx].Getter = PathParam + tagInfos[idx].Getter = path case formTag: - tagInfos[idx].Getter = Form + tagInfos[idx].Getter = form case queryTag: - tagInfos[idx].Getter = Query + tagInfos[idx].Getter = query case cookieTag: - tagInfos[idx].Getter = Cookie + tagInfos[idx].Getter = cookie case headerTag: - tagInfos[idx].Getter = Header + tagInfos[idx].Getter = header case jsonTag: // do nothing case rawBodyTag: - tagInfo.Getter = RawBody + tagInfo.Getter = rawBody case fileNameTag: // do nothing default: diff --git a/pkg/app/server/binding/binder.go b/pkg/app/server/binding/binder.go index d6fbda809..835ba029b 100644 --- a/pkg/app/server/binding/binder.go +++ b/pkg/app/server/binding/binder.go @@ -44,14 +44,14 @@ import ( "github.com/cloudwego/hertz/pkg/protocol" ) -// PathParams parameter acquisition interface on the URL path -type PathParams interface { +// PathParam parameter acquisition interface on the URL path +type PathParam interface { Get(name string) (string, bool) } type Binder interface { Name() string - Bind(*protocol.Request, PathParams, interface{}) error + Bind(*protocol.Request, PathParam, interface{}) error } -var DefaultBinder Binder = &Bind{} +var DefaultBinder Binder = &defaultBinder{} diff --git a/pkg/app/server/binding/binder_test.go b/pkg/app/server/binding/binder_test.go index 46a90f6b1..2a5aa9c66 100644 --- a/pkg/app/server/binding/binder_test.go +++ b/pkg/app/server/binding/binder_test.go @@ -492,7 +492,7 @@ type CustomizedDecode struct { A string } -func (c *CustomizedDecode) CustomizedFieldDecode(req *protocol.Request, params PathParams) error { +func (c *CustomizedDecode) CustomizedFieldDecode(req *protocol.Request, params PathParam) error { q1 := req.URI().QueryArgs().Peek("a") if len(q1) == 0 { return fmt.Errorf("can be nil") diff --git a/pkg/app/server/binding/customized_type_decoder.go b/pkg/app/server/binding/customized_type_decoder.go index 87379c644..870252251 100644 --- a/pkg/app/server/binding/customized_type_decoder.go +++ b/pkg/app/server/binding/customized_type_decoder.go @@ -42,20 +42,18 @@ package binding import ( "reflect" - - "github.com/cloudwego/hertz/pkg/protocol" ) type customizedFieldTextDecoder struct { fieldInfo } -func (d *customizedFieldTextDecoder) Decode(req *protocol.Request, params PathParams, reqValue reflect.Value) error { +func (d *customizedFieldTextDecoder) Decode(req *bindRequest, params PathParam, reqValue reflect.Value) error { var err error v := reflect.New(d.fieldType) decoder := v.Interface().(CustomizedFieldDecoder) - if err = decoder.CustomizedFieldDecode(req, params); err != nil { + if err = decoder.CustomizedFieldDecode(req.Req, params); err != nil { return err } diff --git a/pkg/app/server/binding/decoder.go b/pkg/app/server/binding/decoder.go index 0cc3695a3..9b77598c3 100644 --- a/pkg/app/server/binding/decoder.go +++ b/pkg/app/server/binding/decoder.go @@ -43,22 +43,33 @@ package binding import ( "fmt" "mime/multipart" + "net/http" + "net/url" "reflect" "github.com/cloudwego/hertz/pkg/protocol" ) +type bindRequest struct { + Req *protocol.Request + Query url.Values + Form url.Values + MultipartForm url.Values + Header http.Header + Cookie []*http.Cookie +} + type decoder interface { - Decode(req *protocol.Request, params PathParams, reqValue reflect.Value) error + Decode(req *bindRequest, params PathParam, reqValue reflect.Value) error } type CustomizedFieldDecoder interface { - CustomizedFieldDecode(req *protocol.Request, params PathParams) error + CustomizedFieldDecode(req *protocol.Request, params PathParam) error } -type Decoder func(req *protocol.Request, params PathParams, rv reflect.Value) error +type Decoder func(req *protocol.Request, params PathParam, rv reflect.Value) error -var fieldDecoderType = reflect.TypeOf((*CustomizedFieldDecoder)(nil)).Elem() +var customizedFieldDecoderType = reflect.TypeOf((*CustomizedFieldDecoder)(nil)).Elem() func getReqDecoder(rt reflect.Type) (Decoder, error) { var decoders []decoder @@ -84,9 +95,12 @@ func getReqDecoder(rt reflect.Type) (Decoder, error) { } } - return func(req *protocol.Request, params PathParams, rv reflect.Value) error { + return func(req *protocol.Request, params PathParam, rv reflect.Value) error { + bindReq := &bindRequest{ + Req: req, + } for _, decoder := range decoders { - err := decoder.Decode(req, params, rv) + err := decoder.Decode(bindReq, params, rv) if err != nil { return err } @@ -100,7 +114,7 @@ func getFieldDecoder(field reflect.StructField, index int, parentIdx []int) ([]d for field.Type.Kind() == reflect.Ptr { field.Type = field.Type.Elem() } - if reflect.PtrTo(field.Type).Implements(fieldDecoderType) { + if reflect.PtrTo(field.Type).Implements(customizedFieldDecoderType) { return []decoder{&customizedFieldTextDecoder{ fieldInfo: fieldInfo{ index: index, diff --git a/pkg/app/server/binding/default_binder.go b/pkg/app/server/binding/default.go similarity index 55% rename from pkg/app/server/binding/default_binder.go rename to pkg/app/server/binding/default.go index 18b5de06e..0e983191c 100644 --- a/pkg/app/server/binding/default_binder.go +++ b/pkg/app/server/binding/default.go @@ -12,7 +12,7 @@ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. - * MIT License + * The MIT License * * Copyright (c) 2019-present Fenny and Contributors * @@ -34,6 +34,26 @@ * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE * SOFTWARE. * + * Copyright (c) 2014 Manuel Martínez-Almeida + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in + * all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN + * THE SOFTWARE. + * * This file may have been modified by CloudWeGo authors. All CloudWeGo * Modifications are Copyright 2022 CloudWeGo Authors */ @@ -47,18 +67,19 @@ import ( "github.com/cloudwego/hertz/internal/bytesconv" "github.com/cloudwego/hertz/pkg/protocol" + "github.com/go-playground/validator/v10" "google.golang.org/protobuf/proto" ) -type Bind struct { +type defaultBinder struct { decoderCache sync.Map } -func (b *Bind) Name() string { +func (b *defaultBinder) Name() string { return "hertz" } -func (b *Bind) Bind(req *protocol.Request, params PathParams, v interface{}) error { +func (b *defaultBinder) Bind(req *protocol.Request, params PathParam, v interface{}) error { err := b.preBindBody(req, v) if err != nil { return err @@ -92,7 +113,7 @@ var ( ) // best effort binding -func (b *Bind) preBindBody(req *protocol.Request, v interface{}) error { +func (b *defaultBinder) preBindBody(req *protocol.Request, v interface{}) error { if req.Header.ContentLength() <= 0 { return nil } @@ -110,3 +131,49 @@ func (b *Bind) preBindBody(req *protocol.Request, v interface{}) error { return nil } } + +var _ StructValidator = (*defaultValidator)(nil) + +type defaultValidator struct { + once sync.Once + validate *validator.Validate +} + +// ValidateStruct receives any kind of type, but only performed struct or pointer to struct type. +func (v *defaultValidator) ValidateStruct(obj interface{}) error { + if obj == nil { + return nil + } + + value := reflect.ValueOf(obj) + switch value.Kind() { + case reflect.Ptr: + return v.ValidateStruct(value.Elem().Interface()) + case reflect.Struct: + return v.validateStruct(obj) + default: + return nil + } +} + +// validateStruct receives struct type +func (v *defaultValidator) validateStruct(obj interface{}) error { + v.lazyinit() + return v.validate.Struct(obj) +} + +func (v *defaultValidator) lazyinit() { + v.once.Do(func() { + v.validate = validator.New() + v.validate.SetTagName("validate") + }) +} + +// Engine returns the underlying validator engine which powers the default +// Validator instance. This is useful if you want to register custom validations +// or struct level validations. See validator GoDoc for more info - +// https://pkg.go.dev/github.com/go-playground/validator/v10 +func (v *defaultValidator) Engine() interface{} { + v.lazyinit() + return v.validate +} diff --git a/pkg/app/server/binding/default_validator.go b/pkg/app/server/binding/default_validator.go deleted file mode 100644 index 0ff3cd3d6..000000000 --- a/pkg/app/server/binding/default_validator.go +++ /dev/null @@ -1,94 +0,0 @@ -/* - * Copyright 2022 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * The MIT License (MIT) - * - * Copyright (c) 2014 Manuel Martínez-Almeida - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in - * all copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN - * THE SOFTWARE. - * - * This file may have been modified by CloudWeGo authors. All CloudWeGo - * Modifications are Copyright 2022 CloudWeGo Authors - */ - -package binding - -import ( - "reflect" - "sync" - - "github.com/go-playground/validator/v10" -) - -var _ StructValidator = (*defaultValidator)(nil) - -type defaultValidator struct { - once sync.Once - validate *validator.Validate -} - -// ValidateStruct receives any kind of type, but only performed struct or pointer to struct type. -func (v *defaultValidator) ValidateStruct(obj interface{}) error { - if obj == nil { - return nil - } - - value := reflect.ValueOf(obj) - switch value.Kind() { - case reflect.Ptr: - return v.ValidateStruct(value.Elem().Interface()) - case reflect.Struct: - return v.validateStruct(obj) - default: - return nil - } -} - -// validateStruct receives struct type -func (v *defaultValidator) validateStruct(obj interface{}) error { - v.lazyinit() - return v.validate.Struct(obj) -} - -func (v *defaultValidator) lazyinit() { - v.once.Do(func() { - v.validate = validator.New() - v.validate.SetTagName("validate") - }) -} - -// Engine returns the underlying validator engine which powers the default -// Validator instance. This is useful if you want to register custom validations -// or struct level validations. See validator GoDoc for more info - -// https://pkg.go.dev/github.com/go-playground/validator/v10 -func (v *defaultValidator) Engine() interface{} { - v.lazyinit() - return v.validate -} diff --git a/pkg/app/server/binding/getter.go b/pkg/app/server/binding/getter.go index 878cd5e5c..1174e99d0 100644 --- a/pkg/app/server/binding/getter.go +++ b/pkg/app/server/binding/getter.go @@ -41,13 +41,13 @@ package binding import ( - "github.com/cloudwego/hertz/internal/bytesconv" - "github.com/cloudwego/hertz/pkg/protocol" + "net/http" + "net/url" ) -type getter func(req *protocol.Request, params PathParams, key string, defaultValue ...string) (ret []string) +type getter func(req *bindRequest, params PathParam, key string, defaultValue ...string) (ret []string) -func PathParam(req *protocol.Request, params PathParams, key string, defaultValue ...string) (ret []string) { +func path(req *bindRequest, params PathParam, key string, defaultValue ...string) (ret []string) { var value string if params != nil { value, _ = params.Get(key) @@ -64,33 +64,47 @@ func PathParam(req *protocol.Request, params PathParams, key string, defaultValu } // todo: Optimize 'postform' and 'multipart-form' -func Form(req *protocol.Request, params PathParams, key string, defaultValue ...string) (ret []string) { - req.URI().QueryArgs().VisitAll(func(queryKey, value []byte) { - if bytesconv.B2s(queryKey) == key { - ret = append(ret, string(value)) - } - }) +func form(req *bindRequest, params PathParam, key string, defaultValue ...string) (ret []string) { + if req.Query == nil { + req.Query = make(url.Values) + req.Req.URI().QueryArgs().VisitAll(func(queryKey, value []byte) { + keyStr := string(queryKey) + values, _ := req.Query[keyStr] + values = append(values, string(value)) + req.Query[keyStr] = values + }) + } + ret = req.Query[key] if len(ret) > 0 { return } - req.PostArgs().VisitAll(func(formKey, value []byte) { - if bytesconv.B2s(formKey) == key { - ret = append(ret, string(value)) - } - }) + if req.Form == nil { + req.Form = make(url.Values) + req.Req.PostArgs().VisitAll(func(formKey, value []byte) { + keyStr := string(formKey) + values, _ := req.Form[keyStr] + values = append(values, string(value)) + req.Form[keyStr] = values + }) + } + ret = req.Form[key] if len(ret) > 0 { return } - mf, err := req.MultipartForm() - if err == nil && mf.Value != nil { - for k, v := range mf.Value { - if k == key { - ret = append(ret, v[0]) + if req.MultipartForm == nil { + req.MultipartForm = make(url.Values) + mf, err := req.Req.MultipartForm() + if err == nil && mf.Value != nil { + for k, v := range mf.Value { + if len(v) > 0 { + req.MultipartForm[k] = v + } } } } + ret = req.MultipartForm[key] if len(ret) > 0 { return } @@ -102,12 +116,18 @@ func Form(req *protocol.Request, params PathParams, key string, defaultValue ... return } -func Query(req *protocol.Request, params PathParams, key string, defaultValue ...string) (ret []string) { - req.URI().QueryArgs().VisitAll(func(queryKey, value []byte) { - if bytesconv.B2s(queryKey) == key { - ret = append(ret, string(value)) - } - }) +func query(req *bindRequest, params PathParam, key string, defaultValue ...string) (ret []string) { + if req.Query == nil { + req.Query = make(url.Values) + req.Req.URI().QueryArgs().VisitAll(func(queryKey, value []byte) { + keyStr := string(queryKey) + values, _ := req.Query[keyStr] + values = append(values, string(value)) + req.Query[keyStr] = values + }) + } + + ret = req.Query[key] if len(ret) == 0 && len(defaultValue) != 0 { ret = append(ret, defaultValue...) } @@ -115,14 +135,20 @@ func Query(req *protocol.Request, params PathParams, key string, defaultValue .. return } -// todo: cookie -func Cookie(req *protocol.Request, params PathParams, key string, defaultValue ...string) (ret []string) { - req.Header.VisitAllCookie(func(cookieKey, value []byte) { - if bytesconv.B2s(cookieKey) == key { - ret = append(ret, string(value)) +func cookie(req *bindRequest, params PathParam, key string, defaultValue ...string) (ret []string) { + if len(req.Cookie) == 0 { + req.Req.Header.VisitAllCookie(func(cookieKey, value []byte) { + req.Cookie = append(req.Cookie, &http.Cookie{ + Name: string(cookieKey), + Value: string(value), + }) + }) + } + for _, c := range req.Cookie { + if c.Name == key { + ret = append(ret, c.Value) } - }) - + } if len(ret) == 0 && len(defaultValue) != 0 { ret = append(ret, defaultValue...) } @@ -130,13 +156,18 @@ func Cookie(req *protocol.Request, params PathParams, key string, defaultValue . return } -func Header(req *protocol.Request, params PathParams, key string, defaultValue ...string) (ret []string) { - req.Header.VisitAll(func(headerKey, value []byte) { - if bytesconv.B2s(headerKey) == key { - ret = append(ret, string(value)) - } - }) +func header(req *bindRequest, params PathParam, key string, defaultValue ...string) (ret []string) { + if req.Header == nil { + req.Header = make(http.Header) + req.Req.Header.VisitAll(func(headerKey, value []byte) { + keyStr := string(headerKey) + values, _ := req.Header[keyStr] + values = append(values, string(value)) + req.Header[keyStr] = values + }) + } + ret = req.Header[key] if len(ret) == 0 && len(defaultValue) != 0 { ret = append(ret, defaultValue...) } @@ -144,19 +175,19 @@ func Header(req *protocol.Request, params PathParams, key string, defaultValue . return } -func Json(req *protocol.Request, params PathParams, key string, defaultValue ...string) (ret []string) { +func json(req *bindRequest, params PathParam, key string, defaultValue ...string) (ret []string) { // do nothing return } -func RawBody(req *protocol.Request, params PathParams, key string, defaultValue ...string) (ret []string) { - if req.Header.ContentLength() > 0 { - ret = append(ret, string(req.Body())) +func rawBody(req *bindRequest, params PathParam, key string, defaultValue ...string) (ret []string) { + if req.Req.Header.ContentLength() > 0 { + ret = append(ret, string(req.Req.Body())) } return } -func FileName(req *protocol.Request, params PathParams, key string, defaultValue ...string) (ret []string) { +func FileName(req *bindRequest, params PathParam, key string, defaultValue ...string) (ret []string) { // do nothing return } diff --git a/pkg/app/server/binding/map_type_decoder.go b/pkg/app/server/binding/map_type_decoder.go index 7af3bdbac..f7d6e1c83 100644 --- a/pkg/app/server/binding/map_type_decoder.go +++ b/pkg/app/server/binding/map_type_decoder.go @@ -46,14 +46,13 @@ import ( "github.com/cloudwego/hertz/internal/bytesconv" "github.com/cloudwego/hertz/pkg/common/utils" - "github.com/cloudwego/hertz/pkg/protocol" ) type mapTypeFieldTextDecoder struct { fieldInfo } -func (d *mapTypeFieldTextDecoder) Decode(req *protocol.Request, params PathParams, reqValue reflect.Value) error { +func (d *mapTypeFieldTextDecoder) Decode(req *bindRequest, params PathParam, reqValue reflect.Value) error { var text string var defaultValue string for _, tagInfo := range d.tagInfos { @@ -61,7 +60,7 @@ func (d *mapTypeFieldTextDecoder) Decode(req *protocol.Request, params PathParam continue } if tagInfo.Key == headerTag { - tagInfo.Value = utils.GetNormalizeHeaderKey(tagInfo.Value, req.Header.IsDisableNormalizing()) + tagInfo.Value = utils.GetNormalizeHeaderKey(tagInfo.Value, req.Req.Header.IsDisableNormalizing()) } ret := tagInfo.Getter(req, params, tagInfo.Value) defaultValue = tagInfo.Default @@ -107,19 +106,19 @@ func getMapTypeTextDecoder(field reflect.StructField, index int, tagInfos []TagI for idx, tagInfo := range tagInfos { switch tagInfo.Key { case pathTag: - tagInfos[idx].Getter = PathParam + tagInfos[idx].Getter = path case formTag: - tagInfos[idx].Getter = Form + tagInfos[idx].Getter = form case queryTag: - tagInfos[idx].Getter = Query + tagInfos[idx].Getter = query case cookieTag: - tagInfos[idx].Getter = Cookie + tagInfos[idx].Getter = cookie case headerTag: - tagInfos[idx].Getter = Header + tagInfos[idx].Getter = header case jsonTag: // do nothing case rawBodyTag: - tagInfo.Getter = RawBody + tagInfo.Getter = rawBody case fileNameTag: // do nothing default: diff --git a/pkg/app/server/binding/multipart_file_decoder.go b/pkg/app/server/binding/multipart_file_decoder.go index cf4a4c888..6f08c7668 100644 --- a/pkg/app/server/binding/multipart_file_decoder.go +++ b/pkg/app/server/binding/multipart_file_decoder.go @@ -19,8 +19,6 @@ package binding import ( "fmt" "reflect" - - "github.com/cloudwego/hertz/pkg/protocol" ) type fileTypeDecoder struct { @@ -28,7 +26,7 @@ type fileTypeDecoder struct { isRepeated bool } -func (d *fileTypeDecoder) Decode(req *protocol.Request, params PathParams, reqValue reflect.Value) error { +func (d *fileTypeDecoder) Decode(req *bindRequest, params PathParam, reqValue reflect.Value) error { fieldValue := GetFieldValue(reqValue, d.parentIndex) field := fieldValue.Field(d.index) @@ -49,7 +47,7 @@ func (d *fileTypeDecoder) Decode(req *protocol.Request, params PathParams, reqVa if len(fileName) == 0 { fileName = d.fieldName } - file, err := req.FormFile(fileName) + file, err := req.Req.FormFile(fileName) if err != nil { return fmt.Errorf("can not get file '%s', err: %v", fileName, err) } @@ -72,7 +70,7 @@ func (d *fileTypeDecoder) Decode(req *protocol.Request, params PathParams, reqVa return nil } -func (d *fileTypeDecoder) fileSliceDecode(req *protocol.Request, params PathParams, reqValue reflect.Value) error { +func (d *fileTypeDecoder) fileSliceDecode(req *bindRequest, params PathParam, reqValue reflect.Value) error { fieldValue := GetFieldValue(reqValue, d.parentIndex) field := fieldValue.Field(d.index) // 如果没值,需要为其建一个值 @@ -102,7 +100,7 @@ func (d *fileTypeDecoder) fileSliceDecode(req *protocol.Request, params PathPara if len(fileName) == 0 { fileName = d.fieldName } - multipartForm, err := req.MultipartForm() + multipartForm, err := req.Req.MultipartForm() if err != nil { return fmt.Errorf("can not get multipartForm info, err: %v", err) } diff --git a/pkg/app/server/binding/slice_type_decoder.go b/pkg/app/server/binding/slice_type_decoder.go index f167db96f..a84547643 100644 --- a/pkg/app/server/binding/slice_type_decoder.go +++ b/pkg/app/server/binding/slice_type_decoder.go @@ -48,7 +48,6 @@ import ( "github.com/cloudwego/hertz/internal/bytesconv" "github.com/cloudwego/hertz/pkg/app/server/binding/text_decoder" "github.com/cloudwego/hertz/pkg/common/utils" - "github.com/cloudwego/hertz/pkg/protocol" ) type sliceTypeFieldTextDecoder struct { @@ -56,7 +55,7 @@ type sliceTypeFieldTextDecoder struct { isArray bool } -func (d *sliceTypeFieldTextDecoder) Decode(req *protocol.Request, params PathParams, reqValue reflect.Value) error { +func (d *sliceTypeFieldTextDecoder) Decode(req *bindRequest, params PathParam, reqValue reflect.Value) error { var texts []string var defaultValue string for _, tagInfo := range d.tagInfos { @@ -64,10 +63,10 @@ func (d *sliceTypeFieldTextDecoder) Decode(req *protocol.Request, params PathPar continue } if tagInfo.Key == headerTag { - tagInfo.Value = utils.GetNormalizeHeaderKey(tagInfo.Value, req.Header.IsDisableNormalizing()) + tagInfo.Value = utils.GetNormalizeHeaderKey(tagInfo.Value, req.Req.Header.IsDisableNormalizing()) } texts = tagInfo.Getter(req, params, tagInfo.Value) - // todo: array/slice default value + //todo: array/slice default value defaultValue = tagInfo.Default if len(texts) != 0 { break @@ -82,6 +81,7 @@ func (d *sliceTypeFieldTextDecoder) Decode(req *protocol.Request, params PathPar reqValue = GetFieldValue(reqValue, d.parentIndex) field := reqValue.Field(d.index) + // **[]**int if field.Kind() == reflect.Ptr { if field.IsNil() { nonNilVal, ptrDepth := GetNonNilReferenceValue(field) @@ -103,7 +103,7 @@ func (d *sliceTypeFieldTextDecoder) Decode(req *protocol.Request, params PathPar field = reflect.MakeSlice(field.Type(), len(texts), len(texts)) } - // handle multiple pointer + // handle internal multiple pointer, []**int var ptrDepth int t := d.fieldType.Elem() elemKind := t.Kind() @@ -137,19 +137,19 @@ func getSliceFieldDecoder(field reflect.StructField, index int, tagInfos []TagIn for idx, tagInfo := range tagInfos { switch tagInfo.Key { case pathTag: - tagInfos[idx].Getter = PathParam + tagInfos[idx].Getter = path case formTag: - tagInfos[idx].Getter = Form + tagInfos[idx].Getter = form case queryTag: - tagInfos[idx].Getter = Query + tagInfos[idx].Getter = query case cookieTag: - tagInfos[idx].Getter = Cookie + tagInfos[idx].Getter = cookie case headerTag: - tagInfos[idx].Getter = Header + tagInfos[idx].Getter = header case jsonTag: // do nothing case rawBodyTag: - tagInfo.Getter = RawBody + tagInfo.Getter = rawBody case fileNameTag: // do nothing default: From a630b6bc9cabbb965bb49e56de8ab7d256724041 Mon Sep 17 00:00:00 2001 From: fgy Date: Wed, 10 May 2023 14:37:48 +0800 Subject: [PATCH 14/91] refactor: layout --- pkg/app/server/binding/binder.go | 8 +-- pkg/app/server/binding/binder_test.go | 3 +- pkg/app/server/binding/config.go | 15 ++++ .../{ => decoder}/base_type_decoder.go | 18 +++-- .../{ => decoder}/customized_type_decoder.go | 6 +- .../server/binding/{ => decoder}/decoder.go | 23 ++++--- .../server/binding/{ => decoder}/getter.go | 22 +++--- .../binding/{ => decoder}/map_type_decoder.go | 17 ++--- .../{ => decoder}/multipart_file_decoder.go | 16 ++--- .../unit.go => decoder/reflect.go} | 69 ++++++++++++++++--- .../{ => decoder}/slice_type_decoder.go | 21 +++--- pkg/app/server/binding/{ => decoder}/tag.go | 2 +- .../{text_decoder => decoder}/text_decoder.go | 60 +++++++++++++++- pkg/app/server/binding/default.go | 13 ++-- pkg/app/server/binding/json.go | 65 ----------------- pkg/app/server/binding/path/path.go | 6 ++ pkg/app/server/binding/reflect.go | 64 ----------------- pkg/app/server/binding/text_decoder/bool.go | 57 --------------- pkg/app/server/binding/text_decoder/float.go | 59 ---------------- pkg/app/server/binding/text_decoder/int.go | 59 ---------------- pkg/app/server/binding/text_decoder/string.go | 50 -------------- 21 files changed, 215 insertions(+), 438 deletions(-) create mode 100644 pkg/app/server/binding/config.go rename pkg/app/server/binding/{ => decoder}/base_type_decoder.go (92%) rename pkg/app/server/binding/{ => decoder}/customized_type_decoder.go (94%) rename pkg/app/server/binding/{ => decoder}/decoder.go (87%) rename pkg/app/server/binding/{ => decoder}/getter.go (81%) rename pkg/app/server/binding/{ => decoder}/map_type_decoder.go (91%) rename pkg/app/server/binding/{ => decoder}/multipart_file_decoder.go (91%) rename pkg/app/server/binding/{text_decoder/unit.go => decoder/reflect.go} (55%) rename pkg/app/server/binding/{ => decoder}/slice_type_decoder.go (92%) rename pkg/app/server/binding/{ => decoder}/tag.go (99%) rename pkg/app/server/binding/{text_decoder => decoder}/text_decoder.go (74%) delete mode 100644 pkg/app/server/binding/json.go create mode 100644 pkg/app/server/binding/path/path.go delete mode 100644 pkg/app/server/binding/text_decoder/bool.go delete mode 100644 pkg/app/server/binding/text_decoder/float.go delete mode 100644 pkg/app/server/binding/text_decoder/int.go delete mode 100644 pkg/app/server/binding/text_decoder/string.go diff --git a/pkg/app/server/binding/binder.go b/pkg/app/server/binding/binder.go index 835ba029b..09ed5e49f 100644 --- a/pkg/app/server/binding/binder.go +++ b/pkg/app/server/binding/binder.go @@ -41,17 +41,13 @@ package binding import ( + "github.com/cloudwego/hertz/pkg/app/server/binding/path" "github.com/cloudwego/hertz/pkg/protocol" ) -// PathParam parameter acquisition interface on the URL path -type PathParam interface { - Get(name string) (string, bool) -} - type Binder interface { Name() string - Bind(*protocol.Request, PathParam, interface{}) error + Bind(*protocol.Request, path.PathParam, interface{}) error } var DefaultBinder Binder = &defaultBinder{} diff --git a/pkg/app/server/binding/binder_test.go b/pkg/app/server/binding/binder_test.go index 2a5aa9c66..4025fef44 100644 --- a/pkg/app/server/binding/binder_test.go +++ b/pkg/app/server/binding/binder_test.go @@ -42,6 +42,7 @@ package binding import ( "fmt" + "github.com/cloudwego/hertz/pkg/app/server/binding/path" "mime/multipart" "testing" @@ -492,7 +493,7 @@ type CustomizedDecode struct { A string } -func (c *CustomizedDecode) CustomizedFieldDecode(req *protocol.Request, params PathParam) error { +func (c *CustomizedDecode) CustomizedFieldDecode(req *protocol.Request, params path.PathParam) error { q1 := req.URI().QueryArgs().Peek("a") if len(q1) == 0 { return fmt.Errorf("can be nil") diff --git a/pkg/app/server/binding/config.go b/pkg/app/server/binding/config.go new file mode 100644 index 000000000..b02e6e78b --- /dev/null +++ b/pkg/app/server/binding/config.go @@ -0,0 +1,15 @@ +package binding + +import ( + standardJson "encoding/json" + + hjson "github.com/cloudwego/hertz/pkg/common/json" +) + +func ResetJSONUnmarshaler(fn func(data []byte, v interface{}) error) { + hjson.Unmarshal = fn +} + +func ResetStdJSONUnmarshaler() { + ResetJSONUnmarshaler(standardJson.Unmarshal) +} diff --git a/pkg/app/server/binding/base_type_decoder.go b/pkg/app/server/binding/decoder/base_type_decoder.go similarity index 92% rename from pkg/app/server/binding/base_type_decoder.go rename to pkg/app/server/binding/decoder/base_type_decoder.go index 483a0a0cf..b48921930 100644 --- a/pkg/app/server/binding/base_type_decoder.go +++ b/pkg/app/server/binding/decoder/base_type_decoder.go @@ -38,13 +38,13 @@ * Modifications are Copyright 2022 CloudWeGo Authors */ -package binding +package decoder import ( "fmt" "reflect" - "github.com/cloudwego/hertz/pkg/app/server/binding/text_decoder" + path1 "github.com/cloudwego/hertz/pkg/app/server/binding/path" "github.com/cloudwego/hertz/pkg/common/utils" ) @@ -58,10 +58,10 @@ type fieldInfo struct { type baseTypeFieldTextDecoder struct { fieldInfo - decoder text_decoder.TextDecoder + decoder TextDecoder } -func (d *baseTypeFieldTextDecoder) Decode(req *bindRequest, params PathParam, reqValue reflect.Value) error { +func (d *baseTypeFieldTextDecoder) Decode(req *bindRequest, params path1.PathParam, reqValue reflect.Value) error { var err error var text string var defaultValue string @@ -122,7 +122,7 @@ func (d *baseTypeFieldTextDecoder) Decode(req *bindRequest, params PathParam, re return nil } -func getBaseTypeTextDecoder(field reflect.StructField, index int, tagInfos []TagInfo, parentIdx []int) ([]decoder, error) { +func getBaseTypeTextDecoder(field reflect.StructField, index int, tagInfos []TagInfo, parentIdx []int) ([]fieldDecoder, error) { for idx, tagInfo := range tagInfos { switch tagInfo.Key { case pathTag: @@ -150,12 +150,12 @@ func getBaseTypeTextDecoder(field reflect.StructField, index int, tagInfos []Tag fieldType = field.Type.Elem() } - textDecoder, err := text_decoder.SelectTextDecoder(fieldType) + textDecoder, err := SelectTextDecoder(fieldType) if err != nil { return nil, err } - fieldDecoder := &baseTypeFieldTextDecoder{ + return []fieldDecoder{&baseTypeFieldTextDecoder{ fieldInfo: fieldInfo{ index: index, parentIndex: parentIdx, @@ -164,7 +164,5 @@ func getBaseTypeTextDecoder(field reflect.StructField, index int, tagInfos []Tag fieldType: fieldType, }, decoder: textDecoder, - } - - return []decoder{fieldDecoder}, nil + }}, nil } diff --git a/pkg/app/server/binding/customized_type_decoder.go b/pkg/app/server/binding/decoder/customized_type_decoder.go similarity index 94% rename from pkg/app/server/binding/customized_type_decoder.go rename to pkg/app/server/binding/decoder/customized_type_decoder.go index 870252251..d8908f18d 100644 --- a/pkg/app/server/binding/customized_type_decoder.go +++ b/pkg/app/server/binding/decoder/customized_type_decoder.go @@ -38,17 +38,19 @@ * Modifications are Copyright 2022 CloudWeGo Authors */ -package binding +package decoder import ( "reflect" + + path1 "github.com/cloudwego/hertz/pkg/app/server/binding/path" ) type customizedFieldTextDecoder struct { fieldInfo } -func (d *customizedFieldTextDecoder) Decode(req *bindRequest, params PathParam, reqValue reflect.Value) error { +func (d *customizedFieldTextDecoder) Decode(req *bindRequest, params path1.PathParam, reqValue reflect.Value) error { var err error v := reflect.New(d.fieldType) decoder := v.Interface().(CustomizedFieldDecoder) diff --git a/pkg/app/server/binding/decoder.go b/pkg/app/server/binding/decoder/decoder.go similarity index 87% rename from pkg/app/server/binding/decoder.go rename to pkg/app/server/binding/decoder/decoder.go index 9b77598c3..3de28e8ca 100644 --- a/pkg/app/server/binding/decoder.go +++ b/pkg/app/server/binding/decoder/decoder.go @@ -38,7 +38,7 @@ * Modifications are Copyright 2022 CloudWeGo Authors */ -package binding +package decoder import ( "fmt" @@ -47,6 +47,7 @@ import ( "net/url" "reflect" + path1 "github.com/cloudwego/hertz/pkg/app/server/binding/path" "github.com/cloudwego/hertz/pkg/protocol" ) @@ -59,20 +60,20 @@ type bindRequest struct { Cookie []*http.Cookie } -type decoder interface { - Decode(req *bindRequest, params PathParam, reqValue reflect.Value) error +type fieldDecoder interface { + Decode(req *bindRequest, params path1.PathParam, reqValue reflect.Value) error } type CustomizedFieldDecoder interface { - CustomizedFieldDecode(req *protocol.Request, params PathParam) error + CustomizedFieldDecode(req *protocol.Request, params path1.PathParam) error } -type Decoder func(req *protocol.Request, params PathParam, rv reflect.Value) error +type Decoder func(req *protocol.Request, params path1.PathParam, rv reflect.Value) error var customizedFieldDecoderType = reflect.TypeOf((*CustomizedFieldDecoder)(nil)).Elem() -func getReqDecoder(rt reflect.Type) (Decoder, error) { - var decoders []decoder +func GetReqDecoder(rt reflect.Type) (Decoder, error) { + var decoders []fieldDecoder el := rt.Elem() if el.Kind() != reflect.Struct { @@ -95,7 +96,7 @@ func getReqDecoder(rt reflect.Type) (Decoder, error) { } } - return func(req *protocol.Request, params PathParam, rv reflect.Value) error { + return func(req *protocol.Request, params path1.PathParam, rv reflect.Value) error { bindReq := &bindRequest{ Req: req, } @@ -110,12 +111,12 @@ func getReqDecoder(rt reflect.Type) (Decoder, error) { }, nil } -func getFieldDecoder(field reflect.StructField, index int, parentIdx []int) ([]decoder, error) { +func getFieldDecoder(field reflect.StructField, index int, parentIdx []int) ([]fieldDecoder, error) { for field.Type.Kind() == reflect.Ptr { field.Type = field.Type.Elem() } if reflect.PtrTo(field.Type).Implements(customizedFieldDecoderType) { - return []decoder{&customizedFieldTextDecoder{ + return []fieldDecoder{&customizedFieldTextDecoder{ fieldInfo: fieldInfo{ index: index, parentIndex: parentIdx, @@ -139,7 +140,7 @@ func getFieldDecoder(field reflect.StructField, index int, parentIdx []int) ([]d } if field.Type.Kind() == reflect.Struct { - var decoders []decoder + var decoders []fieldDecoder el := field.Type // todo: built-in bindings for some common structs, code need to be optimized switch el { diff --git a/pkg/app/server/binding/getter.go b/pkg/app/server/binding/decoder/getter.go similarity index 81% rename from pkg/app/server/binding/getter.go rename to pkg/app/server/binding/decoder/getter.go index 1174e99d0..a6c293082 100644 --- a/pkg/app/server/binding/getter.go +++ b/pkg/app/server/binding/decoder/getter.go @@ -38,16 +38,18 @@ * Modifications are Copyright 2022 CloudWeGo Authors */ -package binding +package decoder import ( "net/http" "net/url" + + path1 "github.com/cloudwego/hertz/pkg/app/server/binding/path" ) -type getter func(req *bindRequest, params PathParam, key string, defaultValue ...string) (ret []string) +type getter func(req *bindRequest, params path1.PathParam, key string, defaultValue ...string) (ret []string) -func path(req *bindRequest, params PathParam, key string, defaultValue ...string) (ret []string) { +func path(req *bindRequest, params path1.PathParam, key string, defaultValue ...string) (ret []string) { var value string if params != nil { value, _ = params.Get(key) @@ -64,7 +66,7 @@ func path(req *bindRequest, params PathParam, key string, defaultValue ...string } // todo: Optimize 'postform' and 'multipart-form' -func form(req *bindRequest, params PathParam, key string, defaultValue ...string) (ret []string) { +func form(req *bindRequest, params path1.PathParam, key string, defaultValue ...string) (ret []string) { if req.Query == nil { req.Query = make(url.Values) req.Req.URI().QueryArgs().VisitAll(func(queryKey, value []byte) { @@ -116,7 +118,7 @@ func form(req *bindRequest, params PathParam, key string, defaultValue ...string return } -func query(req *bindRequest, params PathParam, key string, defaultValue ...string) (ret []string) { +func query(req *bindRequest, params path1.PathParam, key string, defaultValue ...string) (ret []string) { if req.Query == nil { req.Query = make(url.Values) req.Req.URI().QueryArgs().VisitAll(func(queryKey, value []byte) { @@ -135,7 +137,7 @@ func query(req *bindRequest, params PathParam, key string, defaultValue ...strin return } -func cookie(req *bindRequest, params PathParam, key string, defaultValue ...string) (ret []string) { +func cookie(req *bindRequest, params path1.PathParam, key string, defaultValue ...string) (ret []string) { if len(req.Cookie) == 0 { req.Req.Header.VisitAllCookie(func(cookieKey, value []byte) { req.Cookie = append(req.Cookie, &http.Cookie{ @@ -156,7 +158,7 @@ func cookie(req *bindRequest, params PathParam, key string, defaultValue ...stri return } -func header(req *bindRequest, params PathParam, key string, defaultValue ...string) (ret []string) { +func header(req *bindRequest, params path1.PathParam, key string, defaultValue ...string) (ret []string) { if req.Header == nil { req.Header = make(http.Header) req.Req.Header.VisitAll(func(headerKey, value []byte) { @@ -175,19 +177,19 @@ func header(req *bindRequest, params PathParam, key string, defaultValue ...stri return } -func json(req *bindRequest, params PathParam, key string, defaultValue ...string) (ret []string) { +func json(req *bindRequest, params path1.PathParam, key string, defaultValue ...string) (ret []string) { // do nothing return } -func rawBody(req *bindRequest, params PathParam, key string, defaultValue ...string) (ret []string) { +func rawBody(req *bindRequest, params path1.PathParam, key string, defaultValue ...string) (ret []string) { if req.Req.Header.ContentLength() > 0 { ret = append(ret, string(req.Req.Body())) } return } -func FileName(req *bindRequest, params PathParam, key string, defaultValue ...string) (ret []string) { +func fileName(req *bindRequest, params path1.PathParam, key string, defaultValue ...string) (ret []string) { // do nothing return } diff --git a/pkg/app/server/binding/map_type_decoder.go b/pkg/app/server/binding/decoder/map_type_decoder.go similarity index 91% rename from pkg/app/server/binding/map_type_decoder.go rename to pkg/app/server/binding/decoder/map_type_decoder.go index f7d6e1c83..37529e378 100644 --- a/pkg/app/server/binding/map_type_decoder.go +++ b/pkg/app/server/binding/decoder/map_type_decoder.go @@ -38,13 +38,15 @@ * Modifications are Copyright 2022 CloudWeGo Authors */ -package binding +package decoder import ( "fmt" "reflect" "github.com/cloudwego/hertz/internal/bytesconv" + path1 "github.com/cloudwego/hertz/pkg/app/server/binding/path" + hjson "github.com/cloudwego/hertz/pkg/common/json" "github.com/cloudwego/hertz/pkg/common/utils" ) @@ -52,7 +54,7 @@ type mapTypeFieldTextDecoder struct { fieldInfo } -func (d *mapTypeFieldTextDecoder) Decode(req *bindRequest, params PathParam, reqValue reflect.Value) error { +func (d *mapTypeFieldTextDecoder) Decode(req *bindRequest, params path1.PathParam, reqValue reflect.Value) error { var text string var defaultValue string for _, tagInfo := range d.tagInfos { @@ -94,7 +96,7 @@ func (d *mapTypeFieldTextDecoder) Decode(req *bindRequest, params PathParam, req return nil } - err := jsonUnmarshalFunc(bytesconv.S2b(text), field.Addr().Interface()) + err := hjson.Unmarshal(bytesconv.S2b(text), field.Addr().Interface()) if err != nil { return fmt.Errorf("unable to decode '%s' as %s: %w", text, d.fieldType.Name(), err) } @@ -102,7 +104,7 @@ func (d *mapTypeFieldTextDecoder) Decode(req *bindRequest, params PathParam, req return nil } -func getMapTypeTextDecoder(field reflect.StructField, index int, tagInfos []TagInfo, parentIdx []int) ([]decoder, error) { +func getMapTypeTextDecoder(field reflect.StructField, index int, tagInfos []TagInfo, parentIdx []int) ([]fieldDecoder, error) { for idx, tagInfo := range tagInfos { switch tagInfo.Key { case pathTag: @@ -129,7 +131,8 @@ func getMapTypeTextDecoder(field reflect.StructField, index int, tagInfos []TagI for field.Type.Kind() == reflect.Ptr { fieldType = field.Type.Elem() } - fieldDecoder := &mapTypeFieldTextDecoder{ + + return []fieldDecoder{&mapTypeFieldTextDecoder{ fieldInfo: fieldInfo{ index: index, parentIndex: parentIdx, @@ -137,7 +140,5 @@ func getMapTypeTextDecoder(field reflect.StructField, index int, tagInfos []TagI tagInfos: tagInfos, fieldType: fieldType, }, - } - - return []decoder{fieldDecoder}, nil + }}, nil } diff --git a/pkg/app/server/binding/multipart_file_decoder.go b/pkg/app/server/binding/decoder/multipart_file_decoder.go similarity index 91% rename from pkg/app/server/binding/multipart_file_decoder.go rename to pkg/app/server/binding/decoder/multipart_file_decoder.go index 6f08c7668..1c3a180d6 100644 --- a/pkg/app/server/binding/multipart_file_decoder.go +++ b/pkg/app/server/binding/decoder/multipart_file_decoder.go @@ -14,11 +14,13 @@ * limitations under the License. */ -package binding +package decoder import ( "fmt" "reflect" + + path1 "github.com/cloudwego/hertz/pkg/app/server/binding/path" ) type fileTypeDecoder struct { @@ -26,7 +28,7 @@ type fileTypeDecoder struct { isRepeated bool } -func (d *fileTypeDecoder) Decode(req *bindRequest, params PathParam, reqValue reflect.Value) error { +func (d *fileTypeDecoder) Decode(req *bindRequest, params path1.PathParam, reqValue reflect.Value) error { fieldValue := GetFieldValue(reqValue, d.parentIndex) field := fieldValue.Field(d.index) @@ -70,7 +72,7 @@ func (d *fileTypeDecoder) Decode(req *bindRequest, params PathParam, reqValue re return nil } -func (d *fileTypeDecoder) fileSliceDecode(req *bindRequest, params PathParam, reqValue reflect.Value) error { +func (d *fileTypeDecoder) fileSliceDecode(req *bindRequest, params path1.PathParam, reqValue reflect.Value) error { fieldValue := GetFieldValue(reqValue, d.parentIndex) field := fieldValue.Field(d.index) // 如果没值,需要为其建一个值 @@ -138,7 +140,7 @@ func (d *fileTypeDecoder) fileSliceDecode(req *bindRequest, params PathParam, re return nil } -func getMultipartFileDecoder(field reflect.StructField, index int, tagInfos []TagInfo, parentIdx []int) ([]decoder, error) { +func getMultipartFileDecoder(field reflect.StructField, index int, tagInfos []TagInfo, parentIdx []int) ([]fieldDecoder, error) { fieldType := field.Type for field.Type.Kind() == reflect.Ptr { fieldType = field.Type.Elem() @@ -148,7 +150,7 @@ func getMultipartFileDecoder(field reflect.StructField, index int, tagInfos []Ta isRepeated = true } - fieldDecoder := &fileTypeDecoder{ + return []fieldDecoder{&fileTypeDecoder{ fieldInfo: fieldInfo{ index: index, parentIndex: parentIdx, @@ -157,7 +159,5 @@ func getMultipartFileDecoder(field reflect.StructField, index int, tagInfos []Ta fieldType: fieldType, }, isRepeated: isRepeated, - } - - return []decoder{fieldDecoder}, nil + }}, nil } diff --git a/pkg/app/server/binding/text_decoder/unit.go b/pkg/app/server/binding/decoder/reflect.go similarity index 55% rename from pkg/app/server/binding/text_decoder/unit.go rename to pkg/app/server/binding/decoder/reflect.go index 1c3703b1c..dba448fd6 100644 --- a/pkg/app/server/binding/text_decoder/unit.go +++ b/pkg/app/server/binding/decoder/reflect.go @@ -38,22 +38,71 @@ * Modifications are Copyright 2022 CloudWeGo Authors */ -package text_decoder +package decoder import ( "reflect" - "strconv" ) -type uintDecoder struct { - bitSize int +// ReferenceValue convert T to *T, the ptrDepth is the count of '*'. +func ReferenceValue(v reflect.Value, ptrDepth int) reflect.Value { + switch { + case ptrDepth > 0: + for ; ptrDepth > 0; ptrDepth-- { + vv := reflect.New(v.Type()) + vv.Elem().Set(v) + v = vv + } + case ptrDepth < 0: + for ; ptrDepth < 0 && v.Kind() == reflect.Ptr; ptrDepth++ { + v = v.Elem() + } + } + return v +} + +func GetNonNilReferenceValue(v reflect.Value) (reflect.Value, int) { + var ptrDepth int + t := v.Type() + elemKind := t.Kind() + for elemKind == reflect.Ptr { + t = t.Elem() + elemKind = t.Kind() + ptrDepth++ + } + val := reflect.New(t).Elem() + return val, ptrDepth } -func (d *uintDecoder) UnmarshalString(s string, fieldValue reflect.Value) error { - v, err := strconv.ParseUint(s, 10, d.bitSize) - if err != nil { - return err +func GetFieldValue(reqValue reflect.Value, parentIndex []int) reflect.Value { + for _, idx := range parentIndex { + if reqValue.Kind() == reflect.Ptr && reqValue.IsNil() { + nonNilVal, ptrDepth := GetNonNilReferenceValue(reqValue) + reqValue.Set(ReferenceValue(nonNilVal, ptrDepth)) + } + for reqValue.Kind() == reflect.Ptr { + reqValue = reqValue.Elem() + } + reqValue = reqValue.Field(idx) + } + + // It is possible that the parent struct is also a pointer, + // so need to create a non-nil reflect.Value for it at runtime. + for reqValue.Kind() == reflect.Ptr { + if reqValue.IsNil() { + nonNilVal, ptrDepth := GetNonNilReferenceValue(reqValue) + reqValue.Set(ReferenceValue(nonNilVal, ptrDepth)) + } + reqValue = reqValue.Elem() } - fieldValue.SetUint(v) - return nil + + return reqValue +} + +func getElemType(t reflect.Type) reflect.Type { + for t.Kind() == reflect.Ptr { + t = t.Elem() + } + + return t } diff --git a/pkg/app/server/binding/slice_type_decoder.go b/pkg/app/server/binding/decoder/slice_type_decoder.go similarity index 92% rename from pkg/app/server/binding/slice_type_decoder.go rename to pkg/app/server/binding/decoder/slice_type_decoder.go index a84547643..40aadf3c2 100644 --- a/pkg/app/server/binding/slice_type_decoder.go +++ b/pkg/app/server/binding/decoder/slice_type_decoder.go @@ -38,7 +38,7 @@ * Modifications are Copyright 2022 CloudWeGo Authors */ -package binding +package decoder import ( "fmt" @@ -46,7 +46,8 @@ import ( "reflect" "github.com/cloudwego/hertz/internal/bytesconv" - "github.com/cloudwego/hertz/pkg/app/server/binding/text_decoder" + path1 "github.com/cloudwego/hertz/pkg/app/server/binding/path" + hjson "github.com/cloudwego/hertz/pkg/common/json" "github.com/cloudwego/hertz/pkg/common/utils" ) @@ -55,7 +56,7 @@ type sliceTypeFieldTextDecoder struct { isArray bool } -func (d *sliceTypeFieldTextDecoder) Decode(req *bindRequest, params PathParam, reqValue reflect.Value) error { +func (d *sliceTypeFieldTextDecoder) Decode(req *bindRequest, params path1.PathParam, reqValue reflect.Value) error { var texts []string var defaultValue string for _, tagInfo := range d.tagInfos { @@ -126,7 +127,7 @@ func (d *sliceTypeFieldTextDecoder) Decode(req *bindRequest, params PathParam, r return nil } -func getSliceFieldDecoder(field reflect.StructField, index int, tagInfos []TagInfo, parentIdx []int) ([]decoder, error) { +func getSliceFieldDecoder(field reflect.StructField, index int, tagInfos []TagInfo, parentIdx []int) ([]fieldDecoder, error) { if !(field.Type.Kind() == reflect.Slice || field.Type.Kind() == reflect.Array) { return nil, fmt.Errorf("unexpected type %s, expected slice or array", field.Type.String()) } @@ -165,7 +166,7 @@ func getSliceFieldDecoder(field reflect.StructField, index int, tagInfos []TagIn return getMultipartFileDecoder(field, index, tagInfos, parentIdx) } - fieldDecoder := &sliceTypeFieldTextDecoder{ + return []fieldDecoder{&sliceTypeFieldTextDecoder{ fieldInfo: fieldInfo{ index: index, parentIndex: parentIdx, @@ -174,9 +175,7 @@ func getSliceFieldDecoder(field reflect.StructField, index int, tagInfos []TagIn fieldType: fieldType, }, isArray: isArray, - } - - return []decoder{fieldDecoder}, nil + }}, nil } func stringToValue(elemType reflect.Type, text string) (v reflect.Value, err error) { @@ -185,13 +184,13 @@ func stringToValue(elemType reflect.Type, text string) (v reflect.Value, err err switch elemType.Kind() { case reflect.Struct: - err = jsonUnmarshalFunc(bytesconv.S2b(text), v.Addr().Interface()) + err = hjson.Unmarshal(bytesconv.S2b(text), v.Addr().Interface()) case reflect.Map: - err = jsonUnmarshalFunc(bytesconv.S2b(text), v.Addr().Interface()) + err = hjson.Unmarshal(bytesconv.S2b(text), v.Addr().Interface()) case reflect.Array, reflect.Slice: // do nothing default: - decoder, err := text_decoder.SelectTextDecoder(elemType) + decoder, err := SelectTextDecoder(elemType) if err != nil { return reflect.Value{}, fmt.Errorf("unsupported type %s for slice/array", elemType.String()) } diff --git a/pkg/app/server/binding/tag.go b/pkg/app/server/binding/decoder/tag.go similarity index 99% rename from pkg/app/server/binding/tag.go rename to pkg/app/server/binding/decoder/tag.go index f44778d97..6ad002788 100644 --- a/pkg/app/server/binding/tag.go +++ b/pkg/app/server/binding/decoder/tag.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package binding +package decoder import ( "reflect" diff --git a/pkg/app/server/binding/text_decoder/text_decoder.go b/pkg/app/server/binding/decoder/text_decoder.go similarity index 74% rename from pkg/app/server/binding/text_decoder/text_decoder.go rename to pkg/app/server/binding/decoder/text_decoder.go index 08659aede..425c2ea46 100644 --- a/pkg/app/server/binding/text_decoder/text_decoder.go +++ b/pkg/app/server/binding/decoder/text_decoder.go @@ -38,11 +38,12 @@ * Modifications are Copyright 2022 CloudWeGo Authors */ -package text_decoder +package decoder import ( "fmt" "reflect" + "strconv" ) type TextDecoder interface { @@ -90,3 +91,60 @@ func SelectTextDecoder(rt reflect.Type) (TextDecoder, error) { return nil, fmt.Errorf("unsupported type " + rt.String()) } + +type boolDecoder struct{} + +func (d *boolDecoder) UnmarshalString(s string, fieldValue reflect.Value) error { + v, err := strconv.ParseBool(s) + if err != nil { + return err + } + fieldValue.SetBool(v) + return nil +} + +type floatDecoder struct { + bitSize int +} + +func (d *floatDecoder) UnmarshalString(s string, fieldValue reflect.Value) error { + v, err := strconv.ParseFloat(s, d.bitSize) + if err != nil { + return err + } + fieldValue.SetFloat(v) + return nil +} + +type intDecoder struct { + bitSize int +} + +func (d *intDecoder) UnmarshalString(s string, fieldValue reflect.Value) error { + v, err := strconv.ParseInt(s, 10, d.bitSize) + if err != nil { + return err + } + fieldValue.SetInt(v) + return nil +} + +type stringDecoder struct{} + +func (d *stringDecoder) UnmarshalString(s string, fieldValue reflect.Value) error { + fieldValue.SetString(s) + return nil +} + +type uintDecoder struct { + bitSize int +} + +func (d *uintDecoder) UnmarshalString(s string, fieldValue reflect.Value) error { + v, err := strconv.ParseUint(s, 10, d.bitSize) + if err != nil { + return err + } + fieldValue.SetUint(v) + return nil +} diff --git a/pkg/app/server/binding/default.go b/pkg/app/server/binding/default.go index 0e983191c..980bab634 100644 --- a/pkg/app/server/binding/default.go +++ b/pkg/app/server/binding/default.go @@ -62,10 +62,13 @@ package binding import ( "fmt" + hjson "github.com/cloudwego/hertz/pkg/common/json" "reflect" "sync" "github.com/cloudwego/hertz/internal/bytesconv" + "github.com/cloudwego/hertz/pkg/app/server/binding/decoder" + "github.com/cloudwego/hertz/pkg/app/server/binding/path" "github.com/cloudwego/hertz/pkg/protocol" "github.com/go-playground/validator/v10" "google.golang.org/protobuf/proto" @@ -79,7 +82,7 @@ func (b *defaultBinder) Name() string { return "hertz" } -func (b *defaultBinder) Bind(req *protocol.Request, params PathParam, v interface{}) error { +func (b *defaultBinder) Bind(req *protocol.Request, params path.PathParam, v interface{}) error { err := b.preBindBody(req, v) if err != nil { return err @@ -93,12 +96,12 @@ func (b *defaultBinder) Bind(req *protocol.Request, params PathParam, v interfac } cached, ok := b.decoderCache.Load(typeID) if ok { - // cached decoder, fast path - decoder := cached.(Decoder) + // cached fieldDecoder, fast path + decoder := cached.(decoder.Decoder) return decoder(req, params, rv.Elem()) } - decoder, err := getReqDecoder(rv.Type()) + decoder, err := decoder.GetReqDecoder(rv.Type()) if err != nil { return err } @@ -120,7 +123,7 @@ func (b *defaultBinder) preBindBody(req *protocol.Request, v interface{}) error switch bytesconv.B2s(req.Header.ContentType()) { case jsonContentTypeBytes: // todo: Aligning the gin, add "EnableDecoderUseNumber"/"EnableDecoderDisallowUnknownFields" interface - return jsonUnmarshalFunc(req.Body(), v) + return hjson.Unmarshal(req.Body(), v) case protobufContentType: msg, ok := v.(proto.Message) if !ok { diff --git a/pkg/app/server/binding/json.go b/pkg/app/server/binding/json.go deleted file mode 100644 index 24407fb58..000000000 --- a/pkg/app/server/binding/json.go +++ /dev/null @@ -1,65 +0,0 @@ -/* - * Copyright 2022 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * MIT License - * - * Copyright (c) 2019-present Fenny and Contributors - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - * This file may have been modified by CloudWeGo authors. All CloudWeGo - * Modifications are Copyright 2022 CloudWeGo Authors - */ - -package binding - -import ( - "encoding/json" - - hjson "github.com/cloudwego/hertz/pkg/common/json" -) - -// JSONUnmarshaler is the interface implemented by types -// that can unmarshal a JSON description of themselves. -type JSONUnmarshaler func(data []byte, v interface{}) error - -var jsonUnmarshalFunc JSONUnmarshaler - -func init() { - ResetJSONUnmarshaler(hjson.Unmarshal) -} - -func ResetJSONUnmarshaler(fn JSONUnmarshaler) { - jsonUnmarshalFunc = fn -} - -func ResetStdJSONUnmarshaler() { - ResetJSONUnmarshaler(json.Unmarshal) -} diff --git a/pkg/app/server/binding/path/path.go b/pkg/app/server/binding/path/path.go new file mode 100644 index 000000000..26ddc02d2 --- /dev/null +++ b/pkg/app/server/binding/path/path.go @@ -0,0 +1,6 @@ +package path + +// PathParam parameter acquisition interface on the URL path +type PathParam interface { + Get(name string) (string, bool) +} diff --git a/pkg/app/server/binding/reflect.go b/pkg/app/server/binding/reflect.go index 7b8933442..4d2e7f33d 100644 --- a/pkg/app/server/binding/reflect.go +++ b/pkg/app/server/binding/reflect.go @@ -47,7 +47,6 @@ import ( func valueAndTypeID(v interface{}) (reflect.Value, uintptr) { header := (*emptyInterface)(unsafe.Pointer(&v)) - rv := reflect.ValueOf(v) return rv, header.typeID } @@ -56,66 +55,3 @@ type emptyInterface struct { typeID uintptr dataPtr unsafe.Pointer } - -// ReferenceValue convert T to *T, the ptrDepth is the count of '*'. -func ReferenceValue(v reflect.Value, ptrDepth int) reflect.Value { - switch { - case ptrDepth > 0: - for ; ptrDepth > 0; ptrDepth-- { - vv := reflect.New(v.Type()) - vv.Elem().Set(v) - v = vv - } - case ptrDepth < 0: - for ; ptrDepth < 0 && v.Kind() == reflect.Ptr; ptrDepth++ { - v = v.Elem() - } - } - return v -} - -func GetNonNilReferenceValue(v reflect.Value) (reflect.Value, int) { - var ptrDepth int - t := v.Type() - elemKind := t.Kind() - for elemKind == reflect.Ptr { - t = t.Elem() - elemKind = t.Kind() - ptrDepth++ - } - val := reflect.New(t).Elem() - return val, ptrDepth -} - -func GetFieldValue(reqValue reflect.Value, parentIndex []int) reflect.Value { - for _, idx := range parentIndex { - if reqValue.Kind() == reflect.Ptr && reqValue.IsNil() { - nonNilVal, ptrDepth := GetNonNilReferenceValue(reqValue) - reqValue.Set(ReferenceValue(nonNilVal, ptrDepth)) - } - for reqValue.Kind() == reflect.Ptr { - reqValue = reqValue.Elem() - } - reqValue = reqValue.Field(idx) - } - - // It is possible that the parent struct is also a pointer, - // so need to create a non-nil reflect.Value for it at runtime. - for reqValue.Kind() == reflect.Ptr { - if reqValue.IsNil() { - nonNilVal, ptrDepth := GetNonNilReferenceValue(reqValue) - reqValue.Set(ReferenceValue(nonNilVal, ptrDepth)) - } - reqValue = reqValue.Elem() - } - - return reqValue -} - -func getElemType(t reflect.Type) reflect.Type { - for t.Kind() == reflect.Ptr { - t = t.Elem() - } - - return t -} diff --git a/pkg/app/server/binding/text_decoder/bool.go b/pkg/app/server/binding/text_decoder/bool.go deleted file mode 100644 index 5ae167296..000000000 --- a/pkg/app/server/binding/text_decoder/bool.go +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Copyright 2022 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * MIT License - * - * Copyright (c) 2019-present Fenny and Contributors - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - * This file may have been modified by CloudWeGo authors. All CloudWeGo - * Modifications are Copyright 2022 CloudWeGo Authors - */ - -package text_decoder - -import ( - "reflect" - "strconv" -) - -type boolDecoder struct{} - -func (d *boolDecoder) UnmarshalString(s string, fieldValue reflect.Value) error { - v, err := strconv.ParseBool(s) - if err != nil { - return err - } - fieldValue.SetBool(v) - return nil -} diff --git a/pkg/app/server/binding/text_decoder/float.go b/pkg/app/server/binding/text_decoder/float.go deleted file mode 100644 index f44a1c76d..000000000 --- a/pkg/app/server/binding/text_decoder/float.go +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Copyright 2022 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * MIT License - * - * Copyright (c) 2019-present Fenny and Contributors - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - * This file may have been modified by CloudWeGo authors. All CloudWeGo - * Modifications are Copyright 2022 CloudWeGo Authors - */ - -package text_decoder - -import ( - "reflect" - "strconv" -) - -type floatDecoder struct { - bitSize int -} - -func (d *floatDecoder) UnmarshalString(s string, fieldValue reflect.Value) error { - v, err := strconv.ParseFloat(s, d.bitSize) - if err != nil { - return err - } - fieldValue.SetFloat(v) - return nil -} diff --git a/pkg/app/server/binding/text_decoder/int.go b/pkg/app/server/binding/text_decoder/int.go deleted file mode 100644 index 1594e2016..000000000 --- a/pkg/app/server/binding/text_decoder/int.go +++ /dev/null @@ -1,59 +0,0 @@ -/* - * Copyright 2022 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * MIT License - * - * Copyright (c) 2019-present Fenny and Contributors - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - * This file may have been modified by CloudWeGo authors. All CloudWeGo - * Modifications are Copyright 2022 CloudWeGo Authors - */ - -package text_decoder - -import ( - "reflect" - "strconv" -) - -type intDecoder struct { - bitSize int -} - -func (d *intDecoder) UnmarshalString(s string, fieldValue reflect.Value) error { - v, err := strconv.ParseInt(s, 10, d.bitSize) - if err != nil { - return err - } - fieldValue.SetInt(v) - return nil -} diff --git a/pkg/app/server/binding/text_decoder/string.go b/pkg/app/server/binding/text_decoder/string.go deleted file mode 100644 index 46917469f..000000000 --- a/pkg/app/server/binding/text_decoder/string.go +++ /dev/null @@ -1,50 +0,0 @@ -/* - * Copyright 2022 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * MIT License - * - * Copyright (c) 2019-present Fenny and Contributors - * - * Permission is hereby granted, free of charge, to any person obtaining a copy - * of this software and associated documentation files (the "Software"), to deal - * in the Software without restriction, including without limitation the rights - * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell - * copies of the Software, and to permit persons to whom the Software is - * furnished to do so, subject to the following conditions: - * - * The above copyright notice and this permission notice shall be included in all - * copies or substantial portions of the Software. - * - * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR - * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, - * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE - * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER - * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, - * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE - * SOFTWARE. - * - * This file may have been modified by CloudWeGo authors. All CloudWeGo - * Modifications are Copyright 2022 CloudWeGo Authors - */ - -package text_decoder - -import "reflect" - -type stringDecoder struct{} - -func (d *stringDecoder) UnmarshalString(s string, fieldValue reflect.Value) error { - fieldValue.SetString(s) - return nil -} From 6c9f894d69bdab2a8a7eb46895a9cddeb61b7a81 Mon Sep 17 00:00:00 2001 From: fgy Date: Wed, 10 May 2023 15:03:03 +0800 Subject: [PATCH 15/91] optimzie: remove todo --- pkg/app/server/binding/decoder/base_type_decoder.go | 1 - pkg/app/server/binding/decoder/decoder.go | 2 +- pkg/app/server/binding/decoder/getter.go | 1 - pkg/app/server/binding/decoder/slice_type_decoder.go | 2 -- pkg/app/server/binding/decoder/text_decoder.go | 7 ------- pkg/app/server/binding/default.go | 2 +- 6 files changed, 2 insertions(+), 13 deletions(-) diff --git a/pkg/app/server/binding/decoder/base_type_decoder.go b/pkg/app/server/binding/decoder/base_type_decoder.go index b48921930..3862983d5 100644 --- a/pkg/app/server/binding/decoder/base_type_decoder.go +++ b/pkg/app/server/binding/decoder/base_type_decoder.go @@ -89,7 +89,6 @@ func (d *baseTypeFieldTextDecoder) Decode(req *bindRequest, params path1.PathPar if len(text) == 0 && len(defaultValue) != 0 { text = defaultValue } - //todo: check a=?b=?c= 这种情况 loosemode if text == "" { return nil } diff --git a/pkg/app/server/binding/decoder/decoder.go b/pkg/app/server/binding/decoder/decoder.go index 3de28e8ca..89622ba1b 100644 --- a/pkg/app/server/binding/decoder/decoder.go +++ b/pkg/app/server/binding/decoder/decoder.go @@ -142,7 +142,7 @@ func getFieldDecoder(field reflect.StructField, index int, parentIdx []int) ([]f if field.Type.Kind() == reflect.Struct { var decoders []fieldDecoder el := field.Type - // todo: built-in bindings for some common structs, code need to be optimized + //todo: more built-int common struct binding, ex. time... switch el { case reflect.TypeOf(multipart.FileHeader{}): return getMultipartFileDecoder(field, index, fieldTagInfos, parentIdx) diff --git a/pkg/app/server/binding/decoder/getter.go b/pkg/app/server/binding/decoder/getter.go index a6c293082..3bf160f35 100644 --- a/pkg/app/server/binding/decoder/getter.go +++ b/pkg/app/server/binding/decoder/getter.go @@ -65,7 +65,6 @@ func path(req *bindRequest, params path1.PathParam, key string, defaultValue ... return } -// todo: Optimize 'postform' and 'multipart-form' func form(req *bindRequest, params path1.PathParam, key string, defaultValue ...string) (ret []string) { if req.Query == nil { req.Query = make(url.Values) diff --git a/pkg/app/server/binding/decoder/slice_type_decoder.go b/pkg/app/server/binding/decoder/slice_type_decoder.go index 40aadf3c2..833d0a2e0 100644 --- a/pkg/app/server/binding/decoder/slice_type_decoder.go +++ b/pkg/app/server/binding/decoder/slice_type_decoder.go @@ -180,8 +180,6 @@ func getSliceFieldDecoder(field reflect.StructField, index int, tagInfos []TagIn func stringToValue(elemType reflect.Type, text string) (v reflect.Value, err error) { v = reflect.New(elemType).Elem() - // todo: customized type binding - switch elemType.Kind() { case reflect.Struct: err = hjson.Unmarshal(bytesconv.S2b(text), v.Addr().Interface()) diff --git a/pkg/app/server/binding/decoder/text_decoder.go b/pkg/app/server/binding/decoder/text_decoder.go index 425c2ea46..6c034cdc8 100644 --- a/pkg/app/server/binding/decoder/text_decoder.go +++ b/pkg/app/server/binding/decoder/text_decoder.go @@ -50,14 +50,7 @@ type TextDecoder interface { UnmarshalString(s string, fieldValue reflect.Value) error } -// var textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem() - func SelectTextDecoder(rt reflect.Type) (TextDecoder, error) { - // todo: encoding.TextUnmarshaler - //if reflect.PtrTo(rt).Implements(textUnmarshalerType) { - // return &textUnmarshalEncoder{fieldType: rt}, nil - //} - switch rt.Kind() { case reflect.Bool: return &boolDecoder{}, nil diff --git a/pkg/app/server/binding/default.go b/pkg/app/server/binding/default.go index 980bab634..126b69687 100644 --- a/pkg/app/server/binding/default.go +++ b/pkg/app/server/binding/default.go @@ -122,7 +122,7 @@ func (b *defaultBinder) preBindBody(req *protocol.Request, v interface{}) error } switch bytesconv.B2s(req.Header.ContentType()) { case jsonContentTypeBytes: - // todo: Aligning the gin, add "EnableDecoderUseNumber"/"EnableDecoderDisallowUnknownFields" interface + //todo: aligning the gin, add "EnableDecoderUseNumber"/"EnableDecoderDisallowUnknownFields" interface return hjson.Unmarshal(req.Body(), v) case protobufContentType: msg, ok := v.(proto.Message) From 032b5fdd32e1bce1d608b76c91874b1258d58593 Mon Sep 17 00:00:00 2001 From: fgy Date: Wed, 10 May 2023 16:09:33 +0800 Subject: [PATCH 16/91] feat: unexported field --- pkg/app/server/binding/binder_test.go | 22 ++++++++++++++++++++++ pkg/app/server/binding/decoder/decoder.go | 7 +++++-- pkg/app/server/binding/default.go | 2 +- 3 files changed, 28 insertions(+), 3 deletions(-) diff --git a/pkg/app/server/binding/binder_test.go b/pkg/app/server/binding/binder_test.go index 4025fef44..42575ea7d 100644 --- a/pkg/app/server/binding/binder_test.go +++ b/pkg/app/server/binding/binder_test.go @@ -661,6 +661,28 @@ func TestBind_FileSliceBind(t *testing.T) { } } +func TestBind_AnonymousField(t *testing.T) { + type nest struct { + n1 string `query:"n1"` // bind default value + N2 ***string `query:"n2"` // bind n2 value + string `query:"n3"` // bind default value + } + + var s struct { + s1 int `query:"s1"` // bind default value + int `query:"s2"` // bind default value + nest + } + req := newMockRequest(). + SetRequestURI("http://foobar.com?s1=1&s2=2&n1=1&n2=2&n3=3") + err := DefaultBinder.Bind(req.Req, nil, &s) + if err != nil { + t.Fatal(err) + } + //assert.DeepEqual(t, 1, s.A) + //assert.DeepEqual(t, 0, s.b) +} + func Benchmark_Binding(b *testing.B) { type Req struct { Version string `path:"v"` diff --git a/pkg/app/server/binding/decoder/decoder.go b/pkg/app/server/binding/decoder/decoder.go index 89622ba1b..724a32099 100644 --- a/pkg/app/server/binding/decoder/decoder.go +++ b/pkg/app/server/binding/decoder/decoder.go @@ -81,7 +81,7 @@ func GetReqDecoder(rt reflect.Type) (Decoder, error) { } for i := 0; i < el.NumField(); i++ { - if !el.Field(i).IsExported() { + if el.Field(i).PkgPath != "" && !el.Field(i).Anonymous { // ignore unexported field continue } @@ -115,6 +115,9 @@ func getFieldDecoder(field reflect.StructField, index int, parentIdx []int) ([]f for field.Type.Kind() == reflect.Ptr { field.Type = field.Type.Elem() } + if field.Type.Kind() != reflect.Struct && field.Anonymous { + return nil, nil + } if reflect.PtrTo(field.Type).Implements(customizedFieldDecoderType) { return []fieldDecoder{&customizedFieldTextDecoder{ fieldInfo: fieldInfo{ @@ -149,7 +152,7 @@ func getFieldDecoder(field reflect.StructField, index int, parentIdx []int) ([]f } for i := 0; i < el.NumField(); i++ { - if !el.Field(i).IsExported() { + if el.Field(i).PkgPath != "" && !el.Field(i).Anonymous { // ignore unexported field continue } diff --git a/pkg/app/server/binding/default.go b/pkg/app/server/binding/default.go index 126b69687..f03d1def5 100644 --- a/pkg/app/server/binding/default.go +++ b/pkg/app/server/binding/default.go @@ -88,7 +88,7 @@ func (b *defaultBinder) Bind(req *protocol.Request, params path.PathParam, v int return err } rv, typeID := valueAndTypeID(v) - if rv.Kind() != reflect.Pointer || rv.IsNil() { + if rv.Kind() != reflect.Ptr || rv.IsNil() { return fmt.Errorf("receiver must be a non-nil pointer") } if rv.Elem().Kind() == reflect.Map { From 4710c35cbb1772a57afebf868961fed7786eab46 Mon Sep 17 00:00:00 2001 From: fgy Date: Wed, 10 May 2023 16:12:49 +0800 Subject: [PATCH 17/91] feat: add license --- pkg/app/server/binding/config.go | 16 ++++++++++++++++ pkg/app/server/binding/path/path.go | 16 ++++++++++++++++ 2 files changed, 32 insertions(+) diff --git a/pkg/app/server/binding/config.go b/pkg/app/server/binding/config.go index b02e6e78b..3991d437e 100644 --- a/pkg/app/server/binding/config.go +++ b/pkg/app/server/binding/config.go @@ -1,3 +1,19 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package binding import ( diff --git a/pkg/app/server/binding/path/path.go b/pkg/app/server/binding/path/path.go index 26ddc02d2..b40432e04 100644 --- a/pkg/app/server/binding/path/path.go +++ b/pkg/app/server/binding/path/path.go @@ -1,3 +1,19 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package path // PathParam parameter acquisition interface on the URL path From 40d41b2f23a41d618f1b69e7dc5c8e7a8bbf8c0d Mon Sep 17 00:00:00 2001 From: fgy Date: Wed, 10 May 2023 16:16:37 +0800 Subject: [PATCH 18/91] ci: test --- pkg/app/server/binding/binder_test.go | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/pkg/app/server/binding/binder_test.go b/pkg/app/server/binding/binder_test.go index 42575ea7d..6f715adf0 100644 --- a/pkg/app/server/binding/binder_test.go +++ b/pkg/app/server/binding/binder_test.go @@ -679,8 +679,11 @@ func TestBind_AnonymousField(t *testing.T) { if err != nil { t.Fatal(err) } - //assert.DeepEqual(t, 1, s.A) - //assert.DeepEqual(t, 0, s.b) + assert.DeepEqual(t, 0, s.s1) + assert.DeepEqual(t, 0, s.int) + assert.DeepEqual(t, "", s.nest.n1) + assert.DeepEqual(t, "2", ***s.nest.N2) + assert.DeepEqual(t, "", s.nest.string) } func Benchmark_Binding(b *testing.B) { From 8c85456ca7c4cb71aee736349b34f6db060d9d92 Mon Sep 17 00:00:00 2001 From: fgy Date: Wed, 10 May 2023 16:20:16 +0800 Subject: [PATCH 19/91] ci: go module --- go.mod | 7 ++++--- go.sum | 48 +++++++++++++++++------------------------------- 2 files changed, 21 insertions(+), 34 deletions(-) diff --git a/go.mod b/go.mod index 8f0892db9..44b68b92c 100644 --- a/go.mod +++ b/go.mod @@ -8,8 +8,9 @@ require ( github.com/bytedance/sonic v1.8.1 github.com/cloudwego/netpoll v0.4.2-0.20230807055039-52fd5fb7b00f github.com/fsnotify/fsnotify v1.5.4 - github.com/go-playground/validator/v10 v10.11.2 - golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 - golang.org/x/sys v0.4.0 + github.com/go-playground/assert/v2 v2.2.0 // indirect + github.com/go-playground/validator/v10 v10.11.1 + golang.org/x/sync v0.0.0-20210220032951-036812b2e83c + golang.org/x/sys v0.0.0-20220412211240-33da011f77ad google.golang.org/protobuf v1.27.1 ) diff --git a/go.sum b/go.sum index e21bb766c..f4f84c45a 100644 --- a/go.sum +++ b/go.sum @@ -15,14 +15,15 @@ github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/fsnotify/fsnotify v1.5.4 h1:jRbGcIw6P2Meqdwuo0H1p6JVLbL5DHKAKlYndzMwVZI= github.com/fsnotify/fsnotify v1.5.4/go.mod h1:OVB6XrOHzAwXMpEM7uPOzcehqUV2UqJxmVXmkdnm1bU= +github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= -github.com/go-playground/locales v0.14.1 h1:EWaQ/wswjilfKLTECiXz7Rh+3BjFhfDFKv/oXslEjJA= -github.com/go-playground/locales v0.14.1/go.mod h1:hxrqLVvrK65+Rwrd5Fc6F2O76J/NuW9t0sjnWqG1slY= -github.com/go-playground/universal-translator v0.18.1 h1:Bcnm0ZwsGyWbCzImXv+pAJnYK9S473LQFuzCbDbfSFY= -github.com/go-playground/universal-translator v0.18.1/go.mod h1:xekY+UJKNuX9WP91TpwSH2VMlDf28Uj24BCp08ZFTUY= -github.com/go-playground/validator/v10 v10.11.2 h1:q3SHpufmypg+erIExEKUmsgmhDTyhcJ38oeKGACXohU= -github.com/go-playground/validator/v10 v10.11.2/go.mod h1:NieE624vt4SCTJtD87arVLvdmjPAeV8BQlHtMnw9D7s= +github.com/go-playground/locales v0.14.0 h1:u50s323jtVGugKlcYeyzC0etD1HifMjqmJqb8WugfUU= +github.com/go-playground/locales v0.14.0/go.mod h1:sawfccIbzZTqEDETgFXqTho0QybSa7l++s0DH+LDiLs= +github.com/go-playground/universal-translator v0.18.0 h1:82dyy6p4OuJq4/CByFNOn/jYrnRPArHwAcmLoJZxyho= +github.com/go-playground/universal-translator v0.18.0/go.mod h1:UvRDBj+xPUEGrFYl+lu/H90nyDXpg0fqeB/AQUGNTVA= +github.com/go-playground/validator/v10 v10.11.1 h1:prmOlTVv+YjZjmRmNSF3VmspqJIxJWXmqUsHwfTRRkQ= +github.com/go-playground/validator/v10 v10.11.1/go.mod h1:i+3WkQ1FvaUjjxh1kSvIA4dMGDBiPU55YFDl0WbKdWU= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= @@ -60,47 +61,31 @@ github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKs github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= -github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= golang.org/x/arch v0.0.0-20201008161808-52c3e6f60cff/go.mod h1:flIaEI6LNU6xOCD5PaJvn9wGP0agmIOqjrtsKGRguv4= golang.org/x/arch v0.0.0-20210923205945-b76863e36670 h1:18EFjUmQOcUvxNYSkA6jO9VAiXCnxFY6NyDX0bHDmkU= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= -golang.org/x/crypto v0.5.0 h1:U/0M97KRkSFvyD/3FSmdP5W5swImpNgle/EHFhOsQPE= -golang.org/x/crypto v0.5.0/go.mod h1:NK/OQwhpMQP3MwtdjgLlYHnH9ebylxKWv3e0fK+mkQU= -golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= +golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3 h1:0es+/5331RGQPcXlMfP+WrnIIS6dNnNRe0WB02W0F4M= +golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= -golang.org/x/net v0.0.0-20210226172049-e18ecbb05110/go.mod h1:m0MpNAwzfU5UDzcl9v0D8zg8gWTRqZa9RBIspLL5mdg= -golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= -golang.org/x/net v0.5.0/go.mod h1:DivGGAXEgPSlEBzxGzZI+ZLohi+xUj054jfeKui00ws= -golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= -golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4 h1:uVc8UZUe6tr40fFVnUP5Oj+veunVezqYl9z7DYw9xzw= -golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220110181412-a018aaa089fe/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220412211240-33da011f77ad h1:ntjMns5wyP/fN65tdBD4g8J5w8n015+iIIs9rtjXkY0= golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220520151302-bc2c85ada10a/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20220722155257-8c9f86f7a55f/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18= -golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= -golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= -golang.org/x/term v0.4.0/go.mod h1:9P2UbLfCdcvo3p/nzKvsmas4TnlujnuoV9hGgYzW1lQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/text v0.3.8/go.mod h1:E6s5w1FMmriuDzIBO73fBruAKo1PCIq6d2Q6DHfQ8WQ= -golang.org/x/text v0.6.0 h1:3XmdazWV+ubf7QgHSTWeykHOci5oeekaGJBLkrkaw4k= -golang.org/x/text v0.6.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= -golang.org/x/tools v0.0.0-20191119224855-298f0cb1881e/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= -golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= -golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= @@ -111,6 +96,7 @@ gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8 gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= From b805df5c6b9f5d457340b1241601e65f3d61295e Mon Sep 17 00:00:00 2001 From: fgy Date: Wed, 10 May 2023 16:49:19 +0800 Subject: [PATCH 20/91] ci: getter --- pkg/app/server/binding/decoder/getter.go | 18 ++++-------------- 1 file changed, 4 insertions(+), 14 deletions(-) diff --git a/pkg/app/server/binding/decoder/getter.go b/pkg/app/server/binding/decoder/getter.go index 3bf160f35..58ca02555 100644 --- a/pkg/app/server/binding/decoder/getter.go +++ b/pkg/app/server/binding/decoder/getter.go @@ -70,7 +70,7 @@ func form(req *bindRequest, params path1.PathParam, key string, defaultValue ... req.Query = make(url.Values) req.Req.URI().QueryArgs().VisitAll(func(queryKey, value []byte) { keyStr := string(queryKey) - values, _ := req.Query[keyStr] + var values, _ = req.Query[keyStr] values = append(values, string(value)) req.Query[keyStr] = values }) @@ -84,7 +84,7 @@ func form(req *bindRequest, params path1.PathParam, key string, defaultValue ... req.Form = make(url.Values) req.Req.PostArgs().VisitAll(func(formKey, value []byte) { keyStr := string(formKey) - values, _ := req.Form[keyStr] + var values, _ = req.Form[keyStr] values = append(values, string(value)) req.Form[keyStr] = values }) @@ -122,7 +122,7 @@ func query(req *bindRequest, params path1.PathParam, key string, defaultValue .. req.Query = make(url.Values) req.Req.URI().QueryArgs().VisitAll(func(queryKey, value []byte) { keyStr := string(queryKey) - values, _ := req.Query[keyStr] + var values, _ = req.Query[keyStr] values = append(values, string(value)) req.Query[keyStr] = values }) @@ -162,7 +162,7 @@ func header(req *bindRequest, params path1.PathParam, key string, defaultValue . req.Header = make(http.Header) req.Req.Header.VisitAll(func(headerKey, value []byte) { keyStr := string(headerKey) - values, _ := req.Header[keyStr] + var values, _ = req.Header[keyStr] values = append(values, string(value)) req.Header[keyStr] = values }) @@ -176,19 +176,9 @@ func header(req *bindRequest, params path1.PathParam, key string, defaultValue . return } -func json(req *bindRequest, params path1.PathParam, key string, defaultValue ...string) (ret []string) { - // do nothing - return -} - func rawBody(req *bindRequest, params path1.PathParam, key string, defaultValue ...string) (ret []string) { if req.Req.Header.ContentLength() > 0 { ret = append(ret, string(req.Req.Body())) } return } - -func fileName(req *bindRequest, params path1.PathParam, key string, defaultValue ...string) (ret []string) { - // do nothing - return -} From 15f58259f911a0df707d63cf1ac57c8d9f5b816a Mon Sep 17 00:00:00 2001 From: fgy Date: Wed, 10 May 2023 16:57:32 +0800 Subject: [PATCH 21/91] ci: gofump --- pkg/app/server/binding/binder_test.go | 2 +- pkg/app/server/binding/decoder/decoder.go | 2 +- pkg/app/server/binding/decoder/getter.go | 8 ++++---- pkg/app/server/binding/decoder/slice_type_decoder.go | 2 +- pkg/app/server/binding/default.go | 4 ++-- 5 files changed, 9 insertions(+), 9 deletions(-) diff --git a/pkg/app/server/binding/binder_test.go b/pkg/app/server/binding/binder_test.go index 6f715adf0..05dd82153 100644 --- a/pkg/app/server/binding/binder_test.go +++ b/pkg/app/server/binding/binder_test.go @@ -42,10 +42,10 @@ package binding import ( "fmt" - "github.com/cloudwego/hertz/pkg/app/server/binding/path" "mime/multipart" "testing" + "github.com/cloudwego/hertz/pkg/app/server/binding/path" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/protocol" req2 "github.com/cloudwego/hertz/pkg/protocol/http1/req" diff --git a/pkg/app/server/binding/decoder/decoder.go b/pkg/app/server/binding/decoder/decoder.go index 724a32099..2cb4a550b 100644 --- a/pkg/app/server/binding/decoder/decoder.go +++ b/pkg/app/server/binding/decoder/decoder.go @@ -145,7 +145,7 @@ func getFieldDecoder(field reflect.StructField, index int, parentIdx []int) ([]f if field.Type.Kind() == reflect.Struct { var decoders []fieldDecoder el := field.Type - //todo: more built-int common struct binding, ex. time... + // todo: more built-int common struct binding, ex. time... switch el { case reflect.TypeOf(multipart.FileHeader{}): return getMultipartFileDecoder(field, index, fieldTagInfos, parentIdx) diff --git a/pkg/app/server/binding/decoder/getter.go b/pkg/app/server/binding/decoder/getter.go index 58ca02555..01b8ebb83 100644 --- a/pkg/app/server/binding/decoder/getter.go +++ b/pkg/app/server/binding/decoder/getter.go @@ -70,7 +70,7 @@ func form(req *bindRequest, params path1.PathParam, key string, defaultValue ... req.Query = make(url.Values) req.Req.URI().QueryArgs().VisitAll(func(queryKey, value []byte) { keyStr := string(queryKey) - var values, _ = req.Query[keyStr] + values := req.Query[keyStr] values = append(values, string(value)) req.Query[keyStr] = values }) @@ -84,7 +84,7 @@ func form(req *bindRequest, params path1.PathParam, key string, defaultValue ... req.Form = make(url.Values) req.Req.PostArgs().VisitAll(func(formKey, value []byte) { keyStr := string(formKey) - var values, _ = req.Form[keyStr] + values := req.Form[keyStr] values = append(values, string(value)) req.Form[keyStr] = values }) @@ -122,7 +122,7 @@ func query(req *bindRequest, params path1.PathParam, key string, defaultValue .. req.Query = make(url.Values) req.Req.URI().QueryArgs().VisitAll(func(queryKey, value []byte) { keyStr := string(queryKey) - var values, _ = req.Query[keyStr] + values := req.Query[keyStr] values = append(values, string(value)) req.Query[keyStr] = values }) @@ -162,7 +162,7 @@ func header(req *bindRequest, params path1.PathParam, key string, defaultValue . req.Header = make(http.Header) req.Req.Header.VisitAll(func(headerKey, value []byte) { keyStr := string(headerKey) - var values, _ = req.Header[keyStr] + values := req.Header[keyStr] values = append(values, string(value)) req.Header[keyStr] = values }) diff --git a/pkg/app/server/binding/decoder/slice_type_decoder.go b/pkg/app/server/binding/decoder/slice_type_decoder.go index 833d0a2e0..93fd4b391 100644 --- a/pkg/app/server/binding/decoder/slice_type_decoder.go +++ b/pkg/app/server/binding/decoder/slice_type_decoder.go @@ -67,7 +67,7 @@ func (d *sliceTypeFieldTextDecoder) Decode(req *bindRequest, params path1.PathPa tagInfo.Value = utils.GetNormalizeHeaderKey(tagInfo.Value, req.Req.Header.IsDisableNormalizing()) } texts = tagInfo.Getter(req, params, tagInfo.Value) - //todo: array/slice default value + // todo: array/slice default value defaultValue = tagInfo.Default if len(texts) != 0 { break diff --git a/pkg/app/server/binding/default.go b/pkg/app/server/binding/default.go index f03d1def5..6f246bd53 100644 --- a/pkg/app/server/binding/default.go +++ b/pkg/app/server/binding/default.go @@ -62,13 +62,13 @@ package binding import ( "fmt" - hjson "github.com/cloudwego/hertz/pkg/common/json" "reflect" "sync" "github.com/cloudwego/hertz/internal/bytesconv" "github.com/cloudwego/hertz/pkg/app/server/binding/decoder" "github.com/cloudwego/hertz/pkg/app/server/binding/path" + hjson "github.com/cloudwego/hertz/pkg/common/json" "github.com/cloudwego/hertz/pkg/protocol" "github.com/go-playground/validator/v10" "google.golang.org/protobuf/proto" @@ -122,7 +122,7 @@ func (b *defaultBinder) preBindBody(req *protocol.Request, v interface{}) error } switch bytesconv.B2s(req.Header.ContentType()) { case jsonContentTypeBytes: - //todo: aligning the gin, add "EnableDecoderUseNumber"/"EnableDecoderDisallowUnknownFields" interface + // todo: aligning the gin, add "EnableDecoderUseNumber"/"EnableDecoderDisallowUnknownFields" interface return hjson.Unmarshal(req.Body(), v) case protobufContentType: msg, ok := v.(proto.Message) From 9ed9b832888d119f49285658abf9aa87e928e771 Mon Sep 17 00:00:00 2001 From: fgy Date: Wed, 10 May 2023 17:47:34 +0800 Subject: [PATCH 22/91] feat: ignore field --- pkg/app/server/binding/binder_test.go | 31 +++++++++++++++++++ .../binding/decoder/base_type_decoder.go | 2 +- .../binding/decoder/map_type_decoder.go | 2 +- .../binding/decoder/slice_type_decoder.go | 2 +- pkg/app/server/binding/decoder/tag.go | 7 ++++- 5 files changed, 40 insertions(+), 4 deletions(-) diff --git a/pkg/app/server/binding/binder_test.go b/pkg/app/server/binding/binder_test.go index 05dd82153..8abfd2d37 100644 --- a/pkg/app/server/binding/binder_test.go +++ b/pkg/app/server/binding/binder_test.go @@ -686,6 +686,37 @@ func TestBind_AnonymousField(t *testing.T) { assert.DeepEqual(t, "", s.nest.string) } +func TestBind_IgnoreField(t *testing.T) { + type Req struct { + Version int `path:"-"` + ID int `query:"-"` + Header string `header:"-"` + Form string `form:"-"` + } + + req := newMockRequest(). + SetRequestURI("http://foobar.com?id=12"). + SetHeaders("H", "header"). + SetPostArg("f", "form"). + SetUrlEncodeContentType() + var params param.Params + params = append(params, param.Param{ + Key: "v", + Value: "1", + }) + + var result Req + + err := DefaultBinder.Bind(req.Req, params, &result) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, 0, result.Version) + assert.DeepEqual(t, 0, result.ID) + assert.DeepEqual(t, "", result.Header) + assert.DeepEqual(t, "", result.Form) +} + func Benchmark_Binding(b *testing.B) { type Req struct { Version string `path:"v"` diff --git a/pkg/app/server/binding/decoder/base_type_decoder.go b/pkg/app/server/binding/decoder/base_type_decoder.go index 3862983d5..c1d22d239 100644 --- a/pkg/app/server/binding/decoder/base_type_decoder.go +++ b/pkg/app/server/binding/decoder/base_type_decoder.go @@ -66,7 +66,7 @@ func (d *baseTypeFieldTextDecoder) Decode(req *bindRequest, params path1.PathPar var text string var defaultValue string for _, tagInfo := range d.tagInfos { - if tagInfo.Key == jsonTag || tagInfo.Key == fileNameTag { + if tagInfo.Skip || tagInfo.Key == jsonTag || tagInfo.Key == fileNameTag { continue } if tagInfo.Key == headerTag { diff --git a/pkg/app/server/binding/decoder/map_type_decoder.go b/pkg/app/server/binding/decoder/map_type_decoder.go index 37529e378..37bb2d8ed 100644 --- a/pkg/app/server/binding/decoder/map_type_decoder.go +++ b/pkg/app/server/binding/decoder/map_type_decoder.go @@ -58,7 +58,7 @@ func (d *mapTypeFieldTextDecoder) Decode(req *bindRequest, params path1.PathPara var text string var defaultValue string for _, tagInfo := range d.tagInfos { - if tagInfo.Key == jsonTag || tagInfo.Key == fileNameTag { + if tagInfo.Skip || tagInfo.Key == jsonTag || tagInfo.Key == fileNameTag { continue } if tagInfo.Key == headerTag { diff --git a/pkg/app/server/binding/decoder/slice_type_decoder.go b/pkg/app/server/binding/decoder/slice_type_decoder.go index 93fd4b391..65f1fbe69 100644 --- a/pkg/app/server/binding/decoder/slice_type_decoder.go +++ b/pkg/app/server/binding/decoder/slice_type_decoder.go @@ -60,7 +60,7 @@ func (d *sliceTypeFieldTextDecoder) Decode(req *bindRequest, params path1.PathPa var texts []string var defaultValue string for _, tagInfo := range d.tagInfos { - if tagInfo.Key == jsonTag || tagInfo.Key == fileNameTag { + if tagInfo.Skip || tagInfo.Key == jsonTag || tagInfo.Key == fileNameTag { continue } if tagInfo.Key == headerTag { diff --git a/pkg/app/server/binding/decoder/tag.go b/pkg/app/server/binding/decoder/tag.go index 6ad002788..db4f1d12a 100644 --- a/pkg/app/server/binding/decoder/tag.go +++ b/pkg/app/server/binding/decoder/tag.go @@ -44,6 +44,7 @@ type TagInfo struct { Key string Value string Required bool + Skip bool Default string Options []string Getter getter @@ -75,6 +76,10 @@ func lookupFieldTags(field reflect.StructField) []TagInfo { for _, tag := range ret { tagContent := field.Tag.Get(tag) tagValue, opts := head(tagContent, ",") + skip := false + if tagValue == "-" { + skip = true + } var options []string var opt string var required bool @@ -85,7 +90,7 @@ func lookupFieldTags(field reflect.StructField) []TagInfo { required = true } } - tagInfos = append(tagInfos, TagInfo{Key: tag, Value: tagValue, Options: options, Required: required, Default: defaultVal}) + tagInfos = append(tagInfos, TagInfo{Key: tag, Value: tagValue, Options: options, Required: required, Default: defaultVal, Skip: skip}) } return tagInfos From 36ebf47b0c4a831ac2c2b082c011800b9051a319 Mon Sep 17 00:00:00 2001 From: fgy Date: Thu, 11 May 2023 11:06:15 +0800 Subject: [PATCH 23/91] fix: some diff from go-tagexpr --- pkg/app/server/binding/decoder/base_type_decoder.go | 2 +- pkg/app/server/binding/decoder/map_type_decoder.go | 2 +- pkg/app/server/binding/decoder/slice_type_decoder.go | 7 ++++++- pkg/app/server/binding/decoder/text_decoder.go | 12 ++++++++++++ 4 files changed, 20 insertions(+), 3 deletions(-) diff --git a/pkg/app/server/binding/decoder/base_type_decoder.go b/pkg/app/server/binding/decoder/base_type_decoder.go index c1d22d239..ed0f1541b 100644 --- a/pkg/app/server/binding/decoder/base_type_decoder.go +++ b/pkg/app/server/binding/decoder/base_type_decoder.go @@ -137,7 +137,7 @@ func getBaseTypeTextDecoder(field reflect.StructField, index int, tagInfos []Tag case jsonTag: // do nothing case rawBodyTag: - tagInfo.Getter = rawBody + tagInfos[idx].Getter = rawBody case fileNameTag: // do nothing default: diff --git a/pkg/app/server/binding/decoder/map_type_decoder.go b/pkg/app/server/binding/decoder/map_type_decoder.go index 37bb2d8ed..d9090c931 100644 --- a/pkg/app/server/binding/decoder/map_type_decoder.go +++ b/pkg/app/server/binding/decoder/map_type_decoder.go @@ -120,7 +120,7 @@ func getMapTypeTextDecoder(field reflect.StructField, index int, tagInfos []TagI case jsonTag: // do nothing case rawBodyTag: - tagInfo.Getter = rawBody + tagInfos[idx].Getter = rawBody case fileNameTag: // do nothing default: diff --git a/pkg/app/server/binding/decoder/slice_type_decoder.go b/pkg/app/server/binding/decoder/slice_type_decoder.go index 65f1fbe69..2164abbae 100644 --- a/pkg/app/server/binding/decoder/slice_type_decoder.go +++ b/pkg/app/server/binding/decoder/slice_type_decoder.go @@ -100,6 +100,11 @@ func (d *sliceTypeFieldTextDecoder) Decode(req *bindRequest, params path1.PathPa return fmt.Errorf("%q is not valid value for %s", texts, field.Type().String()) } } else { + if field.Type().Elem().Kind() == reflect.Uint8 { + reqValue.Field(d.index).Set(reflect.ValueOf([]byte(texts[0]))) + return nil + } + // slice need creating enough capacity field = reflect.MakeSlice(field.Type(), len(texts), len(texts)) } @@ -150,7 +155,7 @@ func getSliceFieldDecoder(field reflect.StructField, index int, tagInfos []TagIn case jsonTag: // do nothing case rawBodyTag: - tagInfo.Getter = rawBody + tagInfos[idx].Getter = rawBody case fileNameTag: // do nothing default: diff --git a/pkg/app/server/binding/decoder/text_decoder.go b/pkg/app/server/binding/decoder/text_decoder.go index 6c034cdc8..157023f3c 100644 --- a/pkg/app/server/binding/decoder/text_decoder.go +++ b/pkg/app/server/binding/decoder/text_decoder.go @@ -88,6 +88,9 @@ func SelectTextDecoder(rt reflect.Type) (TextDecoder, error) { type boolDecoder struct{} func (d *boolDecoder) UnmarshalString(s string, fieldValue reflect.Value) error { + if s == "" { + s = "false" + } v, err := strconv.ParseBool(s) if err != nil { return err @@ -101,6 +104,9 @@ type floatDecoder struct { } func (d *floatDecoder) UnmarshalString(s string, fieldValue reflect.Value) error { + if s == "" { + s = "0.0" + } v, err := strconv.ParseFloat(s, d.bitSize) if err != nil { return err @@ -114,6 +120,9 @@ type intDecoder struct { } func (d *intDecoder) UnmarshalString(s string, fieldValue reflect.Value) error { + if s == "" { + s = "0" + } v, err := strconv.ParseInt(s, 10, d.bitSize) if err != nil { return err @@ -134,6 +143,9 @@ type uintDecoder struct { } func (d *uintDecoder) UnmarshalString(s string, fieldValue reflect.Value) error { + if s == "" { + s = "0" + } v, err := strconv.ParseUint(s, 10, d.bitSize) if err != nil { return err From 787bed49942eba6ae8d03d2170017f7a4dbb58e4 Mon Sep 17 00:00:00 2001 From: fgy Date: Thu, 11 May 2023 20:35:49 +0800 Subject: [PATCH 24/91] fix: filter content type --- pkg/app/server/binding/binder_test.go | 2 +- .../binding/decoder/base_type_decoder.go | 1 + pkg/app/server/binding/default.go | 10 +- pkg/app/server/binding/tagexpr_bind_test.go | 1255 +++++++++++++++++ pkg/common/utils/utils.go | 9 + 5 files changed, 1272 insertions(+), 5 deletions(-) create mode 100644 pkg/app/server/binding/tagexpr_bind_test.go diff --git a/pkg/app/server/binding/binder_test.go b/pkg/app/server/binding/binder_test.go index 8abfd2d37..51c5505b2 100644 --- a/pkg/app/server/binding/binder_test.go +++ b/pkg/app/server/binding/binder_test.go @@ -93,7 +93,7 @@ func (m *mockRequest) SetUrlEncodeContentType() *mockRequest { } func (m *mockRequest) SetJSONContentType() *mockRequest { - m.Req.Header.SetContentTypeBytes([]byte(jsonContentTypeBytes)) + m.Req.Header.SetContentTypeBytes([]byte(jsonContentType)) return m } diff --git a/pkg/app/server/binding/decoder/base_type_decoder.go b/pkg/app/server/binding/decoder/base_type_decoder.go index ed0f1541b..aaba63ebf 100644 --- a/pkg/app/server/binding/decoder/base_type_decoder.go +++ b/pkg/app/server/binding/decoder/base_type_decoder.go @@ -67,6 +67,7 @@ func (d *baseTypeFieldTextDecoder) Decode(req *bindRequest, params path1.PathPar var defaultValue string for _, tagInfo := range d.tagInfos { if tagInfo.Skip || tagInfo.Key == jsonTag || tagInfo.Key == fileNameTag { + defaultValue = tagInfo.Default continue } if tagInfo.Key == headerTag { diff --git a/pkg/app/server/binding/default.go b/pkg/app/server/binding/default.go index 6f246bd53..85dbb93de 100644 --- a/pkg/app/server/binding/default.go +++ b/pkg/app/server/binding/default.go @@ -69,6 +69,7 @@ import ( "github.com/cloudwego/hertz/pkg/app/server/binding/decoder" "github.com/cloudwego/hertz/pkg/app/server/binding/path" hjson "github.com/cloudwego/hertz/pkg/common/json" + "github.com/cloudwego/hertz/pkg/common/utils" "github.com/cloudwego/hertz/pkg/protocol" "github.com/go-playground/validator/v10" "google.golang.org/protobuf/proto" @@ -111,8 +112,8 @@ func (b *defaultBinder) Bind(req *protocol.Request, params path.PathParam, v int } var ( - jsonContentTypeBytes = "application/json; charset=utf-8" - protobufContentType = "application/x-protobuf" + jsonContentType = "application/json" + protobufContentType = "application/x-protobuf" ) // best effort binding @@ -120,8 +121,9 @@ func (b *defaultBinder) preBindBody(req *protocol.Request, v interface{}) error if req.Header.ContentLength() <= 0 { return nil } - switch bytesconv.B2s(req.Header.ContentType()) { - case jsonContentTypeBytes: + ct := bytesconv.B2s(req.Header.ContentType()) + switch utils.FilterContentType(ct) { + case jsonContentType: // todo: aligning the gin, add "EnableDecoderUseNumber"/"EnableDecoderDisallowUnknownFields" interface return hjson.Unmarshal(req.Body(), v) case protobufContentType: diff --git a/pkg/app/server/binding/tagexpr_bind_test.go b/pkg/app/server/binding/tagexpr_bind_test.go new file mode 100644 index 000000000..bf5650ecd --- /dev/null +++ b/pkg/app/server/binding/tagexpr_bind_test.go @@ -0,0 +1,1255 @@ +package binding + +import ( + "bytes" + "encoding/json" + "io" + "io/ioutil" + "mime/multipart" + "net/http" + "net/url" + "os" + "strings" + "testing" + "time" + + "github.com/cloudwego/hertz/pkg/common/test/assert" +) + +func TestRawBody(t *testing.T) { + type Recv struct { + S []byte `raw_body:""` + F **string `raw_body:""` + } + bodyBytes := []byte("raw_body.............") + req := newRequest("", nil, nil, bytes.NewReader(bodyBytes)) + recv := new(Recv) + err := DefaultBinder.Bind(req.Req, nil, recv) + if err != nil { + if err != nil { + t.Error(err) + } + } + + assert.DeepEqual(t, bodyBytes, recv.S) + assert.DeepEqual(t, string(bodyBytes), **recv.F) +} + +func TestQueryString(t *testing.T) { + type metric string + type count int32 + + type Recv struct { + X **struct { + A []string `query:"a"` + B string `query:"b"` + C *[]string `query:"c,required"` + D *string `query:"d"` + E *[]***int `query:"e"` + F metric `query:"f"` + G []count `query:"g"` + } + Y string `query:"y,required"` + Z *string `query:"z"` + } + req := newRequest("http://localhost:8080/?a=a1&a=a2&b=b1&c=c1&c=c2&d=d1&d=d&f=qps&g=1002&g=1003&e=&e=2&y=y1", nil, nil, nil) + recv := new(Recv) + err := DefaultBinder.Bind(req.Req, nil, recv) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, 0, ***(*(**recv.X).E)[0]) + assert.DeepEqual(t, 2, ***(*(**recv.X).E)[1]) + assert.DeepEqual(t, []string{"a1", "a2"}, (**recv.X).A) + assert.DeepEqual(t, "b1", (**recv.X).B) + assert.DeepEqual(t, []string{"c1", "c2"}, *(**recv.X).C) + assert.DeepEqual(t, "d1", *(**recv.X).D) + assert.DeepEqual(t, metric("qps"), (**recv.X).F) + assert.DeepEqual(t, []count{1002, 1003}, (**recv.X).G) + assert.DeepEqual(t, "y1", recv.Y) + assert.DeepEqual(t, (*string)(nil), recv.Z) +} + +func TestGetBody(t *testing.T) { + type Recv struct { + X **struct { + E string `json:"e,required" query:"e,required"` + } + } + req := newRequest("http://localhost:8080/", nil, nil, nil) + recv := new(Recv) + err := DefaultBinder.Bind(req.Req, nil, recv) + assert.DeepEqual(t, err.Error(), "'E' field is a 'required' parameter, but the request does not have this parameter") +} + +func TestQueryNum(t *testing.T) { + type Recv struct { + X **struct { + A []int `query:"a"` + B int32 `query:"b"` + C *[]uint16 `query:"c,required"` + D *float32 `query:"d"` + } + Y bool `query:"y,required"` + Z *int64 `query:"z"` + } + req := newRequest("http://localhost:8080/?a=11&a=12&b=21&c=31&c=32&d=41&d=42&y=true", nil, nil, nil) + recv := new(Recv) + err := DefaultBinder.Bind(req.Req, nil, recv) + if err != nil { + if err != nil { + t.Error(err) + } + } + assert.DeepEqual(t, []int{11, 12}, (**recv.X).A) + assert.DeepEqual(t, int32(21), (**recv.X).B) + assert.DeepEqual(t, &[]uint16{31, 32}, (**recv.X).C) + assert.DeepEqual(t, float32(41), *(**recv.X).D) + assert.DeepEqual(t, true, recv.Y) + assert.DeepEqual(t, (*int64)(nil), recv.Z) +} + +func TestHeaderString(t *testing.T) { + type Recv struct { + X **struct { + A []string `header:"X-A"` + B string `header:"X-B"` + C *[]string `header:"X-C,required"` + D *string `header:"X-D"` + } + Y string `header:"X-Y,required"` + Z *string `header:"X-Z"` + } + header := make(http.Header) + header.Add("X-A", "a1") + header.Add("X-A", "a2") + header.Add("X-B", "b1") + header.Add("X-C", "c1") + header.Add("X-C", "c2") + header.Add("X-D", "d1") + header.Add("X-D", "d2") + header.Add("X-Y", "y1") + req := newRequest("", header, nil, nil) + recv := new(Recv) + err := DefaultBinder.Bind(req.Req, nil, recv) + if err != nil { + if err != nil { + t.Error(err) + } + } + assert.DeepEqual(t, []string{"a1", "a2"}, (**recv.X).A) + assert.DeepEqual(t, "b1", (**recv.X).B) + assert.DeepEqual(t, []string{"c1", "c2"}, *(**recv.X).C) + assert.DeepEqual(t, "d1", *(**recv.X).D) + assert.DeepEqual(t, "y1", recv.Y) + assert.DeepEqual(t, (*string)(nil), recv.Z) +} + +func TestHeaderNum(t *testing.T) { + type Recv struct { + X **struct { + A []int `header:"X-A"` + B int32 `header:"X-B"` + C *[]uint16 `header:"X-C,required"` + D *float32 `header:"X-D"` + } + Y bool `header:"X-Y,required"` + Z *int64 `header:"X-Z"` + } + header := make(http.Header) + header.Add("X-A", "11") + header.Add("X-A", "12") + header.Add("X-B", "21") + header.Add("X-C", "31") + header.Add("X-C", "32") + header.Add("X-D", "41") + header.Add("X-D", "42") + header.Add("X-Y", "true") + req := newRequest("", header, nil, nil) + recv := new(Recv) + + err := DefaultBinder.Bind(req.Req, nil, recv) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, []int{11, 12}, (**recv.X).A) + assert.DeepEqual(t, int32(21), (**recv.X).B) + assert.DeepEqual(t, &[]uint16{31, 32}, (**recv.X).C) + assert.DeepEqual(t, float32(41), *(**recv.X).D) + assert.DeepEqual(t, true, recv.Y) + assert.DeepEqual(t, (*int64)(nil), recv.Z) +} + +// todo: cookie slice +func TestCookieString(t *testing.T) { + type Recv struct { + X **struct { + A []string `cookie:"a"` + B string `cookie:"b"` + C *[]string `cookie:"c,required"` + D *string `cookie:"d"` + } + Y string `cookie:"y,required"` + Z *string `cookie:"z"` + } + cookies := []*http.Cookie{ + {Name: "a", Value: "a1"}, + {Name: "a", Value: "a2"}, + {Name: "b", Value: "b1"}, + {Name: "c", Value: "c1"}, + {Name: "c", Value: "c2"}, + {Name: "d", Value: "d1"}, + {Name: "d", Value: "d2"}, + {Name: "y", Value: "y1"}, + } + req := newRequest("", nil, cookies, nil) + recv := new(Recv) + + err := DefaultBinder.Bind(req.Req, nil, recv) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, []string{"a2"}, (**recv.X).A) + assert.DeepEqual(t, "b1", (**recv.X).B) + assert.DeepEqual(t, []string{"c2"}, *(**recv.X).C) + assert.DeepEqual(t, "d2", *(**recv.X).D) + assert.DeepEqual(t, "y1", recv.Y) + assert.DeepEqual(t, (*string)(nil), recv.Z) +} + +func TestCookieNum(t *testing.T) { + type Recv struct { + X **struct { + A []int `cookie:"a"` + B int32 `cookie:"b"` + C *[]uint16 `cookie:"c,required"` + D *float32 `cookie:"d"` + } + Y bool `cookie:"y,required"` + Z *int64 `cookie:"z"` + } + cookies := []*http.Cookie{ + {Name: "a", Value: "11"}, + {Name: "b", Value: "21"}, + {Name: "c", Value: "31"}, + {Name: "d", Value: "41"}, + {Name: "y", Value: "t"}, + } + req := newRequest("", nil, cookies, nil) + recv := new(Recv) + + err := DefaultBinder.Bind(req.Req, nil, recv) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, []int{11}, (**recv.X).A) + assert.DeepEqual(t, int32(21), (**recv.X).B) + assert.DeepEqual(t, &[]uint16{31}, (**recv.X).C) + assert.DeepEqual(t, float32(41), *(**recv.X).D) + assert.DeepEqual(t, true, recv.Y) + assert.DeepEqual(t, (*int64)(nil), recv.Z) +} + +func TestFormString(t *testing.T) { + type Recv struct { + X **struct { + A []string `form:"a"` + B string `form:"b"` + C *[]string `form:"c,required"` + D *string `form:"d"` + } + Y string `form:"y,required"` + Z *string `form:"z"` + F *multipart.FileHeader `form:"F1"` + F1 multipart.FileHeader + Fs []multipart.FileHeader `form:"F1"` + Fs1 []*multipart.FileHeader `form:"F1"` + } + values := make(url.Values) + values.Add("a", "a1") + values.Add("a", "a2") + values.Add("b", "b1") + values.Add("c", "c1") + values.Add("c", "c2") + values.Add("d", "d1") + values.Add("d", "d2") + values.Add("y", "y1") + for i, f := range []files{{ + "F1": []file{ + newFile("txt", strings.NewReader("0123")), + }, + }} { + contentType, bodyReader := newFormBody2(values, f) + header := make(http.Header) + header.Set("Content-Type", contentType) + req := newRequest("", header, nil, bodyReader) + recv := new(Recv) + err := DefaultBinder.Bind(req.Req, nil, recv) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, []string{"a1", "a2"}, (**recv.X).A) + assert.DeepEqual(t, "b1", (**recv.X).B) + assert.DeepEqual(t, []string{"c1", "c2"}, *(**recv.X).C) + assert.DeepEqual(t, "d1", *(**recv.X).D) + assert.DeepEqual(t, "y1", recv.Y) + assert.DeepEqual(t, (*string)(nil), recv.Z) + t.Logf("[%d] F: %#v", i, recv.F) + t.Logf("[%d] F1: %#v", i, recv.F1) + t.Logf("[%d] Fs: %#v", i, recv.Fs) + t.Logf("[%d] Fs1: %#v", i, recv.Fs1) + if len(recv.Fs1) > 0 { + t.Logf("[%d] Fs1[0]: %#v", i, recv.Fs1[0]) + } + } +} + +func TestFormNum(t *testing.T) { + type Recv struct { + X **struct { + A []int `form:"a"` + B int32 `form:"b"` + C *[]uint16 `form:"c,required"` + D *float32 `form:"d"` + } + Y bool `form:"y,required"` + Z *int64 `form:"z"` + } + values := make(url.Values) + values.Add("a", "11") + values.Add("a", "12") + values.Add("b", "-21") + values.Add("c", "31") + values.Add("c", "32") + values.Add("d", "41") + values.Add("d", "42") + values.Add("y", "1") + for _, f := range []files{nil, { + "f1": []file{ + newFile("txt", strings.NewReader("f11 text.")), + }, + }} { + contentType, bodyReader := newFormBody2(values, f) + header := make(http.Header) + header.Set("Content-Type", contentType) + req := newRequest("", header, nil, bodyReader) + recv := new(Recv) + + err := DefaultBinder.Bind(req.Req, nil, recv) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, []int{11, 12}, (**recv.X).A) + assert.DeepEqual(t, int32(-21), (**recv.X).B) + assert.DeepEqual(t, &[]uint16{31, 32}, (**recv.X).C) + assert.DeepEqual(t, float32(41), *(**recv.X).D) + assert.DeepEqual(t, true, recv.Y) + assert.DeepEqual(t, (*int64)(nil), recv.Z) + } +} + +// FIXME: content-type 裁剪 +func TestJSON(t *testing.T) { + type metric string + type count int32 + type ZS struct { + Z *int64 + } + type Recv struct { + X **struct { + A []string `json:"a"` + B int32 `json:""` + C *[]uint16 `json:",required"` + D *float32 `json:"d"` + E metric `json:"e"` + F count `json:"f"` + M map[string]string `json:"m"` + } + Y string `json:"y,required"` + ZS + } + + bodyReader := strings.NewReader(`{ + "X": { + "a": ["a1","a2"], + "B": 21, + "C": [31,32], + "d": 41, + "e": "qps", + "f": 100, + "m": {"a":"x"} + }, + "Z": 6 + }`) + + header := make(http.Header) + header.Set("Content-Type", "application/json") + req := newRequest("", header, nil, bodyReader) + recv := new(Recv) + + err := DefaultBinder.Bind(req.Req, nil, recv) + if err != nil { + t.Error(err) + } + //assert.DeepEqual(t, &binding.Error{ErrType: "binding", FailField: "y", Msg: "missing required parameter"}, err) + assert.DeepEqual(t, []string{"a1", "a2"}, (**recv.X).A) + assert.DeepEqual(t, int32(21), (**recv.X).B) + assert.DeepEqual(t, &[]uint16{31, 32}, (**recv.X).C) + assert.DeepEqual(t, float32(41), *(**recv.X).D) + assert.DeepEqual(t, metric("qps"), (**recv.X).E) + assert.DeepEqual(t, count(100), (**recv.X).F) + assert.DeepEqual(t, map[string]string{"a": "x"}, (**recv.X).M) + assert.DeepEqual(t, "", recv.Y) + assert.DeepEqual(t, (int64)(6), *recv.Z) +} + +// FIXME: 非 stuct 绑定,直接走 unmarshal +func TestNonstruct(t *testing.T) { + bodyReader := strings.NewReader(`{ + "X": { + "a": ["a1","a2"], + "B": 21, + "C": [31,32], + "d": 41, + "e": "qps", + "f": 100 + }, + "Z": 6 + }`) + + header := make(http.Header) + header.Set("Content-Type", "application/json") + req := newRequest("", header, nil, bodyReader) + var recv interface{} + + err := DefaultBinder.Bind(req.Req, nil, &recv) + if err != nil { + t.Error(err) + } + b, err := json.Marshal(recv) + if err != nil { + t.Error(err) + } + t.Logf("%s", b) + + bodyReader = strings.NewReader("b=334ddddd&token=yoMba34uspjVQEbhflgTRe2ceeDFUK32&type=url_verification") + header.Set("Content-Type", "application/x-www-form-urlencoded; charset=utf-8") + req = newRequest("", header, nil, bodyReader) + recv = nil + err = DefaultBinder.Bind(req.Req, nil, recv) + if err != nil { + t.Error(err) + } + b, err = json.Marshal(recv) + if err != nil { + t.Error(err) + } + t.Logf("%s", b) +} + +type testPathParams struct{} + +func (testPathParams) Get(name string) (string, bool) { + switch name { + case "a": + return "a1", true + case "b": + return "-21", true + case "c": + return "31", true + case "d": + return "41", true + case "y": + return "y1", true + case "name": + return "henrylee2cn", true + default: + return "", false + } +} + +func TestPath(t *testing.T) { + type Recv struct { + X **struct { + A []string `path:"a"` + B int32 `path:"b"` + C *[]uint16 `path:"c,required"` + D *float32 `path:"d"` + } + Y string `path:"y,required"` + Z *int64 + } + + req := newRequest("", nil, nil, nil) + recv := new(Recv) + + err := DefaultBinder.Bind(req.Req, new(testPathParams), recv) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, []string{"a1"}, (**recv.X).A) + assert.DeepEqual(t, int32(-21), (**recv.X).B) + assert.DeepEqual(t, &[]uint16{31}, (**recv.X).C) + assert.DeepEqual(t, float32(41), *(**recv.X).D) + assert.DeepEqual(t, "y1", recv.Y) + assert.DeepEqual(t, (*int64)(nil), recv.Z) +} + +type testPathParams2 struct{} + +func (testPathParams2) Get(name string) (string, bool) { + switch name { + case "e": + return "123", true + default: + return "", false + } +} + +// FIXME: 复杂类型的默认值 +func TestDefault(t *testing.T) { + type S struct { + SS string `json:"ss"` + } + + type Recv struct { + X **struct { + A []string `path:"a" json:"a"` + B int32 `path:"b" default:"32"` + C bool `json:"c" default:"true"` + D *float32 `default:"123.4"` + //E *[]string `default:"['a','b','c','d,e,f']"` + //F map[string]string `default:"{'a':'\"\\'1','\"b':'c','c':'2'}"` + //G map[string]int64 `default:"{'a':1,'b':2,'c':3}"` + //H map[string]float64 `default:"{'a':0.1,'b':1.2,'c':2.3}"` + //I map[string]float64 `default:"{'\"a\"':0.1,'b':1.2,'c':2.3}"` + Empty string `default:""` + Null string `default:""` + CommaSpace string `default:",a:c "` + Dash string `default:"-"` + // InvalidInt int `default:"abc"` + // InvalidMap map[string]string `default:"abc"` + } + Y string `json:"y" default:"y1"` + Z int64 + W string `json:"w"` + //V []int64 `json:"u" default:"[1,2,3]"` + //U []float32 `json:"u" default:"[1.1,2,3]"` + T *string `json:"t" default:"t1"` + //S S `default:"{'ss':'test'}"` + //O *S `default:"{'ss':'test2'}"` + //Complex map[string][]map[string][]int64 `default:"{'a':[{'aa':[1,2,3], 'bb':[4,5]}],'b':[{}]}"` + } + + bodyReader := strings.NewReader(`{ + "X": { + "a": ["a1","a2"] + }, + "Z": 6 + }`) + + // var nilMap map[string]string + header := make(http.Header) + header.Set("Content-Type", "application/json") + req := newRequest("", header, nil, bodyReader) + recv := new(Recv) + + err := DefaultBinder.Bind(req.Req, new(testPathParams2), recv) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, []string{"a1", "a2"}, (**recv.X).A) + assert.DeepEqual(t, int32(32), (**recv.X).B) + assert.DeepEqual(t, true, (**recv.X).C) + assert.DeepEqual(t, float32(123.4), *(**recv.X).D) + //assert.DeepEqual(t, []string{"a", "b", "c", "d,e,f"}, *(**recv.X).E) + //assert.DeepEqual(t, map[string]string{"a": "\"'1", "\"b": "c", "c": "2"}, (**recv.X).F) + //assert.DeepEqual(t, map[string]int64{"a": 1, "b": 2, "c": 3}, (**recv.X).G) + //assert.DeepEqual(t, map[string]float64{"a": 0.1, "b": 1.2, "c": 2.3}, (**recv.X).H) + //assert.DeepEqual(t, map[string]float64{"\"a\"": 0.1, "b": 1.2, "c": 2.3}, (**recv.X).I) + assert.DeepEqual(t, "", (**recv.X).Empty) + assert.DeepEqual(t, "", (**recv.X).Null) + assert.DeepEqual(t, ",a:c ", (**recv.X).CommaSpace) + assert.DeepEqual(t, "-", (**recv.X).Dash) + // assert.DeepEqual(t, 0, (**recv.X).InvalidInt) + // assert.DeepEqual(t, nilMap, (**recv.X).InvalidMap) + assert.DeepEqual(t, "y1", recv.Y) + assert.DeepEqual(t, "t1", *recv.T) + assert.DeepEqual(t, int64(6), recv.Z) + //assert.DeepEqual(t, []int64{1, 2, 3}, recv.V) + //assert.DeepEqual(t, []float32{1.1, 2, 3}, recv.U) + //assert.DeepEqual(t, S{SS: "test"}, recv.S) + //assert.DeepEqual(t, &S{SS: "test2"}, recv.O) + //assert.DeepEqual(t, map[string][]map[string][]int64{"a": {{"aa": {1, 2, 3}, "bb": []int64{4, 5}}}, "b": {map[string][]int64{}}}, recv.Complex) +} + +// FIXME: query 和 form getter 的优先级 +func TestAuto(t *testing.T) { + type Recv struct { + A string + B string + C string + D string `query:"D,required" form:"D,required"` + E string `cookie:"e" json:"e"` + } + query := make(url.Values) + query.Add("A", "a") + query.Add("B", "b") + query.Add("C", "c") + query.Add("D", "d-from-query") + contentType, bodyReader, err := newJSONBody(map[string]string{"e": "e-from-jsonbody"}) + if err != nil { + t.Error(err) + } + header := make(http.Header) + header.Set("Content-Type", contentType) + req := newRequest("http://localhost/?"+query.Encode(), header, []*http.Cookie{ + {Name: "e", Value: "e-from-cookie"}, + }, bodyReader) + recv := new(Recv) + + err = DefaultBinder.Bind(req.Req, nil, recv) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, "a", recv.A) + assert.DeepEqual(t, "b", recv.B) + assert.DeepEqual(t, "c", recv.C) + assert.DeepEqual(t, "d-from-query", recv.D) + assert.DeepEqual(t, "e-from-cookie", recv.E) + + query = make(url.Values) + query.Add("D", "d-from-query") + form := make(url.Values) + form.Add("B", "b") + form.Add("C", "c") + form.Add("D", "d-from-form") + contentType, bodyReader = newFormBody2(form, nil) + header = make(http.Header) + header.Set("Content-Type", contentType) + req = newRequest("http://localhost/?"+query.Encode(), header, nil, bodyReader) + recv = new(Recv) + err = DefaultBinder.Bind(req.Req, nil, recv) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, "", recv.A) + assert.DeepEqual(t, "b", recv.B) + assert.DeepEqual(t, "c", recv.C) + assert.DeepEqual(t, "d-from-form", recv.D) +} + +// FIXME: 自定义验证函数 & TIME 类型内置 +func TestTypeUnmarshal(t *testing.T) { + type Recv struct { + A time.Time `form:"t1"` + B *time.Time `query:"t2"` + C []time.Time `query:"t2"` + } + query := make(url.Values) + query.Add("t2", "2019-09-04T14:05:24+08:00") + query.Add("t2", "2019-09-04T18:05:24+08:00") + form := make(url.Values) + form.Add("t1", "2019-09-03T18:05:24+08:00") + contentType, bodyReader := newFormBody2(form, nil) + header := make(http.Header) + header.Set("Content-Type", contentType) + req := newRequest("http://localhost/?"+query.Encode(), header, nil, bodyReader) + recv := new(Recv) + + err := DefaultBinder.Bind(req.Req, nil, recv) + if err != nil { + t.Error(err) + } + t1, err := time.Parse(time.RFC3339, "2019-09-03T18:05:24+08:00") + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, t1, recv.A) + t21, err := time.Parse(time.RFC3339, "2019-09-04T14:05:24+08:00") + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, t21, *recv.B) + t22, err := time.Parse(time.RFC3339, "2019-09-04T18:05:24+08:00") + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, []time.Time{t21, t22}, recv.C) + t.Logf("%v", recv) +} + +// FIXME: JSON required 校验 +func TestOption(t *testing.T) { + type Recv struct { + X *struct { + C int `json:"c,required"` + D int `json:"d"` + } `json:"X"` + Y string `json:"y"` + } + header := make(http.Header) + header.Set("Content-Type", "application/json") + + bodyReader := strings.NewReader(`{ + "X": { + "c": 21, + "d": 41 + }, + "y": "y1" + }`) + req := newRequest("", header, nil, bodyReader) + recv := new(Recv) + + err := DefaultBinder.Bind(req.Req, nil, recv) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, 21, recv.X.C) + assert.DeepEqual(t, 41, recv.X.D) + assert.DeepEqual(t, "y1", recv.Y) + + bodyReader = strings.NewReader(`{ + "X": { + }, + "y": "y1" + }`) + req = newRequest("", header, nil, bodyReader) + recv = new(Recv) + err = DefaultBinder.Bind(req.Req, nil, recv) + //assert.DeepEqual(t, err.Error(), "binding: expr_path=X.c, cause=missing required parameter") + assert.DeepEqual(t, 0, recv.X.C) + assert.DeepEqual(t, 0, recv.X.D) + assert.DeepEqual(t, "y1", recv.Y) + + bodyReader = strings.NewReader(`{ + "y": "y1" + }`) + req = newRequest("", header, nil, bodyReader) + recv = new(Recv) + err = DefaultBinder.Bind(req.Req, nil, recv) + if err != nil { + t.Error(err) + } + assert.True(t, recv.X == nil) + assert.DeepEqual(t, "y1", recv.Y) + + type Recv2 struct { + X *struct { + C int `json:"c,required"` + D int `json:"d"` + } `json:"X,required"` + Y string `json:"y"` + } + bodyReader = strings.NewReader(`{ + "y": "y1" + }`) + req = newRequest("", header, nil, bodyReader) + recv2 := new(Recv2) + err = DefaultBinder.Bind(req.Req, nil, recv2) + //assert.DeepEqual(t, err.Error(), "binding: expr_path=X, cause=missing required parameter") + assert.True(t, recv2.X == nil) + assert.DeepEqual(t, "y1", recv2.Y) +} + +func newRequest(u string, header http.Header, cookies []*http.Cookie, bodyReader io.Reader) *mockRequest { + if header == nil { + header = make(http.Header) + } + var method = "GET" + var body []byte + if bodyReader != nil { + body, _ = ioutil.ReadAll(bodyReader) + method = "POST" + } + if u == "" { + u = "http://localhost" + } + req := newMockRequest() + req.SetRequestURI(u) + for k, v := range header { + for _, val := range v { + req.Req.Header.Add(k, val) + } + } + if len(body) != 0 { + req.SetBody(body) + req.Req.Header.SetContentLength(len(body)) + } + req.Req.SetMethod(method) + for _, c := range cookies { + req.Req.Header.SetCookie(c.Name, c.Value) + + } + return req +} + +func TestQueryStringIssue(t *testing.T) { + type Timestamp struct { + time.Time + } + type Recv struct { + Name *string `query:"name"` + T *Timestamp `query:"t"` + } + req := newRequest("http://localhost:8080/?name=test", nil, nil, nil) + recv := new(Recv) + + err := DefaultBinder.Bind(req.Req, nil, recv) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, "test", *recv.Name) + assert.DeepEqual(t, (*Timestamp)(nil), recv.T) +} + +func TestQueryTypes(t *testing.T) { + type metric string + type count int32 + type metrics []string + type filter struct { + Col1 string + } + + type Recv struct { + A metric + B count + C *count + D metrics `query:"D,required" form:"D,required"` + E metric `cookie:"e" json:"e"` + F filter `json:"f"` + } + query := make(url.Values) + query.Add("A", "qps") + query.Add("B", "123") + query.Add("C", "321") + query.Add("D", "dau") + query.Add("D", "dnu") + contentType, bodyReader, err := newJSONBody( + map[string]interface{}{ + "e": "e-from-jsonbody", + "f": filter{Col1: "abc"}, + }, + ) + if err != nil { + t.Error(err) + } + header := make(http.Header) + header.Set("Content-Type", contentType) + req := newRequest("http://localhost/?"+query.Encode(), header, []*http.Cookie{ + {Name: "e", Value: "e-from-cookie"}, + }, bodyReader) + recv := new(Recv) + + err = DefaultBinder.Bind(req.Req, nil, recv) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, metric("qps"), recv.A) + assert.DeepEqual(t, count(123), recv.B) + assert.DeepEqual(t, count(321), *recv.C) + assert.DeepEqual(t, metrics{"dau", "dnu"}, recv.D) + assert.DeepEqual(t, metric("e-from-cookie"), recv.E) + assert.DeepEqual(t, filter{Col1: "abc"}, recv.F) +} + +func TestNoTagIssue(t *testing.T) { + type x int + type T struct { + x + x2 x + a int + B int + } + req := newRequest("http://localhost:8080/?x=11&x2=12&a=1&B=2", nil, nil, nil) + recv := new(T) + + err := DefaultBinder.Bind(req.Req, nil, recv) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, x(0), recv.x) + assert.DeepEqual(t, x(0), recv.x2) + assert.DeepEqual(t, 0, recv.a) + assert.DeepEqual(t, 2, recv.B) +} + +// DIFF: go-tagexpr 会对保留 t.Q的结构体信息,而目前的实现不会 t.Q 做特殊处理,会直接拆开。有需求也可以加上 +func TestRegTypeUnmarshal(t *testing.T) { + type Q struct { + A int + B string + } + type T struct { + Q Q `query:"q"` + Qs []*Q `query:"qs"` + } + var values = url.Values{} + b, err := json.Marshal(Q{A: 2, B: "y"}) + if err != nil { + t.Error(err) + } + values.Add("q", string(b)) + bs, err := json.Marshal([]Q{{A: 1, B: "x"}, {A: 2, B: "y"}}) + values.Add("qs", string(bs)) + req := newRequest("http://localhost:8080/?"+values.Encode(), nil, nil, nil) + recv := new(T) + + err = DefaultBinder.Bind(req.Req, nil, recv) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, 2, recv.Q.A) + assert.DeepEqual(t, "y", recv.Q.B) + assert.DeepEqual(t, 1, recv.Qs[0].A) + assert.DeepEqual(t, "x", recv.Qs[0].B) + assert.DeepEqual(t, 2, recv.Qs[1].A) + assert.DeepEqual(t, "y", recv.Qs[1].B) +} + +//func TestPathnameBUG(t *testing.T) { +// type Currency struct { +// CurrencyName *string `form:"currency_name,required" json:"currency_name,required" protobuf:"bytes,1,req,name=currency_name,json=currencyName" query:"currency_name,required"` +// CurrencySymbol *string `form:"currency_symbol,required" json:"currency_symbol,required" protobuf:"bytes,2,req,name=currency_symbol,json=currencySymbol" query:"currency_symbol,required"` +// SymbolPosition *int32 `form:"symbol_position,required" json:"symbol_position,required" protobuf:"varint,3,req,name=symbol_position,json=symbolPosition" query:"symbol_position,required"` +// DecimalPlaces *int32 `form:"decimal_places,required" json:"decimal_places,required" protobuf:"varint,4,req,name=decimal_places,json=decimalPlaces" query:"decimal_places,required"` // 56x56 +// DecimalSymbol *string `form:"decimal_symbol,required" json:"decimal_symbol,required" protobuf:"bytes,5,req,name=decimal_symbol,json=decimalSymbol" query:"decimal_symbol,required"` +// Separator *string `form:"separator,required" json:"separator,required" protobuf:"bytes,6,req,name=separator" query:"separator,required"` +// SeparatorIndex *string `form:"separator_index,required" json:"separator_index,required" protobuf:"bytes,7,req,name=separator_index,json=separatorIndex" query:"separator_index,required"` +// Between *string `form:"between,required" json:"between,required" protobuf:"bytes,8,req,name=between" query:"between,required"` +// MinPrice *string `form:"min_price" json:"min_price,omitempty" protobuf:"bytes,9,opt,name=min_price,json=minPrice" query:"min_price"` +// MaxPrice *string `form:"max_price" json:"max_price,omitempty" protobuf:"bytes,10,opt,name=max_price,json=maxPrice" query:"max_price"` +// } +// +// type CurrencyData struct { +// Amount *string `form:"amount,required" json:"amount,required" protobuf:"bytes,1,req,name=amount" query:"amount,required"` +// Currency *Currency `form:"currency,required" json:"currency,required" protobuf:"bytes,2,req,name=currency" query:"currency,required"` +// } +// +// type ExchangeCurrencyRequest struct { +// PromotionRegion *string `form:"promotion_region,required" json:"promotion_region,required" protobuf:"bytes,1,req,name=promotion_region,json=promotionRegion" query:"promotion_region,required"` +// Currency *CurrencyData `form:"currency,required" json:"currency,required" protobuf:"bytes,2,req,name=currency" query:"currency,required"` +// Version *int32 `json:"version,omitempty" path:"version" protobuf:"varint,100,opt,name=version"` +// } +// +// z := &ExchangeCurrencyRequest{} +// v := ameda.InitSampleValue(reflect.TypeOf(z), 10).Interface().(*ExchangeCurrencyRequest) +// b, err := json.MarshalIndent(v, "", " ") +// t.Log(string(b)) +// header := make(http.Header) +// header.Set("Content-Type", "application/json;charset=utf-8") +// req := newRequest("http://localhost", header, nil, bytes.NewReader(b)) +// recv := new(ExchangeCurrencyRequest) +// +// err = DefaultBinder.Bind(req.Req, nil, recv) +// if err != nil { +// +// assert.DeepEqual(t, v, recv) +//} + +// FIXME: json unmarshal 后其他 required 没必要 校验 required +func TestPathnameBUG2(t *testing.T) { + type CurrencyData struct { + z int + Amount *string `form:"amount,required" json:"amount,required" protobuf:"bytes,1,req,name=amount" query:"amount,required"` + Name *string `form:"name,required" json:"name,required" protobuf:"bytes,2,req,name=name" query:"name,required"` + Symbol *string `form:"symbol" json:"symbol,omitempty" protobuf:"bytes,3,opt,name=symbol" query:"symbol"` + } + type TimeRange struct { + z int + StartTime *int64 `form:"start_time,required" json:"start_time,required" protobuf:"varint,1,req,name=start_time,json=startTime" query:"start_time,required"` + EndTime *int64 `form:"end_time,required" json:"end_time,required" protobuf:"varint,2,req,name=end_time,json=endTime" query:"end_time,required"` + } + type CreateFreeShippingRequest struct { + z int + PromotionName *string `form:"promotion_name,required" json:"promotion_name,required" protobuf:"bytes,1,req,name=promotion_name,json=promotionName" query:"promotion_name,required"` + PromotionRegion *string `form:"promotion_region,required" json:"promotion_region,required" protobuf:"bytes,2,req,name=promotion_region,json=promotionRegion" query:"promotion_region,required"` + TimeRange *TimeRange `form:"time_range,required" json:"time_range,required" protobuf:"bytes,3,req,name=time_range,json=timeRange" query:"time_range,required"` + PromotionBudget *CurrencyData `form:"promotion_budget,required" json:"promotion_budget,required" protobuf:"bytes,4,req,name=promotion_budget,json=promotionBudget" query:"promotion_budget,required"` + Loaded_SellerIds []string `form:"loaded_Seller_ids" json:"loaded_Seller_ids,omitempty" protobuf:"bytes,5,rep,name=loaded_Seller_ids,json=loadedSellerIds" query:"loaded_Seller_ids"` + Version *int32 `json:"version,omitempty" path:"version" protobuf:"varint,100,opt,name=version"` + } + + // z := &CreateFreeShippingRequest{} + // v := ameda.InitSampleValue(reflect.TypeOf(z), 10).Interface().(*CreateFreeShippingRequest) + // b, err := json.MarshalIndent(v, "", " ") + // t.Log(string(b)) + b := []byte(`{ + "promotion_name": "mu", + "promotion_region": "ID", + "time_range": { + "start_time": 1616420139, + "end_time": 1616520139 + }, + "promotion_budget": { + "amount":"10000000", + "name":"USD", + "symbol":"$" + }, + "loaded_Seller_ids": [ + "7493989780026655762","11111","111212121" + ] +}`) + var v = new(CreateFreeShippingRequest) + err := json.Unmarshal(b, v) + if err != nil { + t.Error(err) + } + + header := make(http.Header) + header.Set("Content-Type", "application/json;charset=utf-8") + req := newRequest("http://localhost", header, nil, bytes.NewReader(b)) + recv := new(CreateFreeShippingRequest) + + err = DefaultBinder.Bind(req.Req, nil, recv) + if err != nil { + t.Error(err) + } + + assert.DeepEqual(t, v, recv) +} + +// FIXME: json unmarshal 后的其他 tag 的 required 的校验 +func TestRequiredBUG(t *testing.T) { + type Currency struct { + currencyName *string `form:"currency_name,required" json:"currency_name,required" protobuf:"bytes,1,req,name=currency_name,json=currencyName" query:"currency_name,required"` + CurrencySymbol *string `form:"currency_symbol,required" json:"currency_symbol,required" protobuf:"bytes,2,req,name=currency_symbol,json=currencySymbol" query:"currency_symbol,required"` + } + + type CurrencyData struct { + Amount *string `form:"amount,required" json:"amount,required" protobuf:"bytes,1,req,name=amount" query:"amount,required"` + Slice []*Currency `form:"slice,required" json:"slice,required" protobuf:"bytes,2,req,name=slice" query:"slice,required"` + Map map[string]*Currency `form:"map,required" json:"map,required" protobuf:"bytes,2,req,name=map" query:"map,required"` + } + + type ExchangeCurrencyRequest struct { + PromotionRegion *string `form:"promotion_region,required" json:"promotion_region,required" protobuf:"bytes,1,req,name=promotion_region,json=promotionRegion" query:"promotion_region,required"` + Currency *CurrencyData `form:"currency,required" json:"currency,required" protobuf:"bytes,2,req,name=currency" query:"currency,required"` + } + + z := &ExchangeCurrencyRequest{} + // v := ameda.InitSampleValue(reflect.TypeOf(z), 10).Interface().(*ExchangeCurrencyRequest) + b := []byte(`{ + "promotion_region": "?", + "currency": { + "amount": "?", + "slice": [ + { + "currency_symbol": "?" + } + ], + "map": { + "?": { + "currency_name": "?" + } + } + } + }`) + json.Unmarshal(b, z) + header := make(http.Header) + header.Set("Content-Type", "application/json;charset=utf-8") + req := newRequest("http://localhost", header, nil, bytes.NewReader(b)) + recv := new(ExchangeCurrencyRequest) + + err := DefaultBinder.Bind(req.Req, nil, recv) + assert.DeepEqual(t, err.Error(), "validating: expr_path=Currency.Slice[0].currencyName, cause=invalid") + assert.DeepEqual(t, z, recv) +} + +func TestIssue25(t *testing.T) { + type Recv struct { + A string + } + header := make(http.Header) + header.Set("A", "from header") + cookies := []*http.Cookie{ + {Name: "A", Value: "from cookie"}, + } + req := newRequest("/1", header, cookies, nil) + recv := new(Recv) + + err := DefaultBinder.Bind(req.Req, nil, recv) + if err != nil { + t.Error(err) + } + //assert.DeepEqual(t, "from cookie", recv.A) + + header2 := make(http.Header) + header2.Set("A", "from header") + cookies2 := []*http.Cookie{} + req2 := newRequest("/2", header2, cookies2, nil) + recv2 := new(Recv) + err2 := DefaultBinder.Bind(req2.Req, nil, recv2) + if err2 != nil { + t.Error(err2) + } + assert.DeepEqual(t, "from header", recv2.A) +} + +func TestIssue26(t *testing.T) { + type Recv struct { + Type string `json:"type,required" vd:"($=='update_target_threshold' && (TargetThreshold)$!='-1') || ($=='update_status' && (Status)$!='-1')"` + RuleName string `json:"rule_name,required" vd:"regexp('^rule[0-9]+$')"` + TargetThreshold string `json:"target_threshold" vd:"regexp('^-?[0-9]+(\\.[0-9]+)?$')"` + Status string `json:"status" vd:"$=='0' || $=='1'"` + Operator string `json:"operator,required" vd:"len($)>0"` + } + + b := []byte(`{ + "status": "1", + "adv": "11520", + "target_deep_external_action": "39", + "package": "test.bytedance.com", + "previous_target_threshold": "0.6", + "deep_external_action": "675", + "rule_name": "rule2", + "deep_bid_type": "54", + "modify_time": "2021-08-24:14:35:20", + "aid": "111", + "operator": "yanghaoze", + "external_action": "76", + "target_threshold": "0.1", + "type": "update_status" +}`) + + recv := new(Recv) + err := json.Unmarshal(b, recv) + if err != nil { + t.Error(err) + } + + header := make(http.Header) + header.Set("Content-Type", "application/json") + header.Set("A", "from header") + cookies := []*http.Cookie{ + {Name: "A", Value: "from cookie"}, + } + + req := newRequest("/1", header, cookies, bytes.NewReader(b)) + + recv2 := new(Recv) + err = DefaultBinder.Bind(req.Req, nil, recv2) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, recv, recv2) +} + +// FIXME: json unmarshal 后,默认值的问题 +func TestDefault2(t *testing.T) { + type Recv struct { + X **struct { + Dash string `default:"xxxx"` + } + } + bodyReader := strings.NewReader(`{ + "X": { + "Dash": "hello Dash" + } + }`) + header := make(http.Header) + header.Set("Content-Type", "application/json") + req := newRequest("", header, nil, bodyReader) + recv := new(Recv) + + err := DefaultBinder.Bind(req.Req, nil, recv) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, "hello Dash", (**recv.X).Dash) +} + +func newFormBody(values, files url.Values) (contentType string, bodyReader io.Reader, err error) { + if len(files) == 0 { + return "application/x-www-form-urlencoded", strings.NewReader(values.Encode()), nil + } + var rw = bytes.NewBuffer(make([]byte, 32*1024*len(files))) + var bodyWriter = multipart.NewWriter(rw) + var buf = make([]byte, 32*1024) + var fileWriter io.Writer + var f *os.File + for fieldName, postfiles := range files { + for _, fileName := range postfiles { + fileWriter, err = bodyWriter.CreateFormFile(fieldName, fileName) + if err != nil { + return + } + f, err = os.Open(fileName) + if err != nil { + return + } + _, err = io.CopyBuffer(fileWriter, f, buf) + f.Close() + if err != nil { + return + } + } + } + for k, v := range values { + for _, vv := range v { + bodyWriter.WriteField(k, vv) + } + } + bodyWriter.Close() + return bodyWriter.FormDataContentType(), rw, nil +} + +type ( + files map[string][]file + file interface { + Name() string + Read(p []byte) (n int, err error) + } +) + +func newFormBody2(values url.Values, files files) (contentType string, bodyReader io.Reader) { + if len(files) == 0 { + return "application/x-www-form-urlencoded", strings.NewReader(values.Encode()) + } + var pr, pw = io.Pipe() + var bodyWriter = multipart.NewWriter(pw) + var fileWriter io.Writer + var buf = make([]byte, 32*1024) + go func() { + for fieldName, postfiles := range files { + for _, file := range postfiles { + fileWriter, _ = bodyWriter.CreateFormFile(fieldName, file.Name()) + io.CopyBuffer(fileWriter, file, buf) + } + } + for k, v := range values { + for _, vv := range v { + bodyWriter.WriteField(k, vv) + } + } + bodyWriter.Close() + pw.Close() + }() + return bodyWriter.FormDataContentType(), pr +} + +func newFile(name string, bodyReader io.Reader) file { + return &fileReader{name, bodyReader} +} + +// fileReader file name and bytes. +type fileReader struct { + name string + bodyReader io.Reader +} + +func (f *fileReader) Name() string { + return f.name +} + +func (f *fileReader) Read(p []byte) (int, error) { + return f.bodyReader.Read(p) +} + +func newJSONBody(v interface{}) (contentType string, bodyReader io.Reader, err error) { + b, err := json.Marshal(v) + if err != nil { + return + } + return "application/json;charset=utf-8", bytes.NewReader(b), nil +} diff --git a/pkg/common/utils/utils.go b/pkg/common/utils/utils.go index b90132bb5..8cd76fe61 100644 --- a/pkg/common/utils/utils.go +++ b/pkg/common/utils/utils.go @@ -123,3 +123,12 @@ func NextLine(b []byte) ([]byte, []byte, error) { } return b[:n], b[nNext+1:], nil } + +func FilterContentType(content string) string { + for i, char := range content { + if char == ' ' || char == ';' { + return content[:i] + } + } + return content +} From 9ed986384c08ee8826a5049331cd41f8dfe4070f Mon Sep 17 00:00:00 2001 From: fgy Date: Thu, 11 May 2023 21:16:58 +0800 Subject: [PATCH 25/91] fix: diff between query and form --- pkg/app/server/binding/binder_test.go | 1 + .../server/binding/decoder/base_type_decoder.go | 2 +- pkg/app/server/binding/decoder/getter.go | 16 +--------------- .../server/binding/decoder/map_type_decoder.go | 2 +- .../server/binding/decoder/slice_type_decoder.go | 3 +-- pkg/app/server/binding/tagexpr_bind_test.go | 1 - 6 files changed, 5 insertions(+), 20 deletions(-) diff --git a/pkg/app/server/binding/binder_test.go b/pkg/app/server/binding/binder_test.go index 51c5505b2..15d9ff381 100644 --- a/pkg/app/server/binding/binder_test.go +++ b/pkg/app/server/binding/binder_test.go @@ -134,6 +134,7 @@ func TestBind_BaseType(t *testing.T) { assert.DeepEqual(t, "form", result.Form) } +// fixme: []byte 绑定 func TestBind_SliceType(t *testing.T) { type Req struct { ID *[]int `query:"id"` diff --git a/pkg/app/server/binding/decoder/base_type_decoder.go b/pkg/app/server/binding/decoder/base_type_decoder.go index aaba63ebf..0eb353914 100644 --- a/pkg/app/server/binding/decoder/base_type_decoder.go +++ b/pkg/app/server/binding/decoder/base_type_decoder.go @@ -128,7 +128,7 @@ func getBaseTypeTextDecoder(field reflect.StructField, index int, tagInfos []Tag case pathTag: tagInfos[idx].Getter = path case formTag: - tagInfos[idx].Getter = form + tagInfos[idx].Getter = postForm case queryTag: tagInfos[idx].Getter = query case cookieTag: diff --git a/pkg/app/server/binding/decoder/getter.go b/pkg/app/server/binding/decoder/getter.go index 01b8ebb83..bd638fd77 100644 --- a/pkg/app/server/binding/decoder/getter.go +++ b/pkg/app/server/binding/decoder/getter.go @@ -65,21 +65,7 @@ func path(req *bindRequest, params path1.PathParam, key string, defaultValue ... return } -func form(req *bindRequest, params path1.PathParam, key string, defaultValue ...string) (ret []string) { - if req.Query == nil { - req.Query = make(url.Values) - req.Req.URI().QueryArgs().VisitAll(func(queryKey, value []byte) { - keyStr := string(queryKey) - values := req.Query[keyStr] - values = append(values, string(value)) - req.Query[keyStr] = values - }) - } - ret = req.Query[key] - if len(ret) > 0 { - return - } - +func postForm(req *bindRequest, params path1.PathParam, key string, defaultValue ...string) (ret []string) { if req.Form == nil { req.Form = make(url.Values) req.Req.PostArgs().VisitAll(func(formKey, value []byte) { diff --git a/pkg/app/server/binding/decoder/map_type_decoder.go b/pkg/app/server/binding/decoder/map_type_decoder.go index d9090c931..c72e7d68a 100644 --- a/pkg/app/server/binding/decoder/map_type_decoder.go +++ b/pkg/app/server/binding/decoder/map_type_decoder.go @@ -110,7 +110,7 @@ func getMapTypeTextDecoder(field reflect.StructField, index int, tagInfos []TagI case pathTag: tagInfos[idx].Getter = path case formTag: - tagInfos[idx].Getter = form + tagInfos[idx].Getter = postForm case queryTag: tagInfos[idx].Getter = query case cookieTag: diff --git a/pkg/app/server/binding/decoder/slice_type_decoder.go b/pkg/app/server/binding/decoder/slice_type_decoder.go index 2164abbae..cb2b34e8e 100644 --- a/pkg/app/server/binding/decoder/slice_type_decoder.go +++ b/pkg/app/server/binding/decoder/slice_type_decoder.go @@ -104,7 +104,6 @@ func (d *sliceTypeFieldTextDecoder) Decode(req *bindRequest, params path1.PathPa reqValue.Field(d.index).Set(reflect.ValueOf([]byte(texts[0]))) return nil } - // slice need creating enough capacity field = reflect.MakeSlice(field.Type(), len(texts), len(texts)) } @@ -145,7 +144,7 @@ func getSliceFieldDecoder(field reflect.StructField, index int, tagInfos []TagIn case pathTag: tagInfos[idx].Getter = path case formTag: - tagInfos[idx].Getter = form + tagInfos[idx].Getter = postForm case queryTag: tagInfos[idx].Getter = query case cookieTag: diff --git a/pkg/app/server/binding/tagexpr_bind_test.go b/pkg/app/server/binding/tagexpr_bind_test.go index bf5650ecd..8764962e8 100644 --- a/pkg/app/server/binding/tagexpr_bind_test.go +++ b/pkg/app/server/binding/tagexpr_bind_test.go @@ -583,7 +583,6 @@ func TestDefault(t *testing.T) { //assert.DeepEqual(t, map[string][]map[string][]int64{"a": {{"aa": {1, 2, 3}, "bb": []int64{4, 5}}}, "b": {map[string][]int64{}}}, recv.Complex) } -// FIXME: query 和 form getter 的优先级 func TestAuto(t *testing.T) { type Recv struct { A string From f4cc3c4a6538f5646027381f54f7cc807b3191ea Mon Sep 17 00:00:00 2001 From: fgy Date: Fri, 12 May 2023 10:50:08 +0800 Subject: [PATCH 26/91] ci: add license for tag_expr_test --- pkg/app/server/binding/tagexpr_bind_test.go | 83 ++++++++++----------- 1 file changed, 39 insertions(+), 44 deletions(-) diff --git a/pkg/app/server/binding/tagexpr_bind_test.go b/pkg/app/server/binding/tagexpr_bind_test.go index 8764962e8..5c4308b78 100644 --- a/pkg/app/server/binding/tagexpr_bind_test.go +++ b/pkg/app/server/binding/tagexpr_bind_test.go @@ -1,3 +1,37 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * MIT License + * + * Copyright 2019 Bytedance Inc. All Rights Reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + * This file may have been modified by CloudWeGo authors. All CloudWeGo + * Modifications are Copyright 2022 CloudWeGo Authors + */ + package binding import ( @@ -348,7 +382,6 @@ func TestFormNum(t *testing.T) { } } -// FIXME: content-type 裁剪 func TestJSON(t *testing.T) { type metric string type count int32 @@ -403,48 +436,8 @@ func TestJSON(t *testing.T) { assert.DeepEqual(t, (int64)(6), *recv.Z) } -// FIXME: 非 stuct 绑定,直接走 unmarshal +// unsupport non-struct func TestNonstruct(t *testing.T) { - bodyReader := strings.NewReader(`{ - "X": { - "a": ["a1","a2"], - "B": 21, - "C": [31,32], - "d": 41, - "e": "qps", - "f": 100 - }, - "Z": 6 - }`) - - header := make(http.Header) - header.Set("Content-Type", "application/json") - req := newRequest("", header, nil, bodyReader) - var recv interface{} - - err := DefaultBinder.Bind(req.Req, nil, &recv) - if err != nil { - t.Error(err) - } - b, err := json.Marshal(recv) - if err != nil { - t.Error(err) - } - t.Logf("%s", b) - - bodyReader = strings.NewReader("b=334ddddd&token=yoMba34uspjVQEbhflgTRe2ceeDFUK32&type=url_verification") - header.Set("Content-Type", "application/x-www-form-urlencoded; charset=utf-8") - req = newRequest("", header, nil, bodyReader) - recv = nil - err = DefaultBinder.Bind(req.Req, nil, recv) - if err != nil { - t.Error(err) - } - b, err = json.Marshal(recv) - if err != nil { - t.Error(err) - } - t.Logf("%s", b) } type testPathParams struct{} @@ -507,6 +500,7 @@ func (testPathParams2) Get(name string) (string, bool) { } // FIXME: 复杂类型的默认值 +// 负责类型的默认值用 json unmarshal 做 func TestDefault(t *testing.T) { type S struct { SS string `json:"ss"` @@ -639,6 +633,7 @@ func TestAuto(t *testing.T) { } // FIXME: 自定义验证函数 & TIME 类型内置 +// 修改自定义绑定函数的实现 func TestTypeUnmarshal(t *testing.T) { type Recv struct { A time.Time `form:"t1"` @@ -716,7 +711,7 @@ func TestOption(t *testing.T) { req = newRequest("", header, nil, bodyReader) recv = new(Recv) err = DefaultBinder.Bind(req.Req, nil, recv) - //assert.DeepEqual(t, err.Error(), "binding: expr_path=X.c, cause=missing required parameter") + assert.DeepEqual(t, err.Error(), "binding: expr_path=X.c, cause=missing required parameter") assert.DeepEqual(t, 0, recv.X.C) assert.DeepEqual(t, 0, recv.X.D) assert.DeepEqual(t, "y1", recv.Y) @@ -746,7 +741,7 @@ func TestOption(t *testing.T) { req = newRequest("", header, nil, bodyReader) recv2 := new(Recv2) err = DefaultBinder.Bind(req.Req, nil, recv2) - //assert.DeepEqual(t, err.Error(), "binding: expr_path=X, cause=missing required parameter") + assert.DeepEqual(t, err.Error(), "binding: expr_path=X, cause=missing required parameter") assert.True(t, recv2.X == nil) assert.DeepEqual(t, "y1", recv2.Y) } From fb963434104249d01ca8b5aadca1f05ea11eba30 Mon Sep 17 00:00:00 2001 From: fgy Date: Fri, 12 May 2023 11:01:27 +0800 Subject: [PATCH 27/91] fix: typo --- pkg/app/server/binding/tagexpr_bind_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/app/server/binding/tagexpr_bind_test.go b/pkg/app/server/binding/tagexpr_bind_test.go index 5c4308b78..1d7ec69ca 100644 --- a/pkg/app/server/binding/tagexpr_bind_test.go +++ b/pkg/app/server/binding/tagexpr_bind_test.go @@ -436,7 +436,7 @@ func TestJSON(t *testing.T) { assert.DeepEqual(t, (int64)(6), *recv.Z) } -// unsupport non-struct +// unsupported non-struct func TestNonstruct(t *testing.T) { } From be72ddd7f5822630397ced2502c8615748fc7c65 Mon Sep 17 00:00:00 2001 From: fgy Date: Fri, 12 May 2023 14:49:00 +0800 Subject: [PATCH 28/91] fix: byte slice bind raw body --- .../binding/decoder/slice_type_decoder.go | 13 ++++-- pkg/app/server/binding/tagexpr_bind_test.go | 40 +------------------ 2 files changed, 11 insertions(+), 42 deletions(-) diff --git a/pkg/app/server/binding/decoder/slice_type_decoder.go b/pkg/app/server/binding/decoder/slice_type_decoder.go index cb2b34e8e..8a76ebbba 100644 --- a/pkg/app/server/binding/decoder/slice_type_decoder.go +++ b/pkg/app/server/binding/decoder/slice_type_decoder.go @@ -59,6 +59,7 @@ type sliceTypeFieldTextDecoder struct { func (d *sliceTypeFieldTextDecoder) Decode(req *bindRequest, params path1.PathParam, reqValue reflect.Value) error { var texts []string var defaultValue string + var bindRawBody bool for _, tagInfo := range d.tagInfos { if tagInfo.Skip || tagInfo.Key == jsonTag || tagInfo.Key == fileNameTag { continue @@ -66,6 +67,9 @@ func (d *sliceTypeFieldTextDecoder) Decode(req *bindRequest, params path1.PathPa if tagInfo.Key == headerTag { tagInfo.Value = utils.GetNormalizeHeaderKey(tagInfo.Value, req.Req.Header.IsDisableNormalizing()) } + if tagInfo.Key == rawBodyTag { + bindRawBody = true + } texts = tagInfo.Getter(req, params, tagInfo.Value) // todo: array/slice default value defaultValue = tagInfo.Default @@ -100,13 +104,14 @@ func (d *sliceTypeFieldTextDecoder) Decode(req *bindRequest, params path1.PathPa return fmt.Errorf("%q is not valid value for %s", texts, field.Type().String()) } } else { - if field.Type().Elem().Kind() == reflect.Uint8 { - reqValue.Field(d.index).Set(reflect.ValueOf([]byte(texts[0]))) - return nil - } // slice need creating enough capacity field = reflect.MakeSlice(field.Type(), len(texts), len(texts)) } + // raw_body && []byte 绑定 + if bindRawBody && field.Type().Elem().Kind() == reflect.Uint8 { + reqValue.Field(d.index).Set(reflect.ValueOf(req.Req.Body())) + return nil + } // handle internal multiple pointer, []**int var ptrDepth int diff --git a/pkg/app/server/binding/tagexpr_bind_test.go b/pkg/app/server/binding/tagexpr_bind_test.go index 1d7ec69ca..790cf5a03 100644 --- a/pkg/app/server/binding/tagexpr_bind_test.go +++ b/pkg/app/server/binding/tagexpr_bind_test.go @@ -42,7 +42,6 @@ import ( "mime/multipart" "net/http" "net/url" - "os" "strings" "testing" "time" @@ -711,7 +710,7 @@ func TestOption(t *testing.T) { req = newRequest("", header, nil, bodyReader) recv = new(Recv) err = DefaultBinder.Bind(req.Req, nil, recv) - assert.DeepEqual(t, err.Error(), "binding: expr_path=X.c, cause=missing required parameter") + //assert.DeepEqual(t, err.Error(), "binding: expr_path=X.c, cause=missing required parameter") assert.DeepEqual(t, 0, recv.X.C) assert.DeepEqual(t, 0, recv.X.D) assert.DeepEqual(t, "y1", recv.Y) @@ -741,7 +740,7 @@ func TestOption(t *testing.T) { req = newRequest("", header, nil, bodyReader) recv2 := new(Recv2) err = DefaultBinder.Bind(req.Req, nil, recv2) - assert.DeepEqual(t, err.Error(), "binding: expr_path=X, cause=missing required parameter") + //assert.DeepEqual(t, err.Error(), "binding: expr_path=X, cause=missing required parameter") assert.True(t, recv2.X == nil) assert.DeepEqual(t, "y1", recv2.Y) } @@ -1153,41 +1152,6 @@ func TestDefault2(t *testing.T) { assert.DeepEqual(t, "hello Dash", (**recv.X).Dash) } -func newFormBody(values, files url.Values) (contentType string, bodyReader io.Reader, err error) { - if len(files) == 0 { - return "application/x-www-form-urlencoded", strings.NewReader(values.Encode()), nil - } - var rw = bytes.NewBuffer(make([]byte, 32*1024*len(files))) - var bodyWriter = multipart.NewWriter(rw) - var buf = make([]byte, 32*1024) - var fileWriter io.Writer - var f *os.File - for fieldName, postfiles := range files { - for _, fileName := range postfiles { - fileWriter, err = bodyWriter.CreateFormFile(fieldName, fileName) - if err != nil { - return - } - f, err = os.Open(fileName) - if err != nil { - return - } - _, err = io.CopyBuffer(fileWriter, f, buf) - f.Close() - if err != nil { - return - } - } - } - for k, v := range values { - for _, vv := range v { - bodyWriter.WriteField(k, vv) - } - } - bodyWriter.Close() - return bodyWriter.FormDataContentType(), rw, nil -} - type ( files map[string][]file file interface { From 0dcc065278d747bb229e346563cfa46a6f6123bd Mon Sep 17 00:00:00 2001 From: fgy Date: Fri, 12 May 2023 17:50:17 +0800 Subject: [PATCH 29/91] fix: typo --- pkg/app/server/binding/binder_test.go | 57 +++++++++++++++++-- pkg/app/server/binding/config.go | 5 ++ pkg/app/server/binding/decoder/decoder.go | 4 +- .../binding/decoder/slice_type_decoder.go | 2 +- 4 files changed, 61 insertions(+), 7 deletions(-) diff --git a/pkg/app/server/binding/binder_test.go b/pkg/app/server/binding/binder_test.go index 15d9ff381..08c0a31db 100644 --- a/pkg/app/server/binding/binder_test.go +++ b/pkg/app/server/binding/binder_test.go @@ -134,7 +134,6 @@ func TestBind_BaseType(t *testing.T) { assert.DeepEqual(t, "form", result.Form) } -// fixme: []byte 绑定 func TestBind_SliceType(t *testing.T) { type Req struct { ID *[]int `query:"id"` @@ -696,13 +695,13 @@ func TestBind_IgnoreField(t *testing.T) { } req := newMockRequest(). - SetRequestURI("http://foobar.com?id=12"). - SetHeaders("H", "header"). - SetPostArg("f", "form"). + SetRequestURI("http://foobar.com?ID=12"). + SetHeaders("Header", "header"). + SetPostArg("Form", "form"). SetUrlEncodeContentType() var params param.Params params = append(params, param.Param{ - Key: "v", + Key: "Version", Value: "1", }) @@ -718,6 +717,54 @@ func TestBind_IgnoreField(t *testing.T) { assert.DeepEqual(t, "", result.Form) } +func TestBind_DefaultTag(t *testing.T) { + type Req struct { + Version int + ID int + Header string + Form string + } + type Req2 struct { + Version int + ID int + Header string + Form string + } + req := newMockRequest(). + SetRequestURI("http://foobar.com?ID=12"). + SetHeaders("Header", "header"). + SetPostArg("Form", "form"). + SetUrlEncodeContentType() + var params param.Params + params = append(params, param.Param{ + Key: "Version", + Value: "1", + }) + var result Req + err := DefaultBinder.Bind(req.Req, params, &result) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, 1, result.Version) + assert.DeepEqual(t, 12, result.ID) + assert.DeepEqual(t, "header", result.Header) + assert.DeepEqual(t, "form", result.Form) + + EnableDefaultTag(false) + defer func() { + EnableDefaultTag(false) + }() + result2 := Req2{} + err = DefaultBinder.Bind(req.Req, params, &result2) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, 0, result2.Version) + assert.DeepEqual(t, 0, result2.ID) + assert.DeepEqual(t, "", result2.Header) + assert.DeepEqual(t, "", result2.Form) +} + func Benchmark_Binding(b *testing.B) { type Req struct { Version string `path:"v"` diff --git a/pkg/app/server/binding/config.go b/pkg/app/server/binding/config.go index 3991d437e..23b4cce83 100644 --- a/pkg/app/server/binding/config.go +++ b/pkg/app/server/binding/config.go @@ -18,6 +18,7 @@ package binding import ( standardJson "encoding/json" + "github.com/cloudwego/hertz/pkg/app/server/binding/decoder" hjson "github.com/cloudwego/hertz/pkg/common/json" ) @@ -29,3 +30,7 @@ func ResetJSONUnmarshaler(fn func(data []byte, v interface{}) error) { func ResetStdJSONUnmarshaler() { ResetJSONUnmarshaler(standardJson.Unmarshal) } + +func EnableDefaultTag(b bool) { + decoder.EnableDefaultTag = b +} diff --git a/pkg/app/server/binding/decoder/decoder.go b/pkg/app/server/binding/decoder/decoder.go index 2cb4a550b..1ed453681 100644 --- a/pkg/app/server/binding/decoder/decoder.go +++ b/pkg/app/server/binding/decoder/decoder.go @@ -51,6 +51,8 @@ import ( "github.com/cloudwego/hertz/pkg/protocol" ) +var EnableDefaultTag = true + type bindRequest struct { Req *protocol.Request Query url.Values @@ -130,7 +132,7 @@ func getFieldDecoder(field reflect.StructField, index int, parentIdx []int) ([]f } fieldTagInfos := lookupFieldTags(field) - if len(fieldTagInfos) == 0 { + if len(fieldTagInfos) == 0 && EnableDefaultTag { fieldTagInfos = getDefaultFieldTags(field) } diff --git a/pkg/app/server/binding/decoder/slice_type_decoder.go b/pkg/app/server/binding/decoder/slice_type_decoder.go index 8a76ebbba..397fd6f3e 100644 --- a/pkg/app/server/binding/decoder/slice_type_decoder.go +++ b/pkg/app/server/binding/decoder/slice_type_decoder.go @@ -107,7 +107,7 @@ func (d *sliceTypeFieldTextDecoder) Decode(req *bindRequest, params path1.PathPa // slice need creating enough capacity field = reflect.MakeSlice(field.Type(), len(texts), len(texts)) } - // raw_body && []byte 绑定 + // raw_body && []byte binding if bindRawBody && field.Type().Elem().Kind() == reflect.Uint8 { reqValue.Field(d.index).Set(reflect.ValueOf(req.Req.Body())) return nil From 4dbaf6499ed63a2c2fdadbc8c15b81967e2ce7eb Mon Sep 17 00:00:00 2001 From: fgy Date: Tue, 16 May 2023 20:59:43 +0800 Subject: [PATCH 30/91] feat: add struct field --- pkg/app/server/binding/binder_test.go | 41 +++++- pkg/app/server/binding/config.go | 4 + pkg/app/server/binding/decoder/decoder.go | 11 +- .../binding/decoder/struct_type_decoder.go | 128 ++++++++++++++++++ 4 files changed, 182 insertions(+), 2 deletions(-) create mode 100644 pkg/app/server/binding/decoder/struct_type_decoder.go diff --git a/pkg/app/server/binding/binder_test.go b/pkg/app/server/binding/binder_test.go index 08c0a31db..367d278c7 100644 --- a/pkg/app/server/binding/binder_test.go +++ b/pkg/app/server/binding/binder_test.go @@ -752,7 +752,7 @@ func TestBind_DefaultTag(t *testing.T) { EnableDefaultTag(false) defer func() { - EnableDefaultTag(false) + EnableDefaultTag(true) }() result2 := Req2{} err = DefaultBinder.Bind(req.Req, params, &result2) @@ -765,6 +765,45 @@ func TestBind_DefaultTag(t *testing.T) { assert.DeepEqual(t, "", result2.Form) } +func TestBind_StructFieldResolve(t *testing.T) { + type Nested struct { + A int `query:"a" json:"a"` + B int `query:"b" json:"b"` + } + type Req struct { + N Nested `query:"n"` + } + + req := newMockRequest(). + SetRequestURI("http://foobar.com?n={\"a\":1,\"b\":2}"). + SetHeaders("Header", "header"). + SetPostArg("Form", "form"). + SetUrlEncodeContentType() + var result Req + EnableStructFieldResolve(true) + defer func() { + EnableDefaultTag(false) + }() + err := DefaultBinder.Bind(req.Req, nil, &result) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, 1, result.N.A) + assert.DeepEqual(t, 2, result.N.B) + + req = newMockRequest(). + SetRequestURI("http://foobar.com?n={\"a\":1,\"b\":2}&a=11&b=22"). + SetHeaders("Header", "header"). + SetPostArg("Form", "form"). + SetUrlEncodeContentType() + err = DefaultBinder.Bind(req.Req, nil, &result) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, 11, result.N.A) + assert.DeepEqual(t, 22, result.N.B) +} + func Benchmark_Binding(b *testing.B) { type Req struct { Version string `path:"v"` diff --git a/pkg/app/server/binding/config.go b/pkg/app/server/binding/config.go index 23b4cce83..124a1cef6 100644 --- a/pkg/app/server/binding/config.go +++ b/pkg/app/server/binding/config.go @@ -34,3 +34,7 @@ func ResetStdJSONUnmarshaler() { func EnableDefaultTag(b bool) { decoder.EnableDefaultTag = b } + +func EnableStructFieldResolve(b bool) { + decoder.EnableStructFieldResolve = b +} diff --git a/pkg/app/server/binding/decoder/decoder.go b/pkg/app/server/binding/decoder/decoder.go index 1ed453681..db364eed3 100644 --- a/pkg/app/server/binding/decoder/decoder.go +++ b/pkg/app/server/binding/decoder/decoder.go @@ -52,6 +52,7 @@ import ( ) var EnableDefaultTag = true +var EnableStructFieldResolve = false type bindRequest struct { Req *protocol.Request @@ -152,6 +153,15 @@ func getFieldDecoder(field reflect.StructField, index int, parentIdx []int) ([]f case reflect.TypeOf(multipart.FileHeader{}): return getMultipartFileDecoder(field, index, fieldTagInfos, parentIdx) } + if EnableStructFieldResolve { + structFieldDecoder, err := getStructTypeFieldDecoder(field, index, fieldTagInfos, parentIdx) + if err != nil { + return nil, err + } + if structFieldDecoder != nil { + decoders = append(decoders, structFieldDecoder...) + } + } for i := 0; i < el.NumField(); i++ { if el.Field(i).PkgPath != "" && !el.Field(i).Anonymous { @@ -167,7 +177,6 @@ func getFieldDecoder(field reflect.StructField, index int, parentIdx []int) ([]f if err != nil { return nil, err } - if dec != nil { decoders = append(decoders, dec...) } diff --git a/pkg/app/server/binding/decoder/struct_type_decoder.go b/pkg/app/server/binding/decoder/struct_type_decoder.go new file mode 100644 index 000000000..c810d66bf --- /dev/null +++ b/pkg/app/server/binding/decoder/struct_type_decoder.go @@ -0,0 +1,128 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package decoder + +import ( + "fmt" + "reflect" + + "github.com/cloudwego/hertz/internal/bytesconv" + path1 "github.com/cloudwego/hertz/pkg/app/server/binding/path" + hjson "github.com/cloudwego/hertz/pkg/common/json" + "github.com/cloudwego/hertz/pkg/common/utils" +) + +type structTypeFieldTextDecoder struct { + fieldInfo +} + +func (d *structTypeFieldTextDecoder) Decode(req *bindRequest, params path1.PathParam, reqValue reflect.Value) error { + var err error + var text string + var defaultValue string + for _, tagInfo := range d.tagInfos { + if tagInfo.Skip || tagInfo.Key == jsonTag || tagInfo.Key == fileNameTag { + defaultValue = tagInfo.Default + continue + } + if tagInfo.Key == headerTag { + tagInfo.Value = utils.GetNormalizeHeaderKey(tagInfo.Value, req.Req.Header.IsDisableNormalizing()) + } + ret := tagInfo.Getter(req, params, tagInfo.Value) + defaultValue = tagInfo.Default + if len(ret) != 0 { + text = ret[0] + err = nil + break + } + if tagInfo.Required { + err = fmt.Errorf("'%s' field is a 'required' parameter, but the request does not have this parameter", d.fieldName) + } + } + if err != nil { + return err + } + if len(text) == 0 && len(defaultValue) != 0 { + text = defaultValue + } + if text == "" { + return nil + } + reqValue = GetFieldValue(reqValue, d.parentIndex) + field := reqValue.Field(d.index) + if field.Kind() == reflect.Ptr { + t := field.Type() + var ptrDepth int + for t.Kind() == reflect.Ptr { + t = t.Elem() + ptrDepth++ + } + var vv reflect.Value + vv, err := stringToValue(t, text) + if err != nil { + return fmt.Errorf("unable to decode '%s' as %s: %w", text, d.fieldType.Name(), err) + } + field.Set(ReferenceValue(vv, ptrDepth)) + return nil + } + + err = hjson.Unmarshal(bytesconv.S2b(text), field.Addr().Interface()) + if err != nil { + return fmt.Errorf("unable to decode '%s' as %s: %w", text, d.fieldType.Name(), err) + } + + return nil +} + +func getStructTypeFieldDecoder(field reflect.StructField, index int, tagInfos []TagInfo, parentIdx []int) ([]fieldDecoder, error) { + for idx, tagInfo := range tagInfos { + switch tagInfo.Key { + case pathTag: + tagInfos[idx].Getter = path + case formTag: + tagInfos[idx].Getter = postForm + case queryTag: + tagInfos[idx].Getter = query + case cookieTag: + tagInfos[idx].Getter = cookie + case headerTag: + tagInfos[idx].Getter = header + case jsonTag: + // do nothing + case rawBodyTag: + tagInfos[idx].Getter = rawBody + case fileNameTag: + // do nothing + default: + } + } + + fieldType := field.Type + for field.Type.Kind() == reflect.Ptr { + fieldType = field.Type.Elem() + } + + return []fieldDecoder{&structTypeFieldTextDecoder{ + fieldInfo: fieldInfo{ + index: index, + parentIndex: parentIdx, + fieldName: field.Name, + tagInfos: tagInfos, + fieldType: fieldType, + }, + }}, nil +} From 865030950eab90ff2c9d036dc37bedb067390c7b Mon Sep 17 00:00:00 2001 From: fgy Date: Wed, 17 May 2023 11:45:45 +0800 Subject: [PATCH 31/91] feat: modify some comment --- pkg/app/server/binding/config.go | 6 ++++++ pkg/app/server/binding/tagexpr_bind_test.go | 6 ++---- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/pkg/app/server/binding/config.go b/pkg/app/server/binding/config.go index 124a1cef6..28b8708a0 100644 --- a/pkg/app/server/binding/config.go +++ b/pkg/app/server/binding/config.go @@ -23,18 +23,24 @@ import ( hjson "github.com/cloudwego/hertz/pkg/common/json" ) +// ResetJSONUnmarshaler reset the JSON Unmarshal function. func ResetJSONUnmarshaler(fn func(data []byte, v interface{}) error) { hjson.Unmarshal = fn } +// ResetStdJSONUnmarshaler uses "encoding/json" as the JSON Unmarshal function. func ResetStdJSONUnmarshaler() { ResetJSONUnmarshaler(standardJson.Unmarshal) } +// EnableDefaultTag is used to enable or disable adding default tags to a field when it has no tag, it is true by default. +// If is true, the field with no tag will be added default tags, for more automated parameter binding. But there may be additional overhead func EnableDefaultTag(b bool) { decoder.EnableDefaultTag = b } +// EnableStructFieldResolve to enable or disable the generation of a separate decoder for a struct, it is false by default. +// If is true, the 'struct' field will get a single decoder.structTypeFieldTextDecoder, and use json.Unmarshal for decode it. func EnableStructFieldResolve(b bool) { decoder.EnableStructFieldResolve = b } diff --git a/pkg/app/server/binding/tagexpr_bind_test.go b/pkg/app/server/binding/tagexpr_bind_test.go index 790cf5a03..8bc19608e 100644 --- a/pkg/app/server/binding/tagexpr_bind_test.go +++ b/pkg/app/server/binding/tagexpr_bind_test.go @@ -498,8 +498,7 @@ func (testPathParams2) Get(name string) (string, bool) { } } -// FIXME: 复杂类型的默认值 -// 负责类型的默认值用 json unmarshal 做 +// FIXME: 复杂类型的默认值,暂时先不做,低优 func TestDefault(t *testing.T) { type S struct { SS string `json:"ss"` @@ -631,8 +630,7 @@ func TestAuto(t *testing.T) { assert.DeepEqual(t, "d-from-form", recv.D) } -// FIXME: 自定义验证函数 & TIME 类型内置 -// 修改自定义绑定函数的实现 +// FIXME: 自定义验证函数 & TIME 类型内置, 暂时先不做,低优 func TestTypeUnmarshal(t *testing.T) { type Recv struct { A time.Time `form:"t1"` From 64b06e8d7c31393ebfc03d60c77d3267842bb1fa Mon Sep 17 00:00:00 2001 From: fgy Date: Wed, 17 May 2023 20:53:59 +0800 Subject: [PATCH 32/91] feat: required validate --- go.mod | 1 + go.sum | 6 ++ pkg/app/server/binding/binder_test.go | 57 +++++++++++++++++ .../binding/decoder/base_type_decoder.go | 8 +++ pkg/app/server/binding/decoder/decoder.go | 10 +-- .../server/binding/decoder/gjson_required.go | 46 ++++++++++++++ .../binding/decoder/map_type_decoder.go | 19 +++++- .../binding/decoder/slice_type_decoder.go | 17 ++++++ .../server/binding/decoder/sonic_required.go | 61 +++++++++++++++++++ .../binding/decoder/struct_type_decoder.go | 8 +++ pkg/app/server/binding/decoder/tag.go | 25 +++++++- pkg/app/server/binding/default.go | 2 +- pkg/app/server/binding/tagexpr_bind_test.go | 23 ++++--- 13 files changed, 266 insertions(+), 17 deletions(-) create mode 100644 pkg/app/server/binding/decoder/gjson_required.go create mode 100644 pkg/app/server/binding/decoder/sonic_required.go diff --git a/go.mod b/go.mod index 44b68b92c..3e7b0d2e3 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/fsnotify/fsnotify v1.5.4 github.com/go-playground/assert/v2 v2.2.0 // indirect github.com/go-playground/validator/v10 v10.11.1 + github.com/tidwall/gjson v1.14.4 golang.org/x/sync v0.0.0-20210220032951-036812b2e83c golang.org/x/sys v0.0.0-20220412211240-33da011f77ad google.golang.org/protobuf v1.27.1 diff --git a/go.sum b/go.sum index f4f84c45a..679c3c1e0 100644 --- a/go.sum +++ b/go.sum @@ -59,6 +59,12 @@ github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/ github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM= +github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= golang.org/x/arch v0.0.0-20201008161808-52c3e6f60cff/go.mod h1:flIaEI6LNU6xOCD5PaJvn9wGP0agmIOqjrtsKGRguv4= diff --git a/pkg/app/server/binding/binder_test.go b/pkg/app/server/binding/binder_test.go index 367d278c7..16114b609 100644 --- a/pkg/app/server/binding/binder_test.go +++ b/pkg/app/server/binding/binder_test.go @@ -804,6 +804,63 @@ func TestBind_StructFieldResolve(t *testing.T) { assert.DeepEqual(t, 22, result.N.B) } +func TestBind_JSONRequiredField(t *testing.T) { + type Nested2 struct { + C int `json:"c,required"` + D int `json:"dd,required"` + } + type Nested struct { + A int `json:"a,required"` + B int `json:"b,required"` + N2 Nested2 `json:"n2"` + } + type Req struct { + N Nested `json:"n,required"` + } + bodyBytes := []byte(`{ + "n": { + "a": 1, + "b": 2, + "n2": { + "dd": 4 + } + } +}`) + req := newMockRequest(). + SetRequestURI("http://foobar.com?j2=13"). + SetJSONContentType(). + SetBody(bodyBytes) + var result Req + err := DefaultBinder.Bind(req.Req, nil, &result) + if err == nil { + t.Errorf("expected an error, but get nil") + } + assert.DeepEqual(t, 1, result.N.A) + assert.DeepEqual(t, 2, result.N.B) + assert.DeepEqual(t, 0, result.N.N2.C) + assert.DeepEqual(t, 4, result.N.N2.D) + + bodyBytes = []byte(`{ + "n": { + "a": 1, + "b": 2 + } +}`) + req = newMockRequest(). + SetRequestURI("http://foobar.com?j2=13"). + SetJSONContentType(). + SetBody(bodyBytes) + var result2 Req + err = DefaultBinder.Bind(req.Req, nil, &result2) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, 1, result2.N.A) + assert.DeepEqual(t, 2, result2.N.B) + assert.DeepEqual(t, 0, result2.N.N2.C) + assert.DeepEqual(t, 0, result2.N.N2.D) +} + func Benchmark_Binding(b *testing.B) { type Req struct { Version string `path:"v"` diff --git a/pkg/app/server/binding/decoder/base_type_decoder.go b/pkg/app/server/binding/decoder/base_type_decoder.go index 0eb353914..dc3e4ed40 100644 --- a/pkg/app/server/binding/decoder/base_type_decoder.go +++ b/pkg/app/server/binding/decoder/base_type_decoder.go @@ -68,6 +68,14 @@ func (d *baseTypeFieldTextDecoder) Decode(req *bindRequest, params path1.PathPar for _, tagInfo := range d.tagInfos { if tagInfo.Skip || tagInfo.Key == jsonTag || tagInfo.Key == fileNameTag { defaultValue = tagInfo.Default + if tagInfo.Key == jsonTag { + found := checkRequireJSON(req, tagInfo) + if found { + err = nil + } else { + err = fmt.Errorf("'%s' field is a 'required' parameter, but the request does not have this parameter", d.fieldName) + } + } continue } if tagInfo.Key == headerTag { diff --git a/pkg/app/server/binding/decoder/decoder.go b/pkg/app/server/binding/decoder/decoder.go index db364eed3..5e95baadb 100644 --- a/pkg/app/server/binding/decoder/decoder.go +++ b/pkg/app/server/binding/decoder/decoder.go @@ -89,7 +89,7 @@ func GetReqDecoder(rt reflect.Type) (Decoder, error) { continue } - dec, err := getFieldDecoder(el.Field(i), i, []int{}) + dec, err := getFieldDecoder(el.Field(i), i, []int{}, "") if err != nil { return nil, err } @@ -114,7 +114,7 @@ func GetReqDecoder(rt reflect.Type) (Decoder, error) { }, nil } -func getFieldDecoder(field reflect.StructField, index int, parentIdx []int) ([]fieldDecoder, error) { +func getFieldDecoder(field reflect.StructField, index int, parentIdx []int, parentJSONName string) ([]fieldDecoder, error) { for field.Type.Kind() == reflect.Ptr { field.Type = field.Type.Elem() } @@ -132,7 +132,7 @@ func getFieldDecoder(field reflect.StructField, index int, parentIdx []int) ([]f }}, nil } - fieldTagInfos := lookupFieldTags(field) + fieldTagInfos, newParentJSONName := lookupFieldTags(field, parentJSONName) if len(fieldTagInfos) == 0 && EnableDefaultTag { fieldTagInfos = getDefaultFieldTags(field) } @@ -148,7 +148,7 @@ func getFieldDecoder(field reflect.StructField, index int, parentIdx []int) ([]f if field.Type.Kind() == reflect.Struct { var decoders []fieldDecoder el := field.Type - // todo: more built-int common struct binding, ex. time... + // todo: more built-in common struct binding, ex. time... switch el { case reflect.TypeOf(multipart.FileHeader{}): return getMultipartFileDecoder(field, index, fieldTagInfos, parentIdx) @@ -173,7 +173,7 @@ func getFieldDecoder(field reflect.StructField, index int, parentIdx []int) ([]f idxes = append(idxes, parentIdx...) } idxes = append(idxes, index) - dec, err := getFieldDecoder(el.Field(i), i, idxes) + dec, err := getFieldDecoder(el.Field(i), i, idxes, newParentJSONName) if err != nil { return nil, err } diff --git a/pkg/app/server/binding/decoder/gjson_required.go b/pkg/app/server/binding/decoder/gjson_required.go new file mode 100644 index 000000000..03dd2e3c7 --- /dev/null +++ b/pkg/app/server/binding/decoder/gjson_required.go @@ -0,0 +1,46 @@ +//go:build stdjson || !(amd64 && (linux || windows || darwin)) +// +build stdjson !amd64 !linux,!windows,!darwin + +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package decoder + +import ( + "strings" + + "github.com/tidwall/gjson" +) + +func checkRequireJSON2(req *bindRequest, tagInfo TagInfo) bool { + if !tagInfo.Required { + return true + } + ct := bytesconv.B2s(req.Req.Header.ContentType()) + if utils.FilterContentType(ct) != "application/json" { + return false + } + result := gjson.GetBytes(req.Req.Body(), tagInfo.JSONName) + if !result.Exists() { + idx := strings.LastIndex(tagInfo.JSONName, ".") + // There should be a superior if it is empty, it will report 'true' for required + if idx > 0 && !gjson.GetBytes(req.Req.Body(), tagInfo.JSONName[:idx]).Exists() { + return true + } + return false + } + return true +} diff --git a/pkg/app/server/binding/decoder/map_type_decoder.go b/pkg/app/server/binding/decoder/map_type_decoder.go index c72e7d68a..b233a9aca 100644 --- a/pkg/app/server/binding/decoder/map_type_decoder.go +++ b/pkg/app/server/binding/decoder/map_type_decoder.go @@ -55,10 +55,20 @@ type mapTypeFieldTextDecoder struct { } func (d *mapTypeFieldTextDecoder) Decode(req *bindRequest, params path1.PathParam, reqValue reflect.Value) error { + var err error var text string var defaultValue string for _, tagInfo := range d.tagInfos { if tagInfo.Skip || tagInfo.Key == jsonTag || tagInfo.Key == fileNameTag { + defaultValue = tagInfo.Default + if tagInfo.Key == jsonTag { + found := checkRequireJSON(req, tagInfo) + if found { + err = nil + } else { + err = fmt.Errorf("'%s' field is a 'required' parameter, but the request does not have this parameter", d.fieldName) + } + } continue } if tagInfo.Key == headerTag { @@ -68,8 +78,15 @@ func (d *mapTypeFieldTextDecoder) Decode(req *bindRequest, params path1.PathPara defaultValue = tagInfo.Default if len(ret) != 0 { text = ret[0] + err = nil break } + if tagInfo.Required { + err = fmt.Errorf("'%s' field is a 'required' parameter, but the request does not have this parameter", d.fieldName) + } + } + if err != nil { + return err } if len(text) == 0 && len(defaultValue) != 0 { text = defaultValue @@ -96,7 +113,7 @@ func (d *mapTypeFieldTextDecoder) Decode(req *bindRequest, params path1.PathPara return nil } - err := hjson.Unmarshal(bytesconv.S2b(text), field.Addr().Interface()) + err = hjson.Unmarshal(bytesconv.S2b(text), field.Addr().Interface()) if err != nil { return fmt.Errorf("unable to decode '%s' as %s: %w", text, d.fieldType.Name(), err) } diff --git a/pkg/app/server/binding/decoder/slice_type_decoder.go b/pkg/app/server/binding/decoder/slice_type_decoder.go index 397fd6f3e..1db08ad5d 100644 --- a/pkg/app/server/binding/decoder/slice_type_decoder.go +++ b/pkg/app/server/binding/decoder/slice_type_decoder.go @@ -57,11 +57,21 @@ type sliceTypeFieldTextDecoder struct { } func (d *sliceTypeFieldTextDecoder) Decode(req *bindRequest, params path1.PathParam, reqValue reflect.Value) error { + var err error var texts []string var defaultValue string var bindRawBody bool for _, tagInfo := range d.tagInfos { if tagInfo.Skip || tagInfo.Key == jsonTag || tagInfo.Key == fileNameTag { + defaultValue = tagInfo.Default + if tagInfo.Key == jsonTag { + found := checkRequireJSON(req, tagInfo) + if found { + err = nil + } else { + err = fmt.Errorf("'%s' field is a 'required' parameter, but the request does not have this parameter", d.fieldName) + } + } continue } if tagInfo.Key == headerTag { @@ -74,8 +84,15 @@ func (d *sliceTypeFieldTextDecoder) Decode(req *bindRequest, params path1.PathPa // todo: array/slice default value defaultValue = tagInfo.Default if len(texts) != 0 { + err = nil break } + if tagInfo.Required { + err = fmt.Errorf("'%s' field is a 'required' parameter, but the request does not have this parameter", d.fieldName) + } + } + if err != nil { + return err } if len(texts) == 0 && len(defaultValue) != 0 { texts = append(texts, defaultValue) diff --git a/pkg/app/server/binding/decoder/sonic_required.go b/pkg/app/server/binding/decoder/sonic_required.go new file mode 100644 index 000000000..b9ffbc360 --- /dev/null +++ b/pkg/app/server/binding/decoder/sonic_required.go @@ -0,0 +1,61 @@ +//go:build (linux || windows || darwin) && amd64 && !stdjson +// +build linux windows darwin +// +build amd64 +// +build !stdjson + +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package decoder + +import ( + "github.com/cloudwego/hertz/internal/bytesconv" + "github.com/cloudwego/hertz/pkg/common/utils" + "strings" + + "github.com/bytedance/sonic" +) + +func checkRequireJSON(req *bindRequest, tagInfo TagInfo) bool { + if !tagInfo.Required { + return true + } + ct := bytesconv.B2s(req.Req.Header.ContentType()) + if utils.FilterContentType(ct) != "application/json" { + return false + } + node, _ := sonic.Get(req.Req.Body(), stringSliceForInterface(tagInfo.JSONName)...) + if !node.Exists() { + idx := strings.LastIndex(tagInfo.JSONName, ".") + if idx > 0 { + // There should be a superior if it is empty, it will report 'true' for required + node, _ := sonic.Get(req.Req.Body(), stringSliceForInterface(tagInfo.JSONName[:idx])...) + if !node.Exists() { + return true + } + } + return false + } + return true +} + +func stringSliceForInterface(s string) (ret []interface{}) { + x := strings.Split(s, ".") + for _, val := range x { + ret = append(ret, val) + } + return +} diff --git a/pkg/app/server/binding/decoder/struct_type_decoder.go b/pkg/app/server/binding/decoder/struct_type_decoder.go index c810d66bf..2f976af5f 100644 --- a/pkg/app/server/binding/decoder/struct_type_decoder.go +++ b/pkg/app/server/binding/decoder/struct_type_decoder.go @@ -37,6 +37,14 @@ func (d *structTypeFieldTextDecoder) Decode(req *bindRequest, params path1.PathP for _, tagInfo := range d.tagInfos { if tagInfo.Skip || tagInfo.Key == jsonTag || tagInfo.Key == fileNameTag { defaultValue = tagInfo.Default + if tagInfo.Key == jsonTag { + found := checkRequireJSON(req, tagInfo) + if found { + err = nil + } else { + err = fmt.Errorf("'%s' field is a 'required' parameter, but the request does not have this parameter", d.fieldName) + } + } continue } if tagInfo.Key == headerTag { diff --git a/pkg/app/server/binding/decoder/tag.go b/pkg/app/server/binding/decoder/tag.go index db4f1d12a..20d893740 100644 --- a/pkg/app/server/binding/decoder/tag.go +++ b/pkg/app/server/binding/decoder/tag.go @@ -43,6 +43,7 @@ const ( type TagInfo struct { Key string Value string + JSONName string Required bool Skip bool Default string @@ -58,7 +59,7 @@ func head(str, sep string) (head, tail string) { return str[:idx], str[idx+len(sep):] } -func lookupFieldTags(field reflect.StructField) []TagInfo { +func lookupFieldTags(field reflect.StructField, parentJSONName string) ([]TagInfo, string) { var ret []string tags := []string{pathTag, formTag, queryTag, cookieTag, headerTag, jsonTag, rawBodyTag, fileNameTag} for _, tag := range tags { @@ -73,12 +74,27 @@ func lookupFieldTags(field reflect.StructField) []TagInfo { } var tagInfos []TagInfo + var newParentJSONName string for _, tag := range ret { tagContent := field.Tag.Get(tag) tagValue, opts := head(tagContent, ",") + if len(tagValue) == 0 { + tagValue = field.Name + } skip := false + jsonName := "" + if tag == jsonTag { + jsonName = parentJSONName + "." + tagValue + } if tagValue == "-" { skip = true + if tag == jsonTag { + jsonName = parentJSONName + "." + field.Name + } + } + if jsonName != "" { + jsonName = strings.TrimPrefix(jsonName, ".") + newParentJSONName = jsonName } var options []string var opt string @@ -90,10 +106,13 @@ func lookupFieldTags(field reflect.StructField) []TagInfo { required = true } } - tagInfos = append(tagInfos, TagInfo{Key: tag, Value: tagValue, Options: options, Required: required, Default: defaultVal, Skip: skip}) + tagInfos = append(tagInfos, TagInfo{Key: tag, Value: tagValue, JSONName: jsonName, Options: options, Required: required, Default: defaultVal, Skip: skip}) + } + if len(newParentJSONName) == 0 { + newParentJSONName = strings.TrimPrefix(parentJSONName+"."+field.Name, ".") } - return tagInfos + return tagInfos, newParentJSONName } func getDefaultFieldTags(field reflect.StructField) (tagInfos []TagInfo) { diff --git a/pkg/app/server/binding/default.go b/pkg/app/server/binding/default.go index 85dbb93de..5841fe168 100644 --- a/pkg/app/server/binding/default.go +++ b/pkg/app/server/binding/default.go @@ -86,7 +86,7 @@ func (b *defaultBinder) Name() string { func (b *defaultBinder) Bind(req *protocol.Request, params path.PathParam, v interface{}) error { err := b.preBindBody(req, v) if err != nil { - return err + return fmt.Errorf("bind body failed, err=%v", err) } rv, typeID := valueAndTypeID(v) if rv.Kind() != reflect.Ptr || rv.IsNil() { diff --git a/pkg/app/server/binding/tagexpr_bind_test.go b/pkg/app/server/binding/tagexpr_bind_test.go index 8bc19608e..04f0ad9fe 100644 --- a/pkg/app/server/binding/tagexpr_bind_test.go +++ b/pkg/app/server/binding/tagexpr_bind_test.go @@ -112,6 +112,9 @@ func TestGetBody(t *testing.T) { req := newRequest("http://localhost:8080/", nil, nil, nil) recv := new(Recv) err := DefaultBinder.Bind(req.Req, nil, recv) + if err == nil { + t.Fatalf("expected an error, but get nil") + } assert.DeepEqual(t, err.Error(), "'E' field is a 'required' parameter, but the request does not have this parameter") } @@ -420,10 +423,10 @@ func TestJSON(t *testing.T) { recv := new(Recv) err := DefaultBinder.Bind(req.Req, nil, recv) - if err != nil { - t.Error(err) + if err == nil { + t.Error("expected an error, but get nil") } - //assert.DeepEqual(t, &binding.Error{ErrType: "binding", FailField: "y", Msg: "missing required parameter"}, err) + assert.DeepEqual(t, err.Error(), "'Y' field is a 'required' parameter, but the request does not have this parameter") assert.DeepEqual(t, []string{"a1", "a2"}, (**recv.X).A) assert.DeepEqual(t, int32(21), (**recv.X).B) assert.DeepEqual(t, &[]uint16{31, 32}, (**recv.X).C) @@ -708,7 +711,7 @@ func TestOption(t *testing.T) { req = newRequest("", header, nil, bodyReader) recv = new(Recv) err = DefaultBinder.Bind(req.Req, nil, recv) - //assert.DeepEqual(t, err.Error(), "binding: expr_path=X.c, cause=missing required parameter") + assert.DeepEqual(t, err.Error(), "'C' field is a 'required' parameter, but the request does not have this parameter") assert.DeepEqual(t, 0, recv.X.C) assert.DeepEqual(t, 0, recv.X.D) assert.DeepEqual(t, "y1", recv.Y) @@ -737,8 +740,12 @@ func TestOption(t *testing.T) { }`) req = newRequest("", header, nil, bodyReader) recv2 := new(Recv2) + EnableStructFieldResolve(true) + defer func() { + EnableStructFieldResolve(false) + }() err = DefaultBinder.Bind(req.Req, nil, recv2) - //assert.DeepEqual(t, err.Error(), "binding: expr_path=X, cause=missing required parameter") + assert.DeepEqual(t, err.Error(), "'X' field is a 'required' parameter, but the request does not have this parameter") assert.True(t, recv2.X == nil) assert.DeepEqual(t, "y1", recv2.Y) } @@ -1019,7 +1026,6 @@ func TestRequiredBUG(t *testing.T) { } z := &ExchangeCurrencyRequest{} - // v := ameda.InitSampleValue(reflect.TypeOf(z), 10).Interface().(*ExchangeCurrencyRequest) b := []byte(`{ "promotion_region": "?", "currency": { @@ -1043,7 +1049,10 @@ func TestRequiredBUG(t *testing.T) { recv := new(ExchangeCurrencyRequest) err := DefaultBinder.Bind(req.Req, nil, recv) - assert.DeepEqual(t, err.Error(), "validating: expr_path=Currency.Slice[0].currencyName, cause=invalid") + // no need for validate + if err != nil { + t.Error(err) + } assert.DeepEqual(t, z, recv) } From 8213aaa4b8364ca3ee73f38492606b5be4d52d14 Mon Sep 17 00:00:00 2001 From: fgy Date: Wed, 17 May 2023 21:04:33 +0800 Subject: [PATCH 33/91] ci: go fumpt --- pkg/app/server/binding/config.go | 2 +- pkg/app/server/binding/decoder/decoder.go | 6 +- .../server/binding/decoder/sonic_required.go | 4 +- pkg/app/server/binding/tagexpr_bind_test.go | 68 +++++++++---------- 4 files changed, 39 insertions(+), 41 deletions(-) diff --git a/pkg/app/server/binding/config.go b/pkg/app/server/binding/config.go index 28b8708a0..4ae07255b 100644 --- a/pkg/app/server/binding/config.go +++ b/pkg/app/server/binding/config.go @@ -18,8 +18,8 @@ package binding import ( standardJson "encoding/json" - "github.com/cloudwego/hertz/pkg/app/server/binding/decoder" + "github.com/cloudwego/hertz/pkg/app/server/binding/decoder" hjson "github.com/cloudwego/hertz/pkg/common/json" ) diff --git a/pkg/app/server/binding/decoder/decoder.go b/pkg/app/server/binding/decoder/decoder.go index 5e95baadb..56924ed34 100644 --- a/pkg/app/server/binding/decoder/decoder.go +++ b/pkg/app/server/binding/decoder/decoder.go @@ -51,8 +51,10 @@ import ( "github.com/cloudwego/hertz/pkg/protocol" ) -var EnableDefaultTag = true -var EnableStructFieldResolve = false +var ( + EnableDefaultTag = true + EnableStructFieldResolve = false +) type bindRequest struct { Req *protocol.Request diff --git a/pkg/app/server/binding/decoder/sonic_required.go b/pkg/app/server/binding/decoder/sonic_required.go index b9ffbc360..d3d8e4b2c 100644 --- a/pkg/app/server/binding/decoder/sonic_required.go +++ b/pkg/app/server/binding/decoder/sonic_required.go @@ -22,11 +22,11 @@ package decoder import ( - "github.com/cloudwego/hertz/internal/bytesconv" - "github.com/cloudwego/hertz/pkg/common/utils" "strings" "github.com/bytedance/sonic" + "github.com/cloudwego/hertz/internal/bytesconv" + "github.com/cloudwego/hertz/pkg/common/utils" ) func checkRequireJSON(req *bindRequest, tagInfo TagInfo) bool { diff --git a/pkg/app/server/binding/tagexpr_bind_test.go b/pkg/app/server/binding/tagexpr_bind_test.go index 04f0ad9fe..4264a6397 100644 --- a/pkg/app/server/binding/tagexpr_bind_test.go +++ b/pkg/app/server/binding/tagexpr_bind_test.go @@ -503,9 +503,9 @@ func (testPathParams2) Get(name string) (string, bool) { // FIXME: 复杂类型的默认值,暂时先不做,低优 func TestDefault(t *testing.T) { - type S struct { - SS string `json:"ss"` - } + //type S struct { + // SS string `json:"ss"` + //} type Recv struct { X **struct { @@ -513,11 +513,11 @@ func TestDefault(t *testing.T) { B int32 `path:"b" default:"32"` C bool `json:"c" default:"true"` D *float32 `default:"123.4"` - //E *[]string `default:"['a','b','c','d,e,f']"` - //F map[string]string `default:"{'a':'\"\\'1','\"b':'c','c':'2'}"` - //G map[string]int64 `default:"{'a':1,'b':2,'c':3}"` - //H map[string]float64 `default:"{'a':0.1,'b':1.2,'c':2.3}"` - //I map[string]float64 `default:"{'\"a\"':0.1,'b':1.2,'c':2.3}"` + // E *[]string `default:"['a','b','c','d,e,f']"` + // F map[string]string `default:"{'a':'\"\\'1','\"b':'c','c':'2'}"` + // G map[string]int64 `default:"{'a':1,'b':2,'c':3}"` + // H map[string]float64 `default:"{'a':0.1,'b':1.2,'c':2.3}"` + // I map[string]float64 `default:"{'\"a\"':0.1,'b':1.2,'c':2.3}"` Empty string `default:""` Null string `default:""` CommaSpace string `default:",a:c "` @@ -528,12 +528,12 @@ func TestDefault(t *testing.T) { Y string `json:"y" default:"y1"` Z int64 W string `json:"w"` - //V []int64 `json:"u" default:"[1,2,3]"` - //U []float32 `json:"u" default:"[1.1,2,3]"` + // V []int64 `json:"u" default:"[1,2,3]"` + // U []float32 `json:"u" default:"[1.1,2,3]"` T *string `json:"t" default:"t1"` - //S S `default:"{'ss':'test'}"` - //O *S `default:"{'ss':'test2'}"` - //Complex map[string][]map[string][]int64 `default:"{'a':[{'aa':[1,2,3], 'bb':[4,5]}],'b':[{}]}"` + // S S `default:"{'ss':'test'}"` + // O *S `default:"{'ss':'test2'}"` + // Complex map[string][]map[string][]int64 `default:"{'a':[{'aa':[1,2,3], 'bb':[4,5]}],'b':[{}]}"` } bodyReader := strings.NewReader(`{ @@ -557,11 +557,11 @@ func TestDefault(t *testing.T) { assert.DeepEqual(t, int32(32), (**recv.X).B) assert.DeepEqual(t, true, (**recv.X).C) assert.DeepEqual(t, float32(123.4), *(**recv.X).D) - //assert.DeepEqual(t, []string{"a", "b", "c", "d,e,f"}, *(**recv.X).E) - //assert.DeepEqual(t, map[string]string{"a": "\"'1", "\"b": "c", "c": "2"}, (**recv.X).F) - //assert.DeepEqual(t, map[string]int64{"a": 1, "b": 2, "c": 3}, (**recv.X).G) - //assert.DeepEqual(t, map[string]float64{"a": 0.1, "b": 1.2, "c": 2.3}, (**recv.X).H) - //assert.DeepEqual(t, map[string]float64{"\"a\"": 0.1, "b": 1.2, "c": 2.3}, (**recv.X).I) + // assert.DeepEqual(t, []string{"a", "b", "c", "d,e,f"}, *(**recv.X).E) + // assert.DeepEqual(t, map[string]string{"a": "\"'1", "\"b": "c", "c": "2"}, (**recv.X).F) + // assert.DeepEqual(t, map[string]int64{"a": 1, "b": 2, "c": 3}, (**recv.X).G) + // assert.DeepEqual(t, map[string]float64{"a": 0.1, "b": 1.2, "c": 2.3}, (**recv.X).H) + // assert.DeepEqual(t, map[string]float64{"\"a\"": 0.1, "b": 1.2, "c": 2.3}, (**recv.X).I) assert.DeepEqual(t, "", (**recv.X).Empty) assert.DeepEqual(t, "", (**recv.X).Null) assert.DeepEqual(t, ",a:c ", (**recv.X).CommaSpace) @@ -571,11 +571,11 @@ func TestDefault(t *testing.T) { assert.DeepEqual(t, "y1", recv.Y) assert.DeepEqual(t, "t1", *recv.T) assert.DeepEqual(t, int64(6), recv.Z) - //assert.DeepEqual(t, []int64{1, 2, 3}, recv.V) - //assert.DeepEqual(t, []float32{1.1, 2, 3}, recv.U) - //assert.DeepEqual(t, S{SS: "test"}, recv.S) - //assert.DeepEqual(t, &S{SS: "test2"}, recv.O) - //assert.DeepEqual(t, map[string][]map[string][]int64{"a": {{"aa": {1, 2, 3}, "bb": []int64{4, 5}}}, "b": {map[string][]int64{}}}, recv.Complex) + // assert.DeepEqual(t, []int64{1, 2, 3}, recv.V) + // assert.DeepEqual(t, []float32{1.1, 2, 3}, recv.U) + // assert.DeepEqual(t, S{SS: "test"}, recv.S) + // assert.DeepEqual(t, &S{SS: "test2"}, recv.O) + // assert.DeepEqual(t, map[string][]map[string][]int64{"a": {{"aa": {1, 2, 3}, "bb": []int64{4, 5}}}, "b": {map[string][]int64{}}}, recv.Complex) } func TestAuto(t *testing.T) { @@ -754,7 +754,7 @@ func newRequest(u string, header http.Header, cookies []*http.Cookie, bodyReader if header == nil { header = make(http.Header) } - var method = "GET" + method := "GET" var body []byte if bodyReader != nil { body, _ = ioutil.ReadAll(bodyReader) @@ -777,7 +777,6 @@ func newRequest(u string, header http.Header, cookies []*http.Cookie, bodyReader req.Req.SetMethod(method) for _, c := range cookies { req.Req.Header.SetCookie(c.Name, c.Value) - } return req } @@ -882,13 +881,13 @@ func TestRegTypeUnmarshal(t *testing.T) { Q Q `query:"q"` Qs []*Q `query:"qs"` } - var values = url.Values{} + values := url.Values{} b, err := json.Marshal(Q{A: 2, B: "y"}) if err != nil { t.Error(err) } values.Add("q", string(b)) - bs, err := json.Marshal([]Q{{A: 1, B: "x"}, {A: 2, B: "y"}}) + bs, _ := json.Marshal([]Q{{A: 1, B: "x"}, {A: 2, B: "y"}}) values.Add("qs", string(bs)) req := newRequest("http://localhost:8080/?"+values.Encode(), nil, nil, nil) recv := new(T) @@ -948,18 +947,15 @@ func TestRegTypeUnmarshal(t *testing.T) { // FIXME: json unmarshal 后其他 required 没必要 校验 required func TestPathnameBUG2(t *testing.T) { type CurrencyData struct { - z int Amount *string `form:"amount,required" json:"amount,required" protobuf:"bytes,1,req,name=amount" query:"amount,required"` Name *string `form:"name,required" json:"name,required" protobuf:"bytes,2,req,name=name" query:"name,required"` Symbol *string `form:"symbol" json:"symbol,omitempty" protobuf:"bytes,3,opt,name=symbol" query:"symbol"` } type TimeRange struct { - z int StartTime *int64 `form:"start_time,required" json:"start_time,required" protobuf:"varint,1,req,name=start_time,json=startTime" query:"start_time,required"` EndTime *int64 `form:"end_time,required" json:"end_time,required" protobuf:"varint,2,req,name=end_time,json=endTime" query:"end_time,required"` } type CreateFreeShippingRequest struct { - z int PromotionName *string `form:"promotion_name,required" json:"promotion_name,required" protobuf:"bytes,1,req,name=promotion_name,json=promotionName" query:"promotion_name,required"` PromotionRegion *string `form:"promotion_region,required" json:"promotion_region,required" protobuf:"bytes,2,req,name=promotion_region,json=promotionRegion" query:"promotion_region,required"` TimeRange *TimeRange `form:"time_range,required" json:"time_range,required" protobuf:"bytes,3,req,name=time_range,json=timeRange" query:"time_range,required"` @@ -988,7 +984,7 @@ func TestPathnameBUG2(t *testing.T) { "7493989780026655762","11111","111212121" ] }`) - var v = new(CreateFreeShippingRequest) + v := new(CreateFreeShippingRequest) err := json.Unmarshal(b, v) if err != nil { t.Error(err) @@ -1010,7 +1006,7 @@ func TestPathnameBUG2(t *testing.T) { // FIXME: json unmarshal 后的其他 tag 的 required 的校验 func TestRequiredBUG(t *testing.T) { type Currency struct { - currencyName *string `form:"currency_name,required" json:"currency_name,required" protobuf:"bytes,1,req,name=currency_name,json=currencyName" query:"currency_name,required"` + // currencyName *string `form:"currency_name,required" json:"currency_name,required" protobuf:"bytes,1,req,name=currency_name,json=currencyName" query:"currency_name,required"` CurrencySymbol *string `form:"currency_symbol,required" json:"currency_symbol,required" protobuf:"bytes,2,req,name=currency_symbol,json=currencySymbol" query:"currency_symbol,required"` } @@ -1072,7 +1068,7 @@ func TestIssue25(t *testing.T) { if err != nil { t.Error(err) } - //assert.DeepEqual(t, "from cookie", recv.A) + // assert.DeepEqual(t, "from cookie", recv.A) header2 := make(http.Header) header2.Set("A", "from header") @@ -1171,10 +1167,10 @@ func newFormBody2(values url.Values, files files) (contentType string, bodyReade if len(files) == 0 { return "application/x-www-form-urlencoded", strings.NewReader(values.Encode()) } - var pr, pw = io.Pipe() - var bodyWriter = multipart.NewWriter(pw) + pr, pw := io.Pipe() + bodyWriter := multipart.NewWriter(pw) var fileWriter io.Writer - var buf = make([]byte, 32*1024) + buf := make([]byte, 32*1024) go func() { for fieldName, postfiles := range files { for _, file := range postfiles { From b0802fdd3509cfaef1ade29c5afc917cf46827de Mon Sep 17 00:00:00 2001 From: fgy Date: Wed, 17 May 2023 21:09:31 +0800 Subject: [PATCH 34/91] ci: license --- .../server/binding/decoder/gjson_required.go | 35 ++++++++++--------- .../server/binding/decoder/sonic_required.go | 35 +++++++++---------- 2 files changed, 35 insertions(+), 35 deletions(-) diff --git a/pkg/app/server/binding/decoder/gjson_required.go b/pkg/app/server/binding/decoder/gjson_required.go index 03dd2e3c7..d53921efd 100644 --- a/pkg/app/server/binding/decoder/gjson_required.go +++ b/pkg/app/server/binding/decoder/gjson_required.go @@ -1,25 +1,26 @@ -//go:build stdjson || !(amd64 && (linux || windows || darwin)) -// +build stdjson !amd64 !linux,!windows,!darwin +// Copyright 2022 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// -/* - * Copyright 2022 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +//go:build gjson || !(amd64 && (linux || windows || darwin)) +// +build gjson !amd64 !linux,!windows,!darwin package decoder import ( + "github.com/cloudwego/hertz/internal/bytesconv" + "github.com/cloudwego/hertz/pkg/common/utils" "strings" "github.com/tidwall/gjson" diff --git a/pkg/app/server/binding/decoder/sonic_required.go b/pkg/app/server/binding/decoder/sonic_required.go index d3d8e4b2c..269747c6e 100644 --- a/pkg/app/server/binding/decoder/sonic_required.go +++ b/pkg/app/server/binding/decoder/sonic_required.go @@ -1,23 +1,22 @@ -//go:build (linux || windows || darwin) && amd64 && !stdjson +// Copyright 2022 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// + +//go:build (linux || windows || darwin) && amd64 && !gjson // +build linux windows darwin // +build amd64 -// +build !stdjson - -/* - * Copyright 2022 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +// +build !gjson package decoder From 6fc6b8753ba7455cf50e15604a1cba9f442e9776 Mon Sep 17 00:00:00 2001 From: fgy Date: Thu, 18 May 2023 14:32:13 +0800 Subject: [PATCH 35/91] ci: lint --- .github/workflows/pr-check.yml | 2 +- pkg/app/server/binding/tagexpr_bind_test.go | 6 ++---- 2 files changed, 3 insertions(+), 5 deletions(-) diff --git a/.github/workflows/pr-check.yml b/.github/workflows/pr-check.yml index 683bd08c3..dba5baaf3 100644 --- a/.github/workflows/pr-check.yml +++ b/.github/workflows/pr-check.yml @@ -34,4 +34,4 @@ jobs: # Exit with 1 when it find at least one finding. fail_on_error: true # Set staticcheck flags - staticcheck_flags: -checks=inherit,-SA1029 + staticcheck_flags: -checks=inherit,-SA1029,-SA5008 diff --git a/pkg/app/server/binding/tagexpr_bind_test.go b/pkg/app/server/binding/tagexpr_bind_test.go index 4264a6397..5c831fada 100644 --- a/pkg/app/server/binding/tagexpr_bind_test.go +++ b/pkg/app/server/binding/tagexpr_bind_test.go @@ -216,7 +216,6 @@ func TestHeaderNum(t *testing.T) { assert.DeepEqual(t, (*int64)(nil), recv.Z) } -// todo: cookie slice func TestCookieString(t *testing.T) { type Recv struct { X **struct { @@ -673,7 +672,7 @@ func TestTypeUnmarshal(t *testing.T) { t.Logf("%v", recv) } -// FIXME: JSON required 校验 +// test: required func TestOption(t *testing.T) { type Recv struct { X *struct { @@ -944,7 +943,7 @@ func TestRegTypeUnmarshal(t *testing.T) { // assert.DeepEqual(t, v, recv) //} -// FIXME: json unmarshal 后其他 required 没必要 校验 required +// test: required func TestPathnameBUG2(t *testing.T) { type CurrencyData struct { Amount *string `form:"amount,required" json:"amount,required" protobuf:"bytes,1,req,name=amount" query:"amount,required"` @@ -1003,7 +1002,6 @@ func TestPathnameBUG2(t *testing.T) { assert.DeepEqual(t, v, recv) } -// FIXME: json unmarshal 后的其他 tag 的 required 的校验 func TestRequiredBUG(t *testing.T) { type Currency struct { // currencyName *string `form:"currency_name,required" json:"currency_name,required" protobuf:"bytes,1,req,name=currency_name,json=currencyName" query:"currency_name,required"` From ae3ee12acc5d280de3d9058ea4462aa0e12ae6c4 Mon Sep 17 00:00:00 2001 From: fgy Date: Thu, 18 May 2023 17:30:14 +0800 Subject: [PATCH 36/91] ci: test panic --- pkg/app/server/binding/binder_test.go | 2 +- pkg/app/server/binding/tagexpr_bind_test.go | 183 ++++++++++---------- 2 files changed, 92 insertions(+), 93 deletions(-) diff --git a/pkg/app/server/binding/binder_test.go b/pkg/app/server/binding/binder_test.go index 16114b609..0cbd65889 100644 --- a/pkg/app/server/binding/binder_test.go +++ b/pkg/app/server/binding/binder_test.go @@ -782,7 +782,7 @@ func TestBind_StructFieldResolve(t *testing.T) { var result Req EnableStructFieldResolve(true) defer func() { - EnableDefaultTag(false) + EnableStructFieldResolve(false) }() err := DefaultBinder.Bind(req.Req, nil, &result) if err != nil { diff --git a/pkg/app/server/binding/tagexpr_bind_test.go b/pkg/app/server/binding/tagexpr_bind_test.go index 5c831fada..437822a54 100644 --- a/pkg/app/server/binding/tagexpr_bind_test.go +++ b/pkg/app/server/binding/tagexpr_bind_test.go @@ -505,7 +505,6 @@ func TestDefault(t *testing.T) { //type S struct { // SS string `json:"ss"` //} - type Recv struct { X **struct { A []string `path:"a" json:"a"` @@ -633,44 +632,44 @@ func TestAuto(t *testing.T) { } // FIXME: 自定义验证函数 & TIME 类型内置, 暂时先不做,低优 -func TestTypeUnmarshal(t *testing.T) { - type Recv struct { - A time.Time `form:"t1"` - B *time.Time `query:"t2"` - C []time.Time `query:"t2"` - } - query := make(url.Values) - query.Add("t2", "2019-09-04T14:05:24+08:00") - query.Add("t2", "2019-09-04T18:05:24+08:00") - form := make(url.Values) - form.Add("t1", "2019-09-03T18:05:24+08:00") - contentType, bodyReader := newFormBody2(form, nil) - header := make(http.Header) - header.Set("Content-Type", contentType) - req := newRequest("http://localhost/?"+query.Encode(), header, nil, bodyReader) - recv := new(Recv) - - err := DefaultBinder.Bind(req.Req, nil, recv) - if err != nil { - t.Error(err) - } - t1, err := time.Parse(time.RFC3339, "2019-09-03T18:05:24+08:00") - if err != nil { - t.Error(err) - } - assert.DeepEqual(t, t1, recv.A) - t21, err := time.Parse(time.RFC3339, "2019-09-04T14:05:24+08:00") - if err != nil { - t.Error(err) - } - assert.DeepEqual(t, t21, *recv.B) - t22, err := time.Parse(time.RFC3339, "2019-09-04T18:05:24+08:00") - if err != nil { - t.Error(err) - } - assert.DeepEqual(t, []time.Time{t21, t22}, recv.C) - t.Logf("%v", recv) -} +//func TestTypeUnmarshal(t *testing.T) { +// type Recv struct { +// A time.Time `form:"t1"` +// B *time.Time `query:"t2"` +// C []time.Time `query:"t2"` +// } +// query := make(url.Values) +// query.Add("t2", "2019-09-04T14:05:24+08:00") +// query.Add("t2", "2019-09-04T18:05:24+08:00") +// form := make(url.Values) +// form.Add("t1", "2019-09-03T18:05:24+08:00") +// contentType, bodyReader := newFormBody2(form, nil) +// header := make(http.Header) +// header.Set("Content-Type", contentType) +// req := newRequest("http://localhost/?"+query.Encode(), header, nil, bodyReader) +// recv := new(Recv) +// +// err := DefaultBinder.Bind(req.Req, nil, recv) +// if err != nil { +// t.Error(err) +// } +// t1, err := time.Parse(time.RFC3339, "2019-09-03T18:05:24+08:00") +// if err != nil { +// t.Error(err) +// } +// assert.DeepEqual(t, t1, recv.A) +// t21, err := time.Parse(time.RFC3339, "2019-09-04T14:05:24+08:00") +// if err != nil { +// t.Error(err) +// } +// assert.DeepEqual(t, t21, *recv.B) +// t22, err := time.Parse(time.RFC3339, "2019-09-04T18:05:24+08:00") +// if err != nil { +// t.Error(err) +// } +// assert.DeepEqual(t, []time.Time{t21, t22}, recv.C) +// t.Logf("%v", recv) +//} // test: required func TestOption(t *testing.T) { @@ -871,37 +870,37 @@ func TestNoTagIssue(t *testing.T) { } // DIFF: go-tagexpr 会对保留 t.Q的结构体信息,而目前的实现不会 t.Q 做特殊处理,会直接拆开。有需求也可以加上 -func TestRegTypeUnmarshal(t *testing.T) { - type Q struct { - A int - B string - } - type T struct { - Q Q `query:"q"` - Qs []*Q `query:"qs"` - } - values := url.Values{} - b, err := json.Marshal(Q{A: 2, B: "y"}) - if err != nil { - t.Error(err) - } - values.Add("q", string(b)) - bs, _ := json.Marshal([]Q{{A: 1, B: "x"}, {A: 2, B: "y"}}) - values.Add("qs", string(bs)) - req := newRequest("http://localhost:8080/?"+values.Encode(), nil, nil, nil) - recv := new(T) - - err = DefaultBinder.Bind(req.Req, nil, recv) - if err != nil { - t.Error(err) - } - assert.DeepEqual(t, 2, recv.Q.A) - assert.DeepEqual(t, "y", recv.Q.B) - assert.DeepEqual(t, 1, recv.Qs[0].A) - assert.DeepEqual(t, "x", recv.Qs[0].B) - assert.DeepEqual(t, 2, recv.Qs[1].A) - assert.DeepEqual(t, "y", recv.Qs[1].B) -} +//func TestRegTypeUnmarshal(t *testing.T) { +// type Q struct { +// A int +// B string +// } +// type T struct { +// Q Q `query:"q"` +// Qs []*Q `query:"qs"` +// } +// values := url.Values{} +// b, err := json.Marshal(Q{A: 2, B: "y"}) +// if err != nil { +// t.Error(err) +// } +// values.Add("q", string(b)) +// bs, _ := json.Marshal([]Q{{A: 1, B: "x"}, {A: 2, B: "y"}}) +// values.Add("qs", string(bs)) +// req := newRequest("http://localhost:8080/?"+values.Encode(), nil, nil, nil) +// recv := new(T) +// +// err = DefaultBinder.Bind(req.Req, nil, recv) +// if err != nil { +// t.Error(err) +// } +// assert.DeepEqual(t, 2, recv.Q.A) +// assert.DeepEqual(t, "y", recv.Q.B) +// assert.DeepEqual(t, 1, recv.Qs[0].A) +// assert.DeepEqual(t, "x", recv.Qs[0].B) +// assert.DeepEqual(t, 2, recv.Qs[1].A) +// assert.DeepEqual(t, "y", recv.Qs[1].B) +//} //func TestPathnameBUG(t *testing.T) { // type Currency struct { @@ -1130,28 +1129,28 @@ func TestIssue26(t *testing.T) { } // FIXME: json unmarshal 后,默认值的问题 -func TestDefault2(t *testing.T) { - type Recv struct { - X **struct { - Dash string `default:"xxxx"` - } - } - bodyReader := strings.NewReader(`{ - "X": { - "Dash": "hello Dash" - } - }`) - header := make(http.Header) - header.Set("Content-Type", "application/json") - req := newRequest("", header, nil, bodyReader) - recv := new(Recv) - - err := DefaultBinder.Bind(req.Req, nil, recv) - if err != nil { - t.Error(err) - } - assert.DeepEqual(t, "hello Dash", (**recv.X).Dash) -} +//func TestDefault2(t *testing.T) { +// type Recv struct { +// X **struct { +// Dash string `default:"xxxx"` +// } +// } +// bodyReader := strings.NewReader(`{ +// "X": { +// "Dash": "hello Dash" +// } +// }`) +// header := make(http.Header) +// header.Set("Content-Type", "application/json") +// req := newRequest("", header, nil, bodyReader) +// recv := new(Recv) +// +// err := DefaultBinder.Bind(req.Req, nil, recv) +// if err != nil { +// t.Error(err) +// } +// assert.DeepEqual(t, "hello Dash", (**recv.X).Dash) +//} type ( files map[string][]file From 4d70b0f7ca5def718508f03e912ba85f08334b47 Mon Sep 17 00:00:00 2001 From: fgy Date: Fri, 19 May 2023 17:53:59 +0800 Subject: [PATCH 37/91] feat: add customized type deocder --- pkg/app/server/binding/binder_test.go | 29 ++- pkg/app/server/binding/config.go | 13 + .../binding/decoder/base_type_decoder.go | 2 +- .../decoder/customized_type_decoder.go | 121 ++++++++- pkg/app/server/binding/decoder/decoder.go | 21 +- .../binding/decoder/map_type_decoder.go | 2 +- .../binding/decoder/slice_type_decoder.go | 26 +- .../binding/decoder/struct_type_decoder.go | 2 +- pkg/app/server/binding/tagexpr_bind_test.go | 246 ++++++++++-------- 9 files changed, 311 insertions(+), 151 deletions(-) diff --git a/pkg/app/server/binding/binder_test.go b/pkg/app/server/binding/binder_test.go index 0cbd65889..ec4e60e13 100644 --- a/pkg/app/server/binding/binder_test.go +++ b/pkg/app/server/binding/binder_test.go @@ -43,6 +43,7 @@ package binding import ( "fmt" "mime/multipart" + "reflect" "testing" "github.com/cloudwego/hertz/pkg/app/server/binding/path" @@ -493,24 +494,30 @@ type CustomizedDecode struct { A string } -func (c *CustomizedDecode) CustomizedFieldDecode(req *protocol.Request, params path.PathParam) error { - q1 := req.URI().QueryArgs().Peek("a") - if len(q1) == 0 { - return fmt.Errorf("can be nil") - } - - c.A = string(q1) - return nil -} - func TestBind_CustomizedTypeDecode(t *testing.T) { type Foo struct { F ***CustomizedDecode } + + err := RegTypeUnmarshal(reflect.TypeOf(CustomizedDecode{}), func(req *protocol.Request, params path.PathParam, text string) (reflect.Value, error) { + q1 := req.URI().QueryArgs().Peek("a") + if len(q1) == 0 { + return reflect.Value{}, fmt.Errorf("can be nil") + } + val := CustomizedDecode{ + A: string(q1), + } + return reflect.ValueOf(val), nil + }) + + if err != nil { + t.Fatal(err) + } + req := newMockRequest(). SetRequestURI("http://foobar.com?a=1&b=2") result := Foo{} - err := DefaultBinder.Bind(req.Req, nil, &result) + err = DefaultBinder.Bind(req.Req, nil, &result) if err != nil { t.Fatal(err) } diff --git a/pkg/app/server/binding/config.go b/pkg/app/server/binding/config.go index 4ae07255b..d9dd54e12 100644 --- a/pkg/app/server/binding/config.go +++ b/pkg/app/server/binding/config.go @@ -18,9 +18,12 @@ package binding import ( standardJson "encoding/json" + "reflect" "github.com/cloudwego/hertz/pkg/app/server/binding/decoder" + path1 "github.com/cloudwego/hertz/pkg/app/server/binding/path" hjson "github.com/cloudwego/hertz/pkg/common/json" + "github.com/cloudwego/hertz/pkg/protocol" ) // ResetJSONUnmarshaler reset the JSON Unmarshal function. @@ -44,3 +47,13 @@ func EnableDefaultTag(b bool) { func EnableStructFieldResolve(b bool) { decoder.EnableStructFieldResolve = b } + +// RegTypeUnmarshal registers customized type unmarshaler. +func RegTypeUnmarshal(t reflect.Type, fn func(req *protocol.Request, params path1.PathParam, text string) (reflect.Value, error)) error { + return decoder.RegTypeUnmarshal(t, fn) +} + +// MustRegTypeUnmarshal registers customized type unmarshaler. It will panic if exist error. +func MustRegTypeUnmarshal(t reflect.Type, fn func(req *protocol.Request, params path1.PathParam, text string) (reflect.Value, error)) { + decoder.MustRegTypeUnmarshal(t, fn) +} diff --git a/pkg/app/server/binding/decoder/base_type_decoder.go b/pkg/app/server/binding/decoder/base_type_decoder.go index dc3e4ed40..808b1ef58 100644 --- a/pkg/app/server/binding/decoder/base_type_decoder.go +++ b/pkg/app/server/binding/decoder/base_type_decoder.go @@ -113,7 +113,7 @@ func (d *baseTypeFieldTextDecoder) Decode(req *bindRequest, params path1.PathPar ptrDepth++ } var vv reflect.Value - vv, err := stringToValue(t, text) + vv, err := stringToValue(t, text, req, params) if err != nil { return err } diff --git a/pkg/app/server/binding/decoder/customized_type_decoder.go b/pkg/app/server/binding/decoder/customized_type_decoder.go index d8908f18d..1588bd169 100644 --- a/pkg/app/server/binding/decoder/customized_type_decoder.go +++ b/pkg/app/server/binding/decoder/customized_type_decoder.go @@ -41,21 +41,94 @@ package decoder import ( + "fmt" "reflect" + "time" path1 "github.com/cloudwego/hertz/pkg/app/server/binding/path" + "github.com/cloudwego/hertz/pkg/common/utils" + "github.com/cloudwego/hertz/pkg/protocol" ) +func init() { + MustRegTypeUnmarshal(reflect.TypeOf(time.Time{}), func(req *protocol.Request, params path1.PathParam, text string) (reflect.Value, error) { + if text == "" { + return reflect.ValueOf(time.Time{}), nil + } + t, err := time.Parse(time.RFC3339, text) + if err != nil { + return reflect.Value{}, err + } + return reflect.ValueOf(t), nil + }) +} + +type customizeDecodeFunc func(req *protocol.Request, params path1.PathParam, text string) (reflect.Value, error) + +var typeUnmarshalFuncs = make(map[reflect.Type]customizeDecodeFunc) + +// RegTypeUnmarshal registers customized type unmarshaler. +func RegTypeUnmarshal(t reflect.Type, fn customizeDecodeFunc) error { + // check + switch t.Kind() { + case reflect.String, reflect.Bool, + reflect.Float32, reflect.Float64, + reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8, + reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8: + return fmt.Errorf("registration type cannot be a basic type") + case reflect.Ptr: + return fmt.Errorf("registration type cannot be a pointer type") + } + // test + //vv, err := fn(&protocol.Request{}, nil) + //if err != nil { + // return fmt.Errorf("test fail: %s", err) + //} + //if tt := vv.Type(); tt != t { + // return fmt.Errorf("test fail: expect return value type is %s, but got %s", t.String(), tt.String()) + //} + + typeUnmarshalFuncs[t] = fn + return nil +} + +// MustRegTypeUnmarshal registers customized type unmarshaler. It will panic if exist error. +func MustRegTypeUnmarshal(t reflect.Type, fn customizeDecodeFunc) { + err := RegTypeUnmarshal(t, fn) + if err != nil { + panic(err) + } +} + type customizedFieldTextDecoder struct { fieldInfo + decodeFunc customizeDecodeFunc } func (d *customizedFieldTextDecoder) Decode(req *bindRequest, params path1.PathParam, reqValue reflect.Value) error { - var err error - v := reflect.New(d.fieldType) - decoder := v.Interface().(CustomizedFieldDecoder) + var text string + var defaultValue string + for _, tagInfo := range d.tagInfos { + if tagInfo.Skip || tagInfo.Key == jsonTag || tagInfo.Key == fileNameTag { + defaultValue = tagInfo.Default + continue + } + if tagInfo.Key == headerTag { + tagInfo.Value = utils.GetNormalizeHeaderKey(tagInfo.Value, req.Req.Header.IsDisableNormalizing()) + } + ret := tagInfo.Getter(req, params, tagInfo.Value) + defaultValue = tagInfo.Default + if len(ret) != 0 { + text = ret[0] + break + } + } + if len(text) == 0 && len(defaultValue) != 0 { + text = defaultValue + } - if err = decoder.CustomizedFieldDecode(req.Req, params); err != nil { + v, err := d.decodeFunc(req.Req, params, text) + if err != nil { return err } @@ -68,10 +141,48 @@ func (d *customizedFieldTextDecoder) Decode(req *bindRequest, params path1.PathP t = t.Elem() ptrDepth++ } - field.Set(ReferenceValue(v.Elem(), ptrDepth)) + field.Set(ReferenceValue(v, ptrDepth)) return nil } field.Set(v) return nil } + +func getCustomizedFieldDecoder(field reflect.StructField, index int, tagInfos []TagInfo, parentIdx []int, decodeFunc customizeDecodeFunc) ([]fieldDecoder, error) { + for idx, tagInfo := range tagInfos { + switch tagInfo.Key { + case pathTag: + tagInfos[idx].Getter = path + case formTag: + tagInfos[idx].Getter = postForm + case queryTag: + tagInfos[idx].Getter = query + case cookieTag: + tagInfos[idx].Getter = cookie + case headerTag: + tagInfos[idx].Getter = header + case jsonTag: + // do nothing + case rawBodyTag: + tagInfos[idx].Getter = rawBody + case fileNameTag: + // do nothing + default: + } + } + fieldType := field.Type + for field.Type.Kind() == reflect.Ptr { + fieldType = field.Type.Elem() + } + return []fieldDecoder{&customizedFieldTextDecoder{ + fieldInfo: fieldInfo{ + index: index, + parentIndex: parentIdx, + fieldName: field.Name, + tagInfos: tagInfos, + fieldType: fieldType, + }, + decodeFunc: decodeFunc, + }}, nil +} diff --git a/pkg/app/server/binding/decoder/decoder.go b/pkg/app/server/binding/decoder/decoder.go index 56924ed34..c675db37c 100644 --- a/pkg/app/server/binding/decoder/decoder.go +++ b/pkg/app/server/binding/decoder/decoder.go @@ -69,14 +69,8 @@ type fieldDecoder interface { Decode(req *bindRequest, params path1.PathParam, reqValue reflect.Value) error } -type CustomizedFieldDecoder interface { - CustomizedFieldDecode(req *protocol.Request, params path1.PathParam) error -} - type Decoder func(req *protocol.Request, params path1.PathParam, rv reflect.Value) error -var customizedFieldDecoderType = reflect.TypeOf((*CustomizedFieldDecoder)(nil)).Elem() - func GetReqDecoder(rt reflect.Type) (Decoder, error) { var decoders []fieldDecoder @@ -123,22 +117,17 @@ func getFieldDecoder(field reflect.StructField, index int, parentIdx []int, pare if field.Type.Kind() != reflect.Struct && field.Anonymous { return nil, nil } - if reflect.PtrTo(field.Type).Implements(customizedFieldDecoderType) { - return []fieldDecoder{&customizedFieldTextDecoder{ - fieldInfo: fieldInfo{ - index: index, - parentIndex: parentIdx, - fieldName: field.Name, - fieldType: field.Type, - }, - }}, nil - } fieldTagInfos, newParentJSONName := lookupFieldTags(field, parentJSONName) if len(fieldTagInfos) == 0 && EnableDefaultTag { fieldTagInfos = getDefaultFieldTags(field) } + // customized type decoder has the highest priority + if customizedFunc, exist := typeUnmarshalFuncs[field.Type]; exist { + return getCustomizedFieldDecoder(field, index, fieldTagInfos, parentIdx, customizedFunc) + } + if field.Type.Kind() == reflect.Slice || field.Type.Kind() == reflect.Array { return getSliceFieldDecoder(field, index, fieldTagInfos, parentIdx) } diff --git a/pkg/app/server/binding/decoder/map_type_decoder.go b/pkg/app/server/binding/decoder/map_type_decoder.go index b233a9aca..ff563c871 100644 --- a/pkg/app/server/binding/decoder/map_type_decoder.go +++ b/pkg/app/server/binding/decoder/map_type_decoder.go @@ -105,7 +105,7 @@ func (d *mapTypeFieldTextDecoder) Decode(req *bindRequest, params path1.PathPara ptrDepth++ } var vv reflect.Value - vv, err := stringToValue(t, text) + vv, err := stringToValue(t, text, req, params) if err != nil { return fmt.Errorf("unable to decode '%s' as %s: %w", text, d.fieldType.Name(), err) } diff --git a/pkg/app/server/binding/decoder/slice_type_decoder.go b/pkg/app/server/binding/decoder/slice_type_decoder.go index 1db08ad5d..13ecf612d 100644 --- a/pkg/app/server/binding/decoder/slice_type_decoder.go +++ b/pkg/app/server/binding/decoder/slice_type_decoder.go @@ -142,13 +142,24 @@ func (d *sliceTypeFieldTextDecoder) Decode(req *bindRequest, params path1.PathPa for idx, text := range texts { var vv reflect.Value - vv, err := stringToValue(t, text) + vv, err = stringToValue(t, text, req, params) if err != nil { - return err + break } field.Index(idx).Set(ReferenceValue(vv, ptrDepth)) } - reqValue.Field(d.index).Set(ReferenceValue(field, parentPtrDepth)) + if err != nil { + if !reqValue.Field(d.index).CanAddr() { + return err + } + // text[0] can be a complete json content for []Type. + err = hjson.Unmarshal(bytesconv.S2b(texts[0]), reqValue.Field(d.index).Addr().Interface()) + if err != nil { + return err + } + } else { + reqValue.Field(d.index).Set(ReferenceValue(field, parentPtrDepth)) + } return nil } @@ -204,8 +215,15 @@ func getSliceFieldDecoder(field reflect.StructField, index int, tagInfos []TagIn }}, nil } -func stringToValue(elemType reflect.Type, text string) (v reflect.Value, err error) { +func stringToValue(elemType reflect.Type, text string, req *bindRequest, params path1.PathParam) (v reflect.Value, err error) { v = reflect.New(elemType).Elem() + if customizedFunc, exist := typeUnmarshalFuncs[elemType]; exist { + val, err := customizedFunc(req.Req, params, text) + if err != nil { + return reflect.Value{}, err + } + return val, nil + } switch elemType.Kind() { case reflect.Struct: err = hjson.Unmarshal(bytesconv.S2b(text), v.Addr().Interface()) diff --git a/pkg/app/server/binding/decoder/struct_type_decoder.go b/pkg/app/server/binding/decoder/struct_type_decoder.go index 2f976af5f..071c21f40 100644 --- a/pkg/app/server/binding/decoder/struct_type_decoder.go +++ b/pkg/app/server/binding/decoder/struct_type_decoder.go @@ -80,7 +80,7 @@ func (d *structTypeFieldTextDecoder) Decode(req *bindRequest, params path1.PathP ptrDepth++ } var vv reflect.Value - vv, err := stringToValue(t, text) + vv, err := stringToValue(t, text, req, params) if err != nil { return fmt.Errorf("unable to decode '%s' as %s: %w", text, d.fieldType.Name(), err) } diff --git a/pkg/app/server/binding/tagexpr_bind_test.go b/pkg/app/server/binding/tagexpr_bind_test.go index 437822a54..0dcde177f 100644 --- a/pkg/app/server/binding/tagexpr_bind_test.go +++ b/pkg/app/server/binding/tagexpr_bind_test.go @@ -47,6 +47,7 @@ import ( "time" "github.com/cloudwego/hertz/pkg/common/test/assert" + "google.golang.org/protobuf/proto" ) func TestRawBody(t *testing.T) { @@ -631,45 +632,44 @@ func TestAuto(t *testing.T) { assert.DeepEqual(t, "d-from-form", recv.D) } -// FIXME: 自定义验证函数 & TIME 类型内置, 暂时先不做,低优 -//func TestTypeUnmarshal(t *testing.T) { -// type Recv struct { -// A time.Time `form:"t1"` -// B *time.Time `query:"t2"` -// C []time.Time `query:"t2"` -// } -// query := make(url.Values) -// query.Add("t2", "2019-09-04T14:05:24+08:00") -// query.Add("t2", "2019-09-04T18:05:24+08:00") -// form := make(url.Values) -// form.Add("t1", "2019-09-03T18:05:24+08:00") -// contentType, bodyReader := newFormBody2(form, nil) -// header := make(http.Header) -// header.Set("Content-Type", contentType) -// req := newRequest("http://localhost/?"+query.Encode(), header, nil, bodyReader) -// recv := new(Recv) -// -// err := DefaultBinder.Bind(req.Req, nil, recv) -// if err != nil { -// t.Error(err) -// } -// t1, err := time.Parse(time.RFC3339, "2019-09-03T18:05:24+08:00") -// if err != nil { -// t.Error(err) -// } -// assert.DeepEqual(t, t1, recv.A) -// t21, err := time.Parse(time.RFC3339, "2019-09-04T14:05:24+08:00") -// if err != nil { -// t.Error(err) -// } -// assert.DeepEqual(t, t21, *recv.B) -// t22, err := time.Parse(time.RFC3339, "2019-09-04T18:05:24+08:00") -// if err != nil { -// t.Error(err) -// } -// assert.DeepEqual(t, []time.Time{t21, t22}, recv.C) -// t.Logf("%v", recv) -//} +func TestTypeUnmarshal(t *testing.T) { + type Recv struct { + A time.Time `form:"t1"` + B *time.Time `query:"t2"` + C []time.Time `query:"t2"` + } + query := make(url.Values) + query.Add("t2", "2019-09-04T14:05:24+08:00") + query.Add("t2", "2019-09-04T18:05:24+08:00") + form := make(url.Values) + form.Add("t1", "2019-09-03T18:05:24+08:00") + contentType, bodyReader := newFormBody2(form, nil) + header := make(http.Header) + header.Set("Content-Type", contentType) + req := newRequest("http://localhost/?"+query.Encode(), header, nil, bodyReader) + recv := new(Recv) + + err := DefaultBinder.Bind(req.Req, nil, recv) + if err != nil { + t.Error(err) + } + t1, err := time.Parse(time.RFC3339, "2019-09-03T18:05:24+08:00") + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, t1, recv.A) + t21, err := time.Parse(time.RFC3339, "2019-09-04T14:05:24+08:00") + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, t21, *recv.B) + t22, err := time.Parse(time.RFC3339, "2019-09-04T18:05:24+08:00") + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, []time.Time{t21, t22}, recv.C) + t.Logf("%v", recv) +} // test: required func TestOption(t *testing.T) { @@ -795,7 +795,8 @@ func TestQueryStringIssue(t *testing.T) { t.Error(err) } assert.DeepEqual(t, "test", *recv.Name) - assert.DeepEqual(t, (*Timestamp)(nil), recv.T) + // DIFF: the type with customized decoder must be a non-nil value + //assert.DeepEqual(t, (*Timestamp)(nil), recv.T) } func TestQueryTypes(t *testing.T) { @@ -869,78 +870,99 @@ func TestNoTagIssue(t *testing.T) { assert.DeepEqual(t, 2, recv.B) } -// DIFF: go-tagexpr 会对保留 t.Q的结构体信息,而目前的实现不会 t.Q 做特殊处理,会直接拆开。有需求也可以加上 -//func TestRegTypeUnmarshal(t *testing.T) { -// type Q struct { -// A int -// B string -// } -// type T struct { -// Q Q `query:"q"` -// Qs []*Q `query:"qs"` -// } -// values := url.Values{} -// b, err := json.Marshal(Q{A: 2, B: "y"}) -// if err != nil { -// t.Error(err) -// } -// values.Add("q", string(b)) -// bs, _ := json.Marshal([]Q{{A: 1, B: "x"}, {A: 2, B: "y"}}) -// values.Add("qs", string(bs)) -// req := newRequest("http://localhost:8080/?"+values.Encode(), nil, nil, nil) -// recv := new(T) -// -// err = DefaultBinder.Bind(req.Req, nil, recv) -// if err != nil { -// t.Error(err) -// } -// assert.DeepEqual(t, 2, recv.Q.A) -// assert.DeepEqual(t, "y", recv.Q.B) -// assert.DeepEqual(t, 1, recv.Qs[0].A) -// assert.DeepEqual(t, "x", recv.Qs[0].B) -// assert.DeepEqual(t, 2, recv.Qs[1].A) -// assert.DeepEqual(t, "y", recv.Qs[1].B) -//} +func TestRegTypeUnmarshal(t *testing.T) { + type Q struct { + A int + B string + } + type T struct { + Q Q `query:"q"` + Qs []*Q `query:"qs"` + Qs2 ***[]*Q `query:"qs"` + } + values := url.Values{} + b, err := json.Marshal(Q{A: 2, B: "y"}) + if err != nil { + t.Error(err) + } + values.Add("q", string(b)) + bs, _ := json.Marshal([]Q{{A: 1, B: "x"}, {A: 2, B: "y"}}) + values.Add("qs", string(bs)) + req := newRequest("http://localhost:8080/?"+values.Encode(), nil, nil, nil) + recv := new(T) -//func TestPathnameBUG(t *testing.T) { -// type Currency struct { -// CurrencyName *string `form:"currency_name,required" json:"currency_name,required" protobuf:"bytes,1,req,name=currency_name,json=currencyName" query:"currency_name,required"` -// CurrencySymbol *string `form:"currency_symbol,required" json:"currency_symbol,required" protobuf:"bytes,2,req,name=currency_symbol,json=currencySymbol" query:"currency_symbol,required"` -// SymbolPosition *int32 `form:"symbol_position,required" json:"symbol_position,required" protobuf:"varint,3,req,name=symbol_position,json=symbolPosition" query:"symbol_position,required"` -// DecimalPlaces *int32 `form:"decimal_places,required" json:"decimal_places,required" protobuf:"varint,4,req,name=decimal_places,json=decimalPlaces" query:"decimal_places,required"` // 56x56 -// DecimalSymbol *string `form:"decimal_symbol,required" json:"decimal_symbol,required" protobuf:"bytes,5,req,name=decimal_symbol,json=decimalSymbol" query:"decimal_symbol,required"` -// Separator *string `form:"separator,required" json:"separator,required" protobuf:"bytes,6,req,name=separator" query:"separator,required"` -// SeparatorIndex *string `form:"separator_index,required" json:"separator_index,required" protobuf:"bytes,7,req,name=separator_index,json=separatorIndex" query:"separator_index,required"` -// Between *string `form:"between,required" json:"between,required" protobuf:"bytes,8,req,name=between" query:"between,required"` -// MinPrice *string `form:"min_price" json:"min_price,omitempty" protobuf:"bytes,9,opt,name=min_price,json=minPrice" query:"min_price"` -// MaxPrice *string `form:"max_price" json:"max_price,omitempty" protobuf:"bytes,10,opt,name=max_price,json=maxPrice" query:"max_price"` -// } -// -// type CurrencyData struct { -// Amount *string `form:"amount,required" json:"amount,required" protobuf:"bytes,1,req,name=amount" query:"amount,required"` -// Currency *Currency `form:"currency,required" json:"currency,required" protobuf:"bytes,2,req,name=currency" query:"currency,required"` -// } -// -// type ExchangeCurrencyRequest struct { -// PromotionRegion *string `form:"promotion_region,required" json:"promotion_region,required" protobuf:"bytes,1,req,name=promotion_region,json=promotionRegion" query:"promotion_region,required"` -// Currency *CurrencyData `form:"currency,required" json:"currency,required" protobuf:"bytes,2,req,name=currency" query:"currency,required"` -// Version *int32 `json:"version,omitempty" path:"version" protobuf:"varint,100,opt,name=version"` -// } -// -// z := &ExchangeCurrencyRequest{} -// v := ameda.InitSampleValue(reflect.TypeOf(z), 10).Interface().(*ExchangeCurrencyRequest) -// b, err := json.MarshalIndent(v, "", " ") -// t.Log(string(b)) -// header := make(http.Header) -// header.Set("Content-Type", "application/json;charset=utf-8") -// req := newRequest("http://localhost", header, nil, bytes.NewReader(b)) -// recv := new(ExchangeCurrencyRequest) -// -// err = DefaultBinder.Bind(req.Req, nil, recv) -// if err != nil { -// -// assert.DeepEqual(t, v, recv) -//} + EnableStructFieldResolve(true) + defer func() { + EnableStructFieldResolve(false) + }() + err = DefaultBinder.Bind(req.Req, nil, recv) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, 2, recv.Q.A) + assert.DeepEqual(t, "y", recv.Q.B) + assert.DeepEqual(t, 1, recv.Qs[0].A) + assert.DeepEqual(t, "x", recv.Qs[0].B) + assert.DeepEqual(t, 2, recv.Qs[1].A) + assert.DeepEqual(t, "y", recv.Qs[1].B) + assert.DeepEqual(t, 1, (***recv.Qs2)[0].A) + assert.DeepEqual(t, "x", (***recv.Qs2)[0].B) + assert.DeepEqual(t, 2, (***recv.Qs2)[1].A) + assert.DeepEqual(t, "y", (***recv.Qs2)[1].B) +} + +func TestPathnameBUG(t *testing.T) { + type Currency struct { + CurrencyName *string `form:"currency_name,required" json:"currency_name,required" protobuf:"bytes,1,req,name=currency_name,json=currencyName" query:"currency_name,required"` + CurrencySymbol *string `form:"currency_symbol,required" json:"currency_symbol,required" protobuf:"bytes,2,req,name=currency_symbol,json=currencySymbol" query:"currency_symbol,required"` + SymbolPosition *int32 `form:"symbol_position,required" json:"symbol_position,required" protobuf:"varint,3,req,name=symbol_position,json=symbolPosition" query:"symbol_position,required"` + DecimalPlaces *int32 `form:"decimal_places,required" json:"decimal_places,required" protobuf:"varint,4,req,name=decimal_places,json=decimalPlaces" query:"decimal_places,required"` // 56x56 + DecimalSymbol *string `form:"decimal_symbol,required" json:"decimal_symbol,required" protobuf:"bytes,5,req,name=decimal_symbol,json=decimalSymbol" query:"decimal_symbol,required"` + Separator *string `form:"separator,required" json:"separator,required" protobuf:"bytes,6,req,name=separator" query:"separator,required"` + SeparatorIndex *string `form:"separator_index,required" json:"separator_index,required" protobuf:"bytes,7,req,name=separator_index,json=separatorIndex" query:"separator_index,required"` + Between *string `form:"between,required" json:"between,required" protobuf:"bytes,8,req,name=between" query:"between,required"` + MinPrice *string `form:"min_price" json:"min_price,omitempty" protobuf:"bytes,9,opt,name=min_price,json=minPrice" query:"min_price"` + MaxPrice *string `form:"max_price" json:"max_price,omitempty" protobuf:"bytes,10,opt,name=max_price,json=maxPrice" query:"max_price"` + } + + type CurrencyData struct { + Amount *string `form:"amount,required" json:"amount,required" protobuf:"bytes,1,req,name=amount" query:"amount,required"` + Currency *Currency `form:"currency,required" json:"currency,required" protobuf:"bytes,2,req,name=currency" query:"currency,required"` + } + + type ExchangeCurrencyRequest struct { + PromotionRegion *string `form:"promotion_region,required" json:"promotion_region,required" protobuf:"bytes,1,req,name=promotion_region,json=promotionRegion" query:"promotion_region,required"` + Currency *CurrencyData `form:"currency,required" json:"currency,required" protobuf:"bytes,2,req,name=currency" query:"currency,required"` + Version *int32 `json:"version,omitempty" path:"version" protobuf:"varint,100,opt,name=version"` + } + + z := new(ExchangeCurrencyRequest) + z.Currency = new(CurrencyData) + z.Currency.Currency = new(Currency) + z.PromotionRegion = proto.String("?") + z.Version = proto.Int32(-32) + z.Currency.Amount = proto.String("?") + z.Currency.Currency.CurrencyName = proto.String("?") + z.Currency.Currency.CurrencySymbol = proto.String("?") + z.Currency.Currency.SymbolPosition = proto.Int32(-32) + z.Currency.Currency.DecimalPlaces = proto.Int32(-32) + z.Currency.Currency.DecimalSymbol = proto.String("?") + z.Currency.Currency.Separator = proto.String("?") + z.Currency.Currency.Between = proto.String("?") + z.Currency.Currency.MinPrice = proto.String("?") + z.Currency.Currency.MaxPrice = proto.String("?") + + b, err := json.MarshalIndent(z, "", " ") + header := make(http.Header) + header.Set("Content-Type", "application/json;charset=utf-8") + req := newRequest("http://localhost", header, nil, bytes.NewReader(b)) + recv := new(ExchangeCurrencyRequest) + + err = DefaultBinder.Bind(req.Req, nil, recv) + if err != nil { + t.Error(err) + } +} // test: required func TestPathnameBUG2(t *testing.T) { @@ -1128,7 +1150,7 @@ func TestIssue26(t *testing.T) { assert.DeepEqual(t, recv, recv2) } -// FIXME: json unmarshal 后,默认值的问题 +// FIXME: after 'json unmarshal', the default value will change it //func TestDefault2(t *testing.T) { // type Recv struct { // X **struct { From d498b122fc67e85246099122bfc6ac6d9d1ee0e3 Mon Sep 17 00:00:00 2001 From: fgy Date: Tue, 23 May 2023 21:49:29 +0800 Subject: [PATCH 38/91] optimize: optimize performance --- .../binding/decoder/base_type_decoder.go | 13 +- .../decoder/customized_type_decoder.go | 11 +- pkg/app/server/binding/decoder/decoder.go | 5 +- pkg/app/server/binding/decoder/getter.go | 128 ++++++------- .../server/binding/decoder/getter_slice.go | 170 ++++++++++++++++++ .../binding/decoder/map_type_decoder.go | 11 +- .../binding/decoder/slice_type_decoder.go | 8 +- .../binding/decoder/struct_type_decoder.go | 11 +- pkg/app/server/binding/decoder/tag.go | 17 +- 9 files changed, 279 insertions(+), 95 deletions(-) create mode 100644 pkg/app/server/binding/decoder/getter_slice.go diff --git a/pkg/app/server/binding/decoder/base_type_decoder.go b/pkg/app/server/binding/decoder/base_type_decoder.go index 808b1ef58..48a7d309f 100644 --- a/pkg/app/server/binding/decoder/base_type_decoder.go +++ b/pkg/app/server/binding/decoder/base_type_decoder.go @@ -52,7 +52,7 @@ type fieldInfo struct { index int parentIndex []int fieldName string - tagInfos []TagInfo // query,param,header,respHeader ... + tagInfos []TagInfo // querySlice,param,headerSlice,respHeader ... fieldType reflect.Type // can not be pointer type } @@ -81,10 +81,9 @@ func (d *baseTypeFieldTextDecoder) Decode(req *bindRequest, params path1.PathPar if tagInfo.Key == headerTag { tagInfo.Value = utils.GetNormalizeHeaderKey(tagInfo.Value, req.Req.Header.IsDisableNormalizing()) } - ret := tagInfo.Getter(req, params, tagInfo.Value) + text = tagInfo.Getter(req, params, tagInfo.Value) defaultValue = tagInfo.Default - if len(ret) != 0 { - text = ret[0] + if len(text) != 0 { err = nil break } @@ -134,18 +133,24 @@ func getBaseTypeTextDecoder(field reflect.StructField, index int, tagInfos []Tag for idx, tagInfo := range tagInfos { switch tagInfo.Key { case pathTag: + tagInfos[idx].SliceGetter = pathSlice tagInfos[idx].Getter = path case formTag: + tagInfos[idx].SliceGetter = postFormSlice tagInfos[idx].Getter = postForm case queryTag: + tagInfos[idx].SliceGetter = querySlice tagInfos[idx].Getter = query case cookieTag: + tagInfos[idx].SliceGetter = cookieSlice tagInfos[idx].Getter = cookie case headerTag: + tagInfos[idx].SliceGetter = headerSlice tagInfos[idx].Getter = header case jsonTag: // do nothing case rawBodyTag: + tagInfos[idx].SliceGetter = rawBodySlice tagInfos[idx].Getter = rawBody case fileNameTag: // do nothing diff --git a/pkg/app/server/binding/decoder/customized_type_decoder.go b/pkg/app/server/binding/decoder/customized_type_decoder.go index 1588bd169..2255e573d 100644 --- a/pkg/app/server/binding/decoder/customized_type_decoder.go +++ b/pkg/app/server/binding/decoder/customized_type_decoder.go @@ -116,10 +116,9 @@ func (d *customizedFieldTextDecoder) Decode(req *bindRequest, params path1.PathP if tagInfo.Key == headerTag { tagInfo.Value = utils.GetNormalizeHeaderKey(tagInfo.Value, req.Req.Header.IsDisableNormalizing()) } - ret := tagInfo.Getter(req, params, tagInfo.Value) + text = tagInfo.Getter(req, params, tagInfo.Value) defaultValue = tagInfo.Default - if len(ret) != 0 { - text = ret[0] + if len(text) != 0 { break } } @@ -153,18 +152,24 @@ func getCustomizedFieldDecoder(field reflect.StructField, index int, tagInfos [] for idx, tagInfo := range tagInfos { switch tagInfo.Key { case pathTag: + tagInfos[idx].SliceGetter = pathSlice tagInfos[idx].Getter = path case formTag: + tagInfos[idx].SliceGetter = postFormSlice tagInfos[idx].Getter = postForm case queryTag: + tagInfos[idx].SliceGetter = querySlice tagInfos[idx].Getter = query case cookieTag: + tagInfos[idx].SliceGetter = cookieSlice tagInfos[idx].Getter = cookie case headerTag: + tagInfos[idx].SliceGetter = headerSlice tagInfos[idx].Getter = header case jsonTag: // do nothing case rawBodyTag: + tagInfos[idx].SliceGetter = rawBodySlice tagInfos[idx].Getter = rawBody case fileNameTag: // do nothing diff --git a/pkg/app/server/binding/decoder/decoder.go b/pkg/app/server/binding/decoder/decoder.go index c675db37c..c7cad2dca 100644 --- a/pkg/app/server/binding/decoder/decoder.go +++ b/pkg/app/server/binding/decoder/decoder.go @@ -42,13 +42,12 @@ package decoder import ( "fmt" + path1 "github.com/cloudwego/hertz/pkg/app/server/binding/path" + "github.com/cloudwego/hertz/pkg/protocol" "mime/multipart" "net/http" "net/url" "reflect" - - path1 "github.com/cloudwego/hertz/pkg/app/server/binding/path" - "github.com/cloudwego/hertz/pkg/protocol" ) var ( diff --git a/pkg/app/server/binding/decoder/getter.go b/pkg/app/server/binding/decoder/getter.go index bd638fd77..bbc7de2ca 100644 --- a/pkg/app/server/binding/decoder/getter.go +++ b/pkg/app/server/binding/decoder/getter.go @@ -41,130 +41,118 @@ package decoder import ( - "net/http" - "net/url" - path1 "github.com/cloudwego/hertz/pkg/app/server/binding/path" ) -type getter func(req *bindRequest, params path1.PathParam, key string, defaultValue ...string) (ret []string) +type getter func(req *bindRequest, params path1.PathParam, key string, defaultValue ...string) (ret string) -func path(req *bindRequest, params path1.PathParam, key string, defaultValue ...string) (ret []string) { - var value string +func path(req *bindRequest, params path1.PathParam, key string, defaultValue ...string) (ret string) { if params != nil { - value, _ = params.Get(key) + ret, _ = params.Get(key) } - if len(value) == 0 && len(defaultValue) != 0 { - value = defaultValue[0] - } - if len(value) != 0 { - ret = append(ret, value) + if len(ret) == 0 && len(defaultValue) != 0 { + ret = defaultValue[0] } - - return + return ret } -func postForm(req *bindRequest, params path1.PathParam, key string, defaultValue ...string) (ret []string) { - if req.Form == nil { - req.Form = make(url.Values) - req.Req.PostArgs().VisitAll(func(formKey, value []byte) { - keyStr := string(formKey) - values := req.Form[keyStr] - values = append(values, string(value)) - req.Form[keyStr] = values - }) +func postForm(req *bindRequest, params path1.PathParam, key string, defaultValue ...string) (ret string) { + if req.Form != nil { + if val, exist := req.Form[key]; exist { + ret = val[0] + } + } else { + if val := req.Req.PostArgs().Peek(key); val != nil { + ret = string(val) + } } - ret = req.Form[key] if len(ret) > 0 { return } - if req.MultipartForm == nil { - req.MultipartForm = make(url.Values) + if req.MultipartForm != nil { + if val, exist := req.MultipartForm[key]; exist { + ret = val[0] + } + } else { mf, err := req.Req.MultipartForm() if err == nil && mf.Value != nil { for k, v := range mf.Value { - if len(v) > 0 { - req.MultipartForm[k] = v + if k == key && len(v) > 0 { + ret = v[0] } } } } - ret = req.MultipartForm[key] - if len(ret) > 0 { - return - } if len(ret) == 0 && len(defaultValue) != 0 { - ret = append(ret, defaultValue...) + ret = defaultValue[0] } return } -func query(req *bindRequest, params path1.PathParam, key string, defaultValue ...string) (ret []string) { - if req.Query == nil { - req.Query = make(url.Values) - req.Req.URI().QueryArgs().VisitAll(func(queryKey, value []byte) { - keyStr := string(queryKey) - values := req.Query[keyStr] - values = append(values, string(value)) - req.Query[keyStr] = values - }) +func query(req *bindRequest, params path1.PathParam, key string, defaultValue ...string) (ret string) { + if req.Query != nil { + if val, exist := req.Query[key]; exist { + ret = val[0] + } + } else { + if val := req.Req.URI().QueryArgs().Peek(key); val != nil { + ret = string(val) + } } - ret = req.Query[key] if len(ret) == 0 && len(defaultValue) != 0 { - ret = append(ret, defaultValue...) + ret = defaultValue[0] } return } -func cookie(req *bindRequest, params path1.PathParam, key string, defaultValue ...string) (ret []string) { - if len(req.Cookie) == 0 { - req.Req.Header.VisitAllCookie(func(cookieKey, value []byte) { - req.Cookie = append(req.Cookie, &http.Cookie{ - Name: string(cookieKey), - Value: string(value), - }) - }) - } - for _, c := range req.Cookie { - if c.Name == key { - ret = append(ret, c.Value) +func cookie(req *bindRequest, params path1.PathParam, key string, defaultValue ...string) (ret string) { + if len(req.Cookie) != 0 { + for _, c := range req.Cookie { + if c.Name == key { + ret = c.Value + break + } + } + } else { + if val := req.Req.Header.Cookie(key); val != nil { + ret = string(val) } } + if len(ret) == 0 && len(defaultValue) != 0 { - ret = append(ret, defaultValue...) + ret = defaultValue[0] } return } -func header(req *bindRequest, params path1.PathParam, key string, defaultValue ...string) (ret []string) { - if req.Header == nil { - req.Header = make(http.Header) - req.Req.Header.VisitAll(func(headerKey, value []byte) { - keyStr := string(headerKey) - values := req.Header[keyStr] - values = append(values, string(value)) - req.Header[keyStr] = values - }) +func header(req *bindRequest, params path1.PathParam, key string, defaultValue ...string) (ret string) { + if req.Header != nil { + if val, exist := req.Header[key]; exist { + ret = val[0] + } + } else { + if val := req.Req.Header.Peek(key); val != nil { + ret = string(val) + } } - ret = req.Header[key] if len(ret) == 0 && len(defaultValue) != 0 { - ret = append(ret, defaultValue...) + ret = defaultValue[0] } return } -func rawBody(req *bindRequest, params path1.PathParam, key string, defaultValue ...string) (ret []string) { +func rawBody(req *bindRequest, params path1.PathParam, key string, defaultValue ...string) (ret string) { if req.Req.Header.ContentLength() > 0 { - ret = append(ret, string(req.Req.Body())) + ret = string(req.Req.Body()) } return } diff --git a/pkg/app/server/binding/decoder/getter_slice.go b/pkg/app/server/binding/decoder/getter_slice.go new file mode 100644 index 000000000..425bb529e --- /dev/null +++ b/pkg/app/server/binding/decoder/getter_slice.go @@ -0,0 +1,170 @@ +/* + * Copyright 2022 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * MIT License + * + * Copyright (c) 2019-present Fenny and Contributors + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + * + * This file may have been modified by CloudWeGo authors. All CloudWeGo + * Modifications are Copyright 2022 CloudWeGo Authors + */ + +package decoder + +import ( + "net/http" + "net/url" + + path1 "github.com/cloudwego/hertz/pkg/app/server/binding/path" +) + +type sliceGetter func(req *bindRequest, params path1.PathParam, key string, defaultValue ...string) (ret []string) + +func pathSlice(req *bindRequest, params path1.PathParam, key string, defaultValue ...string) (ret []string) { + var value string + if params != nil { + value, _ = params.Get(key) + } + + if len(value) == 0 && len(defaultValue) != 0 { + value = defaultValue[0] + } + if len(value) != 0 { + ret = append(ret, value) + } + + return +} + +func postFormSlice(req *bindRequest, params path1.PathParam, key string, defaultValue ...string) (ret []string) { + if req.Form == nil { + req.Form = make(url.Values) + req.Req.PostArgs().VisitAll(func(formKey, value []byte) { + keyStr := string(formKey) + values := req.Form[keyStr] + values = append(values, string(value)) + req.Form[keyStr] = values + }) + } + ret = req.Form[key] + if len(ret) > 0 { + return + } + + if req.MultipartForm == nil { + req.MultipartForm = make(url.Values) + mf, err := req.Req.MultipartForm() + if err == nil && mf.Value != nil { + for k, v := range mf.Value { + if len(v) > 0 { + req.MultipartForm[k] = v + } + } + } + } + ret = req.MultipartForm[key] + if len(ret) > 0 { + return + } + + if len(ret) == 0 && len(defaultValue) != 0 { + ret = append(ret, defaultValue...) + } + + return +} + +func querySlice(req *bindRequest, params path1.PathParam, key string, defaultValue ...string) (ret []string) { + if req.Query == nil { + req.Query = make(url.Values, 100) + req.Req.URI().QueryArgs().VisitAll(func(queryKey, value []byte) { + keyStr := string(queryKey) + values := req.Query[keyStr] + values = append(values, string(value)) + req.Query[keyStr] = values + }) + } + + ret = req.Query[key] + if len(ret) == 0 && len(defaultValue) != 0 { + ret = append(ret, defaultValue...) + } + + return +} + +func cookieSlice(req *bindRequest, params path1.PathParam, key string, defaultValue ...string) (ret []string) { + if len(req.Cookie) == 0 { + req.Req.Header.VisitAllCookie(func(cookieKey, value []byte) { + req.Cookie = append(req.Cookie, &http.Cookie{ + Name: string(cookieKey), + Value: string(value), + }) + }) + } + for _, c := range req.Cookie { + if c.Name == key { + ret = append(ret, c.Value) + } + } + if len(ret) == 0 && len(defaultValue) != 0 { + ret = append(ret, defaultValue...) + } + + return +} + +func headerSlice(req *bindRequest, params path1.PathParam, key string, defaultValue ...string) (ret []string) { + if req.Header == nil { + req.Header = make(http.Header) + req.Req.Header.VisitAll(func(headerKey, value []byte) { + keyStr := string(headerKey) + values := req.Header[keyStr] + values = append(values, string(value)) + req.Header[keyStr] = values + }) + } + + ret = req.Header[key] + if len(ret) == 0 && len(defaultValue) != 0 { + ret = append(ret, defaultValue...) + } + + return +} + +func rawBodySlice(req *bindRequest, params path1.PathParam, key string, defaultValue ...string) (ret []string) { + if req.Req.Header.ContentLength() > 0 { + ret = append(ret, string(req.Req.Body())) + } + return +} diff --git a/pkg/app/server/binding/decoder/map_type_decoder.go b/pkg/app/server/binding/decoder/map_type_decoder.go index ff563c871..47a5b5885 100644 --- a/pkg/app/server/binding/decoder/map_type_decoder.go +++ b/pkg/app/server/binding/decoder/map_type_decoder.go @@ -74,10 +74,9 @@ func (d *mapTypeFieldTextDecoder) Decode(req *bindRequest, params path1.PathPara if tagInfo.Key == headerTag { tagInfo.Value = utils.GetNormalizeHeaderKey(tagInfo.Value, req.Req.Header.IsDisableNormalizing()) } - ret := tagInfo.Getter(req, params, tagInfo.Value) + text = tagInfo.Getter(req, params, tagInfo.Value) defaultValue = tagInfo.Default - if len(ret) != 0 { - text = ret[0] + if len(text) != 0 { err = nil break } @@ -125,18 +124,24 @@ func getMapTypeTextDecoder(field reflect.StructField, index int, tagInfos []TagI for idx, tagInfo := range tagInfos { switch tagInfo.Key { case pathTag: + tagInfos[idx].SliceGetter = pathSlice tagInfos[idx].Getter = path case formTag: + tagInfos[idx].SliceGetter = postFormSlice tagInfos[idx].Getter = postForm case queryTag: + tagInfos[idx].SliceGetter = querySlice tagInfos[idx].Getter = query case cookieTag: + tagInfos[idx].SliceGetter = cookieSlice tagInfos[idx].Getter = cookie case headerTag: + tagInfos[idx].SliceGetter = headerSlice tagInfos[idx].Getter = header case jsonTag: // do nothing case rawBodyTag: + tagInfos[idx].SliceGetter = rawBodySlice tagInfos[idx].Getter = rawBody case fileNameTag: // do nothing diff --git a/pkg/app/server/binding/decoder/slice_type_decoder.go b/pkg/app/server/binding/decoder/slice_type_decoder.go index 13ecf612d..0cbac6f47 100644 --- a/pkg/app/server/binding/decoder/slice_type_decoder.go +++ b/pkg/app/server/binding/decoder/slice_type_decoder.go @@ -80,7 +80,7 @@ func (d *sliceTypeFieldTextDecoder) Decode(req *bindRequest, params path1.PathPa if tagInfo.Key == rawBodyTag { bindRawBody = true } - texts = tagInfo.Getter(req, params, tagInfo.Value) + texts = tagInfo.SliceGetter(req, params, tagInfo.Value) // todo: array/slice default value defaultValue = tagInfo.Default if len(texts) != 0 { @@ -175,18 +175,24 @@ func getSliceFieldDecoder(field reflect.StructField, index int, tagInfos []TagIn for idx, tagInfo := range tagInfos { switch tagInfo.Key { case pathTag: + tagInfos[idx].SliceGetter = pathSlice tagInfos[idx].Getter = path case formTag: + tagInfos[idx].SliceGetter = postFormSlice tagInfos[idx].Getter = postForm case queryTag: + tagInfos[idx].SliceGetter = querySlice tagInfos[idx].Getter = query case cookieTag: + tagInfos[idx].SliceGetter = cookieSlice tagInfos[idx].Getter = cookie case headerTag: + tagInfos[idx].SliceGetter = headerSlice tagInfos[idx].Getter = header case jsonTag: // do nothing case rawBodyTag: + tagInfos[idx].SliceGetter = rawBodySlice tagInfos[idx].Getter = rawBody case fileNameTag: // do nothing diff --git a/pkg/app/server/binding/decoder/struct_type_decoder.go b/pkg/app/server/binding/decoder/struct_type_decoder.go index 071c21f40..3ee6a1f83 100644 --- a/pkg/app/server/binding/decoder/struct_type_decoder.go +++ b/pkg/app/server/binding/decoder/struct_type_decoder.go @@ -50,10 +50,9 @@ func (d *structTypeFieldTextDecoder) Decode(req *bindRequest, params path1.PathP if tagInfo.Key == headerTag { tagInfo.Value = utils.GetNormalizeHeaderKey(tagInfo.Value, req.Req.Header.IsDisableNormalizing()) } - ret := tagInfo.Getter(req, params, tagInfo.Value) + text = tagInfo.Getter(req, params, tagInfo.Value) defaultValue = tagInfo.Default - if len(ret) != 0 { - text = ret[0] + if len(text) != 0 { err = nil break } @@ -100,18 +99,24 @@ func getStructTypeFieldDecoder(field reflect.StructField, index int, tagInfos [] for idx, tagInfo := range tagInfos { switch tagInfo.Key { case pathTag: + tagInfos[idx].SliceGetter = pathSlice tagInfos[idx].Getter = path case formTag: + tagInfos[idx].SliceGetter = postFormSlice tagInfos[idx].Getter = postForm case queryTag: + tagInfos[idx].SliceGetter = querySlice tagInfos[idx].Getter = query case cookieTag: + tagInfos[idx].SliceGetter = cookieSlice tagInfos[idx].Getter = cookie case headerTag: + tagInfos[idx].SliceGetter = headerSlice tagInfos[idx].Getter = header case jsonTag: // do nothing case rawBodyTag: + tagInfos[idx].SliceGetter = rawBodySlice tagInfos[idx].Getter = rawBody case fileNameTag: // do nothing diff --git a/pkg/app/server/binding/decoder/tag.go b/pkg/app/server/binding/decoder/tag.go index 20d893740..0c53791f8 100644 --- a/pkg/app/server/binding/decoder/tag.go +++ b/pkg/app/server/binding/decoder/tag.go @@ -41,14 +41,15 @@ const ( ) type TagInfo struct { - Key string - Value string - JSONName string - Required bool - Skip bool - Default string - Options []string - Getter getter + Key string + Value string + JSONName string + Required bool + Skip bool + Default string + Options []string + Getter getter + SliceGetter sliceGetter } func head(str, sep string) (head, tail string) { From 46bda3a941c00fb6cf76b7e6a514163902ec3a7a Mon Sep 17 00:00:00 2001 From: fgy Date: Wed, 24 May 2023 15:58:00 +0800 Subject: [PATCH 39/91] optimize: optimize performence --- pkg/app/server/binding/binder.go | 4 +- pkg/app/server/binding/binder_test.go | 13 +++---- pkg/app/server/binding/config.go | 6 +-- .../binding/decoder/base_type_decoder.go | 4 +- .../decoder/customized_type_decoder.go | 8 ++-- pkg/app/server/binding/decoder/decoder.go | 34 ++++++++++++---- pkg/app/server/binding/decoder/getter.go | 16 ++++---- .../server/binding/decoder/gjson_required.go | 4 +- .../binding/decoder/map_type_decoder.go | 4 +- .../binding/decoder/multipart_file_decoder.go | 6 +-- .../{getter_slice.go => slice_getter.go} | 16 ++++---- .../binding/decoder/slice_type_decoder.go | 6 +-- .../binding/decoder/struct_type_decoder.go | 4 +- pkg/app/server/binding/default.go | 4 +- pkg/app/server/binding/path/path.go | 22 ----------- pkg/app/server/binding/tagexpr_bind_test.go | 39 ++++++++++++++++++- 16 files changed, 110 insertions(+), 80 deletions(-) rename pkg/app/server/binding/decoder/{getter_slice.go => slice_getter.go} (84%) delete mode 100644 pkg/app/server/binding/path/path.go diff --git a/pkg/app/server/binding/binder.go b/pkg/app/server/binding/binder.go index 09ed5e49f..c99d95adb 100644 --- a/pkg/app/server/binding/binder.go +++ b/pkg/app/server/binding/binder.go @@ -41,13 +41,13 @@ package binding import ( - "github.com/cloudwego/hertz/pkg/app/server/binding/path" "github.com/cloudwego/hertz/pkg/protocol" + "github.com/cloudwego/hertz/pkg/route/param" ) type Binder interface { Name() string - Bind(*protocol.Request, path.PathParam, interface{}) error + Bind(*protocol.Request, param.Params, interface{}) error } var DefaultBinder Binder = &defaultBinder{} diff --git a/pkg/app/server/binding/binder_test.go b/pkg/app/server/binding/binder_test.go index ec4e60e13..8a82fb80d 100644 --- a/pkg/app/server/binding/binder_test.go +++ b/pkg/app/server/binding/binder_test.go @@ -46,7 +46,6 @@ import ( "reflect" "testing" - "github.com/cloudwego/hertz/pkg/app/server/binding/path" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/protocol" req2 "github.com/cloudwego/hertz/pkg/protocol/http1/req" @@ -499,7 +498,7 @@ func TestBind_CustomizedTypeDecode(t *testing.T) { F ***CustomizedDecode } - err := RegTypeUnmarshal(reflect.TypeOf(CustomizedDecode{}), func(req *protocol.Request, params path.PathParam, text string) (reflect.Value, error) { + err := RegTypeUnmarshal(reflect.TypeOf(CustomizedDecode{}), func(req *protocol.Request, params param.Params, text string) (reflect.Value, error) { q1 := req.URI().QueryArgs().Peek("a") if len(q1) == 0 { return reflect.Value{}, fmt.Errorf("can be nil") @@ -670,14 +669,14 @@ func TestBind_FileSliceBind(t *testing.T) { func TestBind_AnonymousField(t *testing.T) { type nest struct { - n1 string `query:"n1"` // bind default value - N2 ***string `query:"n2"` // bind n2 value - string `query:"n3"` // bind default value + n1 string `query:"n1"` // bind default value + N2 ***string `query:"n2"` // bind n2 value + string `query:"n3"` // bind default value } var s struct { - s1 int `query:"s1"` // bind default value - int `query:"s2"` // bind default value + s1 int `query:"s1"` // bind default value + int `query:"s2"` // bind default value nest } req := newMockRequest(). diff --git a/pkg/app/server/binding/config.go b/pkg/app/server/binding/config.go index d9dd54e12..49c3a15a0 100644 --- a/pkg/app/server/binding/config.go +++ b/pkg/app/server/binding/config.go @@ -21,9 +21,9 @@ import ( "reflect" "github.com/cloudwego/hertz/pkg/app/server/binding/decoder" - path1 "github.com/cloudwego/hertz/pkg/app/server/binding/path" hjson "github.com/cloudwego/hertz/pkg/common/json" "github.com/cloudwego/hertz/pkg/protocol" + "github.com/cloudwego/hertz/pkg/route/param" ) // ResetJSONUnmarshaler reset the JSON Unmarshal function. @@ -49,11 +49,11 @@ func EnableStructFieldResolve(b bool) { } // RegTypeUnmarshal registers customized type unmarshaler. -func RegTypeUnmarshal(t reflect.Type, fn func(req *protocol.Request, params path1.PathParam, text string) (reflect.Value, error)) error { +func RegTypeUnmarshal(t reflect.Type, fn func(req *protocol.Request, params param.Params, text string) (reflect.Value, error)) error { return decoder.RegTypeUnmarshal(t, fn) } // MustRegTypeUnmarshal registers customized type unmarshaler. It will panic if exist error. -func MustRegTypeUnmarshal(t reflect.Type, fn func(req *protocol.Request, params path1.PathParam, text string) (reflect.Value, error)) { +func MustRegTypeUnmarshal(t reflect.Type, fn func(req *protocol.Request, params param.Params, text string) (reflect.Value, error)) { decoder.MustRegTypeUnmarshal(t, fn) } diff --git a/pkg/app/server/binding/decoder/base_type_decoder.go b/pkg/app/server/binding/decoder/base_type_decoder.go index 48a7d309f..c002c6ca7 100644 --- a/pkg/app/server/binding/decoder/base_type_decoder.go +++ b/pkg/app/server/binding/decoder/base_type_decoder.go @@ -44,8 +44,8 @@ import ( "fmt" "reflect" - path1 "github.com/cloudwego/hertz/pkg/app/server/binding/path" "github.com/cloudwego/hertz/pkg/common/utils" + "github.com/cloudwego/hertz/pkg/route/param" ) type fieldInfo struct { @@ -61,7 +61,7 @@ type baseTypeFieldTextDecoder struct { decoder TextDecoder } -func (d *baseTypeFieldTextDecoder) Decode(req *bindRequest, params path1.PathParam, reqValue reflect.Value) error { +func (d *baseTypeFieldTextDecoder) Decode(req *bindRequest, params param.Params, reqValue reflect.Value) error { var err error var text string var defaultValue string diff --git a/pkg/app/server/binding/decoder/customized_type_decoder.go b/pkg/app/server/binding/decoder/customized_type_decoder.go index 2255e573d..99bdbbdbb 100644 --- a/pkg/app/server/binding/decoder/customized_type_decoder.go +++ b/pkg/app/server/binding/decoder/customized_type_decoder.go @@ -45,13 +45,13 @@ import ( "reflect" "time" - path1 "github.com/cloudwego/hertz/pkg/app/server/binding/path" "github.com/cloudwego/hertz/pkg/common/utils" "github.com/cloudwego/hertz/pkg/protocol" + "github.com/cloudwego/hertz/pkg/route/param" ) func init() { - MustRegTypeUnmarshal(reflect.TypeOf(time.Time{}), func(req *protocol.Request, params path1.PathParam, text string) (reflect.Value, error) { + MustRegTypeUnmarshal(reflect.TypeOf(time.Time{}), func(req *protocol.Request, params param.Params, text string) (reflect.Value, error) { if text == "" { return reflect.ValueOf(time.Time{}), nil } @@ -63,7 +63,7 @@ func init() { }) } -type customizeDecodeFunc func(req *protocol.Request, params path1.PathParam, text string) (reflect.Value, error) +type customizeDecodeFunc func(req *protocol.Request, params param.Params, text string) (reflect.Value, error) var typeUnmarshalFuncs = make(map[reflect.Type]customizeDecodeFunc) @@ -105,7 +105,7 @@ type customizedFieldTextDecoder struct { decodeFunc customizeDecodeFunc } -func (d *customizedFieldTextDecoder) Decode(req *bindRequest, params path1.PathParam, reqValue reflect.Value) error { +func (d *customizedFieldTextDecoder) Decode(req *bindRequest, params param.Params, reqValue reflect.Value) error { var text string var defaultValue string for _, tagInfo := range d.tagInfos { diff --git a/pkg/app/server/binding/decoder/decoder.go b/pkg/app/server/binding/decoder/decoder.go index c7cad2dca..31dd9709e 100644 --- a/pkg/app/server/binding/decoder/decoder.go +++ b/pkg/app/server/binding/decoder/decoder.go @@ -42,12 +42,14 @@ package decoder import ( "fmt" - path1 "github.com/cloudwego/hertz/pkg/app/server/binding/path" - "github.com/cloudwego/hertz/pkg/protocol" "mime/multipart" "net/http" "net/url" "reflect" + "sync" + + "github.com/cloudwego/hertz/pkg/protocol" + "github.com/cloudwego/hertz/pkg/route/param" ) var ( @@ -55,6 +57,10 @@ var ( EnableStructFieldResolve = false ) +var bindRequestPool = sync.Pool{New: func() interface{} { + return &bindRequest{} +}} + type bindRequest struct { Req *protocol.Request Query url.Values @@ -64,11 +70,22 @@ type bindRequest struct { Cookie []*http.Cookie } +func (b *bindRequest) reset() { + b.Req = nil + b.Query = nil + b.Form = nil + b.MultipartForm = nil + b.Header = nil + if b.Cookie != nil { + b.Cookie = b.Cookie[:0] + } +} + type fieldDecoder interface { - Decode(req *bindRequest, params path1.PathParam, reqValue reflect.Value) error + Decode(req *bindRequest, params param.Params, reqValue reflect.Value) error } -type Decoder func(req *protocol.Request, params path1.PathParam, rv reflect.Value) error +type Decoder func(req *protocol.Request, params param.Params, rv reflect.Value) error func GetReqDecoder(rt reflect.Type) (Decoder, error) { var decoders []fieldDecoder @@ -94,16 +111,17 @@ func GetReqDecoder(rt reflect.Type) (Decoder, error) { } } - return func(req *protocol.Request, params path1.PathParam, rv reflect.Value) error { - bindReq := &bindRequest{ - Req: req, - } + return func(req *protocol.Request, params param.Params, rv reflect.Value) error { + bindReq := bindRequestPool.Get().(*bindRequest) + bindReq.Req = req for _, decoder := range decoders { err := decoder.Decode(bindReq, params, rv) if err != nil { return err } } + bindReq.reset() + bindRequestPool.Put(bindReq) return nil }, nil diff --git a/pkg/app/server/binding/decoder/getter.go b/pkg/app/server/binding/decoder/getter.go index bbc7de2ca..5e8908c38 100644 --- a/pkg/app/server/binding/decoder/getter.go +++ b/pkg/app/server/binding/decoder/getter.go @@ -41,12 +41,12 @@ package decoder import ( - path1 "github.com/cloudwego/hertz/pkg/app/server/binding/path" + "github.com/cloudwego/hertz/pkg/route/param" ) -type getter func(req *bindRequest, params path1.PathParam, key string, defaultValue ...string) (ret string) +type getter func(req *bindRequest, params param.Params, key string, defaultValue ...string) (ret string) -func path(req *bindRequest, params path1.PathParam, key string, defaultValue ...string) (ret string) { +func path(req *bindRequest, params param.Params, key string, defaultValue ...string) (ret string) { if params != nil { ret, _ = params.Get(key) } @@ -57,7 +57,7 @@ func path(req *bindRequest, params path1.PathParam, key string, defaultValue ... return ret } -func postForm(req *bindRequest, params path1.PathParam, key string, defaultValue ...string) (ret string) { +func postForm(req *bindRequest, params param.Params, key string, defaultValue ...string) (ret string) { if req.Form != nil { if val, exist := req.Form[key]; exist { ret = val[0] @@ -93,7 +93,7 @@ func postForm(req *bindRequest, params path1.PathParam, key string, defaultValue return } -func query(req *bindRequest, params path1.PathParam, key string, defaultValue ...string) (ret string) { +func query(req *bindRequest, params param.Params, key string, defaultValue ...string) (ret string) { if req.Query != nil { if val, exist := req.Query[key]; exist { ret = val[0] @@ -111,7 +111,7 @@ func query(req *bindRequest, params path1.PathParam, key string, defaultValue .. return } -func cookie(req *bindRequest, params path1.PathParam, key string, defaultValue ...string) (ret string) { +func cookie(req *bindRequest, params param.Params, key string, defaultValue ...string) (ret string) { if len(req.Cookie) != 0 { for _, c := range req.Cookie { if c.Name == key { @@ -132,7 +132,7 @@ func cookie(req *bindRequest, params path1.PathParam, key string, defaultValue . return } -func header(req *bindRequest, params path1.PathParam, key string, defaultValue ...string) (ret string) { +func header(req *bindRequest, params param.Params, key string, defaultValue ...string) (ret string) { if req.Header != nil { if val, exist := req.Header[key]; exist { ret = val[0] @@ -150,7 +150,7 @@ func header(req *bindRequest, params path1.PathParam, key string, defaultValue . return } -func rawBody(req *bindRequest, params path1.PathParam, key string, defaultValue ...string) (ret string) { +func rawBody(req *bindRequest, params param.Params, key string, defaultValue ...string) (ret string) { if req.Req.Header.ContentLength() > 0 { ret = string(req.Req.Body()) } diff --git a/pkg/app/server/binding/decoder/gjson_required.go b/pkg/app/server/binding/decoder/gjson_required.go index d53921efd..42c25425a 100644 --- a/pkg/app/server/binding/decoder/gjson_required.go +++ b/pkg/app/server/binding/decoder/gjson_required.go @@ -19,10 +19,10 @@ package decoder import ( - "github.com/cloudwego/hertz/internal/bytesconv" - "github.com/cloudwego/hertz/pkg/common/utils" "strings" + "github.com/cloudwego/hertz/internal/bytesconv" + "github.com/cloudwego/hertz/pkg/common/utils" "github.com/tidwall/gjson" ) diff --git a/pkg/app/server/binding/decoder/map_type_decoder.go b/pkg/app/server/binding/decoder/map_type_decoder.go index 47a5b5885..ae6ac11d5 100644 --- a/pkg/app/server/binding/decoder/map_type_decoder.go +++ b/pkg/app/server/binding/decoder/map_type_decoder.go @@ -45,16 +45,16 @@ import ( "reflect" "github.com/cloudwego/hertz/internal/bytesconv" - path1 "github.com/cloudwego/hertz/pkg/app/server/binding/path" hjson "github.com/cloudwego/hertz/pkg/common/json" "github.com/cloudwego/hertz/pkg/common/utils" + "github.com/cloudwego/hertz/pkg/route/param" ) type mapTypeFieldTextDecoder struct { fieldInfo } -func (d *mapTypeFieldTextDecoder) Decode(req *bindRequest, params path1.PathParam, reqValue reflect.Value) error { +func (d *mapTypeFieldTextDecoder) Decode(req *bindRequest, params param.Params, reqValue reflect.Value) error { var err error var text string var defaultValue string diff --git a/pkg/app/server/binding/decoder/multipart_file_decoder.go b/pkg/app/server/binding/decoder/multipart_file_decoder.go index 1c3a180d6..d4d7dcbda 100644 --- a/pkg/app/server/binding/decoder/multipart_file_decoder.go +++ b/pkg/app/server/binding/decoder/multipart_file_decoder.go @@ -20,7 +20,7 @@ import ( "fmt" "reflect" - path1 "github.com/cloudwego/hertz/pkg/app/server/binding/path" + "github.com/cloudwego/hertz/pkg/route/param" ) type fileTypeDecoder struct { @@ -28,7 +28,7 @@ type fileTypeDecoder struct { isRepeated bool } -func (d *fileTypeDecoder) Decode(req *bindRequest, params path1.PathParam, reqValue reflect.Value) error { +func (d *fileTypeDecoder) Decode(req *bindRequest, params param.Params, reqValue reflect.Value) error { fieldValue := GetFieldValue(reqValue, d.parentIndex) field := fieldValue.Field(d.index) @@ -72,7 +72,7 @@ func (d *fileTypeDecoder) Decode(req *bindRequest, params path1.PathParam, reqVa return nil } -func (d *fileTypeDecoder) fileSliceDecode(req *bindRequest, params path1.PathParam, reqValue reflect.Value) error { +func (d *fileTypeDecoder) fileSliceDecode(req *bindRequest, params param.Params, reqValue reflect.Value) error { fieldValue := GetFieldValue(reqValue, d.parentIndex) field := fieldValue.Field(d.index) // 如果没值,需要为其建一个值 diff --git a/pkg/app/server/binding/decoder/getter_slice.go b/pkg/app/server/binding/decoder/slice_getter.go similarity index 84% rename from pkg/app/server/binding/decoder/getter_slice.go rename to pkg/app/server/binding/decoder/slice_getter.go index 425bb529e..6bb50a66f 100644 --- a/pkg/app/server/binding/decoder/getter_slice.go +++ b/pkg/app/server/binding/decoder/slice_getter.go @@ -44,12 +44,12 @@ import ( "net/http" "net/url" - path1 "github.com/cloudwego/hertz/pkg/app/server/binding/path" + "github.com/cloudwego/hertz/pkg/route/param" ) -type sliceGetter func(req *bindRequest, params path1.PathParam, key string, defaultValue ...string) (ret []string) +type sliceGetter func(req *bindRequest, params param.Params, key string, defaultValue ...string) (ret []string) -func pathSlice(req *bindRequest, params path1.PathParam, key string, defaultValue ...string) (ret []string) { +func pathSlice(req *bindRequest, params param.Params, key string, defaultValue ...string) (ret []string) { var value string if params != nil { value, _ = params.Get(key) @@ -65,7 +65,7 @@ func pathSlice(req *bindRequest, params path1.PathParam, key string, defaultValu return } -func postFormSlice(req *bindRequest, params path1.PathParam, key string, defaultValue ...string) (ret []string) { +func postFormSlice(req *bindRequest, params param.Params, key string, defaultValue ...string) (ret []string) { if req.Form == nil { req.Form = make(url.Values) req.Req.PostArgs().VisitAll(func(formKey, value []byte) { @@ -103,7 +103,7 @@ func postFormSlice(req *bindRequest, params path1.PathParam, key string, default return } -func querySlice(req *bindRequest, params path1.PathParam, key string, defaultValue ...string) (ret []string) { +func querySlice(req *bindRequest, params param.Params, key string, defaultValue ...string) (ret []string) { if req.Query == nil { req.Query = make(url.Values, 100) req.Req.URI().QueryArgs().VisitAll(func(queryKey, value []byte) { @@ -122,7 +122,7 @@ func querySlice(req *bindRequest, params path1.PathParam, key string, defaultVal return } -func cookieSlice(req *bindRequest, params path1.PathParam, key string, defaultValue ...string) (ret []string) { +func cookieSlice(req *bindRequest, params param.Params, key string, defaultValue ...string) (ret []string) { if len(req.Cookie) == 0 { req.Req.Header.VisitAllCookie(func(cookieKey, value []byte) { req.Cookie = append(req.Cookie, &http.Cookie{ @@ -143,7 +143,7 @@ func cookieSlice(req *bindRequest, params path1.PathParam, key string, defaultVa return } -func headerSlice(req *bindRequest, params path1.PathParam, key string, defaultValue ...string) (ret []string) { +func headerSlice(req *bindRequest, params param.Params, key string, defaultValue ...string) (ret []string) { if req.Header == nil { req.Header = make(http.Header) req.Req.Header.VisitAll(func(headerKey, value []byte) { @@ -162,7 +162,7 @@ func headerSlice(req *bindRequest, params path1.PathParam, key string, defaultVa return } -func rawBodySlice(req *bindRequest, params path1.PathParam, key string, defaultValue ...string) (ret []string) { +func rawBodySlice(req *bindRequest, params param.Params, key string, defaultValue ...string) (ret []string) { if req.Req.Header.ContentLength() > 0 { ret = append(ret, string(req.Req.Body())) } diff --git a/pkg/app/server/binding/decoder/slice_type_decoder.go b/pkg/app/server/binding/decoder/slice_type_decoder.go index 0cbac6f47..3787c6912 100644 --- a/pkg/app/server/binding/decoder/slice_type_decoder.go +++ b/pkg/app/server/binding/decoder/slice_type_decoder.go @@ -46,9 +46,9 @@ import ( "reflect" "github.com/cloudwego/hertz/internal/bytesconv" - path1 "github.com/cloudwego/hertz/pkg/app/server/binding/path" hjson "github.com/cloudwego/hertz/pkg/common/json" "github.com/cloudwego/hertz/pkg/common/utils" + "github.com/cloudwego/hertz/pkg/route/param" ) type sliceTypeFieldTextDecoder struct { @@ -56,7 +56,7 @@ type sliceTypeFieldTextDecoder struct { isArray bool } -func (d *sliceTypeFieldTextDecoder) Decode(req *bindRequest, params path1.PathParam, reqValue reflect.Value) error { +func (d *sliceTypeFieldTextDecoder) Decode(req *bindRequest, params param.Params, reqValue reflect.Value) error { var err error var texts []string var defaultValue string @@ -221,7 +221,7 @@ func getSliceFieldDecoder(field reflect.StructField, index int, tagInfos []TagIn }}, nil } -func stringToValue(elemType reflect.Type, text string, req *bindRequest, params path1.PathParam) (v reflect.Value, err error) { +func stringToValue(elemType reflect.Type, text string, req *bindRequest, params param.Params) (v reflect.Value, err error) { v = reflect.New(elemType).Elem() if customizedFunc, exist := typeUnmarshalFuncs[elemType]; exist { val, err := customizedFunc(req.Req, params, text) diff --git a/pkg/app/server/binding/decoder/struct_type_decoder.go b/pkg/app/server/binding/decoder/struct_type_decoder.go index 3ee6a1f83..9ec17b60d 100644 --- a/pkg/app/server/binding/decoder/struct_type_decoder.go +++ b/pkg/app/server/binding/decoder/struct_type_decoder.go @@ -21,16 +21,16 @@ import ( "reflect" "github.com/cloudwego/hertz/internal/bytesconv" - path1 "github.com/cloudwego/hertz/pkg/app/server/binding/path" hjson "github.com/cloudwego/hertz/pkg/common/json" "github.com/cloudwego/hertz/pkg/common/utils" + "github.com/cloudwego/hertz/pkg/route/param" ) type structTypeFieldTextDecoder struct { fieldInfo } -func (d *structTypeFieldTextDecoder) Decode(req *bindRequest, params path1.PathParam, reqValue reflect.Value) error { +func (d *structTypeFieldTextDecoder) Decode(req *bindRequest, params param.Params, reqValue reflect.Value) error { var err error var text string var defaultValue string diff --git a/pkg/app/server/binding/default.go b/pkg/app/server/binding/default.go index 5841fe168..df436a694 100644 --- a/pkg/app/server/binding/default.go +++ b/pkg/app/server/binding/default.go @@ -67,10 +67,10 @@ import ( "github.com/cloudwego/hertz/internal/bytesconv" "github.com/cloudwego/hertz/pkg/app/server/binding/decoder" - "github.com/cloudwego/hertz/pkg/app/server/binding/path" hjson "github.com/cloudwego/hertz/pkg/common/json" "github.com/cloudwego/hertz/pkg/common/utils" "github.com/cloudwego/hertz/pkg/protocol" + "github.com/cloudwego/hertz/pkg/route/param" "github.com/go-playground/validator/v10" "google.golang.org/protobuf/proto" ) @@ -83,7 +83,7 @@ func (b *defaultBinder) Name() string { return "hertz" } -func (b *defaultBinder) Bind(req *protocol.Request, params path.PathParam, v interface{}) error { +func (b *defaultBinder) Bind(req *protocol.Request, params param.Params, v interface{}) error { err := b.preBindBody(req, v) if err != nil { return fmt.Errorf("bind body failed, err=%v", err) diff --git a/pkg/app/server/binding/path/path.go b/pkg/app/server/binding/path/path.go deleted file mode 100644 index b40432e04..000000000 --- a/pkg/app/server/binding/path/path.go +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Copyright 2022 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package path - -// PathParam parameter acquisition interface on the URL path -type PathParam interface { - Get(name string) (string, bool) -} diff --git a/pkg/app/server/binding/tagexpr_bind_test.go b/pkg/app/server/binding/tagexpr_bind_test.go index 0dcde177f..b9c18e5bc 100644 --- a/pkg/app/server/binding/tagexpr_bind_test.go +++ b/pkg/app/server/binding/tagexpr_bind_test.go @@ -47,6 +47,7 @@ import ( "time" "github.com/cloudwego/hertz/pkg/common/test/assert" + "github.com/cloudwego/hertz/pkg/route/param" "google.golang.org/protobuf/proto" ) @@ -478,7 +479,34 @@ func TestPath(t *testing.T) { req := newRequest("", nil, nil, nil) recv := new(Recv) - err := DefaultBinder.Bind(req.Req, new(testPathParams), recv) + params := param.Params{ + { + "a", + "a1", + }, + { + "b", + "-21", + }, + { + "c", + "31", + }, + { + "d", + "41", + }, + { + "y", + "y1", + }, + { + "name", + "henrylee2cn", + }, + } + + err := DefaultBinder.Bind(req.Req, params, recv) if err != nil { t.Error(err) } @@ -548,7 +576,14 @@ func TestDefault(t *testing.T) { req := newRequest("", header, nil, bodyReader) recv := new(Recv) - err := DefaultBinder.Bind(req.Req, new(testPathParams2), recv) + param2 := param.Params{ + { + "e", + "123", + }, + } + + err := DefaultBinder.Bind(req.Req, param2, recv) if err != nil { t.Error(err) } From f3b95f5cb5a15f744d52711b9135d32d94d8a5b6 Mon Sep 17 00:00:00 2001 From: fgy Date: Wed, 24 May 2023 16:00:47 +0800 Subject: [PATCH 40/91] ci: ci --- pkg/app/server/binding/tagexpr_bind_test.go | 35 ++------------------- 1 file changed, 3 insertions(+), 32 deletions(-) diff --git a/pkg/app/server/binding/tagexpr_bind_test.go b/pkg/app/server/binding/tagexpr_bind_test.go index b9c18e5bc..e5bb24a7d 100644 --- a/pkg/app/server/binding/tagexpr_bind_test.go +++ b/pkg/app/server/binding/tagexpr_bind_test.go @@ -443,27 +443,6 @@ func TestJSON(t *testing.T) { func TestNonstruct(t *testing.T) { } -type testPathParams struct{} - -func (testPathParams) Get(name string) (string, bool) { - switch name { - case "a": - return "a1", true - case "b": - return "-21", true - case "c": - return "31", true - case "d": - return "41", true - case "y": - return "y1", true - case "name": - return "henrylee2cn", true - default: - return "", false - } -} - func TestPath(t *testing.T) { type Recv struct { X **struct { @@ -518,17 +497,6 @@ func TestPath(t *testing.T) { assert.DeepEqual(t, (*int64)(nil), recv.Z) } -type testPathParams2 struct{} - -func (testPathParams2) Get(name string) (string, bool) { - switch name { - case "e": - return "123", true - default: - return "", false - } -} - // FIXME: 复杂类型的默认值,暂时先不做,低优 func TestDefault(t *testing.T) { //type S struct { @@ -988,6 +956,9 @@ func TestPathnameBUG(t *testing.T) { z.Currency.Currency.MaxPrice = proto.String("?") b, err := json.MarshalIndent(z, "", " ") + if err != nil { + t.Error(err) + } header := make(http.Header) header.Set("Content-Type", "application/json;charset=utf-8") req := newRequest("http://localhost", header, nil, bytes.NewReader(b)) From bc88f03679da799bc8bbfe9936d47c89f1f50715 Mon Sep 17 00:00:00 2001 From: fgy Date: Wed, 24 May 2023 17:32:15 +0800 Subject: [PATCH 41/91] optimzie: slice performance --- pkg/app/server/binding/decoder/slice_getter.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/app/server/binding/decoder/slice_getter.go b/pkg/app/server/binding/decoder/slice_getter.go index 6bb50a66f..d4e5c9dc5 100644 --- a/pkg/app/server/binding/decoder/slice_getter.go +++ b/pkg/app/server/binding/decoder/slice_getter.go @@ -105,7 +105,7 @@ func postFormSlice(req *bindRequest, params param.Params, key string, defaultVal func querySlice(req *bindRequest, params param.Params, key string, defaultValue ...string) (ret []string) { if req.Query == nil { - req.Query = make(url.Values, 100) + req.Query = make(url.Values) req.Req.URI().QueryArgs().VisitAll(func(queryKey, value []byte) { keyStr := string(queryKey) values := req.Query[keyStr] From eceae04a971a6707bd0016374528c5bea0d42e15 Mon Sep 17 00:00:00 2001 From: fgy Date: Wed, 24 May 2023 20:25:49 +0800 Subject: [PATCH 42/91] optimize: remove cache http info --- .../binding/decoder/base_type_decoder.go | 5 +- .../decoder/customized_type_decoder.go | 6 +- pkg/app/server/binding/decoder/decoder.go | 35 +------ pkg/app/server/binding/decoder/getter.go | 78 +++++---------- .../server/binding/decoder/gjson_required.go | 2 +- .../binding/decoder/map_type_decoder.go | 5 +- .../binding/decoder/multipart_file_decoder.go | 9 +- .../server/binding/decoder/slice_getter.go | 99 +++++++------------ .../binding/decoder/slice_type_decoder.go | 11 ++- .../server/binding/decoder/sonic_required.go | 9 +- .../binding/decoder/struct_type_decoder.go | 5 +- 11 files changed, 90 insertions(+), 174 deletions(-) diff --git a/pkg/app/server/binding/decoder/base_type_decoder.go b/pkg/app/server/binding/decoder/base_type_decoder.go index c002c6ca7..3fe659825 100644 --- a/pkg/app/server/binding/decoder/base_type_decoder.go +++ b/pkg/app/server/binding/decoder/base_type_decoder.go @@ -45,6 +45,7 @@ import ( "reflect" "github.com/cloudwego/hertz/pkg/common/utils" + "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/route/param" ) @@ -61,7 +62,7 @@ type baseTypeFieldTextDecoder struct { decoder TextDecoder } -func (d *baseTypeFieldTextDecoder) Decode(req *bindRequest, params param.Params, reqValue reflect.Value) error { +func (d *baseTypeFieldTextDecoder) Decode(req *protocol.Request, params param.Params, reqValue reflect.Value) error { var err error var text string var defaultValue string @@ -79,7 +80,7 @@ func (d *baseTypeFieldTextDecoder) Decode(req *bindRequest, params param.Params, continue } if tagInfo.Key == headerTag { - tagInfo.Value = utils.GetNormalizeHeaderKey(tagInfo.Value, req.Req.Header.IsDisableNormalizing()) + tagInfo.Value = utils.GetNormalizeHeaderKey(tagInfo.Value, req.Header.IsDisableNormalizing()) } text = tagInfo.Getter(req, params, tagInfo.Value) defaultValue = tagInfo.Default diff --git a/pkg/app/server/binding/decoder/customized_type_decoder.go b/pkg/app/server/binding/decoder/customized_type_decoder.go index 99bdbbdbb..3fb38c4e5 100644 --- a/pkg/app/server/binding/decoder/customized_type_decoder.go +++ b/pkg/app/server/binding/decoder/customized_type_decoder.go @@ -105,7 +105,7 @@ type customizedFieldTextDecoder struct { decodeFunc customizeDecodeFunc } -func (d *customizedFieldTextDecoder) Decode(req *bindRequest, params param.Params, reqValue reflect.Value) error { +func (d *customizedFieldTextDecoder) Decode(req *protocol.Request, params param.Params, reqValue reflect.Value) error { var text string var defaultValue string for _, tagInfo := range d.tagInfos { @@ -114,7 +114,7 @@ func (d *customizedFieldTextDecoder) Decode(req *bindRequest, params param.Param continue } if tagInfo.Key == headerTag { - tagInfo.Value = utils.GetNormalizeHeaderKey(tagInfo.Value, req.Req.Header.IsDisableNormalizing()) + tagInfo.Value = utils.GetNormalizeHeaderKey(tagInfo.Value, req.Header.IsDisableNormalizing()) } text = tagInfo.Getter(req, params, tagInfo.Value) defaultValue = tagInfo.Default @@ -126,7 +126,7 @@ func (d *customizedFieldTextDecoder) Decode(req *bindRequest, params param.Param text = defaultValue } - v, err := d.decodeFunc(req.Req, params, text) + v, err := d.decodeFunc(req, params, text) if err != nil { return err } diff --git a/pkg/app/server/binding/decoder/decoder.go b/pkg/app/server/binding/decoder/decoder.go index 31dd9709e..8f62e27d7 100644 --- a/pkg/app/server/binding/decoder/decoder.go +++ b/pkg/app/server/binding/decoder/decoder.go @@ -43,10 +43,7 @@ package decoder import ( "fmt" "mime/multipart" - "net/http" - "net/url" "reflect" - "sync" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/route/param" @@ -57,32 +54,8 @@ var ( EnableStructFieldResolve = false ) -var bindRequestPool = sync.Pool{New: func() interface{} { - return &bindRequest{} -}} - -type bindRequest struct { - Req *protocol.Request - Query url.Values - Form url.Values - MultipartForm url.Values - Header http.Header - Cookie []*http.Cookie -} - -func (b *bindRequest) reset() { - b.Req = nil - b.Query = nil - b.Form = nil - b.MultipartForm = nil - b.Header = nil - if b.Cookie != nil { - b.Cookie = b.Cookie[:0] - } -} - type fieldDecoder interface { - Decode(req *bindRequest, params param.Params, reqValue reflect.Value) error + Decode(req *protocol.Request, params param.Params, reqValue reflect.Value) error } type Decoder func(req *protocol.Request, params param.Params, rv reflect.Value) error @@ -112,16 +85,12 @@ func GetReqDecoder(rt reflect.Type) (Decoder, error) { } return func(req *protocol.Request, params param.Params, rv reflect.Value) error { - bindReq := bindRequestPool.Get().(*bindRequest) - bindReq.Req = req for _, decoder := range decoders { - err := decoder.Decode(bindReq, params, rv) + err := decoder.Decode(req, params, rv) if err != nil { return err } } - bindReq.reset() - bindRequestPool.Put(bindReq) return nil }, nil diff --git a/pkg/app/server/binding/decoder/getter.go b/pkg/app/server/binding/decoder/getter.go index 5e8908c38..1998b26ba 100644 --- a/pkg/app/server/binding/decoder/getter.go +++ b/pkg/app/server/binding/decoder/getter.go @@ -41,12 +41,13 @@ package decoder import ( + "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/route/param" ) -type getter func(req *bindRequest, params param.Params, key string, defaultValue ...string) (ret string) +type getter func(req *protocol.Request, params param.Params, key string, defaultValue ...string) (ret string) -func path(req *bindRequest, params param.Params, key string, defaultValue ...string) (ret string) { +func path(req *protocol.Request, params param.Params, key string, defaultValue ...string) (ret string) { if params != nil { ret, _ = params.Get(key) } @@ -57,31 +58,19 @@ func path(req *bindRequest, params param.Params, key string, defaultValue ...str return ret } -func postForm(req *bindRequest, params param.Params, key string, defaultValue ...string) (ret string) { - if req.Form != nil { - if val, exist := req.Form[key]; exist { - ret = val[0] - } - } else { - if val := req.Req.PostArgs().Peek(key); val != nil { - ret = string(val) - } +func postForm(req *protocol.Request, params param.Params, key string, defaultValue ...string) (ret string) { + if val := req.PostArgs().Peek(key); val != nil { + ret = string(val) } if len(ret) > 0 { return } - if req.MultipartForm != nil { - if val, exist := req.MultipartForm[key]; exist { - ret = val[0] - } - } else { - mf, err := req.Req.MultipartForm() - if err == nil && mf.Value != nil { - for k, v := range mf.Value { - if k == key && len(v) > 0 { - ret = v[0] - } + mf, err := req.MultipartForm() + if err == nil && mf.Value != nil { + for k, v := range mf.Value { + if k == key && len(v) > 0 { + ret = v[0] } } } @@ -93,15 +82,9 @@ func postForm(req *bindRequest, params param.Params, key string, defaultValue .. return } -func query(req *bindRequest, params param.Params, key string, defaultValue ...string) (ret string) { - if req.Query != nil { - if val, exist := req.Query[key]; exist { - ret = val[0] - } - } else { - if val := req.Req.URI().QueryArgs().Peek(key); val != nil { - ret = string(val) - } +func query(req *protocol.Request, params param.Params, key string, defaultValue ...string) (ret string) { + if val := req.URI().QueryArgs().Peek(key); val != nil { + ret = string(val) } if len(ret) == 0 && len(defaultValue) != 0 { @@ -111,18 +94,9 @@ func query(req *bindRequest, params param.Params, key string, defaultValue ...st return } -func cookie(req *bindRequest, params param.Params, key string, defaultValue ...string) (ret string) { - if len(req.Cookie) != 0 { - for _, c := range req.Cookie { - if c.Name == key { - ret = c.Value - break - } - } - } else { - if val := req.Req.Header.Cookie(key); val != nil { - ret = string(val) - } +func cookie(req *protocol.Request, params param.Params, key string, defaultValue ...string) (ret string) { + if val := req.Header.Cookie(key); val != nil { + ret = string(val) } if len(ret) == 0 && len(defaultValue) != 0 { @@ -132,15 +106,9 @@ func cookie(req *bindRequest, params param.Params, key string, defaultValue ...s return } -func header(req *bindRequest, params param.Params, key string, defaultValue ...string) (ret string) { - if req.Header != nil { - if val, exist := req.Header[key]; exist { - ret = val[0] - } - } else { - if val := req.Req.Header.Peek(key); val != nil { - ret = string(val) - } +func header(req *protocol.Request, params param.Params, key string, defaultValue ...string) (ret string) { + if val := req.Header.Peek(key); val != nil { + ret = string(val) } if len(ret) == 0 && len(defaultValue) != 0 { @@ -150,9 +118,9 @@ func header(req *bindRequest, params param.Params, key string, defaultValue ...s return } -func rawBody(req *bindRequest, params param.Params, key string, defaultValue ...string) (ret string) { - if req.Req.Header.ContentLength() > 0 { - ret = string(req.Req.Body()) +func rawBody(req *protocol.Request, params param.Params, key string, defaultValue ...string) (ret string) { + if req.Header.ContentLength() > 0 { + ret = string(req.Body()) } return } diff --git a/pkg/app/server/binding/decoder/gjson_required.go b/pkg/app/server/binding/decoder/gjson_required.go index 42c25425a..f6aac0a84 100644 --- a/pkg/app/server/binding/decoder/gjson_required.go +++ b/pkg/app/server/binding/decoder/gjson_required.go @@ -26,7 +26,7 @@ import ( "github.com/tidwall/gjson" ) -func checkRequireJSON2(req *bindRequest, tagInfo TagInfo) bool { +func checkRequireJSON2(req *protocol.Request, tagInfo TagInfo) bool { if !tagInfo.Required { return true } diff --git a/pkg/app/server/binding/decoder/map_type_decoder.go b/pkg/app/server/binding/decoder/map_type_decoder.go index ae6ac11d5..71fea1db5 100644 --- a/pkg/app/server/binding/decoder/map_type_decoder.go +++ b/pkg/app/server/binding/decoder/map_type_decoder.go @@ -47,6 +47,7 @@ import ( "github.com/cloudwego/hertz/internal/bytesconv" hjson "github.com/cloudwego/hertz/pkg/common/json" "github.com/cloudwego/hertz/pkg/common/utils" + "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/route/param" ) @@ -54,7 +55,7 @@ type mapTypeFieldTextDecoder struct { fieldInfo } -func (d *mapTypeFieldTextDecoder) Decode(req *bindRequest, params param.Params, reqValue reflect.Value) error { +func (d *mapTypeFieldTextDecoder) Decode(req *protocol.Request, params param.Params, reqValue reflect.Value) error { var err error var text string var defaultValue string @@ -72,7 +73,7 @@ func (d *mapTypeFieldTextDecoder) Decode(req *bindRequest, params param.Params, continue } if tagInfo.Key == headerTag { - tagInfo.Value = utils.GetNormalizeHeaderKey(tagInfo.Value, req.Req.Header.IsDisableNormalizing()) + tagInfo.Value = utils.GetNormalizeHeaderKey(tagInfo.Value, req.Header.IsDisableNormalizing()) } text = tagInfo.Getter(req, params, tagInfo.Value) defaultValue = tagInfo.Default diff --git a/pkg/app/server/binding/decoder/multipart_file_decoder.go b/pkg/app/server/binding/decoder/multipart_file_decoder.go index d4d7dcbda..db6d60d1a 100644 --- a/pkg/app/server/binding/decoder/multipart_file_decoder.go +++ b/pkg/app/server/binding/decoder/multipart_file_decoder.go @@ -20,6 +20,7 @@ import ( "fmt" "reflect" + "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/route/param" ) @@ -28,7 +29,7 @@ type fileTypeDecoder struct { isRepeated bool } -func (d *fileTypeDecoder) Decode(req *bindRequest, params param.Params, reqValue reflect.Value) error { +func (d *fileTypeDecoder) Decode(req *protocol.Request, params param.Params, reqValue reflect.Value) error { fieldValue := GetFieldValue(reqValue, d.parentIndex) field := fieldValue.Field(d.index) @@ -49,7 +50,7 @@ func (d *fileTypeDecoder) Decode(req *bindRequest, params param.Params, reqValue if len(fileName) == 0 { fileName = d.fieldName } - file, err := req.Req.FormFile(fileName) + file, err := req.FormFile(fileName) if err != nil { return fmt.Errorf("can not get file '%s', err: %v", fileName, err) } @@ -72,7 +73,7 @@ func (d *fileTypeDecoder) Decode(req *bindRequest, params param.Params, reqValue return nil } -func (d *fileTypeDecoder) fileSliceDecode(req *bindRequest, params param.Params, reqValue reflect.Value) error { +func (d *fileTypeDecoder) fileSliceDecode(req *protocol.Request, params param.Params, reqValue reflect.Value) error { fieldValue := GetFieldValue(reqValue, d.parentIndex) field := fieldValue.Field(d.index) // 如果没值,需要为其建一个值 @@ -102,7 +103,7 @@ func (d *fileTypeDecoder) fileSliceDecode(req *bindRequest, params param.Params, if len(fileName) == 0 { fileName = d.fieldName } - multipartForm, err := req.Req.MultipartForm() + multipartForm, err := req.MultipartForm() if err != nil { return fmt.Errorf("can not get multipartForm info, err: %v", err) } diff --git a/pkg/app/server/binding/decoder/slice_getter.go b/pkg/app/server/binding/decoder/slice_getter.go index d4e5c9dc5..7bf6dc27b 100644 --- a/pkg/app/server/binding/decoder/slice_getter.go +++ b/pkg/app/server/binding/decoder/slice_getter.go @@ -41,15 +41,14 @@ package decoder import ( - "net/http" - "net/url" - + "github.com/cloudwego/hertz/internal/bytesconv" + "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/route/param" ) -type sliceGetter func(req *bindRequest, params param.Params, key string, defaultValue ...string) (ret []string) +type sliceGetter func(req *protocol.Request, params param.Params, key string, defaultValue ...string) (ret []string) -func pathSlice(req *bindRequest, params param.Params, key string, defaultValue ...string) (ret []string) { +func pathSlice(req *protocol.Request, params param.Params, key string, defaultValue ...string) (ret []string) { var value string if params != nil { value, _ = params.Get(key) @@ -65,33 +64,24 @@ func pathSlice(req *bindRequest, params param.Params, key string, defaultValue . return } -func postFormSlice(req *bindRequest, params param.Params, key string, defaultValue ...string) (ret []string) { - if req.Form == nil { - req.Form = make(url.Values) - req.Req.PostArgs().VisitAll(func(formKey, value []byte) { - keyStr := string(formKey) - values := req.Form[keyStr] - values = append(values, string(value)) - req.Form[keyStr] = values - }) - } - ret = req.Form[key] +func postFormSlice(req *protocol.Request, params param.Params, key string, defaultValue ...string) (ret []string) { + req.PostArgs().VisitAll(func(formKey, value []byte) { + if bytesconv.B2s(formKey) == key { + ret = append(ret, string(value)) + } + }) if len(ret) > 0 { return } - if req.MultipartForm == nil { - req.MultipartForm = make(url.Values) - mf, err := req.Req.MultipartForm() - if err == nil && mf.Value != nil { - for k, v := range mf.Value { - if len(v) > 0 { - req.MultipartForm[k] = v - } + mf, err := req.MultipartForm() + if err == nil && mf.Value != nil { + for k, v := range mf.Value { + if k == key && len(v) > 0 { + ret = append(ret, v...) } } } - ret = req.MultipartForm[key] if len(ret) > 0 { return } @@ -103,18 +93,13 @@ func postFormSlice(req *bindRequest, params param.Params, key string, defaultVal return } -func querySlice(req *bindRequest, params param.Params, key string, defaultValue ...string) (ret []string) { - if req.Query == nil { - req.Query = make(url.Values) - req.Req.URI().QueryArgs().VisitAll(func(queryKey, value []byte) { - keyStr := string(queryKey) - values := req.Query[keyStr] - values = append(values, string(value)) - req.Query[keyStr] = values - }) - } +func querySlice(req *protocol.Request, params param.Params, key string, defaultValue ...string) (ret []string) { + req.URI().QueryArgs().VisitAll(func(queryKey, value []byte) { + if key == bytesconv.B2s(queryKey) { + ret = append(ret, string(value)) + } + }) - ret = req.Query[key] if len(ret) == 0 && len(defaultValue) != 0 { ret = append(ret, defaultValue...) } @@ -122,20 +107,13 @@ func querySlice(req *bindRequest, params param.Params, key string, defaultValue return } -func cookieSlice(req *bindRequest, params param.Params, key string, defaultValue ...string) (ret []string) { - if len(req.Cookie) == 0 { - req.Req.Header.VisitAllCookie(func(cookieKey, value []byte) { - req.Cookie = append(req.Cookie, &http.Cookie{ - Name: string(cookieKey), - Value: string(value), - }) - }) - } - for _, c := range req.Cookie { - if c.Name == key { - ret = append(ret, c.Value) +func cookieSlice(req *protocol.Request, params param.Params, key string, defaultValue ...string) (ret []string) { + req.Header.VisitAllCookie(func(cookieKey, value []byte) { + if bytesconv.B2s(cookieKey) == key { + ret = append(ret, string(value)) } - } + }) + if len(ret) == 0 && len(defaultValue) != 0 { ret = append(ret, defaultValue...) } @@ -143,18 +121,13 @@ func cookieSlice(req *bindRequest, params param.Params, key string, defaultValue return } -func headerSlice(req *bindRequest, params param.Params, key string, defaultValue ...string) (ret []string) { - if req.Header == nil { - req.Header = make(http.Header) - req.Req.Header.VisitAll(func(headerKey, value []byte) { - keyStr := string(headerKey) - values := req.Header[keyStr] - values = append(values, string(value)) - req.Header[keyStr] = values - }) - } +func headerSlice(req *protocol.Request, params param.Params, key string, defaultValue ...string) (ret []string) { + req.Header.VisitAll(func(headerKey, value []byte) { + if bytesconv.B2s(headerKey) == key { + ret = append(ret, string(value)) + } + }) - ret = req.Header[key] if len(ret) == 0 && len(defaultValue) != 0 { ret = append(ret, defaultValue...) } @@ -162,9 +135,9 @@ func headerSlice(req *bindRequest, params param.Params, key string, defaultValue return } -func rawBodySlice(req *bindRequest, params param.Params, key string, defaultValue ...string) (ret []string) { - if req.Req.Header.ContentLength() > 0 { - ret = append(ret, string(req.Req.Body())) +func rawBodySlice(req *protocol.Request, params param.Params, key string, defaultValue ...string) (ret []string) { + if req.Header.ContentLength() > 0 { + ret = append(ret, string(req.Body())) } return } diff --git a/pkg/app/server/binding/decoder/slice_type_decoder.go b/pkg/app/server/binding/decoder/slice_type_decoder.go index 3787c6912..63d5cea37 100644 --- a/pkg/app/server/binding/decoder/slice_type_decoder.go +++ b/pkg/app/server/binding/decoder/slice_type_decoder.go @@ -48,6 +48,7 @@ import ( "github.com/cloudwego/hertz/internal/bytesconv" hjson "github.com/cloudwego/hertz/pkg/common/json" "github.com/cloudwego/hertz/pkg/common/utils" + "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/route/param" ) @@ -56,7 +57,7 @@ type sliceTypeFieldTextDecoder struct { isArray bool } -func (d *sliceTypeFieldTextDecoder) Decode(req *bindRequest, params param.Params, reqValue reflect.Value) error { +func (d *sliceTypeFieldTextDecoder) Decode(req *protocol.Request, params param.Params, reqValue reflect.Value) error { var err error var texts []string var defaultValue string @@ -75,7 +76,7 @@ func (d *sliceTypeFieldTextDecoder) Decode(req *bindRequest, params param.Params continue } if tagInfo.Key == headerTag { - tagInfo.Value = utils.GetNormalizeHeaderKey(tagInfo.Value, req.Req.Header.IsDisableNormalizing()) + tagInfo.Value = utils.GetNormalizeHeaderKey(tagInfo.Value, req.Header.IsDisableNormalizing()) } if tagInfo.Key == rawBodyTag { bindRawBody = true @@ -126,7 +127,7 @@ func (d *sliceTypeFieldTextDecoder) Decode(req *bindRequest, params param.Params } // raw_body && []byte binding if bindRawBody && field.Type().Elem().Kind() == reflect.Uint8 { - reqValue.Field(d.index).Set(reflect.ValueOf(req.Req.Body())) + reqValue.Field(d.index).Set(reflect.ValueOf(req.Body())) return nil } @@ -221,10 +222,10 @@ func getSliceFieldDecoder(field reflect.StructField, index int, tagInfos []TagIn }}, nil } -func stringToValue(elemType reflect.Type, text string, req *bindRequest, params param.Params) (v reflect.Value, err error) { +func stringToValue(elemType reflect.Type, text string, req *protocol.Request, params param.Params) (v reflect.Value, err error) { v = reflect.New(elemType).Elem() if customizedFunc, exist := typeUnmarshalFuncs[elemType]; exist { - val, err := customizedFunc(req.Req, params, text) + val, err := customizedFunc(req, params, text) if err != nil { return reflect.Value{}, err } diff --git a/pkg/app/server/binding/decoder/sonic_required.go b/pkg/app/server/binding/decoder/sonic_required.go index 269747c6e..fcf922c65 100644 --- a/pkg/app/server/binding/decoder/sonic_required.go +++ b/pkg/app/server/binding/decoder/sonic_required.go @@ -26,22 +26,23 @@ import ( "github.com/bytedance/sonic" "github.com/cloudwego/hertz/internal/bytesconv" "github.com/cloudwego/hertz/pkg/common/utils" + "github.com/cloudwego/hertz/pkg/protocol" ) -func checkRequireJSON(req *bindRequest, tagInfo TagInfo) bool { +func checkRequireJSON(req *protocol.Request, tagInfo TagInfo) bool { if !tagInfo.Required { return true } - ct := bytesconv.B2s(req.Req.Header.ContentType()) + ct := bytesconv.B2s(req.Header.ContentType()) if utils.FilterContentType(ct) != "application/json" { return false } - node, _ := sonic.Get(req.Req.Body(), stringSliceForInterface(tagInfo.JSONName)...) + node, _ := sonic.Get(req.Body(), stringSliceForInterface(tagInfo.JSONName)...) if !node.Exists() { idx := strings.LastIndex(tagInfo.JSONName, ".") if idx > 0 { // There should be a superior if it is empty, it will report 'true' for required - node, _ := sonic.Get(req.Req.Body(), stringSliceForInterface(tagInfo.JSONName[:idx])...) + node, _ := sonic.Get(req.Body(), stringSliceForInterface(tagInfo.JSONName[:idx])...) if !node.Exists() { return true } diff --git a/pkg/app/server/binding/decoder/struct_type_decoder.go b/pkg/app/server/binding/decoder/struct_type_decoder.go index 9ec17b60d..7592a420e 100644 --- a/pkg/app/server/binding/decoder/struct_type_decoder.go +++ b/pkg/app/server/binding/decoder/struct_type_decoder.go @@ -23,6 +23,7 @@ import ( "github.com/cloudwego/hertz/internal/bytesconv" hjson "github.com/cloudwego/hertz/pkg/common/json" "github.com/cloudwego/hertz/pkg/common/utils" + "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/route/param" ) @@ -30,7 +31,7 @@ type structTypeFieldTextDecoder struct { fieldInfo } -func (d *structTypeFieldTextDecoder) Decode(req *bindRequest, params param.Params, reqValue reflect.Value) error { +func (d *structTypeFieldTextDecoder) Decode(req *protocol.Request, params param.Params, reqValue reflect.Value) error { var err error var text string var defaultValue string @@ -48,7 +49,7 @@ func (d *structTypeFieldTextDecoder) Decode(req *bindRequest, params param.Param continue } if tagInfo.Key == headerTag { - tagInfo.Value = utils.GetNormalizeHeaderKey(tagInfo.Value, req.Req.Header.IsDisableNormalizing()) + tagInfo.Value = utils.GetNormalizeHeaderKey(tagInfo.Value, req.Header.IsDisableNormalizing()) } text = tagInfo.Getter(req, params, tagInfo.Value) defaultValue = tagInfo.Default From 4ade7828102c09cf97ccc0667e6ae7c1c2a865b3 Mon Sep 17 00:00:00 2001 From: fgy Date: Wed, 24 May 2023 21:06:42 +0800 Subject: [PATCH 43/91] ci:ci --- pkg/app/context.go | 32 +++---- pkg/app/server/binding/binder.go | 6 +- pkg/app/server/binding/binder_test.go | 75 +++++++++-------- pkg/app/server/binding/tagexpr_bind_test.go | 92 ++++++++++----------- 4 files changed, 104 insertions(+), 101 deletions(-) diff --git a/pkg/app/context.go b/pkg/app/context.go index c60febe24..747b93b75 100644 --- a/pkg/app/context.go +++ b/pkg/app/context.go @@ -1130,24 +1130,24 @@ func (ctx *RequestContext) Cookie(key string) []byte { return ctx.Request.Header.Cookie(key) } -// SetCookie adds a Set-cookie header to the Response's headers. +// SetCookie adds a Set-Cookie header to the Response's headers. // // Parameter introduce: -// name and value is used to set cookie's name and value, eg. Set-cookie: name=value -// maxAge is use to set cookie's expiry date, eg. Set-cookie: name=value; max-age=1 -// path and domain is used to set the scope of a cookie, eg. Set-cookie: name=value;domain=localhost; path=/; -// secure and httpOnly is used to sent cookies securely; eg. Set-cookie: name=value;HttpOnly; secure; -// sameSite let servers specify whether/when cookies are sent with cross-site requests; eg. Set-cookie: name=value;HttpOnly; secure; SameSite=Lax; +// name and value is used to set cookie's name and value, eg. Set-Cookie: name=value +// maxAge is use to set cookie's expiry date, eg. Set-Cookie: name=value; max-age=1 +// path and domain is used to set the scope of a cookie, eg. Set-Cookie: name=value;domain=localhost; path=/; +// secure and httpOnly is used to sent cookies securely; eg. Set-Cookie: name=value;HttpOnly; secure; +// sameSite let servers specify whether/when cookies are sent with cross-site requests; eg. Set-Cookie: name=value;HttpOnly; secure; SameSite=Lax; // // For example: // 1. ctx.SetCookie("user", "hertz", 1, "/", "localhost",protocol.CookieSameSiteLaxMode, true, true) -// add response header ---> Set-cookie: user=hertz; max-age=1; domain=localhost; path=/; HttpOnly; secure; SameSite=Lax; +// add response header ---> Set-Cookie: user=hertz; max-age=1; domain=localhost; path=/; HttpOnly; secure; SameSite=Lax; // 2. ctx.SetCookie("user", "hertz", 10, "/", "localhost",protocol.CookieSameSiteLaxMode, false, false) -// add response header ---> Set-cookie: user=hertz; max-age=10; domain=localhost; path=/; SameSite=Lax; +// add response header ---> Set-Cookie: user=hertz; max-age=10; domain=localhost; path=/; SameSite=Lax; // 3. ctx.SetCookie("", "hertz", 10, "/", "localhost",protocol.CookieSameSiteLaxMode, false, false) -// add response header ---> Set-cookie: hertz; max-age=10; domain=localhost; path=/; SameSite=Lax; +// add response header ---> Set-Cookie: hertz; max-age=10; domain=localhost; path=/; SameSite=Lax; // 4. ctx.SetCookie("user", "", 10, "/", "localhost",protocol.CookieSameSiteLaxMode, false, false) -// add response header ---> Set-cookie: user=; max-age=10; domain=localhost; path=/; SameSite=Lax; +// add response header ---> Set-Cookie: user=; max-age=10; domain=localhost; path=/; SameSite=Lax; func (ctx *RequestContext) SetCookie(name, value string, maxAge int, path, domain string, sameSite protocol.CookieSameSite, secure, httpOnly bool) { if path == "" { path = "/" @@ -1223,10 +1223,10 @@ func (ctx *RequestContext) PostArgs() *protocol.Args { // For example: // // GET /path?id=1234&name=Manu&value= -// c.query("id") == "1234" -// c.query("name") == "Manu" -// c.query("value") == "" -// c.query("wtf") == "" +// c.Query("id") == "1234" +// c.Query("name") == "Manu" +// c.Query("value") == "" +// c.Query("wtf") == "" func (ctx *RequestContext) Query(key string) string { value, _ := ctx.GetQuery(key) return value @@ -1305,7 +1305,7 @@ func bodyAllowedForStatus(status int) bool { // BindAndValidate binds data from *RequestContext to obj and validates them if needed. // NOTE: obj should be a pointer. func (ctx *RequestContext) BindAndValidate(obj interface{}) error { - err := binding.DefaultBinder.Bind(&ctx.Request, ctx.Params, obj) + err := binding.DefaultBinder().Bind(&ctx.Request, ctx.Params, obj) if err != nil { return err } @@ -1316,7 +1316,7 @@ func (ctx *RequestContext) BindAndValidate(obj interface{}) error { // Bind binds data from *RequestContext to obj. // NOTE: obj should be a pointer. func (ctx *RequestContext) Bind(obj interface{}) error { - return binding.DefaultBinder.Bind(&ctx.Request, ctx.Params, obj) + return binding.DefaultBinder().Bind(&ctx.Request, ctx.Params, obj) } // Validate validates obj with "vd" tag diff --git a/pkg/app/server/binding/binder.go b/pkg/app/server/binding/binder.go index c99d95adb..665a9ad81 100644 --- a/pkg/app/server/binding/binder.go +++ b/pkg/app/server/binding/binder.go @@ -50,4 +50,8 @@ type Binder interface { Bind(*protocol.Request, param.Params, interface{}) error } -var DefaultBinder Binder = &defaultBinder{} +var defaultBind Binder = &defaultBinder{} + +func DefaultBinder() Binder { + return defaultBind +} diff --git a/pkg/app/server/binding/binder_test.go b/pkg/app/server/binding/binder_test.go index 8a82fb80d..d1de12ed6 100644 --- a/pkg/app/server/binding/binder_test.go +++ b/pkg/app/server/binding/binder_test.go @@ -124,7 +124,7 @@ func TestBind_BaseType(t *testing.T) { var result Req - err := DefaultBinder.Bind(req.Req, params, &result) + err := DefaultBinder().Bind(req.Req, params, &result) if err != nil { t.Error(err) } @@ -149,7 +149,7 @@ func TestBind_SliceType(t *testing.T) { var result Req - err := DefaultBinder.Bind(req.Req, nil, &result) + err := DefaultBinder().Bind(req.Req, nil, &result) if err != nil { t.Error(err) } @@ -192,7 +192,7 @@ func TestBind_StructType(t *testing.T) { req := newMockRequest().SetRequestURI("http://foobar.com?F1=f1&B1=b1").SetHeader("f2", "f2") - err := DefaultBinder.Bind(req.Req, nil, &result) + err := DefaultBinder().Bind(req.Req, nil, &result) if err != nil { t.Error(err) } @@ -232,7 +232,7 @@ func TestBind_PointerType(t *testing.T) { req := newMockRequest().SetRequestURI(fmt.Sprintf("http://foobar.com?F1=%s&B1=%s&B2=%s&B3=%s&B3=%s&B4=%d&B4=%d", F1, B1, B2, B3s[0], B3s[1], B4s[0], B4s[1])). SetHeader("f2", "f2") - err := DefaultBinder.Bind(req.Req, nil, &result) + err := DefaultBinder().Bind(req.Req, nil, &result) if err != nil { t.Error(err) } @@ -264,7 +264,7 @@ func TestBind_NestedStruct(t *testing.T) { result := Bar{} req := newMockRequest().SetRequestURI("http://foobar.com?F1=qwe") - err := DefaultBinder.Bind(req.Req, nil, &result) + err := DefaultBinder().Bind(req.Req, nil, &result) if err != nil { t.Error(err) } @@ -285,7 +285,7 @@ func TestBind_SliceStruct(t *testing.T) { B1s := []string{"1", "2", "3"} req := newMockRequest().SetRequestURI(fmt.Sprintf("http://foobar.com?F1={\"f1\":\"%s\"}&F1={\"f1\":\"%s\"}&F1={\"f1\":\"%s\"}", B1s[0], B1s[1], B1s[2])) - err := DefaultBinder.Bind(req.Req, nil, &result) + err := DefaultBinder().Bind(req.Req, nil, &result) if err != nil { t.Error(err) } @@ -300,7 +300,7 @@ func TestBind_MapType(t *testing.T) { req := newMockRequest(). SetJSONContentType(). SetBody([]byte(`{"j1":"j1", "j2":"j2"}`)) - err := DefaultBinder.Bind(req.Req, nil, &result) + err := DefaultBinder().Bind(req.Req, nil, &result) if err != nil { t.Fatal(err) } @@ -319,7 +319,7 @@ func TestBind_MapFieldType(t *testing.T) { SetJSONContentType(). SetBody([]byte(`{"j1":"j1", "j2":"j2"}`)) result := Foo{} - err := DefaultBinder.Bind(req.Req, nil, &result) + err := DefaultBinder().Bind(req.Req, nil, &result) if err != nil { t.Fatal(err) } @@ -334,7 +334,7 @@ func TestBind_UnexportedField(t *testing.T) { } req := newMockRequest(). SetRequestURI("http://foobar.com?a=1&b=2") - err := DefaultBinder.Bind(req.Req, nil, &s) + err := DefaultBinder().Bind(req.Req, nil, &s) if err != nil { t.Fatal(err) } @@ -358,7 +358,7 @@ func TestBind_NoTagField(t *testing.T) { Value: "b2", }) - err := DefaultBinder.Bind(req.Req, params, &s) + err := DefaultBinder().Bind(req.Req, params, &s) if err != nil { t.Fatal(err) } @@ -375,7 +375,7 @@ func TestBind_ZeroValueBind(t *testing.T) { req := newMockRequest(). SetRequestURI("http://foobar.com?a=&b") - err := DefaultBinder.Bind(req.Req, nil, &s) + err := DefaultBinder().Bind(req.Req, nil, &s) if err != nil { t.Fatal(err) } @@ -393,7 +393,7 @@ func TestBind_DefaultValueBind(t *testing.T) { req := newMockRequest(). SetRequestURI("http://foobar.com") - err := DefaultBinder.Bind(req.Req, nil, &s) + err := DefaultBinder().Bind(req.Req, nil, &s) if err != nil { t.Fatal(err) } @@ -406,7 +406,7 @@ func TestBind_DefaultValueBind(t *testing.T) { D [2]string `default:"qwe"` } - err = DefaultBinder.Bind(req.Req, nil, &d) + err = DefaultBinder().Bind(req.Req, nil, &d) if err == nil { t.Fatal("expected err") } @@ -420,7 +420,7 @@ func TestBind_RequiredBind(t *testing.T) { SetRequestURI("http://foobar.com"). SetHeader("A", "1") - err := DefaultBinder.Bind(req.Req, nil, &s) + err := DefaultBinder().Bind(req.Req, nil, &s) if err == nil { t.Fatal("expected error") } @@ -428,7 +428,7 @@ func TestBind_RequiredBind(t *testing.T) { var d struct { A int `query:"a,required" header:"A"` } - err = DefaultBinder.Bind(req.Req, nil, &d) + err = DefaultBinder().Bind(req.Req, nil, &d) if err != nil { t.Fatal(err) } @@ -450,7 +450,7 @@ func TestBind_TypedefType(t *testing.T) { } req := newMockRequest(). SetRequestURI("http://foobar.com?a=1&b=2") - err := DefaultBinder.Bind(req.Req, nil, &s) + err := DefaultBinder().Bind(req.Req, nil, &s) if err != nil { t.Fatal(err) } @@ -483,7 +483,7 @@ func TestBind_EnumBind(t *testing.T) { } req := newMockRequest(). SetRequestURI("http://foobar.com?a=0&b=2") - err := DefaultBinder.Bind(req.Req, nil, &s) + err := DefaultBinder().Bind(req.Req, nil, &s) if err != nil { t.Fatal(err) } @@ -508,7 +508,6 @@ func TestBind_CustomizedTypeDecode(t *testing.T) { } return reflect.ValueOf(val), nil }) - if err != nil { t.Fatal(err) } @@ -516,7 +515,7 @@ func TestBind_CustomizedTypeDecode(t *testing.T) { req := newMockRequest(). SetRequestURI("http://foobar.com?a=1&b=2") result := Foo{} - err = DefaultBinder.Bind(req.Req, nil, &result) + err = DefaultBinder().Bind(req.Req, nil, &result) if err != nil { t.Fatal(err) } @@ -527,7 +526,7 @@ func TestBind_CustomizedTypeDecode(t *testing.T) { } result2 := Bar{} - err = DefaultBinder.Bind(req.Req, nil, &result2) + err = DefaultBinder().Bind(req.Req, nil, &result2) if err != nil { t.Error(err) } @@ -549,7 +548,7 @@ func TestBind_JSON(t *testing.T) { SetJSONContentType(). SetBody([]byte(fmt.Sprintf(`{"j1":"j1", "j2":12, "j3":[%d, %d], "j4":["%s", "%s"]}`, J3s[0], J3s[1], J4s[0], J4s[1]))) var result Req - err := DefaultBinder.Bind(req.Req, nil, &result) + err := DefaultBinder().Bind(req.Req, nil, &result) if err != nil { t.Error(err) } @@ -578,7 +577,7 @@ func TestBind_ResetJSONUnmarshal(t *testing.T) { SetJSONContentType(). SetBody([]byte(fmt.Sprintf(`{"j1":"j1", "j2":12, "j3":[%d, %d], "j4":["%s", "%s"]}`, J3s[0], J3s[1], J4s[0], J4s[1]))) var result Req - err := DefaultBinder.Bind(req.Req, nil, &result) + err := DefaultBinder().Bind(req.Req, nil, &result) if err != nil { t.Error(err) } @@ -613,7 +612,7 @@ func TestBind_FileBind(t *testing.T) { // to parse multipart files req2 := req2.GetHTTP1Request(req.Req) _ = req2.String() - err := DefaultBinder.Bind(req.Req, nil, &s) + err := DefaultBinder().Bind(req.Req, nil, &s) if err != nil { t.Fatal(err) } @@ -645,7 +644,7 @@ func TestBind_FileSliceBind(t *testing.T) { // to parse multipart files req2 := req2.GetHTTP1Request(req.Req) _ = req2.String() - err := DefaultBinder.Bind(req.Req, nil, &s) + err := DefaultBinder().Bind(req.Req, nil, &s) if err != nil { t.Fatal(err) } @@ -669,19 +668,19 @@ func TestBind_FileSliceBind(t *testing.T) { func TestBind_AnonymousField(t *testing.T) { type nest struct { - n1 string `query:"n1"` // bind default value - N2 ***string `query:"n2"` // bind n2 value - string `query:"n3"` // bind default value + n1 string `query:"n1"` // bind default value + N2 ***string `query:"n2"` // bind n2 value + string `query:"n3"` // bind default value } var s struct { - s1 int `query:"s1"` // bind default value - int `query:"s2"` // bind default value + s1 int `query:"s1"` // bind default value + int `query:"s2"` // bind default value nest } req := newMockRequest(). SetRequestURI("http://foobar.com?s1=1&s2=2&n1=1&n2=2&n3=3") - err := DefaultBinder.Bind(req.Req, nil, &s) + err := DefaultBinder().Bind(req.Req, nil, &s) if err != nil { t.Fatal(err) } @@ -713,7 +712,7 @@ func TestBind_IgnoreField(t *testing.T) { var result Req - err := DefaultBinder.Bind(req.Req, params, &result) + err := DefaultBinder().Bind(req.Req, params, &result) if err != nil { t.Error(err) } @@ -747,7 +746,7 @@ func TestBind_DefaultTag(t *testing.T) { Value: "1", }) var result Req - err := DefaultBinder.Bind(req.Req, params, &result) + err := DefaultBinder().Bind(req.Req, params, &result) if err != nil { t.Error(err) } @@ -761,7 +760,7 @@ func TestBind_DefaultTag(t *testing.T) { EnableDefaultTag(true) }() result2 := Req2{} - err = DefaultBinder.Bind(req.Req, params, &result2) + err = DefaultBinder().Bind(req.Req, params, &result2) if err != nil { t.Error(err) } @@ -790,7 +789,7 @@ func TestBind_StructFieldResolve(t *testing.T) { defer func() { EnableStructFieldResolve(false) }() - err := DefaultBinder.Bind(req.Req, nil, &result) + err := DefaultBinder().Bind(req.Req, nil, &result) if err != nil { t.Error(err) } @@ -802,7 +801,7 @@ func TestBind_StructFieldResolve(t *testing.T) { SetHeaders("Header", "header"). SetPostArg("Form", "form"). SetUrlEncodeContentType() - err = DefaultBinder.Bind(req.Req, nil, &result) + err = DefaultBinder().Bind(req.Req, nil, &result) if err != nil { t.Error(err) } @@ -837,7 +836,7 @@ func TestBind_JSONRequiredField(t *testing.T) { SetJSONContentType(). SetBody(bodyBytes) var result Req - err := DefaultBinder.Bind(req.Req, nil, &result) + err := DefaultBinder().Bind(req.Req, nil, &result) if err == nil { t.Errorf("expected an error, but get nil") } @@ -857,7 +856,7 @@ func TestBind_JSONRequiredField(t *testing.T) { SetJSONContentType(). SetBody(bodyBytes) var result2 Req - err = DefaultBinder.Bind(req.Req, nil, &result2) + err = DefaultBinder().Bind(req.Req, nil, &result2) if err != nil { t.Error(err) } @@ -890,7 +889,7 @@ func Benchmark_Binding(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { var result Req - err := DefaultBinder.Bind(req.Req, params, &result) + err := DefaultBinder().Bind(req.Req, params, &result) if err != nil { b.Error(err) } diff --git a/pkg/app/server/binding/tagexpr_bind_test.go b/pkg/app/server/binding/tagexpr_bind_test.go index e5bb24a7d..ada72fe27 100644 --- a/pkg/app/server/binding/tagexpr_bind_test.go +++ b/pkg/app/server/binding/tagexpr_bind_test.go @@ -59,7 +59,7 @@ func TestRawBody(t *testing.T) { bodyBytes := []byte("raw_body.............") req := newRequest("", nil, nil, bytes.NewReader(bodyBytes)) recv := new(Recv) - err := DefaultBinder.Bind(req.Req, nil, recv) + err := DefaultBinder().Bind(req.Req, nil, recv) if err != nil { if err != nil { t.Error(err) @@ -89,7 +89,7 @@ func TestQueryString(t *testing.T) { } req := newRequest("http://localhost:8080/?a=a1&a=a2&b=b1&c=c1&c=c2&d=d1&d=d&f=qps&g=1002&g=1003&e=&e=2&y=y1", nil, nil, nil) recv := new(Recv) - err := DefaultBinder.Bind(req.Req, nil, recv) + err := DefaultBinder().Bind(req.Req, nil, recv) if err != nil { t.Error(err) } @@ -113,7 +113,7 @@ func TestGetBody(t *testing.T) { } req := newRequest("http://localhost:8080/", nil, nil, nil) recv := new(Recv) - err := DefaultBinder.Bind(req.Req, nil, recv) + err := DefaultBinder().Bind(req.Req, nil, recv) if err == nil { t.Fatalf("expected an error, but get nil") } @@ -133,7 +133,7 @@ func TestQueryNum(t *testing.T) { } req := newRequest("http://localhost:8080/?a=11&a=12&b=21&c=31&c=32&d=41&d=42&y=true", nil, nil, nil) recv := new(Recv) - err := DefaultBinder.Bind(req.Req, nil, recv) + err := DefaultBinder().Bind(req.Req, nil, recv) if err != nil { if err != nil { t.Error(err) @@ -169,7 +169,7 @@ func TestHeaderString(t *testing.T) { header.Add("X-Y", "y1") req := newRequest("", header, nil, nil) recv := new(Recv) - err := DefaultBinder.Bind(req.Req, nil, recv) + err := DefaultBinder().Bind(req.Req, nil, recv) if err != nil { if err != nil { t.Error(err) @@ -206,7 +206,7 @@ func TestHeaderNum(t *testing.T) { req := newRequest("", header, nil, nil) recv := new(Recv) - err := DefaultBinder.Bind(req.Req, nil, recv) + err := DefaultBinder().Bind(req.Req, nil, recv) if err != nil { t.Error(err) } @@ -242,7 +242,7 @@ func TestCookieString(t *testing.T) { req := newRequest("", nil, cookies, nil) recv := new(Recv) - err := DefaultBinder.Bind(req.Req, nil, recv) + err := DefaultBinder().Bind(req.Req, nil, recv) if err != nil { t.Error(err) } @@ -275,7 +275,7 @@ func TestCookieNum(t *testing.T) { req := newRequest("", nil, cookies, nil) recv := new(Recv) - err := DefaultBinder.Bind(req.Req, nil, recv) + err := DefaultBinder().Bind(req.Req, nil, recv) if err != nil { t.Error(err) } @@ -321,7 +321,7 @@ func TestFormString(t *testing.T) { header.Set("Content-Type", contentType) req := newRequest("", header, nil, bodyReader) recv := new(Recv) - err := DefaultBinder.Bind(req.Req, nil, recv) + err := DefaultBinder().Bind(req.Req, nil, recv) if err != nil { t.Error(err) } @@ -372,7 +372,7 @@ func TestFormNum(t *testing.T) { req := newRequest("", header, nil, bodyReader) recv := new(Recv) - err := DefaultBinder.Bind(req.Req, nil, recv) + err := DefaultBinder().Bind(req.Req, nil, recv) if err != nil { t.Error(err) } @@ -423,7 +423,7 @@ func TestJSON(t *testing.T) { req := newRequest("", header, nil, bodyReader) recv := new(Recv) - err := DefaultBinder.Bind(req.Req, nil, recv) + err := DefaultBinder().Bind(req.Req, nil, recv) if err == nil { t.Error("expected an error, but get nil") } @@ -460,32 +460,32 @@ func TestPath(t *testing.T) { params := param.Params{ { - "a", - "a1", + Key: "a", + Value: "a1", }, { - "b", - "-21", + Key: "b", + Value: "-21", }, { - "c", - "31", + Key: "c", + Value: "31", }, { - "d", - "41", + Key: "d", + Value: "41", }, { - "y", - "y1", + Key: "y", + Value: "y1", }, { - "name", - "henrylee2cn", + Key: "name", + Value: "henrylee2cn", }, } - err := DefaultBinder.Bind(req.Req, params, recv) + err := DefaultBinder().Bind(req.Req, params, recv) if err != nil { t.Error(err) } @@ -546,12 +546,12 @@ func TestDefault(t *testing.T) { param2 := param.Params{ { - "e", - "123", + Key: "e", + Value: "123", }, } - err := DefaultBinder.Bind(req.Req, param2, recv) + err := DefaultBinder().Bind(req.Req, param2, recv) if err != nil { t.Error(err) } @@ -604,7 +604,7 @@ func TestAuto(t *testing.T) { }, bodyReader) recv := new(Recv) - err = DefaultBinder.Bind(req.Req, nil, recv) + err = DefaultBinder().Bind(req.Req, nil, recv) if err != nil { t.Error(err) } @@ -625,7 +625,7 @@ func TestAuto(t *testing.T) { header.Set("Content-Type", contentType) req = newRequest("http://localhost/?"+query.Encode(), header, nil, bodyReader) recv = new(Recv) - err = DefaultBinder.Bind(req.Req, nil, recv) + err = DefaultBinder().Bind(req.Req, nil, recv) if err != nil { t.Error(err) } @@ -652,7 +652,7 @@ func TestTypeUnmarshal(t *testing.T) { req := newRequest("http://localhost/?"+query.Encode(), header, nil, bodyReader) recv := new(Recv) - err := DefaultBinder.Bind(req.Req, nil, recv) + err := DefaultBinder().Bind(req.Req, nil, recv) if err != nil { t.Error(err) } @@ -696,7 +696,7 @@ func TestOption(t *testing.T) { req := newRequest("", header, nil, bodyReader) recv := new(Recv) - err := DefaultBinder.Bind(req.Req, nil, recv) + err := DefaultBinder().Bind(req.Req, nil, recv) if err != nil { t.Error(err) } @@ -711,7 +711,7 @@ func TestOption(t *testing.T) { }`) req = newRequest("", header, nil, bodyReader) recv = new(Recv) - err = DefaultBinder.Bind(req.Req, nil, recv) + err = DefaultBinder().Bind(req.Req, nil, recv) assert.DeepEqual(t, err.Error(), "'C' field is a 'required' parameter, but the request does not have this parameter") assert.DeepEqual(t, 0, recv.X.C) assert.DeepEqual(t, 0, recv.X.D) @@ -722,7 +722,7 @@ func TestOption(t *testing.T) { }`) req = newRequest("", header, nil, bodyReader) recv = new(Recv) - err = DefaultBinder.Bind(req.Req, nil, recv) + err = DefaultBinder().Bind(req.Req, nil, recv) if err != nil { t.Error(err) } @@ -745,7 +745,7 @@ func TestOption(t *testing.T) { defer func() { EnableStructFieldResolve(false) }() - err = DefaultBinder.Bind(req.Req, nil, recv2) + err = DefaultBinder().Bind(req.Req, nil, recv2) assert.DeepEqual(t, err.Error(), "'X' field is a 'required' parameter, but the request does not have this parameter") assert.True(t, recv2.X == nil) assert.DeepEqual(t, "y1", recv2.Y) @@ -793,13 +793,13 @@ func TestQueryStringIssue(t *testing.T) { req := newRequest("http://localhost:8080/?name=test", nil, nil, nil) recv := new(Recv) - err := DefaultBinder.Bind(req.Req, nil, recv) + err := DefaultBinder().Bind(req.Req, nil, recv) if err != nil { t.Error(err) } assert.DeepEqual(t, "test", *recv.Name) // DIFF: the type with customized decoder must be a non-nil value - //assert.DeepEqual(t, (*Timestamp)(nil), recv.T) + // assert.DeepEqual(t, (*Timestamp)(nil), recv.T) } func TestQueryTypes(t *testing.T) { @@ -840,7 +840,7 @@ func TestQueryTypes(t *testing.T) { }, bodyReader) recv := new(Recv) - err = DefaultBinder.Bind(req.Req, nil, recv) + err = DefaultBinder().Bind(req.Req, nil, recv) if err != nil { t.Error(err) } @@ -863,7 +863,7 @@ func TestNoTagIssue(t *testing.T) { req := newRequest("http://localhost:8080/?x=11&x2=12&a=1&B=2", nil, nil, nil) recv := new(T) - err := DefaultBinder.Bind(req.Req, nil, recv) + err := DefaultBinder().Bind(req.Req, nil, recv) if err != nil { t.Error(err) } @@ -898,7 +898,7 @@ func TestRegTypeUnmarshal(t *testing.T) { defer func() { EnableStructFieldResolve(false) }() - err = DefaultBinder.Bind(req.Req, nil, recv) + err = DefaultBinder().Bind(req.Req, nil, recv) if err != nil { t.Error(err) } @@ -964,7 +964,7 @@ func TestPathnameBUG(t *testing.T) { req := newRequest("http://localhost", header, nil, bytes.NewReader(b)) recv := new(ExchangeCurrencyRequest) - err = DefaultBinder.Bind(req.Req, nil, recv) + err = DefaultBinder().Bind(req.Req, nil, recv) if err != nil { t.Error(err) } @@ -1021,7 +1021,7 @@ func TestPathnameBUG2(t *testing.T) { req := newRequest("http://localhost", header, nil, bytes.NewReader(b)) recv := new(CreateFreeShippingRequest) - err = DefaultBinder.Bind(req.Req, nil, recv) + err = DefaultBinder().Bind(req.Req, nil, recv) if err != nil { t.Error(err) } @@ -1069,7 +1069,7 @@ func TestRequiredBUG(t *testing.T) { req := newRequest("http://localhost", header, nil, bytes.NewReader(b)) recv := new(ExchangeCurrencyRequest) - err := DefaultBinder.Bind(req.Req, nil, recv) + err := DefaultBinder().Bind(req.Req, nil, recv) // no need for validate if err != nil { t.Error(err) @@ -1089,7 +1089,7 @@ func TestIssue25(t *testing.T) { req := newRequest("/1", header, cookies, nil) recv := new(Recv) - err := DefaultBinder.Bind(req.Req, nil, recv) + err := DefaultBinder().Bind(req.Req, nil, recv) if err != nil { t.Error(err) } @@ -1100,7 +1100,7 @@ func TestIssue25(t *testing.T) { cookies2 := []*http.Cookie{} req2 := newRequest("/2", header2, cookies2, nil) recv2 := new(Recv) - err2 := DefaultBinder.Bind(req2.Req, nil, recv2) + err2 := DefaultBinder().Bind(req2.Req, nil, recv2) if err2 != nil { t.Error(err2) } @@ -1149,7 +1149,7 @@ func TestIssue26(t *testing.T) { req := newRequest("/1", header, cookies, bytes.NewReader(b)) recv2 := new(Recv) - err = DefaultBinder.Bind(req.Req, nil, recv2) + err = DefaultBinder().Bind(req.Req, nil, recv2) if err != nil { t.Error(err) } @@ -1173,7 +1173,7 @@ func TestIssue26(t *testing.T) { // req := newRequest("", header, nil, bodyReader) // recv := new(Recv) // -// err := DefaultBinder.Bind(req.Req, nil, recv) +// err := DefaultBinder().Bind(req.Req, nil, recv) // if err != nil { // t.Error(err) // } From fb4680efe417bec6f9fe5feeaa48ab1380f65004 Mon Sep 17 00:00:00 2001 From: fgy Date: Fri, 26 May 2023 16:28:28 +0800 Subject: [PATCH 44/91] feat: more api --- go.mod | 3 +- go.sum | 56 ++---- licenses/LICENSE-validator.txt | 21 --- pkg/app/context.go | 56 +++++- pkg/app/context_test.go | 2 +- pkg/app/server/binding/binder.go | 7 + pkg/app/server/binding/binder_test.go | 41 ++++- pkg/app/server/binding/config.go | 63 ++++++- pkg/app/server/binding/decoder/decoder.go | 45 +++-- pkg/app/server/binding/decoder/tag.go | 39 +++- pkg/app/server/binding/default.go | 211 +++++++++++++++++++--- pkg/app/server/binding/validator.go | 22 +-- pkg/app/server/binding/validator_test.go | 26 +-- 13 files changed, 432 insertions(+), 160 deletions(-) delete mode 100644 licenses/LICENSE-validator.txt diff --git a/go.mod b/go.mod index 3e7b0d2e3..25f20f8a7 100644 --- a/go.mod +++ b/go.mod @@ -3,13 +3,12 @@ module github.com/cloudwego/hertz go 1.16 require ( + github.com/bytedance/go-tagexpr/v2 v2.9.2 github.com/bytedance/gopkg v0.0.0-20220413063733-65bf48ffb3a7 github.com/bytedance/mockey v1.2.1 github.com/bytedance/sonic v1.8.1 github.com/cloudwego/netpoll v0.4.2-0.20230807055039-52fd5fb7b00f github.com/fsnotify/fsnotify v1.5.4 - github.com/go-playground/assert/v2 v2.2.0 // indirect - github.com/go-playground/validator/v10 v10.11.1 github.com/tidwall/gjson v1.14.4 golang.org/x/sync v0.0.0-20210220032951-036812b2e83c golang.org/x/sys v0.0.0-20220412211240-33da011f77ad diff --git a/go.sum b/go.sum index 679c3c1e0..30da12c13 100644 --- a/go.sum +++ b/go.sum @@ -1,3 +1,5 @@ +github.com/bytedance/go-tagexpr/v2 v2.9.2 h1:QySJaAIQgOEDQBLS3x9BxOWrnhqu5sQ+f6HaZIxD39I= +github.com/bytedance/go-tagexpr/v2 v2.9.2/go.mod h1:5qsx05dYOiUXOUgnQ7w3Oz8BYs2qtM/bJokdLb79wRM= github.com/bytedance/gopkg v0.0.0-20220413063733-65bf48ffb3a7 h1:PtwsQyQJGxf8iaPptPNaduEIu9BnrNms+pcRdHAxZaM= github.com/bytedance/gopkg v0.0.0-20220413063733-65bf48ffb3a7/go.mod h1:2ZlV9BaUH4+NXIBF0aMdKKAnHTzqH+iMU4KUjAbL23Q= github.com/bytedance/mockey v1.2.1 h1:g84ngI88hz1DR4wZTL3yOuqlEcq67MretBfQUdXwrmw= @@ -8,44 +10,33 @@ github.com/bytedance/sonic v1.8.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZX github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= -github.com/cloudwego/netpoll v0.4.2-0.20230807055039-52fd5fb7b00f h1:8iWPKjHdXl4tjcSxUJTavnhRL5JPupYvxbtsAlm2Igw= -github.com/cloudwego/netpoll v0.4.2-0.20230807055039-52fd5fb7b00f/go.mod h1:xVefXptcyheopwNDZjDPcfU6kIjZXZ4nY550k1yH9eQ= +github.com/cloudwego/netpoll v0.3.2 h1:/998ICrNMVBo4mlul4j7qcIeY7QnEfuCCPPwck9S3X4= +github.com/cloudwego/netpoll v0.3.2/go.mod h1:xVefXptcyheopwNDZjDPcfU6kIjZXZ4nY550k1yH9eQ= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/fsnotify/fsnotify v1.5.4 h1:jRbGcIw6P2Meqdwuo0H1p6JVLbL5DHKAKlYndzMwVZI= github.com/fsnotify/fsnotify v1.5.4/go.mod h1:OVB6XrOHzAwXMpEM7uPOzcehqUV2UqJxmVXmkdnm1bU= -github.com/go-playground/assert/v2 v2.0.1/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= -github.com/go-playground/assert/v2 v2.2.0 h1:JvknZsQTYeFEAhQwI4qEt9cyV5ONwRHC+lYKSsYSR8s= -github.com/go-playground/assert/v2 v2.2.0/go.mod h1:VDjEfimB/XKnb+ZQfWdccd7VUvScMdVu0Titje2rxJ4= -github.com/go-playground/locales v0.14.0 h1:u50s323jtVGugKlcYeyzC0etD1HifMjqmJqb8WugfUU= -github.com/go-playground/locales v0.14.0/go.mod h1:sawfccIbzZTqEDETgFXqTho0QybSa7l++s0DH+LDiLs= -github.com/go-playground/universal-translator v0.18.0 h1:82dyy6p4OuJq4/CByFNOn/jYrnRPArHwAcmLoJZxyho= -github.com/go-playground/universal-translator v0.18.0/go.mod h1:UvRDBj+xPUEGrFYl+lu/H90nyDXpg0fqeB/AQUGNTVA= -github.com/go-playground/validator/v10 v10.11.1 h1:prmOlTVv+YjZjmRmNSF3VmspqJIxJWXmqUsHwfTRRkQ= -github.com/go-playground/validator/v10 v10.11.1/go.mod h1:i+3WkQ1FvaUjjxh1kSvIA4dMGDBiPU55YFDl0WbKdWU= +github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.5.0 h1:LUVKkCeviFUMKqHa4tXIIij/lbhnMbP7Fn5wKdKkRh4= github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 h1:EGx4pi6eqNxGaHF6qqu48+N2wcFQ5qg5FXgOdqsJ5d8= github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= +github.com/henrylee2cn/ameda v1.4.8/go.mod h1:liZulR8DgHxdK+MEwvZIylGnmcjzQ6N6f2PlWe7nEO4= +github.com/henrylee2cn/ameda v1.4.10 h1:JdvI2Ekq7tapdPsuhrc4CaFiqw6QXFvZIULWJgQyCAk= +github.com/henrylee2cn/ameda v1.4.10/go.mod h1:liZulR8DgHxdK+MEwvZIylGnmcjzQ6N6f2PlWe7nEO4= +github.com/henrylee2cn/goutil v0.0.0-20210127050712-89660552f6f8 h1:yE9ULgp02BhYIrO6sdV/FPe0xQM6fNHkVQW2IAymfM0= +github.com/henrylee2cn/goutil v0.0.0-20210127050712-89660552f6f8/go.mod h1:Nhe/DM3671a5udlv2AdV2ni/MZzgfv2qrPL5nIi3EGQ= github.com/jtolds/gls v4.20.0+incompatible h1:xdiiI2gbIgH/gLH7ADydsJ1uDOEzR8yvV7C0MuV77Wo= github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4= github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= -github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= -github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI= -github.com/kr/pretty v0.3.0/go.mod h1:640gp4NfQd8pI5XOwp5fnNeVWj67G7CFk/SaSQn7NBk= -github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= -github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= -github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= -github.com/leodido/go-urn v1.2.1 h1:BqpAaACuzVSgi/VLzGZIobT2z4v53pjosyNd9Yv6n/w= -github.com/leodido/go-urn v1.2.1/go.mod h1:zt4jvISO2HfUBqxjfIshjdMTYS56ZS/qv49ictyFfxY= -github.com/pkg/diff v0.0.0-20210226163009-20ebb0f2a09e/go.mod h1:pJLUxLENpZxwdsKMEsNbx1VGcRFpLqf3715MtcvvzbA= +github.com/nyaruka/phonenumbers v1.0.55 h1:bj0nTO88Y68KeUQ/n3Lo2KgK7lM1hF7L9NFuwcCl3yg= +github.com/nyaruka/phonenumbers v1.0.55/go.mod h1:sDaTZ/KPX5f8qyV9qN+hIm+4ZBARJrupC6LuhshJq1U= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/rogpeppe/go-internal v1.6.1/go.mod h1:xXDCJY+GAPziupqXw64V24skbSoqbTEfhy4qGm1nDQc= -github.com/rogpeppe/go-internal v1.8.0/go.mod h1:WmiCO8CzOY8rg0OYDC4/i/2WRWAB6poM+XZ2dLUbcbE= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d h1:zE9ykElWQ6/NYmHa3jpm/yHnI4xSofP+UP6SpjHcSeM= github.com/smartystreets/assertions v0.0.0-20180927180507-b2de0cb4f26d/go.mod h1:OnSkiWE9lh6wB0YB77sQom3nweQdgAjqCqsofrRNTgc= github.com/smartystreets/goconvey v1.6.4 h1:fv0U8FUIMPNf1L9lnHLvLhgicrIVChEkdzIKYqbNC9s= @@ -53,12 +44,14 @@ github.com/smartystreets/goconvey v1.6.4/go.mod h1:syvi0/a8iFYH4r/RixwvyeAJjdLS9 github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= -github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +github.com/tidwall/gjson v1.9.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.14.4 h1:uo0p8EbA09J7RQaflQ1aBRffTR7xedD2bcIVSYxLnkM= github.com/tidwall/gjson v1.14.4/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= @@ -71,26 +64,14 @@ golang.org/x/arch v0.0.0-20201008161808-52c3e6f60cff/go.mod h1:flIaEI6LNU6xOCD5P golang.org/x/arch v0.0.0-20210923205945-b76863e36670 h1:18EFjUmQOcUvxNYSkA6jO9VAiXCnxFY6NyDX0bHDmkU= golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= -golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3 h1:0es+/5331RGQPcXlMfP+WrnIIS6dNnNRe0WB02W0F4M= -golang.org/x/crypto v0.0.0-20211215153901-e495a2d5b3d3/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= -golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c h1:5KslGYwFpkhGh+Q16bwMP3cOontH8FOep7tGV86Y7SQ= golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= -golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= -golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/sys v0.0.0-20210806184541-e5e7981a1069/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220110181412-a018aaa089fe/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220412211240-33da011f77ad h1:ntjMns5wyP/fN65tdBD4g8J5w8n015+iIIs9rtjXkY0= golang.org/x/sys v0.0.0-20220412211240-33da011f77ad/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= -golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= -golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= -golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk= -golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= -golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= golang.org/x/tools v0.0.0-20190328211700-ab21143f2384/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= @@ -98,11 +79,8 @@ google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp0 google.golang.org/protobuf v1.27.1 h1:SnqbnDw1V7RiZcXPx5MEeqPv2s79L9i7BJUlG/+RurQ= google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= -gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q= -gopkg.in/errgo.v2 v2.1.0/go.mod h1:hNsd1EY+bozCKY1Ytp96fpM3vjJbqLJn88ws8XvfDNI= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= -gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/licenses/LICENSE-validator.txt b/licenses/LICENSE-validator.txt deleted file mode 100644 index ab4304b3c..000000000 --- a/licenses/LICENSE-validator.txt +++ /dev/null @@ -1,21 +0,0 @@ -The MIT License (MIT) - -Copyright (c) 2015 Dean Karn - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. diff --git a/pkg/app/context.go b/pkg/app/context.go index 747b93b75..008ed49db 100644 --- a/pkg/app/context.go +++ b/pkg/app/context.go @@ -1305,12 +1305,7 @@ func bodyAllowedForStatus(status int) bool { // BindAndValidate binds data from *RequestContext to obj and validates them if needed. // NOTE: obj should be a pointer. func (ctx *RequestContext) BindAndValidate(obj interface{}) error { - err := binding.DefaultBinder().Bind(&ctx.Request, ctx.Params, obj) - if err != nil { - return err - } - err = binding.DefaultValidator.ValidateStruct(obj) - return err + return binding.DefaultBinder().BindAndValidate(&ctx.Request, ctx.Params, obj) } // Bind binds data from *RequestContext to obj. @@ -1322,7 +1317,54 @@ func (ctx *RequestContext) Bind(obj interface{}) error { // Validate validates obj with "vd" tag // NOTE: obj should be a pointer. func (ctx *RequestContext) Validate(obj interface{}) error { - return binding.DefaultValidator.ValidateStruct(obj) + return binding.DefaultValidator().ValidateStruct(obj) +} + +func (ctx *RequestContext) BindQuery(obj interface{}) error { + return binding.DefaultBinder().BindQuery(&ctx.Request, obj) +} + +func (ctx *RequestContext) BindHeader(obj interface{}) error { + return binding.DefaultBinder().BindHeader(&ctx.Request, obj) +} + +func (ctx *RequestContext) BindPath(obj interface{}) error { + return binding.DefaultBinder().BindPath(&ctx.Request, ctx.Params, obj) +} + +func (ctx *RequestContext) BindForm(obj interface{}) error { + return binding.DefaultBinder().BindForm(&ctx.Request, obj) +} + +func (ctx *RequestContext) BindJSON(obj interface{}) error { + return binding.DefaultBinder().BindJSON(&ctx.Request, obj) +} + +func (ctx *RequestContext) BindProtobuf(obj interface{}) error { + return binding.DefaultBinder().BindProtobuf(&ctx.Request, obj) +} + +func (ctx *RequestContext) BindByContentType(obj interface{}) error { + if bytesconv.B2s(ctx.Request.Method()) == consts.MethodGet { + return ctx.BindQuery(obj) + } + ct := utils.FilterContentType(bytesconv.B2s(ctx.Request.Header.ContentType())) + switch ct { + case "application/json": + return ctx.BindJSON(obj) + case "application/x-protobuf": + return ctx.BindProtobuf(obj) + case "application/xml", "text/xml": + return fmt.Errorf("unsupported bind content-type for '%s'", ct) + case "application/x-msgpack", "application/msgpack": + return fmt.Errorf("unsupported bind content-type for '%s'", ct) + case "application/x-yaml": + return fmt.Errorf("unsupported bind content-type for '%s'", ct) + case "application/toml": + return fmt.Errorf("unsupported bind content-type for '%s'", ct) + default: // case MIMEPOSTForm/MIMEMultipartPOSTForm + return ctx.BindForm(obj) + } } // VisitAllQueryArgs calls f for each existing query arg. diff --git a/pkg/app/context_test.go b/pkg/app/context_test.go index 2567a822d..88cd26338 100644 --- a/pkg/app/context_test.go +++ b/pkg/app/context_test.go @@ -1419,7 +1419,7 @@ func TestRequestContext_GetResponse(t *testing.T) { func TestBindAndValidate(t *testing.T) { type Test struct { A string `query:"a"` - B int `query:"b" validate:"gt=10"` + B int `query:"b" vd:"$>10"` } c := &RequestContext{} diff --git a/pkg/app/server/binding/binder.go b/pkg/app/server/binding/binder.go index 665a9ad81..742e3fe3d 100644 --- a/pkg/app/server/binding/binder.go +++ b/pkg/app/server/binding/binder.go @@ -48,6 +48,13 @@ import ( type Binder interface { Name() string Bind(*protocol.Request, param.Params, interface{}) error + BindAndValidate(*protocol.Request, param.Params, interface{}) error + BindQuery(*protocol.Request, interface{}) error + BindHeader(*protocol.Request, interface{}) error + BindPath(*protocol.Request, param.Params, interface{}) error + BindForm(*protocol.Request, interface{}) error + BindJSON(*protocol.Request, interface{}) error + BindProtobuf(*protocol.Request, interface{}) error } var defaultBind Binder = &defaultBinder{} diff --git a/pkg/app/server/binding/binder_test.go b/pkg/app/server/binding/binder_test.go index d1de12ed6..c98be2c14 100644 --- a/pkg/app/server/binding/binder_test.go +++ b/pkg/app/server/binding/binder_test.go @@ -563,7 +563,7 @@ func TestBind_JSON(t *testing.T) { } func TestBind_ResetJSONUnmarshal(t *testing.T) { - ResetStdJSONUnmarshaler() + UseStdJSONUnmarshaler() type Req struct { J1 string `json:"j1"` J2 int `json:"j2"` @@ -866,6 +866,45 @@ func TestBind_JSONRequiredField(t *testing.T) { assert.DeepEqual(t, 0, result2.N.N2.D) } +func TestValidate_MultipleValidate(t *testing.T) { + type Test1 struct { + A int `query:"a" vd:"$>10"` + } + req := newMockRequest(). + SetRequestURI("http://foobar.com?a=9") + var result Test1 + err := DefaultBinder().BindAndValidate(req.Req, nil, &result) + if err == nil { + t.Fatalf("expected an error, but get nil") + } +} + +func TestBind_BindQuery(t *testing.T) { + type Req struct { + Q1 int `query:"q1"` + Q2 int + Q3 string + Q4 string + Q5 []int + } + + req := newMockRequest(). + SetRequestURI("http://foobar.com?q1=1&Q2=2&Q3=3&Q4=4&Q5=51&Q5=52") + + var result Req + + err := DefaultBinder().BindQuery(req.Req, &result) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, 1, result.Q1) + assert.DeepEqual(t, 2, result.Q2) + assert.DeepEqual(t, "3", result.Q3) + assert.DeepEqual(t, "4", result.Q4) + assert.DeepEqual(t, 51, result.Q5[0]) + assert.DeepEqual(t, 52, result.Q5[1]) +} + func Benchmark_Binding(b *testing.B) { type Req struct { Version string `path:"v"` diff --git a/pkg/app/server/binding/config.go b/pkg/app/server/binding/config.go index 49c3a15a0..bdc78d1f7 100644 --- a/pkg/app/server/binding/config.go +++ b/pkg/app/server/binding/config.go @@ -20,20 +20,28 @@ import ( standardJson "encoding/json" "reflect" + "github.com/bytedance/go-tagexpr/v2/validator" "github.com/cloudwego/hertz/pkg/app/server/binding/decoder" hjson "github.com/cloudwego/hertz/pkg/common/json" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/route/param" ) -// ResetJSONUnmarshaler reset the JSON Unmarshal function. -func ResetJSONUnmarshaler(fn func(data []byte, v interface{}) error) { +// UseThirdPartyJSONUnmarshaler uses third-party json library for binding +// NOTE: +// +// UseThirdPartyJSONUnmarshaler will remain in effect once it has been called. +func UseThirdPartyJSONUnmarshaler(fn func(data []byte, v interface{}) error) { hjson.Unmarshal = fn } -// ResetStdJSONUnmarshaler uses "encoding/json" as the JSON Unmarshal function. -func ResetStdJSONUnmarshaler() { - ResetJSONUnmarshaler(standardJson.Unmarshal) +// UseStdJSONUnmarshaler uses encoding/json as json library +// NOTE: +// +// The current version uses encoding/json by default. +// UseStdJSONUnmarshaler will remain in effect once it has been called. +func UseStdJSONUnmarshaler() { + UseThirdPartyJSONUnmarshaler(standardJson.Unmarshal) } // EnableDefaultTag is used to enable or disable adding default tags to a field when it has no tag, it is true by default. @@ -57,3 +65,48 @@ func RegTypeUnmarshal(t reflect.Type, fn func(req *protocol.Request, params para func MustRegTypeUnmarshal(t reflect.Type, fn func(req *protocol.Request, params param.Params, text string) (reflect.Value, error)) { decoder.MustRegTypeUnmarshal(t, fn) } + +// ResetValidator reset a customized +func ResetValidator(v StructValidator, validatorTag string) { + defaultValidate = v + decoder.DefaultValidatorTag = validatorTag +} + +// MustRegValidateFunc registers validator function expression. +// NOTE: +// +// If force=true, allow to cover the existed same funcName. +// MustRegValidateFunc will remain in effect once it has been called. +func MustRegValidateFunc(funcName string, fn func(args ...interface{}) error, force ...bool) { + validator.MustRegFunc(funcName, fn, force...) +} + +// SetValidatorErrorFactory customizes the factory of validation error. +func SetValidatorErrorFactory(validatingErrFactory func(failField, msg string) error) { + if val, ok := DefaultValidator().(*defaultValidator); ok { + val.validate.SetErrorFactory(validatingErrFactory) + } else { + panic("customized validator can not use 'SetValidatorErrorFactory'") + } +} + +var enableDecoderUseNumber = false + +var enableDecoderDisallowUnknownFields = false + +// EnableDecoderUseNumber is used to call the UseNumber method on the JSON +// Decoder instance. UseNumber causes the Decoder to unmarshal a number into an +// interface{} as a Number instead of as a float64. +// NOTE: it is used for BindJSON(). +func EnableDecoderUseNumber(b bool) { + enableDecoderUseNumber = b +} + +// EnableDecoderDisallowUnknownFields is used to call the DisallowUnknownFields method +// on the JSON Decoder instance. DisallowUnknownFields causes the Decoder to +// return an error when the destination is a struct and the input contains object +// keys which do not match any non-ignored, exported fields in the destination. +// NOTE: it is used for BindJSON(). +func EnableDecoderDisallowUnknownFields(b bool) { + enableDecoderDisallowUnknownFields = b +} diff --git a/pkg/app/server/binding/decoder/decoder.go b/pkg/app/server/binding/decoder/decoder.go index 8f62e27d7..480a17a64 100644 --- a/pkg/app/server/binding/decoder/decoder.go +++ b/pkg/app/server/binding/decoder/decoder.go @@ -60,12 +60,13 @@ type fieldDecoder interface { type Decoder func(req *protocol.Request, params param.Params, rv reflect.Value) error -func GetReqDecoder(rt reflect.Type) (Decoder, error) { +func GetReqDecoder(rt reflect.Type, byTag string) (Decoder, bool, error) { var decoders []fieldDecoder + var needValidate bool el := rt.Elem() if el.Kind() != reflect.Struct { - return nil, fmt.Errorf("unsupported \"%s\" type binding", el.String()) + return nil, false, fmt.Errorf("unsupported \"%s\" type binding", el.String()) } for i := 0; i < el.NumField(); i++ { @@ -74,10 +75,11 @@ func GetReqDecoder(rt reflect.Type) (Decoder, error) { continue } - dec, err := getFieldDecoder(el.Field(i), i, []int{}, "") + dec, needValidate2, err := getFieldDecoder(el.Field(i), i, []int{}, "", "") if err != nil { - return nil, err + return nil, false, err } + needValidate = needValidate || needValidate2 if dec != nil { decoders = append(decoders, dec...) @@ -93,33 +95,39 @@ func GetReqDecoder(rt reflect.Type) (Decoder, error) { } return nil - }, nil + }, needValidate, nil } -func getFieldDecoder(field reflect.StructField, index int, parentIdx []int, parentJSONName string) ([]fieldDecoder, error) { +func getFieldDecoder(field reflect.StructField, index int, parentIdx []int, parentJSONName string, byTag string) ([]fieldDecoder, bool, error) { for field.Type.Kind() == reflect.Ptr { field.Type = field.Type.Elem() } if field.Type.Kind() != reflect.Struct && field.Anonymous { - return nil, nil + return nil, false, nil } - fieldTagInfos, newParentJSONName := lookupFieldTags(field, parentJSONName) + fieldTagInfos, newParentJSONName, needValidate := lookupFieldTags(field, parentJSONName) if len(fieldTagInfos) == 0 && EnableDefaultTag { fieldTagInfos = getDefaultFieldTags(field) } + if len(byTag) != 0 { + fieldTagInfos = getFieldTagInfoByTag(field, byTag) + } // customized type decoder has the highest priority if customizedFunc, exist := typeUnmarshalFuncs[field.Type]; exist { - return getCustomizedFieldDecoder(field, index, fieldTagInfos, parentIdx, customizedFunc) + dec, err := getCustomizedFieldDecoder(field, index, fieldTagInfos, parentIdx, customizedFunc) + return dec, needValidate, err } if field.Type.Kind() == reflect.Slice || field.Type.Kind() == reflect.Array { - return getSliceFieldDecoder(field, index, fieldTagInfos, parentIdx) + dec, err := getSliceFieldDecoder(field, index, fieldTagInfos, parentIdx) + return dec, needValidate, err } if field.Type.Kind() == reflect.Map { - return getMapTypeTextDecoder(field, index, fieldTagInfos, parentIdx) + dec, err := getMapTypeTextDecoder(field, index, fieldTagInfos, parentIdx) + return dec, needValidate, err } if field.Type.Kind() == reflect.Struct { @@ -128,12 +136,13 @@ func getFieldDecoder(field reflect.StructField, index int, parentIdx []int, pare // todo: more built-in common struct binding, ex. time... switch el { case reflect.TypeOf(multipart.FileHeader{}): - return getMultipartFileDecoder(field, index, fieldTagInfos, parentIdx) + dec, err := getMultipartFileDecoder(field, index, fieldTagInfos, parentIdx) + return dec, needValidate, err } if EnableStructFieldResolve { structFieldDecoder, err := getStructTypeFieldDecoder(field, index, fieldTagInfos, parentIdx) if err != nil { - return nil, err + return nil, needValidate, err } if structFieldDecoder != nil { decoders = append(decoders, structFieldDecoder...) @@ -150,17 +159,19 @@ func getFieldDecoder(field reflect.StructField, index int, parentIdx []int, pare idxes = append(idxes, parentIdx...) } idxes = append(idxes, index) - dec, err := getFieldDecoder(el.Field(i), i, idxes, newParentJSONName) + dec, needValidate2, err := getFieldDecoder(el.Field(i), i, idxes, newParentJSONName, byTag) + needValidate = needValidate || needValidate2 if err != nil { - return nil, err + return nil, false, err } if dec != nil { decoders = append(decoders, dec...) } } - return decoders, nil + return decoders, needValidate, nil } - return getBaseTypeTextDecoder(field, index, fieldTagInfos, parentIdx) + dec, err := getBaseTypeTextDecoder(field, index, fieldTagInfos, parentIdx) + return dec, needValidate, err } diff --git a/pkg/app/server/binding/decoder/tag.go b/pkg/app/server/binding/decoder/tag.go index 0c53791f8..f50f35f1c 100644 --- a/pkg/app/server/binding/decoder/tag.go +++ b/pkg/app/server/binding/decoder/tag.go @@ -32,6 +32,8 @@ const ( fileNameTag = "file_name" ) +var DefaultValidatorTag = "vd" + const ( defaultTag = "default" ) @@ -60,8 +62,12 @@ func head(str, sep string) (head, tail string) { return str[:idx], str[idx+len(sep):] } -func lookupFieldTags(field reflect.StructField, parentJSONName string) ([]TagInfo, string) { +func lookupFieldTags(field reflect.StructField, parentJSONName string) ([]TagInfo, string, bool) { var ret []string + var needValidate bool + if _, ok := field.Tag.Lookup(DefaultValidatorTag); ok { + needValidate = true + } tags := []string{pathTag, formTag, queryTag, cookieTag, headerTag, jsonTag, rawBodyTag, fileNameTag} for _, tag := range tags { if _, ok := field.Tag.Lookup(tag); ok { @@ -113,7 +119,7 @@ func lookupFieldTags(field reflect.StructField, parentJSONName string) ([]TagInf newParentJSONName = strings.TrimPrefix(parentJSONName+"."+field.Name, ".") } - return tagInfos, newParentJSONName + return tagInfos, newParentJSONName, needValidate } func getDefaultFieldTags(field reflect.StructField) (tagInfos []TagInfo) { @@ -129,3 +135,32 @@ func getDefaultFieldTags(field reflect.StructField) (tagInfos []TagInfo) { return } + +func getFieldTagInfoByTag(field reflect.StructField, tag string) []TagInfo { + var tagInfos []TagInfo + if content, ok := field.Tag.Lookup(tag); ok { + tagValue, opts := head(content, ",") + if len(tagValue) == 0 { + tagValue = field.Name + } + skip := false + if tagValue == "-" { + skip = true + } + var options []string + var opt string + var required bool + for len(opts) > 0 { + opt, opts = head(opts, ",") + options = append(options, opt) + if opt == requiredTagOpt { + required = true + } + } + tagInfos = append(tagInfos, TagInfo{Key: tag, Value: tagValue, Options: options, Required: required, Skip: skip}) + } else { + tagInfos = append(tagInfos, TagInfo{Key: tag, Value: field.Name}) + } + + return tagInfos +} diff --git a/pkg/app/server/binding/default.go b/pkg/app/server/binding/default.go index df436a694..8785cb9c4 100644 --- a/pkg/app/server/binding/default.go +++ b/pkg/app/server/binding/default.go @@ -62,27 +62,199 @@ package binding import ( "fmt" + "io" "reflect" "sync" + "github.com/bytedance/go-tagexpr/v2/validator" "github.com/cloudwego/hertz/internal/bytesconv" "github.com/cloudwego/hertz/pkg/app/server/binding/decoder" hjson "github.com/cloudwego/hertz/pkg/common/json" "github.com/cloudwego/hertz/pkg/common/utils" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/route/param" - "github.com/go-playground/validator/v10" "google.golang.org/protobuf/proto" ) +type decoderInfo struct { + decoder decoder.Decoder + needValidate bool +} + type defaultBinder struct { - decoderCache sync.Map + decoderCache sync.Map + queryDecoderCache sync.Map + formDecoderCache sync.Map + headerDecoderCache sync.Map + pathDecoderCache sync.Map +} + +func (b *defaultBinder) BindQuery(req *protocol.Request, v interface{}) error { + rv, typeID := valueAndTypeID(v) + if rv.Kind() != reflect.Ptr || rv.IsNil() { + return fmt.Errorf("receiver must be a non-nil pointer") + } + if rv.Elem().Kind() == reflect.Map { + return nil + } + cached, ok := b.queryDecoderCache.Load(typeID) + if ok { + // cached fieldDecoder, fast path + decoder := cached.(decoderInfo) + return decoder.decoder(req, nil, rv.Elem()) + } + + decoder, needValidate, err := decoder.GetReqDecoder(rv.Type(), "query") + if err != nil { + return err + } + + b.queryDecoderCache.Store(typeID, decoderInfo{decoder: decoder, needValidate: needValidate}) + return decoder(req, nil, rv.Elem()) +} + +func (b *defaultBinder) BindHeader(req *protocol.Request, v interface{}) error { + rv, typeID := valueAndTypeID(v) + if rv.Kind() != reflect.Ptr || rv.IsNil() { + return fmt.Errorf("receiver must be a non-nil pointer") + } + if rv.Elem().Kind() == reflect.Map { + return nil + } + cached, ok := b.headerDecoderCache.Load(typeID) + if ok { + // cached fieldDecoder, fast path + decoder := cached.(decoderInfo) + return decoder.decoder(req, nil, rv.Elem()) + } + + decoder, needValidate, err := decoder.GetReqDecoder(rv.Type(), "header") + if err != nil { + return err + } + + b.headerDecoderCache.Store(typeID, decoderInfo{decoder: decoder, needValidate: needValidate}) + return decoder(req, nil, rv.Elem()) +} + +func (b *defaultBinder) BindPath(req *protocol.Request, params param.Params, v interface{}) error { + rv, typeID := valueAndTypeID(v) + if rv.Kind() != reflect.Ptr || rv.IsNil() { + return fmt.Errorf("receiver must be a non-nil pointer") + } + if rv.Elem().Kind() == reflect.Map { + return nil + } + cached, ok := b.pathDecoderCache.Load(typeID) + if ok { + // cached fieldDecoder, fast path + decoder := cached.(decoderInfo) + return decoder.decoder(req, params, rv.Elem()) + } + + decoder, needValidate, err := decoder.GetReqDecoder(rv.Type(), "path") + if err != nil { + return err + } + + b.pathDecoderCache.Store(typeID, decoderInfo{decoder: decoder, needValidate: needValidate}) + return decoder(req, params, rv.Elem()) +} + +func (b *defaultBinder) BindForm(req *protocol.Request, v interface{}) error { + rv, typeID := valueAndTypeID(v) + if rv.Kind() != reflect.Ptr || rv.IsNil() { + return fmt.Errorf("receiver must be a non-nil pointer") + } + if rv.Elem().Kind() == reflect.Map { + return nil + } + cached, ok := b.formDecoderCache.Load(typeID) + if ok { + // cached fieldDecoder, fast path + decoder := cached.(decoderInfo) + return decoder.decoder(req, nil, rv.Elem()) + } + + decoder, needValidate, err := decoder.GetReqDecoder(rv.Type(), "form") + if err != nil { + return err + } + + b.formDecoderCache.Store(typeID, decoderInfo{decoder: decoder, needValidate: needValidate}) + return decoder(req, nil, rv.Elem()) +} + +func (b *defaultBinder) BindJSON(req *protocol.Request, v interface{}) error { + return decodeJSON(req.BodyStream(), v) +} + +func decodeJSON(r io.Reader, obj interface{}) error { + decoder := hjson.NewDecoder(r) + if enableDecoderUseNumber { + decoder.UseNumber() + } + if enableDecoderDisallowUnknownFields { + decoder.DisallowUnknownFields() + } + + return decoder.Decode(obj) +} + +func (b *defaultBinder) BindProtobuf(req *protocol.Request, v interface{}) error { + msg, ok := v.(proto.Message) + if !ok { + return fmt.Errorf("%s can not implement 'proto.Message'", v) + } + return proto.Unmarshal(req.Body(), msg) } func (b *defaultBinder) Name() string { return "hertz" } +func (b *defaultBinder) BindAndValidate(req *protocol.Request, params param.Params, v interface{}) error { + err := b.preBindBody(req, v) + if err != nil { + return fmt.Errorf("bind body failed, err=%v", err) + } + rv, typeID := valueAndTypeID(v) + if rv.Kind() != reflect.Ptr || rv.IsNil() { + return fmt.Errorf("receiver must be a non-nil pointer") + } + if rv.Elem().Kind() == reflect.Map { + return nil + } + cached, ok := b.decoderCache.Load(typeID) + if ok { + // cached fieldDecoder, fast path + decoder := cached.(decoderInfo) + err = decoder.decoder(req, params, rv.Elem()) + if err != nil { + return err + } + if decoder.needValidate { + err = DefaultValidator().ValidateStruct(rv.Elem()) + } + return err + } + + decoder, needValidate, err := decoder.GetReqDecoder(rv.Type(), "") + if err != nil { + return err + } + + b.decoderCache.Store(typeID, decoderInfo{decoder: decoder, needValidate: needValidate}) + err = decoder(req, params, rv.Elem()) + if err != nil { + return err + } + if needValidate { + err = DefaultValidator().ValidateStruct(rv.Elem()) + } + return err +} + func (b *defaultBinder) Bind(req *protocol.Request, params param.Params, v interface{}) error { err := b.preBindBody(req, v) if err != nil { @@ -98,16 +270,16 @@ func (b *defaultBinder) Bind(req *protocol.Request, params param.Params, v inter cached, ok := b.decoderCache.Load(typeID) if ok { // cached fieldDecoder, fast path - decoder := cached.(decoder.Decoder) - return decoder(req, params, rv.Elem()) + decoder := cached.(decoderInfo) + return decoder.decoder(req, params, rv.Elem()) } - decoder, err := decoder.GetReqDecoder(rv.Type()) + decoder, needValidate, err := decoder.GetReqDecoder(rv.Type(), "") if err != nil { return err } - b.decoderCache.Store(typeID, decoder) + b.decoderCache.Store(typeID, decoderInfo{decoder: decoder, needValidate: needValidate}) return decoder(req, params, rv.Elem()) } @@ -124,7 +296,6 @@ func (b *defaultBinder) preBindBody(req *protocol.Request, v interface{}) error ct := bytesconv.B2s(req.Header.ContentType()) switch utils.FilterContentType(ct) { case jsonContentType: - // todo: aligning the gin, add "EnableDecoderUseNumber"/"EnableDecoderDisallowUnknownFields" interface return hjson.Unmarshal(req.Body(), v) case protobufContentType: msg, ok := v.(proto.Message) @@ -141,7 +312,7 @@ var _ StructValidator = (*defaultValidator)(nil) type defaultValidator struct { once sync.Once - validate *validator.Validate + validate *validator.Validator } // ValidateStruct receives any kind of type, but only performed struct or pointer to struct type. @@ -149,35 +320,17 @@ func (v *defaultValidator) ValidateStruct(obj interface{}) error { if obj == nil { return nil } - - value := reflect.ValueOf(obj) - switch value.Kind() { - case reflect.Ptr: - return v.ValidateStruct(value.Elem().Interface()) - case reflect.Struct: - return v.validateStruct(obj) - default: - return nil - } -} - -// validateStruct receives struct type -func (v *defaultValidator) validateStruct(obj interface{}) error { v.lazyinit() - return v.validate.Struct(obj) + return v.validate.Validate(obj) } func (v *defaultValidator) lazyinit() { v.once.Do(func() { - v.validate = validator.New() - v.validate.SetTagName("validate") + v.validate = validator.Default() }) } -// Engine returns the underlying validator engine which powers the default -// Validator instance. This is useful if you want to register custom validations -// or struct level validations. See validator GoDoc for more info - -// https://pkg.go.dev/github.com/go-playground/validator/v10 +// Engine returns the underlying validator func (v *defaultValidator) Engine() interface{} { v.lazyinit() return v.validate diff --git a/pkg/app/server/binding/validator.go b/pkg/app/server/binding/validator.go index 3332752a8..1ede4deb0 100644 --- a/pkg/app/server/binding/validator.go +++ b/pkg/app/server/binding/validator.go @@ -40,25 +40,13 @@ package binding -// StructValidator is the minimal interface which needs to be implemented in -// order for it to be used as the validator engine for ensuring the correctness -// of the request. Hertz provides a default implementation for this using -// https://github.com/go-playground/validator/tree/v10.6.1. type StructValidator interface { - // ValidateStruct can receive any kind of type and it should never panic, even if the configuration is not right. - // If the received type is a slice|array, the validation should be performed travel on every element. - // If the received type is not a struct or slice|array, any validation should be skipped and nil must be returned. - // If the received type is a struct or pointer to a struct, the validation should be performed. - // If the struct is not valid or the validation itself fails, a descriptive error should be returned. - // Otherwise nil must be returned. ValidateStruct(interface{}) error - - // Engine returns the underlying validator engine which powers the - // StructValidator implementation. Engine() interface{} } -// DefaultValidator is the default validator which implements the StructValidator -// interface. It uses https://github.com/go-playground/validator/tree/v10.6.1 -// under the hood. -var DefaultValidator StructValidator = &defaultValidator{} +var defaultValidate StructValidator = &defaultValidator{} + +func DefaultValidator() StructValidator { + return defaultValidate +} diff --git a/pkg/app/server/binding/validator_test.go b/pkg/app/server/binding/validator_test.go index 3c7b5a292..05a59affa 100644 --- a/pkg/app/server/binding/validator_test.go +++ b/pkg/app/server/binding/validator_test.go @@ -17,31 +17,19 @@ package binding import ( - "fmt" + "testing" ) -func ExampleDefaultValidator_ValidateStruct() { +func Test_ValidateStruct(t *testing.T) { type User struct { - FirstName string `validate:"required"` - LastName string `validate:"required"` - Age uint8 `validate:"gte=0,lte=130"` - Email string `validate:"required,email"` - FavouriteColor string `validate:"iscolor"` + Age int `vd:"$>=0&&$<=130"` } user := &User{ - FirstName: "Hertz", - Age: 135, - Email: "hertz", - FavouriteColor: "sad", + Age: 135, } - err := DefaultValidator.ValidateStruct(user) - if err != nil { - fmt.Println(err) + err := DefaultValidator().ValidateStruct(user) + if err == nil { + t.Fatalf("expected an error, but got nil") } - // Output: - // Key: 'User.LastName' Error:Field validation for 'LastName' failed on the 'required' tag - // Key: 'User.Age' Error:Field validation for 'Age' failed on the 'lte' tag - // Key: 'User.Email' Error:Field validation for 'Email' failed on the 'email' tag - // Key: 'User.FavouriteColor' Error:Field validation for 'FavouriteColor' failed on the 'iscolor' tag } From 38795e5c32dba0c06a5a726c8a779e087cc70ba3 Mon Sep 17 00:00:00 2001 From: fgy Date: Fri, 26 May 2023 16:57:44 +0800 Subject: [PATCH 45/91] refactor: add internal --- pkg/app/server/binding/config.go | 2 +- pkg/app/server/binding/default.go | 2 +- .../server/binding/{ => internal}/decoder/base_type_decoder.go | 0 .../binding/{ => internal}/decoder/customized_type_decoder.go | 0 pkg/app/server/binding/{ => internal}/decoder/decoder.go | 0 pkg/app/server/binding/{ => internal}/decoder/getter.go | 0 pkg/app/server/binding/{ => internal}/decoder/gjson_required.go | 0 .../server/binding/{ => internal}/decoder/map_type_decoder.go | 0 .../binding/{ => internal}/decoder/multipart_file_decoder.go | 0 pkg/app/server/binding/{ => internal}/decoder/reflect.go | 0 pkg/app/server/binding/{ => internal}/decoder/slice_getter.go | 0 .../server/binding/{ => internal}/decoder/slice_type_decoder.go | 0 pkg/app/server/binding/{ => internal}/decoder/sonic_required.go | 0 .../binding/{ => internal}/decoder/struct_type_decoder.go | 0 pkg/app/server/binding/{ => internal}/decoder/tag.go | 0 pkg/app/server/binding/{ => internal}/decoder/text_decoder.go | 0 16 files changed, 2 insertions(+), 2 deletions(-) rename pkg/app/server/binding/{ => internal}/decoder/base_type_decoder.go (100%) rename pkg/app/server/binding/{ => internal}/decoder/customized_type_decoder.go (100%) rename pkg/app/server/binding/{ => internal}/decoder/decoder.go (100%) rename pkg/app/server/binding/{ => internal}/decoder/getter.go (100%) rename pkg/app/server/binding/{ => internal}/decoder/gjson_required.go (100%) rename pkg/app/server/binding/{ => internal}/decoder/map_type_decoder.go (100%) rename pkg/app/server/binding/{ => internal}/decoder/multipart_file_decoder.go (100%) rename pkg/app/server/binding/{ => internal}/decoder/reflect.go (100%) rename pkg/app/server/binding/{ => internal}/decoder/slice_getter.go (100%) rename pkg/app/server/binding/{ => internal}/decoder/slice_type_decoder.go (100%) rename pkg/app/server/binding/{ => internal}/decoder/sonic_required.go (100%) rename pkg/app/server/binding/{ => internal}/decoder/struct_type_decoder.go (100%) rename pkg/app/server/binding/{ => internal}/decoder/tag.go (100%) rename pkg/app/server/binding/{ => internal}/decoder/text_decoder.go (100%) diff --git a/pkg/app/server/binding/config.go b/pkg/app/server/binding/config.go index bdc78d1f7..b8d0cbcd2 100644 --- a/pkg/app/server/binding/config.go +++ b/pkg/app/server/binding/config.go @@ -21,7 +21,7 @@ import ( "reflect" "github.com/bytedance/go-tagexpr/v2/validator" - "github.com/cloudwego/hertz/pkg/app/server/binding/decoder" + "github.com/cloudwego/hertz/pkg/app/server/binding/internal/decoder" hjson "github.com/cloudwego/hertz/pkg/common/json" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/route/param" diff --git a/pkg/app/server/binding/default.go b/pkg/app/server/binding/default.go index 8785cb9c4..86ff5241b 100644 --- a/pkg/app/server/binding/default.go +++ b/pkg/app/server/binding/default.go @@ -68,7 +68,7 @@ import ( "github.com/bytedance/go-tagexpr/v2/validator" "github.com/cloudwego/hertz/internal/bytesconv" - "github.com/cloudwego/hertz/pkg/app/server/binding/decoder" + "github.com/cloudwego/hertz/pkg/app/server/binding/internal/decoder" hjson "github.com/cloudwego/hertz/pkg/common/json" "github.com/cloudwego/hertz/pkg/common/utils" "github.com/cloudwego/hertz/pkg/protocol" diff --git a/pkg/app/server/binding/decoder/base_type_decoder.go b/pkg/app/server/binding/internal/decoder/base_type_decoder.go similarity index 100% rename from pkg/app/server/binding/decoder/base_type_decoder.go rename to pkg/app/server/binding/internal/decoder/base_type_decoder.go diff --git a/pkg/app/server/binding/decoder/customized_type_decoder.go b/pkg/app/server/binding/internal/decoder/customized_type_decoder.go similarity index 100% rename from pkg/app/server/binding/decoder/customized_type_decoder.go rename to pkg/app/server/binding/internal/decoder/customized_type_decoder.go diff --git a/pkg/app/server/binding/decoder/decoder.go b/pkg/app/server/binding/internal/decoder/decoder.go similarity index 100% rename from pkg/app/server/binding/decoder/decoder.go rename to pkg/app/server/binding/internal/decoder/decoder.go diff --git a/pkg/app/server/binding/decoder/getter.go b/pkg/app/server/binding/internal/decoder/getter.go similarity index 100% rename from pkg/app/server/binding/decoder/getter.go rename to pkg/app/server/binding/internal/decoder/getter.go diff --git a/pkg/app/server/binding/decoder/gjson_required.go b/pkg/app/server/binding/internal/decoder/gjson_required.go similarity index 100% rename from pkg/app/server/binding/decoder/gjson_required.go rename to pkg/app/server/binding/internal/decoder/gjson_required.go diff --git a/pkg/app/server/binding/decoder/map_type_decoder.go b/pkg/app/server/binding/internal/decoder/map_type_decoder.go similarity index 100% rename from pkg/app/server/binding/decoder/map_type_decoder.go rename to pkg/app/server/binding/internal/decoder/map_type_decoder.go diff --git a/pkg/app/server/binding/decoder/multipart_file_decoder.go b/pkg/app/server/binding/internal/decoder/multipart_file_decoder.go similarity index 100% rename from pkg/app/server/binding/decoder/multipart_file_decoder.go rename to pkg/app/server/binding/internal/decoder/multipart_file_decoder.go diff --git a/pkg/app/server/binding/decoder/reflect.go b/pkg/app/server/binding/internal/decoder/reflect.go similarity index 100% rename from pkg/app/server/binding/decoder/reflect.go rename to pkg/app/server/binding/internal/decoder/reflect.go diff --git a/pkg/app/server/binding/decoder/slice_getter.go b/pkg/app/server/binding/internal/decoder/slice_getter.go similarity index 100% rename from pkg/app/server/binding/decoder/slice_getter.go rename to pkg/app/server/binding/internal/decoder/slice_getter.go diff --git a/pkg/app/server/binding/decoder/slice_type_decoder.go b/pkg/app/server/binding/internal/decoder/slice_type_decoder.go similarity index 100% rename from pkg/app/server/binding/decoder/slice_type_decoder.go rename to pkg/app/server/binding/internal/decoder/slice_type_decoder.go diff --git a/pkg/app/server/binding/decoder/sonic_required.go b/pkg/app/server/binding/internal/decoder/sonic_required.go similarity index 100% rename from pkg/app/server/binding/decoder/sonic_required.go rename to pkg/app/server/binding/internal/decoder/sonic_required.go diff --git a/pkg/app/server/binding/decoder/struct_type_decoder.go b/pkg/app/server/binding/internal/decoder/struct_type_decoder.go similarity index 100% rename from pkg/app/server/binding/decoder/struct_type_decoder.go rename to pkg/app/server/binding/internal/decoder/struct_type_decoder.go diff --git a/pkg/app/server/binding/decoder/tag.go b/pkg/app/server/binding/internal/decoder/tag.go similarity index 100% rename from pkg/app/server/binding/decoder/tag.go rename to pkg/app/server/binding/internal/decoder/tag.go diff --git a/pkg/app/server/binding/decoder/text_decoder.go b/pkg/app/server/binding/internal/decoder/text_decoder.go similarity index 100% rename from pkg/app/server/binding/decoder/text_decoder.go rename to pkg/app/server/binding/internal/decoder/text_decoder.go From 01b9109b0b85ea28247e87957caa0f4fbdf95300 Mon Sep 17 00:00:00 2001 From: fgy Date: Wed, 31 May 2023 18:33:58 +0800 Subject: [PATCH 46/91] feat: assign to gin --- pkg/app/context.go | 3 +++ pkg/app/context_test.go | 26 +++++++++++++++++++ .../binding/internal/decoder/decoder.go | 2 +- 3 files changed, 30 insertions(+), 1 deletion(-) diff --git a/pkg/app/context.go b/pkg/app/context.go index 008ed49db..96f1c583d 100644 --- a/pkg/app/context.go +++ b/pkg/app/context.go @@ -1333,6 +1333,9 @@ func (ctx *RequestContext) BindPath(obj interface{}) error { } func (ctx *RequestContext) BindForm(obj interface{}) error { + if len(ctx.Request.Body()) == 0 { + return fmt.Errorf("missing form body") + } return binding.DefaultBinder().BindForm(&ctx.Request, obj) } diff --git a/pkg/app/context_test.go b/pkg/app/context_test.go index 88cd26338..a7d29f5d8 100644 --- a/pkg/app/context_test.go +++ b/pkg/app/context_test.go @@ -1457,6 +1457,32 @@ func TestBindAndValidate(t *testing.T) { } } +func TestBindForm(t *testing.T) { + type Test struct { + A string + B int + } + + c := &RequestContext{} + c.Request.SetRequestURI("/foo/bar?a=123&b=11") + c.Request.SetBody([]byte("A=123&B=11")) + c.Request.Header.SetContentTypeBytes([]byte("application/x-www-form-urlencoded")) + + var req Test + err := c.BindForm(&req) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + assert.DeepEqual(t, "123", req.A) + assert.DeepEqual(t, 11, req.B) + + c.Request.SetBody([]byte("")) + err = c.BindForm(&req) + if err == nil { + t.Fatalf("expected error, but get nil") + } +} + func TestRequestContext_SetCookie(t *testing.T) { c := NewContext(0) c.SetCookie("user", "hertz", 1, "/", "localhost", protocol.CookieSameSiteLaxMode, true, true) diff --git a/pkg/app/server/binding/internal/decoder/decoder.go b/pkg/app/server/binding/internal/decoder/decoder.go index 480a17a64..6b50ce9dd 100644 --- a/pkg/app/server/binding/internal/decoder/decoder.go +++ b/pkg/app/server/binding/internal/decoder/decoder.go @@ -75,7 +75,7 @@ func GetReqDecoder(rt reflect.Type, byTag string) (Decoder, bool, error) { continue } - dec, needValidate2, err := getFieldDecoder(el.Field(i), i, []int{}, "", "") + dec, needValidate2, err := getFieldDecoder(el.Field(i), i, []int{}, "", byTag) if err != nil { return nil, false, err } From b773fe8db00a0e67168fe55ab3ebd8128bd2b5b8 Mon Sep 17 00:00:00 2001 From: fgy Date: Tue, 22 Aug 2023 19:42:21 +0800 Subject: [PATCH 47/91] feat: align gin for post-form --- pkg/app/server/binding/internal/decoder/getter.go | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/pkg/app/server/binding/internal/decoder/getter.go b/pkg/app/server/binding/internal/decoder/getter.go index 1998b26ba..570076cd3 100644 --- a/pkg/app/server/binding/internal/decoder/getter.go +++ b/pkg/app/server/binding/internal/decoder/getter.go @@ -75,6 +75,13 @@ func postForm(req *protocol.Request, params param.Params, key string, defaultVal } } + if len(ret) != 0 { + return + } + if val := req.URI().QueryArgs().Peek(key); val != nil { + ret = string(val) + } + if len(ret) == 0 && len(defaultValue) != 0 { ret = defaultValue[0] } From 44eb90e87b89520294a8d23df3f6464e05d9e544 Mon Sep 17 00:00:00 2001 From: fgy Date: Wed, 23 Aug 2023 15:44:41 +0800 Subject: [PATCH 48/91] feat: modify multi-pointer erorr --- pkg/app/server/binding/internal/decoder/decoder.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/app/server/binding/internal/decoder/decoder.go b/pkg/app/server/binding/internal/decoder/decoder.go index 6b50ce9dd..3ece4fb3e 100644 --- a/pkg/app/server/binding/internal/decoder/decoder.go +++ b/pkg/app/server/binding/internal/decoder/decoder.go @@ -66,7 +66,7 @@ func GetReqDecoder(rt reflect.Type, byTag string) (Decoder, bool, error) { el := rt.Elem() if el.Kind() != reflect.Struct { - return nil, false, fmt.Errorf("unsupported \"%s\" type binding", el.String()) + return nil, false, fmt.Errorf("unsupported \"%s\" type binding", rt.String()) } for i := 0; i < el.NumField(); i++ { From f4fcbee177a134feb3605929d0c2de79d2340c05 Mon Sep 17 00:00:00 2001 From: fgy Date: Wed, 23 Aug 2023 17:04:15 +0800 Subject: [PATCH 49/91] optimize: more comment --- .../binding/internal/decoder/base_type_decoder.go | 10 +++++----- .../internal/decoder/customized_type_decoder.go | 8 -------- pkg/app/server/binding/internal/decoder/decoder.go | 13 +++++++++++-- .../binding/internal/decoder/map_type_decoder.go | 2 +- .../binding/internal/decoder/slice_type_decoder.go | 6 +++--- pkg/app/server/binding/tagexpr_bind_test.go | 6 +++--- 6 files changed, 23 insertions(+), 22 deletions(-) diff --git a/pkg/app/server/binding/internal/decoder/base_type_decoder.go b/pkg/app/server/binding/internal/decoder/base_type_decoder.go index 3fe659825..099aea20f 100644 --- a/pkg/app/server/binding/internal/decoder/base_type_decoder.go +++ b/pkg/app/server/binding/internal/decoder/base_type_decoder.go @@ -53,8 +53,8 @@ type fieldInfo struct { index int parentIndex []int fieldName string - tagInfos []TagInfo // querySlice,param,headerSlice,respHeader ... - fieldType reflect.Type // can not be pointer type + tagInfos []TagInfo + fieldType reflect.Type } type baseTypeFieldTextDecoder struct { @@ -74,7 +74,7 @@ func (d *baseTypeFieldTextDecoder) Decode(req *protocol.Request, params param.Pa if found { err = nil } else { - err = fmt.Errorf("'%s' field is a 'required' parameter, but the request does not have this parameter", d.fieldName) + err = fmt.Errorf("'%s' field is a 'required' parameter, but the request body does not have this parameter '%s'", d.fieldName, tagInfo.JSONName) } } continue @@ -98,11 +98,11 @@ func (d *baseTypeFieldTextDecoder) Decode(req *protocol.Request, params param.Pa if len(text) == 0 && len(defaultValue) != 0 { text = defaultValue } - if text == "" { + if len(text) == 0 { return nil } - // get the non-nil value for the field + // get the non-nil value for the parent field reqValue = GetFieldValue(reqValue, d.parentIndex) field := reqValue.Field(d.index) if field.Kind() == reflect.Ptr { diff --git a/pkg/app/server/binding/internal/decoder/customized_type_decoder.go b/pkg/app/server/binding/internal/decoder/customized_type_decoder.go index 3fb38c4e5..dcfcd9d6f 100644 --- a/pkg/app/server/binding/internal/decoder/customized_type_decoder.go +++ b/pkg/app/server/binding/internal/decoder/customized_type_decoder.go @@ -79,14 +79,6 @@ func RegTypeUnmarshal(t reflect.Type, fn customizeDecodeFunc) error { case reflect.Ptr: return fmt.Errorf("registration type cannot be a pointer type") } - // test - //vv, err := fn(&protocol.Request{}, nil) - //if err != nil { - // return fmt.Errorf("test fail: %s", err) - //} - //if tt := vv.Type(); tt != t { - // return fmt.Errorf("test fail: expect return value type is %s, but got %s", t.String(), tt.String()) - //} typeUnmarshalFuncs[t] = fn return nil diff --git a/pkg/app/server/binding/internal/decoder/decoder.go b/pkg/app/server/binding/internal/decoder/decoder.go index 3ece4fb3e..fe651368b 100644 --- a/pkg/app/server/binding/internal/decoder/decoder.go +++ b/pkg/app/server/binding/internal/decoder/decoder.go @@ -102,10 +102,15 @@ func getFieldDecoder(field reflect.StructField, index int, parentIdx []int, pare for field.Type.Kind() == reflect.Ptr { field.Type = field.Type.Elem() } + // skip anonymous definitions, like: + // type A struct { + // string + // } if field.Type.Kind() != reflect.Struct && field.Anonymous { return nil, false, nil } + // JSONName is like 'a.b.c' for 'required validate' fieldTagInfos, newParentJSONName, needValidate := lookupFieldTags(field, parentJSONName) if len(fieldTagInfos) == 0 && EnableDefaultTag { fieldTagInfos = getDefaultFieldTags(field) @@ -120,26 +125,29 @@ func getFieldDecoder(field reflect.StructField, index int, parentIdx []int, pare return dec, needValidate, err } + // slice/array field decoder if field.Type.Kind() == reflect.Slice || field.Type.Kind() == reflect.Array { dec, err := getSliceFieldDecoder(field, index, fieldTagInfos, parentIdx) return dec, needValidate, err } + // map filed decoder if field.Type.Kind() == reflect.Map { dec, err := getMapTypeTextDecoder(field, index, fieldTagInfos, parentIdx) return dec, needValidate, err } + // struct field will be resolved recursively if field.Type.Kind() == reflect.Struct { var decoders []fieldDecoder el := field.Type // todo: more built-in common struct binding, ex. time... switch el { - case reflect.TypeOf(multipart.FileHeader{}): + case reflect.TypeOf(multipart.FileHeader{}): // file binding dec, err := getMultipartFileDecoder(field, index, fieldTagInfos, parentIdx) return dec, needValidate, err } - if EnableStructFieldResolve { + if EnableStructFieldResolve { // decode struct type separately structFieldDecoder, err := getStructTypeFieldDecoder(field, index, fieldTagInfos, parentIdx) if err != nil { return nil, needValidate, err @@ -172,6 +180,7 @@ func getFieldDecoder(field reflect.StructField, index int, parentIdx []int, pare return decoders, needValidate, nil } + // base type decoder dec, err := getBaseTypeTextDecoder(field, index, fieldTagInfos, parentIdx) return dec, needValidate, err } diff --git a/pkg/app/server/binding/internal/decoder/map_type_decoder.go b/pkg/app/server/binding/internal/decoder/map_type_decoder.go index 71fea1db5..e2f2e819a 100644 --- a/pkg/app/server/binding/internal/decoder/map_type_decoder.go +++ b/pkg/app/server/binding/internal/decoder/map_type_decoder.go @@ -91,7 +91,7 @@ func (d *mapTypeFieldTextDecoder) Decode(req *protocol.Request, params param.Par if len(text) == 0 && len(defaultValue) != 0 { text = defaultValue } - if text == "" { + if len(text) == 0 { return nil } diff --git a/pkg/app/server/binding/internal/decoder/slice_type_decoder.go b/pkg/app/server/binding/internal/decoder/slice_type_decoder.go index 63d5cea37..12d876f14 100644 --- a/pkg/app/server/binding/internal/decoder/slice_type_decoder.go +++ b/pkg/app/server/binding/internal/decoder/slice_type_decoder.go @@ -82,7 +82,6 @@ func (d *sliceTypeFieldTextDecoder) Decode(req *protocol.Request, params param.P bindRawBody = true } texts = tagInfo.SliceGetter(req, params, tagInfo.Value) - // todo: array/slice default value defaultValue = tagInfo.Default if len(texts) != 0 { err = nil @@ -133,7 +132,7 @@ func (d *sliceTypeFieldTextDecoder) Decode(req *protocol.Request, params param.P // handle internal multiple pointer, []**int var ptrDepth int - t := d.fieldType.Elem() + t := d.fieldType.Elem() // d.fieldType is non-pointer type for the field elemKind := t.Kind() for elemKind == reflect.Ptr { t = t.Elem() @@ -156,7 +155,7 @@ func (d *sliceTypeFieldTextDecoder) Decode(req *protocol.Request, params param.P // text[0] can be a complete json content for []Type. err = hjson.Unmarshal(bytesconv.S2b(texts[0]), reqValue.Field(d.index).Addr().Interface()) if err != nil { - return err + return fmt.Errorf("using '%s' to unmarshal type '%s' failed, %s", texts[0], reqValue.Field(d.index).Kind().String(), err.Error()) } } else { reqValue.Field(d.index).Set(ReferenceValue(field, parentPtrDepth)) @@ -205,6 +204,7 @@ func getSliceFieldDecoder(field reflect.StructField, index int, tagInfos []TagIn for field.Type.Kind() == reflect.Ptr { fieldType = field.Type.Elem() } + // fieldType.Elem() is the type for array/slice elem t := getElemType(fieldType.Elem()) if t == reflect.TypeOf(multipart.FileHeader{}) { return getMultipartFileDecoder(field, index, tagInfos, parentIdx) diff --git a/pkg/app/server/binding/tagexpr_bind_test.go b/pkg/app/server/binding/tagexpr_bind_test.go index ada72fe27..d3ca96246 100644 --- a/pkg/app/server/binding/tagexpr_bind_test.go +++ b/pkg/app/server/binding/tagexpr_bind_test.go @@ -117,7 +117,7 @@ func TestGetBody(t *testing.T) { if err == nil { t.Fatalf("expected an error, but get nil") } - assert.DeepEqual(t, err.Error(), "'E' field is a 'required' parameter, but the request does not have this parameter") + assert.DeepEqual(t, err.Error(), "'E' field is a 'required' parameter, but the request body does not have this parameter 'X.e'") } func TestQueryNum(t *testing.T) { @@ -427,7 +427,7 @@ func TestJSON(t *testing.T) { if err == nil { t.Error("expected an error, but get nil") } - assert.DeepEqual(t, err.Error(), "'Y' field is a 'required' parameter, but the request does not have this parameter") + assert.DeepEqual(t, err.Error(), "'Y' field is a 'required' parameter, but the request body does not have this parameter 'y'") assert.DeepEqual(t, []string{"a1", "a2"}, (**recv.X).A) assert.DeepEqual(t, int32(21), (**recv.X).B) assert.DeepEqual(t, &[]uint16{31, 32}, (**recv.X).C) @@ -712,7 +712,7 @@ func TestOption(t *testing.T) { req = newRequest("", header, nil, bodyReader) recv = new(Recv) err = DefaultBinder().Bind(req.Req, nil, recv) - assert.DeepEqual(t, err.Error(), "'C' field is a 'required' parameter, but the request does not have this parameter") + assert.DeepEqual(t, err.Error(), "'C' field is a 'required' parameter, but the request body does not have this parameter 'X.c'") assert.DeepEqual(t, 0, recv.X.C) assert.DeepEqual(t, 0, recv.X.D) assert.DeepEqual(t, "y1", recv.Y) From 655fe6c5e09b425dd979357403c9fe917deaa347 Mon Sep 17 00:00:00 2001 From: fgy Date: Thu, 24 Aug 2023 21:38:40 +0800 Subject: [PATCH 50/91] refactor: bind func signature --- pkg/app/context.go | 6 +- pkg/app/server/binding/binder.go | 6 +- pkg/app/server/binding/binder_test.go | 66 ++++++++++----------- pkg/app/server/binding/default.go | 47 +++++++++++---- pkg/app/server/binding/tagexpr_bind_test.go | 60 +++++++++---------- 5 files changed, 105 insertions(+), 80 deletions(-) diff --git a/pkg/app/context.go b/pkg/app/context.go index 96f1c583d..925ecf040 100644 --- a/pkg/app/context.go +++ b/pkg/app/context.go @@ -1305,13 +1305,13 @@ func bodyAllowedForStatus(status int) bool { // BindAndValidate binds data from *RequestContext to obj and validates them if needed. // NOTE: obj should be a pointer. func (ctx *RequestContext) BindAndValidate(obj interface{}) error { - return binding.DefaultBinder().BindAndValidate(&ctx.Request, ctx.Params, obj) + return binding.DefaultBinder().BindAndValidate(&ctx.Request, obj, ctx.Params) } // Bind binds data from *RequestContext to obj. // NOTE: obj should be a pointer. func (ctx *RequestContext) Bind(obj interface{}) error { - return binding.DefaultBinder().Bind(&ctx.Request, ctx.Params, obj) + return binding.DefaultBinder().Bind(&ctx.Request, obj, ctx.Params) } // Validate validates obj with "vd" tag @@ -1329,7 +1329,7 @@ func (ctx *RequestContext) BindHeader(obj interface{}) error { } func (ctx *RequestContext) BindPath(obj interface{}) error { - return binding.DefaultBinder().BindPath(&ctx.Request, ctx.Params, obj) + return binding.DefaultBinder().BindPath(&ctx.Request, obj, ctx.Params) } func (ctx *RequestContext) BindForm(obj interface{}) error { diff --git a/pkg/app/server/binding/binder.go b/pkg/app/server/binding/binder.go index 742e3fe3d..dc8951ef1 100644 --- a/pkg/app/server/binding/binder.go +++ b/pkg/app/server/binding/binder.go @@ -47,11 +47,11 @@ import ( type Binder interface { Name() string - Bind(*protocol.Request, param.Params, interface{}) error - BindAndValidate(*protocol.Request, param.Params, interface{}) error + Bind(*protocol.Request, interface{}, param.Params) error + BindAndValidate(*protocol.Request, interface{}, param.Params) error BindQuery(*protocol.Request, interface{}) error BindHeader(*protocol.Request, interface{}) error - BindPath(*protocol.Request, param.Params, interface{}) error + BindPath(*protocol.Request, interface{}, param.Params) error BindForm(*protocol.Request, interface{}) error BindJSON(*protocol.Request, interface{}) error BindProtobuf(*protocol.Request, interface{}) error diff --git a/pkg/app/server/binding/binder_test.go b/pkg/app/server/binding/binder_test.go index c98be2c14..2970bbb5f 100644 --- a/pkg/app/server/binding/binder_test.go +++ b/pkg/app/server/binding/binder_test.go @@ -124,7 +124,7 @@ func TestBind_BaseType(t *testing.T) { var result Req - err := DefaultBinder().Bind(req.Req, params, &result) + err := DefaultBinder().Bind(req.Req, &result, params) if err != nil { t.Error(err) } @@ -149,7 +149,7 @@ func TestBind_SliceType(t *testing.T) { var result Req - err := DefaultBinder().Bind(req.Req, nil, &result) + err := DefaultBinder().Bind(req.Req, &result, nil) if err != nil { t.Error(err) } @@ -192,7 +192,7 @@ func TestBind_StructType(t *testing.T) { req := newMockRequest().SetRequestURI("http://foobar.com?F1=f1&B1=b1").SetHeader("f2", "f2") - err := DefaultBinder().Bind(req.Req, nil, &result) + err := DefaultBinder().Bind(req.Req, &result, nil) if err != nil { t.Error(err) } @@ -232,7 +232,7 @@ func TestBind_PointerType(t *testing.T) { req := newMockRequest().SetRequestURI(fmt.Sprintf("http://foobar.com?F1=%s&B1=%s&B2=%s&B3=%s&B3=%s&B4=%d&B4=%d", F1, B1, B2, B3s[0], B3s[1], B4s[0], B4s[1])). SetHeader("f2", "f2") - err := DefaultBinder().Bind(req.Req, nil, &result) + err := DefaultBinder().Bind(req.Req, &result, nil) if err != nil { t.Error(err) } @@ -264,7 +264,7 @@ func TestBind_NestedStruct(t *testing.T) { result := Bar{} req := newMockRequest().SetRequestURI("http://foobar.com?F1=qwe") - err := DefaultBinder().Bind(req.Req, nil, &result) + err := DefaultBinder().Bind(req.Req, &result, nil) if err != nil { t.Error(err) } @@ -285,7 +285,7 @@ func TestBind_SliceStruct(t *testing.T) { B1s := []string{"1", "2", "3"} req := newMockRequest().SetRequestURI(fmt.Sprintf("http://foobar.com?F1={\"f1\":\"%s\"}&F1={\"f1\":\"%s\"}&F1={\"f1\":\"%s\"}", B1s[0], B1s[1], B1s[2])) - err := DefaultBinder().Bind(req.Req, nil, &result) + err := DefaultBinder().Bind(req.Req, &result, nil) if err != nil { t.Error(err) } @@ -300,7 +300,7 @@ func TestBind_MapType(t *testing.T) { req := newMockRequest(). SetJSONContentType(). SetBody([]byte(`{"j1":"j1", "j2":"j2"}`)) - err := DefaultBinder().Bind(req.Req, nil, &result) + err := DefaultBinder().Bind(req.Req, &result, nil) if err != nil { t.Fatal(err) } @@ -319,7 +319,7 @@ func TestBind_MapFieldType(t *testing.T) { SetJSONContentType(). SetBody([]byte(`{"j1":"j1", "j2":"j2"}`)) result := Foo{} - err := DefaultBinder().Bind(req.Req, nil, &result) + err := DefaultBinder().Bind(req.Req, &result, nil) if err != nil { t.Fatal(err) } @@ -334,7 +334,7 @@ func TestBind_UnexportedField(t *testing.T) { } req := newMockRequest(). SetRequestURI("http://foobar.com?a=1&b=2") - err := DefaultBinder().Bind(req.Req, nil, &s) + err := DefaultBinder().Bind(req.Req, &s, nil) if err != nil { t.Fatal(err) } @@ -358,7 +358,7 @@ func TestBind_NoTagField(t *testing.T) { Value: "b2", }) - err := DefaultBinder().Bind(req.Req, params, &s) + err := DefaultBinder().Bind(req.Req, &s, params) if err != nil { t.Fatal(err) } @@ -375,7 +375,7 @@ func TestBind_ZeroValueBind(t *testing.T) { req := newMockRequest(). SetRequestURI("http://foobar.com?a=&b") - err := DefaultBinder().Bind(req.Req, nil, &s) + err := DefaultBinder().Bind(req.Req, &s, nil) if err != nil { t.Fatal(err) } @@ -393,7 +393,7 @@ func TestBind_DefaultValueBind(t *testing.T) { req := newMockRequest(). SetRequestURI("http://foobar.com") - err := DefaultBinder().Bind(req.Req, nil, &s) + err := DefaultBinder().Bind(req.Req, &s, nil) if err != nil { t.Fatal(err) } @@ -406,7 +406,7 @@ func TestBind_DefaultValueBind(t *testing.T) { D [2]string `default:"qwe"` } - err = DefaultBinder().Bind(req.Req, nil, &d) + err = DefaultBinder().Bind(req.Req, &d, nil) if err == nil { t.Fatal("expected err") } @@ -420,7 +420,7 @@ func TestBind_RequiredBind(t *testing.T) { SetRequestURI("http://foobar.com"). SetHeader("A", "1") - err := DefaultBinder().Bind(req.Req, nil, &s) + err := DefaultBinder().Bind(req.Req, &s, nil) if err == nil { t.Fatal("expected error") } @@ -428,7 +428,7 @@ func TestBind_RequiredBind(t *testing.T) { var d struct { A int `query:"a,required" header:"A"` } - err = DefaultBinder().Bind(req.Req, nil, &d) + err = DefaultBinder().Bind(req.Req, &d, nil) if err != nil { t.Fatal(err) } @@ -450,7 +450,7 @@ func TestBind_TypedefType(t *testing.T) { } req := newMockRequest(). SetRequestURI("http://foobar.com?a=1&b=2") - err := DefaultBinder().Bind(req.Req, nil, &s) + err := DefaultBinder().Bind(req.Req, &s, nil) if err != nil { t.Fatal(err) } @@ -483,7 +483,7 @@ func TestBind_EnumBind(t *testing.T) { } req := newMockRequest(). SetRequestURI("http://foobar.com?a=0&b=2") - err := DefaultBinder().Bind(req.Req, nil, &s) + err := DefaultBinder().Bind(req.Req, &s, nil) if err != nil { t.Fatal(err) } @@ -515,7 +515,7 @@ func TestBind_CustomizedTypeDecode(t *testing.T) { req := newMockRequest(). SetRequestURI("http://foobar.com?a=1&b=2") result := Foo{} - err = DefaultBinder().Bind(req.Req, nil, &result) + err = DefaultBinder().Bind(req.Req, &result, nil) if err != nil { t.Fatal(err) } @@ -526,7 +526,7 @@ func TestBind_CustomizedTypeDecode(t *testing.T) { } result2 := Bar{} - err = DefaultBinder().Bind(req.Req, nil, &result2) + err = DefaultBinder().Bind(req.Req, &result2, nil) if err != nil { t.Error(err) } @@ -548,7 +548,7 @@ func TestBind_JSON(t *testing.T) { SetJSONContentType(). SetBody([]byte(fmt.Sprintf(`{"j1":"j1", "j2":12, "j3":[%d, %d], "j4":["%s", "%s"]}`, J3s[0], J3s[1], J4s[0], J4s[1]))) var result Req - err := DefaultBinder().Bind(req.Req, nil, &result) + err := DefaultBinder().Bind(req.Req, &result, nil) if err != nil { t.Error(err) } @@ -577,7 +577,7 @@ func TestBind_ResetJSONUnmarshal(t *testing.T) { SetJSONContentType(). SetBody([]byte(fmt.Sprintf(`{"j1":"j1", "j2":12, "j3":[%d, %d], "j4":["%s", "%s"]}`, J3s[0], J3s[1], J4s[0], J4s[1]))) var result Req - err := DefaultBinder().Bind(req.Req, nil, &result) + err := DefaultBinder().Bind(req.Req, &result, nil) if err != nil { t.Error(err) } @@ -612,7 +612,7 @@ func TestBind_FileBind(t *testing.T) { // to parse multipart files req2 := req2.GetHTTP1Request(req.Req) _ = req2.String() - err := DefaultBinder().Bind(req.Req, nil, &s) + err := DefaultBinder().Bind(req.Req, &s, nil) if err != nil { t.Fatal(err) } @@ -644,7 +644,7 @@ func TestBind_FileSliceBind(t *testing.T) { // to parse multipart files req2 := req2.GetHTTP1Request(req.Req) _ = req2.String() - err := DefaultBinder().Bind(req.Req, nil, &s) + err := DefaultBinder().Bind(req.Req, &s, nil) if err != nil { t.Fatal(err) } @@ -680,7 +680,7 @@ func TestBind_AnonymousField(t *testing.T) { } req := newMockRequest(). SetRequestURI("http://foobar.com?s1=1&s2=2&n1=1&n2=2&n3=3") - err := DefaultBinder().Bind(req.Req, nil, &s) + err := DefaultBinder().Bind(req.Req, &s, nil) if err != nil { t.Fatal(err) } @@ -712,7 +712,7 @@ func TestBind_IgnoreField(t *testing.T) { var result Req - err := DefaultBinder().Bind(req.Req, params, &result) + err := DefaultBinder().Bind(req.Req, &result, params) if err != nil { t.Error(err) } @@ -746,7 +746,7 @@ func TestBind_DefaultTag(t *testing.T) { Value: "1", }) var result Req - err := DefaultBinder().Bind(req.Req, params, &result) + err := DefaultBinder().Bind(req.Req, &result, params) if err != nil { t.Error(err) } @@ -760,7 +760,7 @@ func TestBind_DefaultTag(t *testing.T) { EnableDefaultTag(true) }() result2 := Req2{} - err = DefaultBinder().Bind(req.Req, params, &result2) + err = DefaultBinder().Bind(req.Req, &result2, params) if err != nil { t.Error(err) } @@ -789,7 +789,7 @@ func TestBind_StructFieldResolve(t *testing.T) { defer func() { EnableStructFieldResolve(false) }() - err := DefaultBinder().Bind(req.Req, nil, &result) + err := DefaultBinder().Bind(req.Req, &result, nil) if err != nil { t.Error(err) } @@ -801,7 +801,7 @@ func TestBind_StructFieldResolve(t *testing.T) { SetHeaders("Header", "header"). SetPostArg("Form", "form"). SetUrlEncodeContentType() - err = DefaultBinder().Bind(req.Req, nil, &result) + err = DefaultBinder().Bind(req.Req, &result, nil) if err != nil { t.Error(err) } @@ -836,7 +836,7 @@ func TestBind_JSONRequiredField(t *testing.T) { SetJSONContentType(). SetBody(bodyBytes) var result Req - err := DefaultBinder().Bind(req.Req, nil, &result) + err := DefaultBinder().Bind(req.Req, &result, nil) if err == nil { t.Errorf("expected an error, but get nil") } @@ -856,7 +856,7 @@ func TestBind_JSONRequiredField(t *testing.T) { SetJSONContentType(). SetBody(bodyBytes) var result2 Req - err = DefaultBinder().Bind(req.Req, nil, &result2) + err = DefaultBinder().Bind(req.Req, &result2, nil) if err != nil { t.Error(err) } @@ -873,7 +873,7 @@ func TestValidate_MultipleValidate(t *testing.T) { req := newMockRequest(). SetRequestURI("http://foobar.com?a=9") var result Test1 - err := DefaultBinder().BindAndValidate(req.Req, nil, &result) + err := DefaultBinder().BindAndValidate(req.Req, &result, nil) if err == nil { t.Fatalf("expected an error, but get nil") } @@ -928,7 +928,7 @@ func Benchmark_Binding(b *testing.B) { b.ResetTimer() for i := 0; i < b.N; i++ { var result Req - err := DefaultBinder().Bind(req.Req, params, &result) + err := DefaultBinder().Bind(req.Req, &result, params) if err != nil { b.Error(err) } diff --git a/pkg/app/server/binding/default.go b/pkg/app/server/binding/default.go index 86ff5241b..d5f759147 100644 --- a/pkg/app/server/binding/default.go +++ b/pkg/app/server/binding/default.go @@ -68,7 +68,7 @@ import ( "github.com/bytedance/go-tagexpr/v2/validator" "github.com/cloudwego/hertz/internal/bytesconv" - "github.com/cloudwego/hertz/pkg/app/server/binding/internal/decoder" + inDecoder "github.com/cloudwego/hertz/pkg/app/server/binding/internal/decoder" hjson "github.com/cloudwego/hertz/pkg/common/json" "github.com/cloudwego/hertz/pkg/common/utils" "github.com/cloudwego/hertz/pkg/protocol" @@ -77,7 +77,7 @@ import ( ) type decoderInfo struct { - decoder decoder.Decoder + decoder inDecoder.Decoder needValidate bool } @@ -89,6 +89,31 @@ type defaultBinder struct { pathDecoderCache sync.Map } +// BindAndValidate binds data from *protocol.Request to obj and validates them if needed. +// NOTE: +// +// obj should be a pointer. +func BindAndValidate(req *protocol.Request, obj interface{}, pathParams param.Params) error { + return DefaultBinder().BindAndValidate(req, obj, pathParams) +} + +// Bind binds data from *protocol.Request to obj. +// NOTE: +// +// obj should be a pointer. +func Bind(req *protocol.Request, obj interface{}, pathParams param.Params) error { + return DefaultBinder().Bind(req, obj, pathParams) +} + +// Validate validates obj with "vd" tag +// NOTE: +// +// obj should be a pointer. +// Validate should be called after Bind. +func Validate(obj interface{}) error { + return DefaultValidator().ValidateStruct(obj) +} + func (b *defaultBinder) BindQuery(req *protocol.Request, v interface{}) error { rv, typeID := valueAndTypeID(v) if rv.Kind() != reflect.Ptr || rv.IsNil() { @@ -104,7 +129,7 @@ func (b *defaultBinder) BindQuery(req *protocol.Request, v interface{}) error { return decoder.decoder(req, nil, rv.Elem()) } - decoder, needValidate, err := decoder.GetReqDecoder(rv.Type(), "query") + decoder, needValidate, err := inDecoder.GetReqDecoder(rv.Type(), "query") if err != nil { return err } @@ -128,7 +153,7 @@ func (b *defaultBinder) BindHeader(req *protocol.Request, v interface{}) error { return decoder.decoder(req, nil, rv.Elem()) } - decoder, needValidate, err := decoder.GetReqDecoder(rv.Type(), "header") + decoder, needValidate, err := inDecoder.GetReqDecoder(rv.Type(), "header") if err != nil { return err } @@ -137,7 +162,7 @@ func (b *defaultBinder) BindHeader(req *protocol.Request, v interface{}) error { return decoder(req, nil, rv.Elem()) } -func (b *defaultBinder) BindPath(req *protocol.Request, params param.Params, v interface{}) error { +func (b *defaultBinder) BindPath(req *protocol.Request, v interface{}, params param.Params) error { rv, typeID := valueAndTypeID(v) if rv.Kind() != reflect.Ptr || rv.IsNil() { return fmt.Errorf("receiver must be a non-nil pointer") @@ -152,7 +177,7 @@ func (b *defaultBinder) BindPath(req *protocol.Request, params param.Params, v i return decoder.decoder(req, params, rv.Elem()) } - decoder, needValidate, err := decoder.GetReqDecoder(rv.Type(), "path") + decoder, needValidate, err := inDecoder.GetReqDecoder(rv.Type(), "path") if err != nil { return err } @@ -176,7 +201,7 @@ func (b *defaultBinder) BindForm(req *protocol.Request, v interface{}) error { return decoder.decoder(req, nil, rv.Elem()) } - decoder, needValidate, err := decoder.GetReqDecoder(rv.Type(), "form") + decoder, needValidate, err := inDecoder.GetReqDecoder(rv.Type(), "form") if err != nil { return err } @@ -213,7 +238,7 @@ func (b *defaultBinder) Name() string { return "hertz" } -func (b *defaultBinder) BindAndValidate(req *protocol.Request, params param.Params, v interface{}) error { +func (b *defaultBinder) BindAndValidate(req *protocol.Request, v interface{}, params param.Params) error { err := b.preBindBody(req, v) if err != nil { return fmt.Errorf("bind body failed, err=%v", err) @@ -239,7 +264,7 @@ func (b *defaultBinder) BindAndValidate(req *protocol.Request, params param.Para return err } - decoder, needValidate, err := decoder.GetReqDecoder(rv.Type(), "") + decoder, needValidate, err := inDecoder.GetReqDecoder(rv.Type(), "") if err != nil { return err } @@ -255,7 +280,7 @@ func (b *defaultBinder) BindAndValidate(req *protocol.Request, params param.Para return err } -func (b *defaultBinder) Bind(req *protocol.Request, params param.Params, v interface{}) error { +func (b *defaultBinder) Bind(req *protocol.Request, v interface{}, params param.Params) error { err := b.preBindBody(req, v) if err != nil { return fmt.Errorf("bind body failed, err=%v", err) @@ -274,7 +299,7 @@ func (b *defaultBinder) Bind(req *protocol.Request, params param.Params, v inter return decoder.decoder(req, params, rv.Elem()) } - decoder, needValidate, err := decoder.GetReqDecoder(rv.Type(), "") + decoder, needValidate, err := inDecoder.GetReqDecoder(rv.Type(), "") if err != nil { return err } diff --git a/pkg/app/server/binding/tagexpr_bind_test.go b/pkg/app/server/binding/tagexpr_bind_test.go index d3ca96246..b8a46d37c 100644 --- a/pkg/app/server/binding/tagexpr_bind_test.go +++ b/pkg/app/server/binding/tagexpr_bind_test.go @@ -59,7 +59,7 @@ func TestRawBody(t *testing.T) { bodyBytes := []byte("raw_body.............") req := newRequest("", nil, nil, bytes.NewReader(bodyBytes)) recv := new(Recv) - err := DefaultBinder().Bind(req.Req, nil, recv) + err := DefaultBinder().Bind(req.Req, recv, nil) if err != nil { if err != nil { t.Error(err) @@ -89,7 +89,7 @@ func TestQueryString(t *testing.T) { } req := newRequest("http://localhost:8080/?a=a1&a=a2&b=b1&c=c1&c=c2&d=d1&d=d&f=qps&g=1002&g=1003&e=&e=2&y=y1", nil, nil, nil) recv := new(Recv) - err := DefaultBinder().Bind(req.Req, nil, recv) + err := DefaultBinder().Bind(req.Req, recv, nil) if err != nil { t.Error(err) } @@ -113,7 +113,7 @@ func TestGetBody(t *testing.T) { } req := newRequest("http://localhost:8080/", nil, nil, nil) recv := new(Recv) - err := DefaultBinder().Bind(req.Req, nil, recv) + err := DefaultBinder().Bind(req.Req, recv, nil) if err == nil { t.Fatalf("expected an error, but get nil") } @@ -133,7 +133,7 @@ func TestQueryNum(t *testing.T) { } req := newRequest("http://localhost:8080/?a=11&a=12&b=21&c=31&c=32&d=41&d=42&y=true", nil, nil, nil) recv := new(Recv) - err := DefaultBinder().Bind(req.Req, nil, recv) + err := DefaultBinder().Bind(req.Req, recv, nil) if err != nil { if err != nil { t.Error(err) @@ -169,7 +169,7 @@ func TestHeaderString(t *testing.T) { header.Add("X-Y", "y1") req := newRequest("", header, nil, nil) recv := new(Recv) - err := DefaultBinder().Bind(req.Req, nil, recv) + err := DefaultBinder().Bind(req.Req, recv, nil) if err != nil { if err != nil { t.Error(err) @@ -206,7 +206,7 @@ func TestHeaderNum(t *testing.T) { req := newRequest("", header, nil, nil) recv := new(Recv) - err := DefaultBinder().Bind(req.Req, nil, recv) + err := DefaultBinder().Bind(req.Req, recv, nil) if err != nil { t.Error(err) } @@ -242,7 +242,7 @@ func TestCookieString(t *testing.T) { req := newRequest("", nil, cookies, nil) recv := new(Recv) - err := DefaultBinder().Bind(req.Req, nil, recv) + err := DefaultBinder().Bind(req.Req, recv, nil) if err != nil { t.Error(err) } @@ -275,7 +275,7 @@ func TestCookieNum(t *testing.T) { req := newRequest("", nil, cookies, nil) recv := new(Recv) - err := DefaultBinder().Bind(req.Req, nil, recv) + err := DefaultBinder().Bind(req.Req, recv, nil) if err != nil { t.Error(err) } @@ -321,7 +321,7 @@ func TestFormString(t *testing.T) { header.Set("Content-Type", contentType) req := newRequest("", header, nil, bodyReader) recv := new(Recv) - err := DefaultBinder().Bind(req.Req, nil, recv) + err := DefaultBinder().Bind(req.Req, recv, nil) if err != nil { t.Error(err) } @@ -372,7 +372,7 @@ func TestFormNum(t *testing.T) { req := newRequest("", header, nil, bodyReader) recv := new(Recv) - err := DefaultBinder().Bind(req.Req, nil, recv) + err := DefaultBinder().Bind(req.Req, recv, nil) if err != nil { t.Error(err) } @@ -423,7 +423,7 @@ func TestJSON(t *testing.T) { req := newRequest("", header, nil, bodyReader) recv := new(Recv) - err := DefaultBinder().Bind(req.Req, nil, recv) + err := DefaultBinder().Bind(req.Req, recv, nil) if err == nil { t.Error("expected an error, but get nil") } @@ -485,7 +485,7 @@ func TestPath(t *testing.T) { }, } - err := DefaultBinder().Bind(req.Req, params, recv) + err := DefaultBinder().Bind(req.Req, recv, params) if err != nil { t.Error(err) } @@ -551,7 +551,7 @@ func TestDefault(t *testing.T) { }, } - err := DefaultBinder().Bind(req.Req, param2, recv) + err := DefaultBinder().Bind(req.Req, recv, param2) if err != nil { t.Error(err) } @@ -604,7 +604,7 @@ func TestAuto(t *testing.T) { }, bodyReader) recv := new(Recv) - err = DefaultBinder().Bind(req.Req, nil, recv) + err = DefaultBinder().Bind(req.Req, recv, nil) if err != nil { t.Error(err) } @@ -625,7 +625,7 @@ func TestAuto(t *testing.T) { header.Set("Content-Type", contentType) req = newRequest("http://localhost/?"+query.Encode(), header, nil, bodyReader) recv = new(Recv) - err = DefaultBinder().Bind(req.Req, nil, recv) + err = DefaultBinder().Bind(req.Req, recv, nil) if err != nil { t.Error(err) } @@ -652,7 +652,7 @@ func TestTypeUnmarshal(t *testing.T) { req := newRequest("http://localhost/?"+query.Encode(), header, nil, bodyReader) recv := new(Recv) - err := DefaultBinder().Bind(req.Req, nil, recv) + err := DefaultBinder().Bind(req.Req, recv, nil) if err != nil { t.Error(err) } @@ -696,7 +696,7 @@ func TestOption(t *testing.T) { req := newRequest("", header, nil, bodyReader) recv := new(Recv) - err := DefaultBinder().Bind(req.Req, nil, recv) + err := DefaultBinder().Bind(req.Req, recv, nil) if err != nil { t.Error(err) } @@ -711,7 +711,7 @@ func TestOption(t *testing.T) { }`) req = newRequest("", header, nil, bodyReader) recv = new(Recv) - err = DefaultBinder().Bind(req.Req, nil, recv) + err = DefaultBinder().Bind(req.Req, recv, nil) assert.DeepEqual(t, err.Error(), "'C' field is a 'required' parameter, but the request body does not have this parameter 'X.c'") assert.DeepEqual(t, 0, recv.X.C) assert.DeepEqual(t, 0, recv.X.D) @@ -722,7 +722,7 @@ func TestOption(t *testing.T) { }`) req = newRequest("", header, nil, bodyReader) recv = new(Recv) - err = DefaultBinder().Bind(req.Req, nil, recv) + err = DefaultBinder().Bind(req.Req, recv, nil) if err != nil { t.Error(err) } @@ -745,7 +745,7 @@ func TestOption(t *testing.T) { defer func() { EnableStructFieldResolve(false) }() - err = DefaultBinder().Bind(req.Req, nil, recv2) + err = DefaultBinder().Bind(req.Req, recv2, nil) assert.DeepEqual(t, err.Error(), "'X' field is a 'required' parameter, but the request does not have this parameter") assert.True(t, recv2.X == nil) assert.DeepEqual(t, "y1", recv2.Y) @@ -793,7 +793,7 @@ func TestQueryStringIssue(t *testing.T) { req := newRequest("http://localhost:8080/?name=test", nil, nil, nil) recv := new(Recv) - err := DefaultBinder().Bind(req.Req, nil, recv) + err := DefaultBinder().Bind(req.Req, recv, nil) if err != nil { t.Error(err) } @@ -840,7 +840,7 @@ func TestQueryTypes(t *testing.T) { }, bodyReader) recv := new(Recv) - err = DefaultBinder().Bind(req.Req, nil, recv) + err = DefaultBinder().Bind(req.Req, recv, nil) if err != nil { t.Error(err) } @@ -863,7 +863,7 @@ func TestNoTagIssue(t *testing.T) { req := newRequest("http://localhost:8080/?x=11&x2=12&a=1&B=2", nil, nil, nil) recv := new(T) - err := DefaultBinder().Bind(req.Req, nil, recv) + err := DefaultBinder().Bind(req.Req, recv, nil) if err != nil { t.Error(err) } @@ -898,7 +898,7 @@ func TestRegTypeUnmarshal(t *testing.T) { defer func() { EnableStructFieldResolve(false) }() - err = DefaultBinder().Bind(req.Req, nil, recv) + err = DefaultBinder().Bind(req.Req, recv, nil) if err != nil { t.Error(err) } @@ -964,7 +964,7 @@ func TestPathnameBUG(t *testing.T) { req := newRequest("http://localhost", header, nil, bytes.NewReader(b)) recv := new(ExchangeCurrencyRequest) - err = DefaultBinder().Bind(req.Req, nil, recv) + err = DefaultBinder().Bind(req.Req, recv, nil) if err != nil { t.Error(err) } @@ -1021,7 +1021,7 @@ func TestPathnameBUG2(t *testing.T) { req := newRequest("http://localhost", header, nil, bytes.NewReader(b)) recv := new(CreateFreeShippingRequest) - err = DefaultBinder().Bind(req.Req, nil, recv) + err = DefaultBinder().Bind(req.Req, recv, nil) if err != nil { t.Error(err) } @@ -1069,7 +1069,7 @@ func TestRequiredBUG(t *testing.T) { req := newRequest("http://localhost", header, nil, bytes.NewReader(b)) recv := new(ExchangeCurrencyRequest) - err := DefaultBinder().Bind(req.Req, nil, recv) + err := DefaultBinder().Bind(req.Req, recv, nil) // no need for validate if err != nil { t.Error(err) @@ -1089,7 +1089,7 @@ func TestIssue25(t *testing.T) { req := newRequest("/1", header, cookies, nil) recv := new(Recv) - err := DefaultBinder().Bind(req.Req, nil, recv) + err := DefaultBinder().Bind(req.Req, recv, nil) if err != nil { t.Error(err) } @@ -1100,7 +1100,7 @@ func TestIssue25(t *testing.T) { cookies2 := []*http.Cookie{} req2 := newRequest("/2", header2, cookies2, nil) recv2 := new(Recv) - err2 := DefaultBinder().Bind(req2.Req, nil, recv2) + err2 := DefaultBinder().Bind(req2.Req, recv2, nil) if err2 != nil { t.Error(err2) } @@ -1149,7 +1149,7 @@ func TestIssue26(t *testing.T) { req := newRequest("/1", header, cookies, bytes.NewReader(b)) recv2 := new(Recv) - err = DefaultBinder().Bind(req.Req, nil, recv2) + err = DefaultBinder().Bind(req.Req, recv2, nil) if err != nil { t.Error(err) } From ab16a83b3c4217cf0510fe71a17da91aeb99ece3 Mon Sep 17 00:00:00 2001 From: fgy Date: Thu, 24 Aug 2023 21:53:49 +0800 Subject: [PATCH 51/91] feat: add context bindXXX comment --- pkg/app/context.go | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/pkg/app/context.go b/pkg/app/context.go index 925ecf040..1e1a23f0a 100644 --- a/pkg/app/context.go +++ b/pkg/app/context.go @@ -1320,18 +1320,26 @@ func (ctx *RequestContext) Validate(obj interface{}) error { return binding.DefaultValidator().ValidateStruct(obj) } +// BindQuery binds query parameters from *RequestContext to obj with 'query' tag. It will only use 'query' tag for binding. +// NOTE: obj should be a pointer. func (ctx *RequestContext) BindQuery(obj interface{}) error { return binding.DefaultBinder().BindQuery(&ctx.Request, obj) } +// BindHeader binds header parameters from *RequestContext to obj with 'header' tag. It will only use 'header' tag for binding. +// NOTE: obj should be a pointer. func (ctx *RequestContext) BindHeader(obj interface{}) error { return binding.DefaultBinder().BindHeader(&ctx.Request, obj) } +// BindPath binds router parameters from *RequestContext to obj with 'path' tag. It will only use 'path' tag for binding. +// NOTE: obj should be a pointer. func (ctx *RequestContext) BindPath(obj interface{}) error { return binding.DefaultBinder().BindPath(&ctx.Request, obj, ctx.Params) } +// BindForm binds form parameters from *RequestContext to obj with 'form' tag. It will only use 'form' tag for binding. +// NOTE: obj should be a pointer. func (ctx *RequestContext) BindForm(obj interface{}) error { if len(ctx.Request.Body()) == 0 { return fmt.Errorf("missing form body") @@ -1339,14 +1347,20 @@ func (ctx *RequestContext) BindForm(obj interface{}) error { return binding.DefaultBinder().BindForm(&ctx.Request, obj) } +// BindJSON binds JSON body from *RequestContext. +// NOTE: obj should be a pointer. func (ctx *RequestContext) BindJSON(obj interface{}) error { return binding.DefaultBinder().BindJSON(&ctx.Request, obj) } +// BindProtobuf binds protobuf body from *RequestContext. +// NOTE: obj should be a pointer. func (ctx *RequestContext) BindProtobuf(obj interface{}) error { return binding.DefaultBinder().BindProtobuf(&ctx.Request, obj) } +// BindByContentType will select the binding type on the ContentType automatically. +// NOTE: obj should be a pointer. func (ctx *RequestContext) BindByContentType(obj interface{}) error { if bytesconv.B2s(ctx.Request.Method()) == consts.MethodGet { return ctx.BindQuery(obj) From f9be06f9f3df07b42d786c6a32f213cb6e847f6e Mon Sep 17 00:00:00 2001 From: fgy Date: Thu, 24 Aug 2023 22:03:35 +0800 Subject: [PATCH 52/91] refactor: BindByContentType --- pkg/app/context.go | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/pkg/app/context.go b/pkg/app/context.go index 1e1a23f0a..4132b80cb 100644 --- a/pkg/app/context.go +++ b/pkg/app/context.go @@ -1362,7 +1362,7 @@ func (ctx *RequestContext) BindProtobuf(obj interface{}) error { // BindByContentType will select the binding type on the ContentType automatically. // NOTE: obj should be a pointer. func (ctx *RequestContext) BindByContentType(obj interface{}) error { - if bytesconv.B2s(ctx.Request.Method()) == consts.MethodGet { + if ctx.Request.Header.IsGet() { return ctx.BindQuery(obj) } ct := utils.FilterContentType(bytesconv.B2s(ctx.Request.Header.ContentType())) @@ -1371,16 +1371,10 @@ func (ctx *RequestContext) BindByContentType(obj interface{}) error { return ctx.BindJSON(obj) case "application/x-protobuf": return ctx.BindProtobuf(obj) - case "application/xml", "text/xml": - return fmt.Errorf("unsupported bind content-type for '%s'", ct) - case "application/x-msgpack", "application/msgpack": - return fmt.Errorf("unsupported bind content-type for '%s'", ct) - case "application/x-yaml": - return fmt.Errorf("unsupported bind content-type for '%s'", ct) - case "application/toml": - return fmt.Errorf("unsupported bind content-type for '%s'", ct) - default: // case MIMEPOSTForm/MIMEMultipartPOSTForm + case "application/x-www-form-urlencoded", "multipart/form-data": return ctx.BindForm(obj) + default: // case MIMEPOSTForm/MIMEMultipartPOSTForm + return fmt.Errorf("unsupported bind content-type for '%s'", ct) } } From 79bb2b0e659bcf1f535533ab1c07d5e642ec32e1 Mon Sep 17 00:00:00 2001 From: fgy Date: Thu, 24 Aug 2023 22:09:25 +0800 Subject: [PATCH 53/91] feat: license to 2023 --- pkg/app/server/binding/binder.go | 4 ++-- pkg/app/server/binding/binder_test.go | 4 ++-- pkg/app/server/binding/config.go | 2 +- pkg/app/server/binding/default.go | 4 ++-- pkg/app/server/binding/internal/decoder/base_type_decoder.go | 4 ++-- .../binding/internal/decoder/customized_type_decoder.go | 4 ++-- pkg/app/server/binding/internal/decoder/decoder.go | 4 ++-- pkg/app/server/binding/internal/decoder/getter.go | 4 ++-- pkg/app/server/binding/internal/decoder/gjson_required.go | 2 +- pkg/app/server/binding/internal/decoder/map_type_decoder.go | 4 ++-- .../server/binding/internal/decoder/multipart_file_decoder.go | 2 +- pkg/app/server/binding/internal/decoder/reflect.go | 4 ++-- pkg/app/server/binding/internal/decoder/slice_getter.go | 4 ++-- pkg/app/server/binding/internal/decoder/slice_type_decoder.go | 4 ++-- pkg/app/server/binding/internal/decoder/sonic_required.go | 2 +- .../server/binding/internal/decoder/struct_type_decoder.go | 2 +- pkg/app/server/binding/internal/decoder/tag.go | 2 +- pkg/app/server/binding/internal/decoder/text_decoder.go | 4 ++-- pkg/app/server/binding/reflect.go | 4 ++-- pkg/app/server/binding/tagexpr_bind_test.go | 4 ++-- pkg/app/server/binding/validator.go | 4 ++-- pkg/app/server/binding/validator_test.go | 2 +- 22 files changed, 37 insertions(+), 37 deletions(-) diff --git a/pkg/app/server/binding/binder.go b/pkg/app/server/binding/binder.go index dc8951ef1..2e43daee3 100644 --- a/pkg/app/server/binding/binder.go +++ b/pkg/app/server/binding/binder.go @@ -1,5 +1,5 @@ /* - * Copyright 2022 CloudWeGo Authors + * Copyright 2023 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -35,7 +35,7 @@ * SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo - * Modifications are Copyright 2022 CloudWeGo Authors + * Modifications are Copyright 2023 CloudWeGo Authors */ package binding diff --git a/pkg/app/server/binding/binder_test.go b/pkg/app/server/binding/binder_test.go index 2970bbb5f..13c60233a 100644 --- a/pkg/app/server/binding/binder_test.go +++ b/pkg/app/server/binding/binder_test.go @@ -1,5 +1,5 @@ /* - * Copyright 2022 CloudWeGo Authors + * Copyright 2023 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -35,7 +35,7 @@ * SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo - * Modifications are Copyright 2022 CloudWeGo Authors + * Modifications are Copyright 2023 CloudWeGo Authors */ package binding diff --git a/pkg/app/server/binding/config.go b/pkg/app/server/binding/config.go index b8d0cbcd2..a3ae9dc30 100644 --- a/pkg/app/server/binding/config.go +++ b/pkg/app/server/binding/config.go @@ -1,5 +1,5 @@ /* - * Copyright 2022 CloudWeGo Authors + * Copyright 2023 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/pkg/app/server/binding/default.go b/pkg/app/server/binding/default.go index d5f759147..842416126 100644 --- a/pkg/app/server/binding/default.go +++ b/pkg/app/server/binding/default.go @@ -1,5 +1,5 @@ /* - * Copyright 2022 CloudWeGo Authors + * Copyright 2023 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -55,7 +55,7 @@ * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo - * Modifications are Copyright 2022 CloudWeGo Authors + * Modifications are Copyright 2023 CloudWeGo Authors */ package binding diff --git a/pkg/app/server/binding/internal/decoder/base_type_decoder.go b/pkg/app/server/binding/internal/decoder/base_type_decoder.go index 099aea20f..342c52960 100644 --- a/pkg/app/server/binding/internal/decoder/base_type_decoder.go +++ b/pkg/app/server/binding/internal/decoder/base_type_decoder.go @@ -1,5 +1,5 @@ /* - * Copyright 2022 CloudWeGo Authors + * Copyright 2023 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -35,7 +35,7 @@ * SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo - * Modifications are Copyright 2022 CloudWeGo Authors + * Modifications are Copyright 2023 CloudWeGo Authors */ package decoder diff --git a/pkg/app/server/binding/internal/decoder/customized_type_decoder.go b/pkg/app/server/binding/internal/decoder/customized_type_decoder.go index dcfcd9d6f..9ea8f751e 100644 --- a/pkg/app/server/binding/internal/decoder/customized_type_decoder.go +++ b/pkg/app/server/binding/internal/decoder/customized_type_decoder.go @@ -1,5 +1,5 @@ /* - * Copyright 2022 CloudWeGo Authors + * Copyright 2023 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -35,7 +35,7 @@ * SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo - * Modifications are Copyright 2022 CloudWeGo Authors + * Modifications are Copyright 2023 CloudWeGo Authors */ package decoder diff --git a/pkg/app/server/binding/internal/decoder/decoder.go b/pkg/app/server/binding/internal/decoder/decoder.go index fe651368b..4100815a1 100644 --- a/pkg/app/server/binding/internal/decoder/decoder.go +++ b/pkg/app/server/binding/internal/decoder/decoder.go @@ -1,5 +1,5 @@ /* - * Copyright 2022 CloudWeGo Authors + * Copyright 2023 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -35,7 +35,7 @@ * SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo - * Modifications are Copyright 2022 CloudWeGo Authors + * Modifications are Copyright 2023 CloudWeGo Authors */ package decoder diff --git a/pkg/app/server/binding/internal/decoder/getter.go b/pkg/app/server/binding/internal/decoder/getter.go index 570076cd3..4c605657a 100644 --- a/pkg/app/server/binding/internal/decoder/getter.go +++ b/pkg/app/server/binding/internal/decoder/getter.go @@ -1,5 +1,5 @@ /* - * Copyright 2022 CloudWeGo Authors + * Copyright 2023 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -35,7 +35,7 @@ * SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo - * Modifications are Copyright 2022 CloudWeGo Authors + * Modifications are Copyright 2023 CloudWeGo Authors */ package decoder diff --git a/pkg/app/server/binding/internal/decoder/gjson_required.go b/pkg/app/server/binding/internal/decoder/gjson_required.go index f6aac0a84..d41427194 100644 --- a/pkg/app/server/binding/internal/decoder/gjson_required.go +++ b/pkg/app/server/binding/internal/decoder/gjson_required.go @@ -1,4 +1,4 @@ -// Copyright 2022 CloudWeGo Authors +// Copyright 2023 CloudWeGo Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pkg/app/server/binding/internal/decoder/map_type_decoder.go b/pkg/app/server/binding/internal/decoder/map_type_decoder.go index e2f2e819a..0b19a9ba8 100644 --- a/pkg/app/server/binding/internal/decoder/map_type_decoder.go +++ b/pkg/app/server/binding/internal/decoder/map_type_decoder.go @@ -1,5 +1,5 @@ /* - * Copyright 2022 CloudWeGo Authors + * Copyright 2023 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -35,7 +35,7 @@ * SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo - * Modifications are Copyright 2022 CloudWeGo Authors + * Modifications are Copyright 2023 CloudWeGo Authors */ package decoder diff --git a/pkg/app/server/binding/internal/decoder/multipart_file_decoder.go b/pkg/app/server/binding/internal/decoder/multipart_file_decoder.go index db6d60d1a..c37c0e292 100644 --- a/pkg/app/server/binding/internal/decoder/multipart_file_decoder.go +++ b/pkg/app/server/binding/internal/decoder/multipart_file_decoder.go @@ -1,5 +1,5 @@ /* - * Copyright 2022 CloudWeGo Authors + * Copyright 2023 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/pkg/app/server/binding/internal/decoder/reflect.go b/pkg/app/server/binding/internal/decoder/reflect.go index dba448fd6..d69c4b780 100644 --- a/pkg/app/server/binding/internal/decoder/reflect.go +++ b/pkg/app/server/binding/internal/decoder/reflect.go @@ -1,5 +1,5 @@ /* - * Copyright 2022 CloudWeGo Authors + * Copyright 2023 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -35,7 +35,7 @@ * SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo - * Modifications are Copyright 2022 CloudWeGo Authors + * Modifications are Copyright 2023 CloudWeGo Authors */ package decoder diff --git a/pkg/app/server/binding/internal/decoder/slice_getter.go b/pkg/app/server/binding/internal/decoder/slice_getter.go index 7bf6dc27b..27d2b4174 100644 --- a/pkg/app/server/binding/internal/decoder/slice_getter.go +++ b/pkg/app/server/binding/internal/decoder/slice_getter.go @@ -1,5 +1,5 @@ /* - * Copyright 2022 CloudWeGo Authors + * Copyright 2023 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -35,7 +35,7 @@ * SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo - * Modifications are Copyright 2022 CloudWeGo Authors + * Modifications are Copyright 2023 CloudWeGo Authors */ package decoder diff --git a/pkg/app/server/binding/internal/decoder/slice_type_decoder.go b/pkg/app/server/binding/internal/decoder/slice_type_decoder.go index 12d876f14..08c0ba416 100644 --- a/pkg/app/server/binding/internal/decoder/slice_type_decoder.go +++ b/pkg/app/server/binding/internal/decoder/slice_type_decoder.go @@ -1,5 +1,5 @@ /* - * Copyright 2022 CloudWeGo Authors + * Copyright 2023 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -35,7 +35,7 @@ * SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo - * Modifications are Copyright 2022 CloudWeGo Authors + * Modifications are Copyright 2023 CloudWeGo Authors */ package decoder diff --git a/pkg/app/server/binding/internal/decoder/sonic_required.go b/pkg/app/server/binding/internal/decoder/sonic_required.go index fcf922c65..af24ed18d 100644 --- a/pkg/app/server/binding/internal/decoder/sonic_required.go +++ b/pkg/app/server/binding/internal/decoder/sonic_required.go @@ -1,4 +1,4 @@ -// Copyright 2022 CloudWeGo Authors +// Copyright 2023 CloudWeGo Authors // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/pkg/app/server/binding/internal/decoder/struct_type_decoder.go b/pkg/app/server/binding/internal/decoder/struct_type_decoder.go index 7592a420e..6f81d731b 100644 --- a/pkg/app/server/binding/internal/decoder/struct_type_decoder.go +++ b/pkg/app/server/binding/internal/decoder/struct_type_decoder.go @@ -1,5 +1,5 @@ /* - * Copyright 2022 CloudWeGo Authors + * Copyright 2023 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/pkg/app/server/binding/internal/decoder/tag.go b/pkg/app/server/binding/internal/decoder/tag.go index f50f35f1c..0f754fbb1 100644 --- a/pkg/app/server/binding/internal/decoder/tag.go +++ b/pkg/app/server/binding/internal/decoder/tag.go @@ -1,5 +1,5 @@ /* - * Copyright 2022 CloudWeGo Authors + * Copyright 2023 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/pkg/app/server/binding/internal/decoder/text_decoder.go b/pkg/app/server/binding/internal/decoder/text_decoder.go index 157023f3c..349094272 100644 --- a/pkg/app/server/binding/internal/decoder/text_decoder.go +++ b/pkg/app/server/binding/internal/decoder/text_decoder.go @@ -1,5 +1,5 @@ /* - * Copyright 2022 CloudWeGo Authors + * Copyright 2023 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -35,7 +35,7 @@ * SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo - * Modifications are Copyright 2022 CloudWeGo Authors + * Modifications are Copyright 2023 CloudWeGo Authors */ package decoder diff --git a/pkg/app/server/binding/reflect.go b/pkg/app/server/binding/reflect.go index 4d2e7f33d..0b3be8f6b 100644 --- a/pkg/app/server/binding/reflect.go +++ b/pkg/app/server/binding/reflect.go @@ -1,5 +1,5 @@ /* - * Copyright 2022 CloudWeGo Authors + * Copyright 2023 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -35,7 +35,7 @@ * SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo - * Modifications are Copyright 2022 CloudWeGo Authors + * Modifications are Copyright 2023 CloudWeGo Authors */ package binding diff --git a/pkg/app/server/binding/tagexpr_bind_test.go b/pkg/app/server/binding/tagexpr_bind_test.go index b8a46d37c..a9729af8e 100644 --- a/pkg/app/server/binding/tagexpr_bind_test.go +++ b/pkg/app/server/binding/tagexpr_bind_test.go @@ -1,5 +1,5 @@ /* - * Copyright 2022 CloudWeGo Authors + * Copyright 2023 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -29,7 +29,7 @@ * limitations under the License. * * This file may have been modified by CloudWeGo authors. All CloudWeGo - * Modifications are Copyright 2022 CloudWeGo Authors + * Modifications are Copyright 2023 CloudWeGo Authors */ package binding diff --git a/pkg/app/server/binding/validator.go b/pkg/app/server/binding/validator.go index 1ede4deb0..910a1a02c 100644 --- a/pkg/app/server/binding/validator.go +++ b/pkg/app/server/binding/validator.go @@ -1,5 +1,5 @@ /* - * Copyright 2022 CloudWeGo Authors + * Copyright 2023 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -35,7 +35,7 @@ * THE SOFTWARE. * * This file may have been modified by CloudWeGo authors. All CloudWeGo - * Modifications are Copyright 2022 CloudWeGo Authors + * Modifications are Copyright 2023 CloudWeGo Authors */ package binding diff --git a/pkg/app/server/binding/validator_test.go b/pkg/app/server/binding/validator_test.go index 05a59affa..2f85716b5 100644 --- a/pkg/app/server/binding/validator_test.go +++ b/pkg/app/server/binding/validator_test.go @@ -1,5 +1,5 @@ /* - * Copyright 2022 CloudWeGo Authors + * Copyright 2023 CloudWeGo Authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. From 76e373a22ce22691897874b28b255e87cdd5d361 Mon Sep 17 00:00:00 2001 From: fgy Date: Thu, 24 Aug 2023 22:24:56 +0800 Subject: [PATCH 54/91] feat: use consts content-type --- pkg/app/context.go | 8 ++++---- pkg/app/server/binding/binder_test.go | 3 ++- pkg/app/server/binding/default.go | 10 +++------- .../server/binding/internal/decoder/gjson_required.go | 3 ++- .../server/binding/internal/decoder/sonic_required.go | 3 ++- pkg/app/server/binding/tagexpr_bind_test.go | 11 ++++++----- pkg/protocol/consts/headers.go | 2 ++ 7 files changed, 21 insertions(+), 19 deletions(-) diff --git a/pkg/app/context.go b/pkg/app/context.go index 4132b80cb..042e0f540 100644 --- a/pkg/app/context.go +++ b/pkg/app/context.go @@ -1367,13 +1367,13 @@ func (ctx *RequestContext) BindByContentType(obj interface{}) error { } ct := utils.FilterContentType(bytesconv.B2s(ctx.Request.Header.ContentType())) switch ct { - case "application/json": + case consts.MIMEApplicationJSON: return ctx.BindJSON(obj) - case "application/x-protobuf": + case consts.MIMEPROTOBUF: return ctx.BindProtobuf(obj) - case "application/x-www-form-urlencoded", "multipart/form-data": + case consts.MIMEApplicationHTMLForm, consts.MIMEMultipartPOSTForm: return ctx.BindForm(obj) - default: // case MIMEPOSTForm/MIMEMultipartPOSTForm + default: return fmt.Errorf("unsupported bind content-type for '%s'", ct) } } diff --git a/pkg/app/server/binding/binder_test.go b/pkg/app/server/binding/binder_test.go index 13c60233a..6abb06d56 100644 --- a/pkg/app/server/binding/binder_test.go +++ b/pkg/app/server/binding/binder_test.go @@ -48,6 +48,7 @@ import ( "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/protocol" + "github.com/cloudwego/hertz/pkg/protocol/consts" req2 "github.com/cloudwego/hertz/pkg/protocol/http1/req" "github.com/cloudwego/hertz/pkg/route/param" ) @@ -93,7 +94,7 @@ func (m *mockRequest) SetUrlEncodeContentType() *mockRequest { } func (m *mockRequest) SetJSONContentType() *mockRequest { - m.Req.Header.SetContentTypeBytes([]byte(jsonContentType)) + m.Req.Header.SetContentTypeBytes([]byte(consts.MIMEApplicationJSON)) return m } diff --git a/pkg/app/server/binding/default.go b/pkg/app/server/binding/default.go index 842416126..5faabf263 100644 --- a/pkg/app/server/binding/default.go +++ b/pkg/app/server/binding/default.go @@ -72,6 +72,7 @@ import ( hjson "github.com/cloudwego/hertz/pkg/common/json" "github.com/cloudwego/hertz/pkg/common/utils" "github.com/cloudwego/hertz/pkg/protocol" + "github.com/cloudwego/hertz/pkg/protocol/consts" "github.com/cloudwego/hertz/pkg/route/param" "google.golang.org/protobuf/proto" ) @@ -308,11 +309,6 @@ func (b *defaultBinder) Bind(req *protocol.Request, v interface{}, params param. return decoder(req, params, rv.Elem()) } -var ( - jsonContentType = "application/json" - protobufContentType = "application/x-protobuf" -) - // best effort binding func (b *defaultBinder) preBindBody(req *protocol.Request, v interface{}) error { if req.Header.ContentLength() <= 0 { @@ -320,9 +316,9 @@ func (b *defaultBinder) preBindBody(req *protocol.Request, v interface{}) error } ct := bytesconv.B2s(req.Header.ContentType()) switch utils.FilterContentType(ct) { - case jsonContentType: + case consts.MIMEApplicationJSON: return hjson.Unmarshal(req.Body(), v) - case protobufContentType: + case consts.MIMEPROTOBUF: msg, ok := v.(proto.Message) if !ok { return fmt.Errorf("%s can not implement 'proto.Message'", v) diff --git a/pkg/app/server/binding/internal/decoder/gjson_required.go b/pkg/app/server/binding/internal/decoder/gjson_required.go index d41427194..5fbfd4086 100644 --- a/pkg/app/server/binding/internal/decoder/gjson_required.go +++ b/pkg/app/server/binding/internal/decoder/gjson_required.go @@ -23,6 +23,7 @@ import ( "github.com/cloudwego/hertz/internal/bytesconv" "github.com/cloudwego/hertz/pkg/common/utils" + "github.com/cloudwego/hertz/pkg/protocol/consts" "github.com/tidwall/gjson" ) @@ -31,7 +32,7 @@ func checkRequireJSON2(req *protocol.Request, tagInfo TagInfo) bool { return true } ct := bytesconv.B2s(req.Req.Header.ContentType()) - if utils.FilterContentType(ct) != "application/json" { + if utils.FilterContentType(ct) != consts.MIMEApplicationJSON { return false } result := gjson.GetBytes(req.Req.Body(), tagInfo.JSONName) diff --git a/pkg/app/server/binding/internal/decoder/sonic_required.go b/pkg/app/server/binding/internal/decoder/sonic_required.go index af24ed18d..2aae0c3a4 100644 --- a/pkg/app/server/binding/internal/decoder/sonic_required.go +++ b/pkg/app/server/binding/internal/decoder/sonic_required.go @@ -27,6 +27,7 @@ import ( "github.com/cloudwego/hertz/internal/bytesconv" "github.com/cloudwego/hertz/pkg/common/utils" "github.com/cloudwego/hertz/pkg/protocol" + "github.com/cloudwego/hertz/pkg/protocol/consts" ) func checkRequireJSON(req *protocol.Request, tagInfo TagInfo) bool { @@ -34,7 +35,7 @@ func checkRequireJSON(req *protocol.Request, tagInfo TagInfo) bool { return true } ct := bytesconv.B2s(req.Header.ContentType()) - if utils.FilterContentType(ct) != "application/json" { + if utils.FilterContentType(ct) != consts.MIMEApplicationJSON { return false } node, _ := sonic.Get(req.Body(), stringSliceForInterface(tagInfo.JSONName)...) diff --git a/pkg/app/server/binding/tagexpr_bind_test.go b/pkg/app/server/binding/tagexpr_bind_test.go index a9729af8e..7cca8b0d3 100644 --- a/pkg/app/server/binding/tagexpr_bind_test.go +++ b/pkg/app/server/binding/tagexpr_bind_test.go @@ -47,6 +47,7 @@ import ( "time" "github.com/cloudwego/hertz/pkg/common/test/assert" + "github.com/cloudwego/hertz/pkg/protocol/consts" "github.com/cloudwego/hertz/pkg/route/param" "google.golang.org/protobuf/proto" ) @@ -419,7 +420,7 @@ func TestJSON(t *testing.T) { }`) header := make(http.Header) - header.Set("Content-Type", "application/json") + header.Set("Content-Type", consts.MIMEApplicationJSON) req := newRequest("", header, nil, bodyReader) recv := new(Recv) @@ -540,7 +541,7 @@ func TestDefault(t *testing.T) { // var nilMap map[string]string header := make(http.Header) - header.Set("Content-Type", "application/json") + header.Set("Content-Type", consts.MIMEApplicationJSON) req := newRequest("", header, nil, bodyReader) recv := new(Recv) @@ -684,7 +685,7 @@ func TestOption(t *testing.T) { Y string `json:"y"` } header := make(http.Header) - header.Set("Content-Type", "application/json") + header.Set("Content-Type", consts.MIMEApplicationJSON) bodyReader := strings.NewReader(`{ "X": { @@ -1140,7 +1141,7 @@ func TestIssue26(t *testing.T) { } header := make(http.Header) - header.Set("Content-Type", "application/json") + header.Set("Content-Type", consts.MIMEApplicationJSON) header.Set("A", "from header") cookies := []*http.Cookie{ {Name: "A", Value: "from cookie"}, @@ -1169,7 +1170,7 @@ func TestIssue26(t *testing.T) { // } // }`) // header := make(http.Header) -// header.Set("Content-Type", "application/json") +// header.Set("Content-Type", consts.MIMEApplicationJSON) // req := newRequest("", header, nil, bodyReader) // recv := new(Recv) // diff --git a/pkg/protocol/consts/headers.go b/pkg/protocol/consts/headers.go index 3c2b82e7e..e4b7b316b 100644 --- a/pkg/protocol/consts/headers.go +++ b/pkg/protocol/consts/headers.go @@ -96,6 +96,7 @@ const ( MIMETextHtml = "text/html" MIMETextCss = "text/css" MIMETextJavascript = "text/javascript" + MIMEMultipartPOSTForm = "multipart/form-data" // MIME application MIMEApplicationOctetStream = "application/octet-stream" @@ -121,6 +122,7 @@ const ( MIMEApplicationOpenXMLWord = "application/vnd.openxmlformats-officedocument.wordprocessingml.document" MIMEApplicationOpenXMLExcel = "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" MIMEApplicationOpenXMLPPT = "application/vnd.openxmlformats-officedocument.presentationml.presentation" + MIMEPROTOBUF = "application/x-protobuf" // MIME image MIMEImageJPEG = "image/jpeg" From d74c955309d9c58fcc36944740e9d4df06f4f03c Mon Sep 17 00:00:00 2001 From: fgy Date: Thu, 24 Aug 2023 22:27:31 +0800 Subject: [PATCH 55/91] refactor: json alias --- pkg/app/server/binding/config.go | 8 ++++---- pkg/app/server/binding/default.go | 6 +++--- .../server/binding/internal/decoder/map_type_decoder.go | 4 ++-- .../server/binding/internal/decoder/slice_type_decoder.go | 8 ++++---- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/pkg/app/server/binding/config.go b/pkg/app/server/binding/config.go index a3ae9dc30..6401c9b5f 100644 --- a/pkg/app/server/binding/config.go +++ b/pkg/app/server/binding/config.go @@ -17,12 +17,12 @@ package binding import ( - standardJson "encoding/json" + stdJson "encoding/json" "reflect" "github.com/bytedance/go-tagexpr/v2/validator" "github.com/cloudwego/hertz/pkg/app/server/binding/internal/decoder" - hjson "github.com/cloudwego/hertz/pkg/common/json" + hJson "github.com/cloudwego/hertz/pkg/common/json" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/route/param" ) @@ -32,7 +32,7 @@ import ( // // UseThirdPartyJSONUnmarshaler will remain in effect once it has been called. func UseThirdPartyJSONUnmarshaler(fn func(data []byte, v interface{}) error) { - hjson.Unmarshal = fn + hJson.Unmarshal = fn } // UseStdJSONUnmarshaler uses encoding/json as json library @@ -41,7 +41,7 @@ func UseThirdPartyJSONUnmarshaler(fn func(data []byte, v interface{}) error) { // The current version uses encoding/json by default. // UseStdJSONUnmarshaler will remain in effect once it has been called. func UseStdJSONUnmarshaler() { - UseThirdPartyJSONUnmarshaler(standardJson.Unmarshal) + UseThirdPartyJSONUnmarshaler(stdJson.Unmarshal) } // EnableDefaultTag is used to enable or disable adding default tags to a field when it has no tag, it is true by default. diff --git a/pkg/app/server/binding/default.go b/pkg/app/server/binding/default.go index 5faabf263..f73fc8fee 100644 --- a/pkg/app/server/binding/default.go +++ b/pkg/app/server/binding/default.go @@ -69,7 +69,7 @@ import ( "github.com/bytedance/go-tagexpr/v2/validator" "github.com/cloudwego/hertz/internal/bytesconv" inDecoder "github.com/cloudwego/hertz/pkg/app/server/binding/internal/decoder" - hjson "github.com/cloudwego/hertz/pkg/common/json" + hJson "github.com/cloudwego/hertz/pkg/common/json" "github.com/cloudwego/hertz/pkg/common/utils" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/protocol/consts" @@ -216,7 +216,7 @@ func (b *defaultBinder) BindJSON(req *protocol.Request, v interface{}) error { } func decodeJSON(r io.Reader, obj interface{}) error { - decoder := hjson.NewDecoder(r) + decoder := hJson.NewDecoder(r) if enableDecoderUseNumber { decoder.UseNumber() } @@ -317,7 +317,7 @@ func (b *defaultBinder) preBindBody(req *protocol.Request, v interface{}) error ct := bytesconv.B2s(req.Header.ContentType()) switch utils.FilterContentType(ct) { case consts.MIMEApplicationJSON: - return hjson.Unmarshal(req.Body(), v) + return hJson.Unmarshal(req.Body(), v) case consts.MIMEPROTOBUF: msg, ok := v.(proto.Message) if !ok { diff --git a/pkg/app/server/binding/internal/decoder/map_type_decoder.go b/pkg/app/server/binding/internal/decoder/map_type_decoder.go index 0b19a9ba8..f0a717a05 100644 --- a/pkg/app/server/binding/internal/decoder/map_type_decoder.go +++ b/pkg/app/server/binding/internal/decoder/map_type_decoder.go @@ -45,7 +45,7 @@ import ( "reflect" "github.com/cloudwego/hertz/internal/bytesconv" - hjson "github.com/cloudwego/hertz/pkg/common/json" + hJson "github.com/cloudwego/hertz/pkg/common/json" "github.com/cloudwego/hertz/pkg/common/utils" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/route/param" @@ -113,7 +113,7 @@ func (d *mapTypeFieldTextDecoder) Decode(req *protocol.Request, params param.Par return nil } - err = hjson.Unmarshal(bytesconv.S2b(text), field.Addr().Interface()) + err = hJson.Unmarshal(bytesconv.S2b(text), field.Addr().Interface()) if err != nil { return fmt.Errorf("unable to decode '%s' as %s: %w", text, d.fieldType.Name(), err) } diff --git a/pkg/app/server/binding/internal/decoder/slice_type_decoder.go b/pkg/app/server/binding/internal/decoder/slice_type_decoder.go index 08c0ba416..e9c8c28d9 100644 --- a/pkg/app/server/binding/internal/decoder/slice_type_decoder.go +++ b/pkg/app/server/binding/internal/decoder/slice_type_decoder.go @@ -46,7 +46,7 @@ import ( "reflect" "github.com/cloudwego/hertz/internal/bytesconv" - hjson "github.com/cloudwego/hertz/pkg/common/json" + hJson "github.com/cloudwego/hertz/pkg/common/json" "github.com/cloudwego/hertz/pkg/common/utils" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/route/param" @@ -153,7 +153,7 @@ func (d *sliceTypeFieldTextDecoder) Decode(req *protocol.Request, params param.P return err } // text[0] can be a complete json content for []Type. - err = hjson.Unmarshal(bytesconv.S2b(texts[0]), reqValue.Field(d.index).Addr().Interface()) + err = hJson.Unmarshal(bytesconv.S2b(texts[0]), reqValue.Field(d.index).Addr().Interface()) if err != nil { return fmt.Errorf("using '%s' to unmarshal type '%s' failed, %s", texts[0], reqValue.Field(d.index).Kind().String(), err.Error()) } @@ -233,9 +233,9 @@ func stringToValue(elemType reflect.Type, text string, req *protocol.Request, pa } switch elemType.Kind() { case reflect.Struct: - err = hjson.Unmarshal(bytesconv.S2b(text), v.Addr().Interface()) + err = hJson.Unmarshal(bytesconv.S2b(text), v.Addr().Interface()) case reflect.Map: - err = hjson.Unmarshal(bytesconv.S2b(text), v.Addr().Interface()) + err = hJson.Unmarshal(bytesconv.S2b(text), v.Addr().Interface()) case reflect.Array, reflect.Slice: // do nothing default: From b4361e78e00b23a4dabcb3422e5123bcef4fc735 Mon Sep 17 00:00:00 2001 From: fgy Date: Thu, 24 Aug 2023 22:29:45 +0800 Subject: [PATCH 56/91] refactor: var location --- pkg/app/server/binding/config.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/pkg/app/server/binding/config.go b/pkg/app/server/binding/config.go index 6401c9b5f..fc100ddbc 100644 --- a/pkg/app/server/binding/config.go +++ b/pkg/app/server/binding/config.go @@ -27,6 +27,11 @@ import ( "github.com/cloudwego/hertz/pkg/route/param" ) +var ( + enableDecoderUseNumber = false + enableDecoderDisallowUnknownFields = false +) + // UseThirdPartyJSONUnmarshaler uses third-party json library for binding // NOTE: // @@ -90,10 +95,6 @@ func SetValidatorErrorFactory(validatingErrFactory func(failField, msg string) e } } -var enableDecoderUseNumber = false - -var enableDecoderDisallowUnknownFields = false - // EnableDecoderUseNumber is used to call the UseNumber method on the JSON // Decoder instance. UseNumber causes the Decoder to unmarshal a number into an // interface{} as a Number instead of as a float64. From 3dcd79a4f1cd5209977e7acd7cc7a2b7e93493cf Mon Sep 17 00:00:00 2001 From: fgy Date: Fri, 25 Aug 2023 16:58:28 +0800 Subject: [PATCH 57/91] feat: add setLooseMode --- pkg/app/server/binding/binder_test.go | 32 ++++++++++++++ pkg/app/server/binding/config.go | 10 +++++ .../internal/decoder/base_type_decoder.go | 7 +-- .../decoder/customized_type_decoder.go | 5 ++- .../server/binding/internal/decoder/getter.go | 43 ++++++++++--------- .../internal/decoder/map_type_decoder.go | 7 +-- .../internal/decoder/slice_type_decoder.go | 2 +- .../internal/decoder/struct_type_decoder.go | 7 +-- .../binding/internal/decoder/text_decoder.go | 20 +++++++-- pkg/app/server/binding/tagexpr_bind_test.go | 4 ++ 10 files changed, 100 insertions(+), 37 deletions(-) diff --git a/pkg/app/server/binding/binder_test.go b/pkg/app/server/binding/binder_test.go index 6abb06d56..60a1bc486 100644 --- a/pkg/app/server/binding/binder_test.go +++ b/pkg/app/server/binding/binder_test.go @@ -53,6 +53,10 @@ import ( "github.com/cloudwego/hertz/pkg/route/param" ) +func init() { + SetLooseZeroMode(true) +} + type mockRequest struct { Req *protocol.Request } @@ -906,6 +910,34 @@ func TestBind_BindQuery(t *testing.T) { assert.DeepEqual(t, 52, result.Q5[1]) } +func TestBind_LooseMode(t *testing.T) { + SetLooseZeroMode(false) + defer SetLooseZeroMode(true) + type Req struct { + ID int `query:"id"` + } + + req := newMockRequest(). + SetRequestURI("http://foobar.com?id=") + + var result Req + + err := DefaultBinder().Bind(req.Req, &result, nil) + if err == nil { + t.Fatal("expected err") + } + assert.DeepEqual(t, 0, result.ID) + + SetLooseZeroMode(true) + var result2 Req + + err = DefaultBinder().Bind(req.Req, &result2, nil) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, 0, result.ID) +} + func Benchmark_Binding(b *testing.B) { type Req struct { Version string `path:"v"` diff --git a/pkg/app/server/binding/config.go b/pkg/app/server/binding/config.go index fc100ddbc..e27283395 100644 --- a/pkg/app/server/binding/config.go +++ b/pkg/app/server/binding/config.go @@ -32,6 +32,16 @@ var ( enableDecoderDisallowUnknownFields = false ) +// SetLooseZeroMode if set to true, +// the empty string request parameter is bound to the zero value of parameter. +// NOTE: +// +// The default is false; +// Suitable for these parameter types: query/header/cookie/form . +func SetLooseZeroMode(enable bool) { + decoder.SetLooseZeroMode(enable) +} + // UseThirdPartyJSONUnmarshaler uses third-party json library for binding // NOTE: // diff --git a/pkg/app/server/binding/internal/decoder/base_type_decoder.go b/pkg/app/server/binding/internal/decoder/base_type_decoder.go index 342c52960..dd16b0b47 100644 --- a/pkg/app/server/binding/internal/decoder/base_type_decoder.go +++ b/pkg/app/server/binding/internal/decoder/base_type_decoder.go @@ -65,6 +65,7 @@ type baseTypeFieldTextDecoder struct { func (d *baseTypeFieldTextDecoder) Decode(req *protocol.Request, params param.Params, reqValue reflect.Value) error { var err error var text string + var exist bool var defaultValue string for _, tagInfo := range d.tagInfos { if tagInfo.Skip || tagInfo.Key == jsonTag || tagInfo.Key == fileNameTag { @@ -82,9 +83,9 @@ func (d *baseTypeFieldTextDecoder) Decode(req *protocol.Request, params param.Pa if tagInfo.Key == headerTag { tagInfo.Value = utils.GetNormalizeHeaderKey(tagInfo.Value, req.Header.IsDisableNormalizing()) } - text = tagInfo.Getter(req, params, tagInfo.Value) + text, exist = tagInfo.Getter(req, params, tagInfo.Value) defaultValue = tagInfo.Default - if len(text) != 0 { + if exist { err = nil break } @@ -98,7 +99,7 @@ func (d *baseTypeFieldTextDecoder) Decode(req *protocol.Request, params param.Pa if len(text) == 0 && len(defaultValue) != 0 { text = defaultValue } - if len(text) == 0 { + if !exist && len(text) == 0 { return nil } diff --git a/pkg/app/server/binding/internal/decoder/customized_type_decoder.go b/pkg/app/server/binding/internal/decoder/customized_type_decoder.go index 9ea8f751e..37966ad27 100644 --- a/pkg/app/server/binding/internal/decoder/customized_type_decoder.go +++ b/pkg/app/server/binding/internal/decoder/customized_type_decoder.go @@ -99,6 +99,7 @@ type customizedFieldTextDecoder struct { func (d *customizedFieldTextDecoder) Decode(req *protocol.Request, params param.Params, reqValue reflect.Value) error { var text string + var exist bool var defaultValue string for _, tagInfo := range d.tagInfos { if tagInfo.Skip || tagInfo.Key == jsonTag || tagInfo.Key == fileNameTag { @@ -108,9 +109,9 @@ func (d *customizedFieldTextDecoder) Decode(req *protocol.Request, params param. if tagInfo.Key == headerTag { tagInfo.Value = utils.GetNormalizeHeaderKey(tagInfo.Value, req.Header.IsDisableNormalizing()) } - text = tagInfo.Getter(req, params, tagInfo.Value) + text, exist = tagInfo.Getter(req, params, tagInfo.Value) defaultValue = tagInfo.Default - if len(text) != 0 { + if exist { break } } diff --git a/pkg/app/server/binding/internal/decoder/getter.go b/pkg/app/server/binding/internal/decoder/getter.go index 4c605657a..81f8202c8 100644 --- a/pkg/app/server/binding/internal/decoder/getter.go +++ b/pkg/app/server/binding/internal/decoder/getter.go @@ -45,24 +45,21 @@ import ( "github.com/cloudwego/hertz/pkg/route/param" ) -type getter func(req *protocol.Request, params param.Params, key string, defaultValue ...string) (ret string) +type getter func(req *protocol.Request, params param.Params, key string, defaultValue ...string) (ret string, exist bool) -func path(req *protocol.Request, params param.Params, key string, defaultValue ...string) (ret string) { +func path(req *protocol.Request, params param.Params, key string, defaultValue ...string) (ret string, exist bool) { if params != nil { - ret, _ = params.Get(key) + ret, exist = params.Get(key) } if len(ret) == 0 && len(defaultValue) != 0 { ret = defaultValue[0] } - return ret + return ret, exist } -func postForm(req *protocol.Request, params param.Params, key string, defaultValue ...string) (ret string) { - if val := req.PostArgs().Peek(key); val != nil { - ret = string(val) - } - if len(ret) > 0 { +func postForm(req *protocol.Request, params param.Params, key string, defaultValue ...string) (ret string, exist bool) { + if ret, exist = req.PostArgs().PeekExists(key); exist { return } @@ -76,22 +73,22 @@ func postForm(req *protocol.Request, params param.Params, key string, defaultVal } if len(ret) != 0 { - return + return ret, true } - if val := req.URI().QueryArgs().Peek(key); val != nil { - ret = string(val) + if ret, exist = req.URI().QueryArgs().PeekExists(key); exist { + return } if len(ret) == 0 && len(defaultValue) != 0 { ret = defaultValue[0] } - return + return ret, false } -func query(req *protocol.Request, params param.Params, key string, defaultValue ...string) (ret string) { - if val := req.URI().QueryArgs().Peek(key); val != nil { - ret = string(val) +func query(req *protocol.Request, params param.Params, key string, defaultValue ...string) (ret string, exist bool) { + if ret, exist = req.URI().QueryArgs().PeekExists(key); exist { + return } if len(ret) == 0 && len(defaultValue) != 0 { @@ -101,33 +98,37 @@ func query(req *protocol.Request, params param.Params, key string, defaultValue return } -func cookie(req *protocol.Request, params param.Params, key string, defaultValue ...string) (ret string) { +func cookie(req *protocol.Request, params param.Params, key string, defaultValue ...string) (ret string, exist bool) { if val := req.Header.Cookie(key); val != nil { ret = string(val) + return ret, true } if len(ret) == 0 && len(defaultValue) != 0 { ret = defaultValue[0] } - return + return ret, false } -func header(req *protocol.Request, params param.Params, key string, defaultValue ...string) (ret string) { +func header(req *protocol.Request, params param.Params, key string, defaultValue ...string) (ret string, exist bool) { if val := req.Header.Peek(key); val != nil { ret = string(val) + return ret, true } if len(ret) == 0 && len(defaultValue) != 0 { ret = defaultValue[0] } - return + return ret, false } -func rawBody(req *protocol.Request, params param.Params, key string, defaultValue ...string) (ret string) { +func rawBody(req *protocol.Request, params param.Params, key string, defaultValue ...string) (ret string, exist bool) { + exist = false if req.Header.ContentLength() > 0 { ret = string(req.Body()) + exist = true } return } diff --git a/pkg/app/server/binding/internal/decoder/map_type_decoder.go b/pkg/app/server/binding/internal/decoder/map_type_decoder.go index f0a717a05..9bb5180b0 100644 --- a/pkg/app/server/binding/internal/decoder/map_type_decoder.go +++ b/pkg/app/server/binding/internal/decoder/map_type_decoder.go @@ -58,6 +58,7 @@ type mapTypeFieldTextDecoder struct { func (d *mapTypeFieldTextDecoder) Decode(req *protocol.Request, params param.Params, reqValue reflect.Value) error { var err error var text string + var exist bool var defaultValue string for _, tagInfo := range d.tagInfos { if tagInfo.Skip || tagInfo.Key == jsonTag || tagInfo.Key == fileNameTag { @@ -75,9 +76,9 @@ func (d *mapTypeFieldTextDecoder) Decode(req *protocol.Request, params param.Par if tagInfo.Key == headerTag { tagInfo.Value = utils.GetNormalizeHeaderKey(tagInfo.Value, req.Header.IsDisableNormalizing()) } - text = tagInfo.Getter(req, params, tagInfo.Value) + text, exist = tagInfo.Getter(req, params, tagInfo.Value) defaultValue = tagInfo.Default - if len(text) != 0 { + if exist { err = nil break } @@ -91,7 +92,7 @@ func (d *mapTypeFieldTextDecoder) Decode(req *protocol.Request, params param.Par if len(text) == 0 && len(defaultValue) != 0 { text = defaultValue } - if len(text) == 0 { + if !exist && len(text) == 0 { return nil } diff --git a/pkg/app/server/binding/internal/decoder/slice_type_decoder.go b/pkg/app/server/binding/internal/decoder/slice_type_decoder.go index e9c8c28d9..be505f03c 100644 --- a/pkg/app/server/binding/internal/decoder/slice_type_decoder.go +++ b/pkg/app/server/binding/internal/decoder/slice_type_decoder.go @@ -155,7 +155,7 @@ func (d *sliceTypeFieldTextDecoder) Decode(req *protocol.Request, params param.P // text[0] can be a complete json content for []Type. err = hJson.Unmarshal(bytesconv.S2b(texts[0]), reqValue.Field(d.index).Addr().Interface()) if err != nil { - return fmt.Errorf("using '%s' to unmarshal type '%s' failed, %s", texts[0], reqValue.Field(d.index).Kind().String(), err.Error()) + return fmt.Errorf("using '%s' to unmarshal field '%s: %s' failed, %v", texts[0], d.fieldName, d.fieldType.String(), err) } } else { reqValue.Field(d.index).Set(ReferenceValue(field, parentPtrDepth)) diff --git a/pkg/app/server/binding/internal/decoder/struct_type_decoder.go b/pkg/app/server/binding/internal/decoder/struct_type_decoder.go index 6f81d731b..ce0a4b237 100644 --- a/pkg/app/server/binding/internal/decoder/struct_type_decoder.go +++ b/pkg/app/server/binding/internal/decoder/struct_type_decoder.go @@ -34,6 +34,7 @@ type structTypeFieldTextDecoder struct { func (d *structTypeFieldTextDecoder) Decode(req *protocol.Request, params param.Params, reqValue reflect.Value) error { var err error var text string + var exist bool var defaultValue string for _, tagInfo := range d.tagInfos { if tagInfo.Skip || tagInfo.Key == jsonTag || tagInfo.Key == fileNameTag { @@ -51,9 +52,9 @@ func (d *structTypeFieldTextDecoder) Decode(req *protocol.Request, params param. if tagInfo.Key == headerTag { tagInfo.Value = utils.GetNormalizeHeaderKey(tagInfo.Value, req.Header.IsDisableNormalizing()) } - text = tagInfo.Getter(req, params, tagInfo.Value) + text, exist = tagInfo.Getter(req, params, tagInfo.Value) defaultValue = tagInfo.Default - if len(text) != 0 { + if exist { err = nil break } @@ -67,7 +68,7 @@ func (d *structTypeFieldTextDecoder) Decode(req *protocol.Request, params param. if len(text) == 0 && len(defaultValue) != 0 { text = defaultValue } - if text == "" { + if !exist && len(text) == 0 { return nil } reqValue = GetFieldValue(reqValue, d.parentIndex) diff --git a/pkg/app/server/binding/internal/decoder/text_decoder.go b/pkg/app/server/binding/internal/decoder/text_decoder.go index 349094272..dc06647fc 100644 --- a/pkg/app/server/binding/internal/decoder/text_decoder.go +++ b/pkg/app/server/binding/internal/decoder/text_decoder.go @@ -46,6 +46,18 @@ import ( "strconv" ) +var looseZeroMode = false + +// SetLooseZeroMode if set to true, +// the empty string request parameter is bound to the zero value of parameter. +// NOTE: +// +// The default is false; +// Suitable for these parameter types: query/header/cookie/form . +func SetLooseZeroMode(enable bool) { + looseZeroMode = enable +} + type TextDecoder interface { UnmarshalString(s string, fieldValue reflect.Value) error } @@ -88,7 +100,7 @@ func SelectTextDecoder(rt reflect.Type) (TextDecoder, error) { type boolDecoder struct{} func (d *boolDecoder) UnmarshalString(s string, fieldValue reflect.Value) error { - if s == "" { + if s == "" && looseZeroMode { s = "false" } v, err := strconv.ParseBool(s) @@ -104,7 +116,7 @@ type floatDecoder struct { } func (d *floatDecoder) UnmarshalString(s string, fieldValue reflect.Value) error { - if s == "" { + if s == "" && looseZeroMode { s = "0.0" } v, err := strconv.ParseFloat(s, d.bitSize) @@ -120,7 +132,7 @@ type intDecoder struct { } func (d *intDecoder) UnmarshalString(s string, fieldValue reflect.Value) error { - if s == "" { + if s == "" && looseZeroMode { s = "0" } v, err := strconv.ParseInt(s, 10, d.bitSize) @@ -143,7 +155,7 @@ type uintDecoder struct { } func (d *uintDecoder) UnmarshalString(s string, fieldValue reflect.Value) error { - if s == "" { + if s == "" && looseZeroMode { s = "0" } v, err := strconv.ParseUint(s, 10, d.bitSize) diff --git a/pkg/app/server/binding/tagexpr_bind_test.go b/pkg/app/server/binding/tagexpr_bind_test.go index 7cca8b0d3..1e38b69f1 100644 --- a/pkg/app/server/binding/tagexpr_bind_test.go +++ b/pkg/app/server/binding/tagexpr_bind_test.go @@ -52,6 +52,10 @@ import ( "google.golang.org/protobuf/proto" ) +func init() { + SetLooseZeroMode(true) +} + func TestRawBody(t *testing.T) { type Recv struct { S []byte `raw_body:""` From 6c2c3982ec0469952d2973edad05468fd668a5aa Mon Sep 17 00:00:00 2001 From: fgy Date: Mon, 28 Aug 2023 15:22:36 +0800 Subject: [PATCH 58/91] feat: add non-struct bind --- pkg/app/server/binding/binder_test.go | 11 ++ pkg/app/server/binding/default.go | 130 +++++++++++++++++--- pkg/app/server/binding/tagexpr_bind_test.go | 40 +++++- 3 files changed, 160 insertions(+), 21 deletions(-) diff --git a/pkg/app/server/binding/binder_test.go b/pkg/app/server/binding/binder_test.go index 60a1bc486..892b615f1 100644 --- a/pkg/app/server/binding/binder_test.go +++ b/pkg/app/server/binding/binder_test.go @@ -938,6 +938,17 @@ func TestBind_LooseMode(t *testing.T) { assert.DeepEqual(t, 0, result.ID) } +func TestBind_NonStruct(t *testing.T) { + req := newMockRequest(). + SetRequestURI("http://foobar.com?id=1&id=2") + var id interface{} + err := DefaultBinder().Bind(req.Req, &id, nil) + if err != nil { + t.Error(err) + } + +} + func Benchmark_Binding(b *testing.B) { type Req struct { Version string `path:"v"` diff --git a/pkg/app/server/binding/default.go b/pkg/app/server/binding/default.go index f73fc8fee..ed42019f3 100644 --- a/pkg/app/server/binding/default.go +++ b/pkg/app/server/binding/default.go @@ -61,8 +61,10 @@ package binding import ( + stdJson "encoding/json" "fmt" "io" + "net/url" "reflect" "sync" @@ -120,8 +122,17 @@ func (b *defaultBinder) BindQuery(req *protocol.Request, v interface{}) error { if rv.Kind() != reflect.Ptr || rv.IsNil() { return fmt.Errorf("receiver must be a non-nil pointer") } - if rv.Elem().Kind() == reflect.Map { - return nil + rt := rv.Type() + for rt.Kind() == reflect.Ptr { + rt = rt.Elem() + } + if rt.Kind() != reflect.Struct { + return b.bindNonStruct(req, v) + } + + err := b.preBindBody(req, v) + if err != nil { + return fmt.Errorf("bind body failed, err=%v", err) } cached, ok := b.queryDecoderCache.Load(typeID) if ok { @@ -144,8 +155,17 @@ func (b *defaultBinder) BindHeader(req *protocol.Request, v interface{}) error { if rv.Kind() != reflect.Ptr || rv.IsNil() { return fmt.Errorf("receiver must be a non-nil pointer") } - if rv.Elem().Kind() == reflect.Map { - return nil + rt := rv.Type() + for rt.Kind() == reflect.Ptr { + rt = rt.Elem() + } + if rt.Kind() != reflect.Struct { + return b.bindNonStruct(req, v) + } + + err := b.preBindBody(req, v) + if err != nil { + return fmt.Errorf("bind body failed, err=%v", err) } cached, ok := b.headerDecoderCache.Load(typeID) if ok { @@ -168,8 +188,17 @@ func (b *defaultBinder) BindPath(req *protocol.Request, v interface{}, params pa if rv.Kind() != reflect.Ptr || rv.IsNil() { return fmt.Errorf("receiver must be a non-nil pointer") } - if rv.Elem().Kind() == reflect.Map { - return nil + rt := rv.Type() + for rt.Kind() == reflect.Ptr { + rt = rt.Elem() + } + if rt.Kind() != reflect.Struct { + return b.bindNonStruct(req, v) + } + + err := b.preBindBody(req, v) + if err != nil { + return fmt.Errorf("bind body failed, err=%v", err) } cached, ok := b.pathDecoderCache.Load(typeID) if ok { @@ -192,8 +221,17 @@ func (b *defaultBinder) BindForm(req *protocol.Request, v interface{}) error { if rv.Kind() != reflect.Ptr || rv.IsNil() { return fmt.Errorf("receiver must be a non-nil pointer") } - if rv.Elem().Kind() == reflect.Map { - return nil + rt := rv.Type() + for rt.Kind() == reflect.Ptr { + rt = rt.Elem() + } + if rt.Kind() != reflect.Struct { + return b.bindNonStruct(req, v) + } + + err := b.preBindBody(req, v) + if err != nil { + return fmt.Errorf("bind body failed, err=%v", err) } cached, ok := b.formDecoderCache.Load(typeID) if ok { @@ -240,16 +278,21 @@ func (b *defaultBinder) Name() string { } func (b *defaultBinder) BindAndValidate(req *protocol.Request, v interface{}, params param.Params) error { - err := b.preBindBody(req, v) - if err != nil { - return fmt.Errorf("bind body failed, err=%v", err) - } rv, typeID := valueAndTypeID(v) if rv.Kind() != reflect.Ptr || rv.IsNil() { return fmt.Errorf("receiver must be a non-nil pointer") } - if rv.Elem().Kind() == reflect.Map { - return nil + rt := rv.Type() + for rt.Kind() == reflect.Ptr { + rt = rt.Elem() + } + if rt.Kind() != reflect.Struct { + return b.bindNonStruct(req, v) + } + + err := b.preBindBody(req, v) + if err != nil { + return fmt.Errorf("bind body failed, err=%v", err) } cached, ok := b.decoderCache.Load(typeID) if ok { @@ -282,16 +325,21 @@ func (b *defaultBinder) BindAndValidate(req *protocol.Request, v interface{}, pa } func (b *defaultBinder) Bind(req *protocol.Request, v interface{}, params param.Params) error { - err := b.preBindBody(req, v) - if err != nil { - return fmt.Errorf("bind body failed, err=%v", err) - } rv, typeID := valueAndTypeID(v) if rv.Kind() != reflect.Ptr || rv.IsNil() { return fmt.Errorf("receiver must be a non-nil pointer") } - if rv.Elem().Kind() == reflect.Map { - return nil + rt := rv.Type() + for rt.Kind() == reflect.Ptr { + rt = rt.Elem() + } + if rt.Kind() != reflect.Struct { + return b.bindNonStruct(req, v) + } + + err := b.preBindBody(req, v) + if err != nil { + return fmt.Errorf("bind body failed, err=%v", err) } cached, ok := b.decoderCache.Load(typeID) if ok { @@ -329,6 +377,48 @@ func (b *defaultBinder) preBindBody(req *protocol.Request, v interface{}) error } } +func (b *defaultBinder) bindNonStruct(req *protocol.Request, v interface{}) (err error) { + ct := bytesconv.B2s(req.Header.ContentType()) + switch utils.FilterContentType(ct) { + case consts.MIMEApplicationJSON: + err = hJson.Unmarshal(req.Body(), v) + case consts.MIMEPROTOBUF: + msg, ok := v.(proto.Message) + if !ok { + return fmt.Errorf("%s can not implement 'proto.Message'", v) + } + err = proto.Unmarshal(req.Body(), msg) + case consts.MIMEMultipartPOSTForm: + form := make(url.Values) + mf, err := req.MultipartForm() + if err == nil && mf.Value != nil { + for k, v := range mf.Value { + for _, vv := range v { + form.Add(k, vv) + } + } + } + b, _ := stdJson.Marshal(form) + err = hJson.Unmarshal(b, v) + case consts.MIMEApplicationHTMLForm: + form := make(url.Values) + req.PostArgs().VisitAll(func(formKey, value []byte) { + form.Add(string(formKey), string(value)) + }) + b, _ := stdJson.Marshal(form) + err = hJson.Unmarshal(b, v) + default: + // using query to decode + query := make(url.Values) + req.URI().QueryArgs().VisitAll(func(queryKey, value []byte) { + query.Add(string(queryKey), string(value)) + }) + b, _ := stdJson.Marshal(query) + err = hJson.Unmarshal(b, v) + } + return +} + var _ StructValidator = (*defaultValidator)(nil) type defaultValidator struct { diff --git a/pkg/app/server/binding/tagexpr_bind_test.go b/pkg/app/server/binding/tagexpr_bind_test.go index 1e38b69f1..0f24f696a 100644 --- a/pkg/app/server/binding/tagexpr_bind_test.go +++ b/pkg/app/server/binding/tagexpr_bind_test.go @@ -444,8 +444,46 @@ func TestJSON(t *testing.T) { assert.DeepEqual(t, (int64)(6), *recv.Z) } -// unsupported non-struct func TestNonstruct(t *testing.T) { + bodyReader := strings.NewReader(`{ + "X": { + "a": ["a1","a2"], + "B": 21, + "C": [31,32], + "d": 41, + "e": "qps", + "f": 100 + }, + "Z": 6 + }`) + + header := make(http.Header) + header.Set("Content-Type", "application/json") + req := newRequest("", header, nil, bodyReader) + var recv interface{} + err := DefaultBinder().Bind(req.Req, &recv, nil) + if err != nil { + t.Error(err) + } + b, err := json.Marshal(recv) + if err != nil { + t.Error(err) + } + t.Logf("%s", b) + + bodyReader = strings.NewReader("b=334ddddd&token=yoMba34uspjVQEbhflgTRe2ceeDFUK32&type=url_verification") + header.Set("Content-Type", "application/x-www-form-urlencoded; charset=utf-8") + req = newRequest("", header, nil, bodyReader) + recv = nil + err = DefaultBinder().Bind(req.Req, &recv, nil) + if err != nil { + t.Error(err) + } + b, err = json.Marshal(recv) + if err != nil { + t.Error(err) + } + t.Logf("%s", b) } func TestPath(t *testing.T) { From 2ee7e1b354d576b1d4051c40bc748933e0ff26ea Mon Sep 17 00:00:00 2001 From: fgy Date: Mon, 28 Aug 2023 17:30:51 +0800 Subject: [PATCH 59/91] feat: more test coverage --- pkg/app/server/binding/binder_test.go | 129 +++++++++++++++ pkg/app/server/binding/default.go | 225 ++++++++------------------ pkg/app/server/binding/reflect.go | 16 ++ 3 files changed, 213 insertions(+), 157 deletions(-) diff --git a/pkg/app/server/binding/binder_test.go b/pkg/app/server/binding/binder_test.go index 892b615f1..8c4e123c1 100644 --- a/pkg/app/server/binding/binder_test.go +++ b/pkg/app/server/binding/binder_test.go @@ -949,6 +949,135 @@ func TestBind_NonStruct(t *testing.T) { } +func TestBind_BindTag(t *testing.T) { + type Req struct { + Query string + Header string + Path string + Form string + } + req := newMockRequest(). + SetRequestURI("http://foobar.com?Query=query"). + SetHeader("Header", "header"). + SetPostArg("Form", "form") + var params param.Params + params = append(params, param.Param{ + Key: "Path", + Value: "path", + }) + result := Req{} + + // test query tag + err := DefaultBinder().BindQuery(req.Req, &result) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, "query", result.Query) + + // test header tag + result = Req{} + err = DefaultBinder().BindHeader(req.Req, &result) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, "header", result.Header) + + // test form tag + result = Req{} + err = DefaultBinder().BindForm(req.Req, &result) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, "form", result.Form) + + // test path tag + result = Req{} + err = DefaultBinder().BindPath(req.Req, &result, params) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, "path", result.Path) + + // test json tag + req = newMockRequest(). + SetRequestURI("http://foobar.com"). + SetJSONContentType(). + SetBody([]byte("{\n \"Query\": \"query\",\n \"Path\": \"path\",\n \"Header\": \"header\",\n \"Form\": \"form\"\n}")) + result = Req{} + err = DefaultBinder().BindJSON(req.Req, &result) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, "form", result.Form) + assert.DeepEqual(t, "query", result.Query) + assert.DeepEqual(t, "header", result.Header) + assert.DeepEqual(t, "path", result.Path) +} + +func TestBind_BindAndValidate(t *testing.T) { + type Req struct { + ID int `query:"id" vd:"$>10"` + } + req := newMockRequest(). + SetRequestURI("http://foobar.com?id=12") + + // test bindAndValidate + var result Req + err := BindAndValidate(req.Req, &result, nil) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, 12, result.ID) + + // test bind + result = Req{} + err = Bind(req.Req, &result, nil) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, 12, result.ID) + + // test validate + req = newMockRequest(). + SetRequestURI("http://foobar.com?id=9") + result = Req{} + err = Bind(req.Req, &result, nil) + if err != nil { + t.Error(err) + } + err = Validate(result) + if err == nil { + t.Errorf("expect an error, but get nil") + } + assert.DeepEqual(t, 9, result.ID) + +} + +func TestBind_FastPath(t *testing.T) { + type Req struct { + ID int `query:"id"` + } + req := newMockRequest(). + SetRequestURI("http://foobar.com?id=12") + + // test bindAndValidate + var result Req + err := BindAndValidate(req.Req, &result, nil) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, 12, result.ID) + // execute multiple times, test cache + for i := 0; i < 10; i++ { + result = Req{} + err := BindAndValidate(req.Req, &result, nil) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, 12, result.ID) + } +} + func Benchmark_Binding(b *testing.B) { type Req struct { Version string `path:"v"` diff --git a/pkg/app/server/binding/default.go b/pkg/app/server/binding/default.go index ed42019f3..7c29faad7 100644 --- a/pkg/app/server/binding/default.go +++ b/pkg/app/server/binding/default.go @@ -61,6 +61,7 @@ package binding import ( + "bytes" stdJson "encoding/json" "fmt" "io" @@ -79,6 +80,13 @@ import ( "google.golang.org/protobuf/proto" ) +const ( + queryTag = "query" + headerTag = "header" + formTag = "form" + pathTag = "path" +) + type decoderInfo struct { decoder inDecoder.Decoder needValidate bool @@ -117,48 +125,27 @@ func Validate(obj interface{}) error { return DefaultValidator().ValidateStruct(obj) } -func (b *defaultBinder) BindQuery(req *protocol.Request, v interface{}) error { - rv, typeID := valueAndTypeID(v) - if rv.Kind() != reflect.Ptr || rv.IsNil() { - return fmt.Errorf("receiver must be a non-nil pointer") - } - rt := rv.Type() - for rt.Kind() == reflect.Ptr { - rt = rt.Elem() - } - if rt.Kind() != reflect.Struct { - return b.bindNonStruct(req, v) - } - - err := b.preBindBody(req, v) - if err != nil { - return fmt.Errorf("bind body failed, err=%v", err) - } - cached, ok := b.queryDecoderCache.Load(typeID) - if ok { - // cached fieldDecoder, fast path - decoder := cached.(decoderInfo) - return decoder.decoder(req, nil, rv.Elem()) - } - - decoder, needValidate, err := inDecoder.GetReqDecoder(rv.Type(), "query") - if err != nil { - return err +func (b *defaultBinder) tagCache(tag string) *sync.Map { + switch tag { + case queryTag: + return &b.queryDecoderCache + case headerTag: + return &b.headerDecoderCache + case formTag: + return &b.formDecoderCache + case pathTag: + return &b.pathDecoderCache + default: + return &b.decoderCache } - - b.queryDecoderCache.Store(typeID, decoderInfo{decoder: decoder, needValidate: needValidate}) - return decoder(req, nil, rv.Elem()) } -func (b *defaultBinder) BindHeader(req *protocol.Request, v interface{}) error { +func (b *defaultBinder) bindTag(req *protocol.Request, v interface{}, params param.Params, tag string) error { rv, typeID := valueAndTypeID(v) - if rv.Kind() != reflect.Ptr || rv.IsNil() { - return fmt.Errorf("receiver must be a non-nil pointer") - } - rt := rv.Type() - for rt.Kind() == reflect.Ptr { - rt = rt.Elem() + if err := checkPointer(rv); err != nil { + return err } + rt := dereferPointer(rv) if rt.Kind() != reflect.Struct { return b.bindNonStruct(req, v) } @@ -167,31 +154,29 @@ func (b *defaultBinder) BindHeader(req *protocol.Request, v interface{}) error { if err != nil { return fmt.Errorf("bind body failed, err=%v", err) } - cached, ok := b.headerDecoderCache.Load(typeID) + cache := b.tagCache(tag) + cached, ok := cache.Load(typeID) if ok { // cached fieldDecoder, fast path decoder := cached.(decoderInfo) - return decoder.decoder(req, nil, rv.Elem()) + return decoder.decoder(req, params, rv.Elem()) } - decoder, needValidate, err := inDecoder.GetReqDecoder(rv.Type(), "header") + decoder, needValidate, err := inDecoder.GetReqDecoder(rv.Type(), tag) if err != nil { return err } - b.headerDecoderCache.Store(typeID, decoderInfo{decoder: decoder, needValidate: needValidate}) - return decoder(req, nil, rv.Elem()) + cache.Store(typeID, decoderInfo{decoder: decoder, needValidate: needValidate}) + return decoder(req, params, rv.Elem()) } -func (b *defaultBinder) BindPath(req *protocol.Request, v interface{}, params param.Params) error { +func (b *defaultBinder) bindTagWithValidate(req *protocol.Request, v interface{}, params param.Params, tag string) error { rv, typeID := valueAndTypeID(v) - if rv.Kind() != reflect.Ptr || rv.IsNil() { - return fmt.Errorf("receiver must be a non-nil pointer") - } - rt := rv.Type() - for rt.Kind() == reflect.Ptr { - rt = rt.Elem() + if err := checkPointer(rv); err != nil { + return err } + rt := dereferPointer(rv) if rt.Kind() != reflect.Struct { return b.bindNonStruct(req, v) } @@ -200,57 +185,55 @@ func (b *defaultBinder) BindPath(req *protocol.Request, v interface{}, params pa if err != nil { return fmt.Errorf("bind body failed, err=%v", err) } - cached, ok := b.pathDecoderCache.Load(typeID) + cache := b.tagCache(tag) + cached, ok := cache.Load(typeID) if ok { // cached fieldDecoder, fast path decoder := cached.(decoderInfo) - return decoder.decoder(req, params, rv.Elem()) + err = decoder.decoder(req, params, rv.Elem()) + if err != nil { + return err + } + if decoder.needValidate { + err = DefaultValidator().ValidateStruct(rv.Elem()) + } + return err } - decoder, needValidate, err := inDecoder.GetReqDecoder(rv.Type(), "path") + decoder, needValidate, err := inDecoder.GetReqDecoder(rv.Type(), tag) if err != nil { return err } - b.pathDecoderCache.Store(typeID, decoderInfo{decoder: decoder, needValidate: needValidate}) - return decoder(req, params, rv.Elem()) -} - -func (b *defaultBinder) BindForm(req *protocol.Request, v interface{}) error { - rv, typeID := valueAndTypeID(v) - if rv.Kind() != reflect.Ptr || rv.IsNil() { - return fmt.Errorf("receiver must be a non-nil pointer") - } - rt := rv.Type() - for rt.Kind() == reflect.Ptr { - rt = rt.Elem() - } - if rt.Kind() != reflect.Struct { - return b.bindNonStruct(req, v) - } - - err := b.preBindBody(req, v) + cache.Store(typeID, decoderInfo{decoder: decoder, needValidate: needValidate}) + err = decoder(req, params, rv.Elem()) if err != nil { - return fmt.Errorf("bind body failed, err=%v", err) + return err } - cached, ok := b.formDecoderCache.Load(typeID) - if ok { - // cached fieldDecoder, fast path - decoder := cached.(decoderInfo) - return decoder.decoder(req, nil, rv.Elem()) + if needValidate { + err = DefaultValidator().ValidateStruct(rv.Elem()) } + return err +} - decoder, needValidate, err := inDecoder.GetReqDecoder(rv.Type(), "form") - if err != nil { - return err - } +func (b *defaultBinder) BindQuery(req *protocol.Request, v interface{}) error { + return b.bindTag(req, v, nil, queryTag) +} - b.formDecoderCache.Store(typeID, decoderInfo{decoder: decoder, needValidate: needValidate}) - return decoder(req, nil, rv.Elem()) +func (b *defaultBinder) BindHeader(req *protocol.Request, v interface{}) error { + return b.bindTag(req, v, nil, headerTag) +} + +func (b *defaultBinder) BindPath(req *protocol.Request, v interface{}, params param.Params) error { + return b.bindTag(req, v, params, pathTag) +} + +func (b *defaultBinder) BindForm(req *protocol.Request, v interface{}) error { + return b.bindTag(req, v, nil, formTag) } func (b *defaultBinder) BindJSON(req *protocol.Request, v interface{}) error { - return decodeJSON(req.BodyStream(), v) + return decodeJSON(bytes.NewReader(req.Body()), v) } func decodeJSON(r io.Reader, obj interface{}) error { @@ -278,83 +261,11 @@ func (b *defaultBinder) Name() string { } func (b *defaultBinder) BindAndValidate(req *protocol.Request, v interface{}, params param.Params) error { - rv, typeID := valueAndTypeID(v) - if rv.Kind() != reflect.Ptr || rv.IsNil() { - return fmt.Errorf("receiver must be a non-nil pointer") - } - rt := rv.Type() - for rt.Kind() == reflect.Ptr { - rt = rt.Elem() - } - if rt.Kind() != reflect.Struct { - return b.bindNonStruct(req, v) - } - - err := b.preBindBody(req, v) - if err != nil { - return fmt.Errorf("bind body failed, err=%v", err) - } - cached, ok := b.decoderCache.Load(typeID) - if ok { - // cached fieldDecoder, fast path - decoder := cached.(decoderInfo) - err = decoder.decoder(req, params, rv.Elem()) - if err != nil { - return err - } - if decoder.needValidate { - err = DefaultValidator().ValidateStruct(rv.Elem()) - } - return err - } - - decoder, needValidate, err := inDecoder.GetReqDecoder(rv.Type(), "") - if err != nil { - return err - } - - b.decoderCache.Store(typeID, decoderInfo{decoder: decoder, needValidate: needValidate}) - err = decoder(req, params, rv.Elem()) - if err != nil { - return err - } - if needValidate { - err = DefaultValidator().ValidateStruct(rv.Elem()) - } - return err + return b.bindTagWithValidate(req, v, params, "") } func (b *defaultBinder) Bind(req *protocol.Request, v interface{}, params param.Params) error { - rv, typeID := valueAndTypeID(v) - if rv.Kind() != reflect.Ptr || rv.IsNil() { - return fmt.Errorf("receiver must be a non-nil pointer") - } - rt := rv.Type() - for rt.Kind() == reflect.Ptr { - rt = rt.Elem() - } - if rt.Kind() != reflect.Struct { - return b.bindNonStruct(req, v) - } - - err := b.preBindBody(req, v) - if err != nil { - return fmt.Errorf("bind body failed, err=%v", err) - } - cached, ok := b.decoderCache.Load(typeID) - if ok { - // cached fieldDecoder, fast path - decoder := cached.(decoderInfo) - return decoder.decoder(req, params, rv.Elem()) - } - - decoder, needValidate, err := inDecoder.GetReqDecoder(rv.Type(), "") - if err != nil { - return err - } - - b.decoderCache.Store(typeID, decoderInfo{decoder: decoder, needValidate: needValidate}) - return decoder(req, params, rv.Elem()) + return b.bindTag(req, v, params, "") } // best effort binding diff --git a/pkg/app/server/binding/reflect.go b/pkg/app/server/binding/reflect.go index 0b3be8f6b..502de11d2 100644 --- a/pkg/app/server/binding/reflect.go +++ b/pkg/app/server/binding/reflect.go @@ -41,6 +41,7 @@ package binding import ( + "fmt" "reflect" "unsafe" ) @@ -55,3 +56,18 @@ type emptyInterface struct { typeID uintptr dataPtr unsafe.Pointer } + +func checkPointer(rv reflect.Value) error { + if rv.Kind() != reflect.Ptr || rv.IsNil() { + return fmt.Errorf("receiver must be a non-nil pointer") + } + return nil +} + +func dereferPointer(rv reflect.Value) reflect.Type { + rt := rv.Type() + for rt.Kind() == reflect.Ptr { + rt = rt.Elem() + } + return rt +} From 4f80efc1ca10602bc20a1fa0d2c1334966b431fb Mon Sep 17 00:00:00 2001 From: fgy Date: Mon, 28 Aug 2023 19:07:50 +0800 Subject: [PATCH 60/91] feat: more test coverage --- pkg/app/server/binding/binder_test.go | 82 +++++++++++- pkg/app/server/binding/testdata/hello.pb.go | 141 ++++++++++++++++++++ pkg/app/server/binding/testdata/hello.proto | 8 ++ 3 files changed, 230 insertions(+), 1 deletion(-) create mode 100644 pkg/app/server/binding/testdata/hello.pb.go create mode 100644 pkg/app/server/binding/testdata/hello.proto diff --git a/pkg/app/server/binding/binder_test.go b/pkg/app/server/binding/binder_test.go index 8c4e123c1..2076144c4 100644 --- a/pkg/app/server/binding/binder_test.go +++ b/pkg/app/server/binding/binder_test.go @@ -42,6 +42,8 @@ package binding import ( "fmt" + "github.com/cloudwego/hertz/pkg/app/server/binding/testdata" + "google.golang.org/protobuf/proto" "mime/multipart" "reflect" "testing" @@ -102,6 +104,11 @@ func (m *mockRequest) SetJSONContentType() *mockRequest { return m } +func (m *mockRequest) SetProtobufContentType() *mockRequest { + m.Req.Header.SetContentTypeBytes([]byte(consts.MIMEPROTOBUF)) + return m +} + func (m *mockRequest) SetBody(data []byte) *mockRequest { m.Req.SetBody(data) m.Req.Header.SetContentLength(len(data)) @@ -947,6 +954,10 @@ func TestBind_NonStruct(t *testing.T) { t.Error(err) } + err = DefaultBinder().BindAndValidate(req.Req, &id, nil) + if err != nil { + t.Error(err) + } } func TestBind_BindTag(t *testing.T) { @@ -1055,7 +1066,7 @@ func TestBind_BindAndValidate(t *testing.T) { func TestBind_FastPath(t *testing.T) { type Req struct { - ID int `query:"id"` + ID int `query:"id" vd:"$>10"` } req := newMockRequest(). SetRequestURI("http://foobar.com?id=12") @@ -1078,6 +1089,75 @@ func TestBind_FastPath(t *testing.T) { } } +func TestBind_NonPointer(t *testing.T) { + type Req struct { + ID int `query:"id" vd:"$>10"` + } + req := newMockRequest(). + SetRequestURI("http://foobar.com?id=12") + + // test bindAndValidate + var result Req + err := BindAndValidate(req.Req, result, nil) + if err == nil { + t.Error("expect an error, but get nil") + } + + err = Bind(req.Req, result, nil) + if err == nil { + t.Error("expect an error, but get nil") + } +} + +func TestBind_PreBind(t *testing.T) { + type Req struct { + Query string + Header string + Path string + Form string + } + // test json tag + req := newMockRequest(). + SetRequestURI("http://foobar.com"). + SetJSONContentType(). + SetBody([]byte("\n \"Query\": \"query\",\n \"Path\": \"path\",\n \"Header\": \"header\",\n \"Form\": \"form\"\n}")) + result := Req{} + err := DefaultBinder().Bind(req.Req, &result, nil) + if err == nil { + t.Error("expect an error, but get nil") + } + err = DefaultBinder().BindAndValidate(req.Req, &result, nil) + if err == nil { + t.Error("expect an error, but get nil") + } +} + +func TestBind_BindProtobuf(t *testing.T) { + data := testdata.HertzReq{Name: "hertz"} + body, err := proto.Marshal(&data) + if err != nil { + t.Fatal(err) + } + req := newMockRequest(). + SetRequestURI("http://foobar.com"). + SetProtobufContentType(). + SetBody(body) + + result := testdata.HertzReq{} + err = DefaultBinder().BindAndValidate(req.Req, &result, nil) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, "hertz", result.Name) + + result = testdata.HertzReq{} + err = DefaultBinder().BindProtobuf(req.Req, &result) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, "hertz", result.Name) +} + func Benchmark_Binding(b *testing.B) { type Req struct { Version string `path:"v"` diff --git a/pkg/app/server/binding/testdata/hello.pb.go b/pkg/app/server/binding/testdata/hello.pb.go new file mode 100644 index 000000000..c316985c5 --- /dev/null +++ b/pkg/app/server/binding/testdata/hello.pb.go @@ -0,0 +1,141 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.30.0 +// protoc v3.21.12 +// source: hello.proto + +package testdata + +import ( + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +type HertzReq struct { + state protoimpl.MessageState + sizeCache protoimpl.SizeCache + unknownFields protoimpl.UnknownFields + + Name string `protobuf:"bytes,1,opt,name=Name,proto3" json:"Name,omitempty"` +} + +func (x *HertzReq) Reset() { + *x = HertzReq{} + if protoimpl.UnsafeEnabled { + mi := &file_hello_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) + } +} + +func (x *HertzReq) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*HertzReq) ProtoMessage() {} + +func (x *HertzReq) ProtoReflect() protoreflect.Message { + mi := &file_hello_proto_msgTypes[0] + if protoimpl.UnsafeEnabled && x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use HertzReq.ProtoReflect.Descriptor instead. +func (*HertzReq) Descriptor() ([]byte, []int) { + return file_hello_proto_rawDescGZIP(), []int{0} +} + +func (x *HertzReq) GetName() string { + if x != nil { + return x.Name + } + return "" +} + +var File_hello_proto protoreflect.FileDescriptor + +var file_hello_proto_rawDesc = []byte{ + 0x0a, 0x0b, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0x2e, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x12, 0x05, 0x68, + 0x65, 0x72, 0x74, 0x7a, 0x22, 0x1e, 0x0a, 0x08, 0x48, 0x65, 0x72, 0x74, 0x7a, 0x52, 0x65, 0x71, + 0x12, 0x12, 0x0a, 0x04, 0x4e, 0x61, 0x6d, 0x65, 0x18, 0x01, 0x20, 0x01, 0x28, 0x09, 0x52, 0x04, + 0x4e, 0x61, 0x6d, 0x65, 0x42, 0x0d, 0x5a, 0x0b, 0x68, 0x65, 0x72, 0x74, 0x7a, 0x2f, 0x68, 0x65, + 0x6c, 0x6c, 0x6f, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, +} + +var ( + file_hello_proto_rawDescOnce sync.Once + file_hello_proto_rawDescData = file_hello_proto_rawDesc +) + +func file_hello_proto_rawDescGZIP() []byte { + file_hello_proto_rawDescOnce.Do(func() { + file_hello_proto_rawDescData = protoimpl.X.CompressGZIP(file_hello_proto_rawDescData) + }) + return file_hello_proto_rawDescData +} + +var file_hello_proto_msgTypes = make([]protoimpl.MessageInfo, 1) +var file_hello_proto_goTypes = []interface{}{ + (*HertzReq)(nil), // 0: hertz.HertzReq +} +var file_hello_proto_depIdxs = []int32{ + 0, // [0:0] is the sub-list for method output_type + 0, // [0:0] is the sub-list for method input_type + 0, // [0:0] is the sub-list for extension type_name + 0, // [0:0] is the sub-list for extension extendee + 0, // [0:0] is the sub-list for field type_name +} + +func init() { file_hello_proto_init() } +func file_hello_proto_init() { + if File_hello_proto != nil { + return + } + if !protoimpl.UnsafeEnabled { + file_hello_proto_msgTypes[0].Exporter = func(v interface{}, i int) interface{} { + switch v := v.(*HertzReq); i { + case 0: + return &v.state + case 1: + return &v.sizeCache + case 2: + return &v.unknownFields + default: + return nil + } + } + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: file_hello_proto_rawDesc, + NumEnums: 0, + NumMessages: 1, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_hello_proto_goTypes, + DependencyIndexes: file_hello_proto_depIdxs, + MessageInfos: file_hello_proto_msgTypes, + }.Build() + File_hello_proto = out.File + file_hello_proto_rawDesc = nil + file_hello_proto_goTypes = nil + file_hello_proto_depIdxs = nil +} diff --git a/pkg/app/server/binding/testdata/hello.proto b/pkg/app/server/binding/testdata/hello.proto new file mode 100644 index 000000000..1ec7f83f5 --- /dev/null +++ b/pkg/app/server/binding/testdata/hello.proto @@ -0,0 +1,8 @@ +syntax = "proto3"; +package hertz; +option go_package = "hertz/hello"; + +message HertzReq { + string Name = 1; +} + From 40a7b05cb949c4dbc302056c36f9e0f5b78c0f92 Mon Sep 17 00:00:00 2001 From: fgy Date: Mon, 28 Aug 2023 19:31:06 +0800 Subject: [PATCH 61/91] feat: struct decode error to warn --- .../server/binding/internal/decoder/struct_type_decoder.go | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pkg/app/server/binding/internal/decoder/struct_type_decoder.go b/pkg/app/server/binding/internal/decoder/struct_type_decoder.go index ce0a4b237..70fa12080 100644 --- a/pkg/app/server/binding/internal/decoder/struct_type_decoder.go +++ b/pkg/app/server/binding/internal/decoder/struct_type_decoder.go @@ -18,6 +18,7 @@ package decoder import ( "fmt" + "github.com/cloudwego/hertz/pkg/common/hlog" "reflect" "github.com/cloudwego/hertz/internal/bytesconv" @@ -83,7 +84,8 @@ func (d *structTypeFieldTextDecoder) Decode(req *protocol.Request, params param. var vv reflect.Value vv, err := stringToValue(t, text, req, params) if err != nil { - return fmt.Errorf("unable to decode '%s' as %s: %w", text, d.fieldType.Name(), err) + hlog.Warnf("unable to decode '%s' as %s: %v, but it may not affect correctness, so skip it", text, d.fieldType.Name(), err) + return nil } field.Set(ReferenceValue(vv, ptrDepth)) return nil @@ -91,7 +93,7 @@ func (d *structTypeFieldTextDecoder) Decode(req *protocol.Request, params param. err = hjson.Unmarshal(bytesconv.S2b(text), field.Addr().Interface()) if err != nil { - return fmt.Errorf("unable to decode '%s' as %s: %w", text, d.fieldType.Name(), err) + hlog.Warnf("unable to decode '%s' as %s: %v, but it may not affect correctness, so skip it", text, d.fieldType.Name(), err) } return nil From 239833077b652f863ce5607cc804f904a32ec51c Mon Sep 17 00:00:00 2001 From: fgy Date: Mon, 28 Aug 2023 20:09:23 +0800 Subject: [PATCH 62/91] feat: more test coverage --- pkg/app/server/binding/binder_test.go | 75 +++++++++++++++++++++++++++ 1 file changed, 75 insertions(+) diff --git a/pkg/app/server/binding/binder_test.go b/pkg/app/server/binding/binder_test.go index 2076144c4..e42843fd8 100644 --- a/pkg/app/server/binding/binder_test.go +++ b/pkg/app/server/binding/binder_test.go @@ -45,6 +45,7 @@ import ( "github.com/cloudwego/hertz/pkg/app/server/binding/testdata" "google.golang.org/protobuf/proto" "mime/multipart" + "net/url" "reflect" "testing" @@ -1158,6 +1159,80 @@ func TestBind_BindProtobuf(t *testing.T) { assert.DeepEqual(t, "hertz", result.Name) } +func TestBind_PointerStruct(t *testing.T) { + EnableStructFieldResolve(true) + defer EnableStructFieldResolve(false) + type Foo struct { + F1 string `query:"F1"` + } + type Bar struct { + B1 **Foo `query:"B1,required"` + } + query := make(url.Values) + query.Add("B1", "{\n \"F1\": \"111\"\n}") + + var result Bar + req := newMockRequest(). + SetRequestURI(fmt.Sprintf("http://foobar.com?%s", query.Encode())) + + err := DefaultBinder().Bind(req.Req, &result, nil) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, "111", (**result.B1).F1) + + result = Bar{} + req = newMockRequest(). + SetRequestURI(fmt.Sprintf("http://foobar.com?%s&F1=222", query.Encode())) + err = DefaultBinder().Bind(req.Req, &result, nil) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, "222", (**result.B1).F1) +} + +func TestBind_StructRequired(t *testing.T) { + EnableStructFieldResolve(true) + defer EnableStructFieldResolve(false) + type Foo struct { + F1 string `query:"F1"` + } + type Bar struct { + B1 **Foo `query:"B1,required"` + } + + var result Bar + req := newMockRequest(). + SetRequestURI(fmt.Sprintf("http://foobar.com")) + + err := DefaultBinder().Bind(req.Req, &result, nil) + if err == nil { + t.Error("expect an error, but get nil") + } +} + +func TestBind_StructErrorToWarn(t *testing.T) { + EnableStructFieldResolve(true) + defer EnableStructFieldResolve(false) + type Foo struct { + F1 string `query:"F1"` + } + type Bar struct { + B1 **Foo `query:"B1,required"` + } + + var result Bar + req := newMockRequest(). + SetRequestURI(fmt.Sprintf("http://foobar.com?B1=111&F1=222")) + + err := DefaultBinder().Bind(req.Req, &result, nil) + // transfer 'unmarsahl err' to 'warn' + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, "222", (**result.B1).F1) +} + func Benchmark_Binding(b *testing.B) { type Req struct { Version string `path:"v"` From 1d23b7823a10abbc8b79bb14d18e6ff505b98c4d Mon Sep 17 00:00:00 2001 From: fgy Date: Mon, 28 Aug 2023 21:23:24 +0800 Subject: [PATCH 63/91] feat: support interface --- pkg/app/server/binding/binder_test.go | 97 ++++++++++++++++++- .../binding/internal/decoder/text_decoder.go | 16 +++ 2 files changed, 108 insertions(+), 5 deletions(-) diff --git a/pkg/app/server/binding/binder_test.go b/pkg/app/server/binding/binder_test.go index e42843fd8..969d21a52 100644 --- a/pkg/app/server/binding/binder_test.go +++ b/pkg/app/server/binding/binder_test.go @@ -41,6 +41,7 @@ package binding import ( + "encoding/json" "fmt" "github.com/cloudwego/hertz/pkg/app/server/binding/testdata" "google.golang.org/protobuf/proto" @@ -338,6 +339,24 @@ func TestBind_MapFieldType(t *testing.T) { } assert.DeepEqual(t, 1, len(***result.F1)) assert.DeepEqual(t, "f1", (***result.F1)["f1"]) + + type Foo2 struct { + F1 map[string]string `query:"f1" json:"f1"` + } + result2 := Foo2{} + err = DefaultBinder().Bind(req.Req, &result2, nil) + if err != nil { + t.Fatal(err) + } + assert.DeepEqual(t, 1, len(result2.F1)) + assert.DeepEqual(t, "f1", result2.F1["f1"]) + req = newMockRequest(). + SetRequestURI("http://foobar.com?f1={\"f1\":\"f1\"") + result2 = Foo2{} + err = DefaultBinder().Bind(req.Req, &result2, nil) + if err == nil { + t.Error(err) + } } func TestBind_UnexportedField(t *testing.T) { @@ -681,14 +700,14 @@ func TestBind_FileSliceBind(t *testing.T) { func TestBind_AnonymousField(t *testing.T) { type nest struct { - n1 string `query:"n1"` // bind default value - N2 ***string `query:"n2"` // bind n2 value - string `query:"n3"` // bind default value + n1 string `query:"n1"` // bind default value + N2 ***string `query:"n2"` // bind n2 value + string `query:"n3"` // bind default value } var s struct { - s1 int `query:"s1"` // bind default value - int `query:"s2"` // bind default value + s1 int `query:"s1"` // bind default value + int `query:"s2"` // bind default value nest } req := newMockRequest(). @@ -1209,6 +1228,18 @@ func TestBind_StructRequired(t *testing.T) { if err == nil { t.Error("expect an error, but get nil") } + + type Bar2 struct { + B1 **Foo `query:"B1"` + } + var result2 Bar2 + req = newMockRequest(). + SetRequestURI(fmt.Sprintf("http://foobar.com")) + + err = DefaultBinder().Bind(req.Req, &result2, nil) + if err != nil { + t.Error(err) + } } func TestBind_StructErrorToWarn(t *testing.T) { @@ -1231,6 +1262,62 @@ func TestBind_StructErrorToWarn(t *testing.T) { t.Error(err) } assert.DeepEqual(t, "222", (**result.B1).F1) + + type Bar2 struct { + B1 Foo `query:"B1,required"` + } + var result2 Bar2 + err = DefaultBinder().Bind(req.Req, &result2, nil) + // transfer 'unmarsahl err' to 'warn' + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, "222", result2.B1.F1) +} + +func TestBind_DisallowUnknownFieldsConfig(t *testing.T) { + EnableDecoderDisallowUnknownFields(true) + defer EnableDecoderDisallowUnknownFields(false) + type FooStructUseNumber struct { + Foo interface{} `json:"foo"` + } + req := newMockRequest(). + SetRequestURI("http://foobar.com"). + SetJSONContentType(). + SetBody([]byte(`{"foo": 123,"bar": "456"}`)) + var result FooStructUseNumber + + err := DefaultBinder().BindJSON(req.Req, &result) + if err == nil { + t.Errorf("expected an error, but get nil") + } +} + +func TestBind_UseNumberConfig(t *testing.T) { + EnableDecoderUseNumber(true) + defer EnableDecoderUseNumber(false) + type FooStructUseNumber struct { + Foo interface{} `json:"foo"` + } + req := newMockRequest(). + SetRequestURI("http://foobar.com"). + SetJSONContentType(). + SetBody([]byte(`{"foo": 123}`)) + var result FooStructUseNumber + + err := DefaultBinder().BindJSON(req.Req, &result) + if err != nil { + t.Error(err) + } + v, err := result.Foo.(json.Number).Int64() + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, int64(123), v) +} + +func TestBind_InterfaceType(t *testing.T) { + } func Benchmark_Binding(b *testing.B) { diff --git a/pkg/app/server/binding/internal/decoder/text_decoder.go b/pkg/app/server/binding/internal/decoder/text_decoder.go index dc06647fc..45b87332b 100644 --- a/pkg/app/server/binding/internal/decoder/text_decoder.go +++ b/pkg/app/server/binding/internal/decoder/text_decoder.go @@ -44,6 +44,9 @@ import ( "fmt" "reflect" "strconv" + + "github.com/cloudwego/hertz/internal/bytesconv" + hJson "github.com/cloudwego/hertz/pkg/common/json" ) var looseZeroMode = false @@ -92,6 +95,9 @@ func SelectTextDecoder(rt reflect.Type) (TextDecoder, error) { return &floatDecoder{bitSize: 32}, nil case reflect.Float64: return &floatDecoder{bitSize: 64}, nil + case reflect.Interface: + return &interfaceDecoder{}, nil + } return nil, fmt.Errorf("unsupported type " + rt.String()) @@ -165,3 +171,13 @@ func (d *uintDecoder) UnmarshalString(s string, fieldValue reflect.Value) error fieldValue.SetUint(v) return nil } + +type interfaceDecoder struct { +} + +func (d *interfaceDecoder) UnmarshalString(s string, fieldValue reflect.Value) error { + if s == "" && looseZeroMode { + s = "0" + } + return hJson.Unmarshal(bytesconv.S2b(s), fieldValue.Addr().Interface()) +} From cd687f8a9cc96f12746d3fd065cfbae06448e3d5 Mon Sep 17 00:00:00 2001 From: fgy Date: Mon, 28 Aug 2023 21:43:25 +0800 Subject: [PATCH 64/91] feat: more test coverage --- pkg/app/server/binding/binder_test.go | 44 +++++++++++++++++++++--- pkg/app/server/binding/validator_test.go | 30 ++++++++++++++++ 2 files changed, 69 insertions(+), 5 deletions(-) diff --git a/pkg/app/server/binding/binder_test.go b/pkg/app/server/binding/binder_test.go index 969d21a52..ab9383582 100644 --- a/pkg/app/server/binding/binder_test.go +++ b/pkg/app/server/binding/binder_test.go @@ -565,6 +565,18 @@ func TestBind_CustomizedTypeDecode(t *testing.T) { assert.DeepEqual(t, "1", (***(*result2.B).F).A) } +func TestBind_CustomizedTypeDecodeForPanic(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("expect a panic, but get nil") + } + }() + + MustRegTypeUnmarshal(reflect.TypeOf(string("")), func(req *protocol.Request, params param.Params, text string) (reflect.Value, error) { + return reflect.Value{}, nil + }) +} + func TestBind_JSON(t *testing.T) { type Req struct { J1 string `json:"j1"` @@ -700,14 +712,14 @@ func TestBind_FileSliceBind(t *testing.T) { func TestBind_AnonymousField(t *testing.T) { type nest struct { - n1 string `query:"n1"` // bind default value - N2 ***string `query:"n2"` // bind n2 value - string `query:"n3"` // bind default value + n1 string `query:"n1"` // bind default value + N2 ***string `query:"n2"` // bind n2 value + string `query:"n3"` // bind default value } var s struct { - s1 int `query:"s1"` // bind default value - int `query:"s2"` // bind default value + s1 int `query:"s1"` // bind default value + int `query:"s2"` // bind default value nest } req := newMockRequest(). @@ -1317,7 +1329,29 @@ func TestBind_UseNumberConfig(t *testing.T) { } func TestBind_InterfaceType(t *testing.T) { + type Bar struct { + B1 interface{} `query:"B1"` + } + + var result Bar + query := make(url.Values) + query.Add("B1", `{"B1":"111"}`) + req := newMockRequest(). + SetRequestURI(fmt.Sprintf("http://foobar.com?%s", query.Encode())) + err := DefaultBinder().Bind(req.Req, &result, nil) + if err != nil { + t.Error(err) + } + type Bar2 struct { + B2 *interface{} `query:"B1"` + } + + var result2 Bar2 + err = DefaultBinder().Bind(req.Req, &result2, nil) + if err != nil { + t.Error(err) + } } func Benchmark_Binding(b *testing.B) { diff --git a/pkg/app/server/binding/validator_test.go b/pkg/app/server/binding/validator_test.go index 2f85716b5..b02222bc6 100644 --- a/pkg/app/server/binding/validator_test.go +++ b/pkg/app/server/binding/validator_test.go @@ -17,6 +17,8 @@ package binding import ( + "fmt" + "github.com/cloudwego/hertz/pkg/common/test/assert" "testing" ) @@ -33,3 +35,31 @@ func Test_ValidateStruct(t *testing.T) { t.Fatalf("expected an error, but got nil") } } + +type mockValidator struct{} + +func (m *mockValidator) ValidateStruct(interface{}) error { + return fmt.Errorf("test mock") + +} + +func (m *mockValidator) Engine() interface{} { + return nil +} + +func Test_ResetValidatorConfig(t *testing.T) { + m := &mockValidator{} + ResetValidator(m, "vt") + type User struct { + Age int `vt:"$>=0&&$<=130"` + } + + user := &User{ + Age: 135, + } + err := DefaultValidator().ValidateStruct(user) + if err == nil { + t.Fatalf("expected an error, but got nil") + } + assert.DeepEqual(t, "test mock", err.Error()) +} From 5269fceaa711b777d06bc9869f3e0c2925249449 Mon Sep 17 00:00:00 2001 From: fgy Date: Tue, 29 Aug 2023 17:45:13 +0800 Subject: [PATCH 65/91] feat: more test coverage --- .../binding/internal/decoder/reflect.go | 5 ++ .../binding/internal/decoder/reflect_test.go | 73 +++++++++++++++++++ pkg/app/server/binding/reflect_test.go | 71 ++++++++++++++++++ 3 files changed, 149 insertions(+) create mode 100644 pkg/app/server/binding/internal/decoder/reflect_test.go create mode 100644 pkg/app/server/binding/reflect_test.go diff --git a/pkg/app/server/binding/internal/decoder/reflect.go b/pkg/app/server/binding/internal/decoder/reflect.go index d69c4b780..1f0b9996e 100644 --- a/pkg/app/server/binding/internal/decoder/reflect.go +++ b/pkg/app/server/binding/internal/decoder/reflect.go @@ -75,6 +75,11 @@ func GetNonNilReferenceValue(v reflect.Value) (reflect.Value, int) { } func GetFieldValue(reqValue reflect.Value, parentIndex []int) reflect.Value { + // reqVaule -> (***bar)(nil) need new a default + if reqValue.Kind() == reflect.Ptr && reqValue.IsNil() { + nonNilVal, ptrDepth := GetNonNilReferenceValue(reqValue) + reqValue = ReferenceValue(nonNilVal, ptrDepth) + } for _, idx := range parentIndex { if reqValue.Kind() == reflect.Ptr && reqValue.IsNil() { nonNilVal, ptrDepth := GetNonNilReferenceValue(reqValue) diff --git a/pkg/app/server/binding/internal/decoder/reflect_test.go b/pkg/app/server/binding/internal/decoder/reflect_test.go new file mode 100644 index 000000000..c33f82f14 --- /dev/null +++ b/pkg/app/server/binding/internal/decoder/reflect_test.go @@ -0,0 +1,73 @@ +package decoder + +import ( + "github.com/cloudwego/hertz/pkg/common/test/assert" + "reflect" + "testing" +) + +type foo struct { + F1 string +} + +type fooq struct { + F1 **string +} + +func Test_ReferenceValue(t *testing.T) { + foo1 := foo{F1: "f1"} + foo1Val := reflect.ValueOf(foo1) + foo1PointerVal := ReferenceValue(foo1Val, 5) + assert.DeepEqual(t, "f1", foo1.F1) + assert.DeepEqual(t, "f1", foo1Val.Field(0).Interface().(string)) + if foo1PointerVal.Kind() != reflect.Ptr { + t.Errorf("expect a pointer, but get nil") + } + assert.DeepEqual(t, "*****decoder.foo", foo1PointerVal.Type().String()) + + deFoo1PointerVal := ReferenceValue(foo1PointerVal, -5) + if deFoo1PointerVal.Kind() == reflect.Ptr { + t.Errorf("expect a non-pointer, but get a pointer") + } + assert.DeepEqual(t, "f1", deFoo1PointerVal.Field(0).Interface().(string)) +} + +func Test_GetNonNilReferenceValue(t *testing.T) { + foo1 := (****foo)(nil) + foo1Val := reflect.ValueOf(foo1) + foo1ValNonNil, ptrDepth := GetNonNilReferenceValue(foo1Val) + if !foo1ValNonNil.IsValid() { + t.Errorf("expect a valid value, but get nil") + } + if !foo1ValNonNil.CanSet() { + t.Errorf("expect can set value, but not") + } + + foo1ReferPointer := ReferenceValue(foo1ValNonNil, ptrDepth) + if foo1ReferPointer.Kind() != reflect.Ptr { + t.Errorf("expect a pointer, but get nil") + } +} + +func Test_GetFieldValue(t *testing.T) { + type bar struct { + B1 **fooq + } + bar1 := (***bar)(nil) + parentIdx := []int{0} + idx := 0 + + bar1Val := reflect.ValueOf(bar1) + parentFieldVal := GetFieldValue(bar1Val, parentIdx) + if parentFieldVal.Kind() == reflect.Ptr { + t.Errorf("expect a non-pointer, but get a pointer") + } + if !parentFieldVal.CanSet() { + t.Errorf("expect can set value, but not") + } + fooFieldVal := parentFieldVal.Field(idx) + assert.DeepEqual(t, "**string", fooFieldVal.Type().String()) + if !fooFieldVal.CanSet() { + t.Errorf("expect can set value, but not") + } +} diff --git a/pkg/app/server/binding/reflect_test.go b/pkg/app/server/binding/reflect_test.go new file mode 100644 index 000000000..c06df7df9 --- /dev/null +++ b/pkg/app/server/binding/reflect_test.go @@ -0,0 +1,71 @@ +package binding + +import ( + "reflect" + "testing" + + "github.com/cloudwego/hertz/pkg/common/test/assert" +) + +type foo struct { + f1 string +} + +func TestReflect_TypeID(t *testing.T) { + _, intType := valueAndTypeID(int(1)) + _, uintType := valueAndTypeID(uint(1)) + _, shouldBeIntType := valueAndTypeID(int(1)) + assert.DeepEqual(t, intType, shouldBeIntType) + assert.NotEqual(t, intType, uintType) + + foo1 := foo{f1: "1"} + foo2 := foo{f1: "2"} + _, foo1Type := valueAndTypeID(foo1) + _, foo2Type := valueAndTypeID(foo2) + _, foo2PointerType := valueAndTypeID(&foo2) + _, foo1PointerType := valueAndTypeID(&foo1) + assert.DeepEqual(t, foo1Type, foo2Type) + assert.NotEqual(t, foo1Type, foo2PointerType) + assert.DeepEqual(t, foo1PointerType, foo2PointerType) +} + +func TestReflect_CheckPointer(t *testing.T) { + foo1 := foo{} + foo1Val := reflect.ValueOf(foo1) + err := checkPointer(foo1Val) + if err == nil { + t.Errorf("expect an err, but get nil") + } + + foo2 := &foo{} + foo2Val := reflect.ValueOf(foo2) + err = checkPointer(foo2Val) + if err != nil { + t.Error(err) + } + + foo3 := (*foo)(nil) + foo3Val := reflect.ValueOf(foo3) + err = checkPointer(foo3Val) + if err == nil { + t.Errorf("expect an err, but get nil") + } +} + +func TestReflect_DereferPointer(t *testing.T) { + var foo1 ****foo + foo1Val := reflect.ValueOf(foo1) + rt := dereferPointer(foo1Val) + if rt.Kind() == reflect.Ptr { + t.Errorf("expect non-pointer type, but get pointer") + } + assert.DeepEqual(t, "foo", rt.Name()) + + var foo2 foo + foo2Val := reflect.ValueOf(foo2) + rt2 := dereferPointer(foo2Val) + if rt2.Kind() == reflect.Ptr { + t.Errorf("expect non-pointer type, but get pointer") + } + assert.DeepEqual(t, "foo", rt2.Name()) +} From d2ad53332b15aa1fc3b30eef71576e94729c4d2b Mon Sep 17 00:00:00 2001 From: fgy Date: Tue, 12 Sep 2023 17:40:26 +0800 Subject: [PATCH 66/91] refactor: rm old test file --- pkg/app/server/binding/binding_test.go | 0 pkg/app/server/binding/request_test.go | 235 ------------------------- 2 files changed, 235 deletions(-) delete mode 100644 pkg/app/server/binding/binding_test.go delete mode 100644 pkg/app/server/binding/request_test.go diff --git a/pkg/app/server/binding/binding_test.go b/pkg/app/server/binding/binding_test.go deleted file mode 100644 index e69de29bb..000000000 diff --git a/pkg/app/server/binding/request_test.go b/pkg/app/server/binding/request_test.go deleted file mode 100644 index b3bb70523..000000000 --- a/pkg/app/server/binding/request_test.go +++ /dev/null @@ -1,235 +0,0 @@ -/* - * Copyright 2022 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package binding - -import ( - "bytes" - "fmt" - "testing" - - "github.com/cloudwego/hertz/pkg/common/test/assert" - "github.com/cloudwego/hertz/pkg/protocol" - "github.com/cloudwego/hertz/pkg/protocol/consts" -) - -func TestGetQuery(t *testing.T) { - r := protocol.NewRequest("GET", "/foo", nil) - r.SetRequestURI("/foo/bar?para1=hertz¶2=query1¶2=query2¶3=1¶3=2") - - bindReq := bindRequest{ - req: r, - } - - values := bindReq.GetQuery() - - assert.DeepEqual(t, []string{"hertz"}, values["para1"]) - assert.DeepEqual(t, []string{"query1", "query2"}, values["para2"]) - assert.DeepEqual(t, []string{"1", "2"}, values["para3"]) -} - -func TestGetPostForm(t *testing.T) { - data := "a=aaa&b=b1&b=b2&c=ccc&d=100" - mr := bytes.NewBufferString(data) - - r := protocol.NewRequest("POST", "/foo", mr) - r.Header.SetContentTypeBytes([]byte(consts.MIMEApplicationHTMLForm)) - r.Header.SetContentLength(len(data)) - - bindReq := bindRequest{ - req: r, - } - - values, err := bindReq.GetPostForm() - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - assert.DeepEqual(t, []string{"aaa"}, values["a"]) - assert.DeepEqual(t, []string{"b1", "b2"}, values["b"]) - assert.DeepEqual(t, []string{"ccc"}, values["c"]) - assert.DeepEqual(t, []string{"100"}, values["d"]) -} - -func TestGetForm(t *testing.T) { - data := "a=aaa&b=b1&b=b2&c=ccc&d=100" - mr := bytes.NewBufferString(data) - - r := protocol.NewRequest("POST", "/foo", mr) - r.SetRequestURI("/foo/bar?para1=hertz¶2=query1¶2=query2¶3=1¶3=2") - r.Header.SetContentTypeBytes([]byte(consts.MIMEApplicationHTMLForm)) - r.Header.SetContentLength(len(data)) - - bindReq := bindRequest{ - req: r, - } - - values, err := bindReq.GetForm() - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - assert.DeepEqual(t, []string{"aaa"}, values["a"]) - assert.DeepEqual(t, []string{"b1", "b2"}, values["b"]) - assert.DeepEqual(t, []string{"ccc"}, values["c"]) - assert.DeepEqual(t, []string{"100"}, values["d"]) - assert.DeepEqual(t, []string{"hertz"}, values["para1"]) - assert.DeepEqual(t, []string{"query1", "query2"}, values["para2"]) - assert.DeepEqual(t, []string{"1", "2"}, values["para3"]) -} - -func TestGetCookies(t *testing.T) { - r := protocol.NewRequest("POST", "/foo", nil) - r.SetCookie("cookie1", "cookies1") - r.SetCookie("cookie2", "cookies2") - - bindReq := bindRequest{ - req: r, - } - - values := bindReq.GetCookies() - - assert.DeepEqual(t, "cookies1", values[0].Value) - assert.DeepEqual(t, "cookies2", values[1].Value) -} - -func TestGetHeader(t *testing.T) { - headers := map[string]string{ - "Header1": "headers1", - "Header2": "headers2", - } - - r := protocol.NewRequest("GET", "/foo", nil) - r.SetHeaders(headers) - r.SetHeader("Header3", "headers3") - - bindReq := bindRequest{ - req: r, - } - - values := bindReq.GetHeader() - - assert.DeepEqual(t, []string{"headers1"}, values["Header1"]) - assert.DeepEqual(t, []string{"headers2"}, values["Header2"]) - assert.DeepEqual(t, []string{"headers3"}, values["Header3"]) -} - -func TestGetMethod(t *testing.T) { - r := protocol.NewRequest("GET", "/foo", nil) - - bindReq := bindRequest{ - req: r, - } - - values := bindReq.GetMethod() - - assert.DeepEqual(t, "GET", values) -} - -func TestGetContentType(t *testing.T) { - data := "a=aaa&b=b1&b=b2&c=ccc&d=100" - mr := bytes.NewBufferString(data) - - r := protocol.NewRequest("POST", "/foo", mr) - r.Header.SetContentTypeBytes([]byte(consts.MIMEApplicationHTMLForm)) - r.Header.SetContentLength(len(data)) - - bindReq := bindRequest{ - req: r, - } - - values := bindReq.GetContentType() - - assert.DeepEqual(t, consts.MIMEApplicationHTMLForm, values) -} - -func TestGetBody(t *testing.T) { - data := "a=aaa&b=b1&b=b2&c=ccc&d=100" - mr := bytes.NewBufferString(data) - - r := protocol.NewRequest("POST", "/foo", mr) - r.Header.SetContentTypeBytes([]byte(consts.MIMEApplicationHTMLForm)) - r.Header.SetContentLength(len(data)) - - bindReq := bindRequest{ - req: r, - } - - values, err := bindReq.GetBody() - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - assert.DeepEqual(t, []byte("a=aaa&b=b1&b=b2&c=ccc&d=100"), values) -} - -func TestGetFileHeaders(t *testing.T) { - s := `------WebKitFormBoundaryJwfATyF8tmxSJnLg -Content-Disposition: form-data; name="f" - -fff -------WebKitFormBoundaryJwfATyF8tmxSJnLg -Content-Disposition: form-data; name="F1"; filename="TODO1" -Content-Type: application/octet-stream - -- SessionClient with referer and cookies support. -- Client with requests' pipelining support. -- ProxyHandler similar to FSHandler. -- WebSockets. See https://tools.ietf.org/html/rfc6455 . -- HTTP/2.0. See https://tools.ietf.org/html/rfc7540 . -------WebKitFormBoundaryJwfATyF8tmxSJnLg -Content-Disposition: form-data; name="F1"; filename="TODO2" -Content-Type: application/octet-stream - -- SessionClient with referer and cookies support. -- Client with requests' pipelining support. -- ProxyHandler similar to FSHandler. -- WebSockets. See https://tools.ietf.org/html/rfc6455 . -- HTTP/2.0. See https://tools.ietf.org/html/rfc7540 . -------WebKitFormBoundaryJwfATyF8tmxSJnLg -Content-Disposition: form-data; name="F2"; filename="TODO3" -Content-Type: application/octet-stream - -- SessionClient with referer and cookies support. -- Client with requests' pipelining support. -- ProxyHandler similar to FSHandler. -- WebSockets. See https://tools.ietf.org/html/rfc6455 . -- HTTP/2.0. See https://tools.ietf.org/html/rfc7540 . - -------WebKitFormBoundaryJwfATyF8tmxSJnLg-- -tailfoobar` - - mr := bytes.NewBufferString(s) - - r := protocol.NewRequest("POST", "/foo", mr) - r.Header.SetContentTypeBytes([]byte("multipart/form-data; boundary=----WebKitFormBoundaryJwfATyF8tmxSJnLg")) - r.Header.SetContentLength(len(s)) - - bindReq := bindRequest{ - req: r, - } - - values, err := bindReq.GetFileHeaders() - if err != nil { - t.Fatalf("unexpected error: %v", err) - } - - fmt.Printf("%v\n", values) - - assert.DeepEqual(t, "TODO1", values["F1"][0].Filename) - assert.DeepEqual(t, "TODO2", values["F1"][1].Filename) - assert.DeepEqual(t, "TODO3", values["F2"][0].Filename) -} From d79a390f307f0b59162054edd7626b0ba64954ee Mon Sep 17 00:00:00 2001 From: fgy Date: Tue, 12 Sep 2023 19:46:53 +0800 Subject: [PATCH 67/91] feat: add license --- .../binding/internal/decoder/reflect_test.go | 16 ++++++++++++++++ pkg/app/server/binding/reflect_test.go | 18 +++++++++++++++++- pkg/app/server/binding/testdata/hello.pb.go | 16 ++++++++++++++++ pkg/app/server/binding/testdata/hello.proto | 16 ++++++++++++++++ 4 files changed, 65 insertions(+), 1 deletion(-) diff --git a/pkg/app/server/binding/internal/decoder/reflect_test.go b/pkg/app/server/binding/internal/decoder/reflect_test.go index c33f82f14..36476e76c 100644 --- a/pkg/app/server/binding/internal/decoder/reflect_test.go +++ b/pkg/app/server/binding/internal/decoder/reflect_test.go @@ -1,3 +1,19 @@ +/* + * Copyright 2023 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package decoder import ( diff --git a/pkg/app/server/binding/reflect_test.go b/pkg/app/server/binding/reflect_test.go index c06df7df9..036eb7f35 100644 --- a/pkg/app/server/binding/reflect_test.go +++ b/pkg/app/server/binding/reflect_test.go @@ -1,3 +1,19 @@ +/* + * Copyright 2023 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package binding import ( @@ -60,7 +76,7 @@ func TestReflect_DereferPointer(t *testing.T) { t.Errorf("expect non-pointer type, but get pointer") } assert.DeepEqual(t, "foo", rt.Name()) - + var foo2 foo foo2Val := reflect.ValueOf(foo2) rt2 := dereferPointer(foo2Val) diff --git a/pkg/app/server/binding/testdata/hello.pb.go b/pkg/app/server/binding/testdata/hello.pb.go index c316985c5..8b4bce477 100644 --- a/pkg/app/server/binding/testdata/hello.pb.go +++ b/pkg/app/server/binding/testdata/hello.pb.go @@ -1,3 +1,19 @@ +/* + * Copyright 2023 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + // Code generated by protoc-gen-go. DO NOT EDIT. // versions: // protoc-gen-go v1.30.0 diff --git a/pkg/app/server/binding/testdata/hello.proto b/pkg/app/server/binding/testdata/hello.proto index 1ec7f83f5..e880c3fec 100644 --- a/pkg/app/server/binding/testdata/hello.proto +++ b/pkg/app/server/binding/testdata/hello.proto @@ -1,3 +1,19 @@ +/* + * Copyright 2023 CloudWeGo Authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + syntax = "proto3"; package hertz; option go_package = "hertz/hello"; From 5170a8921bc13003ccf656bd53590bb3803642ea Mon Sep 17 00:00:00 2001 From: fgy Date: Tue, 12 Sep 2023 19:56:42 +0800 Subject: [PATCH 68/91] fix: typo --- pkg/app/server/binding/internal/decoder/reflect.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/app/server/binding/internal/decoder/reflect.go b/pkg/app/server/binding/internal/decoder/reflect.go index 1f0b9996e..8d9b115e5 100644 --- a/pkg/app/server/binding/internal/decoder/reflect.go +++ b/pkg/app/server/binding/internal/decoder/reflect.go @@ -75,7 +75,7 @@ func GetNonNilReferenceValue(v reflect.Value) (reflect.Value, int) { } func GetFieldValue(reqValue reflect.Value, parentIndex []int) reflect.Value { - // reqVaule -> (***bar)(nil) need new a default + // reqValue -> (***bar)(nil) need new a default if reqValue.Kind() == reflect.Ptr && reqValue.IsNil() { nonNilVal, ptrDepth := GetNonNilReferenceValue(reqValue) reqValue = ReferenceValue(nonNilVal, ptrDepth) From e6961f1a808a75c0a79dad3b3ec8c4f8930d1304 Mon Sep 17 00:00:00 2001 From: fgy Date: Tue, 12 Sep 2023 20:08:35 +0800 Subject: [PATCH 69/91] fix: golong lint --- pkg/app/server/binding/binder_test.go | 11 +++++------ pkg/app/server/binding/default.go | 4 ++-- .../server/binding/internal/decoder/reflect_test.go | 3 ++- .../binding/internal/decoder/struct_type_decoder.go | 2 +- .../server/binding/internal/decoder/text_decoder.go | 3 +-- pkg/app/server/binding/validator_test.go | 4 ++-- 6 files changed, 13 insertions(+), 14 deletions(-) diff --git a/pkg/app/server/binding/binder_test.go b/pkg/app/server/binding/binder_test.go index ab9383582..0ba2f1e06 100644 --- a/pkg/app/server/binding/binder_test.go +++ b/pkg/app/server/binding/binder_test.go @@ -43,18 +43,18 @@ package binding import ( "encoding/json" "fmt" - "github.com/cloudwego/hertz/pkg/app/server/binding/testdata" - "google.golang.org/protobuf/proto" "mime/multipart" "net/url" "reflect" "testing" + "github.com/cloudwego/hertz/pkg/app/server/binding/testdata" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/protocol/consts" req2 "github.com/cloudwego/hertz/pkg/protocol/http1/req" "github.com/cloudwego/hertz/pkg/route/param" + "google.golang.org/protobuf/proto" ) func init() { @@ -1093,7 +1093,6 @@ func TestBind_BindAndValidate(t *testing.T) { t.Errorf("expect an error, but get nil") } assert.DeepEqual(t, 9, result.ID) - } func TestBind_FastPath(t *testing.T) { @@ -1234,7 +1233,7 @@ func TestBind_StructRequired(t *testing.T) { var result Bar req := newMockRequest(). - SetRequestURI(fmt.Sprintf("http://foobar.com")) + SetRequestURI("http://foobar.com") err := DefaultBinder().Bind(req.Req, &result, nil) if err == nil { @@ -1246,7 +1245,7 @@ func TestBind_StructRequired(t *testing.T) { } var result2 Bar2 req = newMockRequest(). - SetRequestURI(fmt.Sprintf("http://foobar.com")) + SetRequestURI("http://foobar.com") err = DefaultBinder().Bind(req.Req, &result2, nil) if err != nil { @@ -1266,7 +1265,7 @@ func TestBind_StructErrorToWarn(t *testing.T) { var result Bar req := newMockRequest(). - SetRequestURI(fmt.Sprintf("http://foobar.com?B1=111&F1=222")) + SetRequestURI("http://foobar.com?B1=111&F1=222") err := DefaultBinder().Bind(req.Req, &result, nil) // transfer 'unmarsahl err' to 'warn' diff --git a/pkg/app/server/binding/default.go b/pkg/app/server/binding/default.go index 7c29faad7..c8b68c545 100644 --- a/pkg/app/server/binding/default.go +++ b/pkg/app/server/binding/default.go @@ -301,8 +301,8 @@ func (b *defaultBinder) bindNonStruct(req *protocol.Request, v interface{}) (err err = proto.Unmarshal(req.Body(), msg) case consts.MIMEMultipartPOSTForm: form := make(url.Values) - mf, err := req.MultipartForm() - if err == nil && mf.Value != nil { + mf, err1 := req.MultipartForm() + if err1 == nil && mf.Value != nil { for k, v := range mf.Value { for _, vv := range v { form.Add(k, vv) diff --git a/pkg/app/server/binding/internal/decoder/reflect_test.go b/pkg/app/server/binding/internal/decoder/reflect_test.go index 36476e76c..23a588350 100644 --- a/pkg/app/server/binding/internal/decoder/reflect_test.go +++ b/pkg/app/server/binding/internal/decoder/reflect_test.go @@ -17,9 +17,10 @@ package decoder import ( - "github.com/cloudwego/hertz/pkg/common/test/assert" "reflect" "testing" + + "github.com/cloudwego/hertz/pkg/common/test/assert" ) type foo struct { diff --git a/pkg/app/server/binding/internal/decoder/struct_type_decoder.go b/pkg/app/server/binding/internal/decoder/struct_type_decoder.go index 70fa12080..7804aa65b 100644 --- a/pkg/app/server/binding/internal/decoder/struct_type_decoder.go +++ b/pkg/app/server/binding/internal/decoder/struct_type_decoder.go @@ -18,10 +18,10 @@ package decoder import ( "fmt" - "github.com/cloudwego/hertz/pkg/common/hlog" "reflect" "github.com/cloudwego/hertz/internal/bytesconv" + "github.com/cloudwego/hertz/pkg/common/hlog" hjson "github.com/cloudwego/hertz/pkg/common/json" "github.com/cloudwego/hertz/pkg/common/utils" "github.com/cloudwego/hertz/pkg/protocol" diff --git a/pkg/app/server/binding/internal/decoder/text_decoder.go b/pkg/app/server/binding/internal/decoder/text_decoder.go index 45b87332b..7224d61fc 100644 --- a/pkg/app/server/binding/internal/decoder/text_decoder.go +++ b/pkg/app/server/binding/internal/decoder/text_decoder.go @@ -172,8 +172,7 @@ func (d *uintDecoder) UnmarshalString(s string, fieldValue reflect.Value) error return nil } -type interfaceDecoder struct { -} +type interfaceDecoder struct{} func (d *interfaceDecoder) UnmarshalString(s string, fieldValue reflect.Value) error { if s == "" && looseZeroMode { diff --git a/pkg/app/server/binding/validator_test.go b/pkg/app/server/binding/validator_test.go index b02222bc6..3f11b634d 100644 --- a/pkg/app/server/binding/validator_test.go +++ b/pkg/app/server/binding/validator_test.go @@ -18,8 +18,9 @@ package binding import ( "fmt" - "github.com/cloudwego/hertz/pkg/common/test/assert" "testing" + + "github.com/cloudwego/hertz/pkg/common/test/assert" ) func Test_ValidateStruct(t *testing.T) { @@ -40,7 +41,6 @@ type mockValidator struct{} func (m *mockValidator) ValidateStruct(interface{}) error { return fmt.Errorf("test mock") - } func (m *mockValidator) Engine() interface{} { From 02f3bdf646082914102190fd5226777993879c1d Mon Sep 17 00:00:00 2001 From: fgy Date: Thu, 14 Sep 2023 16:48:42 +0800 Subject: [PATCH 70/91] feat: resolve struct by default --- pkg/app/server/binding/internal/decoder/decoder.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/app/server/binding/internal/decoder/decoder.go b/pkg/app/server/binding/internal/decoder/decoder.go index 4100815a1..6329035e2 100644 --- a/pkg/app/server/binding/internal/decoder/decoder.go +++ b/pkg/app/server/binding/internal/decoder/decoder.go @@ -51,7 +51,7 @@ import ( var ( EnableDefaultTag = true - EnableStructFieldResolve = false + EnableStructFieldResolve = true ) type fieldDecoder interface { From 2299988294cd5d26d8355dc5a9f952bd07ddab4a Mon Sep 17 00:00:00 2001 From: fgy Date: Tue, 19 Sep 2023 11:40:14 +0800 Subject: [PATCH 71/91] fix: gjson for windows --- .../server/binding/internal/decoder/gjson_required.go | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/pkg/app/server/binding/internal/decoder/gjson_required.go b/pkg/app/server/binding/internal/decoder/gjson_required.go index 5fbfd4086..88697e0f3 100644 --- a/pkg/app/server/binding/internal/decoder/gjson_required.go +++ b/pkg/app/server/binding/internal/decoder/gjson_required.go @@ -23,23 +23,24 @@ import ( "github.com/cloudwego/hertz/internal/bytesconv" "github.com/cloudwego/hertz/pkg/common/utils" + "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/protocol/consts" "github.com/tidwall/gjson" ) -func checkRequireJSON2(req *protocol.Request, tagInfo TagInfo) bool { +func checkRequireJSON(req *protocol.Request, tagInfo TagInfo) bool { if !tagInfo.Required { return true } - ct := bytesconv.B2s(req.Req.Header.ContentType()) + ct := bytesconv.B2s(req.Header.ContentType()) if utils.FilterContentType(ct) != consts.MIMEApplicationJSON { return false } - result := gjson.GetBytes(req.Req.Body(), tagInfo.JSONName) + result := gjson.GetBytes(req.Body(), tagInfo.JSONName) if !result.Exists() { idx := strings.LastIndex(tagInfo.JSONName, ".") // There should be a superior if it is empty, it will report 'true' for required - if idx > 0 && !gjson.GetBytes(req.Req.Body(), tagInfo.JSONName[:idx]).Exists() { + if idx > 0 && !gjson.GetBytes(req.Body(), tagInfo.JSONName[:idx]).Exists() { return true } return false From f2e61b904c6cd59cb973ad507079c738d994738a Mon Sep 17 00:00:00 2001 From: fgy Date: Tue, 19 Sep 2023 12:32:53 +0800 Subject: [PATCH 72/91] feat: reflect internal test --- ...flect_test.go => reflect_internal_test.go} | 46 +++++++++---------- 1 file changed, 23 insertions(+), 23 deletions(-) rename pkg/app/server/binding/{internal/decoder/reflect_test.go => reflect_internal_test.go} (61%) diff --git a/pkg/app/server/binding/internal/decoder/reflect_test.go b/pkg/app/server/binding/reflect_internal_test.go similarity index 61% rename from pkg/app/server/binding/internal/decoder/reflect_test.go rename to pkg/app/server/binding/reflect_internal_test.go index 23a588350..d0ecdeefe 100644 --- a/pkg/app/server/binding/internal/decoder/reflect_test.go +++ b/pkg/app/server/binding/reflect_internal_test.go @@ -1,29 +1,29 @@ -/* - * Copyright 2023 CloudWeGo Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ +// Copyright 2023 CloudWeGo Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// -package decoder +package binding import ( "reflect" "testing" + "github.com/cloudwego/hertz/pkg/app/server/binding/internal/decoder" "github.com/cloudwego/hertz/pkg/common/test/assert" ) -type foo struct { +type foo2 struct { F1 string } @@ -32,9 +32,9 @@ type fooq struct { } func Test_ReferenceValue(t *testing.T) { - foo1 := foo{F1: "f1"} + foo1 := foo2{F1: "f1"} foo1Val := reflect.ValueOf(foo1) - foo1PointerVal := ReferenceValue(foo1Val, 5) + foo1PointerVal := decoder.ReferenceValue(foo1Val, 5) assert.DeepEqual(t, "f1", foo1.F1) assert.DeepEqual(t, "f1", foo1Val.Field(0).Interface().(string)) if foo1PointerVal.Kind() != reflect.Ptr { @@ -42,7 +42,7 @@ func Test_ReferenceValue(t *testing.T) { } assert.DeepEqual(t, "*****decoder.foo", foo1PointerVal.Type().String()) - deFoo1PointerVal := ReferenceValue(foo1PointerVal, -5) + deFoo1PointerVal := decoder.ReferenceValue(foo1PointerVal, -5) if deFoo1PointerVal.Kind() == reflect.Ptr { t.Errorf("expect a non-pointer, but get a pointer") } @@ -52,7 +52,7 @@ func Test_ReferenceValue(t *testing.T) { func Test_GetNonNilReferenceValue(t *testing.T) { foo1 := (****foo)(nil) foo1Val := reflect.ValueOf(foo1) - foo1ValNonNil, ptrDepth := GetNonNilReferenceValue(foo1Val) + foo1ValNonNil, ptrDepth := decoder.GetNonNilReferenceValue(foo1Val) if !foo1ValNonNil.IsValid() { t.Errorf("expect a valid value, but get nil") } @@ -60,7 +60,7 @@ func Test_GetNonNilReferenceValue(t *testing.T) { t.Errorf("expect can set value, but not") } - foo1ReferPointer := ReferenceValue(foo1ValNonNil, ptrDepth) + foo1ReferPointer := decoder.ReferenceValue(foo1ValNonNil, ptrDepth) if foo1ReferPointer.Kind() != reflect.Ptr { t.Errorf("expect a pointer, but get nil") } @@ -75,7 +75,7 @@ func Test_GetFieldValue(t *testing.T) { idx := 0 bar1Val := reflect.ValueOf(bar1) - parentFieldVal := GetFieldValue(bar1Val, parentIdx) + parentFieldVal := decoder.GetFieldValue(bar1Val, parentIdx) if parentFieldVal.Kind() == reflect.Ptr { t.Errorf("expect a non-pointer, but get a pointer") } From 1a3de68517246fdf365ef323e7d4e29bd43e1eb1 Mon Sep 17 00:00:00 2001 From: fgy Date: Tue, 19 Sep 2023 12:40:46 +0800 Subject: [PATCH 73/91] feat: warn to info --- .../server/binding/internal/decoder/struct_type_decoder.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/app/server/binding/internal/decoder/struct_type_decoder.go b/pkg/app/server/binding/internal/decoder/struct_type_decoder.go index 7804aa65b..e157b4a0a 100644 --- a/pkg/app/server/binding/internal/decoder/struct_type_decoder.go +++ b/pkg/app/server/binding/internal/decoder/struct_type_decoder.go @@ -84,7 +84,7 @@ func (d *structTypeFieldTextDecoder) Decode(req *protocol.Request, params param. var vv reflect.Value vv, err := stringToValue(t, text, req, params) if err != nil { - hlog.Warnf("unable to decode '%s' as %s: %v, but it may not affect correctness, so skip it", text, d.fieldType.Name(), err) + hlog.Infof("unable to decode '%s' as %s: %v, but it may not affect correctness, so skip it", text, d.fieldType.Name(), err) return nil } field.Set(ReferenceValue(vv, ptrDepth)) @@ -93,7 +93,7 @@ func (d *structTypeFieldTextDecoder) Decode(req *protocol.Request, params param. err = hjson.Unmarshal(bytesconv.S2b(text), field.Addr().Interface()) if err != nil { - hlog.Warnf("unable to decode '%s' as %s: %v, but it may not affect correctness, so skip it", text, d.fieldType.Name(), err) + hlog.Infof("unable to decode '%s' as %s: %v, but it may not affect correctness, so skip it", text, d.fieldType.Name(), err) } return nil From 022954f3732901128f52b0e987972a6108af8db2 Mon Sep 17 00:00:00 2001 From: fgy Date: Tue, 19 Sep 2023 12:42:24 +0800 Subject: [PATCH 74/91] fix: modify test --- pkg/app/server/binding/reflect_internal_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/app/server/binding/reflect_internal_test.go b/pkg/app/server/binding/reflect_internal_test.go index d0ecdeefe..a090b131f 100644 --- a/pkg/app/server/binding/reflect_internal_test.go +++ b/pkg/app/server/binding/reflect_internal_test.go @@ -40,7 +40,7 @@ func Test_ReferenceValue(t *testing.T) { if foo1PointerVal.Kind() != reflect.Ptr { t.Errorf("expect a pointer, but get nil") } - assert.DeepEqual(t, "*****decoder.foo", foo1PointerVal.Type().String()) + assert.DeepEqual(t, "*****decoder.foo2but it may", foo1PointerVal.Type().String()) deFoo1PointerVal := decoder.ReferenceValue(foo1PointerVal, -5) if deFoo1PointerVal.Kind() == reflect.Ptr { From df13ac43c7c1ca5f186b95b2ae322d747e48bbde Mon Sep 17 00:00:00 2001 From: fgy Date: Tue, 19 Sep 2023 12:50:34 +0800 Subject: [PATCH 75/91] feat: go mod tidy --- go.sum | 4 ++-- pkg/app/server/binding/reflect_internal_test.go | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/go.sum b/go.sum index 30da12c13..f86006603 100644 --- a/go.sum +++ b/go.sum @@ -10,8 +10,8 @@ github.com/bytedance/sonic v1.8.1/go.mod h1:i736AoUSYt75HyZLoJW9ERYxcy6eaN6h4BZX github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311 h1:qSGYFH7+jGhDF8vLC+iwCD4WpbV1EBDSzWkJODFLams= github.com/chenzhuoyu/base64x v0.0.0-20221115062448-fe3a3abad311/go.mod h1:b583jCggY9gE99b6G5LEC39OIiVsWj+R97kbl5odCEk= -github.com/cloudwego/netpoll v0.3.2 h1:/998ICrNMVBo4mlul4j7qcIeY7QnEfuCCPPwck9S3X4= -github.com/cloudwego/netpoll v0.3.2/go.mod h1:xVefXptcyheopwNDZjDPcfU6kIjZXZ4nY550k1yH9eQ= +github.com/cloudwego/netpoll v0.4.2-0.20230807055039-52fd5fb7b00f h1:8iWPKjHdXl4tjcSxUJTavnhRL5JPupYvxbtsAlm2Igw= +github.com/cloudwego/netpoll v0.4.2-0.20230807055039-52fd5fb7b00f/go.mod h1:xVefXptcyheopwNDZjDPcfU6kIjZXZ4nY550k1yH9eQ= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= diff --git a/pkg/app/server/binding/reflect_internal_test.go b/pkg/app/server/binding/reflect_internal_test.go index a090b131f..65dc68fc8 100644 --- a/pkg/app/server/binding/reflect_internal_test.go +++ b/pkg/app/server/binding/reflect_internal_test.go @@ -40,7 +40,7 @@ func Test_ReferenceValue(t *testing.T) { if foo1PointerVal.Kind() != reflect.Ptr { t.Errorf("expect a pointer, but get nil") } - assert.DeepEqual(t, "*****decoder.foo2but it may", foo1PointerVal.Type().String()) + assert.DeepEqual(t, "*****binding.foo2", foo1PointerVal.Type().String()) deFoo1PointerVal := decoder.ReferenceValue(foo1PointerVal, -5) if deFoo1PointerVal.Kind() == reflect.Ptr { From 20990ffcfcfb46bebf513ad737d66a88ddd6b4a1 Mon Sep 17 00:00:00 2001 From: fgy Date: Tue, 19 Sep 2023 14:17:23 +0800 Subject: [PATCH 76/91] fix: call function --- pkg/app/context.go | 6 +++--- pkg/app/server/binding/config.go | 2 +- pkg/app/server/binding/default.go | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pkg/app/context.go b/pkg/app/context.go index 042e0f540..0c062c21c 100644 --- a/pkg/app/context.go +++ b/pkg/app/context.go @@ -1305,19 +1305,19 @@ func bodyAllowedForStatus(status int) bool { // BindAndValidate binds data from *RequestContext to obj and validates them if needed. // NOTE: obj should be a pointer. func (ctx *RequestContext) BindAndValidate(obj interface{}) error { - return binding.DefaultBinder().BindAndValidate(&ctx.Request, obj, ctx.Params) + return binding.BindAndValidate(&ctx.Request, obj, ctx.Params) } // Bind binds data from *RequestContext to obj. // NOTE: obj should be a pointer. func (ctx *RequestContext) Bind(obj interface{}) error { - return binding.DefaultBinder().Bind(&ctx.Request, obj, ctx.Params) + return binding.Bind(&ctx.Request, obj, ctx.Params) } // Validate validates obj with "vd" tag // NOTE: obj should be a pointer. func (ctx *RequestContext) Validate(obj interface{}) error { - return binding.DefaultValidator().ValidateStruct(obj) + return binding.Validate(obj) } // BindQuery binds query parameters from *RequestContext to obj with 'query' tag. It will only use 'query' tag for binding. diff --git a/pkg/app/server/binding/config.go b/pkg/app/server/binding/config.go index e27283395..6869a6f1a 100644 --- a/pkg/app/server/binding/config.go +++ b/pkg/app/server/binding/config.go @@ -81,7 +81,7 @@ func MustRegTypeUnmarshal(t reflect.Type, fn func(req *protocol.Request, params decoder.MustRegTypeUnmarshal(t, fn) } -// ResetValidator reset a customized +// ResetValidator reset a customized validator. func ResetValidator(v StructValidator, validatorTag string) { defaultValidate = v decoder.DefaultValidatorTag = validatorTag diff --git a/pkg/app/server/binding/default.go b/pkg/app/server/binding/default.go index c8b68c545..319e67ae5 100644 --- a/pkg/app/server/binding/default.go +++ b/pkg/app/server/binding/default.go @@ -251,7 +251,7 @@ func decodeJSON(r io.Reader, obj interface{}) error { func (b *defaultBinder) BindProtobuf(req *protocol.Request, v interface{}) error { msg, ok := v.(proto.Message) if !ok { - return fmt.Errorf("%s can not implement 'proto.Message'", v) + return fmt.Errorf("%s does not implement 'proto.Message'", v) } return proto.Unmarshal(req.Body(), msg) } From 8b4c98d7d674574350e645f5e2e83b4ecbaaad29 Mon Sep 17 00:00:00 2001 From: fgy Date: Tue, 19 Sep 2023 15:40:48 +0800 Subject: [PATCH 77/91] refactor: set binder and validator --- pkg/app/context.go | 40 +++++++++++++++++++ pkg/app/context_test.go | 30 ++++++++++++++ pkg/app/server/binding/binder.go | 4 +- pkg/app/server/binding/config.go | 6 --- pkg/app/server/binding/default.go | 13 +++++- .../binding/internal/decoder/decoder.go | 10 ++--- .../server/binding/internal/decoder/tag.go | 6 +-- pkg/app/server/binding/validator_test.go | 30 -------------- pkg/route/engine.go | 17 ++++++++ 9 files changed, 108 insertions(+), 48 deletions(-) diff --git a/pkg/app/context.go b/pkg/app/context.go index 0c062c21c..1bed69de8 100644 --- a/pkg/app/context.go +++ b/pkg/app/context.go @@ -233,6 +233,10 @@ type RequestContext struct { // clientIPFunc get form value by use custom function. formValueFunc FormValueFunc + + binder binding.Binder + validator binding.StructValidator + validateTag string } // Flush is the shortcut for ctx.Response.GetHijackWriter().Flush(). @@ -252,6 +256,15 @@ func (ctx *RequestContext) SetFormValueFunc(f FormValueFunc) { ctx.formValueFunc = f } +func (ctx *RequestContext) SetBinder(binder binding.Binder) { + ctx.binder = binder +} + +func (ctx *RequestContext) SetValidator(validator binding.StructValidator, tag string) { + ctx.validator = validator + ctx.validateTag = tag +} + func (ctx *RequestContext) GetTraceInfo() traceinfo.TraceInfo { return ctx.traceInfo } @@ -1305,36 +1318,54 @@ func bodyAllowedForStatus(status int) bool { // BindAndValidate binds data from *RequestContext to obj and validates them if needed. // NOTE: obj should be a pointer. func (ctx *RequestContext) BindAndValidate(obj interface{}) error { + if ctx.binder != nil { + return ctx.binder.BindAndValidate(&ctx.Request, obj, ctx.Params) + } return binding.BindAndValidate(&ctx.Request, obj, ctx.Params) } // Bind binds data from *RequestContext to obj. // NOTE: obj should be a pointer. func (ctx *RequestContext) Bind(obj interface{}) error { + if ctx.binder != nil { + return ctx.binder.Bind(&ctx.Request, obj, ctx.Params) + } return binding.Bind(&ctx.Request, obj, ctx.Params) } // Validate validates obj with "vd" tag // NOTE: obj should be a pointer. func (ctx *RequestContext) Validate(obj interface{}) error { + if ctx.validator != nil { + return ctx.validator.ValidateStruct(obj) + } return binding.Validate(obj) } // BindQuery binds query parameters from *RequestContext to obj with 'query' tag. It will only use 'query' tag for binding. // NOTE: obj should be a pointer. func (ctx *RequestContext) BindQuery(obj interface{}) error { + if ctx.binder != nil { + return ctx.binder.BindQuery(&ctx.Request, obj) + } return binding.DefaultBinder().BindQuery(&ctx.Request, obj) } // BindHeader binds header parameters from *RequestContext to obj with 'header' tag. It will only use 'header' tag for binding. // NOTE: obj should be a pointer. func (ctx *RequestContext) BindHeader(obj interface{}) error { + if ctx.binder != nil { + return ctx.binder.BindHeader(&ctx.Request, obj) + } return binding.DefaultBinder().BindHeader(&ctx.Request, obj) } // BindPath binds router parameters from *RequestContext to obj with 'path' tag. It will only use 'path' tag for binding. // NOTE: obj should be a pointer. func (ctx *RequestContext) BindPath(obj interface{}) error { + if ctx.binder != nil { + return ctx.binder.BindPath(&ctx.Request, obj, ctx.Params) + } return binding.DefaultBinder().BindPath(&ctx.Request, obj, ctx.Params) } @@ -1344,18 +1375,27 @@ func (ctx *RequestContext) BindForm(obj interface{}) error { if len(ctx.Request.Body()) == 0 { return fmt.Errorf("missing form body") } + if ctx.binder != nil { + return ctx.binder.BindForm(&ctx.Request, obj) + } return binding.DefaultBinder().BindForm(&ctx.Request, obj) } // BindJSON binds JSON body from *RequestContext. // NOTE: obj should be a pointer. func (ctx *RequestContext) BindJSON(obj interface{}) error { + if ctx.binder != nil { + return ctx.binder.BindJSON(&ctx.Request, obj) + } return binding.DefaultBinder().BindJSON(&ctx.Request, obj) } // BindProtobuf binds protobuf body from *RequestContext. // NOTE: obj should be a pointer. func (ctx *RequestContext) BindProtobuf(obj interface{}) error { + if ctx.binder != nil { + return ctx.binder.BindProtobuf(&ctx.Request, obj) + } return binding.DefaultBinder().BindProtobuf(&ctx.Request, obj) } diff --git a/pkg/app/context_test.go b/pkg/app/context_test.go index a7d29f5d8..482b08c00 100644 --- a/pkg/app/context_test.go +++ b/pkg/app/context_test.go @@ -33,6 +33,7 @@ import ( "github.com/cloudwego/hertz/internal/bytesconv" "github.com/cloudwego/hertz/internal/bytestr" + "github.com/cloudwego/hertz/pkg/app/server/binding" "github.com/cloudwego/hertz/pkg/app/server/render" errs "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/common/test/assert" @@ -873,6 +874,35 @@ func TestSetClientIPFunc(t *testing.T) { assert.DeepEqual(t, reflect.ValueOf(fn).Pointer(), reflect.ValueOf(defaultClientIP).Pointer()) } +type mockValidator struct{} + +func (m *mockValidator) ValidateStruct(interface{}) error { + return fmt.Errorf("test mock") +} + +func (m *mockValidator) Engine() interface{} { + return nil +} + +func TestSetValidator(t *testing.T) { + m := &mockValidator{} + c := NewContext(0) + c.SetValidator(m, "vt") + defer c.SetValidator(binding.DefaultValidator(), "vd") + type User struct { + Age int `vt:"$>=0&&$<=130"` + } + + user := &User{ + Age: 135, + } + err := c.Validate(user) + if err == nil { + t.Fatalf("expected an error, but got nil") + } + assert.DeepEqual(t, "test mock", err.Error()) +} + func TestGetQuery(t *testing.T) { c := NewContext(0) c.Request.SetRequestURI("http://aaa.com?a=1&b=") diff --git a/pkg/app/server/binding/binder.go b/pkg/app/server/binding/binder.go index 2e43daee3..e36acda10 100644 --- a/pkg/app/server/binding/binder.go +++ b/pkg/app/server/binding/binder.go @@ -55,9 +55,11 @@ type Binder interface { BindForm(*protocol.Request, interface{}) error BindJSON(*protocol.Request, interface{}) error BindProtobuf(*protocol.Request, interface{}) error + ValidateTag() string + SetValidateTag(string) } -var defaultBind Binder = &defaultBinder{} +var defaultBind Binder = &defaultBinder{validateTag: "vd"} func DefaultBinder() Binder { return defaultBind diff --git a/pkg/app/server/binding/config.go b/pkg/app/server/binding/config.go index 6869a6f1a..1ba2ba3eb 100644 --- a/pkg/app/server/binding/config.go +++ b/pkg/app/server/binding/config.go @@ -81,12 +81,6 @@ func MustRegTypeUnmarshal(t reflect.Type, fn func(req *protocol.Request, params decoder.MustRegTypeUnmarshal(t, fn) } -// ResetValidator reset a customized validator. -func ResetValidator(v StructValidator, validatorTag string) { - defaultValidate = v - decoder.DefaultValidatorTag = validatorTag -} - // MustRegValidateFunc registers validator function expression. // NOTE: // diff --git a/pkg/app/server/binding/default.go b/pkg/app/server/binding/default.go index 319e67ae5..63c8eb74d 100644 --- a/pkg/app/server/binding/default.go +++ b/pkg/app/server/binding/default.go @@ -93,6 +93,7 @@ type decoderInfo struct { } type defaultBinder struct { + validateTag string decoderCache sync.Map queryDecoderCache sync.Map formDecoderCache sync.Map @@ -162,7 +163,7 @@ func (b *defaultBinder) bindTag(req *protocol.Request, v interface{}, params par return decoder.decoder(req, params, rv.Elem()) } - decoder, needValidate, err := inDecoder.GetReqDecoder(rv.Type(), tag) + decoder, needValidate, err := inDecoder.GetReqDecoder(rv.Type(), tag, b.ValidateTag()) if err != nil { return err } @@ -200,7 +201,7 @@ func (b *defaultBinder) bindTagWithValidate(req *protocol.Request, v interface{} return err } - decoder, needValidate, err := inDecoder.GetReqDecoder(rv.Type(), tag) + decoder, needValidate, err := inDecoder.GetReqDecoder(rv.Type(), tag, b.ValidateTag()) if err != nil { return err } @@ -268,6 +269,14 @@ func (b *defaultBinder) Bind(req *protocol.Request, v interface{}, params param. return b.bindTag(req, v, params, "") } +func (b *defaultBinder) ValidateTag() string { + return b.validateTag +} + +func (b *defaultBinder) SetValidateTag(tag string) { + b.validateTag = tag +} + // best effort binding func (b *defaultBinder) preBindBody(req *protocol.Request, v interface{}) error { if req.Header.ContentLength() <= 0 { diff --git a/pkg/app/server/binding/internal/decoder/decoder.go b/pkg/app/server/binding/internal/decoder/decoder.go index 6329035e2..7da7be69a 100644 --- a/pkg/app/server/binding/internal/decoder/decoder.go +++ b/pkg/app/server/binding/internal/decoder/decoder.go @@ -60,7 +60,7 @@ type fieldDecoder interface { type Decoder func(req *protocol.Request, params param.Params, rv reflect.Value) error -func GetReqDecoder(rt reflect.Type, byTag string) (Decoder, bool, error) { +func GetReqDecoder(rt reflect.Type, byTag string, validateTag string) (Decoder, bool, error) { var decoders []fieldDecoder var needValidate bool @@ -75,7 +75,7 @@ func GetReqDecoder(rt reflect.Type, byTag string) (Decoder, bool, error) { continue } - dec, needValidate2, err := getFieldDecoder(el.Field(i), i, []int{}, "", byTag) + dec, needValidate2, err := getFieldDecoder(el.Field(i), i, []int{}, "", byTag, validateTag) if err != nil { return nil, false, err } @@ -98,7 +98,7 @@ func GetReqDecoder(rt reflect.Type, byTag string) (Decoder, bool, error) { }, needValidate, nil } -func getFieldDecoder(field reflect.StructField, index int, parentIdx []int, parentJSONName string, byTag string) ([]fieldDecoder, bool, error) { +func getFieldDecoder(field reflect.StructField, index int, parentIdx []int, parentJSONName string, byTag string, validateTag string) ([]fieldDecoder, bool, error) { for field.Type.Kind() == reflect.Ptr { field.Type = field.Type.Elem() } @@ -111,7 +111,7 @@ func getFieldDecoder(field reflect.StructField, index int, parentIdx []int, pare } // JSONName is like 'a.b.c' for 'required validate' - fieldTagInfos, newParentJSONName, needValidate := lookupFieldTags(field, parentJSONName) + fieldTagInfos, newParentJSONName, needValidate := lookupFieldTags(field, parentJSONName, validateTag) if len(fieldTagInfos) == 0 && EnableDefaultTag { fieldTagInfos = getDefaultFieldTags(field) } @@ -167,7 +167,7 @@ func getFieldDecoder(field reflect.StructField, index int, parentIdx []int, pare idxes = append(idxes, parentIdx...) } idxes = append(idxes, index) - dec, needValidate2, err := getFieldDecoder(el.Field(i), i, idxes, newParentJSONName, byTag) + dec, needValidate2, err := getFieldDecoder(el.Field(i), i, idxes, newParentJSONName, byTag, validateTag) needValidate = needValidate || needValidate2 if err != nil { return nil, false, err diff --git a/pkg/app/server/binding/internal/decoder/tag.go b/pkg/app/server/binding/internal/decoder/tag.go index 0f754fbb1..ee3ed388d 100644 --- a/pkg/app/server/binding/internal/decoder/tag.go +++ b/pkg/app/server/binding/internal/decoder/tag.go @@ -32,8 +32,6 @@ const ( fileNameTag = "file_name" ) -var DefaultValidatorTag = "vd" - const ( defaultTag = "default" ) @@ -62,10 +60,10 @@ func head(str, sep string) (head, tail string) { return str[:idx], str[idx+len(sep):] } -func lookupFieldTags(field reflect.StructField, parentJSONName string) ([]TagInfo, string, bool) { +func lookupFieldTags(field reflect.StructField, parentJSONName string, validateTag string) ([]TagInfo, string, bool) { var ret []string var needValidate bool - if _, ok := field.Tag.Lookup(DefaultValidatorTag); ok { + if _, ok := field.Tag.Lookup(validateTag); ok { needValidate = true } tags := []string{pathTag, formTag, queryTag, cookieTag, headerTag, jsonTag, rawBodyTag, fileNameTag} diff --git a/pkg/app/server/binding/validator_test.go b/pkg/app/server/binding/validator_test.go index 3f11b634d..2f85716b5 100644 --- a/pkg/app/server/binding/validator_test.go +++ b/pkg/app/server/binding/validator_test.go @@ -17,10 +17,7 @@ package binding import ( - "fmt" "testing" - - "github.com/cloudwego/hertz/pkg/common/test/assert" ) func Test_ValidateStruct(t *testing.T) { @@ -36,30 +33,3 @@ func Test_ValidateStruct(t *testing.T) { t.Fatalf("expected an error, but got nil") } } - -type mockValidator struct{} - -func (m *mockValidator) ValidateStruct(interface{}) error { - return fmt.Errorf("test mock") -} - -func (m *mockValidator) Engine() interface{} { - return nil -} - -func Test_ResetValidatorConfig(t *testing.T) { - m := &mockValidator{} - ResetValidator(m, "vt") - type User struct { - Age int `vt:"$>=0&&$<=130"` - } - - user := &User{ - Age: 135, - } - err := DefaultValidator().ValidateStruct(user) - if err == nil { - t.Fatalf("expected an error, but got nil") - } - assert.DeepEqual(t, "test mock", err.Error()) -} diff --git a/pkg/route/engine.go b/pkg/route/engine.go index bd8fbc1a9..42da1d662 100644 --- a/pkg/route/engine.go +++ b/pkg/route/engine.go @@ -59,6 +59,7 @@ import ( "github.com/cloudwego/hertz/internal/nocopy" internalStats "github.com/cloudwego/hertz/internal/stats" "github.com/cloudwego/hertz/pkg/app" + "github.com/cloudwego/hertz/pkg/app/server/binding" "github.com/cloudwego/hertz/pkg/app/server/render" "github.com/cloudwego/hertz/pkg/common/config" errs "github.com/cloudwego/hertz/pkg/common/errors" @@ -192,6 +193,11 @@ type Engine struct { // Custom Functions clientIPFunc app.ClientIP formValueFunc app.FormValueFunc + + // Custom Binder and Validator + binder binding.Binder + validator binding.StructValidator + validateTag string } func (engine *Engine) IsTraceEnable() bool { @@ -737,6 +743,8 @@ func (engine *Engine) allocateContext() *app.RequestContext { ctx.Response.SetMaxKeepBodySize(engine.options.MaxKeepBodySize) ctx.SetClientIPFunc(engine.clientIPFunc) ctx.SetFormValueFunc(engine.formValueFunc) + ctx.SetBinder(engine.binder) + ctx.SetValidator(engine.validator, engine.validateTag) return ctx } @@ -891,6 +899,15 @@ func (engine *Engine) SetClientIPFunc(f app.ClientIP) { engine.clientIPFunc = f } +func (engine *Engine) SetBinder(binder binding.Binder) { + engine.binder = binder +} + +func (engine *Engine) SetValidator(validator binding.StructValidator, tag string) { + engine.validator = validator + engine.validateTag = tag +} + func (engine *Engine) SetFormValueFunc(f app.FormValueFunc) { engine.formValueFunc = f } From 8cd821765a43c6314003acff3eba145750f868fe Mon Sep 17 00:00:00 2001 From: fgy Date: Tue, 19 Sep 2023 16:32:36 +0800 Subject: [PATCH 78/91] feat: add utils test --- pkg/common/utils/utils_test.go | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/pkg/common/utils/utils_test.go b/pkg/common/utils/utils_test.go index 231ad930f..7462868c7 100644 --- a/pkg/common/utils/utils_test.go +++ b/pkg/common/utils/utils_test.go @@ -136,3 +136,19 @@ func TestUtilsNextLine(t *testing.T) { _, _, sErr = NextLine(singleHeaderStr) assert.DeepEqual(t, errNeedMore, sErr) } + +func TestFilterContentType(t *testing.T) { + contentType := "text/plain; charset=utf-8" + contentType = FilterContentType(contentType) + assert.DeepEqual(t, "text/plain", contentType) +} + +func TestGetNormalizeHeaderKey(t *testing.T) { + key := "content-type" + key = GetNormalizeHeaderKey(key, false) + assert.DeepEqual(t, "Content-Type", key) + + key = "content-type" + key = GetNormalizeHeaderKey(key, true) + assert.DeepEqual(t, "content-type", key) +} From 4664a4c499a2316770470e182363590361a889af Mon Sep 17 00:00:00 2001 From: fgy Date: Tue, 19 Sep 2023 17:56:29 +0800 Subject: [PATCH 79/91] feat: context test --- pkg/app/context_test.go | 68 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 68 insertions(+) diff --git a/pkg/app/context_test.go b/pkg/app/context_test.go index 482b08c00..35470199c 100644 --- a/pkg/app/context_test.go +++ b/pkg/app/context_test.go @@ -1513,6 +1513,74 @@ func TestBindForm(t *testing.T) { } } +type mockBinder struct{} + +func (m *mockBinder) Name() string { + return "test binder" +} + +func (m *mockBinder) Bind(request *protocol.Request, i interface{}, params param.Params) error { + return nil +} + +func (m *mockBinder) BindAndValidate(request *protocol.Request, i interface{}, params param.Params) error { + return nil +} + +func (m *mockBinder) BindQuery(request *protocol.Request, i interface{}) error { + return nil +} + +func (m *mockBinder) BindHeader(request *protocol.Request, i interface{}) error { + return nil +} + +func (m *mockBinder) BindPath(request *protocol.Request, i interface{}, params param.Params) error { + return nil +} + +func (m *mockBinder) BindForm(request *protocol.Request, i interface{}) error { + return nil +} + +func (m *mockBinder) BindJSON(request *protocol.Request, i interface{}) error { + return nil +} + +func (m *mockBinder) BindProtobuf(request *protocol.Request, i interface{}) error { + return nil +} + +func (m *mockBinder) ValidateTag() string { + return "test" +} + +func (m *mockBinder) SetValidateTag(s string) {} + +func TestSetBinder(t *testing.T) { + mockBind := &mockBinder{} + c := NewContext(0) + c.SetBinder(mockBind) + defer c.SetBinder(binding.DefaultBinder()) + type T struct{} + req := T{} + err := c.Bind(&req) + assert.NotNil(t, err) + err = c.BindAndValidate(&req) + assert.NotNil(t, err) + err = c.BindProtobuf(&req) + assert.NotNil(t, err) + err = c.BindJSON(&req) + assert.NotNil(t, err) + err = c.BindForm(&req) + assert.NotNil(t, err) + err = c.BindPath(&req) + assert.NotNil(t, err) + err = c.BindQuery(&req) + assert.NotNil(t, err) + err = c.BindHeader(&req) +} + func TestRequestContext_SetCookie(t *testing.T) { c := NewContext(0) c.SetCookie("user", "hertz", 1, "/", "localhost", protocol.CookieSameSiteLaxMode, true, true) From 19c66a6b885503d766a907bba56be1e574e37a19 Mon Sep 17 00:00:00 2001 From: fgy Date: Wed, 20 Sep 2023 20:51:30 +0800 Subject: [PATCH 80/91] refactor: refactor config --- pkg/app/context.go | 53 ++--- pkg/app/context_test.go | 10 +- pkg/app/server/binding/binder.go | 4 +- pkg/app/server/binding/binder_test.go | 98 +++++----- pkg/app/server/binding/config.go | 161 ++++++++++------ pkg/app/server/binding/default.go | 63 ++++-- .../internal/decoder/base_type_decoder.go | 8 +- .../decoder/customized_type_decoder.go | 49 +---- .../binding/internal/decoder/decoder.go | 43 +++-- .../internal/decoder/map_type_decoder.go | 5 +- .../decoder/multipart_file_decoder.go | 3 +- .../internal/decoder/slice_type_decoder.go | 13 +- .../internal/decoder/struct_type_decoder.go | 5 +- .../server/binding/internal/decoder/tag.go | 4 +- .../binding/internal/decoder/text_decoder.go | 27 +-- pkg/app/server/binding/tagexpr_bind_test.go | 27 ++- pkg/app/server/hertz_test.go | 170 ++++++++++++++++ pkg/app/server/option.go | 29 +++ pkg/common/config/option.go | 4 + pkg/common/config/option_test.go | 4 + pkg/route/engine.go | 58 ++++-- pkg/route/engine_test.go | 181 ++++++++++++++++++ pkg/route/routes_test.go | 1 + 23 files changed, 722 insertions(+), 298 deletions(-) diff --git a/pkg/app/context.go b/pkg/app/context.go index 1bed69de8..0acda77cf 100644 --- a/pkg/app/context.go +++ b/pkg/app/context.go @@ -234,9 +234,8 @@ type RequestContext struct { // clientIPFunc get form value by use custom function. formValueFunc FormValueFunc - binder binding.Binder - validator binding.StructValidator - validateTag string + binder binding.Binder + validator binding.StructValidator } // Flush is the shortcut for ctx.Response.GetHijackWriter().Flush(). @@ -260,9 +259,8 @@ func (ctx *RequestContext) SetBinder(binder binding.Binder) { ctx.binder = binder } -func (ctx *RequestContext) SetValidator(validator binding.StructValidator, tag string) { +func (ctx *RequestContext) SetValidator(validator binding.StructValidator) { ctx.validator = validator - ctx.validateTag = tag } func (ctx *RequestContext) GetTraceInfo() traceinfo.TraceInfo { @@ -1318,55 +1316,37 @@ func bodyAllowedForStatus(status int) bool { // BindAndValidate binds data from *RequestContext to obj and validates them if needed. // NOTE: obj should be a pointer. func (ctx *RequestContext) BindAndValidate(obj interface{}) error { - if ctx.binder != nil { - return ctx.binder.BindAndValidate(&ctx.Request, obj, ctx.Params) - } - return binding.BindAndValidate(&ctx.Request, obj, ctx.Params) + return ctx.binder.BindAndValidate(&ctx.Request, obj, ctx.Params) } // Bind binds data from *RequestContext to obj. // NOTE: obj should be a pointer. func (ctx *RequestContext) Bind(obj interface{}) error { - if ctx.binder != nil { - return ctx.binder.Bind(&ctx.Request, obj, ctx.Params) - } - return binding.Bind(&ctx.Request, obj, ctx.Params) + return ctx.binder.Bind(&ctx.Request, obj, ctx.Params) } // Validate validates obj with "vd" tag // NOTE: obj should be a pointer. func (ctx *RequestContext) Validate(obj interface{}) error { - if ctx.validator != nil { - return ctx.validator.ValidateStruct(obj) - } - return binding.Validate(obj) + return ctx.validator.ValidateStruct(obj) } // BindQuery binds query parameters from *RequestContext to obj with 'query' tag. It will only use 'query' tag for binding. // NOTE: obj should be a pointer. func (ctx *RequestContext) BindQuery(obj interface{}) error { - if ctx.binder != nil { - return ctx.binder.BindQuery(&ctx.Request, obj) - } - return binding.DefaultBinder().BindQuery(&ctx.Request, obj) + return ctx.binder.BindQuery(&ctx.Request, obj) } // BindHeader binds header parameters from *RequestContext to obj with 'header' tag. It will only use 'header' tag for binding. // NOTE: obj should be a pointer. func (ctx *RequestContext) BindHeader(obj interface{}) error { - if ctx.binder != nil { - return ctx.binder.BindHeader(&ctx.Request, obj) - } - return binding.DefaultBinder().BindHeader(&ctx.Request, obj) + return ctx.binder.BindHeader(&ctx.Request, obj) } // BindPath binds router parameters from *RequestContext to obj with 'path' tag. It will only use 'path' tag for binding. // NOTE: obj should be a pointer. func (ctx *RequestContext) BindPath(obj interface{}) error { - if ctx.binder != nil { - return ctx.binder.BindPath(&ctx.Request, obj, ctx.Params) - } - return binding.DefaultBinder().BindPath(&ctx.Request, obj, ctx.Params) + return ctx.binder.BindPath(&ctx.Request, obj, ctx.Params) } // BindForm binds form parameters from *RequestContext to obj with 'form' tag. It will only use 'form' tag for binding. @@ -1375,28 +1355,19 @@ func (ctx *RequestContext) BindForm(obj interface{}) error { if len(ctx.Request.Body()) == 0 { return fmt.Errorf("missing form body") } - if ctx.binder != nil { - return ctx.binder.BindForm(&ctx.Request, obj) - } - return binding.DefaultBinder().BindForm(&ctx.Request, obj) + return ctx.binder.BindForm(&ctx.Request, obj) } // BindJSON binds JSON body from *RequestContext. // NOTE: obj should be a pointer. func (ctx *RequestContext) BindJSON(obj interface{}) error { - if ctx.binder != nil { - return ctx.binder.BindJSON(&ctx.Request, obj) - } - return binding.DefaultBinder().BindJSON(&ctx.Request, obj) + return ctx.binder.BindJSON(&ctx.Request, obj) } // BindProtobuf binds protobuf body from *RequestContext. // NOTE: obj should be a pointer. func (ctx *RequestContext) BindProtobuf(obj interface{}) error { - if ctx.binder != nil { - return ctx.binder.BindProtobuf(&ctx.Request, obj) - } - return binding.DefaultBinder().BindProtobuf(&ctx.Request, obj) + return ctx.binder.BindProtobuf(&ctx.Request, obj) } // BindByContentType will select the binding type on the ContentType automatically. diff --git a/pkg/app/context_test.go b/pkg/app/context_test.go index 35470199c..cba91a75a 100644 --- a/pkg/app/context_test.go +++ b/pkg/app/context_test.go @@ -887,8 +887,8 @@ func (m *mockValidator) Engine() interface{} { func TestSetValidator(t *testing.T) { m := &mockValidator{} c := NewContext(0) - c.SetValidator(m, "vt") - defer c.SetValidator(binding.DefaultValidator(), "vd") + c.SetValidator(m) + c.SetBinder(binding.NewDefaultBinder(&binding.BindConfig{ValidateTag: "vt"})) type User struct { Age int `vt:"$>=0&&$<=130"` } @@ -1551,12 +1551,6 @@ func (m *mockBinder) BindProtobuf(request *protocol.Request, i interface{}) erro return nil } -func (m *mockBinder) ValidateTag() string { - return "test" -} - -func (m *mockBinder) SetValidateTag(s string) {} - func TestSetBinder(t *testing.T) { mockBind := &mockBinder{} c := NewContext(0) diff --git a/pkg/app/server/binding/binder.go b/pkg/app/server/binding/binder.go index e36acda10..291d22a04 100644 --- a/pkg/app/server/binding/binder.go +++ b/pkg/app/server/binding/binder.go @@ -55,11 +55,9 @@ type Binder interface { BindForm(*protocol.Request, interface{}) error BindJSON(*protocol.Request, interface{}) error BindProtobuf(*protocol.Request, interface{}) error - ValidateTag() string - SetValidateTag(string) } -var defaultBind Binder = &defaultBinder{validateTag: "vd"} +var defaultBind = NewDefaultBinder(nil) func DefaultBinder() Binder { return defaultBind diff --git a/pkg/app/server/binding/binder_test.go b/pkg/app/server/binding/binder_test.go index 0ba2f1e06..696010d2a 100644 --- a/pkg/app/server/binding/binder_test.go +++ b/pkg/app/server/binding/binder_test.go @@ -57,10 +57,6 @@ import ( "google.golang.org/protobuf/proto" ) -func init() { - SetLooseZeroMode(true) -} - type mockRequest struct { Req *protocol.Request } @@ -407,7 +403,10 @@ func TestBind_ZeroValueBind(t *testing.T) { req := newMockRequest(). SetRequestURI("http://foobar.com?a=&b") - err := DefaultBinder().Bind(req.Req, &s, nil) + bindConfig := &BindConfig{} + bindConfig.LooseZeroMode = true + binder := NewDefaultBinder(bindConfig) + err := binder.Bind(req.Req, &s, nil) if err != nil { t.Fatal(err) } @@ -530,7 +529,8 @@ func TestBind_CustomizedTypeDecode(t *testing.T) { F ***CustomizedDecode } - err := RegTypeUnmarshal(reflect.TypeOf(CustomizedDecode{}), func(req *protocol.Request, params param.Params, text string) (reflect.Value, error) { + bindConfig := &BindConfig{} + err := bindConfig.RegTypeUnmarshal(reflect.TypeOf(CustomizedDecode{}), func(req *protocol.Request, params param.Params, text string) (reflect.Value, error) { q1 := req.URI().QueryArgs().Peek("a") if len(q1) == 0 { return reflect.Value{}, fmt.Errorf("can be nil") @@ -543,11 +543,12 @@ func TestBind_CustomizedTypeDecode(t *testing.T) { if err != nil { t.Fatal(err) } + binder := NewDefaultBinder(bindConfig) req := newMockRequest(). SetRequestURI("http://foobar.com?a=1&b=2") result := Foo{} - err = DefaultBinder().Bind(req.Req, &result, nil) + err = binder.Bind(req.Req, &result, nil) if err != nil { t.Fatal(err) } @@ -558,7 +559,7 @@ func TestBind_CustomizedTypeDecode(t *testing.T) { } result2 := Bar{} - err = DefaultBinder().Bind(req.Req, &result2, nil) + err = binder.Bind(req.Req, &result2, nil) if err != nil { t.Error(err) } @@ -572,9 +573,11 @@ func TestBind_CustomizedTypeDecodeForPanic(t *testing.T) { } }() - MustRegTypeUnmarshal(reflect.TypeOf(string("")), func(req *protocol.Request, params param.Params, text string) (reflect.Value, error) { + bindConfig := &BindConfig{} + bindConfig.MustRegTypeUnmarshal(reflect.TypeOf(string("")), func(req *protocol.Request, params param.Params, text string) (reflect.Value, error) { return reflect.Value{}, nil }) + } func TestBind_JSON(t *testing.T) { @@ -607,7 +610,9 @@ func TestBind_JSON(t *testing.T) { } func TestBind_ResetJSONUnmarshal(t *testing.T) { - UseStdJSONUnmarshaler() + bindConfig := &BindConfig{} + bindConfig.UseStdJSONUnmarshaler() + binder := NewDefaultBinder(bindConfig) type Req struct { J1 string `json:"j1"` J2 int `json:"j2"` @@ -621,7 +626,7 @@ func TestBind_ResetJSONUnmarshal(t *testing.T) { SetJSONContentType(). SetBody([]byte(fmt.Sprintf(`{"j1":"j1", "j2":12, "j3":[%d, %d], "j4":["%s", "%s"]}`, J3s[0], J3s[1], J4s[0], J4s[1]))) var result Req - err := DefaultBinder().Bind(req.Req, &result, nil) + err := binder.Bind(req.Req, &result, nil) if err != nil { t.Error(err) } @@ -799,12 +804,11 @@ func TestBind_DefaultTag(t *testing.T) { assert.DeepEqual(t, "header", result.Header) assert.DeepEqual(t, "form", result.Form) - EnableDefaultTag(false) - defer func() { - EnableDefaultTag(true) - }() + bindConfig := &BindConfig{} + bindConfig.EnableDefaultTag = false + binder := NewDefaultBinder(bindConfig) result2 := Req2{} - err = DefaultBinder().Bind(req.Req, &result2, params) + err = binder.Bind(req.Req, &result2, params) if err != nil { t.Error(err) } @@ -829,11 +833,10 @@ func TestBind_StructFieldResolve(t *testing.T) { SetPostArg("Form", "form"). SetUrlEncodeContentType() var result Req - EnableStructFieldResolve(true) - defer func() { - EnableStructFieldResolve(false) - }() - err := DefaultBinder().Bind(req.Req, &result, nil) + bindConfig := &BindConfig{} + bindConfig.EnableStructFieldResolve = true + binder := NewDefaultBinder(bindConfig) + err := binder.Bind(req.Req, &result, nil) if err != nil { t.Error(err) } @@ -950,8 +953,9 @@ func TestBind_BindQuery(t *testing.T) { } func TestBind_LooseMode(t *testing.T) { - SetLooseZeroMode(false) - defer SetLooseZeroMode(true) + bindConfig := &BindConfig{} + bindConfig.LooseZeroMode = false + binder := NewDefaultBinder(bindConfig) type Req struct { ID int `query:"id"` } @@ -961,16 +965,17 @@ func TestBind_LooseMode(t *testing.T) { var result Req - err := DefaultBinder().Bind(req.Req, &result, nil) + err := binder.Bind(req.Req, &result, nil) if err == nil { t.Fatal("expected err") } assert.DeepEqual(t, 0, result.ID) - SetLooseZeroMode(true) + bindConfig.LooseZeroMode = true + binder = NewDefaultBinder(bindConfig) var result2 Req - err = DefaultBinder().Bind(req.Req, &result2, nil) + err = binder.Bind(req.Req, &result2, nil) if err != nil { t.Error(err) } @@ -1190,8 +1195,9 @@ func TestBind_BindProtobuf(t *testing.T) { } func TestBind_PointerStruct(t *testing.T) { - EnableStructFieldResolve(true) - defer EnableStructFieldResolve(false) + bindConfig := &BindConfig{} + bindConfig.EnableStructFieldResolve = true + binder := NewDefaultBinder(bindConfig) type Foo struct { F1 string `query:"F1"` } @@ -1205,7 +1211,7 @@ func TestBind_PointerStruct(t *testing.T) { req := newMockRequest(). SetRequestURI(fmt.Sprintf("http://foobar.com?%s", query.Encode())) - err := DefaultBinder().Bind(req.Req, &result, nil) + err := binder.Bind(req.Req, &result, nil) if err != nil { t.Error(err) } @@ -1214,7 +1220,7 @@ func TestBind_PointerStruct(t *testing.T) { result = Bar{} req = newMockRequest(). SetRequestURI(fmt.Sprintf("http://foobar.com?%s&F1=222", query.Encode())) - err = DefaultBinder().Bind(req.Req, &result, nil) + err = binder.Bind(req.Req, &result, nil) if err != nil { t.Error(err) } @@ -1222,8 +1228,9 @@ func TestBind_PointerStruct(t *testing.T) { } func TestBind_StructRequired(t *testing.T) { - EnableStructFieldResolve(true) - defer EnableStructFieldResolve(false) + bindConfig := &BindConfig{} + bindConfig.EnableStructFieldResolve = true + binder := NewDefaultBinder(bindConfig) type Foo struct { F1 string `query:"F1"` } @@ -1235,7 +1242,7 @@ func TestBind_StructRequired(t *testing.T) { req := newMockRequest(). SetRequestURI("http://foobar.com") - err := DefaultBinder().Bind(req.Req, &result, nil) + err := binder.Bind(req.Req, &result, nil) if err == nil { t.Error("expect an error, but get nil") } @@ -1247,15 +1254,16 @@ func TestBind_StructRequired(t *testing.T) { req = newMockRequest(). SetRequestURI("http://foobar.com") - err = DefaultBinder().Bind(req.Req, &result2, nil) + err = binder.Bind(req.Req, &result2, nil) if err != nil { t.Error(err) } } func TestBind_StructErrorToWarn(t *testing.T) { - EnableStructFieldResolve(true) - defer EnableStructFieldResolve(false) + bindConfig := &BindConfig{} + bindConfig.EnableStructFieldResolve = true + binder := NewDefaultBinder(bindConfig) type Foo struct { F1 string `query:"F1"` } @@ -1267,7 +1275,7 @@ func TestBind_StructErrorToWarn(t *testing.T) { req := newMockRequest(). SetRequestURI("http://foobar.com?B1=111&F1=222") - err := DefaultBinder().Bind(req.Req, &result, nil) + err := binder.Bind(req.Req, &result, nil) // transfer 'unmarsahl err' to 'warn' if err != nil { t.Error(err) @@ -1278,7 +1286,7 @@ func TestBind_StructErrorToWarn(t *testing.T) { B1 Foo `query:"B1,required"` } var result2 Bar2 - err = DefaultBinder().Bind(req.Req, &result2, nil) + err = binder.Bind(req.Req, &result2, nil) // transfer 'unmarsahl err' to 'warn' if err != nil { t.Error(err) @@ -1287,8 +1295,9 @@ func TestBind_StructErrorToWarn(t *testing.T) { } func TestBind_DisallowUnknownFieldsConfig(t *testing.T) { - EnableDecoderDisallowUnknownFields(true) - defer EnableDecoderDisallowUnknownFields(false) + bindConfig := &BindConfig{} + bindConfig.EnableDecoderDisallowUnknownFields = true + binder := NewDefaultBinder(bindConfig) type FooStructUseNumber struct { Foo interface{} `json:"foo"` } @@ -1298,15 +1307,16 @@ func TestBind_DisallowUnknownFieldsConfig(t *testing.T) { SetBody([]byte(`{"foo": 123,"bar": "456"}`)) var result FooStructUseNumber - err := DefaultBinder().BindJSON(req.Req, &result) + err := binder.BindJSON(req.Req, &result) if err == nil { t.Errorf("expected an error, but get nil") } } func TestBind_UseNumberConfig(t *testing.T) { - EnableDecoderUseNumber(true) - defer EnableDecoderUseNumber(false) + bindConfig := &BindConfig{} + bindConfig.EnableDecoderUseNumber = true + binder := NewDefaultBinder(bindConfig) type FooStructUseNumber struct { Foo interface{} `json:"foo"` } @@ -1316,7 +1326,7 @@ func TestBind_UseNumberConfig(t *testing.T) { SetBody([]byte(`{"foo": 123}`)) var result FooStructUseNumber - err := DefaultBinder().BindJSON(req.Req, &result) + err := binder.BindJSON(req.Req, &result) if err != nil { t.Error(err) } diff --git a/pkg/app/server/binding/config.go b/pkg/app/server/binding/config.go index 1ba2ba3eb..103d2629d 100644 --- a/pkg/app/server/binding/config.go +++ b/pkg/app/server/binding/config.go @@ -18,35 +18,121 @@ package binding import ( stdJson "encoding/json" + "fmt" "reflect" + "time" "github.com/bytedance/go-tagexpr/v2/validator" - "github.com/cloudwego/hertz/pkg/app/server/binding/internal/decoder" + inDecoder "github.com/cloudwego/hertz/pkg/app/server/binding/internal/decoder" hJson "github.com/cloudwego/hertz/pkg/common/json" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/route/param" ) -var ( - enableDecoderUseNumber = false - enableDecoderDisallowUnknownFields = false -) +// BindConfig contains options for default bind behavior. +type BindConfig struct { + // LooseZeroMode if set to true, + // the empty string request parameter is bound to the zero value of parameter. + // NOTE: + // The default is false. + // Suitable for these parameter types: query/header/cookie/form . + LooseZeroMode bool + // EnableDefaultTag is used to add default tags to a field when it has no tag + // If is true, the field with no tag will be added default tags, for more automated binding. But there may be additional overhead. + // NOTE: + // The default is true. + EnableDefaultTag bool + // EnableStructFieldResolve is used to generate a separate decoder for a struct. + // If is true, the 'struct' field will get a single inDecoder.structTypeFieldTextDecoder, and use json.Unmarshal for decode it. + // It usually used to add json string to query parameter. + // NOTE: + // The default is true. + EnableStructFieldResolve bool + // EnableDecoderUseNumber is used to call the UseNumber method on the JSON + // Decoder instance. UseNumber causes the Decoder to unmarshal a number into an + // interface{} as a Number instead of as a float64. + // NOTE: + // The default is false. + // It is used for BindJSON(). + EnableDecoderUseNumber bool + // EnableDecoderDisallowUnknownFields is used to call the DisallowUnknownFields method + // on the JSON Decoder instance. DisallowUnknownFields causes the Decoder to + // return an error when the destination is a struct and the input contains object + // keys which do not match any non-ignored, exported fields in the destination. + // NOTE: + // The default is false. + // It is used for BindJSON(). + EnableDecoderDisallowUnknownFields bool + // ValidateTag is used to determine if a filed needs to be validated. + // NOTE: + // The default is "vd". + ValidateTag string + // TypeUnmarshalFuncs registers customized type unmarshaler. + // NOTE: + // time.Time is registered by default + TypeUnmarshalFuncs map[reflect.Type]inDecoder.CustomizeDecodeFunc + // Validator is used to validate for BindAndValidate() + Validator StructValidator +} -// SetLooseZeroMode if set to true, -// the empty string request parameter is bound to the zero value of parameter. -// NOTE: -// -// The default is false; -// Suitable for these parameter types: query/header/cookie/form . -func SetLooseZeroMode(enable bool) { - decoder.SetLooseZeroMode(enable) +func NewBindConfig() *BindConfig { + return &BindConfig{ + LooseZeroMode: false, + EnableDefaultTag: true, + EnableStructFieldResolve: true, + EnableDecoderUseNumber: false, + EnableDecoderDisallowUnknownFields: false, + ValidateTag: "vd", + TypeUnmarshalFuncs: make(map[reflect.Type]inDecoder.CustomizeDecodeFunc), + Validator: defaultValidate, + } +} + +// RegTypeUnmarshal registers customized type unmarshaler. +func (config *BindConfig) RegTypeUnmarshal(t reflect.Type, fn inDecoder.CustomizeDecodeFunc) error { + // check + switch t.Kind() { + case reflect.String, reflect.Bool, + reflect.Float32, reflect.Float64, + reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8, + reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8: + return fmt.Errorf("registration type cannot be a basic type") + case reflect.Ptr: + return fmt.Errorf("registration type cannot be a pointer type") + } + if config.TypeUnmarshalFuncs == nil { + config.TypeUnmarshalFuncs = make(map[reflect.Type]inDecoder.CustomizeDecodeFunc) + } + config.TypeUnmarshalFuncs[t] = fn + return nil +} + +// MustRegTypeUnmarshal registers customized type unmarshaler. It will panic if exist error. +func (config *BindConfig) MustRegTypeUnmarshal(t reflect.Type, fn func(req *protocol.Request, params param.Params, text string) (reflect.Value, error)) { + err := config.RegTypeUnmarshal(t, fn) + if err != nil { + panic(err) + } +} + +func (config *BindConfig) initTypeUnmarshal() { + config.MustRegTypeUnmarshal(reflect.TypeOf(time.Time{}), func(req *protocol.Request, params param.Params, text string) (reflect.Value, error) { + if text == "" { + return reflect.ValueOf(time.Time{}), nil + } + t, err := time.Parse(time.RFC3339, text) + if err != nil { + return reflect.Value{}, err + } + return reflect.ValueOf(t), nil + }) } // UseThirdPartyJSONUnmarshaler uses third-party json library for binding // NOTE: // // UseThirdPartyJSONUnmarshaler will remain in effect once it has been called. -func UseThirdPartyJSONUnmarshaler(fn func(data []byte, v interface{}) error) { +func (config *BindConfig) UseThirdPartyJSONUnmarshaler(fn func(data []byte, v interface{}) error) { hJson.Unmarshal = fn } @@ -55,63 +141,26 @@ func UseThirdPartyJSONUnmarshaler(fn func(data []byte, v interface{}) error) { // // The current version uses encoding/json by default. // UseStdJSONUnmarshaler will remain in effect once it has been called. -func UseStdJSONUnmarshaler() { - UseThirdPartyJSONUnmarshaler(stdJson.Unmarshal) -} - -// EnableDefaultTag is used to enable or disable adding default tags to a field when it has no tag, it is true by default. -// If is true, the field with no tag will be added default tags, for more automated parameter binding. But there may be additional overhead -func EnableDefaultTag(b bool) { - decoder.EnableDefaultTag = b -} - -// EnableStructFieldResolve to enable or disable the generation of a separate decoder for a struct, it is false by default. -// If is true, the 'struct' field will get a single decoder.structTypeFieldTextDecoder, and use json.Unmarshal for decode it. -func EnableStructFieldResolve(b bool) { - decoder.EnableStructFieldResolve = b -} - -// RegTypeUnmarshal registers customized type unmarshaler. -func RegTypeUnmarshal(t reflect.Type, fn func(req *protocol.Request, params param.Params, text string) (reflect.Value, error)) error { - return decoder.RegTypeUnmarshal(t, fn) +func (config *BindConfig) UseStdJSONUnmarshaler() { + config.UseThirdPartyJSONUnmarshaler(stdJson.Unmarshal) } -// MustRegTypeUnmarshal registers customized type unmarshaler. It will panic if exist error. -func MustRegTypeUnmarshal(t reflect.Type, fn func(req *protocol.Request, params param.Params, text string) (reflect.Value, error)) { - decoder.MustRegTypeUnmarshal(t, fn) -} +type ValidateConfig struct{} // MustRegValidateFunc registers validator function expression. // NOTE: // // If force=true, allow to cover the existed same funcName. // MustRegValidateFunc will remain in effect once it has been called. -func MustRegValidateFunc(funcName string, fn func(args ...interface{}) error, force ...bool) { +func (config *ValidateConfig) MustRegValidateFunc(funcName string, fn func(args ...interface{}) error, force ...bool) { validator.MustRegFunc(funcName, fn, force...) } // SetValidatorErrorFactory customizes the factory of validation error. -func SetValidatorErrorFactory(validatingErrFactory func(failField, msg string) error) { +func (config *ValidateConfig) SetValidatorErrorFactory(validatingErrFactory func(failField, msg string) error) { if val, ok := DefaultValidator().(*defaultValidator); ok { val.validate.SetErrorFactory(validatingErrFactory) } else { panic("customized validator can not use 'SetValidatorErrorFactory'") } } - -// EnableDecoderUseNumber is used to call the UseNumber method on the JSON -// Decoder instance. UseNumber causes the Decoder to unmarshal a number into an -// interface{} as a Number instead of as a float64. -// NOTE: it is used for BindJSON(). -func EnableDecoderUseNumber(b bool) { - enableDecoderUseNumber = b -} - -// EnableDecoderDisallowUnknownFields is used to call the DisallowUnknownFields method -// on the JSON Decoder instance. DisallowUnknownFields causes the Decoder to -// return an error when the destination is a struct and the input contains object -// keys which do not match any non-ignored, exported fields in the destination. -// NOTE: it is used for BindJSON(). -func EnableDecoderDisallowUnknownFields(b bool) { - enableDecoderDisallowUnknownFields = b -} diff --git a/pkg/app/server/binding/default.go b/pkg/app/server/binding/default.go index 63c8eb74d..cc241d820 100644 --- a/pkg/app/server/binding/default.go +++ b/pkg/app/server/binding/default.go @@ -93,7 +93,7 @@ type decoderInfo struct { } type defaultBinder struct { - validateTag string + config *BindConfig decoderCache sync.Map queryDecoderCache sync.Map formDecoderCache sync.Map @@ -101,6 +101,20 @@ type defaultBinder struct { pathDecoderCache sync.Map } +func NewDefaultBinder(config *BindConfig) Binder { + if config == nil { + bindConfig := NewBindConfig() + bindConfig.initTypeUnmarshal() + return &defaultBinder{ + config: bindConfig, + } + } + config.initTypeUnmarshal() + return &defaultBinder{ + config: config, + } +} + // BindAndValidate binds data from *protocol.Request to obj and validates them if needed. // NOTE: // @@ -163,7 +177,16 @@ func (b *defaultBinder) bindTag(req *protocol.Request, v interface{}, params par return decoder.decoder(req, params, rv.Elem()) } - decoder, needValidate, err := inDecoder.GetReqDecoder(rv.Type(), tag, b.ValidateTag()) + decodeConfig := &inDecoder.DecodeConfig{ + LooseZeroMode: b.config.LooseZeroMode, + EnableDefaultTag: b.config.EnableDefaultTag, + EnableStructFieldResolve: b.config.EnableStructFieldResolve, + EnableDecoderUseNumber: b.config.EnableDecoderUseNumber, + EnableDecoderDisallowUnknownFields: b.config.EnableDecoderDisallowUnknownFields, + ValidateTag: b.config.ValidateTag, + TypeUnmarshalFuncs: b.config.TypeUnmarshalFuncs, + } + decoder, needValidate, err := inDecoder.GetReqDecoder(rv.Type(), tag, decodeConfig) if err != nil { return err } @@ -196,12 +219,20 @@ func (b *defaultBinder) bindTagWithValidate(req *protocol.Request, v interface{} return err } if decoder.needValidate { - err = DefaultValidator().ValidateStruct(rv.Elem()) + err = b.config.Validator.ValidateStruct(rv.Elem()) } return err } - - decoder, needValidate, err := inDecoder.GetReqDecoder(rv.Type(), tag, b.ValidateTag()) + decodeConfig := &inDecoder.DecodeConfig{ + LooseZeroMode: b.config.LooseZeroMode, + EnableDefaultTag: b.config.EnableDefaultTag, + EnableStructFieldResolve: b.config.EnableStructFieldResolve, + EnableDecoderUseNumber: b.config.EnableDecoderUseNumber, + EnableDecoderDisallowUnknownFields: b.config.EnableDecoderDisallowUnknownFields, + ValidateTag: b.config.ValidateTag, + TypeUnmarshalFuncs: b.config.TypeUnmarshalFuncs, + } + decoder, needValidate, err := inDecoder.GetReqDecoder(rv.Type(), tag, decodeConfig) if err != nil { return err } @@ -212,7 +243,7 @@ func (b *defaultBinder) bindTagWithValidate(req *protocol.Request, v interface{} return err } if needValidate { - err = DefaultValidator().ValidateStruct(rv.Elem()) + err = b.config.Validator.ValidateStruct(rv.Elem()) } return err } @@ -234,15 +265,15 @@ func (b *defaultBinder) BindForm(req *protocol.Request, v interface{}) error { } func (b *defaultBinder) BindJSON(req *protocol.Request, v interface{}) error { - return decodeJSON(bytes.NewReader(req.Body()), v) + return b.decodeJSON(bytes.NewReader(req.Body()), v) } -func decodeJSON(r io.Reader, obj interface{}) error { +func (b *defaultBinder) decodeJSON(r io.Reader, obj interface{}) error { decoder := hJson.NewDecoder(r) - if enableDecoderUseNumber { + if b.config.EnableDecoderUseNumber { decoder.UseNumber() } - if enableDecoderDisallowUnknownFields { + if b.config.EnableDecoderDisallowUnknownFields { decoder.DisallowUnknownFields() } @@ -269,14 +300,6 @@ func (b *defaultBinder) Bind(req *protocol.Request, v interface{}, params param. return b.bindTag(req, v, params, "") } -func (b *defaultBinder) ValidateTag() string { - return b.validateTag -} - -func (b *defaultBinder) SetValidateTag(tag string) { - b.validateTag = tag -} - // best effort binding func (b *defaultBinder) preBindBody(req *protocol.Request, v interface{}) error { if req.Header.ContentLength() <= 0 { @@ -346,6 +369,10 @@ type defaultValidator struct { validate *validator.Validator } +func NewDefaultValidator(config *ValidateConfig) StructValidator { + return defaultValidate +} + // ValidateStruct receives any kind of type, but only performed struct or pointer to struct type. func (v *defaultValidator) ValidateStruct(obj interface{}) error { if obj == nil { diff --git a/pkg/app/server/binding/internal/decoder/base_type_decoder.go b/pkg/app/server/binding/internal/decoder/base_type_decoder.go index dd16b0b47..c3f4346f3 100644 --- a/pkg/app/server/binding/internal/decoder/base_type_decoder.go +++ b/pkg/app/server/binding/internal/decoder/base_type_decoder.go @@ -55,6 +55,7 @@ type fieldInfo struct { fieldName string tagInfos []TagInfo fieldType reflect.Type + config *DecodeConfig } type baseTypeFieldTextDecoder struct { @@ -114,7 +115,7 @@ func (d *baseTypeFieldTextDecoder) Decode(req *protocol.Request, params param.Pa ptrDepth++ } var vv reflect.Value - vv, err := stringToValue(t, text, req, params) + vv, err := stringToValue(t, text, req, params, d.config) if err != nil { return err } @@ -123,7 +124,7 @@ func (d *baseTypeFieldTextDecoder) Decode(req *protocol.Request, params param.Pa } // Non-pointer elems - err = d.decoder.UnmarshalString(text, field) + err = d.decoder.UnmarshalString(text, field, d.config.LooseZeroMode) if err != nil { return fmt.Errorf("unable to decode '%s' as %s: %w", text, d.fieldType.Name(), err) } @@ -131,7 +132,7 @@ func (d *baseTypeFieldTextDecoder) Decode(req *protocol.Request, params param.Pa return nil } -func getBaseTypeTextDecoder(field reflect.StructField, index int, tagInfos []TagInfo, parentIdx []int) ([]fieldDecoder, error) { +func getBaseTypeTextDecoder(field reflect.StructField, index int, tagInfos []TagInfo, parentIdx []int, config *DecodeConfig) ([]fieldDecoder, error) { for idx, tagInfo := range tagInfos { switch tagInfo.Key { case pathTag: @@ -177,6 +178,7 @@ func getBaseTypeTextDecoder(field reflect.StructField, index int, tagInfos []Tag fieldName: field.Name, tagInfos: tagInfos, fieldType: fieldType, + config: config, }, decoder: textDecoder, }}, nil diff --git a/pkg/app/server/binding/internal/decoder/customized_type_decoder.go b/pkg/app/server/binding/internal/decoder/customized_type_decoder.go index 37966ad27..b192b3bda 100644 --- a/pkg/app/server/binding/internal/decoder/customized_type_decoder.go +++ b/pkg/app/server/binding/internal/decoder/customized_type_decoder.go @@ -41,60 +41,18 @@ package decoder import ( - "fmt" "reflect" - "time" "github.com/cloudwego/hertz/pkg/common/utils" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/route/param" ) -func init() { - MustRegTypeUnmarshal(reflect.TypeOf(time.Time{}), func(req *protocol.Request, params param.Params, text string) (reflect.Value, error) { - if text == "" { - return reflect.ValueOf(time.Time{}), nil - } - t, err := time.Parse(time.RFC3339, text) - if err != nil { - return reflect.Value{}, err - } - return reflect.ValueOf(t), nil - }) -} - -type customizeDecodeFunc func(req *protocol.Request, params param.Params, text string) (reflect.Value, error) - -var typeUnmarshalFuncs = make(map[reflect.Type]customizeDecodeFunc) - -// RegTypeUnmarshal registers customized type unmarshaler. -func RegTypeUnmarshal(t reflect.Type, fn customizeDecodeFunc) error { - // check - switch t.Kind() { - case reflect.String, reflect.Bool, - reflect.Float32, reflect.Float64, - reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8, - reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8: - return fmt.Errorf("registration type cannot be a basic type") - case reflect.Ptr: - return fmt.Errorf("registration type cannot be a pointer type") - } - - typeUnmarshalFuncs[t] = fn - return nil -} - -// MustRegTypeUnmarshal registers customized type unmarshaler. It will panic if exist error. -func MustRegTypeUnmarshal(t reflect.Type, fn customizeDecodeFunc) { - err := RegTypeUnmarshal(t, fn) - if err != nil { - panic(err) - } -} +type CustomizeDecodeFunc func(req *protocol.Request, params param.Params, text string) (reflect.Value, error) type customizedFieldTextDecoder struct { fieldInfo - decodeFunc customizeDecodeFunc + decodeFunc CustomizeDecodeFunc } func (d *customizedFieldTextDecoder) Decode(req *protocol.Request, params param.Params, reqValue reflect.Value) error { @@ -141,7 +99,7 @@ func (d *customizedFieldTextDecoder) Decode(req *protocol.Request, params param. return nil } -func getCustomizedFieldDecoder(field reflect.StructField, index int, tagInfos []TagInfo, parentIdx []int, decodeFunc customizeDecodeFunc) ([]fieldDecoder, error) { +func getCustomizedFieldDecoder(field reflect.StructField, index int, tagInfos []TagInfo, parentIdx []int, decodeFunc CustomizeDecodeFunc, config *DecodeConfig) ([]fieldDecoder, error) { for idx, tagInfo := range tagInfos { switch tagInfo.Key { case pathTag: @@ -180,6 +138,7 @@ func getCustomizedFieldDecoder(field reflect.StructField, index int, tagInfos [] fieldName: field.Name, tagInfos: tagInfos, fieldType: fieldType, + config: config, }, decodeFunc: decodeFunc, }}, nil diff --git a/pkg/app/server/binding/internal/decoder/decoder.go b/pkg/app/server/binding/internal/decoder/decoder.go index 7da7be69a..425adfc6b 100644 --- a/pkg/app/server/binding/internal/decoder/decoder.go +++ b/pkg/app/server/binding/internal/decoder/decoder.go @@ -49,18 +49,23 @@ import ( "github.com/cloudwego/hertz/pkg/route/param" ) -var ( - EnableDefaultTag = true - EnableStructFieldResolve = true -) - type fieldDecoder interface { Decode(req *protocol.Request, params param.Params, reqValue reflect.Value) error } type Decoder func(req *protocol.Request, params param.Params, rv reflect.Value) error -func GetReqDecoder(rt reflect.Type, byTag string, validateTag string) (Decoder, bool, error) { +type DecodeConfig struct { + LooseZeroMode bool + EnableDefaultTag bool + EnableStructFieldResolve bool + EnableDecoderUseNumber bool + EnableDecoderDisallowUnknownFields bool + ValidateTag string + TypeUnmarshalFuncs map[reflect.Type]CustomizeDecodeFunc +} + +func GetReqDecoder(rt reflect.Type, byTag string, config *DecodeConfig) (Decoder, bool, error) { var decoders []fieldDecoder var needValidate bool @@ -75,7 +80,7 @@ func GetReqDecoder(rt reflect.Type, byTag string, validateTag string) (Decoder, continue } - dec, needValidate2, err := getFieldDecoder(el.Field(i), i, []int{}, "", byTag, validateTag) + dec, needValidate2, err := getFieldDecoder(el.Field(i), i, []int{}, "", byTag, config) if err != nil { return nil, false, err } @@ -98,7 +103,7 @@ func GetReqDecoder(rt reflect.Type, byTag string, validateTag string) (Decoder, }, needValidate, nil } -func getFieldDecoder(field reflect.StructField, index int, parentIdx []int, parentJSONName string, byTag string, validateTag string) ([]fieldDecoder, bool, error) { +func getFieldDecoder(field reflect.StructField, index int, parentIdx []int, parentJSONName string, byTag string, config *DecodeConfig) ([]fieldDecoder, bool, error) { for field.Type.Kind() == reflect.Ptr { field.Type = field.Type.Elem() } @@ -111,8 +116,8 @@ func getFieldDecoder(field reflect.StructField, index int, parentIdx []int, pare } // JSONName is like 'a.b.c' for 'required validate' - fieldTagInfos, newParentJSONName, needValidate := lookupFieldTags(field, parentJSONName, validateTag) - if len(fieldTagInfos) == 0 && EnableDefaultTag { + fieldTagInfos, newParentJSONName, needValidate := lookupFieldTags(field, parentJSONName, config) + if len(fieldTagInfos) == 0 && config.EnableDefaultTag { fieldTagInfos = getDefaultFieldTags(field) } if len(byTag) != 0 { @@ -120,20 +125,20 @@ func getFieldDecoder(field reflect.StructField, index int, parentIdx []int, pare } // customized type decoder has the highest priority - if customizedFunc, exist := typeUnmarshalFuncs[field.Type]; exist { - dec, err := getCustomizedFieldDecoder(field, index, fieldTagInfos, parentIdx, customizedFunc) + if customizedFunc, exist := config.TypeUnmarshalFuncs[field.Type]; exist { + dec, err := getCustomizedFieldDecoder(field, index, fieldTagInfos, parentIdx, customizedFunc, config) return dec, needValidate, err } // slice/array field decoder if field.Type.Kind() == reflect.Slice || field.Type.Kind() == reflect.Array { - dec, err := getSliceFieldDecoder(field, index, fieldTagInfos, parentIdx) + dec, err := getSliceFieldDecoder(field, index, fieldTagInfos, parentIdx, config) return dec, needValidate, err } // map filed decoder if field.Type.Kind() == reflect.Map { - dec, err := getMapTypeTextDecoder(field, index, fieldTagInfos, parentIdx) + dec, err := getMapTypeTextDecoder(field, index, fieldTagInfos, parentIdx, config) return dec, needValidate, err } @@ -144,11 +149,11 @@ func getFieldDecoder(field reflect.StructField, index int, parentIdx []int, pare // todo: more built-in common struct binding, ex. time... switch el { case reflect.TypeOf(multipart.FileHeader{}): // file binding - dec, err := getMultipartFileDecoder(field, index, fieldTagInfos, parentIdx) + dec, err := getMultipartFileDecoder(field, index, fieldTagInfos, parentIdx, config) return dec, needValidate, err } - if EnableStructFieldResolve { // decode struct type separately - structFieldDecoder, err := getStructTypeFieldDecoder(field, index, fieldTagInfos, parentIdx) + if config.EnableStructFieldResolve { // decode struct type separately + structFieldDecoder, err := getStructTypeFieldDecoder(field, index, fieldTagInfos, parentIdx, config) if err != nil { return nil, needValidate, err } @@ -167,7 +172,7 @@ func getFieldDecoder(field reflect.StructField, index int, parentIdx []int, pare idxes = append(idxes, parentIdx...) } idxes = append(idxes, index) - dec, needValidate2, err := getFieldDecoder(el.Field(i), i, idxes, newParentJSONName, byTag, validateTag) + dec, needValidate2, err := getFieldDecoder(el.Field(i), i, idxes, newParentJSONName, byTag, config) needValidate = needValidate || needValidate2 if err != nil { return nil, false, err @@ -181,6 +186,6 @@ func getFieldDecoder(field reflect.StructField, index int, parentIdx []int, pare } // base type decoder - dec, err := getBaseTypeTextDecoder(field, index, fieldTagInfos, parentIdx) + dec, err := getBaseTypeTextDecoder(field, index, fieldTagInfos, parentIdx, config) return dec, needValidate, err } diff --git a/pkg/app/server/binding/internal/decoder/map_type_decoder.go b/pkg/app/server/binding/internal/decoder/map_type_decoder.go index 9bb5180b0..34b1104fa 100644 --- a/pkg/app/server/binding/internal/decoder/map_type_decoder.go +++ b/pkg/app/server/binding/internal/decoder/map_type_decoder.go @@ -106,7 +106,7 @@ func (d *mapTypeFieldTextDecoder) Decode(req *protocol.Request, params param.Par ptrDepth++ } var vv reflect.Value - vv, err := stringToValue(t, text, req, params) + vv, err := stringToValue(t, text, req, params, d.config) if err != nil { return fmt.Errorf("unable to decode '%s' as %s: %w", text, d.fieldType.Name(), err) } @@ -122,7 +122,7 @@ func (d *mapTypeFieldTextDecoder) Decode(req *protocol.Request, params param.Par return nil } -func getMapTypeTextDecoder(field reflect.StructField, index int, tagInfos []TagInfo, parentIdx []int) ([]fieldDecoder, error) { +func getMapTypeTextDecoder(field reflect.StructField, index int, tagInfos []TagInfo, parentIdx []int, config *DecodeConfig) ([]fieldDecoder, error) { for idx, tagInfo := range tagInfos { switch tagInfo.Key { case pathTag: @@ -163,6 +163,7 @@ func getMapTypeTextDecoder(field reflect.StructField, index int, tagInfos []TagI fieldName: field.Name, tagInfos: tagInfos, fieldType: fieldType, + config: config, }, }}, nil } diff --git a/pkg/app/server/binding/internal/decoder/multipart_file_decoder.go b/pkg/app/server/binding/internal/decoder/multipart_file_decoder.go index c37c0e292..ae32dfea5 100644 --- a/pkg/app/server/binding/internal/decoder/multipart_file_decoder.go +++ b/pkg/app/server/binding/internal/decoder/multipart_file_decoder.go @@ -141,7 +141,7 @@ func (d *fileTypeDecoder) fileSliceDecode(req *protocol.Request, params param.Pa return nil } -func getMultipartFileDecoder(field reflect.StructField, index int, tagInfos []TagInfo, parentIdx []int) ([]fieldDecoder, error) { +func getMultipartFileDecoder(field reflect.StructField, index int, tagInfos []TagInfo, parentIdx []int, config *DecodeConfig) ([]fieldDecoder, error) { fieldType := field.Type for field.Type.Kind() == reflect.Ptr { fieldType = field.Type.Elem() @@ -158,6 +158,7 @@ func getMultipartFileDecoder(field reflect.StructField, index int, tagInfos []Ta fieldName: field.Name, tagInfos: tagInfos, fieldType: fieldType, + config: config, }, isRepeated: isRepeated, }}, nil diff --git a/pkg/app/server/binding/internal/decoder/slice_type_decoder.go b/pkg/app/server/binding/internal/decoder/slice_type_decoder.go index be505f03c..66b93ff13 100644 --- a/pkg/app/server/binding/internal/decoder/slice_type_decoder.go +++ b/pkg/app/server/binding/internal/decoder/slice_type_decoder.go @@ -142,7 +142,7 @@ func (d *sliceTypeFieldTextDecoder) Decode(req *protocol.Request, params param.P for idx, text := range texts { var vv reflect.Value - vv, err = stringToValue(t, text, req, params) + vv, err = stringToValue(t, text, req, params, d.config) if err != nil { break } @@ -164,7 +164,7 @@ func (d *sliceTypeFieldTextDecoder) Decode(req *protocol.Request, params param.P return nil } -func getSliceFieldDecoder(field reflect.StructField, index int, tagInfos []TagInfo, parentIdx []int) ([]fieldDecoder, error) { +func getSliceFieldDecoder(field reflect.StructField, index int, tagInfos []TagInfo, parentIdx []int, config *DecodeConfig) ([]fieldDecoder, error) { if !(field.Type.Kind() == reflect.Slice || field.Type.Kind() == reflect.Array) { return nil, fmt.Errorf("unexpected type %s, expected slice or array", field.Type.String()) } @@ -207,7 +207,7 @@ func getSliceFieldDecoder(field reflect.StructField, index int, tagInfos []TagIn // fieldType.Elem() is the type for array/slice elem t := getElemType(fieldType.Elem()) if t == reflect.TypeOf(multipart.FileHeader{}) { - return getMultipartFileDecoder(field, index, tagInfos, parentIdx) + return getMultipartFileDecoder(field, index, tagInfos, parentIdx, config) } return []fieldDecoder{&sliceTypeFieldTextDecoder{ @@ -217,14 +217,15 @@ func getSliceFieldDecoder(field reflect.StructField, index int, tagInfos []TagIn fieldName: field.Name, tagInfos: tagInfos, fieldType: fieldType, + config: config, }, isArray: isArray, }}, nil } -func stringToValue(elemType reflect.Type, text string, req *protocol.Request, params param.Params) (v reflect.Value, err error) { +func stringToValue(elemType reflect.Type, text string, req *protocol.Request, params param.Params, config *DecodeConfig) (v reflect.Value, err error) { v = reflect.New(elemType).Elem() - if customizedFunc, exist := typeUnmarshalFuncs[elemType]; exist { + if customizedFunc, exist := config.TypeUnmarshalFuncs[elemType]; exist { val, err := customizedFunc(req, params, text) if err != nil { return reflect.Value{}, err @@ -243,7 +244,7 @@ func stringToValue(elemType reflect.Type, text string, req *protocol.Request, pa if err != nil { return reflect.Value{}, fmt.Errorf("unsupported type %s for slice/array", elemType.String()) } - err = decoder.UnmarshalString(text, v) + err = decoder.UnmarshalString(text, v, config.LooseZeroMode) if err != nil { return reflect.Value{}, fmt.Errorf("unable to decode '%s' as %s: %w", text, elemType.String(), err) } diff --git a/pkg/app/server/binding/internal/decoder/struct_type_decoder.go b/pkg/app/server/binding/internal/decoder/struct_type_decoder.go index e157b4a0a..c00a633c7 100644 --- a/pkg/app/server/binding/internal/decoder/struct_type_decoder.go +++ b/pkg/app/server/binding/internal/decoder/struct_type_decoder.go @@ -82,7 +82,7 @@ func (d *structTypeFieldTextDecoder) Decode(req *protocol.Request, params param. ptrDepth++ } var vv reflect.Value - vv, err := stringToValue(t, text, req, params) + vv, err := stringToValue(t, text, req, params, d.config) if err != nil { hlog.Infof("unable to decode '%s' as %s: %v, but it may not affect correctness, so skip it", text, d.fieldType.Name(), err) return nil @@ -99,7 +99,7 @@ func (d *structTypeFieldTextDecoder) Decode(req *protocol.Request, params param. return nil } -func getStructTypeFieldDecoder(field reflect.StructField, index int, tagInfos []TagInfo, parentIdx []int) ([]fieldDecoder, error) { +func getStructTypeFieldDecoder(field reflect.StructField, index int, tagInfos []TagInfo, parentIdx []int, config *DecodeConfig) ([]fieldDecoder, error) { for idx, tagInfo := range tagInfos { switch tagInfo.Key { case pathTag: @@ -140,6 +140,7 @@ func getStructTypeFieldDecoder(field reflect.StructField, index int, tagInfos [] fieldName: field.Name, tagInfos: tagInfos, fieldType: fieldType, + config: config, }, }}, nil } diff --git a/pkg/app/server/binding/internal/decoder/tag.go b/pkg/app/server/binding/internal/decoder/tag.go index ee3ed388d..6df09aaa3 100644 --- a/pkg/app/server/binding/internal/decoder/tag.go +++ b/pkg/app/server/binding/internal/decoder/tag.go @@ -60,10 +60,10 @@ func head(str, sep string) (head, tail string) { return str[:idx], str[idx+len(sep):] } -func lookupFieldTags(field reflect.StructField, parentJSONName string, validateTag string) ([]TagInfo, string, bool) { +func lookupFieldTags(field reflect.StructField, parentJSONName string, config *DecodeConfig) ([]TagInfo, string, bool) { var ret []string var needValidate bool - if _, ok := field.Tag.Lookup(validateTag); ok { + if _, ok := field.Tag.Lookup(config.ValidateTag); ok { needValidate = true } tags := []string{pathTag, formTag, queryTag, cookieTag, headerTag, jsonTag, rawBodyTag, fileNameTag} diff --git a/pkg/app/server/binding/internal/decoder/text_decoder.go b/pkg/app/server/binding/internal/decoder/text_decoder.go index 7224d61fc..8b53c2bf5 100644 --- a/pkg/app/server/binding/internal/decoder/text_decoder.go +++ b/pkg/app/server/binding/internal/decoder/text_decoder.go @@ -49,20 +49,8 @@ import ( hJson "github.com/cloudwego/hertz/pkg/common/json" ) -var looseZeroMode = false - -// SetLooseZeroMode if set to true, -// the empty string request parameter is bound to the zero value of parameter. -// NOTE: -// -// The default is false; -// Suitable for these parameter types: query/header/cookie/form . -func SetLooseZeroMode(enable bool) { - looseZeroMode = enable -} - type TextDecoder interface { - UnmarshalString(s string, fieldValue reflect.Value) error + UnmarshalString(s string, fieldValue reflect.Value, looseZeroMode bool) error } func SelectTextDecoder(rt reflect.Type) (TextDecoder, error) { @@ -97,7 +85,6 @@ func SelectTextDecoder(rt reflect.Type) (TextDecoder, error) { return &floatDecoder{bitSize: 64}, nil case reflect.Interface: return &interfaceDecoder{}, nil - } return nil, fmt.Errorf("unsupported type " + rt.String()) @@ -105,7 +92,7 @@ func SelectTextDecoder(rt reflect.Type) (TextDecoder, error) { type boolDecoder struct{} -func (d *boolDecoder) UnmarshalString(s string, fieldValue reflect.Value) error { +func (d *boolDecoder) UnmarshalString(s string, fieldValue reflect.Value, looseZeroMode bool) error { if s == "" && looseZeroMode { s = "false" } @@ -121,7 +108,7 @@ type floatDecoder struct { bitSize int } -func (d *floatDecoder) UnmarshalString(s string, fieldValue reflect.Value) error { +func (d *floatDecoder) UnmarshalString(s string, fieldValue reflect.Value, looseZeroMode bool) error { if s == "" && looseZeroMode { s = "0.0" } @@ -137,7 +124,7 @@ type intDecoder struct { bitSize int } -func (d *intDecoder) UnmarshalString(s string, fieldValue reflect.Value) error { +func (d *intDecoder) UnmarshalString(s string, fieldValue reflect.Value, looseZeroMode bool) error { if s == "" && looseZeroMode { s = "0" } @@ -151,7 +138,7 @@ func (d *intDecoder) UnmarshalString(s string, fieldValue reflect.Value) error { type stringDecoder struct{} -func (d *stringDecoder) UnmarshalString(s string, fieldValue reflect.Value) error { +func (d *stringDecoder) UnmarshalString(s string, fieldValue reflect.Value, looseZeroMode bool) error { fieldValue.SetString(s) return nil } @@ -160,7 +147,7 @@ type uintDecoder struct { bitSize int } -func (d *uintDecoder) UnmarshalString(s string, fieldValue reflect.Value) error { +func (d *uintDecoder) UnmarshalString(s string, fieldValue reflect.Value, looseZeroMode bool) error { if s == "" && looseZeroMode { s = "0" } @@ -174,7 +161,7 @@ func (d *uintDecoder) UnmarshalString(s string, fieldValue reflect.Value) error type interfaceDecoder struct{} -func (d *interfaceDecoder) UnmarshalString(s string, fieldValue reflect.Value) error { +func (d *interfaceDecoder) UnmarshalString(s string, fieldValue reflect.Value, looseZeroMode bool) error { if s == "" && looseZeroMode { s = "0" } diff --git a/pkg/app/server/binding/tagexpr_bind_test.go b/pkg/app/server/binding/tagexpr_bind_test.go index 0f24f696a..e01c4b0ab 100644 --- a/pkg/app/server/binding/tagexpr_bind_test.go +++ b/pkg/app/server/binding/tagexpr_bind_test.go @@ -52,10 +52,6 @@ import ( "google.golang.org/protobuf/proto" ) -func init() { - SetLooseZeroMode(true) -} - func TestRawBody(t *testing.T) { type Recv struct { S []byte `raw_body:""` @@ -94,7 +90,10 @@ func TestQueryString(t *testing.T) { } req := newRequest("http://localhost:8080/?a=a1&a=a2&b=b1&c=c1&c=c2&d=d1&d=d&f=qps&g=1002&g=1003&e=&e=2&y=y1", nil, nil, nil) recv := new(Recv) - err := DefaultBinder().Bind(req.Req, recv, nil) + bindConfig := &BindConfig{} + bindConfig.LooseZeroMode = true + binder := NewDefaultBinder(bindConfig) + err := binder.Bind(req.Req, recv, nil) if err != nil { t.Error(err) } @@ -784,11 +783,10 @@ func TestOption(t *testing.T) { }`) req = newRequest("", header, nil, bodyReader) recv2 := new(Recv2) - EnableStructFieldResolve(true) - defer func() { - EnableStructFieldResolve(false) - }() - err = DefaultBinder().Bind(req.Req, recv2, nil) + bindConfig := &BindConfig{} + bindConfig.EnableStructFieldResolve = true + binder := NewDefaultBinder(bindConfig) + err = binder.Bind(req.Req, recv2, nil) assert.DeepEqual(t, err.Error(), "'X' field is a 'required' parameter, but the request does not have this parameter") assert.True(t, recv2.X == nil) assert.DeepEqual(t, "y1", recv2.Y) @@ -937,11 +935,10 @@ func TestRegTypeUnmarshal(t *testing.T) { req := newRequest("http://localhost:8080/?"+values.Encode(), nil, nil, nil) recv := new(T) - EnableStructFieldResolve(true) - defer func() { - EnableStructFieldResolve(false) - }() - err = DefaultBinder().Bind(req.Req, recv, nil) + bindConfig := &BindConfig{} + bindConfig.EnableStructFieldResolve = true + binder := NewDefaultBinder(bindConfig) + err = binder.Bind(req.Req, recv, nil) if err != nil { t.Error(err) } diff --git a/pkg/app/server/hertz_test.go b/pkg/app/server/hertz_test.go index a5cf7d350..d9cae06a7 100644 --- a/pkg/app/server/hertz_test.go +++ b/pkg/app/server/hertz_test.go @@ -34,6 +34,7 @@ import ( "github.com/cloudwego/hertz/pkg/app" c "github.com/cloudwego/hertz/pkg/app/client" + "github.com/cloudwego/hertz/pkg/app/server/binding" "github.com/cloudwego/hertz/pkg/app/server/registry" "github.com/cloudwego/hertz/pkg/common/config" errs "github.com/cloudwego/hertz/pkg/common/errors" @@ -47,6 +48,7 @@ import ( "github.com/cloudwego/hertz/pkg/protocol/consts" "github.com/cloudwego/hertz/pkg/protocol/http1/req" "github.com/cloudwego/hertz/pkg/protocol/http1/resp" + "github.com/cloudwego/hertz/pkg/route/param" ) func TestHertz_Run(t *testing.T) { @@ -820,3 +822,171 @@ func TestHertzDisableHeaderNamesNormalizing(t *testing.T) { assert.Nil(t, err) assert.DeepEqual(t, headerValue, res.Header.Get(headerName)) } + +func TestBindConfig(t *testing.T) { + type Req struct { + A int `query:"a"` + } + h := New( + WithHostPorts("localhost:9229"), + WithBindConfig(&binding.BindConfig{ + LooseZeroMode: true, + })) + h.GET("/bind", func(c context.Context, ctx *app.RequestContext) { + var req Req + err := ctx.BindAndValidate(&req) + if err != nil { + t.Fatal("unexpected error") + } + }) + + go h.Spin() + time.Sleep(100 * time.Millisecond) + hc := http.Client{Timeout: time.Second} + _, err := hc.Get("http://127.0.0.1:9229/bind?a=") + assert.Nil(t, err) + + h2 := New( + WithHostPorts("localhost:9230"), + WithBindConfig(&binding.BindConfig{ + LooseZeroMode: false, + })) + h2.GET("/bind", func(c context.Context, ctx *app.RequestContext) { + var req Req + err := ctx.BindAndValidate(&req) + if err == nil { + t.Fatal("expect an error") + } + }) + + go h2.Spin() + time.Sleep(100 * time.Millisecond) + + _, err = hc.Get("http://127.0.0.1:9230/bind?a=") + assert.Nil(t, err) + time.Sleep(100 * time.Millisecond) +} + +type mockBinder struct{} + +func (m *mockBinder) Name() string { + return "test binder" +} + +func (m *mockBinder) Bind(request *protocol.Request, i interface{}, params param.Params) error { + return nil +} + +func (m *mockBinder) BindAndValidate(request *protocol.Request, i interface{}, params param.Params) error { + return fmt.Errorf("test binder") +} + +func (m *mockBinder) BindQuery(request *protocol.Request, i interface{}) error { + return nil +} + +func (m *mockBinder) BindHeader(request *protocol.Request, i interface{}) error { + return nil +} + +func (m *mockBinder) BindPath(request *protocol.Request, i interface{}, params param.Params) error { + return nil +} + +func (m *mockBinder) BindForm(request *protocol.Request, i interface{}) error { + return nil +} + +func (m *mockBinder) BindJSON(request *protocol.Request, i interface{}) error { + return nil +} + +func (m *mockBinder) BindProtobuf(request *protocol.Request, i interface{}) error { + return nil +} + +func TestCustomBinder(t *testing.T) { + type Req struct { + A int `query:"a"` + } + h := New( + WithHostPorts("localhost:9229"), + WithCustomBinder(&mockBinder{})) + h.GET("/bind", func(c context.Context, ctx *app.RequestContext) { + var req Req + err := ctx.BindAndValidate(&req) + if err == nil { + t.Fatal("expect an error") + } + assert.DeepEqual(t, "test binder", err.Error()) + }) + + go h.Spin() + time.Sleep(100 * time.Millisecond) + hc := http.Client{Timeout: time.Second} + _, err := hc.Get("http://127.0.0.1:9229/bind?a=") + assert.Nil(t, err) + time.Sleep(100 * time.Millisecond) +} + +func TestValidateConfig(t *testing.T) { + type Req struct { + A int `query:"a" vd:"f($)"` + } + validateConfig := &binding.ValidateConfig{} + validateConfig.MustRegValidateFunc("f", func(args ...interface{}) error { + return fmt.Errorf("test validator") + }) + h := New( + WithHostPorts("localhost:9229"), + WithValidateConfig(validateConfig)) + h.GET("/bind", func(c context.Context, ctx *app.RequestContext) { + var req Req + err := ctx.BindAndValidate(&req) + if err == nil { + t.Fatal("expect an error") + } + assert.DeepEqual(t, "test validator", err.Error()) + }) + + go h.Spin() + time.Sleep(100 * time.Millisecond) + hc := http.Client{Timeout: time.Second} + _, err := hc.Get("http://127.0.0.1:9229/bind?a=2") + assert.Nil(t, err) + time.Sleep(100 * time.Millisecond) +} + +type mockValidator struct{} + +func (m *mockValidator) ValidateStruct(interface{}) error { + return fmt.Errorf("test mock validator") +} + +func (m *mockValidator) Engine() interface{} { + return nil +} + +func TestCustomValidator(t *testing.T) { + type Req struct { + A int `query:"a" vd:"f($)"` + } + h := New( + WithHostPorts("localhost:9229"), + WithCustomValidator(&mockValidator{})) + h.GET("/bind", func(c context.Context, ctx *app.RequestContext) { + var req Req + err := ctx.BindAndValidate(&req) + if err == nil { + t.Fatal("expect an error") + } + assert.DeepEqual(t, "test mock validator", err.Error()) + }) + + go h.Spin() + time.Sleep(100 * time.Millisecond) + hc := http.Client{Timeout: time.Second} + _, err := hc.Get("http://127.0.0.1:9229/bind?a=2") + assert.Nil(t, err) + time.Sleep(100 * time.Millisecond) +} diff --git a/pkg/app/server/option.go b/pkg/app/server/option.go index c9e3735be..df8392318 100644 --- a/pkg/app/server/option.go +++ b/pkg/app/server/option.go @@ -19,6 +19,7 @@ package server import ( "context" "crypto/tls" + "github.com/cloudwego/hertz/pkg/app/server/binding" "net" "strings" "time" @@ -347,6 +348,34 @@ func WithOnConnect(fn func(ctx context.Context, conn network.Conn) context.Conte }} } +// WithBindConfig sets bind config. +func WithBindConfig(bc *binding.BindConfig) config.Option { + return config.Option{F: func(o *config.Options) { + o.BindConfig = bc + }} +} + +// WithCustomBinder sets customized Binder. +func WithCustomBinder(b binding.Binder) config.Option { + return config.Option{F: func(o *config.Options) { + o.CustomBinder = b + }} +} + +// WithValidateConfig sets bind config. +func WithValidateConfig(vc *binding.ValidateConfig) config.Option { + return config.Option{F: func(o *config.Options) { + o.ValidateConfig = vc + }} +} + +// WithCustomValidator sets customized Binder. +func WithCustomValidator(b binding.StructValidator) config.Option { + return config.Option{F: func(o *config.Options) { + o.CustomValidator = b + }} +} + // WithDisableHeaderNamesNormalizing is used to set whether disable header names normalizing. func WithDisableHeaderNamesNormalizing(disable bool) config.Option { return config.Option{F: func(o *config.Options) { diff --git a/pkg/common/config/option.go b/pkg/common/config/option.go index 9ef7ddf42..89c028b93 100644 --- a/pkg/common/config/option.go +++ b/pkg/common/config/option.go @@ -72,6 +72,10 @@ type Options struct { Tracers []interface{} TraceLevel interface{} ListenConfig *net.ListenConfig + BindConfig interface{} + CustomBinder interface{} + ValidateConfig interface{} + CustomValidator interface{} // TransporterNewer is the function to create a transporter. TransporterNewer func(opt *Options) network.Transporter diff --git a/pkg/common/config/option_test.go b/pkg/common/config/option_test.go index 39d92d736..488913cc9 100644 --- a/pkg/common/config/option_test.go +++ b/pkg/common/config/option_test.go @@ -53,6 +53,10 @@ func TestDefaultOptions(t *testing.T) { assert.DeepEqual(t, []interface{}{}, options.Tracers) assert.DeepEqual(t, new(interface{}), options.TraceLevel) assert.DeepEqual(t, registry.NoopRegistry, options.Registry) + assert.Nil(t, options.BindConfig) + assert.Nil(t, options.CustomBinder) + assert.Nil(t, options.ValidateConfig) + assert.Nil(t, options.CustomValidator) assert.DeepEqual(t, false, options.DisableHeaderNamesNormalizing) } diff --git a/pkg/route/engine.go b/pkg/route/engine.go index 42da1d662..f8800938b 100644 --- a/pkg/route/engine.go +++ b/pkg/route/engine.go @@ -195,9 +195,8 @@ type Engine struct { formValueFunc app.FormValueFunc // Custom Binder and Validator - binder binding.Binder - validator binding.StructValidator - validateTag string + binder binding.Binder + validator binding.StructValidator } func (engine *Engine) IsTraceEnable() bool { @@ -557,6 +556,47 @@ func (engine *Engine) ServeStream(ctx context.Context, conn network.StreamConn) return errs.ErrNotSupportProtocol } +func (engine *Engine) initBinderAndValidator(opt *config.Options) { + // init validator + engine.validator = binding.DefaultValidator() + if opt.ValidateConfig != nil { + vConf, ok := opt.ValidateConfig.(*binding.ValidateConfig) + if !ok { + panic("validate config error") + } + engine.validator = binding.NewDefaultValidator(vConf) + } + if opt.CustomValidator != nil { + customValidator, ok := opt.CustomValidator.(binding.StructValidator) + if !ok { + panic("customized validator can not implement binding.StructValidator") + } + engine.validator = customValidator + } + + // Init binder. Due to the existence of the "BindAndValidate" interface, the Validator needs to be injected here. + defaultBindConfig := binding.NewBindConfig() + defaultBindConfig.Validator = engine.validator + engine.binder = binding.NewDefaultBinder(defaultBindConfig) + if opt.BindConfig != nil { + bConf, ok := opt.BindConfig.(*binding.BindConfig) + if !ok { + panic("bind config error") + } + if bConf.Validator != nil { + bConf.Validator = engine.validator + } + engine.binder = binding.NewDefaultBinder(bConf) + } + if opt.CustomBinder != nil { + customBinder, ok := opt.CustomBinder.(binding.Binder) + if !ok { + panic("customized binder can not implement binding.Binder") + } + engine.binder = customBinder + } +} + func NewEngine(opt *config.Options) *Engine { engine := &Engine{ trees: make(MethodTrees, 0, 9), @@ -572,6 +612,7 @@ func NewEngine(opt *config.Options) *Engine { enableTrace: true, options: opt, } + engine.initBinderAndValidator(opt) if opt.TransporterNewer != nil { engine.transport = opt.TransporterNewer(opt) } @@ -744,7 +785,7 @@ func (engine *Engine) allocateContext() *app.RequestContext { ctx.SetClientIPFunc(engine.clientIPFunc) ctx.SetFormValueFunc(engine.formValueFunc) ctx.SetBinder(engine.binder) - ctx.SetValidator(engine.validator, engine.validateTag) + ctx.SetValidator(engine.validator) return ctx } @@ -899,15 +940,6 @@ func (engine *Engine) SetClientIPFunc(f app.ClientIP) { engine.clientIPFunc = f } -func (engine *Engine) SetBinder(binder binding.Binder) { - engine.binder = binder -} - -func (engine *Engine) SetValidator(validator binding.StructValidator, tag string) { - engine.validator = validator - engine.validateTag = tag -} - func (engine *Engine) SetFormValueFunc(f app.FormValueFunc) { engine.formValueFunc = f } diff --git a/pkg/route/engine_test.go b/pkg/route/engine_test.go index 350da8177..0a0979a67 100644 --- a/pkg/route/engine_test.go +++ b/pkg/route/engine_test.go @@ -54,13 +54,16 @@ import ( "time" "github.com/cloudwego/hertz/pkg/app" + "github.com/cloudwego/hertz/pkg/app/server/binding" "github.com/cloudwego/hertz/pkg/common/config" errs "github.com/cloudwego/hertz/pkg/common/errors" "github.com/cloudwego/hertz/pkg/common/test/assert" "github.com/cloudwego/hertz/pkg/common/test/mock" "github.com/cloudwego/hertz/pkg/network" "github.com/cloudwego/hertz/pkg/network/standard" + "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/protocol/consts" + "github.com/cloudwego/hertz/pkg/route/param" ) func TestNew_Engine(t *testing.T) { @@ -623,3 +626,181 @@ func (f *fakeTransporter) ListenAndServe(onData network.OnData) error { // TODO implement me panic("implement me") } + +type mockBinder struct{} + +func (m *mockBinder) Name() string { + return "test binder" +} + +func (m *mockBinder) Bind(request *protocol.Request, i interface{}, params param.Params) error { + return nil +} + +func (m *mockBinder) BindAndValidate(request *protocol.Request, i interface{}, params param.Params) error { + return nil +} + +func (m *mockBinder) BindQuery(request *protocol.Request, i interface{}) error { + return nil +} + +func (m *mockBinder) BindHeader(request *protocol.Request, i interface{}) error { + return nil +} + +func (m *mockBinder) BindPath(request *protocol.Request, i interface{}, params param.Params) error { + return nil +} + +func (m *mockBinder) BindForm(request *protocol.Request, i interface{}) error { + return nil +} + +func (m *mockBinder) BindJSON(request *protocol.Request, i interface{}) error { + return nil +} + +func (m *mockBinder) BindProtobuf(request *protocol.Request, i interface{}) error { + return nil +} + +type mockValidator struct{} + +func (m *mockValidator) ValidateStruct(interface{}) error { + return fmt.Errorf("test mock") +} + +func (m *mockValidator) Engine() interface{} { + return nil +} + +type mockNonValidator struct{} + +func (m *mockNonValidator) ValidateStruct(interface{}) error { + return fmt.Errorf("test mock") +} + +func TestInitBinderAndValidator(t *testing.T) { + defer func() { + if r := recover(); r != nil { + t.Errorf("unexpect panic, %v", r) + } + }() + opt := config.NewOptions([]config.Option{}) + opt.BindConfig = &binding.BindConfig{ + EnableDefaultTag: true, + } + binder := &mockBinder{} + opt.CustomBinder = binder + opt.ValidateConfig = &binding.ValidateConfig{} + validator := &mockValidator{} + opt.CustomValidator = validator + NewEngine(opt) +} + +func TestInitBinderAndValidatorForPanic(t *testing.T) { + defer func() { + if r := recover(); r == nil { + t.Errorf("expect a panic, but get nil") + } + }() + opt := config.NewOptions([]config.Option{}) + opt.BindConfig = &binding.BindConfig{ + EnableDefaultTag: true, + } + binder := &mockBinder{} + opt.CustomBinder = binder + opt.ValidateConfig = &binding.ValidateConfig{} + nonValidator := &mockNonValidator{} + opt.CustomValidator = nonValidator + NewEngine(opt) +} + +func TestBindConfig(t *testing.T) { + type Req struct { + A int `query:"a"` + } + opt := config.NewOptions([]config.Option{}) + opt.BindConfig = &binding.BindConfig{LooseZeroMode: false} + e := NewEngine(opt) + e.GET("/bind", func(c context.Context, ctx *app.RequestContext) { + var req Req + err := ctx.BindAndValidate(&req) + if err == nil { + t.Fatal("expect an error") + } + }) + performRequest(e, "GET", "/bind?a=") + + opt.BindConfig = &binding.BindConfig{LooseZeroMode: true} + e = NewEngine(opt) + e.GET("/bind", func(c context.Context, ctx *app.RequestContext) { + var req Req + err := ctx.BindAndValidate(&req) + if err != nil { + t.Fatal("unexpected error") + } + assert.DeepEqual(t, 0, req.A) + }) + performRequest(e, "GET", "/bind?a=") +} + +func TestCustomBinder(t *testing.T) { + type Req struct { + A int `query:"a"` + } + opt := config.NewOptions([]config.Option{}) + opt.CustomBinder = &mockBinder{} + e := NewEngine(opt) + e.GET("/bind", func(c context.Context, ctx *app.RequestContext) { + var req Req + err := ctx.BindAndValidate(&req) + if err != nil { + t.Fatal("unexpected error") + } + assert.NotEqual(t, 2, req.A) + }) + performRequest(e, "GET", "/bind?a=2") +} + +func TestValidateConfig(t *testing.T) { + type Req struct { + A int `query:"a" vd:"f($)"` + } + opt := config.NewOptions([]config.Option{}) + validateConfig := &binding.ValidateConfig{} + validateConfig.MustRegValidateFunc("f", func(args ...interface{}) error { + return fmt.Errorf("test error") + }) + opt.ValidateConfig = validateConfig + e := NewEngine(opt) + e.GET("/validate", func(c context.Context, ctx *app.RequestContext) { + var req Req + err := ctx.BindAndValidate(&req) + assert.NotNil(t, err) + assert.DeepEqual(t, "test error", err.Error()) + }) + performRequest(e, "GET", "/validate?a=2") +} + +func TestCustomValidator(t *testing.T) { + type Req struct { + A int `query:"a" vd:"f($)"` + } + opt := config.NewOptions([]config.Option{}) + validateConfig := &binding.ValidateConfig{} + validateConfig.MustRegValidateFunc("f", func(args ...interface{}) error { + return fmt.Errorf("test error") + }) + opt.ValidateConfig = validateConfig + opt.CustomValidator = &mockValidator{} + e := NewEngine(opt) + e.GET("/validate", func(c context.Context, ctx *app.RequestContext) { + var req Req + err := ctx.BindAndValidate(&req) + assert.NotNil(t, err) + assert.DeepEqual(t, "test mock", err.Error()) + }) + performRequest(e, "GET", "/validate?a=2") +} diff --git a/pkg/route/routes_test.go b/pkg/route/routes_test.go index 1e76d673e..1d4e17fb8 100644 --- a/pkg/route/routes_test.go +++ b/pkg/route/routes_test.go @@ -68,6 +68,7 @@ func performRequest(e *Engine, method, path string, headers ...header) *httptest ctx.HTMLRender = e.htmlRender r := protocol.NewRequest(method, path, nil) + r.PostArgs() r.CopyTo(&ctx.Request) for _, v := range headers { ctx.Request.Header.Add(v.Key, v.Value) From 92ea388a5ac1c22043e3dba0bcf4feaf8c130519 Mon Sep 17 00:00:00 2001 From: fgy Date: Wed, 20 Sep 2023 20:54:45 +0800 Subject: [PATCH 81/91] fix: typo --- pkg/route/engine_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pkg/route/engine_test.go b/pkg/route/engine_test.go index 0a0979a67..b0fe53ddf 100644 --- a/pkg/route/engine_test.go +++ b/pkg/route/engine_test.go @@ -684,7 +684,7 @@ func (m *mockNonValidator) ValidateStruct(interface{}) error { func TestInitBinderAndValidator(t *testing.T) { defer func() { if r := recover(); r != nil { - t.Errorf("unexpect panic, %v", r) + t.Errorf("unexpected panic, %v", r) } }() opt := config.NewOptions([]config.Option{}) From 81aa5a89a6df55c3bb2e1512612e701cfcabdd4f Mon Sep 17 00:00:00 2001 From: fgy Date: Wed, 20 Sep 2023 20:59:23 +0800 Subject: [PATCH 82/91] refactor: default bind and validate --- pkg/app/server/binding/binder.go | 6 ------ pkg/app/server/binding/config.go | 4 ++++ pkg/app/server/binding/default.go | 12 ++++++++++++ pkg/app/server/binding/validator.go | 6 ------ 4 files changed, 16 insertions(+), 12 deletions(-) diff --git a/pkg/app/server/binding/binder.go b/pkg/app/server/binding/binder.go index 291d22a04..f97b80dbd 100644 --- a/pkg/app/server/binding/binder.go +++ b/pkg/app/server/binding/binder.go @@ -56,9 +56,3 @@ type Binder interface { BindJSON(*protocol.Request, interface{}) error BindProtobuf(*protocol.Request, interface{}) error } - -var defaultBind = NewDefaultBinder(nil) - -func DefaultBinder() Binder { - return defaultBind -} diff --git a/pkg/app/server/binding/config.go b/pkg/app/server/binding/config.go index 103d2629d..822e6a070 100644 --- a/pkg/app/server/binding/config.go +++ b/pkg/app/server/binding/config.go @@ -147,6 +147,10 @@ func (config *BindConfig) UseStdJSONUnmarshaler() { type ValidateConfig struct{} +func NewValidateConfig() *ValidateConfig { + return &ValidateConfig{} +} + // MustRegValidateFunc registers validator function expression. // NOTE: // diff --git a/pkg/app/server/binding/default.go b/pkg/app/server/binding/default.go index cc241d820..ccc9e628c 100644 --- a/pkg/app/server/binding/default.go +++ b/pkg/app/server/binding/default.go @@ -92,6 +92,12 @@ type decoderInfo struct { needValidate bool } +var defaultBind = NewDefaultBinder(nil) + +func DefaultBinder() Binder { + return defaultBind +} + type defaultBinder struct { config *BindConfig decoderCache sync.Map @@ -393,3 +399,9 @@ func (v *defaultValidator) Engine() interface{} { v.lazyinit() return v.validate } + +var defaultValidate = NewDefaultValidator(nil) + +func DefaultValidator() StructValidator { + return defaultValidate +} diff --git a/pkg/app/server/binding/validator.go b/pkg/app/server/binding/validator.go index 910a1a02c..0939b7aef 100644 --- a/pkg/app/server/binding/validator.go +++ b/pkg/app/server/binding/validator.go @@ -44,9 +44,3 @@ type StructValidator interface { ValidateStruct(interface{}) error Engine() interface{} } - -var defaultValidate StructValidator = &defaultValidator{} - -func DefaultValidator() StructValidator { - return defaultValidate -} From a81df8a73c169b0eadff0e4a529a19275fcc445a Mon Sep 17 00:00:00 2001 From: fgy Date: Wed, 20 Sep 2023 21:09:17 +0800 Subject: [PATCH 83/91] fix: route test --- pkg/app/server/binding/default.go | 2 +- pkg/route/engine.go | 4 ++-- pkg/route/engine_test.go | 4 ++-- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/pkg/app/server/binding/default.go b/pkg/app/server/binding/default.go index ccc9e628c..412d7aa2f 100644 --- a/pkg/app/server/binding/default.go +++ b/pkg/app/server/binding/default.go @@ -376,7 +376,7 @@ type defaultValidator struct { } func NewDefaultValidator(config *ValidateConfig) StructValidator { - return defaultValidate + return &defaultValidator{} } // ValidateStruct receives any kind of type, but only performed struct or pointer to struct type. diff --git a/pkg/route/engine.go b/pkg/route/engine.go index f8800938b..c96cfda8e 100644 --- a/pkg/route/engine.go +++ b/pkg/route/engine.go @@ -712,6 +712,8 @@ func (engine *Engine) recv(ctx *app.RequestContext) { // ServeHTTP makes the router implement the Handler interface. func (engine *Engine) ServeHTTP(c context.Context, ctx *app.RequestContext) { + ctx.SetBinder(engine.binder) + ctx.SetValidator(engine.validator) if engine.PanicHandler != nil { defer engine.recv(ctx) } @@ -784,8 +786,6 @@ func (engine *Engine) allocateContext() *app.RequestContext { ctx.Response.SetMaxKeepBodySize(engine.options.MaxKeepBodySize) ctx.SetClientIPFunc(engine.clientIPFunc) ctx.SetFormValueFunc(engine.formValueFunc) - ctx.SetBinder(engine.binder) - ctx.SetValidator(engine.validator) return ctx } diff --git a/pkg/route/engine_test.go b/pkg/route/engine_test.go index b0fe53ddf..ad25bf974 100644 --- a/pkg/route/engine_test.go +++ b/pkg/route/engine_test.go @@ -786,11 +786,11 @@ func TestValidateConfig(t *testing.T) { func TestCustomValidator(t *testing.T) { type Req struct { - A int `query:"a" vd:"f($)"` + A int `query:"a" vd:"d($)"` } opt := config.NewOptions([]config.Option{}) validateConfig := &binding.ValidateConfig{} - validateConfig.MustRegValidateFunc("f", func(args ...interface{}) error { + validateConfig.MustRegValidateFunc("d", func(args ...interface{}) error { return fmt.Errorf("test error") }) opt.ValidateConfig = validateConfig From 6994515a76f97d668bfb0ded5a100db0fa71828c Mon Sep 17 00:00:00 2001 From: fgy Date: Wed, 20 Sep 2023 21:22:21 +0800 Subject: [PATCH 84/91] fix: ci --- pkg/app/context_test.go | 1 + pkg/app/server/binding/binder_test.go | 1 - pkg/app/server/option.go | 2 +- 3 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pkg/app/context_test.go b/pkg/app/context_test.go index cba91a75a..caf9e2351 100644 --- a/pkg/app/context_test.go +++ b/pkg/app/context_test.go @@ -1573,6 +1573,7 @@ func TestSetBinder(t *testing.T) { err = c.BindQuery(&req) assert.NotNil(t, err) err = c.BindHeader(&req) + assert.NotNil(t, err) } func TestRequestContext_SetCookie(t *testing.T) { diff --git a/pkg/app/server/binding/binder_test.go b/pkg/app/server/binding/binder_test.go index 696010d2a..5f42b27d5 100644 --- a/pkg/app/server/binding/binder_test.go +++ b/pkg/app/server/binding/binder_test.go @@ -577,7 +577,6 @@ func TestBind_CustomizedTypeDecodeForPanic(t *testing.T) { bindConfig.MustRegTypeUnmarshal(reflect.TypeOf(string("")), func(req *protocol.Request, params param.Params, text string) (reflect.Value, error) { return reflect.Value{}, nil }) - } func TestBind_JSON(t *testing.T) { diff --git a/pkg/app/server/option.go b/pkg/app/server/option.go index df8392318..e3fba28cd 100644 --- a/pkg/app/server/option.go +++ b/pkg/app/server/option.go @@ -19,11 +19,11 @@ package server import ( "context" "crypto/tls" - "github.com/cloudwego/hertz/pkg/app/server/binding" "net" "strings" "time" + "github.com/cloudwego/hertz/pkg/app/server/binding" "github.com/cloudwego/hertz/pkg/app/server/registry" "github.com/cloudwego/hertz/pkg/common/config" "github.com/cloudwego/hertz/pkg/common/tracer" From ec867c7a89535f30d18d80302e079ff08045e0c6 Mon Sep 17 00:00:00 2001 From: fgy Date: Wed, 20 Sep 2023 21:30:10 +0800 Subject: [PATCH 85/91] fix: option test --- pkg/app/server/hertz_test.go | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/pkg/app/server/hertz_test.go b/pkg/app/server/hertz_test.go index d9cae06a7..6fc0308a8 100644 --- a/pkg/app/server/hertz_test.go +++ b/pkg/app/server/hertz_test.go @@ -695,7 +695,7 @@ type CloseWithoutResetBuffer interface { func TestOnprepare(t *testing.T) { h1 := New( - WithHostPorts("localhost:9229"), + WithHostPorts("localhost:9333"), WithOnConnect(func(ctx context.Context, conn network.Conn) context.Context { b, err := conn.Peek(3) assert.Nil(t, err) @@ -713,7 +713,7 @@ func TestOnprepare(t *testing.T) { go h1.Spin() time.Sleep(time.Second) - _, _, err := c.Get(context.Background(), nil, "http://127.0.0.1:9229/ping") + _, _, err := c.Get(context.Background(), nil, "http://127.0.0.1:9333/ping") assert.DeepEqual(t, "the server closed connection before returning the first response byte. Make sure the server returns 'Connection: close' response header before closing the connection", err.Error()) h2 := New( @@ -721,13 +721,13 @@ func TestOnprepare(t *testing.T) { conn.Close() return context.Background() }), - WithHostPorts("localhost:9230")) + WithHostPorts("localhost:9331")) h2.GET("/ping", func(ctx context.Context, c *app.RequestContext) { c.JSON(consts.StatusOK, utils.H{"ping": "pong"}) }) go h2.Spin() time.Sleep(time.Second) - _, _, err = c.Get(context.Background(), nil, "http://127.0.0.1:9230/ping") + _, _, err = c.Get(context.Background(), nil, "http://127.0.0.1:9331/ping") if err == nil { t.Fatalf("err should not be nil") } @@ -828,7 +828,7 @@ func TestBindConfig(t *testing.T) { A int `query:"a"` } h := New( - WithHostPorts("localhost:9229"), + WithHostPorts("localhost:9332"), WithBindConfig(&binding.BindConfig{ LooseZeroMode: true, })) @@ -843,11 +843,11 @@ func TestBindConfig(t *testing.T) { go h.Spin() time.Sleep(100 * time.Millisecond) hc := http.Client{Timeout: time.Second} - _, err := hc.Get("http://127.0.0.1:9229/bind?a=") + _, err := hc.Get("http://127.0.0.1:9332/bind?a=") assert.Nil(t, err) h2 := New( - WithHostPorts("localhost:9230"), + WithHostPorts("localhost:9448"), WithBindConfig(&binding.BindConfig{ LooseZeroMode: false, })) @@ -862,7 +862,7 @@ func TestBindConfig(t *testing.T) { go h2.Spin() time.Sleep(100 * time.Millisecond) - _, err = hc.Get("http://127.0.0.1:9230/bind?a=") + _, err = hc.Get("http://127.0.0.1:9448/bind?a=") assert.Nil(t, err) time.Sleep(100 * time.Millisecond) } @@ -910,7 +910,7 @@ func TestCustomBinder(t *testing.T) { A int `query:"a"` } h := New( - WithHostPorts("localhost:9229"), + WithHostPorts("localhost:9334"), WithCustomBinder(&mockBinder{})) h.GET("/bind", func(c context.Context, ctx *app.RequestContext) { var req Req @@ -924,7 +924,7 @@ func TestCustomBinder(t *testing.T) { go h.Spin() time.Sleep(100 * time.Millisecond) hc := http.Client{Timeout: time.Second} - _, err := hc.Get("http://127.0.0.1:9229/bind?a=") + _, err := hc.Get("http://127.0.0.1:9334/bind?a=") assert.Nil(t, err) time.Sleep(100 * time.Millisecond) } @@ -972,7 +972,7 @@ func TestCustomValidator(t *testing.T) { A int `query:"a" vd:"f($)"` } h := New( - WithHostPorts("localhost:9229"), + WithHostPorts("localhost:9555"), WithCustomValidator(&mockValidator{})) h.GET("/bind", func(c context.Context, ctx *app.RequestContext) { var req Req @@ -986,7 +986,7 @@ func TestCustomValidator(t *testing.T) { go h.Spin() time.Sleep(100 * time.Millisecond) hc := http.Client{Timeout: time.Second} - _, err := hc.Get("http://127.0.0.1:9229/bind?a=2") + _, err := hc.Get("http://127.0.0.1:9555/bind?a=2") assert.Nil(t, err) time.Sleep(100 * time.Millisecond) } From 3a7ebed111c1d0266f7e27fcab4156de934d708b Mon Sep 17 00:00:00 2001 From: fgy Date: Wed, 20 Sep 2023 21:43:25 +0800 Subject: [PATCH 86/91] feat: context test --- pkg/app/context_test.go | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/pkg/app/context_test.go b/pkg/app/context_test.go index caf9e2351..072408240 100644 --- a/pkg/app/context_test.go +++ b/pkg/app/context_test.go @@ -1454,6 +1454,8 @@ func TestBindAndValidate(t *testing.T) { c := &RequestContext{} c.Request.SetRequestURI("/foo/bar?a=123&b=11") + c.SetValidator(binding.DefaultValidator()) + c.SetBinder(binding.DefaultBinder()) var req Test err := c.BindAndValidate(&req) @@ -1494,6 +1496,8 @@ func TestBindForm(t *testing.T) { } c := &RequestContext{} + c.SetValidator(binding.DefaultValidator()) + c.SetBinder(binding.DefaultBinder()) c.Request.SetRequestURI("/foo/bar?a=123&b=11") c.Request.SetBody([]byte("A=123&B=11")) c.Request.Header.SetContentTypeBytes([]byte("application/x-www-form-urlencoded")) @@ -1555,7 +1559,6 @@ func TestSetBinder(t *testing.T) { mockBind := &mockBinder{} c := NewContext(0) c.SetBinder(mockBind) - defer c.SetBinder(binding.DefaultBinder()) type T struct{} req := T{} err := c.Bind(&req) From a6cfb63ba53804a53b964ee7f3f31b98a28e0fa8 Mon Sep 17 00:00:00 2001 From: fgy Date: Thu, 21 Sep 2023 12:50:02 +0800 Subject: [PATCH 87/91] fix: more copy --- pkg/app/context.go | 49 +++++++++++++++++++++++++++++++++-------- pkg/app/context_test.go | 22 ++++++++---------- 2 files changed, 49 insertions(+), 22 deletions(-) diff --git a/pkg/app/context.go b/pkg/app/context.go index 0acda77cf..a13744e65 100644 --- a/pkg/app/context.go +++ b/pkg/app/context.go @@ -743,6 +743,10 @@ func (ctx *RequestContext) Copy() *RequestContext { paramCopy := make([]param.Param, len(cp.Params)) copy(paramCopy, cp.Params) cp.Params = paramCopy + cp.clientIPFunc = ctx.clientIPFunc + cp.formValueFunc = ctx.formValueFunc + cp.binder = ctx.binder + cp.validator = ctx.validator return cp } @@ -1316,37 +1320,55 @@ func bodyAllowedForStatus(status int) bool { // BindAndValidate binds data from *RequestContext to obj and validates them if needed. // NOTE: obj should be a pointer. func (ctx *RequestContext) BindAndValidate(obj interface{}) error { - return ctx.binder.BindAndValidate(&ctx.Request, obj, ctx.Params) + if ctx.binder != nil { + return ctx.binder.BindAndValidate(&ctx.Request, obj, ctx.Params) + } + return binding.BindAndValidate(&ctx.Request, obj, ctx.Params) } // Bind binds data from *RequestContext to obj. // NOTE: obj should be a pointer. func (ctx *RequestContext) Bind(obj interface{}) error { - return ctx.binder.Bind(&ctx.Request, obj, ctx.Params) + if ctx.binder != nil { + return ctx.binder.Bind(&ctx.Request, obj, ctx.Params) + } + return binding.Bind(&ctx.Request, obj, ctx.Params) } // Validate validates obj with "vd" tag // NOTE: obj should be a pointer. func (ctx *RequestContext) Validate(obj interface{}) error { - return ctx.validator.ValidateStruct(obj) + if ctx.validator != nil { + return ctx.validator.ValidateStruct(obj) + } + return binding.Validate(obj) } // BindQuery binds query parameters from *RequestContext to obj with 'query' tag. It will only use 'query' tag for binding. // NOTE: obj should be a pointer. func (ctx *RequestContext) BindQuery(obj interface{}) error { - return ctx.binder.BindQuery(&ctx.Request, obj) + if ctx.binder != nil { + return ctx.binder.BindQuery(&ctx.Request, obj) + } + return binding.DefaultBinder().BindQuery(&ctx.Request, obj) } // BindHeader binds header parameters from *RequestContext to obj with 'header' tag. It will only use 'header' tag for binding. // NOTE: obj should be a pointer. func (ctx *RequestContext) BindHeader(obj interface{}) error { - return ctx.binder.BindHeader(&ctx.Request, obj) + if ctx.binder != nil { + return ctx.binder.BindHeader(&ctx.Request, obj) + } + return binding.DefaultBinder().BindHeader(&ctx.Request, obj) } // BindPath binds router parameters from *RequestContext to obj with 'path' tag. It will only use 'path' tag for binding. // NOTE: obj should be a pointer. func (ctx *RequestContext) BindPath(obj interface{}) error { - return ctx.binder.BindPath(&ctx.Request, obj, ctx.Params) + if ctx.binder != nil { + return ctx.binder.BindPath(&ctx.Request, obj, ctx.Params) + } + return binding.DefaultBinder().BindPath(&ctx.Request, obj, ctx.Params) } // BindForm binds form parameters from *RequestContext to obj with 'form' tag. It will only use 'form' tag for binding. @@ -1355,19 +1377,28 @@ func (ctx *RequestContext) BindForm(obj interface{}) error { if len(ctx.Request.Body()) == 0 { return fmt.Errorf("missing form body") } - return ctx.binder.BindForm(&ctx.Request, obj) + if ctx.binder != nil { + return ctx.binder.BindForm(&ctx.Request, obj) + } + return binding.DefaultBinder().BindForm(&ctx.Request, obj) } // BindJSON binds JSON body from *RequestContext. // NOTE: obj should be a pointer. func (ctx *RequestContext) BindJSON(obj interface{}) error { - return ctx.binder.BindJSON(&ctx.Request, obj) + if ctx.binder != nil { + return ctx.binder.BindJSON(&ctx.Request, obj) + } + return binding.DefaultBinder().BindJSON(&ctx.Request, obj) } // BindProtobuf binds protobuf body from *RequestContext. // NOTE: obj should be a pointer. func (ctx *RequestContext) BindProtobuf(obj interface{}) error { - return ctx.binder.BindProtobuf(&ctx.Request, obj) + if ctx.binder != nil { + return ctx.binder.BindProtobuf(&ctx.Request, obj) + } + return binding.DefaultBinder().BindProtobuf(&ctx.Request, obj) } // BindByContentType will select the binding type on the ContentType automatically. diff --git a/pkg/app/context_test.go b/pkg/app/context_test.go index 072408240..22e7f8608 100644 --- a/pkg/app/context_test.go +++ b/pkg/app/context_test.go @@ -1454,8 +1454,6 @@ func TestBindAndValidate(t *testing.T) { c := &RequestContext{} c.Request.SetRequestURI("/foo/bar?a=123&b=11") - c.SetValidator(binding.DefaultValidator()) - c.SetBinder(binding.DefaultBinder()) var req Test err := c.BindAndValidate(&req) @@ -1496,8 +1494,6 @@ func TestBindForm(t *testing.T) { } c := &RequestContext{} - c.SetValidator(binding.DefaultValidator()) - c.SetBinder(binding.DefaultBinder()) c.Request.SetRequestURI("/foo/bar?a=123&b=11") c.Request.SetBody([]byte("A=123&B=11")) c.Request.Header.SetContentTypeBytes([]byte("application/x-www-form-urlencoded")) @@ -1528,7 +1524,7 @@ func (m *mockBinder) Bind(request *protocol.Request, i interface{}, params param } func (m *mockBinder) BindAndValidate(request *protocol.Request, i interface{}, params param.Params) error { - return nil + return fmt.Errorf("test binder") } func (m *mockBinder) BindQuery(request *protocol.Request, i interface{}) error { @@ -1556,27 +1552,27 @@ func (m *mockBinder) BindProtobuf(request *protocol.Request, i interface{}) erro } func TestSetBinder(t *testing.T) { - mockBind := &mockBinder{} c := NewContext(0) - c.SetBinder(mockBind) + c.SetBinder(&mockBinder{}) type T struct{} req := T{} err := c.Bind(&req) - assert.NotNil(t, err) + assert.Nil(t, err) err = c.BindAndValidate(&req) assert.NotNil(t, err) + assert.DeepEqual(t, "test binder", err.Error()) err = c.BindProtobuf(&req) - assert.NotNil(t, err) + assert.Nil(t, err) err = c.BindJSON(&req) - assert.NotNil(t, err) + assert.Nil(t, err) err = c.BindForm(&req) assert.NotNil(t, err) err = c.BindPath(&req) - assert.NotNil(t, err) + assert.Nil(t, err) err = c.BindQuery(&req) - assert.NotNil(t, err) + assert.Nil(t, err) err = c.BindHeader(&req) - assert.NotNil(t, err) + assert.Nil(t, err) } func TestRequestContext_SetCookie(t *testing.T) { From e2ce1cec4697b32664ad3264714283f839ac7b0d Mon Sep 17 00:00:00 2001 From: fgy Date: Thu, 21 Sep 2023 21:36:39 +0800 Subject: [PATCH 88/91] fix: validate --- pkg/app/server/binding/default.go | 3 +++ pkg/app/server/hertz_test.go | 15 +++++++-------- pkg/app/server/option.go | 7 ------- pkg/common/config/option.go | 1 - pkg/common/config/option_test.go | 1 - pkg/route/engine.go | 24 +++++++++--------------- pkg/route/engine_test.go | 24 ++++++++++++------------ pkg/route/routes_test.go | 1 - 8 files changed, 31 insertions(+), 45 deletions(-) diff --git a/pkg/app/server/binding/default.go b/pkg/app/server/binding/default.go index 412d7aa2f..8324b3771 100644 --- a/pkg/app/server/binding/default.go +++ b/pkg/app/server/binding/default.go @@ -116,6 +116,9 @@ func NewDefaultBinder(config *BindConfig) Binder { } } config.initTypeUnmarshal() + if config.Validator == nil { + config.Validator = DefaultValidator() + } return &defaultBinder{ config: config, } diff --git a/pkg/app/server/hertz_test.go b/pkg/app/server/hertz_test.go index 6fc0308a8..c70ff7340 100644 --- a/pkg/app/server/hertz_test.go +++ b/pkg/app/server/hertz_test.go @@ -827,11 +827,11 @@ func TestBindConfig(t *testing.T) { type Req struct { A int `query:"a"` } + bindConfig := binding.NewBindConfig() + bindConfig.LooseZeroMode = true h := New( WithHostPorts("localhost:9332"), - WithBindConfig(&binding.BindConfig{ - LooseZeroMode: true, - })) + WithBindConfig(bindConfig)) h.GET("/bind", func(c context.Context, ctx *app.RequestContext) { var req Req err := ctx.BindAndValidate(&req) @@ -846,11 +846,11 @@ func TestBindConfig(t *testing.T) { _, err := hc.Get("http://127.0.0.1:9332/bind?a=") assert.Nil(t, err) + bindConfig = binding.NewBindConfig() + bindConfig.LooseZeroMode = false h2 := New( WithHostPorts("localhost:9448"), - WithBindConfig(&binding.BindConfig{ - LooseZeroMode: false, - })) + WithBindConfig(bindConfig)) h2.GET("/bind", func(c context.Context, ctx *app.RequestContext) { var req Req err := ctx.BindAndValidate(&req) @@ -938,8 +938,7 @@ func TestValidateConfig(t *testing.T) { return fmt.Errorf("test validator") }) h := New( - WithHostPorts("localhost:9229"), - WithValidateConfig(validateConfig)) + WithHostPorts("localhost:9229")) h.GET("/bind", func(c context.Context, ctx *app.RequestContext) { var req Req err := ctx.BindAndValidate(&req) diff --git a/pkg/app/server/option.go b/pkg/app/server/option.go index e3fba28cd..d94a7b0cc 100644 --- a/pkg/app/server/option.go +++ b/pkg/app/server/option.go @@ -362,13 +362,6 @@ func WithCustomBinder(b binding.Binder) config.Option { }} } -// WithValidateConfig sets bind config. -func WithValidateConfig(vc *binding.ValidateConfig) config.Option { - return config.Option{F: func(o *config.Options) { - o.ValidateConfig = vc - }} -} - // WithCustomValidator sets customized Binder. func WithCustomValidator(b binding.StructValidator) config.Option { return config.Option{F: func(o *config.Options) { diff --git a/pkg/common/config/option.go b/pkg/common/config/option.go index 89c028b93..d8e6de2d0 100644 --- a/pkg/common/config/option.go +++ b/pkg/common/config/option.go @@ -74,7 +74,6 @@ type Options struct { ListenConfig *net.ListenConfig BindConfig interface{} CustomBinder interface{} - ValidateConfig interface{} CustomValidator interface{} // TransporterNewer is the function to create a transporter. diff --git a/pkg/common/config/option_test.go b/pkg/common/config/option_test.go index 488913cc9..6ee0fee95 100644 --- a/pkg/common/config/option_test.go +++ b/pkg/common/config/option_test.go @@ -55,7 +55,6 @@ func TestDefaultOptions(t *testing.T) { assert.DeepEqual(t, registry.NoopRegistry, options.Registry) assert.Nil(t, options.BindConfig) assert.Nil(t, options.CustomBinder) - assert.Nil(t, options.ValidateConfig) assert.Nil(t, options.CustomValidator) assert.DeepEqual(t, false, options.DisableHeaderNamesNormalizing) } diff --git a/pkg/route/engine.go b/pkg/route/engine.go index c96cfda8e..ff8cf5a7e 100644 --- a/pkg/route/engine.go +++ b/pkg/route/engine.go @@ -559,13 +559,6 @@ func (engine *Engine) ServeStream(ctx context.Context, conn network.StreamConn) func (engine *Engine) initBinderAndValidator(opt *config.Options) { // init validator engine.validator = binding.DefaultValidator() - if opt.ValidateConfig != nil { - vConf, ok := opt.ValidateConfig.(*binding.ValidateConfig) - if !ok { - panic("validate config error") - } - engine.validator = binding.NewDefaultValidator(vConf) - } if opt.CustomValidator != nil { customValidator, ok := opt.CustomValidator.(binding.StructValidator) if !ok { @@ -574,6 +567,14 @@ func (engine *Engine) initBinderAndValidator(opt *config.Options) { engine.validator = customValidator } + if opt.CustomBinder != nil { + customBinder, ok := opt.CustomBinder.(binding.Binder) + if !ok { + panic("customized binder can not implement binding.Binder") + } + engine.binder = customBinder + return + } // Init binder. Due to the existence of the "BindAndValidate" interface, the Validator needs to be injected here. defaultBindConfig := binding.NewBindConfig() defaultBindConfig.Validator = engine.validator @@ -583,18 +584,11 @@ func (engine *Engine) initBinderAndValidator(opt *config.Options) { if !ok { panic("bind config error") } - if bConf.Validator != nil { + if bConf.Validator == nil { bConf.Validator = engine.validator } engine.binder = binding.NewDefaultBinder(bConf) } - if opt.CustomBinder != nil { - customBinder, ok := opt.CustomBinder.(binding.Binder) - if !ok { - panic("customized binder can not implement binding.Binder") - } - engine.binder = customBinder - } } func NewEngine(opt *config.Options) *Engine { diff --git a/pkg/route/engine_test.go b/pkg/route/engine_test.go index ad25bf974..65f4fb16a 100644 --- a/pkg/route/engine_test.go +++ b/pkg/route/engine_test.go @@ -688,12 +688,11 @@ func TestInitBinderAndValidator(t *testing.T) { } }() opt := config.NewOptions([]config.Option{}) - opt.BindConfig = &binding.BindConfig{ - EnableDefaultTag: true, - } + bindConfig := binding.NewBindConfig() + bindConfig.LooseZeroMode = true + opt.BindConfig = bindConfig binder := &mockBinder{} opt.CustomBinder = binder - opt.ValidateConfig = &binding.ValidateConfig{} validator := &mockValidator{} opt.CustomValidator = validator NewEngine(opt) @@ -706,12 +705,11 @@ func TestInitBinderAndValidatorForPanic(t *testing.T) { } }() opt := config.NewOptions([]config.Option{}) - opt.BindConfig = &binding.BindConfig{ - EnableDefaultTag: true, - } + bindConfig := binding.NewBindConfig() + bindConfig.LooseZeroMode = true + opt.BindConfig = bindConfig binder := &mockBinder{} opt.CustomBinder = binder - opt.ValidateConfig = &binding.ValidateConfig{} nonValidator := &mockNonValidator{} opt.CustomValidator = nonValidator NewEngine(opt) @@ -722,7 +720,9 @@ func TestBindConfig(t *testing.T) { A int `query:"a"` } opt := config.NewOptions([]config.Option{}) - opt.BindConfig = &binding.BindConfig{LooseZeroMode: false} + bindConfig := binding.NewBindConfig() + bindConfig.LooseZeroMode = false + opt.BindConfig = bindConfig e := NewEngine(opt) e.GET("/bind", func(c context.Context, ctx *app.RequestContext) { var req Req @@ -733,7 +733,9 @@ func TestBindConfig(t *testing.T) { }) performRequest(e, "GET", "/bind?a=") - opt.BindConfig = &binding.BindConfig{LooseZeroMode: true} + bindConfig = binding.NewBindConfig() + bindConfig.LooseZeroMode = true + opt.BindConfig = bindConfig e = NewEngine(opt) e.GET("/bind", func(c context.Context, ctx *app.RequestContext) { var req Req @@ -773,7 +775,6 @@ func TestValidateConfig(t *testing.T) { validateConfig.MustRegValidateFunc("f", func(args ...interface{}) error { return fmt.Errorf("test error") }) - opt.ValidateConfig = validateConfig e := NewEngine(opt) e.GET("/validate", func(c context.Context, ctx *app.RequestContext) { var req Req @@ -793,7 +794,6 @@ func TestCustomValidator(t *testing.T) { validateConfig.MustRegValidateFunc("d", func(args ...interface{}) error { return fmt.Errorf("test error") }) - opt.ValidateConfig = validateConfig opt.CustomValidator = &mockValidator{} e := NewEngine(opt) e.GET("/validate", func(c context.Context, ctx *app.RequestContext) { diff --git a/pkg/route/routes_test.go b/pkg/route/routes_test.go index 1d4e17fb8..1e76d673e 100644 --- a/pkg/route/routes_test.go +++ b/pkg/route/routes_test.go @@ -68,7 +68,6 @@ func performRequest(e *Engine, method, path string, headers ...header) *httptest ctx.HTMLRender = e.htmlRender r := protocol.NewRequest(method, path, nil) - r.PostArgs() r.CopyTo(&ctx.Request) for _, v := range headers { ctx.Request.Header.Add(v.Key, v.Value) From e4084df3e974d5379aa274903f8c498af9de41dd Mon Sep 17 00:00:00 2001 From: fgy Date: Thu, 21 Sep 2023 21:45:03 +0800 Subject: [PATCH 89/91] feat: enable config to disable --- pkg/app/server/binding/binder_test.go | 10 +++++----- pkg/app/server/binding/config.go | 20 +++++++++---------- pkg/app/server/binding/default.go | 8 ++++---- .../binding/internal/decoder/decoder.go | 8 ++++---- pkg/app/server/binding/tagexpr_bind_test.go | 4 ++-- 5 files changed, 25 insertions(+), 25 deletions(-) diff --git a/pkg/app/server/binding/binder_test.go b/pkg/app/server/binding/binder_test.go index 5f42b27d5..1d3e9a1ae 100644 --- a/pkg/app/server/binding/binder_test.go +++ b/pkg/app/server/binding/binder_test.go @@ -804,7 +804,7 @@ func TestBind_DefaultTag(t *testing.T) { assert.DeepEqual(t, "form", result.Form) bindConfig := &BindConfig{} - bindConfig.EnableDefaultTag = false + bindConfig.DisableDefaultTag = true binder := NewDefaultBinder(bindConfig) result2 := Req2{} err = binder.Bind(req.Req, &result2, params) @@ -833,7 +833,7 @@ func TestBind_StructFieldResolve(t *testing.T) { SetUrlEncodeContentType() var result Req bindConfig := &BindConfig{} - bindConfig.EnableStructFieldResolve = true + bindConfig.DisableStructFieldResolve = false binder := NewDefaultBinder(bindConfig) err := binder.Bind(req.Req, &result, nil) if err != nil { @@ -1195,7 +1195,7 @@ func TestBind_BindProtobuf(t *testing.T) { func TestBind_PointerStruct(t *testing.T) { bindConfig := &BindConfig{} - bindConfig.EnableStructFieldResolve = true + bindConfig.DisableStructFieldResolve = false binder := NewDefaultBinder(bindConfig) type Foo struct { F1 string `query:"F1"` @@ -1228,7 +1228,7 @@ func TestBind_PointerStruct(t *testing.T) { func TestBind_StructRequired(t *testing.T) { bindConfig := &BindConfig{} - bindConfig.EnableStructFieldResolve = true + bindConfig.DisableStructFieldResolve = false binder := NewDefaultBinder(bindConfig) type Foo struct { F1 string `query:"F1"` @@ -1261,7 +1261,7 @@ func TestBind_StructRequired(t *testing.T) { func TestBind_StructErrorToWarn(t *testing.T) { bindConfig := &BindConfig{} - bindConfig.EnableStructFieldResolve = true + bindConfig.DisableStructFieldResolve = false binder := NewDefaultBinder(bindConfig) type Foo struct { F1 string `query:"F1"` diff --git a/pkg/app/server/binding/config.go b/pkg/app/server/binding/config.go index 822e6a070..c122c54c6 100644 --- a/pkg/app/server/binding/config.go +++ b/pkg/app/server/binding/config.go @@ -37,17 +37,17 @@ type BindConfig struct { // The default is false. // Suitable for these parameter types: query/header/cookie/form . LooseZeroMode bool - // EnableDefaultTag is used to add default tags to a field when it has no tag - // If is true, the field with no tag will be added default tags, for more automated binding. But there may be additional overhead. + // DisableDefaultTag is used to add default tags to a field when it has no tag + // If is false, the field with no tag will be added default tags, for more automated binding. But there may be additional overhead. // NOTE: - // The default is true. - EnableDefaultTag bool - // EnableStructFieldResolve is used to generate a separate decoder for a struct. - // If is true, the 'struct' field will get a single inDecoder.structTypeFieldTextDecoder, and use json.Unmarshal for decode it. + // The default is false. + DisableDefaultTag bool + // DisableStructFieldResolve is used to generate a separate decoder for a struct. + // If is false, the 'struct' field will get a single inDecoder.structTypeFieldTextDecoder, and use json.Unmarshal for decode it. // It usually used to add json string to query parameter. // NOTE: - // The default is true. - EnableStructFieldResolve bool + // The default is false. + DisableStructFieldResolve bool // EnableDecoderUseNumber is used to call the UseNumber method on the JSON // Decoder instance. UseNumber causes the Decoder to unmarshal a number into an // interface{} as a Number instead of as a float64. @@ -78,8 +78,8 @@ type BindConfig struct { func NewBindConfig() *BindConfig { return &BindConfig{ LooseZeroMode: false, - EnableDefaultTag: true, - EnableStructFieldResolve: true, + DisableDefaultTag: false, + DisableStructFieldResolve: false, EnableDecoderUseNumber: false, EnableDecoderDisallowUnknownFields: false, ValidateTag: "vd", diff --git a/pkg/app/server/binding/default.go b/pkg/app/server/binding/default.go index 8324b3771..28bbc5311 100644 --- a/pkg/app/server/binding/default.go +++ b/pkg/app/server/binding/default.go @@ -188,8 +188,8 @@ func (b *defaultBinder) bindTag(req *protocol.Request, v interface{}, params par decodeConfig := &inDecoder.DecodeConfig{ LooseZeroMode: b.config.LooseZeroMode, - EnableDefaultTag: b.config.EnableDefaultTag, - EnableStructFieldResolve: b.config.EnableStructFieldResolve, + DisableDefaultTag: b.config.DisableDefaultTag, + DisableStructFieldResolve: b.config.DisableStructFieldResolve, EnableDecoderUseNumber: b.config.EnableDecoderUseNumber, EnableDecoderDisallowUnknownFields: b.config.EnableDecoderDisallowUnknownFields, ValidateTag: b.config.ValidateTag, @@ -234,8 +234,8 @@ func (b *defaultBinder) bindTagWithValidate(req *protocol.Request, v interface{} } decodeConfig := &inDecoder.DecodeConfig{ LooseZeroMode: b.config.LooseZeroMode, - EnableDefaultTag: b.config.EnableDefaultTag, - EnableStructFieldResolve: b.config.EnableStructFieldResolve, + DisableDefaultTag: b.config.DisableDefaultTag, + DisableStructFieldResolve: b.config.DisableStructFieldResolve, EnableDecoderUseNumber: b.config.EnableDecoderUseNumber, EnableDecoderDisallowUnknownFields: b.config.EnableDecoderDisallowUnknownFields, ValidateTag: b.config.ValidateTag, diff --git a/pkg/app/server/binding/internal/decoder/decoder.go b/pkg/app/server/binding/internal/decoder/decoder.go index 425adfc6b..0bd13442a 100644 --- a/pkg/app/server/binding/internal/decoder/decoder.go +++ b/pkg/app/server/binding/internal/decoder/decoder.go @@ -57,8 +57,8 @@ type Decoder func(req *protocol.Request, params param.Params, rv reflect.Value) type DecodeConfig struct { LooseZeroMode bool - EnableDefaultTag bool - EnableStructFieldResolve bool + DisableDefaultTag bool + DisableStructFieldResolve bool EnableDecoderUseNumber bool EnableDecoderDisallowUnknownFields bool ValidateTag string @@ -117,7 +117,7 @@ func getFieldDecoder(field reflect.StructField, index int, parentIdx []int, pare // JSONName is like 'a.b.c' for 'required validate' fieldTagInfos, newParentJSONName, needValidate := lookupFieldTags(field, parentJSONName, config) - if len(fieldTagInfos) == 0 && config.EnableDefaultTag { + if len(fieldTagInfos) == 0 && !config.DisableDefaultTag { fieldTagInfos = getDefaultFieldTags(field) } if len(byTag) != 0 { @@ -152,7 +152,7 @@ func getFieldDecoder(field reflect.StructField, index int, parentIdx []int, pare dec, err := getMultipartFileDecoder(field, index, fieldTagInfos, parentIdx, config) return dec, needValidate, err } - if config.EnableStructFieldResolve { // decode struct type separately + if !config.DisableStructFieldResolve { // decode struct type separately structFieldDecoder, err := getStructTypeFieldDecoder(field, index, fieldTagInfos, parentIdx, config) if err != nil { return nil, needValidate, err diff --git a/pkg/app/server/binding/tagexpr_bind_test.go b/pkg/app/server/binding/tagexpr_bind_test.go index e01c4b0ab..82221745c 100644 --- a/pkg/app/server/binding/tagexpr_bind_test.go +++ b/pkg/app/server/binding/tagexpr_bind_test.go @@ -784,7 +784,7 @@ func TestOption(t *testing.T) { req = newRequest("", header, nil, bodyReader) recv2 := new(Recv2) bindConfig := &BindConfig{} - bindConfig.EnableStructFieldResolve = true + bindConfig.DisableStructFieldResolve = false binder := NewDefaultBinder(bindConfig) err = binder.Bind(req.Req, recv2, nil) assert.DeepEqual(t, err.Error(), "'X' field is a 'required' parameter, but the request does not have this parameter") @@ -936,7 +936,7 @@ func TestRegTypeUnmarshal(t *testing.T) { recv := new(T) bindConfig := &BindConfig{} - bindConfig.EnableStructFieldResolve = true + bindConfig.DisableStructFieldResolve = false binder := NewDefaultBinder(bindConfig) err = binder.Bind(req.Req, recv, nil) if err != nil { From a6e415911f77027b053dd6bc32dea83e73bee3fd Mon Sep 17 00:00:00 2001 From: fgy Date: Fri, 22 Sep 2023 11:00:57 +0800 Subject: [PATCH 90/91] refactor: ctx.bind interface --- pkg/app/context.go | 59 ++++++++++++++++++---------------------------- 1 file changed, 23 insertions(+), 36 deletions(-) diff --git a/pkg/app/context.go b/pkg/app/context.go index a13744e65..c607ce73b 100644 --- a/pkg/app/context.go +++ b/pkg/app/context.go @@ -1317,58 +1317,54 @@ func bodyAllowedForStatus(status int) bool { return true } +func (ctx *RequestContext) getBinder() binding.Binder { + if ctx.binder != nil { + return ctx.binder + } + return binding.DefaultBinder() +} + +func (ctx *RequestContext) getValidator() binding.StructValidator { + if ctx.validator != nil { + return ctx.validator + } + return binding.DefaultValidator() +} + // BindAndValidate binds data from *RequestContext to obj and validates them if needed. // NOTE: obj should be a pointer. func (ctx *RequestContext) BindAndValidate(obj interface{}) error { - if ctx.binder != nil { - return ctx.binder.BindAndValidate(&ctx.Request, obj, ctx.Params) - } - return binding.BindAndValidate(&ctx.Request, obj, ctx.Params) + return ctx.getBinder().BindAndValidate(&ctx.Request, obj, ctx.Params) } // Bind binds data from *RequestContext to obj. // NOTE: obj should be a pointer. func (ctx *RequestContext) Bind(obj interface{}) error { - if ctx.binder != nil { - return ctx.binder.Bind(&ctx.Request, obj, ctx.Params) - } - return binding.Bind(&ctx.Request, obj, ctx.Params) + return ctx.getBinder().Bind(&ctx.Request, obj, ctx.Params) } // Validate validates obj with "vd" tag // NOTE: obj should be a pointer. func (ctx *RequestContext) Validate(obj interface{}) error { - if ctx.validator != nil { - return ctx.validator.ValidateStruct(obj) - } - return binding.Validate(obj) + return ctx.getValidator().ValidateStruct(obj) } // BindQuery binds query parameters from *RequestContext to obj with 'query' tag. It will only use 'query' tag for binding. // NOTE: obj should be a pointer. func (ctx *RequestContext) BindQuery(obj interface{}) error { - if ctx.binder != nil { - return ctx.binder.BindQuery(&ctx.Request, obj) - } - return binding.DefaultBinder().BindQuery(&ctx.Request, obj) + return ctx.getBinder().BindQuery(&ctx.Request, obj) } // BindHeader binds header parameters from *RequestContext to obj with 'header' tag. It will only use 'header' tag for binding. // NOTE: obj should be a pointer. func (ctx *RequestContext) BindHeader(obj interface{}) error { - if ctx.binder != nil { - return ctx.binder.BindHeader(&ctx.Request, obj) - } - return binding.DefaultBinder().BindHeader(&ctx.Request, obj) + return ctx.getBinder().BindHeader(&ctx.Request, obj) } // BindPath binds router parameters from *RequestContext to obj with 'path' tag. It will only use 'path' tag for binding. // NOTE: obj should be a pointer. func (ctx *RequestContext) BindPath(obj interface{}) error { - if ctx.binder != nil { - return ctx.binder.BindPath(&ctx.Request, obj, ctx.Params) - } - return binding.DefaultBinder().BindPath(&ctx.Request, obj, ctx.Params) + return ctx.getBinder().BindPath(&ctx.Request, obj, ctx.Params) } // BindForm binds form parameters from *RequestContext to obj with 'form' tag. It will only use 'form' tag for binding. @@ -1377,28 +1373,19 @@ func (ctx *RequestContext) BindForm(obj interface{}) error { if len(ctx.Request.Body()) == 0 { return fmt.Errorf("missing form body") } - if ctx.binder != nil { - return ctx.binder.BindForm(&ctx.Request, obj) - } - return binding.DefaultBinder().BindForm(&ctx.Request, obj) + return ctx.getBinder().BindForm(&ctx.Request, obj) } // BindJSON binds JSON body from *RequestContext. // NOTE: obj should be a pointer. func (ctx *RequestContext) BindJSON(obj interface{}) error { - if ctx.binder != nil { - return ctx.binder.BindJSON(&ctx.Request, obj) - } - return binding.DefaultBinder().BindJSON(&ctx.Request, obj) + return ctx.getBinder().BindJSON(&ctx.Request, obj) } // BindProtobuf binds protobuf body from *RequestContext. // NOTE: obj should be a pointer. func (ctx *RequestContext) BindProtobuf(obj interface{}) error { - if ctx.binder != nil { - return ctx.binder.BindProtobuf(&ctx.Request, obj) - } - return binding.DefaultBinder().BindProtobuf(&ctx.Request, obj) + return ctx.getBinder().BindProtobuf(&ctx.Request, obj) } // BindByContentType will select the binding type on the ContentType automatically. From 256f9bb5aa32b05fa865bc3b736670b8db11697b Mon Sep 17 00:00:00 2001 From: fgy Date: Fri, 22 Sep 2023 12:29:45 +0800 Subject: [PATCH 91/91] feat: remove normalize for header bind --- pkg/app/server/binding/binder_test.go | 74 +++++++++++++++++++ .../internal/decoder/base_type_decoder.go | 4 - .../decoder/customized_type_decoder.go | 4 - .../internal/decoder/map_type_decoder.go | 4 - .../internal/decoder/slice_type_decoder.go | 4 - .../internal/decoder/struct_type_decoder.go | 4 - pkg/common/utils/utils.go | 6 -- pkg/common/utils/utils_test.go | 10 --- 8 files changed, 74 insertions(+), 36 deletions(-) diff --git a/pkg/app/server/binding/binder_test.go b/pkg/app/server/binding/binder_test.go index 1d3e9a1ae..d106ed7ad 100644 --- a/pkg/app/server/binding/binder_test.go +++ b/pkg/app/server/binding/binder_test.go @@ -1362,6 +1362,80 @@ func TestBind_InterfaceType(t *testing.T) { } } +func Test_BindHeaderNormalize(t *testing.T) { + type Req struct { + Header string `header:"h"` + } + + req := newMockRequest(). + SetRequestURI("http://foobar.com"). + SetHeaders("h", "header") + var result Req + + err := DefaultBinder().Bind(req.Req, &result, nil) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, "header", result.Header) + req = newMockRequest(). + SetRequestURI("http://foobar.com"). + SetHeaders("H", "header") + err = DefaultBinder().Bind(req.Req, &result, nil) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, "header", result.Header) + + type Req2 struct { + Header string `header:"H"` + } + + req2 := newMockRequest(). + SetRequestURI("http://foobar.com"). + SetHeaders("h", "header") + var result2 Req2 + + err2 := DefaultBinder().Bind(req2.Req, &result2, nil) + if err != nil { + t.Error(err2) + } + assert.DeepEqual(t, "header", result2.Header) + req2 = newMockRequest(). + SetRequestURI("http://foobar.com"). + SetHeaders("H", "header") + err2 = DefaultBinder().Bind(req2.Req, &result2, nil) + if err2 != nil { + t.Error(err2) + } + assert.DeepEqual(t, "header", result2.Header) + + type Req3 struct { + Header string `header:"h"` + } + + // without normalize, the header key & tag key need to be consistent + req3 := newMockRequest(). + SetRequestURI("http://foobar.com") + req3.Req.Header.DisableNormalizing() + req3.SetHeaders("h", "header") + var result3 Req3 + err3 := DefaultBinder().Bind(req3.Req, &result3, nil) + if err3 != nil { + t.Error(err3) + } + assert.DeepEqual(t, "header", result3.Header) + req3 = newMockRequest(). + SetRequestURI("http://foobar.com") + req3.Req.Header.DisableNormalizing() + req3.SetHeaders("H", "header") + result3 = Req3{} + err3 = DefaultBinder().Bind(req3.Req, &result3, nil) + if err3 != nil { + t.Error(err3) + } + assert.DeepEqual(t, "", result3.Header) +} + func Benchmark_Binding(b *testing.B) { type Req struct { Version string `path:"v"` diff --git a/pkg/app/server/binding/internal/decoder/base_type_decoder.go b/pkg/app/server/binding/internal/decoder/base_type_decoder.go index c3f4346f3..ece04f737 100644 --- a/pkg/app/server/binding/internal/decoder/base_type_decoder.go +++ b/pkg/app/server/binding/internal/decoder/base_type_decoder.go @@ -44,7 +44,6 @@ import ( "fmt" "reflect" - "github.com/cloudwego/hertz/pkg/common/utils" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/route/param" ) @@ -81,9 +80,6 @@ func (d *baseTypeFieldTextDecoder) Decode(req *protocol.Request, params param.Pa } continue } - if tagInfo.Key == headerTag { - tagInfo.Value = utils.GetNormalizeHeaderKey(tagInfo.Value, req.Header.IsDisableNormalizing()) - } text, exist = tagInfo.Getter(req, params, tagInfo.Value) defaultValue = tagInfo.Default if exist { diff --git a/pkg/app/server/binding/internal/decoder/customized_type_decoder.go b/pkg/app/server/binding/internal/decoder/customized_type_decoder.go index b192b3bda..8bf0f0121 100644 --- a/pkg/app/server/binding/internal/decoder/customized_type_decoder.go +++ b/pkg/app/server/binding/internal/decoder/customized_type_decoder.go @@ -43,7 +43,6 @@ package decoder import ( "reflect" - "github.com/cloudwego/hertz/pkg/common/utils" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/route/param" ) @@ -64,9 +63,6 @@ func (d *customizedFieldTextDecoder) Decode(req *protocol.Request, params param. defaultValue = tagInfo.Default continue } - if tagInfo.Key == headerTag { - tagInfo.Value = utils.GetNormalizeHeaderKey(tagInfo.Value, req.Header.IsDisableNormalizing()) - } text, exist = tagInfo.Getter(req, params, tagInfo.Value) defaultValue = tagInfo.Default if exist { diff --git a/pkg/app/server/binding/internal/decoder/map_type_decoder.go b/pkg/app/server/binding/internal/decoder/map_type_decoder.go index 34b1104fa..31fe85a1b 100644 --- a/pkg/app/server/binding/internal/decoder/map_type_decoder.go +++ b/pkg/app/server/binding/internal/decoder/map_type_decoder.go @@ -46,7 +46,6 @@ import ( "github.com/cloudwego/hertz/internal/bytesconv" hJson "github.com/cloudwego/hertz/pkg/common/json" - "github.com/cloudwego/hertz/pkg/common/utils" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/route/param" ) @@ -73,9 +72,6 @@ func (d *mapTypeFieldTextDecoder) Decode(req *protocol.Request, params param.Par } continue } - if tagInfo.Key == headerTag { - tagInfo.Value = utils.GetNormalizeHeaderKey(tagInfo.Value, req.Header.IsDisableNormalizing()) - } text, exist = tagInfo.Getter(req, params, tagInfo.Value) defaultValue = tagInfo.Default if exist { diff --git a/pkg/app/server/binding/internal/decoder/slice_type_decoder.go b/pkg/app/server/binding/internal/decoder/slice_type_decoder.go index 66b93ff13..fc5c9814f 100644 --- a/pkg/app/server/binding/internal/decoder/slice_type_decoder.go +++ b/pkg/app/server/binding/internal/decoder/slice_type_decoder.go @@ -47,7 +47,6 @@ import ( "github.com/cloudwego/hertz/internal/bytesconv" hJson "github.com/cloudwego/hertz/pkg/common/json" - "github.com/cloudwego/hertz/pkg/common/utils" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/route/param" ) @@ -75,9 +74,6 @@ func (d *sliceTypeFieldTextDecoder) Decode(req *protocol.Request, params param.P } continue } - if tagInfo.Key == headerTag { - tagInfo.Value = utils.GetNormalizeHeaderKey(tagInfo.Value, req.Header.IsDisableNormalizing()) - } if tagInfo.Key == rawBodyTag { bindRawBody = true } diff --git a/pkg/app/server/binding/internal/decoder/struct_type_decoder.go b/pkg/app/server/binding/internal/decoder/struct_type_decoder.go index c00a633c7..3030f2ac6 100644 --- a/pkg/app/server/binding/internal/decoder/struct_type_decoder.go +++ b/pkg/app/server/binding/internal/decoder/struct_type_decoder.go @@ -23,7 +23,6 @@ import ( "github.com/cloudwego/hertz/internal/bytesconv" "github.com/cloudwego/hertz/pkg/common/hlog" hjson "github.com/cloudwego/hertz/pkg/common/json" - "github.com/cloudwego/hertz/pkg/common/utils" "github.com/cloudwego/hertz/pkg/protocol" "github.com/cloudwego/hertz/pkg/route/param" ) @@ -50,9 +49,6 @@ func (d *structTypeFieldTextDecoder) Decode(req *protocol.Request, params param. } continue } - if tagInfo.Key == headerTag { - tagInfo.Value = utils.GetNormalizeHeaderKey(tagInfo.Value, req.Header.IsDisableNormalizing()) - } text, exist = tagInfo.Getter(req, params, tagInfo.Value) defaultValue = tagInfo.Default if exist { diff --git a/pkg/common/utils/utils.go b/pkg/common/utils/utils.go index 8cd76fe61..68778a468 100644 --- a/pkg/common/utils/utils.go +++ b/pkg/common/utils/utils.go @@ -82,12 +82,6 @@ func CaseInsensitiveCompare(a, b []byte) bool { return true } -func GetNormalizeHeaderKey(key string, disableNormalizing bool) string { - keyBytes := []byte(key) - NormalizeHeaderKey(keyBytes, disableNormalizing) - return string(keyBytes) -} - func NormalizeHeaderKey(b []byte, disableNormalizing bool) { if disableNormalizing { return diff --git a/pkg/common/utils/utils_test.go b/pkg/common/utils/utils_test.go index 7462868c7..92873b51d 100644 --- a/pkg/common/utils/utils_test.go +++ b/pkg/common/utils/utils_test.go @@ -142,13 +142,3 @@ func TestFilterContentType(t *testing.T) { contentType = FilterContentType(contentType) assert.DeepEqual(t, "text/plain", contentType) } - -func TestGetNormalizeHeaderKey(t *testing.T) { - key := "content-type" - key = GetNormalizeHeaderKey(key, false) - assert.DeepEqual(t, "Content-Type", key) - - key = "content-type" - key = GetNormalizeHeaderKey(key, true) - assert.DeepEqual(t, "content-type", key) -}