Skip to content

Commit b9ab50e

Browse files
committed
Ensure reactive transaction function gets rolled back on cancellation
This update also ensures that session cancellation does not result in multiple rollback attempts.
1 parent a44346c commit b9ab50e

File tree

2 files changed

+118
-132
lines changed

2 files changed

+118
-132
lines changed

driver/src/main/java/org/neo4j/driver/internal/async/UnmanagedTransaction.java

Lines changed: 117 additions & 131 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@
2222
import java.util.EnumSet;
2323
import java.util.concurrent.CompletionException;
2424
import java.util.concurrent.CompletionStage;
25+
import java.util.concurrent.locks.Lock;
26+
import java.util.concurrent.locks.ReentrantLock;
2527
import java.util.function.BiFunction;
2628

2729
import org.neo4j.driver.Bookmark;
@@ -41,6 +43,7 @@
4143

4244
import static org.neo4j.driver.internal.util.Futures.completedWithNull;
4345
import static org.neo4j.driver.internal.util.Futures.failedFuture;
46+
import static org.neo4j.driver.internal.util.LockUtil.executeWithLock;
4447

4548
public class UnmanagedTransaction
4649
{
@@ -62,66 +65,18 @@ private enum State
6265
ROLLED_BACK
6366
}
6467

65-
/**
66-
* This is a holder so that we can have ony the state volatile in the tx without having to synchronize the whole block.
67-
*/
68-
private static final class StateHolder
69-
{
70-
private static final EnumSet<State> OPEN_STATES = EnumSet.of( State.ACTIVE, State.TERMINATED );
71-
private static final StateHolder ACTIVE_HOLDER = new StateHolder( State.ACTIVE, null );
72-
private static final StateHolder COMMITTED_HOLDER = new StateHolder( State.COMMITTED, null );
73-
private static final StateHolder ROLLED_BACK_HOLDER = new StateHolder( State.ROLLED_BACK, null );
74-
75-
/**
76-
* The actual state.
77-
*/
78-
final State value;
79-
80-
/**
81-
* If this holder contains a state of {@link State#TERMINATED}, this represents the cause if any.
82-
*/
83-
final Throwable causeOfTermination;
84-
85-
static StateHolder of( State value )
86-
{
87-
switch ( value )
88-
{
89-
case ACTIVE:
90-
return ACTIVE_HOLDER;
91-
case COMMITTED:
92-
return COMMITTED_HOLDER;
93-
case ROLLED_BACK:
94-
return ROLLED_BACK_HOLDER;
95-
case TERMINATED:
96-
default:
97-
throw new IllegalArgumentException( "Cannot provide a default state holder for state " + value );
98-
}
99-
}
100-
101-
static StateHolder terminatedWith( Throwable cause )
102-
{
103-
return new StateHolder( State.TERMINATED, cause );
104-
}
105-
106-
private StateHolder( State value, Throwable causeOfTermination )
107-
{
108-
this.value = value;
109-
this.causeOfTermination = causeOfTermination;
110-
}
111-
112-
boolean isOpen()
113-
{
114-
return OPEN_STATES.contains( this.value );
115-
}
116-
}
68+
private static final EnumSet<State> OPEN_STATES = EnumSet.of( State.ACTIVE, State.TERMINATED );
11769

11870
private final Connection connection;
11971
private final BoltProtocol protocol;
12072
private final BookmarkHolder bookmarkHolder;
12173
private final ResultCursorsHolder resultCursors;
12274
private final long fetchSize;
123-
124-
private volatile StateHolder state = StateHolder.of( State.ACTIVE );
75+
private final Lock lock = new ReentrantLock();
76+
private State state = State.ACTIVE;
77+
private CompletionStage<Void> commitStage;
78+
private CompletionStage<Void> rollbackStage;
79+
private Throwable causeOfTermination;
12580

12681
public UnmanagedTransaction( Connection connection, BookmarkHolder bookmarkHolder, long fetchSize )
12782
{
@@ -164,50 +119,61 @@ else if ( beginError instanceof ConnectionReadTimeoutException )
164119

165120
public CompletionStage<Void> closeAsync()
166121
{
167-
if ( isOpen() )
168-
{
169-
return rollbackAsync();
170-
}
171-
else
172-
{
173-
return completedWithNull();
174-
}
122+
return executeWithLock( lock, () -> isOpen() ? rollbackAsync() : completedWithNull() );
175123
}
176124

177125
public CompletionStage<Void> commitAsync()
178126
{
179-
if ( state.value == State.COMMITTED )
180-
{
181-
return failedFuture( new ClientException( "Can't commit, transaction has been committed" ) );
182-
}
183-
else if ( state.value == State.ROLLED_BACK )
127+
return executeWithLock( lock, () ->
184128
{
185-
return failedFuture( new ClientException( "Can't commit, transaction has been rolled back" ) );
186-
}
187-
else
188-
{
189-
return resultCursors.retrieveNotConsumedError()
190-
.thenCompose( error -> doCommitAsync( error ).handle( handleCommitOrRollback( error ) ) )
191-
.whenComplete( ( ignore, error ) -> handleTransactionCompletion( true, error ) );
192-
}
129+
if ( state == State.COMMITTED )
130+
{
131+
return failedFuture( new ClientException( "Can't commit, transaction has been committed" ) );
132+
}
133+
else if ( state == State.ROLLED_BACK )
134+
{
135+
return failedFuture( new ClientException( "Can't commit, transaction has been rolled back" ) );
136+
}
137+
else if ( commitStage != null )
138+
{
139+
return commitStage;
140+
}
141+
else
142+
{
143+
CompletionStage<Void> stage = resultCursors.retrieveNotConsumedError()
144+
.thenCompose( error -> doCommitAsync( error ).handle( handleCommitOrRollback( error ) ) );
145+
commitStage = stage.whenComplete( ( ignore, error ) -> releaseConnection( error ) );
146+
stage.whenComplete( ( ignored, error ) -> updateStateAfterCommitOrRollback( true, error ) );
147+
return stage;
148+
}
149+
} );
193150
}
194151

195152
public CompletionStage<Void> rollbackAsync()
196153
{
197-
if ( state.value == State.COMMITTED )
154+
return executeWithLock( lock, () ->
198155
{
199-
return failedFuture( new ClientException( "Can't rollback, transaction has been committed" ) );
200-
}
201-
else if ( state.value == State.ROLLED_BACK )
202-
{
203-
return failedFuture( new ClientException( "Can't rollback, transaction has been rolled back" ) );
204-
}
205-
else
206-
{
207-
return resultCursors.retrieveNotConsumedError()
208-
.thenCompose( error -> doRollbackAsync().handle( handleCommitOrRollback( error ) ) )
209-
.whenComplete( ( ignore, error ) -> handleTransactionCompletion( false, error ) );
210-
}
156+
if ( state == State.COMMITTED )
157+
{
158+
return failedFuture( new ClientException( "Can't rollback, transaction has been committed" ) );
159+
}
160+
else if ( state == State.ROLLED_BACK )
161+
{
162+
return failedFuture( new ClientException( "Can't rollback, transaction has been rolled back" ) );
163+
}
164+
else if ( rollbackStage != null )
165+
{
166+
return rollbackStage;
167+
}
168+
else
169+
{
170+
CompletionStage<Void> stage = resultCursors.retrieveNotConsumedError()
171+
.thenCompose( error -> doRollbackAsync().handle( handleCommitOrRollback( error ) ) );
172+
rollbackStage = stage.whenComplete( ( ignore, error ) -> releaseConnection( error ) );
173+
stage.whenComplete( ( ignored, error ) -> updateStateAfterCommitOrRollback( false, error ) );
174+
return stage;
175+
}
176+
} );
211177
}
212178

213179
public CompletionStage<ResultCursor> runAsync( Query query )
@@ -230,22 +196,27 @@ public CompletionStage<RxResultCursor> runRx(Query query)
230196

231197
public boolean isOpen()
232198
{
233-
return state.isOpen();
199+
State currentState = executeWithLock( lock, () -> state );
200+
return OPEN_STATES.contains( currentState );
234201
}
235202

236203
public void markTerminated( Throwable cause )
237204
{
238-
if ( state.value == State.TERMINATED )
205+
executeWithLock( lock, () ->
239206
{
240-
if ( state.causeOfTermination != null )
207+
if ( state == State.TERMINATED )
241208
{
242-
addSuppressedWhenNotCaptured( state.causeOfTermination, cause );
209+
if ( causeOfTermination != null )
210+
{
211+
addSuppressedWhenNotCaptured( causeOfTermination, cause );
212+
}
243213
}
244-
}
245-
else
246-
{
247-
state = StateHolder.terminatedWith( cause );
248-
}
214+
else
215+
{
216+
state = State.TERMINATED;
217+
causeOfTermination = cause;
218+
}
219+
} );
249220
}
250221

251222
private void addSuppressedWhenNotCaptured( Throwable currentCause, Throwable newCause )
@@ -267,39 +238,40 @@ public Connection connection()
267238

268239
private void ensureCanRunQueries()
269240
{
270-
if ( state.value == State.COMMITTED )
271-
{
272-
throw new ClientException( "Cannot run more queries in this transaction, it has been committed" );
273-
}
274-
else if ( state.value == State.ROLLED_BACK )
275-
{
276-
throw new ClientException( "Cannot run more queries in this transaction, it has been rolled back" );
277-
}
278-
else if ( state.value == State.TERMINATED )
241+
executeWithLock( lock, () ->
279242
{
280-
throw new ClientException( "Cannot run more queries in this transaction, " +
281-
"it has either experienced an fatal error or was explicitly terminated", state.causeOfTermination );
282-
}
243+
if ( state == State.COMMITTED )
244+
{
245+
throw new ClientException( "Cannot run more queries in this transaction, it has been committed" );
246+
}
247+
else if ( state == State.ROLLED_BACK )
248+
{
249+
throw new ClientException( "Cannot run more queries in this transaction, it has been rolled back" );
250+
}
251+
else if ( state == State.TERMINATED )
252+
{
253+
throw new ClientException( "Cannot run more queries in this transaction, " +
254+
"it has either experienced an fatal error or was explicitly terminated", causeOfTermination );
255+
}
256+
} );
283257
}
284258

285259
private CompletionStage<Void> doCommitAsync( Throwable cursorFailure )
286260
{
287-
if ( state.value == State.TERMINATED )
288-
{
289-
return failedFuture( new ClientException( "Transaction can't be committed. " +
290-
"It has been rolled back either because of an error or explicit termination",
291-
cursorFailure != state.causeOfTermination ? state.causeOfTermination : null ) );
292-
}
293-
return protocol.commitTransaction( connection ).thenAccept( bookmarkHolder::setBookmark );
261+
ClientException exception = executeWithLock(
262+
lock, () -> state == State.TERMINATED
263+
? new ClientException( "Transaction can't be committed. " +
264+
"It has been rolled back either because of an error or explicit termination",
265+
cursorFailure != causeOfTermination ? causeOfTermination : null )
266+
: null
267+
);
268+
return exception != null ? failedFuture( exception ) : protocol.commitTransaction( connection ).thenAccept( bookmarkHolder::setBookmark );
294269
}
295270

296271
private CompletionStage<Void> doRollbackAsync()
297272
{
298-
if ( state.value == State.TERMINATED )
299-
{
300-
return completedWithNull();
301-
}
302-
return protocol.rollbackTransaction( connection );
273+
State currentState = executeWithLock( lock, () -> state );
274+
return currentState == State.TERMINATED ? completedWithNull() : protocol.rollbackTransaction( connection );
303275
}
304276

305277
private static BiFunction<Void,Throwable,Void> handleCommitOrRollback( Throwable cursorFailure )
@@ -315,17 +287,8 @@ private static BiFunction<Void,Throwable,Void> handleCommitOrRollback( Throwable
315287
};
316288
}
317289

318-
private void handleTransactionCompletion( boolean commitOnSuccess, Throwable throwable )
290+
private void releaseConnection( Throwable throwable )
319291
{
320-
if ( commitOnSuccess && throwable == null )
321-
{
322-
state = StateHolder.of( State.COMMITTED );
323-
}
324-
else
325-
{
326-
state = StateHolder.of( State.ROLLED_BACK );
327-
}
328-
329292
if ( throwable instanceof AuthorizationExpiredException )
330293
{
331294
connection.terminateAndRelease( AuthorizationExpiredException.DESCRIPTION );
@@ -339,4 +302,27 @@ else if ( throwable instanceof ConnectionReadTimeoutException )
339302
connection.release(); // release in background
340303
}
341304
}
305+
306+
private void updateStateAfterCommitOrRollback( boolean commitAttempt, Throwable throwable )
307+
{
308+
executeWithLock( lock, () ->
309+
{
310+
if ( commitAttempt && throwable == null )
311+
{
312+
state = State.COMMITTED;
313+
}
314+
else
315+
{
316+
state = State.ROLLED_BACK;
317+
}
318+
if ( commitAttempt )
319+
{
320+
commitStage = null;
321+
}
322+
else
323+
{
324+
rollbackStage = null;
325+
}
326+
} );
327+
}
342328
}

driver/src/main/java/org/neo4j/driver/internal/reactive/InternalRxSession.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ public <T> Publisher<T> writeTransaction( RxTransactionWork<? extends Publisher<
130130
private <T> Publisher<T> runTransaction( AccessMode mode, RxTransactionWork<? extends Publisher<T>> work, TransactionConfig config )
131131
{
132132
Flux<T> repeatableWork = Flux.usingWhen( beginTransaction( mode, config ), work::execute,
133-
InternalRxTransaction::commitIfOpen, ( tx, error ) -> tx.close(), null );
133+
InternalRxTransaction::commitIfOpen, ( tx, error ) -> tx.close(), InternalRxTransaction::close );
134134
return session.retryLogic().retryRx( repeatableWork );
135135
}
136136

0 commit comments

Comments
 (0)