diff --git a/spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcOperations.java b/spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcOperations.java index 7910753806aa..4bd49112fde8 100644 --- a/spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcOperations.java +++ b/spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcOperations.java @@ -990,6 +990,27 @@ List queryForList(String sql, Object[] args, int[] argTypes, Class ele */ int[] batchUpdate(String sql, BatchPreparedStatementSetter pss) throws DataAccessException; + /** + * Issue multiple update statements on a single PreparedStatement, + * using batch updates and a BatchPreparedStatementSetter to set values. + * Generated keys will be put into the given KeyHolder. + *

Note that the given PreparedStatementCreator has to create a statement + * with activated extraction of generated keys (a JDBC 3.0 feature). This can + * either be done directly or through using a PreparedStatementCreatorFactory. + *

Will fall back to separate updates on a single PreparedStatement + * if the JDBC driver does not support batch updates. + * @param psc a callback that creates a PreparedStatement given a Connection + * @param pss object to set parameters on the PreparedStatement + * created by this method + * @param generatedKeyHolder a KeyHolder that will hold the generated keys + * @return an array of the number of rows affected by each statement + * (may also contain special JDBC-defined negative values for affected rows such as + * {@link java.sql.Statement#SUCCESS_NO_INFO}/{@link java.sql.Statement#EXECUTE_FAILED}) + * @throws DataAccessException if there is any problem issuing the update + * @see org.springframework.jdbc.support.GeneratedKeyHolder + */ + int[] batchUpdate(PreparedStatementCreator psc, BatchPreparedStatementSetter pss, KeyHolder generatedKeyHolder) throws DataAccessException; + /** * Execute a batch using the supplied SQL statement with the batch of supplied arguments. * @param sql the SQL statement to execute diff --git a/spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java b/spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java index 179a6b79e166..cd7d45306e0f 100644 --- a/spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java +++ b/spring-jdbc/src/main/java/org/springframework/jdbc/core/JdbcTemplate.java @@ -990,21 +990,10 @@ public int update(final PreparedStatementCreator psc, final KeyHolder generatedK return updateCount(execute(psc, ps -> { int rows = ps.executeUpdate(); - List> generatedKeys = generatedKeyHolder.getKeyList(); - generatedKeys.clear(); - ResultSet keys = ps.getGeneratedKeys(); - if (keys != null) { - try { - RowMapperResultSetExtractor> rse = - new RowMapperResultSetExtractor<>(getColumnMapRowMapper(), 1); - generatedKeys.addAll(result(rse.extractData(keys))); - } - finally { - JdbcUtils.closeResultSet(keys); - } - } + generatedKeyHolder.getKeyList().clear(); + storeGeneratedKeys(generatedKeyHolder, ps, 1); if (logger.isTraceEnabled()) { - logger.trace("SQL update affected " + rows + " rows and returned " + generatedKeys.size() + " keys"); + logger.trace("SQL update affected " + rows + " rows and returned " + generatedKeyHolder.getKeyList().size() + " keys"); } return rows; }, true)); @@ -1025,50 +1014,21 @@ public int update(String sql, @Nullable Object... args) throws DataAccessExcepti return update(sql, newArgPreparedStatementSetter(args)); } + @Override + public int[] batchUpdate(final PreparedStatementCreator psc, final BatchPreparedStatementSetter pss, final KeyHolder generatedKeyHolder) throws DataAccessException { + int[] result = execute(psc, getPreparedStatementCallback(pss, generatedKeyHolder)); + + Assert.state(result != null, "No result array"); + return result; + } + @Override public int[] batchUpdate(String sql, final BatchPreparedStatementSetter pss) throws DataAccessException { if (logger.isDebugEnabled()) { logger.debug("Executing SQL batch update [" + sql + "]"); } - int[] result = execute(sql, (PreparedStatementCallback) ps -> { - try { - int batchSize = pss.getBatchSize(); - InterruptibleBatchPreparedStatementSetter ipss = - (pss instanceof InterruptibleBatchPreparedStatementSetter ? - (InterruptibleBatchPreparedStatementSetter) pss : null); - if (JdbcUtils.supportsBatchUpdates(ps.getConnection())) { - for (int i = 0; i < batchSize; i++) { - pss.setValues(ps, i); - if (ipss != null && ipss.isBatchExhausted(i)) { - break; - } - ps.addBatch(); - } - return ps.executeBatch(); - } - else { - List rowsAffected = new ArrayList<>(); - for (int i = 0; i < batchSize; i++) { - pss.setValues(ps, i); - if (ipss != null && ipss.isBatchExhausted(i)) { - break; - } - rowsAffected.add(ps.executeUpdate()); - } - int[] rowsAffectedArray = new int[rowsAffected.size()]; - for (int i = 0; i < rowsAffectedArray.length; i++) { - rowsAffectedArray[i] = rowsAffected.get(i); - } - return rowsAffectedArray; - } - } - finally { - if (pss instanceof ParameterDisposer) { - ((ParameterDisposer) pss).cleanupParameters(); - } - } - }); + int[] result = execute(sql, getPreparedStatementCallback(pss, null)); Assert.state(result != null, "No result array"); return result; @@ -1567,6 +1527,72 @@ private static int updateCount(@Nullable Integer result) { return result; } + private void storeGeneratedKeys(KeyHolder generatedKeyHolder, PreparedStatement ps, int rowsExpected) throws SQLException { + List> generatedKeys = generatedKeyHolder.getKeyList(); + ResultSet keys = ps.getGeneratedKeys(); + if (keys != null) { + try { + RowMapperResultSetExtractor> rse = + new RowMapperResultSetExtractor<>(getColumnMapRowMapper(), rowsExpected); + generatedKeys.addAll(result(rse.extractData(keys))); + } + finally { + JdbcUtils.closeResultSet(keys); + } + } + } + + private PreparedStatementCallback getPreparedStatementCallback(BatchPreparedStatementSetter pss, @Nullable KeyHolder generatedKeyHolder) { + return ps -> { + try { + int batchSize = pss.getBatchSize(); + InterruptibleBatchPreparedStatementSetter ipss = + (pss instanceof InterruptibleBatchPreparedStatementSetter ? + (InterruptibleBatchPreparedStatementSetter) pss : null); + if (generatedKeyHolder != null) { + generatedKeyHolder.getKeyList().clear(); + } + if (JdbcUtils.supportsBatchUpdates(ps.getConnection())) { + for (int i = 0; i < batchSize; i++) { + pss.setValues(ps, i); + if (ipss != null && ipss.isBatchExhausted(i)) { + break; + } + ps.addBatch(); + } + int[] results = ps.executeBatch(); + if (generatedKeyHolder != null) { + storeGeneratedKeys(generatedKeyHolder, ps, batchSize); + } + return results; + } + else { + List rowsAffected = new ArrayList<>(); + for (int i = 0; i < batchSize; i++) { + pss.setValues(ps, i); + if (ipss != null && ipss.isBatchExhausted(i)) { + break; + } + rowsAffected.add(ps.executeUpdate()); + if (generatedKeyHolder != null) { + storeGeneratedKeys(generatedKeyHolder, ps, 1); + } + } + int[] rowsAffectedArray = new int[rowsAffected.size()]; + for (int i = 0; i < rowsAffectedArray.length; i++) { + rowsAffectedArray[i] = rowsAffected.get(i); + } + return rowsAffectedArray; + } + } + finally { + if (pss instanceof ParameterDisposer) { + ((ParameterDisposer) pss).cleanupParameters(); + } + } + }; + } + /** * Invocation handler that suppresses close calls on JDBC Connections. diff --git a/spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterJdbcOperations.java b/spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterJdbcOperations.java index b308e06f735b..e0a51e9f106b 100644 --- a/spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterJdbcOperations.java +++ b/spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterJdbcOperations.java @@ -549,4 +549,34 @@ int update(String sql, SqlParameterSource paramSource, KeyHolder generatedKeyHol */ int[] batchUpdate(String sql, SqlParameterSource[] batchArgs); + /** + * Execute a batch using the supplied SQL statement with the batch of supplied arguments, + * returning generated keys. + * @param sql the SQL statement to execute + * @param batchArgs the array of {@link SqlParameterSource} containing the batch of + * arguments for the query + * @param generatedKeyHolder a {@link KeyHolder} that will hold the generated keys + * @return an array containing the numbers of rows affected by each update in the batch + * (may also contain special JDBC-defined negative values for affected rows such as + * {@link java.sql.Statement#SUCCESS_NO_INFO}/{@link java.sql.Statement#EXECUTE_FAILED}) + * @throws DataAccessException if there is any problem issuing the update + * @see org.springframework.jdbc.support.GeneratedKeyHolder + */ + int[] batchUpdate(String sql, SqlParameterSource[] batchArgs, KeyHolder generatedKeyHolder); + + /** + * Execute a batch using the supplied SQL statement with the batch of supplied arguments, + * returning generated keys. + * @param sql the SQL statement to execute + * @param batchArgs the array of {@link SqlParameterSource} containing the batch of + * arguments for the query + * @param generatedKeyHolder a {@link KeyHolder} that will hold the generated keys + * @param keyColumnNames names of the columns that will have keys generated for them + * @return an array containing the numbers of rows affected by each update in the batch + * (may also contain special JDBC-defined negative values for affected rows such as + * {@link java.sql.Statement#SUCCESS_NO_INFO}/{@link java.sql.Statement#EXECUTE_FAILED}) + * @throws DataAccessException if there is any problem issuing the update + * @see org.springframework.jdbc.support.GeneratedKeyHolder + */ + int[] batchUpdate(String sql, SqlParameterSource[] batchArgs, KeyHolder generatedKeyHolder, String[] keyColumnNames); } diff --git a/spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterJdbcTemplate.java b/spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterJdbcTemplate.java index ef7b6567dfcc..94779631f15e 100644 --- a/spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterJdbcTemplate.java +++ b/spring-jdbc/src/main/java/org/springframework/jdbc/core/namedparam/NamedParameterJdbcTemplate.java @@ -385,6 +385,44 @@ public int getBatchSize() { }); } + @Override + public int[] batchUpdate(String sql, SqlParameterSource[] batchArgs, KeyHolder generatedKeyHolder) { + return batchUpdate(sql, batchArgs, generatedKeyHolder, null); + } + + @Override + public int[] batchUpdate(String sql, SqlParameterSource[] batchArgs, KeyHolder generatedKeyHolder, String[] keyColumnNames) { + if (batchArgs.length == 0) { + return new int[0]; + } + + ParsedSql parsedSql = getParsedSql(sql); + SqlParameterSource paramSource = batchArgs[0]; + PreparedStatementCreatorFactory pscf = getPreparedStatementCreatorFactory(parsedSql, paramSource); + if (keyColumnNames != null) { + pscf.setGeneratedKeysColumnNames(keyColumnNames); + } + else { + pscf.setReturnGeneratedKeys(true); + } + Object[] params = NamedParameterUtils.buildValueArray(parsedSql, paramSource, null); + PreparedStatementCreator psc = pscf.newPreparedStatementCreator(params); + return getJdbcOperations().batchUpdate( + psc, + new BatchPreparedStatementSetter() { + @Override + public void setValues(PreparedStatement ps, int i) throws SQLException { + Object[] values = NamedParameterUtils.buildValueArray(parsedSql, batchArgs[i], null); + pscf.newPreparedStatementSetter(values).setValues(ps); + } + @Override + public int getBatchSize() { + return batchArgs.length; + } + }, + generatedKeyHolder); + } + /** * Build a {@link PreparedStatementCreator} based on the given SQL and named parameters. diff --git a/spring-jdbc/src/test/java/org/springframework/jdbc/core/JdbcTemplateTests.java b/spring-jdbc/src/test/java/org/springframework/jdbc/core/JdbcTemplateTests.java index 456d59dd5bfd..044076b4622a 100644 --- a/spring-jdbc/src/test/java/org/springframework/jdbc/core/JdbcTemplateTests.java +++ b/spring-jdbc/src/test/java/org/springframework/jdbc/core/JdbcTemplateTests.java @@ -47,6 +47,8 @@ import org.springframework.jdbc.core.support.AbstractInterruptibleBatchPreparedStatementSetter; import org.springframework.jdbc.datasource.ConnectionProxy; import org.springframework.jdbc.datasource.SingleConnectionDataSource; +import org.springframework.jdbc.support.GeneratedKeyHolder; +import org.springframework.jdbc.support.KeyHolder; import org.springframework.jdbc.support.SQLErrorCodeSQLExceptionTranslator; import org.springframework.jdbc.support.SQLStateSQLExceptionTranslator; import org.springframework.util.LinkedCaseInsensitiveMap; @@ -1085,6 +1087,83 @@ public void testEquallyNamedColumn() throws SQLException { assertThat(map.get("x")).isEqualTo("first value"); } + @Test + void testBatchUpdateReturnsGeneratedKeys_whenDatabaseSupportsBatchUpdates() throws SQLException { + final int[] rowsAffected = new int[] {1, 2}; + given(this.preparedStatement.executeBatch()).willReturn(rowsAffected); + DatabaseMetaData databaseMetaData = mock(DatabaseMetaData.class); + given(databaseMetaData.supportsBatchUpdates()).willReturn(true); + given(this.connection.getMetaData()).willReturn(databaseMetaData); + ResultSet generatedKeysResultSet = mock(ResultSet.class); + ResultSetMetaData rsmd = mock(ResultSetMetaData.class); + given(rsmd.getColumnCount()).willReturn(1); + given(rsmd.getColumnLabel(1)).willReturn("someId"); + given(generatedKeysResultSet.getMetaData()).willReturn(rsmd); + given(generatedKeysResultSet.getObject(1)).willReturn(123, 456); + given(generatedKeysResultSet.next()).willReturn(true, true, false); + given(this.preparedStatement.getGeneratedKeys()).willReturn(generatedKeysResultSet); + + int[] values = new int[]{100, 200}; + BatchPreparedStatementSetter bpss = new BatchPreparedStatementSetter() { + @Override + public void setValues(PreparedStatement ps, int i) throws SQLException { + ps.setObject(i, values[i]); + } + + @Override + public int getBatchSize() { + return 2; + } + }; + + KeyHolder keyHolder = new GeneratedKeyHolder(); + this.template.batchUpdate(con -> con.prepareStatement(""), bpss, keyHolder); + + assertThat(keyHolder.getKeyList()).containsExactly( + Collections.singletonMap("someId", 123), + Collections.singletonMap("someId", 456)); + } + + @Test + void testBatchUpdateReturnsGeneratedKeys_whenDatabaseDoesNotSupportBatchUpdates() throws SQLException { + final int[] rowsAffected = new int[] {1, 2}; + given(this.preparedStatement.executeBatch()).willReturn(rowsAffected); + DatabaseMetaData databaseMetaData = mock(DatabaseMetaData.class); + given(databaseMetaData.supportsBatchUpdates()).willReturn(false); + given(this.connection.getMetaData()).willReturn(databaseMetaData); + ResultSetMetaData rsmd = mock(ResultSetMetaData.class); + given(rsmd.getColumnCount()).willReturn(1); + given(rsmd.getColumnLabel(1)).willReturn("someId"); + ResultSet generatedKeysResultSet1 = mock(ResultSet.class); + given(generatedKeysResultSet1.getMetaData()).willReturn(rsmd); + given(generatedKeysResultSet1.getObject(1)).willReturn(123); + given(generatedKeysResultSet1.next()).willReturn(true, false); + ResultSet generatedKeysResultSet2 = mock(ResultSet.class); + given(generatedKeysResultSet2.getMetaData()).willReturn(rsmd); + given(generatedKeysResultSet2.getObject(1)).willReturn(456); + given(generatedKeysResultSet2.next()).willReturn(true, false); + given(this.preparedStatement.getGeneratedKeys()).willReturn(generatedKeysResultSet1, generatedKeysResultSet2); + + int[] values = new int[]{100, 200}; + BatchPreparedStatementSetter bpss = new BatchPreparedStatementSetter() { + @Override + public void setValues(PreparedStatement ps, int i) throws SQLException { + ps.setObject(i, values[i]); + } + + @Override + public int getBatchSize() { + return 2; + } + }; + + KeyHolder keyHolder = new GeneratedKeyHolder(); + this.template.batchUpdate(con -> con.prepareStatement(""), bpss, keyHolder); + + assertThat(keyHolder.getKeyList()).containsExactly( + Collections.singletonMap("someId", 123), + Collections.singletonMap("someId", 456)); + } private void mockDatabaseMetaData(boolean supportsBatchUpdates) throws SQLException { DatabaseMetaData databaseMetaData = mock(DatabaseMetaData.class);