diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/NetworkConnection.java b/driver/src/main/java/org/neo4j/driver/internal/async/NetworkConnection.java index 342c71a055..99be9972d5 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/NetworkConnection.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/NetworkConnection.java @@ -43,11 +43,13 @@ import org.neo4j.driver.internal.handlers.ResetResponseHandler; import org.neo4j.driver.internal.messaging.BoltProtocol; import org.neo4j.driver.internal.messaging.Message; +import org.neo4j.driver.internal.messaging.request.CommitMessage; import org.neo4j.driver.internal.messaging.request.DiscardAllMessage; import org.neo4j.driver.internal.messaging.request.DiscardMessage; import org.neo4j.driver.internal.messaging.request.PullAllMessage; import org.neo4j.driver.internal.messaging.request.PullMessage; import org.neo4j.driver.internal.messaging.request.ResetMessage; +import org.neo4j.driver.internal.messaging.request.RollbackMessage; import org.neo4j.driver.internal.messaging.request.RunWithMetadataMessage; import org.neo4j.driver.internal.metrics.ListenerEvent; import org.neo4j.driver.internal.metrics.MetricsListener; @@ -295,7 +297,9 @@ private boolean isQueryMessage(Message message) { || message instanceof PullMessage || message instanceof PullAllMessage || message instanceof DiscardMessage - || message instanceof DiscardAllMessage; + || message instanceof DiscardAllMessage + || message instanceof CommitMessage + || message instanceof RollbackMessage; } private enum Status { diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/UnmanagedTransaction.java b/driver/src/main/java/org/neo4j/driver/internal/async/UnmanagedTransaction.java index 4a8858d971..1523fa4cb1 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/UnmanagedTransaction.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/UnmanagedTransaction.java @@ -247,6 +247,10 @@ private void ensureCanRunQueries() { + "it has either experienced an fatal error or was explicitly terminated", causeOfTermination); } + } else if (commitFuture != null) { + throw new ClientException("Cannot run more queries in this transaction, it is being committed"); + } else if (rollbackFuture != null) { + throw new ClientException("Cannot run more queries in this transaction, it is being rolled back"); } }); } diff --git a/driver/src/test/java/org/neo4j/driver/internal/async/NetworkConnectionTest.java b/driver/src/test/java/org/neo4j/driver/internal/async/NetworkConnectionTest.java index 7fabe97b60..dd5d93affc 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/async/NetworkConnectionTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/async/NetworkConnectionTest.java @@ -68,10 +68,12 @@ import org.neo4j.driver.internal.async.pool.ExtendedChannelPool; import org.neo4j.driver.internal.handlers.NoOpResponseHandler; import org.neo4j.driver.internal.messaging.Message; +import org.neo4j.driver.internal.messaging.request.CommitMessage; import org.neo4j.driver.internal.messaging.request.DiscardAllMessage; import org.neo4j.driver.internal.messaging.request.DiscardMessage; import org.neo4j.driver.internal.messaging.request.PullAllMessage; import org.neo4j.driver.internal.messaging.request.PullMessage; +import org.neo4j.driver.internal.messaging.request.RollbackMessage; import org.neo4j.driver.internal.messaging.request.RunWithMetadataMessage; import org.neo4j.driver.internal.metrics.DevNullMetricsListener; import org.neo4j.driver.internal.spi.ResponseHandler; @@ -559,7 +561,11 @@ static List queryMessages() { new QueryMessage(false, mock(DiscardMessage.class)), new QueryMessage(true, mock(DiscardMessage.class)), new QueryMessage(false, mock(DiscardAllMessage.class)), - new QueryMessage(true, mock(DiscardAllMessage.class))); + new QueryMessage(true, mock(DiscardAllMessage.class)), + new QueryMessage(false, mock(CommitMessage.class)), + new QueryMessage(true, mock(CommitMessage.class)), + new QueryMessage(false, mock(RollbackMessage.class)), + new QueryMessage(true, mock(RollbackMessage.class))); } private record QueryMessage(boolean flush, Message message) {} diff --git a/driver/src/test/java/org/neo4j/driver/internal/async/UnmanagedTransactionTest.java b/driver/src/test/java/org/neo4j/driver/internal/async/UnmanagedTransactionTest.java index 54d3c62ea6..25c1c05c90 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/async/UnmanagedTransactionTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/async/UnmanagedTransactionTest.java @@ -29,6 +29,7 @@ import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.ArgumentMatchers.eq; import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.then; import static org.mockito.Mockito.doAnswer; @@ -52,13 +53,16 @@ import static org.neo4j.driver.testutil.TestUtil.verifyRunRx; import java.util.Collections; +import java.util.List; import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; import java.util.concurrent.ExecutionException; import java.util.function.Consumer; +import java.util.function.Function; import java.util.function.Supplier; import java.util.stream.Stream; +import org.junit.jupiter.api.Named; import org.junit.jupiter.api.Test; import org.junit.jupiter.params.ParameterizedTest; import org.junit.jupiter.params.provider.Arguments; @@ -72,10 +76,12 @@ import org.neo4j.driver.exceptions.ConnectionReadTimeoutException; import org.neo4j.driver.exceptions.Neo4jException; import org.neo4j.driver.exceptions.TransactionTerminatedException; +import org.neo4j.driver.internal.DatabaseBookmark; import org.neo4j.driver.internal.FailableCursor; import org.neo4j.driver.internal.InternalBookmark; import org.neo4j.driver.internal.messaging.BoltProtocol; import org.neo4j.driver.internal.messaging.v4.BoltProtocolV4; +import org.neo4j.driver.internal.messaging.v53.BoltProtocolV53; import org.neo4j.driver.internal.spi.Connection; import org.neo4j.driver.internal.spi.ResponseHandler; @@ -476,6 +482,76 @@ void shouldHandleTerminationWhenAlreadyTerminated() throws ExecutionException, I assertEquals(exception, actualException); } + @ParameterizedTest + @MethodSource("transactionClosingTestParams") + void shouldThrowOnRunningNewQueriesWhenTransactionIsClosing(TransactionClosingTestParams testParams) { + // Given + var boltProtocol = mock(BoltProtocol.class); + given(boltProtocol.version()).willReturn(BoltProtocolV53.VERSION); + var closureStage = new CompletableFuture(); + var connection = connectionMock(boltProtocol); + given(boltProtocol.beginTransaction(eq(connection), any(), any(), any(), any())) + .willReturn(completedFuture(null)); + given(boltProtocol.commitTransaction(connection)).willReturn(closureStage); + given(boltProtocol.rollbackTransaction(connection)).willReturn(closureStage.thenApply(ignored -> null)); + var tx = beginTx(connection); + + // When + testParams.closeAction().apply(tx); + var exception = assertThrows( + ClientException.class, () -> await(testParams.runAction().apply(tx))); + + // Then + assertEquals(testParams.expectedMessage(), exception.getMessage()); + } + + static List transactionClosingTestParams() { + Function> asyncRun = tx -> tx.runAsync(new Query("query")); + Function> reactiveRun = tx -> tx.runRx(new Query("query")); + return List.of( + Arguments.of(Named.of( + "commit and run async", + new TransactionClosingTestParams( + UnmanagedTransaction::commitAsync, + asyncRun, + "Cannot run more queries in this transaction, it is being committed"))), + Arguments.of(Named.of( + "commit and run reactive", + new TransactionClosingTestParams( + UnmanagedTransaction::commitAsync, + reactiveRun, + "Cannot run more queries in this transaction, it is being committed"))), + Arguments.of(Named.of( + "rollback and run async", + new TransactionClosingTestParams( + UnmanagedTransaction::rollbackAsync, + asyncRun, + "Cannot run more queries in this transaction, it is being rolled back"))), + Arguments.of(Named.of( + "rollback and run reactive", + new TransactionClosingTestParams( + UnmanagedTransaction::rollbackAsync, + reactiveRun, + "Cannot run more queries in this transaction, it is being rolled back"))), + Arguments.of(Named.of( + "close and run async", + new TransactionClosingTestParams( + UnmanagedTransaction::closeAsync, + asyncRun, + "Cannot run more queries in this transaction, it is being rolled back"))), + Arguments.of(Named.of( + "close and run reactive", + new TransactionClosingTestParams( + UnmanagedTransaction::closeAsync, + reactiveRun, + "Cannot run more queries in this transaction, it is being rolled back")))); + } + + private record TransactionClosingTestParams( + Function> closeAction, + Function> runAction, + String expectedMessage) {} + private static UnmanagedTransaction beginTx(Connection connection) { return beginTx(connection, Collections.emptySet()); }