Skip to content

Commit 08e0baa

Browse files
committed
Honor @⁠Primary for test Bean Overrides such as @⁠MockitoBean
Spring Boot has honored @⁠Primary for selecting which candidate bean @⁠MockBean and @⁠SpyBean should mock or spy since Spring Boot 1.4.3; however, the support for @⁠Primary was not ported from Spring Boot to Spring Framework's new Bean Overrides feature in the TestContext framework. To address that, this commit introduces support for @⁠Primary for selecting bean overrides -- for example, for annotations such as @⁠TestBean, @⁠MockitoBean, and @⁠MockitoSpyBean. See spring-projects/spring-boot#7621 Closes gh-33819
1 parent 9166688 commit 08e0baa

File tree

4 files changed

+106
-22
lines changed

4 files changed

+106
-22
lines changed

spring-test/src/main/java/org/springframework/test/context/bean/override/BeanOverrideBeanFactoryPostProcessor.java

+59-14
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
import org.springframework.beans.factory.BeanFactory;
2727
import org.springframework.beans.factory.BeanFactoryUtils;
2828
import org.springframework.beans.factory.FactoryBean;
29+
import org.springframework.beans.factory.NoUniqueBeanDefinitionException;
2930
import org.springframework.beans.factory.config.BeanDefinition;
3031
import org.springframework.beans.factory.config.BeanFactoryPostProcessor;
3132
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
@@ -219,29 +220,43 @@ else if (Boolean.getBoolean(AbstractAotProcessor.AOT_PROCESSING)) {
219220
*/
220221
private void wrapBean(ConfigurableListableBeanFactory beanFactory, BeanOverrideHandler handler) {
221222
String beanName = handler.getBeanName();
223+
ResolvableType beanType = handler.getBeanType();
224+
222225
if (beanName == null) {
226+
// We are wrapping an existing bean by-type.
223227
Set<String> candidateNames = getExistingBeanNamesByType(beanFactory, handler, true);
224228
int candidateCount = candidateNames.size();
225-
if (candidateCount != 1) {
226-
Field field = handler.getField();
227-
throw new IllegalStateException("""
228-
Unable to select a bean to override by wrapping: found %d bean instances of type %s \
229-
(as required by annotated field '%s.%s')%s"""
230-
.formatted(candidateCount, handler.getBeanType(),
231-
field.getDeclaringClass().getSimpleName(), field.getName(),
232-
(candidateCount > 0 ? ": " + candidateNames : "")));
229+
if (candidateCount == 1) {
230+
beanName = candidateNames.iterator().next();
231+
}
232+
else {
233+
String primaryCandidate = determinePrimaryCandidate(beanFactory, candidateNames, beanType.toClass());
234+
if (primaryCandidate != null) {
235+
beanName = primaryCandidate;
236+
}
237+
else {
238+
Field field = handler.getField();
239+
throw new IllegalStateException("""
240+
Unable to select a bean to override by wrapping: found %d bean instances of type %s \
241+
(as required by annotated field '%s.%s')%s"""
242+
.formatted(candidateCount, beanType, field.getDeclaringClass().getSimpleName(),
243+
field.getName(), (candidateCount > 0 ? ": " + candidateNames : "")));
244+
245+
}
233246
}
234-
beanName = BeanFactoryUtils.transformedBeanName(candidateNames.iterator().next());
247+
beanName = BeanFactoryUtils.transformedBeanName(beanName);
235248
}
236249
else {
250+
// We are wrapping an existing bean by-name.
237251
Set<String> candidates = getExistingBeanNamesByType(beanFactory, handler, false);
238252
if (!candidates.contains(beanName)) {
239253
throw new IllegalStateException("""
240254
Unable to override bean by wrapping: there is no existing bean \
241255
with name [%s] and type [%s]."""
242-
.formatted(beanName, handler.getBeanType()));
256+
.formatted(beanName, beanType));
243257
}
244258
}
259+
245260
validateBeanDefinition(beanFactory, beanName);
246261
this.beanOverrideRegistry.registerBeanOverrideHandler(handler, beanName);
247262
}
@@ -250,26 +265,32 @@ private void wrapBean(ConfigurableListableBeanFactory beanFactory, BeanOverrideH
250265
private String getBeanNameForType(ConfigurableListableBeanFactory beanFactory, BeanOverrideHandler handler,
251266
boolean requireExistingBean) {
252267

268+
Field field = handler.getField();
269+
ResolvableType beanType = handler.getBeanType();
270+
253271
Set<String> candidateNames = getExistingBeanNamesByType(beanFactory, handler, true);
254272
int candidateCount = candidateNames.size();
255273
if (candidateCount == 1) {
256274
return candidateNames.iterator().next();
257275
}
258276
else if (candidateCount == 0) {
259277
if (requireExistingBean) {
260-
Field field = handler.getField();
261278
throw new IllegalStateException(
262279
"Unable to override bean: no beans of type %s (as required by annotated field '%s.%s')"
263-
.formatted(handler.getBeanType(), field.getDeclaringClass().getSimpleName(), field.getName()));
280+
.formatted(beanType, field.getDeclaringClass().getSimpleName(), field.getName()));
264281
}
265282
return null;
266283
}
267284

268-
Field field = handler.getField();
285+
String primaryCandidate = determinePrimaryCandidate(beanFactory, candidateNames, beanType.toClass());
286+
if (primaryCandidate != null) {
287+
return primaryCandidate;
288+
}
289+
269290
throw new IllegalStateException("""
270291
Unable to select a bean to override: found %s beans of type %s \
271292
(as required by annotated field '%s.%s'): %s"""
272-
.formatted(candidateCount, handler.getBeanType(), field.getDeclaringClass().getSimpleName(),
293+
.formatted(candidateCount, beanType, field.getDeclaringClass().getSimpleName(),
273294
field.getName(), candidateNames));
274295
}
275296

@@ -310,6 +331,30 @@ private Set<String> getExistingBeanNamesByType(ConfigurableListableBeanFactory b
310331
return beanNames;
311332
}
312333

334+
@Nullable
335+
private static String determinePrimaryCandidate(
336+
ConfigurableListableBeanFactory beanFactory, Set<String> candidateBeanNames, Class<?> beanType) {
337+
338+
if (candidateBeanNames.isEmpty()) {
339+
return null;
340+
}
341+
342+
String primaryBeanName = null;
343+
for (String candidateBeanName : candidateBeanNames) {
344+
if (beanFactory.containsBeanDefinition(candidateBeanName)) {
345+
BeanDefinition beanDefinition = beanFactory.getBeanDefinition(candidateBeanName);
346+
if (beanDefinition.isPrimary()) {
347+
if (primaryBeanName != null) {
348+
throw new NoUniqueBeanDefinitionException(beanType, candidateBeanNames.size(),
349+
"more than one 'primary' bean found among candidates: " + candidateBeanNames);
350+
}
351+
primaryBeanName = candidateBeanName;
352+
}
353+
}
354+
}
355+
return primaryBeanName;
356+
}
357+
313358
/**
314359
* Create a pseudo-{@link BeanDefinition} for the supplied {@link BeanOverrideHandler},
315360
* whose {@linkplain RootBeanDefinition#getTargetType() target type} and

spring-test/src/test/java/org/springframework/test/context/bean/override/BeanOverrideBeanFactoryPostProcessorTests.java

+47
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626

2727
import org.springframework.beans.BeanWrapper;
2828
import org.springframework.beans.factory.FactoryBean;
29+
import org.springframework.beans.factory.NoUniqueBeanDefinitionException;
2930
import org.springframework.beans.factory.annotation.Qualifier;
3031
import org.springframework.beans.factory.config.BeanDefinition;
3132
import org.springframework.beans.factory.config.BeanFactoryPostProcessor;
@@ -42,6 +43,7 @@
4243
import org.springframework.util.Assert;
4344

4445
import static org.assertj.core.api.Assertions.assertThat;
46+
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
4547
import static org.assertj.core.api.Assertions.assertThatIllegalStateException;
4648
import static org.assertj.core.api.Assertions.assertThatNoException;
4749
import static org.mockito.Mockito.mock;
@@ -167,6 +169,41 @@ void replaceBeanByTypeWithMultipleMatchesAndFieldNameAsFallbackQualifierMatches(
167169
assertThat(context.getBean("counter")).isSameAs(42);
168170
}
169171

172+
@Test // gh-33819
173+
void replaceBeanByTypeWithMultipleCandidatesAndOnePrimary() {
174+
AnnotationConfigApplicationContext context = createContext(TestBeanByTypeTestCase.class);
175+
context.registerBean("description1", String.class, () -> "one");
176+
RootBeanDefinition beanDefinition2 = new RootBeanDefinition(String.class);
177+
beanDefinition2.getConstructorArgumentValues().addIndexedArgumentValue(0, "two");
178+
beanDefinition2.setPrimary(true);
179+
context.registerBeanDefinition("description2", beanDefinition2);
180+
context.refresh();
181+
182+
assertThat(context.getBean("description1", String.class)).isEqualTo("one");
183+
assertThat(context.getBean("description2", String.class)).isEqualTo("overridden");
184+
assertThat(context.getBean(String.class)).isEqualTo("overridden");
185+
}
186+
187+
@Test // gh-33819
188+
void replaceBeanByTypeWithMultipleCandidatesAndMultiplePrimaryBeansFails() {
189+
AnnotationConfigApplicationContext context = createContext(TestBeanByTypeTestCase.class);
190+
191+
RootBeanDefinition beanDefinition1 = new RootBeanDefinition(String.class);
192+
beanDefinition1.getConstructorArgumentValues().addIndexedArgumentValue(0, "one");
193+
beanDefinition1.setPrimary(true);
194+
context.registerBeanDefinition("description1", beanDefinition1);
195+
196+
RootBeanDefinition beanDefinition2 = new RootBeanDefinition(String.class);
197+
beanDefinition2.getConstructorArgumentValues().addIndexedArgumentValue(0, "two");
198+
beanDefinition2.setPrimary(true);
199+
context.registerBeanDefinition("description2", beanDefinition2);
200+
201+
assertThatExceptionOfType(NoUniqueBeanDefinitionException.class)
202+
.isThrownBy(context::refresh)
203+
.withMessage("No qualifying bean of type 'java.lang.String' available: " +
204+
"more than one 'primary' bean found among candidates: [description1, description2]");
205+
}
206+
170207
@Test
171208
void createOrReplaceBeanByNameWithMatchingBeanDefinition() {
172209
AnnotationConfigApplicationContext context = createContext(CaseByNameWithReplaceOrCreateStrategy.class);
@@ -428,6 +465,16 @@ static String descriptionBean() {
428465
}
429466
}
430467

468+
static class TestBeanByTypeTestCase {
469+
470+
@TestBean
471+
String description;
472+
473+
static String description() {
474+
return "overridden";
475+
}
476+
}
477+
431478
static class TestFactoryBean implements FactoryBean<Object> {
432479

433480
@Override

spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/integration/MockitoBeanWithMultipleExistingBeansAndOnePrimaryIntegrationTests.java

-4
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
package org.springframework.test.context.bean.override.mockito.integration;
1818

19-
import org.junit.jupiter.api.Disabled;
2019
import org.junit.jupiter.api.Test;
2120
import org.junit.jupiter.api.extension.ExtendWith;
2221
import org.mockito.MockingDetails;
@@ -27,7 +26,6 @@
2726
import org.springframework.context.annotation.Configuration;
2827
import org.springframework.context.annotation.Import;
2928
import org.springframework.context.annotation.Primary;
30-
import org.springframework.test.context.aot.DisabledInAotMode;
3129
import org.springframework.test.context.bean.override.example.ExampleGenericServiceCaller;
3230
import org.springframework.test.context.bean.override.example.IntegerExampleGenericService;
3331
import org.springframework.test.context.bean.override.example.StringExampleGenericService;
@@ -49,8 +47,6 @@
4947
* @see MockitoBeanWithMultipleExistingBeansAndExplicitBeanNameIntegrationTests
5048
* @see MockitoBeanWithMultipleExistingBeansAndExplicitQualifierIntegrationTests
5149
*/
52-
@Disabled("Disabled until @Primary is supported for BeanOverrideStrategy.REPLACE_OR_CREATE")
53-
@DisabledInAotMode
5450
@ExtendWith(SpringExtension.class)
5551
class MockitoBeanWithMultipleExistingBeansAndOnePrimaryIntegrationTests {
5652

spring-test/src/test/java/org/springframework/test/context/bean/override/mockito/integration/MockitoSpyBeanWithMultipleExistingBeansAndOnePrimaryIntegrationTests.java

-4
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@
1616

1717
package org.springframework.test.context.bean.override.mockito.integration;
1818

19-
import org.junit.jupiter.api.Disabled;
2019
import org.junit.jupiter.api.Test;
2120
import org.junit.jupiter.api.extension.ExtendWith;
2221
import org.mockito.MockingDetails;
@@ -27,7 +26,6 @@
2726
import org.springframework.context.annotation.Configuration;
2827
import org.springframework.context.annotation.Import;
2928
import org.springframework.context.annotation.Primary;
30-
import org.springframework.test.context.aot.DisabledInAotMode;
3129
import org.springframework.test.context.bean.override.example.ExampleGenericServiceCaller;
3230
import org.springframework.test.context.bean.override.example.IntegerExampleGenericService;
3331
import org.springframework.test.context.bean.override.example.StringExampleGenericService;
@@ -48,8 +46,6 @@
4846
* @see MockitoSpyBeanWithMultipleExistingBeansAndExplicitBeanNameIntegrationTests
4947
* @see MockitoSpyBeanWithMultipleExistingBeansAndExplicitQualifierIntegrationTests
5048
*/
51-
@Disabled("Disabled until @Primary is supported for @MockitoSpyBean")
52-
@DisabledInAotMode
5349
@ExtendWith(SpringExtension.class)
5450
class MockitoSpyBeanWithMultipleExistingBeansAndOnePrimaryIntegrationTests {
5551

0 commit comments

Comments
 (0)