Skip to content

Commit bb24a1c

Browse files
spring-projects#93 - Adding support for optimistic locking based on @Version column
1 parent edc6d9b commit bb24a1c

File tree

8 files changed

+224
-18
lines changed

8 files changed

+224
-18
lines changed

src/main/java/org/springframework/data/r2dbc/core/DatabaseClient.java

+3-2
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
* to create an instance.
4949
*
5050
* @author Mark Paluch
51+
* @author Bogdan Ilchyshyn
5152
*/
5253
public interface DatabaseClient {
5354

@@ -719,9 +720,9 @@ interface TypedUpdateSpec<T> {
719720
*
720721
* @param objectToUpdate the object of which the attributes will provide the values for the update and the primary
721722
* key. Must not be {@literal null}.
722-
* @return a {@link UpdateSpec} for further configuration of the update. Guaranteed to be not {@literal null}.
723+
* @return a {@link UpdateMatchingSpec} for further configuration of the update. Guaranteed to be not {@literal null}.
723724
*/
724-
UpdateSpec using(T objectToUpdate);
725+
UpdateMatchingSpec using(T objectToUpdate);
725726

726727
/**
727728
* Use the given {@code tableName} as update target.

src/main/java/org/springframework/data/r2dbc/core/DefaultDatabaseClient.java

+24-7
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
* Default implementation of {@link DatabaseClient}.
7070
*
7171
* @author Mark Paluch
72+
* @author Bogdan Ilchyshyn
7273
*/
7374
class DefaultDatabaseClient implements DatabaseClient, ConnectionAccessor {
7475

@@ -1196,7 +1197,7 @@ public <T> TypedUpdateSpec<T> table(Class<T> table) {
11961197

11971198
assertRegularClass(table);
11981199

1199-
return new DefaultTypedUpdateSpec<>(table, null, null);
1200+
return new DefaultTypedUpdateSpec<>(table, null, null, null);
12001201
}
12011202
}
12021203

@@ -1275,32 +1276,43 @@ private UpdatedRowsFetchSpec exchange(SqlIdentifier table) {
12751276
}
12761277
}
12771278

1278-
class DefaultTypedUpdateSpec<T> implements TypedUpdateSpec<T>, UpdateSpec {
1279+
class DefaultTypedUpdateSpec<T> implements TypedUpdateSpec<T>, UpdateMatchingSpec {
12791280

12801281
private final Class<T> typeToUpdate;
12811282
private final @Nullable SqlIdentifier table;
12821283
private final @Nullable T objectToUpdate;
1284+
private final @Nullable Criteria where;
12831285

1284-
DefaultTypedUpdateSpec(Class<T> typeToUpdate, @Nullable SqlIdentifier table, @Nullable T objectToUpdate) {
1286+
DefaultTypedUpdateSpec(Class<T> typeToUpdate, @Nullable SqlIdentifier table, @Nullable T objectToUpdate,
1287+
@Nullable Criteria where) {
12851288
this.typeToUpdate = typeToUpdate;
12861289
this.table = table;
12871290
this.objectToUpdate = objectToUpdate;
1291+
this.where = where;
12881292
}
12891293

12901294
@Override
1291-
public UpdateSpec using(T objectToUpdate) {
1295+
public UpdateMatchingSpec using(T objectToUpdate) {
12921296

12931297
Assert.notNull(objectToUpdate, "Object to update must not be null");
12941298

1295-
return new DefaultTypedUpdateSpec<>(this.typeToUpdate, this.table, objectToUpdate);
1299+
return new DefaultTypedUpdateSpec<>(this.typeToUpdate, this.table, objectToUpdate, this.where);
12961300
}
12971301

12981302
@Override
12991303
public TypedUpdateSpec<T> table(SqlIdentifier tableName) {
13001304

13011305
Assert.notNull(tableName, "Table name must not be null!");
13021306

1303-
return new DefaultTypedUpdateSpec<>(this.typeToUpdate, tableName, this.objectToUpdate);
1307+
return new DefaultTypedUpdateSpec<>(this.typeToUpdate, tableName, this.objectToUpdate, this.where);
1308+
}
1309+
1310+
@Override
1311+
public UpdateSpec matching(Criteria criteria) {
1312+
1313+
Assert.notNull(criteria, "Criteria must not be null!");
1314+
1315+
return new DefaultTypedUpdateSpec<>(this.typeToUpdate, this.table, this.objectToUpdate, criteria);
13041316
}
13051317

13061318
@Override
@@ -1343,8 +1355,13 @@ private UpdatedRowsFetchSpec exchange(SqlIdentifier table) {
13431355
}
13441356
}
13451357

1358+
Criteria updateCriteria = Criteria.where(dataAccessStrategy.toSql(ids.get(0))).is(id);
1359+
if (this.where != null) {
1360+
updateCriteria = updateCriteria.and(this.where);
1361+
}
1362+
13461363
PreparedOperation<?> operation = mapper.getMappedObject(
1347-
mapper.createUpdate(table, update).withCriteria(Criteria.where(dataAccessStrategy.toSql(ids.get(0))).is(id)));
1364+
mapper.createUpdate(table, update).withCriteria(updateCriteria));
13481365

13491366
return exchangeUpdate(operation);
13501367
}

src/main/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplate.java

+84-8
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import io.r2dbc.spi.Row;
1919
import io.r2dbc.spi.RowMetadata;
20+
import org.springframework.dao.OptimisticLockingFailureException;
2021
import reactor.core.publisher.Flux;
2122
import reactor.core.publisher.Mono;
2223

@@ -30,10 +31,12 @@
3031
import org.springframework.beans.BeansException;
3132
import org.springframework.beans.factory.BeanFactory;
3233
import org.springframework.beans.factory.BeanFactoryAware;
34+
import org.springframework.core.convert.ConversionService;
3335
import org.springframework.dao.DataAccessException;
3436
import org.springframework.dao.TransientDataAccessResourceException;
3537
import org.springframework.data.mapping.IdentifierAccessor;
3638
import org.springframework.data.mapping.MappingException;
39+
import org.springframework.data.mapping.PersistentPropertyAccessor;
3740
import org.springframework.data.mapping.context.MappingContext;
3841
import org.springframework.data.projection.ProjectionInformation;
3942
import org.springframework.data.projection.SpelAwareProxyProjectionFactory;
@@ -58,6 +61,7 @@
5861
* prepared in an application context and given to services as bean reference.
5962
*
6063
* @author Mark Paluch
64+
* @author Bogdan Ilchyshyn
6165
* @since 1.1
6266
*/
6367
public class R2dbcEntityTemplate implements R2dbcEntityOperations, BeanFactoryAware {
@@ -372,6 +376,8 @@ <T> Mono<T> doInsert(T entity, SqlIdentifier tableName) {
372376

373377
RelationalPersistentEntity<T> persistentEntity = getRequiredEntity(entity);
374378

379+
setVersionIfNecessary(persistentEntity, entity);
380+
375381
return this.databaseClient.insert() //
376382
.into(persistentEntity.getType()) //
377383
.table(tableName).using(entity) //
@@ -380,6 +386,19 @@ <T> Mono<T> doInsert(T entity, SqlIdentifier tableName) {
380386
.defaultIfEmpty(entity);
381387
}
382388

389+
private <T> void setVersionIfNecessary(RelationalPersistentEntity<T> persistentEntity, T entity) {
390+
RelationalPersistentProperty versionProperty = persistentEntity.getVersionProperty();
391+
if (versionProperty == null) {
392+
return;
393+
}
394+
395+
Class<?> versionPropertyType = versionProperty.getType();
396+
Long version = versionPropertyType.isPrimitive() ? 1L : 0L;
397+
ConversionService conversionService = this.dataAccessStrategy.getConverter().getConversionService();
398+
PersistentPropertyAccessor<?> propertyAccessor = persistentEntity.getPropertyAccessor(entity);
399+
propertyAccessor.setProperty(versionProperty, conversionService.convert(version, versionPropertyType));
400+
}
401+
383402
/*
384403
* (non-Javadoc)
385404
* @see org.springframework.data.r2dbc.core.R2dbcEntityOperations#update(java.lang.Object)
@@ -391,21 +410,78 @@ public <T> Mono<T> update(T entity) throws DataAccessException {
391410

392411
RelationalPersistentEntity<T> persistentEntity = getRequiredEntity(entity);
393412

394-
return this.databaseClient.update() //
413+
DatabaseClient.UpdateMatchingSpec updateMatchingSpec = this.databaseClient.update() //
395414
.table(persistentEntity.getType()) //
396-
.table(persistentEntity.getTableName()).using(entity) //
397-
.fetch().rowsUpdated().handle((rowsUpdated, sink) -> {
415+
.table(persistentEntity.getTableName()) //
416+
.using(entity);
417+
418+
DatabaseClient.UpdateSpec updateSpec = updateMatchingSpec;
419+
if (persistentEntity.hasVersionProperty()) {
420+
updateSpec = updateMatchingSpec.matching(createMatchingVersionCriteria(entity, persistentEntity));
421+
incrementVersion(entity, persistentEntity);
422+
}
423+
424+
return updateSpec.fetch() //
425+
.rowsUpdated() //
426+
.flatMap(rowsUpdated -> rowsUpdated == 0
427+
? handleMissingUpdate(entity, persistentEntity) : Mono.just(entity));
428+
}
398429

399-
if (rowsUpdated == 0) {
400-
sink.error(new TransientDataAccessResourceException(
401-
String.format("Failed to update table [%s]. Row with Id [%s] does not exist.",
402-
persistentEntity.getTableName(), persistentEntity.getIdentifierAccessor(entity).getIdentifier())));
430+
private <T> Mono<? extends T> handleMissingUpdate(T entity, RelationalPersistentEntity<T> persistentEntity) {
431+
if (!persistentEntity.hasVersionProperty()) {
432+
return Mono.error(new TransientDataAccessResourceException(
433+
formatTransientEntityExceptionMessage(entity, persistentEntity)));
434+
}
435+
436+
return doCount(getByIdQuery(entity, persistentEntity), entity.getClass(), persistentEntity.getTableName())
437+
.map(count -> {
438+
if (count == 0) {
439+
throw new TransientDataAccessResourceException(
440+
formatTransientEntityExceptionMessage(entity, persistentEntity));
403441
} else {
404-
sink.next(entity);
442+
throw new OptimisticLockingFailureException(
443+
formatOptimisticLockingExceptionMessage(entity, persistentEntity));
405444
}
406445
});
407446
}
408447

448+
private <T> String formatOptimisticLockingExceptionMessage(T entity, RelationalPersistentEntity<T> persistentEntity) {
449+
return String.format("Failed to update table [%s]. Version does not match for row with Id [%s].",
450+
persistentEntity.getTableName(), persistentEntity.getIdentifierAccessor(entity).getIdentifier());
451+
}
452+
453+
private <T> String formatTransientEntityExceptionMessage(T entity, RelationalPersistentEntity<T> persistentEntity) {
454+
return String.format("Failed to update table [%s]. Row with Id [%s] does not exist.",
455+
persistentEntity.getTableName(), persistentEntity.getIdentifierAccessor(entity).getIdentifier());
456+
}
457+
458+
private <T> void incrementVersion(T entity, RelationalPersistentEntity<T> persistentEntity) {
459+
PersistentPropertyAccessor<?> propertyAccessor = persistentEntity.getPropertyAccessor(entity);
460+
RelationalPersistentProperty versionProperty = persistentEntity.getVersionProperty();
461+
462+
ConversionService conversionService = this.dataAccessStrategy.getConverter().getConversionService();
463+
Object currentVersionValue = propertyAccessor.getProperty(versionProperty);
464+
long newVersionValue = 1L;
465+
if (currentVersionValue != null) {
466+
newVersionValue = conversionService.convert(currentVersionValue, Long.class) + 1;
467+
}
468+
Class<?> versionPropertyType = versionProperty.getType();
469+
propertyAccessor.setProperty(versionProperty, conversionService.convert(newVersionValue, versionPropertyType));
470+
}
471+
472+
private <T> Criteria createMatchingVersionCriteria(T entity, RelationalPersistentEntity<T> persistentEntity) {
473+
PersistentPropertyAccessor<?> propertyAccessor = persistentEntity.getPropertyAccessor(entity);
474+
RelationalPersistentProperty versionProperty = persistentEntity.getVersionProperty();
475+
476+
Object version = propertyAccessor.getProperty(versionProperty);
477+
Criteria.CriteriaStep versionColumn = Criteria.where(dataAccessStrategy.toSql(versionProperty.getColumnName()));
478+
if (version == null) {
479+
return versionColumn.isNull();
480+
} else {
481+
return versionColumn.is(version);
482+
}
483+
}
484+
409485
/*
410486
* (non-Javadoc)
411487
* @see org.springframework.data.r2dbc.core.R2dbcEntityOperations#delete(java.lang.Object)

src/test/java/org/springframework/data/r2dbc/repository/support/AbstractSimpleR2dbcRepositoryIntegrationTests.java

+101-1
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,11 @@
3333

3434
import org.junit.Before;
3535
import org.junit.Test;
36-
3736
import org.springframework.beans.factory.annotation.Autowired;
3837
import org.springframework.dao.DataAccessException;
38+
import org.springframework.dao.OptimisticLockingFailureException;
3939
import org.springframework.data.annotation.Id;
40+
import org.springframework.data.annotation.Version;
4041
import org.springframework.data.domain.Persistable;
4142
import org.springframework.data.r2dbc.convert.MappingR2dbcConverter;
4243
import org.springframework.data.r2dbc.core.DatabaseClient;
@@ -53,6 +54,7 @@
5354
* Abstract integration tests for {@link SimpleR2dbcRepository} to be ran against various databases.
5455
*
5556
* @author Mark Paluch
57+
* @author Bogdan Ilchyshyn
5658
*/
5759
public abstract class AbstractSimpleR2dbcRepositoryIntegrationTests extends R2dbcIntegrationTestSupport {
5860

@@ -117,6 +119,42 @@ public void shouldSaveNewObject() {
117119
assertThat(map).containsEntry("name", "SCHAUFELRADBAGGER").containsEntry("manual", 12).containsKey("id");
118120
}
119121

122+
@Test
123+
public void shouldSaveNewObjectAndSetVersionIfWrapperVersionPropertyExists() {
124+
125+
LegoSetVersionable legoSet = new LegoSetVersionable(null, "SCHAUFELRADBAGGER", 12, null);
126+
127+
repository.save(legoSet) //
128+
.as(StepVerifier::create) //
129+
.consumeNextWith(actual -> assertThat(actual.getVersion()).isEqualTo(0)) //
130+
.verifyComplete();
131+
132+
Map<String, Object> map = jdbc.queryForMap("SELECT * FROM legoset");
133+
assertThat(map) //
134+
.containsEntry("name", "SCHAUFELRADBAGGER") //
135+
.containsEntry("manual", 12) //
136+
.containsEntry("version", 0) //
137+
.containsKey("id");
138+
}
139+
140+
@Test
141+
public void shouldSaveNewObjectAndSetVersionIfPrimitiveVersionPropertyExists() {
142+
143+
LegoSetPrimitiveVersionable legoSet = new LegoSetPrimitiveVersionable(null, "SCHAUFELRADBAGGER", 12, -1);
144+
145+
repository.save(legoSet) //
146+
.as(StepVerifier::create) //
147+
.consumeNextWith(actual -> assertThat(actual.getVersion()).isEqualTo(1)) //
148+
.verifyComplete();
149+
150+
Map<String, Object> map = jdbc.queryForMap("SELECT * FROM legoset");
151+
assertThat(map) //
152+
.containsEntry("name", "SCHAUFELRADBAGGER") //
153+
.containsEntry("manual", 12) //
154+
.containsEntry("version", 1) //
155+
.containsKey("id");
156+
}
157+
120158
@Test
121159
public void shouldUpdateObject() {
122160

@@ -135,6 +173,44 @@ public void shouldUpdateObject() {
135173
assertThat(map).containsEntry("name", "SCHAUFELRADBAGGER").containsEntry("manual", 14).containsKey("id");
136174
}
137175

176+
@Test
177+
public void shouldUpdateVersionableObjectAndIncreaseVersion() {
178+
179+
jdbc.execute("INSERT INTO legoset (name, manual, version) VALUES('SCHAUFELRADBAGGER', 12, 42)");
180+
Integer id = jdbc.queryForObject("SELECT id FROM legoset", Integer.class);
181+
182+
LegoSetVersionable legoSet = new LegoSetVersionable(id, "SCHAUFELRADBAGGER", 12, 42);
183+
legoSet.setManual(14);
184+
185+
repository.save(legoSet) //
186+
.as(StepVerifier::create) //
187+
.expectNextCount(1) //
188+
.verifyComplete();
189+
190+
assertThat(legoSet.getVersion()).isEqualTo(43);
191+
192+
Map<String, Object> map = jdbc.queryForMap("SELECT * FROM legoset");
193+
assertThat(map)
194+
.containsEntry("name", "SCHAUFELRADBAGGER") //
195+
.containsEntry("manual", 14) //
196+
.containsEntry("version", 43) //
197+
.containsKey("id");
198+
}
199+
200+
@Test
201+
public void shouldFailWithOptimistickLockingWhenVersionDoesNotMatchOnUpdate() {
202+
203+
jdbc.execute("INSERT INTO legoset (name, manual, version) VALUES('SCHAUFELRADBAGGER', 12, 42)");
204+
Integer id = jdbc.queryForObject("SELECT id FROM legoset", Integer.class);
205+
206+
LegoSetVersionable legoSet = new LegoSetVersionable(id, "SCHAUFELRADBAGGER", 12, 0);
207+
208+
repository.save(legoSet) //
209+
.as(StepVerifier::create) //
210+
.expectError(OptimisticLockingFailureException.class) //
211+
.verify();
212+
}
213+
138214
@Test
139215
public void shouldSaveObjectsUsingIterable() {
140216

@@ -392,4 +468,28 @@ public boolean isNew() {
392468
return true;
393469
}
394470
}
471+
472+
@Data
473+
@Table("legoset")
474+
@NoArgsConstructor
475+
static class LegoSetVersionable extends LegoSet {
476+
@Version Integer version;
477+
478+
public LegoSetVersionable(Integer id, String name, Integer manual, Integer version) {
479+
super(id, name, manual);
480+
this.version = version;
481+
}
482+
}
483+
484+
@Data
485+
@Table("legoset")
486+
@NoArgsConstructor
487+
static class LegoSetPrimitiveVersionable extends LegoSet {
488+
@Version int version;
489+
490+
public LegoSetPrimitiveVersionable(Integer id, String name, Integer manual, int version) {
491+
super(id, name, manual);
492+
this.version = version;
493+
}
494+
}
395495
}

0 commit comments

Comments
 (0)