@@ -42,7 +42,8 @@ func run(pass *analysis.Pass) (interface{}, error) {
42
42
var rangeNode ast.Node
43
43
44
44
// Check runs for test functions only
45
- if ! isTestFunction (funcDecl ) {
45
+ isTest , testVar := isTestFunction (funcDecl )
46
+ if ! isTest {
46
47
return
47
48
}
48
49
@@ -53,16 +54,19 @@ func run(pass *analysis.Pass) (interface{}, error) {
53
54
ast .Inspect (v , func (n ast.Node ) bool {
54
55
// Check if the test method is calling t.parallel
55
56
if ! funcHasParallelMethod {
56
- funcHasParallelMethod = methodParallelIsCalledInTestFunction (n )
57
+ funcHasParallelMethod = methodParallelIsCalledInTestFunction (n , testVar )
57
58
}
58
59
59
60
// Check if the t.Run within the test function is calling t.parallel
60
- if methodRunIsCalledInTestFunction (n ) {
61
+ if methodRunIsCalledInTestFunction (n , testVar ) {
62
+ // n is a call to t.Run; find out the name of the subtest's *testing.T parameter.
63
+ innerTestVar := getRunCallbackParameterName (n )
64
+
61
65
hasParallel := false
62
66
numberOfTestRun ++
63
67
ast .Inspect (v , func (p ast.Node ) bool {
64
68
if ! hasParallel {
65
- hasParallel = methodParallelIsCalledInTestFunction (p )
69
+ hasParallel = methodParallelIsCalledInTestFunction (p , innerTestVar )
66
70
}
67
71
return true
68
72
})
@@ -81,12 +85,15 @@ func run(pass *analysis.Pass) (interface{}, error) {
81
85
// nolint: gocritic
82
86
switch r := n .(type ) {
83
87
case * ast.ExprStmt :
84
- if methodRunIsCalledInRangeStatement (r .X ) {
88
+ if methodRunIsCalledInRangeStatement (r .X , testVar ) {
89
+ // r.X is a call to t.Run; find out the name of the subtest's *testing.T parameter.
90
+ innerTestVar := getRunCallbackParameterName (r .X )
91
+
85
92
rangeStatementOverTestCasesExists = true
86
93
testRunLoopIdentifier = methodRunFirstArgumentObjectName (r .X )
87
94
88
95
if ! rangeStatementHasParallelMethod {
89
- rangeStatementHasParallelMethod = methodParallelIsCalledInMethodRun (r .X )
96
+ rangeStatementHasParallelMethod = methodParallelIsCalledInMethodRun (r .X , innerTestVar )
90
97
}
91
98
}
92
99
}
@@ -165,7 +172,7 @@ func getLeftAndRightIdentifier(s ast.Stmt) (string, string) {
165
172
return leftIdentifier , rightIdentifier
166
173
}
167
174
168
- func methodParallelIsCalledInMethodRun (node ast.Node ) bool {
175
+ func methodParallelIsCalledInMethodRun (node ast.Node , testVar string ) bool {
169
176
var methodParallelCalled bool
170
177
// nolint: gocritic
171
178
switch callExp := node .(type ) {
@@ -174,7 +181,7 @@ func methodParallelIsCalledInMethodRun(node ast.Node) bool {
174
181
if ! methodParallelCalled {
175
182
ast .Inspect (arg , func (n ast.Node ) bool {
176
183
if ! methodParallelCalled {
177
- methodParallelCalled = methodParallelIsCalledInRunMethod (n )
184
+ methodParallelCalled = methodParallelIsCalledInRunMethod (n , testVar )
178
185
return true
179
186
}
180
187
return false
@@ -185,32 +192,61 @@ func methodParallelIsCalledInMethodRun(node ast.Node) bool {
185
192
return methodParallelCalled
186
193
}
187
194
188
- func methodParallelIsCalledInRunMethod (node ast.Node ) bool {
189
- return exprCallHasMethod (node , "Parallel" )
195
+ func methodParallelIsCalledInRunMethod (node ast.Node , testVar string ) bool {
196
+ return exprCallHasMethod (node , testVar , "Parallel" )
190
197
}
191
198
192
- func methodParallelIsCalledInTestFunction (node ast.Node ) bool {
193
- return exprCallHasMethod (node , "Parallel" )
199
+ func methodParallelIsCalledInTestFunction (node ast.Node , testVar string ) bool {
200
+ return exprCallHasMethod (node , testVar , "Parallel" )
194
201
}
195
202
196
- func methodRunIsCalledInRangeStatement (node ast.Node ) bool {
197
- return exprCallHasMethod (node , "Run" )
203
+ func methodRunIsCalledInRangeStatement (node ast.Node , testVar string ) bool {
204
+ return exprCallHasMethod (node , testVar , "Run" )
198
205
}
199
206
200
- func methodRunIsCalledInTestFunction (node ast.Node ) bool {
201
- return exprCallHasMethod (node , "Run" )
207
+ func methodRunIsCalledInTestFunction (node ast.Node , testVar string ) bool {
208
+ return exprCallHasMethod (node , testVar , "Run" )
202
209
}
203
- func exprCallHasMethod (node ast.Node , methodName string ) bool {
210
+ func exprCallHasMethod (node ast.Node , receiverName , methodName string ) bool {
204
211
// nolint: gocritic
205
212
switch n := node .(type ) {
206
213
case * ast.CallExpr :
207
214
if fun , ok := n .Fun .(* ast.SelectorExpr ); ok {
208
- return fun .Sel .Name == methodName
215
+ if receiver , ok := fun .X .(* ast.Ident ); ok {
216
+ return receiver .Name == receiverName && fun .Sel .Name == methodName
217
+ }
209
218
}
210
219
}
211
220
return false
212
221
}
213
222
223
+ // In an expression of the form t.Run(x, func(q *testing.T) {...}), return the
224
+ // value "q". In _most_ code, the name is probably t, but we shouldn't just
225
+ // assume.
226
+ func getRunCallbackParameterName (node ast.Node ) string {
227
+ if n , ok := node .(* ast.CallExpr ); ok {
228
+ if len (n .Args ) < 2 {
229
+ // We want argument #2, but this call doesn't have two
230
+ // arguments. Maybe it's not really t.Run.
231
+ return ""
232
+ }
233
+ funcArg := n .Args [1 ]
234
+ if fun , ok := funcArg .(* ast.FuncLit ); ok {
235
+ if len (fun .Type .Params .List ) < 1 {
236
+ // Subtest function doesn't have any parameters.
237
+ return ""
238
+ }
239
+ firstArg := fun .Type .Params .List [0 ]
240
+ // We'll assume firstArg.Type is *testing.T.
241
+ if len (firstArg .Names ) < 1 {
242
+ return ""
243
+ }
244
+ return firstArg .Names [0 ].Name
245
+ }
246
+ }
247
+ return ""
248
+ }
249
+
214
250
// Gets the object name `tc` from method t.Run(tc.Foo, func(t *testing.T)
215
251
func methodRunFirstArgumentObjectName (node ast.Node ) string {
216
252
// nolint: gocritic
@@ -227,30 +263,31 @@ func methodRunFirstArgumentObjectName(node ast.Node) string {
227
263
return ""
228
264
}
229
265
230
- // Checks if the function has the param type *testing.T)
231
- func isTestFunction (funcDecl * ast.FuncDecl ) bool {
266
+ // Checks if the function has the param type *testing.T; if it does, then the
267
+ // parameter name is returned, too.
268
+ func isTestFunction (funcDecl * ast.FuncDecl ) (bool , string ) {
232
269
testMethodPackageType := "testing"
233
270
testMethodStruct := "T"
234
271
testPrefix := "Test"
235
272
236
273
if ! strings .HasPrefix (funcDecl .Name .Name , testPrefix ) {
237
- return false
274
+ return false , ""
238
275
}
239
276
240
277
if funcDecl .Type .Params != nil && len (funcDecl .Type .Params .List ) != 1 {
241
- return false
278
+ return false , ""
242
279
}
243
280
244
281
param := funcDecl .Type .Params .List [0 ]
245
282
if starExp , ok := param .Type .(* ast.StarExpr ); ok {
246
283
if selectExpr , ok := starExp .X .(* ast.SelectorExpr ); ok {
247
284
if selectExpr .Sel .Name == testMethodStruct {
248
285
if s , ok := selectExpr .X .(* ast.Ident ); ok {
249
- return s .Name == testMethodPackageType
286
+ return s .Name == testMethodPackageType , param . Names [ 0 ]. Name
250
287
}
251
288
}
252
289
}
253
290
}
254
291
255
- return false
292
+ return false , ""
256
293
}
0 commit comments