Skip to content

Commit 3db1721

Browse files
committed
Introduce AotTestMappings and AotTestMappingsCodeGenerator
TestContextAotGenerator now uses AotTestMappingsCodeGenerator to generate a AotTestMappings__Generated.java class which is loaded in AotTestMappings via reflection in order to retrieve access to ApplicationContextIntializers generated during AOT processing. Furthermore, the processAheadOfTimeAndGenerateAotTestMappings() method in TestContextAotGeneratorTests now performs a rather extensive test including: - emulating TestClassScanner to find test classes - processing all test classes and generating ApplicationContextIntializers - generating mappings for AotTestMappings - compiling all generated code - loading AotTestMappings - using AotTestMappings to instantiate the generated ApplicationContextIntializers - using AotRuntimeContextLoader to load the AOT-optimized ApplicationContext - asserting the behavior of the loaded ApplicationContext See spring-projectsgh-28205 Closes spring-projectsgh-28204
1 parent 7f1bbea commit 3db1721

File tree

6 files changed

+288
-21
lines changed

6 files changed

+288
-21
lines changed
Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
/*
2+
* Copyright 2002-2022 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.test.context.aot;
18+
19+
import java.lang.reflect.Method;
20+
import java.util.Map;
21+
import java.util.function.Supplier;
22+
23+
import org.springframework.context.ApplicationContextInitializer;
24+
import org.springframework.context.support.GenericApplicationContext;
25+
import org.springframework.util.ClassUtils;
26+
import org.springframework.util.ReflectionUtils;
27+
28+
/**
29+
* {@code AotTestMappings} provides mappings from test classes to AOT-optimized
30+
* context initializers.
31+
*
32+
* <p>If a test class is not {@linkplain #isSupportedTestClass(Class) supported} in
33+
* AOT mode, {@link #getContextInitializer(Class)} will return {@code null}.
34+
*
35+
* <p>Reflectively accesses {@link #GENERATED_MAPPINGS_CLASS_NAME} generated by
36+
* the {@link TestContextAotGenerator} to retrieve the mappings generated during
37+
* AOT processing.
38+
*
39+
* @author Sam Brannen
40+
* @author Stephane Nicoll
41+
* @since 6.0
42+
*/
43+
public class AotTestMappings {
44+
45+
static final String GENERATED_MAPPINGS_CLASS_NAME = AotTestMappings.class.getName() + "__Generated";
46+
47+
static final String GENERATED_MAPPINGS_METHOD_NAME = "getContextInitializers";
48+
49+
private final Map<String, Supplier<ApplicationContextInitializer<GenericApplicationContext>>> contextInitializers;
50+
51+
52+
public AotTestMappings() {
53+
this(GENERATED_MAPPINGS_CLASS_NAME);
54+
}
55+
56+
AotTestMappings(String initializerClassName) {
57+
this(loadContextInitializersMap(initializerClassName));
58+
}
59+
60+
AotTestMappings(Map<String, Supplier<ApplicationContextInitializer<GenericApplicationContext>>> contextInitializers) {
61+
this.contextInitializers = contextInitializers;
62+
}
63+
64+
65+
/**
66+
* Determine if the specified test class has an AOT-optimized application context
67+
* initializer.
68+
* <p>If this method returns {@code true}, {@link #getContextInitializer(Class)}
69+
* should not return {@code null}.
70+
*/
71+
public boolean isSupportedTestClass(Class<?> testClass) {
72+
return this.contextInitializers.containsKey(testClass.getName());
73+
}
74+
75+
/**
76+
* Get the AOT {@link ApplicationContextInitializer} for the specified test class.
77+
* @return the AOT context initializer, or {@code null} if there is no AOT context
78+
* initializer for the specified test class
79+
* @see #isSupportedTestClass(Class)
80+
*/
81+
public ApplicationContextInitializer<GenericApplicationContext> getContextInitializer(Class<?> testClass) {
82+
Supplier<ApplicationContextInitializer<GenericApplicationContext>> supplier =
83+
this.contextInitializers.get(testClass.getName());
84+
return (supplier != null ? supplier.get() : null);
85+
}
86+
87+
88+
@SuppressWarnings({ "rawtypes", "unchecked" })
89+
private static Map<String, Supplier<ApplicationContextInitializer<GenericApplicationContext>>>
90+
loadContextInitializersMap(String className) {
91+
92+
String methodName = GENERATED_MAPPINGS_METHOD_NAME;
93+
94+
try {
95+
Class<?> clazz = ClassUtils.forName(className, null);
96+
Method method = ReflectionUtils.findMethod(clazz, methodName);
97+
if (method == null) {
98+
throw new IllegalStateException("No %s() method found in %s".formatted(methodName, clazz.getName()));
99+
}
100+
return (Map<String, Supplier<ApplicationContextInitializer<GenericApplicationContext>>>)
101+
ReflectionUtils.invokeMethod(method, null);
102+
}
103+
catch (IllegalStateException ex) {
104+
throw ex;
105+
}
106+
catch (Exception ex) {
107+
throw new IllegalStateException("Failed to load %s() method in %s".formatted(methodName, className), ex);
108+
}
109+
}
110+
111+
}
Lines changed: 113 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,113 @@
1+
/*
2+
* Copyright 2002-2022 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.test.context.aot;
18+
19+
import java.util.HashMap;
20+
import java.util.List;
21+
import java.util.Map;
22+
import java.util.function.Supplier;
23+
24+
import javax.lang.model.element.Modifier;
25+
26+
import org.apache.commons.logging.Log;
27+
import org.apache.commons.logging.LogFactory;
28+
29+
import org.springframework.aot.generate.GeneratedClass;
30+
import org.springframework.aot.generate.GeneratedClasses;
31+
import org.springframework.context.ApplicationContextInitializer;
32+
import org.springframework.context.support.GenericApplicationContext;
33+
import org.springframework.javapoet.ClassName;
34+
import org.springframework.javapoet.CodeBlock;
35+
import org.springframework.javapoet.MethodSpec;
36+
import org.springframework.javapoet.ParameterizedTypeName;
37+
import org.springframework.javapoet.TypeName;
38+
import org.springframework.javapoet.TypeSpec;
39+
import org.springframework.util.MultiValueMap;
40+
41+
/**
42+
* Internal code generator for mappings used by {@link AotTestMappings}.
43+
*
44+
* @author Sam Brannen
45+
* @since 6.0
46+
*/
47+
class AotTestMappingsCodeGenerator {
48+
49+
private static final Log logger = LogFactory.getLog(AotTestMappingsCodeGenerator.class);
50+
51+
private final MultiValueMap<ClassName, Class<?>> classNameMappings;
52+
private final GeneratedClass generatedClass;
53+
54+
55+
AotTestMappingsCodeGenerator(MultiValueMap<ClassName, Class<?>> classNameMappings,
56+
GeneratedClasses generatedClasses) {
57+
58+
this.classNameMappings = classNameMappings;
59+
this.generatedClass = generatedClasses.addForFeature("Generated", this::generateType);
60+
}
61+
62+
63+
GeneratedClass getGeneratedClass() {
64+
return this.generatedClass;
65+
}
66+
67+
private void generateType(TypeSpec.Builder type) {
68+
if (logger.isDebugEnabled()) {
69+
logger.debug("Generating AOT test mappings in " + this.generatedClass.getName().reflectionName());
70+
}
71+
type.addJavadoc("Generated mappings for {@link $T}.", AotTestMappings.class);
72+
type.addModifiers(Modifier.PUBLIC);
73+
type.addMethod(generateMappingMethod());
74+
}
75+
76+
private MethodSpec generateMappingMethod() {
77+
// Map<String, Supplier<ApplicationContextInitializer<GenericApplicationContext>>>
78+
ParameterizedTypeName aciType = ParameterizedTypeName.get(
79+
ClassName.get(ApplicationContextInitializer.class),
80+
ClassName.get(GenericApplicationContext.class));
81+
ParameterizedTypeName supplierType = ParameterizedTypeName.get(
82+
ClassName.get(Supplier.class),
83+
aciType);
84+
TypeName mapType = ParameterizedTypeName.get(
85+
ClassName.get(Map.class),
86+
ClassName.get(String.class),
87+
supplierType);
88+
89+
MethodSpec.Builder method = MethodSpec.methodBuilder("getContextInitializers");
90+
method.addModifiers(Modifier.PUBLIC, Modifier.STATIC);
91+
method.returns(mapType);
92+
method.addCode(generateMappingCode(mapType));
93+
return method.build();
94+
}
95+
96+
private CodeBlock generateMappingCode(TypeName mapType) {
97+
CodeBlock.Builder code = CodeBlock.builder();
98+
code.addStatement("$T map = new $T<>()", mapType, HashMap.class);
99+
this.classNameMappings.forEach((className, testClasses) -> {
100+
List<String> testClassNames = testClasses.stream().map(Class::getName).toList();
101+
if (logger.isDebugEnabled()) {
102+
String contextInitializerName = className.reflectionName();
103+
logger.debug("Generating mapping from AOT context initializer [%s] to test classes %s"
104+
.formatted(contextInitializerName, testClassNames));
105+
}
106+
testClassNames.forEach(testClassName ->
107+
code.addStatement("map.put($S, () -> new $T())", testClassName, className));
108+
});
109+
code.addStatement("return map");
110+
return code.build();
111+
}
112+
113+
}

spring-test/src/main/java/org/springframework/test/context/aot/TestContextAotGenerator.java

Lines changed: 28 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,12 @@
2424

2525
import org.springframework.aot.generate.ClassNameGenerator;
2626
import org.springframework.aot.generate.DefaultGenerationContext;
27+
import org.springframework.aot.generate.GeneratedClasses;
2728
import org.springframework.aot.generate.GeneratedFiles;
2829
import org.springframework.aot.generate.GenerationContext;
30+
import org.springframework.aot.hint.MemberCategory;
2931
import org.springframework.aot.hint.RuntimeHints;
32+
import org.springframework.aot.hint.TypeReference;
3033
import org.springframework.context.ApplicationContext;
3134
import org.springframework.context.ApplicationContextInitializer;
3235
import org.springframework.context.aot.ApplicationContextAotGenerator;
@@ -51,7 +54,7 @@
5154
*/
5255
class TestContextAotGenerator {
5356

54-
private static final Log logger = LogFactory.getLog(TestClassScanner.class);
57+
private static final Log logger = LogFactory.getLog(TestContextAotGenerator.class);
5558

5659
private final ApplicationContextAotGenerator aotGenerator = new ApplicationContextAotGenerator();
5760

@@ -97,21 +100,23 @@ public final RuntimeHints getRuntimeHints() {
97100
* @throws TestContextAotException if an error occurs during AOT processing
98101
*/
99102
public void processAheadOfTime(Stream<Class<?>> testClasses) throws TestContextAotException {
100-
MultiValueMap<MergedContextConfiguration, Class<?>> map = new LinkedMultiValueMap<>();
101-
testClasses.forEach(testClass -> map.add(buildMergedContextConfiguration(testClass), testClass));
103+
MultiValueMap<ClassName, Class<?>> classNameMappings = new LinkedMultiValueMap<>();
104+
MultiValueMap<MergedContextConfiguration, Class<?>> mergedConfigMappings = new LinkedMultiValueMap<>();
105+
testClasses.forEach(testClass -> mergedConfigMappings.add(buildMergedContextConfiguration(testClass), testClass));
102106

103-
map.forEach((mergedConfig, classes) -> {
104-
// System.err.println(mergedConfig + " -> " + classes);
107+
mergedConfigMappings.forEach((mergedConfig, classes) -> {
105108
if (logger.isDebugEnabled()) {
106-
logger.debug("Generating AOT artifacts for test classes [%s]"
109+
logger.debug("Generating AOT artifacts for test classes %s"
107110
.formatted(classes.stream().map(Class::getCanonicalName).toList()));
108111
}
109112
try {
110113
// Use first test class discovered for a given unique MergedContextConfiguration.
111114
Class<?> testClass = classes.get(0);
112115
DefaultGenerationContext generationContext = createGenerationContext(testClass);
113116
ClassName className = processAheadOfTime(mergedConfig, generationContext);
114-
// TODO Store ClassName in a map analogous to TestContextAotProcessor in Spring Native.
117+
Assert.state(!classNameMappings.containsKey(className),
118+
() -> "ClassName [%s] already encountered".formatted(className.reflectionName()));
119+
classNameMappings.addAll(className, classes);
115120
generationContext.writeGeneratedContent();
116121
}
117122
catch (Exception ex) {
@@ -121,6 +126,8 @@ public void processAheadOfTime(Stream<Class<?>> testClasses) throws TestContextA
121126
}
122127
}
123128
});
129+
130+
generateAotTestMappings(classNameMappings);
124131
}
125132

126133
/**
@@ -203,4 +210,18 @@ private String nextTestContextId() {
203210
return "TestContext%03d_".formatted(this.sequence.incrementAndGet());
204211
}
205212

213+
private void generateAotTestMappings(MultiValueMap<ClassName, Class<?>> classNameMappings) {
214+
ClassNameGenerator classNameGenerator = new ClassNameGenerator(AotTestMappings.class);
215+
DefaultGenerationContext generationContext =
216+
new DefaultGenerationContext(classNameGenerator, this.generatedFiles, this.runtimeHints);
217+
GeneratedClasses generatedClasses = generationContext.getGeneratedClasses();
218+
219+
AotTestMappingsCodeGenerator codeGenerator =
220+
new AotTestMappingsCodeGenerator(classNameMappings, generatedClasses);
221+
generationContext.writeGeneratedContent();
222+
String className = codeGenerator.getGeneratedClass().getName().reflectionName();
223+
this.runtimeHints.reflection().registerType(TypeReference.of(className),
224+
builder -> builder.withMembers(MemberCategory.INVOKE_PUBLIC_METHODS));
225+
}
226+
206227
}

spring-test/src/test/java/org/springframework/test/context/aot/AbstractAotTests.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@
2828
abstract class AbstractAotTests {
2929

3030
static final String[] expectedSourceFilesForBasicSpringTests = {
31+
// Global
32+
"org/springframework/test/context/aot/AotTestMappings__Generated.java",
3133
// BasicSpringJupiterSharedConfigTests
3234
"org/springframework/context/event/DefaultEventListenerFactory__TestContext001_BeanDefinitions.java",
3335
"org/springframework/context/event/EventListenerMethodProcessor__TestContext001_BeanDefinitions.java",

spring-test/src/test/java/org/springframework/test/context/aot/AotSmokeTests.java

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,11 @@ void scanClassPathThenGenerateSourceFilesAndCompileThem() {
4949
List<String> sourceFiles = generatedFiles.getGeneratedFiles(Kind.SOURCE).keySet().stream().toList();
5050
assertThat(sourceFiles).containsExactlyInAnyOrder(expectedSourceFilesForBasicSpringTests);
5151

52-
TestCompiler.forSystem().withFiles(generatedFiles).compile(compiled -> {
53-
// just make sure compilation completes without errors
54-
});
52+
TestCompiler.forSystem().withFiles(generatedFiles)
53+
// .printFiles(System.out)
54+
.compile(compiled -> {
55+
// just make sure compilation completes without errors
56+
});
5557
}
5658

5759
}

0 commit comments

Comments
 (0)