Skip to content

Commit f628c60

Browse files
committed
Merge pull request #28132 from ctailor2
* pr/28132: Polish "Allow batch update to take a KeyHolder" Allow batch update to take a KeyHolder Closes gh-28132
2 parents 056de7e + c21a9b9 commit f628c60

File tree

5 files changed

+252
-51
lines changed

5 files changed

+252
-51
lines changed

spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcOperations.java

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -997,6 +997,29 @@ <T> List<T> queryForList(String sql, Object[] args, int[] argTypes, Class<T> ele
997997
*/
998998
int[] batchUpdate(String sql, BatchPreparedStatementSetter pss) throws DataAccessException;
999999

1000+
/**
1001+
* Issue multiple update statements on a single PreparedStatement,
1002+
* using batch updates and a BatchPreparedStatementSetter to set values.
1003+
* Generated keys will be put into the given KeyHolder.
1004+
* <p>Note that the given PreparedStatementCreator has to create a statement
1005+
* with activated extraction of generated keys (a JDBC 3.0 feature). This can
1006+
* either be done directly or through using a PreparedStatementCreatorFactory.
1007+
* <p>Will fall back to separate updates on a single PreparedStatement
1008+
* if the JDBC driver does not support batch updates.
1009+
* @param psc a callback that creates a PreparedStatement given a Connection
1010+
* @param pss object to set parameters on the PreparedStatement
1011+
* created by this method
1012+
* @param generatedKeyHolder a KeyHolder that will hold the generated keys
1013+
* @return an array of the number of rows affected by each statement
1014+
* (may also contain special JDBC-defined negative values for affected rows such as
1015+
* {@link java.sql.Statement#SUCCESS_NO_INFO}/{@link java.sql.Statement#EXECUTE_FAILED})
1016+
* @throws DataAccessException if there is any problem issuing the update
1017+
* @since 6.1
1018+
* @see org.springframework.jdbc.support.GeneratedKeyHolder
1019+
*/
1020+
int[] batchUpdate(PreparedStatementCreator psc, BatchPreparedStatementSetter pss,
1021+
KeyHolder generatedKeyHolder) throws DataAccessException;
1022+
10001023
/**
10011024
* Execute a batch using the supplied SQL statement with the batch of supplied arguments.
10021025
* @param sql the SQL statement to execute

spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java

Lines changed: 79 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -996,21 +996,10 @@ public int update(final PreparedStatementCreator psc, final KeyHolder generatedK
996996

997997
return updateCount(execute(psc, ps -> {
998998
int rows = ps.executeUpdate();
999-
List<Map<String, Object>> generatedKeys = generatedKeyHolder.getKeyList();
1000-
generatedKeys.clear();
1001-
ResultSet keys = ps.getGeneratedKeys();
1002-
if (keys != null) {
1003-
try {
1004-
RowMapperResultSetExtractor<Map<String, Object>> rse =
1005-
new RowMapperResultSetExtractor<>(getColumnMapRowMapper(), 1);
1006-
generatedKeys.addAll(result(rse.extractData(keys)));
1007-
}
1008-
finally {
1009-
JdbcUtils.closeResultSet(keys);
1010-
}
1011-
}
999+
generatedKeyHolder.getKeyList().clear();
1000+
storeGeneratedKeys(generatedKeyHolder, ps, 1);
10121001
if (logger.isTraceEnabled()) {
1013-
logger.trace("SQL update affected " + rows + " rows and returned " + generatedKeys.size() + " keys");
1002+
logger.trace("SQL update affected " + rows + " rows and returned " + generatedKeyHolder.getKeyList().size() + " keys");
10141003
}
10151004
return rows;
10161005
}, true));
@@ -1031,6 +1020,16 @@ public int update(String sql, @Nullable Object... args) throws DataAccessExcepti
10311020
return update(sql, newArgPreparedStatementSetter(args));
10321021
}
10331022

1023+
@Override
1024+
public int[] batchUpdate(final PreparedStatementCreator psc, final BatchPreparedStatementSetter pss,
1025+
final KeyHolder generatedKeyHolder) throws DataAccessException {
1026+
1027+
int[] result = execute(psc, getPreparedStatementCallback(pss, generatedKeyHolder));
1028+
1029+
Assert.state(result != null, "No result array");
1030+
return result;
1031+
}
1032+
10341033
@Override
10351034
public int[] batchUpdate(String sql, final BatchPreparedStatementSetter pss) throws DataAccessException {
10361035
if (logger.isDebugEnabled()) {
@@ -1041,43 +1040,7 @@ public int[] batchUpdate(String sql, final BatchPreparedStatementSetter pss) thr
10411040
return new int[0];
10421041
}
10431042

1044-
int[] result = execute(sql, (PreparedStatementCallback<int[]>) ps -> {
1045-
try {
1046-
InterruptibleBatchPreparedStatementSetter ipss =
1047-
(pss instanceof InterruptibleBatchPreparedStatementSetter ibpss ? ibpss : null);
1048-
if (JdbcUtils.supportsBatchUpdates(ps.getConnection())) {
1049-
for (int i = 0; i < batchSize; i++) {
1050-
pss.setValues(ps, i);
1051-
if (ipss != null && ipss.isBatchExhausted(i)) {
1052-
break;
1053-
}
1054-
ps.addBatch();
1055-
}
1056-
return ps.executeBatch();
1057-
}
1058-
else {
1059-
List<Integer> rowsAffected = new ArrayList<>();
1060-
for (int i = 0; i < batchSize; i++) {
1061-
pss.setValues(ps, i);
1062-
if (ipss != null && ipss.isBatchExhausted(i)) {
1063-
break;
1064-
}
1065-
rowsAffected.add(ps.executeUpdate());
1066-
}
1067-
int[] rowsAffectedArray = new int[rowsAffected.size()];
1068-
for (int i = 0; i < rowsAffectedArray.length; i++) {
1069-
rowsAffectedArray[i] = rowsAffected.get(i);
1070-
}
1071-
return rowsAffectedArray;
1072-
}
1073-
}
1074-
finally {
1075-
if (pss instanceof ParameterDisposer parameterDisposer) {
1076-
parameterDisposer.cleanupParameters();
1077-
}
1078-
}
1079-
});
1080-
1043+
int[] result = execute(sql, getPreparedStatementCallback(pss, null));
10811044
Assert.state(result != null, "No result array");
10821045
return result;
10831046
}
@@ -1604,6 +1567,71 @@ private static int updateCount(@Nullable Integer result) {
16041567
return result;
16051568
}
16061569

1570+
private void storeGeneratedKeys(KeyHolder generatedKeyHolder, PreparedStatement ps, int rowsExpected) throws SQLException {
1571+
List<Map<String, Object>> generatedKeys = generatedKeyHolder.getKeyList();
1572+
ResultSet keys = ps.getGeneratedKeys();
1573+
if (keys != null) {
1574+
try {
1575+
RowMapperResultSetExtractor<Map<String, Object>> rse =
1576+
new RowMapperResultSetExtractor<>(getColumnMapRowMapper(), rowsExpected);
1577+
generatedKeys.addAll(result(rse.extractData(keys)));
1578+
}
1579+
finally {
1580+
JdbcUtils.closeResultSet(keys);
1581+
}
1582+
}
1583+
}
1584+
1585+
private PreparedStatementCallback<int[]> getPreparedStatementCallback(BatchPreparedStatementSetter pss, @Nullable KeyHolder generatedKeyHolder) {
1586+
return ps -> {
1587+
try {
1588+
int batchSize = pss.getBatchSize();
1589+
InterruptibleBatchPreparedStatementSetter ipss =
1590+
(pss instanceof InterruptibleBatchPreparedStatementSetter ibpss ? ibpss : null);
1591+
if (generatedKeyHolder != null) {
1592+
generatedKeyHolder.getKeyList().clear();
1593+
}
1594+
if (JdbcUtils.supportsBatchUpdates(ps.getConnection())) {
1595+
for (int i = 0; i < batchSize; i++) {
1596+
pss.setValues(ps, i);
1597+
if (ipss != null && ipss.isBatchExhausted(i)) {
1598+
break;
1599+
}
1600+
ps.addBatch();
1601+
}
1602+
int[] results = ps.executeBatch();
1603+
if (generatedKeyHolder != null) {
1604+
storeGeneratedKeys(generatedKeyHolder, ps, batchSize);
1605+
}
1606+
return results;
1607+
}
1608+
else {
1609+
List<Integer> rowsAffected = new ArrayList<>();
1610+
for (int i = 0; i < batchSize; i++) {
1611+
pss.setValues(ps, i);
1612+
if (ipss != null && ipss.isBatchExhausted(i)) {
1613+
break;
1614+
}
1615+
rowsAffected.add(ps.executeUpdate());
1616+
if (generatedKeyHolder != null) {
1617+
storeGeneratedKeys(generatedKeyHolder, ps, 1);
1618+
}
1619+
}
1620+
int[] rowsAffectedArray = new int[rowsAffected.size()];
1621+
for (int i = 0; i < rowsAffectedArray.length; i++) {
1622+
rowsAffectedArray[i] = rowsAffected.get(i);
1623+
}
1624+
return rowsAffectedArray;
1625+
}
1626+
}
1627+
finally {
1628+
if (pss instanceof ParameterDisposer parameterDisposer) {
1629+
parameterDisposer.cleanupParameters();
1630+
}
1631+
}
1632+
};
1633+
}
1634+
16071635

16081636
/**
16091637
* Invocation handler that suppresses close calls on JDBC Connections.

spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterJdbcOperations.java

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,4 +555,37 @@ int update(String sql, SqlParameterSource paramSource, KeyHolder generatedKeyHol
555555
*/
556556
int[] batchUpdate(String sql, SqlParameterSource[] batchArgs);
557557

558+
/**
559+
* Execute a batch using the supplied SQL statement with the batch of supplied
560+
* arguments, returning generated keys.
561+
* @param sql the SQL statement to execute
562+
* @param batchArgs the array of {@link SqlParameterSource} containing the batch of
563+
* arguments for the query
564+
* @param generatedKeyHolder a {@link KeyHolder} that will hold the generated keys
565+
* @return an array containing the numbers of rows affected by each update in the batch
566+
* (may also contain special JDBC-defined negative values for affected rows such as
567+
* {@link java.sql.Statement#SUCCESS_NO_INFO}/{@link java.sql.Statement#EXECUTE_FAILED})
568+
* @throws DataAccessException if there is any problem issuing the update
569+
* @since 6.1
570+
* @see org.springframework.jdbc.support.GeneratedKeyHolder
571+
*/
572+
int[] batchUpdate(String sql, SqlParameterSource[] batchArgs, KeyHolder generatedKeyHolder);
573+
574+
/**
575+
* Execute a batch using the supplied SQL statement with the batch of supplied arguments,
576+
* returning generated keys.
577+
* @param sql the SQL statement to execute
578+
* @param batchArgs the array of {@link SqlParameterSource} containing the batch of
579+
* arguments for the query
580+
* @param generatedKeyHolder a {@link KeyHolder} that will hold the generated keys
581+
* @param keyColumnNames names of the columns that will have keys generated for them
582+
* @return an array containing the numbers of rows affected by each update in the batch
583+
* (may also contain special JDBC-defined negative values for affected rows such as
584+
* {@link java.sql.Statement#SUCCESS_NO_INFO}/{@link java.sql.Statement#EXECUTE_FAILED})
585+
* @throws DataAccessException if there is any problem issuing the update
586+
* @since 6.1
587+
* @see org.springframework.jdbc.support.GeneratedKeyHolder
588+
*/
589+
int[] batchUpdate(String sql, SqlParameterSource[] batchArgs, KeyHolder generatedKeyHolder,
590+
String[] keyColumnNames);
558591
}

spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterJdbcTemplate.java

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,44 @@ public int getBatchSize() {
394394
});
395395
}
396396

397+
@Override
398+
public int[] batchUpdate(String sql, SqlParameterSource[] batchArgs, KeyHolder generatedKeyHolder) {
399+
return batchUpdate(sql, batchArgs, generatedKeyHolder, null);
400+
}
401+
402+
@Override
403+
public int[] batchUpdate(String sql, SqlParameterSource[] batchArgs, KeyHolder generatedKeyHolder,
404+
@Nullable String[] keyColumnNames) {
405+
406+
if (batchArgs.length == 0) {
407+
return new int[0];
408+
}
409+
410+
ParsedSql parsedSql = getParsedSql(sql);
411+
SqlParameterSource paramSource = batchArgs[0];
412+
PreparedStatementCreatorFactory pscf = getPreparedStatementCreatorFactory(parsedSql, paramSource);
413+
if (keyColumnNames != null) {
414+
pscf.setGeneratedKeysColumnNames(keyColumnNames);
415+
}
416+
else {
417+
pscf.setReturnGeneratedKeys(true);
418+
}
419+
Object[] params = NamedParameterUtils.buildValueArray(parsedSql, paramSource, null);
420+
PreparedStatementCreator psc = pscf.newPreparedStatementCreator(params);
421+
return getJdbcOperations().batchUpdate(psc, new BatchPreparedStatementSetter() {
422+
@Override
423+
public void setValues(PreparedStatement ps, int i) throws SQLException {
424+
Object[] values = NamedParameterUtils.buildValueArray(parsedSql, batchArgs[i], null);
425+
pscf.newPreparedStatementSetter(values).setValues(ps);
426+
}
427+
428+
@Override
429+
public int getBatchSize() {
430+
return batchArgs.length;
431+
}
432+
}, generatedKeyHolder);
433+
}
434+
397435

398436
/**
399437
* Build a {@link PreparedStatementCreator} based on the given SQL and named parameters.

spring-jdbc/src/test/java/org/springframework/jdbc/core/JdbcTemplateTests.java

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747
import org.springframework.jdbc.core.support.AbstractInterruptibleBatchPreparedStatementSetter;
4848
import org.springframework.jdbc.datasource.ConnectionProxy;
4949
import org.springframework.jdbc.datasource.SingleConnectionDataSource;
50+
import org.springframework.jdbc.support.GeneratedKeyHolder;
51+
import org.springframework.jdbc.support.KeyHolder;
5052
import org.springframework.jdbc.support.SQLErrorCodeSQLExceptionTranslator;
5153
import org.springframework.jdbc.support.SQLStateSQLExceptionTranslator;
5254
import org.springframework.util.LinkedCaseInsensitiveMap;
@@ -1104,6 +1106,83 @@ public void testEquallyNamedColumn() throws SQLException {
11041106
assertThat(map.get("x")).isEqualTo("first value");
11051107
}
11061108

1109+
@Test
1110+
void testBatchUpdateReturnsGeneratedKeys_whenDatabaseSupportsBatchUpdates() throws SQLException {
1111+
final int[] rowsAffected = new int[] {1, 2};
1112+
given(this.preparedStatement.executeBatch()).willReturn(rowsAffected);
1113+
DatabaseMetaData databaseMetaData = mock(DatabaseMetaData.class);
1114+
given(databaseMetaData.supportsBatchUpdates()).willReturn(true);
1115+
given(this.connection.getMetaData()).willReturn(databaseMetaData);
1116+
ResultSet generatedKeysResultSet = mock(ResultSet.class);
1117+
ResultSetMetaData rsmd = mock(ResultSetMetaData.class);
1118+
given(rsmd.getColumnCount()).willReturn(1);
1119+
given(rsmd.getColumnLabel(1)).willReturn("someId");
1120+
given(generatedKeysResultSet.getMetaData()).willReturn(rsmd);
1121+
given(generatedKeysResultSet.getObject(1)).willReturn(123, 456);
1122+
given(generatedKeysResultSet.next()).willReturn(true, true, false);
1123+
given(this.preparedStatement.getGeneratedKeys()).willReturn(generatedKeysResultSet);
1124+
1125+
int[] values = new int[]{100, 200};
1126+
BatchPreparedStatementSetter bpss = new BatchPreparedStatementSetter() {
1127+
@Override
1128+
public void setValues(PreparedStatement ps, int i) throws SQLException {
1129+
ps.setObject(i, values[i]);
1130+
}
1131+
1132+
@Override
1133+
public int getBatchSize() {
1134+
return 2;
1135+
}
1136+
};
1137+
1138+
KeyHolder keyHolder = new GeneratedKeyHolder();
1139+
this.template.batchUpdate(con -> con.prepareStatement(""), bpss, keyHolder);
1140+
1141+
assertThat(keyHolder.getKeyList()).containsExactly(
1142+
Collections.singletonMap("someId", 123),
1143+
Collections.singletonMap("someId", 456));
1144+
}
1145+
1146+
@Test
1147+
void testBatchUpdateReturnsGeneratedKeys_whenDatabaseDoesNotSupportBatchUpdates() throws SQLException {
1148+
final int[] rowsAffected = new int[] {1, 2};
1149+
given(this.preparedStatement.executeBatch()).willReturn(rowsAffected);
1150+
DatabaseMetaData databaseMetaData = mock(DatabaseMetaData.class);
1151+
given(databaseMetaData.supportsBatchUpdates()).willReturn(false);
1152+
given(this.connection.getMetaData()).willReturn(databaseMetaData);
1153+
ResultSetMetaData rsmd = mock(ResultSetMetaData.class);
1154+
given(rsmd.getColumnCount()).willReturn(1);
1155+
given(rsmd.getColumnLabel(1)).willReturn("someId");
1156+
ResultSet generatedKeysResultSet1 = mock(ResultSet.class);
1157+
given(generatedKeysResultSet1.getMetaData()).willReturn(rsmd);
1158+
given(generatedKeysResultSet1.getObject(1)).willReturn(123);
1159+
given(generatedKeysResultSet1.next()).willReturn(true, false);
1160+
ResultSet generatedKeysResultSet2 = mock(ResultSet.class);
1161+
given(generatedKeysResultSet2.getMetaData()).willReturn(rsmd);
1162+
given(generatedKeysResultSet2.getObject(1)).willReturn(456);
1163+
given(generatedKeysResultSet2.next()).willReturn(true, false);
1164+
given(this.preparedStatement.getGeneratedKeys()).willReturn(generatedKeysResultSet1, generatedKeysResultSet2);
1165+
1166+
int[] values = new int[]{100, 200};
1167+
BatchPreparedStatementSetter bpss = new BatchPreparedStatementSetter() {
1168+
@Override
1169+
public void setValues(PreparedStatement ps, int i) throws SQLException {
1170+
ps.setObject(i, values[i]);
1171+
}
1172+
1173+
@Override
1174+
public int getBatchSize() {
1175+
return 2;
1176+
}
1177+
};
1178+
1179+
KeyHolder keyHolder = new GeneratedKeyHolder();
1180+
this.template.batchUpdate(con -> con.prepareStatement(""), bpss, keyHolder);
1181+
1182+
assertThat(keyHolder.getKeyList()).containsExactly(
1183+
Collections.singletonMap("someId", 123),
1184+
Collections.singletonMap("someId", 456));
1185+
}
11071186

11081187
private void mockDatabaseMetaData(boolean supportsBatchUpdates) throws SQLException {
11091188
DatabaseMetaData databaseMetaData = mock();

0 commit comments

Comments
 (0)