Skip to content

Commit ef63ac9

Browse files
committed
Allow ArgumentsProvider implementations to use constructor injection
Issue: #4018
1 parent 54a6cb1 commit ef63ac9

File tree

6 files changed

+149
-32
lines changed

6 files changed

+149
-32
lines changed

documentation/src/docs/asciidoc/user-guide/writing-tests.adoc

+9
Original file line numberDiff line numberDiff line change
@@ -1984,6 +1984,15 @@ If you wish to implement a custom `ArgumentsProvider` that also consumes an anno
19841984
(like built-in providers such as `{ValueArgumentsProvider}` or `{CsvArgumentsProvider}`),
19851985
you have the possibility to extend the `{AnnotationBasedArgumentsProvider}` class.
19861986

1987+
Moreover, `ArgumentsProvider` implementations may declare constructor parameters in case
1988+
they need to be resolved by a registered `ParameterResolver` as demonstrated in the
1989+
following example.
1990+
1991+
[source,java,indent=0]
1992+
----
1993+
include::{testDir}/example/ParameterizedTestDemo.java[tags=ArgumentsProviderWithConstructorInjection_example]
1994+
----
1995+
19871996
[[writing-tests-parameterized-repeatable-sources]]
19881997
===== Multiple sources using repeatable annotations
19891998
Repeatable annotations provide a convenient way to specify multiple sources from

documentation/src/test/java/example/ParameterizedTestDemo.java

+23
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,29 @@ public Stream<? extends Arguments> provideArguments(ExtensionContext context) {
348348
}
349349
// end::ArgumentsProvider_example[]
350350

351+
@ParameterizedTest
352+
@ArgumentsSource(MyArgumentsProviderWithConstructorInjection.class)
353+
void testWithArgumentsSourceWithConstructorInjection(String argument) {
354+
assertNotNull(argument);
355+
}
356+
357+
static
358+
// tag::ArgumentsProviderWithConstructorInjection_example[]
359+
public class MyArgumentsProviderWithConstructorInjection implements ArgumentsProvider {
360+
361+
private final TestInfo testInfo;
362+
363+
public MyArgumentsProviderWithConstructorInjection(TestInfo testInfo) {
364+
this.testInfo = testInfo;
365+
}
366+
367+
@Override
368+
public Stream<? extends Arguments> provideArguments(ExtensionContext context) {
369+
return Stream.of(Arguments.of(testInfo.getDisplayName()));
370+
}
371+
}
372+
// end::ArgumentsProviderWithConstructorInjection_example[]
373+
351374
// tag::ParameterResolver_example[]
352375
@BeforeEach
353376
void beforeEach(TestInfo testInfo) {

junit-jupiter-params/src/main/java/org/junit/jupiter/params/ParameterizedTestExtension.java

+31-14
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,11 @@
1313
import static org.junit.platform.commons.support.AnnotationSupport.findAnnotation;
1414
import static org.junit.platform.commons.support.AnnotationSupport.findRepeatableAnnotations;
1515
import static org.junit.platform.commons.support.AnnotationSupport.isAnnotated;
16+
import static org.junit.platform.commons.util.CollectionUtils.getFirstElement;
1617

18+
import java.lang.reflect.Constructor;
1719
import java.lang.reflect.Method;
20+
import java.util.Optional;
1821
import java.util.concurrent.atomic.AtomicLong;
1922
import java.util.stream.Stream;
2023

@@ -27,9 +30,11 @@
2730
import org.junit.jupiter.params.provider.ArgumentsSource;
2831
import org.junit.jupiter.params.support.AnnotationConsumerInitializer;
2932
import org.junit.platform.commons.JUnitException;
30-
import org.junit.platform.commons.support.ReflectionSupport;
33+
import org.junit.platform.commons.PreconditionViolationException;
34+
import org.junit.platform.commons.support.ModifierSupport;
3135
import org.junit.platform.commons.util.ExceptionUtils;
3236
import org.junit.platform.commons.util.Preconditions;
37+
import org.junit.platform.commons.util.ReflectionUtils;
3338

3439
/**
3540
* @since 5.0
@@ -84,7 +89,7 @@ public Stream<TestTemplateInvocationContext> provideTestTemplateInvocationContex
8489
return findRepeatableAnnotations(templateMethod, ArgumentsSource.class)
8590
.stream()
8691
.map(ArgumentsSource::value)
87-
.map(this::instantiateArgumentsProvider)
92+
.map(clazz -> instantiateArgumentsProvider(clazz, extensionContext))
8893
.map(provider -> AnnotationConsumerInitializer.initialize(templateMethod, provider))
8994
.flatMap(provider -> arguments(provider, extensionContext))
9095
.map(arguments -> {
@@ -97,20 +102,32 @@ public Stream<TestTemplateInvocationContext> provideTestTemplateInvocationContex
97102
// @formatter:on
98103
}
99104

100-
@SuppressWarnings("ConstantConditions")
101-
private ArgumentsProvider instantiateArgumentsProvider(Class<? extends ArgumentsProvider> clazz) {
105+
private ArgumentsProvider instantiateArgumentsProvider(Class<? extends ArgumentsProvider> clazz,
106+
ExtensionContext extensionContext) {
107+
return extensionContext.getExecutableInvoker().invoke(findConstructor(ArgumentsProvider.class, clazz));
108+
}
109+
110+
@SuppressWarnings("unchecked")
111+
private static <T> Constructor<? extends T> findConstructor(Class<T> spiClass, Class<? extends T> clazz) {
112+
Optional<Constructor<?>> defaultConstructor = getFirstElement(
113+
ReflectionUtils.findConstructors(clazz, it -> it.getParameterCount() == 0));
114+
if (defaultConstructor.isPresent()) {
115+
return (Constructor<? extends T>) defaultConstructor.get();
116+
}
117+
if (ModifierSupport.isNotStatic(clazz)) {
118+
String message = String.format("The %s [%s] must be either a top-level class or a static nested class",
119+
spiClass.getSimpleName(), clazz.getName());
120+
throw new JUnitException(message);
121+
}
102122
try {
103-
return ReflectionSupport.newInstance(clazz);
123+
return ReflectionUtils.getDeclaredConstructor(clazz);
104124
}
105-
catch (Exception ex) {
106-
if (ex instanceof NoSuchMethodException) {
107-
String message = String.format("Failed to find a no-argument constructor for ArgumentsProvider [%s]. "
108-
+ "Please ensure that a no-argument constructor exists and "
109-
+ "that the class is either a top-level class or a static nested class",
110-
clazz.getName());
111-
throw new JUnitException(message, ex);
112-
}
113-
throw ex;
125+
catch (PreconditionViolationException ex) {
126+
String message = String.format(
127+
"Failed to find constructor for %s [%s]. "
128+
+ "Please ensure that a no-argument or a single constructor exists.",
129+
spiClass.getSimpleName(), clazz.getName());
130+
throw new JUnitException(message);
114131
}
115132
}
116133

junit-jupiter-params/src/main/java/org/junit/jupiter/params/provider/ArgumentsProvider.java

+3-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import org.apiguardian.api.API;
1818
import org.junit.jupiter.api.extension.ExtensionContext;
19+
import org.junit.jupiter.api.extension.ParameterResolver;
1920

2021
/**
2122
* An {@code ArgumentsProvider} is responsible for {@linkplain #provideArguments
@@ -25,7 +26,8 @@
2526
* <p>An {@code ArgumentsProvider} can be registered via the
2627
* {@link ArgumentsSource @ArgumentsSource} annotation.
2728
*
28-
* <p>Implementations must provide a no-args constructor.
29+
* <p>Implementations must provide a no-args constructor or a single unambiguous
30+
* constructor to use {@linkplain ParameterResolver parameter resolution}.
2931
*
3032
* @since 5.0
3133
* @see org.junit.jupiter.params.ParameterizedTest

jupiter-tests/src/test/java/org/junit/jupiter/params/ParameterizedTestExtensionTests.java

+31-17
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
import java.io.FileNotFoundException;
2121
import java.lang.reflect.AnnotatedElement;
22+
import java.lang.reflect.Constructor;
2223
import java.lang.reflect.Method;
2324
import java.util.Arrays;
2425
import java.util.Map;
@@ -40,6 +41,7 @@
4041
import org.junit.jupiter.params.provider.ArgumentsSource;
4142
import org.junit.platform.commons.JUnitException;
4243
import org.junit.platform.commons.PreconditionViolationException;
44+
import org.junit.platform.commons.util.ReflectionUtils;
4345
import org.junit.platform.engine.support.store.NamespacedHierarchicalStore;
4446

4547
/**
@@ -150,11 +152,14 @@ void throwsExceptionWhenArgumentsProviderIsNotStatic() {
150152

151153
var exception = assertThrows(JUnitException.class, stream::toArray);
152154

153-
assertArgumentsProviderInstantiationException(exception, NonStaticArgumentsProvider.class);
155+
assertThat(exception) //
156+
.hasMessage(String.format(
157+
"The ArgumentsProvider [%s] must be either a top-level class or a static nested class",
158+
NonStaticArgumentsProvider.class.getName()));
154159
}
155160

156161
@Test
157-
void throwsExceptionWhenArgumentsProviderDoesNotContainNoArgumentConstructor() {
162+
void throwsExceptionWhenArgumentsProviderDoesNotContainUnambiguousConstructor() {
158163
var extensionContextWithAnnotatedTestMethod = getExtensionContextReturningSingleMethod(
159164
new MissingNoArgumentsConstructorArgumentsProviderTestCase());
160165

@@ -163,15 +168,11 @@ void throwsExceptionWhenArgumentsProviderDoesNotContainNoArgumentConstructor() {
163168

164169
var exception = assertThrows(JUnitException.class, stream::toArray);
165170

166-
assertArgumentsProviderInstantiationException(exception, MissingNoArgumentsConstructorArgumentsProvider.class);
167-
}
168-
169-
private <T> void assertArgumentsProviderInstantiationException(JUnitException exception, Class<T> clazz) {
170-
assertThat(exception).hasMessage(
171-
String.format("Failed to find a no-argument constructor for ArgumentsProvider [%s]. "
172-
+ "Please ensure that a no-argument constructor exists and "
173-
+ "that the class is either a top-level class or a static nested class",
174-
clazz.getName()));
171+
String className = AmbiguousConstructorArgumentsProvider.class.getName();
172+
assertThat(exception) //
173+
.hasMessage(String.format("Failed to find constructor for ArgumentsProvider [%s]. "
174+
+ "Please ensure that a no-argument or a single constructor exists.",
175+
className));
175176
}
176177

177178
private ExtensionContext getExtensionContextReturningSingleMethod(Object testCase) {
@@ -277,7 +278,17 @@ public ExecutionMode getExecutionMode() {
277278

278279
@Override
279280
public ExecutableInvoker getExecutableInvoker() {
280-
return null;
281+
return new ExecutableInvoker() {
282+
@Override
283+
public Object invoke(Method method, Object target) {
284+
return null;
285+
}
286+
287+
@Override
288+
public <T> T invoke(Constructor<T> constructor, Object outerInstance) {
289+
return ReflectionUtils.newInstance(constructor);
290+
}
291+
};
281292
}
282293
};
283294
}
@@ -334,30 +345,33 @@ public Stream<? extends Arguments> provideArguments(ExtensionContext context) {
334345
static class MissingNoArgumentsConstructorArgumentsProviderTestCase {
335346

336347
@ParameterizedTest
337-
@ArgumentsSource(MissingNoArgumentsConstructorArgumentsProvider.class)
348+
@ArgumentsSource(AmbiguousConstructorArgumentsProvider.class)
338349
void method() {
339350
}
340351
}
341352

342353
static class EmptyDisplayNameProviderTestCase {
343354

344355
@ParameterizedTest(name = "")
345-
@ArgumentsSource(MissingNoArgumentsConstructorArgumentsProvider.class)
356+
@ArgumentsSource(AmbiguousConstructorArgumentsProvider.class)
346357
void method() {
347358
}
348359
}
349360

350361
static class DefaultDisplayNameProviderTestCase {
351362

352363
@ParameterizedTest
353-
@ArgumentsSource(MissingNoArgumentsConstructorArgumentsProvider.class)
364+
@ArgumentsSource(AmbiguousConstructorArgumentsProvider.class)
354365
void method() {
355366
}
356367
}
357368

358-
static class MissingNoArgumentsConstructorArgumentsProvider implements ArgumentsProvider {
369+
static class AmbiguousConstructorArgumentsProvider implements ArgumentsProvider {
370+
371+
AmbiguousConstructorArgumentsProvider(String parameter) {
372+
}
359373

360-
MissingNoArgumentsConstructorArgumentsProvider(String parameter) {
374+
AmbiguousConstructorArgumentsProvider(int parameter) {
361375
}
362376

363377
@Override

jupiter-tests/src/test/java/org/junit/jupiter/params/ParameterizedTestIntegrationTests.java

+52
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
import java.lang.annotation.Retention;
4343
import java.lang.annotation.RetentionPolicy;
4444
import java.lang.annotation.Target;
45+
import java.lang.reflect.Constructor;
4546
import java.util.ArrayList;
4647
import java.util.Arrays;
4748
import java.util.Collection;
@@ -84,6 +85,8 @@
8485
import org.junit.jupiter.api.extension.ExtensionContext;
8586
import org.junit.jupiter.api.extension.ParameterContext;
8687
import org.junit.jupiter.api.extension.ParameterResolutionException;
88+
import org.junit.jupiter.api.extension.ParameterResolver;
89+
import org.junit.jupiter.api.extension.RegisterExtension;
8790
import org.junit.jupiter.engine.JupiterTestEngine;
8891
import org.junit.jupiter.params.ParameterizedTestIntegrationTests.RepeatableSourcesTestCase.Action;
8992
import org.junit.jupiter.params.aggregator.AggregateWith;
@@ -1206,6 +1209,17 @@ void executesTwoIterationsBasedOnIterationAndUniqueIdSelector() {
12061209
.haveExactly(1, event(test(), displayName("[3] argument=5"), finishedWithFailure()));
12071210
}
12081211

1212+
@Nested
1213+
class SpiParameterInjectionIntegrationTests {
1214+
1215+
@Test
1216+
void injectsParametersIntoArgumentsProviderConstructor() {
1217+
execute(SpiParameterInjectionTestCase.class, "argumentsProviderWithConstructorParameter", String.class) //
1218+
.testEvents() //
1219+
.assertStatistics(it -> it.succeeded(1));
1220+
}
1221+
}
1222+
12091223
// -------------------------------------------------------------------------
12101224

12111225
static class TestCase {
@@ -1307,6 +1321,7 @@ void testWithThreeIterations(int argument) {
13071321
}
13081322
}
13091323

1324+
@SuppressWarnings("JUnitMalformedDeclaration")
13101325
static class NullSourceTestCase {
13111326

13121327
@ParameterizedTest
@@ -1342,6 +1357,7 @@ void testWithNullSourceForPrimitive(int argument) {
13421357

13431358
}
13441359

1360+
@SuppressWarnings("JUnitMalformedDeclaration")
13451361
static class EmptySourceTestCase {
13461362

13471363
@ParameterizedTest
@@ -1497,6 +1513,7 @@ void testWithEmptySourceForUnsupportedReferenceType(Integer argument) {
14971513

14981514
}
14991515

1516+
@SuppressWarnings("JUnitMalformedDeclaration")
15001517
static class NullAndEmptySourceTestCase {
15011518

15021519
@ParameterizedTest
@@ -1538,6 +1555,7 @@ void testWithNullAndEmptySourceForTwoDimensionalStringArray(String[][] argument)
15381555

15391556
}
15401557

1558+
@SuppressWarnings("JUnitMalformedDeclaration")
15411559
@TestMethodOrder(OrderAnnotation.class)
15421560
static class MethodSourceTestCase {
15431561

@@ -2119,6 +2137,40 @@ void testWithRepeatableArgumentsSource(String argument) {
21192137
}
21202138
}
21212139

2140+
static class SpiParameterInjectionTestCase {
2141+
2142+
@RegisterExtension
2143+
static final ParameterResolver spiParameterResolver = new ParameterResolver() {
2144+
2145+
@Override
2146+
public boolean supportsParameter(ParameterContext parameterContext, ExtensionContext extensionContext)
2147+
throws ParameterResolutionException {
2148+
return parameterContext.getDeclaringExecutable() instanceof Constructor //
2149+
&& String.class.equals(parameterContext.getParameter().getType());
2150+
}
2151+
2152+
@Override
2153+
public Object resolveParameter(ParameterContext parameterContext, ExtensionContext extensionContext)
2154+
throws ParameterResolutionException {
2155+
return "resolved value";
2156+
}
2157+
};
2158+
2159+
@ParameterizedTest
2160+
@ArgumentsSource(ArgumentsProviderWithConstructorParameter.class)
2161+
void argumentsProviderWithConstructorParameter(String argument) {
2162+
assertEquals("resolved value", argument);
2163+
}
2164+
2165+
record ArgumentsProviderWithConstructorParameter(String value) implements ArgumentsProvider {
2166+
2167+
@Override
2168+
public Stream<? extends Arguments> provideArguments(ExtensionContext context) {
2169+
return Stream.of(arguments(value));
2170+
}
2171+
}
2172+
}
2173+
21222174
private static class TwoSingleStringArgumentsProvider implements ArgumentsProvider {
21232175

21242176
@Override

0 commit comments

Comments
 (0)