Skip to content

Commit

Permalink
feat: add struct binding
Browse files Browse the repository at this point in the history
  • Loading branch information
FGYFFFF committed Jan 17, 2023
1 parent 3b8a6a8 commit ae0cf9d
Show file tree
Hide file tree
Showing 8 changed files with 641 additions and 86 deletions.
71 changes: 52 additions & 19 deletions pkg/app/server/binding_v2/base_type_decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,15 +5,17 @@ import (
"reflect"

"github.com/cloudwego/hertz/pkg/app/server/binding_v2/text_decoder"
"github.com/cloudwego/hertz/pkg/common/utils"
"github.com/cloudwego/hertz/pkg/protocol"
)

type baseTypeFieldTextDecoder struct {
index int
fieldName string
tagInfos []TagInfo // query,param,header,respHeader ...
fieldType reflect.Type
decoder text_decoder.TextDecoder
index int
parentIndex []int
fieldName string
tagInfos []TagInfo // query,param,header,respHeader ...
fieldType reflect.Type
decoder text_decoder.TextDecoder
}

func (d *baseTypeFieldTextDecoder) Decode(req *protocol.Request, params PathParams, reqValue reflect.Value) error {
Expand All @@ -24,6 +26,11 @@ func (d *baseTypeFieldTextDecoder) Decode(req *protocol.Request, params PathPara
if tagInfo.Key == jsonTag {
continue
}
if tagInfo.Key == headerTag {
tmp := []byte(tagInfo.Value)
utils.NormalizeHeaderKey(tmp, req.Header.IsDisableNormalizing())
tagInfo.Value = string(tmp)
}
ret := tagInfo.Getter(req, params, tagInfo.Value)
defaultValue = tagInfo.Default
if len(ret) != 0 {
Expand All @@ -40,18 +47,43 @@ func (d *baseTypeFieldTextDecoder) Decode(req *protocol.Request, params PathPara
}

var err error
// 找到该 field 的父 struct 的 reflect.Value
for _, idx := range d.parentIndex {
if reqValue.Kind() == reflect.Ptr && reqValue.IsNil() {
nonNilVal, ptrDepth := GetNonNilReferenceValue(reqValue)
reqValue.Set(ReferenceValue(nonNilVal, ptrDepth))
}
for reqValue.Kind() == reflect.Ptr {
reqValue = reqValue.Elem()
}
reqValue = reqValue.Field(idx)
}

// 父 struct 有可能也是一个指针,所以需要再处理一次才能得到最终的父Value(非nil的reflect.Value)
for reqValue.Kind() == reflect.Ptr {
if reqValue.IsNil() {
nonNilVal, ptrDepth := GetNonNilReferenceValue(reqValue)
reqValue.Set(ReferenceValue(nonNilVal, ptrDepth))
}
reqValue = reqValue.Elem()
}

// Pointer support for struct elems
// 根据最终的 Struct,获取对应 field 的 reflect.Value
field := reqValue.Field(d.index)
if field.Kind() == reflect.Ptr {
elem := reflect.New(d.fieldType)
err = d.decoder.UnmarshalString(text, elem.Elem())
// 如果是指针则新建一个reflect.Value,然后赋值给指针
t := field.Type()
var ptrDepth int
for t.Kind() == reflect.Ptr {
t = t.Elem()
ptrDepth++
}
var vv reflect.Value
vv, err := stringToValue(t, text)
if err != nil {
return fmt.Errorf("unable to decode '%s' as %s: %w", text, d.fieldType.Name(), err)
return err
}

field.Set(elem)

field.Set(ReferenceValue(vv, ptrDepth))
return nil
}

Expand All @@ -64,7 +96,7 @@ func (d *baseTypeFieldTextDecoder) Decode(req *protocol.Request, params PathPara
return nil
}

func getBaseTypeTextDecoder(field reflect.StructField, index int, tagInfos []TagInfo) ([]decoder, error) {
func getBaseTypeTextDecoder(field reflect.StructField, index int, tagInfos []TagInfo, parentIdx []int) ([]decoder, error) {
for idx, tagInfo := range tagInfos {
switch tagInfo.Key {
case pathTag:
Expand All @@ -86,7 +118,7 @@ func getBaseTypeTextDecoder(field reflect.StructField, index int, tagInfos []Tag
}

fieldType := field.Type
if field.Type.Kind() == reflect.Ptr {
for field.Type.Kind() == reflect.Ptr {
fieldType = field.Type.Elem()
}

Expand All @@ -96,11 +128,12 @@ func getBaseTypeTextDecoder(field reflect.StructField, index int, tagInfos []Tag
}

fieldDecoder := &baseTypeFieldTextDecoder{
index: index,
fieldName: field.Name,
tagInfos: tagInfos,
decoder: textDecoder,
fieldType: fieldType,
index: index,
parentIndex: parentIdx,
fieldName: field.Name,
tagInfos: tagInfos,
decoder: textDecoder,
fieldType: fieldType,
}

return []decoder{fieldDecoder}, nil
Expand Down
3 changes: 3 additions & 0 deletions pkg/app/server/binding_v2/binder.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ func (b *Bind) Bind(req *protocol.Request, params PathParams, v interface{}) err
if rv.Kind() != reflect.Pointer || rv.IsNil() {
return fmt.Errorf("receiver must be a non-nil pointer")
}
if rv.Kind() == reflect.Map {
return nil
}
cached, ok := b.decoderCache.Load(typeID)
if ok {
// cached decoder, fast path
Expand Down
Loading

0 comments on commit ae0cf9d

Please sign in to comment.