Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: validate config #955

Merged
merged 4 commits into from
Sep 26, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion pkg/app/context_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -884,11 +884,15 @@ func (m *mockValidator) Engine() interface{} {
return nil
}

func (m *mockValidator) ValidateTag() string {
return "vt"
}

func TestSetValidator(t *testing.T) {
m := &mockValidator{}
c := NewContext(0)
c.SetValidator(m)
c.SetBinder(binding.NewDefaultBinder(&binding.BindConfig{ValidateTag: "vt"}))
c.SetBinder(binding.NewDefaultBinder(&binding.BindConfig{Validator: m}))
type User struct {
Age int `vt:"$>=0&&$<=130"`
}
Expand Down
55 changes: 55 additions & 0 deletions pkg/app/server/binding/binder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1436,6 +1436,61 @@ func Test_BindHeaderNormalize(t *testing.T) {
assert.DeepEqual(t, "", result3.Header)
}

type ValidateError struct {
ErrType, FailField, Msg string
}

// Error implements error interface.
func (e *ValidateError) Error() string {
if e.Msg != "" {
return e.ErrType + ": expr_path=" + e.FailField + ", cause=" + e.Msg
}
return e.ErrType + ": expr_path=" + e.FailField + ", cause=invalid"
}

func Test_ValidatorErrorFactory(t *testing.T) {
type TestBind struct {
A string `query:"a,required"`
}

r := protocol.NewRequest("GET", "/foo", nil)
r.SetRequestURI("/foo/bar?b=20")
CustomValidateErrFunc := func(failField, msg string) error {
err := ValidateError{
ErrType: "validateErr",
FailField: "[validateFailField]: " + failField,
Msg: "[validateErrMsg]: " + msg,
}

return &err
}

validateConfig := NewValidateConfig()
validateConfig.SetValidatorErrorFactory(CustomValidateErrFunc)
validator := NewValidator(validateConfig)

var req TestBind
err := Bind(r, &req, nil)
if err == nil {
t.Fatalf("unexpected nil, expected an error")
}

type TestValidate struct {
B int `query:"b" vd:"$>100"`
}

var reqValidate TestValidate
err = Bind(r, &reqValidate, nil)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
err = validator.ValidateStruct(&reqValidate)
if err == nil {
t.Fatalf("unexpected nil, expected an error")
}
assert.DeepEqual(t, "validateErr: expr_path=[validateFailField]: B, cause=[validateErrMsg]: ", err.Error())
}

func Benchmark_Binding(b *testing.B) {
type Req struct {
Version string `path:"v"`
Expand Down
29 changes: 15 additions & 14 deletions pkg/app/server/binding/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import (
"reflect"
"time"

"github.com/bytedance/go-tagexpr/v2/validator"
exprValidator "github.com/bytedance/go-tagexpr/v2/validator"
inDecoder "github.com/cloudwego/hertz/pkg/app/server/binding/internal/decoder"
hJson "github.com/cloudwego/hertz/pkg/common/json"
"github.com/cloudwego/hertz/pkg/protocol"
Expand Down Expand Up @@ -63,10 +63,6 @@ type BindConfig struct {
// The default is false.
// It is used for BindJSON().
EnableDecoderDisallowUnknownFields bool
// ValidateTag is used to determine if a filed needs to be validated.
// NOTE:
// The default is "vd".
ValidateTag string
// TypeUnmarshalFuncs registers customized type unmarshaler.
// NOTE:
// time.Time is registered by default
Expand All @@ -82,7 +78,6 @@ func NewBindConfig() *BindConfig {
DisableStructFieldResolve: false,
EnableDecoderUseNumber: false,
EnableDecoderDisallowUnknownFields: false,
ValidateTag: "vd",
TypeUnmarshalFuncs: make(map[reflect.Type]inDecoder.CustomizeDecodeFunc),
Validator: defaultValidate,
}
Expand Down Expand Up @@ -145,7 +140,12 @@ func (config *BindConfig) UseStdJSONUnmarshaler() {
config.UseThirdPartyJSONUnmarshaler(stdJson.Unmarshal)
}

type ValidateConfig struct{}
type ValidateErrFactory func(fieldSelector, msg string) error

type ValidateConfig struct {
ValidateTag string
ErrFactory ValidateErrFactory
}

func NewValidateConfig() *ValidateConfig {
return &ValidateConfig{}
Expand All @@ -157,14 +157,15 @@ func NewValidateConfig() *ValidateConfig {
// If force=true, allow to cover the existed same funcName.
// MustRegValidateFunc will remain in effect once it has been called.
func (config *ValidateConfig) MustRegValidateFunc(funcName string, fn func(args ...interface{}) error, force ...bool) {
validator.MustRegFunc(funcName, fn, force...)
exprValidator.MustRegFunc(funcName, fn, force...)
}

// SetValidatorErrorFactory customizes the factory of validation error.
func (config *ValidateConfig) SetValidatorErrorFactory(validatingErrFactory func(failField, msg string) error) {
if val, ok := DefaultValidator().(*defaultValidator); ok {
val.validate.SetErrorFactory(validatingErrFactory)
} else {
panic("customized validator can not use 'SetValidatorErrorFactory'")
}
func (config *ValidateConfig) SetValidatorErrorFactory(errFactory ValidateErrFactory) {
config.ErrFactory = errFactory
}

// SetValidatorTag customizes the factory of validation error.
func (config *ValidateConfig) SetValidatorTag(tag string) {
config.ValidateTag = tag
}
85 changes: 60 additions & 25 deletions pkg/app/server/binding/default.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ import (
"reflect"
"sync"

"github.com/bytedance/go-tagexpr/v2/validator"
exprValidator "github.com/bytedance/go-tagexpr/v2/validator"
"github.com/cloudwego/hertz/internal/bytesconv"
inDecoder "github.com/cloudwego/hertz/pkg/app/server/binding/internal/decoder"
hJson "github.com/cloudwego/hertz/pkg/common/json"
Expand All @@ -81,10 +81,11 @@ import (
)

const (
queryTag = "query"
headerTag = "header"
formTag = "form"
pathTag = "path"
queryTag = "query"
headerTag = "header"
formTag = "form"
pathTag = "path"
defaultValidateTag = "vd"
)

type decoderInfo struct {
Expand Down Expand Up @@ -185,14 +186,17 @@ func (b *defaultBinder) bindTag(req *protocol.Request, v interface{}, params par
decoder := cached.(decoderInfo)
return decoder.decoder(req, params, rv.Elem())
}

validateTag := defaultValidateTag
if len(b.config.Validator.ValidateTag()) != 0 {
validateTag = b.config.Validator.ValidateTag()
}
decodeConfig := &inDecoder.DecodeConfig{
LooseZeroMode: b.config.LooseZeroMode,
DisableDefaultTag: b.config.DisableDefaultTag,
DisableStructFieldResolve: b.config.DisableStructFieldResolve,
EnableDecoderUseNumber: b.config.EnableDecoderUseNumber,
EnableDecoderDisallowUnknownFields: b.config.EnableDecoderDisallowUnknownFields,
ValidateTag: b.config.ValidateTag,
ValidateTag: validateTag,
TypeUnmarshalFuncs: b.config.TypeUnmarshalFuncs,
}
decoder, needValidate, err := inDecoder.GetReqDecoder(rv.Type(), tag, decodeConfig)
Expand Down Expand Up @@ -232,13 +236,17 @@ func (b *defaultBinder) bindTagWithValidate(req *protocol.Request, v interface{}
}
return err
}
validateTag := defaultValidateTag
if len(b.config.Validator.ValidateTag()) != 0 {
validateTag = b.config.Validator.ValidateTag()
}
decodeConfig := &inDecoder.DecodeConfig{
LooseZeroMode: b.config.LooseZeroMode,
DisableDefaultTag: b.config.DisableDefaultTag,
DisableStructFieldResolve: b.config.DisableStructFieldResolve,
EnableDecoderUseNumber: b.config.EnableDecoderUseNumber,
EnableDecoderDisallowUnknownFields: b.config.EnableDecoderDisallowUnknownFields,
ValidateTag: b.config.ValidateTag,
ValidateTag: validateTag,
TypeUnmarshalFuncs: b.config.TypeUnmarshalFuncs,
}
decoder, needValidate, err := inDecoder.GetReqDecoder(rv.Type(), tag, decodeConfig)
Expand Down Expand Up @@ -371,39 +379,66 @@ func (b *defaultBinder) bindNonStruct(req *protocol.Request, v interface{}) (err
return
}

var _ StructValidator = (*defaultValidator)(nil)
var _ StructValidator = (*validator)(nil)

type validator struct {
validateTag string
validate *exprValidator.Validator
}

func NewValidator(config *ValidateConfig) StructValidator {
validateTag := defaultValidateTag
if config != nil && len(config.ValidateTag) != 0 {
validateTag = config.ValidateTag
}
vd := exprValidator.New(validateTag).SetErrorFactory(defaultValidateErrorFactory)
if config != nil && config.ErrFactory != nil {
vd.SetErrorFactory(config.ErrFactory)
}
return &validator{
validateTag: validateTag,
validate: vd,
}
}

// Error validate error
type validateError struct {
FailPath, Msg string
}

type defaultValidator struct {
once sync.Once
validate *validator.Validator
// Error implements error interface.
func (e *validateError) Error() string {
if e.Msg != "" {
welkeyever marked this conversation as resolved.
Show resolved Hide resolved
return e.Msg
}
return "invalid parameter: " + e.FailPath
}

func NewDefaultValidator(config *ValidateConfig) StructValidator {
return &defaultValidator{}
func defaultValidateErrorFactory(failPath, msg string) error {
return &validateError{
FailPath: failPath,
Msg: msg,
}
}

// ValidateStruct receives any kind of type, but only performed struct or pointer to struct type.
func (v *defaultValidator) ValidateStruct(obj interface{}) error {
func (v *validator) ValidateStruct(obj interface{}) error {
if obj == nil {
return nil
}
v.lazyinit()
return v.validate.Validate(obj)
}

func (v *defaultValidator) lazyinit() {
v.once.Do(func() {
v.validate = validator.Default()
})
}

// Engine returns the underlying validator
func (v *defaultValidator) Engine() interface{} {
v.lazyinit()
func (v *validator) Engine() interface{} {
return v.validate
}

var defaultValidate = NewDefaultValidator(nil)
func (v *validator) ValidateTag() string {
return v.validateTag
}

var defaultValidate = NewValidator(NewValidateConfig())

func DefaultValidator() StructValidator {
return defaultValidate
Expand Down
1 change: 1 addition & 0 deletions pkg/app/server/binding/validator.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,5 @@ package binding
type StructValidator interface {
ValidateStruct(interface{}) error
Engine() interface{}
ValidateTag() string
}
29 changes: 29 additions & 0 deletions pkg/app/server/binding/validator_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,32 @@ func Test_ValidateStruct(t *testing.T) {
t.Fatalf("expected an error, but got nil")
}
}

func Test_ValidateTag(t *testing.T) {
type User struct {
Age int `query:"age" vt:"$>=0&&$<=130"`
}

user := &User{
Age: 135,
}
validateConfig := NewValidateConfig()
validateConfig.ValidateTag = "vt"
vd := NewValidator(validateConfig)
err := vd.ValidateStruct(user)
if err == nil {
t.Fatalf("expected an error, but got nil")
}

bindConfig := NewBindConfig()
bindConfig.Validator = vd
binder := NewDefaultBinder(bindConfig)
user = &User{}
req := newMockRequest().
SetRequestURI("http://foobar.com?age=135").
SetHeaders("h", "header")
err = binder.BindAndValidate(req.Req, user, nil)
if err == nil {
t.Fatalf("expected an error, but got nil")
}
}
Loading