Skip to content

Commit

Permalink
refactor: validate tag config
Browse files Browse the repository at this point in the history
  • Loading branch information
FGYFFFF committed Sep 22, 2023
1 parent dbaf4a1 commit f694ddd
Show file tree
Hide file tree
Showing 7 changed files with 87 additions and 11 deletions.
6 changes: 5 additions & 1 deletion pkg/app/context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -884,11 +884,15 @@ func (m *mockValidator) Engine() interface{} {
return nil
}

func (m *mockValidator) ValidateTag() string {
return "vt"
}

func TestSetValidator(t *testing.T) {
m := &mockValidator{}
c := NewContext(0)
c.SetValidator(m)
c.SetBinder(binding.NewDefaultBinder(&binding.BindConfig{ValidateTag: "vt"}))
c.SetBinder(binding.NewDefaultBinder(&binding.BindConfig{Validator: m}))
type User struct {
Age int `vt:"$>=0&&$<=130"`
}
Expand Down
5 changes: 0 additions & 5 deletions pkg/app/server/binding/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,6 @@ type BindConfig struct {
// The default is false.
// It is used for BindJSON().
EnableDecoderDisallowUnknownFields bool
// ValidateTag is used to determine if a filed needs to be validated.
// NOTE:
// The default is "vd".
ValidateTag string
// TypeUnmarshalFuncs registers customized type unmarshaler.
// NOTE:
// time.Time is registered by default
Expand All @@ -82,7 +78,6 @@ func NewBindConfig() *BindConfig {
DisableStructFieldResolve: false,
EnableDecoderUseNumber: false,
EnableDecoderDisallowUnknownFields: false,
ValidateTag: "vd",
TypeUnmarshalFuncs: make(map[reflect.Type]inDecoder.CustomizeDecodeFunc),
Validator: defaultValidate,
}
Expand Down
23 changes: 18 additions & 5 deletions pkg/app/server/binding/default.go
Original file line number Diff line number Diff line change
Expand Up @@ -185,14 +185,17 @@ func (b *defaultBinder) bindTag(req *protocol.Request, v interface{}, params par
decoder := cached.(decoderInfo)
return decoder.decoder(req, params, rv.Elem())
}

validateTag := "vd"
if len(b.config.Validator.ValidateTag()) != 0 {
validateTag = b.config.Validator.ValidateTag()
}
decodeConfig := &inDecoder.DecodeConfig{
LooseZeroMode: b.config.LooseZeroMode,
DisableDefaultTag: b.config.DisableDefaultTag,
DisableStructFieldResolve: b.config.DisableStructFieldResolve,
EnableDecoderUseNumber: b.config.EnableDecoderUseNumber,
EnableDecoderDisallowUnknownFields: b.config.EnableDecoderDisallowUnknownFields,
ValidateTag: b.config.ValidateTag,
ValidateTag: validateTag,
TypeUnmarshalFuncs: b.config.TypeUnmarshalFuncs,
}
decoder, needValidate, err := inDecoder.GetReqDecoder(rv.Type(), tag, decodeConfig)
Expand Down Expand Up @@ -232,13 +235,17 @@ func (b *defaultBinder) bindTagWithValidate(req *protocol.Request, v interface{}
}
return err
}
validateTag := "vd"
if len(b.config.Validator.ValidateTag()) != 0 {
validateTag = b.config.Validator.ValidateTag()
}
decodeConfig := &inDecoder.DecodeConfig{
LooseZeroMode: b.config.LooseZeroMode,
DisableDefaultTag: b.config.DisableDefaultTag,
DisableStructFieldResolve: b.config.DisableStructFieldResolve,
EnableDecoderUseNumber: b.config.EnableDecoderUseNumber,
EnableDecoderDisallowUnknownFields: b.config.EnableDecoderDisallowUnknownFields,
ValidateTag: b.config.ValidateTag,
ValidateTag: validateTag,
TypeUnmarshalFuncs: b.config.TypeUnmarshalFuncs,
}
decoder, needValidate, err := inDecoder.GetReqDecoder(rv.Type(), tag, decodeConfig)
Expand Down Expand Up @@ -374,7 +381,8 @@ func (b *defaultBinder) bindNonStruct(req *protocol.Request, v interface{}) (err
var _ StructValidator = (*defaultValidator)(nil)

type defaultValidator struct {
validate *validator.Validator
validateTag string
validate *validator.Validator
}

func NewDefaultValidator(config *ValidateConfig) StructValidator {
Expand All @@ -387,7 +395,8 @@ func NewDefaultValidator(config *ValidateConfig) StructValidator {
vd.SetErrorFactory(config.ErrFactory)
}
return &defaultValidator{
validate: vd,
validateTag: validateTag,
validate: vd,
}
}

Expand Down Expand Up @@ -424,6 +433,10 @@ func (v *defaultValidator) Engine() interface{} {
return v.validate
}

func (v *defaultValidator) ValidateTag() string {
return v.validateTag
}

var defaultValidate = NewDefaultValidator(NewValidateConfig())

func DefaultValidator() StructValidator {
Expand Down
1 change: 1 addition & 0 deletions pkg/app/server/binding/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,5 @@ package binding
type StructValidator interface {
ValidateStruct(interface{}) error
Engine() interface{}
ValidateTag() string
}
29 changes: 29 additions & 0 deletions pkg/app/server/binding/validator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,32 @@ func Test_ValidateStruct(t *testing.T) {
t.Fatalf("expected an error, but got nil")
}
}

func Test_ValidateTag(t *testing.T) {
type User struct {
Age int `query:"age" vt:"$>=0&&$<=130"`
}

user := &User{
Age: 135,
}
validateConfig := NewValidateConfig()
validateConfig.ValidateTag = "vt"
vd := NewDefaultValidator(validateConfig)
err := vd.ValidateStruct(user)
if err == nil {
t.Fatalf("expected an error, but got nil")
}

bindConfig := NewBindConfig()
bindConfig.Validator = vd
binder := NewDefaultBinder(bindConfig)
user = &User{}
req := newMockRequest().
SetRequestURI("http://foobar.com?age=135").
SetHeaders("h", "header")
err = binder.BindAndValidate(req.Req, user, nil)
if err == nil {
t.Fatalf("expected an error, but got nil")
}
}
30 changes: 30 additions & 0 deletions pkg/app/server/hertz_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -966,6 +966,10 @@ func (m *mockValidator) Engine() interface{} {
return nil
}

func (m *mockValidator) ValidateTag() string {
return "vd"
}

func TestCustomValidator(t *testing.T) {
type Req struct {
A int `query:"a" vd:"f($)"`
Expand Down Expand Up @@ -1036,3 +1040,29 @@ func TestValidateConfigSetSetErrorFactory(t *testing.T) {
assert.Nil(t, err)
time.Sleep(100 * time.Millisecond)
}

func TestValidateConfigAndBindConfig(t *testing.T) {
type Req struct {
A int `query:"a" vt:"$>=0&&$<=130"`
}
validateConfig := binding.NewValidateConfig()
validateConfig.ValidateTag = "vt"
h := New(
WithHostPorts("localhost:9876"),
WithValidateConfig(validateConfig))
h.GET("/bind", func(c context.Context, ctx *app.RequestContext) {
var req Req
err := ctx.BindAndValidate(&req)
if err == nil {
t.Fatal("expect an error")
}
t.Log(err)
})

go h.Spin()
time.Sleep(100 * time.Millisecond)
hc := http.Client{Timeout: time.Second}
_, err := hc.Get("http://127.0.0.1:9876/bind?a=135")
assert.Nil(t, err)
time.Sleep(100 * time.Millisecond)
}
4 changes: 4 additions & 0 deletions pkg/route/engine_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -675,6 +675,10 @@ func (m *mockValidator) Engine() interface{} {
return nil
}

func (m *mockValidator) ValidateTag() string {
return "vd"
}

type mockNonValidator struct{}

func (m *mockNonValidator) ValidateStruct(interface{}) error {
Expand Down

0 comments on commit f694ddd

Please sign in to comment.