Skip to content

Commit 0fa7bb4

Browse files
spring-projects#93 - Adding support for optimistic locking based on @Version column
1 parent e56f126 commit 0fa7bb4

File tree

8 files changed

+224
-18
lines changed

8 files changed

+224
-18
lines changed

src/main/java/org/springframework/data/r2dbc/core/DatabaseClient.java

+3-2
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
* to create an instance.
4747
*
4848
* @author Mark Paluch
49+
* @author Bogdan Ilchyshyn
4950
*/
5051
public interface DatabaseClient {
5152

@@ -707,9 +708,9 @@ interface TypedUpdateSpec<T> {
707708
*
708709
* @param objectToUpdate the object of which the attributes will provide the values for the update and the primary
709710
* key. Must not be {@literal null}.
710-
* @return a {@link UpdateSpec} for further configuration of the update. Guaranteed to be not {@literal null}.
711+
* @return a {@link UpdateMatchingSpec} for further configuration of the update. Guaranteed to be not {@literal null}.
711712
*/
712-
UpdateSpec using(T objectToUpdate);
713+
UpdateMatchingSpec using(T objectToUpdate);
713714

714715
/**
715716
* Use the given {@code tableName} as update target.

src/main/java/org/springframework/data/r2dbc/core/DefaultDatabaseClient.java

+24-7
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
* Default implementation of {@link DatabaseClient}.
7070
*
7171
* @author Mark Paluch
72+
* @author Bogdan Ilchyshyn
7273
*/
7374
class DefaultDatabaseClient implements DatabaseClient, ConnectionAccessor {
7475

@@ -1188,7 +1189,7 @@ public <T> TypedUpdateSpec<T> table(Class<T> table) {
11881189

11891190
assertRegularClass(table);
11901191

1191-
return new DefaultTypedUpdateSpec<>(table, null, null);
1192+
return new DefaultTypedUpdateSpec<>(table, null, null, null);
11921193
}
11931194
}
11941195

@@ -1262,32 +1263,43 @@ private UpdatedRowsFetchSpec exchange(SqlIdentifier table) {
12621263
}
12631264
}
12641265

1265-
class DefaultTypedUpdateSpec<T> implements TypedUpdateSpec<T>, UpdateSpec {
1266+
class DefaultTypedUpdateSpec<T> implements TypedUpdateSpec<T>, UpdateMatchingSpec {
12661267

12671268
private final @Nullable Class<T> typeToUpdate;
12681269
private final @Nullable SqlIdentifier table;
12691270
private final T objectToUpdate;
1271+
private final @Nullable Criteria where;
12701272

1271-
DefaultTypedUpdateSpec(@Nullable Class<T> typeToUpdate, @Nullable SqlIdentifier table, T objectToUpdate) {
1273+
DefaultTypedUpdateSpec(@Nullable Class<T> typeToUpdate, @Nullable SqlIdentifier table, T objectToUpdate,
1274+
@Nullable Criteria where) {
12721275
this.typeToUpdate = typeToUpdate;
12731276
this.table = table;
12741277
this.objectToUpdate = objectToUpdate;
1278+
this.where = where;
12751279
}
12761280

12771281
@Override
1278-
public UpdateSpec using(T objectToUpdate) {
1282+
public UpdateMatchingSpec using(T objectToUpdate) {
12791283

12801284
Assert.notNull(objectToUpdate, "Object to update must not be null");
12811285

1282-
return new DefaultTypedUpdateSpec<>(this.typeToUpdate, this.table, objectToUpdate);
1286+
return new DefaultTypedUpdateSpec<>(this.typeToUpdate, this.table, objectToUpdate, this.where);
12831287
}
12841288

12851289
@Override
12861290
public TypedUpdateSpec<T> table(SqlIdentifier tableName) {
12871291

12881292
Assert.notNull(tableName, "Table name must not be null!");
12891293

1290-
return new DefaultTypedUpdateSpec<>(this.typeToUpdate, tableName, this.objectToUpdate);
1294+
return new DefaultTypedUpdateSpec<>(this.typeToUpdate, tableName, this.objectToUpdate, this.where);
1295+
}
1296+
1297+
@Override
1298+
public UpdateSpec matching(Criteria criteria) {
1299+
1300+
Assert.notNull(criteria, "Criteria must not be null!");
1301+
1302+
return new DefaultTypedUpdateSpec<>(this.typeToUpdate, this.table, this.objectToUpdate, criteria);
12911303
}
12921304

12931305
@Override
@@ -1330,8 +1342,13 @@ private UpdatedRowsFetchSpec exchange(SqlIdentifier table) {
13301342
}
13311343
}
13321344

1345+
Criteria updateCriteria = Criteria.where(dataAccessStrategy.toSql(ids.get(0))).is(id);
1346+
if (this.where != null) {
1347+
updateCriteria = updateCriteria.and(this.where);
1348+
}
1349+
13331350
PreparedOperation<?> operation = mapper.getMappedObject(
1334-
mapper.createUpdate(table, update).withCriteria(Criteria.where(dataAccessStrategy.toSql(ids.get(0))).is(id)));
1351+
mapper.createUpdate(table, update).withCriteria(updateCriteria));
13351352

13361353
return exchangeUpdate(operation);
13371354
}

src/main/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplate.java

+84-8
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import io.r2dbc.spi.Row;
1919
import io.r2dbc.spi.RowMetadata;
20+
import org.springframework.dao.OptimisticLockingFailureException;
2021
import reactor.core.publisher.Flux;
2122
import reactor.core.publisher.Mono;
2223

@@ -30,10 +31,12 @@
3031
import org.springframework.beans.BeansException;
3132
import org.springframework.beans.factory.BeanFactory;
3233
import org.springframework.beans.factory.BeanFactoryAware;
34+
import org.springframework.core.convert.ConversionService;
3335
import org.springframework.dao.DataAccessException;
3436
import org.springframework.dao.TransientDataAccessResourceException;
3537
import org.springframework.data.mapping.IdentifierAccessor;
3638
import org.springframework.data.mapping.MappingException;
39+
import org.springframework.data.mapping.PersistentPropertyAccessor;
3740
import org.springframework.data.mapping.context.MappingContext;
3841
import org.springframework.data.projection.ProjectionInformation;
3942
import org.springframework.data.projection.SpelAwareProxyProjectionFactory;
@@ -58,6 +61,7 @@
5861
* prepared in an application context and given to services as bean reference.
5962
*
6063
* @author Mark Paluch
64+
* @author Bogdan Ilchyshyn
6165
* @since 1.1
6266
*/
6367
public class R2dbcEntityTemplate implements R2dbcEntityOperations, BeanFactoryAware {
@@ -372,6 +376,8 @@ <T> Mono<T> doInsert(T entity, SqlIdentifier tableName) {
372376

373377
RelationalPersistentEntity<T> persistentEntity = getRequiredEntity(entity);
374378

379+
setVersionIfNecessary(persistentEntity, entity);
380+
375381
return this.databaseClient.insert() //
376382
.into(persistentEntity.getType()) //
377383
.table(tableName).using(entity) //
@@ -380,6 +386,19 @@ <T> Mono<T> doInsert(T entity, SqlIdentifier tableName) {
380386
.defaultIfEmpty(entity);
381387
}
382388

389+
private <T> void setVersionIfNecessary(RelationalPersistentEntity<T> persistentEntity, T entity) {
390+
RelationalPersistentProperty versionProperty = persistentEntity.getVersionProperty();
391+
if (versionProperty == null) {
392+
return;
393+
}
394+
395+
Class<?> versionPropertyType = versionProperty.getType();
396+
Long version = versionPropertyType.isPrimitive() ? 1L : 0L;
397+
ConversionService conversionService = this.dataAccessStrategy.getConverter().getConversionService();
398+
PersistentPropertyAccessor<?> propertyAccessor = persistentEntity.getPropertyAccessor(entity);
399+
propertyAccessor.setProperty(versionProperty, conversionService.convert(version, versionPropertyType));
400+
}
401+
383402
/*
384403
* (non-Javadoc)
385404
* @see org.springframework.data.r2dbc.core.R2dbcEntityOperations#update(java.lang.Object)
@@ -391,21 +410,78 @@ public <T> Mono<T> update(T entity) throws DataAccessException {
391410

392411
RelationalPersistentEntity<T> persistentEntity = getRequiredEntity(entity);
393412

394-
return this.databaseClient.update() //
413+
DatabaseClient.UpdateMatchingSpec updateMatchingSpec = this.databaseClient.update() //
395414
.table(persistentEntity.getType()) //
396-
.table(persistentEntity.getTableName()).using(entity) //
397-
.fetch().rowsUpdated().handle((rowsUpdated, sink) -> {
415+
.table(persistentEntity.getTableName()) //
416+
.using(entity);
417+
418+
DatabaseClient.UpdateSpec updateSpec = updateMatchingSpec;
419+
if (persistentEntity.hasVersionProperty()) {
420+
updateSpec = updateMatchingSpec.matching(createMatchingVersionCriteria(entity, persistentEntity));
421+
incrementVersion(entity, persistentEntity);
422+
}
423+
424+
return updateSpec.fetch() //
425+
.rowsUpdated() //
426+
.flatMap(rowsUpdated -> rowsUpdated == 0
427+
? handleMissingUpdate(entity, persistentEntity) : Mono.just(entity));
428+
}
398429

399-
if (rowsUpdated == 0) {
400-
sink.error(new TransientDataAccessResourceException(
401-
String.format("Failed to update table [%s]. Row with Id [%s] does not exist.",
402-
persistentEntity.getTableName(), persistentEntity.getIdentifierAccessor(entity).getIdentifier())));
430+
private <T> Mono<? extends T> handleMissingUpdate(T entity, RelationalPersistentEntity<T> persistentEntity) {
431+
if (!persistentEntity.hasVersionProperty()) {
432+
return Mono.error(new TransientDataAccessResourceException(
433+
formatTransientEntityExceptionMessage(entity, persistentEntity)));
434+
}
435+
436+
return doCount(getByIdQuery(entity, persistentEntity), entity.getClass(), persistentEntity.getTableName())
437+
.map(count -> {
438+
if (count == 0) {
439+
throw new TransientDataAccessResourceException(
440+
formatTransientEntityExceptionMessage(entity, persistentEntity));
403441
} else {
404-
sink.next(entity);
442+
throw new OptimisticLockingFailureException(
443+
formatOptimisticLockingExceptionMessage(entity, persistentEntity));
405444
}
406445
});
407446
}
408447

448+
private <T> String formatOptimisticLockingExceptionMessage(T entity, RelationalPersistentEntity<T> persistentEntity) {
449+
return String.format("Failed to update table [%s]. Version does not match for row with Id [%s].",
450+
persistentEntity.getTableName(), persistentEntity.getIdentifierAccessor(entity).getIdentifier());
451+
}
452+
453+
private <T> String formatTransientEntityExceptionMessage(T entity, RelationalPersistentEntity<T> persistentEntity) {
454+
return String.format("Failed to update table [%s]. Row with Id [%s] does not exist.",
455+
persistentEntity.getTableName(), persistentEntity.getIdentifierAccessor(entity).getIdentifier());
456+
}
457+
458+
private <T> void incrementVersion(T entity, RelationalPersistentEntity<T> persistentEntity) {
459+
PersistentPropertyAccessor<?> propertyAccessor = persistentEntity.getPropertyAccessor(entity);
460+
RelationalPersistentProperty versionProperty = persistentEntity.getVersionProperty();
461+
462+
ConversionService conversionService = this.dataAccessStrategy.getConverter().getConversionService();
463+
Object currentVersionValue = propertyAccessor.getProperty(versionProperty);
464+
long newVersionValue = 1L;
465+
if (currentVersionValue != null) {
466+
newVersionValue = conversionService.convert(currentVersionValue, Long.class) + 1;
467+
}
468+
Class<?> versionPropertyType = versionProperty.getType();
469+
propertyAccessor.setProperty(versionProperty, conversionService.convert(newVersionValue, versionPropertyType));
470+
}
471+
472+
private <T> Criteria createMatchingVersionCriteria(T entity, RelationalPersistentEntity<T> persistentEntity) {
473+
PersistentPropertyAccessor<?> propertyAccessor = persistentEntity.getPropertyAccessor(entity);
474+
RelationalPersistentProperty versionProperty = persistentEntity.getVersionProperty();
475+
476+
Object version = propertyAccessor.getProperty(versionProperty);
477+
Criteria.CriteriaStep versionColumn = Criteria.where(dataAccessStrategy.toSql(versionProperty.getColumnName()));
478+
if (version == null) {
479+
return versionColumn.isNull();
480+
} else {
481+
return versionColumn.is(version);
482+
}
483+
}
484+
409485
/*
410486
* (non-Javadoc)
411487
* @see org.springframework.data.r2dbc.core.R2dbcEntityOperations#delete(java.lang.Object)

src/test/java/org/springframework/data/r2dbc/repository/support/AbstractSimpleR2dbcRepositoryIntegrationTests.java

+101-1
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,11 @@
3333

3434
import org.junit.Before;
3535
import org.junit.Test;
36-
3736
import org.springframework.beans.factory.annotation.Autowired;
3837
import org.springframework.dao.DataAccessException;
38+
import org.springframework.dao.OptimisticLockingFailureException;
3939
import org.springframework.data.annotation.Id;
40+
import org.springframework.data.annotation.Version;
4041
import org.springframework.data.domain.Persistable;
4142
import org.springframework.data.r2dbc.convert.MappingR2dbcConverter;
4243
import org.springframework.data.r2dbc.core.DatabaseClient;
@@ -53,6 +54,7 @@
5354
* Abstract integration tests for {@link SimpleR2dbcRepository} to be ran against various databases.
5455
*
5556
* @author Mark Paluch
57+
* @author Bogdan Ilchyshyn
5658
*/
5759
public abstract class AbstractSimpleR2dbcRepositoryIntegrationTests extends R2dbcIntegrationTestSupport {
5860

@@ -117,6 +119,42 @@ public void shouldSaveNewObject() {
117119
assertThat(map).containsEntry("name", "SCHAUFELRADBAGGER").containsEntry("manual", 12).containsKey("id");
118120
}
119121

122+
@Test
123+
public void shouldSaveNewObjectAndSetVersionIfWrapperVersionPropertyExists() {
124+
125+
LegoSetVersionable legoSet = new LegoSetVersionable(null, "SCHAUFELRADBAGGER", 12, null);
126+
127+
repository.save(legoSet) //
128+
.as(StepVerifier::create) //
129+
.consumeNextWith(actual -> assertThat(actual.getVersion()).isEqualTo(0)) //
130+
.verifyComplete();
131+
132+
Map<String, Object> map = jdbc.queryForMap("SELECT * FROM legoset");
133+
assertThat(map) //
134+
.containsEntry("name", "SCHAUFELRADBAGGER") //
135+
.containsEntry("manual", 12) //
136+
.containsEntry("version", 0) //
137+
.containsKey("id");
138+
}
139+
140+
@Test
141+
public void shouldSaveNewObjectAndSetVersionIfPrimitiveVersionPropertyExists() {
142+
143+
LegoSetPrimitiveVersionable legoSet = new LegoSetPrimitiveVersionable(null, "SCHAUFELRADBAGGER", 12, -1);
144+
145+
repository.save(legoSet) //
146+
.as(StepVerifier::create) //
147+
.consumeNextWith(actual -> assertThat(actual.getVersion()).isEqualTo(1)) //
148+
.verifyComplete();
149+
150+
Map<String, Object> map = jdbc.queryForMap("SELECT * FROM legoset");
151+
assertThat(map) //
152+
.containsEntry("name", "SCHAUFELRADBAGGER") //
153+
.containsEntry("manual", 12) //
154+
.containsEntry("version", 1) //
155+
.containsKey("id");
156+
}
157+
120158
@Test
121159
public void shouldUpdateObject() {
122160

@@ -135,6 +173,44 @@ public void shouldUpdateObject() {
135173
assertThat(map).containsEntry("name", "SCHAUFELRADBAGGER").containsEntry("manual", 14).containsKey("id");
136174
}
137175

176+
@Test
177+
public void shouldUpdateVersionableObjectAndIncreaseVersion() {
178+
179+
jdbc.execute("INSERT INTO legoset (name, manual, version) VALUES('SCHAUFELRADBAGGER', 12, 42)");
180+
Integer id = jdbc.queryForObject("SELECT id FROM legoset", Integer.class);
181+
182+
LegoSetVersionable legoSet = new LegoSetVersionable(id, "SCHAUFELRADBAGGER", 12, 42);
183+
legoSet.setManual(14);
184+
185+
repository.save(legoSet) //
186+
.as(StepVerifier::create) //
187+
.expectNextCount(1) //
188+
.verifyComplete();
189+
190+
assertThat(legoSet.getVersion()).isEqualTo(43);
191+
192+
Map<String, Object> map = jdbc.queryForMap("SELECT * FROM legoset");
193+
assertThat(map)
194+
.containsEntry("name", "SCHAUFELRADBAGGER") //
195+
.containsEntry("manual", 14) //
196+
.containsEntry("version", 43) //
197+
.containsKey("id");
198+
}
199+
200+
@Test
201+
public void shouldFailWithOptimistickLockingWhenVersionDoesNotMatchOnUpdate() {
202+
203+
jdbc.execute("INSERT INTO legoset (name, manual, version) VALUES('SCHAUFELRADBAGGER', 12, 42)");
204+
Integer id = jdbc.queryForObject("SELECT id FROM legoset", Integer.class);
205+
206+
LegoSetVersionable legoSet = new LegoSetVersionable(id, "SCHAUFELRADBAGGER", 12, 0);
207+
208+
repository.save(legoSet) //
209+
.as(StepVerifier::create) //
210+
.expectError(OptimisticLockingFailureException.class) //
211+
.verify();
212+
}
213+
138214
@Test
139215
public void shouldSaveObjectsUsingIterable() {
140216

@@ -392,4 +468,28 @@ public boolean isNew() {
392468
return true;
393469
}
394470
}
471+
472+
@Data
473+
@Table("legoset")
474+
@NoArgsConstructor
475+
static class LegoSetVersionable extends LegoSet {
476+
@Version Integer version;
477+
478+
public LegoSetVersionable(Integer id, String name, Integer manual, Integer version) {
479+
super(id, name, manual);
480+
this.version = version;
481+
}
482+
}
483+
484+
@Data
485+
@Table("legoset")
486+
@NoArgsConstructor
487+
static class LegoSetPrimitiveVersionable extends LegoSet {
488+
@Version int version;
489+
490+
public LegoSetPrimitiveVersionable(Integer id, String name, Integer manual, int version) {
491+
super(id, name, manual);
492+
this.version = version;
493+
}
494+
}
395495
}

0 commit comments

Comments
 (0)