Skip to content

Commit 5d1264e

Browse files
authored
Merge pull request #7 from moricho/namedtrun
Supports named t.Run functions
2 parents 4ff8963 + b499386 commit 5d1264e

File tree

5 files changed

+270
-75
lines changed

5 files changed

+270
-75
lines changed

pkg/ssafunc/ssafunc.go

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
package ssafunc
2+
3+
import (
4+
"go/types"
5+
6+
"github.com/gostaticanalysis/analysisutil"
7+
"github.com/moricho/tparallel/pkg/ssainstr"
8+
"golang.org/x/tools/go/ssa"
9+
)
10+
11+
// IsDeferCalled returns whether the given ssa.Function calls `defer`
12+
func IsDeferCalled(f *ssa.Function) bool {
13+
for _, block := range f.Blocks {
14+
for _, instr := range block.Instrs {
15+
switch instr.(type) {
16+
case *ssa.Defer:
17+
return true
18+
}
19+
}
20+
}
21+
return false
22+
}
23+
24+
// IsCalled returns whether the given ssa.Function calls `fn` func
25+
func IsCalled(f *ssa.Function, fn *types.Func) bool {
26+
block := f.Blocks[0]
27+
for _, instr := range block.Instrs {
28+
called := analysisutil.Called(instr, nil, fn)
29+
if _, ok := ssainstr.LookupCalled(instr, fn); ok || called {
30+
return true
31+
}
32+
}
33+
return false
34+
}

pkg/ssainstr/ssainstr.go

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
package ssainstr
2+
3+
import (
4+
"go/types"
5+
6+
"github.com/gostaticanalysis/analysisutil"
7+
"golang.org/x/tools/go/ssa"
8+
)
9+
10+
// LookupCalled looks up ssa.Instruction that call the `fn` func in the given instr
11+
func LookupCalled(instr ssa.Instruction, fn *types.Func) ([]ssa.Instruction, bool) {
12+
instrs := []ssa.Instruction{}
13+
14+
call, ok := instr.(ssa.CallInstruction)
15+
if !ok {
16+
return instrs, false
17+
}
18+
19+
ssaCall := call.Value()
20+
if ssaCall == nil {
21+
return instrs, false
22+
}
23+
common := ssaCall.Common()
24+
if common == nil {
25+
return instrs, false
26+
}
27+
val := common.Value
28+
29+
called := false
30+
switch fnval := val.(type) {
31+
case *ssa.Function:
32+
for _, block := range fnval.Blocks {
33+
for _, instr := range block.Instrs {
34+
if analysisutil.Called(instr, nil, fn) {
35+
called = true
36+
instrs = append(instrs, instr)
37+
}
38+
}
39+
}
40+
}
41+
42+
return instrs, called
43+
}
44+
45+
// HasArgs returns whether the given ssa.Instruction has `typ` type args
46+
func HasArgs(instr ssa.Instruction, typ types.Type) bool {
47+
call, ok := instr.(ssa.CallInstruction)
48+
if !ok {
49+
return false
50+
}
51+
52+
ssaCall := call.Value()
53+
if ssaCall == nil {
54+
return false
55+
}
56+
57+
for _, arg := range ssaCall.Call.Args {
58+
if types.Identical(arg.Type(), typ) {
59+
return true
60+
}
61+
}
62+
return false
63+
}

testdata/src/test/named_trun_test.go

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
package test
2+
3+
import "testing"
4+
5+
func tRun1(t *testing.T) {
6+
t.Parallel()
7+
8+
t.Run("Named_tRun1_Sub1", func(t *testing.T) {
9+
call("Named_tRun1_Sub")
10+
})
11+
12+
t.Run("Named_tRun1_Sub2", func(t *testing.T) {
13+
call("Named_tRun1_Sub2")
14+
})
15+
}
16+
17+
func tRun2(t *testing.T) {
18+
t.Run("Named_tRun2_Sub", func(t *testing.T) {
19+
call("Named_tRun2_Sub")
20+
})
21+
}
22+
23+
func tRun3(t *testing.T) {
24+
t.Run("Named_tRun3_Sub", func(t *testing.T) {
25+
t.Parallel()
26+
call("Named_tRun3_Sub")
27+
})
28+
}
29+
30+
func tRun4(t *testing.T) {
31+
t.Run("Named_tRun4_Sub", func(t *testing.T) {
32+
t.Parallel()
33+
call("Named_tRun4_Sub")
34+
})
35+
}
36+
37+
func tRun5(t *testing.T) {
38+
t.Parallel()
39+
40+
t.Run("Named_tRun5_Sub1", func(t *testing.T) {
41+
t.Parallel()
42+
call("Named_tRun5_Sub1")
43+
})
44+
45+
t.Run("Named_tRun5_Sub2", func(t *testing.T) {
46+
t.Parallel()
47+
call("Named_tRun5_Sub2")
48+
})
49+
}
50+
51+
func tRun6(t *testing.T) {
52+
t.Run("Named_tRun6_Sub1", func(t *testing.T) {
53+
t.Parallel()
54+
call("Named_tRun6_Sub1")
55+
})
56+
}
57+
58+
func Test_Named_tRun1(t *testing.T) { // want "Test_Named_tRun1's subtests should call t.Parallel"
59+
call("Test_Named_tRun1")
60+
61+
tRun1(t)
62+
}
63+
64+
func Test_Named_tRun2(t *testing.T) { // want "Test_Named_tRun2's subtests should call t.Parallel"
65+
t.Parallel()
66+
67+
call("Test_Named_tRun2")
68+
69+
tRun2(t)
70+
}
71+
72+
func Test_Named_tRun3(t *testing.T) { // want "Test_Named_tRun3 should call t.Parallel on the top level as well as its subtests"
73+
call("Test_Named_tRun3")
74+
75+
tRun3(t)
76+
}
77+
78+
func Test_Named_tRun4(t *testing.T) { // OK
79+
t.Parallel()
80+
81+
call("Test_Named_tRun4")
82+
83+
tRun4(t)
84+
}
85+
86+
func Test_Named_tRun5(t *testing.T) { // OK
87+
call("Test_Named_tRun5")
88+
89+
tRun5(t)
90+
}
91+
92+
func Test_Named_tRun6(t *testing.T) { // OK
93+
t.Parallel()
94+
call("Test_Named_tRun6")
95+
96+
tRun6(t)
97+
98+
t.Run("Named_tRun6_Sub2", func(t *testing.T) {
99+
t.Parallel()
100+
call("Named_tRun6_Sub2")
101+
})
102+
103+
}

testmap.go

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
package tparallel
2+
3+
import (
4+
"go/types"
5+
"strings"
6+
7+
"github.com/gostaticanalysis/analysisutil"
8+
"golang.org/x/tools/go/analysis/passes/buildssa"
9+
"golang.org/x/tools/go/ssa"
10+
11+
"github.com/moricho/tparallel/pkg/ssainstr"
12+
)
13+
14+
// getTestMap gets a set of a top-level test and its sub-tests
15+
func getTestMap(ssaanalyzer *buildssa.SSA, testTyp types.Type) map[*ssa.Function][]*ssa.Function {
16+
testMap := map[*ssa.Function][]*ssa.Function{}
17+
18+
trun := analysisutil.MethodOf(testTyp, "Run")
19+
for _, f := range ssaanalyzer.SrcFuncs {
20+
if !strings.HasPrefix(f.Name(), "Test") || !(f.Parent() == (*ssa.Function)(nil)) {
21+
continue
22+
}
23+
testMap[f] = []*ssa.Function{}
24+
for _, block := range f.Blocks {
25+
for _, instr := range block.Instrs {
26+
called := analysisutil.Called(instr, nil, trun)
27+
28+
if !called && ssainstr.HasArgs(instr, types.NewPointer(testTyp)) {
29+
if instrs, ok := ssainstr.LookupCalled(instr, trun); ok {
30+
for _, v := range instrs {
31+
testMap[f] = appendTestMap(testMap[f], v)
32+
}
33+
}
34+
} else if called {
35+
testMap[f] = appendTestMap(testMap[f], instr)
36+
}
37+
}
38+
}
39+
}
40+
41+
return testMap
42+
}
43+
44+
// appendTestMap converts ssa.Instruction to ssa.Function and append it to a given sub-test slice
45+
func appendTestMap(subtests []*ssa.Function, instr ssa.Instruction) []*ssa.Function {
46+
call, ok := instr.(ssa.CallInstruction)
47+
if !ok {
48+
return subtests
49+
}
50+
51+
ssaCall := call.Value()
52+
for _, arg := range ssaCall.Call.Args {
53+
switch arg := arg.(type) {
54+
case *ssa.Function:
55+
subtests = append(subtests, arg)
56+
case *ssa.MakeClosure:
57+
fn, _ := arg.Fn.(*ssa.Function)
58+
subtests = append(subtests, fn)
59+
}
60+
}
61+
62+
return subtests
63+
}

tparallel.go

Lines changed: 7 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@ package tparallel
22

33
import (
44
"go/types"
5-
"strings"
65

76
"github.com/gostaticanalysis/analysisutil"
87
"golang.org/x/tools/go/analysis"
98
"golang.org/x/tools/go/analysis/passes/buildssa"
10-
"golang.org/x/tools/go/ssa"
9+
10+
"github.com/moricho/tparallel/pkg/ssafunc"
1111
)
1212

1313
const doc = "tparallel detects inappropriate usage of t.Parallel() method in your Go test codes."
@@ -38,20 +38,19 @@ func run(pass *analysis.Pass) (interface{}, error) {
3838
c, _, _ := types.LookupFieldOrMethod(testTyp, true, testPkg, "Cleanup")
3939
cleanup, _ := c.(*types.Func)
4040

41-
testMap := getTestMap(ssaanalyzer, testTyp) // {Test1: [TestSub1, TestSub2], Test2: [TestSub1, TestSub2, TestSub3], ...}
41+
testMap := getTestMap(ssaanalyzer, testTyp) // ex. {Test1: [TestSub1, TestSub2], Test2: [TestSub1, TestSub2, TestSub3], ...}
4242
for top, subs := range testMap {
43-
isParallelTop := isCalled(top, parallel)
44-
43+
isParallelTop := ssafunc.IsCalled(top, parallel)
4544
isPararellSub := false
4645
for _, sub := range subs {
47-
isPararellSub = isCalled(sub, parallel)
46+
isPararellSub = ssafunc.IsCalled(sub, parallel)
4847
if isPararellSub {
4948
break
5049
}
5150
}
5251

53-
if isDeferCalled(top) {
54-
useCleanup := isCalled(top, cleanup)
52+
if ssafunc.IsDeferCalled(top) {
53+
useCleanup := ssafunc.IsCalled(top, cleanup)
5554
if isPararellSub && !useCleanup {
5655
pass.Reportf(top.Pos(), "%s should use t.Cleanup instead of defer", top.Name())
5756
}
@@ -68,70 +67,3 @@ func run(pass *analysis.Pass) (interface{}, error) {
6867

6968
return nil, nil
7069
}
71-
72-
func isDeferCalled(f *ssa.Function) bool {
73-
for _, block := range f.Blocks {
74-
for _, instr := range block.Instrs {
75-
switch instr.(type) {
76-
case *ssa.Defer:
77-
return true
78-
}
79-
}
80-
}
81-
return false
82-
}
83-
84-
func isCalled(f *ssa.Function, typ *types.Func) bool {
85-
block := f.Blocks[0]
86-
for _, instr := range block.Instrs {
87-
called := analysisutil.Called(instr, nil, typ)
88-
if called {
89-
return true
90-
}
91-
}
92-
return false
93-
}
94-
95-
// getTestMap gets a set of a top-level test and its sub-tests
96-
func getTestMap(ssaanalyzer *buildssa.SSA, testTyp types.Type) map[*ssa.Function][]*ssa.Function {
97-
testMap := map[*ssa.Function][]*ssa.Function{}
98-
99-
trun := analysisutil.MethodOf(testTyp, "Run")
100-
for _, f := range ssaanalyzer.SrcFuncs {
101-
if !strings.HasPrefix(f.Name(), "Test") || !(f.Parent() == (*ssa.Function)(nil)) {
102-
continue
103-
}
104-
testMap[f] = []*ssa.Function{}
105-
for _, block := range f.Blocks {
106-
for _, instr := range block.Instrs {
107-
called := analysisutil.Called(instr, nil, trun)
108-
if called {
109-
testMap[f] = appendTestMap(testMap[f], instr)
110-
}
111-
}
112-
}
113-
}
114-
115-
return testMap
116-
}
117-
118-
// appendTestMap converts ssa.Instruction to ssa.Function and append it to a given sub-test slice
119-
func appendTestMap(subtests []*ssa.Function, instr ssa.Instruction) []*ssa.Function {
120-
call, ok := instr.(ssa.CallInstruction)
121-
if !ok {
122-
return subtests
123-
}
124-
125-
ssaCall := call.Value()
126-
for _, arg := range ssaCall.Call.Args {
127-
switch arg := arg.(type) {
128-
case *ssa.Function:
129-
subtests = append(subtests, arg)
130-
case *ssa.MakeClosure:
131-
fn, _ := arg.Fn.(*ssa.Function)
132-
subtests = append(subtests, fn)
133-
}
134-
}
135-
136-
return subtests
137-
}

0 commit comments

Comments
 (0)