From a3f3c7bfea22df8082f3e651a2ecb1a6595bc08b Mon Sep 17 00:00:00 2001 From: Tristan Swadell Date: Fri, 8 May 2026 19:43:06 -0700 Subject: [PATCH] Create a `proto_to_predicate` compiler for converting proto messages into CEL expressions PiperOrigin-RevId: 912808489 --- common/expr_factory.h | 25 +- parser/macro_expr_factory_test.cc | 50 ++++ tools/BUILD | 45 +++ tools/proto_to_predicate.cc | 445 ++++++++++++++++++++++++++++++ tools/proto_to_predicate.h | 35 +++ tools/proto_to_predicate_test.cc | 269 ++++++++++++++++++ 6 files changed, 868 insertions(+), 1 deletion(-) create mode 100644 tools/proto_to_predicate.cc create mode 100644 tools/proto_to_predicate.h create mode 100644 tools/proto_to_predicate_test.cc diff --git a/common/expr_factory.h b/common/expr_factory.h index b9769b457..1a4c0fcd6 100644 --- a/common/expr_factory.h +++ b/common/expr_factory.h @@ -352,7 +352,30 @@ class ExprFactory { return expr; } - private: + template ::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>, + typename = std::enable_if_t::value>> + Expr NewBind(NextIdFunc next_id, BindVar bind_var, BindExpr bind_expr, + RestExpr rest_expr) { + Expr expr; + expr.set_id(next_id()); + auto& comprehension_expr = expr.mutable_comprehension_expr(); + comprehension_expr.set_iter_var("#unused"); + comprehension_expr.set_iter_range( + NewList(next_id(), std::vector{})); + comprehension_expr.set_accu_var(bind_var); + comprehension_expr.set_accu_init(std::move(bind_expr)); + comprehension_expr.set_loop_condition(NewBoolConst(next_id(), false)); + comprehension_expr.set_loop_step(NewIdent(next_id(), bind_var)); + comprehension_expr.set_result(std::move(rest_expr)); + return expr; + } + + protected: friend class MacroExprFactory; friend class ParserMacroExprFactory; diff --git a/parser/macro_expr_factory_test.cc b/parser/macro_expr_factory_test.cc index 489538be1..98ed7d9dc 100644 --- a/parser/macro_expr_factory_test.cc +++ b/parser/macro_expr_factory_test.cc @@ -39,6 +39,7 @@ class TestMacroExprFactory final : public MacroExprFactory { return NewUnspecified(NextId()); } + using MacroExprFactory::NewBind; using MacroExprFactory::NewBoolConst; using MacroExprFactory::NewCall; using MacroExprFactory::NewComprehension; @@ -69,6 +70,8 @@ class TestMacroExprFactory final : public MacroExprFactory { namespace { +using ::testing::IsEmpty; + TEST(MacroExprFactory, CopyUnspecified) { TestMacroExprFactory factory; EXPECT_EQ(factory.Copy(factory.NewUnspecified()), factory.NewUnspecified(2)); @@ -147,5 +150,52 @@ TEST(MacroExprFactory, CopyComprehension) { factory.NewIdent(11, "foo"), factory.NewIdent(12, "bar"))); } +TEST(MacroExprFactory, NewBind) { + TestMacroExprFactory factory; + Expr bind_expr = factory.NewIdent(10, "x"); + Expr rest_expr = factory.NewIdent(20, "y"); + + auto next_id = [id = 100]() mutable { return id++; }; + + Expr expr = + factory.NewBind(next_id, "a", std::move(bind_expr), std::move(rest_expr)); + + EXPECT_EQ(expr.id(), 100); + ASSERT_TRUE(expr.has_comprehension_expr()); + + const auto& comp = expr.comprehension_expr(); + EXPECT_EQ(comp.iter_var(), "#unused"); + + ASSERT_TRUE(comp.has_iter_range()); + EXPECT_EQ(comp.iter_range().id(), 101); + EXPECT_EQ(comp.iter_range().kind_case(), ExprKindCase::kListExpr); + EXPECT_THAT(comp.iter_range().list_expr().elements(), IsEmpty()); + + EXPECT_EQ(comp.accu_var(), "a"); + + ASSERT_TRUE(comp.has_accu_init()); + Expr expected_bind_expr; + expected_bind_expr.set_id(10); + expected_bind_expr.mutable_ident_expr().set_name("x"); + EXPECT_EQ(comp.accu_init(), expected_bind_expr); + + ASSERT_TRUE(comp.has_loop_condition()); + EXPECT_EQ(comp.loop_condition().id(), 102); + EXPECT_EQ(comp.loop_condition().kind_case(), ExprKindCase::kConstant); + EXPECT_TRUE(comp.loop_condition().const_expr().has_bool_value()); + EXPECT_FALSE(comp.loop_condition().const_expr().bool_value()); + + ASSERT_TRUE(comp.has_loop_step()); + EXPECT_EQ(comp.loop_step().id(), 103); + EXPECT_EQ(comp.loop_step().kind_case(), ExprKindCase::kIdentExpr); + EXPECT_EQ(comp.loop_step().ident_expr().name(), "a"); + + ASSERT_TRUE(comp.has_result()); + Expr expected_rest_expr; + expected_rest_expr.set_id(20); + expected_rest_expr.mutable_ident_expr().set_name("y"); + EXPECT_EQ(comp.result(), expected_rest_expr); +} + } // namespace } // namespace cel diff --git a/tools/BUILD b/tools/BUILD index ceb2befc5..1eae8c377 100644 --- a/tools/BUILD +++ b/tools/BUILD @@ -204,6 +204,51 @@ cc_library( ], ) +cc_library( + name = "proto_to_predicate", + srcs = ["proto_to_predicate.cc"], + hdrs = ["proto_to_predicate.h"], + deps = [ + "//common:ast", + "//common:expr", + "//common:expr_factory", + "@com_google_absl//absl/status", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings", + "@com_google_absl//absl/types:span", + "@com_google_protobuf//:protobuf", + ], +) + +cc_test( + name = "proto_to_predicate_test", + srcs = ["proto_to_predicate_test.cc"], + deps = [ + ":cel_unparser", + ":proto_to_predicate", + "//common:ast", + "//common:ast_proto", + "//compiler", + "//env", + "//env:config", + "//env:env_runtime", + "//env:env_std_extensions", + "//env:env_yaml", + "//env:runtime_std_extensions", + "//eval/testutil:test_message_cc_proto", + "//extensions/protobuf:value", + "//internal:status_macros", + "//internal:testing", + "//internal:testing_descriptor_pool", + "//runtime", + "//runtime:activation", + "@com_google_absl//absl/status:statusor", + "@com_google_absl//absl/strings:string_view", + "@com_google_absl//absl/types:span", + "@com_google_protobuf//:protobuf", + ], +) + cc_test( name = "descriptor_pool_builder_test", srcs = ["descriptor_pool_builder_test.cc"], diff --git a/tools/proto_to_predicate.cc b/tools/proto_to_predicate.cc new file mode 100644 index 000000000..8cda36a02 --- /dev/null +++ b/tools/proto_to_predicate.cc @@ -0,0 +1,445 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tools/proto_to_predicate.h" + +#include +#include +#include +#include + +#include "absl/status/status.h" +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "common/ast.h" +#include "common/expr.h" +#include "common/expr_factory.h" +#include "google/protobuf/descriptor.h" +#include "google/protobuf/message.h" +#include "google/protobuf/reflection.h" + +namespace cel::tools { +namespace { + +using ::google::protobuf::FieldDescriptor; +using ::google::protobuf::Message; +using ::google::protobuf::Reflection; + +using FieldPath = std::vector; + +class PredicateBuilder : public ExprFactory { + public: + explicit PredicateBuilder(absl::string_view input_name) + : ExprFactory(), input_name_(input_name), id_(1) {} + + absl::StatusOr Build(const Message& message) { + std::vector predicates; + FieldPath path; + + auto status = Walk(message, path, predicates); + if (!status.ok()) { + return status; + } + + if (predicates.empty()) { + return Ast(NewBoolConst(NextId(), true), std::move(source_info_)); + } + + Expr root = FoldBinaryOp("_&&_", predicates); + return Ast(std::move(root), std::move(source_info_)); + } + + absl::StatusOr Build(absl::Span messages) { + if (messages.empty()) { + return Ast(NewBoolConst(NextId(), true), std::move(source_info_)); + } + + std::vector message_asts; + for (const auto* message : messages) { + std::vector predicates; + FieldPath path; + + auto status = Walk(*message, path, predicates); + if (!status.ok()) { + return status; + } + + if (predicates.empty()) { + message_asts.push_back(NewBoolConst(NextId(), true)); + } else { + message_asts.push_back(FoldBinaryOp("_&&_", predicates)); + } + } + + Expr root = FoldBinaryOp("_||_", message_asts); + return Ast(std::move(root), std::move(source_info_)); + } + + private: + ExprId NextId() { return id_++; } + + // --------------------------------------------------------------------------- + // Path construction + // --------------------------------------------------------------------------- + + Expr BuildPath(const FieldPath& path) { + Expr e = NewIdent(NextId(), input_name_); + for (const auto* f : path) { + e = NewSelect(NextId(), std::move(e), f->name()); + } + return e; + } + + // --------------------------------------------------------------------------- + // Field value extraction + // --------------------------------------------------------------------------- + + // Converts a singular field value to a CEL constant expression. + Expr PrimitiveToExpr(ExprId expr_id, const Message& message, + const Reflection* reflection, + const FieldDescriptor* field) { + switch (field->cpp_type()) { + case FieldDescriptor::CPPTYPE_INT32: + return NewIntConst(expr_id, reflection->GetInt32(message, field)); + case FieldDescriptor::CPPTYPE_INT64: + return NewIntConst(expr_id, reflection->GetInt64(message, field)); + case FieldDescriptor::CPPTYPE_UINT32: + return NewUintConst(expr_id, reflection->GetUInt32(message, field)); + case FieldDescriptor::CPPTYPE_UINT64: + return NewUintConst(expr_id, reflection->GetUInt64(message, field)); + case FieldDescriptor::CPPTYPE_DOUBLE: + return NewDoubleConst(expr_id, reflection->GetDouble(message, field)); + case FieldDescriptor::CPPTYPE_FLOAT: + return NewDoubleConst(expr_id, reflection->GetFloat(message, field)); + case FieldDescriptor::CPPTYPE_BOOL: + return NewBoolConst(expr_id, reflection->GetBool(message, field)); + case FieldDescriptor::CPPTYPE_ENUM: + return NewIntConst(expr_id, reflection->GetEnumValue(message, field)); + case FieldDescriptor::CPPTYPE_STRING: { + std::string str_val = reflection->GetString(message, field); + if (field->type() == FieldDescriptor::TYPE_BYTES) { + return NewBytesConst(expr_id, std::move(str_val)); + } + return NewStringConst(expr_id, std::move(str_val)); + } + default: + // Message is handled elsewhere + break; + } + return NewNullConst(expr_id); + } + + Expr PrimitiveToExpr(const Message& message, const Reflection* reflection, + const FieldDescriptor* field) { + return PrimitiveToExpr(NextId(), message, reflection, field); + } + + // Converts a repeated field element to a CEL constant expression. + Expr RepeatedPrimitiveToExpr(const Message& message, + const Reflection* reflection, + const FieldDescriptor* field, int index) { + const ExprId id = NextId(); + switch (field->cpp_type()) { + case FieldDescriptor::CPPTYPE_INT32: + return NewIntConst(id, + reflection->GetRepeatedInt32(message, field, index)); + case FieldDescriptor::CPPTYPE_INT64: + return NewIntConst(id, + reflection->GetRepeatedInt64(message, field, index)); + case FieldDescriptor::CPPTYPE_UINT32: + return NewUintConst( + id, reflection->GetRepeatedUInt32(message, field, index)); + case FieldDescriptor::CPPTYPE_UINT64: + return NewUintConst( + id, reflection->GetRepeatedUInt64(message, field, index)); + case FieldDescriptor::CPPTYPE_DOUBLE: + return NewDoubleConst( + id, reflection->GetRepeatedDouble(message, field, index)); + case FieldDescriptor::CPPTYPE_FLOAT: + return NewDoubleConst( + id, reflection->GetRepeatedFloat(message, field, index)); + case FieldDescriptor::CPPTYPE_BOOL: + return NewBoolConst(id, + reflection->GetRepeatedBool(message, field, index)); + case FieldDescriptor::CPPTYPE_ENUM: + return NewIntConst( + id, reflection->GetRepeatedEnumValue(message, field, index)); + case FieldDescriptor::CPPTYPE_STRING: { + std::string str_val = + reflection->GetRepeatedString(message, field, index); + if (field->type() == FieldDescriptor::TYPE_BYTES) { + return NewBytesConst(id, std::move(str_val)); + } + return NewStringConst(id, std::move(str_val)); + } + default: + break; + } + return NewNullConst(id); + } + + // --------------------------------------------------------------------------- + // Expression construction helpers + // --------------------------------------------------------------------------- + + // Creates a binary operator call: `lhs rhs`. + Expr ConstructBinaryOp(absl::string_view op, Expr lhs, Expr rhs) { + std::vector args; + args.reserve(2); + args.push_back(std::move(lhs)); + args.push_back(std::move(rhs)); + return NewCall(NextId(), op, std::move(args)); + } + + Expr ConstructEquality(Expr lhs, Expr rhs) { + return ConstructBinaryOp("_==_", std::move(lhs), std::move(rhs)); + } + + // Left-folds a vector of expressions with a binary operator. + // Requires: `exprs` is non-empty. + Expr FoldBinaryOp(absl::string_view op, std::vector& exprs) { + Expr root = std::move(exprs[0]); + for (size_t i = 1; i < exprs.size(); ++i) { + root = ConstructBinaryOp(op, std::move(root), std::move(exprs[i])); + } + return root; + } + + // --------------------------------------------------------------------------- + // Map literal construction + // --------------------------------------------------------------------------- + + Expr ConstructMapLiteral(ExprId expr_id, const Reflection* reflection, + int size, const Message& message, + const FieldDescriptor* field) { + const FieldDescriptor* const key_field = + field->message_type()->FindFieldByName("key"); + const FieldDescriptor* const value_field = + field->message_type()->FindFieldByName("value"); + + std::vector entries; + for (int i = 0; i < size; ++i) { + const Message& entry_msg = + reflection->GetRepeatedMessage(message, field, i); + const Reflection* const entry_ref = entry_msg.GetReflection(); + entries.push_back(NewMapEntry( + NextId(), PrimitiveToExpr(entry_msg, entry_ref, key_field), + PrimitiveToExpr(entry_msg, entry_ref, value_field))); + } + return NewMap(expr_id, std::move(entries)); + } + + // --------------------------------------------------------------------------- + // Map field predicate (extracted from Walk for readability) + // --------------------------------------------------------------------------- + + // Builds the predicate for a map field to assert that all values in the + // input field map are present in the literal map specified in the proto. + // cel.bind( + // map_val, , + // .all(k, v, + // map_val[?k].optMap(val, val == v).orValue(false))) + void WalkMapField(const Reflection* reflection, const Message& message, + const FieldDescriptor* field, const FieldPath& path, + int size, std::vector& predicates) { + auto next_id = [&]() { return NextId(); }; + + // + const ExprId literal_map_id = NextId(); + Expr literal_map = + ConstructMapLiteral(literal_map_id, reflection, size, message, field); + + // Reusable sub-expressions constructed exactly once to avoid heap + // re-allocations. + + // map_val[?k] + const ExprId opt_map_select_key_id = NextId(); + Expr opt_select_map_key_expr = + NewCall(opt_map_select_key_id, "_[?_]", + std::vector{NewIdent(opt_map_select_key_id, "map_val"), + NewIdent(opt_map_select_key_id, "k")}); + + // val == v + const ExprId map_val_equals_v_id = NextId(); + Expr map_val_equals_v_expr = + NewCall(map_val_equals_v_id, "_==_", + std::vector{NewIdent(map_val_equals_v_id, "val"), + NewIdent(map_val_equals_v_id, "v")}); + + // val = map_val[?k].value(), val == v + Expr optmap_comp = + NewBind(next_id, "val", + NewMemberCall(NextId(), "value", opt_select_map_key_expr, + std::vector{}), + map_val_equals_v_expr); + const ExprId optmap_comp_id = optmap_comp.id(); + + // Expanded ternary for optMap: + // map_val[?k].hasValue() + // ? optional.of() + // : optional.none() + Expr expanded_optmap = + NewCall(NextId(), "_?_:_", + std::vector{ + NewMemberCall(NextId(), "hasValue", opt_select_map_key_expr, + std::vector{}), + NewCall(NextId(), "optional.of", + std::vector{std::move(optmap_comp)}), + NewCall(NextId(), "optional.none", std::vector{})}); + + // optMap macro tracker + source_info_.mutable_macro_calls()[optmap_comp_id] = + NewMemberCall(NextId(), "optMap", std::move(opt_select_map_key_expr), + std::vector{NewIdent(NextId(), "val"), + std::move(map_val_equals_v_expr)}); + + // .orValue(false) + Expr condition_expanded = + NewMemberCall(NextId(), "orValue", std::move(expanded_optmap), + std::vector{NewBoolConst(NextId(), false)}); + + // .all comprehension expansion + Expr step_all = NewCall(NextId(), "_&&_", + std::vector{NewIdent(NextId(), AccuVarName()), + std::move(condition_expanded)}); + Expr condition_not_strictly_false = + NewCall(NextId(), "@not_strictly_false", + std::vector{NewIdent(NextId(), AccuVarName())}); + + Expr literal_map_copy1 = literal_map; + Expr literal_map_copy2 = literal_map; + + Expr all_comp = NewComprehension( + NextId(), "k", "v", std::move(literal_map_copy1), AccuVarName(), + NewBoolConst(NextId(), true), std::move(condition_not_strictly_false), + std::move(step_all), NewIdent(NextId(), AccuVarName())); + + // all macro tracker + Expr all_condition_macro = + NewMemberCall(NextId(), "orValue", NewUnspecified(optmap_comp_id), + std::vector{NewBoolConst(NextId(), false)}); + source_info_.mutable_macro_calls()[all_comp.id()] = NewMemberCall( + NextId(), "all", std::move(literal_map_copy2), + std::vector{NewIdent(NextId(), "k"), NewIdent(NextId(), "v"), + std::move(all_condition_macro)}); + + // cel.bind comprehension expansion + const ExprId all_comp_id = all_comp.id(); + Expr bind_comp = + NewBind(next_id, "map_val", BuildPath(path), std::move(all_comp)); + const ExprId bind_comp_id = bind_comp.id(); + + // bind macro tracker + source_info_.mutable_macro_calls()[bind_comp_id] = NewMemberCall( + NextId(), "bind", NewIdent(NextId(), "cel"), + std::vector{NewIdent(NextId(), "map_val"), BuildPath(path), + NewUnspecified(all_comp_id)}); + + predicates.push_back(std::move(bind_comp)); + } + + // --------------------------------------------------------------------------- + // Repeated field predicate (extracted from Walk for readability) + // --------------------------------------------------------------------------- + + // Builds: sets.contains(.field, []) + absl::Status WalkRepeatedField(const Reflection* reflection, + const Message& message, + const FieldDescriptor* field, + const FieldPath& path, int size, + std::vector& predicates) { + if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { + return absl::UnimplementedError( + "Repeated messages are not fully supported yet."); + } + + std::vector elements; + elements.reserve(size); + for (int i = 0; i < size; ++i) { + elements.push_back(NewListElement( + RepeatedPrimitiveToExpr(message, reflection, field, i))); + } + Expr literal_list = NewList(NextId(), std::move(elements)); + + std::vector contains_args; + contains_args.push_back(BuildPath(path)); + contains_args.push_back(std::move(literal_list)); + predicates.push_back( + NewCall(NextId(), "sets.contains", std::move(contains_args))); + + return absl::OkStatus(); + } + + // --------------------------------------------------------------------------- + // Recursive message walk + // --------------------------------------------------------------------------- + + absl::Status Walk(const Message& message, FieldPath& path, + std::vector& predicates) { + const Reflection* const reflection = message.GetReflection(); + std::vector fields; + reflection->ListFields(message, &fields); + + for (const auto* field : fields) { + path.push_back(field); + + if (field->is_map()) { + const int size = reflection->FieldSize(message, field); + if (size > 0) { + WalkMapField(reflection, message, field, path, size, predicates); + } + } else if (field->is_repeated()) { + const int size = reflection->FieldSize(message, field); + if (size > 0) { + auto status = WalkRepeatedField(reflection, message, field, path, + size, predicates); + if (!status.ok()) return status; + } + } else if (field->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { + const Message& sub_message = reflection->GetMessage(message, field); + auto status = Walk(sub_message, path, predicates); + if (!status.ok()) return status; + } else { + // Primitive field: input.field == + predicates.push_back(ConstructEquality( + BuildPath(path), PrimitiveToExpr(message, reflection, field))); + } + + path.pop_back(); + } + return absl::OkStatus(); + } + + absl::string_view input_name_; + ExprId id_; + SourceInfo source_info_; +}; + +} // namespace + +absl::StatusOr ProtocolBufferToPredicateAst( + const ::google::protobuf::Message& message, absl::string_view input_name) { + PredicateBuilder builder(input_name); + return builder.Build(message); +} + +absl::StatusOr ProtocolBufferToPredicateAst( + absl::Span messages, + absl::string_view input_name) { + PredicateBuilder builder(input_name); + return builder.Build(messages); +} + +} // namespace cel::tools diff --git a/tools/proto_to_predicate.h b/tools/proto_to_predicate.h new file mode 100644 index 000000000..61c9a7106 --- /dev/null +++ b/tools/proto_to_predicate.h @@ -0,0 +1,35 @@ +// Copyright 2026 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifndef THIRD_PARTY_CEL_CPP_TOOLS_PROTO_TO_PREDICATE_H_ +#define THIRD_PARTY_CEL_CPP_TOOLS_PROTO_TO_PREDICATE_H_ + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "common/ast.h" +#include "google/protobuf/message.h" + +namespace cel::tools { + +absl::StatusOr ProtocolBufferToPredicateAst( + const ::google::protobuf::Message& message, absl::string_view input_name); + +absl::StatusOr ProtocolBufferToPredicateAst( + absl::Span messages, + absl::string_view input_name); + +} // namespace cel::tools + +#endif // THIRD_PARTY_CEL_CPP_TOOLS_PROTO_TO_PREDICATE_H_ diff --git a/tools/proto_to_predicate_test.cc b/tools/proto_to_predicate_test.cc new file mode 100644 index 000000000..cd638025d --- /dev/null +++ b/tools/proto_to_predicate_test.cc @@ -0,0 +1,269 @@ +// Copyright 2024 Google LLC +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// https://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "tools/proto_to_predicate.h" + +#include +#include +#include + +#include "absl/status/statusor.h" +#include "absl/strings/string_view.h" +#include "absl/types/span.h" +#include "common/ast.h" +#include "common/ast_proto.h" +#include "compiler/compiler.h" +#include "env/config.h" +#include "env/env.h" +#include "env/env_runtime.h" +#include "env/env_std_extensions.h" +#include "env/env_yaml.h" +#include "env/runtime_std_extensions.h" +#include "eval/testutil/test_message.pb.h" +#include "extensions/protobuf/value.h" +#include "internal/status_macros.h" +#include "internal/testing.h" +#include "internal/testing_descriptor_pool.h" +#include "runtime/activation.h" +#include "runtime/runtime.h" +#include "tools/cel_unparser.h" + +namespace cel::tools { +namespace { + +using ::absl_testing::IsOk; +using ::google::api::expr::runtime::TestMessage; + +constexpr absl::string_view kEnvYaml = R"( +name: "test" +extensions: + - name: "sets" + - name: "two-var-comprehensions" + - name: "bindings" + - name: "optional" +variables: + - name: "input" + type: "eval.testutil.TestMessage" +)"; + +absl::StatusOr> CreateCompiler() { + CEL_ASSIGN_OR_RETURN(cel::Config config, + cel::EnvConfigFromYaml(std::string(kEnvYaml))); + cel::Env env; + cel::RegisterStandardExtensions(env); + env.SetConfig(config); + env.SetDescriptorPool(cel::internal::GetSharedTestingDescriptorPool()); + return env.NewCompiler(); +} + +absl::StatusOr EvaluatePredicate(const cel::Ast& ast, + const TestMessage& input) { + auto descriptor_pool = cel::internal::GetSharedTestingDescriptorPool(); + + CEL_ASSIGN_OR_RETURN(cel::Config config, + cel::EnvConfigFromYaml(std::string(kEnvYaml))); + + cel::EnvRuntime env_runtime; + env_runtime.SetDescriptorPool(descriptor_pool); + cel::RegisterStandardExtensions(env_runtime); + env_runtime.SetConfig(config); + + CEL_ASSIGN_OR_RETURN(std::unique_ptr runtime, + env_runtime.NewRuntime()); + CEL_ASSIGN_OR_RETURN(std::unique_ptr program, + runtime->CreateProgram(std::make_unique(ast))); + + google::protobuf::Arena arena; + cel::Activation activation; + CEL_ASSIGN_OR_RETURN( + cel::Value val, cel::extensions::ProtoMessageToValue( + input, descriptor_pool.get(), + google::protobuf::MessageFactory::generated_factory(), &arena)); + activation.InsertOrAssignValue("input", val); + + CEL_ASSIGN_OR_RETURN(cel::Value result, + program->Evaluate(&arena, activation)); + if (!result.IsBool()) { + return absl::InvalidArgumentError( + "Predicate evaluate result must be a boolean value."); + } + return result.GetBool(); +} + +TEST(ProtocolBufferToPredicateAstTest, PrimitivesTest) { + TestMessage msg; + msg.set_int32_value(42); + msg.set_string_value("hello"); + + ASSERT_OK_AND_ASSIGN(cel::Ast ast, + ProtocolBufferToPredicateAst(msg, "input")); + ASSERT_OK_AND_ASSIGN(auto compiler, CreateCompiler()); + + auto result_or = + compiler->GetTypeChecker().Check(std::make_unique(ast)); + ASSERT_THAT(result_or, IsOk()); + EXPECT_TRUE(result_or->IsValid()); + + cel::expr::ParsedExpr parsed_expr; + ASSERT_THAT(cel::AstToParsedExpr(ast, &parsed_expr), IsOk()); + ASSERT_OK_AND_ASSIGN(auto unparsed, google::api::expr::Unparse(parsed_expr)); + EXPECT_EQ(unparsed, + "input.int32_value == 42 && input.string_value == \"hello\""); +} + +TEST(ProtocolBufferToPredicateAstTest, RepeatedFieldTest) { + TestMessage msg; + msg.add_int32_list(1); + msg.add_int32_list(2); + + ASSERT_OK_AND_ASSIGN(cel::Ast ast, + ProtocolBufferToPredicateAst(msg, "input")); + ASSERT_OK_AND_ASSIGN(auto compiler, CreateCompiler()); + + auto result_or = + compiler->GetTypeChecker().Check(std::make_unique(ast)); + ASSERT_THAT(result_or, IsOk()); + EXPECT_TRUE(result_or->IsValid()); + + cel::expr::ParsedExpr parsed_expr; + ASSERT_THAT(cel::AstToParsedExpr(ast, &parsed_expr), IsOk()); + ASSERT_OK_AND_ASSIGN(auto unparsed, google::api::expr::Unparse(parsed_expr)); + EXPECT_EQ(unparsed, "sets.contains(input.int32_list, [1, 2])"); +} + +TEST(ProtocolBufferToPredicateAstTest, MapFieldTest) { + TestMessage msg; + auto& map = *msg.mutable_string_int32_map(); + map["foo"] = 1; + map["bar"] = 2; + + ASSERT_OK_AND_ASSIGN(cel::Ast ast, + ProtocolBufferToPredicateAst(msg, "input")); + ASSERT_OK_AND_ASSIGN(auto compiler, CreateCompiler()); + + auto result_or = + compiler->GetTypeChecker().Check(std::make_unique(ast)); + ASSERT_THAT(result_or, IsOk()); + EXPECT_TRUE(result_or->IsValid()); + + cel::expr::ParsedExpr parsed_expr; + ASSERT_THAT(cel::AstToParsedExpr(ast, &parsed_expr), IsOk()); + ASSERT_OK_AND_ASSIGN(auto unparsed, google::api::expr::Unparse(parsed_expr)); + EXPECT_THAT( + unparsed, + testing::AnyOf( + testing::Eq("cel.bind(map_val, input.string_int32_map, " + "{\"bar\": 2, \"foo\": 1}.all(k, v, " + "map_val[?k].optMap(val, val == v).orValue(false)))"), + testing::Eq("cel.bind(map_val, input.string_int32_map, " + "{\"foo\": 1, \"bar\": 2}.all(k, v, " + "map_val[?k].optMap(val, val == v).orValue(false)))"))); +} + +TEST(ProtocolBufferToPredicateAstTest, MultipleMessagesTest) { + TestMessage msg1; + msg1.set_int32_value(42); + + TestMessage msg2; + msg2.set_int32_value(41); + msg2.set_string_value("hello"); + + std::vector messages = {&msg1, &msg2}; + + ASSERT_OK_AND_ASSIGN(auto ast, ProtocolBufferToPredicateAst( + absl::MakeSpan(messages), "input")); + + cel::expr::ParsedExpr parsed_expr; + ASSERT_THAT(cel::AstToParsedExpr(ast, &parsed_expr), IsOk()); + ASSERT_OK_AND_ASSIGN(auto unparsed, google::api::expr::Unparse(parsed_expr)); + EXPECT_EQ(unparsed, + "input.int32_value == 42 || input.int32_value == 41 && " + "input.string_value == \"hello\""); +} + +TEST(ProtocolBufferToPredicateAstTest, ListFieldEvalTest) { + // Expected template proto: list has exactly [1, 2]. + TestMessage expected_msg; + expected_msg.add_int32_list(1); + expected_msg.add_int32_list(2); + + ASSERT_OK_AND_ASSIGN(cel::Ast ast, + ProtocolBufferToPredicateAst(expected_msg, "input")); + + // Positive case: input list satisfies sets.contains criteria. + { + TestMessage input; + input.add_int32_list(1); + input.add_int32_list(2); + + ASSERT_OK_AND_ASSIGN(bool eval_result, EvaluatePredicate(ast, input)); + EXPECT_TRUE(eval_result); + } + + // Negative case: input list does not satisfy criteria (subset mismatch). + { + TestMessage input; + input.add_int32_list(1); + input.add_int32_list(3); + + ASSERT_OK_AND_ASSIGN(bool eval_result, EvaluatePredicate(ast, input)); + EXPECT_FALSE(eval_result); + } +} + +TEST(ProtocolBufferToPredicateAstTest, MapFieldEvalTest) { + // Expected template proto: map has exactly {"bar": 2, "foo": 1}. + TestMessage expected_msg; + auto& expected_map = *expected_msg.mutable_string_int32_map(); + expected_map["foo"] = 1; + expected_map["bar"] = 2; + + ASSERT_OK_AND_ASSIGN(cel::Ast ast, + ProtocolBufferToPredicateAst(expected_msg, "input")); + + // Positive case: input map satisfies expected key-value checks. + { + TestMessage input; + auto& map = *input.mutable_string_int32_map(); + map["foo"] = 1; + map["bar"] = 2; + + ASSERT_OK_AND_ASSIGN(bool eval_result, EvaluatePredicate(ast, input)); + EXPECT_TRUE(eval_result); + } + + // Negative case: input map has same keys but incorrect value. + { + TestMessage input; + auto& map = *input.mutable_string_int32_map(); + map["foo"] = 1; + map["bar"] = 3; // Should be 2 + + ASSERT_OK_AND_ASSIGN(bool eval_result, EvaluatePredicate(ast, input)); + EXPECT_FALSE(eval_result); + } + + // Negative case: input map is missing one expected key. + { + TestMessage input; + auto& map = *input.mutable_string_int32_map(); + map["foo"] = 1; + + ASSERT_OK_AND_ASSIGN(bool eval_result, EvaluatePredicate(ast, input)); + EXPECT_FALSE(eval_result); + } +} + +} // namespace +} // namespace cel::tools