diff --git a/pom.xml b/pom.xml index 2f3b6ed4..0a181e3a 100644 --- a/pom.xml +++ b/pom.xml @@ -7,7 +7,7 @@ org.springframework.data spring-data-r2dbc - 1.0.0.BUILD-SNAPSHOT + 1.0.0.gh-61-SNAPSHOT Spring Data R2DBC Spring Data module for R2DBC. diff --git a/src/main/java/org/springframework/data/r2dbc/config/AbstractR2dbcConfiguration.java b/src/main/java/org/springframework/data/r2dbc/config/AbstractR2dbcConfiguration.java index 34e1b515..d32a74df 100644 --- a/src/main/java/org/springframework/data/r2dbc/config/AbstractR2dbcConfiguration.java +++ b/src/main/java/org/springframework/data/r2dbc/config/AbstractR2dbcConfiguration.java @@ -30,6 +30,7 @@ import org.springframework.data.r2dbc.function.DatabaseClient; import org.springframework.data.r2dbc.function.DefaultReactiveDataAccessStrategy; import org.springframework.data.r2dbc.function.ReactiveDataAccessStrategy; +import org.springframework.data.r2dbc.function.convert.MappingR2dbcConverter; import org.springframework.data.r2dbc.function.convert.R2dbcCustomConversions; import org.springframework.data.r2dbc.support.R2dbcExceptionTranslator; import org.springframework.data.r2dbc.support.SqlErrorCodeR2dbcExceptionTranslator; @@ -118,8 +119,8 @@ public RelationalMappingContext r2dbcMappingContext(Optional nam } /** - * Creates a {@link ReactiveDataAccessStrategy} using the configured {@link #r2dbcMappingContext(Optional, R2dbcCustomConversions)} - * RelationalMappingContext}. + * Creates a {@link ReactiveDataAccessStrategy} using the configured + * {@link #r2dbcMappingContext(Optional, R2dbcCustomConversions)} RelationalMappingContext}. * * @param mappingContext the configured {@link RelationalMappingContext}. * @param r2dbcCustomConversions customized R2DBC conversions. @@ -134,7 +135,7 @@ public ReactiveDataAccessStrategy reactiveDataAccessStrategy(RelationalMappingCo Assert.notNull(mappingContext, "MappingContext must not be null!"); - BasicRelationalConverter converter = new BasicRelationalConverter(mappingContext, r2dbcCustomConversions); + MappingR2dbcConverter converter = new MappingR2dbcConverter(mappingContext, r2dbcCustomConversions); return new DefaultReactiveDataAccessStrategy(getDialect(connectionFactory()), converter); } diff --git a/src/main/java/org/springframework/data/r2dbc/function/BindableOperation.java b/src/main/java/org/springframework/data/r2dbc/function/BindableOperation.java index 29c3c0b1..8038c3b0 100644 --- a/src/main/java/org/springframework/data/r2dbc/function/BindableOperation.java +++ b/src/main/java/org/springframework/data/r2dbc/function/BindableOperation.java @@ -45,12 +45,12 @@ public interface BindableOperation extends QueryOperation { * @see Statement#bind * @see Statement#bindNull */ - default void bind(Statement statement, SettableValue value) { + default void bind(Statement statement, String identifier, SettableValue value) { if (value.getValue() == null) { - bindNull(statement, value.getIdentifier().toString(), value.getType()); + bindNull(statement, identifier, value.getType()); } else { - bind(statement, value.getIdentifier().toString(), value.getValue()); + bind(statement, identifier, value.getValue()); } } diff --git a/src/main/java/org/springframework/data/r2dbc/function/DefaultDatabaseClient.java b/src/main/java/org/springframework/data/r2dbc/function/DefaultDatabaseClient.java index 6d0ed6f9..8da6d444 100644 --- a/src/main/java/org/springframework/data/r2dbc/function/DefaultDatabaseClient.java +++ b/src/main/java/org/springframework/data/r2dbc/function/DefaultDatabaseClient.java @@ -47,12 +47,14 @@ import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; import org.reactivestreams.Publisher; + import org.springframework.dao.DataAccessException; import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Sort; import org.springframework.data.r2dbc.UncategorizedR2dbcException; import org.springframework.data.r2dbc.function.connectionfactory.ConnectionProxy; import org.springframework.data.r2dbc.function.convert.ColumnMapRowMapper; +import org.springframework.data.r2dbc.function.convert.OutboundRow; import org.springframework.data.r2dbc.function.convert.SettableValue; import org.springframework.data.r2dbc.support.R2dbcExceptionTranslator; import org.springframework.jdbc.core.SqlProvider; @@ -365,8 +367,10 @@ FetchSpec exchange(String sql, BiFunction mappingFun public ExecuteSpecSupport bind(int index, Object value) { + Assert.notNull(value, () -> String.format("Value at index %d must not be null. Use bindNull(…) instead.", index)); + Map byIndex = new LinkedHashMap<>(this.byIndex); - byIndex.put(index, new SettableValue(index, value, null)); + byIndex.put(index, new SettableValue(value, value.getClass())); return createInstance(byIndex, this.byName, this.sqlSupplier); } @@ -374,7 +378,7 @@ public ExecuteSpecSupport bind(int index, Object value) { public ExecuteSpecSupport bindNull(int index, Class type) { Map byIndex = new LinkedHashMap<>(this.byIndex); - byIndex.put(index, new SettableValue(index, null, type)); + byIndex.put(index, new SettableValue(null, type)); return createInstance(byIndex, this.byName, this.sqlSupplier); } @@ -382,9 +386,11 @@ public ExecuteSpecSupport bindNull(int index, Class type) { public ExecuteSpecSupport bind(String name, Object value) { Assert.hasText(name, "Parameter name must not be null or empty!"); + Assert.notNull(value, + () -> String.format("Value for parameter %s must not be null. Use bindNull(…) instead.", name)); Map byName = new LinkedHashMap<>(this.byName); - byName.put(name, new SettableValue(name, value, null)); + byName.put(name, new SettableValue(value, value.getClass())); return createInstance(this.byIndex, byName, this.sqlSupplier); } @@ -394,7 +400,7 @@ public ExecuteSpecSupport bindNull(String name, Class type) { Assert.hasText(name, "Parameter name must not be null or empty!"); Map byName = new LinkedHashMap<>(this.byName); - byName.put(name, new SettableValue(name, null, type)); + byName.put(name, new SettableValue(null, type)); return createInstance(this.byIndex, byName, this.sqlSupplier); } @@ -832,9 +838,11 @@ class DefaultGenericInsertSpec implements GenericInsertSpec { public GenericInsertSpec value(String field, Object value) { Assert.notNull(field, "Field must not be null!"); + Assert.notNull(value, + () -> String.format("Value for field %s must not be null. Use nullValue(…) instead.", field)); Map byName = new LinkedHashMap<>(this.byName); - byName.put(field, new SettableValue(field, value, null)); + byName.put(field, new SettableValue(value, value.getClass())); return new DefaultGenericInsertSpec<>(this.table, byName, this.mappingFunction); } @@ -845,7 +853,7 @@ public GenericInsertSpec nullValue(String field, Class type) { Assert.notNull(field, "Field must not be null!"); Map byName = new LinkedHashMap<>(this.byName); - byName.put(field, new SettableValue(field, null, type)); + byName.put(field, new SettableValue(null, type)); return new DefaultGenericInsertSpec<>(this.table, byName, this.mappingFunction); } @@ -885,7 +893,7 @@ private FetchSpec exchange(BiFunction mappingFunctio Statement statement = it.createStatement(sql).returnGeneratedValues(); - byName.forEach((k, v) -> bindableInsert.bind(statement, v)); + byName.forEach((k, v) -> bindableInsert.bind(statement, k, v)); return statement; }; @@ -989,12 +997,16 @@ public Mono rowsUpdated() { private FetchSpec exchange(Object toInsert, BiFunction mappingFunction) { - List insertValues = dataAccessStrategy.getValuesToInsert(toInsert); + OutboundRow outboundRow = dataAccessStrategy.getOutboundRow(toInsert); + Set columns = new LinkedHashSet<>(); - for (SettableValue insertValue : insertValues) { - columns.add(insertValue.getIdentifier().toString()); - } + outboundRow.forEach((k, v) -> { + + if (v.hasValue()) { + columns.add(k); + } + }); BindableOperation bindableInsert = dataAccessStrategy.insertAndReturnGeneratedKeys(table, columns); @@ -1008,9 +1020,11 @@ private FetchSpec exchange(Object toInsert, BiFunction { + if (v.hasValue()) { + bindableInsert.bind(statement, k, v); + } + }); return statement; }; diff --git a/src/main/java/org/springframework/data/r2dbc/function/DefaultReactiveDataAccessStrategy.java b/src/main/java/org/springframework/data/r2dbc/function/DefaultReactiveDataAccessStrategy.java index eefc3d97..ca482d1f 100644 --- a/src/main/java/org/springframework/data/r2dbc/function/DefaultReactiveDataAccessStrategy.java +++ b/src/main/java/org/springframework/data/r2dbc/function/DefaultReactiveDataAccessStrategy.java @@ -19,7 +19,6 @@ import io.r2dbc.spi.RowMetadata; import io.r2dbc.spi.Statement; -import java.lang.reflect.Array; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; @@ -31,13 +30,11 @@ import java.util.function.BiFunction; import java.util.function.Function; -import org.springframework.dao.InvalidDataAccessApiUsageException; import org.springframework.dao.InvalidDataAccessResourceUsageException; import org.springframework.data.convert.CustomConversions.StoreConversions; import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Sort; import org.springframework.data.domain.Sort.Order; -import org.springframework.data.mapping.PersistentPropertyAccessor; import org.springframework.data.mapping.context.MappingContext; import org.springframework.data.r2dbc.dialect.ArrayColumns; import org.springframework.data.r2dbc.dialect.BindMarker; @@ -45,11 +42,12 @@ import org.springframework.data.r2dbc.dialect.BindMarkersFactory; import org.springframework.data.r2dbc.dialect.Dialect; import org.springframework.data.r2dbc.function.convert.EntityRowMapper; +import org.springframework.data.r2dbc.function.convert.MappingR2dbcConverter; +import org.springframework.data.r2dbc.function.convert.OutboundRow; +import org.springframework.data.r2dbc.function.convert.R2dbcConverter; import org.springframework.data.r2dbc.function.convert.R2dbcCustomConversions; import org.springframework.data.r2dbc.function.convert.SettableValue; import org.springframework.data.r2dbc.support.StatementRenderUtil; -import org.springframework.data.relational.core.conversion.BasicRelationalConverter; -import org.springframework.data.relational.core.conversion.RelationalConverter; import org.springframework.data.relational.core.mapping.RelationalMappingContext; import org.springframework.data.relational.core.mapping.RelationalPersistentEntity; import org.springframework.data.relational.core.mapping.RelationalPersistentProperty; @@ -58,7 +56,6 @@ import org.springframework.data.relational.core.sql.SelectBuilder.SelectFromAndOrderBy; import org.springframework.data.relational.core.sql.StatementBuilder; import org.springframework.data.relational.core.sql.Table; -import org.springframework.data.util.TypeInformation; import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; @@ -72,7 +69,7 @@ public class DefaultReactiveDataAccessStrategy implements ReactiveDataAccessStrategy { private final Dialect dialect; - private final RelationalConverter relationalConverter; + private final R2dbcConverter converter; private final MappingContext, ? extends RelationalPersistentProperty> mappingContext; /** @@ -84,7 +81,7 @@ public DefaultReactiveDataAccessStrategy(Dialect dialect) { this(dialect, createConverter(dialect)); } - private static BasicRelationalConverter createConverter(Dialect dialect) { + private static R2dbcConverter createConverter(Dialect dialect) { Assert.notNull(dialect, "Dialect must not be null"); @@ -94,11 +91,11 @@ private static BasicRelationalConverter createConverter(Dialect dialect) { RelationalMappingContext context = new RelationalMappingContext(); context.setSimpleTypeHolder(customConversions.getSimpleTypeHolder()); - return new BasicRelationalConverter(context, customConversions); + return new MappingR2dbcConverter(context, customConversions); } - public RelationalConverter getRelationalConverter() { - return relationalConverter; + public R2dbcConverter getConverter() { + return converter; } public MappingContext, ? extends RelationalPersistentProperty> getMappingContext() { @@ -106,19 +103,19 @@ public RelationalConverter getRelationalConverter() { } /** - * Creates a new {@link DefaultReactiveDataAccessStrategy} given {@link Dialect} and {@link RelationalConverter}. + * Creates a new {@link DefaultReactiveDataAccessStrategy} given {@link Dialect} and {@link R2dbcConverter}. * * @param dialect the {@link Dialect} to use. * @param converter must not be {@literal null}. */ @SuppressWarnings("unchecked") - public DefaultReactiveDataAccessStrategy(Dialect dialect, RelationalConverter converter) { + public DefaultReactiveDataAccessStrategy(Dialect dialect, R2dbcConverter converter) { Assert.notNull(dialect, "Dialect must not be null"); Assert.notNull(converter, "RelationalConverter must not be null"); - this.relationalConverter = converter; - this.mappingContext = (MappingContext, ? extends RelationalPersistentProperty>) relationalConverter + this.converter = converter; + this.mappingContext = (MappingContext, ? extends RelationalPersistentProperty>) this.converter .getMappingContext(); this.dialect = dialect; } @@ -146,55 +143,47 @@ public List getAllColumns(Class typeToRead) { /* * (non-Javadoc) - * @see org.springframework.data.r2dbc.function.ReactiveDataAccessStrategy#getValuesToInsert(java.lang.Object) + * @see org.springframework.data.r2dbc.function.ReactiveDataAccessStrategy#getOutboundRow(java.lang.Object) */ - @Override - public List getValuesToInsert(Object object) { + public OutboundRow getOutboundRow(Object object) { + + Assert.notNull(object, "Entity object must not be null!"); - Class userClass = ClassUtils.getUserClass(object); + OutboundRow row = new OutboundRow(); - RelationalPersistentEntity entity = getRequiredPersistentEntity(userClass); - PersistentPropertyAccessor propertyAccessor = entity.getPropertyAccessor(object); + converter.write(object, row); - List values = new ArrayList<>(); + RelationalPersistentEntity entity = getRequiredPersistentEntity(ClassUtils.getUserClass(object)); for (RelationalPersistentProperty property : entity) { - Object value = getWriteValue(propertyAccessor, property); + SettableValue value = row.get(property.getColumnName()); + if (shouldConvertArrayValue(property, value)) { - if (value == null) { - continue; + SettableValue writeValue = getArrayValue(value, property); + row.put(property.getColumnName(), writeValue); } - - values.add(new SettableValue(property.getColumnName(), value, property.getType())); } - return values; + return row; } - /* - * (non-Javadoc) - * @see org.springframework.data.r2dbc.function.ReactiveDataAccessStrategy#getColumnsToUpdate(java.lang.Object) - */ - public Map getColumnsToUpdate(Object object) { - - Assert.notNull(object, "Entity object must not be null!"); - - Class userClass = ClassUtils.getUserClass(object); - RelationalPersistentEntity entity = getRequiredPersistentEntity(userClass); + private boolean shouldConvertArrayValue(RelationalPersistentProperty property, SettableValue value) { + return value != null && value.hasValue() && property.isCollectionLike(); + } - Map update = new LinkedHashMap<>(); + private SettableValue getArrayValue(SettableValue value, RelationalPersistentProperty property) { - PersistentPropertyAccessor propertyAccessor = entity.getPropertyAccessor(object); + ArrayColumns arrayColumns = dialect.getArraySupport(); - for (RelationalPersistentProperty property : entity) { + if (!arrayColumns.isSupported()) { - Object writeValue = getWriteValue(propertyAccessor, property); - - update.put(property.getColumnName(), new SettableValue(property.getColumnName(), writeValue, property.getType())); + throw new InvalidDataAccessResourceUsageException( + "Dialect " + dialect.getClass().getName() + " does not support array columns"); } - return update; + return new SettableValue(converter.getArrayValue(arrayColumns, property, value.getValue()), + property.getActualType()); } /* @@ -232,8 +221,7 @@ public Sort getMappedSort(Class typeToRead, Sort sort) { @SuppressWarnings("unchecked") @Override public BiFunction getRowMapper(Class typeToRead) { - return new EntityRowMapper((RelationalPersistentEntity) getRequiredPersistentEntity(typeToRead), - relationalConverter); + return new EntityRowMapper<>(typeToRead, converter); } /* @@ -263,48 +251,6 @@ private RelationalPersistentEntity getPersistentEntity(Class typeToRead) { return mappingContext.getPersistentEntity(typeToRead); } - @SuppressWarnings("unchecked") - private Object getWriteValue(PersistentPropertyAccessor propertyAccessor, RelationalPersistentProperty property) { - - TypeInformation type = property.getTypeInformation(); - Object value = propertyAccessor.getProperty(property); - - if (type.isCollectionLike()) { - - RelationalPersistentEntity nestedEntity = mappingContext - .getPersistentEntity(type.getRequiredActualType().getType()); - - if (nestedEntity != null) { - throw new InvalidDataAccessApiUsageException("Nested entities are not supported"); - } - - ArrayColumns arrayColumns = dialect.getArraySupport(); - - if (!arrayColumns.isSupported()) { - - throw new InvalidDataAccessResourceUsageException( - "Dialect " + dialect.getClass().getName() + " does not support array columns"); - } - - return getArrayValue(arrayColumns, property, value); - } - - return value; - } - - private Object getArrayValue(ArrayColumns arrayColumns, RelationalPersistentProperty property, Object value) { - - Class targetType = arrayColumns.getArrayType(property.getActualType()); - - if (!property.isArray() || !property.getActualType().equals(targetType)) { - - Object zeroLengthArray = Array.newInstance(targetType, 0); - return relationalConverter.getConversionService().convert(value, zeroLengthArray.getClass()); - } - - return value; - } - /* * (non-Javadoc) * @see org.springframework.data.r2dbc.function.ReactiveDataAccessStrategy#insertAndReturnGeneratedKeys(java.lang.String, java.util.Set) diff --git a/src/main/java/org/springframework/data/r2dbc/function/MapBindParameterSource.java b/src/main/java/org/springframework/data/r2dbc/function/MapBindParameterSource.java index 09de5d88..a6d79f4d 100644 --- a/src/main/java/org/springframework/data/r2dbc/function/MapBindParameterSource.java +++ b/src/main/java/org/springframework/data/r2dbc/function/MapBindParameterSource.java @@ -65,7 +65,7 @@ MapBindParameterSource addValue(String paramName, Object value) { Assert.notNull(paramName, "Parameter name must not be null!"); Assert.notNull(value, "Value must not be null!"); - this.values.put(paramName, new SettableValue(paramName, value, value.getClass())); + this.values.put(paramName, new SettableValue(value, value.getClass())); return this; } diff --git a/src/main/java/org/springframework/data/r2dbc/function/ReactiveDataAccessStrategy.java b/src/main/java/org/springframework/data/r2dbc/function/ReactiveDataAccessStrategy.java index 2d3b1b5b..aff9f794 100644 --- a/src/main/java/org/springframework/data/r2dbc/function/ReactiveDataAccessStrategy.java +++ b/src/main/java/org/springframework/data/r2dbc/function/ReactiveDataAccessStrategy.java @@ -20,13 +20,14 @@ import io.r2dbc.spi.Statement; import java.util.List; -import java.util.Map; import java.util.Set; import java.util.function.BiFunction; import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Sort; import org.springframework.data.r2dbc.dialect.BindMarkersFactory; +import org.springframework.data.r2dbc.function.convert.OutboundRow; +import org.springframework.data.r2dbc.function.convert.R2dbcConverter; import org.springframework.data.r2dbc.function.convert.SettableValue; /** @@ -46,18 +47,12 @@ public interface ReactiveDataAccessStrategy { List getAllColumns(Class typeToRead); /** - * @param object - * @return {@link SettableValue} that represent an {@code INSERT} of {@code object}. - */ - List getValuesToInsert(Object object); - - /** - * Returns a {@link Map} that maps column names to a {@link SettableValue} value. + * Returns a {@link OutboundRow} that maps column names to a {@link SettableValue} value. * * @param object must not be {@literal null}. * @return */ - Map getColumnsToUpdate(Object object); + OutboundRow getOutboundRow(Object object); /** * Map the {@link Sort} object to apply field name mapping using {@link Class the type to read}. @@ -84,6 +79,13 @@ public interface ReactiveDataAccessStrategy { */ BindMarkersFactory getBindMarkersFactory(); + /** + * Returns the {@link R2dbcConverter}. + * + * @return the {@link R2dbcConverter}. + */ + R2dbcConverter getConverter(); + // ------------------------------------------------------------------------- // Methods creating SQL operations. // Subject to be moved into a SQL creation DSL. diff --git a/src/main/java/org/springframework/data/r2dbc/function/convert/EntityRowMapper.java b/src/main/java/org/springframework/data/r2dbc/function/convert/EntityRowMapper.java index f84f0bf2..dfb6e119 100644 --- a/src/main/java/org/springframework/data/r2dbc/function/convert/EntityRowMapper.java +++ b/src/main/java/org/springframework/data/r2dbc/function/convert/EntityRowMapper.java @@ -17,23 +17,9 @@ import io.r2dbc.spi.Row; import io.r2dbc.spi.RowMetadata; -import lombok.NonNull; -import lombok.RequiredArgsConstructor; -import java.sql.ResultSet; import java.util.function.BiFunction; -import org.springframework.data.mapping.MappingException; -import org.springframework.data.mapping.PersistentProperty; -import org.springframework.data.mapping.PersistentPropertyAccessor; -import org.springframework.data.mapping.PreferredConstructor.Parameter; -import org.springframework.data.mapping.model.ConvertingPropertyAccessor; -import org.springframework.data.mapping.model.ParameterValueProvider; -import org.springframework.data.relational.core.conversion.RelationalConverter; -import org.springframework.data.relational.core.mapping.RelationalPersistentEntity; -import org.springframework.data.relational.core.mapping.RelationalPersistentProperty; -import org.springframework.lang.Nullable; - /** * Maps a {@link io.r2dbc.spi.Row} to an entity of type {@code T}, including entities referenced. * @@ -42,12 +28,12 @@ */ public class EntityRowMapper implements BiFunction { - private final RelationalPersistentEntity entity; - private final RelationalConverter converter; + private final Class typeRoRead; + private final R2dbcConverter converter; - public EntityRowMapper(RelationalPersistentEntity entity, RelationalConverter converter) { + public EntityRowMapper(Class typeRoRead, R2dbcConverter converter) { - this.entity = entity; + this.typeRoRead = typeRoRead; this.converter = converter; } @@ -57,110 +43,6 @@ public EntityRowMapper(RelationalPersistentEntity entity, RelationalConverter */ @Override public T apply(Row row, RowMetadata metadata) { - - T result = createInstance(row, "", entity); - - ConvertingPropertyAccessor propertyAccessor = new ConvertingPropertyAccessor<>( - entity.getPropertyAccessor(result), converter.getConversionService()); - - for (RelationalPersistentProperty property : entity) { - - if (entity.isConstructorArgument(property)) { - continue; - } - - if (property.isMap()) { - throw new UnsupportedOperationException(); - } else { - propertyAccessor.setProperty(property, readFrom(row, property, "")); - } - } - - return result; - } - - /** - * Read a single value or a complete Entity from the {@link ResultSet} passed as an argument. - * - * @param row the {@link Row} to extract the value from. Must not be {@literal null}. - * @param property the {@link RelationalPersistentProperty} for which the value is intended. Must not be - * {@literal null}. - * @param prefix to be used for all column names accessed by this method. Must not be {@literal null}. - * @return the value read from the {@link ResultSet}. May be {@literal null}. - */ - private Object readFrom(Row row, RelationalPersistentProperty property, String prefix) { - - try { - - if (property.isEntity()) { - return readEntityFrom(row, property); - } - - Object value = row.get(prefix + property.getColumnName()); - return converter.readValue(value, property.getTypeInformation()); - - } catch (Exception o_O) { - throw new MappingException(String.format("Could not read property %s from result set!", property), o_O); - } - } - - private S readEntityFrom(Row row, PersistentProperty property) { - - String prefix = property.getName() + "_"; - - RelationalPersistentEntity entity = (RelationalPersistentEntity) converter.getMappingContext() - .getRequiredPersistentEntity(property.getActualType()); - - if (readFrom(row, entity.getRequiredIdProperty(), prefix) == null) { - return null; - } - - S instance = createInstance(row, prefix, entity); - - PersistentPropertyAccessor accessor = entity.getPropertyAccessor(instance); - ConvertingPropertyAccessor propertyAccessor = new ConvertingPropertyAccessor<>(accessor, - converter.getConversionService()); - - for (RelationalPersistentProperty p : entity) { - if (!entity.isConstructorArgument(property)) { - propertyAccessor.setProperty(p, readFrom(row, p, prefix)); - } - } - - return instance; - } - - private S createInstance(Row row, String prefix, RelationalPersistentEntity entity) { - - RowParameterValueProvider rowParameterValueProvider = new RowParameterValueProvider(row, entity, converter, prefix); - - return converter.createInstance(entity, rowParameterValueProvider::getParameterValue); - } - - @RequiredArgsConstructor - private static class RowParameterValueProvider implements ParameterValueProvider { - - private final @NonNull Row resultSet; - private final @NonNull RelationalPersistentEntity entity; - private final @NonNull RelationalConverter converter; - private final @NonNull String prefix; - - /* - * (non-Javadoc) - * @see org.springframework.data.mapping.model.ParameterValueProvider#getParameterValue(org.springframework.data.mapping.PreferredConstructor.Parameter) - */ - @Override - @Nullable - public T getParameterValue(Parameter parameter) { - - RelationalPersistentProperty property = entity.getRequiredPersistentProperty(parameter.getName()); - String column = prefix + property.getColumnName(); - - try { - return converter.getConversionService().convert(resultSet.get(column), parameter.getType().getType()); - } catch (Exception o_O) { - throw new MappingException(String.format("Couldn't read column %s from Row.", column), o_O); - } - } + return converter.read(typeRoRead, row); } } diff --git a/src/main/java/org/springframework/data/r2dbc/function/convert/MappingR2dbcConverter.java b/src/main/java/org/springframework/data/r2dbc/function/convert/MappingR2dbcConverter.java index 58f337e6..e24c2dbb 100644 --- a/src/main/java/org/springframework/data/r2dbc/function/convert/MappingR2dbcConverter.java +++ b/src/main/java/org/springframework/data/r2dbc/function/convert/MappingR2dbcConverter.java @@ -18,18 +18,32 @@ import io.r2dbc.spi.ColumnMetadata; import io.r2dbc.spi.Row; import io.r2dbc.spi.RowMetadata; +import lombok.NonNull; +import lombok.RequiredArgsConstructor; +import java.lang.reflect.Array; +import java.util.Collections; import java.util.LinkedHashMap; import java.util.Map; import java.util.function.BiFunction; import org.springframework.core.convert.ConversionService; +import org.springframework.dao.InvalidDataAccessApiUsageException; +import org.springframework.data.convert.CustomConversions; +import org.springframework.data.mapping.MappingException; +import org.springframework.data.mapping.PersistentProperty; import org.springframework.data.mapping.PersistentPropertyAccessor; +import org.springframework.data.mapping.PreferredConstructor.Parameter; import org.springframework.data.mapping.context.MappingContext; +import org.springframework.data.mapping.model.ConvertingPropertyAccessor; +import org.springframework.data.mapping.model.ParameterValueProvider; +import org.springframework.data.r2dbc.dialect.ArrayColumns; import org.springframework.data.relational.core.conversion.BasicRelationalConverter; import org.springframework.data.relational.core.conversion.RelationalConverter; import org.springframework.data.relational.core.mapping.RelationalPersistentEntity; import org.springframework.data.relational.core.mapping.RelationalPersistentProperty; +import org.springframework.data.util.TypeInformation; +import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; @@ -38,9 +52,7 @@ * * @author Mark Paluch */ -public class MappingR2dbcConverter { - - private final RelationalConverter relationalConverter; +public class MappingR2dbcConverter extends BasicRelationalConverter implements R2dbcConverter { /** * Creates a new {@link MappingR2dbcConverter} given {@link MappingContext}. @@ -49,21 +61,159 @@ public class MappingR2dbcConverter { */ public MappingR2dbcConverter( MappingContext, ? extends RelationalPersistentProperty> context) { - this(new BasicRelationalConverter(context)); + super(context, new R2dbcCustomConversions(CustomConversions.StoreConversions.NONE, Collections.emptyList())); + } + + /** + * Creates a new {@link MappingR2dbcConverter} given {@link MappingContext} and {@link CustomConversions}. + * + * @param context must not be {@literal null}. + */ + public MappingR2dbcConverter( + MappingContext, ? extends RelationalPersistentProperty> context, + CustomConversions conversions) { + super(context, conversions); + } + + // ---------------------------------- + // Entity reading + // ---------------------------------- + + @Override + public R read(Class type, Row row) { + return read(getRequiredPersistentEntity(type), row); + } + + private R read(RelationalPersistentEntity entity, Row row) { + + R result = createInstance(row, "", entity); + + ConvertingPropertyAccessor propertyAccessor = new ConvertingPropertyAccessor<>( + entity.getPropertyAccessor(result), getConversionService()); + + for (RelationalPersistentProperty property : entity) { + + if (entity.isConstructorArgument(property)) { + continue; + } + + propertyAccessor.setProperty(property, readFrom(row, property, "")); + } + + return result; } /** - * Creates a new {@link MappingR2dbcConverter} given {@link RelationalConverter}. + * Read a single value or a complete Entity from the {@link Row} passed as an argument. * - * @param converter must not be {@literal null}. + * @param row the {@link Row} to extract the value from. Must not be {@literal null}. + * @param property the {@link RelationalPersistentProperty} for which the value is intended. Must not be + * {@literal null}. + * @param prefix to be used for all column names accessed by this method. Must not be {@literal null}. + * @return the value read from the {@link Row}. May be {@literal null}. */ - public MappingR2dbcConverter(RelationalConverter converter) { + private Object readFrom(Row row, RelationalPersistentProperty property, String prefix) { + + try { + + if (property.isEntity()) { + return readEntityFrom(row, property); + } + + Object value = row.get(prefix + property.getColumnName()); + return readValue(value, property.getTypeInformation()); + + } catch (Exception o_O) { + throw new MappingException(String.format("Could not read property %s from result set!", property), o_O); + } + } + + private S readEntityFrom(Row row, PersistentProperty property) { + + String prefix = property.getName() + "_"; + + RelationalPersistentEntity entity = (RelationalPersistentEntity) getMappingContext() + .getRequiredPersistentEntity(property.getActualType()); + + if (readFrom(row, entity.getRequiredIdProperty(), prefix) == null) { + return null; + } + + S instance = createInstance(row, prefix, entity); + + PersistentPropertyAccessor accessor = entity.getPropertyAccessor(instance); + ConvertingPropertyAccessor propertyAccessor = new ConvertingPropertyAccessor<>(accessor, getConversionService()); + + for (RelationalPersistentProperty p : entity) { + if (!entity.isConstructorArgument(property)) { + propertyAccessor.setProperty(p, readFrom(row, p, prefix)); + } + } - Assert.notNull(converter, "RelationalConverter must not be null!"); + return instance; + } + + private S createInstance(Row row, String prefix, RelationalPersistentEntity entity) { + + RowParameterValueProvider rowParameterValueProvider = new RowParameterValueProvider(row, entity, this, prefix); + + return createInstance(entity, rowParameterValueProvider::getParameterValue); + } + + // ---------------------------------- + // Entity writing + // ---------------------------------- + + @Override + public void write(Object source, OutboundRow sink) { + + Class userClass = ClassUtils.getUserClass(source); + RelationalPersistentEntity entity = getRequiredPersistentEntity(userClass); + + PersistentPropertyAccessor propertyAccessor = entity.getPropertyAccessor(source); + + for (RelationalPersistentProperty property : entity) { + + Object writeValue = getWriteValue(propertyAccessor, property); + + sink.put(property.getColumnName(), new SettableValue(writeValue, property.getType())); + } + + } + + @SuppressWarnings("unchecked") + private Object getWriteValue(PersistentPropertyAccessor propertyAccessor, RelationalPersistentProperty property) { + + TypeInformation type = property.getTypeInformation(); + Object value = propertyAccessor.getProperty(property); + + RelationalPersistentEntity nestedEntity = getMappingContext() + .getPersistentEntity(type.getRequiredActualType().getType()); - this.relationalConverter = converter; + if (nestedEntity != null) { + throw new InvalidDataAccessApiUsageException("Nested entities are not supported"); + } + + return value; + } + + public Object getArrayValue(ArrayColumns arrayColumns, RelationalPersistentProperty property, Object value) { + + Class targetType = arrayColumns.getArrayType(property.getActualType()); + + if (!property.isArray() || !property.getActualType().equals(targetType)) { + + Object zeroLengthArray = Array.newInstance(targetType, 0); + return getConversionService().convert(value, zeroLengthArray.getClass()); + } + + return value; } + // ---------------------------------- + // Id handling + // ---------------------------------- + /** * Returns a {@link java.util.function.Function} that populates the id property of the {@code object} from a * {@link Row}. @@ -113,7 +263,7 @@ private boolean potentiallySetId(Row row, RowMetadata metadata, PersistentProper if (generatedIdValue != null) { - ConversionService conversionService = relationalConverter.getConversionService(); + ConversionService conversionService = getConversionService(); propertyAccessor.setProperty(idProperty, conversionService.convert(generatedIdValue, idProperty.getType())); return true; } @@ -121,6 +271,11 @@ private boolean potentiallySetId(Row row, RowMetadata metadata, PersistentProper return false; } + @SuppressWarnings("unchecked") + private RelationalPersistentEntity getRequiredPersistentEntity(Class type) { + return (RelationalPersistentEntity) getMappingContext().getRequiredPersistentEntity(type); + } + private static Map createMetadataMap(RowMetadata metadata) { Map columns = new LinkedHashMap<>(); @@ -132,7 +287,30 @@ private static Map createMetadataMap(RowMetadata metadat return columns; } - public MappingContext, ? extends RelationalPersistentProperty> getMappingContext() { - return relationalConverter.getMappingContext(); + @RequiredArgsConstructor + private static class RowParameterValueProvider implements ParameterValueProvider { + + private final @NonNull Row resultSet; + private final @NonNull RelationalPersistentEntity entity; + private final @NonNull RelationalConverter converter; + private final @NonNull String prefix; + + /* + * (non-Javadoc) + * @see org.springframework.data.mapping.model.ParameterValueProvider#getParameterValue(org.springframework.data.mapping.PreferredConstructor.Parameter) + */ + @Override + @Nullable + public T getParameterValue(Parameter parameter) { + + RelationalPersistentProperty property = entity.getRequiredPersistentProperty(parameter.getName()); + String column = prefix + property.getColumnName(); + + try { + return converter.getConversionService().convert(resultSet.get(column), parameter.getType().getType()); + } catch (Exception o_O) { + throw new MappingException(String.format("Couldn't read column %s from Row.", column), o_O); + } + } } } diff --git a/src/main/java/org/springframework/data/r2dbc/function/convert/OutboundRow.java b/src/main/java/org/springframework/data/r2dbc/function/convert/OutboundRow.java new file mode 100644 index 00000000..1ea2cc5e --- /dev/null +++ b/src/main/java/org/springframework/data/r2dbc/function/convert/OutboundRow.java @@ -0,0 +1,235 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.r2dbc.function.convert; + +import io.r2dbc.spi.Row; + +import java.util.Collection; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.Set; +import java.util.function.BiConsumer; + +import org.springframework.util.Assert; + +/** + * Representation of a {@link Row} to be written through a {@code INSERT} or {@code UPDATE} statement. + * + * @author Mark Paluch + * @see SettableValue + */ +public class OutboundRow implements Map { + + private final Map rowAsMap; + + /** + * Creates an empty {@link OutboundRow} instance. + */ + public OutboundRow() { + rowAsMap = new LinkedHashMap<>(); + } + + /** + * Creates a new {@link OutboundRow} from a {@link Map}. + * + * @param map the map used to initialize the {@link OutboundRow}. + */ + public OutboundRow(Map map) { + + Assert.notNull(map, "Map must not be null"); + + rowAsMap = new LinkedHashMap<>(map); + } + + /** + * Create a {@link OutboundRow} instance initialized with the given key/value pair. + * + * @param key key. + * @param value value. + */ + public OutboundRow(String key, SettableValue value) { + rowAsMap = new LinkedHashMap<>(); + rowAsMap.put(key, value); + } + + /** + * Put the given key/value pair into this {@link OutboundRow} and return this. Useful for chaining puts in a single + * expression: + * + *
+	 * row.append("a", 1).append("b", 2)}
+	 * 
+ * + * @param key key. + * @param value value. + * @return this + */ + public OutboundRow append(String key, SettableValue value) { + rowAsMap.put(key, value); + return this; + } + + /* + * (non-Javadoc) + * @see java.util.Map#size() + */ + @Override + public int size() { + return rowAsMap.size(); + } + + /* + * (non-Javadoc) + * @see java.util.Map#isEmpty() + */ + @Override + public boolean isEmpty() { + return rowAsMap.isEmpty(); + } + + /* + * (non-Javadoc) + * @see java.util.Map#containsKey(java.lang.Object) + */ + @Override + public boolean containsKey(Object key) { + return rowAsMap.containsKey(key); + } + + /* + * (non-Javadoc) + * @see java.util.Map#containsValue(java.lang.Object) + */ + @Override + public boolean containsValue(Object value) { + return rowAsMap.containsValue(value); + } + + /* + * (non-Javadoc) + * @see java.util.Map#get(java.lang.Object) + */ + @Override + public SettableValue get(Object key) { + return rowAsMap.get(key); + } + + /* + * (non-Javadoc) + * @see java.util.Map#put(java.lang.Object, java.lang.Object) + */ + @Override + public SettableValue put(String key, SettableValue value) { + return rowAsMap.put(key, value); + } + + /* + * (non-Javadoc) + * @see java.util.Map#remove(java.lang.Object) + */ + @Override + public SettableValue remove(Object key) { + return rowAsMap.remove(key); + } + + /* + * (non-Javadoc) + * @see java.util.Map#putAll(java.util.Map) + */ + @Override + public void putAll(Map m) { + rowAsMap.putAll(m); + } + + /* + * (non-Javadoc) + * @see java.util.Map#clear() + */ + @Override + public void clear() { + rowAsMap.clear(); + } + + /* + * (non-Javadoc) + * @see java.util.Map#keySet() + */ + @Override + public Set keySet() { + return rowAsMap.keySet(); + } + + /* + * (non-Javadoc) + * @see java.util.Map#values() + */ + @Override + public Collection values() { + return rowAsMap.values(); + } + + /* + * (non-Javadoc) + * @see java.util.Map#entrySet() + */ + @Override + public Set> entrySet() { + return rowAsMap.entrySet(); + } + + /* + * (non-Javadoc) + * @see java.lang.Object#equals(java.lang.Object) + */ + @Override + public boolean equals(final Object o) { + + if (this == o) { + return true; + } + + if (o == null || getClass() != o.getClass()) { + return false; + } + + OutboundRow row = (OutboundRow) o; + + return rowAsMap.equals(row.rowAsMap); + } + + /* + * (non-Javadoc) + * @see java.lang.Object#hashCode() + */ + @Override + public int hashCode() { + return rowAsMap.hashCode(); + } + + /* + * (non-Javadoc) + * @see java.lang.Object#toString() + */ + @Override + public String toString() { + return "OutboundRow[" + rowAsMap + "]"; + } + + @Override + public void forEach(BiConsumer action) { + rowAsMap.forEach(action); + } +} diff --git a/src/main/java/org/springframework/data/r2dbc/function/convert/R2dbcConverter.java b/src/main/java/org/springframework/data/r2dbc/function/convert/R2dbcConverter.java new file mode 100644 index 00000000..6cf89316 --- /dev/null +++ b/src/main/java/org/springframework/data/r2dbc/function/convert/R2dbcConverter.java @@ -0,0 +1,73 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.r2dbc.function.convert; + +import io.r2dbc.spi.Row; +import io.r2dbc.spi.RowMetadata; + +import java.util.function.BiFunction; + +import org.springframework.core.convert.ConversionService; +import org.springframework.data.convert.EntityReader; +import org.springframework.data.convert.EntityWriter; +import org.springframework.data.mapping.context.MappingContext; +import org.springframework.data.r2dbc.dialect.ArrayColumns; +import org.springframework.data.relational.core.conversion.RelationalConverter; +import org.springframework.data.relational.core.mapping.RelationalPersistentEntity; +import org.springframework.data.relational.core.mapping.RelationalPersistentProperty; + +/** + * Central R2DBC specific converter interface. + * + * @author Mark Paluch + * @see EntityReader + */ +public interface R2dbcConverter + extends EntityReader, EntityWriter, RelationalConverter { + + /** + * Returns the underlying {@link MappingContext} used by the converter. + * + * @return never {@literal null} + */ + MappingContext, ? extends RelationalPersistentProperty> getMappingContext(); + + /** + * Returns the underlying {@link ConversionService} used by the converter. + * + * @return never {@literal null}. + */ + ConversionService getConversionService(); + + /** + * Convert a {@code value} into an array representation according to {@link ArrayColumns}. + * + * @param arrayColumns dialect-specific array handling configuration. + * @param property + * @param value + * @return + */ + Object getArrayValue(ArrayColumns arrayColumns, RelationalPersistentProperty property, Object value); + + /** + * Returns a {@link java.util.function.Function} that populates the id property of the {@code object} from a + * {@link Row}. + * + * @param object must not be {@literal null}. + * @return + */ + BiFunction populateIdIfNecessary(T object); +} diff --git a/src/main/java/org/springframework/data/r2dbc/function/convert/SettableValue.java b/src/main/java/org/springframework/data/r2dbc/function/convert/SettableValue.java index fe4e8118..edce1794 100644 --- a/src/main/java/org/springframework/data/r2dbc/function/convert/SettableValue.java +++ b/src/main/java/org/springframework/data/r2dbc/function/convert/SettableValue.java @@ -15,57 +15,87 @@ */ package org.springframework.data.r2dbc.function.convert; +import java.util.Objects; + import org.springframework.lang.Nullable; +import org.springframework.util.Assert; /** * A database value that can be set in a statement. * * @author Mark Paluch + * @see OutboundRow */ public class SettableValue { - private final Object identifier; private final @Nullable Object value; private final Class type; /** - * Create a {@link SettableValue} using an integer index. + * Create a {@link SettableValue}. * - * @param index * @param value * @param type */ - public SettableValue(int index, @Nullable Object value, Class type) { + public SettableValue(@Nullable Object value, Class type) { + + Assert.notNull(type, "Type must not be null"); - this.identifier = index; this.value = value; this.type = type; } /** - * Create a {@link SettableValue} using a {@link String} identifier. + * Returns the column value. Can be {@literal null}. * - * @param identifier - * @param value - * @param type + * @return the column value. Can be {@literal null}. + * @see #hasValue() */ - public SettableValue(String identifier, @Nullable Object value, Class type) { - - this.identifier = identifier; - this.value = value; - this.type = type; - } - - public Object getIdentifier() { - return identifier; - } - @Nullable public Object getValue() { return value; } + /** + * Returns the column value type. Must be also present if the {@code value} is {@literal null}. + * + * @return the column value type + */ public Class getType() { return type; } + + /** + * Returns whether this {@link SettableValue} has a value. + * + * @return whether this {@link SettableValue} has a value. {@literal false} if {@link #getValue()} is {@literal null}. + */ + public boolean hasValue() { + return value != null; + } + + @Override + public boolean equals(Object o) { + if (this == o) + return true; + if (!(o instanceof SettableValue)) + return false; + SettableValue value1 = (SettableValue) o; + return Objects.equals(value, value1.value) && Objects.equals(type, value1.type); + } + + @Override + public int hashCode() { + return Objects.hash(value, type); + } + + @Override + public String toString() { + final StringBuffer sb = new StringBuffer(); + sb.append(getClass().getSimpleName()); + sb.append(" [value=").append(value); + sb.append(", type=").append(type); + sb.append(']'); + return sb.toString(); + } } diff --git a/src/main/java/org/springframework/data/r2dbc/repository/query/AbstractR2dbcQuery.java b/src/main/java/org/springframework/data/r2dbc/repository/query/AbstractR2dbcQuery.java index 5cb02f92..dc15e384 100644 --- a/src/main/java/org/springframework/data/r2dbc/repository/query/AbstractR2dbcQuery.java +++ b/src/main/java/org/springframework/data/r2dbc/repository/query/AbstractR2dbcQuery.java @@ -19,12 +19,13 @@ import reactor.core.publisher.Mono; import org.reactivestreams.Publisher; + import org.springframework.core.convert.converter.Converter; import org.springframework.data.convert.EntityInstantiators; import org.springframework.data.r2dbc.function.DatabaseClient; import org.springframework.data.r2dbc.function.DatabaseClient.GenericExecuteSpec; import org.springframework.data.r2dbc.function.FetchSpec; -import org.springframework.data.r2dbc.function.convert.MappingR2dbcConverter; +import org.springframework.data.r2dbc.function.convert.R2dbcConverter; import org.springframework.data.r2dbc.repository.query.R2dbcQueryExecution.ResultProcessingConverter; import org.springframework.data.r2dbc.repository.query.R2dbcQueryExecution.ResultProcessingExecution; import org.springframework.data.relational.repository.query.RelationalParameterAccessor; @@ -44,7 +45,7 @@ public abstract class AbstractR2dbcQuery implements RepositoryQuery { private final R2dbcQueryMethod method; private final DatabaseClient databaseClient; - private final MappingR2dbcConverter converter; + private final R2dbcConverter converter; private final EntityInstantiators instantiators; /** @@ -54,11 +55,11 @@ public abstract class AbstractR2dbcQuery implements RepositoryQuery { * @param databaseClient must not be {@literal null}. * @param converter must not be {@literal null}. */ - public AbstractR2dbcQuery(R2dbcQueryMethod method, DatabaseClient databaseClient, MappingR2dbcConverter converter) { + public AbstractR2dbcQuery(R2dbcQueryMethod method, DatabaseClient databaseClient, R2dbcConverter converter) { Assert.notNull(method, "R2dbcQueryMethod must not be null!"); Assert.notNull(databaseClient, "DatabaseClient must not be null!"); - Assert.notNull(converter, "MappingR2dbcConverter must not be null!"); + Assert.notNull(converter, "R2dbcConverter must not be null!"); this.method = method; this.databaseClient = databaseClient; diff --git a/src/main/java/org/springframework/data/r2dbc/repository/query/StringBasedR2dbcQuery.java b/src/main/java/org/springframework/data/r2dbc/repository/query/StringBasedR2dbcQuery.java index 11b21092..bfa034d3 100644 --- a/src/main/java/org/springframework/data/r2dbc/repository/query/StringBasedR2dbcQuery.java +++ b/src/main/java/org/springframework/data/r2dbc/repository/query/StringBasedR2dbcQuery.java @@ -17,7 +17,7 @@ import org.springframework.data.r2dbc.function.DatabaseClient; import org.springframework.data.r2dbc.function.DatabaseClient.BindSpec; -import org.springframework.data.r2dbc.function.convert.MappingR2dbcConverter; +import org.springframework.data.r2dbc.function.convert.R2dbcConverter; import org.springframework.data.relational.repository.query.RelationalParameterAccessor; import org.springframework.data.repository.query.Parameter; import org.springframework.data.repository.query.Parameters; @@ -46,9 +46,8 @@ public class StringBasedR2dbcQuery extends AbstractR2dbcQuery { * @param expressionParser must not be {@literal null}. * @param evaluationContextProvider must not be {@literal null}. */ - public StringBasedR2dbcQuery(R2dbcQueryMethod queryMethod, DatabaseClient databaseClient, - MappingR2dbcConverter converter, SpelExpressionParser expressionParser, - QueryMethodEvaluationContextProvider evaluationContextProvider) { + public StringBasedR2dbcQuery(R2dbcQueryMethod queryMethod, DatabaseClient databaseClient, R2dbcConverter converter, + SpelExpressionParser expressionParser, QueryMethodEvaluationContextProvider evaluationContextProvider) { this(queryMethod.getRequiredAnnotatedQuery(), queryMethod, databaseClient, converter, expressionParser, evaluationContextProvider); @@ -65,7 +64,7 @@ public StringBasedR2dbcQuery(R2dbcQueryMethod queryMethod, DatabaseClient databa * @param evaluationContextProvider must not be {@literal null}. */ public StringBasedR2dbcQuery(String query, R2dbcQueryMethod method, DatabaseClient databaseClient, - MappingR2dbcConverter converter, SpelExpressionParser expressionParser, + R2dbcConverter converter, SpelExpressionParser expressionParser, QueryMethodEvaluationContextProvider evaluationContextProvider) { super(method, databaseClient, converter); diff --git a/src/main/java/org/springframework/data/r2dbc/repository/support/R2dbcRepositoryFactory.java b/src/main/java/org/springframework/data/r2dbc/repository/support/R2dbcRepositoryFactory.java index 78e80d12..0c82c087 100644 --- a/src/main/java/org/springframework/data/r2dbc/repository/support/R2dbcRepositoryFactory.java +++ b/src/main/java/org/springframework/data/r2dbc/repository/support/R2dbcRepositoryFactory.java @@ -25,11 +25,10 @@ import org.springframework.data.projection.ProjectionFactory; import org.springframework.data.r2dbc.function.DatabaseClient; import org.springframework.data.r2dbc.function.ReactiveDataAccessStrategy; -import org.springframework.data.r2dbc.function.convert.MappingR2dbcConverter; +import org.springframework.data.r2dbc.function.convert.R2dbcConverter; import org.springframework.data.r2dbc.repository.R2dbcRepository; import org.springframework.data.r2dbc.repository.query.R2dbcQueryMethod; import org.springframework.data.r2dbc.repository.query.StringBasedR2dbcQuery; -import org.springframework.data.relational.core.conversion.BasicRelationalConverter; import org.springframework.data.relational.core.mapping.RelationalPersistentEntity; import org.springframework.data.relational.core.mapping.RelationalPersistentProperty; import org.springframework.data.relational.repository.query.RelationalEntityInformation; @@ -56,28 +55,25 @@ public class R2dbcRepositoryFactory extends ReactiveRepositoryFactorySupport { private static final SpelExpressionParser EXPRESSION_PARSER = new SpelExpressionParser(); private final DatabaseClient databaseClient; - private final MappingContext, RelationalPersistentProperty> mappingContext; - private final MappingR2dbcConverter converter; + private final MappingContext, ? extends RelationalPersistentProperty> mappingContext; + private final R2dbcConverter converter; private final ReactiveDataAccessStrategy dataAccessStrategy; /** * Creates a new {@link R2dbcRepositoryFactory} given {@link DatabaseClient} and {@link MappingContext}. * * @param databaseClient must not be {@literal null}. - * @param mappingContext must not be {@literal null}. + * @param dataAccessStrategy must not be {@literal null}. */ - public R2dbcRepositoryFactory(DatabaseClient databaseClient, - MappingContext, RelationalPersistentProperty> mappingContext, - ReactiveDataAccessStrategy dataAccessStrategy) { + public R2dbcRepositoryFactory(DatabaseClient databaseClient, ReactiveDataAccessStrategy dataAccessStrategy) { Assert.notNull(databaseClient, "DatabaseClient must not be null!"); - Assert.notNull(mappingContext, "MappingContext must not be null!"); Assert.notNull(dataAccessStrategy, "ReactiveDataAccessStrategy must not be null!"); this.databaseClient = databaseClient; - this.mappingContext = mappingContext; + this.converter = dataAccessStrategy.getConverter(); + this.mappingContext = this.converter.getMappingContext(); this.dataAccessStrategy = dataAccessStrategy; - this.converter = new MappingR2dbcConverter(new BasicRelationalConverter(mappingContext)); } /* @@ -140,7 +136,7 @@ private static class R2dbcQueryLookupStrategy implements QueryLookupStrategy { private final DatabaseClient databaseClient; private final QueryMethodEvaluationContextProvider evaluationContextProvider; - private final MappingR2dbcConverter converter; + private final R2dbcConverter converter; /* * (non-Javadoc) @@ -163,7 +159,6 @@ public RepositoryQuery resolveQuery(Method method, RepositoryMetadata metadata, } throw new UnsupportedOperationException("Query derivation not yet supported!"); - } } } diff --git a/src/main/java/org/springframework/data/r2dbc/repository/support/R2dbcRepositoryFactoryBean.java b/src/main/java/org/springframework/data/r2dbc/repository/support/R2dbcRepositoryFactoryBean.java index 766787ca..506f7ee5 100644 --- a/src/main/java/org/springframework/data/r2dbc/repository/support/R2dbcRepositoryFactoryBean.java +++ b/src/main/java/org/springframework/data/r2dbc/repository/support/R2dbcRepositoryFactoryBean.java @@ -20,9 +20,6 @@ import org.springframework.data.mapping.context.MappingContext; import org.springframework.data.r2dbc.function.DatabaseClient; import org.springframework.data.r2dbc.function.ReactiveDataAccessStrategy; -import org.springframework.data.relational.core.mapping.RelationalMappingContext; -import org.springframework.data.relational.core.mapping.RelationalPersistentEntity; -import org.springframework.data.relational.core.mapping.RelationalPersistentProperty; import org.springframework.data.repository.Repository; import org.springframework.data.repository.core.support.RepositoryFactoryBeanSupport; import org.springframework.data.repository.core.support.RepositoryFactorySupport; @@ -41,7 +38,6 @@ public class R2dbcRepositoryFactoryBean, S, ID exten extends RepositoryFactoryBeanSupport { private @Nullable DatabaseClient client; - private @Nullable MappingContext, RelationalPersistentProperty> mappingContext; private @Nullable ReactiveDataAccessStrategy dataAccessStrategy; private boolean mappingContextConfigured = false; @@ -69,14 +65,11 @@ public void setDatabaseClient(@Nullable DatabaseClient client) { * @see org.springframework.data.repository.core.support.RepositoryFactoryBeanSupport#setMappingContext(org.springframework.data.mapping.context.MappingContext) */ @Override - @SuppressWarnings("unchecked") protected void setMappingContext(@Nullable MappingContext mappingContext) { super.setMappingContext(mappingContext); if (mappingContext != null) { - - this.mappingContext = (MappingContext, RelationalPersistentProperty>) mappingContext; this.mappingContextConfigured = true; } } @@ -91,19 +84,19 @@ public void setDataAccessStrategy(@Nullable ReactiveDataAccessStrategy dataAcces */ @Override protected final RepositoryFactorySupport createRepositoryFactory() { - return getFactoryInstance(client, this.mappingContext); + return getFactoryInstance(client, dataAccessStrategy); } /** * Creates and initializes a {@link RepositoryFactorySupport} instance. * * @param client must not be {@literal null}. - * @param mappingContext must not be {@literal null}. + * @param dataAccessStrategy must not be {@literal null}. * @return new instance of {@link RepositoryFactorySupport}. */ protected RepositoryFactorySupport getFactoryInstance(DatabaseClient client, - MappingContext, RelationalPersistentProperty> mappingContext) { - return new R2dbcRepositoryFactory(client, mappingContext, dataAccessStrategy); + ReactiveDataAccessStrategy dataAccessStrategy) { + return new R2dbcRepositoryFactory(client, dataAccessStrategy); } /* @@ -117,7 +110,7 @@ public void afterPropertiesSet() { Assert.state(dataAccessStrategy != null, "ReactiveDataAccessStrategy must not be null!"); if (!mappingContextConfigured) { - setMappingContext(new RelationalMappingContext()); + setMappingContext(dataAccessStrategy.getConverter().getMappingContext()); } super.afterPropertiesSet(); diff --git a/src/main/java/org/springframework/data/r2dbc/repository/support/SimpleR2dbcRepository.java b/src/main/java/org/springframework/data/r2dbc/repository/support/SimpleR2dbcRepository.java index c6b99f5c..4b28127c 100644 --- a/src/main/java/org/springframework/data/r2dbc/repository/support/SimpleR2dbcRepository.java +++ b/src/main/java/org/springframework/data/r2dbc/repository/support/SimpleR2dbcRepository.java @@ -15,7 +15,6 @@ */ package org.springframework.data.r2dbc.repository.support; -import io.r2dbc.spi.Statement; import lombok.NonNull; import lombok.RequiredArgsConstructor; import reactor.core.publisher.Flux; @@ -26,18 +25,16 @@ import java.util.List; import java.util.Map; import java.util.Set; -import java.util.function.BiConsumer; import org.reactivestreams.Publisher; import org.springframework.data.r2dbc.dialect.BindMarker; import org.springframework.data.r2dbc.dialect.BindMarkers; import org.springframework.data.r2dbc.function.BindIdOperation; -import org.springframework.data.r2dbc.function.BindableOperation; import org.springframework.data.r2dbc.function.DatabaseClient; import org.springframework.data.r2dbc.function.DatabaseClient.GenericExecuteSpec; import org.springframework.data.r2dbc.function.ReactiveDataAccessStrategy; -import org.springframework.data.r2dbc.function.convert.MappingR2dbcConverter; +import org.springframework.data.r2dbc.function.convert.R2dbcConverter; import org.springframework.data.r2dbc.function.convert.SettableValue; import org.springframework.data.relational.core.sql.Conditions; import org.springframework.data.relational.core.sql.Expression; @@ -61,7 +58,7 @@ public class SimpleR2dbcRepository implements ReactiveCrudRepository entity; private final @NonNull DatabaseClient databaseClient; - private final @NonNull MappingR2dbcConverter converter; + private final @NonNull R2dbcConverter converter; private final @NonNull ReactiveDataAccessStrategy accessStrategy; /* (non-Javadoc) @@ -82,7 +79,7 @@ public Mono save(S objectToSave) { } Object id = entity.getRequiredId(objectToSave); - Map columns = accessStrategy.getColumnsToUpdate(objectToSave); + Map columns = accessStrategy.getOutboundRow(objectToSave); columns.remove(getIdColumnName()); // do not update the Id column. String idColumnName = getIdColumnName(); BindIdOperation update = accessStrategy.updateById(entity.getTableName(), columns.keySet(), idColumnName); @@ -90,7 +87,10 @@ public Mono save(S objectToSave) { GenericExecuteSpec exec = databaseClient.execute().sql(update); BindSpecAdapter wrapper = BindSpecAdapter.create(exec); - columns.forEach(bind(update, wrapper)); + columns.forEach((k, v) -> { + update.bind(wrapper, k, v); + + }); update.bindId(wrapper, id); return wrapper.getBoundOperation().as(entity.getJavaType()) // @@ -236,11 +236,8 @@ public Flux findAllById(Publisher idPublisher) { } Table table = Table.create(entity.getTableName()); - Select select = StatementBuilder - .select(table.columns(columns)) - .from(table) - .where(Conditions.in(table.column(idColumnName), markers)) - .build(); + Select select = StatementBuilder.select(table.columns(columns)).from(table) + .where(Conditions.in(table.column(idColumnName), markers)).build(); GenericExecuteSpec executeSpec = databaseClient.execute().sql(SqlRenderer.toString(select)); @@ -368,9 +365,4 @@ private String getIdColumnName() { .getRequiredIdProperty() // .getColumnName(); } - - private BiConsumer bind(BindableOperation operation, Statement statement) { - - return (k, v) -> operation.bind(statement, v); - } } diff --git a/src/test/java/org/springframework/data/r2dbc/function/DefaultReactiveDataAccessStrategyUnitTests.java b/src/test/java/org/springframework/data/r2dbc/function/DefaultReactiveDataAccessStrategyUnitTests.java index 2f1e51ec..59690d5b 100644 --- a/src/test/java/org/springframework/data/r2dbc/function/DefaultReactiveDataAccessStrategyUnitTests.java +++ b/src/test/java/org/springframework/data/r2dbc/function/DefaultReactiveDataAccessStrategyUnitTests.java @@ -12,6 +12,7 @@ import java.util.Map; import org.junit.Test; + import org.springframework.data.r2dbc.dialect.PostgresDialect; import org.springframework.data.r2dbc.function.convert.SettableValue; @@ -67,7 +68,7 @@ public void shouldRenderDeleteByIdInQuery() { public void shouldUpdateArray() { Map columnsToUpdate = strategy - .getColumnsToUpdate(new WithCollectionTypes(new String[] { "one", "two" }, null)); + .getOutboundRow(new WithCollectionTypes(new String[] { "one", "two" }, null)); Object stringArray = columnsToUpdate.get("string_array").getValue(); @@ -79,7 +80,7 @@ public void shouldUpdateArray() { public void shouldConvertListToArray() { Map columnsToUpdate = strategy - .getColumnsToUpdate(new WithCollectionTypes(null, Arrays.asList("one", "two"))); + .getOutboundRow(new WithCollectionTypes(null, Arrays.asList("one", "two"))); Object stringArray = columnsToUpdate.get("string_collection").getValue(); diff --git a/src/test/java/org/springframework/data/r2dbc/function/convert/EntityRowMapperUnitTests.java b/src/test/java/org/springframework/data/r2dbc/function/convert/EntityRowMapperUnitTests.java index 91dceff2..307d710e 100644 --- a/src/test/java/org/springframework/data/r2dbc/function/convert/EntityRowMapperUnitTests.java +++ b/src/test/java/org/springframework/data/r2dbc/function/convert/EntityRowMapperUnitTests.java @@ -13,9 +13,9 @@ import org.junit.Test; import org.junit.runner.RunWith; import org.mockito.junit.MockitoJUnitRunner; + import org.springframework.data.r2dbc.dialect.PostgresDialect; import org.springframework.data.r2dbc.function.DefaultReactiveDataAccessStrategy; -import org.springframework.data.relational.core.mapping.RelationalPersistentEntity; /** * Unit tests for {@link EntityRowMapper}. @@ -101,11 +101,8 @@ public void shouldConvertArrayToBoxedArray() { assertThat(result.boxedIntegers).contains(3, 11); } - @SuppressWarnings("unchecked") private EntityRowMapper getRowMapper(Class type) { - RelationalPersistentEntity entity = (RelationalPersistentEntity) strategy.getMappingContext() - .getRequiredPersistentEntity(type); - return new EntityRowMapper<>(entity, strategy.getRelationalConverter()); + return new EntityRowMapper<>(type, strategy.getConverter()); } static class SimpleEntity { diff --git a/src/test/java/org/springframework/data/r2dbc/function/convert/MappingR2dbcConverterUnitTests.java b/src/test/java/org/springframework/data/r2dbc/function/convert/MappingR2dbcConverterUnitTests.java new file mode 100644 index 00000000..dbc9e4e7 --- /dev/null +++ b/src/test/java/org/springframework/data/r2dbc/function/convert/MappingR2dbcConverterUnitTests.java @@ -0,0 +1,53 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.r2dbc.function.convert; + +import static org.assertj.core.api.Assertions.*; + +import lombok.AllArgsConstructor; + +import org.junit.Test; + +import org.springframework.data.annotation.Id; +import org.springframework.data.relational.core.mapping.RelationalMappingContext; + +/** + * Unit tests for {@link MappingR2dbcConverter}. + * + * @author Mark Paluch + */ +public class MappingR2dbcConverterUnitTests { + + MappingR2dbcConverter converter = new MappingR2dbcConverter(new RelationalMappingContext()); + + @Test // gh-61 + public void shouldIncludeAllPropertiesInOutboundRow() { + + OutboundRow row = new OutboundRow(); + + converter.write(new Person("id", "Walter", "White"), row); + + assertThat(row).containsEntry("id", new SettableValue("id", String.class)); + assertThat(row).containsEntry("firstname", new SettableValue("Walter", String.class)); + assertThat(row).containsEntry("lastname", new SettableValue("White", String.class)); + } + + @AllArgsConstructor + static class Person { + @Id String id; + String firstname, lastname; + } +} diff --git a/src/test/java/org/springframework/data/r2dbc/repository/AbstractR2dbcRepositoryIntegrationTests.java b/src/test/java/org/springframework/data/r2dbc/repository/AbstractR2dbcRepositoryIntegrationTests.java index ded79a63..7f92d93f 100644 --- a/src/test/java/org/springframework/data/r2dbc/repository/AbstractR2dbcRepositoryIntegrationTests.java +++ b/src/test/java/org/springframework/data/r2dbc/repository/AbstractR2dbcRepositoryIntegrationTests.java @@ -34,15 +34,16 @@ import org.junit.Before; import org.junit.Test; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.dao.DataAccessException; import org.springframework.data.annotation.Id; import org.springframework.data.r2dbc.dialect.Database; import org.springframework.data.r2dbc.function.DefaultReactiveDataAccessStrategy; import org.springframework.data.r2dbc.function.TransactionalDatabaseClient; +import org.springframework.data.r2dbc.function.convert.MappingR2dbcConverter; import org.springframework.data.r2dbc.repository.support.R2dbcRepositoryFactory; import org.springframework.data.r2dbc.testing.R2dbcIntegrationTestSupport; -import org.springframework.data.relational.core.conversion.BasicRelationalConverter; import org.springframework.data.relational.core.mapping.RelationalMappingContext; import org.springframework.data.relational.core.mapping.Table; import org.springframework.data.repository.NoRepositoryBean; @@ -161,11 +162,11 @@ public void shouldInsertItemsTransactional() { Database database = Database.findDatabase(createConnectionFactory()).get(); DefaultReactiveDataAccessStrategy dataAccessStrategy = new DefaultReactiveDataAccessStrategy( - database.defaultDialect(), new BasicRelationalConverter(mappingContext)); + database.defaultDialect(), new MappingR2dbcConverter(mappingContext)); TransactionalDatabaseClient client = TransactionalDatabaseClient.builder() .connectionFactory(createConnectionFactory()).dataAccessStrategy(dataAccessStrategy).build(); - LegoSetRepository transactionalRepository = new R2dbcRepositoryFactory(client, mappingContext, dataAccessStrategy) + LegoSetRepository transactionalRepository = new R2dbcRepositoryFactory(client, dataAccessStrategy) .getRepository(getRepositoryInterfaceType()); LegoSet legoSet1 = new LegoSet(null, "SCHAUFELRADBAGGER", 12); diff --git a/src/test/java/org/springframework/data/r2dbc/repository/config/R2dbcRepositoriesRegistrarTests.java b/src/test/java/org/springframework/data/r2dbc/repository/config/R2dbcRepositoriesRegistrarTests.java index d40e212b..db97506e 100644 --- a/src/test/java/org/springframework/data/r2dbc/repository/config/R2dbcRepositoriesRegistrarTests.java +++ b/src/test/java/org/springframework/data/r2dbc/repository/config/R2dbcRepositoriesRegistrarTests.java @@ -19,11 +19,14 @@ import org.junit.Test; import org.junit.runner.RunWith; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.ApplicationContext; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; +import org.springframework.data.r2dbc.dialect.PostgresDialect; import org.springframework.data.r2dbc.function.DatabaseClient; +import org.springframework.data.r2dbc.function.DefaultReactiveDataAccessStrategy; import org.springframework.data.r2dbc.function.ReactiveDataAccessStrategy; import org.springframework.test.context.ContextConfiguration; import org.springframework.test.context.junit4.SpringRunner; @@ -48,7 +51,7 @@ public DatabaseClient databaseClient() { @Bean public ReactiveDataAccessStrategy reactiveDataAccessStrategy() { - return mock(ReactiveDataAccessStrategy.class); + return new DefaultReactiveDataAccessStrategy(new PostgresDialect()); } } diff --git a/src/test/java/org/springframework/data/r2dbc/repository/query/StringBasedR2dbcQueryUnitTests.java b/src/test/java/org/springframework/data/r2dbc/repository/query/StringBasedR2dbcQueryUnitTests.java index 9af26b2c..bd2aa716 100644 --- a/src/test/java/org/springframework/data/r2dbc/repository/query/StringBasedR2dbcQueryUnitTests.java +++ b/src/test/java/org/springframework/data/r2dbc/repository/query/StringBasedR2dbcQueryUnitTests.java @@ -26,12 +26,12 @@ import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; + import org.springframework.data.projection.ProjectionFactory; import org.springframework.data.projection.SpelAwareProxyProjectionFactory; import org.springframework.data.r2dbc.function.DatabaseClient; import org.springframework.data.r2dbc.function.DatabaseClient.GenericExecuteSpec; import org.springframework.data.r2dbc.function.convert.MappingR2dbcConverter; -import org.springframework.data.relational.core.conversion.BasicRelationalConverter; import org.springframework.data.relational.core.mapping.RelationalMappingContext; import org.springframework.data.repository.Repository; import org.springframework.data.repository.core.RepositoryMetadata; @@ -62,7 +62,7 @@ public class StringBasedR2dbcQueryUnitTests { public void setUp() { this.mappingContext = new RelationalMappingContext(); - this.converter = new MappingR2dbcConverter(new BasicRelationalConverter(this.mappingContext)); + this.converter = new MappingR2dbcConverter(this.mappingContext); this.metadata = AbstractRepositoryMetadata.getMetadata(SampleRepository.class); this.factory = new SpelAwareProxyProjectionFactory(); diff --git a/src/test/java/org/springframework/data/r2dbc/repository/support/AbstractSimpleR2dbcRepositoryIntegrationTests.java b/src/test/java/org/springframework/data/r2dbc/repository/support/AbstractSimpleR2dbcRepositoryIntegrationTests.java index b1c916d4..29e731e9 100644 --- a/src/test/java/org/springframework/data/r2dbc/repository/support/AbstractSimpleR2dbcRepositoryIntegrationTests.java +++ b/src/test/java/org/springframework/data/r2dbc/repository/support/AbstractSimpleR2dbcRepositoryIntegrationTests.java @@ -34,6 +34,7 @@ import org.junit.Before; import org.junit.Test; + import org.springframework.beans.factory.annotation.Autowired; import org.springframework.dao.DataAccessException; import org.springframework.data.annotation.Id; @@ -41,7 +42,6 @@ import org.springframework.data.r2dbc.function.ReactiveDataAccessStrategy; import org.springframework.data.r2dbc.function.convert.MappingR2dbcConverter; import org.springframework.data.r2dbc.testing.R2dbcIntegrationTestSupport; -import org.springframework.data.relational.core.conversion.BasicRelationalConverter; import org.springframework.data.relational.core.mapping.RelationalMappingContext; import org.springframework.data.relational.core.mapping.RelationalPersistentEntity; import org.springframework.data.relational.core.mapping.Table; @@ -74,7 +74,7 @@ public void before() { (RelationalPersistentEntity) mappingContext.getRequiredPersistentEntity(LegoSet.class)); this.repository = new SimpleR2dbcRepository<>(entityInformation, databaseClient, - new MappingR2dbcConverter(new BasicRelationalConverter(mappingContext)), strategy); + new MappingR2dbcConverter(mappingContext), strategy); this.jdbc = createJdbcTemplate(createDataSource()); try { diff --git a/src/test/java/org/springframework/data/r2dbc/repository/support/R2dbcRepositoryFactoryUnitTests.java b/src/test/java/org/springframework/data/r2dbc/repository/support/R2dbcRepositoryFactoryUnitTests.java index 3efc75cf..a1119355 100644 --- a/src/test/java/org/springframework/data/r2dbc/repository/support/R2dbcRepositoryFactoryUnitTests.java +++ b/src/test/java/org/springframework/data/r2dbc/repository/support/R2dbcRepositoryFactoryUnitTests.java @@ -23,9 +23,11 @@ import org.junit.runner.RunWith; import org.mockito.Mock; import org.mockito.junit.MockitoJUnitRunner; + import org.springframework.data.mapping.context.MappingContext; import org.springframework.data.r2dbc.function.DatabaseClient; import org.springframework.data.r2dbc.function.ReactiveDataAccessStrategy; +import org.springframework.data.r2dbc.function.convert.R2dbcConverter; import org.springframework.data.relational.core.mapping.RelationalPersistentEntity; import org.springframework.data.relational.repository.query.RelationalEntityInformation; import org.springframework.data.relational.repository.support.MappingRelationalEntityInformation; @@ -40,31 +42,32 @@ public class R2dbcRepositoryFactoryUnitTests { @Mock DatabaseClient databaseClient; + @Mock R2dbcConverter r2dbcConverter; + @Mock ReactiveDataAccessStrategy dataAccessStrategy; @Mock @SuppressWarnings("rawtypes") MappingContext mappingContext; @Mock @SuppressWarnings("rawtypes") RelationalPersistentEntity entity; - @Mock ReactiveDataAccessStrategy dataAccessStrategy; @Before @SuppressWarnings("unchecked") public void before() { when(mappingContext.getRequiredPersistentEntity(Person.class)).thenReturn(entity); + when(dataAccessStrategy.getConverter()).thenReturn(r2dbcConverter); + when(r2dbcConverter.getMappingContext()).thenReturn(mappingContext); } @Test - @SuppressWarnings("unchecked") public void usesMappingRelationalEntityInformationIfMappingContextSet() { - R2dbcRepositoryFactory factory = new R2dbcRepositoryFactory(databaseClient, mappingContext, dataAccessStrategy); + R2dbcRepositoryFactory factory = new R2dbcRepositoryFactory(databaseClient, dataAccessStrategy); RelationalEntityInformation entityInformation = factory.getEntityInformation(Person.class); assertThat(entityInformation).isInstanceOf(MappingRelationalEntityInformation.class); } @Test - @SuppressWarnings("unchecked") public void createsRepositoryWithIdTypeLong() { - R2dbcRepositoryFactory factory = new R2dbcRepositoryFactory(databaseClient, mappingContext, dataAccessStrategy); + R2dbcRepositoryFactory factory = new R2dbcRepositoryFactory(databaseClient, dataAccessStrategy); MyPersonRepository repository = factory.getRepository(MyPersonRepository.class); assertThat(repository).isNotNull();