From 4eda0ac5edc61525d1df86499e4dcd0cdb2771bb Mon Sep 17 00:00:00 2001 From: Chris Roche Date: Thu, 13 Apr 2023 11:46:00 -0700 Subject: [PATCH] simplify evaluators --- go/evaluator.go | 129 ++++++++++---------- go/expressions.go | 18 +-- go/gen/example/v1/validations.pb.go | 64 +++++----- go/registry.go | 134 ++++++++++++--------- go/validator.go | 2 +- proto/private/example/v1/validations.proto | 6 +- 6 files changed, 182 insertions(+), 171 deletions(-) diff --git a/go/evaluator.go b/go/evaluator.go index 10268b43..96b57cd8 100644 --- a/go/evaluator.go +++ b/go/evaluator.go @@ -15,29 +15,40 @@ package protovalidate import ( - validatev2 "github.com/bufbuild/protovalidate/go/gen/buf/validate" + "github.com/bufbuild/protovalidate/go/gen/buf/validate" "google.golang.org/protobuf/reflect/protoreflect" ) type evaluator interface { - evaluate(msg protoreflect.Message, failFast bool) error + evaluate(val protoreflect.Value, failFast bool) error } -type messageEvaluator struct { - err error - constraints []evaluator +type messageEvaluator interface { + evaluateMessage(msg protoreflect.Message, failFast bool) error } -func (m *messageEvaluator) evaluate(msg protoreflect.Message, failFast bool) error { - if err := m.err; err != nil { - return err +type constraintsEval []compiledExpression + +func (c constraintsEval) evaluate(val protoreflect.Value, failFast bool) error { + iface := val.Interface() + if msg, ok := iface.(protoreflect.Message); ok { + return c.eval(msg.Interface(), failFast) } + return c.eval(iface, failFast) +} + +func (c constraintsEval) evaluateMessage(msg protoreflect.Message, failFast bool) error { + return c.eval(msg.Interface(), failFast) +} + +func (c constraintsEval) eval(val any, failFast bool) error { + binding := namedBinding{name: "this", val: val} var ( err error ok bool ) - for _, constraint := range m.constraints { - evalErr := constraint.evaluate(msg, failFast) + for _, expr := range c { + evalErr := expr.eval(binding) if err, ok = mergeErrors(err, evalErr, failFast); !ok { break } @@ -45,81 +56,69 @@ func (m *messageEvaluator) evaluate(msg protoreflect.Message, failFast bool) err return err } -type messageExpressionEvaluator struct { - exprs []compiledExpression -} - -func (m messageExpressionEvaluator) evaluate(msg protoreflect.Message, failFast bool) error { - binding := namedBinding{name: "this", val: msg.Interface()} - return evalExprs(m.exprs, binding, failFast) +type messageEval struct { + err error + constraints []messageEvaluator } -type fieldExpressionEvaluator struct { - field protoreflect.FieldDescriptor - exprs []compiledExpression +func (m *messageEval) evaluate(val protoreflect.Value, failFast bool) error { + return m.evaluateMessage(val.Message(), failFast) } -func (f fieldExpressionEvaluator) evaluate(msg protoreflect.Message, failFast bool) error { - binding := namedBinding{name: "this", val: msg.Get(f.field)} - err := evalExprs(f.exprs, binding, failFast) - if valErr, ok := err.(*ValidationError); ok { - valErr.prefixPaths(string(f.field.Name()), ".") +func (m *messageEval) evaluateMessage(msg protoreflect.Message, failFast bool) error { + if err := m.err; err != nil { + return err } - return err -} - -type messageFieldEvaluator struct { - field protoreflect.FieldDescriptor - embeddedMessageEvaluator -} - -func (m messageFieldEvaluator) evaluate(msg protoreflect.Message, failFast bool) error { - fldMsg := msg.Get(m.field).Message() - err := m.embeddedMessageEvaluator.evaluate(fldMsg, failFast) - if valErr, ok := err.(*ValidationError); ok { - valErr.prefixPaths(string(m.field.FullName()), ".") + var ( + err error + ok bool + ) + for _, constraint := range m.constraints { + evalErr := constraint.evaluateMessage(msg, failFast) + if err, ok = mergeErrors(err, evalErr, failFast); !ok { + break + } } return err } -type embeddedMessageEvaluator struct { - required bool - skipped bool - - msgEval *messageEvaluator - exprs []compiledExpression +type singularFieldEval struct { + required bool + ignoreEmpty bool + field protoreflect.FieldDescriptor + constraints []evaluator } -func (e embeddedMessageEvaluator) evaluate(msg protoreflect.Message, failFast bool) error { - if e.required && !msg.IsValid() { - return &ValidationError{Violations: []*validatev2.Violation{{ +func (f singularFieldEval) evaluateMessage(msg protoreflect.Message, failFast bool) (err error) { + if f.ignoreEmpty && !msg.Has(f.field) { + return nil + } + defer func() { + if valErr, ok := err.(*ValidationError); ok { + valErr.prefixPaths(string(f.field.Name()), ".") + } + }() + if f.required && !msg.Has(f.field) { + return &ValidationError{Violations: []*validate.Violation{{ ConstraintId: "required", Message: "value is required", }}} } - - var ( - err error - ok bool - ) - if !e.skipped && msg.IsValid() { - evalErr := e.msgEval.evaluate(msg, failFast) - err, ok = mergeErrors(err, evalErr, failFast) - if !ok { + val := msg.Get(f.field) + ok := false + for _, constraint := range f.constraints { + evalErr := constraint.evaluate(val, failFast) + if err, ok = mergeErrors(err, evalErr, failFast); !ok { return err } } - - binding := namedBinding{name: "this", val: msg.Interface()} - evalErr := evalExprs(e.exprs, binding, failFast) - err, _ = mergeErrors(err, evalErr, failFast) return err } var ( - _ evaluator = (*messageEvaluator)(nil) - _ evaluator = messageExpressionEvaluator{} - _ evaluator = fieldExpressionEvaluator{} - _ evaluator = messageFieldEvaluator{} - _ evaluator = embeddedMessageEvaluator{} + _ evaluator = constraintsEval(nil) + _ messageEvaluator = constraintsEval(nil) + _ evaluator = (*messageEval)(nil) + _ messageEvaluator = (*messageEval)(nil) + _ messageEvaluator = singularFieldEval{} ) diff --git a/go/expressions.go b/go/expressions.go index 02d5f7b6..395a2ad6 100644 --- a/go/expressions.go +++ b/go/expressions.go @@ -22,29 +22,15 @@ import ( "github.com/google/cel-go/interpreter" ) -type Expression interface { +type expression interface { GetId() string GetMessage() string GetExpression() string } -func evalExprs(exprs []compiledExpression, binding interpreter.Activation, failFast bool) error { - var ( - err error - ok bool - ) - for _, expr := range exprs { - evalErr := expr.eval(binding) - if err, ok = mergeErrors(err, evalErr, failFast); !ok { - break - } - } - return err -} - type compiledExpression struct { program cel.Program - source Expression + source expression } func (expr compiledExpression) eval(bindings interpreter.Activation) error { diff --git a/go/gen/example/v1/validations.pb.go b/go/gen/example/v1/validations.pb.go index 5a017c0d..383f1a7c 100644 --- a/go/gen/example/v1/validations.pb.go +++ b/go/gen/example/v1/validations.pb.go @@ -263,36 +263,42 @@ var file_example_v1_validations_proto_rawDesc = []byte{ 0x1a, 0x38, 0x0a, 0x07, 0x79, 0x5f, 0x67, 0x74, 0x5f, 0x34, 0x32, 0x1a, 0x2d, 0x74, 0x68, 0x69, 0x73, 0x2e, 0x79, 0x20, 0x3e, 0x20, 0x34, 0x32, 0x20, 0x3f, 0x20, 0x27, 0x27, 0x3a, 0x20, 0x27, 0x79, 0x20, 0x6d, 0x75, 0x73, 0x74, 0x20, 0x62, 0x65, 0x20, 0x67, 0x72, 0x65, 0x61, 0x74, 0x65, - 0x72, 0x20, 0x74, 0x68, 0x61, 0x6e, 0x20, 0x34, 0x32, 0x27, 0x22, 0xa2, 0x01, 0x0a, 0x0d, 0x53, + 0x72, 0x20, 0x74, 0x68, 0x61, 0x6e, 0x20, 0x34, 0x32, 0x27, 0x22, 0xfa, 0x01, 0x0a, 0x0d, 0x53, 0x65, 0x6c, 0x66, 0x52, 0x65, 0x63, 0x75, 0x72, 0x73, 0x69, 0x76, 0x65, 0x12, 0x0c, 0x0a, 0x01, - 0x78, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x01, 0x78, 0x12, 0x31, 0x0a, 0x06, 0x74, 0x75, - 0x72, 0x74, 0x6c, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x65, 0x78, 0x61, - 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x53, 0x65, 0x6c, 0x66, 0x52, 0x65, 0x63, 0x75, - 0x72, 0x73, 0x69, 0x76, 0x65, 0x52, 0x06, 0x74, 0x75, 0x72, 0x74, 0x6c, 0x65, 0x3a, 0x50, 0xfa, - 0xf7, 0x18, 0x4c, 0x1a, 0x4a, 0x0a, 0x0e, 0x75, 0x6e, 0x69, 0x71, 0x75, 0x65, 0x5f, 0x74, 0x75, - 0x72, 0x74, 0x6c, 0x65, 0x73, 0x12, 0x1f, 0x61, 0x64, 0x6a, 0x61, 0x63, 0x65, 0x6e, 0x74, 0x20, - 0x74, 0x75, 0x72, 0x74, 0x6c, 0x65, 0x73, 0x20, 0x6d, 0x75, 0x73, 0x74, 0x20, 0x62, 0x65, 0x20, - 0x75, 0x6e, 0x69, 0x71, 0x75, 0x65, 0x1a, 0x17, 0x74, 0x68, 0x69, 0x73, 0x2e, 0x78, 0x20, 0x21, - 0x3d, 0x20, 0x74, 0x68, 0x69, 0x73, 0x2e, 0x74, 0x75, 0x72, 0x74, 0x6c, 0x65, 0x2e, 0x78, 0x22, - 0x3a, 0x0a, 0x0e, 0x4c, 0x6f, 0x6f, 0x70, 0x52, 0x65, 0x63, 0x75, 0x72, 0x73, 0x69, 0x76, 0x65, - 0x41, 0x12, 0x28, 0x0a, 0x01, 0x62, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x65, - 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x4c, 0x6f, 0x6f, 0x70, 0x52, 0x65, - 0x63, 0x75, 0x72, 0x73, 0x69, 0x76, 0x65, 0x42, 0x52, 0x01, 0x62, 0x22, 0x3a, 0x0a, 0x0e, 0x4c, - 0x6f, 0x6f, 0x70, 0x52, 0x65, 0x63, 0x75, 0x72, 0x73, 0x69, 0x76, 0x65, 0x42, 0x12, 0x28, 0x0a, - 0x01, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x65, 0x78, 0x61, 0x6d, 0x70, - 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x4c, 0x6f, 0x6f, 0x70, 0x52, 0x65, 0x63, 0x75, 0x72, 0x73, - 0x69, 0x76, 0x65, 0x41, 0x52, 0x01, 0x61, 0x42, 0xaa, 0x01, 0x0a, 0x0e, 0x63, 0x6f, 0x6d, 0x2e, - 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x42, 0x10, 0x56, 0x61, 0x6c, 0x69, - 0x64, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x50, 0x72, 0x6f, 0x74, 0x6f, 0x50, 0x01, 0x5a, 0x3d, - 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, 0x6f, 0x6d, 0x2f, 0x62, 0x75, 0x66, 0x62, 0x75, - 0x69, 0x6c, 0x64, 0x2f, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, - 0x65, 0x2f, 0x67, 0x6f, 0x2f, 0x67, 0x65, 0x6e, 0x2f, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, - 0x2f, 0x76, 0x31, 0x3b, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x76, 0x31, 0xa2, 0x02, 0x03, - 0x45, 0x58, 0x58, 0xaa, 0x02, 0x0a, 0x45, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x56, 0x31, - 0xca, 0x02, 0x0a, 0x45, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x5c, 0x56, 0x31, 0xe2, 0x02, 0x16, - 0x45, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x5c, 0x56, 0x31, 0x5c, 0x47, 0x50, 0x42, 0x4d, 0x65, - 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0xea, 0x02, 0x0b, 0x45, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, - 0x3a, 0x3a, 0x56, 0x31, 0x62, 0x06, 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x33, + 0x78, 0x18, 0x01, 0x20, 0x01, 0x28, 0x05, 0x52, 0x01, 0x78, 0x12, 0x88, 0x01, 0x0a, 0x06, 0x74, + 0x75, 0x72, 0x74, 0x6c, 0x65, 0x18, 0x02, 0x20, 0x01, 0x28, 0x0b, 0x32, 0x19, 0x2e, 0x65, 0x78, + 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x53, 0x65, 0x6c, 0x66, 0x52, 0x65, 0x63, + 0x75, 0x72, 0x73, 0x69, 0x76, 0x65, 0x42, 0x55, 0xfa, 0xf7, 0x18, 0x51, 0xba, 0x01, 0x4e, 0x0a, + 0x14, 0x6e, 0x6f, 0x6e, 0x5f, 0x7a, 0x65, 0x72, 0x6f, 0x5f, 0x62, 0x61, 0x62, 0x79, 0x5f, 0x74, + 0x75, 0x72, 0x74, 0x6c, 0x65, 0x12, 0x2a, 0x65, 0x6d, 0x62, 0x65, 0x64, 0x64, 0x65, 0x64, 0x20, + 0x74, 0x75, 0x72, 0x74, 0x6c, 0x65, 0x27, 0x73, 0x20, 0x78, 0x20, 0x76, 0x61, 0x6c, 0x75, 0x65, + 0x20, 0x6d, 0x75, 0x73, 0x74, 0x20, 0x6e, 0x6f, 0x74, 0x20, 0x62, 0x65, 0x20, 0x7a, 0x65, 0x72, + 0x6f, 0x1a, 0x0a, 0x74, 0x68, 0x69, 0x73, 0x2e, 0x78, 0x20, 0x3e, 0x20, 0x30, 0x52, 0x06, 0x74, + 0x75, 0x72, 0x74, 0x6c, 0x65, 0x3a, 0x50, 0xfa, 0xf7, 0x18, 0x4c, 0x1a, 0x4a, 0x0a, 0x0e, 0x75, + 0x6e, 0x69, 0x71, 0x75, 0x65, 0x5f, 0x74, 0x75, 0x72, 0x74, 0x6c, 0x65, 0x73, 0x12, 0x1f, 0x61, + 0x64, 0x6a, 0x61, 0x63, 0x65, 0x6e, 0x74, 0x20, 0x74, 0x75, 0x72, 0x74, 0x6c, 0x65, 0x73, 0x20, + 0x6d, 0x75, 0x73, 0x74, 0x20, 0x62, 0x65, 0x20, 0x75, 0x6e, 0x69, 0x71, 0x75, 0x65, 0x1a, 0x17, + 0x74, 0x68, 0x69, 0x73, 0x2e, 0x78, 0x20, 0x21, 0x3d, 0x20, 0x74, 0x68, 0x69, 0x73, 0x2e, 0x74, + 0x75, 0x72, 0x74, 0x6c, 0x65, 0x2e, 0x78, 0x22, 0x3a, 0x0a, 0x0e, 0x4c, 0x6f, 0x6f, 0x70, 0x52, + 0x65, 0x63, 0x75, 0x72, 0x73, 0x69, 0x76, 0x65, 0x41, 0x12, 0x28, 0x0a, 0x01, 0x62, 0x18, 0x01, + 0x20, 0x01, 0x28, 0x0b, 0x32, 0x1a, 0x2e, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x76, + 0x31, 0x2e, 0x4c, 0x6f, 0x6f, 0x70, 0x52, 0x65, 0x63, 0x75, 0x72, 0x73, 0x69, 0x76, 0x65, 0x42, + 0x52, 0x01, 0x62, 0x22, 0x3a, 0x0a, 0x0e, 0x4c, 0x6f, 0x6f, 0x70, 0x52, 0x65, 0x63, 0x75, 0x72, + 0x73, 0x69, 0x76, 0x65, 0x42, 0x12, 0x28, 0x0a, 0x01, 0x61, 0x18, 0x01, 0x20, 0x01, 0x28, 0x0b, + 0x32, 0x1a, 0x2e, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x76, 0x31, 0x2e, 0x4c, 0x6f, + 0x6f, 0x70, 0x52, 0x65, 0x63, 0x75, 0x72, 0x73, 0x69, 0x76, 0x65, 0x41, 0x52, 0x01, 0x61, 0x42, + 0xaa, 0x01, 0x0a, 0x0e, 0x63, 0x6f, 0x6d, 0x2e, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, + 0x76, 0x31, 0x42, 0x10, 0x56, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x73, 0x50, + 0x72, 0x6f, 0x74, 0x6f, 0x50, 0x01, 0x5a, 0x3d, 0x67, 0x69, 0x74, 0x68, 0x75, 0x62, 0x2e, 0x63, + 0x6f, 0x6d, 0x2f, 0x62, 0x75, 0x66, 0x62, 0x75, 0x69, 0x6c, 0x64, 0x2f, 0x70, 0x72, 0x6f, 0x74, + 0x6f, 0x76, 0x61, 0x6c, 0x69, 0x64, 0x61, 0x74, 0x65, 0x2f, 0x67, 0x6f, 0x2f, 0x67, 0x65, 0x6e, + 0x2f, 0x65, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2f, 0x76, 0x31, 0x3b, 0x65, 0x78, 0x61, 0x6d, + 0x70, 0x6c, 0x65, 0x76, 0x31, 0xa2, 0x02, 0x03, 0x45, 0x58, 0x58, 0xaa, 0x02, 0x0a, 0x45, 0x78, + 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x2e, 0x56, 0x31, 0xca, 0x02, 0x0a, 0x45, 0x78, 0x61, 0x6d, 0x70, + 0x6c, 0x65, 0x5c, 0x56, 0x31, 0xe2, 0x02, 0x16, 0x45, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x5c, + 0x56, 0x31, 0x5c, 0x47, 0x50, 0x42, 0x4d, 0x65, 0x74, 0x61, 0x64, 0x61, 0x74, 0x61, 0xea, 0x02, + 0x0b, 0x45, 0x78, 0x61, 0x6d, 0x70, 0x6c, 0x65, 0x3a, 0x3a, 0x56, 0x31, 0x62, 0x06, 0x70, 0x72, + 0x6f, 0x74, 0x6f, 0x33, } var ( diff --git a/go/registry.go b/go/registry.go index ccd51a90..36545be7 100644 --- a/go/registry.go +++ b/go/registry.go @@ -26,10 +26,10 @@ import ( type registry struct { baseEnv *cel.Env - sf singleFlight[protoreflect.FullName, *messageEvaluator] - messageEvaluators syncMap[protoreflect.FullName, *messageEvaluator] - inflightEvaluators syncMap[protoreflect.FullName, *messageEvaluator] - // standardConstraints syncMap[protoreflect.FullName, cel.Program] + sf singleFlight[protoreflect.FullName, *messageEval] + messageEvaluators syncMap[protoreflect.FullName, *messageEval] + inflightEvaluators syncMap[protoreflect.FullName, *messageEval] + // standardConstraints syncMap[protoreflect.FullName, compiledExpression] } func newRegistry() *registry { @@ -40,18 +40,18 @@ func newRegistry() *registry { return ®istry{baseEnv: env} } -func (r *registry) loadOrBuild(desc protoreflect.MessageDescriptor) *messageEvaluator { +func (r *registry) loadOrBuild(desc protoreflect.MessageDescriptor) *messageEval { if eval, ok := r.messageEvaluators.Load(desc.FullName()); ok { return eval } - eval, _, _ := r.sf.Do(desc.FullName(), func() (*messageEvaluator, error) { - return r.build(desc), nil + eval, _, _ := r.sf.Do(desc.FullName(), func() (*messageEval, error) { + return r.buildMessage(desc), nil }) return eval } -func (r *registry) inflightLoadOrBuild(desc protoreflect.MessageDescriptor) *messageEvaluator { +func (r *registry) inflightLoadOrBuild(desc protoreflect.MessageDescriptor) *messageEval { eval, ok := r.inflightEvaluators.Load(desc.FullName()) if ok { return eval @@ -59,13 +59,18 @@ func (r *registry) inflightLoadOrBuild(desc protoreflect.MessageDescriptor) *mes return r.loadOrBuild(desc) } -func (r *registry) build(desc protoreflect.MessageDescriptor) *messageEvaluator { +func (r *registry) buildMessage(desc protoreflect.MessageDescriptor) *messageEval { fullName := desc.FullName() - msgEval := &messageEvaluator{} + msgEval := &messageEval{} defer r.messageEvaluators.Store(fullName, msgEval) r.inflightEvaluators.Store(fullName, msgEval) defer r.inflightEvaluators.Delete(fullName) - if r.buildMessageExpressions(msgEval, desc); msgEval.err != nil { + + constraints, _ := proto.GetExtension(desc.Options(), validate.E_Message).(*validate.MessageConstraints) + if constraints.GetDisabled() { + return msgEval + } + if r.buildMessageExpressions(msgEval, desc, constraints); msgEval.err != nil { return msgEval } if r.buildFields(msgEval, desc); msgEval.err != nil { @@ -74,13 +79,17 @@ func (r *registry) build(desc protoreflect.MessageDescriptor) *messageEvaluator return msgEval } -func (r *registry) buildMessageExpressions(msgEval *messageEvaluator, desc protoreflect.MessageDescriptor) { - constraints, ok := proto.GetExtension(desc.Options(), validate.E_Message).(*validate.MessageConstraints) - if !ok || constraints.GetDisabled() || len(constraints.GetCel()) == 0 { +func (r *registry) buildMessageExpressions( + msgEval *messageEval, + desc protoreflect.MessageDescriptor, + constraints *validate.MessageConstraints, +) { + exprs := constraints.GetCel() + if len(exprs) == 0 { return } - - compiledExprs, err := r.compileExprs(constraints.GetCel(), + compiledExprs, err := r.compileExprs( + exprs, cel.TypeDescs(desc.ParentFile()), cel.Variable("this", cel.ObjectType(string(desc.FullName()))), ) @@ -88,16 +97,13 @@ func (r *registry) buildMessageExpressions(msgEval *messageEvaluator, desc proto msgEval.err = err return } - msgEval.constraints = append( msgEval.constraints, - messageExpressionEvaluator{ - exprs: compiledExprs, - }, + constraintsEval(compiledExprs), ) } -func (r *registry) buildFields(msgEval *messageEvaluator, desc protoreflect.MessageDescriptor) { +func (r *registry) buildFields(msgEval *messageEval, desc protoreflect.MessageDescriptor) { fields := desc.Fields() for i := 0; i < fields.Len(); i++ { fdesc := fields.Get(i) @@ -110,10 +116,12 @@ func (r *registry) buildFields(msgEval *messageEvaluator, desc protoreflect.Mess msgEval.err = CompilationError{ cause: fmt.Errorf("repeated field %s is currently unsupported", fdesc.FullName()), } - case fdesc.Kind() == protoreflect.MessageKind: - r.buildMessageField(msgEval, fdesc) + case fdesc.Kind() == protoreflect.GroupKind: + msgEval.err = CompilationError{ + cause: fmt.Errorf("group field %s is currently unsupported", fdesc.FullName()), + } default: - r.buildScalarField(msgEval, fdesc) + r.buildSingularField(msgEval, fdesc) } if msgEval.err != nil { return @@ -121,56 +129,64 @@ func (r *registry) buildFields(msgEval *messageEvaluator, desc protoreflect.Mess } } -func (r *registry) buildMessageField(msgEval *messageEvaluator, fdesc protoreflect.FieldDescriptor) { - fldEval := messageFieldEvaluator{field: fdesc} +func (r *registry) buildSingularField(msgEval *messageEval, fdesc protoreflect.FieldDescriptor) { + fldEval := singularFieldEval{ + field: fdesc, + } - fldEval.msgEval = r.inflightLoadOrBuild(fdesc.Message()) - if err := fldEval.msgEval.err; err != nil { + constraints, _ := proto.GetExtension(fdesc.Options(), validate.E_Field).(*validate.FieldConstraints) + fldEval.required = constraints.GetMessage().GetRequired() + + if consEval, err := r.buildFieldExpressions(constraints.GetCel(), fdesc); err != nil { msgEval.err = CompilationError{ - cause: fmt.Errorf("failed to compile embedded type %s: %w", - fdesc.Message().FullName(), err), + cause: fmt.Errorf("failed to compile constraints for %s: %w", + fdesc.FullName(), err), } return + } else if len(consEval) > 0 { + fldEval.constraints = append(fldEval.constraints, consEval) } - constraints, _ := proto.GetExtension(fdesc.Options(), validate.E_Field).(*validate.FieldConstraints) - fldEval.required = constraints.GetMessage().GetRequired() - fldEval.skipped = constraints.GetMessage().GetSkipped() - // TODO: check for WKTs + // TODO: standard constraints - compiledExprs, err := r.compileExprs(constraints.GetCel(), - cel.TypeDescs(fdesc.ParentFile()), - cel.Variable("this", cel.ObjectType(string(fdesc.Message().FullName()))), - ) - if err != nil { - msgEval.err = err - return + if fdesc.Kind() == protoreflect.MessageKind && !constraints.GetMessage().GetSkipped() { + embedEval := r.inflightLoadOrBuild(fdesc.Message()) + if err := embedEval.err; err != nil { + msgEval.err = CompilationError{ + cause: fmt.Errorf("failed to compile embedded type %s for %s: %w", + fdesc.Message().FullName(), fdesc.FullName(), err), + } + return + } + fldEval.ignoreEmpty = true // unset messages aren't evaluated + fldEval.constraints = append(fldEval.constraints, embedEval) } - fldEval.exprs = compiledExprs - msgEval.constraints = append(msgEval.constraints, fldEval) -} -func (r *registry) buildScalarField(msgEval *messageEvaluator, fdesc protoreflect.FieldDescriptor) { - fldEval := fieldExpressionEvaluator{ - field: fdesc, + if len(fldEval.constraints) > 0 || fldEval.required { + msgEval.constraints = append(msgEval.constraints, fldEval) } +} - // TODO: standard constraints - - constraints, _ := proto.GetExtension(fdesc.Options(), validate.E_Field).(*validate.FieldConstraints) - compiledExprs, err := r.compileExprs(constraints.GetCel(), - cel.Variable("this", protoKindsToCEL[fdesc.Kind()]), - ) - if err != nil { - msgEval.err = err - return +func (r *registry) buildFieldExpressions(exprs []*validate.Constraint, fdesc protoreflect.FieldDescriptor) (constraintsEval, error) { + if len(exprs) == 0 { + return nil, nil } - fldEval.exprs = compiledExprs - if len(fldEval.exprs) > 0 { - msgEval.constraints = append(msgEval.constraints, fldEval) + var opts []cel.EnvOption + if fdesc.Kind() == protoreflect.MessageKind { + opts = []cel.EnvOption{ + cel.TypeDescs(fdesc.ParentFile()), + cel.Variable("this", cel.ObjectType(string(fdesc.Message().FullName()))), + } + } else { + opts = []cel.EnvOption{ + cel.Variable("this", protoKindsToCEL[fdesc.Kind()]), + } } + + compiled, err := r.compileExprs(exprs, opts...) + return compiled, err } func (r *registry) compileExprs(exprs []*validate.Constraint, envOpts ...cel.EnvOption) ([]compiledExpression, error) { diff --git a/go/validator.go b/go/validator.go index 5eec5e14..0290fd14 100644 --- a/go/validator.go +++ b/go/validator.go @@ -36,7 +36,7 @@ func New(options ...ValidatorOption) *Validator { func (v *Validator) Validate(msg proto.Message) error { refl := msg.ProtoReflect() eval := v.registry.loadOrBuild(refl.Descriptor()) - return eval.evaluate(refl, v.failFast) + return eval.evaluateMessage(refl, v.failFast) } type config struct { diff --git a/proto/private/example/v1/validations.proto b/proto/private/example/v1/validations.proto index 904d8c04..0346c421 100644 --- a/proto/private/example/v1/validations.proto +++ b/proto/private/example/v1/validations.proto @@ -53,7 +53,11 @@ message SelfRecursive { }; int32 x = 1; - SelfRecursive turtle = 2; + SelfRecursive turtle = 2 [(buf.validate.field).cel = { + id: "non_zero_baby_turtle", + message: "embedded turtle's x value must not be zero", + expression: "this.x > 0", + }]; } message LoopRecursiveA {