21
21
import java .io .IOException ;
22
22
import java .nio .ByteBuffer ;
23
23
import java .nio .channels .ByteChannel ;
24
- import java .nio .channels .SocketChannel ;
25
24
import java .security .GeneralSecurityException ;
26
25
import javax .net .ssl .SSLContext ;
27
26
import javax .net .ssl .SSLEngine ;
28
27
import javax .net .ssl .SSLEngineResult ;
29
28
import javax .net .ssl .SSLEngineResult .HandshakeStatus ;
30
29
import javax .net .ssl .SSLEngineResult .Status ;
31
- import javax .net .ssl .SSLSession ;
32
30
33
31
import org .neo4j .driver .internal .spi .Logger ;
34
32
import org .neo4j .driver .internal .util .BytePrinter ;
51
49
*/
52
50
public class TLSSocketChannel implements ByteChannel
53
51
{
54
- private final SocketChannel channel ; // The real channel the data is sent to and read from
52
+ private final ByteChannel channel ; // The real channel the data is sent to and read from
55
53
private final Logger logger ;
56
54
57
- private final SSLContext sslContext ;
58
55
private SSLEngine sslEngine ;
59
56
60
57
/** The buffer for network data */
@@ -66,34 +63,36 @@ public class TLSSocketChannel implements ByteChannel
66
63
67
64
private static final ByteBuffer DUMMY_BUFFER = ByteBuffer .allocateDirect ( 0 );
68
65
69
- public TLSSocketChannel ( String host , int port , SocketChannel channel , Logger logger ,
66
+ public TLSSocketChannel ( String host , int port , ByteChannel channel , Logger logger ,
70
67
TrustStrategy trustStrategy )
71
68
throws GeneralSecurityException , IOException
72
69
{
73
- logger .debug ( "TLS connection enabled" );
74
- this .logger = logger ;
75
- this .channel = channel ;
76
- this .channel .configureBlocking ( true );
70
+ this (channel , logger ,
71
+ createSSLEngine ( host , port , new SSLContextFactory ( host , port , trustStrategy , logger ).create () ) );
77
72
78
- sslContext = new SSLContextFactory ( host , port , trustStrategy , logger ).create ();
79
- createSSLEngine ( host , port );
80
- createBuffers ();
81
- runHandshake ();
82
- logger .debug ( "TLS connection established" );
83
73
}
84
74
85
- /** Used in internal tests only */
86
- TLSSocketChannel ( SocketChannel channel , Logger logger , SSLEngine sslEngine ,
75
+ public TLSSocketChannel ( ByteChannel channel , Logger logger , SSLEngine sslEngine ) throws GeneralSecurityException , IOException
76
+ {
77
+ this (channel , logger , sslEngine ,
78
+ ByteBuffer .allocateDirect ( sslEngine .getSession ().getApplicationBufferSize () ),
79
+ ByteBuffer .allocateDirect ( sslEngine .getSession ().getPacketBufferSize () ),
80
+ ByteBuffer .allocateDirect ( sslEngine .getSession ().getApplicationBufferSize () ),
81
+ ByteBuffer .allocateDirect ( sslEngine .getSession ().getPacketBufferSize () ) );
82
+ }
83
+
84
+ TLSSocketChannel ( ByteChannel channel , Logger logger , SSLEngine sslEngine ,
87
85
ByteBuffer plainIn , ByteBuffer cipherIn , ByteBuffer plainOut , ByteBuffer cipherOut )
88
86
throws GeneralSecurityException , IOException
89
87
{
90
- logger .debug ( "Testing TLS buffers" );
91
88
this .logger = logger ;
92
89
this .channel = channel ;
93
-
94
- this .sslContext = SSLContext .getInstance ( "TLS" );
95
90
this .sslEngine = sslEngine ;
96
- resetBuffers ( plainIn , cipherIn , plainOut , cipherOut ); // reset buffer size
91
+ this .plainIn = plainIn ;
92
+ this .cipherIn = cipherIn ;
93
+ this .plainOut = plainOut ;
94
+ this .cipherOut = cipherOut ;
95
+ runHandshake ();
97
96
}
98
97
99
98
/**
@@ -126,17 +125,13 @@ private void runHandshake() throws IOException
126
125
case NEED_UNWRAP :
127
126
// Unwrap the ssl packet to value ssl handshake information
128
127
handshakeStatus = unwrap ( DUMMY_BUFFER );
129
- plainIn .clear ();
130
128
break ;
131
129
case NEED_WRAP :
132
130
// Wrap the app packet into an ssl packet to add ssl handshake information
133
131
handshakeStatus = wrap ( plainOut );
134
132
break ;
135
133
}
136
134
}
137
-
138
- plainIn .clear ();
139
- plainOut .clear ();
140
135
}
141
136
142
137
private HandshakeStatus runDelegatedTasks ()
@@ -185,10 +180,11 @@ private HandshakeStatus unwrap( ByteBuffer buffer ) throws IOException
185
180
}
186
181
cipherIn .flip ();
187
182
188
- Status status = null ;
183
+ Status status ;
189
184
do
190
185
{
191
- status = sslEngine .unwrap ( cipherIn , plainIn ).getStatus ();
186
+ SSLEngineResult unwrapResult = sslEngine .unwrap ( cipherIn , plainIn );
187
+ status = unwrapResult .getStatus ();
192
188
// Possible status here:
193
189
// OK - good
194
190
// BUFFER_OVERFLOW - we need to enlarge* plainIn
@@ -244,17 +240,13 @@ private HandshakeStatus unwrap( ByteBuffer buffer ) throws IOException
244
240
// Otherwise, make room for reading more data from channel
245
241
cipherIn .compact ();
246
242
}
247
-
248
- // I skipped the following check as it "should not" happen at all:
249
- // The channel should not provide us ciphered bytes that cannot hold in the channel buffer at all
250
- // if( cipherIn.remaining() == 0 )
251
- // {throw new ClientException( "cannot enlarge as it already reached the limit" );}
252
-
253
- // Obtain more inbound network data for cipherIn,
254
- // then retry the operation.
255
243
return handshakeStatus ; // old status
244
+ case CLOSED :
245
+ // RFC 2246 #7.2.1 requires us to stop accepting input.
246
+ sslEngine .closeInbound ();
247
+ break ;
256
248
default :
257
- throw new ClientException ( "Got unexpected status " + status );
249
+ throw new ClientException ( "Got unexpected status " + status + ", " + unwrapResult );
258
250
}
259
251
}
260
252
while ( cipherIn .hasRemaining () ); /* Remember we are doing blocking reading.
@@ -285,7 +277,10 @@ private HandshakeStatus wrap( ByteBuffer buffer ) throws IOException
285
277
case OK :
286
278
handshakeStatus = runDelegatedTasks ();
287
279
cipherOut .flip ();
288
- channel .write ( cipherOut );
280
+ while (cipherOut .hasRemaining ())
281
+ {
282
+ channel .write ( cipherOut );
283
+ }
289
284
cipherOut .clear ();
290
285
break ;
291
286
case BUFFER_OVERFLOW :
@@ -344,42 +339,17 @@ static int bufferCopy( ByteBuffer from, ByteBuffer to )
344
339
return maxTransfer ;
345
340
}
346
341
347
- /**
348
- * Create network buffers and application buffers
349
- *
350
- * @throws IOException
351
- */
352
- private void createBuffers () throws IOException
353
- {
354
- SSLSession session = sslEngine .getSession ();
355
- int appBufferSize = session .getApplicationBufferSize ();
356
- int netBufferSize = session .getPacketBufferSize ();
357
-
358
- plainOut = ByteBuffer .allocateDirect ( appBufferSize );
359
- plainIn = ByteBuffer .allocateDirect ( appBufferSize );
360
- cipherOut = ByteBuffer .allocateDirect ( netBufferSize );
361
- cipherIn = ByteBuffer .allocateDirect ( netBufferSize );
362
- }
363
-
364
- /** Should only be used in tests */
365
- void resetBuffers ( ByteBuffer plainIn , ByteBuffer cipherIn , ByteBuffer plainOut , ByteBuffer cipherOut )
366
- {
367
- this .plainIn = plainIn ;
368
- this .cipherIn = cipherIn ;
369
- this .plainOut = plainOut ;
370
- this .cipherOut = cipherOut ;
371
- }
372
-
373
342
/**
374
343
* Create SSLEngine with the SSLContext just created.
375
- *
376
344
* @param host
377
345
* @param port
346
+ * @param sslContext
378
347
*/
379
- private void createSSLEngine ( String host , int port )
348
+ private static SSLEngine createSSLEngine ( String host , int port , SSLContext sslContext )
380
349
{
381
- sslEngine = sslContext .createSSLEngine ( host , port );
350
+ SSLEngine sslEngine = sslContext .createSSLEngine ( host , port );
382
351
sslEngine .setUseClientMode ( true );
352
+ return sslEngine ;
383
353
}
384
354
385
355
@ Override
@@ -431,33 +401,46 @@ public boolean isOpen()
431
401
@ Override
432
402
public void close () throws IOException
433
403
{
434
- plainOut .clear ();
435
- // Indicate that application is done with engine
436
- sslEngine .closeOutbound ();
437
-
438
- while ( !sslEngine .isOutboundDone () )
404
+ try
439
405
{
440
- // Get close message
441
- SSLEngineResult res = sslEngine .wrap ( plainOut , cipherOut );
442
-
443
- // Check res statuses
406
+ plainOut .clear ();
407
+ // Indicate that application is done with engine
408
+ sslEngine .closeOutbound ();
444
409
445
- // Send close message to peer
446
- cipherOut .flip ();
447
- while ( cipherOut .hasRemaining () )
410
+ while ( !sslEngine .isOutboundDone () )
448
411
{
449
- int num = channel .write ( cipherOut );
450
- if ( num == -1 )
412
+ // Get close message
413
+ SSLEngineResult res = sslEngine .wrap ( plainOut , cipherOut );
414
+
415
+ // Check res statuses
416
+
417
+ // Send close message to peer
418
+ cipherOut .flip ();
419
+ while ( cipherOut .hasRemaining () )
451
420
{
452
- // handle closed channel
453
- break ;
421
+ int num = channel .write ( cipherOut );
422
+ if ( num == -1 )
423
+ {
424
+ // handle closed channel
425
+ break ;
426
+ }
454
427
}
428
+ cipherOut .clear ();
455
429
}
456
- cipherOut .clear ();
430
+ // Close transport
431
+ channel .close ();
432
+ logger .debug ( "TLS connection closed" );
433
+ }
434
+ catch (IOException e )
435
+ {
436
+ // Treat this as ok - the connection is closed, even if the TLS session did not exit cleanly.
437
+ logger .warn ( "TLS socket could not be closed cleanly: '" +e .getMessage ()+"'" , e );
457
438
}
458
- // Close transport
459
- channel .close ();
460
- logger .debug ( "TLS connection closed" );
461
439
}
462
440
441
+ @ Override
442
+ public String toString ()
443
+ {
444
+ return "TLSSocketChannel{plainIn: " + plainIn + ", cipherIn:" + cipherIn + "}" ;
445
+ }
463
446
}
0 commit comments