diff --git a/cel/cel_test.go b/cel/cel_test.go index 9dd3fd869..976ff8ded 100644 --- a/cel/cel_test.go +++ b/cel/cel_test.go @@ -2234,6 +2234,62 @@ func TestContextProto(t *testing.T) { } } +func TestContextProtoJSONFieldNames(t *testing.T) { + descriptor := new(proto3pb.TestAllTypes).ProtoReflect().Descriptor() + env := testEnv(t, JSONFieldNames(true), DeclareContextProto(descriptor)) + expression := ` + singleInt64 == 1 + && singleDouble == 1.0 + && singleBool == true + && singleString == '' + && singleNestedMessage == google.expr.proto3.test.TestAllTypes.NestedMessage{} + && standaloneEnum == google.expr.proto3.test.TestAllTypes.NestedEnum.FOO + && singleDuration == duration('5s') + && singleTimestamp == timestamp(63154820) + && singleAny == null + && singleUint32Wrapper == null + && singleUint64Wrapper == 0u + && repeatedInt32 == [1,2] + && mapStringString == {'': ''} + && mapInt64NestedType == {0 : google.expr.proto3.test.NestedTestAllTypes{}}` + ast, iss := env.Compile(expression) + if iss.Err() != nil { + t.Fatalf("env.Compile(%s) failed: %s", expression, iss.Err()) + } + prg, err := env.Program(ast) + if err != nil { + t.Fatalf("env.Program() failed: %v", err) + } + in := &proto3pb.TestAllTypes{ + SingleInt64: 1, + SingleDouble: 1.0, + SingleBool: true, + NestedType: &proto3pb.TestAllTypes_SingleNestedMessage{ + SingleNestedMessage: &proto3pb.TestAllTypes_NestedMessage{}, + }, + StandaloneEnum: proto3pb.TestAllTypes_FOO, + SingleDuration: &durationpb.Duration{Seconds: 5}, + SingleTimestamp: ×tamppb.Timestamp{ + Seconds: 63154820, + }, + SingleUint64Wrapper: wrapperspb.UInt64(0), + RepeatedInt32: []int32{1, 2}, + MapStringString: map[string]string{"": ""}, + MapInt64NestedType: map[int64]*proto3pb.NestedTestAllTypes{0: {}}, + } + vars, err := ContextProtoVars(in, types.JSONFieldNames(true)) + if err != nil { + t.Fatalf("ContextProtoVars(%v) failed: %v", in, err) + } + out, _, err := prg.Eval(vars) + if err != nil { + t.Fatalf("prg.Eval() failed: %v", err) + } + if out.Equal(types.True) != types.True { + t.Errorf("prg.Eval() got %v, wanted true", out) + } +} + func TestRegexOptimizer(t *testing.T) { var stringTests = []struct { expr string @@ -3607,6 +3663,125 @@ func TestAstProgramNilValue(t *testing.T) { } } +func TestJSONFieldNames(t *testing.T) { + tests := []struct { + name string + expr string + jsonFieldNames bool + }{ + { + name: "proto simple field", + expr: `msg.single_int32 == 1`, + }, + { + name: "proto map field", + expr: `msg.map_string_string['key'] == 'value'`, + }, + { + name: "json simple field", + expr: `msg.singleInt32 == 1`, + jsonFieldNames: true, + }, + { + name: "json repeated field", + expr: `msg.mapStringString['key'] == 'value'`, + jsonFieldNames: true, + }, + { + name: "message with json field", + expr: `TestAllTypes{singleInt32: 1} != msg`, + jsonFieldNames: true, + }, + { + name: "message with json field and proto fallback", + expr: `dyn(TestAllTypes{singleInt32: 2}).single_int32 == 2`, + jsonFieldNames: true, + }, + { + name: "json with proto fallback", + expr: `dyn(msg).single_int32 == dyn(msg).singleInt32`, + jsonFieldNames: true, + }, + } + msg := &proto3pb.TestAllTypes{ + SingleInt32: 1, + MapStringString: map[string]string{ + "key": "value", + }, + } + for _, tst := range tests { + tc := tst + t.Run(tc.name, func(t *testing.T) { + env, err := NewEnv( + JSONFieldNames(tc.jsonFieldNames), + Types(msg), + Container(string(msg.ProtoReflect().Descriptor().ParentFile().Package())), + Variable("msg", ObjectType(string(msg.ProtoReflect().Descriptor().FullName()))), + ) + if err != nil { + t.Fatalf("NewEnv() failed: %v", err) + } + ast, iss := env.Compile(tc.expr) + if iss.Err() != nil { + t.Fatalf("env.Compile() failed: %v", iss.Err()) + } + prg, err := env.Program(ast) + if err != nil { + t.Fatalf("env.Program() failed: %v", err) + } + out, _, err := prg.Eval(map[string]any{"msg": msg}) + if err != nil { + t.Fatalf("prg.Eval() failed: %v", err) + } + if out != types.True { + t.Errorf("prg.Eval() got %v, wanted 'true'", out) + } + + if tc.jsonFieldNames { + noJSONEnv, err := env.Extend(JSONFieldNames(false)) + if err != nil { + t.Fatalf("env.Extend() failed: %v", err) + } + _, err = noJSONEnv.Program(ast) + if err == nil { + t.Fatal("env with json disabled allowed program with json extension to be planned") + } + } else { + jsonEnv, err := env.Extend(JSONFieldNames(true)) + if err != nil { + t.Fatalf("env.Extend() failed: %v", err) + } + prg, err = jsonEnv.Program(ast) + if err != nil { + t.Fatalf("env.Program() failed: %v", err) + } + out, _, err := prg.Eval(map[string]any{"msg": msg}) + if err != nil { + t.Fatalf("prg.Eval() failed: %v", err) + } + if out != types.True { + t.Errorf("prg.Eval() got %v, wanted 'true'", out) + } + } + }) + } +} + +func TestJSONFieldNamesInvalidProvider(t *testing.T) { + type wrapperRegistry struct { + *types.Registry + } + reg, err := types.NewProtoRegistry(types.JSONFieldNames(true)) + if err != nil { + t.Fatalf("types.NewProtoRegistry() failed: %v", err) + } + wrapped := wrapperRegistry{Registry: reg} + _, err = NewEnv(CustomTypeProvider(wrapped), CustomTypeAdapter(reg), JSONFieldNames(true)) + if err == nil { + t.Error("NewEnv() created a CEL environment successfully despite incompatible configs") + } +} + // TODO: ideally testCostEstimator and testRuntimeCostEstimator would be shared in a test fixtures package type testCostEstimator struct { hints map[string]uint64 diff --git a/cel/env.go b/cel/env.go index d3295b6a5..e2de2ff6f 100644 --- a/cel/env.go +++ b/cel/env.go @@ -184,6 +184,16 @@ func (e *Env) ToConfig(name string) (*env.Config, error) { conf.AddImports(env.NewImport(typeName)) } + // Serialize features + for featID, enabled := range e.features { + featName, found := featureNameByID(featID) + if !found { + // If the feature isn't named, it isn't intended to be publicly exposed + continue + } + conf.AddFeatures(env.NewFeature(featName, enabled)) + } + libOverloads := map[string][]string{} for libName, lib := range e.libraries { // Track the options which have been configured by a library and @@ -244,7 +254,7 @@ func (e *Env) ToConfig(name string) (*env.Config, error) { fields := e.contextProto.Fields() for i := 0; i < fields.Len(); i++ { field := fields.Get(i) - variable, err := fieldToVariable(field) + variable, err := fieldToVariable(field, e.HasFeature(featureJSONFieldNames)) if err != nil { return nil, fmt.Errorf("could not serialize context field variable %q, reason: %w", field.FullName(), err) } @@ -279,16 +289,6 @@ func (e *Env) ToConfig(name string) (*env.Config, error) { } } - // Serialize features - for featID, enabled := range e.features { - featName, found := featureNameByID(featID) - if !found { - // If the feature isn't named, it isn't intended to be publicly exposed - continue - } - conf.AddFeatures(env.NewFeature(featName, enabled)) - } - for id, val := range e.limits { limitName, found := limitNameByID(id) if !found || val == 0 { @@ -361,7 +361,7 @@ func NewEnv(opts ...EnvOption) (*Env, error) { // See the EnvOption helper functions for the options that can be used to configure the // environment. func NewCustomEnv(opts ...EnvOption) (*Env, error) { - registry, err := types.NewRegistry() + registry, err := types.NewProtoRegistry() if err != nil { return nil, err } @@ -554,6 +554,7 @@ func (e *Env) Extend(opts ...EnvOption) (*Env, error) { } validatorsCopy := make([]ASTValidator, len(e.validators)) copy(validatorsCopy, e.validators) + costOptsCopy := make([]checker.CostOption, len(e.costOptions)) copy(costOptsCopy, e.costOptions) @@ -847,6 +848,18 @@ func (e *Env) configure(opts []EnvOption) (*Env, error) { return nil, err } + // Enable JSON field names is using a proto-based *types.Registry + if e.HasFeature(featureJSONFieldNames) { + reg, isReg := e.provider.(*types.Registry) + if !isReg { + return nil, fmt.Errorf("JSONFieldNames() option is only compatible with *types.Registry providers") + } + err := reg.WithJSONFieldNames(true) + if err != nil { + return nil, err + } + } + // Ensure that the checker init happens eagerly rather than lazily. if e.HasFeature(featureEagerlyValidateDeclarations) { _, err := e.initChecker() @@ -865,6 +878,8 @@ func (e *Env) initChecker() (*checker.Env, error) { chkOpts = append(chkOpts, checker.CrossTypeNumericComparisons( e.HasFeature(featureCrossTypeNumericComparisons))) + chkOpts = append(chkOpts, + checker.JSONFieldNames(e.HasFeature(featureJSONFieldNames))) ce, err := checker.NewEnv(e.Container, e.provider, chkOpts...) if err != nil { diff --git a/cel/env_test.go b/cel/env_test.go index 38322caa1..54f65b575 100644 --- a/cel/env_test.go +++ b/cel/env_test.go @@ -198,9 +198,9 @@ func TestEnvPartialVarsError(t *testing.T) { } func TestTypeProviderInterop(t *testing.T) { - reg, err := types.NewRegistry(&proto3pb.TestAllTypes{}) + reg, err := types.NewProtoRegistry(types.ProtoTypeDefs(&proto3pb.TestAllTypes{})) if err != nil { - t.Fatalf("types.NewRegistry() failed: %v", err) + t.Fatalf("types.NewProtoRegistry() failed: %v", err) } tests := []struct { name string @@ -399,6 +399,14 @@ func TestEnvToConfig(t *testing.T) { env.NewMemberOverload("string_last", env.NewTypeDesc("string"), []*env.TypeDesc{}, env.NewTypeDesc("string")), )), }, + { + name: "json field names", + opts: []EnvOption{ + JSONFieldNames(true), + }, + want: env.NewConfig("json field names"). + AddFeatures(env.NewFeature("cel.feature.json_field_names", true)), + }, { name: "context proto - with extra variable", opts: []EnvOption{ @@ -495,7 +503,7 @@ func TestEnvFromConfig(t *testing.T) { { name: "std env - imports", beforeOpts: []EnvOption{Types(&proto3pb.TestAllTypes{})}, - conf: env.NewConfig("std env - context proto"). + conf: env.NewConfig("std env - imports"). AddImports(env.NewImport("google.expr.proto3.test.TestAllTypes")), exprs: []exprCase{ { @@ -520,6 +528,22 @@ func TestEnvFromConfig(t *testing.T) { }, }, }, + { + name: "std env - context proto w/ json field names", + beforeOpts: []EnvOption{Types(&proto3pb.TestAllTypes{})}, + conf: env.NewConfig("std env - context proto w/ json field names"). + SetContainer("google.expr.proto3.test"). + SetContextVariable(env.NewContextVariable("google.expr.proto3.test.TestAllTypes")). + AddFeatures(env.NewFeature("cel.feature.json_field_names", true)), + exprs: []exprCase{ + { + name: "field select literal", + in: mustContextProto(t, &proto3pb.TestAllTypes{SingleInt64: 10}, types.JSONFieldNames(true)), + expr: "TestAllTypes{singleInt64: singleInt64}.singleInt64", + out: types.Int(10), + }, + }, + }, { name: "custom env - variables", beforeOpts: []EnvOption{Types(&proto3pb.TestAllTypes{})}, @@ -1154,9 +1178,9 @@ func BenchmarkEnvExtendEagerDecls(b *testing.B) { } } -func mustContextProto(t *testing.T, pb proto.Message) Activation { +func mustContextProto(t *testing.T, pb proto.Message, opts ...types.RegistryOption) Activation { t.Helper() - ctx, err := ContextProtoVars(pb) + ctx, err := ContextProtoVars(pb, opts...) if err != nil { t.Fatalf("ContextProtoVars() failed: %v", err) } diff --git a/cel/options.go b/cel/options.go index 0287da9e8..d7d2ab034 100644 --- a/cel/options.go +++ b/cel/options.go @@ -71,12 +71,16 @@ const ( // Enable escape syntax for field identifiers (`). featureIdentEscapeSyntax + + // Enable accessing fields by JSON names within protobuf messages + featureJSONFieldNames ) var featureIDsToNames = map[int]string{ featureEnableMacroCallTracking: "cel.feature.macro_call_tracking", featureCrossTypeNumericComparisons: "cel.feature.cross_type_numeric_comparisons", featureIdentEscapeSyntax: "cel.feature.backtick_escape_syntax", + featureJSONFieldNames: "cel.feature.json_field_names", } func featureNameByID(id int) (string, bool) { @@ -309,9 +313,9 @@ func Abbrevs(qualifiedNames ...string) EnvOption { } } -// customTypeRegistry is an internal-only interface containing the minimum methods required to support +// protoTypeRegistry is an internal-only interface containing the minimum methods required to support // custom types. It is a subset of methods from ref.TypeRegistry. -type customTypeRegistry interface { +type protoTypeRegistry interface { RegisterDescriptor(protoreflect.FileDescriptor) error RegisterType(...ref.Type) error } @@ -328,7 +332,7 @@ type customTypeRegistry interface { // Note: This option must be specified after the CustomTypeProvider option when used together. func Types(addTypes ...any) EnvOption { return func(e *Env) (*Env, error) { - reg, isReg := e.provider.(customTypeRegistry) + reg, isReg := e.provider.(protoTypeRegistry) if !isReg { return nil, fmt.Errorf("custom types not supported by provider: %T", e.provider) } @@ -365,7 +369,7 @@ func Types(addTypes ...any) EnvOption { // extension or by re-using the same EnvOption with another NewEnv() call. func TypeDescs(descs ...any) EnvOption { return func(e *Env) (*Env, error) { - reg, isReg := e.provider.(customTypeRegistry) + reg, isReg := e.provider.(protoTypeRegistry) if !isReg { return nil, fmt.Errorf("custom types not supported by provider: %T", e.provider) } @@ -413,7 +417,7 @@ func TypeDescs(descs ...any) EnvOption { } } -func registerFileSet(reg customTypeRegistry, fileSet *descpb.FileDescriptorSet) error { +func registerFileSet(reg protoTypeRegistry, fileSet *descpb.FileDescriptorSet) error { files, err := protodesc.NewFiles(fileSet) if err != nil { return fmt.Errorf("protodesc.NewFiles(%v) failed: %v", fileSet, err) @@ -421,7 +425,7 @@ func registerFileSet(reg customTypeRegistry, fileSet *descpb.FileDescriptorSet) return registerFiles(reg, files) } -func registerFiles(reg customTypeRegistry, files *protoregistry.Files) error { +func registerFiles(reg protoTypeRegistry, files *protoregistry.Files) error { var err error files.RangeFiles(func(fd protoreflect.FileDescriptor) bool { err = reg.RegisterDescriptor(fd) @@ -430,6 +434,15 @@ func registerFiles(reg customTypeRegistry, files *protoregistry.Files) error { return err } +// JSONFieldNames supports accessing protocol buffer fields by json-name. +// +// Enabling JSON field name support will create a copy of the types.Registry with fields indexed +// by JSON name, and whether JSON name or Proto-style names are supported will be inferred from +// the AST extensions metadata. +func JSONFieldNames(enabled bool) EnvOption { + return features(featureJSONFieldNames, enabled) +} + // ProgramOption is a functional interface for configuring evaluation bindings and behaviors. type ProgramOption func(p *prog) (*prog, error) @@ -557,6 +570,17 @@ func configToEnvOptions(config *env.Config, provider types.Provider, optFactorie envOpts = append(envOpts, Abbrevs(imp.Name)) } + // Configure features and common limits. + for _, feat := range config.Features { + // Note, if a feature is not found, it is skipped as it is possible the feature + // is not intended to be supported publicly. In the future, a refinement of + // to this strategy to report unrecognized features and validators should probably + // be covered as a standard ConfigOptionFactory + if id, found := featureIDByName(feat.Name); found { + envOpts = append(envOpts, features(id, feat.Enabled)) + } + } + // Configure the context variable declaration if config.ContextVariable != nil { typeName := config.ContextVariable.TypeName @@ -598,17 +622,6 @@ func configToEnvOptions(config *env.Config, provider types.Provider, optFactorie envOpts = append(envOpts, FunctionDecls(funcs...)) } - // Configure features and common limits. - for _, feat := range config.Features { - // Note, if a feature is not found, it is skipped as it is possible the feature - // is not intended to be supported publicly. In the future, a refinement of - // to this strategy to report unrecognized features and validators should probably - // be covered as a standard ConfigOptionFactory - if id, found := featureIDByName(feat.Name); found { - envOpts = append(envOpts, features(id, feat.Enabled)) - } - } - for _, limit := range config.Limits { if id, found := limitIDByName(limit.Name); found { envOpts = append(envOpts, setLimit(id, limit.Value)) @@ -767,8 +780,11 @@ func fieldToCELType(field protoreflect.FieldDescriptor) (*Type, error) { return nil, fmt.Errorf("field %s type %s not implemented", field.FullName(), field.Kind().String()) } -func fieldToVariable(field protoreflect.FieldDescriptor) (*decls.VariableDecl, error) { +func fieldToVariable(field protoreflect.FieldDescriptor, jsonFieldNames bool) (*decls.VariableDecl, error) { name := string(field.Name()) + if jsonFieldNames { + name = field.JSONName() + } if field.IsMap() { mapKey := field.MapKey() mapValue := field.MapValue() @@ -799,6 +815,8 @@ func fieldToVariable(field protoreflect.FieldDescriptor) (*decls.VariableDecl, e // DeclareContextProto returns an option to extend CEL environment with declarations from the given context proto. // Each field of the proto defines a variable of the same name in the environment. // https://github.com/google/cel-spec/blob/master/doc/langdef.md#evaluation-environment +// +// If using JSONFieldNames(), ensure that the option is set before DeclareContextProto is provided. func DeclareContextProto(descriptor protoreflect.MessageDescriptor) EnvOption { return func(e *Env) (*Env, error) { if e.contextProto != nil { @@ -808,9 +826,10 @@ func DeclareContextProto(descriptor protoreflect.MessageDescriptor) EnvOption { e.contextProto = descriptor fields := descriptor.Fields() vars := make([]*decls.VariableDecl, 0, fields.Len()) + jsonFieldNames := e.HasFeature(featureJSONFieldNames) for i := 0; i < fields.Len(); i++ { field := fields.Get(i) - variable, err := fieldToVariable(field) + variable, err := fieldToVariable(field, jsonFieldNames) if err != nil { return nil, err } @@ -829,11 +848,15 @@ func DeclareContextProto(descriptor protoreflect.MessageDescriptor) EnvOption { // // Consider using with `DeclareContextProto` to simplify variable type declarations and publishing when using // protocol buffers. -func ContextProtoVars(ctx proto.Message) (Activation, error) { +// +// Use the types.JSONFieldNames(true) option to populate the context proto vars using the JSON field names. +func ContextProtoVars(ctx proto.Message, opts ...types.RegistryOption) (Activation, error) { if ctx == nil || !ctx.ProtoReflect().IsValid() { return interpreter.EmptyActivation(), nil } - reg, err := types.NewRegistry(ctx) + regOpts := []types.RegistryOption{types.ProtoTypeDefs(ctx)} + regOpts = append(regOpts, opts...) + reg, err := types.NewProtoRegistry(regOpts...) if err != nil { return nil, err } @@ -843,15 +866,19 @@ func ContextProtoVars(ctx proto.Message) (Activation, error) { vars := make(map[string]any, fields.Len()) for i := 0; i < fields.Len(); i++ { field := fields.Get(i) - sft, found := reg.FindStructFieldType(typeName, field.TextName()) + fieldName := field.TextName() + if reg.JSONFieldNames() { + fieldName = field.JSONName() + } + sft, found := reg.FindStructFieldType(typeName, fieldName) if !found { - return nil, fmt.Errorf("no such field: %s", field.TextName()) + return nil, fmt.Errorf("no such field: %s", fieldName) } fieldVal, err := sft.GetFrom(ctx) if err != nil { return nil, err } - vars[field.TextName()] = fieldVal + vars[fieldName] = fieldVal } return NewActivation(vars) } diff --git a/cel/program.go b/cel/program.go index 00577e05d..c46d694e4 100644 --- a/cel/program.go +++ b/cel/program.go @@ -219,6 +219,12 @@ func newProgram(e *Env, a *ast.AST, opts []ProgramOption) (Program, error) { attrFactorOpts := []interpreter.AttrFactoryOption{ interpreter.EnableErrorOnBadPresenceTest(p.HasFeature(featureEnableErrorOnBadPresenceTest)), } + if a.SourceInfo().HasExtension("json_name", ast.NewExtensionVersion(1, 1)) { + if !e.HasFeature(featureJSONFieldNames) { + return nil, errors.New("the AST extension 'json_name' requires the option cel.JSONFieldNames(true)") + } + } + // Configure the type provider, considering whether the AST indicates whether it supports JSON field names if p.evalOpts&OptPartialEval == OptPartialEval { attrFactory = interpreter.NewPartialAttributeFactory(e.Container, e.adapter, e.provider, attrFactorOpts...) } else { diff --git a/checker/checker_test.go b/checker/checker_test.go index 623e1f709..3e86ec577 100644 --- a/checker/checker_test.go +++ b/checker/checker_test.go @@ -2544,7 +2544,7 @@ func TestCheck(t *testing.T) { reg, err := types.NewProtoRegistry( types.JSONFieldNames(tc.env.jsonFieldNames), - types.ProtoTypes(&proto2pb.TestAllTypes{}, &proto3pb.TestAllTypes{}), + types.ProtoTypeDefs(&proto2pb.TestAllTypes{}, &proto3pb.TestAllTypes{}), ) if err != nil { t.Fatalf("types.NewProtoRegistry() failed: %v", err) @@ -2654,9 +2654,9 @@ func BenchmarkCheck(b *testing.B) { if len(errors.GetErrors()) > 0 { b.Fatalf("Unexpected parse errors: %v", errors.ToDisplayString()) } - reg, err := types.NewRegistry(&proto2pb.TestAllTypes{}, &proto3pb.TestAllTypes{}) + reg, err := types.NewProtoRegistry(types.ProtoTypeDefs(&proto2pb.TestAllTypes{}, &proto3pb.TestAllTypes{})) if err != nil { - b.Fatalf("types.NewRegistry() failed: %v", err) + b.Fatalf("types.NewProtoRegistry() failed: %v", err) } if tc.env.optionalSyntax { if err := reg.RegisterType(types.OptionalType); err != nil { @@ -2723,9 +2723,9 @@ func BenchmarkCheck(b *testing.B) { } func TestAddDuplicateDeclarations(t *testing.T) { - reg, err := types.NewRegistry(&proto2pb.TestAllTypes{}, &proto3pb.TestAllTypes{}) + reg, err := types.NewProtoRegistry(types.ProtoTypeDefs(&proto2pb.TestAllTypes{}, &proto3pb.TestAllTypes{})) if err != nil { - t.Fatalf("types.NewRegistry() failed: %v", err) + t.Fatalf("types.NewProtoRegistry() failed: %v", err) } env, err := NewEnv(containers.DefaultContainer, reg, CrossTypeNumericComparisons(true)) if err != nil { @@ -2742,9 +2742,9 @@ func TestAddDuplicateDeclarations(t *testing.T) { } func TestAddEquivalentDeclarations(t *testing.T) { - reg, err := types.NewRegistry(&proto2pb.TestAllTypes{}, &proto3pb.TestAllTypes{}) + reg, err := types.NewProtoRegistry(types.ProtoTypeDefs(&proto2pb.TestAllTypes{}, &proto3pb.TestAllTypes{})) if err != nil { - t.Fatalf("types.NewRegistry() failed: %v", err) + t.Fatalf("types.NewProtoRegistry() failed: %v", err) } env, err := NewEnv(containers.DefaultContainer, reg, CrossTypeNumericComparisons(true)) if err != nil { diff --git a/checker/cost_test.go b/checker/cost_test.go index f667ebe0e..10ccdcde1 100644 --- a/checker/cost_test.go +++ b/checker/cost_test.go @@ -757,9 +757,9 @@ func TestCost(t *testing.T) { if len(errs.GetErrors()) != 0 { t.Fatalf("parser.Parse(%v) failed: %v", tc.expr, errs.ToDisplayString()) } - reg, err := types.NewRegistry(&proto3pb.TestAllTypes{}) + reg, err := types.NewProtoRegistry(types.ProtoTypeDefs(&proto3pb.TestAllTypes{})) if err != nil { - t.Fatalf("types.NewRegistry(...) failed: %v", err) + t.Fatalf("types.NewProtoRegistry(...) failed: %v", err) } e, err := NewEnv(containers.DefaultContainer, reg) diff --git a/common/ast/ast.go b/common/ast/ast.go index aae2a83e9..3ae2e1063 100644 --- a/common/ast/ast.go +++ b/common/ast/ast.go @@ -438,6 +438,17 @@ func (s *SourceInfo) Extensions() []Extension { return s.extensions } +// HasExtension returns whether the source info contains the extension which satisfies the minimum version requirement. +// +// For an extension to be considered 'present' it must have the same major version as the minVersion and a minor version +// at least as great as the lowest minor version specified. +func (s *SourceInfo) HasExtension(id string, minVersion ExtensionVersion) bool { + for _, ext := range s.Extensions() { + return ext.ID == id && ext.Version.Major == minVersion.Major && ext.Version.Minor >= minVersion.Minor + } + return false +} + // AddExtension adds an extension record into the SourceInfo. func (s *SourceInfo) AddExtension(ext Extension) { if s == nil { diff --git a/common/ast/ast_test.go b/common/ast/ast_test.go index 17c4c82c2..d8aac7ec8 100644 --- a/common/ast/ast_test.go +++ b/common/ast/ast_test.go @@ -418,6 +418,20 @@ func TestHeights(t *testing.T) { } } +func TestHasExtension(t *testing.T) { + info := ast.NewSourceInfo(common.NewStringSource("true", "test-only")) + info.AddExtension(ast.NewExtension("json_name", ast.NewExtensionVersion(1, 1), ast.ComponentRuntime)) + if !info.HasExtension("json_name", ast.NewExtensionVersion(1, 0)) { + t.Error("info.HasExtension('json_name', 1.0) returned false for v1.1") + } + if info.HasExtension("json_name", ast.NewExtensionVersion(2, 1)) { + t.Error("info.HasExtension() returned true for v2.1 when v1.1 configured") + } + if info.HasExtension("unrelated", ast.NewExtensionVersion(0, 0)) { + t.Error("info.HasExtension() returned true for unrelated extensions not set on AST") + } +} + func mockRelativeSource(t testing.TB, text string, lineOffsets []int32, baseLocation common.Location) common.Source { t.Helper() return &mockSource{ diff --git a/common/ast/navigable_test.go b/common/ast/navigable_test.go index bd8fa8ac1..da86e3437 100644 --- a/common/ast/navigable_test.go +++ b/common/ast/navigable_test.go @@ -580,7 +580,7 @@ func mustTypeCheck(t testing.TB, expr string, opts ...any) *ast.AST { t.Fatalf("mustTypeCheck() failed with invalid option type: %T", v) } } - regOpts = append(regOpts, types.ProtoTypes(&proto3pb.TestAllTypes{})) + regOpts = append(regOpts, types.ProtoTypeDefs(&proto3pb.TestAllTypes{})) reg := newTestRegistry(t, regOpts...) env := newTestEnv(t, containers.DefaultContainer, reg, chkOpts...) checked, iss := checker.Check(parsed, exprSrc, env) diff --git a/common/env/env_test.go b/common/env/env_test.go index 683ca3cf4..bfb981249 100644 --- a/common/env/env_test.go +++ b/common/env/env_test.go @@ -654,9 +654,9 @@ func TestVariableAsCELVariable(t *testing.T) { }, } - tp, err := types.NewRegistry() + tp, err := types.NewProtoRegistry() if err != nil { - t.Fatalf("types.NewRegistry() failed: %v", err) + t.Fatalf("types.NewProtoRegistry() failed: %v", err) } tp.RegisterType(types.NewOpaqueType("set", types.NewTypeParamType("T"))) for _, tst := range tests { @@ -842,9 +842,9 @@ func TestFunctionAsCELFunction(t *testing.T) { types.NewTypeParamType("T"))), }, } - tp, err := types.NewRegistry() + tp, err := types.NewProtoRegistry() if err != nil { - t.Fatalf("types.NewRegistry() failed: %v", err) + t.Fatalf("types.NewProtoRegistry() failed: %v", err) } tp.RegisterType(types.NewOpaqueType("set", types.NewTypeParamType("T"))) for _, tst := range tests { @@ -953,9 +953,9 @@ func TestTypeDescAsCELTypeErrors(t *testing.T) { want: errors.New("undefined type"), }, } - tp, err := types.NewRegistry() + tp, err := types.NewProtoRegistry() if err != nil { - t.Fatalf("types.NewRegistry() failed: %v", err) + t.Fatalf("types.NewProtoRegistry() failed: %v", err) } tp.RegisterType(types.NewOpaqueType("set", types.NewTypeParamType("T"))) for _, tst := range tests { diff --git a/common/types/map_test.go b/common/types/map_test.go index 5b9898c27..81989120d 100644 --- a/common/types/map_test.go +++ b/common/types/map_test.go @@ -44,7 +44,7 @@ type testStruct struct { } func TestMapContains(t *testing.T) { - reg := newTestRegistry(t, ProtoTypes(&proto3pb.TestAllTypes{})) + reg := newTestRegistry(t, ProtoTypeDefs(&proto3pb.TestAllTypes{})) reflectMap := reg.NativeToValue(map[any]any{ int64(1): "hello", uint64(2): "world", @@ -582,7 +582,7 @@ func TestMapIsZeroValue(t *testing.T) { "hello": "world", }, } - reg := newTestRegistry(t, ProtoTypes(msg)) + reg := newTestRegistry(t, ProtoTypeDefs(msg)) obj := reg.NativeToValue(msg).(traits.Indexer) tests := []struct { @@ -749,7 +749,7 @@ func TestProtoMap(t *testing.T) { "welcome": "back", } msg := &proto3pb.TestAllTypes{MapStringString: strMap} - reg := newTestRegistry(t, ProtoTypes(msg)) + reg := newTestRegistry(t, ProtoTypeDefs(msg)) obj := reg.NativeToValue(msg).(traits.Indexer) // Test a simple proto map of string string. @@ -850,7 +850,7 @@ func TestProtoMapGet(t *testing.T) { "welcome": "back", } msg := &proto3pb.TestAllTypes{MapStringString: strMap} - reg := newTestRegistry(t, ProtoTypes(msg)) + reg := newTestRegistry(t, ProtoTypeDefs(msg)) obj := reg.NativeToValue(msg).(traits.Indexer) field := obj.Get(String("map_string_string")) mapVal, ok := field.(traits.Mapper) @@ -890,7 +890,7 @@ func TestProtoMapConvertToNative(t *testing.T) { "welcome": "back", } msg := &proto3pb.TestAllTypes{MapStringString: strMap} - reg := newTestRegistry(t, ProtoTypes(msg)) + reg := newTestRegistry(t, ProtoTypeDefs(msg)) obj := reg.NativeToValue(msg).(traits.Indexer) // Test a simple proto map of string string. field := obj.Get(String("map_string_string")) @@ -974,7 +974,7 @@ func TestProtoMapConvertToNative_NestedProto(t *testing.T) { }, } msg := &proto3pb.TestAllTypes{MapInt64NestedType: nestedTypeMap} - reg := newTestRegistry(t, ProtoTypes(msg)) + reg := newTestRegistry(t, ProtoTypeDefs(msg)) obj := reg.NativeToValue(msg).(traits.Indexer) // Test a simple proto map of string string. field := obj.Get(String("map_int64_nested_type")) diff --git a/common/types/object_test.go b/common/types/object_test.go index 1f89cd165..b2e2207ea 100644 --- a/common/types/object_test.go +++ b/common/types/object_test.go @@ -69,7 +69,7 @@ func TestProtoObjectConvertToNative(t *testing.T) { outValue map[string]any }{ { - opts: []RegistryOption{ProtoTypes(&exprpb.Expr{}), JSONFieldNames(true)}, + opts: []RegistryOption{ProtoTypeDefs(&exprpb.Expr{}), JSONFieldNames(true)}, fieldMap: func(reg *Registry) map[string]ref.Val { return map[string]ref.Val{ "expr": reg.NativeToValue(msg.GetExpr()), @@ -93,7 +93,7 @@ func TestProtoObjectConvertToNative(t *testing.T) { }, }, { - opts: []RegistryOption{ProtoTypes(&exprpb.Expr{}), JSONFieldNames(false)}, + opts: []RegistryOption{ProtoTypeDefs(&exprpb.Expr{}), JSONFieldNames(false)}, fieldMap: func(reg *Registry) map[string]ref.Val { return map[string]ref.Val{ "expr": reg.NativeToValue(msg.GetExpr()), @@ -186,7 +186,7 @@ func TestProtoObjectIsSet(t *testing.T) { LineOffsets: []int32{1, 2, 3}, }, } - reg := newTestRegistry(t, ProtoTypes(msg)) + reg := newTestRegistry(t, ProtoTypeDefs(msg)) objVal := reg.NativeToValue(msg).(*protoObj) if objVal.IsSet(String("source_info")) != True { t.Error("got 'source_info' not set, wanted set") @@ -203,7 +203,7 @@ func TestProtoObjectIsSet(t *testing.T) { } func TestProtoObjectIsZeroValue(t *testing.T) { - reg := newTestRegistry(t, ProtoTypes(&exprpb.ParsedExpr{})) + reg := newTestRegistry(t, ProtoTypeDefs(&exprpb.ParsedExpr{})) emptyObj := reg.NativeToValue(&exprpb.ParsedExpr{}) pb, ok := emptyObj.(traits.Zeroer) if !ok { @@ -225,7 +225,7 @@ func TestProtoObjectGet(t *testing.T) { LineOffsets: []int32{1, 2, 3}, }, } - reg := newTestRegistry(t, ProtoTypes(msg)) + reg := newTestRegistry(t, ProtoTypeDefs(msg)) objVal := reg.NativeToValue(msg).(*protoObj) if objVal.Get(String("source_info")).Equal(reg.NativeToValue(msg.GetSourceInfo())) != True { t.Error("could not get 'source_info'") @@ -247,7 +247,7 @@ func TestProtoObjectConvertToType(t *testing.T) { LineOffsets: []int32{1, 2, 3}, }, } - reg := newTestRegistry(t, ProtoTypes(msg)) + reg := newTestRegistry(t, ProtoTypeDefs(msg)) objVal := reg.NativeToValue(msg) tv := objVal.Type().(ref.Val) if objVal.ConvertToType(TypeType).Equal(tv) != True { diff --git a/common/types/provider.go b/common/types/provider.go index ebfd66dcb..13680a231 100644 --- a/common/types/provider.go +++ b/common/types/provider.go @@ -96,7 +96,7 @@ type Registry struct { // provider which can create new instances of the provided message or any // message that proto depends upon in its FileDescriptor. func NewRegistry(types ...proto.Message) (*Registry, error) { - return NewProtoRegistry(ProtoTypes(types...)) + return NewProtoRegistry(ProtoTypeDefs(types...)) } // RegistryOption configures the behavior of the registry. @@ -105,23 +105,13 @@ type RegistryOption func(r *Registry) (*Registry, error) // JSONFieldNames configures JSON field name support within the protobuf types in the registry. func JSONFieldNames(enabled bool) RegistryOption { return func(r *Registry) (*Registry, error) { - if enabled != r.pbdb.JSONFieldNames() { - newDB := pb.NewDb(pb.JSONFieldNames(enabled)) - files := r.pbdb.FileDescriptions() - for _, fd := range files { - _, err := newDB.RegisterDescriptor(fd.FileDescriptor()) - if err != nil { - return nil, err - } - } - r.pbdb = newDB - } - return r, nil + err := r.WithJSONFieldNames(enabled) + return r, err } } -// ProtoTypes creates a RegistryOption which registers the individual proto messages with the registry. -func ProtoTypes(types ...proto.Message) RegistryOption { +// ProtoTypeDefs creates a RegistryOption which registers the individual proto messages with the registry. +func ProtoTypeDefs(types ...proto.Message) RegistryOption { return func(r *Registry) (*Registry, error) { for _, msgType := range types { err := r.RegisterMessage(msgType) @@ -191,6 +181,28 @@ func (p *Registry) Copy() *Registry { return copy } +// JSONFieldNames returns whether json field names are enabled in this registry. +func (p *Registry) JSONFieldNames() bool { + return p.pbdb.JSONFieldNames() +} + +// WithJSONFieldNames configures the registry with the JSON field name support enabled or disabled. +func (p *Registry) WithJSONFieldNames(enabled bool) error { + if enabled == p.pbdb.JSONFieldNames() { + return nil + } + newDB := pb.NewDb(pb.JSONFieldNames(enabled)) + files := p.pbdb.FileDescriptions() + for _, fd := range files { + _, err := newDB.RegisterDescriptor(fd.FileDescriptor()) + if err != nil { + return err + } + } + p.pbdb = newDB + return nil +} + // EnumValue returns the numeric value of the given enum value name. func (p *Registry) EnumValue(enumName string) ref.Val { enumVal, found := p.pbdb.DescribeEnum(enumName) diff --git a/common/types/provider_test.go b/common/types/provider_test.go index 9189df449..552ac1836 100644 --- a/common/types/provider_test.go +++ b/common/types/provider_test.go @@ -162,7 +162,7 @@ func TestRegistryFindStructFieldNames(t *testing.T) { tc := tst t.Run(fmt.Sprintf("%s", tc.typeName), func(t *testing.T) { reg := newTestRegistry(t, - ProtoTypes(&exprpb.Decl{}, &exprpb.Reference{}), + ProtoTypeDefs(&exprpb.Decl{}, &exprpb.Reference{}), JSONFieldNames(tc.jsonFieldNames)) fields, _ := reg.FindStructFieldNames(tc.typeName) sort.Strings(fields) @@ -299,7 +299,7 @@ func TestRegistryFindStructFieldType(t *testing.T) { } func TestRegistryNewValue(t *testing.T) { - reg := newTestRegistry(t, ProtoTypes(&proto3pb.TestAllTypes{}, &exprpb.SourceInfo{})) + reg := newTestRegistry(t, ProtoTypeDefs(&proto3pb.TestAllTypes{}, &exprpb.SourceInfo{})) tests := []struct { typeName string fields map[string]ref.Val @@ -427,7 +427,7 @@ func TestRegistryNewValue(t *testing.T) { } func TestRegistryNewValueErrors(t *testing.T) { - reg := newTestRegistry(t, ProtoTypes(&proto3pb.TestAllTypes{}, &exprpb.SourceInfo{})) + reg := newTestRegistry(t, ProtoTypeDefs(&proto3pb.TestAllTypes{}, &exprpb.SourceInfo{})) tests := []struct { typeName string fields map[string]ref.Val @@ -504,7 +504,7 @@ func TestRegistryNewValueErrors(t *testing.T) { } func TestRegistryGetters(t *testing.T) { - reg := newTestRegistry(t, ProtoTypes(&exprpb.ParsedExpr{})) + reg := newTestRegistry(t, ProtoTypeDefs(&exprpb.ParsedExpr{})) if sourceInfo := reg.NewValue( "google.api.expr.v1alpha1.SourceInfo", map[string]ref.Val{ @@ -540,7 +540,7 @@ func TestRegistryGetters(t *testing.T) { } func TestConvertToNative(t *testing.T) { - reg := newTestRegistry(t, ProtoTypes(&exprpb.ParsedExpr{})) + reg := newTestRegistry(t, ProtoTypeDefs(&exprpb.ParsedExpr{})) // Core type conversion tests. expectValueToNative(t, True, true) @@ -605,7 +605,7 @@ func TestConvertToNative(t *testing.T) { } func TestNativeToValue_Any(t *testing.T) { - reg := newTestRegistry(t, ProtoTypes(&exprpb.ParsedExpr{})) + reg := newTestRegistry(t, ProtoTypeDefs(&exprpb.ParsedExpr{})) // NullValue anyValue, err := NullValue.ConvertToNative(anyValueType) if err != nil { @@ -666,7 +666,7 @@ func TestNativeToValue_Any(t *testing.T) { } func TestNativeToValue_Json(t *testing.T) { - reg := newTestRegistry(t, ProtoTypes(&exprpb.ParsedExpr{})) + reg := newTestRegistry(t, ProtoTypeDefs(&exprpb.ParsedExpr{})) // Json primitive conversion test. expectNativeToValue(t, structpb.NewBoolValue(false), False) expectNativeToValue(t, structpb.NewNumberValue(1.1), Double(1.1)) @@ -856,7 +856,7 @@ func expectValueToNative(t *testing.T, in ref.Val, out any) { func expectNativeToValue(t *testing.T, in any, out ref.Val) { t.Helper() - reg := newTestRegistry(t, ProtoTypes(&exprpb.ParsedExpr{})) + reg := newTestRegistry(t, ProtoTypeDefs(&exprpb.ParsedExpr{})) if val := reg.NativeToValue(in); IsError(val) { t.Error(val) } else { diff --git a/interpreter/attributes_test.go b/interpreter/attributes_test.go index 0e16f208f..dd879717b 100644 --- a/interpreter/attributes_test.go +++ b/interpreter/attributes_test.go @@ -397,7 +397,7 @@ func TestAttributesConditionalAttrFalseBranch(t *testing.T) { } func TestAttributesOptional(t *testing.T) { - reg := newTestRegistry(t, &proto3pb.TestAllTypes{}) + reg := newTestRegistry(t, types.ProtoTypeDefs(&proto3pb.TestAllTypes{})) cont, err := containers.NewContainer(containers.Name("ns")) if err != nil { t.Fatalf("") @@ -778,7 +778,7 @@ func BenchmarkResolverFieldQualifier(b *testing.B) { }, }, } - reg := newTestRegistry(b, msg) + reg := newTestRegistry(b, types.ProtoTypeDefs(msg)) attrs := NewAttributeFactory(containers.DefaultContainer, reg, reg) vars, _ := NewActivation(map[string]any{ "msg": msg, diff --git a/interpreter/interpreter_test.go b/interpreter/interpreter_test.go index 400ac6b0e..651bc2362 100644 --- a/interpreter/interpreter_test.go +++ b/interpreter/interpreter_test.go @@ -25,8 +25,6 @@ import ( "testing" "time" - "google.golang.org/protobuf/proto" - "github.com/google/cel-go/checker" "github.com/google/cel-go/common" "github.com/google/cel-go/common/ast" @@ -54,7 +52,7 @@ type testCase struct { expr string container string abbrevs []string - types []proto.Message + typeOpts []types.RegistryOption vars []*decls.VariableDecl funcs []*decls.FunctionDecl attrs AttributeFactory @@ -600,7 +598,7 @@ func testData(t testing.TB) []testCase { { name: "literal_pb3_msg", container: "google.api.expr", - types: []proto.Message{&exprpb.Expr{}}, + typeOpts: []types.RegistryOption{types.ProtoTypeDefs(&exprpb.Expr{})}, expr: `v1alpha1.Expr{ id: 1, const_expr: v1alpha1.Constant{ @@ -616,7 +614,7 @@ func testData(t testing.TB) []testCase { { name: "literal_pb_enum", container: "google.expr.proto3.test", - types: []proto.Message{&proto3pb.TestAllTypes{}}, + typeOpts: []types.RegistryOption{types.ProtoTypeDefs(&proto3pb.TestAllTypes{})}, expr: `TestAllTypes{ repeated_nested_enum: [ 0, @@ -637,7 +635,7 @@ func testData(t testing.TB) []testCase { { name: "literal_pb_wrapper_assign", container: "google.expr.proto3.test", - types: []proto.Message{&proto3pb.TestAllTypes{}}, + typeOpts: []types.RegistryOption{types.ProtoTypeDefs(&proto3pb.TestAllTypes{})}, expr: `TestAllTypes{ single_int64_wrapper: 10, single_int32_wrapper: TestAllTypes{}.single_int32_wrapper, @@ -649,7 +647,7 @@ func testData(t testing.TB) []testCase { { name: "literal_pb_wrapper_assign_roundtrip", container: "google.expr.proto3.test", - types: []proto.Message{&proto3pb.TestAllTypes{}}, + typeOpts: []types.RegistryOption{types.ProtoTypeDefs(&proto3pb.TestAllTypes{})}, expr: `TestAllTypes{ single_int32_wrapper: TestAllTypes{}.single_int32_wrapper, }.single_int32_wrapper == null`, @@ -658,7 +656,7 @@ func testData(t testing.TB) []testCase { { name: "literal_pb_list_assign_null_wrapper", container: "google.expr.proto3.test", - types: []proto.Message{&proto3pb.TestAllTypes{}}, + typeOpts: []types.RegistryOption{types.ProtoTypeDefs(&proto3pb.TestAllTypes{})}, expr: `TestAllTypes{ repeated_int32: [123, 456, TestAllTypes{}.single_int32_wrapper], }`, @@ -667,7 +665,7 @@ func testData(t testing.TB) []testCase { { name: "literal_pb_map_assign_null_entry_value", container: "google.expr.proto3.test", - types: []proto.Message{&proto3pb.TestAllTypes{}}, + typeOpts: []types.RegistryOption{types.ProtoTypeDefs(&proto3pb.TestAllTypes{})}, expr: `TestAllTypes{ map_string_string: { 'hello': 'world', @@ -679,7 +677,7 @@ func testData(t testing.TB) []testCase { { name: "unset_wrapper_access", container: "google.expr.proto3.test", - types: []proto.Message{&proto3pb.TestAllTypes{}}, + typeOpts: []types.RegistryOption{types.ProtoTypeDefs(&proto3pb.TestAllTypes{})}, expr: `TestAllTypes{}.single_string_wrapper`, out: types.NullValue, }, @@ -803,7 +801,7 @@ func testData(t testing.TB) []testCase { { name: "macro_has_pb2_field_undefined", container: "google.expr.proto2.test", - types: []proto.Message{&proto2pb.TestAllTypes{}}, + typeOpts: []types.RegistryOption{types.ProtoTypeDefs(&proto2pb.TestAllTypes{})}, unchecked: true, expr: `has(TestAllTypes{}.invalid_field)`, err: "no such field 'invalid_field'", @@ -811,7 +809,7 @@ func testData(t testing.TB) []testCase { { name: "macro_has_pb2_field", container: "google.expr.proto2.test", - types: []proto.Message{&proto2pb.TestAllTypes{}}, + typeOpts: []types.RegistryOption{types.ProtoTypeDefs(&proto2pb.TestAllTypes{})}, vars: []*decls.VariableDecl{ decls.NewVariable("pb2", types.NewObjectType("google.expr.proto2.test.TestAllTypes")), }, @@ -836,8 +834,37 @@ func testData(t testing.TB) []testCase { && !has(pb2.map_string_string)`, }, { - name: "macro_has_pb3_field", - types: []proto.Message{&proto3pb.TestAllTypes{}}, + name: "macro_has_pb2_field_json", + container: "google.expr.proto2.test", + typeOpts: []types.RegistryOption{types.JSONFieldNames(true), types.ProtoTypeDefs(&proto2pb.TestAllTypes{})}, + vars: []*decls.VariableDecl{ + decls.NewVariable("pb2", types.NewObjectType("google.expr.proto2.test.TestAllTypes")), + }, + in: map[string]any{ + "pb2": &proto2pb.TestAllTypes{ + RepeatedBool: []bool{false}, + MapInt64NestedType: map[int64]*proto2pb.NestedTestAllTypes{ + 1: {}, + }, + MapStringString: map[string]string{}, + }, + }, + expr: `has(TestAllTypes{standaloneEnum: TestAllTypes.NestedEnum.BAR}.standaloneEnum) + && has(TestAllTypes{standaloneEnum: TestAllTypes.NestedEnum.FOO}.standaloneEnum) + && !has(TestAllTypes{singleNestedEnum: TestAllTypes.NestedEnum.FOO}.singleNestedMessage) + && has(TestAllTypes{singleNestedEnum: TestAllTypes.NestedEnum.FOO}.singleNestedEnum) + && !has(TestAllTypes{}.singleNestedMessage) + && has(TestAllTypes{singleNestedMessage: TestAllTypes.NestedMessage{}}.singleNestedMessage) + && !has(TestAllTypes{}.standaloneEnum) + && !has(pb2.singleInt64) + && has(pb2.repeatedBool) + && !has(pb2.repeatedInt32) + && has(pb2.mapInt64NestedType) + && !has(pb2.mapStringString)`, + }, + { + name: "macro_has_pb3_field", + typeOpts: []types.RegistryOption{types.ProtoTypeDefs(&proto3pb.TestAllTypes{})}, vars: []*decls.VariableDecl{ decls.NewVariable("pb3", types.NewObjectType("google.expr.proto3.test.TestAllTypes")), }, @@ -864,6 +891,35 @@ func testData(t testing.TB) []testCase { && has(pb3.map_int64_nested_type) && !has(pb3.map_string_string)`, }, + { + name: "macro_has_pb3_field_json", + typeOpts: []types.RegistryOption{types.JSONFieldNames(true), types.ProtoTypeDefs(&proto3pb.TestAllTypes{})}, + vars: []*decls.VariableDecl{ + decls.NewVariable("pb3", types.NewObjectType("google.expr.proto3.test.TestAllTypes")), + }, + container: "google.expr.proto3.test", + in: map[string]any{ + "pb3": &proto3pb.TestAllTypes{ + RepeatedBool: []bool{false}, + MapInt64NestedType: map[int64]*proto3pb.NestedTestAllTypes{ + 1: {}, + }, + MapStringString: map[string]string{}, + }, + }, + expr: `has(TestAllTypes{standaloneEnum: TestAllTypes.NestedEnum.BAR}.standaloneEnum) + && !has(TestAllTypes{standaloneEnum: TestAllTypes.NestedEnum.FOO}.standaloneEnum) + && !has(TestAllTypes{singleNestedEnum: TestAllTypes.NestedEnum.FOO}.singleNestedMessage) + && has(TestAllTypes{singleNestedEnum: TestAllTypes.NestedEnum.FOO}.singleNestedEnum) + && !has(TestAllTypes{}.singleNestedMessage) + && has(TestAllTypes{singleNestedMessage: TestAllTypes.NestedMessage{}}.singleNestedMessage) + && !has(TestAllTypes{}.standaloneEnum) + && !has(pb3.singleInt64) + && has(pb3.repeatedBool) + && !has(pb3.repeatedInt32) + && has(pb3.mapInt64NestedType) + && !has(pb3.mapStringString)`, + }, { name: "macro_map", expr: `[1, 2, 3].map(x, x * 2) == [2, 4, 6]`, @@ -907,9 +963,9 @@ func testData(t testing.TB) []testCase { progErr: "unexpected ): `)k.*`", }, { - name: "nested_proto_field", - expr: `pb3.single_nested_message.bb`, - types: []proto.Message{&proto3pb.TestAllTypes{}}, + name: "nested_proto_field", + expr: `pb3.single_nested_message.bb`, + typeOpts: []types.RegistryOption{types.ProtoTypeDefs(&proto3pb.TestAllTypes{})}, vars: []*decls.VariableDecl{ decls.NewVariable("pb3", types.NewObjectType("google.expr.proto3.test.TestAllTypes")), @@ -926,9 +982,9 @@ func testData(t testing.TB) []testCase { out: types.Int(1234), }, { - name: "nested_proto_field_with_index", - expr: `pb3.map_int64_nested_type[0].child.payload.single_int32 == 1`, - types: []proto.Message{&proto3pb.TestAllTypes{}}, + name: "nested_proto_field_with_index", + expr: `pb3.map_int64_nested_type[0].child.payload.single_int32 == 1`, + typeOpts: []types.RegistryOption{types.ProtoTypeDefs(&proto3pb.TestAllTypes{})}, vars: []*decls.VariableDecl{ decls.NewVariable("pb3", types.NewObjectType("google.expr.proto3.test.TestAllTypes")), @@ -1175,7 +1231,7 @@ func testData(t testing.TB) []testCase { && pb3.repeated_nested_enum[0] == test.TestAllTypes.NestedEnum.BAR && json.list[0] == 'world'`, container: "google.expr.proto3", - types: []proto.Message{&proto3pb.TestAllTypes{}}, + typeOpts: []types.RegistryOption{types.ProtoTypeDefs(&proto3pb.TestAllTypes{})}, vars: []*decls.VariableDecl{ decls.NewVariable("a.b", types.NewMapType(types.StringType, types.BoolType)), decls.NewVariable("pb3", types.NewObjectType("google.expr.proto3.test.TestAllTypes")), @@ -1217,7 +1273,7 @@ func testData(t testing.TB) []testCase { && a.single_double == 6.4 && a.single_bool && "empty" == a.single_string`, - types: []proto.Message{&proto2pb.TestAllTypes{}}, + typeOpts: []types.RegistryOption{types.ProtoTypeDefs(&proto2pb.TestAllTypes{})}, in: map[string]any{ "a": &proto2pb.TestAllTypes{}, }, @@ -1232,8 +1288,8 @@ func testData(t testing.TB) []testCase { && has(a.single_int64_wrapper) && a.single_int64_wrapper == 0 && has(a.single_string_wrapper) && a.single_string_wrapper == "hello" && a.single_int64_wrapper == Int32Value{value: 0}`, - types: []proto.Message{&proto3pb.TestAllTypes{}}, - abbrevs: []string{"google.protobuf.Int32Value"}, + typeOpts: []types.RegistryOption{types.ProtoTypeDefs(&proto3pb.TestAllTypes{})}, + abbrevs: []string{"google.protobuf.Int32Value"}, vars: []*decls.VariableDecl{ decls.NewVariable("a", types.NewObjectType("google.expr.proto3.test.TestAllTypes")), }, @@ -1248,7 +1304,7 @@ func testData(t testing.TB) []testCase { name: "select_pb3_compare", expr: `a.single_uint64 > 3u`, container: "google.expr.proto3.test", - types: []proto.Message{&proto3pb.TestAllTypes{}}, + typeOpts: []types.RegistryOption{types.ProtoTypeDefs(&proto3pb.TestAllTypes{})}, vars: []*decls.VariableDecl{ decls.NewVariable("a", types.NewObjectType("google.expr.proto3.test.TestAllTypes")), }, @@ -1263,7 +1319,7 @@ func testData(t testing.TB) []testCase { name: "select_custom_pb3_compare", expr: `a.bb > 100`, container: "google.expr.proto3.test", - types: []proto.Message{&proto3pb.TestAllTypes_NestedMessage{}}, + typeOpts: []types.RegistryOption{types.ProtoTypeDefs(&proto3pb.TestAllTypes_NestedMessage{})}, vars: []*decls.VariableDecl{ decls.NewVariable("a", types.NewObjectType("google.expr.proto3.test.TestAllTypes.NestedMessage")), @@ -1271,8 +1327,8 @@ func testData(t testing.TB) []testCase { attrs: &custAttrFactory{ AttributeFactory: NewAttributeFactory( testContainer("google.expr.proto3.test"), - newTestRegistry(t, &proto3pb.TestAllTypes_NestedMessage{}), - newTestRegistry(t, &proto3pb.TestAllTypes_NestedMessage{}), + newTestRegistry(t, types.ProtoTypeDefs(&proto3pb.TestAllTypes_NestedMessage{})), + newTestRegistry(t, types.ProtoTypeDefs(&proto3pb.TestAllTypes_NestedMessage{})), ), }, in: map[string]any{ @@ -1286,7 +1342,7 @@ func testData(t testing.TB) []testCase { name: "select_custom_pb3_optional_field", expr: `a.?bb`, container: "google.expr.proto3.test", - types: []proto.Message{&proto3pb.TestAllTypes_NestedMessage{}}, + typeOpts: []types.RegistryOption{types.ProtoTypeDefs(&proto3pb.TestAllTypes_NestedMessage{})}, vars: []*decls.VariableDecl{ decls.NewVariable("a", types.NewObjectType("google.expr.proto3.test.TestAllTypes.NestedMessage")), @@ -1294,8 +1350,8 @@ func testData(t testing.TB) []testCase { attrs: &custAttrFactory{ AttributeFactory: NewAttributeFactory( testContainer("google.expr.proto3.test"), - newTestRegistry(t, &proto3pb.TestAllTypes_NestedMessage{}), - newTestRegistry(t, &proto3pb.TestAllTypes_NestedMessage{}), + newTestRegistry(t, types.ProtoTypeDefs(&proto3pb.TestAllTypes_NestedMessage{})), + newTestRegistry(t, types.ProtoTypeDefs(&proto3pb.TestAllTypes_NestedMessage{})), ), }, in: map[string]any{ @@ -1345,7 +1401,7 @@ func testData(t testing.TB) []testCase { { name: "select_empty_repeated_nested", expr: `TestAllTypes{}.repeated_nested_message.size() == 0`, - types: []proto.Message{&proto3pb.TestAllTypes{}}, + typeOpts: []types.RegistryOption{types.ProtoTypeDefs(&proto3pb.TestAllTypes{})}, container: "google.expr.proto3.test", out: types.True, }, @@ -1428,7 +1484,7 @@ func testData(t testing.TB) []testCase { name: "literal_pb_optional_field", expr: `TestAllTypes{?single_int32: {'value': 1}.?value, ?single_string: {}.?missing}`, container: "google.expr.proto3.test", - types: []proto.Message{&proto3pb.TestAllTypes{}}, + typeOpts: []types.RegistryOption{types.ProtoTypeDefs(&proto3pb.TestAllTypes{})}, out: &proto3pb.TestAllTypes{ SingleInt32: 1, }, @@ -1437,7 +1493,7 @@ func testData(t testing.TB) []testCase { name: "literal_pb_optional_field_bad_init", expr: `TestAllTypes{?single_int32: 1}`, container: "google.expr.proto3.test", - types: []proto.Message{&proto3pb.TestAllTypes{}}, + typeOpts: []types.RegistryOption{types.ProtoTypeDefs(&proto3pb.TestAllTypes{})}, unchecked: true, err: `cannot initialize optional entry 'single_int32' from non-optional`, }, @@ -1745,9 +1801,9 @@ func TestInterpreter(t *testing.T) { func TestInterpreter_ProtoAttributeOpt(t *testing.T) { inst, _, err := program(t, &testCase{ - name: "nested_proto_field_with_index", - expr: `pb3.map_int64_nested_type[0].child.payload.single_int32`, - types: []proto.Message{&proto3pb.TestAllTypes{}}, + name: "nested_proto_field_with_index", + expr: `pb3.map_int64_nested_type[0].child.payload.single_int32`, + typeOpts: []types.RegistryOption{types.ProtoTypeDefs(&proto3pb.TestAllTypes{})}, vars: []*decls.VariableDecl{ decls.NewVariable("pb3", types.NewObjectType("google.expr.proto3.test.TestAllTypes")), @@ -1806,7 +1862,7 @@ func TestInterpreter_ExhaustiveConditionalExpr(t *testing.T) { parsed := testMustParse(t, `a ? b < 1.0 : c == ['hello']`) state := NewEvalState() cont := containers.DefaultContainer - reg := newTestRegistry(t, &exprpb.ParsedExpr{}) + reg := newTestRegistry(t, types.ProtoTypeDefs(&exprpb.ParsedExpr{})) attrs := NewAttributeFactory(cont, reg, reg) intr := newStandardInterpreter(t, cont, reg, reg, attrs) interpretable, _ := intr.NewInterpretable(parsed, ExhaustiveEval(), @@ -1908,7 +1964,7 @@ func TestInterpreter_ExhaustiveLogicalOrEquals(t *testing.T) { // Operator "==" is at Expr 4, should be evaluated though "a" is true parsed := testMustParse(t, `a || b == "b"`) state := NewEvalState() - reg := newTestRegistry(t, &exprpb.Expr{}) + reg := newTestRegistry(t, types.ProtoTypeDefs(&exprpb.Expr{})) cont := testContainer("test") attrs := NewAttributeFactory(cont, reg, reg) interp := newStandardInterpreter(t, cont, reg, reg, attrs) @@ -1944,7 +2000,7 @@ func TestInterpreter_SetProto2PrimitiveFields(t *testing.T) { }`) parsed := testMustParse(t, src) cont := testContainer("google.expr.proto2.test") - reg := newTestRegistry(t, &proto2pb.TestAllTypes{}) + reg := newTestRegistry(t, types.ProtoTypeDefs(&proto2pb.TestAllTypes{})) env := newTestEnv(t, cont, reg) env.AddIdents( decls.NewVariable("input", @@ -2229,8 +2285,8 @@ func program(t testing.TB, tst *testCase, opts ...PlannerOption) (Interpretable, var reg *types.Registry var env *checker.Env reg = newTestRegistry(t) - if tst.types != nil { - reg = newTestRegistry(t, tst.types...) + if tst.typeOpts != nil { + reg = newTestRegistry(t, tst.typeOpts...) } env = newTestEnv(t, cont, reg) attrs := NewAttributeFactory(cont, reg, reg) @@ -2361,11 +2417,11 @@ func newTestEnv(t testing.TB, cont *containers.Container, reg *types.Registry) * return env } -func newTestRegistry(t testing.TB, msgs ...proto.Message) *types.Registry { +func newTestRegistry(t testing.TB, opts ...types.RegistryOption) *types.Registry { t.Helper() - reg, err := types.NewRegistry(msgs...) + reg, err := types.NewProtoRegistry(opts...) if err != nil { - t.Fatalf("types.NewRegistry(%v) failed: %v", msgs, err) + t.Fatalf("types.NewProtoRegistry() failed: %v", err) } return reg } diff --git a/interpreter/prune_test.go b/interpreter/prune_test.go index e333c1ef5..24639322d 100644 --- a/interpreter/prune_test.go +++ b/interpreter/prune_test.go @@ -491,7 +491,7 @@ func TestPrune(t *testing.T) { t.Fatalf("Parse(%q) failed: %v", tst.expr, iss.ToDisplayString()) } state := NewEvalState() - reg := newTestRegistry(t, &proto3pb.TestAllTypes{}) + reg := newTestRegistry(t, types.ProtoTypeDefs(&proto3pb.TestAllTypes{})) attrs := NewPartialAttributeFactory(containers.DefaultContainer, reg, reg) dispatcher := NewDispatcher() addFunctionBindings(t, dispatcher) diff --git a/interpreter/runtimecost_test.go b/interpreter/runtimecost_test.go index 54e10f142..a19308c66 100644 --- a/interpreter/runtimecost_test.go +++ b/interpreter/runtimecost_test.go @@ -123,7 +123,7 @@ func computeCost(t *testing.T, expr string, vars []*decls.VariableDecl, ctx Acti } cont := containers.DefaultContainer - reg := newTestRegistry(t, &proto3pb.TestAllTypes{}) + reg := newTestRegistry(t, types.ProtoTypeDefs(&proto3pb.TestAllTypes{})) attrs := NewAttributeFactory(cont, reg, reg) env := newTestEnv(t, cont, reg) err = env.AddIdents(vars...)