Skip to content

Use Protocol Version instead of server_agent for Bolt V4 connections #718

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,34 +23,46 @@

import java.util.Map;

import org.neo4j.driver.internal.messaging.BoltProtocolVersion;
import org.neo4j.driver.internal.messaging.v3.BoltProtocolV3;
import org.neo4j.driver.internal.spi.ResponseHandler;
import org.neo4j.driver.internal.util.ServerVersion;
import org.neo4j.driver.Value;

import static org.neo4j.driver.internal.async.connection.ChannelAttributes.setConnectionId;
import static org.neo4j.driver.internal.async.connection.ChannelAttributes.setServerVersion;
import static org.neo4j.driver.internal.util.MetadataExtractor.extractNeo4jServerVersion;
import static org.neo4j.driver.internal.util.ServerVersion.fromBoltProtocolVersion;

public class HelloResponseHandler implements ResponseHandler
{
private static final String CONNECTION_ID_METADATA_KEY = "connection_id";

private final ChannelPromise connectionInitializedPromise;
private final Channel channel;
private final BoltProtocolVersion protocolVersion;

public HelloResponseHandler( ChannelPromise connectionInitializedPromise )
public HelloResponseHandler( ChannelPromise connectionInitializedPromise, BoltProtocolVersion protocolVersion )
{
this.connectionInitializedPromise = connectionInitializedPromise;
this.channel = connectionInitializedPromise.channel();
this.protocolVersion = protocolVersion;
}

@Override
public void onSuccess( Map<String,Value> metadata )
{
try
{
ServerVersion serverVersion = extractNeo4jServerVersion( metadata );
setServerVersion( channel, serverVersion );
// From Server V4 extracting server from metadata in the success message is unreliable
// so we fix the Server version against the Bolt Protocol version for Server V4 and above.
if ( BoltProtocolV3.VERSION.equals( protocolVersion ) )
{
setServerVersion( channel, extractNeo4jServerVersion( metadata ) );
}
else
{
setServerVersion( channel, fromBoltProtocolVersion( protocolVersion ) );
}

String connectionId = extractConnectionId( metadata );
setConnectionId( channel, connectionId );
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ public void initializeChannel( String userAgent, Map<String,Value> authToken, Ch
Channel channel = channelInitializedPromise.channel();

HelloMessage message = new HelloMessage( userAgent, authToken );
HelloResponseHandler handler = new HelloResponseHandler( channelInitializedPromise );
HelloResponseHandler handler = new HelloResponseHandler( channelInitializedPromise, version() );

messageDispatcher( channel ).enqueue( handler );
channel.writeAndFlush( message, channel.voidPromise() );
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@

import org.neo4j.driver.Driver;
import org.neo4j.driver.Session;
import org.neo4j.driver.internal.messaging.BoltProtocolVersion;
import org.neo4j.driver.internal.messaging.v4.BoltProtocolV4;
import org.neo4j.driver.internal.messaging.v41.BoltProtocolV41;

import static java.lang.Integer.compare;

Expand Down Expand Up @@ -174,4 +177,19 @@ private static String stringValue( String product, int major, int minor, int pat
}
return String.format( "%s/%s.%s.%s", product, major, minor, patch );
}

public static ServerVersion fromBoltProtocolVersion( BoltProtocolVersion protocolVersion )
{

if ( BoltProtocolV4.VERSION.equals( protocolVersion ) )
{
return ServerVersion.v4_0_0;
}
else if ( BoltProtocolV41.VERSION.equals( protocolVersion ) )
{
return ServerVersion.v4_1_0;
}

return ServerVersion.vInDev;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,10 @@
import org.neo4j.driver.internal.async.inbound.InboundMessageDispatcher;
import org.neo4j.driver.internal.async.outbound.OutboundMessageHandler;
import org.neo4j.driver.internal.messaging.v1.MessageFormatV1;
import org.neo4j.driver.internal.messaging.v3.BoltProtocolV3;
import org.neo4j.driver.internal.messaging.v4.BoltProtocolV4;
import org.neo4j.driver.internal.messaging.v41.BoltProtocolV41;
import org.neo4j.driver.internal.util.ServerVersion;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertFalse;
Expand Down Expand Up @@ -73,7 +77,7 @@ void tearDown()
void shouldSetServerVersionOnChannel()
{
ChannelPromise channelPromise = channel.newPromise();
HelloResponseHandler handler = new HelloResponseHandler( channelPromise );
HelloResponseHandler handler = new HelloResponseHandler( channelPromise, BoltProtocolV3.VERSION );

Map<String,Value> metadata = metadata( anyServerVersion(), "bolt-1" );
handler.onSuccess( metadata );
Expand All @@ -86,7 +90,7 @@ void shouldSetServerVersionOnChannel()
void shouldThrowWhenServerVersionNotReturned()
{
ChannelPromise channelPromise = channel.newPromise();
HelloResponseHandler handler = new HelloResponseHandler( channelPromise );
HelloResponseHandler handler = new HelloResponseHandler( channelPromise, BoltProtocolV3.VERSION );

Map<String,Value> metadata = metadata( null, "bolt-1" );
assertThrows( UntrustedServerException.class, () -> handler.onSuccess( metadata ) );
Expand All @@ -99,7 +103,7 @@ void shouldThrowWhenServerVersionNotReturned()
void shouldThrowWhenServerVersionIsNull()
{
ChannelPromise channelPromise = channel.newPromise();
HelloResponseHandler handler = new HelloResponseHandler( channelPromise );
HelloResponseHandler handler = new HelloResponseHandler( channelPromise, BoltProtocolV3.VERSION );

Map<String,Value> metadata = metadata( Values.NULL, "bolt-x" );
assertThrows( UntrustedServerException.class, () -> handler.onSuccess( metadata ) );
Expand All @@ -112,7 +116,7 @@ void shouldThrowWhenServerVersionIsNull()
void shouldThrowWhenServerVersionCantBeParsed()
{
ChannelPromise channelPromise = channel.newPromise();
HelloResponseHandler handler = new HelloResponseHandler( channelPromise );
HelloResponseHandler handler = new HelloResponseHandler( channelPromise, BoltProtocolV3.VERSION );

Map<String,Value> metadata = metadata( "WrongServerVersion", "bolt-x" );
assertThrows( IllegalArgumentException.class, () -> handler.onSuccess( metadata ) );
Expand All @@ -121,11 +125,39 @@ void shouldThrowWhenServerVersionCantBeParsed()
assertTrue( channel.closeFuture().isDone() ); // channel was closed
}

@Test
void shouldUseProtocolVersionForServerVersionWhenConnectedWithBoltV4()
{
ChannelPromise channelPromise = channel.newPromise();
HelloResponseHandler handler = new HelloResponseHandler( channelPromise, BoltProtocolV4.VERSION );

// server used in metadata should be ignored
Map<String,Value> metadata = metadata( ServerVersion.vInDev, "bolt-1" );
handler.onSuccess( metadata );

assertTrue( channelPromise.isSuccess() );
assertEquals( ServerVersion.v4_0_0, serverVersion( channel ) );
}

@Test
void shouldUseProtocolVersionForServerVersionWhenConnectedWithBoltV41()
{
ChannelPromise channelPromise = channel.newPromise();
HelloResponseHandler handler = new HelloResponseHandler( channelPromise, BoltProtocolV41.VERSION );

// server used in metadata should be ignored
Map<String,Value> metadata = metadata( ServerVersion.vInDev, "bolt-1" );
handler.onSuccess( metadata );

assertTrue( channelPromise.isSuccess() );
assertEquals( ServerVersion.v4_1_0, serverVersion( channel ) );
}

@Test
void shouldSetConnectionIdOnChannel()
{
ChannelPromise channelPromise = channel.newPromise();
HelloResponseHandler handler = new HelloResponseHandler( channelPromise );
HelloResponseHandler handler = new HelloResponseHandler( channelPromise, BoltProtocolV3.VERSION );

Map<String,Value> metadata = metadata( anyServerVersion(), "bolt-42" );
handler.onSuccess( metadata );
Expand All @@ -138,7 +170,7 @@ void shouldSetConnectionIdOnChannel()
void shouldThrowWhenConnectionIdNotReturned()
{
ChannelPromise channelPromise = channel.newPromise();
HelloResponseHandler handler = new HelloResponseHandler( channelPromise );
HelloResponseHandler handler = new HelloResponseHandler( channelPromise, BoltProtocolV3.VERSION );

Map<String,Value> metadata = metadata( anyServerVersion(), null );
assertThrows( IllegalStateException.class, () -> handler.onSuccess( metadata ) );
Expand All @@ -151,7 +183,7 @@ void shouldThrowWhenConnectionIdNotReturned()
void shouldThrowWhenConnectionIdIsNull()
{
ChannelPromise channelPromise = channel.newPromise();
HelloResponseHandler handler = new HelloResponseHandler( channelPromise );
HelloResponseHandler handler = new HelloResponseHandler( channelPromise, BoltProtocolV3.VERSION );

Map<String,Value> metadata = metadata( anyServerVersion(), Values.NULL );
assertThrows( IllegalStateException.class, () -> handler.onSuccess( metadata ) );
Expand All @@ -164,7 +196,7 @@ void shouldThrowWhenConnectionIdIsNull()
void shouldCloseChannelOnFailure() throws Exception
{
ChannelPromise channelPromise = channel.newPromise();
HelloResponseHandler handler = new HelloResponseHandler( channelPromise );
HelloResponseHandler handler = new HelloResponseHandler( channelPromise, BoltProtocolV3.VERSION );

RuntimeException error = new RuntimeException( "Hi!" );
handler.onFailure( error );
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@

import org.junit.jupiter.api.Test;

import org.neo4j.driver.internal.messaging.BoltProtocolVersion;
import org.neo4j.driver.internal.messaging.v4.BoltProtocolV4;
import org.neo4j.driver.internal.messaging.v41.BoltProtocolV41;

import static java.lang.Integer.MAX_VALUE;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.is;
import static org.junit.jupiter.api.Assertions.assertEquals;
Expand Down Expand Up @@ -60,4 +65,12 @@ void shouldFailToCompareDifferentProducts()

assertThrows( IllegalArgumentException.class, () -> version1.greaterThanOrEqual( version2 ) );
}

@Test
void shouldReturnCorrectServerVersionFromBoltProtocolVersion()
{
assertEquals( ServerVersion.v4_0_0, ServerVersion.fromBoltProtocolVersion( BoltProtocolV4.VERSION ) );
assertEquals( ServerVersion.v4_1_0, ServerVersion.fromBoltProtocolVersion( BoltProtocolV41.VERSION ) );
assertEquals( ServerVersion.vInDev, ServerVersion.fromBoltProtocolVersion( new BoltProtocolVersion( MAX_VALUE, MAX_VALUE ) ) );
}
}