Skip to content

Use connect timeout in Bolt and TLS handshake #447

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Dec 15, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,14 @@
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelOption;
import io.netty.channel.ChannelPipeline;
import io.netty.channel.ChannelPromise;

import java.util.Map;

import org.neo4j.driver.internal.BoltServerAddress;
import org.neo4j.driver.internal.ConnectionSettings;
import org.neo4j.driver.internal.async.inbound.ConnectTimeoutHandler;
import org.neo4j.driver.internal.security.InternalAuthToken;
import org.neo4j.driver.internal.security.SecurityPlan;
import org.neo4j.driver.internal.util.Clock;
Expand Down Expand Up @@ -71,20 +73,47 @@ public ChannelConnectorImpl( ConnectionSettings connectionSettings, SecurityPlan
public ChannelFuture connect( BoltServerAddress address, Bootstrap bootstrap )
{
bootstrap.option( ChannelOption.CONNECT_TIMEOUT_MILLIS, connectTimeoutMillis );
bootstrap.handler( new NettyChannelInitializer( address, securityPlan, clock, logging ) );
bootstrap.handler( new NettyChannelInitializer( address, securityPlan, connectTimeoutMillis, clock, logging ) );

ChannelFuture channelConnected = bootstrap.connect( address.toSocketAddress() );

Channel channel = channelConnected.channel();
ChannelPromise handshakeCompleted = channel.newPromise();
ChannelPromise connectionInitialized = channel.newPromise();

installChannelConnectedListeners( address, channelConnected, handshakeCompleted );
installHandshakeCompletedListeners( handshakeCompleted, connectionInitialized );

return connectionInitialized;
}

private void installChannelConnectedListeners( BoltServerAddress address, ChannelFuture channelConnected,
ChannelPromise handshakeCompleted )
{
ChannelPipeline pipeline = channelConnected.channel().pipeline();

// add timeout handler to the pipeline when channel is connected. it's needed to limit amount of time code
// spends in TLS and Bolt handshakes. prevents infinite waiting when database does not respond
channelConnected.addListener( future ->
pipeline.addFirst( new ConnectTimeoutHandler( connectTimeoutMillis ) ) );

// add listener that sends Bolt handshake bytes when channel is connected
channelConnected.addListener(
new ChannelConnectedListener( address, pipelineBuilder, handshakeCompleted, logging ) );
handshakeCompleted.addListener(
new HandshakeCompletedListener( userAgent, authToken, connectionInitialized ) );
}

return connectionInitialized;
private void installHandshakeCompletedListeners( ChannelPromise handshakeCompleted,
ChannelPromise connectionInitialized )
{
ChannelPipeline pipeline = handshakeCompleted.channel().pipeline();

// remove timeout handler from the pipeline once TLS and Bolt handshakes are completed. regular protocol
// messages will flow next and we do not want to have read timeout for them
handshakeCompleted.addListener( future -> pipeline.remove( ConnectTimeoutHandler.class ) );

// add listener that sends an INIT message. connection is now fully established. channel pipeline if fully
// set to send/receive messages for a selected protocol version
handshakeCompleted.addListener( new HandshakeCompletedListener( userAgent, authToken, connectionInitialized ) );
}

private static Map<String,Value> tokenAsMap( AuthToken token )
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,16 @@ public class NettyChannelInitializer extends ChannelInitializer<Channel>
{
private final BoltServerAddress address;
private final SecurityPlan securityPlan;
private final int connectTimeoutMillis;
private final Clock clock;
private final Logging logging;

public NettyChannelInitializer( BoltServerAddress address, SecurityPlan securityPlan, Clock clock, Logging logging )
public NettyChannelInitializer( BoltServerAddress address, SecurityPlan securityPlan, int connectTimeoutMillis,
Clock clock, Logging logging )
{
this.address = address;
this.securityPlan = securityPlan;
this.connectTimeoutMillis = connectTimeoutMillis;
this.clock = clock;
this.logging = logging;
}
Expand All @@ -65,7 +68,9 @@ protected void initChannel( Channel channel )
private SslHandler createSslHandler()
{
SSLEngine sslEngine = createSslEngine();
return new SslHandler( sslEngine );
SslHandler sslHandler = new SslHandler( sslEngine );
sslHandler.setHandshakeTimeoutMillis( connectTimeoutMillis );
return sslHandler;
}

private SSLEngine createSslEngine()
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
/*
* Copyright (c) 2002-2017 "Neo Technology,"
* Network Engine for Objects in Lund AB [http://neotechnology.com]
*
* This file is part of Neo4j.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.neo4j.driver.internal.async.inbound;

import io.netty.channel.ChannelHandlerContext;
import io.netty.handler.timeout.ReadTimeoutHandler;

import java.util.concurrent.TimeUnit;

import org.neo4j.driver.v1.exceptions.ServiceUnavailableException;

/**
* Handler needed to limit amount of time connection performs TLS and Bolt handshakes.
* It should only be used when connection is established and removed from the pipeline afterwards.
* Otherwise it will make long running queries fail.
*/
public class ConnectTimeoutHandler extends ReadTimeoutHandler
{
private final long timeoutMillis;
private boolean triggered;

public ConnectTimeoutHandler( long timeoutMillis )
{
super( timeoutMillis, TimeUnit.MILLISECONDS );
this.timeoutMillis = timeoutMillis;
}

@Override
protected void readTimedOut( ChannelHandlerContext ctx )
{
if ( !triggered )
{
triggered = true;
ctx.fireExceptionCaught( unableToConnectError() );
}
}

private ServiceUnavailableException unableToConnectError()
{
return new ServiceUnavailableException( "Unable to establish connection in " + timeoutMillis + "ms" );
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,24 @@
import io.netty.bootstrap.Bootstrap;
import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelPipeline;
import io.netty.handler.ssl.SslHandler;
import org.junit.After;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.RuleChain;
import org.junit.rules.Timeout;

import java.io.IOException;
import java.net.ConnectException;
import java.net.ServerSocket;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;

import org.neo4j.driver.internal.BoltServerAddress;
import org.neo4j.driver.internal.ConnectionSettings;
import org.neo4j.driver.internal.async.inbound.ConnectTimeoutHandler;
import org.neo4j.driver.internal.security.SecurityPlan;
import org.neo4j.driver.internal.util.FakeClock;
import org.neo4j.driver.v1.AuthToken;
Expand All @@ -42,7 +49,9 @@

import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.startsWith;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
Expand All @@ -52,19 +61,20 @@

public class ChannelConnectorImplTest
{
private final TestNeo4j neo4j = new TestNeo4j();
@Rule
public final TestNeo4j neo4j = new TestNeo4j();
public final RuleChain ruleChain = RuleChain.outerRule( Timeout.seconds( 20 ) ).around( neo4j );

private Bootstrap bootstrap;

@Before
public void setUp() throws Exception
public void setUp()
{
bootstrap = BootstrapFactory.newBootstrap( 1 );
}

@After
public void tearDown() throws Exception
public void tearDown()
{
if ( bootstrap != null )
{
Expand All @@ -75,7 +85,7 @@ public void tearDown() throws Exception
@Test
public void shouldConnect() throws Exception
{
ChannelConnectorImpl connector = newConnector( neo4j.authToken() );
ChannelConnector connector = newConnector( neo4j.authToken() );

ChannelFuture channelFuture = connector.connect( neo4j.address(), bootstrap );
assertTrue( channelFuture.await( 10, TimeUnit.SECONDS ) );
Expand All @@ -85,10 +95,26 @@ public void shouldConnect() throws Exception
assertTrue( channel.isActive() );
}

@Test
public void shouldSetupHandlers() throws Exception
{
ChannelConnector connector = newConnector( neo4j.authToken(), SecurityPlan.forAllCertificates(), 10_000 );

ChannelFuture channelFuture = connector.connect( neo4j.address(), bootstrap );
assertTrue( channelFuture.await( 10, TimeUnit.SECONDS ) );

Channel channel = channelFuture.channel();
ChannelPipeline pipeline = channel.pipeline();
assertTrue( channel.isActive() );

assertNotNull( pipeline.get( SslHandler.class ) );
assertNull( pipeline.get( ConnectTimeoutHandler.class ) );
}

@Test
public void shouldFailToConnectToWrongAddress() throws Exception
{
ChannelConnectorImpl connector = newConnector( neo4j.authToken() );
ChannelConnector connector = newConnector( neo4j.authToken() );

ChannelFuture channelFuture = connector.connect( new BoltServerAddress( "wrong-localhost" ), bootstrap );
assertTrue( channelFuture.await( 10, TimeUnit.SECONDS ) );
Expand All @@ -112,7 +138,7 @@ public void shouldFailToConnectToWrongAddress() throws Exception
public void shouldFailToConnectWithWrongCredentials() throws Exception
{
AuthToken authToken = AuthTokens.basic( "neo4j", "wrong-password" );
ChannelConnectorImpl connector = newConnector( authToken );
ChannelConnector connector = newConnector( authToken );

ChannelFuture channelFuture = connector.connect( neo4j.address(), bootstrap );
assertTrue( channelFuture.await( 10, TimeUnit.SECONDS ) );
Expand All @@ -131,10 +157,10 @@ public void shouldFailToConnectWithWrongCredentials() throws Exception
assertFalse( channel.isActive() );
}

@Test( timeout = 10000 )
@Test
public void shouldEnforceConnectTimeout() throws Exception
{
ChannelConnectorImpl connector = newConnector( neo4j.authToken(), 1000 );
ChannelConnector connector = newConnector( neo4j.authToken(), 1000 );

// try connect to a non-routable ip address 10.0.0.0, it will never respond
ChannelFuture channelFuture = connector.connect( new BoltServerAddress( "10.0.0.0" ), bootstrap );
Expand All @@ -151,15 +177,55 @@ public void shouldEnforceConnectTimeout() throws Exception
}
}

@Test
public void shouldFailWhenProtocolNegotiationTakesTooLong() throws Exception
{
// run without TLS so that Bolt handshake is the very first operation after connection is established
testReadTimeoutOnConnect( SecurityPlan.insecure() );
}

@Test
public void shouldFailWhenTLSHandshakeTakesTooLong() throws Exception
{
// run with TLS so that TLS handshake is the very first operation after connection is established
testReadTimeoutOnConnect( SecurityPlan.forAllCertificates() );
}

private void testReadTimeoutOnConnect( SecurityPlan securityPlan ) throws IOException
{
try ( ServerSocket server = new ServerSocket( 0 ) ) // server that accepts connections but does not reply
{
int timeoutMillis = 1_000;
BoltServerAddress address = new BoltServerAddress( "localhost", server.getLocalPort() );
ChannelConnector connector = newConnector( neo4j.authToken(), securityPlan, timeoutMillis );

ChannelFuture channelFuture = connector.connect( address, bootstrap );
try
{
await( channelFuture );
fail( "Exception expected" );
}
catch ( ServiceUnavailableException e )
{
assertEquals( e.getMessage(), "Unable to establish connection in " + timeoutMillis + "ms" );
}
}
}

private ChannelConnectorImpl newConnector( AuthToken authToken ) throws Exception
{
return newConnector( authToken, Integer.MAX_VALUE );
}

private ChannelConnectorImpl newConnector( AuthToken authToken, int connectTimeoutMillis ) throws Exception
{
ConnectionSettings settings = new ConnectionSettings( authToken, 1000 );
return new ChannelConnectorImpl( settings, SecurityPlan.forAllCertificates(), DEV_NULL_LOGGING,
new FakeClock() );
return newConnector( authToken, SecurityPlan.forAllCertificates(), connectTimeoutMillis );
}

private ChannelConnectorImpl newConnector( AuthToken authToken, SecurityPlan securityPlan,
int connectTimeoutMillis )
{
ConnectionSettings settings = new ConnectionSettings( authToken, connectTimeoutMillis );
return new ChannelConnectorImpl( settings, securityPlan, DEV_NULL_LOGGING, new FakeClock() );
}
}
Loading