Skip to content

Commit e57b5f1

Browse files
committed
Register runtime hints for @SQL scripts
SqlScriptsTestExecutionListener now implements AotTestExecutionListener in order to register run-time hints for SQL scripts used by test classes and test methods annotated with @SQL if the configured or detected SQL scripts are classpath resources. Closes gh-29027
1 parent e85e769 commit e57b5f1

File tree

2 files changed

+61
-14
lines changed

2 files changed

+61
-14
lines changed

spring-test/src/main/java/org/springframework/test/context/jdbc/SqlScriptsTestExecutionListener.java

Lines changed: 55 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,14 +17,17 @@
1717
package org.springframework.test.context.jdbc;
1818

1919
import java.lang.reflect.Method;
20+
import java.util.Arrays;
2021
import java.util.List;
2122
import java.util.Set;
23+
import java.util.stream.Stream;
2224

2325
import javax.sql.DataSource;
2426

2527
import org.apache.commons.logging.Log;
2628
import org.apache.commons.logging.LogFactory;
2729

30+
import org.springframework.aot.hint.RuntimeHints;
2831
import org.springframework.context.ApplicationContext;
2932
import org.springframework.core.annotation.AnnotatedElementUtils;
3033
import org.springframework.core.io.ByteArrayResource;
@@ -35,6 +38,7 @@
3538
import org.springframework.lang.Nullable;
3639
import org.springframework.test.context.TestContext;
3740
import org.springframework.test.context.TestContextAnnotationUtils;
41+
import org.springframework.test.context.aot.AotTestExecutionListener;
3842
import org.springframework.test.context.jdbc.Sql.ExecutionPhase;
3943
import org.springframework.test.context.jdbc.SqlConfig.ErrorMode;
4044
import org.springframework.test.context.jdbc.SqlConfig.TransactionMode;
@@ -52,9 +56,11 @@
5256
import org.springframework.util.ClassUtils;
5357
import org.springframework.util.ObjectUtils;
5458
import org.springframework.util.ReflectionUtils;
55-
import org.springframework.util.ResourceUtils;
59+
import org.springframework.util.ReflectionUtils.MethodFilter;
5660
import org.springframework.util.StringUtils;
5761

62+
import static org.springframework.util.ResourceUtils.CLASSPATH_URL_PREFIX;
63+
5864
/**
5965
* {@code TestExecutionListener} that provides support for executing SQL
6066
* {@link Sql#scripts scripts} and inlined {@link Sql#statements statements}
@@ -90,18 +96,22 @@
9096
* @since 4.1
9197
* @see Sql
9298
* @see SqlConfig
99+
* @see SqlMergeMode
93100
* @see SqlGroup
94101
* @see org.springframework.test.context.transaction.TestContextTransactionUtils
95102
* @see org.springframework.test.context.transaction.TransactionalTestExecutionListener
96103
* @see org.springframework.jdbc.datasource.init.ResourceDatabasePopulator
97104
* @see org.springframework.jdbc.datasource.init.ScriptUtils
98105
*/
99-
public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListener {
106+
public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListener implements AotTestExecutionListener {
100107

101108
private static final String SLASH = "/";
102109

103110
private static final Log logger = LogFactory.getLog(SqlScriptsTestExecutionListener.class);
104111

112+
private static final MethodFilter sqlMethodFilter = ReflectionUtils.USER_DECLARED_METHODS
113+
.and(method -> AnnotatedElementUtils.hasAnnotation(method, Sql.class));
114+
105115

106116
/**
107117
* Returns {@code 5000}.
@@ -129,6 +139,21 @@ public void afterTestMethod(TestContext testContext) {
129139
executeSqlScripts(testContext, ExecutionPhase.AFTER_TEST_METHOD);
130140
}
131141

142+
/**
143+
* Process the supplied test class and its methods and register run-time
144+
* hints for any SQL scripts configured or detected as classpath resources
145+
* via {@link Sql @Sql}.
146+
* @since 6.0
147+
*/
148+
@Override
149+
public void processAheadOfTime(Class<?> testClass, RuntimeHints runtimeHints, ClassLoader classLoader) {
150+
getSqlAnnotationsFor(testClass).forEach(sql ->
151+
registerClasspathResources(runtimeHints, getScripts(sql, testClass, null, true)));
152+
getSqlMethods(testClass).forEach(testMethod ->
153+
getSqlAnnotationsFor(testMethod).forEach(sql ->
154+
registerClasspathResources(runtimeHints, getScripts(sql, testClass, testMethod, false))));
155+
}
156+
132157
/**
133158
* Execute SQL scripts configured via {@link Sql @Sql} for the supplied
134159
* {@link TestContext} and {@link ExecutionPhase}.
@@ -226,8 +251,7 @@ private void executeSqlScripts(
226251
mergedSqlConfig, executionPhase, testContext));
227252
}
228253

229-
String[] scripts = getScripts(sql, testContext, classLevel);
230-
scripts = TestContextResourceUtils.convertToClasspathResourcePaths(testContext.getTestClass(), scripts);
254+
String[] scripts = getScripts(sql, testContext.getTestClass(), testContext.getTestMethod(), classLevel);
231255
List<Resource> scriptResources = TestContextResourceUtils.convertToResourceList(
232256
testContext.getApplicationContext(), scripts);
233257
for (String stmt : sql.statements()) {
@@ -321,31 +345,29 @@ private DataSource getDataSourceFromTransactionManager(PlatformTransactionManage
321345
return null;
322346
}
323347

324-
private String[] getScripts(Sql sql, TestContext testContext, boolean classLevel) {
348+
private String[] getScripts(Sql sql, Class<?> testClass, Method testMethod, boolean classLevel) {
325349
String[] scripts = sql.scripts();
326350
if (ObjectUtils.isEmpty(scripts) && ObjectUtils.isEmpty(sql.statements())) {
327-
scripts = new String[] {detectDefaultScript(testContext, classLevel)};
351+
scripts = new String[] {detectDefaultScript(testClass, testMethod, classLevel)};
328352
}
329-
return scripts;
353+
return TestContextResourceUtils.convertToClasspathResourcePaths(testClass, scripts);
330354
}
331355

332356
/**
333357
* Detect a default SQL script by implementing the algorithm defined in
334358
* {@link Sql#scripts}.
335359
*/
336-
private String detectDefaultScript(TestContext testContext, boolean classLevel) {
337-
Class<?> clazz = testContext.getTestClass();
338-
Method method = testContext.getTestMethod();
360+
private String detectDefaultScript(Class<?> testClass, Method testMethod, boolean classLevel) {
339361
String elementType = (classLevel ? "class" : "method");
340-
String elementName = (classLevel ? clazz.getName() : method.toString());
362+
String elementName = (classLevel ? testClass.getName() : testMethod.toString());
341363

342-
String resourcePath = ClassUtils.convertClassNameToResourcePath(clazz.getName());
364+
String resourcePath = ClassUtils.convertClassNameToResourcePath(testClass.getName());
343365
if (!classLevel) {
344-
resourcePath += "." + method.getName();
366+
resourcePath += "." + testMethod.getName();
345367
}
346368
resourcePath += ".sql";
347369

348-
String prefixedResourcePath = ResourceUtils.CLASSPATH_URL_PREFIX + SLASH + resourcePath;
370+
String prefixedResourcePath = CLASSPATH_URL_PREFIX + SLASH + resourcePath;
349371
ClassPathResource classPathResource = new ClassPathResource(resourcePath);
350372

351373
if (classPathResource.exists()) {
@@ -364,4 +386,23 @@ private String detectDefaultScript(TestContext testContext, boolean classLevel)
364386
}
365387
}
366388

389+
private Stream<Method> getSqlMethods(Class<?> testClass) {
390+
return Arrays.stream(ReflectionUtils.getUniqueDeclaredMethods(testClass, sqlMethodFilter));
391+
}
392+
393+
private void registerClasspathResources(RuntimeHints runtimeHints, String... locations) {
394+
Arrays.stream(locations)
395+
.filter(location -> location.startsWith(CLASSPATH_URL_PREFIX))
396+
.map(this::cleanClasspathResource)
397+
.forEach(runtimeHints.resources()::registerPattern);
398+
}
399+
400+
private String cleanClasspathResource(String location) {
401+
location = location.substring(CLASSPATH_URL_PREFIX.length());
402+
if (!location.startsWith(SLASH)) {
403+
location = SLASH + location;
404+
}
405+
return location;
406+
}
407+
367408
}

spring-test/src/test/java/org/springframework/test/context/aot/TestContextAotGeneratorTests.java

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,12 @@ private static void assertRuntimeHints(RuntimeHints runtimeHints) {
191191
// @WebAppConfiguration(value = ...)
192192
assertThat(resource().forResource("/META-INF/web-resources/resources/Spring.js")).accepts(runtimeHints);
193193
assertThat(resource().forResource("/META-INF/web-resources/WEB-INF/views/home.jsp")).accepts(runtimeHints);
194+
195+
// @Sql(scripts = ...)
196+
assertThat(resource().forResource("/org/springframework/test/context/jdbc/schema.sql"))
197+
.accepts(runtimeHints);
198+
assertThat(resource().forResource("/org/springframework/test/context/aot/samples/jdbc/SqlScriptsSpringJupiterTests.test.sql"))
199+
.accepts(runtimeHints);
194200
}
195201

196202
private static void assertReflectionRegistered(RuntimeHints runtimeHints, String type, MemberCategory memberCategory) {

0 commit comments

Comments
 (0)