19
19
import io .netty .buffer .ByteBuf ;
20
20
import io .netty .util .ReferenceCountUtil ;
21
21
import io .netty .util .ReferenceCounted ;
22
+ import io .r2dbc .postgresql .api .ErrorDetails ;
22
23
import io .r2dbc .postgresql .client .Binding ;
23
24
import io .r2dbc .postgresql .client .Client ;
24
25
import io .r2dbc .postgresql .client .ExtendedQueryMessageFlow ;
45
46
import io .r2dbc .postgresql .util .Operators ;
46
47
import reactor .core .publisher .Flux ;
47
48
import reactor .core .publisher .FluxSink ;
48
- import reactor .core .publisher .Mono ;
49
49
import reactor .core .publisher .SynchronousSink ;
50
50
import reactor .core .publisher .UnicastProcessor ;
51
+ import reactor .util .annotation .Nullable ;
51
52
import reactor .util .concurrent .Queues ;
52
53
53
54
import java .util .ArrayList ;
55
+ import java .util .Arrays ;
56
+ import java .util .Collection ;
54
57
import java .util .List ;
55
58
import java .util .concurrent .atomic .AtomicBoolean ;
59
+ import java .util .concurrent .atomic .AtomicInteger ;
60
+ import java .util .function .BiConsumer ;
56
61
import java .util .function .Predicate ;
57
62
58
63
import static io .r2dbc .postgresql .message .frontend .Execute .NO_LIMIT ;
@@ -87,92 +92,81 @@ public static Flux<BackendMessage> runQuery(ConnectionResources resources, Excep
87
92
StatementCache cache = resources .getStatementCache ();
88
93
Client client = resources .getClient ();
89
94
90
- String name = cache .getName (binding , query );
91
95
String portal = resources .getPortalNameSupplier ().get ();
92
- boolean prepareRequired = cache .requiresPrepare (binding , query );
93
-
94
- List <FrontendMessage .DirectEncoder > messagesToSend = new ArrayList <>(6 );
95
-
96
- if (prepareRequired ) {
97
- messagesToSend .add (new Parse (name , binding .getParameterTypes (), query ));
98
- }
99
-
100
- Bind bind = new Bind (portal , binding .getParameterFormats (), values , ExtendedQueryMessageFlow .resultFormat (resources .getConfiguration ().isForceBinary ()), name );
101
-
102
- messagesToSend .add (bind );
103
- messagesToSend .add (new Describe (portal , PORTAL ));
104
96
105
97
Flux <BackendMessage > exchange ;
106
98
boolean compatibilityMode = resources .getConfiguration ().isCompatibilityMode ();
107
99
boolean implicitTransactions = resources .getClient ().getTransactionStatus () == TransactionStatus .IDLE ;
108
100
101
+ ExtendedFlowOperator operator = new ExtendedFlowOperator (query , binding , cache , values , portal , resources .getConfiguration ().isForceBinary ());
102
+
109
103
if (compatibilityMode ) {
110
104
111
105
if (fetchSize == NO_LIMIT || implicitTransactions ) {
112
- exchange = fetchAll (messagesToSend , client , portal );
106
+ exchange = fetchAll (operator , client , portal );
113
107
} else {
114
- exchange = fetchCursoredWithSync (messagesToSend , client , portal , fetchSize );
108
+ exchange = fetchCursoredWithSync (operator , client , portal , fetchSize );
115
109
}
116
110
} else {
117
111
118
112
if (fetchSize == NO_LIMIT ) {
119
- exchange = fetchAll (messagesToSend , client , portal );
113
+ exchange = fetchAll (operator , client , portal );
120
114
} else {
121
- exchange = fetchCursoredWithFlush (messagesToSend , client , portal , fetchSize );
115
+ exchange = fetchCursoredWithFlush (operator , client , portal , fetchSize );
122
116
}
123
117
}
124
118
125
- if (prepareRequired ) {
126
-
127
- exchange = exchange .doOnNext (message -> {
119
+ exchange = exchange .doOnNext (message -> {
128
120
129
- if (message == ParseComplete .INSTANCE ) {
130
- cache .put (binding , query , name );
131
- }
132
- });
133
- }
121
+ if (message == ParseComplete .INSTANCE ) {
122
+ operator .hydrateStatementCache ();
123
+ }
124
+ });
134
125
135
126
return exchange .doOnSubscribe (it -> QueryLogger .logQuery (client .getContext (), query )).doOnDiscard (ReferenceCounted .class , ReferenceCountUtil ::release ).filter (RESULT_FRAME_FILTER ).handle (factory ::handleErrorResponse );
136
127
}
137
128
138
129
/**
139
130
* Execute the query and indicate to fetch all rows with the {@link Execute} message.
140
131
*
141
- * @param messagesToSend the initial bind flow
142
- * @param client client to use
143
- * @param portal the portal
132
+ * @param operator the flow operator
133
+ * @param client client to use
134
+ * @param portal the portal
144
135
* @return the resulting message stream
145
136
*/
146
- private static Flux <BackendMessage > fetchAll (List < FrontendMessage . DirectEncoder > messagesToSend , Client client , String portal ) {
137
+ private static Flux <BackendMessage > fetchAll (ExtendedFlowOperator operator , Client client , String portal ) {
147
138
148
- messagesToSend . add ( new Execute ( portal , NO_LIMIT ));
149
- messagesToSend . add ( new Close ( portal , PORTAL ) );
150
- messagesToSend . add ( Sync .INSTANCE );
139
+ UnicastProcessor < FrontendMessage > requestsProcessor = UnicastProcessor . create ( Queues .< FrontendMessage > small (). get ( ));
140
+ FluxSink < FrontendMessage > requestsSink = requestsProcessor . sink ( );
141
+ MessageFactory factory = () -> operator . getMessages ( Arrays . asList ( new Execute ( portal , NO_LIMIT ), new Close ( portal , PORTAL ), Sync .INSTANCE ) );
151
142
152
- return client .exchange (Mono .just (new CompositeFrontendMessage (messagesToSend )))
143
+ return client .exchange (operator .takeUntil (), Flux .<FrontendMessage >just (new CompositeFrontendMessage (factory .createMessages ())).concatWith (requestsProcessor ))
144
+ .handle (handleReprepare (requestsSink , operator , factory ))
145
+ .doFinally (ignore -> operator .close (requestsSink ))
153
146
.as (Operators ::discardOnCancel );
154
147
}
155
148
156
149
/**
157
150
* Execute a chunked query and indicate to fetch rows in chunks with the {@link Execute} message.
158
151
*
159
- * @param messagesToSend the messages to send
160
- * @param client client to use
161
- * @param portal the portal
162
- * @param fetchSize fetch size per roundtrip
152
+ * @param operator the flow operator
153
+ * @param client client to use
154
+ * @param portal the portal
155
+ * @param fetchSize fetch size per roundtrip
163
156
* @return the resulting message stream
164
157
*/
165
- private static Flux <BackendMessage > fetchCursoredWithSync (List < FrontendMessage . DirectEncoder > messagesToSend , Client client , String portal , int fetchSize ) {
158
+ private static Flux <BackendMessage > fetchCursoredWithSync (ExtendedFlowOperator operator , Client client , String portal , int fetchSize ) {
166
159
167
160
UnicastProcessor <FrontendMessage > requestsProcessor = UnicastProcessor .create (Queues .<FrontendMessage >small ().get ());
168
161
FluxSink <FrontendMessage > requestsSink = requestsProcessor .sink ();
169
162
AtomicBoolean isCanceled = new AtomicBoolean (false );
170
163
AtomicBoolean done = new AtomicBoolean (false );
171
164
172
- messagesToSend . add ( new Execute (portal , fetchSize ));
173
- messagesToSend . add ( Sync . INSTANCE );
165
+ MessageFactory factory = () -> operator . getMessages ( Arrays . asList ( new Execute (portal , fetchSize ), Sync . INSTANCE ));
166
+ Predicate < BackendMessage > takeUntil = operator . takeUntil ( );
174
167
175
- return client .exchange (it -> done .get () && it instanceof ReadyForQuery , Flux .<FrontendMessage >just (new CompositeFrontendMessage (messagesToSend )).concatWith (requestsProcessor ))
168
+ return client .exchange (it -> done .get () && takeUntil .test (it ), Flux .<FrontendMessage >just (new CompositeFrontendMessage (factory .createMessages ())).concatWith (requestsProcessor ))
169
+ .handle (handleReprepare (requestsSink , operator , factory ))
176
170
.handle ((BackendMessage message , SynchronousSink <BackendMessage > sink ) -> {
177
171
178
172
if (message instanceof CommandComplete ) {
@@ -211,30 +205,30 @@ private static Flux<BackendMessage> fetchCursoredWithSync(List<FrontendMessage.D
211
205
} else {
212
206
sink .next (message );
213
207
}
214
- }).doFinally (ignore -> requestsSink . complete ( ))
208
+ }).doFinally (ignore -> operator . close ( requestsSink ))
215
209
.as (flux -> Operators .discardOnCancel (flux , () -> isCanceled .set (true )));
216
210
}
217
211
218
212
/**
219
213
* Execute a contiguous query and indicate to fetch rows in chunks with the {@link Execute} message. Uses {@link Flush}-based synchronization that creates a cursor. Note that flushing keeps the
220
214
* cursor open even with implicit transactions and this method may not work with newer pgpool implementations.
221
215
*
222
- * @param messagesToSend the messages to send
223
- * @param client client to use
224
- * @param portal the portal
225
- * @param fetchSize fetch size per roundtrip
216
+ * @param operator the flow operator
217
+ * @param client client to use
218
+ * @param portal the portal
219
+ * @param fetchSize fetch size per roundtrip
226
220
* @return the resulting message stream
227
221
*/
228
- private static Flux <BackendMessage > fetchCursoredWithFlush (List < FrontendMessage . DirectEncoder > messagesToSend , Client client , String portal , int fetchSize ) {
222
+ private static Flux <BackendMessage > fetchCursoredWithFlush (ExtendedFlowOperator operator , Client client , String portal , int fetchSize ) {
229
223
230
224
UnicastProcessor <FrontendMessage > requestsProcessor = UnicastProcessor .create (Queues .<FrontendMessage >small ().get ());
231
225
FluxSink <FrontendMessage > requestsSink = requestsProcessor .sink ();
232
226
AtomicBoolean isCanceled = new AtomicBoolean (false );
233
227
234
- messagesToSend .add (new Execute (portal , fetchSize ));
235
- messagesToSend .add (Flush .INSTANCE );
228
+ MessageFactory factory = () -> operator .getMessages (Arrays .asList (new Execute (portal , fetchSize ), Flush .INSTANCE ));
236
229
237
- return client .exchange (Flux .<FrontendMessage >just (new CompositeFrontendMessage (messagesToSend )).concatWith (requestsProcessor ))
230
+ return client .exchange (operator .takeUntil (), Flux .<FrontendMessage >just (new CompositeFrontendMessage (factory .createMessages ())).concatWith (requestsProcessor ))
231
+ .handle (handleReprepare (requestsSink , operator , factory ))
238
232
.handle ((BackendMessage message , SynchronousSink <BackendMessage > sink ) -> {
239
233
240
234
if (message instanceof CommandComplete ) {
@@ -258,8 +252,154 @@ private static Flux<BackendMessage> fetchCursoredWithFlush(List<FrontendMessage.
258
252
} else {
259
253
sink .next (message );
260
254
}
261
- }).doFinally (ignore -> requestsSink . complete ( ))
255
+ }).doFinally (ignore -> operator . close ( requestsSink ))
262
256
.as (flux -> Operators .discardOnCancel (flux , () -> isCanceled .set (true )));
263
257
}
264
258
259
+ private static BiConsumer <BackendMessage , SynchronousSink <BackendMessage >> handleReprepare (FluxSink <FrontendMessage > requests , ExtendedFlowOperator operator , MessageFactory messageFactory ) {
260
+
261
+ AtomicBoolean reprepared = new AtomicBoolean ();
262
+
263
+ return (message , sink ) -> {
264
+
265
+ if (message instanceof ErrorResponse && requiresReprepare ((ErrorResponse ) message ) && reprepared .compareAndSet (false , true )) {
266
+
267
+ operator .evictCachedStatement ();
268
+
269
+ List <FrontendMessage .DirectEncoder > messages = messageFactory .createMessages ();
270
+ if (!messages .contains (Sync .INSTANCE )) {
271
+ messages .add (0 , Sync .INSTANCE );
272
+ }
273
+ requests .next (new CompositeFrontendMessage (messages ));
274
+ } else {
275
+ sink .next (message );
276
+ }
277
+ };
278
+ }
279
+
280
+ private static boolean requiresReprepare (ErrorResponse errorResponse ) {
281
+
282
+ ErrorDetails details = new ErrorDetails (errorResponse .getFields ());
283
+ String code = details .getCode ();
284
+
285
+ // "prepared statement \"S_2\" does not exist"
286
+ // INVALID_SQL_STATEMENT_NAME
287
+ if ("26000" .equals (code )) {
288
+ return true ;
289
+ }
290
+ // NOT_IMPLEMENTED
291
+
292
+ if (!"0A000" .equals (code )) {
293
+ return false ;
294
+ }
295
+
296
+ String routine = details .getRoutine ().orElse (null );
297
+ // "cached plan must not change result type"
298
+ return "RevalidateCachedQuery" .equals (routine ) // 9.2+
299
+ || "RevalidateCachedPlan" .equals (routine ); // <= 9.1
300
+ }
301
+
302
+ interface MessageFactory {
303
+
304
+ List <FrontendMessage .DirectEncoder > createMessages ();
305
+
306
+ }
307
+
308
+ /**
309
+ * Operator to encapsulate common activity around the extended flow. Subclasses {@link AtomicInteger} to capture the number of ReadyForQuery frames.
310
+ */
311
+ static class ExtendedFlowOperator extends AtomicInteger {
312
+
313
+ private final String sql ;
314
+
315
+ private final Binding binding ;
316
+
317
+ @ Nullable
318
+ private volatile String name ;
319
+
320
+ private final StatementCache cache ;
321
+
322
+ private final List <ByteBuf > values ;
323
+
324
+ private final String portal ;
325
+
326
+ private final boolean forceBinary ;
327
+
328
+ public ExtendedFlowOperator (String sql , Binding binding , StatementCache cache , List <ByteBuf > values , String portal , boolean forceBinary ) {
329
+ this .sql = sql ;
330
+ this .binding = binding ;
331
+ this .cache = cache ;
332
+ this .values = values ;
333
+ this .portal = portal ;
334
+ this .forceBinary = forceBinary ;
335
+ set (1 );
336
+ }
337
+
338
+ public void close (FluxSink <FrontendMessage > requests ) {
339
+ requests .complete ();
340
+ this .values .forEach (ReferenceCountUtil ::release );
341
+ }
342
+
343
+ public void evictCachedStatement () {
344
+
345
+ incrementAndGet ();
346
+
347
+ synchronized (this ) {
348
+ this .name = null ;
349
+ }
350
+ this .cache .evict (this .sql );
351
+ }
352
+
353
+ public void hydrateStatementCache () {
354
+ this .cache .put (this .binding , this .sql , getStatementName ());
355
+ }
356
+
357
+ public Predicate <BackendMessage > takeUntil () {
358
+ return m -> {
359
+
360
+ if (m instanceof ReadyForQuery ) {
361
+ return decrementAndGet () <= 0 ;
362
+ }
363
+
364
+ return false ;
365
+ };
366
+ }
367
+
368
+ private boolean isPrepareRequired () {
369
+ return this .cache .requiresPrepare (this .binding , this .sql );
370
+ }
371
+
372
+ public String getStatementName () {
373
+ synchronized (this ) {
374
+
375
+ if (this .name == null ) {
376
+ this .name = this .cache .getName (this .binding , this .sql );
377
+ }
378
+ return this .name ;
379
+ }
380
+ }
381
+
382
+ public List <FrontendMessage .DirectEncoder > getMessages (Collection <FrontendMessage .DirectEncoder > append ) {
383
+ List <FrontendMessage .DirectEncoder > messagesToSend = new ArrayList <>(6 );
384
+
385
+ if (isPrepareRequired ()) {
386
+ messagesToSend .add (new Parse (getStatementName (), this .binding .getParameterTypes (), this .sql ));
387
+ }
388
+
389
+ for (ByteBuf value : this .values ) {
390
+ value .readerIndex (0 );
391
+ value .touch ("ExtendedFlowOperator" ).retain ();
392
+ }
393
+
394
+ Bind bind = new Bind (this .portal , this .binding .getParameterFormats (), this .values , ExtendedQueryMessageFlow .resultFormat (this .forceBinary ), getStatementName ());
395
+
396
+ messagesToSend .add (bind );
397
+ messagesToSend .add (new Describe (this .portal , PORTAL ));
398
+ messagesToSend .addAll (append );
399
+
400
+ return messagesToSend ;
401
+ }
402
+
403
+ }
404
+
265
405
}
0 commit comments