Skip to content

Commit 78db5dd

Browse files
ctailor2snicoll
authored andcommitted
Allow batch update to take a KeyHolder
See gh-28132
1 parent 056de7e commit 78db5dd

File tree

5 files changed

+245
-51
lines changed

5 files changed

+245
-51
lines changed

spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcOperations.java

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -997,6 +997,27 @@ <T> List<T> queryForList(String sql, Object[] args, int[] argTypes, Class<T> ele
997997
*/
998998
int[] batchUpdate(String sql, BatchPreparedStatementSetter pss) throws DataAccessException;
999999

1000+
/**
1001+
* Issue multiple update statements on a single PreparedStatement,
1002+
* using batch updates and a BatchPreparedStatementSetter to set values.
1003+
* Generated keys will be put into the given KeyHolder.
1004+
* <p>Note that the given PreparedStatementCreator has to create a statement
1005+
* with activated extraction of generated keys (a JDBC 3.0 feature). This can
1006+
* either be done directly or through using a PreparedStatementCreatorFactory.
1007+
* <p>Will fall back to separate updates on a single PreparedStatement
1008+
* if the JDBC driver does not support batch updates.
1009+
* @param psc a callback that creates a PreparedStatement given a Connection
1010+
* @param pss object to set parameters on the PreparedStatement
1011+
* created by this method
1012+
* @param generatedKeyHolder a KeyHolder that will hold the generated keys
1013+
* @return an array of the number of rows affected by each statement
1014+
* (may also contain special JDBC-defined negative values for affected rows such as
1015+
* {@link java.sql.Statement#SUCCESS_NO_INFO}/{@link java.sql.Statement#EXECUTE_FAILED})
1016+
* @throws DataAccessException if there is any problem issuing the update
1017+
* @see org.springframework.jdbc.support.GeneratedKeyHolder
1018+
*/
1019+
int[] batchUpdate(PreparedStatementCreator psc, BatchPreparedStatementSetter pss, KeyHolder generatedKeyHolder) throws DataAccessException;
1020+
10001021
/**
10011022
* Execute a batch using the supplied SQL statement with the batch of supplied arguments.
10021023
* @param sql the SQL statement to execute

spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java

Lines changed: 77 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -996,21 +996,10 @@ public int update(final PreparedStatementCreator psc, final KeyHolder generatedK
996996

997997
return updateCount(execute(psc, ps -> {
998998
int rows = ps.executeUpdate();
999-
List<Map<String, Object>> generatedKeys = generatedKeyHolder.getKeyList();
1000-
generatedKeys.clear();
1001-
ResultSet keys = ps.getGeneratedKeys();
1002-
if (keys != null) {
1003-
try {
1004-
RowMapperResultSetExtractor<Map<String, Object>> rse =
1005-
new RowMapperResultSetExtractor<>(getColumnMapRowMapper(), 1);
1006-
generatedKeys.addAll(result(rse.extractData(keys)));
1007-
}
1008-
finally {
1009-
JdbcUtils.closeResultSet(keys);
1010-
}
1011-
}
999+
generatedKeyHolder.getKeyList().clear();
1000+
storeGeneratedKeys(generatedKeyHolder, ps, 1);
10121001
if (logger.isTraceEnabled()) {
1013-
logger.trace("SQL update affected " + rows + " rows and returned " + generatedKeys.size() + " keys");
1002+
logger.trace("SQL update affected " + rows + " rows and returned " + generatedKeyHolder.getKeyList().size() + " keys");
10141003
}
10151004
return rows;
10161005
}, true));
@@ -1031,6 +1020,14 @@ public int update(String sql, @Nullable Object... args) throws DataAccessExcepti
10311020
return update(sql, newArgPreparedStatementSetter(args));
10321021
}
10331022

1023+
@Override
1024+
public int[] batchUpdate(final PreparedStatementCreator psc, final BatchPreparedStatementSetter pss, final KeyHolder generatedKeyHolder) throws DataAccessException {
1025+
int[] result = execute(psc, getPreparedStatementCallback(pss, generatedKeyHolder));
1026+
1027+
Assert.state(result != null, "No result array");
1028+
return result;
1029+
}
1030+
10341031
@Override
10351032
public int[] batchUpdate(String sql, final BatchPreparedStatementSetter pss) throws DataAccessException {
10361033
if (logger.isDebugEnabled()) {
@@ -1041,43 +1038,7 @@ public int[] batchUpdate(String sql, final BatchPreparedStatementSetter pss) thr
10411038
return new int[0];
10421039
}
10431040

1044-
int[] result = execute(sql, (PreparedStatementCallback<int[]>) ps -> {
1045-
try {
1046-
InterruptibleBatchPreparedStatementSetter ipss =
1047-
(pss instanceof InterruptibleBatchPreparedStatementSetter ibpss ? ibpss : null);
1048-
if (JdbcUtils.supportsBatchUpdates(ps.getConnection())) {
1049-
for (int i = 0; i < batchSize; i++) {
1050-
pss.setValues(ps, i);
1051-
if (ipss != null && ipss.isBatchExhausted(i)) {
1052-
break;
1053-
}
1054-
ps.addBatch();
1055-
}
1056-
return ps.executeBatch();
1057-
}
1058-
else {
1059-
List<Integer> rowsAffected = new ArrayList<>();
1060-
for (int i = 0; i < batchSize; i++) {
1061-
pss.setValues(ps, i);
1062-
if (ipss != null && ipss.isBatchExhausted(i)) {
1063-
break;
1064-
}
1065-
rowsAffected.add(ps.executeUpdate());
1066-
}
1067-
int[] rowsAffectedArray = new int[rowsAffected.size()];
1068-
for (int i = 0; i < rowsAffectedArray.length; i++) {
1069-
rowsAffectedArray[i] = rowsAffected.get(i);
1070-
}
1071-
return rowsAffectedArray;
1072-
}
1073-
}
1074-
finally {
1075-
if (pss instanceof ParameterDisposer parameterDisposer) {
1076-
parameterDisposer.cleanupParameters();
1077-
}
1078-
}
1079-
});
1080-
1041+
int[] result = execute(sql, getPreparedStatementCallback(pss, null));
10811042
Assert.state(result != null, "No result array");
10821043
return result;
10831044
}
@@ -1604,6 +1565,71 @@ private static int updateCount(@Nullable Integer result) {
16041565
return result;
16051566
}
16061567

1568+
private void storeGeneratedKeys(KeyHolder generatedKeyHolder, PreparedStatement ps, int rowsExpected) throws SQLException {
1569+
List<Map<String, Object>> generatedKeys = generatedKeyHolder.getKeyList();
1570+
ResultSet keys = ps.getGeneratedKeys();
1571+
if (keys != null) {
1572+
try {
1573+
RowMapperResultSetExtractor<Map<String, Object>> rse =
1574+
new RowMapperResultSetExtractor<>(getColumnMapRowMapper(), rowsExpected);
1575+
generatedKeys.addAll(result(rse.extractData(keys)));
1576+
}
1577+
finally {
1578+
JdbcUtils.closeResultSet(keys);
1579+
}
1580+
}
1581+
}
1582+
1583+
private PreparedStatementCallback<int[]> getPreparedStatementCallback(BatchPreparedStatementSetter pss, @Nullable KeyHolder generatedKeyHolder) {
1584+
return ps -> {
1585+
try {
1586+
int batchSize = pss.getBatchSize();
1587+
InterruptibleBatchPreparedStatementSetter ipss =
1588+
(pss instanceof InterruptibleBatchPreparedStatementSetter ibpss ? ibpss : null);
1589+
if (generatedKeyHolder != null) {
1590+
generatedKeyHolder.getKeyList().clear();
1591+
}
1592+
if (JdbcUtils.supportsBatchUpdates(ps.getConnection())) {
1593+
for (int i = 0; i < batchSize; i++) {
1594+
pss.setValues(ps, i);
1595+
if (ipss != null && ipss.isBatchExhausted(i)) {
1596+
break;
1597+
}
1598+
ps.addBatch();
1599+
}
1600+
int[] results = ps.executeBatch();
1601+
if (generatedKeyHolder != null) {
1602+
storeGeneratedKeys(generatedKeyHolder, ps, batchSize);
1603+
}
1604+
return results;
1605+
}
1606+
else {
1607+
List<Integer> rowsAffected = new ArrayList<>();
1608+
for (int i = 0; i < batchSize; i++) {
1609+
pss.setValues(ps, i);
1610+
if (ipss != null && ipss.isBatchExhausted(i)) {
1611+
break;
1612+
}
1613+
rowsAffected.add(ps.executeUpdate());
1614+
if (generatedKeyHolder != null) {
1615+
storeGeneratedKeys(generatedKeyHolder, ps, 1);
1616+
}
1617+
}
1618+
int[] rowsAffectedArray = new int[rowsAffected.size()];
1619+
for (int i = 0; i < rowsAffectedArray.length; i++) {
1620+
rowsAffectedArray[i] = rowsAffected.get(i);
1621+
}
1622+
return rowsAffectedArray;
1623+
}
1624+
}
1625+
finally {
1626+
if (pss instanceof ParameterDisposer parameterDisposer) {
1627+
parameterDisposer.cleanupParameters();
1628+
}
1629+
}
1630+
};
1631+
}
1632+
16071633

16081634
/**
16091635
* Invocation handler that suppresses close calls on JDBC Connections.

spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterJdbcOperations.java

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -555,4 +555,34 @@ int update(String sql, SqlParameterSource paramSource, KeyHolder generatedKeyHol
555555
*/
556556
int[] batchUpdate(String sql, SqlParameterSource[] batchArgs);
557557

558+
/**
559+
* Execute a batch using the supplied SQL statement with the batch of supplied arguments,
560+
* returning generated keys.
561+
* @param sql the SQL statement to execute
562+
* @param batchArgs the array of {@link SqlParameterSource} containing the batch of
563+
* arguments for the query
564+
* @param generatedKeyHolder a {@link KeyHolder} that will hold the generated keys
565+
* @return an array containing the numbers of rows affected by each update in the batch
566+
* (may also contain special JDBC-defined negative values for affected rows such as
567+
* {@link java.sql.Statement#SUCCESS_NO_INFO}/{@link java.sql.Statement#EXECUTE_FAILED})
568+
* @throws DataAccessException if there is any problem issuing the update
569+
* @see org.springframework.jdbc.support.GeneratedKeyHolder
570+
*/
571+
int[] batchUpdate(String sql, SqlParameterSource[] batchArgs, KeyHolder generatedKeyHolder);
572+
573+
/**
574+
* Execute a batch using the supplied SQL statement with the batch of supplied arguments,
575+
* returning generated keys.
576+
* @param sql the SQL statement to execute
577+
* @param batchArgs the array of {@link SqlParameterSource} containing the batch of
578+
* arguments for the query
579+
* @param generatedKeyHolder a {@link KeyHolder} that will hold the generated keys
580+
* @param keyColumnNames names of the columns that will have keys generated for them
581+
* @return an array containing the numbers of rows affected by each update in the batch
582+
* (may also contain special JDBC-defined negative values for affected rows such as
583+
* {@link java.sql.Statement#SUCCESS_NO_INFO}/{@link java.sql.Statement#EXECUTE_FAILED})
584+
* @throws DataAccessException if there is any problem issuing the update
585+
* @see org.springframework.jdbc.support.GeneratedKeyHolder
586+
*/
587+
int[] batchUpdate(String sql, SqlParameterSource[] batchArgs, KeyHolder generatedKeyHolder, String[] keyColumnNames);
558588
}

spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterJdbcTemplate.java

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -394,6 +394,44 @@ public int getBatchSize() {
394394
});
395395
}
396396

397+
@Override
398+
public int[] batchUpdate(String sql, SqlParameterSource[] batchArgs, KeyHolder generatedKeyHolder) {
399+
return batchUpdate(sql, batchArgs, generatedKeyHolder, null);
400+
}
401+
402+
@Override
403+
public int[] batchUpdate(String sql, SqlParameterSource[] batchArgs, KeyHolder generatedKeyHolder, String[] keyColumnNames) {
404+
if (batchArgs.length == 0) {
405+
return new int[0];
406+
}
407+
408+
ParsedSql parsedSql = getParsedSql(sql);
409+
SqlParameterSource paramSource = batchArgs[0];
410+
PreparedStatementCreatorFactory pscf = getPreparedStatementCreatorFactory(parsedSql, paramSource);
411+
if (keyColumnNames != null) {
412+
pscf.setGeneratedKeysColumnNames(keyColumnNames);
413+
}
414+
else {
415+
pscf.setReturnGeneratedKeys(true);
416+
}
417+
Object[] params = NamedParameterUtils.buildValueArray(parsedSql, paramSource, null);
418+
PreparedStatementCreator psc = pscf.newPreparedStatementCreator(params);
419+
return getJdbcOperations().batchUpdate(
420+
psc,
421+
new BatchPreparedStatementSetter() {
422+
@Override
423+
public void setValues(PreparedStatement ps, int i) throws SQLException {
424+
Object[] values = NamedParameterUtils.buildValueArray(parsedSql, batchArgs[i], null);
425+
pscf.newPreparedStatementSetter(values).setValues(ps);
426+
}
427+
@Override
428+
public int getBatchSize() {
429+
return batchArgs.length;
430+
}
431+
},
432+
generatedKeyHolder);
433+
}
434+
397435

398436
/**
399437
* Build a {@link PreparedStatementCreator} based on the given SQL and named parameters.

spring-jdbc/src/test/java/org/springframework/jdbc/core/JdbcTemplateTests.java

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747
import org.springframework.jdbc.core.support.AbstractInterruptibleBatchPreparedStatementSetter;
4848
import org.springframework.jdbc.datasource.ConnectionProxy;
4949
import org.springframework.jdbc.datasource.SingleConnectionDataSource;
50+
import org.springframework.jdbc.support.GeneratedKeyHolder;
51+
import org.springframework.jdbc.support.KeyHolder;
5052
import org.springframework.jdbc.support.SQLErrorCodeSQLExceptionTranslator;
5153
import org.springframework.jdbc.support.SQLStateSQLExceptionTranslator;
5254
import org.springframework.util.LinkedCaseInsensitiveMap;
@@ -1104,6 +1106,83 @@ public void testEquallyNamedColumn() throws SQLException {
11041106
assertThat(map.get("x")).isEqualTo("first value");
11051107
}
11061108

1109+
@Test
1110+
void testBatchUpdateReturnsGeneratedKeys_whenDatabaseSupportsBatchUpdates() throws SQLException {
1111+
final int[] rowsAffected = new int[] {1, 2};
1112+
given(this.preparedStatement.executeBatch()).willReturn(rowsAffected);
1113+
DatabaseMetaData databaseMetaData = mock(DatabaseMetaData.class);
1114+
given(databaseMetaData.supportsBatchUpdates()).willReturn(true);
1115+
given(this.connection.getMetaData()).willReturn(databaseMetaData);
1116+
ResultSet generatedKeysResultSet = mock(ResultSet.class);
1117+
ResultSetMetaData rsmd = mock(ResultSetMetaData.class);
1118+
given(rsmd.getColumnCount()).willReturn(1);
1119+
given(rsmd.getColumnLabel(1)).willReturn("someId");
1120+
given(generatedKeysResultSet.getMetaData()).willReturn(rsmd);
1121+
given(generatedKeysResultSet.getObject(1)).willReturn(123, 456);
1122+
given(generatedKeysResultSet.next()).willReturn(true, true, false);
1123+
given(this.preparedStatement.getGeneratedKeys()).willReturn(generatedKeysResultSet);
1124+
1125+
int[] values = new int[]{100, 200};
1126+
BatchPreparedStatementSetter bpss = new BatchPreparedStatementSetter() {
1127+
@Override
1128+
public void setValues(PreparedStatement ps, int i) throws SQLException {
1129+
ps.setObject(i, values[i]);
1130+
}
1131+
1132+
@Override
1133+
public int getBatchSize() {
1134+
return 2;
1135+
}
1136+
};
1137+
1138+
KeyHolder keyHolder = new GeneratedKeyHolder();
1139+
this.template.batchUpdate(con -> con.prepareStatement(""), bpss, keyHolder);
1140+
1141+
assertThat(keyHolder.getKeyList()).containsExactly(
1142+
Collections.singletonMap("someId", 123),
1143+
Collections.singletonMap("someId", 456));
1144+
}
1145+
1146+
@Test
1147+
void testBatchUpdateReturnsGeneratedKeys_whenDatabaseDoesNotSupportBatchUpdates() throws SQLException {
1148+
final int[] rowsAffected = new int[] {1, 2};
1149+
given(this.preparedStatement.executeBatch()).willReturn(rowsAffected);
1150+
DatabaseMetaData databaseMetaData = mock(DatabaseMetaData.class);
1151+
given(databaseMetaData.supportsBatchUpdates()).willReturn(false);
1152+
given(this.connection.getMetaData()).willReturn(databaseMetaData);
1153+
ResultSetMetaData rsmd = mock(ResultSetMetaData.class);
1154+
given(rsmd.getColumnCount()).willReturn(1);
1155+
given(rsmd.getColumnLabel(1)).willReturn("someId");
1156+
ResultSet generatedKeysResultSet1 = mock(ResultSet.class);
1157+
given(generatedKeysResultSet1.getMetaData()).willReturn(rsmd);
1158+
given(generatedKeysResultSet1.getObject(1)).willReturn(123);
1159+
given(generatedKeysResultSet1.next()).willReturn(true, false);
1160+
ResultSet generatedKeysResultSet2 = mock(ResultSet.class);
1161+
given(generatedKeysResultSet2.getMetaData()).willReturn(rsmd);
1162+
given(generatedKeysResultSet2.getObject(1)).willReturn(456);
1163+
given(generatedKeysResultSet2.next()).willReturn(true, false);
1164+
given(this.preparedStatement.getGeneratedKeys()).willReturn(generatedKeysResultSet1, generatedKeysResultSet2);
1165+
1166+
int[] values = new int[]{100, 200};
1167+
BatchPreparedStatementSetter bpss = new BatchPreparedStatementSetter() {
1168+
@Override
1169+
public void setValues(PreparedStatement ps, int i) throws SQLException {
1170+
ps.setObject(i, values[i]);
1171+
}
1172+
1173+
@Override
1174+
public int getBatchSize() {
1175+
return 2;
1176+
}
1177+
};
1178+
1179+
KeyHolder keyHolder = new GeneratedKeyHolder();
1180+
this.template.batchUpdate(con -> con.prepareStatement(""), bpss, keyHolder);
1181+
1182+
assertThat(keyHolder.getKeyList()).containsExactly(
1183+
Collections.singletonMap("someId", 123),
1184+
Collections.singletonMap("someId", 456));
1185+
}
11071186

11081187
private void mockDatabaseMetaData(boolean supportsBatchUpdates) throws SQLException {
11091188
DatabaseMetaData databaseMetaData = mock();

0 commit comments

Comments
 (0)