From 1b2d3f163f4d4809d220e822511adc6221c97e50 Mon Sep 17 00:00:00 2001 From: Dmitriy Tverdiakov Date: Mon, 26 Jun 2023 21:43:45 +0100 Subject: [PATCH 1/4] Improve explicit transaction terminated state handling An explicit transaction is considered terminated as soon as a failure occurs. Any usage of the terminated transaction and any of its results must be stopped and the transaction must be closed explicitly. This update aims to ensure that the driver does not send further Bolt messages in regards to the terminated transaction. In addition, a new `TransactionTerminatedException` class has been introduced. It is a subclass of the previously used `ClientException`, making this exception more specific. The exception will contain a non-null `code` if it is created based on the server's response. It will not have a code if it is generated by the driver. Depending on a failure cause, the result handles may emit other exceptions respectively, matching the driver's existing behaviour. --- .../TransactionTerminatedException.java | 67 ++++++ .../driver/internal/InternalTransaction.java | 16 ++ .../internal/async/NetworkConnection.java | 176 ++++++++-------- .../driver/internal/async/NetworkSession.java | 6 +- .../internal/async/UnmanagedTransaction.java | 75 ++++--- .../async/connection/DirectConnection.java | 25 +-- .../async/connection/RoutingConnection.java | 20 +- .../inbound/InboundMessageDispatcher.java | 31 +-- .../handlers/ResetResponseHandler.java | 12 ++ .../pulln/AutoPullResponseHandler.java | 20 +- .../pulln/BasicPullResponseHandler.java | 65 ++++-- .../reactive/InternalReactiveTransaction.java | 11 - .../InternalReactiveTransaction.java | 12 +- .../neo4j/driver/internal/spi/Connection.java | 11 +- .../neo4j/driver/internal/util/ErrorUtil.java | 5 +- .../driver/integration/SessionResetIT.java | 16 +- .../driver/integration/TransactionIT.java | 150 +++++++++++++- .../integration/UnmanagedTransactionIT.java | 15 +- .../reactive/ReactiveTransactionIT.java | 195 ++++++++++++++++++ .../internal/async/NetworkConnectionTest.java | 139 +------------ .../internal/async/NetworkSessionTest.java | 8 +- .../async/UnmanagedTransactionTest.java | 69 +++---- .../connection/DecoratedConnectionTest.java | 34 +-- .../connection/RoutingConnectionTest.java | 38 ---- .../InternalReactiveTransactionTest.java | 67 ------ .../util/FailingConnectionDriverFactory.java | 29 +-- .../org/neo4j/driver/testutil/TestUtil.java | 2 +- .../backend/messages/requests/StartTest.java | 8 + 28 files changed, 759 insertions(+), 563 deletions(-) create mode 100644 driver/src/main/java/org/neo4j/driver/exceptions/TransactionTerminatedException.java create mode 100644 driver/src/test/java/org/neo4j/driver/integration/reactive/ReactiveTransactionIT.java delete mode 100644 driver/src/test/java/org/neo4j/driver/internal/reactive/InternalReactiveTransactionTest.java diff --git a/driver/src/main/java/org/neo4j/driver/exceptions/TransactionTerminatedException.java b/driver/src/main/java/org/neo4j/driver/exceptions/TransactionTerminatedException.java new file mode 100644 index 0000000000..e791d5f46b --- /dev/null +++ b/driver/src/main/java/org/neo4j/driver/exceptions/TransactionTerminatedException.java @@ -0,0 +1,67 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.exceptions; + +import java.io.Serial; + +/** + * Indicates that the transaction has been terminated. + *

+ * Any usage of the terminated transaction and any of its results must be stopped and the transaction must be closed + * explicitly. Moreover, any error in the transaction result(s) should be considered as a transaction termination and + * must be handled in the same way. + *

+ * The exception will contain a non-null {@link #code()} if it is created based on the server's response. It will not + * have a code if it is generated by the driver. + * + * @since 5.11 + */ +public class TransactionTerminatedException extends ClientException { + @Serial + private static final long serialVersionUID = 7639191706067500206L; + + /** + * Creates a new instance. + * + * @param message the message + */ + public TransactionTerminatedException(String message) { + super(message); + } + + /** + * Creates a new instance. + * + * @param code the code + * @param message the message + */ + public TransactionTerminatedException(String code, String message) { + super(code, message); + } + + /** + * Creates a new instance. + * + * @param message the message + * @param cause the cause + */ + public TransactionTerminatedException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/driver/src/main/java/org/neo4j/driver/internal/InternalTransaction.java b/driver/src/main/java/org/neo4j/driver/internal/InternalTransaction.java index 37e7b3a6ef..d7bcbea0e3 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/InternalTransaction.java +++ b/driver/src/main/java/org/neo4j/driver/internal/InternalTransaction.java @@ -66,6 +66,22 @@ public boolean isOpen() { return tx.isOpen(); } + /** + * THIS IS A PRIVATE API + *

+ * Terminates the transaction by sending the Bolt {@code RESET} message and waiting for its response as long as the + * transaction has not already been terminated, is not closed or closing. + * + * @since 5.11 + * @throws org.neo4j.driver.exceptions.ClientException if the transaction is closed or is closing + * @see org.neo4j.driver.exceptions.TransactionTerminatedException + */ + public void terminate() { + Futures.blockingGet( + tx.terminateAsync(), + () -> terminateConnectionOnThreadInterrupt("Thread interrupted while terminating the transaction")); + } + private void terminateConnectionOnThreadInterrupt(String reason) { tx.connection().terminateAndRelease(reason); } diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/NetworkConnection.java b/driver/src/main/java/org/neo4j/driver/internal/async/NetworkConnection.java index 5e4fb4883d..eccea79483 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/NetworkConnection.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/NetworkConnection.java @@ -22,6 +22,7 @@ import static org.neo4j.driver.internal.async.connection.ChannelAttributes.poolId; import static org.neo4j.driver.internal.async.connection.ChannelAttributes.setTerminationReason; import static org.neo4j.driver.internal.util.Futures.asCompletionStage; +import static org.neo4j.driver.internal.util.LockUtil.executeWithLock; import io.netty.channel.Channel; import io.netty.channel.ChannelHandler; @@ -29,7 +30,9 @@ import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; import java.util.concurrent.TimeUnit; -import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantLock; +import java.util.function.Consumer; import org.neo4j.driver.Logger; import org.neo4j.driver.Logging; import org.neo4j.driver.internal.BoltServerAddress; @@ -41,7 +44,12 @@ import org.neo4j.driver.internal.handlers.ResetResponseHandler; import org.neo4j.driver.internal.messaging.BoltProtocol; import org.neo4j.driver.internal.messaging.Message; +import org.neo4j.driver.internal.messaging.request.DiscardAllMessage; +import org.neo4j.driver.internal.messaging.request.DiscardMessage; +import org.neo4j.driver.internal.messaging.request.PullAllMessage; +import org.neo4j.driver.internal.messaging.request.PullMessage; import org.neo4j.driver.internal.messaging.request.ResetMessage; +import org.neo4j.driver.internal.messaging.request.RunWithMetadataMessage; import org.neo4j.driver.internal.metrics.ListenerEvent; import org.neo4j.driver.internal.metrics.MetricsListener; import org.neo4j.driver.internal.spi.Connection; @@ -53,6 +61,7 @@ */ public class NetworkConnection implements Connection { private final Logger log; + private final Lock lock; private final Channel channel; private final InboundMessageDispatcher messageDispatcher; private final String serverAgent; @@ -61,12 +70,13 @@ public class NetworkConnection implements Connection { private final ExtendedChannelPool channelPool; private final CompletableFuture releaseFuture; private final Clock clock; - - private final AtomicReference status = new AtomicReference<>(Status.OPEN); private final MetricsListener metricsListener; private final ListenerEvent inUseEvent; private final Long connectionReadTimeout; + + private Status status = Status.OPEN; + private UnmanagedTransaction transaction; private ChannelHandler connectionReadTimeoutHandler; public NetworkConnection( @@ -76,6 +86,7 @@ public NetworkConnection( MetricsListener metricsListener, Logging logging) { this.log = logging.getLog(getClass()); + this.lock = new ReentrantLock(); this.channel = channel; this.messageDispatcher = ChannelAttributes.messageDispatcher(channel); this.serverAgent = ChannelAttributes.serverAgent(channel); @@ -93,7 +104,7 @@ public NetworkConnection( @Override public boolean isOpen() { - return status.get() == Status.OPEN; + return executeWithLock(lock, () -> status == Status.OPEN); } @Override @@ -110,52 +121,31 @@ public void disableAutoRead() { } } - @Override - public void flush() { - if (verifyOpen(null, null)) { - flushInEventLoop(); - } - } - @Override public void write(Message message, ResponseHandler handler) { - if (verifyOpen(handler, null)) { + if (verifyOpen(handler)) { writeMessageInEventLoop(message, handler, false); } } - @Override - public void write(Message message1, ResponseHandler handler1, Message message2, ResponseHandler handler2) { - if (verifyOpen(handler1, handler2)) { - writeMessagesInEventLoop(message1, handler1, message2, handler2, false); - } - } - @Override public void writeAndFlush(Message message, ResponseHandler handler) { - if (verifyOpen(handler, null)) { + if (verifyOpen(handler)) { writeMessageInEventLoop(message, handler, true); } } @Override - public void writeAndFlush(Message message1, ResponseHandler handler1, Message message2, ResponseHandler handler2) { - if (verifyOpen(handler1, handler2)) { - writeMessagesInEventLoop(message1, handler1, message2, handler2, true); - } - } - - @Override - public CompletionStage reset() { - CompletableFuture result = new CompletableFuture<>(); - ResetResponseHandler handler = new ResetResponseHandler(messageDispatcher, result); + public CompletionStage reset(Throwable throwable) { + var result = new CompletableFuture(); + var handler = new ResetResponseHandler(messageDispatcher, result, throwable); writeResetMessageIfNeeded(handler, true); return result; } @Override public CompletionStage release() { - if (status.compareAndSet(Status.OPEN, Status.RELEASED)) { + if (executeWithLock(lock, () -> updateStateIfOpen(Status.RELEASED))) { ChannelReleasingResetResponseHandler handler = new ChannelReleasingResetResponseHandler( channel, channelPool, messageDispatcher, clock, releaseFuture); @@ -167,7 +157,7 @@ public CompletionStage release() { @Override public void terminateAndRelease(String reason) { - if (status.compareAndSet(Status.OPEN, Status.TERMINATED)) { + if (executeWithLock(lock, () -> updateStateIfOpen(Status.TERMINATED))) { setTerminationReason(channel, reason); asCompletionStage(channel.close()) .exceptionally(throwable -> null) @@ -194,6 +184,25 @@ public BoltProtocol protocol() { return protocol; } + @Override + public void bindTransaction(UnmanagedTransaction transaction) { + executeWithLock(lock, () -> { + if (this.transaction != null) { + throw new IllegalStateException("transaction is already set"); + } + this.transaction = transaction; + }); + } + + private boolean updateStateIfOpen(Status newStatus) { + if (Status.OPEN.equals(status)) { + status = newStatus; + return true; + } else { + return false; + } + } + private void writeResetMessageIfNeeded(ResponseHandler resetHandler, boolean isSessionReset) { channel.eventLoop().execute(() -> { if (isSessionReset && !isOpen()) { @@ -208,73 +217,49 @@ private void writeResetMessageIfNeeded(ResponseHandler resetHandler, boolean isS }); } - private void flushInEventLoop() { - channel.eventLoop().execute(() -> { - channel.flush(); - registerConnectionReadTimeout(channel); - }); - } - private void writeMessageInEventLoop(Message message, ResponseHandler handler, boolean flush) { - channel.eventLoop().execute(() -> { - messageDispatcher.enqueue(handler); - - if (flush) { - channel.writeAndFlush(message).addListener(future -> registerConnectionReadTimeout(channel)); - } else { - channel.write(message, channel.voidPromise()); - } - }); - } - - private void writeMessagesInEventLoop( - Message message1, ResponseHandler handler1, Message message2, ResponseHandler handler2, boolean flush) { - channel.eventLoop().execute(() -> { - messageDispatcher.enqueue(handler1); - messageDispatcher.enqueue(handler2); - - channel.write(message1, channel.voidPromise()); - - if (flush) { - channel.writeAndFlush(message2).addListener(future -> registerConnectionReadTimeout(channel)); - } else { - channel.write(message2, channel.voidPromise()); - } - }); + channel.eventLoop() + .execute(() -> transactionTerminationAwareExecutor(message).accept(causeOfTermination -> { + if (causeOfTermination == null) { + messageDispatcher.enqueue(handler); + + if (flush) { + channel.writeAndFlush(message) + .addListener(future -> registerConnectionReadTimeout(channel)); + } else { + channel.write(message, channel.voidPromise()); + } + } else { + handler.onFailure(causeOfTermination); + } + })); } private void setAutoRead(boolean value) { channel.config().setAutoRead(value); } - private boolean verifyOpen(ResponseHandler handler1, ResponseHandler handler2) { - Status connectionStatus = this.status.get(); - switch (connectionStatus) { - case OPEN: - return true; - case RELEASED: + private boolean verifyOpen(ResponseHandler handler) { + var connectionStatus = executeWithLock(lock, () -> status); + return switch (connectionStatus) { + case OPEN -> true; + case RELEASED -> { Exception error = new IllegalStateException("Connection has been released to the pool and can't be used"); - if (handler1 != null) { - handler1.onFailure(error); + if (handler != null) { + handler.onFailure(error); } - if (handler2 != null) { - handler2.onFailure(error); - } - return false; - case TERMINATED: + yield false; + } + case TERMINATED -> { Exception terminatedError = new IllegalStateException("Connection has been terminated and can't be used"); - if (handler1 != null) { - handler1.onFailure(terminatedError); - } - if (handler2 != null) { - handler2.onFailure(terminatedError); + if (handler != null) { + handler.onFailure(terminatedError); } - return false; - default: - throw new IllegalStateException("Unknown status: " + connectionStatus); - } + yield false; + } + }; } private void registerConnectionReadTimeout(Channel channel) { @@ -295,6 +280,25 @@ private void registerConnectionReadTimeout(Channel channel) { } } + private Consumer> transactionTerminationAwareExecutor(Message message) { + var result = (Consumer>) consumer -> consumer.accept(null); + if (isQueryMessage(message)) { + var transaction = executeWithLock(lock, () -> this.transaction); + if (transaction != null) { + result = transaction::executeWithLockedState; + } + } + return result; + } + + private boolean isQueryMessage(Message message) { + return message instanceof RunWithMetadataMessage + || message instanceof PullMessage + || message instanceof PullAllMessage + || message instanceof DiscardMessage + || message instanceof DiscardAllMessage; + } + private enum Status { OPEN, RELEASED, diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/NetworkSession.java b/driver/src/main/java/org/neo4j/driver/internal/async/NetworkSession.java index f57df5218e..e422b33a39 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/NetworkSession.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/NetworkSession.java @@ -31,6 +31,7 @@ import java.util.concurrent.CompletionException; import java.util.concurrent.CompletionStage; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; import org.neo4j.driver.AccessMode; import org.neo4j.driver.AuthToken; import org.neo4j.driver.Bookmark; @@ -172,17 +173,18 @@ public CompletionStage beginTransactionAsync( } public CompletionStage resetAsync() { + var terminationException = new AtomicReference(); return existingTransactionOrNull() .thenAccept(tx -> { if (tx != null) { - tx.markTerminated(null); + terminationException.set(tx.markTerminated(null)); } }) .thenCompose(ignore -> connectionStage) .thenCompose(connection -> { if (connection != null) { // there exists an active connection, send a RESET message over it - return connection.reset(); + return connection.reset(terminationException.get()); } return completedWithNull(); }); diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/UnmanagedTransaction.java b/driver/src/main/java/org/neo4j/driver/internal/async/UnmanagedTransaction.java index 494f845bfa..7156b0c978 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/UnmanagedTransaction.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/UnmanagedTransaction.java @@ -18,6 +18,7 @@ */ package org.neo4j.driver.internal.async; +import static java.util.concurrent.CompletableFuture.completedFuture; import static org.neo4j.driver.internal.util.Futures.asCompletionException; import static org.neo4j.driver.internal.util.Futures.combineErrors; import static org.neo4j.driver.internal.util.Futures.completedWithNull; @@ -44,6 +45,7 @@ import org.neo4j.driver.exceptions.AuthorizationExpiredException; import org.neo4j.driver.exceptions.ClientException; import org.neo4j.driver.exceptions.ConnectionReadTimeoutException; +import org.neo4j.driver.exceptions.TransactionTerminatedException; import org.neo4j.driver.internal.DatabaseBookmark; import org.neo4j.driver.internal.cursor.AsyncResultCursor; import org.neo4j.driver.internal.cursor.RxResultCursor; @@ -73,6 +75,8 @@ private enum State { ROLLED_BACK } + public static final String EXPLICITLY_TERMINATED_MSG = + "The transaction has been explicitly terminated by the driver"; protected static final String CANT_COMMIT_COMMITTED_MSG = "Can't commit, transaction has been committed"; protected static final String CANT_ROLLBACK_COMMITTED_MSG = "Can't rollback, transaction has been committed"; protected static final String CANT_COMMIT_ROLLED_BACK_MSG = "Can't commit, transaction has been rolled back"; @@ -93,7 +97,7 @@ private enum State { private CompletableFuture commitFuture; private CompletableFuture rollbackFuture; private Throwable causeOfTermination; - private CompletionStage interruptStage; + private CompletionStage terminationStage; private final NotificationConfig notificationConfig; public UnmanagedTransaction( @@ -116,6 +120,8 @@ protected UnmanagedTransaction( this.resultCursors = resultCursors; this.fetchSize = fetchSize; this.notificationConfig = notificationConfig; + + connection.bindTransaction(this); } public CompletionStage beginAsync( @@ -176,16 +182,18 @@ public boolean isOpen() { return OPEN_STATES.contains(executeWithLock(lock, () -> state)); } - public void markTerminated(Throwable cause) { - executeWithLock(lock, () -> { + public Throwable markTerminated(Throwable cause) { + return executeWithLock(lock, () -> { if (state == State.TERMINATED) { - if (causeOfTermination != null && cause != null) { + if (cause != null) { addSuppressedWhenNotCaptured(causeOfTermination, cause); } } else { state = State.TERMINATED; - causeOfTermination = cause; + causeOfTermination = + cause != null ? cause : new TransactionTerminatedException(EXPLICITLY_TERMINATED_MSG); } + return causeOfTermination; }); } @@ -203,6 +211,32 @@ public Connection connection() { return connection; } + /** + * Locks the transaction state and executes the supplied {@link Consumer} with a cause of termination if the + * transaction is terminated. + * + * @param causeOfTerminationConsumer the consumer accepting + */ + public void executeWithLockedState(Consumer causeOfTerminationConsumer) { + executeWithLock(lock, () -> causeOfTerminationConsumer.accept(causeOfTermination)); + } + + public CompletionStage terminateAsync() { + return executeWithLock(lock, () -> { + if (!isOpen() || commitFuture != null || rollbackFuture != null) { + return failedFuture(new ClientException("Can't terminate closed or closing transaction")); + } else { + if (state == State.TERMINATED) { + return terminationStage != null ? terminationStage : completedFuture(null); + } else { + var terminationException = markTerminated(null); + terminationStage = connection.reset(terminationException); + return terminationStage; + } + } + }); + } + private void ensureCanRunQueries() { executeWithLock(lock, () -> { if (state == State.COMMITTED) { @@ -210,10 +244,14 @@ private void ensureCanRunQueries() { } else if (state == State.ROLLED_BACK) { throw new ClientException("Cannot run more queries in this transaction, it has been rolled back"); } else if (state == State.TERMINATED) { - throw new ClientException( - "Cannot run more queries in this transaction, " - + "it has either experienced an fatal error or was explicitly terminated", - causeOfTermination); + if (causeOfTermination instanceof TransactionTerminatedException transactionTerminatedException) { + throw transactionTerminatedException; + } else { + throw new TransactionTerminatedException( + "Cannot run more queries in this transaction, " + + "it has either experienced an fatal error or was explicitly terminated", + causeOfTermination); + } } }); } @@ -222,7 +260,7 @@ private CompletionStage doCommitAsync(Throwable cursorFailure) { ClientException exception = executeWithLock( lock, () -> state == State.TERMINATED - ? new ClientException( + ? new TransactionTerminatedException( "Transaction can't be committed. " + "It has been rolled back either because of an error or explicit termination", cursorFailure != causeOfTermination ? causeOfTermination : null) @@ -318,21 +356,4 @@ private CompletionStage closeAsync(boolean commit, boolean completeWithNul return stage; } - - /** - * Marks transaction as terminated and sends {@code RESET} message over allocated connection. - *

- * THIS METHOD IS NOT PART OF PUBLIC API. This method may be changed or removed at any moment in time. - * - * @return {@code RESET} response stage - */ - public CompletionStage interruptAsync() { - return executeWithLock(lock, () -> { - if (interruptStage == null) { - markTerminated(null); - interruptStage = connection.reset(); - } - return interruptStage; - }); - } } diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/connection/DirectConnection.java b/driver/src/main/java/org/neo4j/driver/internal/async/connection/DirectConnection.java index 010618b4a8..1af7b6e407 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/connection/DirectConnection.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/connection/DirectConnection.java @@ -23,6 +23,7 @@ import org.neo4j.driver.internal.BoltServerAddress; import org.neo4j.driver.internal.DatabaseName; import org.neo4j.driver.internal.DirectConnectionProvider; +import org.neo4j.driver.internal.async.UnmanagedTransaction; import org.neo4j.driver.internal.messaging.BoltProtocol; import org.neo4j.driver.internal.messaging.Message; import org.neo4j.driver.internal.spi.Connection; @@ -68,24 +69,14 @@ public void write(Message message, ResponseHandler handler) { delegate.write(message, handler); } - @Override - public void write(Message message1, ResponseHandler handler1, Message message2, ResponseHandler handler2) { - delegate.write(message1, handler1, message2, handler2); - } - @Override public void writeAndFlush(Message message, ResponseHandler handler) { delegate.writeAndFlush(message, handler); } @Override - public void writeAndFlush(Message message1, ResponseHandler handler1, Message message2, ResponseHandler handler2) { - delegate.writeAndFlush(message1, handler1, message2, handler2); - } - - @Override - public CompletionStage reset() { - return delegate.reset(); + public CompletionStage reset(Throwable throwable) { + return delegate.reset(throwable); } @Override @@ -113,6 +104,11 @@ public BoltProtocol protocol() { return delegate.protocol(); } + @Override + public void bindTransaction(UnmanagedTransaction transaction) { + delegate.bindTransaction(transaction); + } + @Override public AccessMode mode() { return mode; @@ -127,9 +123,4 @@ public DatabaseName databaseName() { public String impersonatedUser() { return impersonatedUser; } - - @Override - public void flush() { - delegate.flush(); - } } diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/connection/RoutingConnection.java b/driver/src/main/java/org/neo4j/driver/internal/async/connection/RoutingConnection.java index fec880ace0..0ee584ab35 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/connection/RoutingConnection.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/connection/RoutingConnection.java @@ -23,6 +23,7 @@ import org.neo4j.driver.internal.BoltServerAddress; import org.neo4j.driver.internal.DatabaseName; import org.neo4j.driver.internal.RoutingErrorHandler; +import org.neo4j.driver.internal.async.UnmanagedTransaction; import org.neo4j.driver.internal.handlers.RoutingResponseHandler; import org.neo4j.driver.internal.messaging.BoltProtocol; import org.neo4j.driver.internal.messaging.Message; @@ -67,25 +68,14 @@ public void write(Message message, ResponseHandler handler) { delegate.write(message, newRoutingResponseHandler(handler)); } - @Override - public void write(Message message1, ResponseHandler handler1, Message message2, ResponseHandler handler2) { - delegate.write(message1, newRoutingResponseHandler(handler1), message2, newRoutingResponseHandler(handler2)); - } - @Override public void writeAndFlush(Message message, ResponseHandler handler) { delegate.writeAndFlush(message, newRoutingResponseHandler(handler)); } @Override - public void writeAndFlush(Message message1, ResponseHandler handler1, Message message2, ResponseHandler handler2) { - delegate.writeAndFlush( - message1, newRoutingResponseHandler(handler1), message2, newRoutingResponseHandler(handler2)); - } - - @Override - public CompletionStage reset() { - return delegate.reset(); + public CompletionStage reset(Throwable throwable) { + return delegate.reset(throwable); } @Override @@ -119,8 +109,8 @@ public BoltProtocol protocol() { } @Override - public void flush() { - delegate.flush(); + public void bindTransaction(UnmanagedTransaction transaction) { + delegate.bindTransaction(transaction); } @Override diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/inbound/InboundMessageDispatcher.java b/driver/src/main/java/org/neo4j/driver/internal/async/inbound/InboundMessageDispatcher.java index 20c9a030b9..307e105126 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/inbound/InboundMessageDispatcher.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/inbound/InboundMessageDispatcher.java @@ -28,6 +28,8 @@ import java.util.Arrays; import java.util.LinkedList; import java.util.Map; +import java.util.Objects; +import java.util.Optional; import java.util.Queue; import org.neo4j.driver.Logger; import org.neo4j.driver.Logging; @@ -146,22 +148,25 @@ public void handleFailureMessage(String code, String message) { public void handleIgnoredMessage() { log.debug("S: IGNORED"); - ResponseHandler handler = removeHandler(); - - Throwable error; - if (currentError != null) { - error = currentError; - } else { - log.warn( - "Received IGNORED message for handler %s but error is missing and RESET is not in progress. " - + "Current handlers %s", - handler, handlers); - - error = new ClientException("Database ignored the request"); - } + var handler = removeHandler(); + var error = Objects.requireNonNullElseGet(currentError, () -> getPendingResetHandler() + .flatMap(ResetResponseHandler::throwable) + .orElseGet(() -> { + log.warn( + "Received IGNORED message for handler %s but error is missing and RESET is not in progress. Current handlers %s", + handler, handlers); + return new ClientException("Database ignored the request"); + })); handler.onFailure(error); } + private Optional getPendingResetHandler() { + return handlers.stream() + .filter(h -> h instanceof ResetResponseHandler) + .map(h -> (ResetResponseHandler) h) + .findFirst(); + } + public void handleChannelInactive(Throwable cause) { // report issue if the connection has not been terminated as a result of a graceful shutdown request from its // parent pool diff --git a/driver/src/main/java/org/neo4j/driver/internal/handlers/ResetResponseHandler.java b/driver/src/main/java/org/neo4j/driver/internal/handlers/ResetResponseHandler.java index 82cd04d207..af89eb3e4f 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/handlers/ResetResponseHandler.java +++ b/driver/src/main/java/org/neo4j/driver/internal/handlers/ResetResponseHandler.java @@ -19,6 +19,7 @@ package org.neo4j.driver.internal.handlers; import java.util.Map; +import java.util.Optional; import java.util.concurrent.CompletableFuture; import org.neo4j.driver.Value; import org.neo4j.driver.internal.async.inbound.InboundMessageDispatcher; @@ -27,14 +28,21 @@ public class ResetResponseHandler implements ResponseHandler { private final InboundMessageDispatcher messageDispatcher; private final CompletableFuture completionFuture; + private final Throwable throwable; public ResetResponseHandler(InboundMessageDispatcher messageDispatcher) { this(messageDispatcher, null); } public ResetResponseHandler(InboundMessageDispatcher messageDispatcher, CompletableFuture completionFuture) { + this(messageDispatcher, completionFuture, null); + } + + public ResetResponseHandler( + InboundMessageDispatcher messageDispatcher, CompletableFuture completionFuture, Throwable throwable) { this.messageDispatcher = messageDispatcher; this.completionFuture = completionFuture; + this.throwable = throwable; } @Override @@ -52,6 +60,10 @@ public final void onRecord(Value[] fields) { throw new UnsupportedOperationException(); } + public Optional throwable() { + return Optional.ofNullable(throwable); + } + private void resetCompleted(boolean success) { messageDispatcher.clearCurrentError(); if (completionFuture != null) { diff --git a/driver/src/main/java/org/neo4j/driver/internal/handlers/pulln/AutoPullResponseHandler.java b/driver/src/main/java/org/neo4j/driver/internal/handlers/pulln/AutoPullResponseHandler.java index 84038d2cb7..ed5952431a 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/handlers/pulln/AutoPullResponseHandler.java +++ b/driver/src/main/java/org/neo4j/driver/internal/handlers/pulln/AutoPullResponseHandler.java @@ -145,12 +145,14 @@ public synchronized CompletionStage consumeAsync() { if (isDone()) { return completedWithValueIfNoFailure(summary); } else { - cancel(); - if (summaryFuture == null) { - summaryFuture = new CompletableFuture<>(); + var future = summaryFuture; + if (future == null) { + future = new CompletableFuture<>(); + summaryFuture = future; } + cancel(); - return summaryFuture; + return future; } } @@ -172,12 +174,14 @@ private synchronized CompletionStage pullAllAsync() { if (isDone()) { return completedWithValueIfNoFailure(summary); } else { - request(UNLIMITED_FETCH_SIZE); - if (summaryFuture == null) { - summaryFuture = new CompletableFuture<>(); + var future = summaryFuture; + if (future == null) { + future = new CompletableFuture<>(); + summaryFuture = future; } + request(UNLIMITED_FETCH_SIZE); - return summaryFuture; + return future; } } diff --git a/driver/src/main/java/org/neo4j/driver/internal/handlers/pulln/BasicPullResponseHandler.java b/driver/src/main/java/org/neo4j/driver/internal/handlers/pulln/BasicPullResponseHandler.java index 5b37957244..64056e4b20 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/handlers/pulln/BasicPullResponseHandler.java +++ b/driver/src/main/java/org/neo4j/driver/internal/handlers/pulln/BasicPullResponseHandler.java @@ -43,6 +43,7 @@ * Provides basic handling of pull responses from sever. The state is managed by {@link State}. */ public class BasicPullResponseHandler implements PullResponseHandler { + private static final Runnable NO_OP_RUNNABLE = () -> {}; private final Query query; protected final RunResponseHandler runResponseHandler; protected final MetadataExtractor metadataExtractor; @@ -163,15 +164,33 @@ record = new InternalRecord(runResponseHandler.queryKeys(), fields); } @Override - public synchronized void request(long size) { - assertRecordAndSummaryConsumerInstalled(); - state.request(this, size); + public void request(long size) { + Runnable postAction; + synchronized (this) { + assertRecordAndSummaryConsumerInstalled(); + postAction = state.request(this, size); + if (syncSignals) { + postAction.run(); + } + } + if (!syncSignals) { + postAction.run(); + } } @Override public synchronized void cancel() { - assertRecordAndSummaryConsumerInstalled(); - state.cancel(this); + Runnable postAction; + synchronized (this) { + assertRecordAndSummaryConsumerInstalled(); + postAction = state.cancel(this); + if (syncSignals) { + postAction.run(); + } + } + if (!syncSignals) { + postAction.run(); + } } protected void writePull(long n) { @@ -285,15 +304,15 @@ void onRecord(BasicPullResponseHandler context, Value[] fields) { } @Override - void request(BasicPullResponseHandler context, long n) { + Runnable request(BasicPullResponseHandler context, long n) { context.state(STREAMING_STATE); - context.writePull(n); + return () -> context.writePull(n); } @Override - void cancel(BasicPullResponseHandler context) { + Runnable cancel(BasicPullResponseHandler context) { context.state(CANCELLED_STATE); - context.discardAll(); + return context::discardAll; } }, STREAMING_STATE { @@ -317,14 +336,16 @@ void onRecord(BasicPullResponseHandler context, Value[] fields) { } @Override - void request(BasicPullResponseHandler context, long n) { + Runnable request(BasicPullResponseHandler context, long n) { context.state(STREAMING_STATE); context.addToRequest(n); + return NO_OP_RUNNABLE; } @Override - void cancel(BasicPullResponseHandler context) { + Runnable cancel(BasicPullResponseHandler context) { context.state(CANCELLED_STATE); + return NO_OP_RUNNABLE; } }, CANCELLED_STATE { @@ -349,13 +370,15 @@ void onRecord(BasicPullResponseHandler context, Value[] fields) { } @Override - void request(BasicPullResponseHandler context, long n) { + Runnable request(BasicPullResponseHandler context, long n) { context.state(CANCELLED_STATE); + return NO_OP_RUNNABLE; } @Override - void cancel(BasicPullResponseHandler context) { + Runnable cancel(BasicPullResponseHandler context) { context.state(CANCELLED_STATE); + return NO_OP_RUNNABLE; } }, SUCCEEDED_STATE { @@ -375,13 +398,15 @@ void onRecord(BasicPullResponseHandler context, Value[] fields) { } @Override - void request(BasicPullResponseHandler context, long n) { + Runnable request(BasicPullResponseHandler context, long n) { context.state(SUCCEEDED_STATE); + return NO_OP_RUNNABLE; } @Override - void cancel(BasicPullResponseHandler context) { + Runnable cancel(BasicPullResponseHandler context) { context.state(SUCCEEDED_STATE); + return NO_OP_RUNNABLE; } }, FAILURE_STATE { @@ -401,13 +426,15 @@ void onRecord(BasicPullResponseHandler context, Value[] fields) { } @Override - void request(BasicPullResponseHandler context, long n) { + Runnable request(BasicPullResponseHandler context, long n) { context.state(FAILURE_STATE); + return NO_OP_RUNNABLE; } @Override - void cancel(BasicPullResponseHandler context) { + Runnable cancel(BasicPullResponseHandler context) { context.state(FAILURE_STATE); + return NO_OP_RUNNABLE; } }; @@ -417,8 +444,8 @@ void cancel(BasicPullResponseHandler context) { abstract void onRecord(BasicPullResponseHandler context, Value[] fields); - abstract void request(BasicPullResponseHandler context, long n); + abstract Runnable request(BasicPullResponseHandler context, long n); - abstract void cancel(BasicPullResponseHandler context); + abstract Runnable cancel(BasicPullResponseHandler context); } } diff --git a/driver/src/main/java/org/neo4j/driver/internal/reactive/InternalReactiveTransaction.java b/driver/src/main/java/org/neo4j/driver/internal/reactive/InternalReactiveTransaction.java index c5a089cded..1f7153099d 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/reactive/InternalReactiveTransaction.java +++ b/driver/src/main/java/org/neo4j/driver/internal/reactive/InternalReactiveTransaction.java @@ -60,17 +60,6 @@ public Publisher run(Query query) { .map(InternalReactiveResult::new)); } - /** - * Marks transaction as terminated and sends {@code RESET} message over allocated connection. - *

- * THIS METHOD IS NOT PART OF PUBLIC API. This method may be changed or removed at any moment in time. - * - * @return {@code RESET} response publisher - */ - public Publisher interrupt() { - return publisherToFlowPublisher(Mono.fromCompletionStage(tx.interruptAsync())); - } - @Override public Publisher commit() { return publisherToFlowPublisher(doCommit()); diff --git a/driver/src/main/java/org/neo4j/driver/internal/reactivestreams/InternalReactiveTransaction.java b/driver/src/main/java/org/neo4j/driver/internal/reactivestreams/InternalReactiveTransaction.java index b91419e6cf..40400ec491 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/reactivestreams/InternalReactiveTransaction.java +++ b/driver/src/main/java/org/neo4j/driver/internal/reactivestreams/InternalReactiveTransaction.java @@ -60,14 +60,16 @@ public Publisher run(Query query) { } /** - * Marks transaction as terminated and sends {@code RESET} message over allocated connection. + * THIS IS A PRIVATE API *

- * THIS METHOD IS NOT PART OF PUBLIC API. This method may be changed or removed at any moment in time. + * Terminates the transaction by sending the Bolt {@code RESET} message and waiting for its response as long as the + * transaction has not already been terminated, is not closed or closing. * - * @return {@code RESET} response publisher + * @return completion publisher (the {@code RESET} completion publisher if the message was sent) + * @since 5.11 */ - public Publisher interrupt() { - return Mono.fromCompletionStage(tx.interruptAsync()); + public Publisher terminate() { + return Mono.fromCompletionStage(tx.terminateAsync()); } @Override diff --git a/driver/src/main/java/org/neo4j/driver/internal/spi/Connection.java b/driver/src/main/java/org/neo4j/driver/internal/spi/Connection.java index 54fe232953..17cdb318f5 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/spi/Connection.java +++ b/driver/src/main/java/org/neo4j/driver/internal/spi/Connection.java @@ -24,6 +24,7 @@ import org.neo4j.driver.AccessMode; import org.neo4j.driver.internal.BoltServerAddress; import org.neo4j.driver.internal.DatabaseName; +import org.neo4j.driver.internal.async.UnmanagedTransaction; import org.neo4j.driver.internal.messaging.BoltProtocol; import org.neo4j.driver.internal.messaging.Message; @@ -36,13 +37,9 @@ public interface Connection { void write(Message message, ResponseHandler handler); - void write(Message message1, ResponseHandler handler1, Message message2, ResponseHandler handler2); - void writeAndFlush(Message message, ResponseHandler handler); - void writeAndFlush(Message message1, ResponseHandler handler1, Message message2, ResponseHandler handler2); - - CompletionStage reset(); + CompletionStage reset(Throwable throwable); CompletionStage release(); @@ -54,6 +51,8 @@ public interface Connection { BoltProtocol protocol(); + void bindTransaction(UnmanagedTransaction transaction); + default AccessMode mode() { throw new UnsupportedOperationException(format("%s does not support access mode.", getClass())); } @@ -65,6 +64,4 @@ default DatabaseName databaseName() { default String impersonatedUser() { throw new UnsupportedOperationException(format("%s does not support impersonated user.", getClass())); } - - void flush(); } diff --git a/driver/src/main/java/org/neo4j/driver/internal/util/ErrorUtil.java b/driver/src/main/java/org/neo4j/driver/internal/util/ErrorUtil.java index 4a2f2a18f5..f569ab1c18 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/util/ErrorUtil.java +++ b/driver/src/main/java/org/neo4j/driver/internal/util/ErrorUtil.java @@ -32,6 +32,7 @@ import org.neo4j.driver.exceptions.SecurityException; import org.neo4j.driver.exceptions.ServiceUnavailableException; import org.neo4j.driver.exceptions.TokenExpiredException; +import org.neo4j.driver.exceptions.TransactionTerminatedException; import org.neo4j.driver.exceptions.TransientException; public final class ErrorUtil { @@ -72,6 +73,8 @@ public static Neo4jException newNeo4jError(String code, String message) { } else { if (code.equalsIgnoreCase("Neo.ClientError.Database.DatabaseNotFound")) { return new FatalDiscoveryException(code, message); + } else if (code.equalsIgnoreCase("Neo.ClientError.Transaction.Terminated")) { + return new TransactionTerminatedException(code, message); } else { return new ClientException(code, message); } @@ -80,7 +83,7 @@ public static Neo4jException newNeo4jError(String code, String message) { // Since 5.0 these 2 errors have been moved to ClientError class. // This mapping is required if driver is connection to earlier server versions. if ("Neo.TransientError.Transaction.Terminated".equals(code)) { - return new ClientException("Neo.ClientError.Transaction.Terminated", message); + return new TransactionTerminatedException("Neo.ClientError.Transaction.Terminated", message); } else if ("Neo.TransientError.Transaction.LockClientStopped".equals(code)) { return new ClientException("Neo.ClientError.Transaction.LockClientStopped", message); } else { diff --git a/driver/src/test/java/org/neo4j/driver/integration/SessionResetIT.java b/driver/src/test/java/org/neo4j/driver/integration/SessionResetIT.java index 950f0e1655..5eea5beb12 100644 --- a/driver/src/test/java/org/neo4j/driver/integration/SessionResetIT.java +++ b/driver/src/test/java/org/neo4j/driver/integration/SessionResetIT.java @@ -24,7 +24,6 @@ import static java.util.concurrent.TimeUnit.MINUTES; import static java.util.concurrent.TimeUnit.SECONDS; import static java.util.stream.Collectors.toList; -import static org.hamcrest.CoreMatchers.containsString; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.instanceOf; import static org.hamcrest.CoreMatchers.startsWith; @@ -66,6 +65,7 @@ import org.neo4j.driver.exceptions.ClientException; import org.neo4j.driver.exceptions.Neo4jException; import org.neo4j.driver.exceptions.ServiceUnavailableException; +import org.neo4j.driver.exceptions.TransactionTerminatedException; import org.neo4j.driver.exceptions.TransientException; import org.neo4j.driver.internal.InternalSession; import org.neo4j.driver.testutil.DatabaseExtension; @@ -158,17 +158,16 @@ void shouldAllowMoreTxAfterSessionReset() { @Test void shouldMarkTxAsFailedAndDisallowRunAfterSessionReset() { // Given - try (InternalSession session = (InternalSession) neo4j.driver().session()) { - Transaction tx = session.beginTransaction(); + try (var session = (InternalSession) neo4j.driver().session()) { + var tx = session.beginTransaction(); // When reset the state of this session session.reset(); // Then - Exception e = assertThrows(Exception.class, () -> { + assertThrows(TransactionTerminatedException.class, () -> { tx.run("RETURN 1"); tx.commit(); }); - assertThat(e.getMessage(), startsWith("Cannot run more queries in this transaction")); } } @@ -282,12 +281,11 @@ void shouldBeAbleToRunMoreQueriesAfterResetOnNoErrorState() { @Test void shouldHandleResetBeforeRun() { - try (InternalSession session = (InternalSession) neo4j.driver().session(); - Transaction tx = session.beginTransaction()) { + try (var session = (InternalSession) neo4j.driver().session(); + var tx = session.beginTransaction()) { session.reset(); - ClientException e = assertThrows(ClientException.class, () -> tx.run("CREATE (n:FirstNode)")); - assertThat(e.getMessage(), containsString("Cannot run more queries in this transaction")); + assertThrows(TransactionTerminatedException.class, () -> tx.run("CREATE (n:FirstNode)")); } } diff --git a/driver/src/test/java/org/neo4j/driver/integration/TransactionIT.java b/driver/src/test/java/org/neo4j/driver/integration/TransactionIT.java index 64b9e8b372..e369e0d021 100644 --- a/driver/src/test/java/org/neo4j/driver/integration/TransactionIT.java +++ b/driver/src/test/java/org/neo4j/driver/integration/TransactionIT.java @@ -25,6 +25,7 @@ import static org.hamcrest.junit.MatcherAssert.assertThat; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.neo4j.driver.internal.logging.DevNullLogging.DEV_NULL_LOGGING; @@ -35,8 +36,11 @@ import java.util.List; import java.util.Map; import java.util.function.Consumer; +import java.util.stream.LongStream; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; import org.neo4j.driver.Config; import org.neo4j.driver.Driver; import org.neo4j.driver.Record; @@ -46,6 +50,8 @@ import org.neo4j.driver.Value; import org.neo4j.driver.exceptions.ClientException; import org.neo4j.driver.exceptions.ServiceUnavailableException; +import org.neo4j.driver.exceptions.TransactionTerminatedException; +import org.neo4j.driver.internal.InternalTransaction; import org.neo4j.driver.internal.security.SecurityPlanImpl; import org.neo4j.driver.internal.util.io.ChannelTrackingDriverFactory; import org.neo4j.driver.testutil.ParallelizableIT; @@ -143,7 +149,7 @@ void shouldFailToRollbackAfterTxIsCommitted() { } @Test - void shouldFailToCommitAfterCommit() throws Throwable { + void shouldFailToCommitAfterCommit() { Transaction tx = session.beginTransaction(); tx.run("CREATE (:MyLabel)"); tx.commit(); @@ -153,7 +159,7 @@ void shouldFailToCommitAfterCommit() throws Throwable { } @Test - void shouldFailToRollbackAfterRollback() throws Throwable { + void shouldFailToRollbackAfterRollback() { Transaction tx = session.beginTransaction(); tx.run("CREATE (:MyLabel)"); tx.rollback(); @@ -434,6 +440,146 @@ void shouldRollbackWhenOneOfQueriesFails() { assertEquals(0, countNodesByLabel("Node4")); } + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void shouldPreventPullAfterTransactionTermination(boolean iterate) { + // Given + var tx = session.beginTransaction(); + var streamSize = Config.defaultConfig().fetchSize() + 1; + var result0 = tx.run("UNWIND range(1, $limit) AS x RETURN x", Map.of("limit", streamSize)); + var result1 = tx.run("UNWIND range(1, $limit) AS x RETURN x", Map.of("limit", streamSize)); + + // When + var terminationException = assertThrows(ClientException.class, () -> tx.run("invalid")); + assertEquals(terminationException.code(), "Neo.ClientError.Statement.SyntaxError"); + + // Then + for (var result : List.of(result0, result1)) { + var exception = assertThrows(ClientException.class, () -> { + if (iterate) { + LongStream.range(0, streamSize).forEach(ignored -> result.next()); + } else { + result.list(); + } + }); + assertEquals(terminationException, exception); + } + tx.close(); + } + + @Test + void shouldPreventDiscardAfterTransactionTermination() { + // Given + var tx = session.beginTransaction(); + var streamSize = Config.defaultConfig().fetchSize() + 1; + var result0 = tx.run("UNWIND range(1, $limit) AS x RETURN x", Map.of("limit", streamSize)); + var result1 = tx.run("UNWIND range(1, $limit) AS x RETURN x", Map.of("limit", streamSize)); + + // When + var terminationException = assertThrows(ClientException.class, () -> tx.run("invalid")); + assertEquals(terminationException.code(), "Neo.ClientError.Statement.SyntaxError"); + + // Then + for (var result : List.of(result0, result1)) { + var exception = assertThrows(ClientException.class, result::consume); + assertEquals(terminationException, exception); + } + tx.close(); + } + + @Test + void shouldPreventRunAfterTransactionTermination() { + // Given + var tx = session.beginTransaction(); + var terminationException = assertThrows(ClientException.class, () -> tx.run("invalid")); + assertEquals(terminationException.code(), "Neo.ClientError.Statement.SyntaxError"); + + // When + var exception = assertThrows(TransactionTerminatedException.class, () -> tx.run("RETURN 1")); + + // Then + assertEquals(terminationException, exception.getCause()); + tx.close(); + } + + @ParameterizedTest + @ValueSource(booleans = {true, false}) + void shouldPreventPullAfterDriverTransactionTermination(boolean iterate) { + // Given + var tx = (InternalTransaction) session.beginTransaction(); + var streamSize = Config.defaultConfig().fetchSize() + 1; + var result0 = tx.run("UNWIND range(1, $limit) AS x RETURN x", Map.of("limit", streamSize)); + var result1 = tx.run("UNWIND range(1, $limit) AS x RETURN x", Map.of("limit", streamSize)); + + // When + tx.terminate(); + + // Then + for (var result : List.of(result0, result1)) { + assertThrows(TransactionTerminatedException.class, () -> { + if (iterate) { + LongStream.range(0, streamSize).forEach(ignored -> result.next()); + } else { + result.list(); + } + }); + } + tx.close(); + } + + @Test + void shouldPreventDiscardAfterDriverTransactionTermination() { + // Given + var tx = (InternalTransaction) session.beginTransaction(); + var streamSize = Config.defaultConfig().fetchSize() + 1; + var result0 = tx.run("UNWIND range(1, $limit) AS x RETURN x", Map.of("limit", streamSize)); + var result1 = tx.run("UNWIND range(1, $limit) AS x RETURN x", Map.of("limit", streamSize)); + + // When + tx.terminate(); + + // Then + for (var result : List.of(result0, result1)) { + assertThrows(TransactionTerminatedException.class, result::consume); + } + tx.close(); + } + + @Test + void shouldPreventRunAfterDriverTransactionTermination() { + // Given + var tx = (InternalTransaction) session.beginTransaction(); + var streamSize = Config.defaultConfig().fetchSize() + 1; + var result = tx.run("UNWIND range(1, $limit) AS x RETURN x", Map.of("limit", streamSize)); + result.next(); + + // When + tx.terminate(); + + // Then + assertThrows(TransactionTerminatedException.class, () -> tx.run("UNWIND range(0, 5) AS x RETURN x")); + // the result handle has the pending error + assertThrows(TransactionTerminatedException.class, tx::close); + // all errors have been surfaced + tx.close(); + } + + @Test + void shouldTerminateTransactionAndHandleFailureResponseOrPreventFurtherPulls() { + // Given + var tx = (InternalTransaction) session.beginTransaction(); + var streamSize = Config.defaultConfig().fetchSize() + 1; + var result = tx.run("UNWIND range(1, $limit) AS x RETURN x", Map.of("limit", streamSize)); + + // When + tx.terminate(); + + // Then + assertThrows(TransactionTerminatedException.class, () -> LongStream.range(0, streamSize) + .forEach(ignored -> assertNotNull(result.next()))); + tx.close(); + } + private void shouldRunAndCloseAfterAction(Consumer txConsumer, boolean isCommit) { // When try (Transaction tx = session.beginTransaction()) { diff --git a/driver/src/test/java/org/neo4j/driver/integration/UnmanagedTransactionIT.java b/driver/src/test/java/org/neo4j/driver/integration/UnmanagedTransactionIT.java index 75d8d735de..50c482e072 100644 --- a/driver/src/test/java/org/neo4j/driver/integration/UnmanagedTransactionIT.java +++ b/driver/src/test/java/org/neo4j/driver/integration/UnmanagedTransactionIT.java @@ -25,6 +25,7 @@ import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; import static org.neo4j.driver.Values.parameters; import static org.neo4j.driver.internal.logging.DevNullLogging.DEV_NULL_LOGGING; import static org.neo4j.driver.testutil.TestUtil.await; @@ -46,7 +47,9 @@ import org.neo4j.driver.TransactionConfig; import org.neo4j.driver.async.ResultCursor; import org.neo4j.driver.exceptions.ClientException; +import org.neo4j.driver.exceptions.Neo4jException; import org.neo4j.driver.exceptions.ServiceUnavailableException; +import org.neo4j.driver.exceptions.TransactionTerminatedException; import org.neo4j.driver.internal.InternalDriver; import org.neo4j.driver.internal.async.NetworkSession; import org.neo4j.driver.internal.async.UnmanagedTransaction; @@ -111,11 +114,11 @@ void shouldFailToCommitAfterRollback() { @Test void shouldFailToCommitAfterTermination() { - UnmanagedTransaction tx = beginTransaction(); + var tx = beginTransaction(); tx.markTerminated(null); - ClientException e = assertThrows(ClientException.class, () -> await(tx.commitAsync())); + var e = assertThrows(TransactionTerminatedException.class, () -> await(tx.commitAsync())); assertThat(e.getMessage(), startsWith("Transaction can't be committed")); } @@ -154,10 +157,12 @@ void shouldRollbackAfterTermination() { void shouldFailToRunQueryWhenTerminated() { UnmanagedTransaction tx = beginTransaction(); txRun(tx, "CREATE (:MyLabel)"); - tx.markTerminated(null); + var terminationException = mock(Neo4jException.class); + tx.markTerminated(terminationException); - ClientException e = assertThrows(ClientException.class, () -> txRun(tx, "CREATE (:MyOtherLabel)")); + var e = assertThrows(TransactionTerminatedException.class, () -> txRun(tx, "CREATE (:MyOtherLabel)")); assertThat(e.getMessage(), startsWith("Cannot run more queries in this transaction")); + assertEquals(e.getCause(), terminationException); } @Test @@ -166,7 +171,7 @@ void shouldBePossibleToRunMoreTransactionsAfterOneIsTerminated() { tx1.markTerminated(null); // commit should fail, make session forget about this transaction and release the connection to the pool - ClientException e = assertThrows(ClientException.class, () -> await(tx1.commitAsync())); + var e = assertThrows(TransactionTerminatedException.class, () -> await(tx1.commitAsync())); assertThat(e.getMessage(), startsWith("Transaction can't be committed")); await(session.beginTransactionAsync(TransactionConfig.empty()) diff --git a/driver/src/test/java/org/neo4j/driver/integration/reactive/ReactiveTransactionIT.java b/driver/src/test/java/org/neo4j/driver/integration/reactive/ReactiveTransactionIT.java new file mode 100644 index 0000000000..dcc27b9731 --- /dev/null +++ b/driver/src/test/java/org/neo4j/driver/integration/reactive/ReactiveTransactionIT.java @@ -0,0 +1,195 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.integration.reactive; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.util.List; +import java.util.Map; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.neo4j.driver.Config; +import org.neo4j.driver.exceptions.ClientException; +import org.neo4j.driver.exceptions.TransactionTerminatedException; +import org.neo4j.driver.internal.reactivestreams.InternalReactiveTransaction; +import org.neo4j.driver.reactivestreams.ReactiveSession; +import org.neo4j.driver.testutil.DatabaseExtension; +import org.neo4j.driver.testutil.ParallelizableIT; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +@ParallelizableIT +class ReactiveTransactionIT { + @RegisterExtension + static final DatabaseExtension neo4j = new DatabaseExtension(); + + @Test + void shouldPreventPullAfterTransactionTermination() { + // Given + var session = neo4j.driver().session(ReactiveSession.class); + var tx = Mono.fromDirect(session.beginTransaction()).block(); + var streamSize = Config.defaultConfig().fetchSize() + 1; + var result0 = Mono.fromDirect(tx.run("UNWIND range(1, $limit) AS x RETURN x", Map.of("limit", streamSize))) + .block(); + var result1 = Mono.fromDirect(tx.run("UNWIND range(1, $limit) AS x RETURN x", Map.of("limit", streamSize))) + .block(); + + // When + var terminationException = assertThrows( + ClientException.class, () -> Mono.fromDirect(tx.run("invalid")).block()); + assertEquals(terminationException.code(), "Neo.ClientError.Statement.SyntaxError"); + + // Then + for (var result : List.of(result0, result1)) { + var exception = assertThrows( + ClientException.class, () -> Flux.from(result.records()).blockFirst()); + assertEquals(terminationException, exception); + } + Mono.fromDirect(tx.close()).block(); + } + + @Test + void shouldPreventDiscardAfterTransactionTermination() { + // Given + var session = neo4j.driver().session(ReactiveSession.class); + var tx = Mono.fromDirect(session.beginTransaction()).block(); + var streamSize = Config.defaultConfig().fetchSize() + 1; + var result0 = Mono.fromDirect(tx.run("UNWIND range(1, $limit) AS x RETURN x", Map.of("limit", streamSize))) + .block(); + var result1 = Mono.fromDirect(tx.run("UNWIND range(1, $limit) AS x RETURN x", Map.of("limit", streamSize))) + .block(); + + // When + var terminationException = assertThrows( + ClientException.class, () -> Mono.fromDirect(tx.run("invalid")).block()); + assertEquals(terminationException.code(), "Neo.ClientError.Statement.SyntaxError"); + + // Then + for (var result : List.of(result0, result1)) { + var exception = assertThrows(ClientException.class, () -> Mono.fromDirect(result.consume()) + .block()); + assertEquals(terminationException, exception); + } + Mono.fromDirect(tx.close()).block(); + } + + @Test + void shouldPreventRunAfterTransactionTermination() { + // Given + var session = neo4j.driver().session(ReactiveSession.class); + var tx = Mono.fromDirect(session.beginTransaction()).block(); + var terminationException = assertThrows( + ClientException.class, () -> Mono.fromDirect(tx.run("invalid")).block()); + assertEquals(terminationException.code(), "Neo.ClientError.Statement.SyntaxError"); + + // When + var exception = assertThrows(TransactionTerminatedException.class, () -> Mono.fromDirect(tx.run("RETURN 1")) + .block()); + + // Then + assertEquals(terminationException, exception.getCause()); + Mono.fromDirect(tx.close()).block(); + } + + @Test + void shouldPreventPullAfterDriverTransactionTermination() { + // Given + var session = neo4j.driver().session(ReactiveSession.class); + var tx = (InternalReactiveTransaction) + Mono.fromDirect(session.beginTransaction()).block(); + var streamSize = Config.defaultConfig().fetchSize() + 1; + var result0 = Mono.fromDirect(tx.run("UNWIND range(1, $limit) AS x RETURN x", Map.of("limit", streamSize))) + .block(); + var result1 = Mono.fromDirect(tx.run("UNWIND range(1, $limit) AS x RETURN x", Map.of("limit", streamSize))) + .block(); + + // When + Mono.fromDirect(tx.terminate()).block(); + + // Then + for (var result : List.of(result0, result1)) { + assertThrows(TransactionTerminatedException.class, () -> Flux.from(result.records()) + .blockFirst()); + } + Mono.fromDirect(tx.close()).block(); + } + + @Test + void shouldPreventDiscardAfterDriverTransactionTermination() { + // Given + var session = neo4j.driver().session(ReactiveSession.class); + var tx = (InternalReactiveTransaction) + Mono.fromDirect(session.beginTransaction()).block(); + var streamSize = Config.defaultConfig().fetchSize() + 1; + var result0 = Mono.fromDirect(tx.run("UNWIND range(1, $limit) AS x RETURN x", Map.of("limit", streamSize))) + .block(); + var result1 = Mono.fromDirect(tx.run("UNWIND range(1, $limit) AS x RETURN x", Map.of("limit", streamSize))) + .block(); + + // When + Mono.fromDirect(tx.terminate()).block(); + + // Then + for (var result : List.of(result0, result1)) { + assertThrows(TransactionTerminatedException.class, () -> Mono.fromDirect(result.consume()) + .block()); + } + Mono.fromDirect(tx.close()).block(); + } + + @Test + void shouldPreventRunAfterDriverTransactionTermination() { + // Given + var session = neo4j.driver().session(ReactiveSession.class); + var tx = (InternalReactiveTransaction) + Mono.fromDirect(session.beginTransaction()).block(); + var streamSize = Config.defaultConfig().fetchSize() + 1; + Mono.fromDirect(tx.run("UNWIND range(1, $limit) AS x RETURN x", Map.of("limit", streamSize))) + .block(); + + // When + Mono.fromDirect(tx.terminate()).block(); + + // Then + assertThrows( + TransactionTerminatedException.class, () -> Mono.fromDirect(tx.run("UNWIND range(0, 5) AS x RETURN x")) + .block()); + Mono.fromDirect(tx.close()).block(); + } + + @Test + void shouldTerminateTransactionAndHandleFailureResponseOrPreventFurtherPulls() { + // Given + var session = neo4j.driver().session(ReactiveSession.class); + var tx = (InternalReactiveTransaction) + Mono.fromDirect(session.beginTransaction()).block(); + var streamSize = Config.defaultConfig().fetchSize() + 1; + var result = Mono.fromDirect(tx.run("UNWIND range(1, $limit) AS x RETURN x", Map.of("limit", streamSize))) + .block(); + + // When + Mono.fromDirect(tx.terminate()).block(); + + // Then + assertThrows(TransactionTerminatedException.class, () -> Flux.from(result.records()) + .blockLast()); + Mono.fromDirect(tx.close()).block(); + } +} diff --git a/driver/src/test/java/org/neo4j/driver/internal/async/NetworkConnectionTest.java b/driver/src/test/java/org/neo4j/driver/internal/async/NetworkConnectionTest.java index 3e0036cb45..107b2348a7 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/async/NetworkConnectionTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/async/NetworkConnectionTest.java @@ -107,14 +107,6 @@ void shouldWriteInEventLoopThread() throws Exception { "WriteSingleMessage", connection -> connection.write( RunWithMetadataMessage.unmanagedTxRunMessage(new Query("RETURN 1")), NO_OP_HANDLER)); - - testWriteInEventLoop( - "WriteMultipleMessages", - connection -> connection.write( - RunWithMetadataMessage.unmanagedTxRunMessage(new Query("RETURN 1")), - NO_OP_HANDLER, - PULL_ALL, - NO_OP_HANDLER)); } @Test @@ -123,14 +115,6 @@ void shouldWriteAndFlushInEventLoopThread() throws Exception { "WriteAndFlushSingleMessage", connection -> connection.writeAndFlush( RunWithMetadataMessage.unmanagedTxRunMessage(new Query("RETURN 1")), NO_OP_HANDLER)); - - testWriteInEventLoop( - "WriteAndFlushMultipleMessages", - connection -> connection.writeAndFlush( - RunWithMetadataMessage.unmanagedTxRunMessage(new Query("RETURN 1")), - NO_OP_HANDLER, - PULL_ALL, - NO_OP_HANDLER)); } @Test @@ -138,19 +122,6 @@ void shouldWriteForceReleaseInEventLoopThread() throws Exception { testWriteInEventLoop("ReleaseTestEventLoop", NetworkConnection::release); } - @Test - void shouldFlushInEventLoopThread() throws Exception { - EmbeddedChannel channel = spy(new EmbeddedChannel()); - initializeEventLoop(channel, "Flush"); - ChannelAttributes.setProtocolVersion(channel, DEFAULT_TEST_PROTOCOL_VERSION); - - NetworkConnection connection = newConnection(channel); - connection.flush(); - - shutdownEventLoop(); - verify(channel).flush(); - } - @Test void shouldEnableAutoReadWhenReleased() { EmbeddedChannel channel = newChannel(); @@ -189,20 +160,6 @@ void shouldWriteSingleMessage() { assertEquals(PULL_ALL, single(channel.outboundMessages())); } - @Test - void shouldWriteMultipleMessage() { - EmbeddedChannel channel = newChannel(); - NetworkConnection connection = newConnection(channel); - - connection.write(PULL_ALL, NO_OP_HANDLER, RESET, NO_OP_HANDLER); - - assertEquals(0, channel.outboundMessages().size()); - channel.flushOutbound(); - assertEquals(2, channel.outboundMessages().size()); - assertEquals(PULL_ALL, channel.outboundMessages().poll()); - assertEquals(RESET, channel.outboundMessages().poll()); - } - @Test void shouldWriteAndFlushSingleMessage() { EmbeddedChannel channel = newChannel(); @@ -216,20 +173,6 @@ void shouldWriteAndFlushSingleMessage() { assertEquals(PULL_ALL, single(channel.outboundMessages())); } - @Test - void shouldWriteAndFlushMultipleMessage() { - EmbeddedChannel channel = newChannel(); - NetworkConnection connection = newConnection(channel); - - connection.writeAndFlush(PULL_ALL, NO_OP_HANDLER, RESET, NO_OP_HANDLER); - channel.runPendingTasks(); // writeAndFlush is scheduled to execute in the event loop thread, trigger its - // execution - - assertEquals(2, channel.outboundMessages().size()); - assertEquals(PULL_ALL, channel.outboundMessages().poll()); - assertEquals(RESET, channel.outboundMessages().poll()); - } - @Test void shouldNotWriteSingleMessageWhenReleased() { ResponseHandler handler = mock(ResponseHandler.class); @@ -243,24 +186,6 @@ void shouldNotWriteSingleMessageWhenReleased() { assertConnectionReleasedError(failureCaptor.getValue()); } - @Test - void shouldNotWriteMultipleMessagesWhenReleased() { - ResponseHandler runHandler = mock(ResponseHandler.class); - ResponseHandler pullAllHandler = mock(ResponseHandler.class); - NetworkConnection connection = newConnection(newChannel()); - - connection.release(); - connection.write( - RunWithMetadataMessage.unmanagedTxRunMessage(new Query("RETURN 1")), - runHandler, - PULL_ALL, - pullAllHandler); - - ArgumentCaptor failureCaptor = ArgumentCaptor.forClass(IllegalStateException.class); - verify(runHandler).onFailure(failureCaptor.capture()); - assertConnectionReleasedError(failureCaptor.getValue()); - } - @Test void shouldNotWriteAndFlushSingleMessageWhenReleased() { ResponseHandler handler = mock(ResponseHandler.class); @@ -274,24 +199,6 @@ void shouldNotWriteAndFlushSingleMessageWhenReleased() { assertConnectionReleasedError(failureCaptor.getValue()); } - @Test - void shouldNotWriteAndFlushMultipleMessagesWhenReleased() { - ResponseHandler runHandler = mock(ResponseHandler.class); - ResponseHandler pullAllHandler = mock(ResponseHandler.class); - NetworkConnection connection = newConnection(newChannel()); - - connection.release(); - connection.writeAndFlush( - RunWithMetadataMessage.unmanagedTxRunMessage(new Query("RETURN 1")), - runHandler, - PULL_ALL, - pullAllHandler); - - ArgumentCaptor failureCaptor = ArgumentCaptor.forClass(IllegalStateException.class); - verify(runHandler).onFailure(failureCaptor.capture()); - assertConnectionReleasedError(failureCaptor.getValue()); - } - @Test void shouldNotWriteSingleMessageWhenTerminated() { ResponseHandler handler = mock(ResponseHandler.class); @@ -305,24 +212,6 @@ void shouldNotWriteSingleMessageWhenTerminated() { assertConnectionTerminatedError(failureCaptor.getValue()); } - @Test - void shouldNotWriteMultipleMessagesWhenTerminated() { - ResponseHandler runHandler = mock(ResponseHandler.class); - ResponseHandler pullAllHandler = mock(ResponseHandler.class); - NetworkConnection connection = newConnection(newChannel()); - - connection.terminateAndRelease("42"); - connection.write( - RunWithMetadataMessage.unmanagedTxRunMessage(new Query("RETURN 1")), - runHandler, - PULL_ALL, - pullAllHandler); - - ArgumentCaptor failureCaptor = ArgumentCaptor.forClass(IllegalStateException.class); - verify(runHandler).onFailure(failureCaptor.capture()); - assertConnectionTerminatedError(failureCaptor.getValue()); - } - @Test void shouldNotWriteAndFlushSingleMessageWhenTerminated() { ResponseHandler handler = mock(ResponseHandler.class); @@ -336,24 +225,6 @@ void shouldNotWriteAndFlushSingleMessageWhenTerminated() { assertConnectionTerminatedError(failureCaptor.getValue()); } - @Test - void shouldNotWriteAndFlushMultipleMessagesWhenTerminated() { - ResponseHandler runHandler = mock(ResponseHandler.class); - ResponseHandler pullAllHandler = mock(ResponseHandler.class); - NetworkConnection connection = newConnection(newChannel()); - - connection.terminateAndRelease("42"); - connection.writeAndFlush( - RunWithMetadataMessage.unmanagedTxRunMessage(new Query("RETURN 1")), - runHandler, - PULL_ALL, - pullAllHandler); - - ArgumentCaptor failureCaptor = ArgumentCaptor.forClass(IllegalStateException.class); - verify(runHandler).onFailure(failureCaptor.capture()); - assertConnectionTerminatedError(failureCaptor.getValue()); - } - @Test void shouldReturnServerAgentWhenCreated() { EmbeddedChannel channel = newChannel(); @@ -503,7 +374,7 @@ void shouldSendResetMessageWhenReset() { EmbeddedChannel channel = newChannel(); NetworkConnection connection = newConnection(channel); - connection.reset(); + connection.reset(null); channel.runPendingTasks(); assertEquals(1, channel.outboundMessages().size()); @@ -515,7 +386,7 @@ void shouldCompleteResetFutureWhenSuccessResponseArrives() { EmbeddedChannel channel = newChannel(); NetworkConnection connection = newConnection(channel); - CompletableFuture resetFuture = connection.reset().toCompletableFuture(); + CompletableFuture resetFuture = connection.reset(null).toCompletableFuture(); channel.runPendingTasks(); assertFalse(resetFuture.isDone()); @@ -529,7 +400,7 @@ void shouldCompleteResetFutureWhenFailureResponseArrives() { EmbeddedChannel channel = newChannel(); NetworkConnection connection = newConnection(channel); - CompletableFuture resetFuture = connection.reset().toCompletableFuture(); + CompletableFuture resetFuture = connection.reset(null).toCompletableFuture(); channel.runPendingTasks(); assertFalse(resetFuture.isDone()); @@ -546,7 +417,7 @@ void shouldDoNothingInResetWhenClosed() { connection.release(); channel.runPendingTasks(); - CompletableFuture resetFuture = connection.reset().toCompletableFuture(); + CompletableFuture resetFuture = connection.reset(null).toCompletableFuture(); channel.runPendingTasks(); assertEquals(1, channel.outboundMessages().size()); @@ -561,7 +432,7 @@ void shouldEnableAutoReadWhenDoingReset() { channel.config().setAutoRead(false); NetworkConnection connection = newConnection(channel); - connection.reset(); + connection.reset(null); channel.runPendingTasks(); assertTrue(channel.config().isAutoRead()); diff --git a/driver/src/test/java/org/neo4j/driver/internal/async/NetworkSessionTest.java b/driver/src/test/java/org/neo4j/driver/internal/async/NetworkSessionTest.java index 7a3715bb87..2c069bbbe1 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/async/NetworkSessionTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/async/NetworkSessionTest.java @@ -328,11 +328,11 @@ void connectionShouldBeResetAfterSessionReset() { run(session, query); InOrder connectionInOrder = inOrder(connection); - connectionInOrder.verify(connection, never()).reset(); + connectionInOrder.verify(connection, never()).reset(null); connectionInOrder.verify(connection).release(); await(session.resetAsync()); - connectionInOrder.verify(connection).reset(); + connectionInOrder.verify(connection).reset(null); connectionInOrder.verify(connection, never()).release(); } @@ -467,11 +467,11 @@ void shouldMarkTransactionAsTerminatedAndThenResetConnectionOnReset() { UnmanagedTransaction tx = beginTransaction(session); assertTrue(tx.isOpen()); - verify(connection, never()).reset(); + verify(connection, never()).reset(null); await(session.resetAsync()); - verify(connection).reset(); + verify(connection).reset(any()); } private static ResultCursor run(NetworkSession session, String query) { diff --git a/driver/src/test/java/org/neo4j/driver/internal/async/UnmanagedTransactionTest.java b/driver/src/test/java/org/neo4j/driver/internal/async/UnmanagedTransactionTest.java index 658755e1d3..54d3c62ea6 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/async/UnmanagedTransactionTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/async/UnmanagedTransactionTest.java @@ -71,6 +71,7 @@ import org.neo4j.driver.exceptions.ClientException; import org.neo4j.driver.exceptions.ConnectionReadTimeoutException; import org.neo4j.driver.exceptions.Neo4jException; +import org.neo4j.driver.exceptions.TransactionTerminatedException; import org.neo4j.driver.internal.FailableCursor; import org.neo4j.driver.internal.InternalBookmark; import org.neo4j.driver.internal.messaging.BoltProtocol; @@ -130,7 +131,6 @@ void shouldOnlyQueueMessagesWhenNoBookmarkGiven() { beginTx(connection, Collections.emptySet()); verifyBeginTx(connection); - verify(connection, never()).writeAndFlush(any(), any(), any(), any()); } @Test @@ -141,7 +141,6 @@ void shouldFlushWhenBookmarkGiven() { beginTx(connection, bookmarks); verifyBeginTx(connection); - verify(connection, never()).write(any(), any(), any(), any()); } @Test @@ -153,7 +152,7 @@ void shouldBeOpenAfterConstruction() { @Test void shouldBeClosedWhenMarkedAsTerminated() { - UnmanagedTransaction tx = beginTx(connectionMock()); + var tx = beginTx(connectionMock()); tx.markTerminated(null); @@ -162,7 +161,7 @@ void shouldBeClosedWhenMarkedAsTerminated() { @Test void shouldBeClosedWhenMarkedTerminatedAndClosed() { - UnmanagedTransaction tx = beginTx(connectionMock()); + var tx = beginTx(connectionMock()); tx.markTerminated(null); await(tx.closeAsync()); @@ -201,12 +200,12 @@ void shouldNotReleaseConnectionWhenBeginSucceeds() { @Test void shouldReleaseConnectionWhenTerminatedAndCommitted() { - Connection connection = connectionMock(); - UnmanagedTransaction tx = new UnmanagedTransaction(connection, (ignored) -> {}, UNLIMITED_FETCH_SIZE, null); + var connection = connectionMock(); + var tx = new UnmanagedTransaction(connection, (ignored) -> {}, UNLIMITED_FETCH_SIZE, null); tx.markTerminated(null); - assertThrows(ClientException.class, () -> await(tx.commitAsync())); + assertThrows(TransactionTerminatedException.class, () -> await(tx.commitAsync())); assertFalse(tx.isOpen()); verify(connection).release(); @@ -214,30 +213,28 @@ void shouldReleaseConnectionWhenTerminatedAndCommitted() { @Test void shouldNotCreateCircularExceptionWhenTerminationCauseEqualsToCursorFailure() { - Connection connection = connectionMock(); - ClientException terminationCause = new ClientException("Custom exception"); - ResultCursorsHolder resultCursorsHolder = mockResultCursorWith(terminationCause); - UnmanagedTransaction tx = - new UnmanagedTransaction(connection, (ignored) -> {}, UNLIMITED_FETCH_SIZE, resultCursorsHolder, null); + var connection = connectionMock(); + var terminationCause = new ClientException("Custom exception"); + var resultCursorsHolder = mockResultCursorWith(terminationCause); + var tx = new UnmanagedTransaction(connection, (ignored) -> {}, UNLIMITED_FETCH_SIZE, resultCursorsHolder, null); tx.markTerminated(terminationCause); - ClientException e = assertThrows(ClientException.class, () -> await(tx.commitAsync())); + var e = assertThrows(ClientException.class, () -> await(tx.commitAsync())); assertNoCircularReferences(e); assertEquals(terminationCause, e); } @Test void shouldNotCreateCircularExceptionWhenTerminationCauseDifferentFromCursorFailure() { - Connection connection = connectionMock(); - ClientException terminationCause = new ClientException("Custom exception"); - ResultCursorsHolder resultCursorsHolder = mockResultCursorWith(new ClientException("Cursor error")); - UnmanagedTransaction tx = - new UnmanagedTransaction(connection, (ignored) -> {}, UNLIMITED_FETCH_SIZE, resultCursorsHolder, null); + var connection = connectionMock(); + var terminationCause = new ClientException("Custom exception"); + var resultCursorsHolder = mockResultCursorWith(new ClientException("Cursor error")); + var tx = new UnmanagedTransaction(connection, (ignored) -> {}, UNLIMITED_FETCH_SIZE, resultCursorsHolder, null); tx.markTerminated(terminationCause); - ClientException e = assertThrows(ClientException.class, () -> await(tx.commitAsync())); + var e = assertThrows(ClientException.class, () -> await(tx.commitAsync())); assertNoCircularReferences(e); assertEquals(1, e.getSuppressed().length); @@ -247,13 +244,13 @@ void shouldNotCreateCircularExceptionWhenTerminationCauseDifferentFromCursorFail @Test void shouldNotCreateCircularExceptionWhenTerminatedWithoutFailure() { - Connection connection = connectionMock(); - ClientException terminationCause = new ClientException("Custom exception"); - UnmanagedTransaction tx = new UnmanagedTransaction(connection, (ignored) -> {}, UNLIMITED_FETCH_SIZE, null); + var connection = connectionMock(); + var terminationCause = new ClientException("Custom exception"); + var tx = new UnmanagedTransaction(connection, (ignored) -> {}, UNLIMITED_FETCH_SIZE, null); tx.markTerminated(terminationCause); - ClientException e = assertThrows(ClientException.class, () -> await(tx.commitAsync())); + var e = assertThrows(TransactionTerminatedException.class, () -> await(tx.commitAsync())); assertNoCircularReferences(e); assertEquals(terminationCause, e.getCause()); @@ -261,8 +258,8 @@ void shouldNotCreateCircularExceptionWhenTerminatedWithoutFailure() { @Test void shouldReleaseConnectionWhenTerminatedAndRolledBack() { - Connection connection = connectionMock(); - UnmanagedTransaction tx = new UnmanagedTransaction(connection, (ignored) -> {}, UNLIMITED_FETCH_SIZE, null); + var connection = connectionMock(); + var tx = new UnmanagedTransaction(connection, (ignored) -> {}, UNLIMITED_FETCH_SIZE, null); tx.markTerminated(null); await(tx.rollbackAsync()); @@ -271,7 +268,7 @@ void shouldReleaseConnectionWhenTerminatedAndRolledBack() { } @Test - void shouldReleaseConnectionWhenClose() throws Throwable { + void shouldReleaseConnectionWhenClose() { Connection connection = connectionMock(); UnmanagedTransaction tx = new UnmanagedTransaction(connection, (ignored) -> {}, UNLIMITED_FETCH_SIZE, null); @@ -432,34 +429,34 @@ void shouldReturnCompletedWithNullStageOnClosingInactiveTransactionExceptCommitt } @Test - void shouldInterruptOnInterruptAsync() { + void shouldTerminateOnTerminateAsync() { // Given - Connection connection = connectionMock(BoltProtocolV4.INSTANCE); - UnmanagedTransaction tx = beginTx(connection); + var connection = connectionMock(BoltProtocolV4.INSTANCE); + var tx = beginTx(connection); // When - await(tx.interruptAsync()); + await(tx.terminateAsync()); // Then - then(connection).should().reset(); + then(connection).should().reset(any()); } @Test - void shouldServeTheSameStageOnInterruptAsync() { + void shouldServeTheSameStageOnTerminateAsync() { // Given Connection connection = connectionMock(BoltProtocolV4.INSTANCE); UnmanagedTransaction tx = beginTx(connection); // When - CompletionStage stage0 = tx.interruptAsync(); - CompletionStage stage1 = tx.interruptAsync(); + CompletionStage stage0 = tx.terminateAsync(); + CompletionStage stage1 = tx.terminateAsync(); // Then assertEquals(stage0, stage1); } @Test - void shouldHandleInterruptionWhenAlreadyInterrupted() throws ExecutionException, InterruptedException { + void shouldHandleTerminationWhenAlreadyTerminated() throws ExecutionException, InterruptedException { // Given var connection = connectionMock(BoltProtocolV4.INSTANCE); var exception = new Neo4jException("message"); @@ -473,7 +470,7 @@ void shouldHandleInterruptionWhenAlreadyInterrupted() throws ExecutionException, } catch (ExecutionException e) { actualException = e.getCause(); } - tx.interruptAsync().toCompletableFuture().get(); + tx.terminateAsync().toCompletableFuture().get(); // Then assertEquals(exception, actualException); diff --git a/driver/src/test/java/org/neo4j/driver/internal/async/connection/DecoratedConnectionTest.java b/driver/src/test/java/org/neo4j/driver/internal/async/connection/DecoratedConnectionTest.java index 227cd66e23..313e7ad465 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/async/connection/DecoratedConnectionTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/async/connection/DecoratedConnectionTest.java @@ -84,21 +84,6 @@ void shouldDelegateWrite() { verify(mockConnection).write(message, handler); } - @Test - void shouldDelegateWriteTwoMessages() { - Connection mockConnection = mock(Connection.class); - DirectConnection connection = newConnection(mockConnection); - - Message message1 = mock(Message.class); - ResponseHandler handler1 = mock(ResponseHandler.class); - Message message2 = mock(Message.class); - ResponseHandler handler2 = mock(ResponseHandler.class); - - connection.write(message1, handler1, message2, handler2); - - verify(mockConnection).write(message1, handler1, message2, handler2); - } - @Test void shouldDelegateWriteAndFlush() { Connection mockConnection = mock(Connection.class); @@ -112,29 +97,14 @@ void shouldDelegateWriteAndFlush() { verify(mockConnection).writeAndFlush(message, handler); } - @Test - void shouldDelegateWriteAndFlush1() { - Connection mockConnection = mock(Connection.class); - DirectConnection connection = newConnection(mockConnection); - - Message message1 = mock(Message.class); - ResponseHandler handler1 = mock(ResponseHandler.class); - Message message2 = mock(Message.class); - ResponseHandler handler2 = mock(ResponseHandler.class); - - connection.writeAndFlush(message1, handler1, message2, handler2); - - verify(mockConnection).writeAndFlush(message1, handler1, message2, handler2); - } - @Test void shouldDelegateReset() { Connection mockConnection = mock(Connection.class); DirectConnection connection = newConnection(mockConnection); - connection.reset(); + connection.reset(null); - verify(mockConnection).reset(); + verify(mockConnection).reset(null); } @Test diff --git a/driver/src/test/java/org/neo4j/driver/internal/async/connection/RoutingConnectionTest.java b/driver/src/test/java/org/neo4j/driver/internal/async/connection/RoutingConnectionTest.java index f8a3bd4341..0efef85ce2 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/async/connection/RoutingConnectionTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/async/connection/RoutingConnectionTest.java @@ -28,7 +28,6 @@ import static org.mockito.Mockito.verify; import static org.neo4j.driver.AccessMode.READ; import static org.neo4j.driver.internal.DatabaseNameUtil.defaultDatabase; -import static org.neo4j.driver.internal.messaging.request.DiscardAllMessage.DISCARD_ALL; import static org.neo4j.driver.internal.messaging.request.PullAllMessage.PULL_ALL; import org.junit.jupiter.api.Test; @@ -49,16 +48,6 @@ void shouldWrapHandlersWhenWritingAndFlushingSingleMessage() { testHandlersWrappingWithSingleMessage(true); } - @Test - void shouldWrapHandlersWhenWritingMultipleMessages() { - testHandlersWrappingWithMultipleMessages(false); - } - - @Test - void shouldWrapHandlersWhenWritingAndFlushingMultipleMessages() { - testHandlersWrappingWithMultipleMessages(true); - } - @Test void shouldReturnServerAgent() { // given @@ -99,31 +88,4 @@ private static void testHandlersWrappingWithSingleMessage(boolean flush) { assertThat(handlerCaptor.getValue(), instanceOf(RoutingResponseHandler.class)); } - - private static void testHandlersWrappingWithMultipleMessages(boolean flush) { - Connection connection = mock(Connection.class); - RoutingErrorHandler errorHandler = mock(RoutingErrorHandler.class); - RoutingConnection routingConnection = - new RoutingConnection(connection, defaultDatabase(), READ, null, errorHandler); - - if (flush) { - routingConnection.writeAndFlush( - PULL_ALL, mock(ResponseHandler.class), DISCARD_ALL, mock(ResponseHandler.class)); - } else { - routingConnection.write(PULL_ALL, mock(ResponseHandler.class), DISCARD_ALL, mock(ResponseHandler.class)); - } - - ArgumentCaptor handlerCaptor1 = ArgumentCaptor.forClass(ResponseHandler.class); - ArgumentCaptor handlerCaptor2 = ArgumentCaptor.forClass(ResponseHandler.class); - - if (flush) { - verify(connection) - .writeAndFlush(eq(PULL_ALL), handlerCaptor1.capture(), eq(DISCARD_ALL), handlerCaptor2.capture()); - } else { - verify(connection).write(eq(PULL_ALL), handlerCaptor1.capture(), eq(DISCARD_ALL), handlerCaptor2.capture()); - } - - assertThat(handlerCaptor1.getValue(), instanceOf(RoutingResponseHandler.class)); - assertThat(handlerCaptor2.getValue(), instanceOf(RoutingResponseHandler.class)); - } } diff --git a/driver/src/test/java/org/neo4j/driver/internal/reactive/InternalReactiveTransactionTest.java b/driver/src/test/java/org/neo4j/driver/internal/reactive/InternalReactiveTransactionTest.java deleted file mode 100644 index a4ab8117d0..0000000000 --- a/driver/src/test/java/org/neo4j/driver/internal/reactive/InternalReactiveTransactionTest.java +++ /dev/null @@ -1,67 +0,0 @@ -/* - * Copyright (c) "Neo4j" - * Neo4j Sweden AB [http://neo4j.com] - * - * This file is part of Neo4j. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.neo4j.driver.internal.reactive; - -import static java.util.concurrent.CompletableFuture.completedFuture; -import static org.mockito.BDDMockito.given; -import static org.mockito.BDDMockito.then; -import static org.mockito.Mockito.mock; -import static org.neo4j.driver.internal.util.Futures.failedFuture; -import static reactor.adapter.JdkFlowAdapter.flowPublisherToFlux; - -import org.junit.jupiter.api.Test; -import org.neo4j.driver.internal.async.UnmanagedTransaction; -import reactor.test.StepVerifier; - -public class InternalReactiveTransactionTest { - private InternalReactiveTransaction tx; - - @Test - void shouldDelegateInterrupt() { - // Given - UnmanagedTransaction utx = mock(UnmanagedTransaction.class); - given(utx.interruptAsync()).willReturn(completedFuture(null)); - tx = new InternalReactiveTransaction(utx); - - // When - StepVerifier.create(flowPublisherToFlux(tx.interrupt())) - .expectComplete() - .verify(); - - // Then - then(utx).should().interruptAsync(); - } - - @Test - void shouldDelegateInterruptAndReportError() { - // Given - UnmanagedTransaction utx = mock(UnmanagedTransaction.class); - RuntimeException e = mock(RuntimeException.class); - given(utx.interruptAsync()).willReturn(failedFuture(e)); - tx = new InternalReactiveTransaction(utx); - - // When - StepVerifier.create(flowPublisherToFlux(tx.interrupt())) - .expectErrorMatches(ar -> ar == e) - .verify(); - - // Then - then(utx).should().interruptAsync(); - } -} diff --git a/driver/src/test/java/org/neo4j/driver/internal/util/FailingConnectionDriverFactory.java b/driver/src/test/java/org/neo4j/driver/internal/util/FailingConnectionDriverFactory.java index f3ed13395a..ae13f7616b 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/util/FailingConnectionDriverFactory.java +++ b/driver/src/test/java/org/neo4j/driver/internal/util/FailingConnectionDriverFactory.java @@ -28,6 +28,7 @@ import org.neo4j.driver.Config; import org.neo4j.driver.internal.BoltServerAddress; import org.neo4j.driver.internal.DriverFactory; +import org.neo4j.driver.internal.async.UnmanagedTransaction; import org.neo4j.driver.internal.cluster.RoutingContext; import org.neo4j.driver.internal.messaging.BoltProtocol; import org.neo4j.driver.internal.messaging.Message; @@ -133,14 +134,6 @@ public void write(Message message, ResponseHandler handler) { delegate.write(message, handler); } - @Override - public void write(Message message1, ResponseHandler handler1, Message message2, ResponseHandler handler2) { - if (tryFail(handler1, handler2)) { - return; - } - delegate.write(message1, handler1, message2, handler2); - } - @Override public void writeAndFlush(Message message, ResponseHandler handler) { if (tryFail(handler, null)) { @@ -150,17 +143,8 @@ public void writeAndFlush(Message message, ResponseHandler handler) { } @Override - public void writeAndFlush( - Message message1, ResponseHandler handler1, Message message2, ResponseHandler handler2) { - if (tryFail(handler1, handler2)) { - return; - } - delegate.writeAndFlush(message1, handler1, message2, handler2); - } - - @Override - public CompletionStage reset() { - return delegate.reset(); + public CompletionStage reset(Throwable throwable) { + return delegate.reset(throwable); } @Override @@ -189,11 +173,8 @@ public BoltProtocol protocol() { } @Override - public void flush() { - if (tryFail(null, null)) { - return; - } - delegate.flush(); + public void bindTransaction(UnmanagedTransaction transaction) { + delegate.bindTransaction(transaction); } private boolean tryFail(ResponseHandler handler1, ResponseHandler handler2) { diff --git a/driver/src/test/java/org/neo4j/driver/testutil/TestUtil.java b/driver/src/test/java/org/neo4j/driver/testutil/TestUtil.java index 516f1c918b..d98a42259d 100644 --- a/driver/src/test/java/org/neo4j/driver/testutil/TestUtil.java +++ b/driver/src/test/java/org/neo4j/driver/testutil/TestUtil.java @@ -466,7 +466,7 @@ public static Connection connectionMock(String databaseName, AccessMode mode, Bo setupSuccessResponse(connection, RollbackMessage.class); setupSuccessResponse(connection, BeginMessage.class); when(connection.release()).thenReturn(completedWithNull()); - when(connection.reset()).thenReturn(completedWithNull()); + when(connection.reset(any())).thenReturn(completedWithNull()); } else { throw new IllegalArgumentException("Unsupported bolt protocol version: " + version); } diff --git a/testkit-backend/src/main/java/neo4j/org/testkit/backend/messages/requests/StartTest.java b/testkit-backend/src/main/java/neo4j/org/testkit/backend/messages/requests/StartTest.java index a43895431c..89717dcf6d 100644 --- a/testkit-backend/src/main/java/neo4j/org/testkit/backend/messages/requests/StartTest.java +++ b/testkit-backend/src/main/java/neo4j/org/testkit/backend/messages/requests/StartTest.java @@ -151,6 +151,14 @@ public class StartTest implements TestkitRequest { REACTIVE_LEGACY_SKIP_PATTERN_TO_REASON.put( "^.*\\.Routing[^.]+\\.test_should_fail_when_writing_on_unexpectedly_interrupting_writer_on_run_using_tx_run$", skipMessage); + REACTIVE_LEGACY_SKIP_PATTERN_TO_REASON.put( + "^.*\\.TestTxRun\\.test_should_prevent_pull_after_tx_termination_on_run$", skipMessage); + REACTIVE_LEGACY_SKIP_PATTERN_TO_REASON.put( + "^.*\\.TestTxRun\\.test_should_prevent_discard_after_tx_termination_on_run$", skipMessage); + REACTIVE_LEGACY_SKIP_PATTERN_TO_REASON.put( + "^.*\\.TestTxRun\\.test_should_prevent_run_after_tx_termination_on_run$", skipMessage); + REACTIVE_LEGACY_SKIP_PATTERN_TO_REASON.put( + "^.*\\.TestTxRun\\.test_should_prevent_run_after_tx_termination_on_pull$", skipMessage); skipMessage = "Does not support multiple concurrent result streams on session level"; REACTIVE_LEGACY_SKIP_PATTERN_TO_REASON.put("^.*\\.TestSessionRun\\.test_iteration_nested$", skipMessage); REACTIVE_LEGACY_SKIP_PATTERN_TO_REASON.put("^.*\\.TestSessionRun\\.test_partial_iteration$", skipMessage); From 964d5b2a12e59cb9e84962af2f2fd97e2fe64896 Mon Sep 17 00:00:00 2001 From: Dmitriy Tverdiakov Date: Wed, 5 Jul 2023 11:34:17 +0100 Subject: [PATCH 2/4] Remove UnmanagedTransaction from NetworkConnection --- .../internal/async/NetworkConnection.java | 23 +++++++------- .../TerminationAwareStateLockingExecutor.java | 31 +++++++++++++++++++ .../internal/async/UnmanagedTransaction.java | 13 +++----- .../async/connection/DirectConnection.java | 6 ++-- .../async/connection/RoutingConnection.java | 6 ++-- .../neo4j/driver/internal/spi/Connection.java | 4 +-- .../util/FailingConnectionDriverFactory.java | 6 ++-- 7 files changed, 57 insertions(+), 32 deletions(-) create mode 100644 driver/src/main/java/org/neo4j/driver/internal/async/TerminationAwareStateLockingExecutor.java diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/NetworkConnection.java b/driver/src/main/java/org/neo4j/driver/internal/async/NetworkConnection.java index eccea79483..6812475ad6 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/NetworkConnection.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/NetworkConnection.java @@ -32,7 +32,6 @@ import java.util.concurrent.TimeUnit; import java.util.concurrent.locks.Lock; import java.util.concurrent.locks.ReentrantLock; -import java.util.function.Consumer; import org.neo4j.driver.Logger; import org.neo4j.driver.Logging; import org.neo4j.driver.internal.BoltServerAddress; @@ -76,7 +75,7 @@ public class NetworkConnection implements Connection { private final Long connectionReadTimeout; private Status status = Status.OPEN; - private UnmanagedTransaction transaction; + private TerminationAwareStateLockingExecutor terminationAwareStateLockingExecutor; private ChannelHandler connectionReadTimeoutHandler; public NetworkConnection( @@ -185,12 +184,12 @@ public BoltProtocol protocol() { } @Override - public void bindTransaction(UnmanagedTransaction transaction) { + public void bindTerminationAwareStateLockingExecutor(TerminationAwareStateLockingExecutor executor) { executeWithLock(lock, () -> { - if (this.transaction != null) { - throw new IllegalStateException("transaction is already set"); + if (this.terminationAwareStateLockingExecutor != null) { + throw new IllegalStateException("terminationAwareStateLockingExecutor is already set"); } - this.transaction = transaction; + this.terminationAwareStateLockingExecutor = terminationAwareStateLockingExecutor; }); } @@ -219,7 +218,7 @@ private void writeResetMessageIfNeeded(ResponseHandler resetHandler, boolean isS private void writeMessageInEventLoop(Message message, ResponseHandler handler, boolean flush) { channel.eventLoop() - .execute(() -> transactionTerminationAwareExecutor(message).accept(causeOfTermination -> { + .execute(() -> transactionTerminationAwareExecutor(message).execute(causeOfTermination -> { if (causeOfTermination == null) { messageDispatcher.enqueue(handler); @@ -280,12 +279,12 @@ private void registerConnectionReadTimeout(Channel channel) { } } - private Consumer> transactionTerminationAwareExecutor(Message message) { - var result = (Consumer>) consumer -> consumer.accept(null); + private TerminationAwareStateLockingExecutor transactionTerminationAwareExecutor(Message message) { + var result = (TerminationAwareStateLockingExecutor) consumer -> consumer.accept(null); if (isQueryMessage(message)) { - var transaction = executeWithLock(lock, () -> this.transaction); - if (transaction != null) { - result = transaction::executeWithLockedState; + var lockingExecutor = executeWithLock(lock, () -> this.terminationAwareStateLockingExecutor); + if (lockingExecutor != null) { + result = lockingExecutor; } } return result; diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/TerminationAwareStateLockingExecutor.java b/driver/src/main/java/org/neo4j/driver/internal/async/TerminationAwareStateLockingExecutor.java new file mode 100644 index 0000000000..5f937d997c --- /dev/null +++ b/driver/src/main/java/org/neo4j/driver/internal/async/TerminationAwareStateLockingExecutor.java @@ -0,0 +1,31 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal.async; + +import java.util.function.Consumer; + +@FunctionalInterface +public interface TerminationAwareStateLockingExecutor { + /** + * Locks the state and executes the supplied {@link Consumer} with a cause of termination if the state is terminated. + * + * @param causeOfTerminationConsumer the consumer accepting + */ + void execute(Consumer causeOfTerminationConsumer); +} diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/UnmanagedTransaction.java b/driver/src/main/java/org/neo4j/driver/internal/async/UnmanagedTransaction.java index 7156b0c978..4a8858d971 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/UnmanagedTransaction.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/UnmanagedTransaction.java @@ -52,7 +52,7 @@ import org.neo4j.driver.internal.messaging.BoltProtocol; import org.neo4j.driver.internal.spi.Connection; -public class UnmanagedTransaction { +public class UnmanagedTransaction implements TerminationAwareStateLockingExecutor { private enum State { /** * The transaction is running with no explicit success or failure marked @@ -121,7 +121,7 @@ protected UnmanagedTransaction( this.fetchSize = fetchSize; this.notificationConfig = notificationConfig; - connection.bindTransaction(this); + connection.bindTerminationAwareStateLockingExecutor(this); } public CompletionStage beginAsync( @@ -211,13 +211,8 @@ public Connection connection() { return connection; } - /** - * Locks the transaction state and executes the supplied {@link Consumer} with a cause of termination if the - * transaction is terminated. - * - * @param causeOfTerminationConsumer the consumer accepting - */ - public void executeWithLockedState(Consumer causeOfTerminationConsumer) { + @Override + public void execute(Consumer causeOfTerminationConsumer) { executeWithLock(lock, () -> causeOfTerminationConsumer.accept(causeOfTermination)); } diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/connection/DirectConnection.java b/driver/src/main/java/org/neo4j/driver/internal/async/connection/DirectConnection.java index 1af7b6e407..ff3d01ff44 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/connection/DirectConnection.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/connection/DirectConnection.java @@ -23,7 +23,7 @@ import org.neo4j.driver.internal.BoltServerAddress; import org.neo4j.driver.internal.DatabaseName; import org.neo4j.driver.internal.DirectConnectionProvider; -import org.neo4j.driver.internal.async.UnmanagedTransaction; +import org.neo4j.driver.internal.async.TerminationAwareStateLockingExecutor; import org.neo4j.driver.internal.messaging.BoltProtocol; import org.neo4j.driver.internal.messaging.Message; import org.neo4j.driver.internal.spi.Connection; @@ -105,8 +105,8 @@ public BoltProtocol protocol() { } @Override - public void bindTransaction(UnmanagedTransaction transaction) { - delegate.bindTransaction(transaction); + public void bindTerminationAwareStateLockingExecutor(TerminationAwareStateLockingExecutor executor) { + delegate.bindTerminationAwareStateLockingExecutor(executor); } @Override diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/connection/RoutingConnection.java b/driver/src/main/java/org/neo4j/driver/internal/async/connection/RoutingConnection.java index 0ee584ab35..77ee8d0a16 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/connection/RoutingConnection.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/connection/RoutingConnection.java @@ -23,7 +23,7 @@ import org.neo4j.driver.internal.BoltServerAddress; import org.neo4j.driver.internal.DatabaseName; import org.neo4j.driver.internal.RoutingErrorHandler; -import org.neo4j.driver.internal.async.UnmanagedTransaction; +import org.neo4j.driver.internal.async.TerminationAwareStateLockingExecutor; import org.neo4j.driver.internal.handlers.RoutingResponseHandler; import org.neo4j.driver.internal.messaging.BoltProtocol; import org.neo4j.driver.internal.messaging.Message; @@ -109,8 +109,8 @@ public BoltProtocol protocol() { } @Override - public void bindTransaction(UnmanagedTransaction transaction) { - delegate.bindTransaction(transaction); + public void bindTerminationAwareStateLockingExecutor(TerminationAwareStateLockingExecutor executor) { + delegate.bindTerminationAwareStateLockingExecutor(executor); } @Override diff --git a/driver/src/main/java/org/neo4j/driver/internal/spi/Connection.java b/driver/src/main/java/org/neo4j/driver/internal/spi/Connection.java index 17cdb318f5..1d07d9b8e3 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/spi/Connection.java +++ b/driver/src/main/java/org/neo4j/driver/internal/spi/Connection.java @@ -24,7 +24,7 @@ import org.neo4j.driver.AccessMode; import org.neo4j.driver.internal.BoltServerAddress; import org.neo4j.driver.internal.DatabaseName; -import org.neo4j.driver.internal.async.UnmanagedTransaction; +import org.neo4j.driver.internal.async.TerminationAwareStateLockingExecutor; import org.neo4j.driver.internal.messaging.BoltProtocol; import org.neo4j.driver.internal.messaging.Message; @@ -51,7 +51,7 @@ public interface Connection { BoltProtocol protocol(); - void bindTransaction(UnmanagedTransaction transaction); + void bindTerminationAwareStateLockingExecutor(TerminationAwareStateLockingExecutor executor); default AccessMode mode() { throw new UnsupportedOperationException(format("%s does not support access mode.", getClass())); diff --git a/driver/src/test/java/org/neo4j/driver/internal/util/FailingConnectionDriverFactory.java b/driver/src/test/java/org/neo4j/driver/internal/util/FailingConnectionDriverFactory.java index ae13f7616b..94f8fa7f1d 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/util/FailingConnectionDriverFactory.java +++ b/driver/src/test/java/org/neo4j/driver/internal/util/FailingConnectionDriverFactory.java @@ -28,7 +28,7 @@ import org.neo4j.driver.Config; import org.neo4j.driver.internal.BoltServerAddress; import org.neo4j.driver.internal.DriverFactory; -import org.neo4j.driver.internal.async.UnmanagedTransaction; +import org.neo4j.driver.internal.async.TerminationAwareStateLockingExecutor; import org.neo4j.driver.internal.cluster.RoutingContext; import org.neo4j.driver.internal.messaging.BoltProtocol; import org.neo4j.driver.internal.messaging.Message; @@ -173,8 +173,8 @@ public BoltProtocol protocol() { } @Override - public void bindTransaction(UnmanagedTransaction transaction) { - delegate.bindTransaction(transaction); + public void bindTerminationAwareStateLockingExecutor(TerminationAwareStateLockingExecutor executor) { + delegate.bindTerminationAwareStateLockingExecutor(executor); } private boolean tryFail(ResponseHandler handler1, ResponseHandler handler2) { From 32a7108ba6091d03fd39b74018e4c254b964f88a Mon Sep 17 00:00:00 2001 From: Dmitriy Tverdiakov Date: Wed, 5 Jul 2023 11:36:24 +0100 Subject: [PATCH 3/4] Refactoring --- .../org/neo4j/driver/internal/async/NetworkConnection.java | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/NetworkConnection.java b/driver/src/main/java/org/neo4j/driver/internal/async/NetworkConnection.java index 6812475ad6..1edc0ff835 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/NetworkConnection.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/NetworkConnection.java @@ -218,7 +218,7 @@ private void writeResetMessageIfNeeded(ResponseHandler resetHandler, boolean isS private void writeMessageInEventLoop(Message message, ResponseHandler handler, boolean flush) { channel.eventLoop() - .execute(() -> transactionTerminationAwareExecutor(message).execute(causeOfTermination -> { + .execute(() -> terminationAwareStateLockingExecutor(message).execute(causeOfTermination -> { if (causeOfTermination == null) { messageDispatcher.enqueue(handler); @@ -279,7 +279,7 @@ private void registerConnectionReadTimeout(Channel channel) { } } - private TerminationAwareStateLockingExecutor transactionTerminationAwareExecutor(Message message) { + private TerminationAwareStateLockingExecutor terminationAwareStateLockingExecutor(Message message) { var result = (TerminationAwareStateLockingExecutor) consumer -> consumer.accept(null); if (isQueryMessage(message)) { var lockingExecutor = executeWithLock(lock, () -> this.terminationAwareStateLockingExecutor); From 13b5ad03b8ccf95dc72a9493f71763106a1c69ba Mon Sep 17 00:00:00 2001 From: Dmitriy Tverdiakov Date: Wed, 5 Jul 2023 12:29:39 +0100 Subject: [PATCH 4/4] Add more tests --- .../internal/async/NetworkConnection.java | 2 +- .../internal/async/NetworkConnectionTest.java | 126 ++++++++++++++++++ 2 files changed, 127 insertions(+), 1 deletion(-) diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/NetworkConnection.java b/driver/src/main/java/org/neo4j/driver/internal/async/NetworkConnection.java index 1edc0ff835..342c71a055 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/NetworkConnection.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/NetworkConnection.java @@ -189,7 +189,7 @@ public void bindTerminationAwareStateLockingExecutor(TerminationAwareStateLockin if (this.terminationAwareStateLockingExecutor != null) { throw new IllegalStateException("terminationAwareStateLockingExecutor is already set"); } - this.terminationAwareStateLockingExecutor = terminationAwareStateLockingExecutor; + this.terminationAwareStateLockingExecutor = executor; }); } diff --git a/driver/src/test/java/org/neo4j/driver/internal/async/NetworkConnectionTest.java b/driver/src/test/java/org/neo4j/driver/internal/async/NetworkConnectionTest.java index 107b2348a7..7fabe97b60 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/async/NetworkConnectionTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/async/NetworkConnectionTest.java @@ -23,8 +23,11 @@ import static org.hamcrest.junit.MatcherAssert.assertThat; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; +import static org.mockito.BDDMockito.then; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.spy; @@ -43,6 +46,7 @@ import io.netty.channel.DefaultEventLoop; import io.netty.channel.EventLoop; import io.netty.channel.embedded.EmbeddedChannel; +import java.util.List; import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; @@ -53,13 +57,21 @@ import java.util.function.Consumer; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; import org.mockito.ArgumentCaptor; import org.neo4j.driver.Query; +import org.neo4j.driver.exceptions.Neo4jException; import org.neo4j.driver.internal.BoltServerAddress; import org.neo4j.driver.internal.async.connection.ChannelAttributes; import org.neo4j.driver.internal.async.inbound.InboundMessageDispatcher; import org.neo4j.driver.internal.async.pool.ExtendedChannelPool; import org.neo4j.driver.internal.handlers.NoOpResponseHandler; +import org.neo4j.driver.internal.messaging.Message; +import org.neo4j.driver.internal.messaging.request.DiscardAllMessage; +import org.neo4j.driver.internal.messaging.request.DiscardMessage; +import org.neo4j.driver.internal.messaging.request.PullAllMessage; +import org.neo4j.driver.internal.messaging.request.PullMessage; import org.neo4j.driver.internal.messaging.request.RunWithMetadataMessage; import org.neo4j.driver.internal.metrics.DevNullMetricsListener; import org.neo4j.driver.internal.spi.ResponseHandler; @@ -438,6 +450,120 @@ void shouldEnableAutoReadWhenDoingReset() { assertTrue(channel.config().isAutoRead()); } + @Test + void shouldRejectBindingTerminationAwareStateLockingExecutorTwice() { + var channel = newChannel(); + var connection = newConnection(channel); + var lockingExecutor = mock(TerminationAwareStateLockingExecutor.class); + connection.bindTerminationAwareStateLockingExecutor(lockingExecutor); + + assertThrows( + IllegalStateException.class, + () -> connection.bindTerminationAwareStateLockingExecutor(lockingExecutor)); + } + + @ParameterizedTest + @MethodSource("queryMessages") + void shouldPreventDispatchingQueryMessagesOnTermination(QueryMessage queryMessage) { + // Given + var channel = newChannel(); + var connection = newConnection(channel); + var lockingExecutor = mock(TerminationAwareStateLockingExecutor.class); + var error = mock(Neo4jException.class); + doAnswer(invocationOnMock -> { + @SuppressWarnings("unchecked") + var consumer = (Consumer) invocationOnMock.getArguments()[0]; + consumer.accept(error); + return null; + }) + .when(lockingExecutor) + .execute(any()); + connection.bindTerminationAwareStateLockingExecutor(lockingExecutor); + var handler = mock(ResponseHandler.class); + + // When + if (queryMessage.flush()) { + connection.writeAndFlush(queryMessage.message(), handler); + } else { + connection.write(queryMessage.message(), handler); + } + channel.runPendingTasks(); + + // Then + assertTrue(channel.outboundMessages().isEmpty()); + then(lockingExecutor).should().execute(any()); + then(handler).should().onFailure(error); + } + + @ParameterizedTest + @MethodSource("queryMessages") + void shouldDispatchingQueryMessagesWhenNotTerminated(QueryMessage queryMessage) { + // Given + var channel = newChannel(); + var connection = newConnection(channel); + var lockingExecutor = mock(TerminationAwareStateLockingExecutor.class); + doAnswer(invocationOnMock -> { + @SuppressWarnings("unchecked") + var consumer = (Consumer) invocationOnMock.getArguments()[0]; + consumer.accept(null); + return null; + }) + .when(lockingExecutor) + .execute(any()); + connection.bindTerminationAwareStateLockingExecutor(lockingExecutor); + var handler = mock(ResponseHandler.class); + + // When + if (queryMessage.flush()) { + connection.writeAndFlush(queryMessage.message(), handler); + } else { + connection.write(queryMessage.message(), handler); + channel.flushOutbound(); + } + channel.runPendingTasks(); + + // Then + assertEquals(1, channel.outboundMessages().size()); + then(lockingExecutor).should().execute(any()); + } + + @ParameterizedTest + @MethodSource("queryMessages") + void shouldDispatchingQueryMessagesWhenExecutorAbsent(QueryMessage queryMessage) { + // Given + var channel = newChannel(); + var connection = newConnection(channel); + var handler = mock(ResponseHandler.class); + + // When + if (queryMessage.flush()) { + connection.writeAndFlush(queryMessage.message(), handler); + } else { + connection.write(queryMessage.message(), handler); + channel.flushOutbound(); + } + channel.runPendingTasks(); + + // Then + assertEquals(1, channel.outboundMessages().size()); + } + + static List queryMessages() { + return List.of( + new QueryMessage(false, mock(RunWithMetadataMessage.class)), + new QueryMessage(true, mock(RunWithMetadataMessage.class)), + new QueryMessage(false, mock(PullMessage.class)), + new QueryMessage(true, mock(PullMessage.class)), + new QueryMessage(false, mock(PullAllMessage.class)), + new QueryMessage(true, mock(PullAllMessage.class)), + new QueryMessage(false, mock(DiscardMessage.class)), + new QueryMessage(true, mock(DiscardMessage.class)), + new QueryMessage(false, mock(DiscardAllMessage.class)), + new QueryMessage(true, mock(DiscardAllMessage.class))); + } + + private record QueryMessage(boolean flush, Message message) {} + private void testWriteInEventLoop(String threadName, Consumer action) throws Exception { EmbeddedChannel channel = spy(new EmbeddedChannel()); initializeEventLoop(channel, threadName);