1
- // Copyright (c) 2007-2020 VMware, Inc. or its affiliates. All rights reserved.
1
+ // Copyright (c) 2007-2021 VMware, Inc. or its affiliates. All rights reserved.
2
2
//
3
3
// This software, the RabbitMQ Java client library, is triple-licensed under the
4
4
// Mozilla Public License 2.0 ("MPL"), the GNU General Public License version 2
23
23
import java .nio .channels .ReadableByteChannel ;
24
24
import java .nio .channels .SocketChannel ;
25
25
import java .nio .channels .WritableByteChannel ;
26
+ import org .slf4j .Logger ;
27
+ import org .slf4j .LoggerFactory ;
26
28
27
29
import static javax .net .ssl .SSLEngineResult .HandshakeStatus .FINISHED ;
30
+ import static javax .net .ssl .SSLEngineResult .HandshakeStatus .NEED_TASK ;
31
+ import static javax .net .ssl .SSLEngineResult .HandshakeStatus .NEED_WRAP ;
28
32
import static javax .net .ssl .SSLEngineResult .HandshakeStatus .NOT_HANDSHAKING ;
29
33
30
34
/**
31
35
*
32
36
*/
33
37
public class SslEngineHelper {
34
38
39
+ private static final Logger LOGGER = LoggerFactory .getLogger (SslEngineHelper .class );
40
+
35
41
public static boolean doHandshake (SocketChannel socketChannel , SSLEngine engine ) throws IOException {
36
42
37
43
ByteBuffer plainOut = ByteBuffer .allocate (engine .getSession ().getApplicationBufferSize ());
38
44
ByteBuffer plainIn = ByteBuffer .allocate (engine .getSession ().getApplicationBufferSize ());
39
45
ByteBuffer cipherOut = ByteBuffer .allocate (engine .getSession ().getPacketBufferSize ());
40
46
ByteBuffer cipherIn = ByteBuffer .allocate (engine .getSession ().getPacketBufferSize ());
41
47
48
+ LOGGER .debug ("Starting TLS handshake" );
49
+
42
50
SSLEngineResult .HandshakeStatus handshakeStatus = engine .getHandshakeStatus ();
51
+ LOGGER .debug ("Initial handshake status is {}" , handshakeStatus );
43
52
while (handshakeStatus != FINISHED && handshakeStatus != NOT_HANDSHAKING ) {
53
+ LOGGER .debug ("Handshake status is {}" , handshakeStatus );
44
54
switch (handshakeStatus ) {
45
55
case NEED_TASK :
56
+ LOGGER .debug ("Running tasks" );
46
57
handshakeStatus = runDelegatedTasks (engine );
47
58
break ;
48
59
case NEED_UNWRAP :
60
+ LOGGER .debug ("Unwrapping..." );
49
61
handshakeStatus = unwrap (cipherIn , plainIn , socketChannel , engine );
50
62
break ;
51
63
case NEED_WRAP :
64
+ LOGGER .debug ("Wrapping..." );
52
65
handshakeStatus = wrap (plainOut , cipherOut , socketChannel , engine );
53
66
break ;
67
+ case FINISHED :
68
+ break ;
69
+ case NOT_HANDSHAKING :
70
+ break ;
71
+ default :
72
+ throw new SSLException ("Unexpected handshake status " + handshakeStatus );
54
73
}
55
74
}
75
+
76
+
77
+ LOGGER .debug ("TLS handshake completed" );
56
78
return true ;
57
79
}
58
80
59
81
private static SSLEngineResult .HandshakeStatus runDelegatedTasks (SSLEngine sslEngine ) {
60
82
// FIXME run in executor?
61
83
Runnable runnable ;
62
84
while ((runnable = sslEngine .getDelegatedTask ()) != null ) {
85
+ LOGGER .debug ("Running delegated task" );
63
86
runnable .run ();
64
87
}
65
88
return sslEngine .getHandshakeStatus ();
@@ -68,29 +91,57 @@ private static SSLEngineResult.HandshakeStatus runDelegatedTasks(SSLEngine sslEn
68
91
private static SSLEngineResult .HandshakeStatus unwrap (ByteBuffer cipherIn , ByteBuffer plainIn ,
69
92
ReadableByteChannel channel , SSLEngine sslEngine ) throws IOException {
70
93
SSLEngineResult .HandshakeStatus handshakeStatus = sslEngine .getHandshakeStatus ();
71
-
72
- if (channel .read (cipherIn ) < 0 ) {
73
- throw new SSLException ("Could not read from socket channel" );
94
+ LOGGER .debug ("Handshake status is {} before unwrapping" , handshakeStatus );
95
+
96
+ LOGGER .debug ("Cipher in position {}" , cipherIn .position ());
97
+ int read ;
98
+ if (cipherIn .position () == 0 ) {
99
+ LOGGER .debug ("Reading from channel" );
100
+ read = channel .read (cipherIn );
101
+ LOGGER .debug ("Read {} byte(s) from channel" , read );
102
+ if (read < 0 ) {
103
+ throw new SSLException ("Could not read from socket channel" );
104
+ }
105
+ cipherIn .flip ();
106
+ } else {
107
+ LOGGER .debug ("Not reading" );
74
108
}
75
- cipherIn .flip ();
76
109
77
110
SSLEngineResult .Status status ;
111
+ SSLEngineResult unwrapResult ;
78
112
do {
79
- SSLEngineResult unwrapResult = sslEngine .unwrap (cipherIn , plainIn );
113
+ int positionBeforeUnwrapping = cipherIn .position ();
114
+ unwrapResult = sslEngine .unwrap (cipherIn , plainIn );
115
+ LOGGER .debug ("SSL engine result is {} after unwrapping" , unwrapResult );
80
116
status = unwrapResult .getStatus ();
81
117
switch (status ) {
82
118
case OK :
83
119
plainIn .clear ();
84
- handshakeStatus = runDelegatedTasks (sslEngine );
120
+ if (unwrapResult .getHandshakeStatus () == NEED_TASK ) {
121
+ handshakeStatus = runDelegatedTasks (sslEngine );
122
+ int newPosition = positionBeforeUnwrapping + unwrapResult .bytesConsumed ();
123
+ if (newPosition == cipherIn .limit ()) {
124
+ LOGGER .debug ("Clearing cipherIn because all bytes have been read and unwrapped" );
125
+ cipherIn .clear ();
126
+ } else {
127
+ LOGGER .debug ("Setting cipherIn position to {} (limit is {})" , newPosition , cipherIn .limit ());
128
+ cipherIn .position (positionBeforeUnwrapping + unwrapResult .bytesConsumed ());
129
+ }
130
+ } else {
131
+ handshakeStatus = unwrapResult .getHandshakeStatus ();
132
+ }
85
133
break ;
86
134
case BUFFER_OVERFLOW :
87
135
throw new SSLException ("Buffer overflow during handshake" );
88
136
case BUFFER_UNDERFLOW :
137
+ LOGGER .debug ("Buffer underflow" );
89
138
cipherIn .compact ();
90
- int read = NioHelper .read (channel , cipherIn );
139
+ LOGGER .debug ("Reading from channel..." );
140
+ read = NioHelper .read (channel , cipherIn );
91
141
if (read <= 0 ) {
92
142
retryRead (channel , cipherIn );
93
143
}
144
+ LOGGER .debug ("Done reading from channel..." );
94
145
cipherIn .flip ();
95
146
break ;
96
147
case CLOSED :
@@ -100,9 +151,9 @@ private static SSLEngineResult.HandshakeStatus unwrap(ByteBuffer cipherIn, ByteB
100
151
throw new SSLException ("Unexpected status from " + unwrapResult );
101
152
}
102
153
}
103
- while (cipherIn . hasRemaining () );
154
+ while (unwrapResult . getHandshakeStatus () != NEED_WRAP && unwrapResult . getHandshakeStatus () != FINISHED );
104
155
105
- cipherIn . compact ( );
156
+ LOGGER . debug ( " cipherIn position after unwrap {}" , cipherIn . position () );
106
157
return handshakeStatus ;
107
158
}
108
159
@@ -127,36 +178,32 @@ private static int retryRead(ReadableByteChannel channel, ByteBuffer buffer) thr
127
178
private static SSLEngineResult .HandshakeStatus wrap (ByteBuffer plainOut , ByteBuffer cipherOut ,
128
179
WritableByteChannel channel , SSLEngine sslEngine ) throws IOException {
129
180
SSLEngineResult .HandshakeStatus handshakeStatus = sslEngine .getHandshakeStatus ();
130
- SSLEngineResult .Status status = sslEngine .wrap (plainOut , cipherOut ).getStatus ();
131
- switch (status ) {
181
+ LOGGER .debug ("Handshake status is {} before wrapping" , handshakeStatus );
182
+ SSLEngineResult result = sslEngine .wrap (plainOut , cipherOut );
183
+ LOGGER .debug ("SSL engine result is {} after wrapping" , result );
184
+ switch (result .getStatus ()) {
132
185
case OK :
133
- handshakeStatus = runDelegatedTasks (sslEngine );
134
186
cipherOut .flip ();
135
187
while (cipherOut .hasRemaining ()) {
136
- channel .write (cipherOut );
188
+ int written = channel .write (cipherOut );
189
+ LOGGER .debug ("Wrote {} byte(s)" , written );
137
190
}
138
191
cipherOut .clear ();
192
+ if (result .getHandshakeStatus () == NEED_TASK ) {
193
+ handshakeStatus = runDelegatedTasks (sslEngine );
194
+ } else {
195
+ handshakeStatus = result .getHandshakeStatus ();
196
+ }
197
+
139
198
break ;
140
199
case BUFFER_OVERFLOW :
141
200
throw new SSLException ("Buffer overflow during handshake" );
142
201
default :
143
- throw new SSLException ("Unexpected status " + status );
202
+ throw new SSLException ("Unexpected status " + result . getStatus () );
144
203
}
145
204
return handshakeStatus ;
146
205
}
147
206
148
- static int bufferCopy (ByteBuffer from , ByteBuffer to ) {
149
- int maxTransfer = Math .min (to .remaining (), from .remaining ());
150
-
151
- ByteBuffer temporaryBuffer = from .duplicate ();
152
- temporaryBuffer .limit (temporaryBuffer .position () + maxTransfer );
153
- to .put (temporaryBuffer );
154
-
155
- from .position (from .position () + maxTransfer );
156
-
157
- return maxTransfer ;
158
- }
159
-
160
207
public static void write (WritableByteChannel socketChannel , SSLEngine engine , ByteBuffer plainOut , ByteBuffer cypherOut ) throws IOException {
161
208
while (plainOut .hasRemaining ()) {
162
209
cypherOut .clear ();
0 commit comments