diff --git a/src/main/java/com/rabbitmq/client/impl/nio/SslEngineHelper.java b/src/main/java/com/rabbitmq/client/impl/nio/SslEngineHelper.java index 1e7e3a0793..bcefe8b205 100644 --- a/src/main/java/com/rabbitmq/client/impl/nio/SslEngineHelper.java +++ b/src/main/java/com/rabbitmq/client/impl/nio/SslEngineHelper.java @@ -1,4 +1,4 @@ -// Copyright (c) 2007-2020 VMware, Inc. or its affiliates. All rights reserved. +// Copyright (c) 2007-2021 VMware, Inc. or its affiliates. All rights reserved. // // This software, the RabbitMQ Java client library, is triple-licensed under the // Mozilla Public License 2.0 ("MPL"), the GNU General Public License version 2 @@ -23,8 +23,12 @@ import java.nio.channels.ReadableByteChannel; import java.nio.channels.SocketChannel; import java.nio.channels.WritableByteChannel; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; import static javax.net.ssl.SSLEngineResult.HandshakeStatus.FINISHED; +import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_TASK; +import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_WRAP; import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING; /** @@ -32,6 +36,8 @@ */ public class SslEngineHelper { + private static final Logger LOGGER = LoggerFactory.getLogger(SslEngineHelper.class); + public static boolean doHandshake(SocketChannel socketChannel, SSLEngine engine) throws IOException { ByteBuffer plainOut = ByteBuffer.allocate(engine.getSession().getApplicationBufferSize()); @@ -39,20 +45,36 @@ public static boolean doHandshake(SocketChannel socketChannel, SSLEngine engine) ByteBuffer cipherOut = ByteBuffer.allocate(engine.getSession().getPacketBufferSize()); ByteBuffer cipherIn = ByteBuffer.allocate(engine.getSession().getPacketBufferSize()); + LOGGER.debug("Starting TLS handshake"); + SSLEngineResult.HandshakeStatus handshakeStatus = engine.getHandshakeStatus(); + LOGGER.debug("Initial handshake status is {}", handshakeStatus); while (handshakeStatus != FINISHED && handshakeStatus != NOT_HANDSHAKING) { + LOGGER.debug("Handshake status is {}", handshakeStatus); switch (handshakeStatus) { case NEED_TASK: + LOGGER.debug("Running tasks"); handshakeStatus = runDelegatedTasks(engine); break; case NEED_UNWRAP: + LOGGER.debug("Unwrapping..."); handshakeStatus = unwrap(cipherIn, plainIn, socketChannel, engine); break; case NEED_WRAP: + LOGGER.debug("Wrapping..."); handshakeStatus = wrap(plainOut, cipherOut, socketChannel, engine); break; + case FINISHED: + break; + case NOT_HANDSHAKING: + break; + default: + throw new SSLException("Unexpected handshake status " + handshakeStatus); } } + + + LOGGER.debug("TLS handshake completed"); return true; } @@ -60,6 +82,7 @@ private static SSLEngineResult.HandshakeStatus runDelegatedTasks(SSLEngine sslEn // FIXME run in executor? Runnable runnable; while ((runnable = sslEngine.getDelegatedTask()) != null) { + LOGGER.debug("Running delegated task"); runnable.run(); } return sslEngine.getHandshakeStatus(); @@ -68,29 +91,57 @@ private static SSLEngineResult.HandshakeStatus runDelegatedTasks(SSLEngine sslEn private static SSLEngineResult.HandshakeStatus unwrap(ByteBuffer cipherIn, ByteBuffer plainIn, ReadableByteChannel channel, SSLEngine sslEngine) throws IOException { SSLEngineResult.HandshakeStatus handshakeStatus = sslEngine.getHandshakeStatus(); - - if (channel.read(cipherIn) < 0) { - throw new SSLException("Could not read from socket channel"); + LOGGER.debug("Handshake status is {} before unwrapping", handshakeStatus); + + LOGGER.debug("Cipher in position {}", cipherIn.position()); + int read; + if (cipherIn.position() == 0) { + LOGGER.debug("Reading from channel"); + read = channel.read(cipherIn); + LOGGER.debug("Read {} byte(s) from channel", read); + if (read < 0) { + throw new SSLException("Could not read from socket channel"); + } + cipherIn.flip(); + } else { + LOGGER.debug("Not reading"); } - cipherIn.flip(); SSLEngineResult.Status status; + SSLEngineResult unwrapResult; do { - SSLEngineResult unwrapResult = sslEngine.unwrap(cipherIn, plainIn); + int positionBeforeUnwrapping = cipherIn.position(); + unwrapResult = sslEngine.unwrap(cipherIn, plainIn); + LOGGER.debug("SSL engine result is {} after unwrapping", unwrapResult); status = unwrapResult.getStatus(); switch (status) { case OK: plainIn.clear(); - handshakeStatus = runDelegatedTasks(sslEngine); + if (unwrapResult.getHandshakeStatus() == NEED_TASK) { + handshakeStatus = runDelegatedTasks(sslEngine); + int newPosition = positionBeforeUnwrapping + unwrapResult.bytesConsumed(); + if (newPosition == cipherIn.limit()) { + LOGGER.debug("Clearing cipherIn because all bytes have been read and unwrapped"); + cipherIn.clear(); + } else { + LOGGER.debug("Setting cipherIn position to {} (limit is {})", newPosition, cipherIn.limit()); + cipherIn.position(positionBeforeUnwrapping + unwrapResult.bytesConsumed()); + } + } else { + handshakeStatus = unwrapResult.getHandshakeStatus(); + } break; case BUFFER_OVERFLOW: throw new SSLException("Buffer overflow during handshake"); case BUFFER_UNDERFLOW: + LOGGER.debug("Buffer underflow"); cipherIn.compact(); - int read = NioHelper.read(channel, cipherIn); + LOGGER.debug("Reading from channel..."); + read = NioHelper.read(channel, cipherIn); if(read <= 0) { retryRead(channel, cipherIn); } + LOGGER.debug("Done reading from channel..."); cipherIn.flip(); break; case CLOSED: @@ -100,9 +151,9 @@ private static SSLEngineResult.HandshakeStatus unwrap(ByteBuffer cipherIn, ByteB throw new SSLException("Unexpected status from " + unwrapResult); } } - while (cipherIn.hasRemaining()); + while (unwrapResult.getHandshakeStatus() != NEED_WRAP && unwrapResult.getHandshakeStatus() != FINISHED); - cipherIn.compact(); + LOGGER.debug("cipherIn position after unwrap {}", cipherIn.position()); return handshakeStatus; } @@ -127,36 +178,32 @@ private static int retryRead(ReadableByteChannel channel, ByteBuffer buffer) thr private static SSLEngineResult.HandshakeStatus wrap(ByteBuffer plainOut, ByteBuffer cipherOut, WritableByteChannel channel, SSLEngine sslEngine) throws IOException { SSLEngineResult.HandshakeStatus handshakeStatus = sslEngine.getHandshakeStatus(); - SSLEngineResult.Status status = sslEngine.wrap(plainOut, cipherOut).getStatus(); - switch (status) { + LOGGER.debug("Handshake status is {} before wrapping", handshakeStatus); + SSLEngineResult result = sslEngine.wrap(plainOut, cipherOut); + LOGGER.debug("SSL engine result is {} after wrapping", result); + switch (result.getStatus()) { case OK: - handshakeStatus = runDelegatedTasks(sslEngine); cipherOut.flip(); while (cipherOut.hasRemaining()) { - channel.write(cipherOut); + int written = channel.write(cipherOut); + LOGGER.debug("Wrote {} byte(s)", written); } cipherOut.clear(); + if (result.getHandshakeStatus() == NEED_TASK) { + handshakeStatus = runDelegatedTasks(sslEngine); + } else { + handshakeStatus = result.getHandshakeStatus(); + } + break; case BUFFER_OVERFLOW: throw new SSLException("Buffer overflow during handshake"); default: - throw new SSLException("Unexpected status " + status); + throw new SSLException("Unexpected status " + result.getStatus()); } return handshakeStatus; } - static int bufferCopy(ByteBuffer from, ByteBuffer to) { - int maxTransfer = Math.min(to.remaining(), from.remaining()); - - ByteBuffer temporaryBuffer = from.duplicate(); - temporaryBuffer.limit(temporaryBuffer.position() + maxTransfer); - to.put(temporaryBuffer); - - from.position(from.position() + maxTransfer); - - return maxTransfer; - } - public static void write(WritableByteChannel socketChannel, SSLEngine engine, ByteBuffer plainOut, ByteBuffer cypherOut) throws IOException { while (plainOut.hasRemaining()) { cypherOut.clear(); diff --git a/src/test/java/com/rabbitmq/client/test/BrokerTestCase.java b/src/test/java/com/rabbitmq/client/test/BrokerTestCase.java index 7c23a3c0e6..37cf436db4 100644 --- a/src/test/java/com/rabbitmq/client/test/BrokerTestCase.java +++ b/src/test/java/com/rabbitmq/client/test/BrokerTestCase.java @@ -1,4 +1,4 @@ -// Copyright (c) 2007-2020 VMware, Inc. or its affiliates. All rights reserved. +// Copyright (c) 2007-2021 VMware, Inc. or its affiliates. All rights reserved. // // This software, the RabbitMQ Java client library, is triple-licensed under the // Mozilla Public License 2.0 ("MPL"), the GNU General Public License version 2 @@ -28,9 +28,7 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; -import javax.net.ssl.SSLContext; import java.io.IOException; -import java.security.NoSuchAlgorithmException; import java.util.Map; import java.util.UUID; import java.util.concurrent.TimeoutException; @@ -348,7 +346,4 @@ protected String generateExchangeName() { return "exchange" + UUID.randomUUID().toString(); } - protected SSLContext getSSLContext() throws NoSuchAlgorithmException { - return TestUtils.getSSLContext(); - } } diff --git a/src/test/java/com/rabbitmq/client/test/TestUtils.java b/src/test/java/com/rabbitmq/client/test/TestUtils.java index 7544893760..c488fcff6d 100644 --- a/src/test/java/com/rabbitmq/client/test/TestUtils.java +++ b/src/test/java/com/rabbitmq/client/test/TestUtils.java @@ -1,4 +1,4 @@ -// Copyright (c) 2007-2020 VMware, Inc. or its affiliates. All rights reserved. +// Copyright (c) 2007-2021 VMware, Inc. or its affiliates. All rights reserved. // // This software, the RabbitMQ Java client library, is triple-licensed under the // Mozilla Public License 2.0 ("MPL"), the GNU General Public License version 2 @@ -26,19 +26,15 @@ import org.junit.runners.model.Statement; import org.slf4j.LoggerFactory; -import javax.net.ssl.SSLContext; import java.io.IOException; import java.net.ServerSocket; -import java.security.NoSuchAlgorithmException; import java.time.Duration; -import java.util.Arrays; import java.util.Collection; import java.util.Collections; import java.util.concurrent.Callable; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; -import java.util.function.BooleanSupplier; import static org.junit.Assert.assertTrue; @@ -109,22 +105,6 @@ public static void abort(Connection connection) { } } - public static SSLContext getSSLContext() throws NoSuchAlgorithmException { - SSLContext c = null; - - // pick the first protocol available, preferring TLSv1.2, then TLSv1, - // falling back to SSLv3 if running on an ancient/crippled JDK - for (String proto : Arrays.asList("TLSv1.2", "TLSv1", "SSLv3")) { - try { - c = SSLContext.getInstance(proto); - return c; - } catch (NoSuchAlgorithmException x) { - // keep trying - } - } - throw new NoSuchAlgorithmException(); - } - public static TestRule atLeast38() { return new BrokerVersionTestRule("3.8.0"); } @@ -361,4 +341,27 @@ public interface CallableFunction { } + public static boolean basicGetBasicConsume(Connection connection, String queue, final CountDownLatch latch, int msgSize) + throws Exception { + Channel channel = connection.createChannel(); + channel.queueDeclare(queue, false, true, false, null); + channel.queuePurge(queue); + + channel.basicPublish("", queue, null, new byte[msgSize]); + + String tag = channel.basicConsume(queue, false, new DefaultConsumer(channel) { + + @Override + public void handleDelivery(String consumerTag, Envelope envelope, AMQP.BasicProperties properties, byte[] body) throws IOException { + getChannel().basicAck(envelope.getDeliveryTag(), false); + latch.countDown(); + } + }); + + boolean messageReceived = latch.await(20, TimeUnit.SECONDS); + + channel.basicCancel(tag); + + return messageReceived; + } } diff --git a/src/test/java/com/rabbitmq/client/test/ssl/BadVerifiedConnection.java b/src/test/java/com/rabbitmq/client/test/ssl/BadVerifiedConnection.java index 9137213578..fe33af7dec 100644 --- a/src/test/java/com/rabbitmq/client/test/ssl/BadVerifiedConnection.java +++ b/src/test/java/com/rabbitmq/client/test/ssl/BadVerifiedConnection.java @@ -1,4 +1,4 @@ -// Copyright (c) 2007-2020 VMware, Inc. or its affiliates. All rights reserved. +// Copyright (c) 2007-2021 VMware, Inc. or its affiliates. All rights reserved. // // This software, the RabbitMQ Java client library, is triple-licensed under the // Mozilla Public License 2.0 ("MPL"), the GNU General Public License version 2 @@ -15,20 +15,13 @@ package com.rabbitmq.client.test.ssl; -import com.rabbitmq.client.test.TestUtils; import org.junit.Test; -import javax.net.ssl.KeyManagerFactory; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLHandshakeException; -import javax.net.ssl.TrustManagerFactory; -import java.io.FileInputStream; import java.io.IOException; -import java.security.*; -import java.security.cert.CertificateException; import java.util.concurrent.TimeoutException; -import static org.junit.Assert.assertNotNull; import static org.junit.Assert.fail; /** @@ -39,44 +32,10 @@ public class BadVerifiedConnection extends UnverifiedConnection { public void openConnection() throws IOException, TimeoutException { try { - String keystorePath = System.getProperty("test-keystore.empty"); - assertNotNull(keystorePath); - String keystorePasswd = System.getProperty("test-keystore.password"); - assertNotNull(keystorePasswd); - char [] keystorePassword = keystorePasswd.toCharArray(); - - KeyStore tks = KeyStore.getInstance("JKS"); - tks.load(new FileInputStream(keystorePath), keystorePassword); - - TrustManagerFactory tmf = TrustManagerFactory.getInstance("SunX509"); - tmf.init(tks); - - String p12Path = System.getProperty("test-client-cert.path"); - assertNotNull(p12Path); - String p12Passwd = System.getProperty("test-client-cert.password"); - assertNotNull(p12Passwd); - KeyStore ks = KeyStore.getInstance("PKCS12"); - char [] p12Password = p12Passwd.toCharArray(); - ks.load(new FileInputStream(p12Path), p12Password); - - KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509"); - kmf.init(ks, p12Password); - - SSLContext c = getSSLContext(); - c.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null); - - connectionFactory = TestUtils.connectionFactory(); + SSLContext c = TlsTestUtils.badVerifiedSslContext(); connectionFactory.useSslProtocol(c); - } catch (NoSuchAlgorithmException ex) { - throw new IOException(ex.toString()); - } catch (KeyManagementException ex) { - throw new IOException(ex.toString()); - } catch (KeyStoreException ex) { - throw new IOException(ex.toString()); - } catch (CertificateException ex) { - throw new IOException(ex.toString()); - } catch (UnrecoverableKeyException ex) { - throw new IOException(ex.toString()); + } catch (Exception ex) { + throw new IOException(ex); } try { diff --git a/src/test/java/com/rabbitmq/client/test/ssl/HostnameVerification.java b/src/test/java/com/rabbitmq/client/test/ssl/HostnameVerification.java index 36a66a940d..acbfe48260 100644 --- a/src/test/java/com/rabbitmq/client/test/ssl/HostnameVerification.java +++ b/src/test/java/com/rabbitmq/client/test/ssl/HostnameVerification.java @@ -1,4 +1,4 @@ -// Copyright (c) 2007-2020 VMware, Inc. or its affiliates. All rights reserved. +// Copyright (c) 2007-2021 VMware, Inc. or its affiliates. All rights reserved. // // This software, the RabbitMQ Java client library, is triple-licensed under the // Mozilla Public License 2.0 ("MPL"), the GNU General Public License version 2 @@ -24,17 +24,11 @@ import org.junit.runner.RunWith; import org.junit.runners.Parameterized; -import javax.net.ssl.KeyManagerFactory; import javax.net.ssl.SSLContext; import javax.net.ssl.SSLHandshakeException; -import javax.net.ssl.TrustManagerFactory; -import java.io.FileInputStream; -import java.security.KeyStore; import java.util.function.Consumer; -import static com.rabbitmq.client.test.TestUtils.getSSLContext; import static java.util.Collections.singletonList; -import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -73,32 +67,7 @@ private static Consumer enableHostnameVerification() { @BeforeClass public static void initCrypto() throws Exception { - String keystorePath = System.getProperty("test-keystore.ca"); - assertNotNull(keystorePath); - String keystorePasswd = System.getProperty("test-keystore.password"); - assertNotNull(keystorePasswd); - char[] keystorePassword = keystorePasswd.toCharArray(); - - KeyStore tks = KeyStore.getInstance("JKS"); - tks.load(new FileInputStream(keystorePath), keystorePassword); - - TrustManagerFactory tmf = TrustManagerFactory.getInstance("SunX509"); - tmf.init(tks); - - String p12Path = System.getProperty("test-client-cert.path"); - assertNotNull(p12Path); - String p12Passwd = System.getProperty("test-client-cert.password"); - assertNotNull(p12Passwd); - - KeyStore ks = KeyStore.getInstance("PKCS12"); - char[] p12Password = p12Passwd.toCharArray(); - ks.load(new FileInputStream(p12Path), p12Password); - - KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509"); - kmf.init(ks, p12Password); - - sslContext = getSSLContext(); - sslContext.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null); + sslContext = TlsTestUtils.verifiedSslContext(); } @Test(expected = SSLHandshakeException.class) diff --git a/src/test/java/com/rabbitmq/client/test/ssl/NioTlsUnverifiedConnection.java b/src/test/java/com/rabbitmq/client/test/ssl/NioTlsUnverifiedConnection.java index 29fe35899e..37048739e2 100644 --- a/src/test/java/com/rabbitmq/client/test/ssl/NioTlsUnverifiedConnection.java +++ b/src/test/java/com/rabbitmq/client/test/ssl/NioTlsUnverifiedConnection.java @@ -1,4 +1,4 @@ -// Copyright (c) 2007-2020 VMware, Inc. or its affiliates. All rights reserved. +// Copyright (c) 2007-2021 VMware, Inc. or its affiliates. All rights reserved. // // This software, the RabbitMQ Java client library, is triple-licensed under the // Mozilla Public License 2.0 ("MPL"), the GNU General Public License version 2 @@ -18,6 +18,10 @@ import com.rabbitmq.client.*; import com.rabbitmq.client.impl.nio.NioParams; import com.rabbitmq.client.test.BrokerTestCase; +import com.rabbitmq.client.test.TestUtils; +import java.util.concurrent.atomic.AtomicReference; +import javax.net.ssl.SSLContext; +import javax.net.ssl.TrustManager; import org.junit.Test; import org.slf4j.LoggerFactory; @@ -28,6 +32,8 @@ import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicBoolean; +import static com.rabbitmq.client.test.TestUtils.basicGetBasicConsume; +import static org.assertj.core.api.Assertions.assertThat; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -76,6 +82,29 @@ public void connectionGetConsume() throws Exception { assertTrue("Message has not been received", messagesReceived); } + @Test + public void connectionGetConsumeProtocols() throws Exception { + String [] protocols = new String[] {"TLSv1.2", "TLSv1.3"}; + for (String protocol : protocols) { + SSLContext sslContext = SSLContext.getInstance(protocol); + sslContext.init(null, new TrustManager[] {new TrustEverythingTrustManager()}, null); + ConnectionFactory cf = TestUtils.connectionFactory(); + cf.useSslProtocol(sslContext); + cf.useNio(); + AtomicReference engine = new AtomicReference<>(); + cf.setNioParams(new NioParams() + .setSslEngineConfigurator(sslEngine -> engine.set(sslEngine))); + try (Connection c = cf.newConnection()) { + CountDownLatch latch = new CountDownLatch(1); + basicGetBasicConsume(c, QUEUE, latch, 100); + boolean messagesReceived = latch.await(5, TimeUnit.SECONDS); + assertTrue("Message has not been received", messagesReceived); + assertThat(engine.get()).isNotNull(); + assertThat(engine.get().getEnabledProtocols()).contains(protocol); + } + } + } + @Test public void socketChannelConfigurator() throws Exception { ConnectionFactory connectionFactory = new ConnectionFactory(); connectionFactory.useNio(); @@ -119,28 +148,4 @@ private void sendAndVerifyMessage(int size) throws Exception { assertTrue("Message has not been received", messageReceived); } - private boolean basicGetBasicConsume(Connection connection, String queue, final CountDownLatch latch, int msgSize) - throws Exception { - Channel channel = connection.createChannel(); - channel.queueDeclare(queue, false, false, false, null); - channel.queuePurge(queue); - - channel.basicPublish("", queue, null, new byte[msgSize]); - - String tag = channel.basicConsume(queue, false, new DefaultConsumer(channel) { - - @Override - public void handleDelivery(String consumerTag, Envelope envelope, AMQP.BasicProperties properties, byte[] body) throws IOException { - getChannel().basicAck(envelope.getDeliveryTag(), false); - latch.countDown(); - } - }); - - boolean messageReceived = latch.await(20, TimeUnit.SECONDS); - - channel.basicCancel(tag); - - return messageReceived; - } - } diff --git a/src/test/java/com/rabbitmq/client/test/ssl/TlsConnectionLogging.java b/src/test/java/com/rabbitmq/client/test/ssl/TlsConnectionLogging.java index 3aa6fbe330..4693525ea3 100644 --- a/src/test/java/com/rabbitmq/client/test/ssl/TlsConnectionLogging.java +++ b/src/test/java/com/rabbitmq/client/test/ssl/TlsConnectionLogging.java @@ -1,4 +1,4 @@ -// Copyright (c) 2019-2020 VMware, Inc. or its affiliates. All rights reserved. +// Copyright (c) 2019-2021 VMware, Inc. or its affiliates. All rights reserved. // // This software, the RabbitMQ Java client library, is triple-licensed under the // Mozilla Public License 2.0 ("MPL"), the GNU General Public License version 2 @@ -65,7 +65,7 @@ public static Function> nio() { @Test public void certificateInfoAreProperlyExtracted() throws Exception { - SSLContext sslContext = TestUtils.getSSLContext(); + SSLContext sslContext = TlsTestUtils.getSSLContext(); sslContext.init(null, new TrustManager[]{new AlwaysTrustTrustManager()}, null); ConnectionFactory connectionFactory = TestUtils.connectionFactory(); connectionFactory.useSslProtocol(sslContext); diff --git a/src/test/java/com/rabbitmq/client/test/ssl/TlsTestUtils.java b/src/test/java/com/rabbitmq/client/test/ssl/TlsTestUtils.java new file mode 100644 index 0000000000..891bec7f04 --- /dev/null +++ b/src/test/java/com/rabbitmq/client/test/ssl/TlsTestUtils.java @@ -0,0 +1,115 @@ +// Copyright (c) 2021 VMware, Inc. or its affiliates. All rights reserved. +// +// This software, the RabbitMQ Java client library, is triple-licensed under the +// Mozilla Public License 2.0 ("MPL"), the GNU General Public License version 2 +// ("GPL") and the Apache License version 2 ("ASL"). For the MPL, please see +// LICENSE-MPL-RabbitMQ. For the GPL, please see LICENSE-GPL2. For the ASL, +// please see LICENSE-APACHE2. +// +// This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY KIND, +// either express or implied. See the LICENSE file for specific language governing +// rights and limitations of this software. +// +// If you have any questions regarding licensing, please contact us at +// info@rabbitmq.com. + +package com.rabbitmq.client.test.ssl; + +import static org.junit.Assert.assertNotNull; + +import java.io.FileInputStream; +import java.security.KeyStore; +import java.security.NoSuchAlgorithmException; +import java.util.Arrays; +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.TrustManagerFactory; + +class TlsTestUtils { + + private TlsTestUtils() {} + + static SSLContext badVerifiedSslContext() throws Exception { + return verifiedSslContext(() -> getSSLContext(), emptyKeystoreCa()); + } + + static SSLContext verifiedSslContext() throws Exception { + return verifiedSslContext(() -> getSSLContext(), keystoreCa()); + } + + static SSLContext verifiedSslContext(CallableSupplier sslContextSupplier) throws Exception { + return verifiedSslContext(sslContextSupplier, keystoreCa()); + } + + static SSLContext verifiedSslContext(CallableSupplier sslContextSupplier, String keystorePath) throws Exception { + // for local testing, run ./mvnw test-compile -Dtest-tls-certs.dir=/tmp/tls-gen/basic + // (generates the Java keystores) + assertNotNull(keystorePath); + String keystorePasswd = keystorePassword(); + assertNotNull(keystorePasswd); + char [] keystorePassword = keystorePasswd.toCharArray(); + + KeyStore tks = KeyStore.getInstance("JKS"); + tks.load(new FileInputStream(keystorePath), keystorePassword); + + TrustManagerFactory tmf = TrustManagerFactory.getInstance("SunX509"); + tmf.init(tks); + + String p12Path = clientCertPath(); + assertNotNull(p12Path); + String p12Passwd = clientCertPassword(); + assertNotNull(p12Passwd); + KeyStore ks = KeyStore.getInstance("PKCS12"); + char [] p12Password = p12Passwd.toCharArray(); + ks.load(new FileInputStream(p12Path), p12Password); + + KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509"); + kmf.init(ks, p12Password); + + SSLContext c = sslContextSupplier.get(); + c.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null); + return c; + } + + static String keystoreCa() { + return System.getProperty("test-keystore.ca", "./target/ca.keystore"); + } + + static String emptyKeystoreCa() { + return System.getProperty("test-keystore.empty", "./target/empty.keystore"); + } + + static String keystorePassword() { + return System.getProperty("test-keystore.password", "bunnies"); + } + + static String clientCertPath() { + return System.getProperty("test-client-cert.path", "/tmp/tls-gen/basic/client/keycert.p12"); + } + + static String clientCertPassword() { + return System.getProperty("test-client-cert.password", ""); + } + + public static SSLContext getSSLContext() throws NoSuchAlgorithmException { + SSLContext c; + + // pick the first protocol available, preferring TLSv1.2, then TLSv1, + // falling back to SSLv3 if running on an ancient/crippled JDK + for (String proto : Arrays.asList("TLSv1.3", "TLSv1.2", "TLSv1", "SSLv3")) { + try { + c = SSLContext.getInstance(proto); + return c; + } catch (NoSuchAlgorithmException x) { + // keep trying + } + } + throw new NoSuchAlgorithmException(); + } + + @FunctionalInterface + interface CallableSupplier { + + T get() throws Exception; + } +} diff --git a/src/test/java/com/rabbitmq/client/test/ssl/UnverifiedConnection.java b/src/test/java/com/rabbitmq/client/test/ssl/UnverifiedConnection.java index a14a257c24..39e1e90ba5 100644 --- a/src/test/java/com/rabbitmq/client/test/ssl/UnverifiedConnection.java +++ b/src/test/java/com/rabbitmq/client/test/ssl/UnverifiedConnection.java @@ -1,4 +1,4 @@ -// Copyright (c) 2007-2020 VMware, Inc. or its affiliates. All rights reserved. +// Copyright (c) 2007-2021 VMware, Inc. or its affiliates. All rights reserved. // // This software, the RabbitMQ Java client library, is triple-licensed under the // Mozilla Public License 2.0 ("MPL"), the GNU General Public License version 2 @@ -21,8 +21,6 @@ import org.slf4j.LoggerFactory; import java.io.IOException; -import java.security.KeyManagementException; -import java.security.NoSuchAlgorithmException; import java.util.concurrent.TimeoutException; import static org.junit.Assert.*; @@ -37,10 +35,8 @@ public void openConnection() throws IOException, TimeoutException { try { connectionFactory.useSslProtocol(); - } catch (NoSuchAlgorithmException ex) { - throw new IOException(ex.toString()); - } catch (KeyManagementException ex) { - throw new IOException(ex.toString()); + } catch (Exception ex) { + throw new IOException(ex); } int attempt = 0; diff --git a/src/test/java/com/rabbitmq/client/test/ssl/VerifiedConnection.java b/src/test/java/com/rabbitmq/client/test/ssl/VerifiedConnection.java index 50d4d9003b..0f82fb1194 100644 --- a/src/test/java/com/rabbitmq/client/test/ssl/VerifiedConnection.java +++ b/src/test/java/com/rabbitmq/client/test/ssl/VerifiedConnection.java @@ -1,4 +1,4 @@ -// Copyright (c) 2007-2020 VMware, Inc. or its affiliates. All rights reserved. +// Copyright (c) 2007-2021 VMware, Inc. or its affiliates. All rights reserved. // // This software, the RabbitMQ Java client library, is triple-licensed under the // Mozilla Public License 2.0 ("MPL"), the GNU General Public License version 2 @@ -15,25 +15,25 @@ package com.rabbitmq.client.test.ssl; -import static org.junit.Assert.assertNotNull; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; -import java.io.FileInputStream; +import com.rabbitmq.client.Connection; +import com.rabbitmq.client.impl.nio.NioParams; import java.io.IOException; -import java.security.KeyManagementException; -import java.security.KeyStore; -import java.security.KeyStoreException; -import java.security.NoSuchAlgorithmException; -import java.security.UnrecoverableKeyException; -import java.security.cert.CertificateException; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; -import javax.net.ssl.KeyManagerFactory; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Supplier; import javax.net.ssl.SSLContext; -import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.SSLSocket; import com.rabbitmq.client.ConnectionFactory; import com.rabbitmq.client.test.TestUtils; +import org.junit.Test; import org.slf4j.LoggerFactory; /** @@ -45,44 +45,11 @@ public class VerifiedConnection extends UnverifiedConnection { public void openConnection() throws IOException, TimeoutException { try { - String keystorePath = System.getProperty("test-keystore.ca"); - assertNotNull(keystorePath); - String keystorePasswd = System.getProperty("test-keystore.password"); - assertNotNull(keystorePasswd); - char [] keystorePassword = keystorePasswd.toCharArray(); - - KeyStore tks = KeyStore.getInstance("JKS"); - tks.load(new FileInputStream(keystorePath), keystorePassword); - - TrustManagerFactory tmf = TrustManagerFactory.getInstance("SunX509"); - tmf.init(tks); - - String p12Path = System.getProperty("test-client-cert.path"); - assertNotNull(p12Path); - String p12Passwd = System.getProperty("test-client-cert.password"); - assertNotNull(p12Passwd); - KeyStore ks = KeyStore.getInstance("PKCS12"); - char [] p12Password = p12Passwd.toCharArray(); - ks.load(new FileInputStream(p12Path), p12Password); - - KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509"); - kmf.init(ks, p12Password); - - SSLContext c = getSSLContext(); - c.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null); - + SSLContext c = TlsTestUtils.verifiedSslContext(); connectionFactory = TestUtils.connectionFactory(); connectionFactory.useSslProtocol(c); - } catch (NoSuchAlgorithmException ex) { - throw new IOException(ex.toString()); - } catch (KeyManagementException ex) { - throw new IOException(ex.toString()); - } catch (KeyStoreException ex) { - throw new IOException(ex.toString()); - } catch (CertificateException ex) { - throw new IOException(ex.toString()); - } catch (UnrecoverableKeyException ex) { - throw new IOException(ex.toString()); + } catch (Exception ex) { + throw new IOException(ex); } int attempt = 0; @@ -99,4 +66,36 @@ public void openConnection() fail("Couldn't open TLS connection after 3 attempts"); } } + + @Test + public void connectionGetConsumeProtocols() throws Exception { + String [] protocols = new String[] {"TLSv1.2", "TLSv1.3"}; + for (String protocol : protocols) { + SSLContext sslContext = SSLContext.getInstance(protocol); + ConnectionFactory cf = TestUtils.connectionFactory(); + cf.useSslProtocol(TlsTestUtils.verifiedSslContext(() -> sslContext)); + AtomicReference> protocolsSupplier = new AtomicReference<>(); + if (TestUtils.USE_NIO) { + cf.useNio(); + cf.setNioParams(new NioParams() + .setSslEngineConfigurator(sslEngine -> { + protocolsSupplier.set(() -> sslEngine.getEnabledProtocols()); + })); + } else { + cf.setSocketConfigurator(socket -> { + SSLSocket s = (SSLSocket) socket; + protocolsSupplier.set(() -> s.getEnabledProtocols()); + }); + } + try (Connection c = cf.newConnection()) { + CountDownLatch latch = new CountDownLatch(1); + TestUtils.basicGetBasicConsume(c, VerifiedConnection.class.getName(), latch, 100); + boolean messagesReceived = latch.await(5, TimeUnit.SECONDS); + assertTrue("Message has not been received", messagesReceived); + assertThat(protocolsSupplier.get()).isNotNull(); + assertThat(protocolsSupplier.get().get()).contains(protocol); + } + } + } + } diff --git a/src/test/resources/logback-test.xml b/src/test/resources/logback-test.xml index 4bd2e37606..ee88f442c2 100644 --- a/src/test/resources/logback-test.xml +++ b/src/test/resources/logback-test.xml @@ -5,7 +5,7 @@ - + \ No newline at end of file