Skip to content

Commit fd7d24c

Browse files
committed
Cancel cursor fetching if the outer stream gets canceled.
[resolves #536] Signed-off-by: Mark Paluch <[email protected]>
1 parent f511bd7 commit fd7d24c

File tree

4 files changed

+81
-31
lines changed

4 files changed

+81
-31
lines changed

src/main/java/io/r2dbc/postgresql/ExtendedFlowDelegate.java

+23-23
Original file line numberDiff line numberDiff line change
@@ -76,21 +76,21 @@ class ExtendedFlowDelegate {
7676
* Execute the {@code Parse/Bind/Describe/Execute/Sync} portion of the <a href="https://www.postgresql.org/docs/current/static/protocol-flow.html#PROTOCOL-FLOW-EXT-QUERY">Extended query</a>
7777
* message flow.
7878
*
79-
* @param resources the {@link ConnectionResources} providing access to the {@link Client}
80-
* @param factory the {@link ExceptionFactory}
81-
* @param query the query to execute
82-
* @param binding the {@link Binding} to bind
83-
* @param values the binding values
84-
* @param fetchSize the fetch size to apply. Use a single {@link Execute} with fetch all if {@code fetchSize} is zero. Otherwise, perform multiple roundtrips with smaller
85-
* {@link Execute} sizes.
79+
* @param resources the {@link ConnectionResources} providing access to the {@link Client}
80+
* @param factory the {@link ExceptionFactory}
81+
* @param query the query to execute
82+
* @param binding the {@link Binding} to bind
83+
* @param values the binding values
84+
* @param fetchSize the fetch size to apply. Use a single {@link Execute} with fetch all if {@code fetchSize} is zero. Otherwise, perform multiple roundtrips with smaller
85+
* {@link Execute} sizes.
86+
* @param isCanceled whether the conversation is canceled
8687
* @return the messages received in response to the exchange
8788
* @throws IllegalArgumentException if {@code bindings}, {@code client}, {@code portalNameSupplier}, or {@code statementName} is {@code null}
8889
*/
89-
public static Flux<BackendMessage> runQuery(ConnectionResources resources, ExceptionFactory factory, String query, Binding binding, List<ByteBuf> values, int fetchSize) {
90+
public static Flux<BackendMessage> runQuery(ConnectionResources resources, ExceptionFactory factory, String query, Binding binding, List<ByteBuf> values, int fetchSize, AtomicBoolean isCanceled) {
9091

9192
StatementCache cache = resources.getStatementCache();
9293
Client client = resources.getClient();
93-
9494
String portal = resources.getPortalNameSupplier().get();
9595

9696
Flux<BackendMessage> exchange;
@@ -104,14 +104,14 @@ public static Flux<BackendMessage> runQuery(ConnectionResources resources, Excep
104104
if (fetchSize == NO_LIMIT || implicitTransactions) {
105105
exchange = fetchAll(operator, client, portal);
106106
} else {
107-
exchange = fetchCursoredWithSync(operator, client, portal, fetchSize);
107+
exchange = fetchCursoredWithSync(operator, client, portal, fetchSize, isCanceled);
108108
}
109109
} else {
110110

111111
if (fetchSize == NO_LIMIT) {
112112
exchange = fetchAll(operator, client, portal);
113113
} else {
114-
exchange = fetchCursoredWithFlush(operator, client, portal, fetchSize);
114+
exchange = fetchCursoredWithFlush(operator, client, portal, fetchSize, isCanceled);
115115
}
116116
}
117117

@@ -147,16 +147,16 @@ private static Flux<BackendMessage> fetchAll(ExtendedFlowOperator operator, Clie
147147
/**
148148
* Execute a chunked query and indicate to fetch rows in chunks with the {@link Execute} message.
149149
*
150-
* @param operator the flow operator
151-
* @param client client to use
152-
* @param portal the portal
153-
* @param fetchSize fetch size per roundtrip
150+
* @param operator the flow operator
151+
* @param client client to use
152+
* @param portal the portal
153+
* @param fetchSize fetch size per roundtrip
154+
* @param isCanceled whether the conversation is canceled
154155
* @return the resulting message stream
155156
*/
156-
private static Flux<BackendMessage> fetchCursoredWithSync(ExtendedFlowOperator operator, Client client, String portal, int fetchSize) {
157+
private static Flux<BackendMessage> fetchCursoredWithSync(ExtendedFlowOperator operator, Client client, String portal, int fetchSize, AtomicBoolean isCanceled) {
157158

158159
Sinks.Many<FrontendMessage> requests = Sinks.many().unicast().onBackpressureBuffer(Queues.<FrontendMessage>small().get());
159-
AtomicBoolean isCanceled = new AtomicBoolean(false);
160160
AtomicBoolean done = new AtomicBoolean(false);
161161

162162
MessageFactory factory = () -> operator.getMessages(Arrays.asList(new Execute(portal, fetchSize), Sync.INSTANCE));
@@ -210,16 +210,16 @@ private static Flux<BackendMessage> fetchCursoredWithSync(ExtendedFlowOperator o
210210
* Execute a contiguous query and indicate to fetch rows in chunks with the {@link Execute} message. Uses {@link Flush}-based synchronization that creates a cursor. Note that flushing keeps the
211211
* cursor open even with implicit transactions and this method may not work with newer pgpool implementations.
212212
*
213-
* @param operator the flow operator
214-
* @param client client to use
215-
* @param portal the portal
216-
* @param fetchSize fetch size per roundtrip
213+
* @param operator the flow operator
214+
* @param client client to use
215+
* @param portal the portal
216+
* @param fetchSize fetch size per roundtrip
217+
* @param isCanceled whether the conversation is canceled
217218
* @return the resulting message stream
218219
*/
219-
private static Flux<BackendMessage> fetchCursoredWithFlush(ExtendedFlowOperator operator, Client client, String portal, int fetchSize) {
220+
private static Flux<BackendMessage> fetchCursoredWithFlush(ExtendedFlowOperator operator, Client client, String portal, int fetchSize, AtomicBoolean isCanceled) {
220221

221222
Sinks.Many<FrontendMessage> requests = Sinks.many().unicast().onBackpressureBuffer(Queues.<FrontendMessage>small().get());
222-
AtomicBoolean isCanceled = new AtomicBoolean(false);
223223

224224
MessageFactory factory = () -> operator.getMessages(Arrays.asList(new Execute(portal, fetchSize), Flush.INSTANCE));
225225

src/main/java/io/r2dbc/postgresql/PostgresqlResult.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ public Mono<Long> getRowsUpdated() {
109109
public <T> Flux<T> map(BiFunction<Row, RowMetadata, ? extends T> f) {
110110
Assert.requireNonNull(f, "f must not be null");
111111

112-
return this.messages
112+
return (Flux<T>) this.messages
113113
.handle((message, sink) -> {
114114

115115
try {

src/main/java/io/r2dbc/postgresql/PostgresqlStatement.java

+17-5
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
import java.util.Iterator;
4545
import java.util.List;
4646
import java.util.NoSuchElementException;
47+
import java.util.concurrent.CompletableFuture;
4748
import java.util.concurrent.atomic.AtomicBoolean;
4849
import java.util.function.Predicate;
4950

@@ -199,6 +200,9 @@ private int getIdentifierIndex(String identifier) {
199200
private Flux<io.r2dbc.postgresql.api.PostgresqlResult> execute(String sql) {
200201
ExceptionFactory factory = ExceptionFactory.withSql(sql);
201202

203+
CompletableFuture<Void> onCancel = new CompletableFuture<>();
204+
AtomicBoolean canceled = new AtomicBoolean();
205+
202206
if (this.parsedSql.getParameterCount() != 0) {
203207
// Extended query protocol
204208
if (this.bindings.size() == 0) {
@@ -213,17 +217,22 @@ private Flux<io.r2dbc.postgresql.api.PostgresqlResult> execute(String sql) {
213217
if (this.bindings.size() == 1) {
214218

215219
Binding binding = this.bindings.peekFirst();
216-
Flux<BackendMessage> messages = collectBindingParameters(binding).flatMapMany(values -> ExtendedFlowDelegate.runQuery(this.resources, factory, sql, binding, values, fetchSize));
220+
Flux<BackendMessage> messages = collectBindingParameters(binding).flatMapMany(values -> ExtendedFlowDelegate.runQuery(this.resources, factory, sql, binding, values, fetchSize,
221+
new AtomicBoolean()));
217222
return Flux.just(PostgresqlResult.toResult(this.resources, messages, factory));
218223
}
219224

220225
Iterator<Binding> iterator = this.bindings.iterator();
221226
Sinks.Many<Binding> bindings = Sinks.many().unicast().onBackpressureBuffer();
222-
AtomicBoolean canceled = new AtomicBoolean();
227+
228+
onCancel.whenComplete((unused, throwable) -> {
229+
clearBindings(iterator, canceled);
230+
});
231+
223232
return bindings.asFlux()
224233
.map(it -> {
225234
Flux<BackendMessage> messages =
226-
collectBindingParameters(it).flatMapMany(values -> ExtendedFlowDelegate.runQuery(this.resources, factory, sql, it, values, this.fetchSize)).doOnComplete(() -> tryNextBinding(iterator, bindings, canceled));
235+
collectBindingParameters(it).flatMapMany(values -> ExtendedFlowDelegate.runQuery(this.resources, factory, sql, it, values, this.fetchSize, canceled)).doOnComplete(() -> tryNextBinding(iterator, bindings, canceled));
227236

228237
return PostgresqlResult.toResult(this.resources, messages, factory);
229238
})
@@ -237,15 +246,18 @@ private Flux<io.r2dbc.postgresql.api.PostgresqlResult> execute(String sql) {
237246
Flux<BackendMessage> exchange;
238247
// Simple Query protocol
239248
if (this.fetchSize != NO_LIMIT) {
240-
exchange = ExtendedFlowDelegate.runQuery(this.resources, factory, sql, Binding.EMPTY, Collections.emptyList(), this.fetchSize);
249+
exchange = ExtendedFlowDelegate.runQuery(this.resources, factory, sql, Binding.EMPTY, Collections.emptyList(), this.fetchSize, canceled);
241250
} else {
242251
exchange = SimpleQueryMessageFlow.exchange(this.resources.getClient(), sql);
243252
}
244253

245254
return exchange.windowUntil(WINDOW_UNTIL)
246255
.doOnDiscard(ReferenceCounted.class, ReferenceCountUtil::release) // ensure release of rows within WindowPredicate
247256
.map(messages -> PostgresqlResult.toResult(this.resources, messages, factory))
248-
.as(Operators::discardOnCancel);
257+
.as(source -> Operators.discardOnCancel(source, () -> {
258+
canceled.set(true);
259+
onCancel.complete(null);
260+
}));
249261
}
250262

251263
private static void tryNextBinding(Iterator<Binding> iterator, Sinks.Many<Binding> bindingSink, AtomicBoolean canceled) {

src/test/java/io/r2dbc/postgresql/PostgresCancelIntegrationTests.java

+40-2
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
import org.junit.jupiter.api.RepeatedTest;
2424
import org.junit.jupiter.api.Test;
2525
import org.junit.jupiter.api.TestInstance;
26+
import org.junit.jupiter.api.Timeout;
27+
import org.springframework.jdbc.core.JdbcOperations;
2628
import reactor.core.publisher.Mono;
2729
import reactor.test.StepVerifier;
2830

@@ -42,12 +44,18 @@ void setUp() {
4244

4345
super.setUp();
4446

45-
SERVER.getJdbcOperations().execute("DROP TABLE IF EXISTS insert_test;");
46-
SERVER.getJdbcOperations().execute("CREATE TABLE insert_test\n" +
47+
JdbcOperations jdbc = SERVER.getJdbcOperations();
48+
jdbc.execute("DROP TABLE IF EXISTS insert_test;");
49+
jdbc.execute("CREATE TABLE insert_test\n" +
4750
"(\n" +
4851
" id SERIAL PRIMARY KEY,\n" +
4952
" value CHAR(1) NOT NULL\n" +
5053
");");
54+
55+
56+
jdbc.execute("DROP TABLE IF EXISTS lots_of_data;");
57+
jdbc.execute("CREATE TABLE lots_of_data AS \n"
58+
+ " SELECT i FROM generate_series(1,200000) as i;");
5159
}
5260

5361
@AfterAll
@@ -111,4 +119,34 @@ void cancelRequest() {
111119
.verify(Duration.ofSeconds(5));
112120
}
113121

122+
@Timeout(10)
123+
@RepeatedTest(20)
124+
void shouldCancelParametrizedWithFetchSize() {
125+
126+
this.connection.createStatement("SELECT * FROM lots_of_data WHERE $1 = $1 ORDER BY i")
127+
.fetchSize(10)
128+
.bind(0, 1)
129+
.execute()
130+
.flatMap(r -> r.map((row, meta) -> row.get(0, Integer.class)))
131+
.as(StepVerifier::create)
132+
.expectNext(1)
133+
.expectNextCount(5)
134+
.thenCancel()
135+
.verify(Duration.ofSeconds(5));
136+
}
137+
138+
@Timeout(10)
139+
@RepeatedTest(20)
140+
void shouldCancelSimpleWithFetchSize() {
141+
142+
this.connection.createStatement("SELECT * FROM lots_of_data ORDER BY i")
143+
.fetchSize(10)
144+
.execute()
145+
.flatMap(r -> r.map((row, meta) -> row.get(0, Integer.class)))
146+
.as(StepVerifier::create)
147+
.expectNext(1)
148+
.expectNextCount(5)
149+
.thenCancel()
150+
.verify(Duration.ofSeconds(5));
151+
}
114152
}

0 commit comments

Comments
 (0)