Skip to content

Commit e2a5792

Browse files
authored
Merge pull request #42 from kunwardeep/issue-40
Added new test to validate issue 40
2 parents 9db3346 + ff09094 commit e2a5792

File tree

3 files changed

+266
-118
lines changed

3 files changed

+266
-118
lines changed

.github/workflows/test.yml

+12
Original file line numberDiff line numberDiff line change
@@ -23,3 +23,15 @@ jobs:
2323
run: go build -v
2424
- name: Test
2525
run: go test -v -race ./...
26+
27+
lint:
28+
name: Lint
29+
runs-on: ubuntu-latest
30+
steps:
31+
- name: Checkout source
32+
uses: actions/checkout@v3
33+
- name: golangci-lint
34+
uses: golangci/golangci-lint-action@v3
35+
with:
36+
version: latest
37+
args: --timeout=5m

pkg/paralleltest/paralleltest.go

+183-106
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ import (
1212

1313
const Doc = `check that tests use t.Parallel() method
1414
It also checks that the t.Parallel is used if multiple tests cases are run as part of single test.
15-
As part of ensuring parallel tests works as expected it checks for reinitialising of the range value
15+
As part of ensuring parallel tests works as expected it checks for reinitializing of the range value
1616
over the test cases.(https://tinyurl.com/y6555cy6)`
1717

1818
func NewAnalyzer() *analysis.Analyzer {
@@ -46,138 +46,183 @@ func newParallelAnalyzer() *parallelAnalyzer {
4646
return a
4747
}
4848

49-
func (a *parallelAnalyzer) run(pass *analysis.Pass) (interface{}, error) {
50-
inspector := inspector.New(pass.Files)
49+
type testFunctionAnalysis struct {
50+
funcHasParallelMethod,
51+
funcCantParallelMethod,
52+
rangeStatementOverTestCasesExists,
53+
rangeStatementHasParallelMethod,
54+
rangeStatementCantParallelMethod bool
55+
loopVariableUsedInRun *string
56+
numberOfTestRun int
57+
positionOfTestRunNode []ast.Node
58+
rangeNode ast.Node
59+
}
5160

52-
nodeFilter := []ast.Node{
53-
(*ast.FuncDecl)(nil),
54-
}
61+
type testRunAnalysis struct {
62+
hasParallel bool
63+
cantParallel bool
64+
numberOfTestRun int
65+
positionOfTestRunNode []ast.Node
66+
}
5567

56-
inspector.Preorder(nodeFilter, func(node ast.Node) {
57-
funcDecl := node.(*ast.FuncDecl)
58-
var funcHasParallelMethod,
59-
funcCantParallelMethod,
60-
rangeStatementOverTestCasesExists,
61-
rangeStatementHasParallelMethod,
62-
rangeStatementCantParallelMethod bool
63-
var loopVariableUsedInRun *string
64-
var numberOfTestRun int
65-
var positionOfTestRunNode []ast.Node
66-
var rangeNode ast.Node
67-
68-
// Check runs for test functions only
69-
isTest, testVar := isTestFunction(funcDecl)
70-
if !isTest {
71-
return
72-
}
68+
func (a *parallelAnalyzer) analyzeTestRun(pass *analysis.Pass, n ast.Node, testVar string) testRunAnalysis {
69+
var analysis testRunAnalysis
7370

74-
for _, l := range funcDecl.Body.List {
75-
switch v := l.(type) {
71+
if methodRunIsCalledInTestFunction(n, testVar) {
72+
innerTestVar := getRunCallbackParameterName(n)
73+
analysis.numberOfTestRun++
7674

77-
case *ast.ExprStmt:
78-
ast.Inspect(v, func(n ast.Node) bool {
79-
// Check if the test method is calling t.Parallel
80-
if !funcHasParallelMethod {
81-
funcHasParallelMethod = methodParallelIsCalledInTestFunction(n, testVar)
75+
if callExpr, ok := n.(*ast.CallExpr); ok && len(callExpr.Args) > 1 {
76+
if funcLit, ok := callExpr.Args[1].(*ast.FuncLit); ok {
77+
ast.Inspect(funcLit, func(p ast.Node) bool {
78+
if !analysis.hasParallel {
79+
analysis.hasParallel = methodParallelIsCalledInTestFunction(p, innerTestVar)
8280
}
83-
84-
// Check if the test calls t.Setenv, cannot be used in parallel tests or tests with parallel ancestors
85-
if !funcCantParallelMethod {
86-
funcCantParallelMethod = methodSetenvIsCalledInTestFunction(n, testVar)
81+
if !analysis.cantParallel {
82+
analysis.cantParallel = methodSetenvIsCalledInTestFunction(p, innerTestVar)
8783
}
88-
89-
// Check if the t.Run within the test function is calling t.Parallel
90-
if methodRunIsCalledInTestFunction(n, testVar) {
91-
// n is a call to t.Run; find out the name of the subtest's *testing.T parameter.
92-
innerTestVar := getRunCallbackParameterName(n)
93-
94-
hasParallel := false
95-
cantParallel := false
96-
numberOfTestRun++
97-
ast.Inspect(v, func(p ast.Node) bool {
98-
if !hasParallel {
99-
hasParallel = methodParallelIsCalledInTestFunction(p, innerTestVar)
100-
}
101-
if !cantParallel {
102-
cantParallel = methodSetenvIsCalledInTestFunction(p, innerTestVar)
84+
return true
85+
})
86+
} else if ident, ok := callExpr.Args[1].(*ast.Ident); ok {
87+
foundFunc := false
88+
for _, file := range pass.Files {
89+
for _, decl := range file.Decls {
90+
if funcDecl, ok := decl.(*ast.FuncDecl); ok && funcDecl.Name.Name == ident.Name {
91+
foundFunc = true
92+
isReceivingTestContext, testParamName := isFunctionReceivingTestContext(funcDecl)
93+
if isReceivingTestContext {
94+
ast.Inspect(funcDecl, func(p ast.Node) bool {
95+
if !analysis.hasParallel {
96+
analysis.hasParallel = methodParallelIsCalledInTestFunction(p, testParamName)
97+
}
98+
return true
99+
})
103100
}
104-
return true
105-
})
106-
if !hasParallel && !cantParallel {
107-
positionOfTestRunNode = append(positionOfTestRunNode, n)
108101
}
109102
}
110-
return true
111-
})
103+
}
104+
if !foundFunc {
105+
analysis.hasParallel = false
106+
}
107+
}
108+
}
112109

113-
// Check if the range over testcases is calling t.Parallel
114-
case *ast.RangeStmt:
115-
rangeNode = v
110+
if !analysis.hasParallel && !analysis.cantParallel {
111+
analysis.positionOfTestRunNode = append(analysis.positionOfTestRunNode, n)
112+
}
113+
}
116114

117-
var loopVars []types.Object
118-
for _, expr := range []ast.Expr{v.Key, v.Value} {
119-
if id, ok := expr.(*ast.Ident); ok {
120-
loopVars = append(loopVars, pass.TypesInfo.ObjectOf(id))
121-
}
122-
}
115+
return analysis
116+
}
123117

124-
ast.Inspect(v, func(n ast.Node) bool {
125-
// nolint: gocritic
126-
switch r := n.(type) {
127-
case *ast.ExprStmt:
128-
if methodRunIsCalledInRangeStatement(r.X, testVar) {
129-
// r.X is a call to t.Run; find out the name of the subtest's *testing.T parameter.
130-
innerTestVar := getRunCallbackParameterName(r.X)
118+
func (a *parallelAnalyzer) analyzeTestFunction(pass *analysis.Pass, funcDecl *ast.FuncDecl) {
119+
var analysis testFunctionAnalysis
131120

132-
rangeStatementOverTestCasesExists = true
121+
// Check runs for test functions only
122+
isTest, testVar := isTestFunction(funcDecl)
123+
if !isTest {
124+
return
125+
}
133126

134-
if !rangeStatementHasParallelMethod {
135-
rangeStatementHasParallelMethod = methodParallelIsCalledInMethodRun(r.X, innerTestVar)
136-
}
127+
for _, l := range funcDecl.Body.List {
128+
switch v := l.(type) {
129+
case *ast.ExprStmt:
130+
ast.Inspect(v, func(n ast.Node) bool {
131+
if !analysis.funcHasParallelMethod {
132+
analysis.funcHasParallelMethod = methodParallelIsCalledInTestFunction(n, testVar)
133+
}
134+
if !analysis.funcCantParallelMethod {
135+
analysis.funcCantParallelMethod = methodSetenvIsCalledInTestFunction(n, testVar)
136+
}
137+
runAnalysis := a.analyzeTestRun(pass, n, testVar)
138+
analysis.numberOfTestRun += runAnalysis.numberOfTestRun
139+
analysis.positionOfTestRunNode = append(analysis.positionOfTestRunNode, runAnalysis.positionOfTestRunNode...)
140+
return true
141+
})
142+
143+
case *ast.RangeStmt:
144+
analysis.rangeNode = v
145+
146+
var loopVars []types.Object
147+
for _, expr := range []ast.Expr{v.Key, v.Value} {
148+
if id, ok := expr.(*ast.Ident); ok {
149+
loopVars = append(loopVars, pass.TypesInfo.ObjectOf(id))
150+
}
151+
}
152+
153+
ast.Inspect(v, func(n ast.Node) bool {
154+
if r, ok := n.(*ast.ExprStmt); ok {
155+
if methodRunIsCalledInRangeStatement(r.X, testVar) {
156+
innerTestVar := getRunCallbackParameterName(r.X)
157+
analysis.rangeStatementOverTestCasesExists = true
137158

138-
if !rangeStatementCantParallelMethod {
139-
rangeStatementCantParallelMethod = methodSetenvIsCalledInMethodRun(r.X, innerTestVar)
159+
if !analysis.rangeStatementHasParallelMethod {
160+
analysis.rangeStatementHasParallelMethod = methodParallelIsCalledInMethodRun(r.X, innerTestVar)
161+
}
162+
if !analysis.rangeStatementCantParallelMethod {
163+
analysis.rangeStatementCantParallelMethod = methodSetenvIsCalledInMethodRun(r.X, innerTestVar)
164+
}
165+
if !a.ignoreLoopVar && analysis.loopVariableUsedInRun == nil {
166+
if run, ok := r.X.(*ast.CallExpr); ok {
167+
analysis.loopVariableUsedInRun = loopVarReferencedInRun(run, loopVars, pass.TypesInfo)
140168
}
169+
}
141170

142-
if !a.ignoreLoopVar && loopVariableUsedInRun == nil {
143-
if run, ok := r.X.(*ast.CallExpr); ok {
144-
loopVariableUsedInRun = loopVarReferencedInRun(run, loopVars, pass.TypesInfo)
145-
}
171+
// Check nested test runs
172+
if callExpr, ok := r.X.(*ast.CallExpr); ok && len(callExpr.Args) > 1 {
173+
if funcLit, ok := callExpr.Args[1].(*ast.FuncLit); ok {
174+
ast.Inspect(funcLit, func(p ast.Node) bool {
175+
runAnalysis := a.analyzeTestRun(pass, p, innerTestVar)
176+
analysis.numberOfTestRun += runAnalysis.numberOfTestRun
177+
analysis.positionOfTestRunNode = append(analysis.positionOfTestRunNode, runAnalysis.positionOfTestRunNode...)
178+
return true
179+
})
146180
}
147181
}
148182
}
149-
return true
150-
})
151-
}
183+
}
184+
return true
185+
})
152186
}
187+
}
153188

154-
// Descendents which call Setenv, also prevent tests from calling Parallel
155-
if rangeStatementCantParallelMethod {
156-
funcCantParallelMethod = true
157-
}
189+
if analysis.rangeStatementCantParallelMethod {
190+
analysis.funcCantParallelMethod = true
191+
}
158192

159-
if !a.ignoreMissing && !funcHasParallelMethod && !funcCantParallelMethod {
160-
pass.Reportf(node.Pos(), "Function %s missing the call to method parallel\n", funcDecl.Name.Name)
161-
}
193+
if !a.ignoreMissing && !analysis.funcHasParallelMethod && !analysis.funcCantParallelMethod {
194+
pass.Reportf(funcDecl.Pos(), "Function %s missing the call to method parallel\n", funcDecl.Name.Name)
195+
}
162196

163-
if rangeStatementOverTestCasesExists && rangeNode != nil {
164-
if !rangeStatementHasParallelMethod && !rangeStatementCantParallelMethod {
165-
if !a.ignoreMissing && !a.ignoreMissingSubtests {
166-
pass.Reportf(rangeNode.Pos(), "Range statement for test %s missing the call to method parallel in test Run\n", funcDecl.Name.Name)
167-
}
168-
} else if loopVariableUsedInRun != nil {
169-
pass.Reportf(rangeNode.Pos(), "Range statement for test %s does not reinitialise the variable %s\n", funcDecl.Name.Name, *loopVariableUsedInRun)
197+
if analysis.rangeStatementOverTestCasesExists && analysis.rangeNode != nil {
198+
if !analysis.rangeStatementHasParallelMethod && !analysis.rangeStatementCantParallelMethod {
199+
if !a.ignoreMissing && !a.ignoreMissingSubtests {
200+
pass.Reportf(analysis.rangeNode.Pos(), "Range statement for test %s missing the call to method parallel in test Run\n", funcDecl.Name.Name)
170201
}
202+
} else if analysis.loopVariableUsedInRun != nil && !a.ignoreLoopVar {
203+
pass.Reportf(analysis.rangeNode.Pos(), "Range statement for test %s does not reinitialise the variable %s\n", funcDecl.Name.Name, *analysis.loopVariableUsedInRun)
171204
}
205+
}
172206

173-
// Check if the t.Run is more than one as there is no point making one test parallel
174-
if !a.ignoreMissing && !a.ignoreMissingSubtests {
175-
if numberOfTestRun > 1 && len(positionOfTestRunNode) > 0 {
176-
for _, n := range positionOfTestRunNode {
177-
pass.Reportf(n.Pos(), "Function %s missing the call to method parallel in the test run\n", funcDecl.Name.Name)
178-
}
207+
if !a.ignoreMissing && !a.ignoreMissingSubtests {
208+
if analysis.numberOfTestRun > 1 && len(analysis.positionOfTestRunNode) > 0 {
209+
for _, n := range analysis.positionOfTestRunNode {
210+
pass.Reportf(n.Pos(), "Function %s missing the call to method parallel in the test run\n", funcDecl.Name.Name)
179211
}
180212
}
213+
}
214+
}
215+
216+
func (a *parallelAnalyzer) run(pass *analysis.Pass) (interface{}, error) {
217+
inspector := inspector.New(pass.Files)
218+
219+
nodeFilter := []ast.Node{
220+
(*ast.FuncDecl)(nil),
221+
}
222+
223+
inspector.Preorder(nodeFilter, func(node ast.Node) {
224+
funcDecl := node.(*ast.FuncDecl)
225+
a.analyzeTestFunction(pass, funcDecl)
181226
})
182227

183228
return nil, nil
@@ -267,8 +312,38 @@ func getRunCallbackParameterName(node ast.Node) string {
267312
return ""
268313
}
269314

270-
// Checks if the function has the param type *testing.T; if it does, then the
271-
// parameter name is returned, too.
315+
// isFunctionReceivingTestContext checks if a function declaration receives a *testing.T parameter
316+
// Returns (true, paramName) if it does, (false, "") if it doesn't
317+
func isFunctionReceivingTestContext(funcDecl *ast.FuncDecl) (bool, string) {
318+
testMethodPackageType := "testing"
319+
testMethodStruct := "T"
320+
321+
if funcDecl.Type.Params != nil && len(funcDecl.Type.Params.List) != 1 {
322+
return false, ""
323+
}
324+
325+
param := funcDecl.Type.Params.List[0]
326+
if starExp, ok := param.Type.(*ast.StarExpr); ok {
327+
if selectExpr, ok := starExp.X.(*ast.SelectorExpr); ok {
328+
if selectExpr.Sel.Name == testMethodStruct {
329+
if s, ok := selectExpr.X.(*ast.Ident); ok {
330+
if len(param.Names) > 0 {
331+
return s.Name == testMethodPackageType, param.Names[0].Name
332+
}
333+
}
334+
}
335+
}
336+
}
337+
338+
return false, ""
339+
}
340+
341+
// isTestFunction checks if a function declaration is a test function
342+
// A test function must:
343+
// 1. Start with "Test"
344+
// 2. Have exactly one parameter
345+
// 3. Have that parameter be of type *testing.T
346+
// Returns (true, paramName) if it is a test function, (false, "") if it isn't
272347
func isTestFunction(funcDecl *ast.FuncDecl) (bool, string) {
273348
testMethodPackageType := "testing"
274349
testMethodStruct := "T"
@@ -298,6 +373,8 @@ func isTestFunction(funcDecl *ast.FuncDecl) (bool, string) {
298373
return false, ""
299374
}
300375

376+
// loopVarReferencedInRun checks if a loop variable is referenced within a test run
377+
// This is important for detecting potential race conditions in parallel tests
301378
func loopVarReferencedInRun(call *ast.CallExpr, vars []types.Object, typeInfo *types.Info) (found *string) {
302379
if len(call.Args) != 2 {
303380
return

0 commit comments

Comments
 (0)