Skip to content

Commit 948c9b8

Browse files
committed
Revise internals of AOT testing support
1 parent e7a297a commit 948c9b8

File tree

9 files changed

+177
-108
lines changed

9 files changed

+177
-108
lines changed

spring-test/src/main/java/org/springframework/test/context/aot/AotTestAttributesFactory.java

Lines changed: 3 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,11 @@
1616

1717
package org.springframework.test.context.aot;
1818

19-
import java.lang.reflect.Method;
20-
import java.util.Collections;
2119
import java.util.Map;
2220
import java.util.concurrent.ConcurrentHashMap;
2321

2422
import org.springframework.aot.AotDetector;
2523
import org.springframework.lang.Nullable;
26-
import org.springframework.util.Assert;
27-
import org.springframework.util.ClassUtils;
28-
import org.springframework.util.ReflectionUtils;
2924

3025
/**
3126
* Factory for {@link AotTestAttributes}.
@@ -64,7 +59,7 @@ static Map<String, String> getAttributes() {
6459
}
6560

6661
/**
67-
* Reset AOT test attributes.
62+
* Reset the factory.
6863
* <p>Only for internal use.
6964
*/
7065
static void reset() {
@@ -73,23 +68,11 @@ static void reset() {
7368
}
7469
}
7570

76-
@SuppressWarnings({ "rawtypes", "unchecked" })
71+
@SuppressWarnings("unchecked")
7772
private static Map<String, String> loadAttributesMap() {
7873
String className = AotTestAttributesCodeGenerator.GENERATED_ATTRIBUTES_CLASS_NAME;
7974
String methodName = AotTestAttributesCodeGenerator.GENERATED_ATTRIBUTES_METHOD_NAME;
80-
try {
81-
Class<?> clazz = ClassUtils.forName(className, null);
82-
Method method = ReflectionUtils.findMethod(clazz, methodName);
83-
Assert.state(method != null, () -> "No %s() method found in %s".formatted(methodName, clazz.getName()));
84-
Map<String, String> attributes = (Map<String, String>) ReflectionUtils.invokeMethod(method, null);
85-
return Collections.unmodifiableMap(attributes);
86-
}
87-
catch (IllegalStateException ex) {
88-
throw ex;
89-
}
90-
catch (Exception ex) {
91-
throw new IllegalStateException("Failed to invoke %s() method on %s".formatted(methodName, className), ex);
92-
}
75+
return GeneratedMapUtils.loadMap(className, methodName);
9376
}
9477

9578
}

spring-test/src/main/java/org/springframework/test/context/aot/AotTestContextInitializers.java

Lines changed: 7 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -16,50 +16,35 @@
1616

1717
package org.springframework.test.context.aot;
1818

19-
import java.lang.reflect.Method;
2019
import java.util.Map;
2120
import java.util.function.Supplier;
2221

22+
import org.springframework.aot.AotDetector;
2323
import org.springframework.context.ApplicationContextInitializer;
2424
import org.springframework.context.ConfigurableApplicationContext;
2525
import org.springframework.lang.Nullable;
26-
import org.springframework.util.Assert;
27-
import org.springframework.util.ClassUtils;
28-
import org.springframework.util.ReflectionUtils;
2926

3027
/**
3128
* {@code AotTestContextInitializers} provides mappings from test classes to
3229
* AOT-optimized context initializers.
3330
*
34-
* <p>If a test class is not {@linkplain #isSupportedTestClass(Class) supported} in
35-
* AOT mode, {@link #getContextInitializer(Class)} will return {@code null}.
31+
* <p>Intended solely for internal use within the framework.
3632
*
37-
* <p>Reflectively accesses {@link #GENERATED_MAPPINGS_CLASS_NAME} generated by
38-
* the {@link TestContextAotGenerator} to retrieve the mappings generated during
39-
* AOT processing.
33+
* <p>If we are not running in {@linkplain AotDetector#useGeneratedArtifacts()
34+
* AOT mode} or if a test class is not {@linkplain #isSupportedTestClass(Class)
35+
* supported} in AOT mode, {@link #getContextInitializer(Class)} will return
36+
* {@code null}.
4037
*
4138
* @author Sam Brannen
42-
* @author Stephane Nicoll
4339
* @since 6.0
4440
*/
4541
public class AotTestContextInitializers {
4642

47-
// TODO Add support in ClassNameGenerator for supplying a predefined class name.
48-
// There is a similar issue in Spring Boot where code relies on a generated name.
49-
// Ideally we would generate a class named: org.springframework.test.context.aot.GeneratedAotTestContextInitializers
50-
static final String GENERATED_MAPPINGS_CLASS_NAME = AotTestContextInitializers.class.getName() + "__Generated";
51-
52-
static final String GENERATED_MAPPINGS_METHOD_NAME = "getContextInitializers";
53-
5443
private final Map<String, Supplier<ApplicationContextInitializer<ConfigurableApplicationContext>>> contextInitializers;
5544

5645

5746
public AotTestContextInitializers() {
58-
this(GENERATED_MAPPINGS_CLASS_NAME);
59-
}
60-
61-
AotTestContextInitializers(String initializerClassName) {
62-
this(loadContextInitializersMap(initializerClassName));
47+
this(AotTestContextInitializersFactory.getContextInitializers());
6348
}
6449

6550
AotTestContextInitializers(Map<String, Supplier<ApplicationContextInitializer<ConfigurableApplicationContext>>> contextInitializers) {
@@ -90,26 +75,4 @@ public ApplicationContextInitializer<ConfigurableApplicationContext> getContextI
9075
return (supplier != null ? supplier.get() : null);
9176
}
9277

93-
94-
@SuppressWarnings({ "rawtypes", "unchecked" })
95-
private static Map<String, Supplier<ApplicationContextInitializer<ConfigurableApplicationContext>>>
96-
loadContextInitializersMap(String className) {
97-
98-
String methodName = GENERATED_MAPPINGS_METHOD_NAME;
99-
100-
try {
101-
Class<?> clazz = ClassUtils.forName(className, null);
102-
Method method = ReflectionUtils.findMethod(clazz, methodName);
103-
Assert.state(method != null, () -> "No %s() method found in %s".formatted(methodName, clazz.getName()));
104-
return (Map<String, Supplier<ApplicationContextInitializer<ConfigurableApplicationContext>>>)
105-
ReflectionUtils.invokeMethod(method, null);
106-
}
107-
catch (IllegalStateException ex) {
108-
throw ex;
109-
}
110-
catch (Exception ex) {
111-
throw new IllegalStateException("Failed to invoke %s() method in %s".formatted(methodName, className), ex);
112-
}
113-
}
114-
11578
}

spring-test/src/main/java/org/springframework/test/context/aot/AotTestContextInitializersCodeGenerator.java

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,15 @@ class AotTestContextInitializersCodeGenerator {
6161
private static final TypeName CONTEXT_SUPPLIER_MAP = ParameterizedTypeName
6262
.get(ClassName.get(Map.class), ClassName.get(String.class), CONTEXT_INITIALIZER_SUPPLIER);
6363

64+
private static final String GENERATED_SUFFIX = "Generated";
65+
66+
// TODO Add support in ClassNameGenerator for supplying a predefined class name.
67+
// There is a similar issue in Spring Boot where code relies on a generated name.
68+
// Ideally we would generate a class named: org.springframework.test.context.aot.GeneratedAotTestContextInitializers
69+
static final String GENERATED_MAPPINGS_CLASS_NAME = AotTestContextInitializers.class.getName() + "__" + GENERATED_SUFFIX;
70+
71+
static final String GENERATED_MAPPINGS_METHOD_NAME = "getContextInitializers";
72+
6473

6574
private final MultiValueMap<ClassName, Class<?>> initializerClassMappings;
6675

@@ -71,7 +80,7 @@ class AotTestContextInitializersCodeGenerator {
7180
GeneratedClasses generatedClasses) {
7281

7382
this.initializerClassMappings = initializerClassMappings;
74-
this.generatedClass = generatedClasses.addForFeature("Generated", this::generateType);
83+
this.generatedClass = generatedClasses.addForFeature(GENERATED_SUFFIX, this::generateType);
7584
}
7685

7786

@@ -88,7 +97,7 @@ private void generateType(TypeSpec.Builder type) {
8897
}
8998

9099
private MethodSpec generateMappingMethod() {
91-
MethodSpec.Builder method = MethodSpec.methodBuilder(AotTestContextInitializers.GENERATED_MAPPINGS_METHOD_NAME);
100+
MethodSpec.Builder method = MethodSpec.methodBuilder(GENERATED_MAPPINGS_METHOD_NAME);
92101
method.addModifiers(Modifier.PUBLIC, Modifier.STATIC);
93102
method.returns(CONTEXT_SUPPLIER_MAP);
94103
method.addCode(generateMappingCode());
Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
/*
2+
* Copyright 2002-2022 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.test.context.aot;
18+
19+
import java.util.Map;
20+
import java.util.function.Supplier;
21+
22+
import org.springframework.aot.AotDetector;
23+
import org.springframework.context.ApplicationContextInitializer;
24+
import org.springframework.context.ConfigurableApplicationContext;
25+
import org.springframework.lang.Nullable;
26+
27+
/**
28+
* Factory for {@link AotTestContextInitializers}.
29+
*
30+
* @author Sam Brannen
31+
* @since 6.0
32+
*/
33+
final class AotTestContextInitializersFactory {
34+
35+
@Nullable
36+
private static volatile Map<String, Supplier<ApplicationContextInitializer<ConfigurableApplicationContext>>> contextInitializers;
37+
38+
39+
private AotTestContextInitializersFactory() {
40+
}
41+
42+
/**
43+
* Get the underlying map.
44+
* <p>If the map is not already loaded, this method loads the map from the
45+
* generated class when running in {@linkplain AotDetector#useGeneratedArtifacts()
46+
* AOT execution mode} and otherwise creates an immutable, empty map.
47+
*/
48+
static Map<String, Supplier<ApplicationContextInitializer<ConfigurableApplicationContext>>> getContextInitializers() {
49+
Map<String, Supplier<ApplicationContextInitializer<ConfigurableApplicationContext>>> initializers = contextInitializers;
50+
if (initializers == null) {
51+
synchronized (AotTestContextInitializersFactory.class) {
52+
initializers = contextInitializers;
53+
if (initializers == null) {
54+
initializers = (AotDetector.useGeneratedArtifacts() ? loadContextInitializersMap() : Map.of());
55+
contextInitializers = initializers;
56+
}
57+
}
58+
}
59+
return initializers;
60+
}
61+
62+
/**
63+
* Reset the factory.
64+
* <p>Only for internal use.
65+
*/
66+
static void reset() {
67+
synchronized (AotTestContextInitializersFactory.class) {
68+
contextInitializers = null;
69+
}
70+
}
71+
72+
@SuppressWarnings("unchecked")
73+
private static Map<String, Supplier<ApplicationContextInitializer<ConfigurableApplicationContext>>> loadContextInitializersMap() {
74+
String className = AotTestContextInitializersCodeGenerator.GENERATED_MAPPINGS_CLASS_NAME;
75+
String methodName = AotTestContextInitializersCodeGenerator.GENERATED_MAPPINGS_METHOD_NAME;
76+
return GeneratedMapUtils.loadMap(className, methodName);
77+
}
78+
79+
}
Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
/*
2+
* Copyright 2002-2022 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.test.context.aot;
18+
19+
import java.lang.reflect.Method;
20+
import java.util.Collections;
21+
import java.util.Map;
22+
23+
import org.springframework.util.Assert;
24+
import org.springframework.util.ClassUtils;
25+
import org.springframework.util.ReflectionUtils;
26+
27+
/**
28+
* Utilities for loading generated maps.
29+
*
30+
* @author Sam Brannen
31+
* @author Stephane Nicoll
32+
* @since 6.0
33+
*/
34+
final class GeneratedMapUtils {
35+
36+
private GeneratedMapUtils() {
37+
}
38+
39+
/**
40+
* Load a generated map.
41+
* @param className the name of the class in which the static method resides
42+
* @param methodName the name of the static method to invoke
43+
* @return an unmodifiable map retrieved from a static method
44+
*/
45+
@SuppressWarnings({ "rawtypes", "unchecked" })
46+
static Map loadMap(String className, String methodName) {
47+
try {
48+
Class<?> clazz = ClassUtils.forName(className, null);
49+
Method method = ReflectionUtils.findMethod(clazz, methodName);
50+
Assert.state(method != null, () -> "No %s() method found in %s".formatted(methodName, className));
51+
Map map = (Map) ReflectionUtils.invokeMethod(method, null);
52+
return Collections.unmodifiableMap(map);
53+
}
54+
catch (IllegalStateException ex) {
55+
throw ex;
56+
}
57+
catch (Exception ex) {
58+
throw new IllegalStateException("Failed to invoke %s() method on %s".formatted(methodName, className), ex);
59+
}
60+
}
61+
62+
}

spring-test/src/main/java/org/springframework/test/context/aot/TestContextAotGenerator.java

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import org.apache.commons.logging.Log;
2525
import org.apache.commons.logging.LogFactory;
2626

27+
import org.springframework.aot.AotDetector;
2728
import org.springframework.aot.generate.ClassNameGenerator;
2829
import org.springframework.aot.generate.DefaultGenerationContext;
2930
import org.springframework.aot.generate.GeneratedClasses;
@@ -109,9 +110,9 @@ public final RuntimeHints getRuntimeHints() {
109110
* @throws TestContextAotException if an error occurs during AOT processing
110111
*/
111112
public void processAheadOfTime(Stream<Class<?>> testClasses) throws TestContextAotException {
113+
Assert.state(!AotDetector.useGeneratedArtifacts(), "Cannot perform AOT processing during AOT run-time execution");
112114
try {
113-
// Make sure AOT attributes are cleared before processing
114-
AotTestAttributesFactory.reset();
115+
resetAotFactories();
115116

116117
MultiValueMap<MergedContextConfiguration, Class<?>> mergedConfigMappings = new LinkedMultiValueMap<>();
117118
testClasses.forEach(testClass -> mergedConfigMappings.add(buildMergedContextConfiguration(testClass), testClass));
@@ -121,11 +122,15 @@ public void processAheadOfTime(Stream<Class<?>> testClasses) throws TestContextA
121122
generateAotTestAttributes();
122123
}
123124
finally {
124-
// Clear AOT attributes after processing
125-
AotTestAttributesFactory.reset();
125+
resetAotFactories();
126126
}
127127
}
128128

129+
private void resetAotFactories() {
130+
AotTestAttributesFactory.reset();
131+
AotTestContextInitializersFactory.reset();
132+
}
133+
129134
private MultiValueMap<ClassName, Class<?>> processAheadOfTime(MultiValueMap<MergedContextConfiguration, Class<?>> mergedConfigMappings) {
130135
MultiValueMap<ClassName, Class<?>> initializerClassMappings = new LinkedMultiValueMap<>();
131136
mergedConfigMappings.forEach((mergedConfig, testClasses) -> {

spring-test/src/main/java/org/springframework/test/context/cache/DefaultCacheAwareContextLoaderDelegate.java

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@
1919
import org.apache.commons.logging.Log;
2020
import org.apache.commons.logging.LogFactory;
2121

22-
import org.springframework.aot.AotDetector;
2322
import org.springframework.context.ApplicationContext;
2423
import org.springframework.context.ApplicationContextInitializer;
2524
import org.springframework.context.ConfigurableApplicationContext;
@@ -56,8 +55,7 @@ public class DefaultCacheAwareContextLoaderDelegate implements CacheAwareContext
5655
*/
5756
static final ContextCache defaultContextCache = new DefaultContextCache();
5857

59-
@Nullable
60-
private final AotTestContextInitializers aotTestContextInitializers = getAotTestContextInitializers();
58+
private final AotTestContextInitializers aotTestContextInitializers = new AotTestContextInitializers();
6159

6260
private final ContextCache contextCache;
6361

@@ -200,21 +198,7 @@ private ContextLoader getContextLoader(MergedContextConfiguration mergedConfig)
200198
* Determine if we are running in AOT mode for the supplied test class.
201199
*/
202200
private boolean runningInAotMode(Class<?> testClass) {
203-
return (this.aotTestContextInitializers != null &&
204-
this.aotTestContextInitializers.isSupportedTestClass(testClass));
205-
}
206-
207-
@Nullable
208-
private static AotTestContextInitializers getAotTestContextInitializers() {
209-
if (AotDetector.useGeneratedArtifacts()) {
210-
try {
211-
return new AotTestContextInitializers();
212-
}
213-
catch (Exception ex) {
214-
throw new IllegalStateException("Failed to instantiate AotTestContextInitializers", ex);
215-
}
216-
}
217-
return null;
201+
return this.aotTestContextInitializers.isSupportedTestClass(testClass);
218202
}
219203

220204
}

0 commit comments

Comments
 (0)