Skip to content

Commit 67eeca8

Browse files
committed
DATAJDBC-219 - Add support for optimistic locking.
The @Version annotation is now evaluated properly and an OptimisticLockingFailureException is thrown on updates when the version has been incremented in the meantime. The @Version field can be of type Long, Integer or Short (or their primitive counterparts).
1 parent 78f26aa commit 67eeca8

9 files changed

+285
-8
lines changed

spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/DefaultDataAccessStrategy.java

+73-1
Original file line numberDiff line numberDiff line change
@@ -26,16 +26,20 @@
2626
import java.util.Map;
2727
import java.util.function.Predicate;
2828

29+
import org.springframework.core.convert.ConversionService;
2930
import org.springframework.dao.DataRetrievalFailureException;
3031
import org.springframework.dao.EmptyResultDataAccessException;
3132
import org.springframework.dao.InvalidDataAccessApiUsageException;
33+
import org.springframework.dao.OptimisticLockingFailureException;
3234
import org.springframework.data.jdbc.core.convert.JdbcConverter;
3335
import org.springframework.data.jdbc.core.convert.JdbcValue;
3436
import org.springframework.data.jdbc.support.JdbcUtil;
3537
import org.springframework.data.mapping.PersistentProperty;
3638
import org.springframework.data.mapping.PersistentPropertyAccessor;
3739
import org.springframework.data.mapping.PersistentPropertyPath;
3840
import org.springframework.data.mapping.PropertyHandler;
41+
import org.springframework.data.mapping.model.ConvertingPropertyAccessor;
42+
import org.springframework.data.relational.core.conversion.RelationalConverter;
3943
import org.springframework.data.relational.core.mapping.RelationalMappingContext;
4044
import org.springframework.data.relational.core.mapping.RelationalPersistentEntity;
4145
import org.springframework.data.relational.core.mapping.RelationalPersistentProperty;
@@ -48,6 +52,7 @@
4852
import org.springframework.jdbc.support.KeyHolder;
4953
import org.springframework.lang.Nullable;
5054
import org.springframework.util.Assert;
55+
import static org.springframework.data.jdbc.core.SqlGenerator.*;
5156

5257
/**
5358
* The default {@link DataAccessStrategy} is to generate SQL statements based on meta data from the entity.
@@ -56,6 +61,7 @@
5661
* @author Mark Paluch
5762
* @author Thomas Lang
5863
* @author Bastian Wilhelm
64+
* @author Tom Hombergs
5965
*/
6066
public class DefaultDataAccessStrategy implements DataAccessStrategy {
6167

@@ -124,6 +130,12 @@ public <T> Object insert(T instance, Class<T> domainType, Identifier identifier)
124130
KeyHolder holder = new GeneratedKeyHolder();
125131
RelationalPersistentEntity<T> persistentEntity = getRequiredPersistentEntity(domainType);
126132

133+
if (persistentEntity.hasVersionProperty()) {
134+
135+
Number newVersion = getNextVersion(instance, persistentEntity, converter.getConversionService());
136+
setVersion(instance, persistentEntity, newVersion);
137+
}
138+
127139
MapSqlParameterSource parameterSource = getParameterSource(instance, persistentEntity, "",
128140
PersistentProperty::isIdProperty);
129141

@@ -151,13 +163,46 @@ public <T> Object insert(T instance, Class<T> domainType, Identifier identifier)
151163
*/
152164
@Override
153165
public <S> boolean update(S instance, Class<S> domainType) {
166+
RelationalPersistentEntity<S> persistentEntity = getRequiredPersistentEntity(domainType);
167+
168+
if (persistentEntity.hasVersionProperty()) {
169+
return updateWithVersion(instance, domainType);
170+
} else {
171+
return updateWithoutVersion(instance, domainType);
172+
}
173+
}
174+
175+
private <S> boolean updateWithoutVersion(S instance, Class<S> domainType) {
154176

155177
RelationalPersistentEntity<S> persistentEntity = getRequiredPersistentEntity(domainType);
156178

157179
return operations.update(sql(domainType).getUpdate(),
158180
getParameterSource(instance, persistentEntity, "", Predicates.includeAll())) != 0;
159181
}
160182

183+
private <S> boolean updateWithVersion(S instance, Class<S> domainType) {
184+
185+
RelationalPersistentEntity<S> persistentEntity = getRequiredPersistentEntity(domainType);
186+
187+
Number oldVersion = getVersion(instance, persistentEntity, converter.getConversionService());
188+
Number newVersion = getNextVersion(instance, persistentEntity, converter.getConversionService());
189+
setVersion(instance, persistentEntity, newVersion);
190+
191+
MapSqlParameterSource parameterSource = getParameterSource(instance, persistentEntity, "", Predicates.includeAll());
192+
parameterSource.addValue(VERSION_PARAMETER, oldVersion);
193+
int affectedRows = operations.update(sql(domainType).getUpdateWithVersion(),
194+
parameterSource);
195+
196+
if (affectedRows == 0) {
197+
// reverting version update on entity
198+
setVersion(instance, persistentEntity, oldVersion);
199+
throw new OptimisticLockingFailureException(
200+
String.format("Optimistic lock exception on saving entity of type %s.", persistentEntity.getName()));
201+
}
202+
203+
return true;
204+
}
205+
161206
/*
162207
* (non-Javadoc)
163208
* @see org.springframework.data.jdbc.core.DataAccessStrategy#delete(java.lang.Object, java.lang.Class)
@@ -354,7 +399,7 @@ private <S, ID> ID getIdValueOrNull(S instance, RelationalPersistentEntity<S> pe
354399
}
355400

356401
private static <S, ID> boolean isIdPropertyNullOrScalarZero(@Nullable ID idValue,
357-
RelationalPersistentEntity<S> persistentEntity) {
402+
RelationalPersistentEntity<S> persistentEntity) {
358403

359404
RelationalPersistentProperty idProperty = persistentEntity.getIdProperty();
360405
return idValue == null //
@@ -481,4 +526,31 @@ static Predicate<RelationalPersistentProperty> includeAll() {
481526
return it -> false;
482527
}
483528
}
529+
@Nullable
530+
private <T> Number getVersion(T instance, RelationalPersistentEntity<T> entity, ConversionService conversionService) {
531+
RelationalPersistentProperty versionProperty = entity.getRequiredVersionProperty();
532+
PersistentPropertyAccessor<T> propertyAccessor = entity.getPropertyAccessor(instance);
533+
ConvertingPropertyAccessor<T> convertingPropertyAccessor = new ConvertingPropertyAccessor<>(propertyAccessor, conversionService);
534+
return convertingPropertyAccessor.getProperty(versionProperty, Number.class);
535+
}
536+
537+
private <T> Number getNextVersion(T instance, RelationalPersistentEntity<T> entity, ConversionService conversionService) {
538+
Number version = getVersion(instance, entity, conversionService);
539+
Class<?> versionType = entity.getRequiredVersionProperty().getType();
540+
if (versionType == Integer.class || versionType == int.class) {
541+
return version == null ? 1 : version.intValue() + 1;
542+
} else if (versionType == Long.class || versionType == long.class) {
543+
return version == null ? 1L : version.longValue() + 1;
544+
} else if (versionType == Short.class || versionType == short.class) {
545+
return version == null ? (short) 1 : (short) (version.shortValue() + 1);
546+
}
547+
throw new IllegalStateException(String.format("Entity '%s' has version property of invalid type '%s'.", entity.getType().getName(), entity.getVersionProperty().getType().getName()));
548+
}
549+
550+
private <T> void setVersion(T instance, RelationalPersistentEntity<T> entity, Number newVersion) {
551+
RelationalPersistentProperty versionProperty = entity.getRequiredVersionProperty();
552+
PersistentPropertyAccessor<T> accessor = versionProperty.getOwner().getPropertyAccessor(instance);
553+
accessor.setProperty(versionProperty, newVersion);
554+
}
555+
484556
}

spring-data-jdbc/src/main/java/org/springframework/data/jdbc/core/SqlGenerator.java

+23-3
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,12 @@
4545
* @author Yoichi Imai
4646
* @author Bastian Wilhelm
4747
* @author Oleksandr Kucher
48+
* @author Tom Hombergs
4849
*/
4950
class SqlGenerator {
5051

52+
static final String VERSION_PARAMETER = "___oldOptimisticLockingVersion";
53+
5154
private final RelationalPersistentEntity<?> entity;
5255
private final RelationalMappingContext context;
5356
private final List<String> columnNames = new ArrayList<>();
@@ -62,6 +65,7 @@ class SqlGenerator {
6265
private final Lazy<String> countSql = Lazy.of(this::createCountSql);
6366

6467
private final Lazy<String> updateSql = Lazy.of(this::createUpdateSql);
68+
private final Lazy<String> updateWithVersionSql = Lazy.of(this::createUpdateWithVersionSql);
6569

6670
private final Lazy<String> deleteByIdSql = Lazy.of(this::createDeleteSql);
6771
private final Lazy<String> deleteByListSql = Lazy.of(this::createDeleteByListSql);
@@ -176,6 +180,10 @@ String getUpdate() {
176180
return updateSql.get();
177181
}
178182

183+
String getUpdateWithVersion() {
184+
return updateWithVersionSql.get();
185+
}
186+
179187
String getCount() {
180188
return countSql.get();
181189
}
@@ -343,8 +351,8 @@ private String createInsertSql(Set<String> additionalColumns) {
343351
String tableColumns = String.join(", ", columnNamesForInsert);
344352

345353
String parameterNames = columnNamesForInsert.stream()//
346-
.map(this::columnNameToParameterName)
347-
.map(n -> String.format(":%s", n))//
354+
.map(this::columnNameToParameterName) //
355+
.map(n -> String.format(":%s", n)) //
348356
.collect(Collectors.joining(", "));
349357

350358
return String.format(insertTemplate, entity.getTableName(), tableColumns, parameterNames);
@@ -369,6 +377,18 @@ private String createUpdateSql() {
369377
);
370378
}
371379

380+
private String createUpdateWithVersionSql() {
381+
String whereConditionTemplate = " AND %s = :%s";
382+
383+
String whereCondition = String.format( //
384+
whereConditionTemplate, //
385+
entity.getVersionProperty().getColumnName(), //
386+
VERSION_PARAMETER //
387+
);
388+
389+
return createUpdateSql() + whereCondition;
390+
}
391+
372392
private String createDeleteSql() {
373393
return String.format("DELETE FROM %s WHERE %s = :id", entity.getTableName(), entity.getIdColumn());
374394
}
@@ -458,7 +478,7 @@ private String cascadeConditions(String innerCondition, PersistentPropertyPath<R
458478
);
459479
}
460480

461-
private String columnNameToParameterName(String columnName){
481+
private String columnNameToParameterName(String columnName) {
462482
return parameterPattern.matcher(columnName).replaceAll("");
463483
}
464484
}

spring-data-jdbc/src/test/java/org/springframework/data/jdbc/core/JdbcAggregateTemplateIntegrationTests.java

+155
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import java.util.HashSet;
2626
import java.util.List;
2727
import java.util.Set;
28+
import java.util.function.Function;
2829

2930
import org.assertj.core.api.SoftAssertions;
3031
import org.junit.Assume;
@@ -37,7 +38,9 @@
3738
import org.springframework.context.annotation.Bean;
3839
import org.springframework.context.annotation.Configuration;
3940
import org.springframework.context.annotation.Import;
41+
import org.springframework.dao.OptimisticLockingFailureException;
4042
import org.springframework.data.annotation.Id;
43+
import org.springframework.data.annotation.Version;
4144
import org.springframework.data.jdbc.testing.DatabaseProfileValueSource;
4245
import org.springframework.data.jdbc.testing.TestConfiguration;
4346
import org.springframework.data.relational.core.conversion.RelationalConverter;
@@ -58,6 +61,7 @@
5861
* @author Jens Schauder
5962
* @author Thomas Lang
6063
* @author Mark Paluch
64+
* @author Tom Hombergs
6165
*/
6266
@ContextConfiguration
6367
@Transactional
@@ -434,6 +438,60 @@ public void saveAndLoadAnEntityWithByteArray() {
434438
assertThat(reloaded.binaryData).isEqualTo(new byte[] { 1, 23, 42 });
435439
}
436440

441+
@Test // DATAJDBC-219
442+
public void saveAndUpdateAggregateWithLongVersion() {
443+
saveAndUpdateAggregateWithVersion(new AggregateWithLongVersion(), Number::longValue);
444+
}
445+
446+
@Test // DATAJDBC-219
447+
public void saveAndUpdateAggregateWithPrimitiveLongVersion() {
448+
saveAndUpdateAggregateWithVersion(new AggregateWithPrimitiveLongVersion(), Number::longValue);
449+
}
450+
451+
@Test // DATAJDBC-219
452+
public void saveAndUpdateAggregateWithIntegerVersion() {
453+
saveAndUpdateAggregateWithVersion(new AggregateWithIntegerVersion(), Number::intValue);
454+
}
455+
456+
@Test // DATAJDBC-219
457+
public void saveAndUpdateAggregateWithPrimitiveIntegerVersion() {
458+
saveAndUpdateAggregateWithVersion(new AggregateWithPrimitiveIntegerVersion(), Number::intValue);
459+
}
460+
461+
@Test // DATAJDBC-219
462+
public void saveAndUpdateAggregateWithShortVersion() {
463+
saveAndUpdateAggregateWithVersion(new AggregateWithShortVersion(), Number::shortValue);
464+
}
465+
466+
@Test // DATAJDBC-219
467+
public void saveAndUpdateAggregateWithPrimitiveShortVersion() {
468+
saveAndUpdateAggregateWithVersion(new AggregateWithPrimitiveShortVersion(), Number::shortValue);
469+
}
470+
471+
private <T extends Number> void saveAndUpdateAggregateWithVersion(VersionedAggregate aggregate, Function<Number, T> toConcreteNumber) {
472+
473+
template.save(aggregate);
474+
475+
VersionedAggregate reloadedAggregate = template.findById(aggregate.getId(), aggregate.getClass());
476+
assertThat(reloadedAggregate.getVersion()).isEqualTo(toConcreteNumber.apply(1))
477+
.withFailMessage("version field should initially have the value 1");
478+
template.save(reloadedAggregate);
479+
480+
VersionedAggregate updatedAggregate = template.findById(aggregate.getId(), aggregate.getClass());
481+
assertThat(updatedAggregate.getVersion()).isEqualTo(toConcreteNumber.apply(2))
482+
.withFailMessage("version field should increment by one with each save");
483+
484+
reloadedAggregate.setVersion(toConcreteNumber.apply(1));
485+
assertThatThrownBy(() -> template.save(reloadedAggregate))
486+
.hasRootCauseInstanceOf(OptimisticLockingFailureException.class)
487+
.withFailMessage("saving an aggregate with an outdated version should raise an exception");
488+
489+
reloadedAggregate.setVersion(toConcreteNumber.apply(3));
490+
assertThatThrownBy(() -> template.save(reloadedAggregate))
491+
.hasRootCauseInstanceOf(OptimisticLockingFailureException.class)
492+
.withFailMessage("saving an aggregate with a future version should raise an exception");
493+
}
494+
437495
private static void assumeNot(String dbProfileName) {
438496

439497
Assume.assumeTrue("true"
@@ -522,6 +580,103 @@ static class ElementNoId {
522580
private String content;
523581
}
524582

583+
@Data
584+
static abstract class VersionedAggregate {
585+
586+
@Id private Long id;
587+
588+
abstract Number getVersion();
589+
590+
abstract void setVersion(Number newVersion);
591+
}
592+
593+
@Data
594+
@Table("VERSIONED_AGGREGATE")
595+
static class AggregateWithLongVersion extends VersionedAggregate {
596+
597+
@Version private Long version;
598+
599+
@Override
600+
void setVersion(Number newVersion) {
601+
this.version = (Long) newVersion;
602+
}
603+
}
604+
605+
@Table("VERSIONED_AGGREGATE")
606+
static class AggregateWithPrimitiveLongVersion extends VersionedAggregate {
607+
608+
@Version private long version;
609+
610+
@Override
611+
void setVersion(Number newVersion) {
612+
this.version = (long) newVersion;
613+
}
614+
615+
@Override
616+
Number getVersion(){
617+
return this.version;
618+
}
619+
}
620+
621+
@Data
622+
@Table("VERSIONED_AGGREGATE")
623+
static class AggregateWithIntegerVersion extends VersionedAggregate {
624+
625+
@Version private Integer version;
626+
627+
@Override
628+
void setVersion(Number newVersion) {
629+
this.version = (Integer) newVersion;
630+
}
631+
}
632+
633+
@Table("VERSIONED_AGGREGATE")
634+
static class AggregateWithPrimitiveIntegerVersion extends VersionedAggregate {
635+
636+
@Version private int version;
637+
638+
@Override
639+
void setVersion(Number newVersion) {
640+
this.version = (int) newVersion;
641+
}
642+
643+
@Override
644+
Number getVersion(){
645+
return this.version;
646+
}
647+
}
648+
649+
@Data
650+
@Table("VERSIONED_AGGREGATE")
651+
static class AggregateWithShortVersion extends VersionedAggregate {
652+
653+
@Version private Short version;
654+
655+
@Override
656+
void setVersion(Number newVersion) {
657+
this.version = (Short) newVersion;
658+
}
659+
}
660+
661+
@Table("VERSIONED_AGGREGATE")
662+
static class AggregateWithPrimitiveShortVersion extends VersionedAggregate {
663+
664+
@Version
665+
private short version;
666+
667+
@Override
668+
void setVersion(Number newVersion) {
669+
this.version = (short) newVersion;
670+
}
671+
672+
@Override
673+
Number getVersion(){
674+
return this.version;
675+
}
676+
677+
}
678+
679+
525680
@Configuration
526681
@Import(TestConfiguration.class)
527682
static class Config {

0 commit comments

Comments
 (0)