From b03f050866ecfce662ee5fb591cdcf536ea4ab3e Mon Sep 17 00:00:00 2001 From: jchadwick-buf <116005195+jchadwick-buf@users.noreply.github.com> Date: Mon, 23 Sep 2024 12:30:55 -0400 Subject: [PATCH] Implement predefined field constraints (#61) Like protovalidate-go and protovalidate-java, we need to adjust the code to handle dynamic descriptor sets more robustly, since we need to jump between resolving the protovalidate standard rules and the predefined rule extensions. This necessitates adding a couple of additions to the API surface, namely `ValidatorFactory::SetMessageFactory` and `ValidatorFactory::SetAllowUnknownFields`, which controls instantiation of unknown dynamic types and whether or not to ignore unresolved rules, respectively. Like other protovalidate runtimes, we will default to failing compilation when unknown predefined rules are encountered. This should not break existing users but will prevent silent incorrect behavior. TODO: - [x] Skip reparse when there are no empty fields—this way we can avoid pessimizing the common case - [x] Add an option to fail when unknown rule fields are unable to be resolved. - [x] Update for protobuf changes in https://github.com/bufbuild/protovalidate/pull/246. This will depend on https://github.com/bufbuild/protovalidate/pull/246. --- Makefile | 2 +- bazel/deps.bzl | 6 +- buf/validate/conformance/runner.cc | 5 +- buf/validate/conformance/runner.h | 13 +++- buf/validate/conformance/runner_main.cc | 7 +- buf/validate/internal/BUILD.bazel | 15 ++++ buf/validate/internal/cel_constraint_rules.cc | 40 +++++++++- buf/validate/internal/cel_constraint_rules.h | 10 ++- buf/validate/internal/cel_rules.h | 39 ++++++++-- buf/validate/internal/constraint_rules.h | 2 +- buf/validate/internal/constraints.cc | 4 - buf/validate/internal/constraints_test.cc | 2 +- buf/validate/internal/field_rules.cc | 77 ++++++++++++++++--- buf/validate/internal/field_rules.h | 19 ++++- buf/validate/internal/message_factory.cc | 31 ++++++++ buf/validate/internal/message_factory.h | 43 +++++++++++ buf/validate/internal/message_rules.cc | 7 +- buf/validate/internal/message_rules.h | 3 + buf/validate/validator.cc | 21 +++-- buf/validate/validator.h | 17 +++- 20 files changed, 306 insertions(+), 57 deletions(-) create mode 100644 buf/validate/internal/message_factory.cc create mode 100644 buf/validate/internal/message_factory.h diff --git a/Makefile b/Makefile index 92aa713..e0bb809 100644 --- a/Makefile +++ b/Makefile @@ -11,7 +11,7 @@ COPYRIGHT_YEARS := 2023 LICENSE_IGNORE := -e internal/testdata/ LICENSE_HEADER_VERSION := 0294fdbe1ce8649ebaf5e87e8cdd588e33730bbb # NOTE: Keep this version in sync with the version in `/bazel/deps.bzl`. -PROTOVALIDATE_VERSION ?= v0.7.1 +PROTOVALIDATE_VERSION ?= v0.8.1 # Set to use a different compiler. For example, `GO=go1.18rc1 make test`. GO ?= go diff --git a/bazel/deps.bzl b/bazel/deps.bzl index 5b6780a..82a7167 100644 --- a/bazel/deps.bzl +++ b/bazel/deps.bzl @@ -65,10 +65,10 @@ _dependencies = { }, # NOTE: Keep Version in sync with `/Makefile`. "com_github_bufbuild_protovalidate": { - "sha256": "ccb3952c38397d2cb53fe841af66b05fc012dd17fa754cbe35d9abb547cdf92d", - "strip_prefix": "protovalidate-0.7.1", + "sha256": "c637c8cbaf71b6dc38171e47c2c736581b4cfef385984083561480367659d14f", + "strip_prefix": "protovalidate-0.8.1", "urls": [ - "https://github.com/bufbuild/protovalidate/archive/v0.7.1.tar.gz", + "https://github.com/bufbuild/protovalidate/archive/v0.8.1.tar.gz", ], }, } diff --git a/buf/validate/conformance/runner.cc b/buf/validate/conformance/runner.cc index e1b0d13..8508a46 100644 --- a/buf/validate/conformance/runner.cc +++ b/buf/validate/conformance/runner.cc @@ -20,8 +20,7 @@ namespace buf::validate::conformance { harness::TestConformanceResponse TestRunner::runTest( - const harness::TestConformanceRequest& request, - const google::protobuf::DescriptorPool* descriptorPool) { + const harness::TestConformanceRequest& request) { harness::TestConformanceResponse response; for (const auto& tc : request.cases()) { auto& result = response.mutable_results()->operator[](tc.first); @@ -32,7 +31,7 @@ harness::TestConformanceResponse TestRunner::runTest( *result.mutable_unexpected_error() = "could not parse type url " + dyn.type_url(); continue; } - const auto* desc = descriptorPool->FindMessageTypeByName(dyn.type_url().substr(pos + 1)); + const auto* desc = descriptorPool_->FindMessageTypeByName(dyn.type_url().substr(pos + 1)); if (desc == nullptr) { *result.mutable_unexpected_error() = "could not find descriptor for type " + dyn.type_url(); } else { diff --git a/buf/validate/conformance/runner.h b/buf/validate/conformance/runner.h index 0f60c9f..339afa3 100644 --- a/buf/validate/conformance/runner.h +++ b/buf/validate/conformance/runner.h @@ -22,17 +22,22 @@ namespace buf::validate::conformance { class TestRunner { public: - explicit TestRunner() : validatorFactory_(ValidatorFactory::New().value()) {} + explicit TestRunner( + const google::protobuf::DescriptorPool* descriptorPool = + google::protobuf::DescriptorPool::generated_pool()) + : descriptorPool_(descriptorPool), validatorFactory_(ValidatorFactory::New().value()) { + validatorFactory_->SetMessageFactory(&messageFactory_, descriptorPool_); + validatorFactory_->SetAllowUnknownFields(false); + } - harness::TestConformanceResponse runTest( - const harness::TestConformanceRequest& request, - const google::protobuf::DescriptorPool* descriptorPool); + harness::TestConformanceResponse runTest(const harness::TestConformanceRequest& request); harness::TestResult runTestCase( const google::protobuf::Descriptor* desc, const google::protobuf::Any& dyn); harness::TestResult runTestCase(const google::protobuf::Message& message); private: google::protobuf::DynamicMessageFactory messageFactory_; + const google::protobuf::DescriptorPool* descriptorPool_; std::unique_ptr validatorFactory_; google::protobuf::Arena arena_; }; diff --git a/buf/validate/conformance/runner_main.cc b/buf/validate/conformance/runner_main.cc index db51941..acd421d 100644 --- a/buf/validate/conformance/runner_main.cc +++ b/buf/validate/conformance/runner_main.cc @@ -16,14 +16,15 @@ #include "buf/validate/conformance/runner.h" int main(int argc, char** argv) { - google::protobuf::DescriptorPool descriptorPool; - buf::validate::conformance::TestRunner runner; + google::protobuf::DescriptorPool descriptorPool{ + google::protobuf::DescriptorPool::generated_pool()}; buf::validate::conformance::harness::TestConformanceRequest request; request.ParseFromIstream(&std::cin); for (const auto& file : request.fdset().file()) { descriptorPool.BuildFile(file); } - auto response = runner.runTest(request, &descriptorPool); + buf::validate::conformance::TestRunner runner{&descriptorPool}; + auto response = runner.runTest(request); response.SerializeToOstream(&std::cout); return 0; } diff --git a/buf/validate/internal/BUILD.bazel b/buf/validate/internal/BUILD.bazel index e65d06d..3cd409e 100644 --- a/buf/validate/internal/BUILD.bazel +++ b/buf/validate/internal/BUILD.bazel @@ -23,7 +23,11 @@ cc_library( "@com_google_cel_cpp//eval/public:activation", "@com_google_cel_cpp//eval/public:cel_expression", "@com_google_cel_cpp//eval/public/structs:cel_proto_wrapper", + "@com_google_cel_cpp//eval/public/containers:field_access", + "@com_google_cel_cpp//eval/public/containers:field_backed_list_impl", + "@com_google_cel_cpp//eval/public/containers:field_backed_map_impl", "@com_google_cel_cpp//parser", + "@com_google_cel_cpp//base:value" ], ) @@ -44,15 +48,26 @@ cc_library( deps = [ "@com_google_absl//absl/status", "@com_google_protobuf//:protobuf", + ":message_factory", ], ) +cc_library( + name = "message_factory", + srcs = ["message_factory.cc"], + hdrs = ["message_factory.h"], + deps = [ + "@com_google_protobuf//:protobuf", + ] +) + cc_library( name = "message_rules", srcs = ["message_rules.cc"], hdrs = ["message_rules.h"], deps = [ ":field_rules", + ":message_factory", "@com_github_bufbuild_protovalidate//proto/protovalidate/buf/validate:validate_proto_cc", "@com_google_absl//absl/status", "@com_google_cel_cpp//eval/public:cel_expression", diff --git a/buf/validate/internal/cel_constraint_rules.cc b/buf/validate/internal/cel_constraint_rules.cc index 0469fd0..bd467d3 100644 --- a/buf/validate/internal/cel_constraint_rules.cc +++ b/buf/validate/internal/cel_constraint_rules.cc @@ -14,6 +14,10 @@ #include "buf/validate/internal/cel_constraint_rules.h" +#include "base/values/struct_value.h" +#include "eval/public/containers/field_access.h" +#include "eval/public/containers/field_backed_list_impl.h" +#include "eval/public/containers/field_backed_map_impl.h" #include "eval/public/structs/cel_proto_wrapper.h" #include "parser/parser.h" @@ -57,10 +61,31 @@ absl::Status ProcessConstraint( return absl::OkStatus(); } +cel::runtime::CelValue ProtoFieldToCelValue( + const google::protobuf::Message* message, + const google::protobuf::FieldDescriptor* field, + google::protobuf::Arena* arena) { + if (field->is_map()) { + return cel::runtime::CelValue::CreateMap( + google::protobuf::Arena::Create( + arena, message, field, arena)); + } else if (field->is_repeated()) { + return cel::runtime::CelValue::CreateList( + google::protobuf::Arena::Create( + arena, message, field, arena)); + } else if (cel::runtime::CelValue result; + cel::runtime::CreateValueFromSingleField(message, field, arena, &result).ok()) { + return result; + } + return cel::runtime::CelValue::CreateNull(); +} + } // namespace absl::Status CelConstraintRules::Add( - google::api::expr::runtime::CelExpressionBuilder& builder, Constraint constraint) { + google::api::expr::runtime::CelExpressionBuilder& builder, + Constraint constraint, + const google::protobuf::FieldDescriptor* rule) { auto pexpr_or = cel::parser::Parse(constraint.expression()); if (!pexpr_or.ok()) { return pexpr_or.status(); @@ -71,7 +96,7 @@ absl::Status CelConstraintRules::Add( return expr_or.status(); } std::unique_ptr expr = std::move(expr_or).value(); - exprs_.emplace_back(CompiledConstraint{std::move(constraint), std::move(expr)}); + exprs_.emplace_back(CompiledConstraint{std::move(constraint), std::move(expr), rule}); return absl::OkStatus(); } @@ -79,12 +104,13 @@ absl::Status CelConstraintRules::Add( google::api::expr::runtime::CelExpressionBuilder& builder, std::string_view id, std::string_view message, - std::string_view expression) { + std::string_view expression, + const google::protobuf::FieldDescriptor* rule) { Constraint constraint; *constraint.mutable_id() = id; *constraint.mutable_message() = message; *constraint.mutable_expression() = expression; - return Add(builder, constraint); + return Add(builder, constraint, rule); } absl::Status CelConstraintRules::ValidateCel( @@ -94,11 +120,17 @@ absl::Status CelConstraintRules::ValidateCel( activation.InsertValue("rules", rules_); activation.InsertValue("now", cel::runtime::CelValue::CreateTimestamp(absl::Now())); absl::Status status = absl::OkStatus(); + for (const auto& expr : exprs_) { + if (rules_.IsMessage() && expr.rule) { + activation.InsertValue( + "rule", ProtoFieldToCelValue(rules_.MessageOrDie(), expr.rule, ctx.arena)); + } status = ProcessConstraint(ctx, fieldName, activation, expr); if (ctx.shouldReturn(status)) { break; } + activation.RemoveValueEntry("rule"); } activation.RemoveValueEntry("rules"); return status; diff --git a/buf/validate/internal/cel_constraint_rules.h b/buf/validate/internal/cel_constraint_rules.h index e88e25d..8e68baa 100644 --- a/buf/validate/internal/cel_constraint_rules.h +++ b/buf/validate/internal/cel_constraint_rules.h @@ -16,7 +16,7 @@ #include -#include "buf/validate/expression.pb.h" +#include "buf/validate/validate.pb.h" #include "buf/validate/internal/constraint_rules.h" #include "eval/public/activation.h" #include "eval/public/cel_expression.h" @@ -28,6 +28,7 @@ namespace buf::validate::internal { struct CompiledConstraint { buf::validate::Constraint constraint; std::unique_ptr expr; + const google::protobuf::FieldDescriptor* rule; }; // An abstract base class for constraint with rules that are compiled into CEL expressions. @@ -38,12 +39,15 @@ class CelConstraintRules : public ConstraintRules { using Base::Base; absl::Status Add( - google::api::expr::runtime::CelExpressionBuilder& builder, Constraint constraint); + google::api::expr::runtime::CelExpressionBuilder& builder, + Constraint constraint, + const google::protobuf::FieldDescriptor* rule); absl::Status Add( google::api::expr::runtime::CelExpressionBuilder& builder, std::string_view id, std::string_view message, - std::string_view expression); + std::string_view expression, + const google::protobuf::FieldDescriptor* rule); [[nodiscard]] const std::vector& getExprs() const { return exprs_; } // Validate all the cel rules given the activation that already has 'this' bound. diff --git a/buf/validate/internal/cel_rules.h b/buf/validate/internal/cel_rules.h index aede7ee..1f22ee3 100644 --- a/buf/validate/internal/cel_rules.h +++ b/buf/validate/internal/cel_rules.h @@ -16,6 +16,7 @@ #include "absl/status/status.h" #include "buf/validate/internal/cel_constraint_rules.h" +#include "buf/validate/internal/message_factory.h" #include "buf/validate/validate.pb.h" #include "google/protobuf/arena.h" #include "google/protobuf/descriptor.h" @@ -24,22 +25,48 @@ namespace buf::validate::internal { template absl::Status BuildCelRules( + std::unique_ptr& messageFactory, + bool allowUnknownFields, google::protobuf::Arena* arena, google::api::expr::runtime::CelExpressionBuilder& builder, const R& rules, CelConstraintRules& result) { - result.setRules(&rules, arena); // Look for constraints on the set fields. std::vector fields; - R::GetReflection()->ListFields(rules, &fields); + google::protobuf::Message* reparsedRules{}; + if (messageFactory && rules.unknown_fields().field_count() > 0) { + reparsedRules = messageFactory->messageFactory() + ->GetPrototype(messageFactory->descriptorPool()->FindMessageTypeByName( + rules.GetTypeName())) + ->New(arena); + if (!Reparse(*messageFactory, rules, reparsedRules)) { + reparsedRules = nullptr; + } + } + if (reparsedRules) { + if (!allowUnknownFields && + !reparsedRules->GetReflection()->GetUnknownFields(*reparsedRules).empty()) { + return absl::FailedPreconditionError( + absl::StrCat("unknown constraints in ", reparsedRules->GetTypeName())); + } + result.setRules(reparsedRules, arena); + reparsedRules->GetReflection()->ListFields(*reparsedRules, &fields); + } else { + if (!allowUnknownFields && !R::GetReflection()->GetUnknownFields(rules).empty()) { + return absl::FailedPreconditionError( + absl::StrCat("unknown constraints in ", rules.GetTypeName())); + } + result.setRules(&rules, arena); + R::GetReflection()->ListFields(rules, &fields); + } for (const auto* field : fields) { - if (!field->options().HasExtension(buf::validate::priv::field)) { + if (!field->options().HasExtension(buf::validate::predefined)) { continue; } - const auto& fieldLvl = field->options().GetExtension(buf::validate::priv::field); + const auto& fieldLvl = field->options().GetExtension(buf::validate::predefined); for (const auto& constraint : fieldLvl.cel()) { - auto status = - result.Add(builder, constraint.id(), constraint.message(), constraint.expression()); + auto status = result.Add( + builder, constraint.id(), constraint.message(), constraint.expression(), field); if (!status.ok()) { return status; } diff --git a/buf/validate/internal/constraint_rules.h b/buf/validate/internal/constraint_rules.h index 143030f..36bb349 100644 --- a/buf/validate/internal/constraint_rules.h +++ b/buf/validate/internal/constraint_rules.h @@ -15,7 +15,7 @@ #pragma once #include "absl/status/status.h" -#include "buf/validate/expression.pb.h" +#include "buf/validate/validate.pb.h" #include "eval/public/cel_value.h" #include "google/protobuf/arena.h" #include "google/protobuf/message.h" diff --git a/buf/validate/internal/constraints.cc b/buf/validate/internal/constraints.cc index 75ef886..d696608 100644 --- a/buf/validate/internal/constraints.cc +++ b/buf/validate/internal/constraints.cc @@ -16,8 +16,6 @@ #include "absl/status/statusor.h" #include "buf/validate/internal/extra_func.h" -#include "buf/validate/priv/private.pb.h" -#include "buf/validate/validate.pb.h" #include "eval/public/builtin_func_registrar.h" #include "eval/public/cel_expr_builder_factory.h" #include "eval/public/cel_value.h" @@ -25,8 +23,6 @@ #include "eval/public/containers/field_backed_list_impl.h" #include "eval/public/containers/field_backed_map_impl.h" #include "eval/public/structs/cel_proto_wrapper.h" -#include "google/protobuf/any.pb.h" -#include "google/protobuf/descriptor.pb.h" #include "google/protobuf/dynamic_message.h" #include "google/protobuf/util/message_differencer.h" diff --git a/buf/validate/internal/constraints_test.cc b/buf/validate/internal/constraints_test.cc index 762b9ef..0365b16 100644 --- a/buf/validate/internal/constraints_test.cc +++ b/buf/validate/internal/constraints_test.cc @@ -48,7 +48,7 @@ class ExpressionTest : public testing::Test { constraint.set_expression(std::move(expr)); constraint.set_message(std::move(message)); constraint.set_id(std::move(id)); - return constraints_->Add(*builder_, constraint); + return constraints_->Add(*builder_, constraint, nullptr); } absl::Status Validate( diff --git a/buf/validate/internal/field_rules.cc b/buf/validate/internal/field_rules.cc index ab43ca5..85c96c0 100644 --- a/buf/validate/internal/field_rules.cc +++ b/buf/validate/internal/field_rules.cc @@ -19,6 +19,8 @@ namespace buf::validate::internal { absl::StatusOr> NewFieldRules( + std::unique_ptr& messageFactory, + bool allowUnknownFields, google::protobuf::Arena* arena, google::api::expr::runtime::CelExpressionBuilder& builder, const google::protobuf::FieldDescriptor* field, @@ -30,6 +32,8 @@ absl::StatusOr> NewFieldRules( switch (fieldLvl.type_case()) { case FieldConstraints::kBool: rules_or = NewScalarFieldRules( + messageFactory, + allowUnknownFields, arena, builder, field, @@ -40,6 +44,8 @@ absl::StatusOr> NewFieldRules( break; case FieldConstraints::kFloat: rules_or = NewScalarFieldRules( + messageFactory, + allowUnknownFields, arena, builder, field, @@ -50,6 +56,8 @@ absl::StatusOr> NewFieldRules( break; case FieldConstraints::kDouble: rules_or = NewScalarFieldRules( + messageFactory, + allowUnknownFields, arena, builder, field, @@ -60,6 +68,8 @@ absl::StatusOr> NewFieldRules( break; case FieldConstraints::kInt32: rules_or = NewScalarFieldRules( + messageFactory, + allowUnknownFields, arena, builder, field, @@ -70,6 +80,8 @@ absl::StatusOr> NewFieldRules( break; case FieldConstraints::kInt64: rules_or = NewScalarFieldRules( + messageFactory, + allowUnknownFields, arena, builder, field, @@ -80,6 +92,8 @@ absl::StatusOr> NewFieldRules( break; case FieldConstraints::kUint32: rules_or = NewScalarFieldRules( + messageFactory, + allowUnknownFields, arena, builder, field, @@ -90,6 +104,8 @@ absl::StatusOr> NewFieldRules( break; case FieldConstraints::kUint64: rules_or = NewScalarFieldRules( + messageFactory, + allowUnknownFields, arena, builder, field, @@ -100,6 +116,8 @@ absl::StatusOr> NewFieldRules( break; case FieldConstraints::kSint32: rules_or = NewScalarFieldRules( + messageFactory, + allowUnknownFields, arena, builder, field, @@ -109,6 +127,8 @@ absl::StatusOr> NewFieldRules( break; case FieldConstraints::kSint64: rules_or = NewScalarFieldRules( + messageFactory, + allowUnknownFields, arena, builder, field, @@ -118,6 +138,8 @@ absl::StatusOr> NewFieldRules( break; case FieldConstraints::kFixed32: rules_or = NewScalarFieldRules( + messageFactory, + allowUnknownFields, arena, builder, field, @@ -127,6 +149,8 @@ absl::StatusOr> NewFieldRules( break; case FieldConstraints::kFixed64: rules_or = NewScalarFieldRules( + messageFactory, + allowUnknownFields, arena, builder, field, @@ -136,6 +160,8 @@ absl::StatusOr> NewFieldRules( break; case FieldConstraints::kSfixed32: rules_or = NewScalarFieldRules( + messageFactory, + allowUnknownFields, arena, builder, field, @@ -145,6 +171,8 @@ absl::StatusOr> NewFieldRules( break; case FieldConstraints::kSfixed64: rules_or = NewScalarFieldRules( + messageFactory, + allowUnknownFields, arena, builder, field, @@ -154,6 +182,8 @@ absl::StatusOr> NewFieldRules( break; case FieldConstraints::kString: rules_or = NewScalarFieldRules( + messageFactory, + allowUnknownFields, arena, builder, field, @@ -164,6 +194,8 @@ absl::StatusOr> NewFieldRules( break; case FieldConstraints::kBytes: rules_or = NewScalarFieldRules( + messageFactory, + allowUnknownFields, arena, builder, field, @@ -176,6 +208,8 @@ absl::StatusOr> NewFieldRules( rules_or = std::make_unique(field, fieldLvl); auto status = BuildScalarFieldRules( *rules_or.value(), + messageFactory, + allowUnknownFields, arena, builder, field, @@ -194,7 +228,8 @@ absl::StatusOr> NewFieldRules( return absl::InvalidArgumentError("duration field validator on non-duration field"); } else { auto result = std::make_unique(field, fieldLvl); - auto status = BuildCelRules(arena, builder, fieldLvl.duration(), *result); + auto status = BuildCelRules( + messageFactory, allowUnknownFields, arena, builder, fieldLvl.duration(), *result); if (!status.ok()) { rules_or = status; } else { @@ -209,7 +244,8 @@ absl::StatusOr> NewFieldRules( return absl::InvalidArgumentError("timestamp field validator on non-timestamp field"); } else { auto result = std::make_unique(field, fieldLvl); - auto status = BuildCelRules(arena, builder, fieldLvl.timestamp(), *result); + auto status = BuildCelRules( + messageFactory, allowUnknownFields, arena, builder, fieldLvl.timestamp(), *result); if (!status.ok()) { rules_or = status; } else { @@ -225,14 +261,21 @@ absl::StatusOr> NewFieldRules( } else { std::unique_ptr items; if (fieldLvl.repeated().has_items()) { - auto items_or = NewFieldRules(arena, builder, field, fieldLvl.repeated().items()); + auto items_or = NewFieldRules( + messageFactory, + allowUnknownFields, + arena, + builder, + field, + fieldLvl.repeated().items()); if (!items_or.ok()) { return items_or.status(); } items = std::move(items_or).value(); } auto result = std::make_unique(field, fieldLvl, std::move(items)); - auto status = BuildCelRules(arena, builder, fieldLvl.repeated(), *result); + auto status = BuildCelRules( + messageFactory, allowUnknownFields, arena, builder, fieldLvl.repeated(), *result); if (!status.ok()) { rules_or = status; } else { @@ -244,19 +287,30 @@ absl::StatusOr> NewFieldRules( if (!field->is_map()) { return absl::InvalidArgumentError("map field validator on non-map field"); } else { - auto keyRulesOr = - NewFieldRules(arena, builder, field->message_type()->field(0), fieldLvl.map().keys()); + auto keyRulesOr = NewFieldRules( + messageFactory, + allowUnknownFields, + arena, + builder, + field->message_type()->field(0), + fieldLvl.map().keys()); if (!keyRulesOr.ok()) { return keyRulesOr.status(); } - auto valueRulesOr = - NewFieldRules(arena, builder, field->message_type()->field(1), fieldLvl.map().values()); + auto valueRulesOr = NewFieldRules( + messageFactory, + allowUnknownFields, + arena, + builder, + field->message_type()->field(1), + fieldLvl.map().values()); if (!valueRulesOr.ok()) { return valueRulesOr.status(); } auto result = std::make_unique( field, fieldLvl, std::move(keyRulesOr).value(), std::move(valueRulesOr).value()); - auto status = BuildCelRules(arena, builder, fieldLvl.map(), *result); + auto status = BuildCelRules( + messageFactory, allowUnknownFields, arena, builder, fieldLvl.map(), *result); if (!status.ok()) { rules_or = status; } else { @@ -270,7 +324,8 @@ absl::StatusOr> NewFieldRules( return absl::InvalidArgumentError("any field validator on non-any field"); } else { auto result = std::make_unique(field, fieldLvl, &fieldLvl.any()); - auto status = BuildCelRules(arena, builder, fieldLvl.any(), *result); + auto status = BuildCelRules( + messageFactory, allowUnknownFields, arena, builder, fieldLvl.any(), *result); if (!status.ok()) { rules_or = status; } else { @@ -287,7 +342,7 @@ absl::StatusOr> NewFieldRules( } if (rules_or.ok()) { for (const auto& constraint : fieldLvl.cel()) { - auto status = rules_or.value()->Add(builder, constraint); + auto status = rules_or.value()->Add(builder, constraint, nullptr); if (!status.ok()) { return status; } diff --git a/buf/validate/internal/field_rules.h b/buf/validate/internal/field_rules.h index 9724047..108f42f 100644 --- a/buf/validate/internal/field_rules.h +++ b/buf/validate/internal/field_rules.h @@ -25,6 +25,8 @@ namespace buf::validate::internal { template absl::Status BuildScalarFieldRules( FieldConstraintRules& result, + std::unique_ptr& messageFactory, + bool allowUnknownFields, google::protobuf::Arena* arena, google::api::expr::runtime::CelExpressionBuilder& builder, const google::protobuf::FieldDescriptor* field, @@ -47,11 +49,13 @@ absl::Status BuildScalarFieldRules( google::protobuf::FieldDescriptor::TypeName(expectedType))); } } - return BuildCelRules(arena, builder, rules, result); + return BuildCelRules(messageFactory, allowUnknownFields, arena, builder, rules, result); } template absl::StatusOr> NewScalarFieldRules( + std::unique_ptr& messageFactory, + bool allowUnknownFields, google::protobuf::Arena* arena, google::api::expr::runtime::CelExpressionBuilder& builder, const google::protobuf::FieldDescriptor* field, @@ -61,7 +65,16 @@ absl::StatusOr> NewScalarFieldRules( std::string_view wrapperName = "") { auto result = std::make_unique(field, fieldLvl); auto status = BuildScalarFieldRules( - *result, arena, builder, field, fieldLvl, rules, expectedType, wrapperName); + *result, + messageFactory, + allowUnknownFields, + arena, + builder, + field, + fieldLvl, + rules, + expectedType, + wrapperName); if (!status.ok()) { return status; } @@ -69,6 +82,8 @@ absl::StatusOr> NewScalarFieldRules( } absl::StatusOr> NewFieldRules( + std::unique_ptr& messageFactory, + bool allowUnknownFields, google::protobuf::Arena* arena, google::api::expr::runtime::CelExpressionBuilder& builder, const google::protobuf::FieldDescriptor* field, diff --git a/buf/validate/internal/message_factory.cc b/buf/validate/internal/message_factory.cc new file mode 100644 index 0000000..6ed2e87 --- /dev/null +++ b/buf/validate/internal/message_factory.cc @@ -0,0 +1,31 @@ +// Copyright 2023 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "buf/validate/internal/message_factory.h" + +namespace buf::validate::internal { + +bool Reparse( + MessageFactory& messageFactory, + const google::protobuf::Message& from, + google::protobuf::Message* to) { + std::string serialized; + from.SerializeToString(&serialized); + google::protobuf::io::CodedInputStream input( + reinterpret_cast(serialized.c_str()), static_cast(serialized.size())); + input.SetExtensionRegistry(messageFactory.descriptorPool(), messageFactory.messageFactory()); + return to->ParseFromCodedStream(&input); +} + +} // namespace buf::validate::internal diff --git a/buf/validate/internal/message_factory.h b/buf/validate/internal/message_factory.h new file mode 100644 index 0000000..a70d0c2 --- /dev/null +++ b/buf/validate/internal/message_factory.h @@ -0,0 +1,43 @@ +// Copyright 2023 Buf Technologies, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" + +namespace buf::validate::internal { + +struct MessageFactory { + public: + MessageFactory( + google::protobuf::MessageFactory* messageFactory, + const google::protobuf::DescriptorPool* descriptorPool) + : messageFactory_(messageFactory), descriptorPool_(descriptorPool) {} + + google::protobuf::MessageFactory* messageFactory() { return messageFactory_; } + + const google::protobuf::DescriptorPool* descriptorPool() { return descriptorPool_; } + + private: + google::protobuf::MessageFactory* messageFactory_; + const google::protobuf::DescriptorPool* descriptorPool_; +}; + +bool Reparse( + MessageFactory& messageFactory, + const google::protobuf::Message& from, + google::protobuf::Message* to); + +} // namespace buf::validate::internal diff --git a/buf/validate/internal/message_rules.cc b/buf/validate/internal/message_rules.cc index 989fb8f..34364d4 100644 --- a/buf/validate/internal/message_rules.cc +++ b/buf/validate/internal/message_rules.cc @@ -23,7 +23,7 @@ absl::StatusOr> BuildMessageRules( const MessageConstraints& constraints) { auto result = std::make_unique(); for (const auto& constraint : constraints.cel()) { - if (auto status = result->Add(builder, constraint); !status.ok()) { + if (auto status = result->Add(builder, constraint, nullptr); !status.ok()) { return status; } } @@ -31,6 +31,8 @@ absl::StatusOr> BuildMessageRules( } Constraints NewMessageConstraints( + std::unique_ptr& messageFactory, + bool allowUnknownFields, google::protobuf::Arena* arena, google::api::expr::runtime::CelExpressionBuilder& builder, const google::protobuf::Descriptor* descriptor) { @@ -53,7 +55,8 @@ Constraints NewMessageConstraints( continue; } const auto& fieldLvl = field->options().GetExtension(buf::validate::field); - auto rules_or = NewFieldRules(arena, builder, field, fieldLvl); + auto rules_or = + NewFieldRules(messageFactory, allowUnknownFields, arena, builder, field, fieldLvl); if (!rules_or.ok()) { return rules_or.status(); } diff --git a/buf/validate/internal/message_rules.h b/buf/validate/internal/message_rules.h index 2f89682..438a042 100644 --- a/buf/validate/internal/message_rules.h +++ b/buf/validate/internal/message_rules.h @@ -16,6 +16,7 @@ #include "absl/status/statusor.h" #include "buf/validate/internal/constraints.h" +#include "buf/validate/internal/message_factory.h" #include "buf/validate/validate.pb.h" #include "eval/public/cel_expression.h" #include "google/protobuf/arena.h" @@ -26,6 +27,8 @@ namespace buf::validate::internal { using Constraints = absl::StatusOr>>; Constraints NewMessageConstraints( + std::unique_ptr& messageFactory, + bool allowUnknownFields, google::protobuf::Arena* arena, google::api::expr::runtime::CelExpressionBuilder& builder, const google::protobuf::Descriptor* descriptor); diff --git a/buf/validate/validator.cc b/buf/validate/validator.cc index 2bb0852..cef4a75 100644 --- a/buf/validate/validator.cc +++ b/buf/validate/validator.cc @@ -46,12 +46,13 @@ absl::Status Validator::ValidateFields( } if (field->options().HasExtension(validate::field)) { const auto& fieldExt = field->options().GetExtension(validate::field); - if (fieldExt.ignore() == IGNORE_ALWAYS || - fieldExt.skipped() || - (fieldExt.has_repeated() && (fieldExt.repeated().items().ignore() == IGNORE_ALWAYS || - fieldExt.repeated().items().skipped())) || - (fieldExt.has_map() && (fieldExt.map().values().ignore() == IGNORE_ALWAYS || - fieldExt.map().values().skipped()))) { + if (fieldExt.ignore() == IGNORE_ALWAYS || fieldExt.skipped() || + (fieldExt.has_repeated() && + (fieldExt.repeated().items().ignore() == IGNORE_ALWAYS || + fieldExt.repeated().items().skipped())) || + (fieldExt.has_map() && + (fieldExt.map().values().ignore() == IGNORE_ALWAYS || + fieldExt.map().values().skipped()))) { continue; } } @@ -138,7 +139,9 @@ absl::Status ValidatorFactory::Add(const google::protobuf::Descriptor* desc) { return iter->second.status(); } auto status = - constraints_.emplace(desc, internal::NewMessageConstraints(&arena_, *builder_, desc)) + constraints_ + .emplace( + desc, internal::NewMessageConstraints(messageFactory_, allowUnknownFields_, &arena_, *builder_, desc)) .first->second.status(); if (!status.ok()) { return status; @@ -174,7 +177,9 @@ const internal::Constraints* ValidatorFactory::GetMessageConstraints( if (iter != constraints_.end()) { return &iter->second; } - return &constraints_.emplace(desc, internal::NewMessageConstraints(&arena_, *builder_, desc)) + return &constraints_ + .emplace( + desc, internal::NewMessageConstraints(messageFactory_, allowUnknownFields_, &arena_, *builder_, desc)) .first->second; } diff --git a/buf/validate/validator.h b/buf/validate/validator.h index 0620f63..921fdcd 100644 --- a/buf/validate/validator.h +++ b/buf/validate/validator.h @@ -17,9 +17,10 @@ #include #include -#include "buf/validate/expression.pb.h" +#include "buf/validate/validate.pb.h" #include "buf/validate/internal/constraints.h" #include "buf/validate/internal/message_rules.h" +#include "buf/validate/internal/message_factory.h" #include "eval/public/cel_expression.h" #include "google/protobuf/message.h" @@ -94,10 +95,24 @@ class ValidatorFactory { disableLazyLoading_ = disable; } + /// Set message factory and descriptor pool. This is used for re-parsing unknown fields. + /// The provided messageFactory and descriptorPool must outlive the ValidatorFactory. + void SetMessageFactory(google::protobuf::MessageFactory *messageFactory, + const google::protobuf::DescriptorPool *descriptorPool) { + messageFactory_ = std::make_unique(messageFactory, descriptorPool); + } + + /// Set whether or not unknown constraint fields will be tolerated. Defaults to false. + void SetAllowUnknownFields(bool allowUnknownFields) { + allowUnknownFields_ = allowUnknownFields; + } + private: friend class Validator; google::protobuf::Arena arena_; absl::Mutex mutex_; + std::unique_ptr messageFactory_; + bool allowUnknownFields_; absl::flat_hash_map constraints_ ABSL_GUARDED_BY(mutex_); std::unique_ptr builder_