@@ -2,6 +2,7 @@ package paralleltest
2
2
3
3
import (
4
4
"go/ast"
5
+ "go/types"
5
6
"strings"
6
7
7
8
"golang.org/x/tools/go/analysis"
@@ -34,9 +35,8 @@ func run(pass *analysis.Pass) (interface{}, error) {
34
35
funcDecl := node .(* ast.FuncDecl )
35
36
var funcHasParallelMethod ,
36
37
rangeStatementOverTestCasesExists ,
37
- rangeStatementHasParallelMethod ,
38
- testLoopVariableReinitialised bool
39
- var testRunLoopIdentifier string
38
+ rangeStatementHasParallelMethod bool
39
+ var loopVariableUsedInRun * string
40
40
var numberOfTestRun int
41
41
var positionOfTestRunNode []ast.Node
42
42
var rangeNode ast.Node
@@ -81,6 +81,13 @@ func run(pass *analysis.Pass) (interface{}, error) {
81
81
case * ast.RangeStmt :
82
82
rangeNode = v
83
83
84
+ var loopVars []types.Object
85
+ for _ , expr := range []ast.Expr {v .Key , v .Value } {
86
+ if id , ok := expr .(* ast.Ident ); ok {
87
+ loopVars = append (loopVars , pass .TypesInfo .ObjectOf (id ))
88
+ }
89
+ }
90
+
84
91
ast .Inspect (v , func (n ast.Node ) bool {
85
92
// nolint: gocritic
86
93
switch r := n .(type ) {
@@ -90,26 +97,20 @@ func run(pass *analysis.Pass) (interface{}, error) {
90
97
innerTestVar := getRunCallbackParameterName (r .X )
91
98
92
99
rangeStatementOverTestCasesExists = true
93
- testRunLoopIdentifier = methodRunFirstArgumentObjectName (r .X )
94
100
95
101
if ! rangeStatementHasParallelMethod {
96
102
rangeStatementHasParallelMethod = methodParallelIsCalledInMethodRun (r .X , innerTestVar )
97
103
}
104
+
105
+ if loopVariableUsedInRun == nil {
106
+ if run , ok := r .X .(* ast.CallExpr ); ok {
107
+ loopVariableUsedInRun = loopVarReferencedInRun (run , loopVars , pass .TypesInfo )
108
+ }
109
+ }
98
110
}
99
111
}
100
112
return true
101
113
})
102
-
103
- // Check for the range loop value identifier re assignment
104
- // More info here https://gist.github.com/kunwardeep/80c2e9f3d3256c894898bae82d9f75d0
105
- if rangeStatementOverTestCasesExists {
106
- var rangeValueIdentifier string
107
- if i , ok := v .Value .(* ast.Ident ); ok {
108
- rangeValueIdentifier = i .Name
109
- }
110
-
111
- testLoopVariableReinitialised = testCaseLoopVariableReinitialised (v .Body .List , rangeValueIdentifier , testRunLoopIdentifier )
112
- }
113
114
}
114
115
}
115
116
@@ -120,12 +121,8 @@ func run(pass *analysis.Pass) (interface{}, error) {
120
121
if rangeStatementOverTestCasesExists && rangeNode != nil {
121
122
if ! rangeStatementHasParallelMethod {
122
123
pass .Reportf (rangeNode .Pos (), "Range statement for test %s missing the call to method parallel in test Run\n " , funcDecl .Name .Name )
123
- } else {
124
- if testRunLoopIdentifier == "" {
125
- pass .Reportf (rangeNode .Pos (), "Range statement for test %s does not use range value in test Run\n " , funcDecl .Name .Name )
126
- } else if ! testLoopVariableReinitialised {
127
- pass .Reportf (rangeNode .Pos (), "Range statement for test %s does not reinitialise the variable %s\n " , funcDecl .Name .Name , testRunLoopIdentifier )
128
- }
124
+ } else if loopVariableUsedInRun != nil {
125
+ pass .Reportf (rangeNode .Pos (), "Range statement for test %s does not reinitialise the variable %s\n " , funcDecl .Name .Name , * loopVariableUsedInRun )
129
126
}
130
127
}
131
128
@@ -140,38 +137,6 @@ func run(pass *analysis.Pass) (interface{}, error) {
140
137
return nil , nil
141
138
}
142
139
143
- func testCaseLoopVariableReinitialised (statements []ast.Stmt , rangeValueIdentifier string , testRunLoopIdentifier string ) bool {
144
- if len (statements ) > 1 {
145
- for _ , s := range statements {
146
- leftIdentifier , rightIdentifier := getLeftAndRightIdentifier (s )
147
- if leftIdentifier == testRunLoopIdentifier && rightIdentifier == rangeValueIdentifier {
148
- return true
149
- }
150
- }
151
- }
152
- return false
153
- }
154
-
155
- // Return the left hand side and the right hand side identifiers name
156
- func getLeftAndRightIdentifier (s ast.Stmt ) (string , string ) {
157
- var leftIdentifier , rightIdentifier string
158
- // nolint: gocritic
159
- switch v := s .(type ) {
160
- case * ast.AssignStmt :
161
- if len (v .Rhs ) == 1 {
162
- if i , ok := v .Rhs [0 ].(* ast.Ident ); ok {
163
- rightIdentifier = i .Name
164
- }
165
- }
166
- if len (v .Lhs ) == 1 {
167
- if i , ok := v .Lhs [0 ].(* ast.Ident ); ok {
168
- leftIdentifier = i .Name
169
- }
170
- }
171
- }
172
- return leftIdentifier , rightIdentifier
173
- }
174
-
175
140
func methodParallelIsCalledInMethodRun (node ast.Node , testVar string ) bool {
176
141
var methodParallelCalled bool
177
142
// nolint: gocritic
@@ -247,22 +212,6 @@ func getRunCallbackParameterName(node ast.Node) string {
247
212
return ""
248
213
}
249
214
250
- // Gets the object name `tc` from method t.Run(tc.Foo, func(t *testing.T)
251
- func methodRunFirstArgumentObjectName (node ast.Node ) string {
252
- // nolint: gocritic
253
- switch n := node .(type ) {
254
- case * ast.CallExpr :
255
- for _ , arg := range n .Args {
256
- if s , ok := arg .(* ast.SelectorExpr ); ok {
257
- if i , ok := s .X .(* ast.Ident ); ok {
258
- return i .Name
259
- }
260
- }
261
- }
262
- }
263
- return ""
264
- }
265
-
266
215
// Checks if the function has the param type *testing.T; if it does, then the
267
216
// parameter name is returned, too.
268
217
func isTestFunction (funcDecl * ast.FuncDecl ) (bool , string ) {
@@ -291,3 +240,24 @@ func isTestFunction(funcDecl *ast.FuncDecl) (bool, string) {
291
240
292
241
return false , ""
293
242
}
243
+
244
+ func loopVarReferencedInRun (call * ast.CallExpr , vars []types.Object , typeInfo * types.Info ) (found * string ) {
245
+ if len (call .Args ) != 2 {
246
+ return
247
+ }
248
+
249
+ ast .Inspect (call .Args [1 ], func (n ast.Node ) bool {
250
+ ident , ok := n .(* ast.Ident )
251
+ if ! ok {
252
+ return true
253
+ }
254
+ for _ , o := range vars {
255
+ if typeInfo .ObjectOf (ident ) == o {
256
+ found = & ident .Name
257
+ }
258
+ }
259
+ return true
260
+ })
261
+
262
+ return
263
+ }
0 commit comments