Skip to content

Commit 95be53a

Browse files
author
Zhen
committed
Fix after review
1 parent 7f4d202 commit 95be53a

File tree

4 files changed

+42
-24
lines changed

4 files changed

+42
-24
lines changed

driver/src/main/java/org/neo4j/driver/internal/net/ChannelFactory.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ private static void connect( SocketChannel soChannel, BoltServerAddress address,
6565
}
6666
catch ( SocketTimeoutException e )
6767
{
68-
throw new ConnectException( "Timeout " + timeoutMillis + "ms expired" + e.getMessage() );
68+
throw new ConnectException( "Timeout " + timeoutMillis + "ms expired" + e );
6969
}
7070
}
7171
}

driver/src/main/java/org/neo4j/driver/internal/security/TLSSocketChannel.java

Lines changed: 20 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
import org.neo4j.driver.v1.exceptions.ClientException;
3434
import org.neo4j.driver.v1.exceptions.ServiceUnavailableException;
3535

36+
import static java.lang.String.format;
3637
import static javax.net.ssl.SSLEngineResult.HandshakeStatus.FINISHED;
3738
import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING;
3839

@@ -80,8 +81,7 @@ public static TLSSocketChannel create( ByteChannel channel, Logger logger, SSLEn
8081
}
8182
catch ( SSLHandshakeException e )
8283
{
83-
throw new ClientException( "Failed to establish secured connection with the server: " + e.getMessage(),
84-
e.getCause() );
84+
throw new ClientException( "Failed to establish secured connection with the server: " + e.getMessage(), e );
8585
}
8686
return tlsChannel;
8787
}
@@ -153,12 +153,10 @@ private HandshakeStatus runDelegatedTasks()
153153
* @param toBuffer the destination where the data read from the socket channel are saved
154154
* @throws IOException when failed to read from channel
155155
*/
156-
void channelRead( ByteBuffer toBuffer ) throws IOException
156+
int channelRead( ByteBuffer toBuffer ) throws IOException
157157
{
158-
/**
159-
* This is the only place to read from the underlying channel
160-
*/
161-
if ( channel.read( toBuffer ) < 0 )
158+
int read = channel.read( toBuffer );
159+
if ( read < 0 )
162160
{
163161
try
164162
{
@@ -172,16 +170,18 @@ void channelRead( ByteBuffer toBuffer ) throws IOException
172170
"SSL Connection terminated while receiving data. " +
173171
"This can happen due to network instabilities, or due to restarts of the database." );
174172
}
173+
return read;
175174
}
176175

177176
/**
178177
* Write the data saved in the buffer to the socket channel
179178
* @param fromBuffer the source where the data written to the socket channel are saved
180179
* @throws IOException when failed to write to channel
181180
*/
182-
void channelWrite( ByteBuffer fromBuffer ) throws IOException
181+
int channelWrite( ByteBuffer fromBuffer ) throws IOException
183182
{
184-
if ( channel.write( fromBuffer ) < 0 )
183+
int written = channel.write( fromBuffer );
184+
if ( written < 0 )
185185
{
186186
try
187187
{
@@ -195,6 +195,7 @@ void channelWrite( ByteBuffer fromBuffer ) throws IOException
195195
"SSL Connection terminated while writing data. " +
196196
"This can happen due to network instabilities, or due to restarts of the database." );
197197
}
198+
return written;
198199
}
199200

200201
/**
@@ -259,7 +260,7 @@ private HandshakeStatus unwrap( ByteBuffer buffer ) throws IOException
259260
if ( newAppSize > appSize * 2 )
260261
{
261262
throw new ClientException(
262-
String.format( "Failed ro enlarge application input buffer from %s to %s, as the maximum " +
263+
format( "Failed ro enlarge application input buffer from %s to %s, as the maximum " +
263264
"buffer size allowed is %s. The content in the buffer is: %s\n",
264265
curAppSize, newAppSize, appSize * 2, BytePrinter.hex( plainIn ) ) );
265266
}
@@ -311,7 +312,7 @@ private HandshakeStatus unwrap( ByteBuffer buffer ) throws IOException
311312
* @return The status of the current handshake
312313
* @throws IOException
313314
*/
314-
private HandshakeStatus wrap( ByteBuffer buffer ) throws IOException
315+
private HandshakeStatus wrap( ByteBuffer buffer ) throws IOException, ClientException
315316
{
316317
HandshakeStatus handshakeStatus = sslEngine.getHandshakeStatus();
317318
Status status = sslEngine.wrap( buffer, cipherOut ).getStatus();
@@ -345,7 +346,14 @@ private HandshakeStatus wrap( ByteBuffer buffer ) throws IOException
345346
{
346347
// flush as much data as possible
347348
cipherOut.flip();
348-
channelWrite( cipherOut );
349+
if ( channelWrite( cipherOut ) == 0 )
350+
{
351+
throw new ClientException( format(
352+
"Failed to enlarge network buffer from %s to %s. This is either because the " +
353+
"new size is however less than the old size, or because the application " +
354+
"buffer size %s is so big that the application data still cannot fit into the " +
355+
"new network buffer.", curNetSize, netSize, buffer.capacity() ) );
356+
}
349357
cipherOut.compact();
350358
logger.debug( "Network output buffer couldn't be enlarged, flushing data to the channel instead." );
351359
}

driver/src/test/java/org/neo4j/driver/internal/security/TLSSocketChannelTest.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@
3636
import static org.mockito.Matchers.any;
3737
import static org.mockito.Mockito.doThrow;
3838
import static org.mockito.Mockito.mock;
39-
import static org.mockito.Mockito.times;
39+
import static org.mockito.Mockito.never;
4040
import static org.mockito.Mockito.verify;
4141
import static org.mockito.Mockito.when;
4242
import static org.neo4j.driver.internal.logging.DevNullLogger.DEV_NULL_LOGGER;
@@ -72,7 +72,7 @@ public void shouldCloseConnectionIfFailedToRead() throws Throwable
7272
assertThat( e.getMessage(), startsWith( "SSL Connection terminated while receiving data. " ) );
7373
}
7474
// Then
75-
verify( mockedChannel, times( 1 ) ).close();
75+
verify( mockedChannel ).close();
7676
}
7777

7878
@Test
@@ -103,7 +103,7 @@ public void shouldCloseConnectionIfFailedToWrite() throws Throwable
103103
}
104104

105105
// Then
106-
verify( mockedChannel, times( 1 ) ).close();
106+
verify( mockedChannel ).close();
107107
}
108108

109109
@Test
@@ -131,6 +131,6 @@ public void shouldThrowClientErrorIfFailedToHandshake() throws Throwable
131131
assertThat( e, instanceOf( ClientException.class ) );
132132
assertThat( e.getMessage(), startsWith( "Failed to establish secured connection with the server: Failed handshake!" ) );
133133
}
134-
verify( mockedChannel, times( 0 ) ).close();
134+
verify( mockedChannel, never() ).close();
135135
}
136136
}

driver/src/test/java/org/neo4j/driver/v1/integration/TLSSocketChannelIT.java

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,6 @@
5353
import static org.hamcrest.Matchers.containsString;
5454
import static org.junit.Assert.assertEquals;
5555
import static org.junit.Assert.assertFalse;
56-
import static org.junit.Assert.assertTrue;
5756
import static org.mockito.Matchers.anyString;
5857
import static org.mockito.Mockito.atLeastOnce;
5958
import static org.mockito.Mockito.mock;
@@ -160,8 +159,7 @@ public void shouldNotPerformTLSHandshakeWithNonSystemCert() throws Throwable
160159
catch ( ClientException e )
161160
{
162161
assertThat( e.getMessage(), containsString( "General SSLEngine problem" ) );
163-
assertEquals( "General SSLEngine problem", e.getCause().getMessage() );
164-
assertThat( e.getCause().getCause().getMessage(),
162+
assertThat( getRootCause( e ).getMessage(),
165163
containsString( "unable to find valid certification path to requested target" ) );
166164
}
167165
}
@@ -196,8 +194,7 @@ public void shouldFailTLSHandshakeDueToWrongCertInKnownCertsFile() throws Throwa
196194
catch ( SSLHandshakeException e )
197195
{
198196
assertEquals( "General SSLEngine problem", e.getMessage() );
199-
assertEquals( "General SSLEngine problem", e.getCause().getMessage() );
200-
assertTrue( e.getCause().getCause().getMessage().contains(
197+
assertThat( getRootCause( e ).getMessage(), containsString(
201198
"If you trust the certificate the server uses now, simply remove the line that starts with" ) );
202199
}
203200
finally
@@ -249,8 +246,7 @@ public void shouldFailTLSHandshakeDueToServerCertNotSignedByKnownCA() throws Thr
249246
catch ( ClientException e )
250247
{
251248
assertThat( e.getMessage(), containsString( "General SSLEngine problem" ) );
252-
assertEquals( "General SSLEngine problem", e.getCause().getMessage() );
253-
assertEquals( "No trusted certificate found", e.getCause().getCause().getMessage() );
249+
assertThat( getRootCause( e ).getMessage(), containsString( "No trusted certificate found" ) );
254250
}
255251
finally
256252
{
@@ -261,6 +257,20 @@ public void shouldFailTLSHandshakeDueToServerCertNotSignedByKnownCA() throws Thr
261257
}
262258
}
263259

260+
private Throwable getRootCause( Throwable e )
261+
{
262+
Throwable parentError = e;
263+
Throwable error = null;
264+
do
265+
{
266+
error = parentError;
267+
parentError = error.getCause();
268+
269+
}
270+
while( parentError != null );
271+
return error;
272+
}
273+
264274
@Test
265275
public void shouldPerformTLSHandshakeWithTheSameTrustedServerCert() throws Throwable
266276
{

0 commit comments

Comments
 (0)