Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
175 changes: 175 additions & 0 deletions cel/cel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2234,6 +2234,62 @@ func TestContextProto(t *testing.T) {
}
}

func TestContextProtoJSONFieldNames(t *testing.T) {
descriptor := new(proto3pb.TestAllTypes).ProtoReflect().Descriptor()
env := testEnv(t, JSONFieldNames(true), DeclareContextProto(descriptor))
expression := `
singleInt64 == 1
&& singleDouble == 1.0
&& singleBool == true
&& singleString == ''
&& singleNestedMessage == google.expr.proto3.test.TestAllTypes.NestedMessage{}
&& standaloneEnum == google.expr.proto3.test.TestAllTypes.NestedEnum.FOO
&& singleDuration == duration('5s')
&& singleTimestamp == timestamp(63154820)
&& singleAny == null
&& singleUint32Wrapper == null
&& singleUint64Wrapper == 0u
&& repeatedInt32 == [1,2]
&& mapStringString == {'': ''}
&& mapInt64NestedType == {0 : google.expr.proto3.test.NestedTestAllTypes{}}`
ast, iss := env.Compile(expression)
if iss.Err() != nil {
t.Fatalf("env.Compile(%s) failed: %s", expression, iss.Err())
}
prg, err := env.Program(ast)
if err != nil {
t.Fatalf("env.Program() failed: %v", err)
}
in := &proto3pb.TestAllTypes{
SingleInt64: 1,
SingleDouble: 1.0,
SingleBool: true,
NestedType: &proto3pb.TestAllTypes_SingleNestedMessage{
SingleNestedMessage: &proto3pb.TestAllTypes_NestedMessage{},
},
StandaloneEnum: proto3pb.TestAllTypes_FOO,
SingleDuration: &durationpb.Duration{Seconds: 5},
SingleTimestamp: &timestamppb.Timestamp{
Seconds: 63154820,
},
SingleUint64Wrapper: wrapperspb.UInt64(0),
RepeatedInt32: []int32{1, 2},
MapStringString: map[string]string{"": ""},
MapInt64NestedType: map[int64]*proto3pb.NestedTestAllTypes{0: {}},
}
vars, err := ContextProtoVars(in, types.JSONFieldNames(true))
if err != nil {
t.Fatalf("ContextProtoVars(%v) failed: %v", in, err)
}
out, _, err := prg.Eval(vars)
if err != nil {
t.Fatalf("prg.Eval() failed: %v", err)
}
if out.Equal(types.True) != types.True {
t.Errorf("prg.Eval() got %v, wanted true", out)
}
}

func TestRegexOptimizer(t *testing.T) {
var stringTests = []struct {
expr string
Expand Down Expand Up @@ -3607,6 +3663,125 @@ func TestAstProgramNilValue(t *testing.T) {
}
}

func TestJSONFieldNames(t *testing.T) {
tests := []struct {
name string
expr string
jsonFieldNames bool
}{
{
name: "proto simple field",
expr: `msg.single_int32 == 1`,
},
{
name: "proto map field",
expr: `msg.map_string_string['key'] == 'value'`,
},
{
name: "json simple field",
expr: `msg.singleInt32 == 1`,
jsonFieldNames: true,
},
{
name: "json repeated field",
expr: `msg.mapStringString['key'] == 'value'`,
jsonFieldNames: true,
},
{
name: "message with json field",
expr: `TestAllTypes{singleInt32: 1} != msg`,
jsonFieldNames: true,
},
{
name: "message with json field and proto fallback",
expr: `dyn(TestAllTypes{singleInt32: 2}).single_int32 == 2`,
jsonFieldNames: true,
},
{
name: "json with proto fallback",
expr: `dyn(msg).single_int32 == dyn(msg).singleInt32`,
jsonFieldNames: true,
},
}
msg := &proto3pb.TestAllTypes{
SingleInt32: 1,
MapStringString: map[string]string{
"key": "value",
},
}
for _, tst := range tests {
tc := tst
t.Run(tc.name, func(t *testing.T) {
env, err := NewEnv(
JSONFieldNames(tc.jsonFieldNames),
Types(msg),
Container(string(msg.ProtoReflect().Descriptor().ParentFile().Package())),
Variable("msg", ObjectType(string(msg.ProtoReflect().Descriptor().FullName()))),
)
if err != nil {
t.Fatalf("NewEnv() failed: %v", err)
}
ast, iss := env.Compile(tc.expr)
if iss.Err() != nil {
t.Fatalf("env.Compile() failed: %v", iss.Err())
}
prg, err := env.Program(ast)
if err != nil {
t.Fatalf("env.Program() failed: %v", err)
}
out, _, err := prg.Eval(map[string]any{"msg": msg})
if err != nil {
t.Fatalf("prg.Eval() failed: %v", err)
}
if out != types.True {
t.Errorf("prg.Eval() got %v, wanted 'true'", out)
}

if tc.jsonFieldNames {
noJSONEnv, err := env.Extend(JSONFieldNames(false))
if err != nil {
t.Fatalf("env.Extend() failed: %v", err)
}
_, err = noJSONEnv.Program(ast)
if err == nil {
t.Fatal("env with json disabled allowed program with json extension to be planned")
}
} else {
jsonEnv, err := env.Extend(JSONFieldNames(true))
if err != nil {
t.Fatalf("env.Extend() failed: %v", err)
}
prg, err = jsonEnv.Program(ast)
if err != nil {
t.Fatalf("env.Program() failed: %v", err)
}
out, _, err := prg.Eval(map[string]any{"msg": msg})
if err != nil {
t.Fatalf("prg.Eval() failed: %v", err)
}
if out != types.True {
t.Errorf("prg.Eval() got %v, wanted 'true'", out)
}
}
})
}
}

func TestJSONFieldNamesInvalidProvider(t *testing.T) {
type wrapperRegistry struct {
*types.Registry
}
reg, err := types.NewProtoRegistry(types.JSONFieldNames(true))
if err != nil {
t.Fatalf("types.NewProtoRegistry() failed: %v", err)
}
wrapped := wrapperRegistry{Registry: reg}
_, err = NewEnv(CustomTypeProvider(wrapped), CustomTypeAdapter(reg), JSONFieldNames(true))
if err == nil {
t.Error("NewEnv() created a CEL environment successfully despite incompatible configs")
}
}

// TODO: ideally testCostEstimator and testRuntimeCostEstimator would be shared in a test fixtures package
type testCostEstimator struct {
hints map[string]uint64
Expand Down
39 changes: 27 additions & 12 deletions cel/env.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,16 @@ func (e *Env) ToConfig(name string) (*env.Config, error) {
conf.AddImports(env.NewImport(typeName))
}

// Serialize features
for featID, enabled := range e.features {
featName, found := featureNameByID(featID)
if !found {
// If the feature isn't named, it isn't intended to be publicly exposed
continue
}
conf.AddFeatures(env.NewFeature(featName, enabled))
}

libOverloads := map[string][]string{}
for libName, lib := range e.libraries {
// Track the options which have been configured by a library and
Expand Down Expand Up @@ -244,7 +254,7 @@ func (e *Env) ToConfig(name string) (*env.Config, error) {
fields := e.contextProto.Fields()
for i := 0; i < fields.Len(); i++ {
field := fields.Get(i)
variable, err := fieldToVariable(field)
variable, err := fieldToVariable(field, e.HasFeature(featureJSONFieldNames))
if err != nil {
return nil, fmt.Errorf("could not serialize context field variable %q, reason: %w", field.FullName(), err)
}
Expand Down Expand Up @@ -279,16 +289,6 @@ func (e *Env) ToConfig(name string) (*env.Config, error) {
}
}

// Serialize features
for featID, enabled := range e.features {
featName, found := featureNameByID(featID)
if !found {
// If the feature isn't named, it isn't intended to be publicly exposed
continue
}
conf.AddFeatures(env.NewFeature(featName, enabled))
}

for id, val := range e.limits {
limitName, found := limitNameByID(id)
if !found || val == 0 {
Expand Down Expand Up @@ -361,7 +361,7 @@ func NewEnv(opts ...EnvOption) (*Env, error) {
// See the EnvOption helper functions for the options that can be used to configure the
// environment.
func NewCustomEnv(opts ...EnvOption) (*Env, error) {
registry, err := types.NewRegistry()
registry, err := types.NewProtoRegistry()
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -554,6 +554,7 @@ func (e *Env) Extend(opts ...EnvOption) (*Env, error) {
}
validatorsCopy := make([]ASTValidator, len(e.validators))
copy(validatorsCopy, e.validators)

costOptsCopy := make([]checker.CostOption, len(e.costOptions))
copy(costOptsCopy, e.costOptions)

Expand Down Expand Up @@ -847,6 +848,18 @@ func (e *Env) configure(opts []EnvOption) (*Env, error) {
return nil, err
}

// Enable JSON field names is using a proto-based *types.Registry
if e.HasFeature(featureJSONFieldNames) {
reg, isReg := e.provider.(*types.Registry)
if !isReg {
return nil, fmt.Errorf("JSONFieldNames() option is only compatible with *types.Registry providers")
}
err := reg.WithJSONFieldNames(true)
if err != nil {
return nil, err
}
}

// Ensure that the checker init happens eagerly rather than lazily.
if e.HasFeature(featureEagerlyValidateDeclarations) {
_, err := e.initChecker()
Expand All @@ -865,6 +878,8 @@ func (e *Env) initChecker() (*checker.Env, error) {
chkOpts = append(chkOpts,
checker.CrossTypeNumericComparisons(
e.HasFeature(featureCrossTypeNumericComparisons)))
chkOpts = append(chkOpts,
checker.JSONFieldNames(e.HasFeature(featureJSONFieldNames)))

ce, err := checker.NewEnv(e.Container, e.provider, chkOpts...)
if err != nil {
Expand Down
34 changes: 29 additions & 5 deletions cel/env_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -198,9 +198,9 @@ func TestEnvPartialVarsError(t *testing.T) {
}

func TestTypeProviderInterop(t *testing.T) {
reg, err := types.NewRegistry(&proto3pb.TestAllTypes{})
reg, err := types.NewProtoRegistry(types.ProtoTypeDefs(&proto3pb.TestAllTypes{}))
if err != nil {
t.Fatalf("types.NewRegistry() failed: %v", err)
t.Fatalf("types.NewProtoRegistry() failed: %v", err)
}
tests := []struct {
name string
Expand Down Expand Up @@ -399,6 +399,14 @@ func TestEnvToConfig(t *testing.T) {
env.NewMemberOverload("string_last", env.NewTypeDesc("string"), []*env.TypeDesc{}, env.NewTypeDesc("string")),
)),
},
{
name: "json field names",
opts: []EnvOption{
JSONFieldNames(true),
},
want: env.NewConfig("json field names").
AddFeatures(env.NewFeature("cel.feature.json_field_names", true)),
},
{
name: "context proto - with extra variable",
opts: []EnvOption{
Expand Down Expand Up @@ -495,7 +503,7 @@ func TestEnvFromConfig(t *testing.T) {
{
name: "std env - imports",
beforeOpts: []EnvOption{Types(&proto3pb.TestAllTypes{})},
conf: env.NewConfig("std env - context proto").
conf: env.NewConfig("std env - imports").
AddImports(env.NewImport("google.expr.proto3.test.TestAllTypes")),
exprs: []exprCase{
{
Expand All @@ -520,6 +528,22 @@ func TestEnvFromConfig(t *testing.T) {
},
},
},
{
name: "std env - context proto w/ json field names",
beforeOpts: []EnvOption{Types(&proto3pb.TestAllTypes{})},
conf: env.NewConfig("std env - context proto w/ json field names").
SetContainer("google.expr.proto3.test").
SetContextVariable(env.NewContextVariable("google.expr.proto3.test.TestAllTypes")).
AddFeatures(env.NewFeature("cel.feature.json_field_names", true)),
exprs: []exprCase{
{
name: "field select literal",
in: mustContextProto(t, &proto3pb.TestAllTypes{SingleInt64: 10}, types.JSONFieldNames(true)),
expr: "TestAllTypes{singleInt64: singleInt64}.singleInt64",
out: types.Int(10),
},
},
},
{
name: "custom env - variables",
beforeOpts: []EnvOption{Types(&proto3pb.TestAllTypes{})},
Expand Down Expand Up @@ -1154,9 +1178,9 @@ func BenchmarkEnvExtendEagerDecls(b *testing.B) {
}
}

func mustContextProto(t *testing.T, pb proto.Message) Activation {
func mustContextProto(t *testing.T, pb proto.Message, opts ...types.RegistryOption) Activation {
t.Helper()
ctx, err := ContextProtoVars(pb)
ctx, err := ContextProtoVars(pb, opts...)
if err != nil {
t.Fatalf("ContextProtoVars() failed: %v", err)
}
Expand Down
Loading