From 598ab4cca4b9096301ce186f46ea26348c7ed49b Mon Sep 17 00:00:00 2001 From: Dmitriy Tverdiakov <11927660+injectives@users.noreply.github.com> Date: Mon, 29 Jan 2024 15:00:58 +0000 Subject: [PATCH] Fix Bolt handshake write handling and timeout management (#1528) --- .../connection/ChannelConnectedListener.java | 16 +++++++++++-- .../connection/ChannelConnectorImpl.java | 6 ++++- .../org/neo4j/driver/GraphDatabaseTest.java | 2 ++ .../integration/ChannelConnectorImplIT.java | 2 ++ .../driver/integration/EncryptionIT.java | 8 +++---- .../ChannelConnectedListenerTest.java | 23 +++++++++++++++++++ 6 files changed, 50 insertions(+), 7 deletions(-) diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/connection/ChannelConnectedListener.java b/driver/src/main/java/org/neo4j/driver/internal/async/connection/ChannelConnectedListener.java index 0025574fc5..d55e090a22 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/connection/ChannelConnectedListener.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/connection/ChannelConnectedListener.java @@ -19,7 +19,6 @@ package org.neo4j.driver.internal.async.connection; import static java.lang.String.format; -import static org.neo4j.driver.internal.async.connection.BoltProtocolUtil.handshakeBuf; import static org.neo4j.driver.internal.async.connection.BoltProtocolUtil.handshakeString; import io.netty.channel.Channel; @@ -27,8 +26,10 @@ import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelPipeline; import io.netty.channel.ChannelPromise; +import javax.net.ssl.SSLHandshakeException; import org.neo4j.driver.Logger; import org.neo4j.driver.Logging; +import org.neo4j.driver.exceptions.SecurityException; import org.neo4j.driver.exceptions.ServiceUnavailableException; import org.neo4j.driver.internal.BoltServerAddress; import org.neo4j.driver.internal.logging.ChannelActivityLogger; @@ -61,7 +62,18 @@ public void operationComplete(ChannelFuture future) { ChannelPipeline pipeline = channel.pipeline(); pipeline.addLast(new HandshakeHandler(pipelineBuilder, handshakeCompletedPromise, logging)); log.debug("C: [Bolt Handshake] %s", handshakeString()); - channel.writeAndFlush(handshakeBuf(), channel.voidPromise()); + channel.writeAndFlush(BoltProtocolUtil.handshakeBuf()).addListener(f -> { + if (!f.isSuccess()) { + Throwable error = f.cause(); + if (error instanceof SSLHandshakeException) { + error = new SecurityException("Failed to establish secured connection with the server", error); + } else { + error = new ServiceUnavailableException( + String.format("Unable to write Bolt handshake to %s.", this.address), error); + } + this.handshakeCompletedPromise.setFailure(error); + } + }); } else { handshakeCompletedPromise.setFailure(databaseUnavailableError(address, future.cause())); } diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/connection/ChannelConnectorImpl.java b/driver/src/main/java/org/neo4j/driver/internal/async/connection/ChannelConnectorImpl.java index 32630d6450..10953e64fd 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/connection/ChannelConnectorImpl.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/connection/ChannelConnectorImpl.java @@ -136,7 +136,11 @@ private void installHandshakeCompletedListeners( // remove timeout handler from the pipeline once TLS and Bolt handshakes are completed. regular protocol // messages will flow next and we do not want to have read timeout for them - handshakeCompleted.addListener(future -> pipeline.remove(ConnectTimeoutHandler.class)); + handshakeCompleted.addListener(future -> { + if (future.isSuccess()) { + pipeline.remove(ConnectTimeoutHandler.class); + } + }); // add listener that sends an INIT message. connection is now fully established. channel pipeline if fully // set to send/receive messages for a selected protocol version diff --git a/driver/src/test/java/org/neo4j/driver/GraphDatabaseTest.java b/driver/src/test/java/org/neo4j/driver/GraphDatabaseTest.java index 854bec853a..4bff5dd388 100644 --- a/driver/src/test/java/org/neo4j/driver/GraphDatabaseTest.java +++ b/driver/src/test/java/org/neo4j/driver/GraphDatabaseTest.java @@ -40,6 +40,7 @@ import java.util.Arrays; import java.util.Iterator; import java.util.List; +import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.neo4j.driver.exceptions.ServiceUnavailableException; import org.neo4j.driver.internal.BoltServerAddress; @@ -152,6 +153,7 @@ void shouldFailToCreateUnencryptedDriverWhenServerDoesNotRespond() throws IOExce } @Test + @Disabled("TLS actually fails, the test setup is not valid") void shouldFailToCreateEncryptedDriverWhenServerDoesNotRespond() throws IOException { testFailureWhenServerDoesNotRespond(true); } diff --git a/driver/src/test/java/org/neo4j/driver/integration/ChannelConnectorImplIT.java b/driver/src/test/java/org/neo4j/driver/integration/ChannelConnectorImplIT.java index 1bf87f0d82..e3de67c876 100644 --- a/driver/src/test/java/org/neo4j/driver/integration/ChannelConnectorImplIT.java +++ b/driver/src/test/java/org/neo4j/driver/integration/ChannelConnectorImplIT.java @@ -45,6 +45,7 @@ import java.util.concurrent.TimeUnit; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; import org.neo4j.driver.AuthToken; @@ -158,6 +159,7 @@ void shouldFailWhenProtocolNegotiationTakesTooLong() throws Exception { } @Test + @Disabled("TLS actually fails, the test setup is not valid") void shouldFailWhenTLSHandshakeTakesTooLong() throws Exception { // run with TLS so that TLS handshake is the very first operation after connection is established testReadTimeoutOnConnect(trustAllCertificates()); diff --git a/driver/src/test/java/org/neo4j/driver/integration/EncryptionIT.java b/driver/src/test/java/org/neo4j/driver/integration/EncryptionIT.java index 2337b01aa9..f34f4e0da6 100644 --- a/driver/src/test/java/org/neo4j/driver/integration/EncryptionIT.java +++ b/driver/src/test/java/org/neo4j/driver/integration/EncryptionIT.java @@ -59,7 +59,7 @@ void shouldOperateWithEncryptionWhenItIsOptionalInTheDatabase() { @Test void shouldFailWithoutEncryptionWhenItIsRequiredInTheDatabase() { - testMismatchingEncryption(BoltTlsLevel.REQUIRED, false); + testMismatchingEncryption(BoltTlsLevel.REQUIRED, false, "Connection to the database terminated"); } @Test @@ -74,7 +74,7 @@ void shouldOperateWithEncryptionWhenConfiguredUsingBoltSscURI() { @Test void shouldFailWithEncryptionWhenItIsDisabledInTheDatabase() { - testMismatchingEncryption(BoltTlsLevel.DISABLED, true); + testMismatchingEncryption(BoltTlsLevel.DISABLED, true, "Unable to write Bolt handshake to"); } @Test @@ -110,7 +110,7 @@ private void testMatchingEncryption(BoltTlsLevel tlsLevel, boolean driverEncrypt } } - private void testMismatchingEncryption(BoltTlsLevel tlsLevel, boolean driverEncrypted) { + private void testMismatchingEncryption(BoltTlsLevel tlsLevel, boolean driverEncrypted, String errorMessage) { Map tlsConfig = new HashMap<>(); tlsConfig.put(Neo4jSettings.BOLT_TLS_LEVEL, tlsLevel.toString()); neo4j.deleteAndStartNeo4j(tlsConfig); @@ -120,7 +120,7 @@ private void testMismatchingEncryption(BoltTlsLevel tlsLevel, boolean driverEncr ServiceUnavailableException.class, () -> GraphDatabase.driver(neo4j.uri(), neo4j.authToken(), config) .verifyConnectivity()); - assertThat(e.getMessage(), startsWith("Connection to the database terminated")); + assertThat(e.getMessage(), startsWith(errorMessage)); } private static Config newConfig(boolean withEncryption) { diff --git a/driver/src/test/java/org/neo4j/driver/internal/async/connection/ChannelConnectedListenerTest.java b/driver/src/test/java/org/neo4j/driver/internal/async/connection/ChannelConnectedListenerTest.java index 7f35473b4a..a683d9f0f4 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/async/connection/ChannelConnectedListenerTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/async/connection/ChannelConnectedListenerTest.java @@ -19,6 +19,8 @@ package org.neo4j.driver.internal.async.connection; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -29,7 +31,9 @@ import io.netty.channel.ChannelPromise; import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.util.concurrent.Future; import java.io.IOException; +import java.util.concurrent.CompletableFuture; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; import org.neo4j.driver.exceptions.ServiceUnavailableException; @@ -73,6 +77,25 @@ void shouldWriteHandshakeWhenChannelConnected() { assertEquals(handshakeBuf(), channel.readOutbound()); } + @Test + void shouldCompleteHandshakePromiseExceptionallyOnWriteFailure() { + ChannelPromise handshakeCompletedPromise = channel.newPromise(); + ChannelConnectedListener listener = newListener(handshakeCompletedPromise); + ChannelPromise channelConnectedPromise = channel.newPromise(); + channelConnectedPromise.setSuccess(); + channel.close(); + + listener.operationComplete(channelConnectedPromise); + + assertTrue(handshakeCompletedPromise.isDone()); + CompletableFuture> future = new CompletableFuture<>(); + handshakeCompletedPromise.addListener(future::complete); + Future handshakeFuture = future.join(); + assertTrue(handshakeFuture.isDone()); + assertFalse(handshakeFuture.isSuccess()); + assertInstanceOf(ServiceUnavailableException.class, handshakeFuture.cause()); + } + private static ChannelConnectedListener newListener(ChannelPromise handshakeCompletedPromise) { return new ChannelConnectedListener( LOCAL_DEFAULT, new ChannelPipelineBuilderImpl(), handshakeCompletedPromise, DEV_NULL_LOGGING);