Skip to content

Commit fa6eb32

Browse files
authored
Add option to use inaccessible accumulator var (#1097)
Add an option to use '@Result' as the accumulator variable for builtin comprehensions. The current default "__result__" is accesible in the CEL Source, allowing for expressions to type check but lead to unexpected or incorrect results. '@Result' isn't a normally accessible identifier in the source expression so a bit safer as a default.
1 parent 7c5909e commit fa6eb32

File tree

11 files changed

+331
-38
lines changed

11 files changed

+331
-38
lines changed

cel/cel_test.go

+3-12
Original file line numberDiff line numberDiff line change
@@ -777,18 +777,9 @@ func TestMacroInterop(t *testing.T) {
777777
}
778778

779779
func TestMacroModern(t *testing.T) {
780-
existsOneMacro := ReceiverMacro("exists_one", 2,
781-
func(mef MacroExprFactory, iterRange celast.Expr, args []celast.Expr) (celast.Expr, *Error) {
782-
return parser.MakeExistsOne(mef, iterRange, args)
783-
})
784-
transformMacro := ReceiverMacro("transform", 2,
785-
func(mef MacroExprFactory, iterRange celast.Expr, args []celast.Expr) (celast.Expr, *Error) {
786-
return parser.MakeMap(mef, iterRange, args)
787-
})
788-
filterMacro := ReceiverMacro("filter", 2,
789-
func(mef MacroExprFactory, iterRange celast.Expr, args []celast.Expr) (celast.Expr, *Error) {
790-
return parser.MakeFilter(mef, iterRange, args)
791-
})
780+
existsOneMacro := ReceiverMacro("exists_one", 2, parser.MakeExistsOne)
781+
transformMacro := ReceiverMacro("transform", 2, parser.MakeMap)
782+
filterMacro := ReceiverMacro("filter", 2, parser.MakeFilter)
792783
pairMacro := GlobalMacro("pair", 2,
793784
func(mef MacroExprFactory, iterRange celast.Expr, args []celast.Expr) (celast.Expr, *Error) {
794785
return mef.NewMap(mef.NewMapEntry(args[0], args[1], false)), nil

cel/options.go

+9
Original file line numberDiff line numberDiff line change
@@ -664,6 +664,15 @@ func ParserExpressionSizeLimit(limit int) EnvOption {
664664
}
665665
}
666666

667+
// EnableHiddenAccumulatorName sets the parser to use the identifier '@result' for accumulators
668+
// which is not normally accessible from CEL source.
669+
func EnableHiddenAccumulatorName(enabled bool) EnvOption {
670+
return func(e *Env) (*Env, error) {
671+
e.prsrOpts = append(e.prsrOpts, parser.EnableHiddenAccumulatorName(enabled))
672+
return e, nil
673+
}
674+
}
675+
667676
func maybeInteropProvider(provider any) (types.Provider, error) {
668677
switch p := provider.(type) {
669678
case types.Provider:

checker/cost.go

+5-1
Original file line numberDiff line numberDiff line change
@@ -672,9 +672,13 @@ func (c *coster) addPath(e ast.Expr, path []string) {
672672
c.exprPath[e.ID()] = path
673673
}
674674

675+
func isAccumulatorVar(name string) bool {
676+
return name == parser.AccumulatorName || name == parser.HiddenAccumulatorName
677+
}
678+
675679
func (c *coster) newAstNode(e ast.Expr) *astNode {
676680
path := c.getPath(e)
677-
if len(path) > 0 && path[0] == parser.AccumulatorName {
681+
if len(path) > 0 && isAccumulatorVar(path[0]) {
678682
// only provide paths to root vars; omit accumulator vars
679683
path = nil
680684
}

common/ast/factory.go

+22-3
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ type ExprFactory interface {
4343
//comprehension.
4444
NewAccuIdent(id int64) Expr
4545

46+
// AccuIdentName reports the name of the accumulator variable to be used within a comprehension.
47+
AccuIdentName() string
48+
4649
// NewLiteral creates an Expr value representing a literal value, such as a string or integer.
4750
NewLiteral(id int64, value ref.Val) Expr
4851

@@ -78,11 +81,23 @@ type ExprFactory interface {
7881
isExprFactory()
7982
}
8083

81-
type baseExprFactory struct{}
84+
type baseExprFactory struct {
85+
accumulatorName string
86+
}
8287

8388
// NewExprFactory creates an ExprFactory instance.
8489
func NewExprFactory() ExprFactory {
85-
return &baseExprFactory{}
90+
return &baseExprFactory{
91+
"__result__",
92+
}
93+
}
94+
95+
// NewExprFactoryWithAccumulator creates an ExprFactory instance with a custom
96+
// accumulator identifier name.
97+
func NewExprFactoryWithAccumulator(id string) ExprFactory {
98+
return &baseExprFactory{
99+
id,
100+
}
86101
}
87102

88103
func (fac *baseExprFactory) NewCall(id int64, function string, args ...Expr) Expr {
@@ -138,7 +153,11 @@ func (fac *baseExprFactory) NewIdent(id int64, name string) Expr {
138153
}
139154

140155
func (fac *baseExprFactory) NewAccuIdent(id int64) Expr {
141-
return fac.NewIdent(id, "__result__")
156+
return fac.NewIdent(id, fac.AccuIdentName())
157+
}
158+
159+
func (fac *baseExprFactory) AccuIdentName() string {
160+
return fac.accumulatorName
142161
}
143162

144163
func (fac *baseExprFactory) NewLiteral(id int64, value ref.Val) Expr {

ext/comprehensions.go

+14-14
Original file line numberDiff line numberDiff line change
@@ -249,7 +249,7 @@ func quantifierAll(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (
249249
target,
250250
iterVar1,
251251
iterVar2,
252-
parser.AccumulatorName,
252+
mef.AccuIdentName(),
253253
/*accuInit=*/ mef.NewLiteral(types.True),
254254
/*condition=*/ mef.NewCall(operators.NotStrictlyFalse, mef.NewAccuIdent()),
255255
/*step=*/ mef.NewCall(operators.LogicalAnd, mef.NewAccuIdent(), args[2]),
@@ -267,7 +267,7 @@ func quantifierExists(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr
267267
target,
268268
iterVar1,
269269
iterVar2,
270-
parser.AccumulatorName,
270+
mef.AccuIdentName(),
271271
/*accuInit=*/ mef.NewLiteral(types.False),
272272
/*condition=*/ mef.NewCall(operators.NotStrictlyFalse, mef.NewCall(operators.LogicalNot, mef.NewAccuIdent())),
273273
/*step=*/ mef.NewCall(operators.LogicalOr, mef.NewAccuIdent(), args[2]),
@@ -285,7 +285,7 @@ func quantifierExistsOne(mef cel.MacroExprFactory, target ast.Expr, args []ast.E
285285
target,
286286
iterVar1,
287287
iterVar2,
288-
parser.AccumulatorName,
288+
mef.AccuIdentName(),
289289
/*accuInit=*/ mef.NewLiteral(types.Int(0)),
290290
/*condition=*/ mef.NewLiteral(types.True),
291291
/*step=*/ mef.NewCall(operators.Conditional, args[2],
@@ -311,18 +311,18 @@ func transformList(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (
311311
transform = args[2]
312312
}
313313

314-
// __result__ = __result__ + [transform]
314+
// accumulator = accumulator + [transform]
315315
step := mef.NewCall(operators.Add, mef.NewAccuIdent(), mef.NewList(transform))
316316
if filter != nil {
317-
// __result__ = (filter) ? __result__ + [transform] : __result__
317+
// accumulator = (filter) ? accumulator + [transform] : accumulator
318318
step = mef.NewCall(operators.Conditional, filter, step, mef.NewAccuIdent())
319319
}
320320

321321
return mef.NewComprehensionTwoVar(
322322
target,
323323
iterVar1,
324324
iterVar2,
325-
parser.AccumulatorName,
325+
mef.AccuIdentName(),
326326
/*accuInit=*/ mef.NewList(),
327327
/*condition=*/ mef.NewLiteral(types.True),
328328
step,
@@ -346,17 +346,17 @@ func transformMap(mef cel.MacroExprFactory, target ast.Expr, args []ast.Expr) (a
346346
transform = args[2]
347347
}
348348

349-
// __result__ = cel.@mapInsert(__result__, iterVar1, transform)
349+
// accumulator = cel.@mapInsert(accumulator, iterVar1, transform)
350350
step := mef.NewCall(mapInsert, mef.NewAccuIdent(), mef.NewIdent(iterVar1), transform)
351351
if filter != nil {
352-
// __result__ = (filter) ? cel.@mapInsert(__result__, iterVar1, transform) : __result__
352+
// accumulator = (filter) ? cel.@mapInsert(accumulator, iterVar1, transform) : accumulator
353353
step = mef.NewCall(operators.Conditional, filter, step, mef.NewAccuIdent())
354354
}
355355
return mef.NewComprehensionTwoVar(
356356
target,
357357
iterVar1,
358358
iterVar2,
359-
parser.AccumulatorName,
359+
mef.AccuIdentName(),
360360
/*accuInit=*/ mef.NewMap(),
361361
/*condition=*/ mef.NewLiteral(types.True),
362362
step,
@@ -380,17 +380,17 @@ func transformMapEntry(mef cel.MacroExprFactory, target ast.Expr, args []ast.Exp
380380
transform = args[2]
381381
}
382382

383-
// __result__ = cel.@mapInsert(__result__, transform)
383+
// accumulator = cel.@mapInsert(accumulator, transform)
384384
step := mef.NewCall(mapInsert, mef.NewAccuIdent(), transform)
385385
if filter != nil {
386-
// __result__ = (filter) ? cel.@mapInsert(__result__, transform) : __result__
386+
// accumulator = (filter) ? cel.@mapInsert(accumulator, transform) : accumulator
387387
step = mef.NewCall(operators.Conditional, filter, step, mef.NewAccuIdent())
388388
}
389389
return mef.NewComprehensionTwoVar(
390390
target,
391391
iterVar1,
392392
iterVar2,
393-
parser.AccumulatorName,
393+
mef.AccuIdentName(),
394394
/*accuInit=*/ mef.NewMap(),
395395
/*condition=*/ mef.NewLiteral(types.True),
396396
step,
@@ -410,10 +410,10 @@ func extractIterVars(mef cel.MacroExprFactory, arg0, arg1 ast.Expr) (string, str
410410
if iterVar1 == iterVar2 {
411411
return "", "", mef.NewError(arg1.ID(), fmt.Sprintf("duplicate variable name: %s", iterVar1))
412412
}
413-
if iterVar1 == parser.AccumulatorName {
413+
if iterVar1 == mef.AccuIdentName() || iterVar1 == parser.AccumulatorName {
414414
return "", "", mef.NewError(arg0.ID(), "iteration variable overwrites accumulator variable")
415415
}
416-
if iterVar2 == parser.AccumulatorName {
416+
if iterVar2 == mef.AccuIdentName() || iterVar2 == parser.AccumulatorName {
417417
return "", "", mef.NewError(arg1.ID(), "iteration variable overwrites accumulator variable")
418418
}
419419
return iterVar1, iterVar2, nil

ext/comprehensions_test.go

+3-1
Original file line numberDiff line numberDiff line change
@@ -621,7 +621,9 @@ func testCompreEnv(t *testing.T, opts ...cel.EnvOption) *cel.Env {
621621
Lists(),
622622
Strings(),
623623
cel.OptionalTypes(),
624-
cel.EnableMacroCallTracking()}
624+
cel.EnableMacroCallTracking(),
625+
cel.EnableHiddenAccumulatorName(true),
626+
}
625627
env, err := cel.NewEnv(append(baseOpts, opts...)...)
626628
if err != nil {
627629
t.Fatalf("cel.NewEnv(TwoVarComprehensions()) failed: %v", err)

parser/helper.go

+5
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,11 @@ func (e *exprHelper) NewAccuIdent() ast.Expr {
470470
return e.exprFactory.NewAccuIdent(e.nextMacroID())
471471
}
472472

473+
// AccuIdentName implements the ExprHelper interface method.
474+
func (e *exprHelper) AccuIdentName() string {
475+
return e.exprFactory.AccuIdentName()
476+
}
477+
473478
// NewGlobalCall implements the ExprHelper interface method.
474479
func (e *exprHelper) NewCall(function string, args ...ast.Expr) ast.Expr {
475480
return e.exprFactory.NewCall(e.nextMacroID(), function, args...)

parser/macro.go

+17-6
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,9 @@ type ExprHelper interface {
225225
// NewAccuIdent returns an accumulator identifier for use with comprehension results.
226226
NewAccuIdent() ast.Expr
227227

228+
// AccuIdentName returns the name of the accumulator identifier.
229+
AccuIdentName() string
230+
228231
// NewCall creates a function call Expr value for a global (free) function.
229232
NewCall(function string, args ...ast.Expr) ast.Expr
230233

@@ -298,6 +301,11 @@ var (
298301
// AccumulatorName is the traditional variable name assigned to the fold accumulator variable.
299302
const AccumulatorName = "__result__"
300303

304+
// HiddenAccumulatorName is a proposed update to the default fold accumlator variable.
305+
// @result is not normally accessible from source, preventing accidental or intentional collisions
306+
// in user expressions.
307+
const HiddenAccumulatorName = "@result"
308+
301309
type quantifierKind int
302310

303311
const (
@@ -342,7 +350,8 @@ func MakeMap(eh ExprHelper, target ast.Expr, args []ast.Expr) (ast.Expr, *common
342350
if !found {
343351
return nil, eh.NewError(args[0].ID(), "argument is not an identifier")
344352
}
345-
if v == AccumulatorName {
353+
accu := eh.AccuIdentName()
354+
if v == accu || v == AccumulatorName {
346355
return nil, eh.NewError(args[0].ID(), "iteration variable overwrites accumulator variable")
347356
}
348357

@@ -364,7 +373,7 @@ func MakeMap(eh ExprHelper, target ast.Expr, args []ast.Expr) (ast.Expr, *common
364373
if filter != nil {
365374
step = eh.NewCall(operators.Conditional, filter, step, eh.NewAccuIdent())
366375
}
367-
return eh.NewComprehension(target, v, AccumulatorName, init, condition, step, eh.NewAccuIdent()), nil
376+
return eh.NewComprehension(target, v, accu, init, condition, step, eh.NewAccuIdent()), nil
368377
}
369378

370379
// MakeFilter expands the input call arguments into a comprehension which produces a list which contains
@@ -375,7 +384,8 @@ func MakeFilter(eh ExprHelper, target ast.Expr, args []ast.Expr) (ast.Expr, *com
375384
if !found {
376385
return nil, eh.NewError(args[0].ID(), "argument is not an identifier")
377386
}
378-
if v == AccumulatorName {
387+
accu := eh.AccuIdentName()
388+
if v == accu || v == AccumulatorName {
379389
return nil, eh.NewError(args[0].ID(), "iteration variable overwrites accumulator variable")
380390
}
381391

@@ -384,7 +394,7 @@ func MakeFilter(eh ExprHelper, target ast.Expr, args []ast.Expr) (ast.Expr, *com
384394
condition := eh.NewLiteral(types.True)
385395
step := eh.NewCall(operators.Add, eh.NewAccuIdent(), eh.NewList(args[0]))
386396
step = eh.NewCall(operators.Conditional, filter, step, eh.NewAccuIdent())
387-
return eh.NewComprehension(target, v, AccumulatorName, init, condition, step, eh.NewAccuIdent()), nil
397+
return eh.NewComprehension(target, v, accu, init, condition, step, eh.NewAccuIdent()), nil
388398
}
389399

390400
// MakeHas expands the input call arguments into a presence test, e.g. has(<operand>.field)
@@ -401,7 +411,8 @@ func makeQuantifier(kind quantifierKind, eh ExprHelper, target ast.Expr, args []
401411
if !found {
402412
return nil, eh.NewError(args[0].ID(), "argument must be a simple name")
403413
}
404-
if v == AccumulatorName {
414+
accu := eh.AccuIdentName()
415+
if v == accu || v == AccumulatorName {
405416
return nil, eh.NewError(args[0].ID(), "iteration variable overwrites accumulator variable")
406417
}
407418

@@ -431,7 +442,7 @@ func makeQuantifier(kind quantifierKind, eh ExprHelper, target ast.Expr, args []
431442
default:
432443
return nil, eh.NewError(args[0].ID(), fmt.Sprintf("unrecognized quantifier '%v'", kind))
433444
}
434-
return eh.NewComprehension(target, v, AccumulatorName, init, condition, step, result), nil
445+
return eh.NewComprehension(target, v, accu, init, condition, step, result), nil
435446
}
436447

437448
func extractIdent(e ast.Expr) (string, bool) {

parser/options.go

+13
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ type options struct {
2727
enableOptionalSyntax bool
2828
enableVariadicOperatorASTs bool
2929
enableIdentEscapeSyntax bool
30+
enableHiddenAccumulatorName bool
3031
}
3132

3233
// Option configures the behavior of the parser.
@@ -137,6 +138,18 @@ func EnableIdentEscapeSyntax(enableIdentEscapeSyntax bool) Option {
137138
}
138139
}
139140

141+
// EnableHiddenAccumulatorName uses an accumulator variable name that is not a
142+
// normally accessible identifier in source for comprehension macros. Compatibility notes:
143+
// with this option enabled, a parsed AST would be semantically the same as if disabled, but would
144+
// have different internal identifiers in any of the built-in comprehension sub-expressions. When
145+
// disabled, it is possible but almost certainly a logic error to access the accumulator variable.
146+
func EnableHiddenAccumulatorName(enabled bool) Option {
147+
return func(opts *options) error {
148+
opts.enableHiddenAccumulatorName = enabled
149+
return nil
150+
}
151+
}
152+
140153
// EnableVariadicOperatorASTs enables a compact representation of chained like-kind commutative
141154
// operators. e.g. `a || b || c || d` -> `call(op='||', args=[a, b, c, d])`
142155
//

parser/parser.go

+5-1
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,11 @@ func mustNewParser(opts ...Option) *Parser {
8989
// Parse parses the expression represented by source and returns the result.
9090
func (p *Parser) Parse(source common.Source) (*ast.AST, *common.Errors) {
9191
errs := common.NewErrors(source)
92-
fac := ast.NewExprFactory()
92+
accu := AccumulatorName
93+
if p.enableHiddenAccumulatorName {
94+
accu = HiddenAccumulatorName
95+
}
96+
fac := ast.NewExprFactoryWithAccumulator(accu)
9397
impl := parser{
9498
errors: &parseErrors{errs},
9599
exprFactory: fac,

0 commit comments

Comments
 (0)