Skip to content

#93 - Adding support for optimistic locking based on @Version column #314

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
* to create an instance.
*
* @author Mark Paluch
* @author Bogdan Ilchyshyn
*/
public interface DatabaseClient {

Expand Down Expand Up @@ -719,9 +720,9 @@ interface TypedUpdateSpec<T> {
*
* @param objectToUpdate the object of which the attributes will provide the values for the update and the primary
* key. Must not be {@literal null}.
* @return a {@link UpdateSpec} for further configuration of the update. Guaranteed to be not {@literal null}.
* @return a {@link UpdateMatchingSpec} for further configuration of the update. Guaranteed to be not {@literal null}.
*/
UpdateSpec using(T objectToUpdate);
UpdateMatchingSpec using(T objectToUpdate);

/**
* Use the given {@code tableName} as update target.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
* Default implementation of {@link DatabaseClient}.
*
* @author Mark Paluch
* @author Bogdan Ilchyshyn
*/
class DefaultDatabaseClient implements DatabaseClient, ConnectionAccessor {

Expand Down Expand Up @@ -1196,7 +1197,7 @@ public <T> TypedUpdateSpec<T> table(Class<T> table) {

assertRegularClass(table);

return new DefaultTypedUpdateSpec<>(table, null, null);
return new DefaultTypedUpdateSpec<>(table, null, null, null);
}
}

Expand Down Expand Up @@ -1275,32 +1276,43 @@ private UpdatedRowsFetchSpec exchange(SqlIdentifier table) {
}
}

class DefaultTypedUpdateSpec<T> implements TypedUpdateSpec<T>, UpdateSpec {
class DefaultTypedUpdateSpec<T> implements TypedUpdateSpec<T>, UpdateMatchingSpec {

private final Class<T> typeToUpdate;
private final @Nullable SqlIdentifier table;
private final @Nullable T objectToUpdate;
private final @Nullable Criteria where;

DefaultTypedUpdateSpec(Class<T> typeToUpdate, @Nullable SqlIdentifier table, @Nullable T objectToUpdate) {
DefaultTypedUpdateSpec(Class<T> typeToUpdate, @Nullable SqlIdentifier table, @Nullable T objectToUpdate,
@Nullable Criteria where) {
this.typeToUpdate = typeToUpdate;
this.table = table;
this.objectToUpdate = objectToUpdate;
this.where = where;
}

@Override
public UpdateSpec using(T objectToUpdate) {
public UpdateMatchingSpec using(T objectToUpdate) {

Assert.notNull(objectToUpdate, "Object to update must not be null");

return new DefaultTypedUpdateSpec<>(this.typeToUpdate, this.table, objectToUpdate);
return new DefaultTypedUpdateSpec<>(this.typeToUpdate, this.table, objectToUpdate, this.where);
}

@Override
public TypedUpdateSpec<T> table(SqlIdentifier tableName) {

Assert.notNull(tableName, "Table name must not be null!");

return new DefaultTypedUpdateSpec<>(this.typeToUpdate, tableName, this.objectToUpdate);
return new DefaultTypedUpdateSpec<>(this.typeToUpdate, tableName, this.objectToUpdate, this.where);
}

@Override
public UpdateSpec matching(Criteria criteria) {

Assert.notNull(criteria, "Criteria must not be null!");

return new DefaultTypedUpdateSpec<>(this.typeToUpdate, this.table, this.objectToUpdate, criteria);
}

@Override
Expand Down Expand Up @@ -1343,8 +1355,13 @@ private UpdatedRowsFetchSpec exchange(SqlIdentifier table) {
}
}

Criteria updateCriteria = Criteria.where(dataAccessStrategy.toSql(ids.get(0))).is(id);
if (this.where != null) {
updateCriteria = updateCriteria.and(this.where);
}

PreparedOperation<?> operation = mapper.getMappedObject(
mapper.createUpdate(table, update).withCriteria(Criteria.where(dataAccessStrategy.toSql(ids.get(0))).is(id)));
mapper.createUpdate(table, update).withCriteria(updateCriteria));

return exchangeUpdate(operation);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import io.r2dbc.spi.Row;
import io.r2dbc.spi.RowMetadata;
import org.springframework.dao.OptimisticLockingFailureException;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

Expand All @@ -30,10 +31,12 @@
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.beans.factory.BeanFactoryAware;
import org.springframework.core.convert.ConversionService;
import org.springframework.dao.DataAccessException;
import org.springframework.dao.TransientDataAccessResourceException;
import org.springframework.data.mapping.IdentifierAccessor;
import org.springframework.data.mapping.MappingException;
import org.springframework.data.mapping.PersistentPropertyAccessor;
import org.springframework.data.mapping.context.MappingContext;
import org.springframework.data.projection.ProjectionInformation;
import org.springframework.data.projection.SpelAwareProxyProjectionFactory;
Expand All @@ -58,6 +61,7 @@
* prepared in an application context and given to services as bean reference.
*
* @author Mark Paluch
* @author Bogdan Ilchyshyn
* @since 1.1
*/
public class R2dbcEntityTemplate implements R2dbcEntityOperations, BeanFactoryAware {
Expand Down Expand Up @@ -372,6 +376,8 @@ <T> Mono<T> doInsert(T entity, SqlIdentifier tableName) {

RelationalPersistentEntity<T> persistentEntity = getRequiredEntity(entity);

setVersionIfNecessary(persistentEntity, entity);

return this.databaseClient.insert() //
.into(persistentEntity.getType()) //
.table(tableName).using(entity) //
Expand All @@ -380,6 +386,19 @@ <T> Mono<T> doInsert(T entity, SqlIdentifier tableName) {
.defaultIfEmpty(entity);
}

private <T> void setVersionIfNecessary(RelationalPersistentEntity<T> persistentEntity, T entity) {
RelationalPersistentProperty versionProperty = persistentEntity.getVersionProperty();
if (versionProperty == null) {
return;
}

Class<?> versionPropertyType = versionProperty.getType();
Long version = versionPropertyType.isPrimitive() ? 1L : 0L;
ConversionService conversionService = this.dataAccessStrategy.getConverter().getConversionService();
PersistentPropertyAccessor<?> propertyAccessor = persistentEntity.getPropertyAccessor(entity);
propertyAccessor.setProperty(versionProperty, conversionService.convert(version, versionPropertyType));
}

/*
* (non-Javadoc)
* @see org.springframework.data.r2dbc.core.R2dbcEntityOperations#update(java.lang.Object)
Expand All @@ -391,21 +410,78 @@ public <T> Mono<T> update(T entity) throws DataAccessException {

RelationalPersistentEntity<T> persistentEntity = getRequiredEntity(entity);

return this.databaseClient.update() //
DatabaseClient.UpdateMatchingSpec updateMatchingSpec = this.databaseClient.update() //
.table(persistentEntity.getType()) //
.table(persistentEntity.getTableName()).using(entity) //
.fetch().rowsUpdated().handle((rowsUpdated, sink) -> {
.table(persistentEntity.getTableName()) //
.using(entity);

DatabaseClient.UpdateSpec updateSpec = updateMatchingSpec;
if (persistentEntity.hasVersionProperty()) {
updateSpec = updateMatchingSpec.matching(createMatchingVersionCriteria(entity, persistentEntity));
incrementVersion(entity, persistentEntity);
}

return updateSpec.fetch() //
.rowsUpdated() //
.flatMap(rowsUpdated -> rowsUpdated == 0
? handleMissingUpdate(entity, persistentEntity) : Mono.just(entity));
}

if (rowsUpdated == 0) {
sink.error(new TransientDataAccessResourceException(
String.format("Failed to update table [%s]. Row with Id [%s] does not exist.",
persistentEntity.getTableName(), persistentEntity.getIdentifierAccessor(entity).getIdentifier())));
private <T> Mono<? extends T> handleMissingUpdate(T entity, RelationalPersistentEntity<T> persistentEntity) {
if (!persistentEntity.hasVersionProperty()) {
return Mono.error(new TransientDataAccessResourceException(
formatTransientEntityExceptionMessage(entity, persistentEntity)));
}

return doCount(getByIdQuery(entity, persistentEntity), entity.getClass(), persistentEntity.getTableName())
.map(count -> {
if (count == 0) {
throw new TransientDataAccessResourceException(
formatTransientEntityExceptionMessage(entity, persistentEntity));
} else {
sink.next(entity);
throw new OptimisticLockingFailureException(
formatOptimisticLockingExceptionMessage(entity, persistentEntity));
}
});
}

private <T> String formatOptimisticLockingExceptionMessage(T entity, RelationalPersistentEntity<T> persistentEntity) {
return String.format("Failed to update table [%s]. Version does not match for row with Id [%s].",
persistentEntity.getTableName(), persistentEntity.getIdentifierAccessor(entity).getIdentifier());
}

private <T> String formatTransientEntityExceptionMessage(T entity, RelationalPersistentEntity<T> persistentEntity) {
return String.format("Failed to update table [%s]. Row with Id [%s] does not exist.",
persistentEntity.getTableName(), persistentEntity.getIdentifierAccessor(entity).getIdentifier());
}

private <T> void incrementVersion(T entity, RelationalPersistentEntity<T> persistentEntity) {
PersistentPropertyAccessor<?> propertyAccessor = persistentEntity.getPropertyAccessor(entity);
RelationalPersistentProperty versionProperty = persistentEntity.getVersionProperty();

ConversionService conversionService = this.dataAccessStrategy.getConverter().getConversionService();
Object currentVersionValue = propertyAccessor.getProperty(versionProperty);
long newVersionValue = 1L;
if (currentVersionValue != null) {
newVersionValue = conversionService.convert(currentVersionValue, Long.class) + 1;
}
Class<?> versionPropertyType = versionProperty.getType();
propertyAccessor.setProperty(versionProperty, conversionService.convert(newVersionValue, versionPropertyType));
}

private <T> Criteria createMatchingVersionCriteria(T entity, RelationalPersistentEntity<T> persistentEntity) {
PersistentPropertyAccessor<?> propertyAccessor = persistentEntity.getPropertyAccessor(entity);
RelationalPersistentProperty versionProperty = persistentEntity.getVersionProperty();

Object version = propertyAccessor.getProperty(versionProperty);
Criteria.CriteriaStep versionColumn = Criteria.where(dataAccessStrategy.toSql(versionProperty.getColumnName()));
if (version == null) {
return versionColumn.isNull();
} else {
return versionColumn.is(version);
}
}

/*
* (non-Javadoc)
* @see org.springframework.data.r2dbc.core.R2dbcEntityOperations#delete(java.lang.Object)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,11 @@

import org.junit.Before;
import org.junit.Test;

import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.dao.DataAccessException;
import org.springframework.dao.OptimisticLockingFailureException;
import org.springframework.data.annotation.Id;
import org.springframework.data.annotation.Version;
import org.springframework.data.domain.Persistable;
import org.springframework.data.r2dbc.convert.MappingR2dbcConverter;
import org.springframework.data.r2dbc.core.DatabaseClient;
Expand All @@ -53,6 +54,7 @@
* Abstract integration tests for {@link SimpleR2dbcRepository} to be ran against various databases.
*
* @author Mark Paluch
* @author Bogdan Ilchyshyn
*/
public abstract class AbstractSimpleR2dbcRepositoryIntegrationTests extends R2dbcIntegrationTestSupport {

Expand Down Expand Up @@ -117,6 +119,42 @@ public void shouldSaveNewObject() {
assertThat(map).containsEntry("name", "SCHAUFELRADBAGGER").containsEntry("manual", 12).containsKey("id");
}

@Test
public void shouldSaveNewObjectAndSetVersionIfWrapperVersionPropertyExists() {

LegoSetVersionable legoSet = new LegoSetVersionable(null, "SCHAUFELRADBAGGER", 12, null);

repository.save(legoSet) //
.as(StepVerifier::create) //
.consumeNextWith(actual -> assertThat(actual.getVersion()).isEqualTo(0)) //
.verifyComplete();

Map<String, Object> map = jdbc.queryForMap("SELECT * FROM legoset");
assertThat(map) //
.containsEntry("name", "SCHAUFELRADBAGGER") //
.containsEntry("manual", 12) //
.containsEntry("version", 0) //
.containsKey("id");
}

@Test
public void shouldSaveNewObjectAndSetVersionIfPrimitiveVersionPropertyExists() {

LegoSetPrimitiveVersionable legoSet = new LegoSetPrimitiveVersionable(null, "SCHAUFELRADBAGGER", 12, -1);

repository.save(legoSet) //
.as(StepVerifier::create) //
.consumeNextWith(actual -> assertThat(actual.getVersion()).isEqualTo(1)) //
.verifyComplete();

Map<String, Object> map = jdbc.queryForMap("SELECT * FROM legoset");
assertThat(map) //
.containsEntry("name", "SCHAUFELRADBAGGER") //
.containsEntry("manual", 12) //
.containsEntry("version", 1) //
.containsKey("id");
}

@Test
public void shouldUpdateObject() {

Expand All @@ -135,6 +173,44 @@ public void shouldUpdateObject() {
assertThat(map).containsEntry("name", "SCHAUFELRADBAGGER").containsEntry("manual", 14).containsKey("id");
}

@Test
public void shouldUpdateVersionableObjectAndIncreaseVersion() {

jdbc.execute("INSERT INTO legoset (name, manual, version) VALUES('SCHAUFELRADBAGGER', 12, 42)");
Integer id = jdbc.queryForObject("SELECT id FROM legoset", Integer.class);

LegoSetVersionable legoSet = new LegoSetVersionable(id, "SCHAUFELRADBAGGER", 12, 42);
legoSet.setManual(14);

repository.save(legoSet) //
.as(StepVerifier::create) //
.expectNextCount(1) //
.verifyComplete();

assertThat(legoSet.getVersion()).isEqualTo(43);

Map<String, Object> map = jdbc.queryForMap("SELECT * FROM legoset");
assertThat(map)
.containsEntry("name", "SCHAUFELRADBAGGER") //
.containsEntry("manual", 14) //
.containsEntry("version", 43) //
.containsKey("id");
}

@Test
public void shouldFailWithOptimistickLockingWhenVersionDoesNotMatchOnUpdate() {

jdbc.execute("INSERT INTO legoset (name, manual, version) VALUES('SCHAUFELRADBAGGER', 12, 42)");
Integer id = jdbc.queryForObject("SELECT id FROM legoset", Integer.class);

LegoSetVersionable legoSet = new LegoSetVersionable(id, "SCHAUFELRADBAGGER", 12, 0);

repository.save(legoSet) //
.as(StepVerifier::create) //
.expectError(OptimisticLockingFailureException.class) //
.verify();
}

@Test
public void shouldSaveObjectsUsingIterable() {

Expand Down Expand Up @@ -392,4 +468,28 @@ public boolean isNew() {
return true;
}
}

@Data
@Table("legoset")
@NoArgsConstructor
static class LegoSetVersionable extends LegoSet {
@Version Integer version;

public LegoSetVersionable(Integer id, String name, Integer manual, Integer version) {
super(id, name, manual);
this.version = version;
}
}

@Data
@Table("legoset")
@NoArgsConstructor
static class LegoSetPrimitiveVersionable extends LegoSet {
@Version int version;

public LegoSetPrimitiveVersionable(Integer id, String name, Integer manual, int version) {
super(id, name, manual);
this.version = version;
}
}
}
Loading