Skip to content

Commit 448d3dd

Browse files
committed
Fix handshake with NIO on TLS 1.3
The unwrapping does not work the same way between TLS 1.2 and 1.3. This commit makes the unwrapping more reliable by getting the number of bytes consumed in the unwrapping and then set the position of the reading ByteBuffer accordingly to the number of bytes. With TLS 1.3, the unwrapping seems to read the whole content of the buffer and to extract only the first record, so the rewinding is necessary. The commit also adds some debug logging, adds tests on TLS 1.2 and 1.3, and re-arranges the TLS test (add utility class). Fixes #715
1 parent edd664c commit 448d3dd

11 files changed

+302
-214
lines changed

src/main/java/com/rabbitmq/client/impl/nio/SslEngineHelper.java

+74-27
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2007-2020 VMware, Inc. or its affiliates. All rights reserved.
1+
// Copyright (c) 2007-2021 VMware, Inc. or its affiliates. All rights reserved.
22
//
33
// This software, the RabbitMQ Java client library, is triple-licensed under the
44
// Mozilla Public License 2.0 ("MPL"), the GNU General Public License version 2
@@ -23,43 +23,66 @@
2323
import java.nio.channels.ReadableByteChannel;
2424
import java.nio.channels.SocketChannel;
2525
import java.nio.channels.WritableByteChannel;
26+
import org.slf4j.Logger;
27+
import org.slf4j.LoggerFactory;
2628

2729
import static javax.net.ssl.SSLEngineResult.HandshakeStatus.FINISHED;
30+
import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_TASK;
31+
import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_WRAP;
2832
import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING;
2933

3034
/**
3135
*
3236
*/
3337
public class SslEngineHelper {
3438

39+
private static final Logger LOGGER = LoggerFactory.getLogger(SslEngineHelper.class);
40+
3541
public static boolean doHandshake(SocketChannel socketChannel, SSLEngine engine) throws IOException {
3642

3743
ByteBuffer plainOut = ByteBuffer.allocate(engine.getSession().getApplicationBufferSize());
3844
ByteBuffer plainIn = ByteBuffer.allocate(engine.getSession().getApplicationBufferSize());
3945
ByteBuffer cipherOut = ByteBuffer.allocate(engine.getSession().getPacketBufferSize());
4046
ByteBuffer cipherIn = ByteBuffer.allocate(engine.getSession().getPacketBufferSize());
4147

48+
LOGGER.debug("Starting TLS handshake");
49+
4250
SSLEngineResult.HandshakeStatus handshakeStatus = engine.getHandshakeStatus();
51+
LOGGER.debug("Initial handshake status is {}", handshakeStatus);
4352
while (handshakeStatus != FINISHED && handshakeStatus != NOT_HANDSHAKING) {
53+
LOGGER.debug("Handshake status is {}", handshakeStatus);
4454
switch (handshakeStatus) {
4555
case NEED_TASK:
56+
LOGGER.debug("Running tasks");
4657
handshakeStatus = runDelegatedTasks(engine);
4758
break;
4859
case NEED_UNWRAP:
60+
LOGGER.debug("Unwrapping...");
4961
handshakeStatus = unwrap(cipherIn, plainIn, socketChannel, engine);
5062
break;
5163
case NEED_WRAP:
64+
LOGGER.debug("Wrapping...");
5265
handshakeStatus = wrap(plainOut, cipherOut, socketChannel, engine);
5366
break;
67+
case FINISHED:
68+
break;
69+
case NOT_HANDSHAKING:
70+
break;
71+
default:
72+
throw new SSLException("Unexpected handshake status " + handshakeStatus);
5473
}
5574
}
75+
76+
77+
LOGGER.debug("TLS handshake completed");
5678
return true;
5779
}
5880

5981
private static SSLEngineResult.HandshakeStatus runDelegatedTasks(SSLEngine sslEngine) {
6082
// FIXME run in executor?
6183
Runnable runnable;
6284
while ((runnable = sslEngine.getDelegatedTask()) != null) {
85+
LOGGER.debug("Running delegated task");
6386
runnable.run();
6487
}
6588
return sslEngine.getHandshakeStatus();
@@ -68,29 +91,57 @@ private static SSLEngineResult.HandshakeStatus runDelegatedTasks(SSLEngine sslEn
6891
private static SSLEngineResult.HandshakeStatus unwrap(ByteBuffer cipherIn, ByteBuffer plainIn,
6992
ReadableByteChannel channel, SSLEngine sslEngine) throws IOException {
7093
SSLEngineResult.HandshakeStatus handshakeStatus = sslEngine.getHandshakeStatus();
71-
72-
if (channel.read(cipherIn) < 0) {
73-
throw new SSLException("Could not read from socket channel");
94+
LOGGER.debug("Handshake status is {} before unwrapping", handshakeStatus);
95+
96+
LOGGER.debug("Cipher in position {}", cipherIn.position());
97+
int read;
98+
if (cipherIn.position() == 0) {
99+
LOGGER.debug("Reading from channel");
100+
read = channel.read(cipherIn);
101+
LOGGER.debug("Read {} byte(s) from channel", read);
102+
if (read < 0) {
103+
throw new SSLException("Could not read from socket channel");
104+
}
105+
cipherIn.flip();
106+
} else {
107+
LOGGER.debug("Not reading");
74108
}
75-
cipherIn.flip();
76109

77110
SSLEngineResult.Status status;
111+
SSLEngineResult unwrapResult;
78112
do {
79-
SSLEngineResult unwrapResult = sslEngine.unwrap(cipherIn, plainIn);
113+
int positionBeforeUnwrapping = cipherIn.position();
114+
unwrapResult = sslEngine.unwrap(cipherIn, plainIn);
115+
LOGGER.debug("SSL engine result is {} after unwrapping", unwrapResult);
80116
status = unwrapResult.getStatus();
81117
switch (status) {
82118
case OK:
83119
plainIn.clear();
84-
handshakeStatus = runDelegatedTasks(sslEngine);
120+
if (unwrapResult.getHandshakeStatus() == NEED_TASK) {
121+
handshakeStatus = runDelegatedTasks(sslEngine);
122+
int newPosition = positionBeforeUnwrapping + unwrapResult.bytesConsumed();
123+
if (newPosition == cipherIn.limit()) {
124+
LOGGER.debug("Clearing cipherIn because all bytes have been read and unwrapped");
125+
cipherIn.clear();
126+
} else {
127+
LOGGER.debug("Setting cipherIn position to {} (limit is {})", newPosition, cipherIn.limit());
128+
cipherIn.position(positionBeforeUnwrapping + unwrapResult.bytesConsumed());
129+
}
130+
} else {
131+
handshakeStatus = unwrapResult.getHandshakeStatus();
132+
}
85133
break;
86134
case BUFFER_OVERFLOW:
87135
throw new SSLException("Buffer overflow during handshake");
88136
case BUFFER_UNDERFLOW:
137+
LOGGER.debug("Buffer underflow");
89138
cipherIn.compact();
90-
int read = NioHelper.read(channel, cipherIn);
139+
LOGGER.debug("Reading from channel...");
140+
read = NioHelper.read(channel, cipherIn);
91141
if(read <= 0) {
92142
retryRead(channel, cipherIn);
93143
}
144+
LOGGER.debug("Done reading from channel...");
94145
cipherIn.flip();
95146
break;
96147
case CLOSED:
@@ -100,9 +151,9 @@ private static SSLEngineResult.HandshakeStatus unwrap(ByteBuffer cipherIn, ByteB
100151
throw new SSLException("Unexpected status from " + unwrapResult);
101152
}
102153
}
103-
while (cipherIn.hasRemaining());
154+
while (unwrapResult.getHandshakeStatus() != NEED_WRAP && unwrapResult.getHandshakeStatus() != FINISHED);
104155

105-
cipherIn.compact();
156+
LOGGER.debug("cipherIn position after unwrap {}", cipherIn.position());
106157
return handshakeStatus;
107158
}
108159

@@ -127,36 +178,32 @@ private static int retryRead(ReadableByteChannel channel, ByteBuffer buffer) thr
127178
private static SSLEngineResult.HandshakeStatus wrap(ByteBuffer plainOut, ByteBuffer cipherOut,
128179
WritableByteChannel channel, SSLEngine sslEngine) throws IOException {
129180
SSLEngineResult.HandshakeStatus handshakeStatus = sslEngine.getHandshakeStatus();
130-
SSLEngineResult.Status status = sslEngine.wrap(plainOut, cipherOut).getStatus();
131-
switch (status) {
181+
LOGGER.debug("Handshake status is {} before wrapping", handshakeStatus);
182+
SSLEngineResult result = sslEngine.wrap(plainOut, cipherOut);
183+
LOGGER.debug("SSL engine result is {} after wrapping", result);
184+
switch (result.getStatus()) {
132185
case OK:
133-
handshakeStatus = runDelegatedTasks(sslEngine);
134186
cipherOut.flip();
135187
while (cipherOut.hasRemaining()) {
136-
channel.write(cipherOut);
188+
int written = channel.write(cipherOut);
189+
LOGGER.debug("Wrote {} byte(s)", written);
137190
}
138191
cipherOut.clear();
192+
if (result.getHandshakeStatus() == NEED_TASK) {
193+
handshakeStatus = runDelegatedTasks(sslEngine);
194+
} else {
195+
handshakeStatus = result.getHandshakeStatus();
196+
}
197+
139198
break;
140199
case BUFFER_OVERFLOW:
141200
throw new SSLException("Buffer overflow during handshake");
142201
default:
143-
throw new SSLException("Unexpected status " + status);
202+
throw new SSLException("Unexpected status " + result.getStatus());
144203
}
145204
return handshakeStatus;
146205
}
147206

148-
static int bufferCopy(ByteBuffer from, ByteBuffer to) {
149-
int maxTransfer = Math.min(to.remaining(), from.remaining());
150-
151-
ByteBuffer temporaryBuffer = from.duplicate();
152-
temporaryBuffer.limit(temporaryBuffer.position() + maxTransfer);
153-
to.put(temporaryBuffer);
154-
155-
from.position(from.position() + maxTransfer);
156-
157-
return maxTransfer;
158-
}
159-
160207
public static void write(WritableByteChannel socketChannel, SSLEngine engine, ByteBuffer plainOut, ByteBuffer cypherOut) throws IOException {
161208
while (plainOut.hasRemaining()) {
162209
cypherOut.clear();

src/test/java/com/rabbitmq/client/test/BrokerTestCase.java

+1-6
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2007-2020 VMware, Inc. or its affiliates. All rights reserved.
1+
// Copyright (c) 2007-2021 VMware, Inc. or its affiliates. All rights reserved.
22
//
33
// This software, the RabbitMQ Java client library, is triple-licensed under the
44
// Mozilla Public License 2.0 ("MPL"), the GNU General Public License version 2
@@ -28,9 +28,7 @@
2828
import org.slf4j.Logger;
2929
import org.slf4j.LoggerFactory;
3030

31-
import javax.net.ssl.SSLContext;
3231
import java.io.IOException;
33-
import java.security.NoSuchAlgorithmException;
3432
import java.util.Map;
3533
import java.util.UUID;
3634
import java.util.concurrent.TimeoutException;
@@ -348,7 +346,4 @@ protected String generateExchangeName() {
348346
return "exchange" + UUID.randomUUID().toString();
349347
}
350348

351-
protected SSLContext getSSLContext() throws NoSuchAlgorithmException {
352-
return TestUtils.getSSLContext();
353-
}
354349
}

src/test/java/com/rabbitmq/client/test/TestUtils.java

+24-21
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2007-2020 VMware, Inc. or its affiliates. All rights reserved.
1+
// Copyright (c) 2007-2021 VMware, Inc. or its affiliates. All rights reserved.
22
//
33
// This software, the RabbitMQ Java client library, is triple-licensed under the
44
// Mozilla Public License 2.0 ("MPL"), the GNU General Public License version 2
@@ -26,19 +26,15 @@
2626
import org.junit.runners.model.Statement;
2727
import org.slf4j.LoggerFactory;
2828

29-
import javax.net.ssl.SSLContext;
3029
import java.io.IOException;
3130
import java.net.ServerSocket;
32-
import java.security.NoSuchAlgorithmException;
3331
import java.time.Duration;
34-
import java.util.Arrays;
3532
import java.util.Collection;
3633
import java.util.Collections;
3734
import java.util.concurrent.Callable;
3835
import java.util.concurrent.CountDownLatch;
3936
import java.util.concurrent.TimeUnit;
4037
import java.util.concurrent.TimeoutException;
41-
import java.util.function.BooleanSupplier;
4238

4339
import static org.junit.Assert.assertTrue;
4440

@@ -109,22 +105,6 @@ public static void abort(Connection connection) {
109105
}
110106
}
111107

112-
public static SSLContext getSSLContext() throws NoSuchAlgorithmException {
113-
SSLContext c = null;
114-
115-
// pick the first protocol available, preferring TLSv1.2, then TLSv1,
116-
// falling back to SSLv3 if running on an ancient/crippled JDK
117-
for (String proto : Arrays.asList("TLSv1.2", "TLSv1", "SSLv3")) {
118-
try {
119-
c = SSLContext.getInstance(proto);
120-
return c;
121-
} catch (NoSuchAlgorithmException x) {
122-
// keep trying
123-
}
124-
}
125-
throw new NoSuchAlgorithmException();
126-
}
127-
128108
public static TestRule atLeast38() {
129109
return new BrokerVersionTestRule("3.8.0");
130110
}
@@ -361,4 +341,27 @@ public interface CallableFunction<T, R> {
361341

362342
}
363343

344+
public static boolean basicGetBasicConsume(Connection connection, String queue, final CountDownLatch latch, int msgSize)
345+
throws Exception {
346+
Channel channel = connection.createChannel();
347+
channel.queueDeclare(queue, false, true, false, null);
348+
channel.queuePurge(queue);
349+
350+
channel.basicPublish("", queue, null, new byte[msgSize]);
351+
352+
String tag = channel.basicConsume(queue, false, new DefaultConsumer(channel) {
353+
354+
@Override
355+
public void handleDelivery(String consumerTag, Envelope envelope, AMQP.BasicProperties properties, byte[] body) throws IOException {
356+
getChannel().basicAck(envelope.getDeliveryTag(), false);
357+
latch.countDown();
358+
}
359+
});
360+
361+
boolean messageReceived = latch.await(20, TimeUnit.SECONDS);
362+
363+
channel.basicCancel(tag);
364+
365+
return messageReceived;
366+
}
364367
}

src/test/java/com/rabbitmq/client/test/ssl/BadVerifiedConnection.java

+4-45
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2007-2020 VMware, Inc. or its affiliates. All rights reserved.
1+
// Copyright (c) 2007-2021 VMware, Inc. or its affiliates. All rights reserved.
22
//
33
// This software, the RabbitMQ Java client library, is triple-licensed under the
44
// Mozilla Public License 2.0 ("MPL"), the GNU General Public License version 2
@@ -15,20 +15,13 @@
1515

1616
package com.rabbitmq.client.test.ssl;
1717

18-
import com.rabbitmq.client.test.TestUtils;
1918
import org.junit.Test;
2019

21-
import javax.net.ssl.KeyManagerFactory;
2220
import javax.net.ssl.SSLContext;
2321
import javax.net.ssl.SSLHandshakeException;
24-
import javax.net.ssl.TrustManagerFactory;
25-
import java.io.FileInputStream;
2622
import java.io.IOException;
27-
import java.security.*;
28-
import java.security.cert.CertificateException;
2923
import java.util.concurrent.TimeoutException;
3024

31-
import static org.junit.Assert.assertNotNull;
3225
import static org.junit.Assert.fail;
3326

3427
/**
@@ -39,44 +32,10 @@ public class BadVerifiedConnection extends UnverifiedConnection {
3932
public void openConnection()
4033
throws IOException, TimeoutException {
4134
try {
42-
String keystorePath = System.getProperty("test-keystore.empty");
43-
assertNotNull(keystorePath);
44-
String keystorePasswd = System.getProperty("test-keystore.password");
45-
assertNotNull(keystorePasswd);
46-
char [] keystorePassword = keystorePasswd.toCharArray();
47-
48-
KeyStore tks = KeyStore.getInstance("JKS");
49-
tks.load(new FileInputStream(keystorePath), keystorePassword);
50-
51-
TrustManagerFactory tmf = TrustManagerFactory.getInstance("SunX509");
52-
tmf.init(tks);
53-
54-
String p12Path = System.getProperty("test-client-cert.path");
55-
assertNotNull(p12Path);
56-
String p12Passwd = System.getProperty("test-client-cert.password");
57-
assertNotNull(p12Passwd);
58-
KeyStore ks = KeyStore.getInstance("PKCS12");
59-
char [] p12Password = p12Passwd.toCharArray();
60-
ks.load(new FileInputStream(p12Path), p12Password);
61-
62-
KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509");
63-
kmf.init(ks, p12Password);
64-
65-
SSLContext c = getSSLContext();
66-
c.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null);
67-
68-
connectionFactory = TestUtils.connectionFactory();
35+
SSLContext c = TlsTestUtils.badVerifiedSslContext();
6936
connectionFactory.useSslProtocol(c);
70-
} catch (NoSuchAlgorithmException ex) {
71-
throw new IOException(ex.toString());
72-
} catch (KeyManagementException ex) {
73-
throw new IOException(ex.toString());
74-
} catch (KeyStoreException ex) {
75-
throw new IOException(ex.toString());
76-
} catch (CertificateException ex) {
77-
throw new IOException(ex.toString());
78-
} catch (UnrecoverableKeyException ex) {
79-
throw new IOException(ex.toString());
37+
} catch (Exception ex) {
38+
throw new IOException(ex);
8039
}
8140

8241
try {

0 commit comments

Comments
 (0)