19
19
import io .netty .buffer .ByteBuf ;
20
20
import io .netty .util .ReferenceCountUtil ;
21
21
import io .netty .util .ReferenceCounted ;
22
+ import io .r2dbc .postgresql .api .ErrorDetails ;
22
23
import io .r2dbc .postgresql .client .Binding ;
23
24
import io .r2dbc .postgresql .client .Client ;
24
25
import io .r2dbc .postgresql .client .ExtendedQueryMessageFlow ;
44
45
import io .r2dbc .postgresql .message .frontend .Sync ;
45
46
import io .r2dbc .postgresql .util .Operators ;
46
47
import reactor .core .publisher .Flux ;
47
- import reactor .core .publisher .Mono ;
48
48
import reactor .core .publisher .Sinks ;
49
49
import reactor .core .publisher .SynchronousSink ;
50
+ import reactor .util .annotation .Nullable ;
50
51
import reactor .util .concurrent .Queues ;
51
52
52
53
import java .util .ArrayList ;
54
+ import java .util .Arrays ;
55
+ import java .util .Collection ;
53
56
import java .util .List ;
54
57
import java .util .concurrent .atomic .AtomicBoolean ;
58
+ import java .util .concurrent .atomic .AtomicInteger ;
59
+ import java .util .function .BiConsumer ;
55
60
import java .util .function .Predicate ;
56
61
57
62
import static io .r2dbc .postgresql .message .frontend .Execute .NO_LIMIT ;
@@ -86,91 +91,79 @@ public static Flux<BackendMessage> runQuery(ConnectionResources resources, Excep
86
91
StatementCache cache = resources .getStatementCache ();
87
92
Client client = resources .getClient ();
88
93
89
- String name = cache .getName (binding , query );
90
94
String portal = resources .getPortalNameSupplier ().get ();
91
- boolean prepareRequired = cache .requiresPrepare (binding , query );
92
-
93
- List <FrontendMessage .DirectEncoder > messagesToSend = new ArrayList <>(6 );
94
-
95
- if (prepareRequired ) {
96
- messagesToSend .add (new Parse (name , binding .getParameterTypes (), query ));
97
- }
98
-
99
- Bind bind = new Bind (portal , binding .getParameterFormats (), values , ExtendedQueryMessageFlow .resultFormat (resources .getConfiguration ().isForceBinary ()), name );
100
-
101
- messagesToSend .add (bind );
102
- messagesToSend .add (new Describe (portal , PORTAL ));
103
95
104
96
Flux <BackendMessage > exchange ;
105
97
boolean compatibilityMode = resources .getConfiguration ().isCompatibilityMode ();
106
98
boolean implicitTransactions = resources .getClient ().getTransactionStatus () == TransactionStatus .IDLE ;
107
99
100
+ ExtendedFlowOperator operator = new ExtendedFlowOperator (query , binding , cache , values , portal , resources .getConfiguration ().isForceBinary ());
101
+
108
102
if (compatibilityMode ) {
109
103
110
104
if (fetchSize == NO_LIMIT || implicitTransactions ) {
111
- exchange = fetchAll (messagesToSend , client , portal );
105
+ exchange = fetchAll (operator , client , portal );
112
106
} else {
113
- exchange = fetchCursoredWithSync (messagesToSend , client , portal , fetchSize );
107
+ exchange = fetchCursoredWithSync (operator , client , portal , fetchSize );
114
108
}
115
109
} else {
116
110
117
111
if (fetchSize == NO_LIMIT ) {
118
- exchange = fetchAll (messagesToSend , client , portal );
112
+ exchange = fetchAll (operator , client , portal );
119
113
} else {
120
- exchange = fetchCursoredWithFlush (messagesToSend , client , portal , fetchSize );
114
+ exchange = fetchCursoredWithFlush (operator , client , portal , fetchSize );
121
115
}
122
116
}
123
117
124
- if (prepareRequired ) {
125
-
126
- exchange = exchange .doOnNext (message -> {
118
+ exchange = exchange .doOnNext (message -> {
127
119
128
- if (message == ParseComplete .INSTANCE ) {
129
- cache .put (binding , query , name );
130
- }
131
- });
132
- }
120
+ if (message == ParseComplete .INSTANCE ) {
121
+ operator .hydrateStatementCache ();
122
+ }
123
+ });
133
124
134
125
return exchange .doOnSubscribe (it -> QueryLogger .logQuery (client .getContext (), query )).doOnDiscard (ReferenceCounted .class , ReferenceCountUtil ::release ).filter (RESULT_FRAME_FILTER ).handle (factory ::handleErrorResponse );
135
126
}
136
127
137
128
/**
138
129
* Execute the query and indicate to fetch all rows with the {@link Execute} message.
139
130
*
140
- * @param messagesToSend the initial bind flow
141
- * @param client client to use
142
- * @param portal the portal
131
+ * @param operator the flow operator
132
+ * @param client client to use
133
+ * @param portal the portal
143
134
* @return the resulting message stream
144
135
*/
145
- private static Flux <BackendMessage > fetchAll (List < FrontendMessage . DirectEncoder > messagesToSend , Client client , String portal ) {
136
+ private static Flux <BackendMessage > fetchAll (ExtendedFlowOperator operator , Client client , String portal ) {
146
137
147
- messagesToSend .add (new Execute (portal , NO_LIMIT ));
148
- messagesToSend .add (new Close (portal , PORTAL ));
149
- messagesToSend .add (Sync .INSTANCE );
138
+ Sinks .Many <FrontendMessage > requests = Sinks .many ().unicast ().onBackpressureBuffer (Queues .<FrontendMessage >small ().get ());
139
+ MessageFactory factory = () -> operator .getMessages (Arrays .asList (new Execute (portal , NO_LIMIT ), new Close (portal , PORTAL ), Sync .INSTANCE ));
150
140
151
- return client .exchange (Mono .just (new CompositeFrontendMessage (messagesToSend )))
141
+ return client .exchange (operator .takeUntil (), Flux .<FrontendMessage >just (new CompositeFrontendMessage (factory .createMessages ())).concatWith (requests .asFlux ()))
142
+ .handle (handleReprepare (requests , operator , factory ))
143
+ .doFinally (ignore -> operator .close (requests ))
152
144
.as (Operators ::discardOnCancel );
153
145
}
154
146
155
147
/**
156
148
* Execute a chunked query and indicate to fetch rows in chunks with the {@link Execute} message.
157
149
*
158
- * @param messagesToSend the messages to send
159
- * @param client client to use
160
- * @param portal the portal
161
- * @param fetchSize fetch size per roundtrip
150
+ * @param operator the flow operator
151
+ * @param client client to use
152
+ * @param portal the portal
153
+ * @param fetchSize fetch size per roundtrip
162
154
* @return the resulting message stream
163
155
*/
164
- private static Flux <BackendMessage > fetchCursoredWithSync (List < FrontendMessage . DirectEncoder > messagesToSend , Client client , String portal , int fetchSize ) {
156
+ private static Flux <BackendMessage > fetchCursoredWithSync (ExtendedFlowOperator operator , Client client , String portal , int fetchSize ) {
165
157
166
158
Sinks .Many <FrontendMessage > requests = Sinks .many ().unicast ().onBackpressureBuffer (Queues .<FrontendMessage >small ().get ());
167
159
AtomicBoolean isCanceled = new AtomicBoolean (false );
168
160
AtomicBoolean done = new AtomicBoolean (false );
169
161
170
- messagesToSend . add ( new Execute (portal , fetchSize ));
171
- messagesToSend . add ( Sync . INSTANCE );
162
+ MessageFactory factory = () -> operator . getMessages ( Arrays . asList ( new Execute (portal , fetchSize ), Sync . INSTANCE ));
163
+ Predicate < BackendMessage > takeUntil = operator . takeUntil ( );
172
164
173
- return client .exchange (it -> done .get () && it instanceof ReadyForQuery , Flux .<FrontendMessage >just (new CompositeFrontendMessage (messagesToSend )).concatWith (requests .asFlux ()))
165
+ return client .exchange (it -> done .get () && takeUntil .test (it ), Flux .<FrontendMessage >just (new CompositeFrontendMessage (factory .createMessages ())).concatWith (requests .asFlux ()))
166
+ .handle (handleReprepare (requests , operator , factory ))
174
167
.handle ((BackendMessage message , SynchronousSink <BackendMessage > sink ) -> {
175
168
176
169
if (message instanceof CommandComplete ) {
@@ -209,29 +202,29 @@ private static Flux<BackendMessage> fetchCursoredWithSync(List<FrontendMessage.D
209
202
} else {
210
203
sink .next (message );
211
204
}
212
- }).doFinally (ignore -> requests . emitComplete ( Sinks . EmitFailureHandler . FAIL_FAST ))
205
+ }).doFinally (ignore -> operator . close ( requests ))
213
206
.as (flux -> Operators .discardOnCancel (flux , () -> isCanceled .set (true )));
214
207
}
215
208
216
209
/**
217
210
* Execute a contiguous query and indicate to fetch rows in chunks with the {@link Execute} message. Uses {@link Flush}-based synchronization that creates a cursor. Note that flushing keeps the
218
211
* cursor open even with implicit transactions and this method may not work with newer pgpool implementations.
219
212
*
220
- * @param messagesToSend the messages to send
221
- * @param client client to use
222
- * @param portal the portal
223
- * @param fetchSize fetch size per roundtrip
213
+ * @param operator the flow operator
214
+ * @param client client to use
215
+ * @param portal the portal
216
+ * @param fetchSize fetch size per roundtrip
224
217
* @return the resulting message stream
225
218
*/
226
- private static Flux <BackendMessage > fetchCursoredWithFlush (List < FrontendMessage . DirectEncoder > messagesToSend , Client client , String portal , int fetchSize ) {
219
+ private static Flux <BackendMessage > fetchCursoredWithFlush (ExtendedFlowOperator operator , Client client , String portal , int fetchSize ) {
227
220
228
221
Sinks .Many <FrontendMessage > requests = Sinks .many ().unicast ().onBackpressureBuffer (Queues .<FrontendMessage >small ().get ());
229
222
AtomicBoolean isCanceled = new AtomicBoolean (false );
230
223
231
- messagesToSend .add (new Execute (portal , fetchSize ));
232
- messagesToSend .add (Flush .INSTANCE );
224
+ MessageFactory factory = () -> operator .getMessages (Arrays .asList (new Execute (portal , fetchSize ), Flush .INSTANCE ));
233
225
234
- return client .exchange (Flux .<FrontendMessage >just (new CompositeFrontendMessage (messagesToSend )).concatWith (requests .asFlux ()))
226
+ return client .exchange (operator .takeUntil (), Flux .<FrontendMessage >just (new CompositeFrontendMessage (factory .createMessages ())).concatWith (requests .asFlux ()))
227
+ .handle (handleReprepare (requests , operator , factory ))
235
228
.handle ((BackendMessage message , SynchronousSink <BackendMessage > sink ) -> {
236
229
237
230
if (message instanceof CommandComplete ) {
@@ -255,8 +248,154 @@ private static Flux<BackendMessage> fetchCursoredWithFlush(List<FrontendMessage.
255
248
} else {
256
249
sink .next (message );
257
250
}
258
- }).doFinally (ignore -> requests . emitComplete ( Sinks . EmitFailureHandler . FAIL_FAST ))
251
+ }).doFinally (ignore -> operator . close ( requests ))
259
252
.as (flux -> Operators .discardOnCancel (flux , () -> isCanceled .set (true )));
260
253
}
261
254
255
+ private static BiConsumer <BackendMessage , SynchronousSink <BackendMessage >> handleReprepare (Sinks .Many <FrontendMessage > requests , ExtendedFlowOperator operator , MessageFactory messageFactory ) {
256
+
257
+ AtomicBoolean reprepared = new AtomicBoolean ();
258
+
259
+ return (message , sink ) -> {
260
+
261
+ if (message instanceof ErrorResponse && requiresReprepare ((ErrorResponse ) message ) && reprepared .compareAndSet (false , true )) {
262
+
263
+ operator .evictCachedStatement ();
264
+
265
+ List <FrontendMessage .DirectEncoder > messages = messageFactory .createMessages ();
266
+ if (!messages .contains (Sync .INSTANCE )) {
267
+ messages .add (0 , Sync .INSTANCE );
268
+ }
269
+ requests .emitNext (new CompositeFrontendMessage (messages ), Sinks .EmitFailureHandler .FAIL_FAST );
270
+ } else {
271
+ sink .next (message );
272
+ }
273
+ };
274
+ }
275
+
276
+ private static boolean requiresReprepare (ErrorResponse errorResponse ) {
277
+
278
+ ErrorDetails details = new ErrorDetails (errorResponse .getFields ());
279
+ String code = details .getCode ();
280
+
281
+ // "prepared statement \"S_2\" does not exist"
282
+ // INVALID_SQL_STATEMENT_NAME
283
+ if ("26000" .equals (code )) {
284
+ return true ;
285
+ }
286
+ // NOT_IMPLEMENTED
287
+
288
+ if (!"0A000" .equals (code )) {
289
+ return false ;
290
+ }
291
+
292
+ String routine = details .getRoutine ().orElse (null );
293
+ // "cached plan must not change result type"
294
+ return "RevalidateCachedQuery" .equals (routine ) // 9.2+
295
+ || "RevalidateCachedPlan" .equals (routine ); // <= 9.1
296
+ }
297
+
298
+ interface MessageFactory {
299
+
300
+ List <FrontendMessage .DirectEncoder > createMessages ();
301
+
302
+ }
303
+
304
+ /**
305
+ * Operator to encapsulate common activity around the extended flow. Subclasses {@link AtomicInteger} to capture the number of ReadyForQuery frames.
306
+ */
307
+ static class ExtendedFlowOperator extends AtomicInteger {
308
+
309
+ private final String sql ;
310
+
311
+ private final Binding binding ;
312
+
313
+ @ Nullable
314
+ private volatile String name ;
315
+
316
+ private final StatementCache cache ;
317
+
318
+ private final List <ByteBuf > values ;
319
+
320
+ private final String portal ;
321
+
322
+ private final boolean forceBinary ;
323
+
324
+ public ExtendedFlowOperator (String sql , Binding binding , StatementCache cache , List <ByteBuf > values , String portal , boolean forceBinary ) {
325
+ this .sql = sql ;
326
+ this .binding = binding ;
327
+ this .cache = cache ;
328
+ this .values = values ;
329
+ this .portal = portal ;
330
+ this .forceBinary = forceBinary ;
331
+ set (1 );
332
+ }
333
+
334
+ public void close (Sinks .Many <FrontendMessage > requests ) {
335
+ requests .emitComplete (Sinks .EmitFailureHandler .FAIL_FAST );
336
+ this .values .forEach (ReferenceCountUtil ::release );
337
+ }
338
+
339
+ public void evictCachedStatement () {
340
+
341
+ incrementAndGet ();
342
+
343
+ synchronized (this ) {
344
+ this .name = null ;
345
+ }
346
+ this .cache .evict (this .sql );
347
+ }
348
+
349
+ public void hydrateStatementCache () {
350
+ this .cache .put (this .binding , this .sql , getStatementName ());
351
+ }
352
+
353
+ public Predicate <BackendMessage > takeUntil () {
354
+ return m -> {
355
+
356
+ if (m instanceof ReadyForQuery ) {
357
+ return decrementAndGet () <= 0 ;
358
+ }
359
+
360
+ return false ;
361
+ };
362
+ }
363
+
364
+ private boolean isPrepareRequired () {
365
+ return this .cache .requiresPrepare (this .binding , this .sql );
366
+ }
367
+
368
+ public String getStatementName () {
369
+ synchronized (this ) {
370
+
371
+ if (this .name == null ) {
372
+ this .name = this .cache .getName (this .binding , this .sql );
373
+ }
374
+ return this .name ;
375
+ }
376
+ }
377
+
378
+ public List <FrontendMessage .DirectEncoder > getMessages (Collection <FrontendMessage .DirectEncoder > append ) {
379
+ List <FrontendMessage .DirectEncoder > messagesToSend = new ArrayList <>(6 );
380
+
381
+ if (isPrepareRequired ()) {
382
+ messagesToSend .add (new Parse (getStatementName (), this .binding .getParameterTypes (), this .sql ));
383
+ }
384
+
385
+ for (ByteBuf value : this .values ) {
386
+ value .readerIndex (0 );
387
+ value .touch ("ExtendedFlowOperator" ).retain ();
388
+ }
389
+
390
+ Bind bind = new Bind (this .portal , this .binding .getParameterFormats (), this .values , ExtendedQueryMessageFlow .resultFormat (this .forceBinary ), getStatementName ());
391
+
392
+ messagesToSend .add (bind );
393
+ messagesToSend .add (new Describe (this .portal , PORTAL ));
394
+ messagesToSend .addAll (append );
395
+
396
+ return messagesToSend ;
397
+ }
398
+
399
+ }
400
+
262
401
}
0 commit comments