Skip to content

Commit 92c106d

Browse files
graememorganError Prone Team
authored andcommitted
Encourage when/thenReturn over doReturn/when.
I've been very over-conservative about the heuristics on what to avoid, but nonetheless here's a huge sample flume: unknown commit PiperOrigin-RevId: 627396784
1 parent 07c1a7c commit 92c106d

File tree

4 files changed

+342
-0
lines changed

4 files changed

+342
-0
lines changed
Lines changed: 198 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
/*
2+
* Copyright 2024 The Error Prone Authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package com.google.errorprone.bugpatterns;
17+
18+
import static com.google.common.collect.Iterables.getLast;
19+
import static com.google.errorprone.BugPattern.SeverityLevel.WARNING;
20+
import static com.google.errorprone.matchers.Description.NO_MATCH;
21+
import static com.google.errorprone.matchers.Matchers.instanceMethod;
22+
import static com.google.errorprone.matchers.Matchers.staticMethod;
23+
import static com.google.errorprone.util.ASTHelpers.getReceiver;
24+
import static com.google.errorprone.util.ASTHelpers.getStartPosition;
25+
import static com.google.errorprone.util.ASTHelpers.getSymbol;
26+
import static com.google.errorprone.util.ASTHelpers.hasAnnotation;
27+
import static com.google.errorprone.util.ASTHelpers.isSameType;
28+
import static java.lang.String.format;
29+
30+
import com.google.common.collect.ImmutableMap;
31+
import com.google.common.collect.ImmutableSet;
32+
import com.google.errorprone.BugPattern;
33+
import com.google.errorprone.VisitorState;
34+
import com.google.errorprone.bugpatterns.BugChecker.CompilationUnitTreeMatcher;
35+
import com.google.errorprone.fixes.SuggestedFix;
36+
import com.google.errorprone.fixes.SuggestedFixes;
37+
import com.google.errorprone.matchers.Description;
38+
import com.google.errorprone.matchers.Matcher;
39+
import com.sun.source.tree.AssignmentTree;
40+
import com.sun.source.tree.CompilationUnitTree;
41+
import com.sun.source.tree.ExpressionTree;
42+
import com.sun.source.tree.MethodInvocationTree;
43+
import com.sun.source.tree.Tree;
44+
import com.sun.source.tree.VariableTree;
45+
import com.sun.source.util.TreePath;
46+
import com.sun.source.util.TreePathScanner;
47+
import com.sun.tools.javac.code.Symbol.VarSymbol;
48+
49+
/** A BugPattern; see the summary. */
50+
@BugPattern(
51+
severity = WARNING,
52+
summary = "Prefer using when/thenReturn over doReturn/when for additional type safety.")
53+
public final class MockitoDoSetup extends BugChecker implements CompilationUnitTreeMatcher {
54+
@Override
55+
public Description matchCompilationUnit(CompilationUnitTree tree, VisitorState state) {
56+
ImmutableSet<VarSymbol> spies = findSpies(state);
57+
new SuppressibleTreePathScanner<Void, Void>(state) {
58+
59+
@Override
60+
public Void visitMethodInvocation(MethodInvocationTree tree, Void unused) {
61+
handle(tree);
62+
return super.visitMethodInvocation(tree, null);
63+
}
64+
65+
private void handle(MethodInvocationTree tree) {
66+
if (!DO_STUBBER.matches(tree, state)) {
67+
return;
68+
}
69+
TreePath whenPath = getCurrentPath().getParentPath().getParentPath();
70+
Tree whenCall = whenPath.getLeaf();
71+
if (!(whenCall instanceof MethodInvocationTree)
72+
|| !INSTANCE_WHEN.matches((MethodInvocationTree) whenCall, state)) {
73+
return;
74+
}
75+
if (isSpy(((MethodInvocationTree) whenCall).getArguments().get(0))) {
76+
return;
77+
}
78+
Tree mockedMethod = whenPath.getParentPath().getParentPath().getLeaf();
79+
80+
if (!(mockedMethod instanceof MethodInvocationTree)) {
81+
return;
82+
}
83+
if (isSameType(
84+
getSymbol((MethodInvocationTree) mockedMethod).getReturnType(),
85+
state.getSymtab().voidType,
86+
state)) {
87+
return;
88+
}
89+
90+
SuggestedFix.Builder fix = SuggestedFix.builder();
91+
var when = SuggestedFixes.qualifyStaticImport("org.mockito.Mockito.when", fix, state);
92+
fix.replace(((MethodInvocationTree) whenCall).getMethodSelect(), when)
93+
.replace(state.getEndPosition(whenCall) - 1, state.getEndPosition(whenCall), "")
94+
.postfixWith(
95+
mockedMethod,
96+
format(
97+
").%s(%s)",
98+
NAME_MAPPINGS.get(getSymbol(tree).getSimpleName().toString()),
99+
getParameterSource(tree, state)));
100+
101+
state.reportMatch(describeMatch(tree, fix.build()));
102+
}
103+
104+
private boolean isSpy(ExpressionTree tree) {
105+
var symbol = getSymbol(tree);
106+
return symbol != null
107+
&& (spies.contains(symbol) || hasAnnotation(symbol, "org.mockito.Spy", state));
108+
}
109+
}.scan(state.getPath(), null);
110+
return NO_MATCH;
111+
}
112+
113+
private static String getParameterSource(MethodInvocationTree tree, VisitorState state) {
114+
return state
115+
.getSourceCode()
116+
.subSequence(
117+
getStartPosition(tree.getArguments().get(0)),
118+
state.getEndPosition(getLast(tree.getArguments())))
119+
.toString();
120+
}
121+
122+
private static ImmutableSet<VarSymbol> findSpies(VisitorState state) {
123+
// NOTES: This is extremely conservative in at least two ways.
124+
// 1) We ignore an entire mock if _any_ method is mocked to throw, not just the relevant method.
125+
// 2) We could still refactor if the thenThrow comes _after_, or if the _only_ call is
126+
// thenThrow.
127+
ImmutableSet.Builder<VarSymbol> spiesOrThrows = ImmutableSet.builder();
128+
new TreePathScanner<Void, Void>() {
129+
@Override
130+
public Void visitVariable(VariableTree tree, Void unused) {
131+
if (tree.getInitializer() != null && SPY.matches(tree.getInitializer(), state)) {
132+
spiesOrThrows.add(getSymbol(tree));
133+
}
134+
return super.visitVariable(tree, null);
135+
}
136+
137+
@Override
138+
public Void visitMethodInvocation(MethodInvocationTree tree, Void unused) {
139+
if (DO_THROW.matches(tree, state)) {
140+
var whenCall = getCurrentPath().getParentPath().getParentPath().getLeaf();
141+
if ((whenCall instanceof MethodInvocationTree)
142+
&& INSTANCE_WHEN.matches((MethodInvocationTree) whenCall, state)) {
143+
var whenTarget = getSymbol(((MethodInvocationTree) whenCall).getArguments().get(0));
144+
if (whenTarget instanceof VarSymbol) {
145+
spiesOrThrows.add((VarSymbol) whenTarget);
146+
}
147+
}
148+
}
149+
if (THEN_THROW.matches(tree, state)) {
150+
var receiver = getReceiver(tree);
151+
if (STATIC_WHEN.matches(receiver, state)) {
152+
var mock = getReceiver(((MethodInvocationTree) receiver).getArguments().get(0));
153+
var mockSymbol = getSymbol(mock);
154+
if (mockSymbol instanceof VarSymbol) {
155+
spiesOrThrows.add((VarSymbol) mockSymbol);
156+
}
157+
}
158+
}
159+
return super.visitMethodInvocation(tree, null);
160+
}
161+
162+
@Override
163+
public Void visitAssignment(AssignmentTree tree, Void unused) {
164+
if (SPY.matches(tree.getExpression(), state)) {
165+
var symbol = getSymbol(tree.getVariable());
166+
if (symbol instanceof VarSymbol) {
167+
spiesOrThrows.add((VarSymbol) symbol);
168+
}
169+
}
170+
return super.visitAssignment(tree, null);
171+
}
172+
}.scan(state.getPath().getCompilationUnit(), null);
173+
return spiesOrThrows.build();
174+
}
175+
176+
private static final ImmutableMap<String, String> NAME_MAPPINGS =
177+
ImmutableMap.of(
178+
"doAnswer", "thenAnswer",
179+
"doReturn", "thenReturn",
180+
"doThrow", "thenThrow");
181+
private static final Matcher<ExpressionTree> DO_STUBBER =
182+
staticMethod().onClass("org.mockito.Mockito").namedAnyOf(NAME_MAPPINGS.keySet());
183+
184+
private static final Matcher<ExpressionTree> INSTANCE_WHEN =
185+
instanceMethod().onDescendantOf("org.mockito.stubbing.Stubber").named("when");
186+
187+
private static final Matcher<ExpressionTree> SPY =
188+
staticMethod().onClass("org.mockito.Mockito").named("spy");
189+
190+
private static final Matcher<ExpressionTree> DO_THROW =
191+
staticMethod().onClass("org.mockito.Mockito").named("doThrow");
192+
193+
private static final Matcher<ExpressionTree> STATIC_WHEN =
194+
staticMethod().onClass("org.mockito.Mockito").named("when");
195+
196+
private static final Matcher<ExpressionTree> THEN_THROW =
197+
instanceMethod().onDescendantOf("org.mockito.stubbing.OngoingStubbing").named("thenThrow");
198+
}

core/src/main/java/com/google/errorprone/scanner/BuiltInCheckerSuppliers.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@
244244
import com.google.errorprone.bugpatterns.MixedDescriptors;
245245
import com.google.errorprone.bugpatterns.MixedMutabilityReturnType;
246246
import com.google.errorprone.bugpatterns.MockNotUsedInProduction;
247+
import com.google.errorprone.bugpatterns.MockitoDoSetup;
247248
import com.google.errorprone.bugpatterns.MockitoUsage;
248249
import com.google.errorprone.bugpatterns.ModifiedButNotUsed;
249250
import com.google.errorprone.bugpatterns.ModifyCollectionInEnhancedForLoop;
@@ -1169,6 +1170,7 @@ public static ScannerSupplier warningChecks() {
11691170
MissingBraces.class,
11701171
MissingDefault.class,
11711172
MixedArrayDimensions.class,
1173+
MockitoDoSetup.class,
11721174
MoreThanOneQualifier.class,
11731175
MultiVariableDeclaration.class,
11741176
MultipleTopLevelClasses.class,
Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
/*
2+
* Copyright 2024 The Error Prone Authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
package com.google.errorprone.bugpatterns;
17+
18+
import com.google.errorprone.BugCheckerRefactoringTestHelper;
19+
import org.junit.Test;
20+
import org.junit.runner.RunWith;
21+
import org.junit.runners.JUnit4;
22+
23+
@RunWith(JUnit4.class)
24+
public final class MockitoDoSetupTest {
25+
private final BugCheckerRefactoringTestHelper helper =
26+
BugCheckerRefactoringTestHelper.newInstance(MockitoDoSetup.class, getClass());
27+
28+
@Test
29+
public void happy() {
30+
helper
31+
.addInputLines(
32+
"Test.java",
33+
"import org.mockito.Mockito;",
34+
"public class Test {",
35+
" public int test(Test test) {",
36+
" Mockito.doReturn(1).when(test).test(null);",
37+
" return 1;",
38+
" }",
39+
"}")
40+
.addOutputLines(
41+
"Test.java",
42+
"import static org.mockito.Mockito.when;",
43+
"import org.mockito.Mockito;",
44+
"public class Test {",
45+
" public int test(Test test) {",
46+
" when(test.test(null)).thenReturn(1);",
47+
" return 1;",
48+
" }",
49+
"}")
50+
.doTest();
51+
}
52+
53+
@Test
54+
public void ignoresSpiesCreatedByAnnotation() {
55+
helper
56+
.addInputLines(
57+
"Test.java",
58+
"import org.mockito.Mockito;",
59+
"public class Test {",
60+
" @org.mockito.Spy Test test;",
61+
" public int test() {",
62+
" Mockito.doReturn(1).when(test).test();",
63+
" return 1;",
64+
" }",
65+
"}")
66+
.expectUnchanged()
67+
.doTest();
68+
}
69+
70+
@Test
71+
public void ignoresSpiesCreatedByStaticMethod() {
72+
helper
73+
.addInputLines(
74+
"Test.java",
75+
"import org.mockito.Mockito;",
76+
"public class Test {",
77+
" Test test = Mockito.spy(Test.class);",
78+
" public int test() {",
79+
" Mockito.doReturn(1).when(test).test();",
80+
" return 1;",
81+
" }",
82+
"}")
83+
.expectUnchanged()
84+
.doTest();
85+
}
86+
87+
@Test
88+
public void ignoresMocksConfiguredToThrow_viaThenThrow() {
89+
helper
90+
.addInputLines(
91+
"Test.java",
92+
"import org.mockito.Mockito;",
93+
"public class Test {",
94+
" public int test(Test test) {",
95+
" Mockito.doReturn(1).when(test).test(null);",
96+
" Mockito.when(test.test(null)).thenThrow(new Exception());",
97+
" return 1;",
98+
" }",
99+
"}")
100+
.expectUnchanged()
101+
.doTest();
102+
}
103+
104+
@Test
105+
public void ignoresMocksConfiguredToThrow_viaDoThrow() {
106+
helper
107+
.addInputLines(
108+
"Test.java",
109+
"import org.mockito.Mockito;",
110+
"public class Test {",
111+
" public int test(Test test) {",
112+
" Mockito.doReturn(1).when(test).test(null);",
113+
" Mockito.doThrow(new Exception()).when(test).test(null);",
114+
" return 1;",
115+
" }",
116+
"}")
117+
.expectUnchanged()
118+
.doTest();
119+
}
120+
}

docs/bugpattern/MockitoDoSetup.md

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
Prefer using the format
2+
3+
```java
4+
when(mock.mockedMethod(...)).thenReturn(returnValue);
5+
```
6+
7+
to initialise mocks, rather than,
8+
9+
```java
10+
doReturn(returnValue).when(mock).mockedMethod(...);
11+
```
12+
13+
Mockito recommends the `when`/`thenReturn` syntax as it is both more readable
14+
and provides type-safety: the return type of the stubbed method is checked
15+
against the stubbed value at compile time.
16+
17+
There are certain situations where `doReturn` is required:
18+
19+
* Overriding previous stubbing where the method will *throw*, as `when` makes
20+
an actual method call.
21+
* Overriding a `spy` where the method call where calling the spied method
22+
brings undesired side-effects.

0 commit comments

Comments
 (0)