Skip to content

Commit 7f57957

Browse files
committed
Add support to reprepare cached queries.
We now reprepare cached queries that were invalidated due to e.g. schema changes. [closes #382] Signed-off-by: Mark Paluch <[email protected]>
1 parent 40c2be0 commit 7f57957

9 files changed

+384
-85
lines changed

src/main/java/io/r2dbc/postgresql/BoundedStatementCache.java

+22-1
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,30 @@ public void put(Binding binding, String sql, String name) {
9292

9393
Map.Entry<CacheKey, String> lastAccessedStatement = getAndRemoveEldest();
9494
ExceptionFactory factory = ExceptionFactory.withSql(lastAccessedStatement.getKey().sql);
95+
String statementName = lastAccessedStatement.getValue();
9596

97+
close(lastAccessedStatement, factory, statementName);
98+
}
99+
100+
@Override
101+
public void evict(String name) {
102+
103+
synchronized (this.cache) {
104+
105+
List<CacheKey> toRemove = new ArrayList<>();
106+
for (Map.Entry<CacheKey, String> entry : this.cache.entrySet()) {
107+
if (entry.getKey().sql.equals(name)) {
108+
toRemove.add(entry.getKey());
109+
}
110+
}
111+
112+
toRemove.forEach(this.cache::remove);
113+
}
114+
}
115+
116+
private void close(Map.Entry<CacheKey, String> lastAccessedStatement, ExceptionFactory factory, String statementName) {
96117
ExtendedQueryMessageFlow
97-
.closeStatement(this.client, lastAccessedStatement.getValue())
118+
.closeStatement(this.client, statementName)
98119
.handle(factory::handleErrorResponse)
99120
.subscribe(it -> {
100121
}, err -> LOGGER.warn(String.format("Cannot close statement %s (%s)", lastAccessedStatement.getValue(), lastAccessedStatement.getKey().sql), err));

src/main/java/io/r2dbc/postgresql/DisabledStatementCache.java

+4
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ public boolean requiresPrepare(Binding binding, String sql) {
3939
public void put(Binding binding, String sql, String name) {
4040
}
4141

42+
@Override
43+
public void evict(String sql) {
44+
}
45+
4246
@Override
4347
public String toString() {
4448
return "DisabledStatementCache";

src/main/java/io/r2dbc/postgresql/ExtendedFlowDelegate.java

+191-52
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import io.netty.buffer.ByteBuf;
2020
import io.netty.util.ReferenceCountUtil;
2121
import io.netty.util.ReferenceCounted;
22+
import io.r2dbc.postgresql.api.ErrorDetails;
2223
import io.r2dbc.postgresql.client.Binding;
2324
import io.r2dbc.postgresql.client.Client;
2425
import io.r2dbc.postgresql.client.ExtendedQueryMessageFlow;
@@ -44,14 +45,18 @@
4445
import io.r2dbc.postgresql.message.frontend.Sync;
4546
import io.r2dbc.postgresql.util.Operators;
4647
import reactor.core.publisher.Flux;
47-
import reactor.core.publisher.Mono;
4848
import reactor.core.publisher.Sinks;
4949
import reactor.core.publisher.SynchronousSink;
50+
import reactor.util.annotation.Nullable;
5051
import reactor.util.concurrent.Queues;
5152

5253
import java.util.ArrayList;
54+
import java.util.Arrays;
55+
import java.util.Collection;
5356
import java.util.List;
5457
import java.util.concurrent.atomic.AtomicBoolean;
58+
import java.util.concurrent.atomic.AtomicInteger;
59+
import java.util.function.BiConsumer;
5560
import java.util.function.Predicate;
5661

5762
import static io.r2dbc.postgresql.message.frontend.Execute.NO_LIMIT;
@@ -86,91 +91,79 @@ public static Flux<BackendMessage> runQuery(ConnectionResources resources, Excep
8691
StatementCache cache = resources.getStatementCache();
8792
Client client = resources.getClient();
8893

89-
String name = cache.getName(binding, query);
9094
String portal = resources.getPortalNameSupplier().get();
91-
boolean prepareRequired = cache.requiresPrepare(binding, query);
92-
93-
List<FrontendMessage.DirectEncoder> messagesToSend = new ArrayList<>(6);
94-
95-
if (prepareRequired) {
96-
messagesToSend.add(new Parse(name, binding.getParameterTypes(), query));
97-
}
98-
99-
Bind bind = new Bind(portal, binding.getParameterFormats(), values, ExtendedQueryMessageFlow.resultFormat(resources.getConfiguration().isForceBinary()), name);
100-
101-
messagesToSend.add(bind);
102-
messagesToSend.add(new Describe(portal, PORTAL));
10395

10496
Flux<BackendMessage> exchange;
10597
boolean compatibilityMode = resources.getConfiguration().isCompatibilityMode();
10698
boolean implicitTransactions = resources.getClient().getTransactionStatus() == TransactionStatus.IDLE;
10799

100+
ExtendedFlowOperator operator = new ExtendedFlowOperator(query, binding, cache, values, portal, resources.getConfiguration().isForceBinary());
101+
108102
if (compatibilityMode) {
109103

110104
if (fetchSize == NO_LIMIT || implicitTransactions) {
111-
exchange = fetchAll(messagesToSend, client, portal);
105+
exchange = fetchAll(operator, client, portal);
112106
} else {
113-
exchange = fetchCursoredWithSync(messagesToSend, client, portal, fetchSize);
107+
exchange = fetchCursoredWithSync(operator, client, portal, fetchSize);
114108
}
115109
} else {
116110

117111
if (fetchSize == NO_LIMIT) {
118-
exchange = fetchAll(messagesToSend, client, portal);
112+
exchange = fetchAll(operator, client, portal);
119113
} else {
120-
exchange = fetchCursoredWithFlush(messagesToSend, client, portal, fetchSize);
114+
exchange = fetchCursoredWithFlush(operator, client, portal, fetchSize);
121115
}
122116
}
123117

124-
if (prepareRequired) {
125-
126-
exchange = exchange.doOnNext(message -> {
118+
exchange = exchange.doOnNext(message -> {
127119

128-
if (message == ParseComplete.INSTANCE) {
129-
cache.put(binding, query, name);
130-
}
131-
});
132-
}
120+
if (message == ParseComplete.INSTANCE) {
121+
operator.hydrateStatementCache();
122+
}
123+
});
133124

134125
return exchange.doOnSubscribe(it -> QueryLogger.logQuery(client.getContext(), query)).doOnDiscard(ReferenceCounted.class, ReferenceCountUtil::release).filter(RESULT_FRAME_FILTER).handle(factory::handleErrorResponse);
135126
}
136127

137128
/**
138129
* Execute the query and indicate to fetch all rows with the {@link Execute} message.
139130
*
140-
* @param messagesToSend the initial bind flow
141-
* @param client client to use
142-
* @param portal the portal
131+
* @param operator the flow operator
132+
* @param client client to use
133+
* @param portal the portal
143134
* @return the resulting message stream
144135
*/
145-
private static Flux<BackendMessage> fetchAll(List<FrontendMessage.DirectEncoder> messagesToSend, Client client, String portal) {
136+
private static Flux<BackendMessage> fetchAll(ExtendedFlowOperator operator, Client client, String portal) {
146137

147-
messagesToSend.add(new Execute(portal, NO_LIMIT));
148-
messagesToSend.add(new Close(portal, PORTAL));
149-
messagesToSend.add(Sync.INSTANCE);
138+
Sinks.Many<FrontendMessage> requests = Sinks.many().unicast().onBackpressureBuffer(Queues.<FrontendMessage>small().get());
139+
MessageFactory factory = () -> operator.getMessages(Arrays.asList(new Execute(portal, NO_LIMIT), new Close(portal, PORTAL), Sync.INSTANCE));
150140

151-
return client.exchange(Mono.just(new CompositeFrontendMessage(messagesToSend)))
141+
return client.exchange(operator.takeUntil(), Flux.<FrontendMessage>just(new CompositeFrontendMessage(factory.createMessages())).concatWith(requests.asFlux()))
142+
.handle(handleReprepare(requests, operator, factory))
143+
.doFinally(ignore -> operator.close(requests))
152144
.as(Operators::discardOnCancel);
153145
}
154146

155147
/**
156148
* Execute a chunked query and indicate to fetch rows in chunks with the {@link Execute} message.
157149
*
158-
* @param messagesToSend the messages to send
159-
* @param client client to use
160-
* @param portal the portal
161-
* @param fetchSize fetch size per roundtrip
150+
* @param operator the flow operator
151+
* @param client client to use
152+
* @param portal the portal
153+
* @param fetchSize fetch size per roundtrip
162154
* @return the resulting message stream
163155
*/
164-
private static Flux<BackendMessage> fetchCursoredWithSync(List<FrontendMessage.DirectEncoder> messagesToSend, Client client, String portal, int fetchSize) {
156+
private static Flux<BackendMessage> fetchCursoredWithSync(ExtendedFlowOperator operator, Client client, String portal, int fetchSize) {
165157

166158
Sinks.Many<FrontendMessage> requests = Sinks.many().unicast().onBackpressureBuffer(Queues.<FrontendMessage>small().get());
167159
AtomicBoolean isCanceled = new AtomicBoolean(false);
168160
AtomicBoolean done = new AtomicBoolean(false);
169161

170-
messagesToSend.add(new Execute(portal, fetchSize));
171-
messagesToSend.add(Sync.INSTANCE);
162+
MessageFactory factory = () -> operator.getMessages(Arrays.asList(new Execute(portal, fetchSize), Sync.INSTANCE));
163+
Predicate<BackendMessage> takeUntil = operator.takeUntil();
172164

173-
return client.exchange(it -> done.get() && it instanceof ReadyForQuery, Flux.<FrontendMessage>just(new CompositeFrontendMessage(messagesToSend)).concatWith(requests.asFlux()))
165+
return client.exchange(it -> done.get() && takeUntil.test(it), Flux.<FrontendMessage>just(new CompositeFrontendMessage(factory.createMessages())).concatWith(requests.asFlux()))
166+
.handle(handleReprepare(requests, operator, factory))
174167
.handle((BackendMessage message, SynchronousSink<BackendMessage> sink) -> {
175168

176169
if (message instanceof CommandComplete) {
@@ -209,29 +202,29 @@ private static Flux<BackendMessage> fetchCursoredWithSync(List<FrontendMessage.D
209202
} else {
210203
sink.next(message);
211204
}
212-
}).doFinally(ignore -> requests.emitComplete(Sinks.EmitFailureHandler.FAIL_FAST))
205+
}).doFinally(ignore -> operator.close(requests))
213206
.as(flux -> Operators.discardOnCancel(flux, () -> isCanceled.set(true)));
214207
}
215208

216209
/**
217210
* Execute a contiguous query and indicate to fetch rows in chunks with the {@link Execute} message. Uses {@link Flush}-based synchronization that creates a cursor. Note that flushing keeps the
218211
* cursor open even with implicit transactions and this method may not work with newer pgpool implementations.
219212
*
220-
* @param messagesToSend the messages to send
221-
* @param client client to use
222-
* @param portal the portal
223-
* @param fetchSize fetch size per roundtrip
213+
* @param operator the flow operator
214+
* @param client client to use
215+
* @param portal the portal
216+
* @param fetchSize fetch size per roundtrip
224217
* @return the resulting message stream
225218
*/
226-
private static Flux<BackendMessage> fetchCursoredWithFlush(List<FrontendMessage.DirectEncoder> messagesToSend, Client client, String portal, int fetchSize) {
219+
private static Flux<BackendMessage> fetchCursoredWithFlush(ExtendedFlowOperator operator, Client client, String portal, int fetchSize) {
227220

228221
Sinks.Many<FrontendMessage> requests = Sinks.many().unicast().onBackpressureBuffer(Queues.<FrontendMessage>small().get());
229222
AtomicBoolean isCanceled = new AtomicBoolean(false);
230223

231-
messagesToSend.add(new Execute(portal, fetchSize));
232-
messagesToSend.add(Flush.INSTANCE);
224+
MessageFactory factory = () -> operator.getMessages(Arrays.asList(new Execute(portal, fetchSize), Flush.INSTANCE));
233225

234-
return client.exchange(Flux.<FrontendMessage>just(new CompositeFrontendMessage(messagesToSend)).concatWith(requests.asFlux()))
226+
return client.exchange(operator.takeUntil(), Flux.<FrontendMessage>just(new CompositeFrontendMessage(factory.createMessages())).concatWith(requests.asFlux()))
227+
.handle(handleReprepare(requests, operator, factory))
235228
.handle((BackendMessage message, SynchronousSink<BackendMessage> sink) -> {
236229

237230
if (message instanceof CommandComplete) {
@@ -255,8 +248,154 @@ private static Flux<BackendMessage> fetchCursoredWithFlush(List<FrontendMessage.
255248
} else {
256249
sink.next(message);
257250
}
258-
}).doFinally(ignore -> requests.emitComplete(Sinks.EmitFailureHandler.FAIL_FAST))
251+
}).doFinally(ignore -> operator.close(requests))
259252
.as(flux -> Operators.discardOnCancel(flux, () -> isCanceled.set(true)));
260253
}
261254

255+
private static BiConsumer<BackendMessage, SynchronousSink<BackendMessage>> handleReprepare(Sinks.Many<FrontendMessage> requests, ExtendedFlowOperator operator, MessageFactory messageFactory) {
256+
257+
AtomicBoolean reprepared = new AtomicBoolean();
258+
259+
return (message, sink) -> {
260+
261+
if (message instanceof ErrorResponse && requiresReprepare((ErrorResponse) message) && reprepared.compareAndSet(false, true)) {
262+
263+
operator.evictCachedStatement();
264+
265+
List<FrontendMessage.DirectEncoder> messages = messageFactory.createMessages();
266+
if (!messages.contains(Sync.INSTANCE)) {
267+
messages.add(0, Sync.INSTANCE);
268+
}
269+
requests.emitNext(new CompositeFrontendMessage(messages), Sinks.EmitFailureHandler.FAIL_FAST);
270+
} else {
271+
sink.next(message);
272+
}
273+
};
274+
}
275+
276+
private static boolean requiresReprepare(ErrorResponse errorResponse) {
277+
278+
ErrorDetails details = new ErrorDetails(errorResponse.getFields());
279+
String code = details.getCode();
280+
281+
// "prepared statement \"S_2\" does not exist"
282+
// INVALID_SQL_STATEMENT_NAME
283+
if ("26000".equals(code)) {
284+
return true;
285+
}
286+
// NOT_IMPLEMENTED
287+
288+
if (!"0A000".equals(code)) {
289+
return false;
290+
}
291+
292+
String routine = details.getRoutine().orElse(null);
293+
// "cached plan must not change result type"
294+
return "RevalidateCachedQuery".equals(routine) // 9.2+
295+
|| "RevalidateCachedPlan".equals(routine); // <= 9.1
296+
}
297+
298+
interface MessageFactory {
299+
300+
List<FrontendMessage.DirectEncoder> createMessages();
301+
302+
}
303+
304+
/**
305+
* Operator to encapsulate common activity around the extended flow. Subclasses {@link AtomicInteger} to capture the number of ReadyForQuery frames.
306+
*/
307+
static class ExtendedFlowOperator extends AtomicInteger {
308+
309+
private final String sql;
310+
311+
private final Binding binding;
312+
313+
@Nullable
314+
private volatile String name;
315+
316+
private final StatementCache cache;
317+
318+
private final List<ByteBuf> values;
319+
320+
private final String portal;
321+
322+
private final boolean forceBinary;
323+
324+
public ExtendedFlowOperator(String sql, Binding binding, StatementCache cache, List<ByteBuf> values, String portal, boolean forceBinary) {
325+
this.sql = sql;
326+
this.binding = binding;
327+
this.cache = cache;
328+
this.values = values;
329+
this.portal = portal;
330+
this.forceBinary = forceBinary;
331+
set(1);
332+
}
333+
334+
public void close(Sinks.Many<FrontendMessage> requests) {
335+
requests.emitComplete(Sinks.EmitFailureHandler.FAIL_FAST);
336+
this.values.forEach(ReferenceCountUtil::release);
337+
}
338+
339+
public void evictCachedStatement() {
340+
341+
incrementAndGet();
342+
343+
synchronized (this) {
344+
this.name = null;
345+
}
346+
this.cache.evict(this.sql);
347+
}
348+
349+
public void hydrateStatementCache() {
350+
this.cache.put(this.binding, this.sql, getStatementName());
351+
}
352+
353+
public Predicate<BackendMessage> takeUntil() {
354+
return m -> {
355+
356+
if (m instanceof ReadyForQuery) {
357+
return decrementAndGet() <= 0;
358+
}
359+
360+
return false;
361+
};
362+
}
363+
364+
private boolean isPrepareRequired() {
365+
return this.cache.requiresPrepare(this.binding, this.sql);
366+
}
367+
368+
public String getStatementName() {
369+
synchronized (this) {
370+
371+
if (this.name == null) {
372+
this.name = this.cache.getName(this.binding, this.sql);
373+
}
374+
return this.name;
375+
}
376+
}
377+
378+
public List<FrontendMessage.DirectEncoder> getMessages(Collection<FrontendMessage.DirectEncoder> append) {
379+
List<FrontendMessage.DirectEncoder> messagesToSend = new ArrayList<>(6);
380+
381+
if (isPrepareRequired()) {
382+
messagesToSend.add(new Parse(getStatementName(), this.binding.getParameterTypes(), this.sql));
383+
}
384+
385+
for (ByteBuf value : this.values) {
386+
value.readerIndex(0);
387+
value.touch("ExtendedFlowOperator").retain();
388+
}
389+
390+
Bind bind = new Bind(this.portal, this.binding.getParameterFormats(), this.values, ExtendedQueryMessageFlow.resultFormat(this.forceBinary), getStatementName());
391+
392+
messagesToSend.add(bind);
393+
messagesToSend.add(new Describe(this.portal, PORTAL));
394+
messagesToSend.addAll(append);
395+
396+
return messagesToSend;
397+
}
398+
399+
}
400+
262401
}

src/main/java/io/r2dbc/postgresql/IndefiniteStatementCache.java

+5
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,11 @@ public void put(Binding binding, String sql, String name) {
7070
typedMap.put(binding.getParameterTypes(), name);
7171
}
7272

73+
@Override
74+
public void evict(String sql) {
75+
this.cache.remove(sql);
76+
}
77+
7378
private Map<int[], String> getTypeMap(String sql) {
7479

7580
return this.cache.computeIfAbsent(sql, ignore -> new TreeMap<>((o1, o2) -> {

0 commit comments

Comments
 (0)