diff --git a/driver/src/main/java/org/neo4j/driver/internal/handlers/HelloResponseHandler.java b/driver/src/main/java/org/neo4j/driver/internal/handlers/HelloResponseHandler.java index 18f669ead0..ac4a149f12 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/handlers/HelloResponseHandler.java +++ b/driver/src/main/java/org/neo4j/driver/internal/handlers/HelloResponseHandler.java @@ -23,13 +23,15 @@ import java.util.Map; +import org.neo4j.driver.internal.messaging.BoltProtocolVersion; +import org.neo4j.driver.internal.messaging.v3.BoltProtocolV3; import org.neo4j.driver.internal.spi.ResponseHandler; -import org.neo4j.driver.internal.util.ServerVersion; import org.neo4j.driver.Value; import static org.neo4j.driver.internal.async.connection.ChannelAttributes.setConnectionId; import static org.neo4j.driver.internal.async.connection.ChannelAttributes.setServerVersion; import static org.neo4j.driver.internal.util.MetadataExtractor.extractNeo4jServerVersion; +import static org.neo4j.driver.internal.util.ServerVersion.fromBoltProtocolVersion; public class HelloResponseHandler implements ResponseHandler { @@ -37,11 +39,13 @@ public class HelloResponseHandler implements ResponseHandler private final ChannelPromise connectionInitializedPromise; private final Channel channel; + private final BoltProtocolVersion protocolVersion; - public HelloResponseHandler( ChannelPromise connectionInitializedPromise ) + public HelloResponseHandler( ChannelPromise connectionInitializedPromise, BoltProtocolVersion protocolVersion ) { this.connectionInitializedPromise = connectionInitializedPromise; this.channel = connectionInitializedPromise.channel(); + this.protocolVersion = protocolVersion; } @Override @@ -49,8 +53,16 @@ public void onSuccess( Map metadata ) { try { - ServerVersion serverVersion = extractNeo4jServerVersion( metadata ); - setServerVersion( channel, serverVersion ); + // From Server V4 extracting server from metadata in the success message is unreliable + // so we fix the Server version against the Bolt Protocol version for Server V4 and above. + if ( BoltProtocolV3.VERSION.equals( protocolVersion ) ) + { + setServerVersion( channel, extractNeo4jServerVersion( metadata ) ); + } + else + { + setServerVersion( channel, fromBoltProtocolVersion( protocolVersion ) ); + } String connectionId = extractConnectionId( metadata ); setConnectionId( channel, connectionId ); diff --git a/driver/src/main/java/org/neo4j/driver/internal/messaging/v3/BoltProtocolV3.java b/driver/src/main/java/org/neo4j/driver/internal/messaging/v3/BoltProtocolV3.java index 404339fffc..75e007fc4e 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/messaging/v3/BoltProtocolV3.java +++ b/driver/src/main/java/org/neo4j/driver/internal/messaging/v3/BoltProtocolV3.java @@ -81,7 +81,7 @@ public void initializeChannel( String userAgent, Map authToken, Ch Channel channel = channelInitializedPromise.channel(); HelloMessage message = new HelloMessage( userAgent, authToken ); - HelloResponseHandler handler = new HelloResponseHandler( channelInitializedPromise ); + HelloResponseHandler handler = new HelloResponseHandler( channelInitializedPromise, version() ); messageDispatcher( channel ).enqueue( handler ); channel.writeAndFlush( message, channel.voidPromise() ); diff --git a/driver/src/main/java/org/neo4j/driver/internal/util/ServerVersion.java b/driver/src/main/java/org/neo4j/driver/internal/util/ServerVersion.java index fae9b4d0b5..245949a397 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/util/ServerVersion.java +++ b/driver/src/main/java/org/neo4j/driver/internal/util/ServerVersion.java @@ -24,6 +24,9 @@ import org.neo4j.driver.Driver; import org.neo4j.driver.Session; +import org.neo4j.driver.internal.messaging.BoltProtocolVersion; +import org.neo4j.driver.internal.messaging.v4.BoltProtocolV4; +import org.neo4j.driver.internal.messaging.v41.BoltProtocolV41; import static java.lang.Integer.compare; @@ -174,4 +177,19 @@ private static String stringValue( String product, int major, int minor, int pat } return String.format( "%s/%s.%s.%s", product, major, minor, patch ); } + + public static ServerVersion fromBoltProtocolVersion( BoltProtocolVersion protocolVersion ) + { + + if ( BoltProtocolV4.VERSION.equals( protocolVersion ) ) + { + return ServerVersion.v4_0_0; + } + else if ( BoltProtocolV41.VERSION.equals( protocolVersion ) ) + { + return ServerVersion.v4_1_0; + } + + return ServerVersion.vInDev; + } } diff --git a/driver/src/test/java/org/neo4j/driver/internal/handlers/HelloResponseHandlerTest.java b/driver/src/test/java/org/neo4j/driver/internal/handlers/HelloResponseHandlerTest.java index 7a142e03b5..ef4a373943 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/handlers/HelloResponseHandlerTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/handlers/HelloResponseHandlerTest.java @@ -37,6 +37,10 @@ import org.neo4j.driver.internal.async.inbound.InboundMessageDispatcher; import org.neo4j.driver.internal.async.outbound.OutboundMessageHandler; import org.neo4j.driver.internal.messaging.v1.MessageFormatV1; +import org.neo4j.driver.internal.messaging.v3.BoltProtocolV3; +import org.neo4j.driver.internal.messaging.v4.BoltProtocolV4; +import org.neo4j.driver.internal.messaging.v41.BoltProtocolV41; +import org.neo4j.driver.internal.util.ServerVersion; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertFalse; @@ -73,7 +77,7 @@ void tearDown() void shouldSetServerVersionOnChannel() { ChannelPromise channelPromise = channel.newPromise(); - HelloResponseHandler handler = new HelloResponseHandler( channelPromise ); + HelloResponseHandler handler = new HelloResponseHandler( channelPromise, BoltProtocolV3.VERSION ); Map metadata = metadata( anyServerVersion(), "bolt-1" ); handler.onSuccess( metadata ); @@ -86,7 +90,7 @@ void shouldSetServerVersionOnChannel() void shouldThrowWhenServerVersionNotReturned() { ChannelPromise channelPromise = channel.newPromise(); - HelloResponseHandler handler = new HelloResponseHandler( channelPromise ); + HelloResponseHandler handler = new HelloResponseHandler( channelPromise, BoltProtocolV3.VERSION ); Map metadata = metadata( null, "bolt-1" ); assertThrows( UntrustedServerException.class, () -> handler.onSuccess( metadata ) ); @@ -99,7 +103,7 @@ void shouldThrowWhenServerVersionNotReturned() void shouldThrowWhenServerVersionIsNull() { ChannelPromise channelPromise = channel.newPromise(); - HelloResponseHandler handler = new HelloResponseHandler( channelPromise ); + HelloResponseHandler handler = new HelloResponseHandler( channelPromise, BoltProtocolV3.VERSION ); Map metadata = metadata( Values.NULL, "bolt-x" ); assertThrows( UntrustedServerException.class, () -> handler.onSuccess( metadata ) ); @@ -112,7 +116,7 @@ void shouldThrowWhenServerVersionIsNull() void shouldThrowWhenServerVersionCantBeParsed() { ChannelPromise channelPromise = channel.newPromise(); - HelloResponseHandler handler = new HelloResponseHandler( channelPromise ); + HelloResponseHandler handler = new HelloResponseHandler( channelPromise, BoltProtocolV3.VERSION ); Map metadata = metadata( "WrongServerVersion", "bolt-x" ); assertThrows( IllegalArgumentException.class, () -> handler.onSuccess( metadata ) ); @@ -121,11 +125,39 @@ void shouldThrowWhenServerVersionCantBeParsed() assertTrue( channel.closeFuture().isDone() ); // channel was closed } + @Test + void shouldUseProtocolVersionForServerVersionWhenConnectedWithBoltV4() + { + ChannelPromise channelPromise = channel.newPromise(); + HelloResponseHandler handler = new HelloResponseHandler( channelPromise, BoltProtocolV4.VERSION ); + + // server used in metadata should be ignored + Map metadata = metadata( ServerVersion.vInDev, "bolt-1" ); + handler.onSuccess( metadata ); + + assertTrue( channelPromise.isSuccess() ); + assertEquals( ServerVersion.v4_0_0, serverVersion( channel ) ); + } + + @Test + void shouldUseProtocolVersionForServerVersionWhenConnectedWithBoltV41() + { + ChannelPromise channelPromise = channel.newPromise(); + HelloResponseHandler handler = new HelloResponseHandler( channelPromise, BoltProtocolV41.VERSION ); + + // server used in metadata should be ignored + Map metadata = metadata( ServerVersion.vInDev, "bolt-1" ); + handler.onSuccess( metadata ); + + assertTrue( channelPromise.isSuccess() ); + assertEquals( ServerVersion.v4_1_0, serverVersion( channel ) ); + } + @Test void shouldSetConnectionIdOnChannel() { ChannelPromise channelPromise = channel.newPromise(); - HelloResponseHandler handler = new HelloResponseHandler( channelPromise ); + HelloResponseHandler handler = new HelloResponseHandler( channelPromise, BoltProtocolV3.VERSION ); Map metadata = metadata( anyServerVersion(), "bolt-42" ); handler.onSuccess( metadata ); @@ -138,7 +170,7 @@ void shouldSetConnectionIdOnChannel() void shouldThrowWhenConnectionIdNotReturned() { ChannelPromise channelPromise = channel.newPromise(); - HelloResponseHandler handler = new HelloResponseHandler( channelPromise ); + HelloResponseHandler handler = new HelloResponseHandler( channelPromise, BoltProtocolV3.VERSION ); Map metadata = metadata( anyServerVersion(), null ); assertThrows( IllegalStateException.class, () -> handler.onSuccess( metadata ) ); @@ -151,7 +183,7 @@ void shouldThrowWhenConnectionIdNotReturned() void shouldThrowWhenConnectionIdIsNull() { ChannelPromise channelPromise = channel.newPromise(); - HelloResponseHandler handler = new HelloResponseHandler( channelPromise ); + HelloResponseHandler handler = new HelloResponseHandler( channelPromise, BoltProtocolV3.VERSION ); Map metadata = metadata( anyServerVersion(), Values.NULL ); assertThrows( IllegalStateException.class, () -> handler.onSuccess( metadata ) ); @@ -164,7 +196,7 @@ void shouldThrowWhenConnectionIdIsNull() void shouldCloseChannelOnFailure() throws Exception { ChannelPromise channelPromise = channel.newPromise(); - HelloResponseHandler handler = new HelloResponseHandler( channelPromise ); + HelloResponseHandler handler = new HelloResponseHandler( channelPromise, BoltProtocolV3.VERSION ); RuntimeException error = new RuntimeException( "Hi!" ); handler.onFailure( error ); diff --git a/driver/src/test/java/org/neo4j/driver/internal/util/ServerVersionTest.java b/driver/src/test/java/org/neo4j/driver/internal/util/ServerVersionTest.java index 2b9a12426f..2129023123 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/util/ServerVersionTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/util/ServerVersionTest.java @@ -20,6 +20,11 @@ import org.junit.jupiter.api.Test; +import org.neo4j.driver.internal.messaging.BoltProtocolVersion; +import org.neo4j.driver.internal.messaging.v4.BoltProtocolV4; +import org.neo4j.driver.internal.messaging.v41.BoltProtocolV41; + +import static java.lang.Integer.MAX_VALUE; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.is; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -60,4 +65,12 @@ void shouldFailToCompareDifferentProducts() assertThrows( IllegalArgumentException.class, () -> version1.greaterThanOrEqual( version2 ) ); } + + @Test + void shouldReturnCorrectServerVersionFromBoltProtocolVersion() + { + assertEquals( ServerVersion.v4_0_0, ServerVersion.fromBoltProtocolVersion( BoltProtocolV4.VERSION ) ); + assertEquals( ServerVersion.v4_1_0, ServerVersion.fromBoltProtocolVersion( BoltProtocolV41.VERSION ) ); + assertEquals( ServerVersion.vInDev, ServerVersion.fromBoltProtocolVersion( new BoltProtocolVersion( MAX_VALUE, MAX_VALUE ) ) ); + } }