diff --git a/driver/src/main/java/org/neo4j/driver/internal/ExplicitTransaction.java b/driver/src/main/java/org/neo4j/driver/internal/ExplicitTransaction.java index 58c589bedb..625f19f3a7 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/ExplicitTransaction.java +++ b/driver/src/main/java/org/neo4j/driver/internal/ExplicitTransaction.java @@ -122,17 +122,30 @@ public void close() { if ( state == State.MARKED_SUCCESS ) { - conn.run( "COMMIT", Collections.emptyMap(), Collector.NO_OP ); - conn.pullAll( new BookmarkCollector( this ) ); - conn.sync(); - state = State.SUCCEEDED; + try + { + conn.run( "COMMIT", Collections.emptyMap(), Collector.NO_OP ); + conn.pullAll( new BookmarkCollector( this ) ); + conn.sync(); + state = State.SUCCEEDED; + } + catch( Throwable e ) + { + // failed to commit + try + { + rollbackTx(); + } + catch( Throwable ignored ) + { + // best effort. + } + throw e; + } } else if ( state == State.MARKED_FAILED || state == State.ACTIVE ) { - conn.run( "ROLLBACK", Collections.emptyMap(), Collector.NO_OP ); - conn.pullAll( new BookmarkCollector( this ) ); - conn.sync(); - state = State.ROLLED_BACK; + rollbackTx(); } } } @@ -142,6 +155,14 @@ else if ( state == State.MARKED_FAILED || state == State.ACTIVE ) } } + private void rollbackTx() + { + conn.run( "ROLLBACK", Collections.emptyMap(), Collector.NO_OP ); + conn.pullAll( new BookmarkCollector( this ) ); + conn.sync(); + state = State.ROLLED_BACK; + } + @Override @SuppressWarnings( "unchecked" ) public StatementResult run( String statementText, Value statementParameters ) diff --git a/driver/src/test/java/org/neo4j/driver/v1/integration/SessionIT.java b/driver/src/test/java/org/neo4j/driver/v1/integration/SessionIT.java index 4b8f945616..d45bf141d8 100644 --- a/driver/src/test/java/org/neo4j/driver/v1/integration/SessionIT.java +++ b/driver/src/test/java/org/neo4j/driver/v1/integration/SessionIT.java @@ -360,6 +360,28 @@ public void shouldMarkTxAsFailedAndDisallowRunAfterSessionReset() } } + @Test + public void shouldAllowMoreTxAfterSessionResetInTx() + { + // Given + try( Driver driver = GraphDatabase.driver( neo4j.uri() ); + Session session = driver.session() ) + { + try( Transaction tx = session.beginTransaction() ) + { + // When reset the state of this session + session.reset(); + } + + // Then can run more Tx + try( Transaction tx = session.beginTransaction() ) + { + tx.run("Return 2"); + tx.success(); + } + } + } + @Test public void shouldCloseSessionWhenDriverIsClosed() throws Throwable { diff --git a/driver/src/test/java/org/neo4j/driver/v1/integration/TransactionIT.java b/driver/src/test/java/org/neo4j/driver/v1/integration/TransactionIT.java index fee670cf86..fcb915339a 100644 --- a/driver/src/test/java/org/neo4j/driver/v1/integration/TransactionIT.java +++ b/driver/src/test/java/org/neo4j/driver/v1/integration/TransactionIT.java @@ -18,6 +18,7 @@ */ package org.neo4j.driver.v1.integration; +import org.junit.Assert; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; @@ -38,6 +39,7 @@ import static org.hamcrest.MatcherAssert.assertThat; import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; public class TransactionIT { @@ -303,4 +305,60 @@ public void run() long nodes = result.single().get( "count(n)" ).asLong(); assertThat( nodes, equalTo( 1L ) ); } + + @Test + public void shouldRollBackTxIfErrorWithoutConsume() throws Throwable + { + // Given + int failingPos = 0; + try ( Transaction tx = session.beginTransaction() ) + { + StatementResult result = tx.run( "invalid" ); // send run, pull_all + tx.success(); + failingPos = 1; // fail to send? + } // When send run_commit, and pull_all + + // Then error and should also send ack_fail, roll_back and pull_all + catch ( ClientException e ) + { + failingPos = 2; // fail in tx.close in sync? + try ( Transaction tx = session.beginTransaction() ) + { + StatementResult cursor = tx.run( "RETURN 1" ); + int val = cursor.single().get( "1" ).asInt(); + + + assertThat( val, equalTo( 1 ) ); + } + } + assertThat( failingPos, equalTo( 2 ) ); + } + + @Test + public void shouldRollBackTxIfErrorWithConsume() throws Throwable + { + + // Given + try ( Transaction tx = session.beginTransaction() ) + { + StatementResult result = tx.run( "invalid" ); + tx.success(); + + // When + result.consume(); // run, pull_all + fail( "Should fail tx due to syntax error" ); + } // ack_fail, roll_back, pull_all + // Then + catch ( ClientException e ) + { + try ( Transaction tx = session.beginTransaction() ) + { + StatementResult cursor = tx.run( "RETURN 1" ); + int val = cursor.single().get( "1" ).asInt(); + + Assert.assertThat( val, equalTo( 1 ) ); + } + } + + } }