From 448d3dd0ccb9755db2fccb11f433ddf5b55ca48f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Arnaud=20Cogolu=C3=A8gnes?= Date: Fri, 5 Nov 2021 13:41:55 +0100 Subject: [PATCH] Fix handshake with NIO on TLS 1.3 The unwrapping does not work the same way between TLS 1.2 and 1.3. This commit makes the unwrapping more reliable by getting the number of bytes consumed in the unwrapping and then set the position of the reading ByteBuffer accordingly to the number of bytes. With TLS 1.3, the unwrapping seems to read the whole content of the buffer and to extract only the first record, so the rewinding is necessary. The commit also adds some debug logging, adds tests on TLS 1.2 and 1.3, and re-arranges the TLS test (add utility class). Fixes #715 --- .../client/impl/nio/SslEngineHelper.java | 101 +++++++++++---- .../rabbitmq/client/test/BrokerTestCase.java | 7 +- .../com/rabbitmq/client/test/TestUtils.java | 45 +++---- .../test/ssl/BadVerifiedConnection.java | 49 +------- .../client/test/ssl/HostnameVerification.java | 35 +----- .../test/ssl/NioTlsUnverifiedConnection.java | 55 +++++---- .../client/test/ssl/TlsConnectionLogging.java | 4 +- .../client/test/ssl/TlsTestUtils.java | 115 ++++++++++++++++++ .../client/test/ssl/UnverifiedConnection.java | 10 +- .../client/test/ssl/VerifiedConnection.java | 93 +++++++------- src/test/resources/logback-test.xml | 2 +- 11 files changed, 302 insertions(+), 214 deletions(-) create mode 100644 src/test/java/com/rabbitmq/client/test/ssl/TlsTestUtils.java diff --git a/src/main/java/com/rabbitmq/client/impl/nio/SslEngineHelper.java b/src/main/java/com/rabbitmq/client/impl/nio/SslEngineHelper.java index 1e7e3a0793..bcefe8b205 100644 --- a/src/main/java/com/rabbitmq/client/impl/nio/SslEngineHelper.java +++ b/src/main/java/com/rabbitmq/client/impl/nio/SslEngineHelper.java @@ -1,4 +1,4 @@ -// Copyright (c) 2007-2020 VMware, Inc. or its affiliates. All rights reserved. +// Copyright (c) 2007-2021 VMware, Inc. or its affiliates. All rights reserved. // // This software, the RabbitMQ Java client library, is triple-licensed under the // Mozilla Public License 2.0 ("MPL"), the GNU General Public License version 2 @@ -23,8 +23,12 @@ import java.nio.channels.ReadableByteChannel; import java.nio.channels.SocketChannel; import java.nio.channels.WritableByteChannel; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import static javax.net.ssl.SSLEngineResult.HandshakeStatus.FINISHED; +import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_TASK; +import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_WRAP; import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING; /** @@ -32,6 +36,8 @@ */ public class SslEngineHelper { + private static final Logger LOGGER = LoggerFactory.getLogger(SslEngineHelper.class); + public static boolean doHandshake(SocketChannel socketChannel, SSLEngine engine) throws IOException { ByteBuffer plainOut = ByteBuffer.allocate(engine.getSession().getApplicationBufferSize()); @@ -39,20 +45,36 @@ public static boolean doHandshake(SocketChannel socketChannel, SSLEngine engine) ByteBuffer cipherOut = ByteBuffer.allocate(engine.getSession().getPacketBufferSize()); ByteBuffer cipherIn = ByteBuffer.allocate(engine.getSession().getPacketBufferSize()); + LOGGER.debug("Starting TLS handshake"); + SSLEngineResult.HandshakeStatus handshakeStatus = engine.getHandshakeStatus(); + LOGGER.debug("Initial handshake status is {}", handshakeStatus); while (handshakeStatus != FINISHED && handshakeStatus != NOT_HANDSHAKING) { + LOGGER.debug("Handshake status is {}", handshakeStatus); switch (handshakeStatus) { case NEED_TASK: + LOGGER.debug("Running tasks"); handshakeStatus = runDelegatedTasks(engine); break; case NEED_UNWRAP: + LOGGER.debug("Unwrapping..."); handshakeStatus = unwrap(cipherIn, plainIn, socketChannel, engine); break; case NEED_WRAP: + LOGGER.debug("Wrapping..."); handshakeStatus = wrap(plainOut, cipherOut, socketChannel, engine); break; + case FINISHED: + break; + case NOT_HANDSHAKING: + break; + default: + throw new SSLException("Unexpected handshake status " + handshakeStatus); } } + + + LOGGER.debug("TLS handshake completed"); return true; } @@ -60,6 +82,7 @@ private static SSLEngineResult.HandshakeStatus runDelegatedTasks(SSLEngine sslEn // FIXME run in executor? Runnable runnable; while ((runnable = sslEngine.getDelegatedTask()) != null) { + LOGGER.debug("Running delegated task"); runnable.run(); } return sslEngine.getHandshakeStatus(); @@ -68,29 +91,57 @@ private static SSLEngineResult.HandshakeStatus runDelegatedTasks(SSLEngine sslEn private static SSLEngineResult.HandshakeStatus unwrap(ByteBuffer cipherIn, ByteBuffer plainIn, ReadableByteChannel channel, SSLEngine sslEngine) throws IOException { SSLEngineResult.HandshakeStatus handshakeStatus = sslEngine.getHandshakeStatus(); - - if (channel.read(cipherIn) < 0) { - throw new SSLException("Could not read from socket channel"); + LOGGER.debug("Handshake status is {} before unwrapping", handshakeStatus); + + LOGGER.debug("Cipher in position {}", cipherIn.position()); + int read; + if (cipherIn.position() == 0) { + LOGGER.debug("Reading from channel"); + read = channel.read(cipherIn); + LOGGER.debug("Read {} byte(s) from channel", read); + if (read < 0) { + throw new SSLException("Could not read from socket channel"); + } + cipherIn.flip(); + } else { + LOGGER.debug("Not reading"); } - cipherIn.flip(); SSLEngineResult.Status status; + SSLEngineResult unwrapResult; do { - SSLEngineResult unwrapResult = sslEngine.unwrap(cipherIn, plainIn); + int positionBeforeUnwrapping = cipherIn.position(); + unwrapResult = sslEngine.unwrap(cipherIn, plainIn); + LOGGER.debug("SSL engine result is {} after unwrapping", unwrapResult); status = unwrapResult.getStatus(); switch (status) { case OK: plainIn.clear(); - handshakeStatus = runDelegatedTasks(sslEngine); + if (unwrapResult.getHandshakeStatus() == NEED_TASK) { + handshakeStatus = runDelegatedTasks(sslEngine); + int newPosition = positionBeforeUnwrapping + unwrapResult.bytesConsumed(); + if (newPosition == cipherIn.limit()) { + LOGGER.debug("Clearing cipherIn because all bytes have been read and unwrapped"); + cipherIn.clear(); + } else { + LOGGER.debug("Setting cipherIn position to {} (limit is {})", newPosition, cipherIn.limit()); + cipherIn.position(positionBeforeUnwrapping + unwrapResult.bytesConsumed()); + } + } else { + handshakeStatus = unwrapResult.getHandshakeStatus(); + } break; case BUFFER_OVERFLOW: throw new SSLException("Buffer overflow during handshake"); case BUFFER_UNDERFLOW: + LOGGER.debug("Buffer underflow"); cipherIn.compact(); - int read = NioHelper.read(channel, cipherIn); + LOGGER.debug("Reading from channel..."); + read = NioHelper.read(channel, cipherIn); if(read <= 0) { retryRead(channel, cipherIn); } + LOGGER.debug("Done reading from channel..."); cipherIn.flip(); break; case CLOSED: @@ -100,9 +151,9 @@ private static SSLEngineResult.HandshakeStatus unwrap(ByteBuffer cipherIn, ByteB throw new SSLException("Unexpected status from " + unwrapResult); } } - while (cipherIn.hasRemaining()); + while (unwrapResult.getHandshakeStatus() != NEED_WRAP && unwrapResult.getHandshakeStatus() != FINISHED); - cipherIn.compact(); + LOGGER.debug("cipherIn position after unwrap {}", cipherIn.position()); return handshakeStatus; } @@ -127,36 +178,32 @@ private static int retryRead(ReadableByteChannel channel, ByteBuffer buffer) thr private static SSLEngineResult.HandshakeStatus wrap(ByteBuffer plainOut, ByteBuffer cipherOut, WritableByteChannel channel, SSLEngine sslEngine) throws IOException { SSLEngineResult.HandshakeStatus handshakeStatus = sslEngine.getHandshakeStatus(); - SSLEngineResult.Status status = sslEngine.wrap(plainOut, cipherOut).getStatus(); - switch (status) { + LOGGER.debug("Handshake status is {} before wrapping", handshakeStatus); + SSLEngineResult result = sslEngine.wrap(plainOut, cipherOut); + LOGGER.debug("SSL engine result is {} after wrapping", result); + switch (result.getStatus()) { case OK: - handshakeStatus = runDelegatedTasks(sslEngine); cipherOut.flip(); while (cipherOut.hasRemaining()) { - channel.write(cipherOut); + int written = channel.write(cipherOut); + LOGGER.debug("Wrote {} byte(s)", written); } cipherOut.clear(); + if (result.getHandshakeStatus() == NEED_TASK) { + handshakeStatus = runDelegatedTasks(sslEngine); + } else { + handshakeStatus = result.getHandshakeStatus(); + } + break; case BUFFER_OVERFLOW: throw new SSLException("Buffer overflow during handshake"); default: - throw new SSLException("Unexpected status " + status); + throw new SSLException("Unexpected status " + result.getStatus()); } return handshakeStatus; } - static int bufferCopy(ByteBuffer from, ByteBuffer to) { - int maxTransfer = Math.min(to.remaining(), from.remaining()); - - ByteBuffer temporaryBuffer = from.duplicate(); - temporaryBuffer.limit(temporaryBuffer.position() + maxTransfer); - to.put(temporaryBuffer); - - from.position(from.position() + maxTransfer); - - return maxTransfer; - } - public static void write(WritableByteChannel socketChannel, SSLEngine engine, ByteBuffer plainOut, ByteBuffer cypherOut) throws IOException { while (plainOut.hasRemaining()) { cypherOut.clear(); diff --git a/src/test/java/com/rabbitmq/client/test/BrokerTestCase.java b/src/test/java/com/rabbitmq/client/test/BrokerTestCase.java index 7c23a3c0e6..37cf436db4 100644 --- a/src/test/java/com/rabbitmq/client/test/BrokerTestCase.java +++ b/src/test/java/com/rabbitmq/client/test/BrokerTestCase.java @@ -1,4 +1,4 @@ -// Copyright (c) 2007-2020 VMware, Inc. or its affiliates. All rights reserved. +// Copyright (c) 2007-2021 VMware, Inc. or its affiliates. All rights reserved. // // This software, the RabbitMQ Java client library, is triple-licensed under the // Mozilla Public License 2.0 ("MPL"), the GNU General Public License version 2 @@ -28,9 +28,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import javax.net.ssl.SSLContext; import java.io.IOException; -import java.security.NoSuchAlgorithmException; import java.util.Map; import java.util.UUID; import java.util.concurrent.TimeoutException; @@ -348,7 +346,4 @@ protected String generateExchangeName() { return "exchange" + UUID.randomUUID().toString(); } - protected SSLContext getSSLContext() throws NoSuchAlgorithmException { - return TestUtils.getSSLContext(); - } } diff --git a/src/test/java/com/rabbitmq/client/test/TestUtils.java b/src/test/java/com/rabbitmq/client/test/TestUtils.java index 7544893760..c488fcff6d 100644 --- a/src/test/java/com/rabbitmq/client/test/TestUtils.java +++ b/src/test/java/com/rabbitmq/client/test/TestUtils.java @@ -1,4 +1,4 @@ -// Copyright (c) 2007-2020 VMware, Inc. or its affiliates. All rights reserved. +// Copyright (c) 2007-2021 VMware, Inc. or its affiliates. All rights reserved. // // This software, the RabbitMQ Java client library, is triple-licensed under the // Mozilla Public License 2.0 ("MPL"), the GNU General Public License version 2 @@ -26,19 +26,15 @@ import org.junit.runners.model.Statement; import org.slf4j.LoggerFactory; -import javax.net.ssl.SSLContext; import java.io.IOException; import java.net.ServerSocket; -import java.security.NoSuchAlgorithmException; import java.time.Duration; -import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.concurrent.Callable; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; -import java.util.function.BooleanSupplier; import static org.junit.Assert.assertTrue; @@ -109,22 +105,6 @@ public static void abort(Connection connection) { } } - public static SSLContext getSSLContext() throws NoSuchAlgorithmException { - SSLContext c = null; - - // pick the first protocol available, preferring TLSv1.2, then TLSv1, - // falling back to SSLv3 if running on an ancient/crippled JDK - for (String proto : Arrays.asList("TLSv1.2", "TLSv1", "SSLv3")) { - try { - c = SSLContext.getInstance(proto); - return c; - } catch (NoSuchAlgorithmException x) { - // keep trying - } - } - throw new NoSuchAlgorithmException(); - } - public static TestRule atLeast38() { return new BrokerVersionTestRule("3.8.0"); } @@ -361,4 +341,27 @@ public interface CallableFunction { } + public static boolean basicGetBasicConsume(Connection connection, String queue, final CountDownLatch latch, int msgSize) + throws Exception { + Channel channel = connection.createChannel(); + channel.queueDeclare(queue, false, true, false, null); + channel.queuePurge(queue); + + channel.basicPublish("", queue, null, new byte[msgSize]); + + String tag = channel.basicConsume(queue, false, new DefaultConsumer(channel) { + + @Override + public void handleDelivery(String consumerTag, Envelope envelope, AMQP.BasicProperties properties, byte[] body) throws IOException { + getChannel().basicAck(envelope.getDeliveryTag(), false); + latch.countDown(); + } + }); + + boolean messageReceived = latch.await(20, TimeUnit.SECONDS); + + channel.basicCancel(tag); + + return messageReceived; + } } diff --git a/src/test/java/com/rabbitmq/client/test/ssl/BadVerifiedConnection.java b/src/test/java/com/rabbitmq/client/test/ssl/BadVerifiedConnection.java index 9137213578..fe33af7dec 100644 --- a/src/test/java/com/rabbitmq/client/test/ssl/BadVerifiedConnection.java +++ b/src/test/java/com/rabbitmq/client/test/ssl/BadVerifiedConnection.java @@ -1,4 +1,4 @@ -// Copyright (c) 2007-2020 VMware, Inc. or its affiliates. All rights reserved. +// Copyright (c) 2007-2021 VMware, Inc. or its affiliates. All rights reserved. // // This software, the RabbitMQ Java client library, is triple-licensed under the // Mozilla Public License 2.0 ("MPL"), the GNU General Public License version 2 @@ -15,20 +15,13 @@ package com.rabbitmq.client.test.ssl; -import com.rabbitmq.client.test.TestUtils; import org.junit.Test; -import javax.net.ssl.KeyManagerFactory; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLHandshakeException; -import javax.net.ssl.TrustManagerFactory; -import java.io.FileInputStream; import java.io.IOException; -import java.security.*; -import java.security.cert.CertificateException; import java.util.concurrent.TimeoutException; -import static org.junit.Assert.assertNotNull; import static org.junit.Assert.fail; /** @@ -39,44 +32,10 @@ public class BadVerifiedConnection extends UnverifiedConnection { public void openConnection() throws IOException, TimeoutException { try { - String keystorePath = System.getProperty("test-keystore.empty"); - assertNotNull(keystorePath); - String keystorePasswd = System.getProperty("test-keystore.password"); - assertNotNull(keystorePasswd); - char [] keystorePassword = keystorePasswd.toCharArray(); - - KeyStore tks = KeyStore.getInstance("JKS"); - tks.load(new FileInputStream(keystorePath), keystorePassword); - - TrustManagerFactory tmf = TrustManagerFactory.getInstance("SunX509"); - tmf.init(tks); - - String p12Path = System.getProperty("test-client-cert.path"); - assertNotNull(p12Path); - String p12Passwd = System.getProperty("test-client-cert.password"); - assertNotNull(p12Passwd); - KeyStore ks = KeyStore.getInstance("PKCS12"); - char [] p12Password = p12Passwd.toCharArray(); - ks.load(new FileInputStream(p12Path), p12Password); - - KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509"); - kmf.init(ks, p12Password); - - SSLContext c = getSSLContext(); - c.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null); - - connectionFactory = TestUtils.connectionFactory(); + SSLContext c = TlsTestUtils.badVerifiedSslContext(); connectionFactory.useSslProtocol(c); - } catch (NoSuchAlgorithmException ex) { - throw new IOException(ex.toString()); - } catch (KeyManagementException ex) { - throw new IOException(ex.toString()); - } catch (KeyStoreException ex) { - throw new IOException(ex.toString()); - } catch (CertificateException ex) { - throw new IOException(ex.toString()); - } catch (UnrecoverableKeyException ex) { - throw new IOException(ex.toString()); + } catch (Exception ex) { + throw new IOException(ex); } try { diff --git a/src/test/java/com/rabbitmq/client/test/ssl/HostnameVerification.java b/src/test/java/com/rabbitmq/client/test/ssl/HostnameVerification.java index 36a66a940d..acbfe48260 100644 --- a/src/test/java/com/rabbitmq/client/test/ssl/HostnameVerification.java +++ b/src/test/java/com/rabbitmq/client/test/ssl/HostnameVerification.java @@ -1,4 +1,4 @@ -// Copyright (c) 2007-2020 VMware, Inc. or its affiliates. All rights reserved. +// Copyright (c) 2007-2021 VMware, Inc. or its affiliates. All rights reserved. // // This software, the RabbitMQ Java client library, is triple-licensed under the // Mozilla Public License 2.0 ("MPL"), the GNU General Public License version 2 @@ -24,17 +24,11 @@ import org.junit.runner.RunWith; import org.junit.runners.Parameterized; -import javax.net.ssl.KeyManagerFactory; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLHandshakeException; -import javax.net.ssl.TrustManagerFactory; -import java.io.FileInputStream; -import java.security.KeyStore; import java.util.function.Consumer; -import static com.rabbitmq.client.test.TestUtils.getSSLContext; import static java.util.Collections.singletonList; -import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -73,32 +67,7 @@ private static Consumer enableHostnameVerification() { @BeforeClass public static void initCrypto() throws Exception { - String keystorePath = System.getProperty("test-keystore.ca"); - assertNotNull(keystorePath); - String keystorePasswd = System.getProperty("test-keystore.password"); - assertNotNull(keystorePasswd); - char[] keystorePassword = keystorePasswd.toCharArray(); - - KeyStore tks = KeyStore.getInstance("JKS"); - tks.load(new FileInputStream(keystorePath), keystorePassword); - - TrustManagerFactory tmf = TrustManagerFactory.getInstance("SunX509"); - tmf.init(tks); - - String p12Path = System.getProperty("test-client-cert.path"); - assertNotNull(p12Path); - String p12Passwd = System.getProperty("test-client-cert.password"); - assertNotNull(p12Passwd); - - KeyStore ks = KeyStore.getInstance("PKCS12"); - char[] p12Password = p12Passwd.toCharArray(); - ks.load(new FileInputStream(p12Path), p12Password); - - KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509"); - kmf.init(ks, p12Password); - - sslContext = getSSLContext(); - sslContext.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null); + sslContext = TlsTestUtils.verifiedSslContext(); } @Test(expected = SSLHandshakeException.class) diff --git a/src/test/java/com/rabbitmq/client/test/ssl/NioTlsUnverifiedConnection.java b/src/test/java/com/rabbitmq/client/test/ssl/NioTlsUnverifiedConnection.java index 29fe35899e..37048739e2 100644 --- a/src/test/java/com/rabbitmq/client/test/ssl/NioTlsUnverifiedConnection.java +++ b/src/test/java/com/rabbitmq/client/test/ssl/NioTlsUnverifiedConnection.java @@ -1,4 +1,4 @@ -// Copyright (c) 2007-2020 VMware, Inc. or its affiliates. All rights reserved. +// Copyright (c) 2007-2021 VMware, Inc. or its affiliates. All rights reserved. // // This software, the RabbitMQ Java client library, is triple-licensed under the // Mozilla Public License 2.0 ("MPL"), the GNU General Public License version 2 @@ -18,6 +18,10 @@ import com.rabbitmq.client.*; import com.rabbitmq.client.impl.nio.NioParams; import com.rabbitmq.client.test.BrokerTestCase; +import com.rabbitmq.client.test.TestUtils; +import java.util.concurrent.atomic.AtomicReference; +import javax.net.ssl.SSLContext; +import javax.net.ssl.TrustManager; import org.junit.Test; import org.slf4j.LoggerFactory; @@ -28,6 +32,8 @@ import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicBoolean; +import static com.rabbitmq.client.test.TestUtils.basicGetBasicConsume; +import static org.assertj.core.api.Assertions.assertThat; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -76,6 +82,29 @@ public void connectionGetConsume() throws Exception { assertTrue("Message has not been received", messagesReceived); } + @Test + public void connectionGetConsumeProtocols() throws Exception { + String [] protocols = new String[] {"TLSv1.2", "TLSv1.3"}; + for (String protocol : protocols) { + SSLContext sslContext = SSLContext.getInstance(protocol); + sslContext.init(null, new TrustManager[] {new TrustEverythingTrustManager()}, null); + ConnectionFactory cf = TestUtils.connectionFactory(); + cf.useSslProtocol(sslContext); + cf.useNio(); + AtomicReference engine = new AtomicReference<>(); + cf.setNioParams(new NioParams() + .setSslEngineConfigurator(sslEngine -> engine.set(sslEngine))); + try (Connection c = cf.newConnection()) { + CountDownLatch latch = new CountDownLatch(1); + basicGetBasicConsume(c, QUEUE, latch, 100); + boolean messagesReceived = latch.await(5, TimeUnit.SECONDS); + assertTrue("Message has not been received", messagesReceived); + assertThat(engine.get()).isNotNull(); + assertThat(engine.get().getEnabledProtocols()).contains(protocol); + } + } + } + @Test public void socketChannelConfigurator() throws Exception { ConnectionFactory connectionFactory = new ConnectionFactory(); connectionFactory.useNio(); @@ -119,28 +148,4 @@ private void sendAndVerifyMessage(int size) throws Exception { assertTrue("Message has not been received", messageReceived); } - private boolean basicGetBasicConsume(Connection connection, String queue, final CountDownLatch latch, int msgSize) - throws Exception { - Channel channel = connection.createChannel(); - channel.queueDeclare(queue, false, false, false, null); - channel.queuePurge(queue); - - channel.basicPublish("", queue, null, new byte[msgSize]); - - String tag = channel.basicConsume(queue, false, new DefaultConsumer(channel) { - - @Override - public void handleDelivery(String consumerTag, Envelope envelope, AMQP.BasicProperties properties, byte[] body) throws IOException { - getChannel().basicAck(envelope.getDeliveryTag(), false); - latch.countDown(); - } - }); - - boolean messageReceived = latch.await(20, TimeUnit.SECONDS); - - channel.basicCancel(tag); - - return messageReceived; - } - } diff --git a/src/test/java/com/rabbitmq/client/test/ssl/TlsConnectionLogging.java b/src/test/java/com/rabbitmq/client/test/ssl/TlsConnectionLogging.java index 3aa6fbe330..4693525ea3 100644 --- a/src/test/java/com/rabbitmq/client/test/ssl/TlsConnectionLogging.java +++ b/src/test/java/com/rabbitmq/client/test/ssl/TlsConnectionLogging.java @@ -1,4 +1,4 @@ -// Copyright (c) 2019-2020 VMware, Inc. or its affiliates. All rights reserved. +// Copyright (c) 2019-2021 VMware, Inc. or its affiliates. All rights reserved. // // This software, the RabbitMQ Java client library, is triple-licensed under the // Mozilla Public License 2.0 ("MPL"), the GNU General Public License version 2 @@ -65,7 +65,7 @@ public static Function> nio() { @Test public void certificateInfoAreProperlyExtracted() throws Exception { - SSLContext sslContext = TestUtils.getSSLContext(); + SSLContext sslContext = TlsTestUtils.getSSLContext(); sslContext.init(null, new TrustManager[]{new AlwaysTrustTrustManager()}, null); ConnectionFactory connectionFactory = TestUtils.connectionFactory(); connectionFactory.useSslProtocol(sslContext); diff --git a/src/test/java/com/rabbitmq/client/test/ssl/TlsTestUtils.java b/src/test/java/com/rabbitmq/client/test/ssl/TlsTestUtils.java new file mode 100644 index 0000000000..891bec7f04 --- /dev/null +++ b/src/test/java/com/rabbitmq/client/test/ssl/TlsTestUtils.java @@ -0,0 +1,115 @@ +// Copyright (c) 2021 VMware, Inc. or its affiliates. All rights reserved. +// +// This software, the RabbitMQ Java client library, is triple-licensed under the +// Mozilla Public License 2.0 ("MPL"), the GNU General Public License version 2 +// ("GPL") and the Apache License version 2 ("ASL"). For the MPL, please see +// LICENSE-MPL-RabbitMQ. For the GPL, please see LICENSE-GPL2. For the ASL, +// please see LICENSE-APACHE2. +// +// This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY KIND, +// either express or implied. See the LICENSE file for specific language governing +// rights and limitations of this software. +// +// If you have any questions regarding licensing, please contact us at +// info@rabbitmq.com. + +package com.rabbitmq.client.test.ssl; + +import static org.junit.Assert.assertNotNull; + +import java.io.FileInputStream; +import java.security.KeyStore; +import java.security.NoSuchAlgorithmException; +import java.util.Arrays; +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.TrustManagerFactory; + +class TlsTestUtils { + + private TlsTestUtils() {} + + static SSLContext badVerifiedSslContext() throws Exception { + return verifiedSslContext(() -> getSSLContext(), emptyKeystoreCa()); + } + + static SSLContext verifiedSslContext() throws Exception { + return verifiedSslContext(() -> getSSLContext(), keystoreCa()); + } + + static SSLContext verifiedSslContext(CallableSupplier sslContextSupplier) throws Exception { + return verifiedSslContext(sslContextSupplier, keystoreCa()); + } + + static SSLContext verifiedSslContext(CallableSupplier sslContextSupplier, String keystorePath) throws Exception { + // for local testing, run ./mvnw test-compile -Dtest-tls-certs.dir=/tmp/tls-gen/basic + // (generates the Java keystores) + assertNotNull(keystorePath); + String keystorePasswd = keystorePassword(); + assertNotNull(keystorePasswd); + char [] keystorePassword = keystorePasswd.toCharArray(); + + KeyStore tks = KeyStore.getInstance("JKS"); + tks.load(new FileInputStream(keystorePath), keystorePassword); + + TrustManagerFactory tmf = TrustManagerFactory.getInstance("SunX509"); + tmf.init(tks); + + String p12Path = clientCertPath(); + assertNotNull(p12Path); + String p12Passwd = clientCertPassword(); + assertNotNull(p12Passwd); + KeyStore ks = KeyStore.getInstance("PKCS12"); + char [] p12Password = p12Passwd.toCharArray(); + ks.load(new FileInputStream(p12Path), p12Password); + + KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509"); + kmf.init(ks, p12Password); + + SSLContext c = sslContextSupplier.get(); + c.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null); + return c; + } + + static String keystoreCa() { + return System.getProperty("test-keystore.ca", "./target/ca.keystore"); + } + + static String emptyKeystoreCa() { + return System.getProperty("test-keystore.empty", "./target/empty.keystore"); + } + + static String keystorePassword() { + return System.getProperty("test-keystore.password", "bunnies"); + } + + static String clientCertPath() { + return System.getProperty("test-client-cert.path", "/tmp/tls-gen/basic/client/keycert.p12"); + } + + static String clientCertPassword() { + return System.getProperty("test-client-cert.password", ""); + } + + public static SSLContext getSSLContext() throws NoSuchAlgorithmException { + SSLContext c; + + // pick the first protocol available, preferring TLSv1.2, then TLSv1, + // falling back to SSLv3 if running on an ancient/crippled JDK + for (String proto : Arrays.asList("TLSv1.3", "TLSv1.2", "TLSv1", "SSLv3")) { + try { + c = SSLContext.getInstance(proto); + return c; + } catch (NoSuchAlgorithmException x) { + // keep trying + } + } + throw new NoSuchAlgorithmException(); + } + + @FunctionalInterface + interface CallableSupplier { + + T get() throws Exception; + } +} diff --git a/src/test/java/com/rabbitmq/client/test/ssl/UnverifiedConnection.java b/src/test/java/com/rabbitmq/client/test/ssl/UnverifiedConnection.java index a14a257c24..39e1e90ba5 100644 --- a/src/test/java/com/rabbitmq/client/test/ssl/UnverifiedConnection.java +++ b/src/test/java/com/rabbitmq/client/test/ssl/UnverifiedConnection.java @@ -1,4 +1,4 @@ -// Copyright (c) 2007-2020 VMware, Inc. or its affiliates. All rights reserved. +// Copyright (c) 2007-2021 VMware, Inc. or its affiliates. All rights reserved. // // This software, the RabbitMQ Java client library, is triple-licensed under the // Mozilla Public License 2.0 ("MPL"), the GNU General Public License version 2 @@ -21,8 +21,6 @@ import org.slf4j.LoggerFactory; import java.io.IOException; -import java.security.KeyManagementException; -import java.security.NoSuchAlgorithmException; import java.util.concurrent.TimeoutException; import static org.junit.Assert.*; @@ -37,10 +35,8 @@ public void openConnection() throws IOException, TimeoutException { try { connectionFactory.useSslProtocol(); - } catch (NoSuchAlgorithmException ex) { - throw new IOException(ex.toString()); - } catch (KeyManagementException ex) { - throw new IOException(ex.toString()); + } catch (Exception ex) { + throw new IOException(ex); } int attempt = 0; diff --git a/src/test/java/com/rabbitmq/client/test/ssl/VerifiedConnection.java b/src/test/java/com/rabbitmq/client/test/ssl/VerifiedConnection.java index 50d4d9003b..0f82fb1194 100644 --- a/src/test/java/com/rabbitmq/client/test/ssl/VerifiedConnection.java +++ b/src/test/java/com/rabbitmq/client/test/ssl/VerifiedConnection.java @@ -1,4 +1,4 @@ -// Copyright (c) 2007-2020 VMware, Inc. or its affiliates. All rights reserved. +// Copyright (c) 2007-2021 VMware, Inc. or its affiliates. All rights reserved. // // This software, the RabbitMQ Java client library, is triple-licensed under the // Mozilla Public License 2.0 ("MPL"), the GNU General Public License version 2 @@ -15,25 +15,25 @@ package com.rabbitmq.client.test.ssl; -import static org.junit.Assert.assertNotNull; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; -import java.io.FileInputStream; +import com.rabbitmq.client.Connection; +import com.rabbitmq.client.impl.nio.NioParams; import java.io.IOException; -import java.security.KeyManagementException; -import java.security.KeyStore; -import java.security.KeyStoreException; -import java.security.NoSuchAlgorithmException; -import java.security.UnrecoverableKeyException; -import java.security.cert.CertificateException; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; -import javax.net.ssl.KeyManagerFactory; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; import javax.net.ssl.SSLContext; -import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.SSLSocket; import com.rabbitmq.client.ConnectionFactory; import com.rabbitmq.client.test.TestUtils; +import org.junit.Test; import org.slf4j.LoggerFactory; /** @@ -45,44 +45,11 @@ public class VerifiedConnection extends UnverifiedConnection { public void openConnection() throws IOException, TimeoutException { try { - String keystorePath = System.getProperty("test-keystore.ca"); - assertNotNull(keystorePath); - String keystorePasswd = System.getProperty("test-keystore.password"); - assertNotNull(keystorePasswd); - char [] keystorePassword = keystorePasswd.toCharArray(); - - KeyStore tks = KeyStore.getInstance("JKS"); - tks.load(new FileInputStream(keystorePath), keystorePassword); - - TrustManagerFactory tmf = TrustManagerFactory.getInstance("SunX509"); - tmf.init(tks); - - String p12Path = System.getProperty("test-client-cert.path"); - assertNotNull(p12Path); - String p12Passwd = System.getProperty("test-client-cert.password"); - assertNotNull(p12Passwd); - KeyStore ks = KeyStore.getInstance("PKCS12"); - char [] p12Password = p12Passwd.toCharArray(); - ks.load(new FileInputStream(p12Path), p12Password); - - KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509"); - kmf.init(ks, p12Password); - - SSLContext c = getSSLContext(); - c.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null); - + SSLContext c = TlsTestUtils.verifiedSslContext(); connectionFactory = TestUtils.connectionFactory(); connectionFactory.useSslProtocol(c); - } catch (NoSuchAlgorithmException ex) { - throw new IOException(ex.toString()); - } catch (KeyManagementException ex) { - throw new IOException(ex.toString()); - } catch (KeyStoreException ex) { - throw new IOException(ex.toString()); - } catch (CertificateException ex) { - throw new IOException(ex.toString()); - } catch (UnrecoverableKeyException ex) { - throw new IOException(ex.toString()); + } catch (Exception ex) { + throw new IOException(ex); } int attempt = 0; @@ -99,4 +66,36 @@ public void openConnection() fail("Couldn't open TLS connection after 3 attempts"); } } + + @Test + public void connectionGetConsumeProtocols() throws Exception { + String [] protocols = new String[] {"TLSv1.2", "TLSv1.3"}; + for (String protocol : protocols) { + SSLContext sslContext = SSLContext.getInstance(protocol); + ConnectionFactory cf = TestUtils.connectionFactory(); + cf.useSslProtocol(TlsTestUtils.verifiedSslContext(() -> sslContext)); + AtomicReference> protocolsSupplier = new AtomicReference<>(); + if (TestUtils.USE_NIO) { + cf.useNio(); + cf.setNioParams(new NioParams() + .setSslEngineConfigurator(sslEngine -> { + protocolsSupplier.set(() -> sslEngine.getEnabledProtocols()); + })); + } else { + cf.setSocketConfigurator(socket -> { + SSLSocket s = (SSLSocket) socket; + protocolsSupplier.set(() -> s.getEnabledProtocols()); + }); + } + try (Connection c = cf.newConnection()) { + CountDownLatch latch = new CountDownLatch(1); + TestUtils.basicGetBasicConsume(c, VerifiedConnection.class.getName(), latch, 100); + boolean messagesReceived = latch.await(5, TimeUnit.SECONDS); + assertTrue("Message has not been received", messagesReceived); + assertThat(protocolsSupplier.get()).isNotNull(); + assertThat(protocolsSupplier.get().get()).contains(protocol); + } + } + } + } diff --git a/src/test/resources/logback-test.xml b/src/test/resources/logback-test.xml index 4bd2e37606..ee88f442c2 100644 --- a/src/test/resources/logback-test.xml +++ b/src/test/resources/logback-test.xml @@ -5,7 +5,7 @@ - + \ No newline at end of file