Skip to content

Commit 2cc2ea7

Browse files
authored
Merge pull request #442 from lutovich/1.4-handshake-with-timeout
Use connect timeout in Bolt and TLS handshake
2 parents 1a8ac12 + 1ebe435 commit 2cc2ea7

File tree

6 files changed

+340
-36
lines changed

6 files changed

+340
-36
lines changed

driver/src/main/java/org/neo4j/driver/internal/net/ChannelFactory.java

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -22,31 +22,26 @@
2222
import java.net.ConnectException;
2323
import java.net.Socket;
2424
import java.net.SocketTimeoutException;
25-
import java.net.StandardSocketOptions;
2625
import java.nio.channels.ByteChannel;
27-
import java.nio.channels.SocketChannel;
2826

2927
import org.neo4j.driver.internal.security.SecurityPlan;
3028
import org.neo4j.driver.internal.security.TLSSocketChannel;
3129
import org.neo4j.driver.v1.Logger;
3230

3331
class ChannelFactory
3432
{
35-
static ByteChannel create( BoltServerAddress address, SecurityPlan securityPlan, int timeoutMillis, Logger log )
36-
throws IOException
33+
static ByteChannel create( Socket socket, BoltServerAddress address, SecurityPlan securityPlan, int timeoutMillis,
34+
Logger log ) throws IOException
3735
{
38-
SocketChannel soChannel = SocketChannel.open();
39-
soChannel.setOption( StandardSocketOptions.SO_REUSEADDR, true );
40-
soChannel.setOption( StandardSocketOptions.SO_KEEPALIVE, true );
41-
connect( soChannel, address, timeoutMillis );
36+
connect( socket, address, timeoutMillis );
4237

43-
ByteChannel channel = soChannel;
38+
ByteChannel channel = new UnencryptedSocketChannel( socket );
4439

4540
if ( securityPlan.requiresEncryption() )
4641
{
4742
try
4843
{
49-
channel = TLSSocketChannel.create( address, securityPlan, soChannel, log );
44+
channel = TLSSocketChannel.create( address, securityPlan, channel, log );
5045
}
5146
catch ( Exception e )
5247
{
@@ -70,10 +65,8 @@ static ByteChannel create( BoltServerAddress address, SecurityPlan securityPlan,
7065
return channel;
7166
}
7267

73-
private static void connect( SocketChannel soChannel, BoltServerAddress address, int timeoutMillis )
74-
throws IOException
68+
private static void connect( Socket socket, BoltServerAddress address, int timeoutMillis ) throws IOException
7569
{
76-
Socket socket = soChannel.socket();
7770
try
7871
{
7972
socket.connect( address.toSocketAddress(), timeoutMillis );

driver/src/main/java/org/neo4j/driver/internal/net/SocketClient.java

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,14 +20,19 @@
2020

2121
import java.io.IOException;
2222
import java.net.ConnectException;
23+
import java.net.Socket;
24+
import java.net.SocketAddress;
25+
import java.net.SocketTimeoutException;
2326
import java.nio.ByteBuffer;
2427
import java.nio.channels.ByteChannel;
2528
import java.util.Queue;
2629

2730
import org.neo4j.driver.internal.messaging.Message;
2831
import org.neo4j.driver.internal.messaging.MessageFormat;
2932
import org.neo4j.driver.internal.security.SecurityPlan;
33+
import org.neo4j.driver.internal.security.TLSSocketChannel;
3034
import org.neo4j.driver.internal.util.BytePrinter;
35+
import org.neo4j.driver.v1.Config;
3136
import org.neo4j.driver.v1.Logger;
3237
import org.neo4j.driver.v1.exceptions.ClientException;
3338
import org.neo4j.driver.v1.exceptions.ServiceUnavailableException;
@@ -118,22 +123,31 @@ void blockingWrite( ByteBuffer buf ) throws IOException
118123

119124
public void start()
120125
{
126+
Socket socket = null;
127+
boolean connected = false;
121128
try
122129
{
123130
logger.debug( "Connecting to %s, secure: %s", address, securityPlan.requiresEncryption() );
124131
if( channel == null )
125132
{
126-
setChannel( ChannelFactory.create( address, securityPlan, timeoutMillis, logger ) );
133+
socket = newSocket( timeoutMillis );
134+
setChannel( ChannelFactory.create( socket, address, securityPlan, timeoutMillis, logger ) );
127135
logger.debug( "Connected to %s, secure: %s", address, securityPlan.requiresEncryption() );
128136
}
129137

130138
logger.debug( "Negotiating protocol with %s", address );
131139
SocketProtocol protocol = negotiateProtocol();
132140
setProtocol( protocol );
133141
logger.debug( "Selected protocol %s with %s", protocol.getClass(), address );
142+
143+
// reset read timeout (SO_TIMEOUT) to the original value of zero
144+
// we do not want to permanently limit amount of time driver waits for database to execute query
145+
socket.setSoTimeout( 0 );
146+
connected = true;
134147
}
135-
catch ( ConnectException e )
148+
catch ( ConnectException | SocketTimeoutException e )
136149
{
150+
// unable to connect socket or TLS/Bolt handshake took too much time
137151
throw new ServiceUnavailableException( format(
138152
"Unable to connect to %s, ensure the database is running and that there is a " +
139153
"working network connection to it.", address ), e );
@@ -142,6 +156,19 @@ public void start()
142156
{
143157
throw new ServiceUnavailableException( "Unable to process request: " + e.getMessage(), e );
144158
}
159+
finally
160+
{
161+
if ( !connected && socket != null )
162+
{
163+
try
164+
{
165+
socket.close();
166+
}
167+
catch ( Throwable ignore )
168+
{
169+
}
170+
}
171+
}
145172
}
146173

147174
public void updateProtocol( String serverVersion )
@@ -301,4 +328,35 @@ public BoltServerAddress address()
301328
{
302329
return address;
303330
}
331+
332+
/**
333+
* Creates new {@link Socket} object with {@link Socket#setSoTimeout(int) read timeout} set to the given value.
334+
* Connection to bolt server includes:
335+
* <ol>
336+
* <li>TCP connect via {@link Socket#connect(SocketAddress, int)}</li>
337+
* <li>Optional TLS handshake using {@link TLSSocketChannel}</li>
338+
* <li>Bolt handshake</li>
339+
* </ol>
340+
* We do not want any of these steps to hang infinitely if server does not respond and thus:
341+
* <ol>
342+
* <li>Use {@link Socket#connect(SocketAddress, int)} with timeout, as configured in
343+
* {@link Config#connectionTimeoutMillis()}</li>
344+
* <li>Initially set {@link Socket#setSoTimeout(int) read timeout} on the socket. Same connection-timeout value
345+
* from {@link Config#connectionTimeoutMillis()} is used. This way blocking reads during TLS and Bolt handshakes
346+
* have limited waiting time</li>
347+
* </ol>
348+
*
349+
* @param configuredConnectTimeout user-defined connection timeout to be initially used as read timeout.
350+
* @return new socket.
351+
* @throws IOException when creation or configuration of the socket fails.
352+
*/
353+
private static Socket newSocket( int configuredConnectTimeout ) throws IOException
354+
{
355+
Socket socket = new Socket();
356+
socket.setReuseAddress( true );
357+
socket.setKeepAlive( true );
358+
// set read timeout initially
359+
socket.setSoTimeout( configuredConnectTimeout );
360+
return socket;
361+
}
304362
}
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/*
2+
* Copyright (c) 2002-2017 "Neo Technology,"
3+
* Network Engine for Objects in Lund AB [http://neotechnology.com]
4+
*
5+
* This file is part of Neo4j.
6+
*
7+
* Licensed under the Apache License, Version 2.0 (the "License");
8+
* you may not use this file except in compliance with the License.
9+
* You may obtain a copy of the License at
10+
*
11+
* http://www.apache.org/licenses/LICENSE-2.0
12+
*
13+
* Unless required by applicable law or agreed to in writing, software
14+
* distributed under the License is distributed on an "AS IS" BASIS,
15+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
* See the License for the specific language governing permissions and
17+
* limitations under the License.
18+
*/
19+
package org.neo4j.driver.internal.net;
20+
21+
import java.io.IOException;
22+
import java.net.Socket;
23+
import java.nio.ByteBuffer;
24+
import java.nio.channels.ByteChannel;
25+
import java.nio.channels.Channels;
26+
import java.nio.channels.ReadableByteChannel;
27+
import java.nio.channels.WritableByteChannel;
28+
29+
public class UnencryptedSocketChannel implements ByteChannel
30+
{
31+
private final Socket socket;
32+
private final ReadableByteChannel readableChannel;
33+
private final WritableByteChannel writableChannel;
34+
35+
public UnencryptedSocketChannel( Socket socket ) throws IOException
36+
{
37+
this.socket = socket;
38+
this.readableChannel = Channels.newChannel( socket.getInputStream() );
39+
this.writableChannel = Channels.newChannel( socket.getOutputStream() );
40+
}
41+
42+
@Override
43+
public int read( ByteBuffer dst ) throws IOException
44+
{
45+
return readableChannel.read( dst );
46+
}
47+
48+
@Override
49+
public int write( ByteBuffer src ) throws IOException
50+
{
51+
return writableChannel.write( src );
52+
}
53+
54+
@Override
55+
public boolean isOpen()
56+
{
57+
return !socket.isClosed();
58+
}
59+
60+
@Override
61+
public void close() throws IOException
62+
{
63+
socket.close();
64+
}
65+
}

driver/src/test/java/org/neo4j/driver/internal/net/SocketClientTest.java

Lines changed: 31 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -18,24 +18,25 @@
1818
*/
1919
package org.neo4j.driver.internal.net;
2020

21-
import org.junit.Ignore;
2221
import org.junit.Rule;
2322
import org.junit.Test;
2423
import org.junit.rules.ExpectedException;
2524

2625
import java.io.IOException;
2726
import java.net.ServerSocket;
27+
import java.net.SocketTimeoutException;
2828
import java.nio.ByteBuffer;
2929
import java.nio.channels.ByteChannel;
3030
import java.util.ArrayList;
3131
import java.util.List;
3232

3333
import org.neo4j.driver.internal.security.SecurityPlan;
34-
import org.neo4j.driver.v1.exceptions.ClientException;
3534
import org.neo4j.driver.v1.exceptions.ServiceUnavailableException;
3635

3736
import static org.hamcrest.CoreMatchers.equalTo;
38-
import static org.hamcrest.MatcherAssert.assertThat;
37+
import static org.hamcrest.Matchers.instanceOf;
38+
import static org.junit.Assert.assertThat;
39+
import static org.junit.Assert.fail;
3940
import static org.mockito.Matchers.any;
4041
import static org.mockito.Mockito.mock;
4142
import static org.mockito.Mockito.when;
@@ -49,24 +50,16 @@ public class SocketClientTest
4950
@Rule
5051
public ExpectedException exception = ExpectedException.none();
5152

52-
// TODO: This is not possible with blocking NIO channels, unless we use inputStreams, but then we can't use
53-
// off-heap buffers. We need to swap to use selectors, which would allow us to time out.
5453
@Test
55-
@Ignore
56-
public void testNetworkTimeout() throws Throwable
54+
public void shouldFailWhenProtocolNegotiationTakesTooLong() throws Exception
5755
{
58-
// Given a server that will never reply
59-
ServerSocket server = new ServerSocket( 0 );
60-
BoltServerAddress address = new BoltServerAddress( "localhost", server.getLocalPort() );
61-
62-
SocketClient client = dummyClient( address );
63-
64-
// Expect
65-
exception.expect( ClientException.class );
66-
exception.expectMessage( "database took longer than network timeout (100ms) to reply." );
56+
testReadTimeoutOnConnect( SecurityPlan.insecure() );
57+
}
6758

68-
// When
69-
client.start();
59+
@Test
60+
public void shouldFailWhenTLSHandshakeTakesTooLong() throws Exception
61+
{
62+
testReadTimeoutOnConnect( SecurityPlan.forAllCertificates() );
7063
}
7164

7265
@Test
@@ -190,6 +183,26 @@ public void shouldFailIfConnectionFailsWhileWriting() throws IOException
190183
client.blockingWrite( buffer );
191184
}
192185

186+
private static void testReadTimeoutOnConnect( SecurityPlan securityPlan ) throws IOException
187+
{
188+
try ( ServerSocket server = new ServerSocket( 0 ) ) // server that does not reply
189+
{
190+
int timeoutMillis = 1_000;
191+
BoltServerAddress address = new BoltServerAddress( "localhost", server.getLocalPort() );
192+
SocketClient client = new SocketClient( address, securityPlan, timeoutMillis, DEV_NULL_LOGGER );
193+
194+
try
195+
{
196+
client.start();
197+
fail( "Exception expected" );
198+
}
199+
catch ( ServiceUnavailableException e )
200+
{
201+
assertThat( e.getCause(), instanceOf( SocketTimeoutException.class ) );
202+
}
203+
}
204+
}
205+
193206
private static class ByteAtATimeChannel implements ByteChannel
194207
{
195208

0 commit comments

Comments
 (0)