|
19 | 19 |
|
20 | 20 | import java.sql.ResultSet;
|
21 | 21 | import java.sql.SQLException;
|
| 22 | +import java.util.Arrays; |
22 | 23 | import java.util.Collections;
|
23 | 24 | import java.util.List;
|
| 25 | +import java.util.Map; |
| 26 | +import java.util.Objects; |
24 | 27 | import java.util.Optional;
|
| 28 | +import java.util.Set; |
| 29 | +import java.util.stream.Collectors; |
| 30 | +import java.util.stream.IntStream; |
| 31 | +import java.util.stream.LongStream; |
25 | 32 |
|
26 | 33 | import org.springframework.dao.EmptyResultDataAccessException;
|
27 | 34 | import org.springframework.dao.OptimisticLockingFailureException;
|
|
37 | 44 | import org.springframework.data.relational.core.query.Query;
|
38 | 45 | import org.springframework.data.relational.core.sql.LockMode;
|
39 | 46 | import org.springframework.data.relational.core.sql.SqlIdentifier;
|
| 47 | +import org.springframework.data.util.Pair; |
40 | 48 | import org.springframework.jdbc.core.RowMapper;
|
41 | 49 | import org.springframework.jdbc.core.namedparam.MapSqlParameterSource;
|
42 | 50 | import org.springframework.jdbc.core.namedparam.NamedParameterJdbcOperations;
|
@@ -102,31 +110,35 @@ public DefaultDataAccessStrategy(SqlGeneratorSource sqlGeneratorSource, Relation
|
102 | 110 | @Override
|
103 | 111 | public <T> Object insert(T instance, Class<T> domainType, Identifier identifier, IdValueSource idValueSource) {
|
104 | 112 |
|
105 |
| - SqlIdentifierParameterSource parameterSource = sqlParametersFactory.forInsert(instance, domainType, identifier, |
106 |
| - idValueSource); |
| 113 | + RelationalPersistentEntity<?> persistentEntity = context.getRequiredPersistentEntity(domainType); |
| 114 | + |
| 115 | + Optional<Long> idFromSequence = getIdFromSequenceIfAnyDefined(idValueSource, persistentEntity); |
| 116 | + |
| 117 | + SqlIdentifierParameterSource parameterSource = idFromSequence |
| 118 | + .map(it -> sqlParametersFactory.forInsert(instance, domainType, identifier, it)) |
| 119 | + .orElseGet(() -> sqlParametersFactory.forInsert(instance, domainType, identifier, idValueSource)); |
107 | 120 |
|
108 | 121 | String insertSql = sql(domainType).getInsert(parameterSource.getIdentifiers());
|
109 | 122 |
|
110 |
| - return insertStrategyFactory.insertStrategy(idValueSource, getIdColumn(domainType)).execute(insertSql, |
111 |
| - parameterSource); |
112 |
| - } |
| 123 | + Object idAfterExecute = insertStrategyFactory.insertStrategy(idValueSource, getIdColumn(domainType)) |
| 124 | + .execute(insertSql, parameterSource); |
| 125 | + |
| 126 | + return idFromSequence.map(it -> (Object) it).orElse(idAfterExecute); |
| 127 | + } |
113 | 128 |
|
114 | 129 | @Override
|
115 | 130 | public <T> Object[] insert(List<InsertSubject<T>> insertSubjects, Class<T> domainType, IdValueSource idValueSource) {
|
116 | 131 |
|
117 | 132 | Assert.notEmpty(insertSubjects, "Batch insert must contain at least one InsertSubject");
|
118 |
| - SqlIdentifierParameterSource[] sqlParameterSources = insertSubjects.stream() |
119 |
| - .map(insertSubject -> sqlParametersFactory.forInsert(insertSubject.getInstance(), domainType, |
120 |
| - insertSubject.getIdentifier(), idValueSource)) |
121 |
| - .toArray(SqlIdentifierParameterSource[]::new); |
122 | 133 |
|
123 |
| - String insertSql = sql(domainType).getInsert(sqlParameterSources[0].getIdentifiers()); |
| 134 | + if (IdValueSource.SEQUENCE.equals(idValueSource)) { |
| 135 | + return executeBatchInsertWithSequenceAsIdSource(insertSubjects, domainType, idValueSource); |
| 136 | + } else { |
| 137 | + return executeBatchInsert(insertSubjects, domainType, idValueSource); |
| 138 | + } |
| 139 | + } |
124 | 140 |
|
125 |
| - return insertStrategyFactory.batchInsertStrategy(idValueSource, getIdColumn(domainType)).execute(insertSql, |
126 |
| - sqlParameterSources); |
127 |
| - } |
128 |
| - |
129 |
| - @Override |
| 141 | + @Override |
130 | 142 | public <S> boolean update(S instance, Class<S> domainType) {
|
131 | 143 |
|
132 | 144 | SqlIdentifierParameterSource parameterSource = sqlParametersFactory.forUpdate(instance, domainType);
|
@@ -446,4 +458,70 @@ private Class<?> getBaseType(PersistentPropertyPath<RelationalPersistentProperty
|
446 | 458 | return baseProperty.getOwner().getType();
|
447 | 459 | }
|
448 | 460 |
|
| 461 | + private <T> Object[] executeBatchInsert(List<InsertSubject<T>> insertSubjects, Class<T> domainType, IdValueSource idValueSource) { |
| 462 | + SqlIdentifierParameterSource[] sqlParameterSources = insertSubjects |
| 463 | + .stream() |
| 464 | + .map(insertSubject -> sqlParametersFactory.forInsert( |
| 465 | + insertSubject.getInstance(), domainType, |
| 466 | + insertSubject.getIdentifier(), idValueSource) |
| 467 | + ) |
| 468 | + .toArray(SqlIdentifierParameterSource[]::new); |
| 469 | + |
| 470 | + String insertSql = sql(domainType).getInsert(sqlParameterSources[0].getIdentifiers()); |
| 471 | + |
| 472 | + return insertStrategyFactory.batchInsertStrategy(idValueSource, getIdColumn(domainType)) |
| 473 | + .execute(insertSql, sqlParameterSources); |
| 474 | + } |
| 475 | + |
| 476 | + private <T> Object[] executeBatchInsertWithSequenceAsIdSource(List<InsertSubject<T>> insertSubjects, Class<T> domainType, IdValueSource idValueSource) { |
| 477 | + List<Pair<Long, SqlIdentifierParameterSource>> sqlParameterSources = createBatchParameterSourcesWithSequence(insertSubjects, domainType, |
| 478 | + context.getPersistentEntity(domainType).getIdTargetSequence() |
| 479 | + ); |
| 480 | + |
| 481 | + String insertSql = sql(domainType).getInsert(sqlParameterSources.get(0).getSecond().getIdentifiers()); |
| 482 | + |
| 483 | + insertStrategyFactory.batchInsertStrategy(idValueSource, getIdColumn(domainType)) |
| 484 | + .execute(insertSql, sqlParameterSources.stream() |
| 485 | + .map(Pair::getSecond) |
| 486 | + .toArray(SqlIdentifierParameterSource[]::new)); |
| 487 | + |
| 488 | + return sqlParameterSources.stream().map(Pair::getFirst).toArray(Object[]::new); |
| 489 | + } |
| 490 | + |
| 491 | + private <T> List<Pair<Long, SqlIdentifierParameterSource>> createBatchParameterSourcesWithSequence(List<InsertSubject<T>> insertSubjects, Class<T> domainType, Optional<String> idTargetSequence) { |
| 492 | + List<Pair<Long, SqlIdentifierParameterSource>> sqlParameterSources; |
| 493 | + int subjectsSize = insertSubjects.size(); |
| 494 | + |
| 495 | + List<Long> generatedIds = getMultipleIdsFromSequence(idTargetSequence.get(), subjectsSize); |
| 496 | + |
| 497 | + sqlParameterSources = IntStream |
| 498 | + .range(0, subjectsSize) |
| 499 | + .mapToObj(index -> { |
| 500 | + InsertSubject<T> subject = insertSubjects.get(index); |
| 501 | + Long generatedId = generatedIds.get(index); |
| 502 | + return Pair.of(generatedId, sqlParametersFactory.forInsert( |
| 503 | + subject.getInstance(), domainType, |
| 504 | + subject.getIdentifier(), generatedId |
| 505 | + )); |
| 506 | + }) |
| 507 | + .collect(Collectors.toList()); |
| 508 | + return sqlParameterSources; |
| 509 | + } |
| 510 | + |
| 511 | + private Optional<Long> getIdFromSequenceIfAnyDefined(IdValueSource idValueSource, RelationalPersistentEntity<?> persistentEntity) { |
| 512 | + if (IdValueSource.SEQUENCE.equals(idValueSource) && persistentEntity.getIdTargetSequence().isPresent()) { |
| 513 | + String nextSequenceValueSelect = insertStrategyFactory.getDialect().nextValueFromSequenceSelect(persistentEntity.getIdTargetSequence().get()); |
| 514 | + return Optional.of(operations.queryForObject(nextSequenceValueSelect, Map.of(), (rs, rowNum) -> rs.getLong(1))); |
| 515 | + } |
| 516 | + return Optional.empty(); |
| 517 | + } |
| 518 | + |
| 519 | + private List<Long> getMultipleIdsFromSequence(String sequenceName, Integer requiredIds) { |
| 520 | + String nextSequenceValueSelect = insertStrategyFactory.getDialect().nextValueFromSequenceSelect(sequenceName); |
| 521 | + |
| 522 | + return IntStream.range(0, requiredIds) |
| 523 | + .mapToObj(operand -> operations.queryForObject(nextSequenceValueSelect, Map.of(), (rs, rowNum) -> rs.getLong(1))) |
| 524 | + .collect(Collectors.toList()); |
| 525 | + } |
| 526 | + |
449 | 527 | }
|
0 commit comments