Skip to content

Fix handshake with NIO on TLS 1.3 #716

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 74 additions & 27 deletions src/main/java/com/rabbitmq/client/impl/nio/SslEngineHelper.java
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2007-2020 VMware, Inc. or its affiliates. All rights reserved.
// Copyright (c) 2007-2021 VMware, Inc. or its affiliates. All rights reserved.
//
// This software, the RabbitMQ Java client library, is triple-licensed under the
// Mozilla Public License 2.0 ("MPL"), the GNU General Public License version 2
Expand All @@ -23,43 +23,66 @@
import java.nio.channels.ReadableByteChannel;
import java.nio.channels.SocketChannel;
import java.nio.channels.WritableByteChannel;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import static javax.net.ssl.SSLEngineResult.HandshakeStatus.FINISHED;
import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_TASK;
import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_WRAP;
import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING;

/**
*
*/
public class SslEngineHelper {

private static final Logger LOGGER = LoggerFactory.getLogger(SslEngineHelper.class);

public static boolean doHandshake(SocketChannel socketChannel, SSLEngine engine) throws IOException {

ByteBuffer plainOut = ByteBuffer.allocate(engine.getSession().getApplicationBufferSize());
ByteBuffer plainIn = ByteBuffer.allocate(engine.getSession().getApplicationBufferSize());
ByteBuffer cipherOut = ByteBuffer.allocate(engine.getSession().getPacketBufferSize());
ByteBuffer cipherIn = ByteBuffer.allocate(engine.getSession().getPacketBufferSize());

LOGGER.debug("Starting TLS handshake");

SSLEngineResult.HandshakeStatus handshakeStatus = engine.getHandshakeStatus();
LOGGER.debug("Initial handshake status is {}", handshakeStatus);
while (handshakeStatus != FINISHED && handshakeStatus != NOT_HANDSHAKING) {
LOGGER.debug("Handshake status is {}", handshakeStatus);
switch (handshakeStatus) {
case NEED_TASK:
LOGGER.debug("Running tasks");
handshakeStatus = runDelegatedTasks(engine);
break;
case NEED_UNWRAP:
LOGGER.debug("Unwrapping...");
handshakeStatus = unwrap(cipherIn, plainIn, socketChannel, engine);
break;
case NEED_WRAP:
LOGGER.debug("Wrapping...");
handshakeStatus = wrap(plainOut, cipherOut, socketChannel, engine);
break;
case FINISHED:
break;
case NOT_HANDSHAKING:
break;
default:
throw new SSLException("Unexpected handshake status " + handshakeStatus);
}
}


LOGGER.debug("TLS handshake completed");
return true;
}

private static SSLEngineResult.HandshakeStatus runDelegatedTasks(SSLEngine sslEngine) {
// FIXME run in executor?
Runnable runnable;
while ((runnable = sslEngine.getDelegatedTask()) != null) {
LOGGER.debug("Running delegated task");
runnable.run();
}
return sslEngine.getHandshakeStatus();
Expand All @@ -68,29 +91,57 @@ private static SSLEngineResult.HandshakeStatus runDelegatedTasks(SSLEngine sslEn
private static SSLEngineResult.HandshakeStatus unwrap(ByteBuffer cipherIn, ByteBuffer plainIn,
ReadableByteChannel channel, SSLEngine sslEngine) throws IOException {
SSLEngineResult.HandshakeStatus handshakeStatus = sslEngine.getHandshakeStatus();

if (channel.read(cipherIn) < 0) {
throw new SSLException("Could not read from socket channel");
LOGGER.debug("Handshake status is {} before unwrapping", handshakeStatus);

LOGGER.debug("Cipher in position {}", cipherIn.position());
int read;
if (cipherIn.position() == 0) {
LOGGER.debug("Reading from channel");
read = channel.read(cipherIn);
LOGGER.debug("Read {} byte(s) from channel", read);
if (read < 0) {
throw new SSLException("Could not read from socket channel");
}
cipherIn.flip();
} else {
LOGGER.debug("Not reading");
}
cipherIn.flip();

SSLEngineResult.Status status;
SSLEngineResult unwrapResult;
do {
SSLEngineResult unwrapResult = sslEngine.unwrap(cipherIn, plainIn);
int positionBeforeUnwrapping = cipherIn.position();
unwrapResult = sslEngine.unwrap(cipherIn, plainIn);
LOGGER.debug("SSL engine result is {} after unwrapping", unwrapResult);
status = unwrapResult.getStatus();
switch (status) {
case OK:
plainIn.clear();
handshakeStatus = runDelegatedTasks(sslEngine);
if (unwrapResult.getHandshakeStatus() == NEED_TASK) {
handshakeStatus = runDelegatedTasks(sslEngine);
int newPosition = positionBeforeUnwrapping + unwrapResult.bytesConsumed();
if (newPosition == cipherIn.limit()) {
LOGGER.debug("Clearing cipherIn because all bytes have been read and unwrapped");
cipherIn.clear();
} else {
LOGGER.debug("Setting cipherIn position to {} (limit is {})", newPosition, cipherIn.limit());
cipherIn.position(positionBeforeUnwrapping + unwrapResult.bytesConsumed());
}
} else {
handshakeStatus = unwrapResult.getHandshakeStatus();
}
break;
case BUFFER_OVERFLOW:
throw new SSLException("Buffer overflow during handshake");
case BUFFER_UNDERFLOW:
LOGGER.debug("Buffer underflow");
cipherIn.compact();
int read = NioHelper.read(channel, cipherIn);
LOGGER.debug("Reading from channel...");
read = NioHelper.read(channel, cipherIn);
if(read <= 0) {
retryRead(channel, cipherIn);
}
LOGGER.debug("Done reading from channel...");
cipherIn.flip();
break;
case CLOSED:
Expand All @@ -100,9 +151,9 @@ private static SSLEngineResult.HandshakeStatus unwrap(ByteBuffer cipherIn, ByteB
throw new SSLException("Unexpected status from " + unwrapResult);
}
}
while (cipherIn.hasRemaining());
while (unwrapResult.getHandshakeStatus() != NEED_WRAP && unwrapResult.getHandshakeStatus() != FINISHED);

cipherIn.compact();
LOGGER.debug("cipherIn position after unwrap {}", cipherIn.position());
return handshakeStatus;
}

Expand All @@ -127,36 +178,32 @@ private static int retryRead(ReadableByteChannel channel, ByteBuffer buffer) thr
private static SSLEngineResult.HandshakeStatus wrap(ByteBuffer plainOut, ByteBuffer cipherOut,
WritableByteChannel channel, SSLEngine sslEngine) throws IOException {
SSLEngineResult.HandshakeStatus handshakeStatus = sslEngine.getHandshakeStatus();
SSLEngineResult.Status status = sslEngine.wrap(plainOut, cipherOut).getStatus();
switch (status) {
LOGGER.debug("Handshake status is {} before wrapping", handshakeStatus);
SSLEngineResult result = sslEngine.wrap(plainOut, cipherOut);
LOGGER.debug("SSL engine result is {} after wrapping", result);
switch (result.getStatus()) {
case OK:
handshakeStatus = runDelegatedTasks(sslEngine);
cipherOut.flip();
while (cipherOut.hasRemaining()) {
channel.write(cipherOut);
int written = channel.write(cipherOut);
LOGGER.debug("Wrote {} byte(s)", written);
}
cipherOut.clear();
if (result.getHandshakeStatus() == NEED_TASK) {
handshakeStatus = runDelegatedTasks(sslEngine);
} else {
handshakeStatus = result.getHandshakeStatus();
}

break;
case BUFFER_OVERFLOW:
throw new SSLException("Buffer overflow during handshake");
default:
throw new SSLException("Unexpected status " + status);
throw new SSLException("Unexpected status " + result.getStatus());
}
return handshakeStatus;
}

static int bufferCopy(ByteBuffer from, ByteBuffer to) {
int maxTransfer = Math.min(to.remaining(), from.remaining());

ByteBuffer temporaryBuffer = from.duplicate();
temporaryBuffer.limit(temporaryBuffer.position() + maxTransfer);
to.put(temporaryBuffer);

from.position(from.position() + maxTransfer);

return maxTransfer;
}

public static void write(WritableByteChannel socketChannel, SSLEngine engine, ByteBuffer plainOut, ByteBuffer cypherOut) throws IOException {
while (plainOut.hasRemaining()) {
cypherOut.clear();
Expand Down
7 changes: 1 addition & 6 deletions src/test/java/com/rabbitmq/client/test/BrokerTestCase.java
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2007-2020 VMware, Inc. or its affiliates. All rights reserved.
// Copyright (c) 2007-2021 VMware, Inc. or its affiliates. All rights reserved.
//
// This software, the RabbitMQ Java client library, is triple-licensed under the
// Mozilla Public License 2.0 ("MPL"), the GNU General Public License version 2
Expand Down Expand Up @@ -28,9 +28,7 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.net.ssl.SSLContext;
import java.io.IOException;
import java.security.NoSuchAlgorithmException;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.TimeoutException;
Expand Down Expand Up @@ -348,7 +346,4 @@ protected String generateExchangeName() {
return "exchange" + UUID.randomUUID().toString();
}

protected SSLContext getSSLContext() throws NoSuchAlgorithmException {
return TestUtils.getSSLContext();
}
}
45 changes: 24 additions & 21 deletions src/test/java/com/rabbitmq/client/test/TestUtils.java
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2007-2020 VMware, Inc. or its affiliates. All rights reserved.
// Copyright (c) 2007-2021 VMware, Inc. or its affiliates. All rights reserved.
//
// This software, the RabbitMQ Java client library, is triple-licensed under the
// Mozilla Public License 2.0 ("MPL"), the GNU General Public License version 2
Expand Down Expand Up @@ -26,19 +26,15 @@
import org.junit.runners.model.Statement;
import org.slf4j.LoggerFactory;

import javax.net.ssl.SSLContext;
import java.io.IOException;
import java.net.ServerSocket;
import java.security.NoSuchAlgorithmException;
import java.time.Duration;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.concurrent.Callable;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;
import java.util.function.BooleanSupplier;

import static org.junit.Assert.assertTrue;

Expand Down Expand Up @@ -109,22 +105,6 @@ public static void abort(Connection connection) {
}
}

public static SSLContext getSSLContext() throws NoSuchAlgorithmException {
SSLContext c = null;

// pick the first protocol available, preferring TLSv1.2, then TLSv1,
// falling back to SSLv3 if running on an ancient/crippled JDK
for (String proto : Arrays.asList("TLSv1.2", "TLSv1", "SSLv3")) {
try {
c = SSLContext.getInstance(proto);
return c;
} catch (NoSuchAlgorithmException x) {
// keep trying
}
}
throw new NoSuchAlgorithmException();
}

public static TestRule atLeast38() {
return new BrokerVersionTestRule("3.8.0");
}
Expand Down Expand Up @@ -361,4 +341,27 @@ public interface CallableFunction<T, R> {

}

public static boolean basicGetBasicConsume(Connection connection, String queue, final CountDownLatch latch, int msgSize)
throws Exception {
Channel channel = connection.createChannel();
channel.queueDeclare(queue, false, true, false, null);
channel.queuePurge(queue);

channel.basicPublish("", queue, null, new byte[msgSize]);

String tag = channel.basicConsume(queue, false, new DefaultConsumer(channel) {

@Override
public void handleDelivery(String consumerTag, Envelope envelope, AMQP.BasicProperties properties, byte[] body) throws IOException {
getChannel().basicAck(envelope.getDeliveryTag(), false);
latch.countDown();
}
});

boolean messageReceived = latch.await(20, TimeUnit.SECONDS);

channel.basicCancel(tag);

return messageReceived;
}
}
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (c) 2007-2020 VMware, Inc. or its affiliates. All rights reserved.
// Copyright (c) 2007-2021 VMware, Inc. or its affiliates. All rights reserved.
//
// This software, the RabbitMQ Java client library, is triple-licensed under the
// Mozilla Public License 2.0 ("MPL"), the GNU General Public License version 2
Expand All @@ -15,20 +15,13 @@

package com.rabbitmq.client.test.ssl;

import com.rabbitmq.client.test.TestUtils;
import org.junit.Test;

import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLContext;
import javax.net.ssl.SSLHandshakeException;
import javax.net.ssl.TrustManagerFactory;
import java.io.FileInputStream;
import java.io.IOException;
import java.security.*;
import java.security.cert.CertificateException;
import java.util.concurrent.TimeoutException;

import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.fail;

/**
Expand All @@ -39,44 +32,10 @@ public class BadVerifiedConnection extends UnverifiedConnection {
public void openConnection()
throws IOException, TimeoutException {
try {
String keystorePath = System.getProperty("test-keystore.empty");
assertNotNull(keystorePath);
String keystorePasswd = System.getProperty("test-keystore.password");
assertNotNull(keystorePasswd);
char [] keystorePassword = keystorePasswd.toCharArray();

KeyStore tks = KeyStore.getInstance("JKS");
tks.load(new FileInputStream(keystorePath), keystorePassword);

TrustManagerFactory tmf = TrustManagerFactory.getInstance("SunX509");
tmf.init(tks);

String p12Path = System.getProperty("test-client-cert.path");
assertNotNull(p12Path);
String p12Passwd = System.getProperty("test-client-cert.password");
assertNotNull(p12Passwd);
KeyStore ks = KeyStore.getInstance("PKCS12");
char [] p12Password = p12Passwd.toCharArray();
ks.load(new FileInputStream(p12Path), p12Password);

KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509");
kmf.init(ks, p12Password);

SSLContext c = getSSLContext();
c.init(kmf.getKeyManagers(), tmf.getTrustManagers(), null);

connectionFactory = TestUtils.connectionFactory();
SSLContext c = TlsTestUtils.badVerifiedSslContext();
connectionFactory.useSslProtocol(c);
} catch (NoSuchAlgorithmException ex) {
throw new IOException(ex.toString());
} catch (KeyManagementException ex) {
throw new IOException(ex.toString());
} catch (KeyStoreException ex) {
throw new IOException(ex.toString());
} catch (CertificateException ex) {
throw new IOException(ex.toString());
} catch (UnrecoverableKeyException ex) {
throw new IOException(ex.toString());
} catch (Exception ex) {
throw new IOException(ex);
}

try {
Expand Down
Loading