Skip to content

Commit 8e6797d

Browse files
orange-buffaloschauder
authored andcommitted
#93 - Adding support for optimistic locking based on @Version column.
Original pull request: #314.
1 parent 75e2ba3 commit 8e6797d

File tree

8 files changed

+227
-20
lines changed

8 files changed

+227
-20
lines changed

src/main/java/org/springframework/data/r2dbc/core/DatabaseClient.java

+3-2
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,7 @@
4848
* to create an instance.
4949
*
5050
* @author Mark Paluch
51+
* @author Bogdan Ilchyshyn
5152
*/
5253
public interface DatabaseClient {
5354

@@ -729,9 +730,9 @@ interface TypedUpdateSpec<T> {
729730
*
730731
* @param objectToUpdate the object of which the attributes will provide the values for the update and the primary
731732
* key. Must not be {@literal null}.
732-
* @return a {@link UpdateSpec} for further configuration of the update. Guaranteed to be not {@literal null}.
733+
* @return a {@link UpdateMatchingSpec} for further configuration of the update. Guaranteed to be not {@literal null}.
733734
*/
734-
UpdateSpec using(T objectToUpdate);
735+
UpdateMatchingSpec using(T objectToUpdate);
735736

736737
/**
737738
* Use the given {@code tableName} as update target.

src/main/java/org/springframework/data/r2dbc/core/DefaultDatabaseClient.java

+27-9
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@
4444
import org.apache.commons.logging.Log;
4545
import org.apache.commons.logging.LogFactory;
4646
import org.reactivestreams.Publisher;
47-
4847
import org.springframework.dao.DataAccessException;
4948
import org.springframework.dao.InvalidDataAccessApiUsageException;
5049
import org.springframework.data.domain.Pageable;
@@ -59,6 +58,7 @@
5958
import org.springframework.data.r2dbc.mapping.SettableValue;
6059
import org.springframework.data.r2dbc.query.Update;
6160
import org.springframework.data.r2dbc.support.R2dbcExceptionTranslator;
61+
import org.springframework.data.relational.core.query.Criteria;
6262
import org.springframework.data.relational.core.query.CriteriaDefinition;
6363
import org.springframework.data.relational.core.sql.SqlIdentifier;
6464
import org.springframework.lang.Nullable;
@@ -70,6 +70,7 @@
7070
*
7171
* @author Mark Paluch
7272
* @author Mingyuan Wu
73+
* @author Bogdan Ilchyshyn
7374
*/
7475
class DefaultDatabaseClient implements DatabaseClient, ConnectionAccessor {
7576

@@ -1198,7 +1199,7 @@ public <T> TypedUpdateSpec<T> table(Class<T> table) {
11981199

11991200
assertRegularClass(table);
12001201

1201-
return new DefaultTypedUpdateSpec<>(table, null, null);
1202+
return new DefaultTypedUpdateSpec<>(table, null, null, null);
12021203
}
12031204
}
12041205

@@ -1287,32 +1288,43 @@ private UpdatedRowsFetchSpec exchange(SqlIdentifier table) {
12871288
}
12881289
}
12891290

1290-
class DefaultTypedUpdateSpec<T> implements TypedUpdateSpec<T>, UpdateSpec {
1291+
class DefaultTypedUpdateSpec<T> implements TypedUpdateSpec<T>, UpdateMatchingSpec {
12911292

12921293
private final Class<T> typeToUpdate;
12931294
private final @Nullable SqlIdentifier table;
12941295
private final @Nullable T objectToUpdate;
1296+
private final @Nullable CriteriaDefinition where;
12951297

1296-
DefaultTypedUpdateSpec(Class<T> typeToUpdate, @Nullable SqlIdentifier table, @Nullable T objectToUpdate) {
1298+
DefaultTypedUpdateSpec(Class<T> typeToUpdate, @Nullable SqlIdentifier table, @Nullable T objectToUpdate,
1299+
@Nullable CriteriaDefinition where) {
12971300
this.typeToUpdate = typeToUpdate;
12981301
this.table = table;
12991302
this.objectToUpdate = objectToUpdate;
1303+
this.where = where;
13001304
}
13011305

13021306
@Override
1303-
public UpdateSpec using(T objectToUpdate) {
1307+
public UpdateMatchingSpec using(T objectToUpdate) {
13041308

13051309
Assert.notNull(objectToUpdate, "Object to update must not be null");
13061310

1307-
return new DefaultTypedUpdateSpec<>(this.typeToUpdate, this.table, objectToUpdate);
1311+
return new DefaultTypedUpdateSpec<>(this.typeToUpdate, this.table, objectToUpdate, this.where);
13081312
}
13091313

13101314
@Override
13111315
public TypedUpdateSpec<T> table(SqlIdentifier tableName) {
13121316

13131317
Assert.notNull(tableName, "Table name must not be null!");
13141318

1315-
return new DefaultTypedUpdateSpec<>(this.typeToUpdate, tableName, this.objectToUpdate);
1319+
return new DefaultTypedUpdateSpec<>(this.typeToUpdate, tableName, this.objectToUpdate, this.where);
1320+
}
1321+
1322+
@Override
1323+
public UpdateSpec matching(CriteriaDefinition criteria) {
1324+
1325+
Assert.notNull(criteria, "Criteria must not be null!");
1326+
1327+
return new DefaultTypedUpdateSpec<>(this.typeToUpdate, this.table, this.objectToUpdate, criteria);
13161328
}
13171329

13181330
@Override
@@ -1356,8 +1368,14 @@ private UpdatedRowsFetchSpec exchange(SqlIdentifier table) {
13561368
}
13571369
}
13581370

1359-
PreparedOperation<?> operation = mapper.getMappedObject(mapper.createUpdate(table, update).withCriteria(
1360-
org.springframework.data.relational.core.query.Criteria.where(dataAccessStrategy.toSql(ids.get(0))).is(id)));
1371+
Criteria updateCriteria = org.springframework.data.relational.core.query.Criteria
1372+
.where(dataAccessStrategy.toSql(ids.get(0))).is(id);
1373+
if (this.where != null) {
1374+
updateCriteria = updateCriteria.and(this.where);
1375+
}
1376+
1377+
PreparedOperation<?> operation = mapper
1378+
.getMappedObject(mapper.createUpdate(table, update).withCriteria(updateCriteria));
13611379

13621380
return exchangeUpdate(operation);
13631381
}

src/main/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplate.java

+84-8
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import io.r2dbc.spi.Row;
1919
import io.r2dbc.spi.RowMetadata;
20+
import org.springframework.dao.OptimisticLockingFailureException;
2021
import reactor.core.publisher.Flux;
2122
import reactor.core.publisher.Mono;
2223

@@ -30,10 +31,12 @@
3031
import org.springframework.beans.BeansException;
3132
import org.springframework.beans.factory.BeanFactory;
3233
import org.springframework.beans.factory.BeanFactoryAware;
34+
import org.springframework.core.convert.ConversionService;
3335
import org.springframework.dao.DataAccessException;
3436
import org.springframework.dao.TransientDataAccessResourceException;
3537
import org.springframework.data.mapping.IdentifierAccessor;
3638
import org.springframework.data.mapping.MappingException;
39+
import org.springframework.data.mapping.PersistentPropertyAccessor;
3740
import org.springframework.data.mapping.context.MappingContext;
3841
import org.springframework.data.projection.ProjectionInformation;
3942
import org.springframework.data.projection.SpelAwareProxyProjectionFactory;
@@ -59,6 +62,7 @@
5962
* prepared in an application context and given to services as bean reference.
6063
*
6164
* @author Mark Paluch
65+
* @author Bogdan Ilchyshyn
6266
* @since 1.1
6367
*/
6468
public class R2dbcEntityTemplate implements R2dbcEntityOperations, BeanFactoryAware {
@@ -373,6 +377,8 @@ <T> Mono<T> doInsert(T entity, SqlIdentifier tableName) {
373377

374378
RelationalPersistentEntity<T> persistentEntity = getRequiredEntity(entity);
375379

380+
setVersionIfNecessary(persistentEntity, entity);
381+
376382
return this.databaseClient.insert() //
377383
.into(persistentEntity.getType()) //
378384
.table(tableName).using(entity) //
@@ -381,6 +387,19 @@ <T> Mono<T> doInsert(T entity, SqlIdentifier tableName) {
381387
.defaultIfEmpty(entity);
382388
}
383389

390+
private <T> void setVersionIfNecessary(RelationalPersistentEntity<T> persistentEntity, T entity) {
391+
RelationalPersistentProperty versionProperty = persistentEntity.getVersionProperty();
392+
if (versionProperty == null) {
393+
return;
394+
}
395+
396+
Class<?> versionPropertyType = versionProperty.getType();
397+
Long version = versionPropertyType.isPrimitive() ? 1L : 0L;
398+
ConversionService conversionService = this.dataAccessStrategy.getConverter().getConversionService();
399+
PersistentPropertyAccessor<?> propertyAccessor = persistentEntity.getPropertyAccessor(entity);
400+
propertyAccessor.setProperty(versionProperty, conversionService.convert(version, versionPropertyType));
401+
}
402+
384403
/*
385404
* (non-Javadoc)
386405
* @see org.springframework.data.r2dbc.core.R2dbcEntityOperations#update(java.lang.Object)
@@ -392,21 +411,78 @@ public <T> Mono<T> update(T entity) throws DataAccessException {
392411

393412
RelationalPersistentEntity<T> persistentEntity = getRequiredEntity(entity);
394413

395-
return this.databaseClient.update() //
414+
DatabaseClient.UpdateMatchingSpec updateMatchingSpec = this.databaseClient.update() //
396415
.table(persistentEntity.getType()) //
397-
.table(persistentEntity.getTableName()).using(entity) //
398-
.fetch().rowsUpdated().handle((rowsUpdated, sink) -> {
416+
.table(persistentEntity.getTableName()) //
417+
.using(entity);
418+
419+
DatabaseClient.UpdateSpec updateSpec = updateMatchingSpec;
420+
if (persistentEntity.hasVersionProperty()) {
421+
updateSpec = updateMatchingSpec.matching(createMatchingVersionCriteria(entity, persistentEntity));
422+
incrementVersion(entity, persistentEntity);
423+
}
424+
425+
return updateSpec.fetch() //
426+
.rowsUpdated() //
427+
.flatMap(rowsUpdated -> rowsUpdated == 0
428+
? handleMissingUpdate(entity, persistentEntity) : Mono.just(entity));
429+
}
399430

400-
if (rowsUpdated == 0) {
401-
sink.error(new TransientDataAccessResourceException(
402-
String.format("Failed to update table [%s]. Row with Id [%s] does not exist.",
403-
persistentEntity.getTableName(), persistentEntity.getIdentifierAccessor(entity).getIdentifier())));
431+
private <T> Mono<? extends T> handleMissingUpdate(T entity, RelationalPersistentEntity<T> persistentEntity) {
432+
if (!persistentEntity.hasVersionProperty()) {
433+
return Mono.error(new TransientDataAccessResourceException(
434+
formatTransientEntityExceptionMessage(entity, persistentEntity)));
435+
}
436+
437+
return doCount(getByIdQuery(entity, persistentEntity), entity.getClass(), persistentEntity.getTableName())
438+
.map(count -> {
439+
if (count == 0) {
440+
throw new TransientDataAccessResourceException(
441+
formatTransientEntityExceptionMessage(entity, persistentEntity));
404442
} else {
405-
sink.next(entity);
443+
throw new OptimisticLockingFailureException(
444+
formatOptimisticLockingExceptionMessage(entity, persistentEntity));
406445
}
407446
});
408447
}
409448

449+
private <T> String formatOptimisticLockingExceptionMessage(T entity, RelationalPersistentEntity<T> persistentEntity) {
450+
return String.format("Failed to update table [%s]. Version does not match for row with Id [%s].",
451+
persistentEntity.getTableName(), persistentEntity.getIdentifierAccessor(entity).getIdentifier());
452+
}
453+
454+
private <T> String formatTransientEntityExceptionMessage(T entity, RelationalPersistentEntity<T> persistentEntity) {
455+
return String.format("Failed to update table [%s]. Row with Id [%s] does not exist.",
456+
persistentEntity.getTableName(), persistentEntity.getIdentifierAccessor(entity).getIdentifier());
457+
}
458+
459+
private <T> void incrementVersion(T entity, RelationalPersistentEntity<T> persistentEntity) {
460+
PersistentPropertyAccessor<?> propertyAccessor = persistentEntity.getPropertyAccessor(entity);
461+
RelationalPersistentProperty versionProperty = persistentEntity.getVersionProperty();
462+
463+
ConversionService conversionService = this.dataAccessStrategy.getConverter().getConversionService();
464+
Object currentVersionValue = propertyAccessor.getProperty(versionProperty);
465+
long newVersionValue = 1L;
466+
if (currentVersionValue != null) {
467+
newVersionValue = conversionService.convert(currentVersionValue, Long.class) + 1;
468+
}
469+
Class<?> versionPropertyType = versionProperty.getType();
470+
propertyAccessor.setProperty(versionProperty, conversionService.convert(newVersionValue, versionPropertyType));
471+
}
472+
473+
private <T> Criteria createMatchingVersionCriteria(T entity, RelationalPersistentEntity<T> persistentEntity) {
474+
PersistentPropertyAccessor<?> propertyAccessor = persistentEntity.getPropertyAccessor(entity);
475+
RelationalPersistentProperty versionProperty = persistentEntity.getVersionProperty();
476+
477+
Object version = propertyAccessor.getProperty(versionProperty);
478+
Criteria.CriteriaStep versionColumn = Criteria.where(dataAccessStrategy.toSql(versionProperty.getColumnName()));
479+
if (version == null) {
480+
return versionColumn.isNull();
481+
} else {
482+
return versionColumn.is(version);
483+
}
484+
}
485+
410486
/*
411487
* (non-Javadoc)
412488
* @see org.springframework.data.r2dbc.core.R2dbcEntityOperations#delete(java.lang.Object)

0 commit comments

Comments
 (0)