Skip to content

Commit 9855c70

Browse files
Support for splitting nested branching operators within policies (#1136)
* Support for splitting nested branching operators within policies * Introduce an ast.Heights() helper * Updated tests and expanded flattening to all calls * Added test case for comprehension pruning during unnest
1 parent fad0c1b commit 9855c70

File tree

10 files changed

+644
-64
lines changed

10 files changed

+644
-64
lines changed

common/ast/ast.go

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,13 @@ func MaxID(a *AST) int64 {
160160
return visitor.maxID + 1
161161
}
162162

163+
// Heights computes the heights of all AST expressions and returns a map from expression id to height.
164+
func Heights(a *AST) map[int64]int {
165+
visitor := make(heightVisitor)
166+
PostOrderVisit(a.Expr(), visitor)
167+
return visitor
168+
}
169+
163170
// NewSourceInfo creates a simple SourceInfo object from an input common.Source value.
164171
func NewSourceInfo(src common.Source) *SourceInfo {
165172
var lineOffsets []int32
@@ -455,3 +462,74 @@ func (v *maxIDVisitor) VisitEntryExpr(e EntryExpr) {
455462
v.maxID = e.ID()
456463
}
457464
}
465+
466+
type heightVisitor map[int64]int
467+
468+
// VisitExpr computes the height of a given node as the max height of its children plus one.
469+
//
470+
// Identifiers and literals are treated as having a height of zero.
471+
func (hv heightVisitor) VisitExpr(e Expr) {
472+
// default includes IdentKind, LiteralKind
473+
hv[e.ID()] = 0
474+
switch e.Kind() {
475+
case SelectKind:
476+
hv[e.ID()] = 1 + hv[e.AsSelect().Operand().ID()]
477+
case CallKind:
478+
c := e.AsCall()
479+
height := hv.maxHeight(c.Args()...)
480+
if c.IsMemberFunction() {
481+
tHeight := hv[c.Target().ID()]
482+
if tHeight > height {
483+
height = tHeight
484+
}
485+
}
486+
hv[e.ID()] = 1 + height
487+
case ListKind:
488+
l := e.AsList()
489+
hv[e.ID()] = 1 + hv.maxHeight(l.Elements()...)
490+
case MapKind:
491+
m := e.AsMap()
492+
hv[e.ID()] = 1 + hv.maxEntryHeight(m.Entries()...)
493+
case StructKind:
494+
s := e.AsStruct()
495+
hv[e.ID()] = 1 + hv.maxEntryHeight(s.Fields()...)
496+
case ComprehensionKind:
497+
comp := e.AsComprehension()
498+
hv[e.ID()] = 1 + hv.maxHeight(comp.IterRange(), comp.AccuInit(), comp.LoopCondition(), comp.LoopStep(), comp.Result())
499+
}
500+
}
501+
502+
// VisitEntryExpr computes the max height of a map or struct entry and associates the height with the entry id.
503+
func (hv heightVisitor) VisitEntryExpr(e EntryExpr) {
504+
hv[e.ID()] = 0
505+
switch e.Kind() {
506+
case MapEntryKind:
507+
me := e.AsMapEntry()
508+
hv[e.ID()] = hv.maxHeight(me.Value(), me.Key())
509+
case StructFieldKind:
510+
sf := e.AsStructField()
511+
hv[e.ID()] = hv[sf.Value().ID()]
512+
}
513+
}
514+
515+
func (hv heightVisitor) maxHeight(exprs ...Expr) int {
516+
max := 0
517+
for _, e := range exprs {
518+
h := hv[e.ID()]
519+
if h > max {
520+
max = h
521+
}
522+
}
523+
return max
524+
}
525+
526+
func (hv heightVisitor) maxEntryHeight(entries ...EntryExpr) int {
527+
max := 0
528+
for _, e := range entries {
529+
h := hv[e.ID()]
530+
if h > max {
531+
max = h
532+
}
533+
}
534+
return max
535+
}

common/ast/ast_test.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,31 @@ func TestMaxID(t *testing.T) {
339339
}
340340
}
341341

342+
func TestHeights(t *testing.T) {
343+
tests := []struct {
344+
expr string
345+
height int
346+
}{
347+
{`'a' == 'b'`, 1},
348+
{`'a'.size()`, 1},
349+
{`[1, 2].size()`, 2},
350+
{`size('a')`, 1},
351+
{`has({'a': 1}.a)`, 2},
352+
{`{'a': 1}`, 1},
353+
{`{'a': 1}['a']`, 2},
354+
{`[1, 2, 3].exists(i, i % 2 == 1)`, 4},
355+
{`google.expr.proto3.test.TestAllTypes{}`, 1},
356+
{`google.expr.proto3.test.TestAllTypes{repeated_int32: [1, 2]}`, 2},
357+
}
358+
for _, tst := range tests {
359+
checked := mustTypeCheck(t, tst.expr)
360+
maxHeight := ast.Heights(checked)[checked.Expr().ID()]
361+
if maxHeight != tst.height {
362+
t.Errorf("ast.Heights(%q) got max height %d, wanted %d", tst.expr, maxHeight, tst.height)
363+
}
364+
}
365+
}
366+
342367
func mockRelativeSource(t testing.TB, text string, lineOffsets []int32, baseLocation common.Location) common.Source {
343368
t.Helper()
344369
return &mockSource{

common/ast/navigable.go

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,8 +237,13 @@ func visit(expr Expr, visitor Visitor, order visitOrder, depth, maxDepth int) {
237237
case StructKind:
238238
s := expr.AsStruct()
239239
for _, f := range s.Fields() {
240-
visitor.VisitEntryExpr(f)
240+
if order == preOrder {
241+
visitor.VisitEntryExpr(f)
242+
}
241243
visit(f.AsStructField().Value(), visitor, order, depth+1, maxDepth)
244+
if order == postOrder {
245+
visitor.VisitEntryExpr(f)
246+
}
242247
}
243248
}
244249
if order == postOrder {

policy/compiler.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,8 @@ func Compile(env *cel.Env, p *Policy, opts ...CompilerOption) (*cel.Ast, *cel.Is
198198
if iss.Err() != nil {
199199
return nil, iss
200200
}
201-
composer := NewRuleComposer(env, p)
201+
// An error cannot happen when composing without supplying options
202+
composer, _ := NewRuleComposer(env)
202203
return composer.Compose(rule)
203204
}
204205

policy/compiler_test.go

Lines changed: 125 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,64 @@ import (
3131

3232
func TestCompile(t *testing.T) {
3333
for _, tst := range policyTests {
34-
t.Run(tst.name, func(t *testing.T) {
35-
r := newRunner(t, tst.name, tst.expr, tst.parseOpts, tst.envOpts...)
34+
tc := tst
35+
t.Run(tc.name, func(t *testing.T) {
36+
r := newRunner(tc.name, tc.expr, tc.parseOpts)
37+
env, ast, iss := r.compile(t, tc.envOpts, []CompilerOption{})
38+
if iss.Err() != nil {
39+
t.Fatalf("Compile(%s) failed: %v", r.name, iss.Err())
40+
}
41+
r.setup(t, env, ast)
42+
r.run(t)
43+
})
44+
}
45+
}
46+
47+
func TestRuleComposerError(t *testing.T) {
48+
env, err := cel.NewEnv()
49+
if err != nil {
50+
t.Fatalf("NewEnv() failed: %v", err)
51+
}
52+
_, err = NewRuleComposer(env, ExpressionUnnestHeight(-1))
53+
if err == nil || !strings.Contains(err.Error(), "invalid unnest") {
54+
t.Errorf("NewRuleComposer() got %v, wanted 'invalid unnest'", err)
55+
}
56+
}
57+
58+
func TestRuleComposerUnnest(t *testing.T) {
59+
for _, tst := range composerUnnestTests {
60+
tc := tst
61+
t.Run(tc.name, func(t *testing.T) {
62+
r := newRunner(tc.name, tc.expr, []ParserOption{})
63+
env, rule, iss := r.compileRule(t)
64+
if iss.Err() != nil {
65+
t.Fatalf("CompileRule() failed: %v", iss.Err())
66+
}
67+
rc, err := NewRuleComposer(env, tc.composerOpts...)
68+
if err != nil {
69+
t.Fatalf("NewRuleComposer() failed: %v", err)
70+
}
71+
ast, iss := rc.Compose(rule)
72+
if iss.Err() != nil {
73+
t.Fatalf("Compose(rule) failed: %v", iss.Err())
74+
}
75+
unparsed, err := cel.AstToString(ast)
76+
if err != nil {
77+
t.Fatalf("cel.AstToString() failed: %v", err)
78+
}
79+
if normalize(unparsed) != normalize(tc.composed) {
80+
t.Errorf("cel.AstToString() got %s, wanted %s", unparsed, tc.composed)
81+
}
82+
r.setup(t, env, ast)
3683
r.run(t)
3784
})
3885
}
3986
}
4087

4188
func TestCompileError(t *testing.T) {
4289
for _, tst := range policyErrorTests {
43-
_, _, iss := compile(t, tst.name, []ParserOption{}, []cel.EnvOption{}, tst.compilerOpts)
90+
policy := parsePolicy(t, tst.name, []ParserOption{})
91+
_, _, iss := compile(t, tst.name, policy, []cel.EnvOption{}, tst.compilerOpts)
4492
if iss.Err() == nil {
4593
t.Fatalf("compile(%s) did not error, wanted %s", tst.name, tst.err)
4694
}
@@ -98,7 +146,8 @@ func TestMaxNestedExpressions_Error(t *testing.T) {
98146
wantError := `ERROR: testdata/required_labels/policy.yaml:15:8: error configuring compiler option: nested expression limit must be non-negative, non-zero value: -1
99147
| name: "required_labels"
100148
| .......^`
101-
_, _, iss := compile(t, policyName, []ParserOption{}, []cel.EnvOption{}, []CompilerOption{MaxNestedExpressions(-1)})
149+
policy := parsePolicy(t, policyName, []ParserOption{})
150+
_, _, iss := compile(t, policyName, policy, []cel.EnvOption{}, []CompilerOption{MaxNestedExpressions(-1)})
102151
if iss.Err() == nil {
103152
t.Fatalf("compile(%s) did not error, wanted %s", policyName, wantError)
104153
}
@@ -109,55 +158,40 @@ func TestMaxNestedExpressions_Error(t *testing.T) {
109158

110159
func BenchmarkCompile(b *testing.B) {
111160
for _, tst := range policyTests {
112-
r := newRunner(b, tst.name, tst.expr, tst.parseOpts, tst.envOpts...)
161+
r := newRunner(tst.name, tst.expr, tst.parseOpts)
162+
env, ast, iss := r.compile(b, tst.envOpts, []CompilerOption{})
163+
if iss.Err() != nil {
164+
b.Fatalf("Compile() failed: %v", iss.Err())
165+
}
166+
r.setup(b, env, ast)
113167
r.bench(b)
114168
}
115169
}
116170

117-
func newRunner(t testing.TB, name, expr string, parseOpts []ParserOption, opts ...cel.EnvOption) *runner {
118-
r := &runner{
171+
func newRunner(name, expr string, parseOpts []ParserOption, opts ...cel.EnvOption) *runner {
172+
return &runner{
119173
name: name,
120-
envOpts: opts,
121174
parseOpts: parseOpts,
122175
expr: expr}
123-
r.setup(t)
124-
return r
125176
}
126177

127178
type runner struct {
128-
name string
129-
envOpts []cel.EnvOption
130-
parseOpts []ParserOption
131-
compilerOpts []CompilerOption
132-
env *cel.Env
133-
expr string
134-
prg cel.Program
179+
name string
180+
parseOpts []ParserOption
181+
env *cel.Env
182+
expr string
183+
prg cel.Program
135184
}
136185

137-
func mustCompileExpr(t testing.TB, env *cel.Env, expr string) *cel.Ast {
138-
t.Helper()
139-
out, iss := env.Compile(expr)
140-
if iss.Err() != nil {
141-
t.Fatalf("env.Compile(%s) failed: %v", expr, iss.Err())
142-
}
143-
return out
186+
func (r *runner) compile(t testing.TB, envOpts []cel.EnvOption, compilerOpts []CompilerOption) (*cel.Env, *cel.Ast, *cel.Issues) {
187+
policy := parsePolicy(t, r.name, r.parseOpts)
188+
return compile(t, r.name, policy, envOpts, compilerOpts)
144189
}
145190

146-
func compile(t testing.TB, name string, parseOpts []ParserOption, envOpts []cel.EnvOption, compilerOpts []CompilerOption) (*cel.Env, *cel.Ast, *cel.Issues) {
191+
func (r *runner) compileRule(t testing.TB) (*cel.Env, *CompiledRule, *cel.Issues) {
147192
t.Helper()
148-
config := readPolicyConfig(t, fmt.Sprintf("testdata/%s/config.yaml", name))
149-
srcFile := readPolicy(t, fmt.Sprintf("testdata/%s/policy.yaml", name))
150-
parser, err := NewParser(parseOpts...)
151-
if err != nil {
152-
t.Fatalf("NewParser() failed: %v", err)
153-
}
154-
policy, iss := parser.Parse(srcFile)
155-
if iss.Err() != nil {
156-
t.Fatalf("Parse() failed: %v", iss.Err())
157-
}
158-
if policy.name.Value != name {
159-
t.Errorf("policy name is %v, wanted %q", policy.name, name)
160-
}
193+
config := readPolicyConfig(t, fmt.Sprintf("testdata/%s/config.yaml", r.name))
194+
policy := parsePolicy(t, r.name, r.parseOpts)
161195
env, err := cel.NewCustomEnv(
162196
cel.OptionalTypes(),
163197
cel.EnableMacroCallTracking(),
@@ -166,26 +200,17 @@ func compile(t testing.TB, name string, parseOpts []ParserOption, envOpts []cel.
166200
if err != nil {
167201
t.Fatalf("cel.NewEnv() failed: %v", err)
168202
}
169-
// Configure any custom environment options.
170-
env, err = env.Extend(envOpts...)
171-
if err != nil {
172-
t.Fatalf("env.Extend() with env options %v, failed: %v", config, err)
173-
}
174203
// Configure declarations
175204
env, err = env.Extend(FromConfig(config))
176205
if err != nil {
177206
t.Fatalf("env.Extend() with config options %v, failed: %v", config, err)
178207
}
179-
ast, iss := Compile(env, policy, compilerOpts...)
180-
return env, ast, iss
208+
rule, iss := CompileRule(env, policy)
209+
return env, rule, iss
181210
}
182211

183-
func (r *runner) setup(t testing.TB) {
212+
func (r *runner) setup(t testing.TB, env *cel.Env, ast *cel.Ast) {
184213
t.Helper()
185-
env, ast, iss := compile(t, r.name, r.parseOpts, r.envOpts, r.compilerOpts)
186-
if iss.Err() != nil {
187-
t.Fatalf("Compile(%s) failed: %v", r.name, iss.Err())
188-
}
189214
pExpr, err := cel.AstToString(ast)
190215
if err != nil {
191216
t.Fatalf("cel.AstToString() failed: %v", err)
@@ -323,6 +348,56 @@ func (r *runner) eval(t testing.TB, expr string) ref.Val {
323348
return out
324349
}
325350

351+
func mustCompileExpr(t testing.TB, env *cel.Env, expr string) *cel.Ast {
352+
t.Helper()
353+
out, iss := env.Compile(expr)
354+
if iss.Err() != nil {
355+
t.Fatalf("env.Compile(%s) failed: %v", expr, iss.Err())
356+
}
357+
return out
358+
}
359+
360+
func parsePolicy(t testing.TB, name string, parseOpts []ParserOption) *Policy {
361+
t.Helper()
362+
srcFile := readPolicy(t, fmt.Sprintf("testdata/%s/policy.yaml", name))
363+
parser, err := NewParser(parseOpts...)
364+
if err != nil {
365+
t.Fatalf("NewParser() failed: %v", err)
366+
}
367+
policy, iss := parser.Parse(srcFile)
368+
if iss.Err() != nil {
369+
t.Fatalf("Parse() failed: %v", iss.Err())
370+
}
371+
if policy.name.Value != name {
372+
t.Errorf("policy name is %v, wanted %q", policy.name, name)
373+
}
374+
return policy
375+
}
376+
377+
func compile(t testing.TB, name string, policy *Policy, envOpts []cel.EnvOption, compilerOpts []CompilerOption) (*cel.Env, *cel.Ast, *cel.Issues) {
378+
config := readPolicyConfig(t, fmt.Sprintf("testdata/%s/config.yaml", name))
379+
env, err := cel.NewCustomEnv(
380+
cel.OptionalTypes(),
381+
cel.EnableMacroCallTracking(),
382+
cel.ExtendedValidations(),
383+
ext.Bindings())
384+
if err != nil {
385+
t.Fatalf("cel.NewEnv() failed: %v", err)
386+
}
387+
// Configure any custom environment options.
388+
env, err = env.Extend(envOpts...)
389+
if err != nil {
390+
t.Fatalf("env.Extend() with env options %v, failed: %v", config, err)
391+
}
392+
// Configure declarations
393+
env, err = env.Extend(FromConfig(config))
394+
if err != nil {
395+
t.Fatalf("env.Extend() with config options %v, failed: %v", config, err)
396+
}
397+
ast, iss := Compile(env, policy, compilerOpts...)
398+
return env, ast, iss
399+
}
400+
326401
func normalize(s string) string {
327402
return strings.ReplaceAll(
328403
strings.ReplaceAll(

0 commit comments

Comments
 (0)