diff --git a/pkg/app/server/binding/binder_test.go b/pkg/app/server/binding/binder_test.go index c971776e3..2a66229c7 100644 --- a/pkg/app/server/binding/binder_test.go +++ b/pkg/app/server/binding/binder_test.go @@ -47,6 +47,7 @@ import ( "net/url" "reflect" "testing" + "time" "github.com/cloudwego/hertz/pkg/app/server/binding/testdata" "github.com/cloudwego/hertz/pkg/common/test/assert" @@ -526,7 +527,7 @@ type CustomizedDecode struct { func TestBind_CustomizedTypeDecode(t *testing.T) { type Foo struct { - F ***CustomizedDecode + F ***CustomizedDecode `query:"a"` } bindConfig := &BindConfig{} @@ -1491,6 +1492,29 @@ func Test_ValidatorErrorFactory(t *testing.T) { assert.DeepEqual(t, "validateErr: expr_path=[validateFailField]: B, cause=[validateErrMsg]: ", err.Error()) } +// Test_Issue964 used to the cover issue for time.Time +func Test_Issue964(t *testing.T) { + type CreateReq struct { + StartAt *time.Time `json:"startAt"` + } + r := newMockRequest().SetBody([]byte("{\n \"startAt\": \"2006-01-02T15:04:05+07:00\"\n}")).SetJSONContentType() + var req CreateReq + err := DefaultBinder().BindAndValidate(r.Req, &req, nil) + if err != nil { + t.Error(err) + } + assert.DeepEqual(t, "2006-01-02 15:04:05 +0700 +0700", req.StartAt.String()) + r = newMockRequest() + req = CreateReq{} + err = DefaultBinder().BindAndValidate(r.Req, &req, nil) + if err != nil { + t.Error(err) + } + if req.StartAt != nil { + t.Error("expected nil") + } +} + func Benchmark_Binding(b *testing.B) { type Req struct { Version string `path:"v"` diff --git a/pkg/app/server/binding/internal/decoder/customized_type_decoder.go b/pkg/app/server/binding/internal/decoder/customized_type_decoder.go index 8bf0f0121..19efa46ae 100644 --- a/pkg/app/server/binding/internal/decoder/customized_type_decoder.go +++ b/pkg/app/server/binding/internal/decoder/customized_type_decoder.go @@ -69,6 +69,9 @@ func (d *customizedFieldTextDecoder) Decode(req *protocol.Request, params param. break } } + if !exist { + return nil + } if len(text) == 0 && len(defaultValue) != 0 { text = defaultValue } @@ -77,6 +80,9 @@ func (d *customizedFieldTextDecoder) Decode(req *protocol.Request, params param. if err != nil { return err } + if !v.IsValid() { + return nil + } reqValue = GetFieldValue(reqValue, d.parentIndex) field := reqValue.Field(d.index)