Skip to content

Commit 45c4980

Browse files
Support for feature flags and validators in env.Config (#1132)
* Support for feature flags and validators in env.Config * Minor update to documentation
1 parent 4b27149 commit 45c4980

File tree

10 files changed

+615
-42
lines changed

10 files changed

+615
-42
lines changed

cel/env.go

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,25 @@ func (e *Env) ToConfig(name string) (*env.Config, error) {
263263
}
264264
}
265265

266+
// Serialize validators
267+
for _, val := range e.Validators() {
268+
// Only add configurable validators to the env.Config as all others are
269+
// expected to be implicitly enabled via extension libraries.
270+
if confVal, ok := val.(ConfigurableASTValidator); ok {
271+
conf.AddValidators(confVal.ToConfig())
272+
}
273+
}
274+
275+
// Serialize features
276+
for featID, enabled := range e.features {
277+
featName, found := featureNameByID(featID)
278+
if !found {
279+
// If the feature isn't named, it isn't intended to be publicly exposed
280+
continue
281+
}
282+
conf.AddFeatures(env.NewFeature(featName, enabled))
283+
}
284+
266285
return conf, nil
267286
}
268287

@@ -541,7 +560,7 @@ func (e *Env) Functions() map[string]*decls.FunctionDecl {
541560

542561
// Variables returns the set of variables associated with the environment.
543562
func (e *Env) Variables() []*decls.VariableDecl {
544-
return e.variables
563+
return e.variables[:]
545564
}
546565

547566
// HasValidator returns whether a specific ASTValidator has been configured in the environment.
@@ -554,6 +573,11 @@ func (e *Env) HasValidator(name string) bool {
554573
return false
555574
}
556575

576+
// Validators returns the set of ASTValidators configured on the environment.
577+
func (e *Env) Validators() []ASTValidator {
578+
return e.validators[:]
579+
}
580+
557581
// Parse parses the input expression value `txt` to a Ast and/or a set of Issues.
558582
//
559583
// This form of Parse creates a Source value for the input `txt` and forwards to the

cel/env_test.go

Lines changed: 202 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,8 @@ import (
2424
"testing"
2525

2626
"github.com/google/cel-go/common"
27+
"github.com/google/cel-go/common/ast"
28+
"github.com/google/cel-go/common/decls"
2729
"github.com/google/cel-go/common/env"
2830
"github.com/google/cel-go/common/operators"
2931
"github.com/google/cel-go/common/types"
@@ -401,6 +403,30 @@ func TestEnvToConfig(t *testing.T) {
401403
},
402404
want: env.NewConfig("context proto").SetContextVariable(env.NewContextVariable("google.expr.proto3.test.TestAllTypes")),
403405
},
406+
{
407+
name: "feature flags",
408+
opts: []EnvOption{
409+
DefaultUTCTimeZone(false),
410+
EnableMacroCallTracking(),
411+
},
412+
want: env.NewConfig("feature flags").AddFeatures(
413+
env.NewFeature("cel.feature.macro_call_tracking", true),
414+
),
415+
},
416+
{
417+
name: "validators",
418+
opts: []EnvOption{
419+
ExtendedValidations(),
420+
ASTValidators(ValidateComprehensionNestingLimit(1)),
421+
},
422+
want: env.NewConfig("validators").AddValidators(
423+
env.NewValidator("cel.validator.duration"),
424+
env.NewValidator("cel.validator.timestamp"),
425+
env.NewValidator("cel.validator.matches"),
426+
env.NewValidator("cel.validator.homogeneous_literals"),
427+
env.NewValidator("cel.validator.comprehension_nesting_limit").SetConfig(map[string]any{"limit": 1}),
428+
),
429+
},
404430
}
405431

406432
for _, tst := range tests {
@@ -430,11 +456,12 @@ func TestEnvFromConfig(t *testing.T) {
430456
out ref.Val
431457
}
432458
tests := []struct {
433-
name string
434-
beforeOpts []EnvOption
435-
afterOpts []EnvOption
436-
conf *env.Config
437-
exprs []exprCase
459+
name string
460+
beforeOpts []EnvOption
461+
afterOpts []EnvOption
462+
conf *env.Config
463+
confHandlers []ConfigOptionFactory
464+
exprs []exprCase
438465
}{
439466
{
440467
name: "std env",
@@ -617,18 +644,138 @@ func TestEnvFromConfig(t *testing.T) {
617644
},
618645
},
619646
},
647+
{
648+
name: "extensions - config factory",
649+
conf: env.NewConfig("extensions").
650+
AddExtensions(env.NewExtension("plus", math.MaxUint32)),
651+
confHandlers: []ConfigOptionFactory{
652+
func(a any) (EnvOption, bool) {
653+
ext, ok := a.(*env.Extension)
654+
if !ok || ext.Name != "plus" {
655+
return nil, false
656+
}
657+
return Function("plus", Overload("plus_int_int", []*Type{IntType, IntType}, IntType,
658+
decls.BinaryBinding(func(lhs, rhs ref.Val) ref.Val {
659+
l := lhs.(types.Int)
660+
r := rhs.(types.Int)
661+
return l + r
662+
}))), true
663+
},
664+
},
665+
exprs: []exprCase{
666+
{
667+
name: "plus",
668+
expr: "plus(1, 2)",
669+
out: types.Int(3),
670+
},
671+
},
672+
},
673+
{
674+
name: "features",
675+
conf: env.NewConfig("features").
676+
AddVariables(
677+
env.NewVariable("m",
678+
env.NewTypeDesc("map", env.NewTypeDesc("string"), env.NewTypeDesc("string")))).
679+
AddFeatures(
680+
env.NewFeature("cel.feature.backtick_escape_syntax", true),
681+
env.NewFeature("cel.feature.unknown_feature_name", true)),
682+
exprs: []exprCase{
683+
{
684+
name: "optional key",
685+
expr: "m.`key-name` == 'value'",
686+
in: map[string]any{"m": map[string]string{"key-name": "value"}},
687+
out: types.True,
688+
},
689+
},
690+
},
691+
{
692+
name: "validators",
693+
conf: env.NewConfig("validators").
694+
AddVariables(
695+
env.NewVariable("m",
696+
env.NewTypeDesc("map", env.NewTypeDesc("string"), env.NewTypeDesc("string"))),
697+
).
698+
AddValidators(
699+
env.NewValidator(durationValidatorName),
700+
env.NewValidator(timestampValidatorName),
701+
env.NewValidator(regexValidatorName),
702+
env.NewValidator(homogeneousValidatorName),
703+
env.NewValidator(nestingLimitValidatorName).SetConfig(map[string]any{"limit": 0}),
704+
),
705+
exprs: []exprCase{
706+
{
707+
name: "bad duration",
708+
expr: "duration('1')",
709+
iss: errors.New("invalid duration"),
710+
},
711+
{
712+
name: "bad timestamp",
713+
expr: "timestamp('1')",
714+
iss: errors.New("invalid timestamp"),
715+
},
716+
{
717+
name: "bad regex",
718+
expr: "'hello'.matches('?^()')",
719+
iss: errors.New("invalid matches"),
720+
},
721+
{
722+
name: "mixed type list",
723+
expr: "[1, 2.0]",
724+
iss: errors.New("expected type 'int'"),
725+
},
726+
{
727+
name: "disabled comprehension",
728+
expr: "[1, 2].exists(x, x % 2 == 0)",
729+
iss: errors.New("comprehension exceeds nesting limit"),
730+
},
731+
},
732+
},
733+
{
734+
name: "validators - config factory",
735+
conf: env.NewConfig("validators").
736+
AddValidators(
737+
env.NewValidator("cel.validators.return_type").SetConfig(map[string]any{"type_name": "string"}),
738+
),
739+
confHandlers: []ConfigOptionFactory{
740+
func(a any) (EnvOption, bool) {
741+
val, ok := a.(*env.Validator)
742+
if !ok || val.Name != "cel.validators.return_type" {
743+
return nil, false
744+
}
745+
typeName, found := val.ConfigValue("type_name")
746+
if !found {
747+
return func(*Env) (*Env, error) {
748+
return nil, fmt.Errorf("invalid validator: %s missing config parameter 'type_name'", val.Name)
749+
}, true
750+
}
751+
return func(e *Env) (*Env, error) {
752+
t, err := env.NewTypeDesc(typeName.(string)).AsCELType(e.CELTypeProvider())
753+
if err != nil {
754+
return nil, err
755+
}
756+
return ASTValidators(returnTypeValidator{returnType: t})(e)
757+
}, true
758+
},
759+
},
760+
exprs: []exprCase{
761+
{
762+
name: "string - ok",
763+
expr: "'hello'",
764+
out: types.String("hello"),
765+
},
766+
{
767+
name: "int - error",
768+
expr: "1",
769+
iss: errors.New("unsupported return type: int, want string"),
770+
},
771+
},
772+
},
620773
}
621774
for _, tst := range tests {
622775
tc := tst
623776
t.Run(tc.name, func(t *testing.T) {
624777
opts := tc.beforeOpts
625-
opts = append(opts, FromConfig(tc.conf, func(elem any) (EnvOption, bool) {
626-
if ext, ok := elem.(*env.Extension); ok && ext.Name == "optional" {
627-
ver, _ := ext.GetVersion()
628-
return OptionalTypes(OptionalTypesVersion(ver)), true
629-
}
630-
return nil, false
631-
}))
778+
opts = append(opts, FromConfig(tc.conf, tc.confHandlers...))
632779
opts = append(opts, tc.afterOpts...)
633780
var e *Env
634781
var err error
@@ -679,6 +826,16 @@ func TestEnvFromConfigErrors(t *testing.T) {
679826
conf *env.Config
680827
want error
681828
}{
829+
{
830+
name: "bad container",
831+
conf: env.NewConfig("bad container").SetContainer(".hello.world"),
832+
want: errors.New("container name must not contain"),
833+
},
834+
{
835+
name: "colliding imports",
836+
conf: env.NewConfig("colliding imports").AddImports(env.NewImport("pkg.ImportName"), env.NewImport("pkg2.ImportName")),
837+
want: errors.New("abbreviation collides"),
838+
},
682839
{
683840
name: "invalid subset",
684841
conf: env.NewConfig("invalid subset").SetStdLib(env.NewLibrarySubset().SetDisableMacros(true)),
@@ -707,9 +864,21 @@ func TestEnvFromConfigErrors(t *testing.T) {
707864
{
708865
name: "unrecognized extension",
709866
conf: env.NewConfig("unrecognized extension").
710-
AddExtensions(env.NewExtension("optional", math.MaxUint32)),
867+
AddExtensions(env.NewExtension("unrecognized", math.MaxUint32)),
711868
want: errors.New("unrecognized extension"),
712869
},
870+
{
871+
name: "invalid validator config",
872+
conf: env.NewConfig("invalid validator config").
873+
AddValidators(env.NewValidator("cel.validator.comprehension_nesting_limit")),
874+
want: errors.New("invalid validator"),
875+
},
876+
{
877+
name: "invalid validator config type",
878+
conf: env.NewConfig("invalid validator config").
879+
AddValidators(env.NewValidator("cel.validator.comprehension_nesting_limit").SetConfig(map[string]any{"limit": 2.0})),
880+
want: errors.New("invalid validator"),
881+
},
713882
}
714883
for _, tst := range tests {
715884
tc := tst
@@ -829,6 +998,26 @@ func mustContextProto(t *testing.T, pb proto.Message) Activation {
829998
return ctx
830999
}
8311000

1001+
type returnTypeValidator struct {
1002+
returnType *Type
1003+
}
1004+
1005+
func (returnTypeValidator) Name() string {
1006+
return "cel.validators.return_type"
1007+
}
1008+
1009+
func (v returnTypeValidator) Validate(_ *Env, c ValidatorConfig, a *ast.AST, iss *Issues) {
1010+
if a.GetType(a.Expr().ID()) != v.returnType {
1011+
iss.ReportErrorAtID(a.Expr().ID(),
1012+
"unsupported return type: %s, want %s",
1013+
a.GetType(a.Expr().ID()), v.returnType.TypeName())
1014+
}
1015+
}
1016+
1017+
func (v returnTypeValidator) ToConfig() *env.Validator {
1018+
return env.NewValidator(v.Name()).SetConfig(map[string]any{"type_name": v.returnType.TypeName()})
1019+
}
1020+
8321021
type customLegacyProvider struct {
8331022
provider ref.TypeProvider
8341023
}

0 commit comments

Comments
 (0)