diff --git a/src/main/java/org/springframework/data/r2dbc/core/DatabaseClient.java b/src/main/java/org/springframework/data/r2dbc/core/DatabaseClient.java index 6e090499..ab62c594 100644 --- a/src/main/java/org/springframework/data/r2dbc/core/DatabaseClient.java +++ b/src/main/java/org/springframework/data/r2dbc/core/DatabaseClient.java @@ -48,6 +48,7 @@ * to create an instance. * * @author Mark Paluch + * @author Bogdan Ilchyshyn */ public interface DatabaseClient { @@ -719,9 +720,9 @@ interface TypedUpdateSpec { * * @param objectToUpdate the object of which the attributes will provide the values for the update and the primary * key. Must not be {@literal null}. - * @return a {@link UpdateSpec} for further configuration of the update. Guaranteed to be not {@literal null}. + * @return a {@link UpdateMatchingSpec} for further configuration of the update. Guaranteed to be not {@literal null}. */ - UpdateSpec using(T objectToUpdate); + UpdateMatchingSpec using(T objectToUpdate); /** * Use the given {@code tableName} as update target. diff --git a/src/main/java/org/springframework/data/r2dbc/core/DefaultDatabaseClient.java b/src/main/java/org/springframework/data/r2dbc/core/DefaultDatabaseClient.java index 5f0fc0f8..7add9c3d 100644 --- a/src/main/java/org/springframework/data/r2dbc/core/DefaultDatabaseClient.java +++ b/src/main/java/org/springframework/data/r2dbc/core/DefaultDatabaseClient.java @@ -69,6 +69,7 @@ * Default implementation of {@link DatabaseClient}. * * @author Mark Paluch + * @author Bogdan Ilchyshyn */ class DefaultDatabaseClient implements DatabaseClient, ConnectionAccessor { @@ -1196,7 +1197,7 @@ public TypedUpdateSpec table(Class table) { assertRegularClass(table); - return new DefaultTypedUpdateSpec<>(table, null, null); + return new DefaultTypedUpdateSpec<>(table, null, null, null); } } @@ -1275,24 +1276,27 @@ private UpdatedRowsFetchSpec exchange(SqlIdentifier table) { } } - class DefaultTypedUpdateSpec implements TypedUpdateSpec, UpdateSpec { + class DefaultTypedUpdateSpec implements TypedUpdateSpec, UpdateMatchingSpec { private final Class typeToUpdate; private final @Nullable SqlIdentifier table; private final @Nullable T objectToUpdate; + private final @Nullable Criteria where; - DefaultTypedUpdateSpec(Class typeToUpdate, @Nullable SqlIdentifier table, @Nullable T objectToUpdate) { + DefaultTypedUpdateSpec(Class typeToUpdate, @Nullable SqlIdentifier table, @Nullable T objectToUpdate, + @Nullable Criteria where) { this.typeToUpdate = typeToUpdate; this.table = table; this.objectToUpdate = objectToUpdate; + this.where = where; } @Override - public UpdateSpec using(T objectToUpdate) { + public UpdateMatchingSpec using(T objectToUpdate) { Assert.notNull(objectToUpdate, "Object to update must not be null"); - return new DefaultTypedUpdateSpec<>(this.typeToUpdate, this.table, objectToUpdate); + return new DefaultTypedUpdateSpec<>(this.typeToUpdate, this.table, objectToUpdate, this.where); } @Override @@ -1300,7 +1304,15 @@ public TypedUpdateSpec table(SqlIdentifier tableName) { Assert.notNull(tableName, "Table name must not be null!"); - return new DefaultTypedUpdateSpec<>(this.typeToUpdate, tableName, this.objectToUpdate); + return new DefaultTypedUpdateSpec<>(this.typeToUpdate, tableName, this.objectToUpdate, this.where); + } + + @Override + public UpdateSpec matching(Criteria criteria) { + + Assert.notNull(criteria, "Criteria must not be null!"); + + return new DefaultTypedUpdateSpec<>(this.typeToUpdate, this.table, this.objectToUpdate, criteria); } @Override @@ -1343,8 +1355,13 @@ private UpdatedRowsFetchSpec exchange(SqlIdentifier table) { } } + Criteria updateCriteria = Criteria.where(dataAccessStrategy.toSql(ids.get(0))).is(id); + if (this.where != null) { + updateCriteria = updateCriteria.and(this.where); + } + PreparedOperation operation = mapper.getMappedObject( - mapper.createUpdate(table, update).withCriteria(Criteria.where(dataAccessStrategy.toSql(ids.get(0))).is(id))); + mapper.createUpdate(table, update).withCriteria(updateCriteria)); return exchangeUpdate(operation); } diff --git a/src/main/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplate.java b/src/main/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplate.java index 1c9a0877..2e6cea6b 100644 --- a/src/main/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplate.java +++ b/src/main/java/org/springframework/data/r2dbc/core/R2dbcEntityTemplate.java @@ -17,6 +17,7 @@ import io.r2dbc.spi.Row; import io.r2dbc.spi.RowMetadata; +import org.springframework.dao.OptimisticLockingFailureException; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -30,10 +31,12 @@ import org.springframework.beans.BeansException; import org.springframework.beans.factory.BeanFactory; import org.springframework.beans.factory.BeanFactoryAware; +import org.springframework.core.convert.ConversionService; import org.springframework.dao.DataAccessException; import org.springframework.dao.TransientDataAccessResourceException; import org.springframework.data.mapping.IdentifierAccessor; import org.springframework.data.mapping.MappingException; +import org.springframework.data.mapping.PersistentPropertyAccessor; import org.springframework.data.mapping.context.MappingContext; import org.springframework.data.projection.ProjectionInformation; import org.springframework.data.projection.SpelAwareProxyProjectionFactory; @@ -58,6 +61,7 @@ * prepared in an application context and given to services as bean reference. * * @author Mark Paluch + * @author Bogdan Ilchyshyn * @since 1.1 */ public class R2dbcEntityTemplate implements R2dbcEntityOperations, BeanFactoryAware { @@ -372,6 +376,8 @@ Mono doInsert(T entity, SqlIdentifier tableName) { RelationalPersistentEntity persistentEntity = getRequiredEntity(entity); + setVersionIfNecessary(persistentEntity, entity); + return this.databaseClient.insert() // .into(persistentEntity.getType()) // .table(tableName).using(entity) // @@ -380,6 +386,19 @@ Mono doInsert(T entity, SqlIdentifier tableName) { .defaultIfEmpty(entity); } + private void setVersionIfNecessary(RelationalPersistentEntity persistentEntity, T entity) { + RelationalPersistentProperty versionProperty = persistentEntity.getVersionProperty(); + if (versionProperty == null) { + return; + } + + Class versionPropertyType = versionProperty.getType(); + Long version = versionPropertyType.isPrimitive() ? 1L : 0L; + ConversionService conversionService = this.dataAccessStrategy.getConverter().getConversionService(); + PersistentPropertyAccessor propertyAccessor = persistentEntity.getPropertyAccessor(entity); + propertyAccessor.setProperty(versionProperty, conversionService.convert(version, versionPropertyType)); + } + /* * (non-Javadoc) * @see org.springframework.data.r2dbc.core.R2dbcEntityOperations#update(java.lang.Object) @@ -391,21 +410,78 @@ public Mono update(T entity) throws DataAccessException { RelationalPersistentEntity persistentEntity = getRequiredEntity(entity); - return this.databaseClient.update() // + DatabaseClient.UpdateMatchingSpec updateMatchingSpec = this.databaseClient.update() // .table(persistentEntity.getType()) // - .table(persistentEntity.getTableName()).using(entity) // - .fetch().rowsUpdated().handle((rowsUpdated, sink) -> { + .table(persistentEntity.getTableName()) // + .using(entity); + + DatabaseClient.UpdateSpec updateSpec = updateMatchingSpec; + if (persistentEntity.hasVersionProperty()) { + updateSpec = updateMatchingSpec.matching(createMatchingVersionCriteria(entity, persistentEntity)); + incrementVersion(entity, persistentEntity); + } + + return updateSpec.fetch() // + .rowsUpdated() // + .flatMap(rowsUpdated -> rowsUpdated == 0 + ? handleMissingUpdate(entity, persistentEntity) : Mono.just(entity)); + } - if (rowsUpdated == 0) { - sink.error(new TransientDataAccessResourceException( - String.format("Failed to update table [%s]. Row with Id [%s] does not exist.", - persistentEntity.getTableName(), persistentEntity.getIdentifierAccessor(entity).getIdentifier()))); + private Mono handleMissingUpdate(T entity, RelationalPersistentEntity persistentEntity) { + if (!persistentEntity.hasVersionProperty()) { + return Mono.error(new TransientDataAccessResourceException( + formatTransientEntityExceptionMessage(entity, persistentEntity))); + } + + return doCount(getByIdQuery(entity, persistentEntity), entity.getClass(), persistentEntity.getTableName()) + .map(count -> { + if (count == 0) { + throw new TransientDataAccessResourceException( + formatTransientEntityExceptionMessage(entity, persistentEntity)); } else { - sink.next(entity); + throw new OptimisticLockingFailureException( + formatOptimisticLockingExceptionMessage(entity, persistentEntity)); } }); } + private String formatOptimisticLockingExceptionMessage(T entity, RelationalPersistentEntity persistentEntity) { + return String.format("Failed to update table [%s]. Version does not match for row with Id [%s].", + persistentEntity.getTableName(), persistentEntity.getIdentifierAccessor(entity).getIdentifier()); + } + + private String formatTransientEntityExceptionMessage(T entity, RelationalPersistentEntity persistentEntity) { + return String.format("Failed to update table [%s]. Row with Id [%s] does not exist.", + persistentEntity.getTableName(), persistentEntity.getIdentifierAccessor(entity).getIdentifier()); + } + + private void incrementVersion(T entity, RelationalPersistentEntity persistentEntity) { + PersistentPropertyAccessor propertyAccessor = persistentEntity.getPropertyAccessor(entity); + RelationalPersistentProperty versionProperty = persistentEntity.getVersionProperty(); + + ConversionService conversionService = this.dataAccessStrategy.getConverter().getConversionService(); + Object currentVersionValue = propertyAccessor.getProperty(versionProperty); + long newVersionValue = 1L; + if (currentVersionValue != null) { + newVersionValue = conversionService.convert(currentVersionValue, Long.class) + 1; + } + Class versionPropertyType = versionProperty.getType(); + propertyAccessor.setProperty(versionProperty, conversionService.convert(newVersionValue, versionPropertyType)); + } + + private Criteria createMatchingVersionCriteria(T entity, RelationalPersistentEntity persistentEntity) { + PersistentPropertyAccessor propertyAccessor = persistentEntity.getPropertyAccessor(entity); + RelationalPersistentProperty versionProperty = persistentEntity.getVersionProperty(); + + Object version = propertyAccessor.getProperty(versionProperty); + Criteria.CriteriaStep versionColumn = Criteria.where(dataAccessStrategy.toSql(versionProperty.getColumnName())); + if (version == null) { + return versionColumn.isNull(); + } else { + return versionColumn.is(version); + } + } + /* * (non-Javadoc) * @see org.springframework.data.r2dbc.core.R2dbcEntityOperations#delete(java.lang.Object) diff --git a/src/test/java/org/springframework/data/r2dbc/repository/support/AbstractSimpleR2dbcRepositoryIntegrationTests.java b/src/test/java/org/springframework/data/r2dbc/repository/support/AbstractSimpleR2dbcRepositoryIntegrationTests.java index cfe2094e..2d4bc4a1 100644 --- a/src/test/java/org/springframework/data/r2dbc/repository/support/AbstractSimpleR2dbcRepositoryIntegrationTests.java +++ b/src/test/java/org/springframework/data/r2dbc/repository/support/AbstractSimpleR2dbcRepositoryIntegrationTests.java @@ -33,10 +33,11 @@ import org.junit.Before; import org.junit.Test; - import org.springframework.beans.factory.annotation.Autowired; import org.springframework.dao.DataAccessException; +import org.springframework.dao.OptimisticLockingFailureException; import org.springframework.data.annotation.Id; +import org.springframework.data.annotation.Version; import org.springframework.data.domain.Persistable; import org.springframework.data.r2dbc.convert.MappingR2dbcConverter; import org.springframework.data.r2dbc.core.DatabaseClient; @@ -53,6 +54,7 @@ * Abstract integration tests for {@link SimpleR2dbcRepository} to be ran against various databases. * * @author Mark Paluch + * @author Bogdan Ilchyshyn */ public abstract class AbstractSimpleR2dbcRepositoryIntegrationTests extends R2dbcIntegrationTestSupport { @@ -117,6 +119,42 @@ public void shouldSaveNewObject() { assertThat(map).containsEntry("name", "SCHAUFELRADBAGGER").containsEntry("manual", 12).containsKey("id"); } + @Test + public void shouldSaveNewObjectAndSetVersionIfWrapperVersionPropertyExists() { + + LegoSetVersionable legoSet = new LegoSetVersionable(null, "SCHAUFELRADBAGGER", 12, null); + + repository.save(legoSet) // + .as(StepVerifier::create) // + .consumeNextWith(actual -> assertThat(actual.getVersion()).isEqualTo(0)) // + .verifyComplete(); + + Map map = jdbc.queryForMap("SELECT * FROM legoset"); + assertThat(map) // + .containsEntry("name", "SCHAUFELRADBAGGER") // + .containsEntry("manual", 12) // + .containsEntry("version", 0) // + .containsKey("id"); + } + + @Test + public void shouldSaveNewObjectAndSetVersionIfPrimitiveVersionPropertyExists() { + + LegoSetPrimitiveVersionable legoSet = new LegoSetPrimitiveVersionable(null, "SCHAUFELRADBAGGER", 12, -1); + + repository.save(legoSet) // + .as(StepVerifier::create) // + .consumeNextWith(actual -> assertThat(actual.getVersion()).isEqualTo(1)) // + .verifyComplete(); + + Map map = jdbc.queryForMap("SELECT * FROM legoset"); + assertThat(map) // + .containsEntry("name", "SCHAUFELRADBAGGER") // + .containsEntry("manual", 12) // + .containsEntry("version", 1) // + .containsKey("id"); + } + @Test public void shouldUpdateObject() { @@ -135,6 +173,44 @@ public void shouldUpdateObject() { assertThat(map).containsEntry("name", "SCHAUFELRADBAGGER").containsEntry("manual", 14).containsKey("id"); } + @Test + public void shouldUpdateVersionableObjectAndIncreaseVersion() { + + jdbc.execute("INSERT INTO legoset (name, manual, version) VALUES('SCHAUFELRADBAGGER', 12, 42)"); + Integer id = jdbc.queryForObject("SELECT id FROM legoset", Integer.class); + + LegoSetVersionable legoSet = new LegoSetVersionable(id, "SCHAUFELRADBAGGER", 12, 42); + legoSet.setManual(14); + + repository.save(legoSet) // + .as(StepVerifier::create) // + .expectNextCount(1) // + .verifyComplete(); + + assertThat(legoSet.getVersion()).isEqualTo(43); + + Map map = jdbc.queryForMap("SELECT * FROM legoset"); + assertThat(map) + .containsEntry("name", "SCHAUFELRADBAGGER") // + .containsEntry("manual", 14) // + .containsEntry("version", 43) // + .containsKey("id"); + } + + @Test + public void shouldFailWithOptimistickLockingWhenVersionDoesNotMatchOnUpdate() { + + jdbc.execute("INSERT INTO legoset (name, manual, version) VALUES('SCHAUFELRADBAGGER', 12, 42)"); + Integer id = jdbc.queryForObject("SELECT id FROM legoset", Integer.class); + + LegoSetVersionable legoSet = new LegoSetVersionable(id, "SCHAUFELRADBAGGER", 12, 0); + + repository.save(legoSet) // + .as(StepVerifier::create) // + .expectError(OptimisticLockingFailureException.class) // + .verify(); + } + @Test public void shouldSaveObjectsUsingIterable() { @@ -392,4 +468,28 @@ public boolean isNew() { return true; } } + + @Data + @Table("legoset") + @NoArgsConstructor + static class LegoSetVersionable extends LegoSet { + @Version Integer version; + + public LegoSetVersionable(Integer id, String name, Integer manual, Integer version) { + super(id, name, manual); + this.version = version; + } + } + + @Data + @Table("legoset") + @NoArgsConstructor + static class LegoSetPrimitiveVersionable extends LegoSet { + @Version int version; + + public LegoSetPrimitiveVersionable(Integer id, String name, Integer manual, int version) { + super(id, name, manual); + this.version = version; + } + } } diff --git a/src/test/java/org/springframework/data/r2dbc/testing/H2TestSupport.java b/src/test/java/org/springframework/data/r2dbc/testing/H2TestSupport.java index c6202ecd..50e1f74a 100644 --- a/src/test/java/org/springframework/data/r2dbc/testing/H2TestSupport.java +++ b/src/test/java/org/springframework/data/r2dbc/testing/H2TestSupport.java @@ -27,11 +27,13 @@ * Utility class for testing against H2. * * @author Mark Paluch + * @author Bogdan Ilchyshyn */ public class H2TestSupport { public static String CREATE_TABLE_LEGOSET = "CREATE TABLE legoset (\n" // + " id integer CONSTRAINT id PRIMARY KEY,\n" // + + " version integer NULL,\n" // + " name varchar(255) NOT NULL,\n" // + " manual integer NULL\n," // + " cert bytea NULL\n" // @@ -39,6 +41,7 @@ public class H2TestSupport { public static String CREATE_TABLE_LEGOSET_WITH_ID_GENERATION = "CREATE TABLE legoset (\n" // + " id serial CONSTRAINT id PRIMARY KEY,\n" // + + " version integer NULL,\n" // + " name varchar(255) NOT NULL,\n" // + " manual integer NULL\n" // + ");"; diff --git a/src/test/java/org/springframework/data/r2dbc/testing/MySqlTestSupport.java b/src/test/java/org/springframework/data/r2dbc/testing/MySqlTestSupport.java index ee57f0b5..ee299522 100644 --- a/src/test/java/org/springframework/data/r2dbc/testing/MySqlTestSupport.java +++ b/src/test/java/org/springframework/data/r2dbc/testing/MySqlTestSupport.java @@ -35,6 +35,7 @@ * Utility class for testing against MySQL. * * @author Mark Paluch + * @author Bogdan Ilchyshyn */ public class MySqlTestSupport { @@ -42,6 +43,7 @@ public class MySqlTestSupport { public static String CREATE_TABLE_LEGOSET = "CREATE TABLE legoset (\n" // + " id integer PRIMARY KEY,\n" // + + " version integer NULL,\n" // + " name varchar(255) NOT NULL,\n" // + " manual integer NULL\n," // + " cert varbinary(255) NULL\n" // @@ -49,6 +51,7 @@ public class MySqlTestSupport { public static String CREATE_TABLE_LEGOSET_WITH_ID_GENERATION = "CREATE TABLE legoset (\n" // + " id integer AUTO_INCREMENT PRIMARY KEY,\n" // + + " version integer NULL,\n" // + " name varchar(255) NOT NULL,\n" // + " manual integer NULL\n" // + ") ENGINE=InnoDB;"; diff --git a/src/test/java/org/springframework/data/r2dbc/testing/PostgresTestSupport.java b/src/test/java/org/springframework/data/r2dbc/testing/PostgresTestSupport.java index 2152e47a..37b4075d 100644 --- a/src/test/java/org/springframework/data/r2dbc/testing/PostgresTestSupport.java +++ b/src/test/java/org/springframework/data/r2dbc/testing/PostgresTestSupport.java @@ -18,6 +18,7 @@ * * @author Mark Paluch * @author Jens Schauder + * @author Bogdan Ilchyshyn */ public class PostgresTestSupport { @@ -25,6 +26,7 @@ public class PostgresTestSupport { public static String CREATE_TABLE_LEGOSET = "CREATE TABLE legoset (\n" // + " id integer CONSTRAINT id PRIMARY KEY,\n" // + + " version integer NULL,\n" // + " name varchar(255) NOT NULL,\n" // + " manual integer NULL\n," // + " cert bytea NULL\n" // @@ -32,6 +34,7 @@ public class PostgresTestSupport { public static String CREATE_TABLE_LEGOSET_WITH_ID_GENERATION = "CREATE TABLE legoset (\n" // + " id serial CONSTRAINT id PRIMARY KEY,\n" // + + " version integer NULL,\n" // + " name varchar(255) NOT NULL,\n" // + " manual integer NULL\n" // + ");"; diff --git a/src/test/java/org/springframework/data/r2dbc/testing/SqlServerTestSupport.java b/src/test/java/org/springframework/data/r2dbc/testing/SqlServerTestSupport.java index 6ab05695..5a49d2fc 100644 --- a/src/test/java/org/springframework/data/r2dbc/testing/SqlServerTestSupport.java +++ b/src/test/java/org/springframework/data/r2dbc/testing/SqlServerTestSupport.java @@ -12,11 +12,13 @@ * Utility class for testing against Microsoft SQL Server. * * @author Mark Paluch + * @author Bogdan Ilchyshyn */ public class SqlServerTestSupport { public static String CREATE_TABLE_LEGOSET = "CREATE TABLE legoset (\n" // + " id integer PRIMARY KEY,\n" // + + " version integer NULL,\n" // + " name varchar(255) NOT NULL,\n" // + " manual integer NULL\n," // + " cert varbinary(255) NULL\n" // @@ -24,6 +26,7 @@ public class SqlServerTestSupport { public static String CREATE_TABLE_LEGOSET_WITH_ID_GENERATION = "CREATE TABLE legoset (\n" // + " id integer IDENTITY(1,1) PRIMARY KEY,\n" // + + " version integer NULL,\n" // + " name varchar(255) NOT NULL,\n" // + " manual integer NULL\n" // + ");";