Skip to content

Commit 25aa528

Browse files
committed
Fix for concurrent exchanging on one client
1 parent fc1c7f3 commit 25aa528

File tree

2 files changed

+72
-32
lines changed

2 files changed

+72
-32
lines changed

src/main/java/io/r2dbc/postgresql/client/ReactorNettyClient.java

+16-32
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@
2929
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
3030
import io.netty.handler.logging.LogLevel;
3131
import io.netty.handler.logging.LoggingHandler;
32-
import io.netty.util.ReferenceCountUtil;
3332
import io.netty.util.internal.logging.InternalLogger;
3433
import io.netty.util.internal.logging.InternalLoggerFactory;
3534
import io.r2dbc.postgresql.message.backend.BackendKeyData;
@@ -72,9 +71,9 @@
7271
import java.util.Queue;
7372
import java.util.StringJoiner;
7473
import java.util.concurrent.atomic.AtomicBoolean;
75-
import java.util.concurrent.atomic.AtomicInteger;
7674
import java.util.concurrent.atomic.AtomicReference;
7775
import java.util.function.Consumer;
76+
import java.util.function.Function;
7877
import java.util.function.Predicate;
7978
import java.util.function.Supplier;
8079

@@ -99,9 +98,9 @@ public final class ReactorNettyClient implements Client {
9998

10099
private final Connection connection;
101100

102-
private final EmitterProcessor<FrontendMessage> requestProcessor = EmitterProcessor.create(false);
101+
private final EmitterProcessor<Publisher<FrontendMessage>> requestProcessor = EmitterProcessor.create(false);
103102

104-
private final FluxSink<FrontendMessage> requests = this.requestProcessor.sink();
103+
private final FluxSink<Publisher<FrontendMessage>> requests = this.requestProcessor.sink();
105104

106105
private final Queue<ResponseReceiver> responseReceivers = Queues.<ResponseReceiver>unbounded().get();
107106

@@ -154,6 +153,7 @@ private ReactorNettyClient(Connection connection) {
154153
.then();
155154

156155
Mono<Void> request = this.requestProcessor
156+
.concatMap(Function.identity())
157157
.flatMap(message -> {
158158
if (DEBUG_ENABLED) {
159159
logger.debug("Request: {}", message);
@@ -178,42 +178,26 @@ public Flux<BackendMessage> exchange(Predicate<BackendMessage> takeUntil, Publis
178178

179179
return Flux
180180
.create(sink -> {
181-
182-
final AtomicInteger once = new AtomicInteger();
183-
184-
Flux.from(requests)
185-
.subscribe(message -> {
186-
187-
if (!isConnected()) {
188-
ReferenceCountUtil.safeRelease(message);
189-
sink.error(new PostgresConnectionClosedException("Cannot exchange messages because the connection is closed"));
190-
return;
191-
}
192-
193-
if (once.get() == 0 && once.compareAndSet(0, 1)) {
194-
synchronized (this) {
195-
this.responseReceivers.add(new ResponseReceiver(sink, takeUntil));
196-
this.requests.next(message);
197-
}
198-
} else {
199-
this.requests.next(message);
200-
}
201-
202-
}, this.requests::error, () -> {
203-
181+
if (!isConnected()) {
182+
sink.error(new PostgresConnectionClosedException("Cannot exchange messages because the connection is closed"));
183+
return;
184+
}
185+
synchronized (this) {
186+
this.responseReceivers.add(new ResponseReceiver(sink, takeUntil));
187+
this.requests.next(Flux.from(requests).doOnNext(m -> {
204188
if (!isConnected()) {
205189
sink.error(new PostgresConnectionClosedException("Cannot exchange messages because the connection is closed"));
206190
}
207-
});
208-
191+
}));
192+
}
209193
});
210194
}
211195

212196
@Override
213197
public void send(FrontendMessage message) {
214198
Assert.requireNonNull(message, "requests must not be null");
215199

216-
this.requests.next(message);
200+
this.requests.next(Mono.just(message));
217201
}
218202

219203
private Mono<Void> resumeError(Throwable throwable) {
@@ -482,9 +466,9 @@ private ResponseReceiver(FluxSink<BackendMessage> sink, Predicate<BackendMessage
482466

483467
private final class EnsureSubscribersCompleteChannelHandler extends ChannelDuplexHandler {
484468

485-
private final EmitterProcessor<FrontendMessage> requestProcessor;
469+
private final EmitterProcessor<Publisher<FrontendMessage>> requestProcessor;
486470

487-
private EnsureSubscribersCompleteChannelHandler(EmitterProcessor<FrontendMessage> requestProcessor) {
471+
private EnsureSubscribersCompleteChannelHandler(EmitterProcessor<Publisher<FrontendMessage>> requestProcessor) {
488472
this.requestProcessor = requestProcessor;
489473
}
490474

src/test/java/io/r2dbc/postgresql/client/ReactorNettyClientIntegrationTests.java

+56
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,20 @@
2323
import io.r2dbc.postgresql.PostgresqlConnectionFactory;
2424
import io.r2dbc.postgresql.api.PostgresqlConnection;
2525
import io.r2dbc.postgresql.authentication.PasswordAuthenticationHandler;
26+
import io.r2dbc.postgresql.message.Format;
2627
import io.r2dbc.postgresql.message.backend.BackendMessage;
28+
import io.r2dbc.postgresql.message.backend.BindComplete;
2729
import io.r2dbc.postgresql.message.backend.CommandComplete;
2830
import io.r2dbc.postgresql.message.backend.DataRow;
2931
import io.r2dbc.postgresql.message.backend.NotificationResponse;
3032
import io.r2dbc.postgresql.message.backend.RowDescription;
33+
import io.r2dbc.postgresql.message.frontend.Bind;
34+
import io.r2dbc.postgresql.message.frontend.Describe;
35+
import io.r2dbc.postgresql.message.frontend.Execute;
36+
import io.r2dbc.postgresql.message.frontend.ExecutionType;
3137
import io.r2dbc.postgresql.message.frontend.FrontendMessage;
3238
import io.r2dbc.postgresql.message.frontend.Query;
39+
import io.r2dbc.postgresql.message.frontend.Sync;
3340
import io.r2dbc.postgresql.util.PostgresqlServerExtension;
3441
import io.r2dbc.spi.R2dbcNonTransientResourceException;
3542
import io.r2dbc.spi.R2dbcPermissionDeniedException;
@@ -62,6 +69,7 @@
6269
import java.util.function.Function;
6370
import java.util.stream.IntStream;
6471

72+
import static io.r2dbc.postgresql.type.PostgresqlObjectId.INT4;
6573
import static org.assertj.core.api.Assertions.assertThat;
6674
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
6775
import static org.assertj.core.api.Assertions.fail;
@@ -274,6 +282,54 @@ void parallelExchange() {
274282
.verifyComplete();
275283
}
276284

285+
@Test
286+
void parallelExchangeExtendedFlow() {
287+
ExtendedQueryMessageFlow.parse(this.client, "S_1", "SELECT $1", Arrays.asList(INT4.getObjectId()))
288+
.as(StepVerifier::create)
289+
.verifyComplete();
290+
291+
this.client.exchange(Mono.just(Sync.INSTANCE))
292+
.as(StepVerifier::create)
293+
.verifyComplete();
294+
295+
FrontendMessage bind1 = new Bind(
296+
"P_1",
297+
Collections.singletonList(Format.FORMAT_BINARY),
298+
Collections.singletonList(this.client.getByteBufAllocator().buffer(4).writeInt(42)),
299+
Collections.singletonList(Format.FORMAT_BINARY),
300+
"S_1"
301+
);
302+
Describe describe1 = new Describe("P_1", ExecutionType.PORTAL);
303+
Execute execute1 = new Execute("P_1", Integer.MAX_VALUE);
304+
FrontendMessage bind2 = new Bind(
305+
"P_2",
306+
Collections.singletonList(Format.FORMAT_BINARY),
307+
Collections.singletonList(this.client.getByteBufAllocator().buffer(4).writeInt(42)),
308+
Collections.singletonList(Format.FORMAT_BINARY),
309+
"S_1"
310+
);
311+
Describe describe2 = new Describe("P_2", ExecutionType.PORTAL);
312+
Execute execute2 = new Execute("P_2", Integer.MAX_VALUE);
313+
314+
315+
Flux<FrontendMessage> flow1 = Flux.just(bind1, describe1, execute1, Sync.INSTANCE).delayElements(Duration.ofMillis(10));
316+
Flux<FrontendMessage> flow2 = Flux.just(bind2, describe2, execute2, Sync.INSTANCE).delayElements(Duration.ofMillis(20));
317+
318+
this.datarowCleanup(Flux.zip(this.client.exchange(flow1), this.client.exchange(flow2))
319+
.flatMapIterable(t -> Arrays.asList(t.getT1(), t.getT2()))
320+
)
321+
.as(StepVerifier::create)
322+
.assertNext(message -> assertThat(message).isInstanceOf(BindComplete.class))
323+
.assertNext(message -> assertThat(message).isInstanceOf(BindComplete.class))
324+
.assertNext(message -> assertThat(message).isInstanceOf(RowDescription.class))
325+
.assertNext(message -> assertThat(message).isInstanceOf(RowDescription.class))
326+
.assertNext(message -> assertThat(message).isInstanceOf(DataRow.class))
327+
.assertNext(message -> assertThat(message).isInstanceOf(DataRow.class))
328+
.expectNext(new CommandComplete("SELECT", null, 1))
329+
.expectNext(new CommandComplete("SELECT", null, 1))
330+
.verifyComplete();
331+
}
332+
277333
@Test
278334
void timeoutTest() {
279335
PostgresqlConnectionFactory postgresqlConnectionFactory = new PostgresqlConnectionFactory(PostgresqlConnectionConfiguration.builder()

0 commit comments

Comments
 (0)