diff --git a/testkit-backend/src/main/java/neo4j/org/testkit/backend/ResponseQueueHanlder.java b/testkit-backend/src/main/java/neo4j/org/testkit/backend/ResponseQueueHanlder.java new file mode 100644 index 0000000000..f660e72c61 --- /dev/null +++ b/testkit-backend/src/main/java/neo4j/org/testkit/backend/ResponseQueueHanlder.java @@ -0,0 +1,54 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package neo4j.org.testkit.backend; + +import java.util.ArrayDeque; +import java.util.Queue; +import java.util.function.Consumer; +import neo4j.org.testkit.backend.messages.responses.TestkitResponse; + +public class ResponseQueueHanlder { + private final Consumer responseWriter; + private final Queue responseQueue = new ArrayDeque<>(); + private boolean responseReady; + + ResponseQueueHanlder(Consumer responseWriter) { + this.responseWriter = responseWriter; + } + + public synchronized void setResponseReadyAndDispatchFirst() { + responseReady = true; + dispatchFirst(); + } + + public synchronized void offerAndDispatchFirst(TestkitResponse response) { + responseQueue.offer(response); + if (responseReady) { + dispatchFirst(); + } + } + + private synchronized void dispatchFirst() { + var response = responseQueue.poll(); + if (response != null) { + responseReady = false; + responseWriter.accept(response); + } + } +} diff --git a/testkit-backend/src/main/java/neo4j/org/testkit/backend/Runner.java b/testkit-backend/src/main/java/neo4j/org/testkit/backend/Runner.java index 8bde35e7fa..05cfba2d01 100644 --- a/testkit-backend/src/main/java/neo4j/org/testkit/backend/Runner.java +++ b/testkit-backend/src/main/java/neo4j/org/testkit/backend/Runner.java @@ -49,10 +49,14 @@ public static void main(String[] args) throws InterruptedException { .childHandler(new ChannelInitializer() { @Override protected void initChannel(SocketChannel channel) { + var responseQueueHanlder = new ResponseQueueHanlder(channel::writeAndFlush); channel.pipeline().addLast(new TestkitMessageInboundHandler()); channel.pipeline().addLast(new TestkitMessageOutboundHandler()); - channel.pipeline().addLast(new TestkitRequestResponseMapperHandler(logging)); - channel.pipeline().addLast(new TestkitRequestProcessorHandler(backendMode, logging)); + channel.pipeline() + .addLast(new TestkitRequestResponseMapperHandler(logging, responseQueueHanlder)); + channel.pipeline() + .addLast(new TestkitRequestProcessorHandler( + backendMode, logging, responseQueueHanlder)); } }); var server = bootstrap.bind().sync(); diff --git a/testkit-backend/src/main/java/neo4j/org/testkit/backend/channel/handler/TestkitRequestProcessorHandler.java b/testkit-backend/src/main/java/neo4j/org/testkit/backend/channel/handler/TestkitRequestProcessorHandler.java index 1d21eaff06..310b99283a 100644 --- a/testkit-backend/src/main/java/neo4j/org/testkit/backend/channel/handler/TestkitRequestProcessorHandler.java +++ b/testkit-backend/src/main/java/neo4j/org/testkit/backend/channel/handler/TestkitRequestProcessorHandler.java @@ -30,6 +30,7 @@ import java.util.function.BiFunction; import neo4j.org.testkit.backend.CustomDriverError; import neo4j.org.testkit.backend.FrontendError; +import neo4j.org.testkit.backend.ResponseQueueHanlder; import neo4j.org.testkit.backend.TestkitState; import neo4j.org.testkit.backend.messages.requests.TestkitRequest; import neo4j.org.testkit.backend.messages.responses.BackendError; @@ -47,9 +48,11 @@ public class TestkitRequestProcessorHandler extends ChannelInboundHandlerAdapter private final BiFunction> processorImpl; // Some requests require multiple threads private final Executor requestExecutorService = Executors.newFixedThreadPool(10); + private final ResponseQueueHanlder responseQueueHanlder; private Channel channel; - public TestkitRequestProcessorHandler(BackendMode backendMode, Logging logging) { + public TestkitRequestProcessorHandler( + BackendMode backendMode, Logging logging, ResponseQueueHanlder responseQueueHanlder) { switch (backendMode) { case ASYNC -> processorImpl = TestkitRequest::processAsync; case REACTIVE_LEGACY -> processorImpl = @@ -59,6 +62,7 @@ public TestkitRequestProcessorHandler(BackendMode backendMode, Logging logging) default -> processorImpl = TestkitRequestProcessorHandler::wrapSyncRequest; } testkitState = new TestkitState(this::writeAndFlush, logging); + this.responseQueueHanlder = responseQueueHanlder; } @Override @@ -74,14 +78,14 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) { requestExecutorService.execute(() -> { try { var request = (TestkitRequest) msg; - var responseStage = processorImpl.apply(request, testkitState); - responseStage.whenComplete((response, throwable) -> { - if (throwable != null) { - ctx.writeAndFlush(createErrorResponse(throwable)); - } else if (response != null) { - ctx.writeAndFlush(response); - } - }); + processorImpl + .apply(request, testkitState) + .exceptionally(this::createErrorResponse) + .whenComplete((response, ignored) -> { + if (response != null) { + responseQueueHanlder.offerAndDispatchFirst(response); + } + }); } catch (Throwable throwable) { exceptionCaught(ctx, throwable); } @@ -101,7 +105,8 @@ private static CompletionStage wrapSyncRequest( @Override public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { - ctx.writeAndFlush(createErrorResponse(cause)); + var response = createErrorResponse(cause); + responseQueueHanlder.offerAndDispatchFirst(response); } private TestkitResponse createErrorResponse(Throwable throwable) { @@ -165,7 +170,7 @@ private void writeAndFlush(TestkitResponse response) { if (channel == null) { throw new IllegalStateException("Called before channel is initialized"); } - channel.writeAndFlush(response); + responseQueueHanlder.offerAndDispatchFirst(response); } public enum BackendMode { diff --git a/testkit-backend/src/main/java/neo4j/org/testkit/backend/channel/handler/TestkitRequestResponseMapperHandler.java b/testkit-backend/src/main/java/neo4j/org/testkit/backend/channel/handler/TestkitRequestResponseMapperHandler.java index 1ebd45b6a9..7d2a7e9a4e 100644 --- a/testkit-backend/src/main/java/neo4j/org/testkit/backend/channel/handler/TestkitRequestResponseMapperHandler.java +++ b/testkit-backend/src/main/java/neo4j/org/testkit/backend/channel/handler/TestkitRequestResponseMapperHandler.java @@ -23,6 +23,7 @@ import io.netty.channel.ChannelDuplexHandler; import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelPromise; +import neo4j.org.testkit.backend.ResponseQueueHanlder; import neo4j.org.testkit.backend.messages.TestkitModule; import neo4j.org.testkit.backend.messages.requests.TestkitRequest; import neo4j.org.testkit.backend.messages.responses.TestkitResponse; @@ -32,17 +33,19 @@ public class TestkitRequestResponseMapperHandler extends ChannelDuplexHandler { private final Logger log; private final ObjectMapper objectMapper = newObjectMapper(); + private final ResponseQueueHanlder responseQueueHanlder; - public TestkitRequestResponseMapperHandler(Logging logging) { + public TestkitRequestResponseMapperHandler(Logging logging, ResponseQueueHanlder responseQueueHanlder) { log = logging.getLog(getClass()); + this.responseQueueHanlder = responseQueueHanlder; } @Override public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { var testkitMessage = (String) msg; log.debug("Inbound Testkit message '%s'", testkitMessage.trim()); - TestkitRequest testkitRequest; - testkitRequest = objectMapper.readValue(testkitMessage, TestkitRequest.class); + responseQueueHanlder.setResponseReadyAndDispatchFirst(); + var testkitRequest = objectMapper.readValue(testkitMessage, TestkitRequest.class); ctx.fireChannelRead(testkitRequest); }