@@ -24,6 +24,8 @@ import (
24
24
"testing"
25
25
26
26
"github.com/google/cel-go/common"
27
+ "github.com/google/cel-go/common/ast"
28
+ "github.com/google/cel-go/common/decls"
27
29
"github.com/google/cel-go/common/env"
28
30
"github.com/google/cel-go/common/operators"
29
31
"github.com/google/cel-go/common/types"
@@ -401,6 +403,30 @@ func TestEnvToConfig(t *testing.T) {
401
403
},
402
404
want : env .NewConfig ("context proto" ).SetContextVariable (env .NewContextVariable ("google.expr.proto3.test.TestAllTypes" )),
403
405
},
406
+ {
407
+ name : "feature flags" ,
408
+ opts : []EnvOption {
409
+ DefaultUTCTimeZone (false ),
410
+ EnableMacroCallTracking (),
411
+ },
412
+ want : env .NewConfig ("feature flags" ).AddFeatures (
413
+ env .NewFeature ("cel.feature.macro_call_tracking" , true ),
414
+ ),
415
+ },
416
+ {
417
+ name : "validators" ,
418
+ opts : []EnvOption {
419
+ ExtendedValidations (),
420
+ ASTValidators (ValidateComprehensionNestingLimit (1 )),
421
+ },
422
+ want : env .NewConfig ("validators" ).AddValidators (
423
+ env .NewValidator ("cel.validator.duration" ),
424
+ env .NewValidator ("cel.validator.timestamp" ),
425
+ env .NewValidator ("cel.validator.matches" ),
426
+ env .NewValidator ("cel.validator.homogeneous_literals" ),
427
+ env .NewValidator ("cel.validator.comprehension_nesting_limit" ).SetConfig (map [string ]any {"limit" : 1 }),
428
+ ),
429
+ },
404
430
}
405
431
406
432
for _ , tst := range tests {
@@ -430,11 +456,12 @@ func TestEnvFromConfig(t *testing.T) {
430
456
out ref.Val
431
457
}
432
458
tests := []struct {
433
- name string
434
- beforeOpts []EnvOption
435
- afterOpts []EnvOption
436
- conf * env.Config
437
- exprs []exprCase
459
+ name string
460
+ beforeOpts []EnvOption
461
+ afterOpts []EnvOption
462
+ conf * env.Config
463
+ confHandlers []ConfigOptionFactory
464
+ exprs []exprCase
438
465
}{
439
466
{
440
467
name : "std env" ,
@@ -617,18 +644,138 @@ func TestEnvFromConfig(t *testing.T) {
617
644
},
618
645
},
619
646
},
647
+ {
648
+ name : "extensions - config factory" ,
649
+ conf : env .NewConfig ("extensions" ).
650
+ AddExtensions (env .NewExtension ("plus" , math .MaxUint32 )),
651
+ confHandlers : []ConfigOptionFactory {
652
+ func (a any ) (EnvOption , bool ) {
653
+ ext , ok := a .(* env.Extension )
654
+ if ! ok || ext .Name != "plus" {
655
+ return nil , false
656
+ }
657
+ return Function ("plus" , Overload ("plus_int_int" , []* Type {IntType , IntType }, IntType ,
658
+ decls .BinaryBinding (func (lhs , rhs ref.Val ) ref.Val {
659
+ l := lhs .(types.Int )
660
+ r := rhs .(types.Int )
661
+ return l + r
662
+ }))), true
663
+ },
664
+ },
665
+ exprs : []exprCase {
666
+ {
667
+ name : "plus" ,
668
+ expr : "plus(1, 2)" ,
669
+ out : types .Int (3 ),
670
+ },
671
+ },
672
+ },
673
+ {
674
+ name : "features" ,
675
+ conf : env .NewConfig ("features" ).
676
+ AddVariables (
677
+ env .NewVariable ("m" ,
678
+ env .NewTypeDesc ("map" , env .NewTypeDesc ("string" ), env .NewTypeDesc ("string" )))).
679
+ AddFeatures (
680
+ env .NewFeature ("cel.feature.backtick_escape_syntax" , true ),
681
+ env .NewFeature ("cel.feature.unknown_feature_name" , true )),
682
+ exprs : []exprCase {
683
+ {
684
+ name : "optional key" ,
685
+ expr : "m.`key-name` == 'value'" ,
686
+ in : map [string ]any {"m" : map [string ]string {"key-name" : "value" }},
687
+ out : types .True ,
688
+ },
689
+ },
690
+ },
691
+ {
692
+ name : "validators" ,
693
+ conf : env .NewConfig ("validators" ).
694
+ AddVariables (
695
+ env .NewVariable ("m" ,
696
+ env .NewTypeDesc ("map" , env .NewTypeDesc ("string" ), env .NewTypeDesc ("string" ))),
697
+ ).
698
+ AddValidators (
699
+ env .NewValidator (durationValidatorName ),
700
+ env .NewValidator (timestampValidatorName ),
701
+ env .NewValidator (regexValidatorName ),
702
+ env .NewValidator (homogeneousValidatorName ),
703
+ env .NewValidator (nestingLimitValidatorName ).SetConfig (map [string ]any {"limit" : 0 }),
704
+ ),
705
+ exprs : []exprCase {
706
+ {
707
+ name : "bad duration" ,
708
+ expr : "duration('1')" ,
709
+ iss : errors .New ("invalid duration" ),
710
+ },
711
+ {
712
+ name : "bad timestamp" ,
713
+ expr : "timestamp('1')" ,
714
+ iss : errors .New ("invalid timestamp" ),
715
+ },
716
+ {
717
+ name : "bad regex" ,
718
+ expr : "'hello'.matches('?^()')" ,
719
+ iss : errors .New ("invalid matches" ),
720
+ },
721
+ {
722
+ name : "mixed type list" ,
723
+ expr : "[1, 2.0]" ,
724
+ iss : errors .New ("expected type 'int'" ),
725
+ },
726
+ {
727
+ name : "disabled comprehension" ,
728
+ expr : "[1, 2].exists(x, x % 2 == 0)" ,
729
+ iss : errors .New ("comprehension exceeds nesting limit" ),
730
+ },
731
+ },
732
+ },
733
+ {
734
+ name : "validators - config factory" ,
735
+ conf : env .NewConfig ("validators" ).
736
+ AddValidators (
737
+ env .NewValidator ("cel.validators.return_type" ).SetConfig (map [string ]any {"type_name" : "string" }),
738
+ ),
739
+ confHandlers : []ConfigOptionFactory {
740
+ func (a any ) (EnvOption , bool ) {
741
+ val , ok := a .(* env.Validator )
742
+ if ! ok || val .Name != "cel.validators.return_type" {
743
+ return nil , false
744
+ }
745
+ typeName , found := val .ConfigValue ("type_name" )
746
+ if ! found {
747
+ return func (* Env ) (* Env , error ) {
748
+ return nil , fmt .Errorf ("invalid validator: %s missing config parameter 'type_name'" , val .Name )
749
+ }, true
750
+ }
751
+ return func (e * Env ) (* Env , error ) {
752
+ t , err := env .NewTypeDesc (typeName .(string )).AsCELType (e .CELTypeProvider ())
753
+ if err != nil {
754
+ return nil , err
755
+ }
756
+ return ASTValidators (returnTypeValidator {returnType : t })(e )
757
+ }, true
758
+ },
759
+ },
760
+ exprs : []exprCase {
761
+ {
762
+ name : "string - ok" ,
763
+ expr : "'hello'" ,
764
+ out : types .String ("hello" ),
765
+ },
766
+ {
767
+ name : "int - error" ,
768
+ expr : "1" ,
769
+ iss : errors .New ("unsupported return type: int, want string" ),
770
+ },
771
+ },
772
+ },
620
773
}
621
774
for _ , tst := range tests {
622
775
tc := tst
623
776
t .Run (tc .name , func (t * testing.T ) {
624
777
opts := tc .beforeOpts
625
- opts = append (opts , FromConfig (tc .conf , func (elem any ) (EnvOption , bool ) {
626
- if ext , ok := elem .(* env.Extension ); ok && ext .Name == "optional" {
627
- ver , _ := ext .GetVersion ()
628
- return OptionalTypes (OptionalTypesVersion (ver )), true
629
- }
630
- return nil , false
631
- }))
778
+ opts = append (opts , FromConfig (tc .conf , tc .confHandlers ... ))
632
779
opts = append (opts , tc .afterOpts ... )
633
780
var e * Env
634
781
var err error
@@ -679,6 +826,16 @@ func TestEnvFromConfigErrors(t *testing.T) {
679
826
conf * env.Config
680
827
want error
681
828
}{
829
+ {
830
+ name : "bad container" ,
831
+ conf : env .NewConfig ("bad container" ).SetContainer (".hello.world" ),
832
+ want : errors .New ("container name must not contain" ),
833
+ },
834
+ {
835
+ name : "colliding imports" ,
836
+ conf : env .NewConfig ("colliding imports" ).AddImports (env .NewImport ("pkg.ImportName" ), env .NewImport ("pkg2.ImportName" )),
837
+ want : errors .New ("abbreviation collides" ),
838
+ },
682
839
{
683
840
name : "invalid subset" ,
684
841
conf : env .NewConfig ("invalid subset" ).SetStdLib (env .NewLibrarySubset ().SetDisableMacros (true )),
@@ -707,9 +864,21 @@ func TestEnvFromConfigErrors(t *testing.T) {
707
864
{
708
865
name : "unrecognized extension" ,
709
866
conf : env .NewConfig ("unrecognized extension" ).
710
- AddExtensions (env .NewExtension ("optional " , math .MaxUint32 )),
867
+ AddExtensions (env .NewExtension ("unrecognized " , math .MaxUint32 )),
711
868
want : errors .New ("unrecognized extension" ),
712
869
},
870
+ {
871
+ name : "invalid validator config" ,
872
+ conf : env .NewConfig ("invalid validator config" ).
873
+ AddValidators (env .NewValidator ("cel.validator.comprehension_nesting_limit" )),
874
+ want : errors .New ("invalid validator" ),
875
+ },
876
+ {
877
+ name : "invalid validator config type" ,
878
+ conf : env .NewConfig ("invalid validator config" ).
879
+ AddValidators (env .NewValidator ("cel.validator.comprehension_nesting_limit" ).SetConfig (map [string ]any {"limit" : 2.0 })),
880
+ want : errors .New ("invalid validator" ),
881
+ },
713
882
}
714
883
for _ , tst := range tests {
715
884
tc := tst
@@ -829,6 +998,26 @@ func mustContextProto(t *testing.T, pb proto.Message) Activation {
829
998
return ctx
830
999
}
831
1000
1001
+ type returnTypeValidator struct {
1002
+ returnType * Type
1003
+ }
1004
+
1005
+ func (returnTypeValidator ) Name () string {
1006
+ return "cel.validators.return_type"
1007
+ }
1008
+
1009
+ func (v returnTypeValidator ) Validate (_ * Env , c ValidatorConfig , a * ast.AST , iss * Issues ) {
1010
+ if a .GetType (a .Expr ().ID ()) != v .returnType {
1011
+ iss .ReportErrorAtID (a .Expr ().ID (),
1012
+ "unsupported return type: %s, want %s" ,
1013
+ a .GetType (a .Expr ().ID ()), v .returnType .TypeName ())
1014
+ }
1015
+ }
1016
+
1017
+ func (v returnTypeValidator ) ToConfig () * env.Validator {
1018
+ return env .NewValidator (v .Name ()).SetConfig (map [string ]any {"type_name" : v .returnType .TypeName ()})
1019
+ }
1020
+
832
1021
type customLegacyProvider struct {
833
1022
provider ref.TypeProvider
834
1023
}
0 commit comments