Skip to content

Commit 379f32e

Browse files
committed
Add support to reprepare cached queries.
We now reprepare cached queries that were invalidated due to e.g. schema changes. [closes #382] Signed-off-by: Mark Paluch <[email protected]>
1 parent dd07bc4 commit 379f32e

9 files changed

+385
-85
lines changed

src/main/java/io/r2dbc/postgresql/BoundedStatementCache.java

+22-1
Original file line numberDiff line numberDiff line change
@@ -92,9 +92,30 @@ public void put(Binding binding, String sql, String name) {
9292

9393
Map.Entry<CacheKey, String> lastAccessedStatement = getAndRemoveEldest();
9494
ExceptionFactory factory = ExceptionFactory.withSql(lastAccessedStatement.getKey().sql);
95+
String statementName = lastAccessedStatement.getValue();
9596

97+
close(lastAccessedStatement, factory, statementName);
98+
}
99+
100+
@Override
101+
public void evict(String name) {
102+
103+
synchronized (this.cache) {
104+
105+
List<CacheKey> toRemove = new ArrayList<>();
106+
for (Map.Entry<CacheKey, String> entry : this.cache.entrySet()) {
107+
if (entry.getKey().sql.equals(name)) {
108+
toRemove.add(entry.getKey());
109+
}
110+
}
111+
112+
toRemove.forEach(this.cache::remove);
113+
}
114+
}
115+
116+
private void close(Map.Entry<CacheKey, String> lastAccessedStatement, ExceptionFactory factory, String statementName) {
96117
ExtendedQueryMessageFlow
97-
.closeStatement(this.client, lastAccessedStatement.getValue())
118+
.closeStatement(this.client, statementName)
98119
.handle(factory::handleErrorResponse)
99120
.subscribe(it -> {
100121
}, err -> LOGGER.warn(String.format("Cannot close statement %s (%s)", lastAccessedStatement.getValue(), lastAccessedStatement.getKey().sql), err));

src/main/java/io/r2dbc/postgresql/DisabledStatementCache.java

+4
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,10 @@ public boolean requiresPrepare(Binding binding, String sql) {
3939
public void put(Binding binding, String sql, String name) {
4040
}
4141

42+
@Override
43+
public void evict(String sql) {
44+
}
45+
4246
@Override
4347
public String toString() {
4448
return "DisabledStatementCache";

src/main/java/io/r2dbc/postgresql/ExtendedFlowDelegate.java

+192-52
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import io.netty.buffer.ByteBuf;
2020
import io.netty.util.ReferenceCountUtil;
2121
import io.netty.util.ReferenceCounted;
22+
import io.r2dbc.postgresql.api.ErrorDetails;
2223
import io.r2dbc.postgresql.client.Binding;
2324
import io.r2dbc.postgresql.client.Client;
2425
import io.r2dbc.postgresql.client.ExtendedQueryMessageFlow;
@@ -45,14 +46,18 @@
4546
import io.r2dbc.postgresql.util.Operators;
4647
import reactor.core.publisher.Flux;
4748
import reactor.core.publisher.FluxSink;
48-
import reactor.core.publisher.Mono;
4949
import reactor.core.publisher.SynchronousSink;
5050
import reactor.core.publisher.UnicastProcessor;
51+
import reactor.util.annotation.Nullable;
5152
import reactor.util.concurrent.Queues;
5253

5354
import java.util.ArrayList;
55+
import java.util.Arrays;
56+
import java.util.Collection;
5457
import java.util.List;
5558
import java.util.concurrent.atomic.AtomicBoolean;
59+
import java.util.concurrent.atomic.AtomicInteger;
60+
import java.util.function.BiConsumer;
5661
import java.util.function.Predicate;
5762

5863
import static io.r2dbc.postgresql.message.frontend.Execute.NO_LIMIT;
@@ -87,92 +92,81 @@ public static Flux<BackendMessage> runQuery(ConnectionResources resources, Excep
8792
StatementCache cache = resources.getStatementCache();
8893
Client client = resources.getClient();
8994

90-
String name = cache.getName(binding, query);
9195
String portal = resources.getPortalNameSupplier().get();
92-
boolean prepareRequired = cache.requiresPrepare(binding, query);
93-
94-
List<FrontendMessage.DirectEncoder> messagesToSend = new ArrayList<>(6);
95-
96-
if (prepareRequired) {
97-
messagesToSend.add(new Parse(name, binding.getParameterTypes(), query));
98-
}
99-
100-
Bind bind = new Bind(portal, binding.getParameterFormats(), values, ExtendedQueryMessageFlow.resultFormat(resources.getConfiguration().isForceBinary()), name);
101-
102-
messagesToSend.add(bind);
103-
messagesToSend.add(new Describe(portal, PORTAL));
10496

10597
Flux<BackendMessage> exchange;
10698
boolean compatibilityMode = resources.getConfiguration().isCompatibilityMode();
10799
boolean implicitTransactions = resources.getClient().getTransactionStatus() == TransactionStatus.IDLE;
108100

101+
ExtendedFlowOperator operator = new ExtendedFlowOperator(query, binding, cache, values, portal, resources.getConfiguration().isForceBinary());
102+
109103
if (compatibilityMode) {
110104

111105
if (fetchSize == NO_LIMIT || implicitTransactions) {
112-
exchange = fetchAll(messagesToSend, client, portal);
106+
exchange = fetchAll(operator, client, portal);
113107
} else {
114-
exchange = fetchCursoredWithSync(messagesToSend, client, portal, fetchSize);
108+
exchange = fetchCursoredWithSync(operator, client, portal, fetchSize);
115109
}
116110
} else {
117111

118112
if (fetchSize == NO_LIMIT) {
119-
exchange = fetchAll(messagesToSend, client, portal);
113+
exchange = fetchAll(operator, client, portal);
120114
} else {
121-
exchange = fetchCursoredWithFlush(messagesToSend, client, portal, fetchSize);
115+
exchange = fetchCursoredWithFlush(operator, client, portal, fetchSize);
122116
}
123117
}
124118

125-
if (prepareRequired) {
126-
127-
exchange = exchange.doOnNext(message -> {
119+
exchange = exchange.doOnNext(message -> {
128120

129-
if (message == ParseComplete.INSTANCE) {
130-
cache.put(binding, query, name);
131-
}
132-
});
133-
}
121+
if (message == ParseComplete.INSTANCE) {
122+
operator.hydrateStatementCache();
123+
}
124+
});
134125

135126
return exchange.doOnSubscribe(it -> QueryLogger.logQuery(client.getContext(), query)).doOnDiscard(ReferenceCounted.class, ReferenceCountUtil::release).filter(RESULT_FRAME_FILTER).handle(factory::handleErrorResponse);
136127
}
137128

138129
/**
139130
* Execute the query and indicate to fetch all rows with the {@link Execute} message.
140131
*
141-
* @param messagesToSend the initial bind flow
142-
* @param client client to use
143-
* @param portal the portal
132+
* @param operator the flow operator
133+
* @param client client to use
134+
* @param portal the portal
144135
* @return the resulting message stream
145136
*/
146-
private static Flux<BackendMessage> fetchAll(List<FrontendMessage.DirectEncoder> messagesToSend, Client client, String portal) {
137+
private static Flux<BackendMessage> fetchAll(ExtendedFlowOperator operator, Client client, String portal) {
147138

148-
messagesToSend.add(new Execute(portal, NO_LIMIT));
149-
messagesToSend.add(new Close(portal, PORTAL));
150-
messagesToSend.add(Sync.INSTANCE);
139+
UnicastProcessor<FrontendMessage> requestsProcessor = UnicastProcessor.create(Queues.<FrontendMessage>small().get());
140+
FluxSink<FrontendMessage> requestsSink = requestsProcessor.sink();
141+
MessageFactory factory = () -> operator.getMessages(Arrays.asList(new Execute(portal, NO_LIMIT), new Close(portal, PORTAL), Sync.INSTANCE));
151142

152-
return client.exchange(Mono.just(new CompositeFrontendMessage(messagesToSend)))
143+
return client.exchange(operator.takeUntil(), Flux.<FrontendMessage>just(new CompositeFrontendMessage(factory.createMessages())).concatWith(requestsProcessor))
144+
.handle(handleReprepare(requestsSink, operator, factory))
145+
.doFinally(ignore -> operator.close(requestsSink))
153146
.as(Operators::discardOnCancel);
154147
}
155148

156149
/**
157150
* Execute a chunked query and indicate to fetch rows in chunks with the {@link Execute} message.
158151
*
159-
* @param messagesToSend the messages to send
160-
* @param client client to use
161-
* @param portal the portal
162-
* @param fetchSize fetch size per roundtrip
152+
* @param operator the flow operator
153+
* @param client client to use
154+
* @param portal the portal
155+
* @param fetchSize fetch size per roundtrip
163156
* @return the resulting message stream
164157
*/
165-
private static Flux<BackendMessage> fetchCursoredWithSync(List<FrontendMessage.DirectEncoder> messagesToSend, Client client, String portal, int fetchSize) {
158+
private static Flux<BackendMessage> fetchCursoredWithSync(ExtendedFlowOperator operator, Client client, String portal, int fetchSize) {
166159

167160
UnicastProcessor<FrontendMessage> requestsProcessor = UnicastProcessor.create(Queues.<FrontendMessage>small().get());
168161
FluxSink<FrontendMessage> requestsSink = requestsProcessor.sink();
169162
AtomicBoolean isCanceled = new AtomicBoolean(false);
170163
AtomicBoolean done = new AtomicBoolean(false);
171164

172-
messagesToSend.add(new Execute(portal, fetchSize));
173-
messagesToSend.add(Sync.INSTANCE);
165+
MessageFactory factory = () -> operator.getMessages(Arrays.asList(new Execute(portal, fetchSize), Sync.INSTANCE));
166+
Predicate<BackendMessage> takeUntil = operator.takeUntil();
174167

175-
return client.exchange(it -> done.get() && it instanceof ReadyForQuery, Flux.<FrontendMessage>just(new CompositeFrontendMessage(messagesToSend)).concatWith(requestsProcessor))
168+
return client.exchange(it -> done.get() && takeUntil.test(it), Flux.<FrontendMessage>just(new CompositeFrontendMessage(factory.createMessages())).concatWith(requestsProcessor))
169+
.handle(handleReprepare(requestsSink, operator, factory))
176170
.handle((BackendMessage message, SynchronousSink<BackendMessage> sink) -> {
177171

178172
if (message instanceof CommandComplete) {
@@ -211,30 +205,30 @@ private static Flux<BackendMessage> fetchCursoredWithSync(List<FrontendMessage.D
211205
} else {
212206
sink.next(message);
213207
}
214-
}).doFinally(ignore -> requestsSink.complete())
208+
}).doFinally(ignore -> operator.close(requestsSink))
215209
.as(flux -> Operators.discardOnCancel(flux, () -> isCanceled.set(true)));
216210
}
217211

218212
/**
219213
* Execute a contiguous query and indicate to fetch rows in chunks with the {@link Execute} message. Uses {@link Flush}-based synchronization that creates a cursor. Note that flushing keeps the
220214
* cursor open even with implicit transactions and this method may not work with newer pgpool implementations.
221215
*
222-
* @param messagesToSend the messages to send
223-
* @param client client to use
224-
* @param portal the portal
225-
* @param fetchSize fetch size per roundtrip
216+
* @param operator the flow operator
217+
* @param client client to use
218+
* @param portal the portal
219+
* @param fetchSize fetch size per roundtrip
226220
* @return the resulting message stream
227221
*/
228-
private static Flux<BackendMessage> fetchCursoredWithFlush(List<FrontendMessage.DirectEncoder> messagesToSend, Client client, String portal, int fetchSize) {
222+
private static Flux<BackendMessage> fetchCursoredWithFlush(ExtendedFlowOperator operator, Client client, String portal, int fetchSize) {
229223

230224
UnicastProcessor<FrontendMessage> requestsProcessor = UnicastProcessor.create(Queues.<FrontendMessage>small().get());
231225
FluxSink<FrontendMessage> requestsSink = requestsProcessor.sink();
232226
AtomicBoolean isCanceled = new AtomicBoolean(false);
233227

234-
messagesToSend.add(new Execute(portal, fetchSize));
235-
messagesToSend.add(Flush.INSTANCE);
228+
MessageFactory factory = () -> operator.getMessages(Arrays.asList(new Execute(portal, fetchSize), Flush.INSTANCE));
236229

237-
return client.exchange(Flux.<FrontendMessage>just(new CompositeFrontendMessage(messagesToSend)).concatWith(requestsProcessor))
230+
return client.exchange(operator.takeUntil(), Flux.<FrontendMessage>just(new CompositeFrontendMessage(factory.createMessages())).concatWith(requestsProcessor))
231+
.handle(handleReprepare(requestsSink, operator, factory))
238232
.handle((BackendMessage message, SynchronousSink<BackendMessage> sink) -> {
239233

240234
if (message instanceof CommandComplete) {
@@ -258,8 +252,154 @@ private static Flux<BackendMessage> fetchCursoredWithFlush(List<FrontendMessage.
258252
} else {
259253
sink.next(message);
260254
}
261-
}).doFinally(ignore -> requestsSink.complete())
255+
}).doFinally(ignore -> operator.close(requestsSink))
262256
.as(flux -> Operators.discardOnCancel(flux, () -> isCanceled.set(true)));
263257
}
264258

259+
private static BiConsumer<BackendMessage, SynchronousSink<BackendMessage>> handleReprepare(FluxSink<FrontendMessage> requests, ExtendedFlowOperator operator, MessageFactory messageFactory) {
260+
261+
AtomicBoolean reprepared = new AtomicBoolean();
262+
263+
return (message, sink) -> {
264+
265+
if (message instanceof ErrorResponse && requiresReprepare((ErrorResponse) message) && reprepared.compareAndSet(false, true)) {
266+
267+
operator.evictCachedStatement();
268+
269+
List<FrontendMessage.DirectEncoder> messages = messageFactory.createMessages();
270+
if (!messages.contains(Sync.INSTANCE)) {
271+
messages.add(0, Sync.INSTANCE);
272+
}
273+
requests.next(new CompositeFrontendMessage(messages));
274+
} else {
275+
sink.next(message);
276+
}
277+
};
278+
}
279+
280+
private static boolean requiresReprepare(ErrorResponse errorResponse) {
281+
282+
ErrorDetails details = new ErrorDetails(errorResponse.getFields());
283+
String code = details.getCode();
284+
285+
// "prepared statement \"S_2\" does not exist"
286+
// INVALID_SQL_STATEMENT_NAME
287+
if ("26000".equals(code)) {
288+
return true;
289+
}
290+
// NOT_IMPLEMENTED
291+
292+
if (!"0A000".equals(code)) {
293+
return false;
294+
}
295+
296+
String routine = details.getRoutine().orElse(null);
297+
// "cached plan must not change result type"
298+
return "RevalidateCachedQuery".equals(routine) // 9.2+
299+
|| "RevalidateCachedPlan".equals(routine); // <= 9.1
300+
}
301+
302+
interface MessageFactory {
303+
304+
List<FrontendMessage.DirectEncoder> createMessages();
305+
306+
}
307+
308+
/**
309+
* Operator to encapsulate common activity around the extended flow. Subclasses {@link AtomicInteger} to capture the number of ReadyForQuery frames.
310+
*/
311+
static class ExtendedFlowOperator extends AtomicInteger {
312+
313+
private final String sql;
314+
315+
private final Binding binding;
316+
317+
@Nullable
318+
private volatile String name;
319+
320+
private final StatementCache cache;
321+
322+
private final List<ByteBuf> values;
323+
324+
private final String portal;
325+
326+
private final boolean forceBinary;
327+
328+
public ExtendedFlowOperator(String sql, Binding binding, StatementCache cache, List<ByteBuf> values, String portal, boolean forceBinary) {
329+
this.sql = sql;
330+
this.binding = binding;
331+
this.cache = cache;
332+
this.values = values;
333+
this.portal = portal;
334+
this.forceBinary = forceBinary;
335+
set(1);
336+
}
337+
338+
public void close(FluxSink<FrontendMessage> requests) {
339+
requests.complete();
340+
this.values.forEach(ReferenceCountUtil::release);
341+
}
342+
343+
public void evictCachedStatement() {
344+
345+
incrementAndGet();
346+
347+
synchronized (this) {
348+
this.name = null;
349+
}
350+
this.cache.evict(this.sql);
351+
}
352+
353+
public void hydrateStatementCache() {
354+
this.cache.put(this.binding, this.sql, getStatementName());
355+
}
356+
357+
public Predicate<BackendMessage> takeUntil() {
358+
return m -> {
359+
360+
if (m instanceof ReadyForQuery) {
361+
return decrementAndGet() <= 0;
362+
}
363+
364+
return false;
365+
};
366+
}
367+
368+
private boolean isPrepareRequired() {
369+
return this.cache.requiresPrepare(this.binding, this.sql);
370+
}
371+
372+
public String getStatementName() {
373+
synchronized (this) {
374+
375+
if (this.name == null) {
376+
this.name = this.cache.getName(this.binding, this.sql);
377+
}
378+
return this.name;
379+
}
380+
}
381+
382+
public List<FrontendMessage.DirectEncoder> getMessages(Collection<FrontendMessage.DirectEncoder> append) {
383+
List<FrontendMessage.DirectEncoder> messagesToSend = new ArrayList<>(6);
384+
385+
if (isPrepareRequired()) {
386+
messagesToSend.add(new Parse(getStatementName(), this.binding.getParameterTypes(), this.sql));
387+
}
388+
389+
for (ByteBuf value : this.values) {
390+
value.readerIndex(0);
391+
value.touch("ExtendedFlowOperator").retain();
392+
}
393+
394+
Bind bind = new Bind(this.portal, this.binding.getParameterFormats(), this.values, ExtendedQueryMessageFlow.resultFormat(this.forceBinary), getStatementName());
395+
396+
messagesToSend.add(bind);
397+
messagesToSend.add(new Describe(this.portal, PORTAL));
398+
messagesToSend.addAll(append);
399+
400+
return messagesToSend;
401+
}
402+
403+
}
404+
265405
}

0 commit comments

Comments
 (0)