Skip to content

Commit 12dfe1a

Browse files
committed
Merge pull request #165 from jakewins/tls
Resolve deadlock issue in TLSSocketChannel
2 parents dbea6ff + 54610a4 commit 12dfe1a

File tree

7 files changed

+318
-463
lines changed

7 files changed

+318
-463
lines changed

driver/src/main/java/org/neo4j/driver/internal/connector/socket/SocketClient.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -215,6 +215,7 @@ public static ByteChannel create( String host, int port, Config config, Logger l
215215
SocketChannel soChannel = SocketChannel.open();
216216
soChannel.setOption( StandardSocketOptions.SO_REUSEADDR, true );
217217
soChannel.setOption( StandardSocketOptions.SO_KEEPALIVE, true );
218+
218219
soChannel.connect( new InetSocketAddress( host, port ) );
219220

220221
ByteChannel channel;

driver/src/main/java/org/neo4j/driver/internal/connector/socket/TLSSocketChannel.java

Lines changed: 68 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,12 @@
2121
import java.io.IOException;
2222
import java.nio.ByteBuffer;
2323
import java.nio.channels.ByteChannel;
24-
import java.nio.channels.SocketChannel;
2524
import java.security.GeneralSecurityException;
2625
import javax.net.ssl.SSLContext;
2726
import javax.net.ssl.SSLEngine;
2827
import javax.net.ssl.SSLEngineResult;
2928
import javax.net.ssl.SSLEngineResult.HandshakeStatus;
3029
import javax.net.ssl.SSLEngineResult.Status;
31-
import javax.net.ssl.SSLSession;
3230

3331
import org.neo4j.driver.internal.spi.Logger;
3432
import org.neo4j.driver.internal.util.BytePrinter;
@@ -51,10 +49,9 @@
5149
*/
5250
public class TLSSocketChannel implements ByteChannel
5351
{
54-
private final SocketChannel channel; // The real channel the data is sent to and read from
52+
private final ByteChannel channel; // The real channel the data is sent to and read from
5553
private final Logger logger;
5654

57-
private final SSLContext sslContext;
5855
private SSLEngine sslEngine;
5956

6057
/** The buffer for network data */
@@ -66,34 +63,36 @@ public class TLSSocketChannel implements ByteChannel
6663

6764
private static final ByteBuffer DUMMY_BUFFER = ByteBuffer.allocateDirect( 0 );
6865

69-
public TLSSocketChannel( String host, int port, SocketChannel channel, Logger logger,
66+
public TLSSocketChannel( String host, int port, ByteChannel channel, Logger logger,
7067
TrustStrategy trustStrategy )
7168
throws GeneralSecurityException, IOException
7269
{
73-
logger.debug( "TLS connection enabled" );
74-
this.logger = logger;
75-
this.channel = channel;
76-
this.channel.configureBlocking( true );
70+
this(channel, logger,
71+
createSSLEngine( host, port, new SSLContextFactory( host, port, trustStrategy, logger ).create() ) );
7772

78-
sslContext = new SSLContextFactory( host, port, trustStrategy, logger ).create();
79-
createSSLEngine( host, port );
80-
createBuffers();
81-
runHandshake();
82-
logger.debug( "TLS connection established" );
8373
}
8474

85-
/** Used in internal tests only */
86-
TLSSocketChannel( SocketChannel channel, Logger logger, SSLEngine sslEngine,
75+
public TLSSocketChannel( ByteChannel channel, Logger logger, SSLEngine sslEngine ) throws GeneralSecurityException, IOException
76+
{
77+
this(channel, logger, sslEngine,
78+
ByteBuffer.allocateDirect( sslEngine.getSession().getApplicationBufferSize() ),
79+
ByteBuffer.allocateDirect( sslEngine.getSession().getPacketBufferSize() ),
80+
ByteBuffer.allocateDirect( sslEngine.getSession().getApplicationBufferSize() ),
81+
ByteBuffer.allocateDirect( sslEngine.getSession().getPacketBufferSize() ) );
82+
}
83+
84+
TLSSocketChannel( ByteChannel channel, Logger logger, SSLEngine sslEngine,
8785
ByteBuffer plainIn, ByteBuffer cipherIn, ByteBuffer plainOut, ByteBuffer cipherOut )
8886
throws GeneralSecurityException, IOException
8987
{
90-
logger.debug( "Testing TLS buffers" );
9188
this.logger = logger;
9289
this.channel = channel;
93-
94-
this.sslContext = SSLContext.getInstance( "TLS" );
9590
this.sslEngine = sslEngine;
96-
resetBuffers( plainIn, cipherIn, plainOut, cipherOut ); // reset buffer size
91+
this.plainIn = plainIn;
92+
this.cipherIn = cipherIn;
93+
this.plainOut = plainOut;
94+
this.cipherOut = cipherOut;
95+
runHandshake();
9796
}
9897

9998
/**
@@ -126,17 +125,13 @@ private void runHandshake() throws IOException
126125
case NEED_UNWRAP:
127126
// Unwrap the ssl packet to value ssl handshake information
128127
handshakeStatus = unwrap( DUMMY_BUFFER );
129-
plainIn.clear();
130128
break;
131129
case NEED_WRAP:
132130
// Wrap the app packet into an ssl packet to add ssl handshake information
133131
handshakeStatus = wrap( plainOut );
134132
break;
135133
}
136134
}
137-
138-
plainIn.clear();
139-
plainOut.clear();
140135
}
141136

142137
private HandshakeStatus runDelegatedTasks()
@@ -185,10 +180,11 @@ private HandshakeStatus unwrap( ByteBuffer buffer ) throws IOException
185180
}
186181
cipherIn.flip();
187182

188-
Status status = null;
183+
Status status;
189184
do
190185
{
191-
status = sslEngine.unwrap( cipherIn, plainIn ).getStatus();
186+
SSLEngineResult unwrapResult = sslEngine.unwrap( cipherIn, plainIn );
187+
status = unwrapResult.getStatus();
192188
// Possible status here:
193189
// OK - good
194190
// BUFFER_OVERFLOW - we need to enlarge* plainIn
@@ -244,17 +240,13 @@ private HandshakeStatus unwrap( ByteBuffer buffer ) throws IOException
244240
// Otherwise, make room for reading more data from channel
245241
cipherIn.compact();
246242
}
247-
248-
// I skipped the following check as it "should not" happen at all:
249-
// The channel should not provide us ciphered bytes that cannot hold in the channel buffer at all
250-
// if( cipherIn.remaining() == 0 )
251-
// {throw new ClientException( "cannot enlarge as it already reached the limit" );}
252-
253-
// Obtain more inbound network data for cipherIn,
254-
// then retry the operation.
255243
return handshakeStatus; // old status
244+
case CLOSED:
245+
// RFC 2246 #7.2.1 requires us to stop accepting input.
246+
sslEngine.closeInbound();
247+
break;
256248
default:
257-
throw new ClientException( "Got unexpected status " + status );
249+
throw new ClientException( "Got unexpected status " + status + ", " + unwrapResult );
258250
}
259251
}
260252
while ( cipherIn.hasRemaining() ); /* Remember we are doing blocking reading.
@@ -285,7 +277,10 @@ private HandshakeStatus wrap( ByteBuffer buffer ) throws IOException
285277
case OK:
286278
handshakeStatus = runDelegatedTasks();
287279
cipherOut.flip();
288-
channel.write( cipherOut );
280+
while(cipherOut.hasRemaining())
281+
{
282+
channel.write( cipherOut );
283+
}
289284
cipherOut.clear();
290285
break;
291286
case BUFFER_OVERFLOW:
@@ -344,42 +339,17 @@ static int bufferCopy( ByteBuffer from, ByteBuffer to )
344339
return maxTransfer;
345340
}
346341

347-
/**
348-
* Create network buffers and application buffers
349-
*
350-
* @throws IOException
351-
*/
352-
private void createBuffers() throws IOException
353-
{
354-
SSLSession session = sslEngine.getSession();
355-
int appBufferSize = session.getApplicationBufferSize();
356-
int netBufferSize = session.getPacketBufferSize();
357-
358-
plainOut = ByteBuffer.allocateDirect( appBufferSize );
359-
plainIn = ByteBuffer.allocateDirect( appBufferSize );
360-
cipherOut = ByteBuffer.allocateDirect( netBufferSize );
361-
cipherIn = ByteBuffer.allocateDirect( netBufferSize );
362-
}
363-
364-
/** Should only be used in tests */
365-
void resetBuffers( ByteBuffer plainIn, ByteBuffer cipherIn, ByteBuffer plainOut, ByteBuffer cipherOut )
366-
{
367-
this.plainIn = plainIn;
368-
this.cipherIn = cipherIn;
369-
this.plainOut = plainOut;
370-
this.cipherOut = cipherOut;
371-
}
372-
373342
/**
374343
* Create SSLEngine with the SSLContext just created.
375-
*
376344
* @param host
377345
* @param port
346+
* @param sslContext
378347
*/
379-
private void createSSLEngine( String host, int port )
348+
private static SSLEngine createSSLEngine( String host, int port, SSLContext sslContext )
380349
{
381-
sslEngine = sslContext.createSSLEngine( host, port );
350+
SSLEngine sslEngine = sslContext.createSSLEngine( host, port );
382351
sslEngine.setUseClientMode( true );
352+
return sslEngine;
383353
}
384354

385355
@Override
@@ -431,33 +401,46 @@ public boolean isOpen()
431401
@Override
432402
public void close() throws IOException
433403
{
434-
plainOut.clear();
435-
// Indicate that application is done with engine
436-
sslEngine.closeOutbound();
437-
438-
while ( !sslEngine.isOutboundDone() )
404+
try
439405
{
440-
// Get close message
441-
SSLEngineResult res = sslEngine.wrap( plainOut, cipherOut );
442-
443-
// Check res statuses
406+
plainOut.clear();
407+
// Indicate that application is done with engine
408+
sslEngine.closeOutbound();
444409

445-
// Send close message to peer
446-
cipherOut.flip();
447-
while ( cipherOut.hasRemaining() )
410+
while ( !sslEngine.isOutboundDone() )
448411
{
449-
int num = channel.write( cipherOut );
450-
if ( num == -1 )
412+
// Get close message
413+
SSLEngineResult res = sslEngine.wrap( plainOut, cipherOut );
414+
415+
// Check res statuses
416+
417+
// Send close message to peer
418+
cipherOut.flip();
419+
while ( cipherOut.hasRemaining() )
451420
{
452-
// handle closed channel
453-
break;
421+
int num = channel.write( cipherOut );
422+
if ( num == -1 )
423+
{
424+
// handle closed channel
425+
break;
426+
}
454427
}
428+
cipherOut.clear();
455429
}
456-
cipherOut.clear();
430+
// Close transport
431+
channel.close();
432+
logger.debug( "TLS connection closed" );
433+
}
434+
catch(IOException e)
435+
{
436+
// Treat this as ok - the connection is closed, even if the TLS session did not exit cleanly.
437+
logger.warn( "TLS socket could not be closed cleanly: '"+e.getMessage()+"'", e );
457438
}
458-
// Close transport
459-
channel.close();
460-
logger.debug( "TLS connection closed" );
461439
}
462440

441+
@Override
442+
public String toString()
443+
{
444+
return "TLSSocketChannel{plainIn: " + plainIn + ", cipherIn:" + cipherIn + "}";
445+
}
463446
}

0 commit comments

Comments
 (0)