Skip to content

Commit 003e476

Browse files
authored
ruleguard: implement SinkType filter (#384)
1 parent 3c6c7c9 commit 003e476

File tree

12 files changed

+295
-8
lines changed

12 files changed

+295
-8
lines changed

analyzer/testdata/src/filtertest/f1.go

+120
Original file line numberDiff line numberDiff line change
@@ -955,3 +955,123 @@ func detectGlobal() {
955955
print(globalVar)
956956
}
957957
}
958+
959+
func detectSinkType() {
960+
// Call argument context.
961+
_ = acceptReader(newIface("sink is io.Reader").(*bytes.Buffer)) // want `true`
962+
_ = acceptReader(newIface("sink is io.Reader").(io.Reader)) // want `true`
963+
_ = acceptReader((newIface("sink is io.Reader").(io.Reader))) // want `true`
964+
_ = acceptBuffer(newIface("sink is io.Reader").(*bytes.Buffer))
965+
_ = acceptReaderVariadic(10, newIface("sink is io.Reader").(*bytes.Buffer)) // want `true`
966+
_ = acceptReaderVariadic(10, newIface("sink is io.Reader").(*bytes.Buffer)) // want `true`
967+
_ = acceptReaderVariadic(10, nil, newIface("sink is io.Reader").(*bytes.Buffer)) // want `true`
968+
_ = acceptReaderVariadic(10, newIface("sink is io.Reader").([]io.Reader)...)
969+
_ = acceptWriterVariadic(10, newIface("sink is io.Reader").(*bytes.Buffer))
970+
_ = acceptWriterVariadic(10, nil, newIface("sink is io.Reader").(*bytes.Buffer))
971+
_ = acceptWriterVariadic(10, nil, nil, newIface("sink is io.Reader").(*bytes.Buffer))
972+
_ = acceptVariadic(10, newIface("sink is io.Reader").(*bytes.Buffer))
973+
_ = acceptVariadic(10, nil, newIface("sink is io.Reader").(*bytes.Buffer))
974+
_ = acceptVariadic(10, nil, nil, newIface("sink is io.Reader").(*bytes.Buffer))
975+
fmt.Println(newIface("sink is interface{}").(int)) // want `true`
976+
fmt.Println(1, newIface("sink is interface{}").(io.Reader)) // want `true`
977+
978+
// Type conversion context.
979+
_ = io.Reader(newIface("sink is io.Reader").(*bytes.Buffer)) // want `true`
980+
_ = io.Writer(newIface("sink is io.Reader").(*bytes.Buffer))
981+
982+
// Return stmt context.
983+
{
984+
_ = func() (io.Reader, io.Writer) {
985+
return newIface("sink is io.Reader").(*bytes.Buffer), nil // want `true`
986+
}
987+
_ = func() (io.Reader, io.Writer) {
988+
return nil, newIface("sink is io.Reader").(*bytes.Buffer)
989+
}
990+
_ = func() (io.Writer, io.Reader) {
991+
return nil, newIface("sink is io.Reader").(*bytes.Buffer) // want `true`
992+
}
993+
}
994+
995+
// Assignment context.
996+
{
997+
var r io.Reader = (newIface("sink is io.Reader").(*bytes.Buffer)) // want `true`
998+
var _ io.Reader = newIface("sink is io.Reader").(*bytes.Buffer) // want `true`
999+
var w io.Writer = newIface("sink is io.Reader").(*bytes.Buffer)
1000+
x := newIface("sink is io.Reader").(*bytes.Buffer)
1001+
_ = r
1002+
_ = w
1003+
_ = x
1004+
var readers map[string]io.Reader
1005+
readers["foo"] = newIface("sink is io.Reader").(*bytes.Buffer) // want `true`
1006+
var writers map[string]io.Writer
1007+
writers["foo"] = newIface("sink is io.Reader").(*bytes.Buffer)
1008+
var foo exampleStruct
1009+
foo.r = newIface("sink is io.Reader").(*bytes.Buffer) // want `true`
1010+
foo.buf = newIface("sink is io.Reader").(*bytes.Buffer)
1011+
foo.w = newIface("sink is io.Reader").(*bytes.Buffer)
1012+
}
1013+
1014+
// Index expr context
1015+
{
1016+
var readerKeys map[io.Reader]string
1017+
readerKeys[newIface("sink is io.Reader").(*bytes.Buffer)] = "ok" // want `true`
1018+
readerKeys[(newIface("sink is io.Reader").(*bytes.Buffer))] = "ok" // want `true`
1019+
var writerKeys map[io.Writer]string
1020+
writerKeys[newIface("sink is io.Reader").(*bytes.Buffer)] = "ok"
1021+
writerKeys[(newIface("sink is io.Reader").(*bytes.Buffer))] = "ok"
1022+
}
1023+
1024+
// Composite lit element context.
1025+
_ = []io.Reader{
1026+
newIface("sink is io.Reader").(*bytes.Buffer), // want `true`
1027+
}
1028+
_ = []io.Reader{
1029+
10: newIface("sink is io.Reader").(*bytes.Buffer), // want `true`
1030+
}
1031+
_ = [10]io.Reader{
1032+
4: newIface("sink is io.Reader").(*bytes.Buffer), // want `true`
1033+
}
1034+
_ = map[string]io.Reader{
1035+
"foo": newIface("sink is io.Reader").(*bytes.Buffer), // want `true`
1036+
}
1037+
_ = map[io.Reader]string{
1038+
newIface("sink is io.Reader").(*bytes.Buffer): "foo", // want `true`
1039+
}
1040+
_ = map[io.Reader]string{
1041+
(newIface("sink is io.Reader").(*bytes.Buffer)): "foo", // want `true`
1042+
}
1043+
_ = []io.Writer{
1044+
(newIface("sink is io.Reader").(*bytes.Buffer)),
1045+
}
1046+
_ = exampleStruct{
1047+
w: newIface("sink is io.Reader").(*bytes.Buffer),
1048+
r: newIface("sink is io.Reader").(*bytes.Buffer), // want `true`
1049+
}
1050+
_ = []interface{}{
1051+
newIface("sink is interface{}").(*bytes.Buffer), // want `true`
1052+
newIface("sink is interface{}").(int), // want `true`
1053+
}
1054+
}
1055+
1056+
func detectSinkType2() io.Reader {
1057+
return newIface("sink is io.Reader").(*bytes.Buffer) // want `true`
1058+
}
1059+
1060+
func detectSinkType3() io.Writer {
1061+
return newIface("sink is io.Reader").(*bytes.Buffer)
1062+
}
1063+
1064+
func newIface(key string) interface{} { return nil }
1065+
1066+
func acceptReaderVariadic(a int, r ...io.Reader) int { return 0 }
1067+
func acceptWriterVariadic(a int, r ...io.Writer) int { return 0 }
1068+
func acceptVariadic(a int, r ...interface{}) int { return 0 }
1069+
func acceptReader(r io.Reader) int { return 0 }
1070+
func acceptWriter(r io.Writer) int { return 0 }
1071+
func acceptBuffer(b *bytes.Buffer) int { return 0 }
1072+
1073+
type exampleStruct struct {
1074+
r io.Reader
1075+
w io.Writer
1076+
buf *bytes.Buffer
1077+
}

analyzer/testdata/src/filtertest/rules.go

+8
Original file line numberDiff line numberDiff line change
@@ -274,4 +274,12 @@ func testRules(m dsl.Matcher) {
274274
`$x := time.Now().String()`).
275275
Where(m["x"].Object.IsGlobal()).
276276
Report(`global var`)
277+
278+
m.Match(`newIface("sink is io.Reader").($_)`).
279+
Where(m["$$"].SinkType.Is(`io.Reader`)).
280+
Report(`true`)
281+
282+
m.Match(`newIface("sink is interface{}").($_)`).
283+
Where(m["$$"].SinkType.Is(`interface{}`)).
284+
Report(`true`)
277285
}

go.mod

+1-1
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ go 1.17
55
require (
66
github.com/go-toolsmith/astcopy v1.0.0
77
github.com/google/go-cmp v0.5.6
8-
github.com/quasilyte/go-ruleguard/dsl v0.3.18
8+
github.com/quasilyte/go-ruleguard/dsl v0.3.19
99
github.com/quasilyte/go-ruleguard/rules v0.0.0-20211022131956-028d6511ab71
1010
github.com/quasilyte/gogrep v0.0.0-20220120141003-628d8b3623b5
1111
github.com/quasilyte/stdinfo v0.0.0-20220114132959-f7386bf02567

go.sum

+2-4
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,8 @@ github.com/google/go-cmp v0.5.6 h1:BKbKCqvP6I+rmFHt06ZmyQtvB8xAkWdhFyr0ZUNZcxQ=
1010
github.com/google/go-cmp v0.5.6/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
1111
github.com/quasilyte/go-ruleguard v0.3.1-0.20210203134552-1b5a410e1cc8/go.mod h1:KsAh3x0e7Fkpgs+Q9pNLS5XpFSvYCEVl5gP9Pp1xp30=
1212
github.com/quasilyte/go-ruleguard/dsl v0.3.0/go.mod h1:KeCP03KrjuSO0H1kTuZQCWlQPulDV6YMIXmpQss17rU=
13-
github.com/quasilyte/go-ruleguard/dsl v0.3.17 h1:L5xf3nifnRIdYe9vyMuY2sDnZHIgQol/fDq74FQz7ZY=
14-
github.com/quasilyte/go-ruleguard/dsl v0.3.17/go.mod h1:KeCP03KrjuSO0H1kTuZQCWlQPulDV6YMIXmpQss17rU=
15-
github.com/quasilyte/go-ruleguard/dsl v0.3.18 h1:gzHcFxmTwhn+ZKZd6nGw7JyjoDcYuwcA+TY5MNn9oMk=
16-
github.com/quasilyte/go-ruleguard/dsl v0.3.18/go.mod h1:KeCP03KrjuSO0H1kTuZQCWlQPulDV6YMIXmpQss17rU=
13+
github.com/quasilyte/go-ruleguard/dsl v0.3.19 h1:5+KTKb2YREUYiqZFEIuifFyBxlcCUPWgNZkWy71XS0Q=
14+
github.com/quasilyte/go-ruleguard/dsl v0.3.19/go.mod h1:KeCP03KrjuSO0H1kTuZQCWlQPulDV6YMIXmpQss17rU=
1715
github.com/quasilyte/go-ruleguard/rules v0.0.0-20201231183845-9e62ed36efe1/go.mod h1:7JTjp89EGyU1d6XfBiXihJNG37wB2VRkd125Q1u7Plc=
1816
github.com/quasilyte/go-ruleguard/rules v0.0.0-20211022131956-028d6511ab71 h1:CNooiryw5aisadVfzneSZPswRWvnVW8hF1bS/vo8ReI=
1917
github.com/quasilyte/go-ruleguard/rules v0.0.0-20211022131956-028d6511ab71/go.mod h1:4cgAphtvu7Ftv7vOT2ZOYhC6CvBxZixcasr8qIOTA50=

ruleguard/filters.go

+136
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99

1010
"github.com/quasilyte/gogrep"
1111
"github.com/quasilyte/gogrep/nodetag"
12+
"golang.org/x/tools/go/ast/astutil"
1213

1314
"github.com/quasilyte/go-ruleguard/internal/xtypes"
1415
"github.com/quasilyte/go-ruleguard/ruleguard/quasigo"
@@ -303,6 +304,21 @@ func makeTypesIdenticalFilter(src, lhsVarname, rhsVarname string) filterFunc {
303304
}
304305
}
305306

307+
func makeRootSinkTypeIsFilter(src string, pat *typematch.Pattern) filterFunc {
308+
return func(params *filterParams) matchFilterResult {
309+
// TODO(quasilyte): add variadic support?
310+
e, ok := params.match.Node().(ast.Expr)
311+
if ok {
312+
parent, kv := findSinkRoot(params)
313+
typ := findSinkType(params, parent, kv, e)
314+
if pat.MatchIdentical(params.typematchState, typ) {
315+
return filterSuccess
316+
}
317+
}
318+
return filterFailure(src)
319+
}
320+
}
321+
306322
func makeTypeIsFilter(src, varname string, underlying bool, pat *typematch.Pattern) filterFunc {
307323
if underlying {
308324
return func(params *filterParams) matchFilterResult {
@@ -660,3 +676,123 @@ func typeHasPointers(typ types.Type) bool {
660676
return true
661677
}
662678
}
679+
680+
func findSinkRoot(params *filterParams) (ast.Node, *ast.KeyValueExpr) {
681+
for i := 1; i < params.nodePath.Len(); i++ {
682+
switch n := params.nodePath.NthParent(i).(type) {
683+
case *ast.ParenExpr:
684+
// Skip and continue.
685+
continue
686+
case *ast.KeyValueExpr:
687+
return params.nodePath.NthParent(i + 1).(ast.Expr), n
688+
default:
689+
return n, nil
690+
}
691+
}
692+
return nil, nil
693+
}
694+
695+
func findContainingFunc(params *filterParams) *types.Signature {
696+
for i := 2; i < params.nodePath.Len(); i++ {
697+
switch n := params.nodePath.NthParent(i).(type) {
698+
case *ast.FuncDecl:
699+
fn, ok := params.ctx.Types.TypeOf(n.Name).(*types.Signature)
700+
if ok {
701+
return fn
702+
}
703+
case *ast.FuncLit:
704+
fn, ok := params.ctx.Types.TypeOf(n.Type).(*types.Signature)
705+
if ok {
706+
return fn
707+
}
708+
}
709+
}
710+
return nil
711+
}
712+
713+
func findSinkType(params *filterParams, parent ast.Node, kv *ast.KeyValueExpr, e ast.Expr) types.Type {
714+
switch parent := parent.(type) {
715+
case *ast.ValueSpec:
716+
return params.ctx.Types.TypeOf(parent.Type)
717+
718+
case *ast.ReturnStmt:
719+
for i, result := range parent.Results {
720+
if astutil.Unparen(result) != e {
721+
continue
722+
}
723+
sig := findContainingFunc(params)
724+
if sig == nil {
725+
break
726+
}
727+
return sig.Results().At(i).Type()
728+
}
729+
730+
case *ast.IndexExpr:
731+
if astutil.Unparen(parent.Index) == e {
732+
switch typ := params.ctx.Types.TypeOf(parent.X).Underlying().(type) {
733+
case *types.Map:
734+
return typ.Key()
735+
case *types.Slice, *types.Array:
736+
return nil // TODO: some untyped int type?
737+
}
738+
}
739+
740+
case *ast.AssignStmt:
741+
if parent.Tok != token.ASSIGN || len(parent.Lhs) != len(parent.Rhs) {
742+
break
743+
}
744+
for i, rhs := range parent.Rhs {
745+
if rhs == e {
746+
return params.ctx.Types.TypeOf(parent.Lhs[i])
747+
}
748+
}
749+
750+
case *ast.CompositeLit:
751+
switch typ := params.ctx.Types.TypeOf(parent).Underlying().(type) {
752+
case *types.Slice:
753+
return typ.Elem()
754+
case *types.Array:
755+
return typ.Elem()
756+
case *types.Map:
757+
if astutil.Unparen(kv.Key) == e {
758+
return typ.Key()
759+
}
760+
return typ.Elem()
761+
case *types.Struct:
762+
fieldName, ok := kv.Key.(*ast.Ident)
763+
if !ok {
764+
break
765+
}
766+
for i := 0; i < typ.NumFields(); i++ {
767+
field := typ.Field(i)
768+
if field.Name() == fieldName.String() {
769+
return field.Type()
770+
}
771+
}
772+
}
773+
774+
case *ast.CallExpr:
775+
switch typ := params.ctx.Types.TypeOf(parent.Fun).(type) {
776+
case *types.Signature:
777+
// A function call argument.
778+
for i, arg := range parent.Args {
779+
if astutil.Unparen(arg) != e {
780+
continue
781+
}
782+
isVariadicArg := (i >= typ.Params().Len()-1) && typ.Variadic()
783+
if isVariadicArg && !parent.Ellipsis.IsValid() {
784+
return typ.Params().At(typ.Params().Len() - 1).Type().(*types.Slice).Elem()
785+
}
786+
if i < typ.Params().Len() {
787+
return typ.Params().At(i).Type()
788+
}
789+
break
790+
}
791+
default:
792+
// Probably a type cast.
793+
return typ
794+
}
795+
}
796+
797+
return invalidType
798+
}

ruleguard/gorule.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ func (params *filterParams) typeofNode(n ast.Node) types.Type {
110110
if typ := params.ctx.Types.TypeOf(e); typ != nil {
111111
return typ
112112
}
113-
return types.Typ[types.Invalid]
113+
return invalidType
114114
}
115115

116116
func mergeRuleSets(toMerge []*goRuleSet) (*goRuleSet, error) {

ruleguard/ir/filter_op.gen.go

+4
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

ruleguard/ir/gen_filter_op.go

+1
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,7 @@ func main() {
8686
{name: "Int", comment: "$Value holds an int64 constant", valueType: "int64", flags: flagIsBasicLit},
8787

8888
{name: "RootNodeParentIs", comment: "m[`$$`].Node.Parent().Is($Args[0])"},
89+
{name: "RootSinkTypeIs", comment: "m[`$$`].SinkType.Is($Args[0])"},
8990
}
9091

9192
var buf bytes.Buffer

ruleguard/ir_loader.go

+12
Original file line numberDiff line numberDiff line change
@@ -608,6 +608,18 @@ func (l *irLoader) newFilter(filter ir.FilterExpr, info *filterInfo) (matchFilte
608608
}
609609
result.fn = makeNodeIsFilter(result.src, filter.Value.(string), tag)
610610

611+
case ir.FilterRootSinkTypeIsOp:
612+
typeString := l.unwrapStringExpr(filter.Args[0])
613+
if typeString == "" {
614+
return result, l.errorf(filter.Line, nil, "expected a non-empty string argument")
615+
}
616+
ctx := typematch.Context{Itab: l.itab}
617+
pat, err := typematch.Parse(&ctx, typeString)
618+
if err != nil {
619+
return result, l.errorf(filter.Line, err, "parse type expr")
620+
}
621+
result.fn = makeRootSinkTypeIsFilter(result.src, pat)
622+
611623
case ir.FilterVarTypeHasPointersOp:
612624
result.fn = makeTypeHasPointersFilter(result.src, filter.Value.(string))
613625

ruleguard/irconv/irconv.go

+6
Original file line numberDiff line numberDiff line change
@@ -746,6 +746,12 @@ func (conv *converter) convertFilterExprImpl(e ast.Expr) ir.FilterExpr {
746746
return ir.FilterExpr{Op: ir.FilterVarObjectIsOp, Value: op.varName, Args: args}
747747
case "Object.IsGlobal":
748748
return ir.FilterExpr{Op: ir.FilterVarObjectIsGlobalOp, Value: op.varName}
749+
case "SinkType.Is":
750+
if op.varName != "$$" {
751+
// TODO: remove this restriction.
752+
panic(conv.errorf(e.Args[0], "sink type is only implemented for $$ var"))
753+
}
754+
return ir.FilterExpr{Op: ir.FilterRootSinkTypeIsOp, Value: op.varName, Args: args}
749755
case "Type.HasPointers":
750756
return ir.FilterExpr{Op: ir.FilterVarTypeHasPointersOp, Value: op.varName}
751757
case "Type.Is":

ruleguard/nodepath.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ func (p nodePath) Current() ast.Node {
3131
}
3232

3333
func (p nodePath) NthParent(n int) ast.Node {
34-
index := len(p.stack) - n - 1
35-
if index >= 0 {
34+
index := uint(len(p.stack) - n - 1)
35+
if index < uint(len(p.stack)) {
3636
return p.stack[index]
3737
}
3838
return nil

ruleguard/utils.go

+2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ import (
1111
"strings"
1212
)
1313

14+
var invalidType = types.Typ[types.Invalid]
15+
1416
func regexpHasCaptureGroups(pattern string) bool {
1517
// regexp.Compile() uses syntax.Perl flags, so
1618
// we use the same flags here.

0 commit comments

Comments
 (0)