Skip to content

Commit 2f7606a

Browse files
Cost tracking for two-variable comprehensions and bindings (#1104)
* Updates to the cost estimators to support bind and two-var comprehensions * Consolidation of local variables
1 parent 7621362 commit 2f7606a

File tree

8 files changed

+942
-323
lines changed

8 files changed

+942
-323
lines changed

checker/cost.go

+490-159
Large diffs are not rendered by default.

checker/cost_test.go

+76-15
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ func TestCost(t *testing.T) {
4040
nestedMap := types.NewMapType(types.StringType, allMap)
4141

4242
zeroCost := CostEstimate{}
43-
oneCost := CostEstimate{Min: 1, Max: 1}
43+
oneCost := FixedCostEstimate(1)
4444
cases := []struct {
4545
name string
4646
expr string
@@ -255,6 +255,11 @@ func TestCost(t *testing.T) {
255255
expr: `size("123")`,
256256
wanted: oneCost,
257257
},
258+
{
259+
name: "bytes size",
260+
expr: `size(b"123")`,
261+
wanted: oneCost,
262+
},
258263
{
259264
name: "bytes to string conversion",
260265
vars: []*decls.VariableDecl{decls.NewVariable("input", types.BytesType)},
@@ -462,6 +467,36 @@ func TestCost(t *testing.T) {
462467
},
463468
wanted: CostEstimate{Min: 5, Max: 5},
464469
},
470+
{
471+
name: "list size from concat",
472+
expr: `([x, y] + list1 + list2).size()`,
473+
vars: []*decls.VariableDecl{
474+
decls.NewVariable("x", types.IntType),
475+
decls.NewVariable("y", types.IntType),
476+
decls.NewVariable("list1", types.NewListType(types.IntType)),
477+
decls.NewVariable("list2", types.NewListType(types.IntType)),
478+
},
479+
hints: map[string]uint64{
480+
"list1": 10,
481+
"list2": 20,
482+
},
483+
wanted: CostEstimate{Min: 17, Max: 17},
484+
},
485+
{
486+
name: "list cost tracking through comprehension",
487+
expr: `[list1, list2].exists(l, l.exists(v, v.startsWith('hi')))`,
488+
vars: []*decls.VariableDecl{
489+
decls.NewVariable("list1", types.NewListType(types.StringType)),
490+
decls.NewVariable("list2", types.NewListType(types.StringType)),
491+
},
492+
hints: map[string]uint64{
493+
"list1": 10,
494+
"list1.@items": 64,
495+
"list2": 20,
496+
"list2.@items": 128,
497+
},
498+
wanted: CostEstimate{Min: 21, Max: 265},
499+
},
465500
{
466501
name: "str endsWith equality",
467502
expr: `str1.endsWith("abcdefghijklmnopqrstuvwxyz") == str2.endsWith("abcdefghijklmnopqrstuvwxyz")`,
@@ -539,27 +574,37 @@ func TestCost(t *testing.T) {
539574
wanted: CostEstimate{Min: 61, Max: 61},
540575
},
541576
{
542-
name: "nested array selection",
577+
name: "nested map selection",
543578
expr: `{'a': [1,2], 'b': [1,2], 'c': [1,2], 'd': [1,2], 'e': [1,2]}.b`,
544579
wanted: CostEstimate{Min: 81, Max: 81},
545580
},
546581
{
547-
// Estimated cost does not track the sizes of nested aggregate types
548-
// (lists, maps, ...) and so assumes a worst case cost when an
549-
// expression applies a comprehension to a nested aggregated type,
550-
// even if the size information is available.
551-
// TODO: This should be fixed.
552582
name: "comprehension on nested list",
583+
expr: `[[1, 1], [2, 2], [3, 3], [4, 4], [5, 5]].all(y, y.all(y, y == 1))`,
584+
wanted: CostEstimate{Min: 76, Max: 136},
585+
},
586+
{
587+
name: "comprehension on transformed nested list",
553588
expr: `[1,2,3,4,5].map(x, [x, x]).all(y, y.all(y, y == 1))`,
554-
wanted: CostEstimate{Min: 157, Max: 18446744073709551615},
589+
wanted: CostEstimate{Min: 157, Max: 217},
555590
},
556591
{
557-
// Make sure we're accounting for not just the iteration range size,
558-
// but also the overall comprehension size. The chained map calls
559-
// will treat the result of one map as the iteration range of the other,
560-
// so they're planned in reverse; however, the `+` should verify that
561-
// the comprehension result has a size.
562-
name: "comprehension size",
592+
name: "comprehension on nested literal list",
593+
expr: `["a", "ab", "abc", "abcd", "abcde"].map(x, [x, x]).all(y, y.all(y, y.startsWith('a')))`,
594+
wanted: CostEstimate{Min: 157, Max: 217},
595+
},
596+
{
597+
name: "comprehension on nested variable list",
598+
expr: `input.map(x, [x, x]).all(y, y.all(y, y.startsWith('a')))`,
599+
vars: []*decls.VariableDecl{decls.NewVariable("input", types.NewListType(types.StringType))},
600+
hints: map[string]uint64{
601+
"input": 5,
602+
"input.@items": 10,
603+
},
604+
wanted: CostEstimate{Min: 13, Max: 208},
605+
},
606+
{
607+
name: "comprehension chaining with concat",
563608
expr: `[1,2,3,4,5].map(x, x).map(x, x) + [1]`,
564609
wanted: CostEstimate{Min: 173, Max: 173},
565610
},
@@ -568,9 +613,25 @@ func TestCost(t *testing.T) {
568613
expr: `[1,2,3].all(i, i in [1,2,3].map(j, j + j))`,
569614
wanted: CostEstimate{Min: 20, Max: 230},
570615
},
616+
{
617+
name: "nested dyn comprehension",
618+
expr: `dyn([1,2,3]).all(i, i in dyn([1,2,3]).map(j, j + j))`,
619+
wanted: CostEstimate{Min: 21, Max: 234},
620+
},
621+
{
622+
name: "literal map access",
623+
expr: `{'hello': 'hi'}['hello'] != {'hello': 'bye'}['hello']`,
624+
wanted: CostEstimate{Min: 63, Max: 63},
625+
},
626+
{
627+
name: "literal list access",
628+
expr: `['hello', 'hi'][0] != ['hello', 'bye'][1]`,
629+
wanted: CostEstimate{Min: 23, Max: 23},
630+
},
571631
}
572632

573-
for _, tc := range cases {
633+
for _, tst := range cases {
634+
tc := tst
574635
t.Run(tc.name, func(t *testing.T) {
575636
if tc.hints == nil {
576637
tc.hints = map[string]uint64{}

ext/bindings_test.go

+75-30
Original file line numberDiff line numberDiff line change
@@ -20,56 +20,101 @@ import (
2020
"testing"
2121

2222
"github.com/google/cel-go/cel"
23+
"github.com/google/cel-go/checker"
2324
"github.com/google/cel-go/common/ast"
2425
"github.com/google/cel-go/common/operators"
2526
"github.com/google/cel-go/common/types"
2627
"github.com/google/cel-go/common/types/ref"
2728
)
2829

2930
var bindingTests = []struct {
30-
expr string
31-
parseOnly bool
31+
name string
32+
expr string
33+
vars []cel.EnvOption
34+
in map[string]any
35+
hints map[string]uint64
36+
estimatedCost checker.CostEstimate
37+
actualCost uint64
3238
}{
33-
{expr: `cel.bind(a, 'hell' + 'o' + '!', [a, a, a].join(', ')) ==
34-
['hell' + 'o' + '!', 'hell' + 'o' + '!', 'hell' + 'o' + '!'].join(', ')`},
35-
// Variable shadowing
36-
{expr: `cel.bind(a,
37-
cel.bind(a, 'world', a + '!'),
38-
'hello ' + a) == 'hello ' + 'world' + '!'`},
39+
{
40+
name: "single bind",
41+
expr: `cel.bind(a, 'hell' + 'o' + '!', "%s, %s, %s".format([a, a, a])) ==
42+
'hello!, hello!, hello' + '!'`,
43+
estimatedCost: checker.CostEstimate{Min: 30, Max: 32},
44+
actualCost: 32,
45+
},
46+
{
47+
name: "multiple binds",
48+
expr: `cel.bind(a, 'hello!',
49+
cel.bind(b, 'goodbye',
50+
a + ' and, ' + b)) == 'hello! and, goodbye'`,
51+
estimatedCost: checker.CostEstimate{Min: 27, Max: 28},
52+
actualCost: 28,
53+
},
54+
{
55+
name: "shadow binds",
56+
expr: `cel.bind(a,
57+
cel.bind(a, 'world', a + '!'),
58+
'hello ' + a) == 'hello ' + 'world' + '!'`,
59+
estimatedCost: checker.CostEstimate{Min: 30, Max: 31},
60+
actualCost: 31,
61+
},
62+
{
63+
name: "nested bind with int list",
64+
expr: `cel.bind(a, x,
65+
cel.bind(b, a[0],
66+
cel.bind(c, a[1], b + c))) == 10`,
67+
vars: []cel.EnvOption{cel.Variable("x", cel.ListType(cel.IntType))},
68+
in: map[string]any{
69+
"x": []int64{3, 7},
70+
},
71+
hints: map[string]uint64{
72+
"x": 3,
73+
},
74+
estimatedCost: checker.CostEstimate{Min: 39, Max: 39},
75+
actualCost: 39,
76+
},
77+
{
78+
name: "nested bind with string list",
79+
expr: `cel.bind(a, x,
80+
cel.bind(b, a[0],
81+
cel.bind(c, a[1], b + c))) == "threeseven"`,
82+
vars: []cel.EnvOption{cel.Variable("x", cel.ListType(cel.StringType))},
83+
in: map[string]any{
84+
"x": []string{"three", "seven"},
85+
},
86+
hints: map[string]uint64{
87+
"x": 3,
88+
"x.@items": 10,
89+
},
90+
estimatedCost: checker.CostEstimate{Min: 38, Max: 40},
91+
actualCost: 39,
92+
},
3993
}
4094

4195
func TestBindings(t *testing.T) {
42-
env, err := cel.NewEnv(Bindings(BindingsVersion(0)), Strings())
43-
if err != nil {
44-
t.Fatalf("cel.NewEnv(Bindings(), Strings()) failed: %v", err)
45-
}
46-
for i, tst := range bindingTests {
96+
for _, tst := range bindingTests {
4797
tc := tst
48-
t.Run(fmt.Sprintf("[%d]", i), func(t *testing.T) {
98+
t.Run(tc.name, func(t *testing.T) {
4999
var asts []*cel.Ast
100+
opts := append([]cel.EnvOption{Bindings(BindingsVersion(0)), Strings()}, tc.vars...)
101+
env, err := cel.NewEnv(opts...)
102+
if err != nil {
103+
t.Fatalf("cel.NewEnv(Bindings(), Strings()) failed: %v", err)
104+
}
50105
pAst, iss := env.Parse(tc.expr)
51106
if iss.Err() != nil {
52107
t.Fatalf("env.Parse(%v) failed: %v", tc.expr, iss.Err())
53108
}
54109
asts = append(asts, pAst)
55-
if !tc.parseOnly {
56-
cAst, iss := env.Check(pAst)
57-
if iss.Err() != nil {
58-
t.Fatalf("env.Check(%v) failed: %v", tc.expr, iss.Err())
59-
}
60-
asts = append(asts, cAst)
110+
cAst, iss := env.Check(pAst)
111+
if iss.Err() != nil {
112+
t.Fatalf("env.Check(%v) failed: %v", tc.expr, iss.Err())
61113
}
114+
testCheckCost(t, env, cAst, tc.hints, tc.estimatedCost)
115+
asts = append(asts, cAst)
62116
for _, ast := range asts {
63-
prg, err := env.Program(ast)
64-
if err != nil {
65-
t.Fatal(err)
66-
}
67-
out, _, err := prg.Eval(cel.NoVars())
68-
if err != nil {
69-
t.Fatal(err)
70-
} else if out.Value() != true {
71-
t.Errorf("got %v, wanted true for expr: %s", out.Value(), tc.expr)
72-
}
117+
testEvalWithCost(t, env, ast, tc.in, tc.actualCost)
73118
}
74119
})
75120
}

0 commit comments

Comments
 (0)