Skip to content

Commit 91fb306

Browse files
authored
Update PruneAst to support constants of optional type (#1109)
* Update PruneAst and unparser to support optional types constants. Signed-off-by: Dennis Buduev <[email protected]> * recursive definitions Signed-off-by: Dennis Buduev <[email protected]> * add prune tests for <list>.last() Signed-off-by: Dennis Buduev <[email protected]> * clean-up Signed-off-by: Dennis Buduev <[email protected]> --------- Signed-off-by: Dennis Buduev <[email protected]> Co-authored-by: Dennis Buduev <[email protected]>
1 parent 33a7f97 commit 91fb306

File tree

3 files changed

+71
-3
lines changed

3 files changed

+71
-3
lines changed

interpreter/prune.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ func PruneAst(expr ast.Expr, macroCalls map[int64]ast.Expr, state EvalState) *as
8888

8989
func (p *astPruner) maybeCreateLiteral(id int64, val ref.Val) (ast.Expr, bool) {
9090
switch v := val.(type) {
91-
case types.Bool, types.Bytes, types.Double, types.Int, types.Null, types.String, types.Uint:
91+
case types.Bool, types.Bytes, types.Double, types.Int, types.Null, types.String, types.Uint, *types.Optional:
9292
p.state.SetValue(id, val)
9393
return p.NewLiteral(id, val), true
9494
case types.Duration:

interpreter/prune_test.go

+44
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
"github.com/google/cel-go/common/decls"
2424
"github.com/google/cel-go/common/types"
2525
"github.com/google/cel-go/common/types/ref"
26+
"github.com/google/cel-go/common/types/traits"
2627
"github.com/google/cel-go/parser"
2728
"github.com/google/cel-go/test"
2829

@@ -216,6 +217,36 @@ var testCases = []testInfo{
216217
expr: `a.?b`,
217218
out: `a.?b`,
218219
},
220+
{
221+
in: partialActivation(map[string]any{"a": map[string]any{"b": 10}}),
222+
expr: `a.?b`,
223+
out: `optional.of(10)`,
224+
},
225+
{
226+
in: partialActivation(map[string]any{"a": map[string]any{"b": 10}}),
227+
expr: `a[?"b"]`,
228+
out: `optional.of(10)`,
229+
},
230+
{
231+
in: unknownActivation(),
232+
expr: `{'b': optional.of(10)}.?b`,
233+
out: `optional.of(optional.of(10))`,
234+
},
235+
{
236+
in: partialActivation(map[string]any{"a": map[string]any{}}),
237+
expr: `a.?b`,
238+
out: `optional.none()`,
239+
},
240+
{
241+
in: unknownActivation(),
242+
expr: `[10].last()`,
243+
out: "optional.of(10)",
244+
},
245+
{
246+
in: unknownActivation(),
247+
expr: `[].last()`,
248+
out: "optional.none()",
249+
},
219250
{
220251
in: unknownActivation("a"),
221252
expr: `a[?"b"]`,
@@ -561,5 +592,18 @@ func optionalDecls(t *testing.T) []*decls.FunctionDecl {
561592
types.NewTypeParamType("K"),
562593
}, optionalType),
563594
),
595+
funcDecl(t, "last", decls.Overload("list_last", []*types.Type{paramType}, optionalType,
596+
decls.UnaryBinding(func(v ref.Val) ref.Val {
597+
list := v.(traits.Lister)
598+
sz := list.Size().Value().(int64)
599+
600+
if sz == 0 {
601+
return types.OptionalNone
602+
}
603+
604+
return types.OptionalOf(list.Get(types.Int(sz - 1)))
605+
}),
606+
),
607+
),
564608
}
565609
}

parser/unparser.go

+26-2
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ import (
2424
"github.com/google/cel-go/common/ast"
2525
"github.com/google/cel-go/common/operators"
2626
"github.com/google/cel-go/common/types"
27+
"github.com/google/cel-go/common/types/ref"
2728
)
2829

2930
// Unparse takes an input expression and source position information and generates a human-readable
@@ -273,8 +274,17 @@ func (un *unparser) visitCallUnary(expr ast.Expr) error {
273274
return un.visitMaybeNested(args[0], nested)
274275
}
275276

276-
func (un *unparser) visitConst(expr ast.Expr) error {
277-
val := expr.AsLiteral()
277+
func (un *unparser) visitConstVal(val ref.Val) error {
278+
optional := false
279+
if optVal, ok := val.(*types.Optional); ok {
280+
if !optVal.HasValue() {
281+
un.str.WriteString("optional.none()")
282+
return nil
283+
}
284+
optional = true
285+
un.str.WriteString("optional.of(")
286+
val = optVal.GetValue()
287+
}
278288
switch val := val.(type) {
279289
case types.Bool:
280290
un.str.WriteString(strconv.FormatBool(bool(val)))
@@ -303,7 +313,21 @@ func (un *unparser) visitConst(expr ast.Expr) error {
303313
ui := strconv.FormatUint(uint64(val), 10)
304314
un.str.WriteString(ui)
305315
un.str.WriteString("u")
316+
case *types.Optional:
317+
if err := un.visitConstVal(val); err != nil {
318+
return err
319+
}
306320
default:
321+
return errors.New("unsupported constant")
322+
}
323+
if optional {
324+
un.str.WriteString(")")
325+
}
326+
return nil
327+
}
328+
func (un *unparser) visitConst(expr ast.Expr) error {
329+
val := expr.AsLiteral()
330+
if err := un.visitConstVal(val); err != nil {
307331
return fmt.Errorf("unsupported constant: %v", expr)
308332
}
309333
return nil

0 commit comments

Comments
 (0)