Skip to content

Commit 7fc11b1

Browse files
Merge pull request #716 from rabbitmq/rabbitmq-java-client-715-handshake-hangs-nio-tls-1-3
Fix handshake with NIO on TLS 1.3
2 parents 0f2a53e + 448d3dd commit 7fc11b1

11 files changed

+302
-214
lines changed

src/main/java/com/rabbitmq/client/impl/nio/SslEngineHelper.java

+74-27
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2007-2020 VMware, Inc. or its affiliates. All rights reserved.
1+
// Copyright (c) 2007-2021 VMware, Inc. or its affiliates. All rights reserved.
22
//
33
// This software, the RabbitMQ Java client library, is triple-licensed under the
44
// Mozilla Public License 2.0 ("MPL"), the GNU General Public License version 2
@@ -23,43 +23,66 @@
2323
import java.nio.channels.ReadableByteChannel;
2424
import java.nio.channels.SocketChannel;
2525
import java.nio.channels.WritableByteChannel;
26+
import org.slf4j.Logger;
27+
import org.slf4j.LoggerFactory;
2628

2729
import static javax.net.ssl.SSLEngineResult.HandshakeStatus.FINISHED;
30+
import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_TASK;
31+
import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_WRAP;
2832
import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING;
2933

3034
/**
3135
*
3236
*/
3337
public class SslEngineHelper {
3438

39+
private static final Logger LOGGER = LoggerFactory.getLogger(SslEngineHelper.class);
40+
3541
public static boolean doHandshake(SocketChannel socketChannel, SSLEngine engine) throws IOException {
3642

3743
ByteBuffer plainOut = ByteBuffer.allocate(engine.getSession().getApplicationBufferSize());
3844
ByteBuffer plainIn = ByteBuffer.allocate(engine.getSession().getApplicationBufferSize());
3945
ByteBuffer cipherOut = ByteBuffer.allocate(engine.getSession().getPacketBufferSize());
4046
ByteBuffer cipherIn = ByteBuffer.allocate(engine.getSession().getPacketBufferSize());
4147

48+
LOGGER.debug("Starting TLS handshake");
49+
4250
SSLEngineResult.HandshakeStatus handshakeStatus = engine.getHandshakeStatus();
51+
LOGGER.debug("Initial handshake status is {}", handshakeStatus);
4352
while (handshakeStatus != FINISHED && handshakeStatus != NOT_HANDSHAKING) {
53+
LOGGER.debug("Handshake status is {}", handshakeStatus);
4454
switch (handshakeStatus) {
4555
case NEED_TASK:
56+
LOGGER.debug("Running tasks");
4657
handshakeStatus = runDelegatedTasks(engine);
4758
break;
4859
case NEED_UNWRAP:
60+
LOGGER.debug("Unwrapping...");
4961
handshakeStatus = unwrap(cipherIn, plainIn, socketChannel, engine);
5062
break;
5163
case NEED_WRAP:
64+
LOGGER.debug("Wrapping...");
5265
handshakeStatus = wrap(plainOut, cipherOut, socketChannel, engine);
5366
break;
67+
case FINISHED:
68+
break;
69+
case NOT_HANDSHAKING:
70+
break;
71+
default:
72+
throw new SSLException("Unexpected handshake status " + handshakeStatus);
5473
}
5574
}
75+
76+
77+
LOGGER.debug("TLS handshake completed");
5678
return true;
5779
}
5880

5981
private static SSLEngineResult.HandshakeStatus runDelegatedTasks(SSLEngine sslEngine) {
6082
// FIXME run in executor?
6183
Runnable runnable;
6284
while ((runnable = sslEngine.getDelegatedTask()) != null) {
85+
LOGGER.debug("Running delegated task");
6386
runnable.run();
6487
}
6588
return sslEngine.getHandshakeStatus();
@@ -68,29 +91,57 @@ private static SSLEngineResult.HandshakeStatus runDelegatedTasks(SSLEngine sslEn
6891
private static SSLEngineResult.HandshakeStatus unwrap(ByteBuffer cipherIn, ByteBuffer plainIn,
6992
ReadableByteChannel channel, SSLEngine sslEngine) throws IOException {
7093
SSLEngineResult.HandshakeStatus handshakeStatus = sslEngine.getHandshakeStatus();
71-
72-
if (channel.read(cipherIn) < 0) {
73-
throw new SSLException("Could not read from socket channel");
94+
LOGGER.debug("Handshake status is {} before unwrapping", handshakeStatus);
95+
96+
LOGGER.debug("Cipher in position {}", cipherIn.position());
97+
int read;
98+
if (cipherIn.position() == 0) {
99+
LOGGER.debug("Reading from channel");
100+
read = channel.read(cipherIn);
101+
LOGGER.debug("Read {} byte(s) from channel", read);
102+
if (read < 0) {
103+
throw new SSLException("Could not read from socket channel");
104+
}
105+
cipherIn.flip();
106+
} else {
107+
LOGGER.debug("Not reading");
74108
}
75-
cipherIn.flip();
76109

77110
SSLEngineResult.Status status;
111+
SSLEngineResult unwrapResult;
78112
do {
79-
SSLEngineResult unwrapResult = sslEngine.unwrap(cipherIn, plainIn);
113+
int positionBeforeUnwrapping = cipherIn.position();
114+
unwrapResult = sslEngine.unwrap(cipherIn, plainIn);
115+
LOGGER.debug("SSL engine result is {} after unwrapping", unwrapResult);
80116
status = unwrapResult.getStatus();
81117
switch (status) {
82118
case OK:
83119
plainIn.clear();
84-
handshakeStatus = runDelegatedTasks(sslEngine);
120+
if (unwrapResult.getHandshakeStatus() == NEED_TASK) {
121+
handshakeStatus = runDelegatedTasks(sslEngine);
122+
int newPosition = positionBeforeUnwrapping + unwrapResult.bytesConsumed();
123+
if (newPosition == cipherIn.limit()) {
124+
LOGGER.debug("Clearing cipherIn because all bytes have been read and unwrapped");
125+
cipherIn.clear();
126+
} else {
127+
LOGGER.debug("Setting cipherIn position to {} (limit is {})", newPosition, cipherIn.limit());
128+
cipherIn.position(positionBeforeUnwrapping + unwrapResult.bytesConsumed());
129+
}
130+
} else {
131+
handshakeStatus = unwrapResult.getHandshakeStatus();
132+
}
85133
break;
86134
case BUFFER_OVERFLOW:
87135
throw new SSLException("Buffer overflow during handshake");
88136
case BUFFER_UNDERFLOW:
137+
LOGGER.debug("Buffer underflow");
89138
cipherIn.compact();
90-
int read = NioHelper.read(channel, cipherIn);
139+
LOGGER.debug("Reading from channel...");
140+
read = NioHelper.read(channel, cipherIn);
91141
if(read <= 0) {
92142
retryRead(channel, cipherIn);
93143
}
144+
LOGGER.debug("Done reading from channel...");
94145
cipherIn.flip();
95146
break;
96147
case CLOSED:
@@ -100,9 +151,9 @@ private static SSLEngineResult.HandshakeStatus unwrap(ByteBuffer cipherIn, ByteB
100151
throw new SSLException("Unexpected status from " + unwrapResult);
101152
}
102153
}
103-
while (cipherIn.hasRemaining());
154+
while (unwrapResult.getHandshakeStatus() != NEED_WRAP && unwrapResult.getHandshakeStatus() != FINISHED);
104155

105-
cipherIn.compact();
156+
LOGGER.debug("cipherIn position after unwrap {}", cipherIn.position());
106157
return handshakeStatus;
107158
}
108159

@@ -127,36 +178,32 @@ private static int retryRead(ReadableByteChannel channel, ByteBuffer buffer) thr
127178
private static SSLEngineResult.HandshakeStatus wrap(ByteBuffer plainOut, ByteBuffer cipherOut,
128179
WritableByteChannel channel, SSLEngine sslEngine) throws IOException {
129180
SSLEngineResult.HandshakeStatus handshakeStatus = sslEngine.getHandshakeStatus();
130-
SSLEngineResult.Status status = sslEngine.wrap(plainOut, cipherOut).getStatus();
131-
switch (status) {
181+
LOGGER.debug("Handshake status is {} before wrapping", handshakeStatus);
182+
SSLEngineResult result = sslEngine.wrap(plainOut, cipherOut);
183+
LOGGER.debug("SSL engine result is {} after wrapping", result);
184+
switch (result.getStatus()) {
132185
case OK:
133-
handshakeStatus = runDelegatedTasks(sslEngine);
134186
cipherOut.flip();
135187
while (cipherOut.hasRemaining()) {
136-
channel.write(cipherOut);
188+
int written = channel.write(cipherOut);
189+
LOGGER.debug("Wrote {} byte(s)", written);
137190
}
138191
cipherOut.clear();
192+
if (result.getHandshakeStatus() == NEED_TASK) {
193+
handshakeStatus = runDelegatedTasks(sslEngine);
194+
} else {
195+
handshakeStatus = result.getHandshakeStatus();
196+
}
197+
139198
break;
140199
case BUFFER_OVERFLOW:
141200
throw new SSLException("Buffer overflow during handshake");
142201
default:
143-
throw new SSLException("Unexpected status " + status);
202+
throw new SSLException("Unexpected status " + result.getStatus());
144203
}
145204
return handshakeStatus;
146205
}
147206

148-
static int bufferCopy(ByteBuffer from, ByteBuffer to) {
149-
int maxTransfer = Math.min(to.remaining(), from.remaining());
150-
151-
ByteBuffer temporaryBuffer = from.duplicate();
152-
temporaryBuffer.limit(temporaryBuffer.position() + maxTransfer);
153-
to.put(temporaryBuffer);
154-
155-
from.position(from.position() + maxTransfer);
156-
157-
return maxTransfer;
158-
}
159-
160207
public static void write(WritableByteChannel socketChannel, SSLEngine engine, ByteBuffer plainOut, ByteBuffer cypherOut) throws IOException {
161208
while (plainOut.hasRemaining()) {
162209
cypherOut.clear();

src/test/java/com/rabbitmq/client/test/BrokerTestCase.java

+1-6
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2007-2020 VMware, Inc. or its affiliates. All rights reserved.
1+
// Copyright (c) 2007-2021 VMware, Inc. or its affiliates. All rights reserved.
22
//
33
// This software, the RabbitMQ Java client library, is triple-licensed under the
44
// Mozilla Public License 2.0 ("MPL"), the GNU General Public License version 2
@@ -28,9 +28,7 @@
2828
import org.slf4j.Logger;
2929
import org.slf4j.LoggerFactory;
3030

31-
import javax.net.ssl.SSLContext;
3231
import java.io.IOException;
33-
import java.security.NoSuchAlgorithmException;
3432
import java.util.Map;
3533
import java.util.UUID;
3634
import java.util.concurrent.TimeoutException;
@@ -348,7 +346,4 @@ protected String generateExchangeName() {
348346
return "exchange" + UUID.randomUUID().toString();
349347
}
350348

351-
protected SSLContext getSSLContext() throws NoSuchAlgorithmException {
352-
return TestUtils.getSSLContext();
353-
}
354349
}

src/test/java/com/rabbitmq/client/test/TestUtils.java

+24-21
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2007-2020 VMware, Inc. or its affiliates. All rights reserved.
1+
// Copyright (c) 2007-2021 VMware, Inc. or its affiliates. All rights reserved.
22
//
33
// This software, the RabbitMQ Java client library, is triple-licensed under the
44
// Mozilla Public License 2.0 ("MPL"), the GNU General Public License version 2
@@ -26,19 +26,15 @@
2626
import org.junit.runners.model.Statement;
2727
import org.slf4j.LoggerFactory;
2828

29-
import javax.net.ssl.SSLContext;
3029
import java.io.IOException;
3130
import java.net.ServerSocket;
32-
import java.security.NoSuchAlgorithmException;
3331
import java.time.Duration;
34-
import java.util.Arrays;
3532
import java.util.Collection;
3633
import java.util.Collections;
3734
import java.util.concurrent.Callable;
3835
import java.util.concurrent.CountDownLatch;
3936
import java.util.concurrent.TimeUnit;
4037
import java.util.concurrent.TimeoutException;
41-
import java.util.function.BooleanSupplier;
4238

4339
import static org.junit.Assert.assertTrue;
4440

@@ -109,22 +105,6 @@ public static void abort(Connection connection) {
109105
}
110106
}
111107

112-
public static SSLContext getSSLContext() throws NoSuchAlgorithmException {
113-
SSLContext c = null;
114-
115-
// pick the first protocol available, preferring TLSv1.2, then TLSv1,
116-
// falling back to SSLv3 if running on an ancient/crippled JDK
117-
for (String proto : Arrays.asList("TLSv1.2", "TLSv1", "SSLv3")) {
118-
try {
119-
c = SSLContext.getInstance(proto);
120-
return c;
121-
} catch (NoSuchAlgorithmException x) {
122-
// keep trying
123-
}
124-
}
125-
throw new NoSuchAlgorithmException();
126-
}
127-
128108
public static TestRule atLeast38() {
129109
return new BrokerVersionTestRule("3.8.0");
130110
}
@@ -361,4 +341,27 @@ public interface CallableFunction<T, R> {
361341

362342
}
363343

344+
public static boolean basicGetBasicConsume(Connection connection, String queue, final CountDownLatch latch, int msgSize)
345+
throws Exception {
346+
Channel channel = connection.createChannel();
347+
channel.queueDeclare(queue, false, true, false, null);
348+
channel.queuePurge(queue);
349+
350+
channel.basicPublish("", queue, null, new byte[msgSize]);
351+
352+
String tag = channel.basicConsume(queue, false, new DefaultConsumer(channel) {
353+
354+
@Override
355+
public void handleDelivery(String consumerTag, Envelope envelope, AMQP.BasicProperties properties, byte[] body) throws IOException {
356+
getChannel().basicAck(envelope.getDeliveryTag(), false);
357+
latch.countDown();
358+
}
359+
});
360+
361+
boolean messageReceived = latch.await(20, TimeUnit.SECONDS);
362+
363+
channel.basicCancel(tag);
364+
365+
return messageReceived;
366+
}
364367
}

src/test/java/com/rabbitmq/client/test/ssl/BadVerifiedConnection.java

+4-45
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copyright (c) 2007-2020 VMware, Inc. or its affiliates. All rights reserved.
1+
// Copyright (c) 2007-2021 VMware, Inc. or its affiliates. All rights reserved.
22
//
33
// This software, the RabbitMQ Java client library, is triple-licensed under the
44
// Mozilla Public License 2.0 ("MPL"), the GNU General Public License version 2
@@ -15,20 +15,13 @@
1515

1616
package com.rabbitmq.client.test.ssl;
1717

18-
import com.rabbitmq.client.test.TestUtils;
1918
import org.junit.Test;
2019

21-
import javax.net.ssl.KeyManagerFactory;
2220
import javax.net.ssl.SSLContext;
2321
import javax.net.ssl.SSLHandshakeException;
24-
import javax.net.ssl.TrustManagerFactory;
25-
import java.io.FileInputStream;
2622
import java.io.IOException;
27-
import java.security.*;
28-
import java.security.cert.CertificateException;
2923
import java.util.concurrent.TimeoutException;
3024

31-
import static org.junit.Assert.assertNotNull;
3225
import static org.junit.Assert.fail;
3326

3427
/**
@@ -39,44 +32,10 @@ public class BadVerifiedConnection extends UnverifiedConnection {
3932
public void openConnection()
4033
throws IOException, TimeoutException {
4134
try {
42-
String keystorePath = System.getProperty("test-keystore.empty");
43-
assertNotNull(keystorePath);
44-
String keystorePasswd = System.getProperty("test-keystore.password");
45-
assertNotNull(keystorePasswd);
46-
char [] keystorePassword = keystorePasswd.toCharArray();
47-
48-
KeyStore tks = KeyStore.getInstance("JKS");
49-
tks.load(new FileInputStream(keystorePath), keystorePassword);
50-
51-
TrustManagerFactory tmf = TrustManagerFactory.getInstance("SunX509");
52-
tmf.init(tks);
53-
54-
String p12Path = System.getProperty("test-client-cert.path");
55-
assertNotNull(p12Path);
56-
String p12Passwd = System.getProperty("test-client-cert.password");
57-
assertNotNull(p12Passwd);
58-
KeyStore ks = KeyStore.getInstance("PKCS12");
59-
char [] p12Password = p12Passwd.toCharArray();
60-
ks.load(new FileInputStream(p12Path), p12Password);
61-
62-
KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509");
63-
kmf.init(ks, p12Password);
64-
65-
SSLContext c = getSSLContext();
66-
c.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null);
67-
68-
connectionFactory = TestUtils.connectionFactory();
35+
SSLContext c = TlsTestUtils.badVerifiedSslContext();
6936
connectionFactory.useSslProtocol(c);
70-
} catch (NoSuchAlgorithmException ex) {
71-
throw new IOException(ex.toString());
72-
} catch (KeyManagementException ex) {
73-
throw new IOException(ex.toString());
74-
} catch (KeyStoreException ex) {
75-
throw new IOException(ex.toString());
76-
} catch (CertificateException ex) {
77-
throw new IOException(ex.toString());
78-
} catch (UnrecoverableKeyException ex) {
79-
throw new IOException(ex.toString());
37+
} catch (Exception ex) {
38+
throw new IOException(ex);
8039
}
8140

8241
try {

0 commit comments

Comments
 (0)