Skip to content

Commit f74fc4f

Browse files
committed
Ensure transaction prevents new query runs when it is closing
In addition, prevent `COMMIT` and `ROLLBACK` Bolt messages from being dispatched if transaction is in terminated state.
1 parent 2220ba5 commit f74fc4f

File tree

4 files changed

+92
-2
lines changed

4 files changed

+92
-2
lines changed

driver/src/main/java/org/neo4j/driver/internal/async/NetworkConnection.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,13 @@
4343
import org.neo4j.driver.internal.handlers.ResetResponseHandler;
4444
import org.neo4j.driver.internal.messaging.BoltProtocol;
4545
import org.neo4j.driver.internal.messaging.Message;
46+
import org.neo4j.driver.internal.messaging.request.CommitMessage;
4647
import org.neo4j.driver.internal.messaging.request.DiscardAllMessage;
4748
import org.neo4j.driver.internal.messaging.request.DiscardMessage;
4849
import org.neo4j.driver.internal.messaging.request.PullAllMessage;
4950
import org.neo4j.driver.internal.messaging.request.PullMessage;
5051
import org.neo4j.driver.internal.messaging.request.ResetMessage;
52+
import org.neo4j.driver.internal.messaging.request.RollbackMessage;
5153
import org.neo4j.driver.internal.messaging.request.RunWithMetadataMessage;
5254
import org.neo4j.driver.internal.metrics.ListenerEvent;
5355
import org.neo4j.driver.internal.metrics.MetricsListener;
@@ -295,7 +297,9 @@ private boolean isQueryMessage(Message message) {
295297
|| message instanceof PullMessage
296298
|| message instanceof PullAllMessage
297299
|| message instanceof DiscardMessage
298-
|| message instanceof DiscardAllMessage;
300+
|| message instanceof DiscardAllMessage
301+
|| message instanceof CommitMessage
302+
|| message instanceof RollbackMessage;
299303
}
300304

301305
private enum Status {

driver/src/main/java/org/neo4j/driver/internal/async/UnmanagedTransaction.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -247,6 +247,10 @@ private void ensureCanRunQueries() {
247247
+ "it has either experienced an fatal error or was explicitly terminated",
248248
causeOfTermination);
249249
}
250+
} else if (commitFuture != null) {
251+
throw new ClientException("Cannot run more queries in this transaction, it is being committed");
252+
} else if (rollbackFuture != null) {
253+
throw new ClientException("Cannot run more queries in this transaction, it is being rolled back");
250254
}
251255
});
252256
}

driver/src/test/java/org/neo4j/driver/internal/async/NetworkConnectionTest.java

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,10 +68,12 @@
6868
import org.neo4j.driver.internal.async.pool.ExtendedChannelPool;
6969
import org.neo4j.driver.internal.handlers.NoOpResponseHandler;
7070
import org.neo4j.driver.internal.messaging.Message;
71+
import org.neo4j.driver.internal.messaging.request.CommitMessage;
7172
import org.neo4j.driver.internal.messaging.request.DiscardAllMessage;
7273
import org.neo4j.driver.internal.messaging.request.DiscardMessage;
7374
import org.neo4j.driver.internal.messaging.request.PullAllMessage;
7475
import org.neo4j.driver.internal.messaging.request.PullMessage;
76+
import org.neo4j.driver.internal.messaging.request.RollbackMessage;
7577
import org.neo4j.driver.internal.messaging.request.RunWithMetadataMessage;
7678
import org.neo4j.driver.internal.metrics.DevNullMetricsListener;
7779
import org.neo4j.driver.internal.spi.ResponseHandler;
@@ -559,7 +561,11 @@ static List<QueryMessage> queryMessages() {
559561
new QueryMessage(false, mock(DiscardMessage.class)),
560562
new QueryMessage(true, mock(DiscardMessage.class)),
561563
new QueryMessage(false, mock(DiscardAllMessage.class)),
562-
new QueryMessage(true, mock(DiscardAllMessage.class)));
564+
new QueryMessage(true, mock(DiscardAllMessage.class)),
565+
new QueryMessage(false, mock(CommitMessage.class)),
566+
new QueryMessage(true, mock(CommitMessage.class)),
567+
new QueryMessage(false, mock(RollbackMessage.class)),
568+
new QueryMessage(true, mock(RollbackMessage.class)));
563569
}
564570

565571
private record QueryMessage(boolean flush, Message message) {}

driver/src/test/java/org/neo4j/driver/internal/async/UnmanagedTransactionTest.java

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
import static org.junit.jupiter.api.Assertions.assertTrue;
3030
import static org.mockito.ArgumentMatchers.any;
3131
import static org.mockito.ArgumentMatchers.argThat;
32+
import static org.mockito.ArgumentMatchers.eq;
3233
import static org.mockito.BDDMockito.given;
3334
import static org.mockito.BDDMockito.then;
3435
import static org.mockito.Mockito.doAnswer;
@@ -52,13 +53,16 @@
5253
import static org.neo4j.driver.testutil.TestUtil.verifyRunRx;
5354

5455
import java.util.Collections;
56+
import java.util.List;
5557
import java.util.Set;
5658
import java.util.concurrent.CompletableFuture;
5759
import java.util.concurrent.CompletionStage;
5860
import java.util.concurrent.ExecutionException;
5961
import java.util.function.Consumer;
62+
import java.util.function.Function;
6063
import java.util.function.Supplier;
6164
import java.util.stream.Stream;
65+
import org.junit.jupiter.api.Named;
6266
import org.junit.jupiter.api.Test;
6367
import org.junit.jupiter.params.ParameterizedTest;
6468
import org.junit.jupiter.params.provider.Arguments;
@@ -72,10 +76,12 @@
7276
import org.neo4j.driver.exceptions.ConnectionReadTimeoutException;
7377
import org.neo4j.driver.exceptions.Neo4jException;
7478
import org.neo4j.driver.exceptions.TransactionTerminatedException;
79+
import org.neo4j.driver.internal.DatabaseBookmark;
7580
import org.neo4j.driver.internal.FailableCursor;
7681
import org.neo4j.driver.internal.InternalBookmark;
7782
import org.neo4j.driver.internal.messaging.BoltProtocol;
7883
import org.neo4j.driver.internal.messaging.v4.BoltProtocolV4;
84+
import org.neo4j.driver.internal.messaging.v53.BoltProtocolV53;
7985
import org.neo4j.driver.internal.spi.Connection;
8086
import org.neo4j.driver.internal.spi.ResponseHandler;
8187

@@ -476,6 +482,76 @@ void shouldHandleTerminationWhenAlreadyTerminated() throws ExecutionException, I
476482
assertEquals(exception, actualException);
477483
}
478484

485+
@ParameterizedTest
486+
@MethodSource("transactionClosingTestParams")
487+
void shouldThrowOnRunningNewQueriesWhenTransactionIsClosing(TransactionClosingTestParams testParams) {
488+
// Given
489+
var boltProtocol = mock(BoltProtocol.class);
490+
given(boltProtocol.version()).willReturn(BoltProtocolV53.VERSION);
491+
var closureStage = new CompletableFuture<DatabaseBookmark>();
492+
var connection = connectionMock(boltProtocol);
493+
given(boltProtocol.beginTransaction(eq(connection), any(), any(), any(), any()))
494+
.willReturn(completedFuture(null));
495+
given(boltProtocol.commitTransaction(connection)).willReturn(closureStage);
496+
given(boltProtocol.rollbackTransaction(connection)).willReturn(closureStage.thenApply(ignored -> null));
497+
var tx = beginTx(connection);
498+
499+
// When
500+
testParams.closeAction().apply(tx);
501+
var exception = assertThrows(
502+
ClientException.class, () -> await(testParams.runAction().apply(tx)));
503+
504+
// Then
505+
assertEquals(testParams.expectedMessage(), exception.getMessage());
506+
}
507+
508+
static List<Arguments> transactionClosingTestParams() {
509+
Function<UnmanagedTransaction, CompletionStage<?>> asyncRun = tx -> tx.runAsync(new Query("query"));
510+
Function<UnmanagedTransaction, CompletionStage<?>> reactiveRun = tx -> tx.runRx(new Query("query"));
511+
return List.of(
512+
Arguments.of(Named.of(
513+
"commit and run async",
514+
new TransactionClosingTestParams(
515+
UnmanagedTransaction::commitAsync,
516+
asyncRun,
517+
"Cannot run more queries in this transaction, it is being committed"))),
518+
Arguments.of(Named.of(
519+
"commit and run reactive",
520+
new TransactionClosingTestParams(
521+
UnmanagedTransaction::commitAsync,
522+
reactiveRun,
523+
"Cannot run more queries in this transaction, it is being committed"))),
524+
Arguments.of(Named.of(
525+
"rollback and run async",
526+
new TransactionClosingTestParams(
527+
UnmanagedTransaction::rollbackAsync,
528+
asyncRun,
529+
"Cannot run more queries in this transaction, it is being rolled back"))),
530+
Arguments.of(Named.of(
531+
"rollback and run reactive",
532+
new TransactionClosingTestParams(
533+
UnmanagedTransaction::rollbackAsync,
534+
reactiveRun,
535+
"Cannot run more queries in this transaction, it is being rolled back"))),
536+
Arguments.of(Named.of(
537+
"close and run async",
538+
new TransactionClosingTestParams(
539+
UnmanagedTransaction::closeAsync,
540+
asyncRun,
541+
"Cannot run more queries in this transaction, it is being rolled back"))),
542+
Arguments.of(Named.of(
543+
"close and run reactive",
544+
new TransactionClosingTestParams(
545+
UnmanagedTransaction::closeAsync,
546+
reactiveRun,
547+
"Cannot run more queries in this transaction, it is being rolled back"))));
548+
}
549+
550+
private record TransactionClosingTestParams(
551+
Function<UnmanagedTransaction, CompletionStage<?>> closeAction,
552+
Function<UnmanagedTransaction, CompletionStage<?>> runAction,
553+
String expectedMessage) {}
554+
479555
private static UnmanagedTransaction beginTx(Connection connection) {
480556
return beginTx(connection, Collections.emptySet());
481557
}

0 commit comments

Comments
 (0)