Skip to content

Commit c2fcb9d

Browse files
spring-projects#93 - Adding support for optimistic locking based on @Version column
1 parent e56f126 commit c2fcb9d

File tree

8 files changed

+196
-13
lines changed

8 files changed

+196
-13
lines changed

src/main/java/org/springframework/data/r2dbc/core/DatabaseClient.java

+3-2
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@
4646
* to create an instance.
4747
*
4848
* @author Mark Paluch
49+
* @author Bogdan Ilchyshyn
4950
*/
5051
public interface DatabaseClient {
5152

@@ -707,9 +708,9 @@ interface TypedUpdateSpec<T> {
707708
*
708709
* @param objectToUpdate the object of which the attributes will provide the values for the update and the primary
709710
* key. Must not be {@literal null}.
710-
* @return a {@link UpdateSpec} for further configuration of the update. Guaranteed to be not {@literal null}.
711+
* @return a {@link UpdateMatchingSpec} for further configuration of the update. Guaranteed to be not {@literal null}.
711712
*/
712-
UpdateSpec using(T objectToUpdate);
713+
UpdateMatchingSpec using(T objectToUpdate);
713714

714715
/**
715716
* Use the given {@code tableName} as update target.

src/main/java/org/springframework/data/r2dbc/core/DefaultDatabaseClient.java

+24-7
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,7 @@
6969
* Default implementation of {@link DatabaseClient}.
7070
*
7171
* @author Mark Paluch
72+
* @author Bogdan Ilchyshyn
7273
*/
7374
class DefaultDatabaseClient implements DatabaseClient, ConnectionAccessor {
7475

@@ -1188,7 +1189,7 @@ public <T> TypedUpdateSpec<T> table(Class<T> table) {
11881189

11891190
assertRegularClass(table);
11901191

1191-
return new DefaultTypedUpdateSpec<>(table, null, null);
1192+
return new DefaultTypedUpdateSpec<>(table, null, null, null);
11921193
}
11931194
}
11941195

@@ -1262,32 +1263,43 @@ private UpdatedRowsFetchSpec exchange(SqlIdentifier table) {
12621263
}
12631264
}
12641265

1265-
class DefaultTypedUpdateSpec<T> implements TypedUpdateSpec<T>, UpdateSpec {
1266+
class DefaultTypedUpdateSpec<T> implements TypedUpdateSpec<T>, UpdateMatchingSpec {
12661267

12671268
private final @Nullable Class<T> typeToUpdate;
12681269
private final @Nullable SqlIdentifier table;
12691270
private final T objectToUpdate;
1271+
private final @Nullable Criteria where;
12701272

1271-
DefaultTypedUpdateSpec(@Nullable Class<T> typeToUpdate, @Nullable SqlIdentifier table, T objectToUpdate) {
1273+
DefaultTypedUpdateSpec(@Nullable Class<T> typeToUpdate, @Nullable SqlIdentifier table, T objectToUpdate,
1274+
@Nullable Criteria where) {
12721275
this.typeToUpdate = typeToUpdate;
12731276
this.table = table;
12741277
this.objectToUpdate = objectToUpdate;
1278+
this.where = where;
12751279
}
12761280

12771281
@Override
1278-
public UpdateSpec using(T objectToUpdate) {
1282+
public UpdateMatchingSpec using(T objectToUpdate) {
12791283

12801284
Assert.notNull(objectToUpdate, "Object to update must not be null");
12811285

1282-
return new DefaultTypedUpdateSpec<>(this.typeToUpdate, this.table, objectToUpdate);
1286+
return new DefaultTypedUpdateSpec<>(this.typeToUpdate, this.table, objectToUpdate, this.where);
12831287
}
12841288

12851289
@Override
12861290
public TypedUpdateSpec<T> table(SqlIdentifier tableName) {
12871291

12881292
Assert.notNull(tableName, "Table name must not be null!");
12891293

1290-
return new DefaultTypedUpdateSpec<>(this.typeToUpdate, tableName, this.objectToUpdate);
1294+
return new DefaultTypedUpdateSpec<>(this.typeToUpdate, tableName, this.objectToUpdate, this.where);
1295+
}
1296+
1297+
@Override
1298+
public UpdateSpec matching(Criteria criteria) {
1299+
1300+
Assert.notNull(criteria, "Criteria must not be null!");
1301+
1302+
return new DefaultTypedUpdateSpec<>(this.typeToUpdate, this.table, this.objectToUpdate, criteria);
12911303
}
12921304

12931305
@Override
@@ -1330,8 +1342,13 @@ private UpdatedRowsFetchSpec exchange(SqlIdentifier table) {
13301342
}
13311343
}
13321344

1345+
Criteria updateCriteria = Criteria.where(dataAccessStrategy.toSql(ids.get(0))).is(id);
1346+
if (this.where != null) {
1347+
updateCriteria = updateCriteria.and(this.where);
1348+
}
1349+
13331350
PreparedOperation<?> operation = mapper.getMappedObject(
1334-
mapper.createUpdate(table, update).withCriteria(Criteria.where(dataAccessStrategy.toSql(ids.get(0))).is(id)));
1351+
mapper.createUpdate(table, update).withCriteria(updateCriteria));
13351352

13361353
return exchangeUpdate(operation);
13371354
}

src/main/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplate.java

+56-3
Original file line numberDiff line numberDiff line change
@@ -30,10 +30,12 @@
3030
import org.springframework.beans.BeansException;
3131
import org.springframework.beans.factory.BeanFactory;
3232
import org.springframework.beans.factory.BeanFactoryAware;
33+
import org.springframework.core.convert.ConversionService;
3334
import org.springframework.dao.DataAccessException;
3435
import org.springframework.dao.TransientDataAccessResourceException;
3536
import org.springframework.data.mapping.IdentifierAccessor;
3637
import org.springframework.data.mapping.MappingException;
38+
import org.springframework.data.mapping.PersistentPropertyAccessor;
3739
import org.springframework.data.mapping.context.MappingContext;
3840
import org.springframework.data.projection.ProjectionInformation;
3941
import org.springframework.data.projection.SpelAwareProxyProjectionFactory;
@@ -58,6 +60,7 @@
5860
* prepared in an application context and given to services as bean reference.
5961
*
6062
* @author Mark Paluch
63+
* @author Bogdan Ilchyshyn
6164
* @since 1.1
6265
*/
6366
public class R2dbcEntityTemplate implements R2dbcEntityOperations, BeanFactoryAware {
@@ -372,6 +375,8 @@ <T> Mono<T> doInsert(T entity, SqlIdentifier tableName) {
372375

373376
RelationalPersistentEntity<T> persistentEntity = getRequiredEntity(entity);
374377

378+
setVersionIfNecessary(persistentEntity, entity);
379+
375380
return this.databaseClient.insert() //
376381
.into(persistentEntity.getType()) //
377382
.table(tableName).using(entity) //
@@ -380,6 +385,19 @@ <T> Mono<T> doInsert(T entity, SqlIdentifier tableName) {
380385
.defaultIfEmpty(entity);
381386
}
382387

388+
private <T> void setVersionIfNecessary(RelationalPersistentEntity<T> persistentEntity, T entity) {
389+
RelationalPersistentProperty versionProperty = persistentEntity.getVersionProperty();
390+
if (versionProperty == null) {
391+
return;
392+
}
393+
394+
Class<?> versionPropertyType = versionProperty.getType();
395+
Long version = versionPropertyType.isPrimitive() ? 1L : 0L;
396+
ConversionService conversionService = this.dataAccessStrategy.getConverter().getConversionService();
397+
PersistentPropertyAccessor<?> propertyAccessor = persistentEntity.getPropertyAccessor(entity);
398+
propertyAccessor.setProperty(versionProperty, conversionService.convert(version, versionPropertyType));
399+
}
400+
383401
/*
384402
* (non-Javadoc)
385403
* @see org.springframework.data.r2dbc.core.R2dbcEntityOperations#update(java.lang.Object)
@@ -391,11 +409,19 @@ public <T> Mono<T> update(T entity) throws DataAccessException {
391409

392410
RelationalPersistentEntity<T> persistentEntity = getRequiredEntity(entity);
393411

394-
return this.databaseClient.update() //
412+
DatabaseClient.UpdateMatchingSpec updateMatchingSpec = this.databaseClient.update() //
395413
.table(persistentEntity.getType()) //
396-
.table(persistentEntity.getTableName()).using(entity) //
397-
.fetch().rowsUpdated().handle((rowsUpdated, sink) -> {
414+
.table(persistentEntity.getTableName()) //
415+
.using(entity);
416+
417+
DatabaseClient.UpdateSpec updateSpec = updateMatchingSpec;
418+
if (persistentEntity.hasVersionProperty()) {
419+
updateSpec = updateMatchingSpec.matching(createMatchingVersionCriteria(entity, persistentEntity));
420+
incrementVersion(entity, persistentEntity);
421+
}
398422

423+
return updateSpec.fetch() //
424+
.rowsUpdated().handle((rowsUpdated, sink) -> {
399425
if (rowsUpdated == 0) {
400426
sink.error(new TransientDataAccessResourceException(
401427
String.format("Failed to update table [%s]. Row with Id [%s] does not exist.",
@@ -406,6 +432,33 @@ public <T> Mono<T> update(T entity) throws DataAccessException {
406432
});
407433
}
408434

435+
private <T> void incrementVersion(T entity, RelationalPersistentEntity<T> persistentEntity) {
436+
PersistentPropertyAccessor<?> propertyAccessor = persistentEntity.getPropertyAccessor(entity);
437+
RelationalPersistentProperty versionProperty = persistentEntity.getVersionProperty();
438+
439+
ConversionService conversionService = this.dataAccessStrategy.getConverter().getConversionService();
440+
Object currentVersionValue = propertyAccessor.getProperty(versionProperty);
441+
long newVersionValue = 1L;
442+
if (currentVersionValue != null) {
443+
newVersionValue = conversionService.convert(currentVersionValue, Long.class) + 1;
444+
}
445+
Class<?> versionPropertyType = versionProperty.getType();
446+
propertyAccessor.setProperty(versionProperty, conversionService.convert(newVersionValue, versionPropertyType));
447+
}
448+
449+
private <T> Criteria createMatchingVersionCriteria(T entity, RelationalPersistentEntity<T> persistentEntity) {
450+
PersistentPropertyAccessor<?> propertyAccessor = persistentEntity.getPropertyAccessor(entity);
451+
RelationalPersistentProperty versionProperty = persistentEntity.getVersionProperty();
452+
453+
Object version = propertyAccessor.getProperty(versionProperty);
454+
Criteria.CriteriaStep versionColumn = Criteria.where(dataAccessStrategy.toSql(versionProperty.getColumnName()));
455+
if (version == null) {
456+
return versionColumn.isNull();
457+
} else {
458+
return versionColumn.is(version);
459+
}
460+
}
461+
409462
/*
410463
* (non-Javadoc)
411464
* @see org.springframework.data.r2dbc.core.R2dbcEntityOperations#delete(java.lang.Object)

src/test/java/org/springframework/data/r2dbc/repository/support/AbstractSimpleR2dbcRepositoryIntegrationTests.java

+101-1
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,11 @@
3333

3434
import org.junit.Before;
3535
import org.junit.Test;
36-
3736
import org.springframework.beans.factory.annotation.Autowired;
3837
import org.springframework.dao.DataAccessException;
38+
import org.springframework.dao.OptimisticLockingFailureException;
3939
import org.springframework.data.annotation.Id;
40+
import org.springframework.data.annotation.Version;
4041
import org.springframework.data.domain.Persistable;
4142
import org.springframework.data.r2dbc.convert.MappingR2dbcConverter;
4243
import org.springframework.data.r2dbc.core.DatabaseClient;
@@ -53,6 +54,7 @@
5354
* Abstract integration tests for {@link SimpleR2dbcRepository} to be ran against various databases.
5455
*
5556
* @author Mark Paluch
57+
* @author Bogdan Ilchyshyn
5658
*/
5759
public abstract class AbstractSimpleR2dbcRepositoryIntegrationTests extends R2dbcIntegrationTestSupport {
5860

@@ -117,6 +119,42 @@ public void shouldSaveNewObject() {
117119
assertThat(map).containsEntry("name", "SCHAUFELRADBAGGER").containsEntry("manual", 12).containsKey("id");
118120
}
119121

122+
@Test
123+
public void shouldSaveNewObjectAndSetVersionIfWrapperVersionPropertyExists() {
124+
125+
LegoSetVersionable legoSet = new LegoSetVersionable(null, "SCHAUFELRADBAGGER", 12, null);
126+
127+
repository.save(legoSet) //
128+
.as(StepVerifier::create) //
129+
.consumeNextWith(actual -> assertThat(actual.getVersion()).isEqualTo(0)) //
130+
.verifyComplete();
131+
132+
Map<String, Object> map = jdbc.queryForMap("SELECT * FROM legoset");
133+
assertThat(map) //
134+
.containsEntry("name", "SCHAUFELRADBAGGER") //
135+
.containsEntry("manual", 12) //
136+
.containsEntry("version", 0) //
137+
.containsKey("id");
138+
}
139+
140+
@Test
141+
public void shouldSaveNewObjectAndSetVersionIfPrimitiveVersionPropertyExists() {
142+
143+
LegoSetPrimitiveVersionable legoSet = new LegoSetPrimitiveVersionable(null, "SCHAUFELRADBAGGER", 12, -1);
144+
145+
repository.save(legoSet) //
146+
.as(StepVerifier::create) //
147+
.consumeNextWith(actual -> assertThat(actual.getVersion()).isEqualTo(1)) //
148+
.verifyComplete();
149+
150+
Map<String, Object> map = jdbc.queryForMap("SELECT * FROM legoset");
151+
assertThat(map) //
152+
.containsEntry("name", "SCHAUFELRADBAGGER") //
153+
.containsEntry("manual", 12) //
154+
.containsEntry("version", 1) //
155+
.containsKey("id");
156+
}
157+
120158
@Test
121159
public void shouldUpdateObject() {
122160

@@ -135,6 +173,44 @@ public void shouldUpdateObject() {
135173
assertThat(map).containsEntry("name", "SCHAUFELRADBAGGER").containsEntry("manual", 14).containsKey("id");
136174
}
137175

176+
@Test
177+
public void shouldUpdateVersionableObjectAndIncreaseVersion() {
178+
179+
jdbc.execute("INSERT INTO legoset (name, manual, version) VALUES('SCHAUFELRADBAGGER', 12, 42)");
180+
Integer id = jdbc.queryForObject("SELECT id FROM legoset", Integer.class);
181+
182+
LegoSetVersionable legoSet = new LegoSetVersionable(id, "SCHAUFELRADBAGGER", 12, 42);
183+
legoSet.setManual(14);
184+
185+
repository.save(legoSet) //
186+
.as(StepVerifier::create) //
187+
.expectNextCount(1) //
188+
.verifyComplete();
189+
190+
assertThat(legoSet.getVersion()).isEqualTo(43);
191+
192+
Map<String, Object> map = jdbc.queryForMap("SELECT * FROM legoset");
193+
assertThat(map)
194+
.containsEntry("name", "SCHAUFELRADBAGGER") //
195+
.containsEntry("manual", 14) //
196+
.containsEntry("version", 43) //
197+
.containsKey("id");
198+
}
199+
200+
@Test
201+
public void shouldFailWithOptimistickLockingWhenVersionDoesNotMatchOnUpdate() {
202+
203+
jdbc.execute("INSERT INTO legoset (name, manual, version) VALUES('SCHAUFELRADBAGGER', 12, 42)");
204+
Integer id = jdbc.queryForObject("SELECT id FROM legoset", Integer.class);
205+
206+
LegoSetVersionable legoSet = new LegoSetVersionable(id, "SCHAUFELRADBAGGER", 12, 0);
207+
208+
repository.save(legoSet) //
209+
.as(StepVerifier::create) //
210+
.expectError(OptimisticLockingFailureException.class) //
211+
.verify();
212+
}
213+
138214
@Test
139215
public void shouldSaveObjectsUsingIterable() {
140216

@@ -392,4 +468,28 @@ public boolean isNew() {
392468
return true;
393469
}
394470
}
471+
472+
@Data
473+
@Table("legoset")
474+
@NoArgsConstructor
475+
static class LegoSetVersionable extends LegoSet {
476+
@Version Integer version;
477+
478+
public LegoSetVersionable(Integer id, String name, Integer manual, Integer version) {
479+
super(id, name, manual);
480+
this.version = version;
481+
}
482+
}
483+
484+
@Data
485+
@Table("legoset")
486+
@NoArgsConstructor
487+
static class LegoSetPrimitiveVersionable extends LegoSet {
488+
@Version int version;
489+
490+
public LegoSetPrimitiveVersionable(Integer id, String name, Integer manual, int version) {
491+
super(id, name, manual);
492+
this.version = version;
493+
}
494+
}
395495
}

src/test/java/org/springframework/data/r2dbc/testing/H2TestSupport.java

+3
Original file line numberDiff line numberDiff line change
@@ -27,18 +27,21 @@
2727
* Utility class for testing against H2.
2828
*
2929
* @author Mark Paluch
30+
* @author Bogdan Ilchyshyn
3031
*/
3132
public class H2TestSupport {
3233

3334
public static String CREATE_TABLE_LEGOSET = "CREATE TABLE legoset (\n" //
3435
+ " id integer CONSTRAINT id PRIMARY KEY,\n" //
36+
+ " version integer NULL,\n" //
3537
+ " name varchar(255) NOT NULL,\n" //
3638
+ " manual integer NULL\n," //
3739
+ " cert bytea NULL\n" //
3840
+ ");";
3941

4042
public static String CREATE_TABLE_LEGOSET_WITH_ID_GENERATION = "CREATE TABLE legoset (\n" //
4143
+ " id serial CONSTRAINT id PRIMARY KEY,\n" //
44+
+ " version integer NULL,\n" //
4245
+ " name varchar(255) NOT NULL,\n" //
4346
+ " manual integer NULL\n" //
4447
+ ");";

src/test/java/org/springframework/data/r2dbc/testing/MySqlTestSupport.java

+3
Original file line numberDiff line numberDiff line change
@@ -35,20 +35,23 @@
3535
* Utility class for testing against MySQL.
3636
*
3737
* @author Mark Paluch
38+
* @author Bogdan Ilchyshyn
3839
*/
3940
public class MySqlTestSupport {
4041

4142
private static ExternalDatabase testContainerDatabase;
4243

4344
public static String CREATE_TABLE_LEGOSET = "CREATE TABLE legoset (\n" //
4445
+ " id integer PRIMARY KEY,\n" //
46+
+ " version integer NULL,\n" //
4547
+ " name varchar(255) NOT NULL,\n" //
4648
+ " manual integer NULL\n," //
4749
+ " cert varbinary(255) NULL\n" //
4850
+ ") ENGINE=InnoDB;";
4951

5052
public static String CREATE_TABLE_LEGOSET_WITH_ID_GENERATION = "CREATE TABLE legoset (\n" //
5153
+ " id integer AUTO_INCREMENT PRIMARY KEY,\n" //
54+
+ " version integer NULL,\n" //
5255
+ " name varchar(255) NOT NULL,\n" //
5356
+ " manual integer NULL\n" //
5457
+ ") ENGINE=InnoDB;";

0 commit comments

Comments
 (0)