Skip to content

Improve explicit transaction terminated state handling #1445

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jul 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
/*
* Copyright (c) "Neo4j"
* Neo4j Sweden AB [http://neo4j.com]
*
* This file is part of Neo4j.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.neo4j.driver.exceptions;

import java.io.Serial;

/**
* Indicates that the transaction has been terminated.
* <p>
* Any usage of the terminated transaction and any of its results must be stopped and the transaction must be closed
* explicitly. Moreover, any error in the transaction result(s) should be considered as a transaction termination and
* must be handled in the same way.
* <p>
* The exception will contain a non-null {@link #code()} if it is created based on the server's response. It will not
* have a code if it is generated by the driver.
*
* @since 5.11
*/
public class TransactionTerminatedException extends ClientException {
@Serial
private static final long serialVersionUID = 7639191706067500206L;

/**
* Creates a new instance.
*
* @param message the message
*/
public TransactionTerminatedException(String message) {
super(message);
}

/**
* Creates a new instance.
*
* @param code the code
* @param message the message
*/
public TransactionTerminatedException(String code, String message) {
super(code, message);
}

/**
* Creates a new instance.
*
* @param message the message
* @param cause the cause
*/
public TransactionTerminatedException(String message, Throwable cause) {
super(message, cause);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,22 @@ public boolean isOpen() {
return tx.isOpen();
}

/**
* <b>THIS IS A PRIVATE API</b>
* <p>
* Terminates the transaction by sending the Bolt {@code RESET} message and waiting for its response as long as the
* transaction has not already been terminated, is not closed or closing.
*
* @since 5.11
* @throws org.neo4j.driver.exceptions.ClientException if the transaction is closed or is closing
* @see org.neo4j.driver.exceptions.TransactionTerminatedException
*/
public void terminate() {
Futures.blockingGet(
tx.terminateAsync(),
() -> terminateConnectionOnThreadInterrupt("Thread interrupted while terminating the transaction"));
}

private void terminateConnectionOnThreadInterrupt(String reason) {
tx.connection().terminateAndRelease(reason);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,16 @@
import static org.neo4j.driver.internal.async.connection.ChannelAttributes.poolId;
import static org.neo4j.driver.internal.async.connection.ChannelAttributes.setTerminationReason;
import static org.neo4j.driver.internal.util.Futures.asCompletionStage;
import static org.neo4j.driver.internal.util.LockUtil.executeWithLock;

import io.netty.channel.Channel;
import io.netty.channel.ChannelHandler;
import java.time.Clock;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReentrantLock;
import org.neo4j.driver.Logger;
import org.neo4j.driver.Logging;
import org.neo4j.driver.internal.BoltServerAddress;
Expand All @@ -41,7 +43,12 @@
import org.neo4j.driver.internal.handlers.ResetResponseHandler;
import org.neo4j.driver.internal.messaging.BoltProtocol;
import org.neo4j.driver.internal.messaging.Message;
import org.neo4j.driver.internal.messaging.request.DiscardAllMessage;
import org.neo4j.driver.internal.messaging.request.DiscardMessage;
import org.neo4j.driver.internal.messaging.request.PullAllMessage;
import org.neo4j.driver.internal.messaging.request.PullMessage;
import org.neo4j.driver.internal.messaging.request.ResetMessage;
import org.neo4j.driver.internal.messaging.request.RunWithMetadataMessage;
import org.neo4j.driver.internal.metrics.ListenerEvent;
import org.neo4j.driver.internal.metrics.MetricsListener;
import org.neo4j.driver.internal.spi.Connection;
Expand All @@ -53,6 +60,7 @@
*/
public class NetworkConnection implements Connection {
private final Logger log;
private final Lock lock;
private final Channel channel;
private final InboundMessageDispatcher messageDispatcher;
private final String serverAgent;
Expand All @@ -61,12 +69,13 @@ public class NetworkConnection implements Connection {
private final ExtendedChannelPool channelPool;
private final CompletableFuture<Void> releaseFuture;
private final Clock clock;

private final AtomicReference<Status> status = new AtomicReference<>(Status.OPEN);
private final MetricsListener metricsListener;
private final ListenerEvent<?> inUseEvent;

private final Long connectionReadTimeout;

private Status status = Status.OPEN;
private TerminationAwareStateLockingExecutor terminationAwareStateLockingExecutor;
private ChannelHandler connectionReadTimeoutHandler;

public NetworkConnection(
Expand All @@ -76,6 +85,7 @@ public NetworkConnection(
MetricsListener metricsListener,
Logging logging) {
this.log = logging.getLog(getClass());
this.lock = new ReentrantLock();
this.channel = channel;
this.messageDispatcher = ChannelAttributes.messageDispatcher(channel);
this.serverAgent = ChannelAttributes.serverAgent(channel);
Expand All @@ -93,7 +103,7 @@ public NetworkConnection(

@Override
public boolean isOpen() {
return status.get() == Status.OPEN;
return executeWithLock(lock, () -> status == Status.OPEN);
}

@Override
Expand All @@ -110,52 +120,31 @@ public void disableAutoRead() {
}
}

@Override
public void flush() {
if (verifyOpen(null, null)) {
flushInEventLoop();
}
}

@Override
public void write(Message message, ResponseHandler handler) {
if (verifyOpen(handler, null)) {
if (verifyOpen(handler)) {
writeMessageInEventLoop(message, handler, false);
}
}

@Override
public void write(Message message1, ResponseHandler handler1, Message message2, ResponseHandler handler2) {
if (verifyOpen(handler1, handler2)) {
writeMessagesInEventLoop(message1, handler1, message2, handler2, false);
}
}

@Override
public void writeAndFlush(Message message, ResponseHandler handler) {
if (verifyOpen(handler, null)) {
if (verifyOpen(handler)) {
writeMessageInEventLoop(message, handler, true);
}
}

@Override
public void writeAndFlush(Message message1, ResponseHandler handler1, Message message2, ResponseHandler handler2) {
if (verifyOpen(handler1, handler2)) {
writeMessagesInEventLoop(message1, handler1, message2, handler2, true);
}
}

@Override
public CompletionStage<Void> reset() {
CompletableFuture<Void> result = new CompletableFuture<>();
ResetResponseHandler handler = new ResetResponseHandler(messageDispatcher, result);
public CompletionStage<Void> reset(Throwable throwable) {
var result = new CompletableFuture<Void>();
var handler = new ResetResponseHandler(messageDispatcher, result, throwable);
writeResetMessageIfNeeded(handler, true);
return result;
}

@Override
public CompletionStage<Void> release() {
if (status.compareAndSet(Status.OPEN, Status.RELEASED)) {
if (executeWithLock(lock, () -> updateStateIfOpen(Status.RELEASED))) {
ChannelReleasingResetResponseHandler handler = new ChannelReleasingResetResponseHandler(
channel, channelPool, messageDispatcher, clock, releaseFuture);

Expand All @@ -167,7 +156,7 @@ public CompletionStage<Void> release() {

@Override
public void terminateAndRelease(String reason) {
if (status.compareAndSet(Status.OPEN, Status.TERMINATED)) {
if (executeWithLock(lock, () -> updateStateIfOpen(Status.TERMINATED))) {
setTerminationReason(channel, reason);
asCompletionStage(channel.close())
.exceptionally(throwable -> null)
Expand All @@ -194,6 +183,25 @@ public BoltProtocol protocol() {
return protocol;
}

@Override
public void bindTerminationAwareStateLockingExecutor(TerminationAwareStateLockingExecutor executor) {
executeWithLock(lock, () -> {
if (this.terminationAwareStateLockingExecutor != null) {
throw new IllegalStateException("terminationAwareStateLockingExecutor is already set");
}
this.terminationAwareStateLockingExecutor = executor;
});
}

private boolean updateStateIfOpen(Status newStatus) {
if (Status.OPEN.equals(status)) {
status = newStatus;
return true;
} else {
return false;
}
}

private void writeResetMessageIfNeeded(ResponseHandler resetHandler, boolean isSessionReset) {
channel.eventLoop().execute(() -> {
if (isSessionReset && !isOpen()) {
Expand All @@ -208,73 +216,49 @@ private void writeResetMessageIfNeeded(ResponseHandler resetHandler, boolean isS
});
}

private void flushInEventLoop() {
channel.eventLoop().execute(() -> {
channel.flush();
registerConnectionReadTimeout(channel);
});
}

private void writeMessageInEventLoop(Message message, ResponseHandler handler, boolean flush) {
channel.eventLoop().execute(() -> {
messageDispatcher.enqueue(handler);

if (flush) {
channel.writeAndFlush(message).addListener(future -> registerConnectionReadTimeout(channel));
} else {
channel.write(message, channel.voidPromise());
}
});
}

private void writeMessagesInEventLoop(
Message message1, ResponseHandler handler1, Message message2, ResponseHandler handler2, boolean flush) {
channel.eventLoop().execute(() -> {
messageDispatcher.enqueue(handler1);
messageDispatcher.enqueue(handler2);

channel.write(message1, channel.voidPromise());

if (flush) {
channel.writeAndFlush(message2).addListener(future -> registerConnectionReadTimeout(channel));
} else {
channel.write(message2, channel.voidPromise());
}
});
channel.eventLoop()
.execute(() -> terminationAwareStateLockingExecutor(message).execute(causeOfTermination -> {
if (causeOfTermination == null) {
messageDispatcher.enqueue(handler);

if (flush) {
channel.writeAndFlush(message)
.addListener(future -> registerConnectionReadTimeout(channel));
} else {
channel.write(message, channel.voidPromise());
}
} else {
handler.onFailure(causeOfTermination);
}
}));
}

private void setAutoRead(boolean value) {
channel.config().setAutoRead(value);
}

private boolean verifyOpen(ResponseHandler handler1, ResponseHandler handler2) {
Status connectionStatus = this.status.get();
switch (connectionStatus) {
case OPEN:
return true;
case RELEASED:
private boolean verifyOpen(ResponseHandler handler) {
var connectionStatus = executeWithLock(lock, () -> status);
return switch (connectionStatus) {
case OPEN -> true;
case RELEASED -> {
Exception error =
new IllegalStateException("Connection has been released to the pool and can't be used");
if (handler1 != null) {
handler1.onFailure(error);
if (handler != null) {
handler.onFailure(error);
}
if (handler2 != null) {
handler2.onFailure(error);
}
return false;
case TERMINATED:
yield false;
}
case TERMINATED -> {
Exception terminatedError =
new IllegalStateException("Connection has been terminated and can't be used");
if (handler1 != null) {
handler1.onFailure(terminatedError);
}
if (handler2 != null) {
handler2.onFailure(terminatedError);
if (handler != null) {
handler.onFailure(terminatedError);
}
return false;
default:
throw new IllegalStateException("Unknown status: " + connectionStatus);
}
yield false;
}
};
}

private void registerConnectionReadTimeout(Channel channel) {
Expand All @@ -295,6 +279,25 @@ private void registerConnectionReadTimeout(Channel channel) {
}
}

private TerminationAwareStateLockingExecutor terminationAwareStateLockingExecutor(Message message) {
var result = (TerminationAwareStateLockingExecutor) consumer -> consumer.accept(null);
if (isQueryMessage(message)) {
var lockingExecutor = executeWithLock(lock, () -> this.terminationAwareStateLockingExecutor);
if (lockingExecutor != null) {
result = lockingExecutor;
}
}
return result;
}

private boolean isQueryMessage(Message message) {
return message instanceof RunWithMetadataMessage
|| message instanceof PullMessage
|| message instanceof PullAllMessage
|| message instanceof DiscardMessage
|| message instanceof DiscardAllMessage;
}

private enum Status {
OPEN,
RELEASED,
Expand Down
Loading