diff --git a/driver/src/main/java/org/neo4j/driver/internal/InternalSession.java b/driver/src/main/java/org/neo4j/driver/internal/InternalSession.java index 4e82b18bcc..ba0c1e158e 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/InternalSession.java +++ b/driver/src/main/java/org/neo4j/driver/internal/InternalSession.java @@ -140,6 +140,13 @@ public Set lastBookmarks() { return session.lastBookmarks(); } + // Private API + public void reset() { + Futures.blockingGet( + session.resetAsync(), + () -> terminateConnectionOnThreadInterrupt("Thread interrupted while resetting the session")); + } + private T transaction( AccessMode mode, @SuppressWarnings("deprecation") TransactionWork work, TransactionConfig config) { // use different code path compared to async so that work is executed in the caller thread diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/NetworkSession.java b/driver/src/main/java/org/neo4j/driver/internal/async/NetworkSession.java index 57b02a713c..be43e241e8 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/NetworkSession.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/NetworkSession.java @@ -164,6 +164,24 @@ public CompletionStage beginTransactionAsync( return newTransactionStage; } + // Private API + public CompletionStage resetAsync() { + return existingTransactionOrNull() + .thenAccept(tx -> { + if (tx != null) { + tx.markTerminated(null); + } + }) + .thenCompose(ignore -> connectionStage) + .thenCompose(connection -> { + if (connection != null) { + // there exists an active connection, send a RESET message over it + return connection.reset(); + } + return completedWithNull(); + }); + } + public RetryLogic retryLogic() { return retryLogic; } diff --git a/driver/src/main/java/org/neo4j/driver/internal/reactive/InternalRxSession.java b/driver/src/main/java/org/neo4j/driver/internal/reactive/InternalRxSession.java index 7dd371c34a..8ac4c6817f 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/reactive/InternalRxSession.java +++ b/driver/src/main/java/org/neo4j/driver/internal/reactive/InternalRxSession.java @@ -18,6 +18,8 @@ */ package org.neo4j.driver.internal.reactive; +import static org.neo4j.driver.internal.reactive.RxUtils.createEmptyPublisher; + import java.util.Map; import java.util.concurrent.CompletableFuture; import org.neo4j.driver.AccessMode; @@ -129,6 +131,10 @@ public Bookmark lastBookmark() { return InternalBookmark.from(session.lastBookmarks()); } + public Publisher reset() { + return createEmptyPublisher(session::resetAsync); + } + @Override public Publisher close() { return doClose(); diff --git a/driver/src/test/java/org/neo4j/driver/integration/SessionResetIT.java b/driver/src/test/java/org/neo4j/driver/integration/SessionResetIT.java new file mode 100644 index 0000000000..bf911d356d --- /dev/null +++ b/driver/src/test/java/org/neo4j/driver/integration/SessionResetIT.java @@ -0,0 +1,754 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.integration; + +import static java.util.Collections.newSetFromMap; +import static java.util.concurrent.CompletableFuture.runAsync; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.concurrent.TimeUnit.MINUTES; +import static java.util.concurrent.TimeUnit.SECONDS; +import static java.util.stream.IntStream.range; +import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.startsWith; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.lessThanOrEqualTo; +import static org.hamcrest.junit.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.neo4j.driver.Values.parameters; +import static org.neo4j.driver.testutil.DaemonThreadFactory.daemon; +import static org.neo4j.driver.testutil.TestUtil.activeQueryCount; +import static org.neo4j.driver.testutil.TestUtil.activeQueryNames; +import static org.neo4j.driver.testutil.TestUtil.await; +import static org.neo4j.driver.testutil.TestUtil.awaitCondition; + +import java.io.IOException; +import java.io.UncheckedIOException; +import java.net.URI; +import java.nio.channels.ClosedChannelException; +import java.util.ArrayList; +import java.util.List; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Collectors; +import org.hamcrest.CoreMatchers; +import org.hamcrest.MatcherAssert; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.neo4j.driver.Driver; +import org.neo4j.driver.Result; +import org.neo4j.driver.Session; +import org.neo4j.driver.SimpleQueryRunner; +import org.neo4j.driver.Transaction; +import org.neo4j.driver.exceptions.ClientException; +import org.neo4j.driver.exceptions.Neo4jException; +import org.neo4j.driver.exceptions.ServiceUnavailableException; +import org.neo4j.driver.exceptions.TransientException; +import org.neo4j.driver.internal.InternalSession; +import org.neo4j.driver.testutil.DatabaseExtension; +import org.neo4j.driver.testutil.ParallelizableIT; +import org.neo4j.driver.testutil.TestUtil; +import org.testcontainers.utility.MountableFile; + +@ParallelizableIT +class SessionResetIT { + private static final int CSV_FILE_SIZE = 10_000; + private static final int LOAD_CSV_BATCH_SIZE = 10; + + private static final String SHORT_QUERY_1 = "CREATE (n:Node {name: 'foo', occupation: 'bar'})"; + private static final String SHORT_QUERY_2 = "MATCH (n:Node {name: 'foo'}) RETURN count(n)"; + private static final String LONG_QUERY = "UNWIND range(0, 10000000) AS i CREATE (n:Node {idx: i}) DELETE n"; + private static final String LONG_PERIODIC_COMMIT_QUERY_TEMPLATE = + """ + USING PERIODIC COMMIT 1 + LOAD CSV FROM '%%s' AS line + UNWIND range(1, %d) AS index + CREATE (n:Node {id: index, name: line[0], occupation: line[1]}) + """ + .formatted(LOAD_CSV_BATCH_SIZE); + private static final String LONG_CALL_IN_TX_QUERY_TEMPLATE = + """ + LOAD CSV FROM '%%s' AS line + CALL { + WITH line + UNWIND range(1, %d) as index + CREATE (n:Node {id: index, name: line[0], occupation: line[1]}) + } IN TRANSACTIONS OF 1 ROW + """ + .formatted(LOAD_CSV_BATCH_SIZE); + + private static final int STRESS_TEST_THREAD_COUNT = Runtime.getRuntime().availableProcessors() * 2; + private static final long STRESS_TEST_DURATION_MS = SECONDS.toMillis(5); + private static final String[] STRESS_TEST_QUERIES = {SHORT_QUERY_1, SHORT_QUERY_2, LONG_QUERY}; + private static final String LONG_RUNNING_PLUGIN_PATH = "/longRunningStatement.jar"; + + @RegisterExtension + static final DatabaseExtension neo4j = + new DatabaseExtension().installPlugin(MountableFile.forClasspathResource(LONG_RUNNING_PLUGIN_PATH)); + + private ExecutorService executor; + + @BeforeEach + void setUp() { + executor = Executors.newCachedThreadPool(daemon(getClass().getSimpleName() + "-thread")); + } + + @AfterEach + void tearDown() { + if (executor != null) { + executor.shutdownNow(); + } + } + + @Test + void shouldTerminateAutoCommitQuery() { + testQueryTermination(LONG_QUERY, true); + } + + @Test + void shouldTerminateQueryInUnmanagedTransaction() { + testQueryTermination(LONG_QUERY, false); + } + + /** + * It is currently unsafe to terminate periodic commit query because it'll then be half-committed. + * So the driver give no guarantee when the periodic commit could be terminated. + * For a user who want to terminate a periodic commit, he or she should use kill query by id. + */ + @Test + void shouldTerminatePeriodicCommitQueryRandomly() { + Future queryResult = runQueryInDifferentThreadAndResetSession(longPeriodicCommitQuery(), true); + + final var e = assertThrows(ExecutionException.class, () -> queryResult.get(1, MINUTES)); + assertThat(e.getMessage(), containsString("The transaction has been terminated")); + assertThat(e.getCause(), instanceOf(Neo4jException.class)); + + awaitNoActiveQueries(); + + assertThat(countNodes(), lessThanOrEqualTo(((long) CSV_FILE_SIZE) * LOAD_CSV_BATCH_SIZE)); + } + + @Test + void shouldTerminateAutoCommitQueriesRandomly() throws Exception { + testRandomQueryTermination(true); + } + + @Test + void shouldTerminateQueriesInUnmanagedTransactionsRandomly() throws Exception { + testRandomQueryTermination(false); + } + + @Test + void shouldRejectNewTransactionWhenOpenTransactionExistsAndShouldFailRunResultOnSessionReset() { + try (Session session = neo4j.driver().session()) { + Transaction tx1 = session.beginTransaction(); + + CompletableFuture txRunFuture = CompletableFuture.runAsync( + () -> tx1.run("CALL test.driver.longRunningStatement($seconds)", parameters("seconds", 10))); + + awaitActiveQueriesToContain("CALL test.driver.longRunningStatement"); + ((InternalSession) session).reset(); + + ClientException e1 = assertThrows(ClientException.class, session::beginTransaction); + assertThat( + e1.getMessage(), + containsString("You cannot begin a transaction on a session with an open transaction")); + + ClientException e2 = assertThrows(ClientException.class, () -> tx1.run("RETURN 1")); + assertThat(e2.getMessage(), containsString("Cannot run more queries in this transaction")); + + // Make sure failure from the terminated long running query is propagated + Neo4jException e3 = assertThrows(Neo4jException.class, () -> await(txRunFuture)); + assertThat(e3.getMessage(), containsString("The transaction has been terminated")); + } + } + + @Test + void shouldSuccessfullyCloseAfterSessionReset() { + try (Session session = neo4j.driver().session()) { + CompletableFuture.runAsync( + () -> session.run("CALL test.driver.longRunningStatement($seconds)", parameters("seconds", 10))); + + awaitActiveQueriesToContain("CALL test.driver.longRunningStatement"); + ((InternalSession) session).reset(); + } + } + + @Test + void shouldBeAbleToBeginNewTransactionAfterFirstTransactionInterruptedBySessionResetIsClosed() { + try (Session session = neo4j.driver().session()) { + Transaction tx1 = session.beginTransaction(); + + CompletableFuture txRunFuture = runAsync( + () -> tx1.run("CALL test.driver.longRunningStatement($seconds)", parameters("seconds", 10))); + + awaitActiveQueriesToContain("CALL test.driver.longRunningStatement"); + ((InternalSession) session).reset(); + + Neo4jException e = assertThrows(Neo4jException.class, () -> await(txRunFuture)); + assertThat(e.getMessage(), containsString("The transaction has been terminated")); + tx1.close(); + + try (Transaction tx2 = session.beginTransaction()) { + tx2.run("CREATE (n:FirstNode)"); + tx2.commit(); + } + + Result result = session.run("MATCH (n) RETURN count(n)"); + long nodes = result.single().get("count(n)").asLong(); + MatcherAssert.assertThat(nodes, equalTo(1L)); + } + } + + @Test + void shouldKillLongRunningQuery() { + final int executionTimeout = 10; // 10s + final int killTimeout = 1; // 1s + final AtomicLong startTime = new AtomicLong(-1); + long endTime; + + try (Session session = neo4j.driver().session()) { + CompletableFuture sessionRunFuture = CompletableFuture.runAsync(() -> { + // When + startTime.set(System.currentTimeMillis()); + session.run("CALL test.driver.longRunningStatement($seconds)", parameters("seconds", executionTimeout)); + }); + + resetSessionAfterTimeout(session, killTimeout); + + assertThrows(Neo4jException.class, () -> await(sessionRunFuture)); + } + + endTime = System.currentTimeMillis(); + assertTrue(startTime.get() > 0); + assertTrue(endTime - startTime.get() > killTimeout * 1000); // get reset by session.reset + assertTrue(endTime - startTime.get() < executionTimeout * 1000 / 2); // finished before execution finished + } + + @Test + void shouldKillLongStreamingResult() { + // Given + final int executionTimeout = 10; // 10s + final int killTimeout = 1; // 1s + final AtomicInteger recordCount = new AtomicInteger(); + final AtomicLong startTime = new AtomicLong(-1); + long endTime; + + Neo4jException e = assertThrows(Neo4jException.class, () -> { + try (Session session = neo4j.driver().session()) { + Result result = session.run( + "CALL test.driver.longStreamingResult($seconds)", parameters("seconds", executionTimeout)); + + resetSessionAfterTimeout(session, killTimeout); + + // When + startTime.set(System.currentTimeMillis()); + while (result.hasNext()) { + result.next(); + recordCount.incrementAndGet(); + } + } + }); + + endTime = System.currentTimeMillis(); + assertThat(e.getMessage(), containsString("The transaction has been terminated")); + assertThat(recordCount.get(), greaterThan(1)); + + assertTrue(startTime.get() > 0); + assertTrue(endTime - startTime.get() > killTimeout * 1000); // get reset by session.reset + assertTrue(endTime - startTime.get() < executionTimeout * 1000 / 2); // finished before execution finished + } + + private void resetSessionAfterTimeout(Session session, int timeout) { + executor.submit(() -> { + try { + Thread.sleep(timeout * 1000); // let the query execute for timeout seconds + } catch (InterruptedException ignore) { + } finally { + ((InternalSession) session).reset(); // reset the session after timeout + } + }); + } + + @Test + void shouldAllowMoreQueriesAfterSessionReset() { + // Given + try (Session session = neo4j.driver().session()) { + + session.run("RETURN 1").consume(); + + // When reset the state of this session + ((InternalSession) session).reset(); + + // Then can run successfully more queries without any error + session.run("RETURN 2").consume(); + } + } + + @Test + void shouldAllowMoreTxAfterSessionReset() { + // Given + try (Session session = neo4j.driver().session()) { + try (Transaction tx = session.beginTransaction()) { + tx.run("RETURN 1"); + tx.commit(); + } + + // When reset the state of this session + ((InternalSession) session).reset(); + + // Then can run more Tx + try (Transaction tx = session.beginTransaction()) { + tx.run("RETURN 2"); + tx.commit(); + } + } + } + + @Test + void shouldMarkTxAsFailedAndDisallowRunAfterSessionReset() { + // Given + try (Session session = neo4j.driver().session()) { + Transaction tx = session.beginTransaction(); + // When reset the state of this session + ((InternalSession) session).reset(); + + // Then + Exception e = assertThrows(Exception.class, () -> { + tx.run("RETURN 1"); + tx.commit(); + }); + assertThat(e.getMessage(), startsWith("Cannot run more queries in this transaction")); + } + } + + @Test + void shouldAllowMoreTxAfterSessionResetInTx() { + // Given + try (Session session = neo4j.driver().session()) { + try (Transaction ignore = session.beginTransaction()) { + // When reset the state of this session + ((InternalSession) session).reset(); + } + + // Then can run more Tx + try (Transaction tx = session.beginTransaction()) { + tx.run("RETURN 2"); + tx.commit(); + } + } + } + + @Test + void resetShouldStopQueryWaitingForALock() throws Exception { + testResetOfQueryWaitingForLock(new NodeIdUpdater() { + @Override + void performUpdate( + Driver driver, + int nodeId, + int newNodeId, + AtomicReference usedSessionRef, + CountDownLatch latchToWait) + throws Exception { + try (Session session = driver.session()) { + usedSessionRef.set(session); + latchToWait.await(); + Result result = updateNodeId(session, nodeId, newNodeId); + result.consume(); + } + } + }); + } + + @Test + void resetShouldStopTransactionWaitingForALock() throws Exception { + testResetOfQueryWaitingForLock(new NodeIdUpdater() { + @Override + public void performUpdate( + Driver driver, + int nodeId, + int newNodeId, + AtomicReference usedSessionRef, + CountDownLatch latchToWait) + throws Exception { + try (Session session = neo4j.driver().session(); + Transaction tx = session.beginTransaction()) { + usedSessionRef.set(session); + latchToWait.await(); + Result result = updateNodeId(tx, nodeId, newNodeId); + result.consume(); + } + } + }); + } + + @Test + void resetShouldStopWriteTransactionWaitingForALock() throws Exception { + AtomicInteger invocationsOfWork = new AtomicInteger(); + + testResetOfQueryWaitingForLock(new NodeIdUpdater() { + @Override + public void performUpdate( + Driver driver, + int nodeId, + int newNodeId, + AtomicReference usedSessionRef, + CountDownLatch latchToWait) + throws Exception { + try (Session session = driver.session()) { + usedSessionRef.set(session); + latchToWait.await(); + + session.executeWrite(tx -> { + invocationsOfWork.incrementAndGet(); + Result result = updateNodeId(tx, nodeId, newNodeId); + result.consume(); + return null; + }); + } + } + }); + + assertEquals(1, invocationsOfWork.get()); + } + + @Test + void shouldBeAbleToRunMoreQueriesAfterResetOnNoErrorState() { + try (Session session = neo4j.driver().session()) { + // Given + ((InternalSession) session).reset(); + + // When + Transaction tx = session.beginTransaction(); + tx.run("CREATE (n:FirstNode)"); + tx.commit(); + + // Then the outcome of both queries should be visible + Result result = session.run("MATCH (n) RETURN count(n)"); + long nodes = result.single().get("count(n)").asLong(); + assertThat(nodes, equalTo(1L)); + } + } + + @Test + void shouldHandleResetBeforeRun() { + try (Session session = neo4j.driver().session(); + Transaction tx = session.beginTransaction()) { + ((InternalSession) session).reset(); + + ClientException e = assertThrows(ClientException.class, () -> tx.run("CREATE (n:FirstNode)")); + assertThat(e.getMessage(), containsString("Cannot run more queries in this transaction")); + } + } + + @Test + void shouldHandleResetFromMultipleThreads() throws Throwable { + Session session = neo4j.driver().session(); + + CountDownLatch beforeCommit = new CountDownLatch(1); + CountDownLatch afterReset = new CountDownLatch(1); + + Future txFuture = executor.submit(() -> { + Transaction tx1 = session.beginTransaction(); + tx1.run("CREATE (n:FirstNode)"); + beforeCommit.countDown(); + afterReset.await(); + + // session has been reset, it should not be possible to commit the transaction + try { + tx1.commit(); + } catch (Neo4jException ignore) { + } + + try (Transaction tx2 = session.beginTransaction()) { + tx2.run("CREATE (n:SecondNode)"); + tx2.commit(); + } + + return null; + }); + + Future resetFuture = executor.submit(() -> { + beforeCommit.await(); + ((InternalSession) session).reset(); + afterReset.countDown(); + return null; + }); + + executor.shutdown(); + executor.awaitTermination(20, SECONDS); + + txFuture.get(20, SECONDS); + resetFuture.get(20, SECONDS); + + assertEquals(0, countNodes("FirstNode")); + assertEquals(1, countNodes("SecondNode")); + } + + private void testResetOfQueryWaitingForLock(NodeIdUpdater nodeIdUpdater) throws Exception { + int nodeId = 42; + int newNodeId1 = 4242; + int newNodeId2 = 424242; + + createNodeWithId(nodeId); + + CountDownLatch nodeLocked = new CountDownLatch(1); + AtomicReference otherSessionRef = new AtomicReference<>(); + + try (Session session = neo4j.driver().session(); + Transaction tx = session.beginTransaction()) { + Future txResult = nodeIdUpdater.update(nodeId, newNodeId1, otherSessionRef, nodeLocked); + + Result result = updateNodeId(tx, nodeId, newNodeId2); + result.consume(); + + nodeLocked.countDown(); + // give separate thread some time to block on a lock + Thread.sleep(2_000); + ((InternalSession) otherSessionRef.get()).reset(); + + assertTransactionTerminated(txResult); + tx.commit(); + } + + try (Session session = neo4j.driver().session()) { + Result result = session.run("MATCH (n) RETURN n.id AS id"); + int value = result.single().get("id").asInt(); + assertEquals(newNodeId2, value); + } + } + + private void createNodeWithId(int id) { + try (Session session = neo4j.driver().session()) { + session.run("CREATE (n {id: $id})", parameters("id", id)); + } + } + + private static Result updateNodeId(SimpleQueryRunner queryRunner, int currentId, int newId) { + return queryRunner.run( + "MATCH (n {id: $currentId}) SET n.id = $newId", parameters("currentId", currentId, "newId", newId)); + } + + private static void assertTransactionTerminated(Future work) { + ExecutionException e = assertThrows(ExecutionException.class, () -> work.get(20, TimeUnit.SECONDS)); + assertThat(e.getCause(), CoreMatchers.instanceOf(ClientException.class)); + assertThat(e.getCause().getMessage(), startsWith("The transaction has been terminated")); + } + + private void testRandomQueryTermination(boolean autoCommit) throws Exception { + Set runningSessions = newSetFromMap(new ConcurrentHashMap<>()); + AtomicBoolean stop = new AtomicBoolean(); + List> futures = new ArrayList<>(); + + for (int i = 0; i < STRESS_TEST_THREAD_COUNT; i++) { + futures.add(executor.submit(() -> { + ThreadLocalRandom random = ThreadLocalRandom.current(); + while (!stop.get()) { + runRandomQuery(autoCommit, random, runningSessions, stop); + } + })); + } + + long deadline = System.currentTimeMillis() + STRESS_TEST_DURATION_MS; + while (!stop.get()) { + if (System.currentTimeMillis() > deadline) { + stop.set(true); + } + + resetAny(runningSessions); + + MILLISECONDS.sleep(30); + } + + futures.forEach(TestUtil::await); + awaitNoActiveQueries(); + } + + private void runRandomQuery(boolean autoCommit, Random random, Set runningSessions, AtomicBoolean stop) { + try { + Session session = neo4j.driver().session(); + runningSessions.add(session); + try { + String query = STRESS_TEST_QUERIES[random.nextInt(STRESS_TEST_QUERIES.length - 1)]; + runQuery(session, query, autoCommit); + } finally { + runningSessions.remove(session); + session.close(); + } + } catch (Throwable error) { + if (!stop.get() && !isAcceptable(error)) { + stop.set(true); + throw error; + } + // else it is fine to receive some errors from the driver because + // sessions are being reset concurrently by the main thread, driver can also be closed concurrently + } + } + + private void testQueryTermination(String query, boolean autoCommit) { + Future queryResult = runQueryInDifferentThreadAndResetSession(query, autoCommit); + ExecutionException e = assertThrows(ExecutionException.class, () -> queryResult.get(10, SECONDS)); + assertThat(e.getCause(), instanceOf(Neo4jException.class)); + awaitNoActiveQueries(); + } + + private Future runQueryInDifferentThreadAndResetSession(String query, boolean autoCommit) { + AtomicReference sessionRef = new AtomicReference<>(); + + Future queryResult = runAsync(() -> { + Session session = neo4j.driver().session(); + sessionRef.set(session); + runQuery(session, query, autoCommit); + }); + + awaitActiveQueriesToContain(query); + + Session session = sessionRef.get(); + assertNotNull(session); + ((InternalSession) session).reset(); + + return queryResult; + } + + private static void runQuery(Session session, String query, boolean autoCommit) { + if (autoCommit) { + session.run(query).consume(); + } else { + try (Transaction tx = session.beginTransaction()) { + tx.run(query); + tx.commit(); + } + } + } + + private void awaitNoActiveQueries() { + awaitCondition(() -> activeQueryCount(neo4j) == 0); + } + + private void awaitActiveQueriesToContain(String value) { + awaitCondition(() -> activeQueryNames(neo4j).stream().anyMatch(query -> query.contains(value))); + } + + private long countNodes() { + return countNodes(null); + } + + private long countNodes(String label) { + try (Session session = neo4j.driver().session()) { + Result result = + session.run("MATCH (n" + (label == null ? "" : ":" + label) + ") RETURN count(n) AS result"); + return result.single().get(0).asLong(); + } + } + + private static void resetAny(Set sessions) { + sessions.stream().findAny().ifPresent(session -> { + if (sessions.remove(session)) { + resetSafely(session); + } + }); + } + + private static void resetSafely(Session session) { + try { + if (session.isOpen()) { + ((InternalSession) session).reset(); + } + } catch (ClientException e) { + if (session.isOpen()) { + throw e; + } + // else this thread lost race with close and it's fine + } + } + + private static boolean isAcceptable(Throwable error) { + // get the root cause + while (error.getCause() != null) { + error = error.getCause(); + } + + return isTransactionTerminatedException(error) + || error instanceof ServiceUnavailableException + || error instanceof ClientException + || error instanceof ClosedChannelException; + } + + private static boolean isTransactionTerminatedException(Throwable error) { + return error instanceof TransientException + && error.getMessage().startsWith("The transaction has been terminated") + || error.getMessage().startsWith("Trying to execute query in a terminated transaction"); + } + + private String longPeriodicCommitQuery() { + URI fileUri = createTmpCsvFile(); + final var query = + neo4j.isNeo4j44OrEarlier() ? LONG_PERIODIC_COMMIT_QUERY_TEMPLATE : LONG_CALL_IN_TX_QUERY_TEMPLATE; + return String.format(query, fileUri); + } + + private static URI createTmpCsvFile() { + try { + final String content = range(0, CSV_FILE_SIZE) + .mapToObj(i -> "Foo-" + i + ", Bar-" + i) + .collect(Collectors.joining("\n")); + final String path = neo4j.addImportFile(SessionResetIT.class.getSimpleName(), ".csv", content); + return URI.create(path); + } catch (IOException e) { + throw new UncheckedIOException(e); + } + } + + private abstract class NodeIdUpdater { + final Future update( + int nodeId, int newNodeId, AtomicReference usedSessionRef, CountDownLatch latchToWait) { + return executor.submit(() -> { + performUpdate(neo4j.driver(), nodeId, newNodeId, usedSessionRef, latchToWait); + return null; + }); + } + + abstract void performUpdate( + Driver driver, + int nodeId, + int newNodeId, + AtomicReference usedSessionRef, + CountDownLatch latchToWait) + throws Exception; + } +} diff --git a/driver/src/test/java/org/neo4j/driver/internal/async/InternalAsyncTransactionTest.java b/driver/src/test/java/org/neo4j/driver/internal/async/InternalAsyncTransactionTest.java index a950ee1e03..191fb448b6 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/async/InternalAsyncTransactionTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/async/InternalAsyncTransactionTest.java @@ -21,6 +21,8 @@ import static java.util.Collections.singletonList; import static java.util.Collections.singletonMap; import static java.util.concurrent.CompletableFuture.completedFuture; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsString; import static org.junit.Assert.assertFalse; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertThrows; @@ -53,11 +55,13 @@ import org.neo4j.driver.Value; import org.neo4j.driver.async.AsyncTransaction; import org.neo4j.driver.async.ResultCursor; +import org.neo4j.driver.exceptions.ClientException; import org.neo4j.driver.internal.DatabaseNameUtil; import org.neo4j.driver.internal.InternalRecord; import org.neo4j.driver.internal.messaging.v4.BoltProtocolV4; import org.neo4j.driver.internal.spi.Connection; import org.neo4j.driver.internal.spi.ConnectionProvider; +import org.neo4j.driver.internal.util.Futures; import org.neo4j.driver.internal.value.IntegerValue; import org.neo4j.driver.summary.ResultSummary; @@ -120,6 +124,18 @@ void shouldRollback() { assertFalse(tx.isOpen()); } + @Test + void shouldRollbackWhenFailedRun() { + Futures.blockingGet(networkSession.resetAsync()); + ClientException clientException = assertThrows(ClientException.class, () -> await(tx.commitAsync())); + + assertThat( + clientException.getMessage(), + containsString("It has been rolled back either because of an error or explicit termination")); + verify(connection).release(); + assertFalse(tx.isOpen()); + } + @Test void shouldReleaseConnectionWhenFailedToCommit() { setupFailingCommit(connection); diff --git a/driver/src/test/java/org/neo4j/driver/internal/async/NetworkSessionTest.java b/driver/src/test/java/org/neo4j/driver/internal/async/NetworkSessionTest.java index 5aba0c63c7..2e8016a9a2 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/async/NetworkSessionTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/async/NetworkSessionTest.java @@ -72,6 +72,7 @@ import org.neo4j.driver.internal.messaging.v4.BoltProtocolV4; import org.neo4j.driver.internal.spi.Connection; import org.neo4j.driver.internal.spi.ConnectionProvider; +import org.neo4j.driver.testutil.TestUtil; class NetworkSessionTest { private static final String DATABASE = "neo4j"; @@ -197,6 +198,13 @@ void releasesOpenConnectionUsedForRunWhenSessionIsClosed() { inOrder.verify(connection, atLeastOnce()).release(); } + @Test + void resetDoesNothingWhenNoTransactionAndNoConnection() { + TestUtil.await(session.resetAsync()); + + verify(connectionProvider, never()).acquireConnection(any(ConnectionContext.class)); + } + @Test void closeWithoutConnection() { NetworkSession session = newSession(connectionProvider); @@ -312,6 +320,22 @@ void testPassingNoBookmarkShouldRetainBookmark() { assertThat(session.lastBookmarks(), equalTo(bookmarks)); } + @Test + void connectionShouldBeResetAfterSessionReset() { + String query = "RETURN 1"; + setupSuccessfulRunAndPull(connection, query); + + run(session, query); + + InOrder connectionInOrder = inOrder(connection); + connectionInOrder.verify(connection, never()).reset(); + connectionInOrder.verify(connection).release(); + + await(session.resetAsync()); + connectionInOrder.verify(connection).reset(); + connectionInOrder.verify(connection, never()).release(); + } + @Test void shouldHaveEmptyLastBookmarksInitially() { assertTrue(session.lastBookmarks().isEmpty()); @@ -438,6 +462,18 @@ void shouldBeginTxAfterRunFailureToAcquireConnection() { verifyBeginTx(connection); } + @Test + void shouldMarkTransactionAsTerminatedAndThenResetConnectionOnReset() { + UnmanagedTransaction tx = beginTransaction(session); + + assertTrue(tx.isOpen()); + verify(connection, never()).reset(); + + TestUtil.await(session.resetAsync()); + + verify(connection).reset(); + } + private static ResultCursor run(NetworkSession session, String query) { return await(session.runAsync(new Query(query), TransactionConfig.empty())); } diff --git a/driver/src/test/java/org/neo4j/driver/internal/reactive/InternalRxSessionTest.java b/driver/src/test/java/org/neo4j/driver/internal/reactive/InternalRxSessionTest.java index e43b91bdf9..0afbbc28d6 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/reactive/InternalRxSessionTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/reactive/InternalRxSessionTest.java @@ -295,6 +295,22 @@ void shouldDelegateBookmarks() { verifyNoMoreInteractions(session); } + @Test + void shouldDelegateReset() throws Throwable { + // Given + NetworkSession session = mock(NetworkSession.class); + when(session.resetAsync()).thenReturn(completedWithNull()); + InternalRxSession rxSession = new InternalRxSession(session); + + // When + Publisher mono = rxSession.reset(); + + // Then + StepVerifier.create(mono).verifyComplete(); + verify(session).resetAsync(); + verifyNoMoreInteractions(session); + } + @Test void shouldDelegateClose() { // Given diff --git a/driver/src/test/java/org/neo4j/driver/testutil/DatabaseExtension.java b/driver/src/test/java/org/neo4j/driver/testutil/DatabaseExtension.java index 8bbba3e005..2ba6c271a4 100644 --- a/driver/src/test/java/org/neo4j/driver/testutil/DatabaseExtension.java +++ b/driver/src/test/java/org/neo4j/driver/testutil/DatabaseExtension.java @@ -182,6 +182,16 @@ public String addImportFile(String prefix, String suffix, String contents) throw return String.format("file:///%s", tmpFile.getName()); } + public DatabaseExtension installPlugin(MountableFile plugin) { + if (driver != null) driver.close(); + if (neo4jContainer != null) neo4jContainer.close(); + neo4jContainer = setupNeo4jContainer(cert, key, defaultConfig).withPlugins(plugin); + neo4jContainer.start(); + driver = GraphDatabase.driver(boltUri, authToken); + waitForBoltAvailability(); + return this; + } + public URI uri() { return boltUri; } diff --git a/driver/src/test/java/org/neo4j/driver/testutil/TestUtil.java b/driver/src/test/java/org/neo4j/driver/testutil/TestUtil.java index a9320edb05..35351e9012 100644 --- a/driver/src/test/java/org/neo4j/driver/testutil/TestUtil.java +++ b/driver/src/test/java/org/neo4j/driver/testutil/TestUtil.java @@ -24,6 +24,7 @@ import static java.util.stream.Collectors.toList; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.fail; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.argThat; import static org.mockito.Mockito.doAnswer; @@ -61,7 +62,9 @@ import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; +import java.util.function.BooleanSupplier; import java.util.function.Predicate; import org.mockito.ArgumentMatcher; import org.mockito.invocation.InvocationOnMock; @@ -485,6 +488,41 @@ public static void interruptWhenInWaitingState(Thread thread) { }); } + public static int activeQueryCount(DatabaseExtension db) { + return activeQueryNames(db).size(); + } + + public static List activeQueryNames(DatabaseExtension db) { + try (Session session = db.driver().session()) { + final var query = db.isNeo4j44OrEarlier() + ? "CALL dbms.listQueries() YIELD query RETURN query" + : "SHOW TRANSACTIONS YIELD currentQuery"; + return session.run(query).stream() + .map(record -> record.get(0).asString()) + .filter(q -> !q.contains(query)) // do not include show transactions query + .collect(toList()); + } + } + + public static void awaitCondition(BooleanSupplier condition) { + awaitCondition(condition, DEFAULT_WAIT_TIME_MS, MILLISECONDS); + } + + public static void awaitCondition(BooleanSupplier condition, long value, TimeUnit unit) { + long deadline = System.currentTimeMillis() + unit.toMillis(value); + while (!condition.getAsBoolean()) { + if (System.currentTimeMillis() > deadline) { + fail("Condition was not met in time"); + } + try { + MILLISECONDS.sleep(100); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + fail("Interrupted while waiting"); + } + } + } + public static String randomString(int size) { StringBuilder sb = new StringBuilder(size); ThreadLocalRandom random = ThreadLocalRandom.current(); diff --git a/driver/src/test/resources/longRunningStatement.jar b/driver/src/test/resources/longRunningStatement.jar index a5416792a7..6bd9612a6a 100644 Binary files a/driver/src/test/resources/longRunningStatement.jar and b/driver/src/test/resources/longRunningStatement.jar differ diff --git a/test-procedures/README.md b/test-procedures/README.md new file mode 100644 index 0000000000..f238bb6915 --- /dev/null +++ b/test-procedures/README.md @@ -0,0 +1,8 @@ +# Test Procedures + +This module is not part of the driver build! + +Run this build manually to generate a jar file with procedures used in testing. + +1. `mvn clean package` +2. `cp test-procedures/target/test-procedures-5.7-SNAPSHOT.jar driver/src/test/resources/longRunningStatement.jar` diff --git a/test-procedures/pom.xml b/test-procedures/pom.xml new file mode 100644 index 0000000000..5408f3fd30 --- /dev/null +++ b/test-procedures/pom.xml @@ -0,0 +1,54 @@ + + 4.0.0 + + org.neo4j.driver + test-procedures + 5.7-SNAPSHOT + + + UTF-8 + UTF-8 + 11 + + 'v'yyyyMMdd-HHmm + true + 4.4.18 + + + + + org.neo4j + neo4j + ${neo4j.version} + provided + + + + + + + maven-compiler-plugin + 3.10.1 + + + com.diffplug.spotless + spotless-maven-plugin + 2.23.0 + + + + check + + + + + + + + + + + + \ No newline at end of file diff --git a/test-procedures/src/main/java/org/neo4j/driver/LongRunningProcedures.java b/test-procedures/src/main/java/org/neo4j/driver/LongRunningProcedures.java new file mode 100644 index 0000000000..e2cef1a25e --- /dev/null +++ b/test-procedures/src/main/java/org/neo4j/driver/LongRunningProcedures.java @@ -0,0 +1,84 @@ +/* + * Copyright (c) "Neo4j" + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver; + +import java.util.stream.LongStream; +import java.util.stream.Stream; +import org.neo4j.graphdb.Transaction; +import org.neo4j.logging.Log; +import org.neo4j.procedure.Context; +import org.neo4j.procedure.Name; +import org.neo4j.procedure.Procedure; + +public class LongRunningProcedures { + @Context + public Log log; + + @Context + public Transaction tx; + + public LongRunningProcedures() {} + + @Procedure("test.driver.longRunningStatement") + public void longRunningStatement(@Name("seconds") long seconds) { + final long start = System.currentTimeMillis(); + + while (System.currentTimeMillis() <= start + seconds * 1000L) { + long count = 0; + try { + Thread.sleep(100L); + count = nodeCount(); // Fails if transaction is terminated + } catch (InterruptedException e) { + this.log.error(e.getMessage() + " (last node count " + count + ")", e); + } + } + } + + @Procedure("test.driver.longStreamingResult") + public Stream longStreamingResult(@Name("seconds") long seconds) { + return LongStream.range(0L, seconds * 100L) + .map((x) -> { + if (x == 0L) { + return x; + } else { + try { + Thread.sleep(10L); + } catch (InterruptedException var4) { + this.log.error(var4.getMessage(), var4); + } + + nodeCount(); // Fails if transaction is terminated + return x; + } + }) + .mapToObj(l -> new Output(l)); + } + + private long nodeCount() { + return tx.getAllNodes().stream().count(); + } + + public static class Output { + public final Long out; + + public Output(long value) { + this.out = value; + } + } +}