17
17
package org .springframework .test .context .jdbc ;
18
18
19
19
import java .lang .reflect .Method ;
20
+ import java .util .Arrays ;
20
21
import java .util .List ;
21
22
import java .util .Set ;
23
+ import java .util .stream .Stream ;
22
24
23
25
import javax .sql .DataSource ;
24
26
25
27
import org .apache .commons .logging .Log ;
26
28
import org .apache .commons .logging .LogFactory ;
27
29
30
+ import org .springframework .aot .hint .RuntimeHints ;
28
31
import org .springframework .context .ApplicationContext ;
29
32
import org .springframework .core .annotation .AnnotatedElementUtils ;
30
33
import org .springframework .core .io .ByteArrayResource ;
35
38
import org .springframework .lang .Nullable ;
36
39
import org .springframework .test .context .TestContext ;
37
40
import org .springframework .test .context .TestContextAnnotationUtils ;
41
+ import org .springframework .test .context .aot .AotTestExecutionListener ;
38
42
import org .springframework .test .context .jdbc .Sql .ExecutionPhase ;
39
43
import org .springframework .test .context .jdbc .SqlConfig .ErrorMode ;
40
44
import org .springframework .test .context .jdbc .SqlConfig .TransactionMode ;
52
56
import org .springframework .util .ClassUtils ;
53
57
import org .springframework .util .ObjectUtils ;
54
58
import org .springframework .util .ReflectionUtils ;
55
- import org .springframework .util .ResourceUtils ;
59
+ import org .springframework .util .ReflectionUtils . MethodFilter ;
56
60
import org .springframework .util .StringUtils ;
57
61
62
+ import static org .springframework .util .ResourceUtils .CLASSPATH_URL_PREFIX ;
63
+
58
64
/**
59
65
* {@code TestExecutionListener} that provides support for executing SQL
60
66
* {@link Sql#scripts scripts} and inlined {@link Sql#statements statements}
90
96
* @since 4.1
91
97
* @see Sql
92
98
* @see SqlConfig
99
+ * @see SqlMergeMode
93
100
* @see SqlGroup
94
101
* @see org.springframework.test.context.transaction.TestContextTransactionUtils
95
102
* @see org.springframework.test.context.transaction.TransactionalTestExecutionListener
96
103
* @see org.springframework.jdbc.datasource.init.ResourceDatabasePopulator
97
104
* @see org.springframework.jdbc.datasource.init.ScriptUtils
98
105
*/
99
- public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListener {
106
+ public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListener implements AotTestExecutionListener {
100
107
101
108
private static final String SLASH = "/" ;
102
109
103
110
private static final Log logger = LogFactory .getLog (SqlScriptsTestExecutionListener .class );
104
111
112
+ private static final MethodFilter sqlMethodFilter = ReflectionUtils .USER_DECLARED_METHODS
113
+ .and (method -> AnnotatedElementUtils .hasAnnotation (method , Sql .class ));
114
+
105
115
106
116
/**
107
117
* Returns {@code 5000}.
@@ -129,6 +139,21 @@ public void afterTestMethod(TestContext testContext) {
129
139
executeSqlScripts (testContext , ExecutionPhase .AFTER_TEST_METHOD );
130
140
}
131
141
142
+ /**
143
+ * Process the supplied test class and its methods and register run-time
144
+ * hints for any SQL scripts configured or detected as classpath resources
145
+ * via {@link Sql @Sql}.
146
+ * @since 6.0
147
+ */
148
+ @ Override
149
+ public void processAheadOfTime (Class <?> testClass , RuntimeHints runtimeHints , ClassLoader classLoader ) {
150
+ getSqlAnnotationsFor (testClass ).forEach (sql ->
151
+ registerClasspathResources (runtimeHints , getScripts (sql , testClass , null , true )));
152
+ getSqlMethods (testClass ).forEach (testMethod ->
153
+ getSqlAnnotationsFor (testMethod ).forEach (sql ->
154
+ registerClasspathResources (runtimeHints , getScripts (sql , testClass , testMethod , false ))));
155
+ }
156
+
132
157
/**
133
158
* Execute SQL scripts configured via {@link Sql @Sql} for the supplied
134
159
* {@link TestContext} and {@link ExecutionPhase}.
@@ -226,8 +251,7 @@ private void executeSqlScripts(
226
251
mergedSqlConfig , executionPhase , testContext ));
227
252
}
228
253
229
- String [] scripts = getScripts (sql , testContext , classLevel );
230
- scripts = TestContextResourceUtils .convertToClasspathResourcePaths (testContext .getTestClass (), scripts );
254
+ String [] scripts = getScripts (sql , testContext .getTestClass (), testContext .getTestMethod (), classLevel );
231
255
List <Resource > scriptResources = TestContextResourceUtils .convertToResourceList (
232
256
testContext .getApplicationContext (), scripts );
233
257
for (String stmt : sql .statements ()) {
@@ -321,31 +345,29 @@ private DataSource getDataSourceFromTransactionManager(PlatformTransactionManage
321
345
return null ;
322
346
}
323
347
324
- private String [] getScripts (Sql sql , TestContext testContext , boolean classLevel ) {
348
+ private String [] getScripts (Sql sql , Class <?> testClass , Method testMethod , boolean classLevel ) {
325
349
String [] scripts = sql .scripts ();
326
350
if (ObjectUtils .isEmpty (scripts ) && ObjectUtils .isEmpty (sql .statements ())) {
327
- scripts = new String [] {detectDefaultScript (testContext , classLevel )};
351
+ scripts = new String [] {detectDefaultScript (testClass , testMethod , classLevel )};
328
352
}
329
- return scripts ;
353
+ return TestContextResourceUtils . convertToClasspathResourcePaths ( testClass , scripts ) ;
330
354
}
331
355
332
356
/**
333
357
* Detect a default SQL script by implementing the algorithm defined in
334
358
* {@link Sql#scripts}.
335
359
*/
336
- private String detectDefaultScript (TestContext testContext , boolean classLevel ) {
337
- Class <?> clazz = testContext .getTestClass ();
338
- Method method = testContext .getTestMethod ();
360
+ private String detectDefaultScript (Class <?> testClass , Method testMethod , boolean classLevel ) {
339
361
String elementType = (classLevel ? "class" : "method" );
340
- String elementName = (classLevel ? clazz .getName () : method .toString ());
362
+ String elementName = (classLevel ? testClass .getName () : testMethod .toString ());
341
363
342
- String resourcePath = ClassUtils .convertClassNameToResourcePath (clazz .getName ());
364
+ String resourcePath = ClassUtils .convertClassNameToResourcePath (testClass .getName ());
343
365
if (!classLevel ) {
344
- resourcePath += "." + method .getName ();
366
+ resourcePath += "." + testMethod .getName ();
345
367
}
346
368
resourcePath += ".sql" ;
347
369
348
- String prefixedResourcePath = ResourceUtils . CLASSPATH_URL_PREFIX + SLASH + resourcePath ;
370
+ String prefixedResourcePath = CLASSPATH_URL_PREFIX + SLASH + resourcePath ;
349
371
ClassPathResource classPathResource = new ClassPathResource (resourcePath );
350
372
351
373
if (classPathResource .exists ()) {
@@ -364,4 +386,23 @@ private String detectDefaultScript(TestContext testContext, boolean classLevel)
364
386
}
365
387
}
366
388
389
+ private Stream <Method > getSqlMethods (Class <?> testClass ) {
390
+ return Arrays .stream (ReflectionUtils .getUniqueDeclaredMethods (testClass , sqlMethodFilter ));
391
+ }
392
+
393
+ private void registerClasspathResources (RuntimeHints runtimeHints , String ... locations ) {
394
+ Arrays .stream (locations )
395
+ .filter (location -> location .startsWith (CLASSPATH_URL_PREFIX ))
396
+ .map (this ::cleanClasspathResource )
397
+ .forEach (runtimeHints .resources ()::registerPattern );
398
+ }
399
+
400
+ private String cleanClasspathResource (String location ) {
401
+ location = location .substring (CLASSPATH_URL_PREFIX .length ());
402
+ if (!location .startsWith (SLASH )) {
403
+ location = SLASH + location ;
404
+ }
405
+ return location ;
406
+ }
407
+
367
408
}
0 commit comments