Skip to content

Commit

Permalink
feat: add file bind
Browse files Browse the repository at this point in the history
  • Loading branch information
FGYFFFF committed Apr 3, 2023
1 parent 7d9319a commit 72eb7ed
Show file tree
Hide file tree
Showing 9 changed files with 286 additions and 16 deletions.
4 changes: 3 additions & 1 deletion pkg/app/server/binding/base_type_decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ func (d *baseTypeFieldTextDecoder) Decode(req *protocol.Request, params PathPara
var text string
var defaultValue string
for _, tagInfo := range d.tagInfos {
if tagInfo.Key == jsonTag {
if tagInfo.Key == jsonTag || tagInfo.Key == fileNameTag {
continue
}
if tagInfo.Key == headerTag {
Expand Down Expand Up @@ -139,6 +139,8 @@ func getBaseTypeTextDecoder(field reflect.StructField, index int, tagInfos []Tag
// do nothing
case rawBodyTag:
tagInfo.Getter = RawBody
case fileNameTag:
// do nothing
default:
}
}
Expand Down
86 changes: 83 additions & 3 deletions pkg/app/server/binding/binder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,12 @@ package binding

import (
"fmt"
"mime/multipart"
"testing"

"github.com/cloudwego/hertz/pkg/common/test/assert"
"github.com/cloudwego/hertz/pkg/protocol"
req2 "github.com/cloudwego/hertz/pkg/protocol/http1/req"
"github.com/cloudwego/hertz/pkg/route/param"
)

Expand All @@ -64,6 +66,11 @@ func (m *mockRequest) SetRequestURI(uri string) *mockRequest {
return m
}

func (m *mockRequest) SetFile(param, fileName string) *mockRequest {
m.Req.SetFile(param, fileName)
return m
}

func (m *mockRequest) SetHeader(key, value string) *mockRequest {
m.Req.Header.Set(key, value)
return m
Expand Down Expand Up @@ -128,7 +135,7 @@ func TestBind_BaseType(t *testing.T) {

func TestBind_SliceType(t *testing.T) {
type Req struct {
ID []int `query:"id"`
ID *[]int `query:"id"`
Str [3]string `query:"str"`
Byte []byte `query:"b"`
}
Expand All @@ -145,9 +152,9 @@ func TestBind_SliceType(t *testing.T) {
if err != nil {
t.Error(err)
}
assert.DeepEqual(t, 3, len(result.ID))
assert.DeepEqual(t, 3, len(*result.ID))
for idx, val := range IDs {
assert.DeepEqual(t, val, result.ID[idx])
assert.DeepEqual(t, val, (*result.ID)[idx])
}
assert.DeepEqual(t, 3, len(result.Str))
for idx, val := range Strs {
Expand Down Expand Up @@ -578,6 +585,79 @@ func TestBind_ResetJSONUnmarshal(t *testing.T) {
}
}

func TestBind_FileBind(t *testing.T) {
type Nest struct {
N multipart.FileHeader `file_name:"d"`
}

var s struct {
A *multipart.FileHeader `file_name:"a"`
B *multipart.FileHeader `form:"b"`
C multipart.FileHeader
D **Nest `file_name:"d"`
}
fileName := "binder_test.go"
req := newMockRequest().
SetRequestURI("http://foobar.com").
SetFile("a", fileName).
SetFile("b", fileName).
SetFile("C", fileName).
SetFile("d", fileName)
req2 := req2.GetHTTP1Request(req.Req)
req2.String()
err := DefaultBinder.Bind(req.Req, nil, &s)
if err != nil {
t.Fatal(err)
}
assert.DeepEqual(t, fileName, s.A.Filename)
assert.DeepEqual(t, fileName, s.B.Filename)
assert.DeepEqual(t, fileName, s.C.Filename)
assert.DeepEqual(t, fileName, (**s.D).N.Filename)
}

func TestBind_FileSliceBind(t *testing.T) {
type Nest struct {
N *[]*multipart.FileHeader `form:"b"`
}
var s struct {
A []multipart.FileHeader `form:"a"`
B [3]multipart.FileHeader `form:"b"`
C []*multipart.FileHeader `form:"b"`
D Nest
}
fileName := "binder_test.go"
req := newMockRequest().
SetRequestURI("http://foobar.com").
SetFile("a", fileName).
SetFile("a", fileName).
SetFile("a", fileName).
SetFile("b", fileName).
SetFile("b", fileName).
SetFile("b", fileName)
req2 := req2.GetHTTP1Request(req.Req)
req2.String()
err := DefaultBinder.Bind(req.Req, nil, &s)
if err != nil {
t.Fatal(err)
}
assert.DeepEqual(t, 3, len(s.A))
for _, file := range s.A {
assert.DeepEqual(t, fileName, file.Filename)
}
assert.DeepEqual(t, 3, len(s.B))
for _, file := range s.B {
assert.DeepEqual(t, fileName, file.Filename)
}
assert.DeepEqual(t, 3, len(s.C))
for _, file := range s.C {
assert.DeepEqual(t, fileName, file.Filename)
}
assert.DeepEqual(t, 3, len(*s.D.N))
for _, file := range *s.D.N {
assert.DeepEqual(t, fileName, file.Filename)
}
}

func Benchmark_Binding(b *testing.B) {
type Req struct {
Version string `path:"v"`
Expand Down
6 changes: 6 additions & 0 deletions pkg/app/server/binding/decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ package binding

import (
"fmt"
"mime/multipart"
"reflect"

"github.com/cloudwego/hertz/pkg/protocol"
Expand Down Expand Up @@ -126,6 +127,11 @@ func getFieldDecoder(field reflect.StructField, index int, parentIdx []int) ([]d
if field.Type.Kind() == reflect.Struct {
var decoders []decoder
el := field.Type
// todo: built-in bindings for some common structs, code need to be optimized
switch el {
case reflect.TypeOf(multipart.FileHeader{}):
return getMultipartFileDecoder(field, index, fieldTagInfos, parentIdx)
}

for i := 0; i < el.NumField(); i++ {
if !el.Field(i).IsExported() {
Expand Down
5 changes: 5 additions & 0 deletions pkg/app/server/binding/getter.go
Original file line number Diff line number Diff line change
Expand Up @@ -155,3 +155,8 @@ func RawBody(req *protocol.Request, params PathParams, key string, defaultValue
}
return
}

func FileName(req *protocol.Request, params PathParams, key string, defaultValue ...string) (ret []string) {
// do nothing
return
}
4 changes: 3 additions & 1 deletion pkg/app/server/binding/map_type_decoder.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ func (d *mapTypeFieldTextDecoder) Decode(req *protocol.Request, params PathParam
var text string
var defaultValue string
for _, tagInfo := range d.tagInfos {
if tagInfo.Key == jsonTag {
if tagInfo.Key == jsonTag || tagInfo.Key == fileNameTag {
continue
}
if tagInfo.Key == headerTag {
Expand Down Expand Up @@ -120,6 +120,8 @@ func getMapTypeTextDecoder(field reflect.StructField, index int, tagInfos []TagI
// do nothing
case rawBodyTag:
tagInfo.Getter = RawBody
case fileNameTag:
// do nothing
default:
}
}
Expand Down
148 changes: 148 additions & 0 deletions pkg/app/server/binding/multipart_file_decoder.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
package binding

import (
"fmt"
"github.com/cloudwego/hertz/pkg/protocol"
"reflect"
)

type fileTypeDecoder struct {
fieldInfo
isRepeated bool
}

func (d *fileTypeDecoder) Decode(req *protocol.Request, params PathParams, reqValue reflect.Value) error {
fieldValue := GetFieldValue(reqValue, d.parentIndex)
field := fieldValue.Field(d.index)

if d.isRepeated {
return d.fileSliceDecode(req, params, reqValue)
}
var fileName string
// file_name > form > fieldName
for _, tagInfo := range d.tagInfos {
if tagInfo.Key == fileNameTag {
fileName = tagInfo.Value
break
}
if tagInfo.Key == formTag {
fileName = tagInfo.Value
}
}
if len(fileName) == 0 {
fileName = d.fieldName
}
file, err := req.FormFile(fileName)
if err != nil {
return fmt.Errorf("can not get file '%s', err: %v", fileName, err)
}
if field.Kind() == reflect.Ptr {
t := field.Type()
var ptrDepth int
for t.Kind() == reflect.Ptr {
t = t.Elem()
ptrDepth++
}
v := reflect.New(t).Elem()
v.Set(reflect.ValueOf(*file))
field.Set(ReferenceValue(v, ptrDepth))
return nil
}

// Non-pointer elems
field.Set(reflect.ValueOf(*file))

return nil
}

func (d *fileTypeDecoder) fileSliceDecode(req *protocol.Request, params PathParams, reqValue reflect.Value) error {
fieldValue := GetFieldValue(reqValue, d.parentIndex)
field := fieldValue.Field(d.index)
// 如果没值,需要为其建一个值
if field.Kind() == reflect.Ptr {
if field.IsNil() {
nonNilVal, ptrDepth := GetNonNilReferenceValue(field)
field.Set(ReferenceValue(nonNilVal, ptrDepth))
}
}
var parentPtrDepth int
for field.Kind() == reflect.Ptr {
field = field.Elem()
parentPtrDepth++
}

var fileName string
// file_name > form > fieldName
for _, tagInfo := range d.tagInfos {
if tagInfo.Key == fileNameTag {
fileName = tagInfo.Value
break
}
if tagInfo.Key == formTag {
fileName = tagInfo.Value
}
}
if len(fileName) == 0 {
fileName = d.fieldName
}
multipartForm, err := req.MultipartForm()
if err != nil {
return fmt.Errorf("can not get multipartForm info, err: %v", err)
}
files, exist := multipartForm.File[fileName]
if !exist {
return fmt.Errorf("the file '%s' is not existed", fileName)
}

if field.Kind() == reflect.Array {
if len(files) != field.Len() {
return fmt.Errorf("the numbers(%d) of file '%s' does not match the length(%d) of %s", len(files), fileName, field.Len(), field.Type().String())
}
} else {
// slice need creating enough capacity
field = reflect.MakeSlice(field.Type(), len(files), len(files))
}

// handle multiple pointer
var ptrDepth int
t := d.fieldType.Elem()
elemKind := t.Kind()
for elemKind == reflect.Ptr {
t = t.Elem()
elemKind = t.Kind()
ptrDepth++
}

for idx, file := range files {
v := reflect.New(t).Elem()
v.Set(reflect.ValueOf(*file))
field.Index(idx).Set(ReferenceValue(v, ptrDepth))
}
fieldValue.Field(d.index).Set(ReferenceValue(field, parentPtrDepth))

return nil
}

func getMultipartFileDecoder(field reflect.StructField, index int, tagInfos []TagInfo, parentIdx []int) ([]decoder, error) {
fieldType := field.Type
for field.Type.Kind() == reflect.Ptr {
fieldType = field.Type.Elem()
}
isRepeated := false
if fieldType.Kind() == reflect.Array || fieldType.Kind() == reflect.Slice {
isRepeated = true
}

fieldDecoder := &fileTypeDecoder{
fieldInfo: fieldInfo{
index: index,
parentIndex: parentIdx,
fieldName: field.Name,
tagInfos: tagInfos,
fieldType: fieldType,
},
isRepeated: isRepeated,
}

return []decoder{fieldDecoder}, nil
}
8 changes: 8 additions & 0 deletions pkg/app/server/binding/reflect.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,11 @@ func GetFieldValue(reqValue reflect.Value, parentIndex []int) reflect.Value {

return reqValue
}

func getElemType(t reflect.Type) reflect.Type {
for t.Kind() == reflect.Ptr {
t = t.Elem()
}

return t
}
Loading

0 comments on commit 72eb7ed

Please sign in to comment.