diff --git a/src/main/java/io/r2dbc/postgresql/client/ReactorNettyClient.java b/src/main/java/io/r2dbc/postgresql/client/ReactorNettyClient.java index 594f528e..2c3d27b0 100644 --- a/src/main/java/io/r2dbc/postgresql/client/ReactorNettyClient.java +++ b/src/main/java/io/r2dbc/postgresql/client/ReactorNettyClient.java @@ -29,7 +29,6 @@ import io.netty.handler.codec.LengthFieldBasedFrameDecoder; import io.netty.handler.logging.LogLevel; import io.netty.handler.logging.LoggingHandler; -import io.netty.util.ReferenceCountUtil; import io.netty.util.internal.logging.InternalLogger; import io.netty.util.internal.logging.InternalLoggerFactory; import io.r2dbc.postgresql.message.backend.BackendKeyData; @@ -72,9 +71,9 @@ import java.util.Queue; import java.util.StringJoiner; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Consumer; +import java.util.function.Function; import java.util.function.Predicate; import java.util.function.Supplier; @@ -99,9 +98,9 @@ public final class ReactorNettyClient implements Client { private final Connection connection; - private final EmitterProcessor requestProcessor = EmitterProcessor.create(false); + private final EmitterProcessor> requestProcessor = EmitterProcessor.create(false); - private final FluxSink requests = this.requestProcessor.sink(); + private final FluxSink> requests = this.requestProcessor.sink(); private final Queue responseReceivers = Queues.unbounded().get(); @@ -154,6 +153,7 @@ private ReactorNettyClient(Connection connection) { .then(); Mono request = this.requestProcessor + .concatMap(Function.identity()) .flatMap(message -> { if (DEBUG_ENABLED) { logger.debug("Request: {}", message); @@ -178,34 +178,18 @@ public Flux exchange(Predicate takeUntil, Publis return Flux .create(sink -> { - - final AtomicInteger once = new AtomicInteger(); - - Flux.from(requests) - .subscribe(message -> { - - if (!isConnected()) { - ReferenceCountUtil.safeRelease(message); - sink.error(new PostgresConnectionClosedException("Cannot exchange messages because the connection is closed")); - return; - } - - if (once.get() == 0 && once.compareAndSet(0, 1)) { - synchronized (this) { - this.responseReceivers.add(new ResponseReceiver(sink, takeUntil)); - this.requests.next(message); - } - } else { - this.requests.next(message); - } - - }, this.requests::error, () -> { - + if (!isConnected()) { + sink.error(new PostgresConnectionClosedException("Cannot exchange messages because the connection is closed")); + return; + } + synchronized (this) { + this.responseReceivers.add(new ResponseReceiver(sink, takeUntil)); + this.requests.next(Flux.from(requests).doOnNext(m -> { if (!isConnected()) { sink.error(new PostgresConnectionClosedException("Cannot exchange messages because the connection is closed")); } - }); - + })); + } }); } @@ -213,7 +197,7 @@ public Flux exchange(Predicate takeUntil, Publis public void send(FrontendMessage message) { Assert.requireNonNull(message, "requests must not be null"); - this.requests.next(message); + this.requests.next(Mono.just(message)); } private Mono resumeError(Throwable throwable) { @@ -482,9 +466,9 @@ private ResponseReceiver(FluxSink sink, Predicate requestProcessor; + private final EmitterProcessor> requestProcessor; - private EnsureSubscribersCompleteChannelHandler(EmitterProcessor requestProcessor) { + private EnsureSubscribersCompleteChannelHandler(EmitterProcessor> requestProcessor) { this.requestProcessor = requestProcessor; } diff --git a/src/test/java/io/r2dbc/postgresql/client/ReactorNettyClientIntegrationTests.java b/src/test/java/io/r2dbc/postgresql/client/ReactorNettyClientIntegrationTests.java index 84e25cb9..efbae2c8 100644 --- a/src/test/java/io/r2dbc/postgresql/client/ReactorNettyClientIntegrationTests.java +++ b/src/test/java/io/r2dbc/postgresql/client/ReactorNettyClientIntegrationTests.java @@ -23,13 +23,20 @@ import io.r2dbc.postgresql.PostgresqlConnectionFactory; import io.r2dbc.postgresql.api.PostgresqlConnection; import io.r2dbc.postgresql.authentication.PasswordAuthenticationHandler; +import io.r2dbc.postgresql.message.Format; import io.r2dbc.postgresql.message.backend.BackendMessage; +import io.r2dbc.postgresql.message.backend.BindComplete; import io.r2dbc.postgresql.message.backend.CommandComplete; import io.r2dbc.postgresql.message.backend.DataRow; import io.r2dbc.postgresql.message.backend.NotificationResponse; import io.r2dbc.postgresql.message.backend.RowDescription; +import io.r2dbc.postgresql.message.frontend.Bind; +import io.r2dbc.postgresql.message.frontend.Describe; +import io.r2dbc.postgresql.message.frontend.Execute; +import io.r2dbc.postgresql.message.frontend.ExecutionType; import io.r2dbc.postgresql.message.frontend.FrontendMessage; import io.r2dbc.postgresql.message.frontend.Query; +import io.r2dbc.postgresql.message.frontend.Sync; import io.r2dbc.postgresql.util.PostgresqlServerExtension; import io.r2dbc.spi.R2dbcNonTransientResourceException; import io.r2dbc.spi.R2dbcPermissionDeniedException; @@ -62,6 +69,7 @@ import java.util.function.Function; import java.util.stream.IntStream; +import static io.r2dbc.postgresql.type.PostgresqlObjectId.INT4; import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.assertj.core.api.Assertions.fail; @@ -274,6 +282,54 @@ void parallelExchange() { .verifyComplete(); } + @Test + void parallelExchangeExtendedFlow() { + ExtendedQueryMessageFlow.parse(this.client, "S_1", "SELECT $1", Arrays.asList(INT4.getObjectId())) + .as(StepVerifier::create) + .verifyComplete(); + + this.client.exchange(Mono.just(Sync.INSTANCE)) + .as(StepVerifier::create) + .verifyComplete(); + + FrontendMessage bind1 = new Bind( + "P_1", + Collections.singletonList(Format.FORMAT_BINARY), + Collections.singletonList(this.client.getByteBufAllocator().buffer(4).writeInt(42)), + Collections.singletonList(Format.FORMAT_BINARY), + "S_1" + ); + Describe describe1 = new Describe("P_1", ExecutionType.PORTAL); + Execute execute1 = new Execute("P_1", Integer.MAX_VALUE); + FrontendMessage bind2 = new Bind( + "P_2", + Collections.singletonList(Format.FORMAT_BINARY), + Collections.singletonList(this.client.getByteBufAllocator().buffer(4).writeInt(42)), + Collections.singletonList(Format.FORMAT_BINARY), + "S_1" + ); + Describe describe2 = new Describe("P_2", ExecutionType.PORTAL); + Execute execute2 = new Execute("P_2", Integer.MAX_VALUE); + + + Flux flow1 = Flux.just(bind1, describe1, execute1, Sync.INSTANCE).delayElements(Duration.ofMillis(10)); + Flux flow2 = Flux.just(bind2, describe2, execute2, Sync.INSTANCE).delayElements(Duration.ofMillis(20)); + + this.datarowCleanup(Flux.zip(this.client.exchange(flow1), this.client.exchange(flow2)) + .flatMapIterable(t -> Arrays.asList(t.getT1(), t.getT2())) + ) + .as(StepVerifier::create) + .assertNext(message -> assertThat(message).isInstanceOf(BindComplete.class)) + .assertNext(message -> assertThat(message).isInstanceOf(BindComplete.class)) + .assertNext(message -> assertThat(message).isInstanceOf(RowDescription.class)) + .assertNext(message -> assertThat(message).isInstanceOf(RowDescription.class)) + .assertNext(message -> assertThat(message).isInstanceOf(DataRow.class)) + .assertNext(message -> assertThat(message).isInstanceOf(DataRow.class)) + .expectNext(new CommandComplete("SELECT", null, 1)) + .expectNext(new CommandComplete("SELECT", null, 1)) + .verifyComplete(); + } + @Test void timeoutTest() { PostgresqlConnectionFactory postgresqlConnectionFactory = new PostgresqlConnectionFactory(PostgresqlConnectionConfiguration.builder()