diff --git a/pom.xml b/pom.xml
index 6e43082e..e1b445b2 100644
--- a/pom.xml
+++ b/pom.xml
@@ -7,7 +7,7 @@
org.springframework.dataspring-data-r2dbc
- 1.1.0.BUILD-SNAPSHOT
+ 1.1.0.gh-189-SNAPSHOTSpring Data R2DBCSpring Data module for R2DBC
diff --git a/src/main/asciidoc/new-features.adoc b/src/main/asciidoc/new-features.adoc
index 6f468840..07d4504e 100644
--- a/src/main/asciidoc/new-features.adoc
+++ b/src/main/asciidoc/new-features.adoc
@@ -5,7 +5,8 @@
== What's New in Spring Data R2DBC 1.1.0 RELEASE
* Introduction of `R2dbcEntityTemplate` for entity-oriented operations.
-* Support interface projections with `DatabaseClient.as(…)`
+* Support interface projections with `DatabaseClient.as(…)`.
+* <>.
[[new-features.1-0-0-RELEASE]]
== What's New in Spring Data R2DBC 1.0.0 RELEASE
diff --git a/src/main/asciidoc/reference/r2dbc-sql.adoc b/src/main/asciidoc/reference/r2dbc-sql.adoc
index 58c0ce07..7a750855 100644
--- a/src/main/asciidoc/reference/r2dbc-sql.adoc
+++ b/src/main/asciidoc/reference/r2dbc-sql.adoc
@@ -134,7 +134,7 @@ In JDBC, the actual drivers translate `?` bind markers to database-native marker
Spring Data R2DBC lets you use native bind markers or named bind markers with the `:name` syntax.
-Named parameter support leverages a `R2dbcDialect` instance to expand named parameters to native bind markers at the time of query execution, which gives you a certain degree of query portability across various database vendors.
+Named parameter support leverages a `R2dbcDialect` instance to expand named parameters to native bind markers at the time of query execution, which gives you a certain degree of query portability across various database vendors.
****
The query-preprocessor unrolls named `Collection` parameters into a series of bind markers to remove the need of dynamic query creation based on the number of arguments.
@@ -159,7 +159,7 @@ tuples.add(new Object[] {"John", 35});
tuples.add(new Object[] {"Ann", 50});
db.execute("SELECT id, name, state FROM table WHERE (name, age) IN (:tuples)")
- .bind("tuples", tuples);
+ .bind("tuples", tuples)
----
====
@@ -171,6 +171,38 @@ The following example shows a simpler variant using `IN` predicates:
[source,java]
----
db.execute("SELECT id, name, state FROM table WHERE age IN (:ages)")
- .bind("ages", Arrays.asList(35, 50));
+ .bind("ages", Arrays.asList(35, 50))
----
====
+
+[[r2dbc.datbaseclient.filter]]
+== Statement Filters
+
+You can register a `Statement` filter (`StatementFilterFunction`) through `DatabaseClient` to intercept and modify statements in their execution, as the following example shows:
+
+====
+[source,java]
+----
+db.execute("INSERT INTO table (name, state) VALUES(:name, :state)")
+ .filter((s, next) -> next.execute(s.returnGeneratedValues("id")))
+ .bind("name", …)
+ .bind("state", …)
+----
+====
+
+`DatabaseClient` exposes also simplified `filter(…)` overload accepting `UnaryOperator`:
+
+====
+[source,java]
+----
+db.execute("INSERT INTO table (name, state) VALUES(:name, :state)")
+ .filter(s -> s.returnGeneratedValues("id"))
+ .bind("name", …)
+ .bind("state", …)
+
+db.execute("SELECT id, name, state FROM table")
+ .filter(s -> s.fetchSize(25))
+----
+====
+
+`StatementFilterFunction` allow filtering of the executed `Statement` and filtering of `Result` objects.
diff --git a/src/main/java/org/springframework/data/r2dbc/core/DatabaseClient.java b/src/main/java/org/springframework/data/r2dbc/core/DatabaseClient.java
index 7fde95f6..6e090499 100644
--- a/src/main/java/org/springframework/data/r2dbc/core/DatabaseClient.java
+++ b/src/main/java/org/springframework/data/r2dbc/core/DatabaseClient.java
@@ -18,6 +18,7 @@
import io.r2dbc.spi.ConnectionFactory;
import io.r2dbc.spi.Row;
import io.r2dbc.spi.RowMetadata;
+import io.r2dbc.spi.Statement;
import reactor.core.publisher.Mono;
import java.util.Arrays;
@@ -37,6 +38,7 @@
import org.springframework.data.r2dbc.query.Update;
import org.springframework.data.r2dbc.support.R2dbcExceptionTranslator;
import org.springframework.data.relational.core.sql.SqlIdentifier;
+import org.springframework.util.Assert;
/**
* A non-blocking, reactive client for performing database calls requests with Reactive Streams back pressure. Provides
@@ -142,6 +144,16 @@ interface Builder {
*/
Builder exceptionTranslator(R2dbcExceptionTranslator exceptionTranslator);
+ /**
+ * Configures a {@link ExecuteFunction} to execute {@link Statement} objects.
+ *
+ * @param executeFunction must not be {@literal null}.
+ * @return {@code this} {@link Builder}.
+ * @since 1.1
+ * @see Statement#execute()
+ */
+ Builder executeFunction(ExecuteFunction executeFunction);
+
/**
* Configures a {@link ReactiveDataAccessStrategy}.
*
@@ -186,7 +198,7 @@ interface Builder {
/**
* Contract for specifying a SQL call along with options leading to the exchange.
*/
- interface GenericExecuteSpec extends BindSpec {
+ interface GenericExecuteSpec extends BindSpec, StatementFilterSpec {
/**
* Define the target type the result should be mapped to.
@@ -231,7 +243,7 @@ interface GenericExecuteSpec extends BindSpec {
/**
* Contract for specifying a SQL call along with options leading to the exchange.
*/
- interface TypedExecuteSpec extends BindSpec> {
+ interface TypedExecuteSpec extends BindSpec>, StatementFilterSpec> {
/**
* Define the target type the result should be mapped to.
@@ -866,4 +878,31 @@ interface BindSpec> {
*/
S bindNull(String name, Class> type);
}
+
+ /**
+ * Contract for applying a {@link StatementFilterFunction}.
+ *
+ * @since 1.1
+ */
+ interface StatementFilterSpec> {
+
+ /**
+ * Add the given filter to the end of the filter chain.
+ *
+ * @param filter the filter to be added to the chain.
+ */
+ default S filter(Function super Statement, ? extends Statement> filter) {
+
+ Assert.notNull(filter, "Statement FilterFunction must not be null!");
+
+ return filter((statement, next) -> next.execute(filter.apply(statement)));
+ }
+
+ /**
+ * Add the given filter to the end of the filter chain.
+ *
+ * @param filter the filter to be added to the chain.
+ */
+ S filter(StatementFilterFunction filter);
+ }
}
diff --git a/src/main/java/org/springframework/data/r2dbc/core/DefaultDatabaseClient.java b/src/main/java/org/springframework/data/r2dbc/core/DefaultDatabaseClient.java
index 7dc37bfe..f4204c9c 100644
--- a/src/main/java/org/springframework/data/r2dbc/core/DefaultDatabaseClient.java
+++ b/src/main/java/org/springframework/data/r2dbc/core/DefaultDatabaseClient.java
@@ -78,6 +78,8 @@ class DefaultDatabaseClient implements DatabaseClient, ConnectionAccessor {
private final R2dbcExceptionTranslator exceptionTranslator;
+ private final ExecuteFunction executeFunction;
+
private final ReactiveDataAccessStrategy dataAccessStrategy;
private final boolean namedParameters;
@@ -87,11 +89,12 @@ class DefaultDatabaseClient implements DatabaseClient, ConnectionAccessor {
private final ProjectionFactory projectionFactory;
DefaultDatabaseClient(ConnectionFactory connector, R2dbcExceptionTranslator exceptionTranslator,
- ReactiveDataAccessStrategy dataAccessStrategy, boolean namedParameters, ProjectionFactory projectionFactory,
- DefaultDatabaseClientBuilder builder) {
+ ExecuteFunction executeFunction, ReactiveDataAccessStrategy dataAccessStrategy, boolean namedParameters,
+ ProjectionFactory projectionFactory, DefaultDatabaseClientBuilder builder) {
this.connector = connector;
this.exceptionTranslator = exceptionTranslator;
+ this.executeFunction = executeFunction;
this.dataAccessStrategy = dataAccessStrategy;
this.namedParameters = namedParameters;
this.projectionFactory = projectionFactory;
@@ -264,25 +267,17 @@ protected DataAccessException translateException(String task, @Nullable String s
* Customization hook.
*/
protected DefaultTypedExecuteSpec createTypedExecuteSpec(Map byIndex,
- Map byName, Supplier sqlSupplier, Class typeToRead) {
- return new DefaultTypedExecuteSpec<>(byIndex, byName, sqlSupplier, typeToRead);
- }
-
- /**
- * Customization hook.
- */
- protected DefaultTypedExecuteSpec createTypedExecuteSpec(Map byIndex,
- Map byName, Supplier sqlSupplier,
- BiFunction mappingFunction) {
- return new DefaultTypedExecuteSpec<>(byIndex, byName, sqlSupplier, mappingFunction);
+ Map byName, Supplier sqlSupplier, StatementFilterFunction filterFunction,
+ Class typeToRead) {
+ return new DefaultTypedExecuteSpec<>(byIndex, byName, sqlSupplier, filterFunction, typeToRead);
}
/**
* Customization hook.
*/
protected ExecuteSpecSupport createGenericExecuteSpec(Map byIndex,
- Map byName, Supplier sqlSupplier) {
- return new DefaultGenericExecuteSpec(byIndex, byName, sqlSupplier);
+ Map byName, Supplier sqlSupplier, StatementFilterFunction filterFunction) {
+ return new DefaultGenericExecuteSpec(byIndex, byName, sqlSupplier, filterFunction);
}
/**
@@ -327,26 +322,30 @@ class ExecuteSpecSupport {
final Map byIndex;
final Map byName;
final Supplier sqlSupplier;
+ final StatementFilterFunction filterFunction;
ExecuteSpecSupport(Supplier sqlSupplier) {
this.byIndex = Collections.emptyMap();
this.byName = Collections.emptyMap();
this.sqlSupplier = sqlSupplier;
+ this.filterFunction = StatementFilterFunctions.empty();
}
ExecuteSpecSupport(Map byIndex, Map byName,
- Supplier sqlSupplier) {
+ Supplier sqlSupplier, StatementFilterFunction filterFunction) {
+
this.byIndex = byIndex;
this.byName = byName;
this.sqlSupplier = sqlSupplier;
+ this.filterFunction = filterFunction;
}
FetchSpec exchange(Supplier sqlSupplier, BiFunction mappingFunction) {
String sql = getRequiredSql(sqlSupplier);
- Function executeFunction = it -> {
+ Function statementFactory = it -> {
if (logger.isDebugEnabled()) {
logger.debug("Executing SQL statement [" + sql + "]");
@@ -404,7 +403,7 @@ FetchSpec exchange(Supplier sqlSupplier, BiFunction> resultFunction = toExecuteFunction(sql, executeFunction);
+ Function> resultFunction = toFunction(sql, filterFunction, statementFactory);
return new DefaultSqlResult<>(DefaultDatabaseClient.this, //
sql, //
@@ -426,7 +425,7 @@ public ExecuteSpecSupport bind(int index, Object value) {
byIndex.put(index, SettableValue.fromOrEmpty(value, value.getClass()));
}
- return createInstance(byIndex, this.byName, this.sqlSupplier);
+ return createInstance(byIndex, this.byName, this.sqlSupplier, this.filterFunction);
}
public ExecuteSpecSupport bindNull(int index, Class> type) {
@@ -436,7 +435,7 @@ public ExecuteSpecSupport bindNull(int index, Class> type) {
Map byIndex = new LinkedHashMap<>(this.byIndex);
byIndex.put(index, SettableValue.empty(type));
- return createInstance(byIndex, this.byName, this.sqlSupplier);
+ return createInstance(byIndex, this.byName, this.sqlSupplier, this.filterFunction);
}
public ExecuteSpecSupport bind(String name, Object value) {
@@ -455,7 +454,7 @@ public ExecuteSpecSupport bind(String name, Object value) {
byName.put(name, SettableValue.fromOrEmpty(value, value.getClass()));
}
- return createInstance(this.byIndex, byName, this.sqlSupplier);
+ return createInstance(this.byIndex, byName, this.sqlSupplier, this.filterFunction);
}
public ExecuteSpecSupport bindNull(String name, Class> type) {
@@ -466,7 +465,14 @@ public ExecuteSpecSupport bindNull(String name, Class> type) {
Map byName = new LinkedHashMap<>(this.byName);
byName.put(name, SettableValue.empty(type));
- return createInstance(this.byIndex, byName, this.sqlSupplier);
+ return createInstance(this.byIndex, byName, this.sqlSupplier, this.filterFunction);
+ }
+
+ public ExecuteSpecSupport filter(StatementFilterFunction filter) {
+
+ Assert.notNull(filter, "Statement FilterFunction must not be null!");
+
+ return createInstance(this.byIndex, byName, this.sqlSupplier, this.filterFunction.andThen(filter));
}
private void assertNotPreparedOperation() {
@@ -476,8 +482,8 @@ private void assertNotPreparedOperation() {
}
protected ExecuteSpecSupport createInstance(Map byIndex, Map byName,
- Supplier sqlSupplier) {
- return new ExecuteSpecSupport(byIndex, byName, sqlSupplier);
+ Supplier sqlSupplier, StatementFilterFunction filterFunction) {
+ return new ExecuteSpecSupport(byIndex, byName, sqlSupplier, filterFunction);
}
}
@@ -487,8 +493,8 @@ protected ExecuteSpecSupport createInstance(Map byIndex,
protected class DefaultGenericExecuteSpec extends ExecuteSpecSupport implements GenericExecuteSpec {
DefaultGenericExecuteSpec(Map byIndex, Map byName,
- Supplier sqlSupplier) {
- super(byIndex, byName, sqlSupplier);
+ Supplier sqlSupplier, StatementFilterFunction filterFunction) {
+ super(byIndex, byName, sqlSupplier, filterFunction);
}
DefaultGenericExecuteSpec(Supplier sqlSupplier) {
@@ -500,7 +506,7 @@ public TypedExecuteSpec as(Class resultType) {
Assert.notNull(resultType, "Result type must not be null!");
- return createTypedExecuteSpec(this.byIndex, this.byName, this.sqlSupplier, resultType);
+ return createTypedExecuteSpec(this.byIndex, this.byName, this.sqlSupplier, this.filterFunction, resultType);
}
@Override
@@ -549,10 +555,15 @@ public DefaultGenericExecuteSpec bindNull(String name, Class> type) {
return (DefaultGenericExecuteSpec) super.bindNull(name, type);
}
+ @Override
+ public DefaultGenericExecuteSpec filter(StatementFilterFunction filter) {
+ return (DefaultGenericExecuteSpec) super.filter(filter);
+ }
+
@Override
protected ExecuteSpecSupport createInstance(Map byIndex, Map byName,
- Supplier sqlSupplier) {
- return createGenericExecuteSpec(byIndex, byName, sqlSupplier);
+ Supplier sqlSupplier, StatementFilterFunction filterFunction) {
+ return createGenericExecuteSpec(byIndex, byName, sqlSupplier, filterFunction);
}
}
@@ -562,13 +573,13 @@ protected ExecuteSpecSupport createInstance(Map byIndex,
@SuppressWarnings("unchecked")
protected class DefaultTypedExecuteSpec extends ExecuteSpecSupport implements TypedExecuteSpec {
- private final @Nullable Class typeToRead;
+ private final Class typeToRead;
private final BiFunction mappingFunction;
DefaultTypedExecuteSpec(Map byIndex, Map byName,
- Supplier sqlSupplier, Class typeToRead) {
+ Supplier sqlSupplier, StatementFilterFunction filterFunction, Class typeToRead) {
- super(byIndex, byName, sqlSupplier);
+ super(byIndex, byName, sqlSupplier, filterFunction);
this.typeToRead = typeToRead;
@@ -580,21 +591,12 @@ protected class DefaultTypedExecuteSpec extends ExecuteSpecSupport implements
}
}
- DefaultTypedExecuteSpec(Map byIndex, Map byName,
- Supplier sqlSupplier, BiFunction mappingFunction) {
-
- super(byIndex, byName, sqlSupplier);
-
- this.typeToRead = null;
- this.mappingFunction = mappingFunction;
- }
-
@Override
public TypedExecuteSpec as(Class resultType) {
Assert.notNull(resultType, "Result type must not be null!");
- return createTypedExecuteSpec(this.byIndex, this.byName, this.sqlSupplier, resultType);
+ return createTypedExecuteSpec(this.byIndex, this.byName, this.sqlSupplier, this.filterFunction, resultType);
}
@Override
@@ -643,10 +645,15 @@ public DefaultTypedExecuteSpec bindNull(String name, Class> type) {
return (DefaultTypedExecuteSpec) super.bindNull(name, type);
}
+ @Override
+ public DefaultTypedExecuteSpec filter(StatementFilterFunction filter) {
+ return (DefaultTypedExecuteSpec) super.filter(filter);
+ }
+
@Override
protected DefaultTypedExecuteSpec createInstance(Map byIndex,
- Map byName, Supplier sqlSupplier) {
- return createTypedExecuteSpec(byIndex, byName, sqlSupplier, this.typeToRead);
+ Map byName, Supplier sqlSupplier, StatementFilterFunction filterFunction) {
+ return createTypedExecuteSpec(byIndex, byName, sqlSupplier, filterFunction, this.typeToRead);
}
}
@@ -691,8 +698,8 @@ private abstract class DefaultSelectSpecSupport {
this.page = Pageable.unpaged();
}
- DefaultSelectSpecSupport(SqlIdentifier table, List projectedFields, Criteria criteria, Sort sort,
- Pageable page) {
+ DefaultSelectSpecSupport(SqlIdentifier table, List projectedFields, @Nullable Criteria criteria,
+ Sort sort, Pageable page) {
this.table = table;
this.projectedFields = projectedFields;
this.criteria = criteria;
@@ -735,7 +742,8 @@ FetchSpec execute(PreparedOperation> preparedOperation, BiFunction selectFunction = wrapPreparedOperation(sql, preparedOperation);
- Function> resultFunction = DefaultDatabaseClient.toExecuteFunction(sql, selectFunction);
+ Function> resultFunction = toFunction(sql, StatementFilterFunctions.empty(),
+ selectFunction);
return new DefaultSqlResult<>(DefaultDatabaseClient.this, //
sql, //
@@ -745,13 +753,13 @@ FetchSpec execute(PreparedOperation> preparedOperation, BiFunction projectedFields,
- Criteria criteria, Sort sort, Pageable page);
+ @Nullable Criteria criteria, Sort sort, Pageable page);
}
private class DefaultGenericSelectSpec extends DefaultSelectSpecSupport implements GenericSelectSpec {
- DefaultGenericSelectSpec(SqlIdentifier table, List projectedFields, Criteria criteria, Sort sort,
- Pageable page) {
+ DefaultGenericSelectSpec(SqlIdentifier table, List projectedFields, @Nullable Criteria criteria,
+ Sort sort, Pageable page) {
super(table, projectedFields, criteria, sort, page);
}
@@ -834,7 +842,7 @@ private FetchSpec exchange(BiFunction mappingFunctio
@Override
protected DefaultGenericSelectSpec createInstance(SqlIdentifier table, List projectedFields,
- Criteria criteria, Sort sort, Pageable page) {
+ @Nullable Criteria criteria, Sort sort, Pageable page) {
return new DefaultGenericSelectSpec(table, projectedFields, criteria, sort, page);
}
}
@@ -856,8 +864,8 @@ private class DefaultTypedSelectSpec extends DefaultSelectSpecSupport impleme
this.mappingFunction = dataAccessStrategy.getRowMapper(typeToRead);
}
- DefaultTypedSelectSpec(SqlIdentifier table, List projectedFields, Criteria criteria, Sort sort,
- Pageable page, @Nullable Class typeToRead, BiFunction mappingFunction) {
+ DefaultTypedSelectSpec(SqlIdentifier table, List projectedFields, @Nullable Criteria criteria,
+ Sort sort, Pageable page, Class typeToRead, BiFunction mappingFunction) {
super(table, projectedFields, criteria, sort, page);
@@ -948,7 +956,7 @@ private FetchSpec exchange(BiFunction mappingFunctio
@Override
protected DefaultTypedSelectSpec createInstance(SqlIdentifier table, List projectedFields,
- Criteria criteria, Sort sort, Pageable page) {
+ @Nullable Criteria criteria, Sort sort, Pageable page) {
return new DefaultTypedSelectSpec<>(table, projectedFields, criteria, sort, page, this.typeToRead,
this.mappingFunction);
}
@@ -1196,11 +1204,11 @@ class DefaultGenericUpdateSpec implements GenericUpdateSpec, UpdateMatchingSpec
private final @Nullable Class> typeToUpdate;
private final @Nullable SqlIdentifier table;
- private final Update assignments;
- private final Criteria where;
+ private final @Nullable Update assignments;
+ private final @Nullable Criteria where;
- DefaultGenericUpdateSpec(@Nullable Class> typeToUpdate, @Nullable SqlIdentifier table, Update assignments,
- Criteria where) {
+ DefaultGenericUpdateSpec(@Nullable Class> typeToUpdate, @Nullable SqlIdentifier table,
+ @Nullable Update assignments, @Nullable Criteria where) {
this.typeToUpdate = typeToUpdate;
this.table = table;
this.assignments = assignments;
@@ -1229,6 +1237,7 @@ public UpdatedRowsFetchSpec fetch() {
SqlIdentifier table;
if (StringUtils.isEmpty(this.table)) {
+ Assert.state(this.typeToUpdate != null, "Type to update must not be null!");
table = dataAccessStrategy.getTableName(this.typeToUpdate);
} else {
table = this.table;
@@ -1250,6 +1259,7 @@ private UpdatedRowsFetchSpec exchange(SqlIdentifier table) {
mapper = mapper.forType(this.typeToUpdate);
}
+ Assert.state(this.assignments != null, "Update assignments must not be null!");
StatementMapper.UpdateSpec update = mapper.createUpdate(table, this.assignments);
if (this.where != null) {
@@ -1264,11 +1274,11 @@ private UpdatedRowsFetchSpec exchange(SqlIdentifier table) {
class DefaultTypedUpdateSpec implements TypedUpdateSpec, UpdateSpec {
- private final @Nullable Class typeToUpdate;
+ private final Class typeToUpdate;
private final @Nullable SqlIdentifier table;
- private final T objectToUpdate;
+ private final @Nullable T objectToUpdate;
- DefaultTypedUpdateSpec(@Nullable Class typeToUpdate, @Nullable SqlIdentifier table, T objectToUpdate) {
+ DefaultTypedUpdateSpec(Class typeToUpdate, @Nullable SqlIdentifier table, @Nullable T objectToUpdate) {
this.typeToUpdate = typeToUpdate;
this.table = table;
this.objectToUpdate = objectToUpdate;
@@ -1363,9 +1373,9 @@ class DefaultDeleteSpec implements DeleteMatchingSpec, TypedDeleteSpec {
private final @Nullable Class typeToDelete;
private final @Nullable SqlIdentifier table;
- private final Criteria where;
+ private final @Nullable Criteria where;
- DefaultDeleteSpec(@Nullable Class typeToDelete, @Nullable SqlIdentifier table, Criteria where) {
+ DefaultDeleteSpec(@Nullable Class typeToDelete, @Nullable SqlIdentifier table, @Nullable Criteria where) {
this.typeToDelete = typeToDelete;
this.table = table;
this.where = where;
@@ -1393,6 +1403,7 @@ public UpdatedRowsFetchSpec fetch() {
SqlIdentifier table;
if (StringUtils.isEmpty(this.table)) {
+ Assert.state(this.typeToDelete != null, "Type to delete must not be null!");
table = dataAccessStrategy.getTableName(this.typeToDelete);
} else {
table = this.table;
@@ -1432,7 +1443,8 @@ private FetchSpec exchangeInsert(BiFunction mappingF
String sql = getRequiredSql(operation);
Function insertFunction = wrapPreparedOperation(sql, operation)
.andThen(statement -> statement.returnGeneratedValues());
- Function> resultFunction = toExecuteFunction(sql, insertFunction);
+ Function> resultFunction = toFunction(sql, StatementFilterFunctions.empty(),
+ insertFunction);
return new DefaultSqlResult<>(this, //
sql, //
@@ -1445,7 +1457,8 @@ private UpdatedRowsFetchSpec exchangeUpdate(PreparedOperation> operation) {
String sql = getRequiredSql(operation);
Function executeFunction = wrapPreparedOperation(sql, operation);
- Function> resultFunction = toExecuteFunction(sql, executeFunction);
+ Function> resultFunction = toFunction(sql, StatementFilterFunctions.empty(),
+ executeFunction);
return new DefaultSqlResult<>(this, //
sql, //
@@ -1476,12 +1489,15 @@ private Function wrapPreparedOperation(String sql, Prepar
};
}
- private static Function> toExecuteFunction(String sql,
- Function executeFunction) {
+ private Function> toFunction(String sql, StatementFilterFunction filterFunction,
+ Function statementFactory) {
return it -> {
- Flux from = Flux.defer(() -> executeFunction.apply(it).execute()).cast(Result.class);
+ Flux from = Flux.defer(() -> {
+ Statement statement = statementFactory.apply(it);
+ return filterFunction.filter(statement, executeFunction);
+ }).cast(Result.class);
return from.checkpoint("SQL \"" + sql + "\" [DatabaseClient]");
};
}
@@ -1576,9 +1592,7 @@ public Object invoke(Object proxy, Method method, Object[] args) throws Throwabl
// Invoke method on target Connection.
try {
- Object retVal = method.invoke(this.target, args);
-
- return retVal;
+ return method.invoke(this.target, args);
} catch (InvocationTargetException ex) {
throw ex.getTargetException();
}
diff --git a/src/main/java/org/springframework/data/r2dbc/core/DefaultDatabaseClientBuilder.java b/src/main/java/org/springframework/data/r2dbc/core/DefaultDatabaseClientBuilder.java
index c3a186da..5f08ab95 100644
--- a/src/main/java/org/springframework/data/r2dbc/core/DefaultDatabaseClientBuilder.java
+++ b/src/main/java/org/springframework/data/r2dbc/core/DefaultDatabaseClientBuilder.java
@@ -17,6 +17,7 @@
package org.springframework.data.r2dbc.core;
import io.r2dbc.spi.ConnectionFactory;
+import io.r2dbc.spi.Statement;
import java.util.function.Consumer;
@@ -40,6 +41,8 @@ class DefaultDatabaseClientBuilder implements DatabaseClient.Builder {
private @Nullable R2dbcExceptionTranslator exceptionTranslator;
+ private ExecuteFunction executeFunction = Statement::execute;
+
private ReactiveDataAccessStrategy accessStrategy;
private boolean namedParameters = true;
@@ -54,6 +57,7 @@ class DefaultDatabaseClientBuilder implements DatabaseClient.Builder {
this.connectionFactory = other.connectionFactory;
this.exceptionTranslator = other.exceptionTranslator;
+ this.executeFunction = other.executeFunction;
this.accessStrategy = other.accessStrategy;
this.namedParameters = other.namedParameters;
this.projectionFactory = other.projectionFactory;
@@ -85,6 +89,19 @@ public Builder exceptionTranslator(R2dbcExceptionTranslator exceptionTranslator)
return this;
}
+ /*
+ * (non-Javadoc)
+ * @see org.springframework.data.r2dbc.function.DatabaseClient.Builder#executeFunction(org.springframework.data.r2dbc.core.ExecuteFunction)
+ */
+ @Override
+ public Builder executeFunction(ExecuteFunction executeFunction) {
+
+ Assert.notNull(executeFunction, "ExecuteFunction must not be null!");
+
+ this.executeFunction = executeFunction;
+ return this;
+ }
+
/*
* (non-Javadoc)
* @see org.springframework.data.r2dbc.function.DatabaseClient.Builder#dataAccessStrategy(org.springframework.data.r2dbc.function.ReactiveDataAccessStrategy)
@@ -143,8 +160,8 @@ public DatabaseClient build() {
accessStrategy = new DefaultReactiveDataAccessStrategy(dialect);
}
- return new DefaultDatabaseClient(this.connectionFactory, exceptionTranslator, accessStrategy, namedParameters,
- projectionFactory, new DefaultDatabaseClientBuilder(this));
+ return new DefaultDatabaseClient(this.connectionFactory, exceptionTranslator, executeFunction, accessStrategy,
+ namedParameters, projectionFactory, new DefaultDatabaseClientBuilder(this));
}
/*
diff --git a/src/main/java/org/springframework/data/r2dbc/core/ExecuteFunction.java b/src/main/java/org/springframework/data/r2dbc/core/ExecuteFunction.java
new file mode 100644
index 00000000..773916da
--- /dev/null
+++ b/src/main/java/org/springframework/data/r2dbc/core/ExecuteFunction.java
@@ -0,0 +1,46 @@
+/*
+ * Copyright 2020 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.data.r2dbc.core;
+
+import io.r2dbc.spi.Result;
+import io.r2dbc.spi.Statement;
+
+import java.util.function.BiFunction;
+
+import org.reactivestreams.Publisher;
+
+/**
+ * Represents a function that executes a {@link io.r2dbc.spi.Statement} for a (delayed) {@link io.r2dbc.spi.Result}
+ * stream.
+ *
+ * Note that discarded {@link Result} objects must be consumed according to the R2DBC spec via either
+ * {@link Result#getRowsUpdated()} or {@link Result#map(BiFunction)}.
+ *
+ * @author Mark Paluch
+ * @since 1.1
+ * @see Statement#execute()
+ */
+@FunctionalInterface
+public interface ExecuteFunction {
+
+ /**
+ * Execute the given {@link Statement} for a stream of {@link Result}s.
+ *
+ * @param statement the request to execute.
+ * @return the delayed result stream.
+ */
+ Publisher extends Result> execute(Statement statement);
+}
diff --git a/src/main/java/org/springframework/data/r2dbc/core/StatementFilterFunction.java b/src/main/java/org/springframework/data/r2dbc/core/StatementFilterFunction.java
new file mode 100644
index 00000000..520b7ab6
--- /dev/null
+++ b/src/main/java/org/springframework/data/r2dbc/core/StatementFilterFunction.java
@@ -0,0 +1,64 @@
+/*
+ * Copyright 2020 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.data.r2dbc.core;
+
+import io.r2dbc.spi.Result;
+import io.r2dbc.spi.Statement;
+
+import org.reactivestreams.Publisher;
+
+import org.springframework.util.Assert;
+
+/**
+ * Represents a function that filters an {@link ExecuteFunction execute function}.
+ *
+ * The filter is executed when a {@link org.reactivestreams.Subscriber} subscribes to the {@link Publisher} returned by
+ * the {@link DatabaseClient}.
+ *
+ * @author Mark Paluch
+ * @since 1.1
+ * @see ExecuteFunction
+ */
+@FunctionalInterface
+public interface StatementFilterFunction {
+
+ /**
+ * Apply this filter to the given {@link Statement} and {@link ExecuteFunction}.
+ *
+ * The given {@link ExecuteFunction} represents the next entity in the chain, to be invoked via
+ * {@link ExecuteFunction#execute(Statement)} invoked} in order to proceed with the exchange, or not invoked to
+ * shortcut the chain.
+ *
+ * @param statement the current {@link Statement}.
+ * @param next the next exchange function in the chain.
+ * @return the filtered {@link Result}s.
+ */
+ Publisher extends Result> filter(Statement statement, ExecuteFunction next);
+
+ /**
+ * Return a composed filter function that first applies this filter, and then applies the given {@code "after"}
+ * filter.
+ *
+ * @param afterFilter the filter to apply after this filter.
+ * @return the composed filter.
+ */
+ default StatementFilterFunction andThen(StatementFilterFunction afterFilter) {
+
+ Assert.notNull(afterFilter, "StatementFilterFunction must not be null");
+
+ return (request, next) -> filter(request, afterRequest -> afterFilter.filter(afterRequest, next));
+ }
+}
diff --git a/src/main/java/org/springframework/data/r2dbc/core/StatementFilterFunctions.java b/src/main/java/org/springframework/data/r2dbc/core/StatementFilterFunctions.java
new file mode 100644
index 00000000..e9788992
--- /dev/null
+++ b/src/main/java/org/springframework/data/r2dbc/core/StatementFilterFunctions.java
@@ -0,0 +1,46 @@
+/*
+ * Copyright 2020 the original author or authors.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * https://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.springframework.data.r2dbc.core;
+
+import io.r2dbc.spi.Result;
+import io.r2dbc.spi.Statement;
+
+import org.reactivestreams.Publisher;
+
+/**
+ * Collection of default {@link StatementFilterFunction}s.
+ *
+ * @author Mark Paluch
+ * @since 1.1
+ */
+enum StatementFilterFunctions implements StatementFilterFunction {
+
+ EMPTY_FILTER;
+
+ @Override
+ public Publisher extends Result> filter(Statement statement, ExecuteFunction next) {
+ return next.execute(statement);
+ }
+
+ /**
+ * Return an empty {@link StatementFilterFunction} that delegates to {@link ExecuteFunction}.
+ *
+ * @return an empty {@link StatementFilterFunction} that delegates to {@link ExecuteFunction}.
+ */
+ public static StatementFilterFunction empty() {
+ return EMPTY_FILTER;
+ }
+}
diff --git a/src/test/java/org/springframework/data/r2dbc/core/DefaultDatabaseClientUnitTests.java b/src/test/java/org/springframework/data/r2dbc/core/DefaultDatabaseClientUnitTests.java
index 5c88fec9..ba8b4372 100644
--- a/src/test/java/org/springframework/data/r2dbc/core/DefaultDatabaseClientUnitTests.java
+++ b/src/test/java/org/springframework/data/r2dbc/core/DefaultDatabaseClientUnitTests.java
@@ -37,48 +37,51 @@
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
+import org.mockito.InOrder;
import org.mockito.Mock;
+import org.mockito.Mockito;
import org.mockito.junit.MockitoJUnitRunner;
import org.reactivestreams.Publisher;
import org.reactivestreams.Subscription;
-
import org.springframework.beans.factory.annotation.Value;
import org.springframework.data.annotation.Id;
import org.springframework.data.projection.SpelAwareProxyProjectionFactory;
import org.springframework.data.r2dbc.dialect.PostgresDialect;
import org.springframework.data.r2dbc.mapping.SettableValue;
-import org.springframework.data.r2dbc.support.R2dbcExceptionTranslator;
+import org.springframework.lang.Nullable;
/**
* Unit tests for {@link DefaultDatabaseClient}.
*
* @author Mark Paluch
* @author Ferdinand Jacobs
+ * @author Jens Schauder
*/
@RunWith(MockitoJUnitRunner.class)
public class DefaultDatabaseClientUnitTests {
- @Mock ConnectionFactory connectionFactory;
@Mock Connection connection;
- @Mock R2dbcExceptionTranslator translator;
+ private DatabaseClient.Builder databaseClientBuilder;
@Before
public void before() {
+
+ ConnectionFactory connectionFactory = Mockito.mock(ConnectionFactory.class);
+
when(connectionFactory.create()).thenReturn((Publisher) Mono.just(connection));
when(connection.close()).thenReturn(Mono.empty());
+
+ databaseClientBuilder = DatabaseClient.builder() //
+ .connectionFactory(connectionFactory) //
+ .dataAccessStrategy(new DefaultReactiveDataAccessStrategy(PostgresDialect.INSTANCE));
}
@Test // gh-48
public void shouldCloseConnectionOnlyOnce() {
- DefaultDatabaseClient databaseClient = (DefaultDatabaseClient) DatabaseClient.builder()
- .connectionFactory(connectionFactory)
- .dataAccessStrategy(new DefaultReactiveDataAccessStrategy(PostgresDialect.INSTANCE))
- .exceptionTranslator(translator).build();
+ DefaultDatabaseClient databaseClient = (DefaultDatabaseClient) databaseClientBuilder.build();
- Flux