diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/inbound/InboundMessageDispatcher.java b/driver/src/main/java/org/neo4j/driver/internal/async/inbound/InboundMessageDispatcher.java index 8cefed3be9..6e296c83d2 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/inbound/InboundMessageDispatcher.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/inbound/InboundMessageDispatcher.java @@ -164,6 +164,10 @@ public void handleIgnoredMessage() { handler.onFailure(error); } + public HandlerHook getBeforeLastHandlerHook() { + return this.beforeLastHandlerHook; + } + private Optional getPendingResetHandler() { return handlers.stream() .filter(h -> h instanceof ResetResponseHandler) diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/pool/NettyChannelHealthChecker.java b/driver/src/main/java/org/neo4j/driver/internal/async/pool/NettyChannelHealthChecker.java index 062f0255b2..af3b916a17 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/pool/NettyChannelHealthChecker.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/pool/NettyChannelHealthChecker.java @@ -28,10 +28,14 @@ import io.netty.util.concurrent.Promise; import io.netty.util.concurrent.PromiseNotifier; import java.time.Clock; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; import org.neo4j.driver.Logger; import org.neo4j.driver.Logging; import org.neo4j.driver.internal.async.connection.AuthorizationStateListener; +import org.neo4j.driver.internal.async.connection.ChannelAttributes; +import org.neo4j.driver.internal.async.inbound.ConnectionReadTimeoutHandler; +import org.neo4j.driver.internal.async.inbound.InboundMessageDispatcher; import org.neo4j.driver.internal.handlers.PingResponseHandler; import org.neo4j.driver.internal.messaging.request.ResetMessage; import org.neo4j.driver.internal.messaging.v51.BoltProtocolV51; @@ -164,8 +168,24 @@ private boolean hasBeenIdleForTooLong(Channel channel) { private Future ping(Channel channel) { Promise result = channel.eventLoop().newPromise(); - messageDispatcher(channel).enqueue(new PingResponseHandler(result, channel, logging)); + var messageDispatcher = messageDispatcher(channel); + messageDispatcher.enqueue(new PingResponseHandler(result, channel, logging)); + attachConnectionReadTimeoutHandler(channel, messageDispatcher); channel.writeAndFlush(ResetMessage.RESET, channel.voidPromise()); return result; } + + private void attachConnectionReadTimeoutHandler(Channel channel, InboundMessageDispatcher messageDispatcher) { + ChannelAttributes.connectionReadTimeout(channel).ifPresent(connectionReadTimeout -> { + var connectionReadTimeoutHandler = + new ConnectionReadTimeoutHandler(connectionReadTimeout, TimeUnit.SECONDS); + channel.pipeline().addFirst(connectionReadTimeoutHandler); + log.debug("Added ConnectionReadTimeoutHandler"); + messageDispatcher.setBeforeLastHandlerHook((messageType) -> { + channel.pipeline().remove(connectionReadTimeoutHandler); + messageDispatcher.setBeforeLastHandlerHook(null); + log.debug("Removed ConnectionReadTimeoutHandler"); + }); + }); + } } diff --git a/driver/src/test/java/org/neo4j/driver/internal/async/pool/NettyChannelHealthCheckerTest.java b/driver/src/test/java/org/neo4j/driver/internal/async/pool/NettyChannelHealthCheckerTest.java index b44553e3ac..b99d7ff599 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/async/pool/NettyChannelHealthCheckerTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/async/pool/NettyChannelHealthCheckerTest.java @@ -21,6 +21,9 @@ import static org.hamcrest.Matchers.is; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.BDDMockito.given; import static org.mockito.BDDMockito.then; @@ -29,6 +32,7 @@ import static org.mockito.Mockito.times; import static org.neo4j.driver.internal.async.connection.ChannelAttributes.authContext; import static org.neo4j.driver.internal.async.connection.ChannelAttributes.setAuthContext; +import static org.neo4j.driver.internal.async.connection.ChannelAttributes.setConnectionReadTimeout; import static org.neo4j.driver.internal.async.connection.ChannelAttributes.setCreationTimestamp; import static org.neo4j.driver.internal.async.connection.ChannelAttributes.setLastUsedTimestamp; import static org.neo4j.driver.internal.async.connection.ChannelAttributes.setMessageDispatcher; @@ -54,6 +58,7 @@ import org.junit.jupiter.params.provider.MethodSource; import org.neo4j.driver.AuthTokenManager; import org.neo4j.driver.AuthTokens; +import org.neo4j.driver.internal.async.inbound.ConnectionReadTimeoutHandler; import org.neo4j.driver.internal.async.inbound.InboundMessageDispatcher; import org.neo4j.driver.internal.messaging.BoltProtocolVersion; import org.neo4j.driver.internal.messaging.request.ResetMessage; @@ -248,6 +253,65 @@ void shouldKeepIdleConnectionWhenPingSucceeds() { testPing(true); } + @Test + void shouldHandlePingWithConnectionReceiveTimeout() { + var idleTimeBeforeConnectionTest = 1000; + var connectionReadTimeout = 60L; + var settings = new PoolSettings( + DEFAULT_MAX_CONNECTION_POOL_SIZE, + DEFAULT_CONNECTION_ACQUISITION_TIMEOUT, + NOT_CONFIGURED, + idleTimeBeforeConnectionTest); + var clock = Clock.systemUTC(); + var healthChecker = newHealthChecker(settings, clock); + + setCreationTimestamp(channel, clock.millis()); + setConnectionReadTimeout(channel, connectionReadTimeout); + setLastUsedTimestamp(channel, clock.millis() - idleTimeBeforeConnectionTest * 2); + + var healthy = healthChecker.isHealthy(channel); + channel.runPendingTasks(); + + var firstElementOnPipeline = channel.pipeline().first(); + assertInstanceOf(ConnectionReadTimeoutHandler.class, firstElementOnPipeline); + assertNotNull(dispatcher.getBeforeLastHandlerHook()); + var readTimeoutHandler = (ConnectionReadTimeoutHandler) firstElementOnPipeline; + assertEquals(connectionReadTimeout * 1000L, readTimeoutHandler.getReaderIdleTimeInMillis()); + assertEquals(ResetMessage.RESET, single(channel.outboundMessages())); + assertFalse(healthy.isDone()); + + dispatcher.handleSuccessMessage(Collections.emptyMap()); + assertThat(await(healthy), is(true)); + assertNull(channel.pipeline().first()); + assertNull(dispatcher.getBeforeLastHandlerHook()); + } + + @Test + void shouldHandlePingWithoutConnectionReceiveTimeout() { + var idleTimeBeforeConnectionTest = 1000; + var settings = new PoolSettings( + DEFAULT_MAX_CONNECTION_POOL_SIZE, + DEFAULT_CONNECTION_ACQUISITION_TIMEOUT, + NOT_CONFIGURED, + idleTimeBeforeConnectionTest); + var clock = Clock.systemUTC(); + var healthChecker = newHealthChecker(settings, clock); + + setCreationTimestamp(channel, clock.millis()); + setLastUsedTimestamp(channel, clock.millis() - idleTimeBeforeConnectionTest * 2); + + var healthy = healthChecker.isHealthy(channel); + channel.runPendingTasks(); + + assertNull(channel.pipeline().first()); + assertEquals(ResetMessage.RESET, single(channel.outboundMessages())); + assertFalse(healthy.isDone()); + + dispatcher.handleSuccessMessage(Collections.emptyMap()); + assertThat(await(healthy), is(true)); + assertNull(channel.pipeline().first()); + } + @Test void shouldDropIdleConnectionWhenPingFails() { testPing(false);