Skip to content

Commit 8ea8ed9

Browse files
authored
Merge pull request #15 from drystone/loop-vars
2 parents f435dce + 0fabdff commit 8ea8ed9

File tree

2 files changed

+40
-70
lines changed

2 files changed

+40
-70
lines changed

pkg/paralleltest/paralleltest.go

+39-69
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package paralleltest
22

33
import (
44
"go/ast"
5+
"go/types"
56
"strings"
67

78
"golang.org/x/tools/go/analysis"
@@ -34,9 +35,8 @@ func run(pass *analysis.Pass) (interface{}, error) {
3435
funcDecl := node.(*ast.FuncDecl)
3536
var funcHasParallelMethod,
3637
rangeStatementOverTestCasesExists,
37-
rangeStatementHasParallelMethod,
38-
testLoopVariableReinitialised bool
39-
var testRunLoopIdentifier string
38+
rangeStatementHasParallelMethod bool
39+
var loopVariableUsedInRun *string
4040
var numberOfTestRun int
4141
var positionOfTestRunNode []ast.Node
4242
var rangeNode ast.Node
@@ -81,6 +81,13 @@ func run(pass *analysis.Pass) (interface{}, error) {
8181
case *ast.RangeStmt:
8282
rangeNode = v
8383

84+
var loopVars []types.Object
85+
for _, expr := range []ast.Expr{v.Key, v.Value} {
86+
if id, ok := expr.(*ast.Ident); ok {
87+
loopVars = append(loopVars, pass.TypesInfo.ObjectOf(id))
88+
}
89+
}
90+
8491
ast.Inspect(v, func(n ast.Node) bool {
8592
// nolint: gocritic
8693
switch r := n.(type) {
@@ -90,26 +97,20 @@ func run(pass *analysis.Pass) (interface{}, error) {
9097
innerTestVar := getRunCallbackParameterName(r.X)
9198

9299
rangeStatementOverTestCasesExists = true
93-
testRunLoopIdentifier = methodRunFirstArgumentObjectName(r.X)
94100

95101
if !rangeStatementHasParallelMethod {
96102
rangeStatementHasParallelMethod = methodParallelIsCalledInMethodRun(r.X, innerTestVar)
97103
}
104+
105+
if loopVariableUsedInRun == nil {
106+
if run, ok := r.X.(*ast.CallExpr); ok {
107+
loopVariableUsedInRun = loopVarReferencedInRun(run, loopVars, pass.TypesInfo)
108+
}
109+
}
98110
}
99111
}
100112
return true
101113
})
102-
103-
// Check for the range loop value identifier re assignment
104-
// More info here https://gist.github.com/kunwardeep/80c2e9f3d3256c894898bae82d9f75d0
105-
if rangeStatementOverTestCasesExists {
106-
var rangeValueIdentifier string
107-
if i, ok := v.Value.(*ast.Ident); ok {
108-
rangeValueIdentifier = i.Name
109-
}
110-
111-
testLoopVariableReinitialised = testCaseLoopVariableReinitialised(v.Body.List, rangeValueIdentifier, testRunLoopIdentifier)
112-
}
113114
}
114115
}
115116

@@ -120,12 +121,8 @@ func run(pass *analysis.Pass) (interface{}, error) {
120121
if rangeStatementOverTestCasesExists && rangeNode != nil {
121122
if !rangeStatementHasParallelMethod {
122123
pass.Reportf(rangeNode.Pos(), "Range statement for test %s missing the call to method parallel in test Run\n", funcDecl.Name.Name)
123-
} else {
124-
if testRunLoopIdentifier == "" {
125-
pass.Reportf(rangeNode.Pos(), "Range statement for test %s does not use range value in test Run\n", funcDecl.Name.Name)
126-
} else if !testLoopVariableReinitialised {
127-
pass.Reportf(rangeNode.Pos(), "Range statement for test %s does not reinitialise the variable %s\n", funcDecl.Name.Name, testRunLoopIdentifier)
128-
}
124+
} else if loopVariableUsedInRun != nil {
125+
pass.Reportf(rangeNode.Pos(), "Range statement for test %s does not reinitialise the variable %s\n", funcDecl.Name.Name, *loopVariableUsedInRun)
129126
}
130127
}
131128

@@ -140,38 +137,6 @@ func run(pass *analysis.Pass) (interface{}, error) {
140137
return nil, nil
141138
}
142139

143-
func testCaseLoopVariableReinitialised(statements []ast.Stmt, rangeValueIdentifier string, testRunLoopIdentifier string) bool {
144-
if len(statements) > 1 {
145-
for _, s := range statements {
146-
leftIdentifier, rightIdentifier := getLeftAndRightIdentifier(s)
147-
if leftIdentifier == testRunLoopIdentifier && rightIdentifier == rangeValueIdentifier {
148-
return true
149-
}
150-
}
151-
}
152-
return false
153-
}
154-
155-
// Return the left hand side and the right hand side identifiers name
156-
func getLeftAndRightIdentifier(s ast.Stmt) (string, string) {
157-
var leftIdentifier, rightIdentifier string
158-
// nolint: gocritic
159-
switch v := s.(type) {
160-
case *ast.AssignStmt:
161-
if len(v.Rhs) == 1 {
162-
if i, ok := v.Rhs[0].(*ast.Ident); ok {
163-
rightIdentifier = i.Name
164-
}
165-
}
166-
if len(v.Lhs) == 1 {
167-
if i, ok := v.Lhs[0].(*ast.Ident); ok {
168-
leftIdentifier = i.Name
169-
}
170-
}
171-
}
172-
return leftIdentifier, rightIdentifier
173-
}
174-
175140
func methodParallelIsCalledInMethodRun(node ast.Node, testVar string) bool {
176141
var methodParallelCalled bool
177142
// nolint: gocritic
@@ -247,22 +212,6 @@ func getRunCallbackParameterName(node ast.Node) string {
247212
return ""
248213
}
249214

250-
// Gets the object name `tc` from method t.Run(tc.Foo, func(t *testing.T)
251-
func methodRunFirstArgumentObjectName(node ast.Node) string {
252-
// nolint: gocritic
253-
switch n := node.(type) {
254-
case *ast.CallExpr:
255-
for _, arg := range n.Args {
256-
if s, ok := arg.(*ast.SelectorExpr); ok {
257-
if i, ok := s.X.(*ast.Ident); ok {
258-
return i.Name
259-
}
260-
}
261-
}
262-
}
263-
return ""
264-
}
265-
266215
// Checks if the function has the param type *testing.T; if it does, then the
267216
// parameter name is returned, too.
268217
func isTestFunction(funcDecl *ast.FuncDecl) (bool, string) {
@@ -291,3 +240,24 @@ func isTestFunction(funcDecl *ast.FuncDecl) (bool, string) {
291240

292241
return false, ""
293242
}
243+
244+
func loopVarReferencedInRun(call *ast.CallExpr, vars []types.Object, typeInfo *types.Info) (found *string) {
245+
if len(call.Args) != 2 {
246+
return
247+
}
248+
249+
ast.Inspect(call.Args[1], func(n ast.Node) bool {
250+
ident, ok := n.(*ast.Ident)
251+
if !ok {
252+
return true
253+
}
254+
for _, o := range vars {
255+
if typeInfo.ObjectOf(ident) == o {
256+
found = &ident.Name
257+
}
258+
}
259+
return true
260+
})
261+
262+
return
263+
}

pkg/paralleltest/testdata/src/t/t_test.go

+1-1
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ func TestFunctionRangeNotUsingRangeValueInTDotRun(t *testing.T) {
8181
testCases := []struct {
8282
name string
8383
}{{name: "foo"}}
84-
for _, tc := range testCases { // want "Range statement for test TestFunctionRangeNotUsingRangeValueInTDotRun does not use range value in test Run"
84+
for _, tc := range testCases { // want "Range statement for test TestFunctionRangeNotUsingRangeValueInTDotRun does not reinitialise the variable tc"
8585
t.Run("tc.name", func(t *testing.T) {
8686
t.Parallel()
8787
fmt.Println(tc.name)

0 commit comments

Comments
 (0)