diff --git a/pom.xml b/pom.xml index 03f3665d..5be8a917 100644 --- a/pom.xml +++ b/pom.xml @@ -7,7 +7,7 @@ org.springframework.data spring-data-r2dbc - 1.0.0.BUILD-SNAPSHOT + 1.0.0.gh-73-SNAPSHOT Spring Data R2DBC Spring Data module for R2DBC. diff --git a/src/main/java/org/springframework/data/r2dbc/dialect/IndexedBindMarker.java b/src/main/java/org/springframework/data/r2dbc/dialect/IndexedBindMarker.java index b7d84d65..c54da6dd 100644 --- a/src/main/java/org/springframework/data/r2dbc/dialect/IndexedBindMarker.java +++ b/src/main/java/org/springframework/data/r2dbc/dialect/IndexedBindMarker.java @@ -20,11 +20,11 @@ /** * A single indexed bind marker. */ -class IndexedBindMarker implements BindMarker { +public class IndexedBindMarker implements BindMarker { private final String placeholder; - private int index; + private final int index; IndexedBindMarker(String placeholder, int index) { this.placeholder = placeholder; @@ -57,4 +57,10 @@ public void bind(Statement statement, Object value) { public void bindNull(Statement statement, Class valueType) { statement.bindNull(this.index, valueType); } + + + public int getIndex() { + return index; + } + } diff --git a/src/main/java/org/springframework/data/r2dbc/function/BindIdOperation.java b/src/main/java/org/springframework/data/r2dbc/function/BindIdOperation.java deleted file mode 100644 index 71f437ca..00000000 --- a/src/main/java/org/springframework/data/r2dbc/function/BindIdOperation.java +++ /dev/null @@ -1,32 +0,0 @@ -package org.springframework.data.r2dbc.function; - -import io.r2dbc.spi.Statement; - -/** - * Extension to {@link BindableOperation} for operations that allow parameter substitution for a single {@code id} - * column that accepts either a single value or multiple values, depending on the underlying operation. - * - * @author Mark Paluch - * @see Statement#bind - * @see Statement#bindNull - */ -public interface BindIdOperation extends BindableOperation { - - /** - * Bind the given {@code value} to the {@link Statement} using the underlying binding strategy. - * - * @param statement the statement to bind the value to. - * @param value the actual value. Must not be {@literal null}. - * @see Statement#bind - */ - void bindId(Statement statement, Object value); - - /** - * Bind the given {@code values} to the {@link Statement} using the underlying binding strategy. - * - * @param statement the statement to bind the value to. - * @param values the actual values. - * @see Statement#bind - */ - void bindIds(Statement statement, Iterable values); -} diff --git a/src/main/java/org/springframework/data/r2dbc/function/Bindings.java b/src/main/java/org/springframework/data/r2dbc/function/Bindings.java new file mode 100644 index 00000000..a882156d --- /dev/null +++ b/src/main/java/org/springframework/data/r2dbc/function/Bindings.java @@ -0,0 +1,104 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.r2dbc.function; + +import io.r2dbc.spi.Statement; +import lombok.RequiredArgsConstructor; + +import java.util.List; + +import org.springframework.data.r2dbc.domain.SettableValue; + +/** + * @author Jens Schauder + */ +public class Bindings { + + private final List bindings; + + public Bindings(List bindings) { + this.bindings = bindings; + } + + public void apply(Statement statement) { + bindings.forEach(sb -> sb.bindTo(statement)); + } + + @RequiredArgsConstructor + public static abstract class SingleBinding { + + final T identifier; + final SettableValue value; + + public abstract void bindTo(Statement statement); + + public abstract boolean isIndexed(); + + public final boolean isNamed() { + return !isIndexed(); + } + } + + + public static class IndexedSingleBinding extends SingleBinding { + + public IndexedSingleBinding(Integer identifier, SettableValue value) { + super(identifier, value); + } + + @Override + public void bindTo(Statement statement) { + + if (value.isEmpty()) { + statement.bindNull((int) identifier, value.getType()); + } else { + statement.bind((int) identifier, value.getValue()); + } + } + + @Override + public boolean isIndexed() { + return true; + } + } + + public static class NamedExpandedSingleBinding extends SingleBinding { + + private final BindableOperation operation; + + public NamedExpandedSingleBinding(String identifier, SettableValue value, BindableOperation operation) { + + super(identifier, value); + + this.operation = operation; + } + + @Override + public void bindTo(Statement statement) { + + if (value != null) { + operation.bind(statement, identifier, value); + } else { + operation.bindNull(statement, identifier, value.getType()); + } + } + + @Override + public boolean isIndexed() { + return false; + } + } +} diff --git a/src/main/java/org/springframework/data/r2dbc/function/DatabaseClient.java b/src/main/java/org/springframework/data/r2dbc/function/DatabaseClient.java index 7062fca5..df95db78 100644 --- a/src/main/java/org/springframework/data/r2dbc/function/DatabaseClient.java +++ b/src/main/java/org/springframework/data/r2dbc/function/DatabaseClient.java @@ -26,6 +26,7 @@ import java.util.function.Supplier; import org.reactivestreams.Publisher; + import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Sort; import org.springframework.data.r2dbc.support.R2dbcExceptionTranslator; @@ -137,6 +138,9 @@ interface Builder { * Contract for specifying a SQL call along with options leading to the exchange. The SQL string can contain either * native parameter bind markers (e.g. {@literal $1, $2} for Postgres, {@literal @P0, @P1} for SQL Server) or named * parameters (e.g. {@literal :foo, :bar}) when {@link NamedParameterExpander} is enabled. + *

+ * Accepts {@link PreparedOperation} as SQL and binding {@link Supplier}. + *

* * @see NamedParameterExpander * @see DatabaseClient.Builder#namedParameters(NamedParameterExpander) @@ -156,6 +160,7 @@ interface SqlSpec { * * @param sqlSupplier must not be {@literal null}. * @return a new {@link GenericExecuteSpec}. + * @see PreparedOperation */ GenericExecuteSpec sql(Supplier sqlSupplier); } diff --git a/src/main/java/org/springframework/data/r2dbc/function/DefaultDatabaseClient.java b/src/main/java/org/springframework/data/r2dbc/function/DefaultDatabaseClient.java index b7c52268..2b5cda4a 100644 --- a/src/main/java/org/springframework/data/r2dbc/function/DefaultDatabaseClient.java +++ b/src/main/java/org/springframework/data/r2dbc/function/DefaultDatabaseClient.java @@ -37,7 +37,6 @@ import java.util.LinkedHashSet; import java.util.List; import java.util.Map; -import java.util.Set; import java.util.concurrent.atomic.AtomicBoolean; import java.util.function.BiFunction; import java.util.function.Function; @@ -49,6 +48,7 @@ import org.reactivestreams.Publisher; import org.springframework.dao.DataAccessException; +import org.springframework.dao.InvalidDataAccessApiUsageException; import org.springframework.data.domain.Pageable; import org.springframework.data.domain.Sort; import org.springframework.data.r2dbc.UncategorizedR2dbcException; @@ -57,6 +57,7 @@ import org.springframework.data.r2dbc.function.connectionfactory.ConnectionProxy; import org.springframework.data.r2dbc.function.convert.ColumnMapRowMapper; import org.springframework.data.r2dbc.support.R2dbcExceptionTranslator; +import org.springframework.data.relational.core.sql.Insert; import org.springframework.jdbc.core.SqlProvider; import org.springframework.lang.Nullable; import org.springframework.util.Assert; @@ -331,32 +332,15 @@ protected String getSql() { FetchSpec exchange(String sql, BiFunction mappingFunction) { - Function executeFunction = it -> { + PreparedOperation pop; - if (logger.isDebugEnabled()) { - logger.debug("Executing SQL statement [" + sql + "]"); - } - - BindableOperation operation = namedParameters.expand(sql, dataAccessStrategy.getBindMarkersFactory(), - new MapBindParameterSource(byName)); - - Statement statement = it.createStatement(operation.toQuery()); - - byName.forEach((name, o) -> { - - if (o.getValue() != null) { - operation.bind(statement, name, o.getValue()); - } else { - operation.bindNull(statement, name, o.getType()); - } - }); - - bindByIndex(statement, byIndex); - - return statement; - }; + if (sqlSupplier instanceof PreparedOperation) { + pop = ((PreparedOperation) sqlSupplier); + } else { + pop = new ExpandedPreparedOperation(sql, namedParameters, dataAccessStrategy, byName, byIndex); + } - Function> resultFunction = it -> Flux.from(executeFunction.apply(it).execute()); + Function> resultFunction = it -> Flux.from(pop.createBoundStatement(it).execute()); return new DefaultSqlResult<>(DefaultDatabaseClient.this, // sql, // @@ -367,6 +351,7 @@ FetchSpec exchange(String sql, BiFunction mappingFun public ExecuteSpecSupport bind(int index, Object value) { + assertNotPreparedOperation(); Assert.notNull(value, () -> String.format("Value at index %d must not be null. Use bindNull(…) instead.", index)); Map byIndex = new LinkedHashMap<>(this.byIndex); @@ -377,6 +362,8 @@ public ExecuteSpecSupport bind(int index, Object value) { public ExecuteSpecSupport bindNull(int index, Class type) { + assertNotPreparedOperation(); + Map byIndex = new LinkedHashMap<>(this.byIndex); byIndex.put(index, SettableValue.empty(type)); @@ -385,6 +372,8 @@ public ExecuteSpecSupport bindNull(int index, Class type) { public ExecuteSpecSupport bind(String name, Object value) { + assertNotPreparedOperation(); + Assert.hasText(name, "Parameter name must not be null or empty!"); Assert.notNull(value, () -> String.format("Value for parameter %s must not be null. Use bindNull(…) instead.", name)); @@ -397,6 +386,7 @@ public ExecuteSpecSupport bind(String name, Object value) { public ExecuteSpecSupport bindNull(String name, Class type) { + assertNotPreparedOperation(); Assert.hasText(name, "Parameter name must not be null or empty!"); Map byName = new LinkedHashMap<>(this.byName); @@ -405,6 +395,12 @@ public ExecuteSpecSupport bindNull(String name, Class type) { return createInstance(this.byIndex, byName, this.sqlSupplier); } + private void assertNotPreparedOperation() { + if (sqlSupplier instanceof PreparedOperation) { + throw new InvalidDataAccessApiUsageException("Cannot add bindings to a PreparedOperation"); + } + } + protected ExecuteSpecSupport createInstance(Map byIndex, Map byName, Supplier sqlSupplier) { return new ExecuteSpecSupport(byIndex, byName, sqlSupplier); @@ -882,26 +878,15 @@ private FetchSpec exchange(BiFunction mappingFunctio throw new IllegalStateException("Insert fields is empty!"); } - BindableOperation bindableInsert = dataAccessStrategy.insertAndReturnGeneratedKeys(table, byName.keySet()); - - String sql = bindableInsert.toQuery(); - Function insertFunction = it -> { - - if (logger.isDebugEnabled()) { - logger.debug("Executing SQL statement [" + sql + "]"); - } - - Statement statement = it.createStatement(sql).returnGeneratedValues(); - - byName.forEach((k, v) -> bindableInsert.bind(statement, k, v)); + PreparedOperation operation = dataAccessStrategy.getStatements().insert(table, Collections.emptyList(), + it -> { + byName.forEach(it::bind); + }); - return statement; - }; - - Function> resultFunction = it -> Flux.from(insertFunction.apply(it).execute()); + Function> resultFunction = it -> Flux.from(operation.createBoundStatement(it).execute()); return new DefaultSqlResult<>(DefaultDatabaseClient.this, // - sql, // + operation.toQuery(), // resultFunction, // it -> resultFunction.apply(it).flatMap(Result::getRowsUpdated).next(), // mappingFunction); @@ -999,40 +984,20 @@ private FetchSpec exchange(Object toInsert, BiFunction columns = new LinkedHashSet<>(); - - outboundRow.forEach((k, v) -> { - - if (v.hasValue()) { - columns.add(k); - } - }); - - BindableOperation bindableInsert = dataAccessStrategy.insertAndReturnGeneratedKeys(table, columns); + PreparedOperation operation = dataAccessStrategy.getStatements().insert(table, Collections.emptyList(), + it -> { + outboundRow.forEach((k, v) -> { - String sql = bindableInsert.toQuery(); + if (v.hasValue()) { + it.bind(k, v); + } + }); + }); - Function insertFunction = it -> { - - if (logger.isDebugEnabled()) { - logger.debug("Executing SQL statement [" + sql + "]"); - } - - Statement statement = it.createStatement(sql).returnGeneratedValues(); - - outboundRow.forEach((k, v) -> { - if (v.hasValue()) { - bindableInsert.bind(statement, k, v); - } - }); - - return statement; - }; - - Function> resultFunction = it -> Flux.from(insertFunction.apply(it).execute()); + Function> resultFunction = it -> Flux.from(operation.createBoundStatement(it).execute()); return new DefaultSqlResult<>(DefaultDatabaseClient.this, // - sql, // + operation.toQuery(), // resultFunction, // it -> resultFunction // .apply(it) // diff --git a/src/main/java/org/springframework/data/r2dbc/function/DefaultReactiveDataAccessStrategy.java b/src/main/java/org/springframework/data/r2dbc/function/DefaultReactiveDataAccessStrategy.java index d36ea69c..17f3ddbd 100644 --- a/src/main/java/org/springframework/data/r2dbc/function/DefaultReactiveDataAccessStrategy.java +++ b/src/main/java/org/springframework/data/r2dbc/function/DefaultReactiveDataAccessStrategy.java @@ -17,14 +17,11 @@ import io.r2dbc.spi.Row; import io.r2dbc.spi.RowMetadata; -import io.r2dbc.spi.Statement; import java.util.ArrayList; import java.util.Collection; import java.util.Collections; -import java.util.LinkedHashMap; import java.util.List; -import java.util.Map; import java.util.OptionalLong; import java.util.Set; import java.util.function.BiFunction; @@ -37,8 +34,6 @@ import org.springframework.data.domain.Sort.Order; import org.springframework.data.mapping.context.MappingContext; import org.springframework.data.r2dbc.dialect.ArrayColumns; -import org.springframework.data.r2dbc.dialect.BindMarker; -import org.springframework.data.r2dbc.dialect.BindMarkers; import org.springframework.data.r2dbc.dialect.BindMarkersFactory; import org.springframework.data.r2dbc.dialect.Dialect; import org.springframework.data.r2dbc.domain.OutboundRow; @@ -53,13 +48,17 @@ import org.springframework.data.relational.core.mapping.RelationalPersistentProperty; import org.springframework.data.relational.core.sql.Expression; import org.springframework.data.relational.core.sql.OrderByField; +import org.springframework.data.relational.core.sql.Select; import org.springframework.data.relational.core.sql.SelectBuilder.SelectFromAndOrderBy; import org.springframework.data.relational.core.sql.StatementBuilder; import org.springframework.data.relational.core.sql.Table; +import org.springframework.data.relational.core.sql.render.NamingStrategies; +import org.springframework.data.relational.core.sql.render.RenderContext; +import org.springframework.data.relational.core.sql.render.RenderNamingStrategy; +import org.springframework.data.relational.core.sql.render.SelectRenderContext; import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; -import org.springframework.util.StringUtils; /** * Default {@link ReactiveDataAccessStrategy} implementation. @@ -71,6 +70,7 @@ public class DefaultReactiveDataAccessStrategy implements ReactiveDataAccessStra private final Dialect dialect; private final R2dbcConverter converter; private final MappingContext, ? extends RelationalPersistentProperty> mappingContext; + private final StatementFactory statements; /** * Creates a new {@link DefaultReactiveDataAccessStrategy} given {@link Dialect}. @@ -118,6 +118,30 @@ public DefaultReactiveDataAccessStrategy(Dialect dialect, R2dbcConverter convert this.mappingContext = (MappingContext, ? extends RelationalPersistentProperty>) this.converter .getMappingContext(); this.dialect = dialect; + + RenderContext renderContext = new RenderContext() { + @Override + public RenderNamingStrategy getNamingStrategy() { + return NamingStrategies.asIs(); + } + + @Override + public SelectRenderContext getSelect() { + return new SelectRenderContext() { + @Override + public Function afterSelectList() { + return it -> ""; + } + + @Override + public Function afterOrderBy(boolean hasOrderBy) { + return it -> ""; + } + }; + } + }; + + this.statements = new DefaultStatementFactory(this.dialect, renderContext); } /* @@ -218,7 +242,6 @@ public Sort getMappedSort(Class typeToRead, Sort sort) { * (non-Javadoc) * @see org.springframework.data.r2dbc.function.ReactiveDataAccessStrategy#getRowMapper(java.lang.Class) */ - @SuppressWarnings("unchecked") @Override public BiFunction getRowMapper(Class typeToRead) { return new EntityRowMapper<>(typeToRead, converter); @@ -233,6 +256,15 @@ public String getTableName(Class type) { return getRequiredPersistentEntity(type).getTableName(); } + /* + * (non-Javadoc) + * @see org.springframework.data.r2dbc.function.ReactiveDataAccessStrategy#getStatements() + */ + @Override + public StatementFactory getStatements() { + return this.statements; + } + /* * (non-Javadoc) * @see org.springframework.data.r2dbc.function.ReactiveDataAccessStrategy#getBindMarkersFactory() @@ -251,15 +283,6 @@ private RelationalPersistentEntity getPersistentEntity(Class typeToRead) { return mappingContext.getPersistentEntity(typeToRead); } - /* - * (non-Javadoc) - * @see org.springframework.data.r2dbc.function.ReactiveDataAccessStrategy#insertAndReturnGeneratedKeys(java.lang.String, java.util.Set) - */ - @Override - public BindableOperation insertAndReturnGeneratedKeys(String table, Set columns) { - return new DefaultBindableInsert(dialect.getBindMarkersFactory().create(), table, columns); - } - /* * (non-Javadoc) * @see org.springframework.data.r2dbc.function.ReactiveDataAccessStrategy#select(java.lang.String, java.util.Set, org.springframework.data.domain.Sort, org.springframework.data.domain.Pageable) @@ -290,6 +313,7 @@ public String select(String tableName, Set columns, Sort sort, Pageable offset = OptionalLong.of(page.getOffset()); } + // See https://github.com/spring-projects/spring-data-r2dbc/issues/55 return StatementRenderUtil.render(selectBuilder.build(), limit, offset, this.dialect); } @@ -310,304 +334,4 @@ private Collection createOrderByFields(Table table, Sort return fields; } - - /* - * (non-Javadoc) - * @see org.springframework.data.r2dbc.function.ReactiveDataAccessStrategy#updateById(java.lang.String, java.util.Set, java.lang.String) - */ - @Override - public BindIdOperation updateById(String table, Set columns, String idColumn) { - return new DefaultBindableUpdate(dialect.getBindMarkersFactory().create(), table, columns, idColumn); - } - - /* - * (non-Javadoc) - * @see org.springframework.data.r2dbc.function.ReactiveDataAccessStrategy#deleteById(java.lang.String, java.lang.String) - */ - @Override - public BindIdOperation deleteById(String table, String idColumn) { - - return new DefaultBindIdOperation(dialect.getBindMarkersFactory().create(), - marker -> String.format("DELETE FROM %s WHERE %s = %s", table, idColumn, marker.getPlaceholder())); - } - - /* - * (non-Javadoc) - * @see org.springframework.data.r2dbc.function.ReactiveDataAccessStrategy#deleteByIdIn(java.lang.String, java.lang.String) - */ - @Override - public BindIdOperation deleteByIdIn(String table, String idColumn) { - - String query = String.format("DELETE FROM %s", table); - return new DefaultBindIdIn(dialect.getBindMarkersFactory().create(), query, idColumn); - } - - /** - * Default {@link BindableOperation} implementation for a {@code INSERT} operation. - */ - static class DefaultBindableInsert implements BindableOperation { - - private final Map markers = new LinkedHashMap<>(); - private final String query; - - DefaultBindableInsert(BindMarkers bindMarkers, String table, Collection columns) { - - StringBuilder builder = new StringBuilder(); - List placeholders = new ArrayList<>(columns.size()); - - for (String column : columns) { - BindMarker marker = markers.computeIfAbsent(column, bindMarkers::next); - placeholders.add(marker.getPlaceholder()); - } - - String columnsString = StringUtils.collectionToDelimitedString(columns, ", "); - String placeholdersString = StringUtils.collectionToDelimitedString(placeholders, ", "); - - builder.append("INSERT INTO ").append(table).append(" (").append(columnsString).append(")").append(" VALUES(") - .append(placeholdersString).append(")"); - - this.query = builder.toString(); - } - - /* - * (non-Javadoc) - * @see org.springframework.data.r2dbc.function.BindableOperation#bind(io.r2dbc.spi.Statement, java.lang.String, java.lang.Object) - */ - @Override - public void bind(Statement statement, String identifier, Object value) { - markers.get(identifier).bind(statement, value); - } - - /* - * (non-Javadoc) - * @see org.springframework.data.r2dbc.function.BindableOperation#bindNull(io.r2dbc.spi.Statement, java.lang.String, java.lang.Class) - */ - @Override - public void bindNull(Statement statement, String identifier, Class valueType) { - markers.get(identifier).bindNull(statement, valueType); - } - - /* - * (non-Javadoc) - * @see org.springframework.data.r2dbc.function.QueryOperation#toQuery() - */ - @Override - public String toQuery() { - return this.query; - } - } - - /** - * Default {@link BindIdOperation} implementation for a {@code UPDATE} operation using a single key. - */ - static class DefaultBindableUpdate implements BindIdOperation { - - private final Map markers = new LinkedHashMap<>(); - private final BindMarker idMarker; - private final String query; - - DefaultBindableUpdate(BindMarkers bindMarkers, String tableName, Set columns, String idColumnName) { - - this.idMarker = bindMarkers.next(); - - StringBuilder setClause = new StringBuilder(); - - for (String column : columns) { - - BindMarker marker = markers.computeIfAbsent(column, bindMarkers::next); - - if (setClause.length() != 0) { - setClause.append(", "); - } - - setClause.append(column).append(" = ").append(marker.getPlaceholder()); - } - - this.query = String.format("UPDATE %s SET %s WHERE %s = %s", tableName, setClause, idColumnName, - idMarker.getPlaceholder()); - } - - /* - * (non-Javadoc) - * @see org.springframework.data.r2dbc.function.BindableOperation#bind(io.r2dbc.spi.Statement, java.lang.String, java.lang.Object) - */ - @Override - public void bind(Statement statement, String identifier, Object value) { - markers.get(identifier).bind(statement, value); - } - - /* - * (non-Javadoc) - * @see org.springframework.data.r2dbc.function.BindableOperation#bindNull(io.r2dbc.spi.Statement, java.lang.String, java.lang.Class) - */ - @Override - public void bindNull(Statement statement, String identifier, Class valueType) { - markers.get(identifier).bindNull(statement, valueType); - } - - /* - * (non-Javadoc) - * @see org.springframework.data.r2dbc.function.BindIdOperation#bindId(io.r2dbc.spi.Statement, java.lang.Object) - */ - @Override - public void bindId(Statement statement, Object value) { - idMarker.bind(statement, value); - } - - /* - * (non-Javadoc) - * @see org.springframework.data.r2dbc.function.BindIdOperation#bindIds(io.r2dbc.spi.Statement, java.lang.Iterable) - */ - @Override - public void bindIds(Statement statement, Iterable values) { - throw new UnsupportedOperationException(); - } - - /* - * (non-Javadoc) - * @see org.springframework.data.r2dbc.function.QueryOperation#toQuery() - */ - @Override - public String toQuery() { - return this.query; - } - } - - /** - * Default {@link BindIdOperation} implementation for a {@code SELECT} or {@code DELETE} operation using a single key - * in the {@code WHERE} predicate. - */ - static class DefaultBindIdOperation implements BindIdOperation { - - private final BindMarker idMarker; - private final String query; - - DefaultBindIdOperation(BindMarkers bindMarkers, Function queryFunction) { - - this.idMarker = bindMarkers.next(); - this.query = queryFunction.apply(this.idMarker); - } - - /* - * (non-Javadoc) - * @see org.springframework.data.r2dbc.function.BindableOperation#bind(io.r2dbc.spi.Statement, java.lang.String, java.lang.Object) - */ - @Override - public void bind(Statement statement, String identifier, Object value) { - throw new UnsupportedOperationException(); - } - - /* - * (non-Javadoc) - * @see org.springframework.data.r2dbc.function.BindableOperation#bindNull(io.r2dbc.spi.Statement, java.lang.String, java.lang.Class) - */ - @Override - public void bindNull(Statement statement, String identifier, Class valueType) { - throw new UnsupportedOperationException(); - } - - /* - * (non-Javadoc) - * @see org.springframework.data.r2dbc.function.BindIdOperation#bindId(io.r2dbc.spi.Statement, java.lang.Object) - */ - @Override - public void bindId(Statement statement, Object value) { - idMarker.bind(statement, value); - } - - /* - * (non-Javadoc) - * @see org.springframework.data.r2dbc.function.BindIdOperation#bindIds(io.r2dbc.spi.Statement, java.lang.Iterable) - */ - @Override - public void bindIds(Statement statement, Iterable values) { - throw new UnsupportedOperationException(); - } - - /* - * (non-Javadoc) - * @see org.springframework.data.r2dbc.function.QueryOperation#toQuery() - */ - @Override - public String toQuery() { - return this.query; - } - } - - /** - * Default {@link BindIdOperation} implementation for a {@code SELECT … WHERE id IN (…)} or - * {@code DELETE … WHERE id IN (…)}. - */ - static class DefaultBindIdIn implements BindIdOperation { - - private final List markers = new ArrayList<>(); - private final BindMarkers bindMarkers; - private final String baseQuery; - private final String idColumnName; - - DefaultBindIdIn(BindMarkers bindMarkers, String baseQuery, String idColumnName) { - - this.bindMarkers = bindMarkers; - this.baseQuery = baseQuery; - this.idColumnName = idColumnName; - } - - /* - * (non-Javadoc) - * @see org.springframework.data.r2dbc.function.BindableOperation#bind(io.r2dbc.spi.Statement, java.lang.String, java.lang.Object) - */ - @Override - public void bind(Statement statement, String identifier, Object value) { - throw new UnsupportedOperationException(); - } - - /* - * (non-Javadoc) - * @see org.springframework.data.r2dbc.function.BindableOperation#bindNull(io.r2dbc.spi.Statement, java.lang.String, java.lang.Class) - */ - @Override - public void bindNull(Statement statement, String identifier, Class valueType) { - throw new UnsupportedOperationException(); - } - - /* - * (non-Javadoc) - * @see org.springframework.data.r2dbc.function.BindIdOperation#bindId(io.r2dbc.spi.Statement, java.lang.Object) - */ - @Override - public void bindId(Statement statement, Object value) { - - BindMarker bindMarker = bindMarkers.next(); - markers.add(bindMarker.getPlaceholder()); - bindMarker.bind(statement, value); - } - - /* - * (non-Javadoc) - * @see org.springframework.data.r2dbc.function.BindIdOperation#bindIds(io.r2dbc.spi.Statement, java.lang.Iterable) - */ - @Override - public void bindIds(Statement statement, Iterable values) { - - for (Object value : values) { - bindId(statement, value); - } - } - - /* - * (non-Javadoc) - * @see org.springframework.data.r2dbc.function.QueryOperation#toQuery() - */ - @Override - public String toQuery() { - - if (this.markers.isEmpty()) { - throw new UnsupportedOperationException(); - } - - String in = StringUtils.collectionToDelimitedString(this.markers, ", "); - - return String.format("%s WHERE %s IN (%s)", this.baseQuery, this.idColumnName, in); - } - } } diff --git a/src/main/java/org/springframework/data/r2dbc/function/DefaultStatementFactory.java b/src/main/java/org/springframework/data/r2dbc/function/DefaultStatementFactory.java new file mode 100644 index 00000000..d9fcc61e --- /dev/null +++ b/src/main/java/org/springframework/data/r2dbc/function/DefaultStatementFactory.java @@ -0,0 +1,558 @@ +/* + * Copyright 2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.springframework.data.r2dbc.function; + +import io.r2dbc.spi.Connection; +import io.r2dbc.spi.Statement; +import lombok.Getter; +import lombok.RequiredArgsConstructor; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.BiConsumer; +import java.util.function.BiFunction; +import java.util.function.Consumer; +import java.util.function.Function; + +import org.jetbrains.annotations.NotNull; +import org.springframework.dao.InvalidDataAccessApiUsageException; +import org.springframework.data.r2dbc.dialect.BindMarker; +import org.springframework.data.r2dbc.dialect.BindMarkers; +import org.springframework.data.r2dbc.dialect.Dialect; +import org.springframework.data.r2dbc.dialect.IndexedBindMarker; +import org.springframework.data.r2dbc.domain.SettableValue; +import org.springframework.data.relational.core.sql.*; +import org.springframework.data.relational.core.sql.render.RenderContext; +import org.springframework.data.relational.core.sql.render.SqlRenderer; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; + +/** + * Default {@link StatementFactory} implementation. + * + * @author Mark Paluch + */ +@RequiredArgsConstructor +class DefaultStatementFactory implements StatementFactory { + + private final Dialect dialect; + private final RenderContext renderContext; + + /* + * (non-Javadoc) + * @see org.springframework.data.r2dbc.function.StatementFactory#select(java.lang.String, java.util.Collection, java.util.function.Consumer) + */ + @Override + public PreparedOperation select(String tableName, Collection columnNames, + Consumer binderConsumer); + + /** + * Creates a {@link Insert} statement. + * + * @param tableName must not be {@literal null} or empty. + * @param generatedKeysNames must not be {@literal null}. + * @param binderConsumer customizer for bindings. Supports only + * {@link StatementBinderBuilder#bind(String, SettableValue)} bindings. + * @return the {@link PreparedOperation} to update values in {@code tableName} assigning bound values. + */ + PreparedOperation insert(String tableName, Collection generatedKeysNames, + Consumer binderConsumer); + + /** + * Creates a {@link Update} statement. + * + * @param tableName must not be {@literal null} or empty. + * @param binderConsumer customizer for bindings. + * @return the {@link PreparedOperation} to update values in {@code tableName} assigning bound values. + */ + PreparedOperation update(String tableName, Consumer binderConsumer); + + /** + * Creates a {@link Delete} statement. + * + * @param tableName must not be {@literal null} or empty. + * @param binderConsumer customizer for bindings. Supports only + * {@link StatementBinderBuilder#filterBy(String, SettableValue)} bindings. + * @return the {@link PreparedOperation} to delete rows from {@code tableName}. + */ + PreparedOperation delete(String tableName, Consumer binderConsumer); + + /** + * Binder to specify parameter bindings by name. Bindings match to equals comparisons. + */ + interface StatementBinderBuilder { + + /** + * Bind the given Id {@code value} to this builder using the underlying binding strategy to express a filter + * condition. {@link Collection} type values translate to {@code IN} matching. + * + * @param identifier named identifier that is considered by the underlying binding strategy. + * @param settable must not be {@literal null}. Use {@link SettableValue#empty(Class)} for {@code NULL} values. + */ + void filterBy(String identifier, SettableValue settable); + + /** + * Bind the given {@code value} to this builder using the underlying binding strategy. + * + * @param identifier named identifier that is considered by the underlying binding strategy. + * @param settable must not be {@literal null}. Use {@link SettableValue#empty(Class)} for {@code NULL} values. + */ + void bind(String identifier, SettableValue settable); + } +} diff --git a/src/main/java/org/springframework/data/r2dbc/repository/support/SimpleR2dbcRepository.java b/src/main/java/org/springframework/data/r2dbc/repository/support/SimpleR2dbcRepository.java index 19636168..cda23d9a 100644 --- a/src/main/java/org/springframework/data/r2dbc/repository/support/SimpleR2dbcRepository.java +++ b/src/main/java/org/springframework/data/r2dbc/repository/support/SimpleR2dbcRepository.java @@ -20,29 +20,25 @@ import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; -import java.util.ArrayList; +import java.util.Collections; import java.util.LinkedHashSet; -import java.util.List; import java.util.Map; import java.util.Set; import org.reactivestreams.Publisher; -import org.springframework.data.r2dbc.dialect.BindMarker; -import org.springframework.data.r2dbc.dialect.BindMarkers; import org.springframework.data.r2dbc.domain.SettableValue; -import org.springframework.data.r2dbc.function.BindIdOperation; import org.springframework.data.r2dbc.function.DatabaseClient; -import org.springframework.data.r2dbc.function.DatabaseClient.GenericExecuteSpec; +import org.springframework.data.r2dbc.function.PreparedOperation; import org.springframework.data.r2dbc.function.ReactiveDataAccessStrategy; +import org.springframework.data.r2dbc.function.StatementFactory; import org.springframework.data.r2dbc.function.convert.R2dbcConverter; -import org.springframework.data.relational.core.sql.Conditions; -import org.springframework.data.relational.core.sql.Expression; +import org.springframework.data.relational.core.sql.Delete; import org.springframework.data.relational.core.sql.Functions; -import org.springframework.data.relational.core.sql.SQL; import org.springframework.data.relational.core.sql.Select; import org.springframework.data.relational.core.sql.StatementBuilder; import org.springframework.data.relational.core.sql.Table; +import org.springframework.data.relational.core.sql.Update; import org.springframework.data.relational.core.sql.render.SqlRenderer; import org.springframework.data.relational.repository.query.RelationalEntityInformation; import org.springframework.data.repository.reactive.ReactiveCrudRepository; @@ -83,15 +79,13 @@ public Mono save(S objectToSave) { Map columns = accessStrategy.getOutboundRow(objectToSave); columns.remove(getIdColumnName()); // do not update the Id column. String idColumnName = getIdColumnName(); - BindIdOperation update = accessStrategy.updateById(entity.getTableName(), columns.keySet(), idColumnName); - GenericExecuteSpec exec = databaseClient.execute().sql(update); - - BindSpecAdapter wrapper = BindSpecAdapter.create(exec); - columns.forEach((k, v) -> update.bind(wrapper, k, v)); - update.bindId(wrapper, id); + PreparedOperation operation = accessStrategy.getStatements().update(entity.getTableName(), binder -> { + columns.forEach(binder::bind); + binder.filterBy(idColumnName, SettableValue.from(id)); + }); - return wrapper.getBoundOperation().as(entity.getJavaType()) // + return databaseClient.execute().sql(operation).as(entity.getJavaType()) // .then() // .thenReturn(objectToSave); } @@ -129,18 +123,14 @@ public Mono findById(ID id) { Set columns = new LinkedHashSet<>(accessStrategy.getAllColumns(entity.getJavaType())); String idColumnName = getIdColumnName(); - BindMarkers bindMarkers = accessStrategy.getBindMarkersFactory().create(); - BindMarker bindMarker = bindMarkers.next("id"); + StatementFactory statements; - Table table = Table.create(entity.getTableName()); - Select select = StatementBuilder // - .select(table.columns(columns)) // - .from(table) // - .where(Conditions.isEqual(table.column(idColumnName), SQL.bindMarker(bindMarker.getPlaceholder()))) // - .build(); + PreparedOperation operation = accessStrategy.getStatements().select(entity.getTableName(), + Collections.singleton(idColumnName), binder -> { + binder.filterBy(idColumnName, SettableValue.from(id)); + }); - return databaseClient.execute().sql(SqlRenderer.toString(select)) // - .bind(0, id) // + return databaseClient.execute().sql(operation) // .map((r, md) -> r) // .first() // .hasElement(); @@ -225,25 +209,12 @@ public Flux findAllById(Publisher idPublisher) { Set columns = new LinkedHashSet<>(accessStrategy.getAllColumns(entity.getJavaType())); String idColumnName = getIdColumnName(); - BindMarkers bindMarkers = accessStrategy.getBindMarkersFactory().create(); - - List markers = new ArrayList<>(); - - for (int i = 0; i < ids.size(); i++) { - markers.add(SQL.bindMarker(bindMarkers.next("id").getPlaceholder())); - } - - Table table = Table.create(entity.getTableName()); - Select select = StatementBuilder.select(table.columns(columns)).from(table) - .where(Conditions.in(table.column(idColumnName), markers)).build(); - - GenericExecuteSpec executeSpec = databaseClient.execute().sql(SqlRenderer.toString(select)); + PreparedOperation select = statements.select("foo", Arrays.asList("bar", "baz"), it -> {}); + + assertThat(select.getSource()).isInstanceOf(Select.class); + assertThat(select.toQuery()).isEqualTo("SELECT foo.bar, foo.baz FROM foo"); + + select.createBoundStatement(connectionMock); + + verifyZeroInteractions(statementMock); + } + + @Test + public void shouldToQuerySimpleSelectWithSimpleFilter() { + + PreparedOperation select = statements.select("foo", Arrays.asList("bar", "baz"), it -> { + it.filterBy("doe", SettableValue.from("John")); + it.filterBy("baz", SettableValue.from("Jake")); + }); + + assertThat(select.getSource()).isInstanceOf(Select.class); + assertThat(select.toQuery()).isEqualTo("SELECT foo.bar, foo.baz FROM foo WHERE foo.doe = $1 AND foo.baz = $2"); + + select.createBoundStatement(connectionMock); + + verify(statementMock).bind(0, "John"); + verify(statementMock).bind(1, "Jake"); + verifyNoMoreInteractions(statementMock); + } + + @Test + public void shouldToQuerySimpleSelectWithNullFilter() { + + PreparedOperation select = statements.select("foo", Arrays.asList("bar", "baz"), it -> { + it.filterBy("doe", SettableValue.from(Arrays.asList("John", "Jake"))); + }); + + assertThat(select.getSource()).isInstanceOf(Select.class); + assertThat(select.toQuery()).isEqualTo("SELECT foo.bar, foo.baz FROM foo WHERE foo.doe IN ($1, $2)"); + + select.createBoundStatement(connectionMock); + verify(statementMock).bind(0, "John"); + verify(statementMock).bind(1, "Jake"); + verifyNoMoreInteractions(statementMock); + } + + @Test + public void shouldFailInsertToQueryingWithoutValueBindings() { + + assertThatThrownBy(() -> statements.insert("foo", Collections.emptyList(), it -> {})) + .isInstanceOf(IllegalStateException.class); + } + + @Test + public void shouldToQuerySimpleInsert() { + + PreparedOperation insert = statements.insert("foo", Collections.emptyList(), it -> { + it.bind("bar", SettableValue.from("Foo")); + }); + + assertThat(insert.getSource()).isInstanceOf(Insert.class); + assertThat(insert.toQuery()).isEqualTo("INSERT INTO foo (bar) VALUES ($1)"); + + insert.createBoundStatement(connectionMock); + verify(statementMock).bind(0, "Foo"); + verify(statementMock).returnGeneratedValues(any(String[].class)); + verifyNoMoreInteractions(statementMock); + } + + @Test + public void shouldFailUpdateToQueryingWithoutValueBindings() { + + assertThatThrownBy(() -> statements.update("foo", it -> it.filterBy("foo", SettableValue.empty(Object.class)))) + .isInstanceOf(IllegalStateException.class); + } + + @Test + public void shouldToQuerySimpleUpdate() { + + PreparedOperation update = statements.update("foo", it -> { + it.bind("bar", SettableValue.from("Foo")); + }); + + assertThat(update.getSource()).isInstanceOf(Update.class); + assertThat(update.toQuery()).isEqualTo("UPDATE foo SET bar = $1"); + + update.createBoundStatement(connectionMock); + verify(statementMock).bind(0, "Foo"); + verifyNoMoreInteractions(statementMock); + } + + @Test + public void shouldToQueryNullUpdate() { + + PreparedOperation update = statements.update("foo", it -> { + it.bind("bar", SettableValue.empty(String.class)); + }); + + assertThat(update.getSource()).isInstanceOf(Update.class); + assertThat(update.toQuery()).isEqualTo("UPDATE foo SET bar = $1"); + + update.createBoundStatement(connectionMock); + verify(statementMock).bindNull(0, String.class); + + verifyNoMoreInteractions(statementMock); + } + + @Test + public void shouldToQueryUpdateWithFilter() { + + PreparedOperation update = statements.update("foo", it -> { + it.bind("bar", SettableValue.from("Foo")); + it.filterBy("baz", SettableValue.from("Baz")); + }); + + assertThat(update.getSource()).isInstanceOf(Update.class); + assertThat(update.toQuery()).isEqualTo("UPDATE foo SET bar = $1 WHERE foo.baz = $2"); + + update.createBoundStatement(connectionMock); + verify(statementMock).bind(0, "Foo"); + verify(statementMock).bind(1, "Baz"); + verifyNoMoreInteractions(statementMock); + } + + @Test + public void shouldToQuerySimpleDeleteWithSimpleFilter() { + + PreparedOperation delete = statements.delete("foo", it -> { + it.filterBy("doe", SettableValue.from("John")); + }); + + assertThat(delete.getSource()).isInstanceOf(Delete.class); + assertThat(delete.toQuery()).isEqualTo("DELETE FROM foo WHERE foo.doe = $1"); + + delete.createBoundStatement(connectionMock); + verify(statementMock).bind(0, "John"); + verifyNoMoreInteractions(statementMock); + } + + @Test + public void shouldToQuerySimpleDeleteWithMultipleFilters() { + + PreparedOperation delete = statements.delete("foo", it -> { + it.filterBy("doe", SettableValue.from("John")); + it.filterBy("baz", SettableValue.from("Jake")); + }); + + assertThat(delete.getSource()).isInstanceOf(Delete.class); + assertThat(delete.toQuery()).isEqualTo("DELETE FROM foo WHERE foo.doe = $1 AND foo.baz = $2"); + + delete.createBoundStatement(connectionMock); + verify(statementMock).bind(0, "John"); + verify(statementMock).bind(1, "Jake"); + verifyNoMoreInteractions(statementMock); + } + + @Test + public void shouldToQuerySimpleDeleteWithNullFilter() { + + PreparedOperation delete = statements.delete("foo", it -> { + it.filterBy("doe", SettableValue.empty(String.class)); + }); + + assertThat(delete.getSource()).isInstanceOf(Delete.class); + assertThat(delete.toQuery()).isEqualTo("DELETE FROM foo WHERE foo.doe IS NULL"); + + delete.createBoundStatement(connectionMock); + verifyZeroInteractions(statementMock); + } +}