org.codehaus.gmaven
diff --git a/src/main/java/com/rabbitmq/client/ConnectionFactoryConfigurator.java b/src/main/java/com/rabbitmq/client/ConnectionFactoryConfigurator.java
index 4470760d75..9cd9c7ee31 100644
--- a/src/main/java/com/rabbitmq/client/ConnectionFactoryConfigurator.java
+++ b/src/main/java/com/rabbitmq/client/ConnectionFactoryConfigurator.java
@@ -18,36 +18,35 @@
import com.rabbitmq.client.impl.AMQConnection;
import com.rabbitmq.client.impl.nio.NioParams;
-import java.io.BufferedReader;
-import java.io.FileReader;
+import javax.net.ssl.*;
+import java.io.FileInputStream;
+import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
-import java.io.Reader;
import java.net.URISyntaxException;
-import java.security.KeyManagementException;
-import java.security.NoSuchAlgorithmException;
-import java.util.HashMap;
-import java.util.Map;
-import java.util.Properties;
+import java.security.*;
+import java.security.cert.CertificateException;
+import java.util.*;
+import java.util.concurrent.ConcurrentHashMap;
/**
* Helper class to load {@link ConnectionFactory} settings from a property file.
- *
+ *
* The authorised keys are the constants values in this class (e.g. USERNAME).
* The property file/properties instance/map instance keys can have
* a prefix, the default being rabbitmq.
.
- *
+ *
* Property files can be loaded from the file system (the default),
* but also from the classpath, by using the classpath:
prefix
* in the location.
- *
+ *
* Client properties can be set by using
* the client.properties.
prefix, e.g. client.properties.app.name
.
* Default client properties and custom client properties are merged. To remove
* a default client property, set its key to an empty value.
*
- * @since 4.4.0
* @see ConnectionFactory#load(String, String)
+ * @since 5.1.0
*/
public class ConnectionFactoryConfigurator {
@@ -76,6 +75,33 @@ public class ConnectionFactoryConfigurator {
public static final String NIO_NB_IO_THREADS = "nio.nb.io.threads";
public static final String NIO_WRITE_ENQUEUING_TIMEOUT_IN_MS = "nio.write.enqueuing.timeout.in.ms";
public static final String NIO_WRITE_QUEUE_CAPACITY = "nio.write.queue.capacity";
+ public static final String SSL_ALGORITHM = "ssl.algorithm";
+ public static final String SSL_ENABLED = "ssl.enabled";
+ public static final String SSL_KEY_STORE = "ssl.key.store";
+ public static final String SSL_KEY_STORE_PASSWORD = "ssl.key.store.password";
+ public static final String SSL_KEY_STORE_TYPE = "ssl.key.store.type";
+ public static final String SSL_KEY_STORE_ALGORITHM = "ssl.key.store.algorithm";
+ public static final String SSL_TRUST_STORE = "ssl.trust.store";
+ public static final String SSL_TRUST_STORE_PASSWORD = "ssl.trust.store.password";
+ public static final String SSL_TRUST_STORE_TYPE = "ssl.trust.store.type";
+ public static final String SSL_TRUST_STORE_ALGORITHM = "ssl.trust.store.algorithm";
+ public static final String SSL_VALIDATE_SERVER_CERTIFICATE = "ssl.validate.server.certificate";
+ public static final String SSL_VERIFY_HOSTNAME = "ssl.verify.hostname";
+
+ // aliases allow to be compatible with keys from Spring Boot and still be consistent with
+ // the initial naming of the keys
+ private static final Map> ALIASES = new ConcurrentHashMap>() {{
+ put(SSL_KEY_STORE, Arrays.asList("ssl.key-store"));
+ put(SSL_KEY_STORE_PASSWORD, Arrays.asList("ssl.key-store-password"));
+ put(SSL_KEY_STORE_TYPE, Arrays.asList("ssl.key-store-type"));
+ put(SSL_KEY_STORE_ALGORITHM, Arrays.asList("ssl.key-store-algorithm"));
+ put(SSL_TRUST_STORE, Arrays.asList("ssl.trust-store"));
+ put(SSL_TRUST_STORE_PASSWORD, Arrays.asList("ssl.trust-store-password"));
+ put(SSL_TRUST_STORE_TYPE, Arrays.asList("ssl.trust-store-type"));
+ put(SSL_TRUST_STORE_ALGORITHM, Arrays.asList("ssl.trust-store-algorithm"));
+ put(SSL_VALIDATE_SERVER_CERTIFICATE, Arrays.asList("ssl.validate-server-certificate"));
+ put(SSL_VERIFY_HOSTNAME, Arrays.asList("ssl.verify-hostname"));
+ }};
@SuppressWarnings("unchecked")
public static void load(ConnectionFactory cf, String propertyFileLocation, String prefix) throws IOException {
@@ -83,32 +109,22 @@ public static void load(ConnectionFactory cf, String propertyFileLocation, Strin
throw new IllegalArgumentException("Property file argument cannot be null or empty");
}
Properties properties = new Properties();
- if (propertyFileLocation.startsWith("classpath:")) {
- InputStream in = null;
- try {
- in = ConnectionFactoryConfigurator.class.getResourceAsStream(
- propertyFileLocation.substring("classpath:".length())
- );
- properties.load(in);
- } finally {
- if (in != null) {
- in.close();
- }
- }
- } else {
- Reader reader = null;
- try {
- reader = new BufferedReader(new FileReader(propertyFileLocation));
- properties.load(reader);
- } finally {
- if (reader != null) {
- reader.close();
- }
- }
+ try (InputStream in = loadResource(propertyFileLocation)) {
+ properties.load(in);
}
load(cf, (Map) properties, prefix);
}
+ private static InputStream loadResource(String location) throws FileNotFoundException {
+ if (location.startsWith("classpath:")) {
+ return ConnectionFactoryConfigurator.class.getResourceAsStream(
+ location.substring("classpath:".length())
+ );
+ } else {
+ return new FileInputStream(location);
+ }
+ }
+
public static void load(ConnectionFactory cf, Map properties, String prefix) {
prefix = prefix == null ? "" : prefix;
String uri = properties.get(prefix + "uri");
@@ -116,54 +132,54 @@ public static void load(ConnectionFactory cf, Map properties, St
try {
cf.setUri(uri);
} catch (URISyntaxException e) {
- throw new IllegalArgumentException("Error while setting AMQP URI: "+uri, e);
+ throw new IllegalArgumentException("Error while setting AMQP URI: " + uri, e);
} catch (NoSuchAlgorithmException e) {
- throw new IllegalArgumentException("Error while setting AMQP URI: "+uri, e);
+ throw new IllegalArgumentException("Error while setting AMQP URI: " + uri, e);
} catch (KeyManagementException e) {
- throw new IllegalArgumentException("Error while setting AMQP URI: "+uri, e);
+ throw new IllegalArgumentException("Error while setting AMQP URI: " + uri, e);
}
}
- String username = properties.get(prefix + USERNAME);
+ String username = lookUp(USERNAME, properties, prefix);
if (username != null) {
cf.setUsername(username);
}
- String password = properties.get(prefix + PASSWORD);
+ String password = lookUp(PASSWORD, properties, prefix);
if (password != null) {
cf.setPassword(password);
}
- String vhost = properties.get(prefix + VIRTUAL_HOST);
+ String vhost = lookUp(VIRTUAL_HOST, properties, prefix);
if (vhost != null) {
cf.setVirtualHost(vhost);
}
- String host = properties.get(prefix + HOST);
+ String host = lookUp(HOST, properties, prefix);
if (host != null) {
cf.setHost(host);
}
- String port = properties.get(prefix + PORT);
+ String port = lookUp(PORT, properties, prefix);
if (port != null) {
cf.setPort(Integer.valueOf(port));
}
- String requestedChannelMax = properties.get(prefix + CONNECTION_CHANNEL_MAX);
+ String requestedChannelMax = lookUp(CONNECTION_CHANNEL_MAX, properties, prefix);
if (requestedChannelMax != null) {
cf.setRequestedChannelMax(Integer.valueOf(requestedChannelMax));
}
- String requestedFrameMax = properties.get(prefix + CONNECTION_FRAME_MAX);
+ String requestedFrameMax = lookUp(CONNECTION_FRAME_MAX, properties, prefix);
if (requestedFrameMax != null) {
cf.setRequestedFrameMax(Integer.valueOf(requestedFrameMax));
}
- String requestedHeartbeat = properties.get(prefix + CONNECTION_HEARTBEAT);
+ String requestedHeartbeat = lookUp(CONNECTION_HEARTBEAT, properties, prefix);
if (requestedHeartbeat != null) {
cf.setRequestedHeartbeat(Integer.valueOf(requestedHeartbeat));
}
- String connectionTimeout = properties.get(prefix + CONNECTION_TIMEOUT);
+ String connectionTimeout = lookUp(CONNECTION_TIMEOUT, properties, prefix);
if (connectionTimeout != null) {
cf.setConnectionTimeout(Integer.valueOf(connectionTimeout));
}
- String handshakeTimeout = properties.get(prefix + HANDSHAKE_TIMEOUT);
+ String handshakeTimeout = lookUp(HANDSHAKE_TIMEOUT, properties, prefix);
if (handshakeTimeout != null) {
cf.setHandshakeTimeout(Integer.valueOf(handshakeTimeout));
}
- String shutdownTimeout = properties.get(prefix + SHUTDOWN_TIMEOUT);
+ String shutdownTimeout = lookUp(SHUTDOWN_TIMEOUT, properties, prefix);
if (shutdownTimeout != null) {
cf.setShutdownTimeout(Integer.valueOf(shutdownTimeout));
}
@@ -180,63 +196,175 @@ public static void load(ConnectionFactory cf, Map properties, St
clientProperties.remove(clientPropertyKey);
} else {
clientProperties.put(
- clientPropertyKey,
- entry.getValue()
+ clientPropertyKey,
+ entry.getValue()
);
}
}
}
cf.setClientProperties(clientProperties);
- String automaticRecovery = properties.get(prefix + CONNECTION_RECOVERY_ENABLED);
+ String automaticRecovery = lookUp(CONNECTION_RECOVERY_ENABLED, properties, prefix);
if (automaticRecovery != null) {
cf.setAutomaticRecoveryEnabled(Boolean.valueOf(automaticRecovery));
}
- String topologyRecovery = properties.get(prefix + TOPOLOGY_RECOVERY_ENABLED);
+ String topologyRecovery = lookUp(TOPOLOGY_RECOVERY_ENABLED, properties, prefix);
if (topologyRecovery != null) {
cf.setTopologyRecoveryEnabled(Boolean.getBoolean(topologyRecovery));
}
- String networkRecoveryInterval = properties.get(prefix + CONNECTION_RECOVERY_INTERVAL);
+ String networkRecoveryInterval = lookUp(CONNECTION_RECOVERY_INTERVAL, properties, prefix);
if (networkRecoveryInterval != null) {
cf.setNetworkRecoveryInterval(Long.valueOf(networkRecoveryInterval));
}
- String channelRpcTimeout = properties.get(prefix + CHANNEL_RPC_TIMEOUT);
+ String channelRpcTimeout = lookUp(CHANNEL_RPC_TIMEOUT, properties, prefix);
if (channelRpcTimeout != null) {
cf.setChannelRpcTimeout(Integer.valueOf(channelRpcTimeout));
}
- String channelShouldCheckRpcResponseType = properties.get(prefix + CHANNEL_SHOULD_CHECK_RPC_RESPONSE_TYPE);
+ String channelShouldCheckRpcResponseType = lookUp(CHANNEL_SHOULD_CHECK_RPC_RESPONSE_TYPE, properties, prefix);
if (channelShouldCheckRpcResponseType != null) {
cf.setChannelShouldCheckRpcResponseType(Boolean.valueOf(channelShouldCheckRpcResponseType));
}
- String useNio = properties.get(prefix + USE_NIO);
+ String useNio = lookUp(USE_NIO, properties, prefix);
if (useNio != null && Boolean.valueOf(useNio)) {
cf.useNio();
NioParams nioParams = new NioParams();
- String readByteBufferSize = properties.get(prefix + NIO_READ_BYTE_BUFFER_SIZE);
+ String readByteBufferSize = lookUp(NIO_READ_BYTE_BUFFER_SIZE, properties, prefix);
if (readByteBufferSize != null) {
nioParams.setReadByteBufferSize(Integer.valueOf(readByteBufferSize));
}
- String writeByteBufferSize = properties.get(prefix + NIO_WRITE_BYTE_BUFFER_SIZE);
+ String writeByteBufferSize = lookUp(NIO_WRITE_BYTE_BUFFER_SIZE, properties, prefix);
if (writeByteBufferSize != null) {
nioParams.setWriteByteBufferSize(Integer.valueOf(writeByteBufferSize));
}
- String nbIoThreads = properties.get(prefix + NIO_NB_IO_THREADS);
+ String nbIoThreads = lookUp(NIO_NB_IO_THREADS, properties, prefix);
if (nbIoThreads != null) {
nioParams.setNbIoThreads(Integer.valueOf(nbIoThreads));
}
- String writeEnqueuingTime = properties.get(prefix + NIO_WRITE_ENQUEUING_TIMEOUT_IN_MS);
+ String writeEnqueuingTime = lookUp(NIO_WRITE_ENQUEUING_TIMEOUT_IN_MS, properties, prefix);
if (writeEnqueuingTime != null) {
nioParams.setWriteEnqueuingTimeoutInMs(Integer.valueOf(writeEnqueuingTime));
}
- String writeQueueCapacity = properties.get(prefix + NIO_WRITE_QUEUE_CAPACITY);
+ String writeQueueCapacity = lookUp(NIO_WRITE_QUEUE_CAPACITY, properties, prefix);
if (writeQueueCapacity != null) {
nioParams.setWriteQueueCapacity(Integer.valueOf(writeQueueCapacity));
}
cf.setNioParams(nioParams);
}
+
+ String useSsl = lookUp(SSL_ENABLED, properties, prefix);
+ if (useSsl != null && Boolean.valueOf(useSsl)) {
+ setUpSsl(cf, properties, prefix);
+ }
+ }
+
+ private static void setUpSsl(ConnectionFactory cf, Map properties, String prefix) {
+ String algorithm = lookUp(SSL_ALGORITHM, properties, prefix);
+ String keyStoreLocation = lookUp(SSL_KEY_STORE, properties, prefix);
+ String keyStorePassword = lookUp(SSL_KEY_STORE_PASSWORD, properties, prefix);
+ String keyStoreType = lookUp(SSL_KEY_STORE_TYPE, properties, prefix, "PKCS12");
+ String keyStoreAlgorithm = lookUp(SSL_KEY_STORE_ALGORITHM, properties, prefix, "SunX509");
+ String trustStoreLocation = lookUp(SSL_TRUST_STORE, properties, prefix);
+ String trustStorePassword = lookUp(SSL_TRUST_STORE_PASSWORD, properties, prefix);
+ String trustStoreType = lookUp(SSL_TRUST_STORE_TYPE, properties, prefix, "JKS");
+ String trustStoreAlgorithm = lookUp(SSL_TRUST_STORE_ALGORITHM, properties, prefix, "SunX509");
+ String validateServerCertificate = lookUp(SSL_VALIDATE_SERVER_CERTIFICATE, properties, prefix);
+ String verifyHostname = lookUp(SSL_VERIFY_HOSTNAME, properties, prefix);
+
+ try {
+ algorithm = algorithm == null ?
+ ConnectionFactory.computeDefaultTlsProtocol(SSLContext.getDefault().getSupportedSSLParameters().getProtocols()) : algorithm;
+ boolean enableHostnameVerification = verifyHostname == null ? Boolean.FALSE : Boolean.valueOf(verifyHostname);
+
+ if (keyStoreLocation == null && trustStoreLocation == null) {
+ setUpBasicSsl(
+ cf,
+ validateServerCertificate == null ? Boolean.FALSE : Boolean.valueOf(validateServerCertificate),
+ enableHostnameVerification,
+ algorithm
+ );
+ } else {
+ KeyManager[] keyManagers = configureKeyManagers(keyStoreLocation, keyStorePassword, keyStoreType, keyStoreAlgorithm);
+ TrustManager[] trustManagers = configureTrustManagers(trustStoreLocation, trustStorePassword, trustStoreType, trustStoreAlgorithm);
+
+ // create ssl context
+ SSLContext sslContext = SSLContext.getInstance(algorithm);
+ sslContext.init(keyManagers, trustManagers, null);
+
+ cf.useSslProtocol(sslContext);
+
+ if (enableHostnameVerification) {
+ cf.enableHostnameVerification();
+ }
+ }
+ } catch (NoSuchAlgorithmException | IOException | CertificateException |
+ UnrecoverableKeyException | KeyStoreException | KeyManagementException e) {
+ throw new IllegalStateException("Error while configuring TLS", e);
+ }
+ }
+
+ private static KeyManager[] configureKeyManagers(String keystore, String keystorePassword, String keystoreType, String keystoreAlgorithm) throws KeyStoreException, IOException, NoSuchAlgorithmException,
+ CertificateException, UnrecoverableKeyException {
+ char[] keyPassphrase = null;
+ if (keystorePassword != null) {
+ keyPassphrase = keystorePassword.toCharArray();
+ }
+ KeyManager[] keyManagers = null;
+ if (keystore != null) {
+ KeyStore ks = KeyStore.getInstance(keystoreType);
+ try (InputStream in = loadResource(keystore)) {
+ ks.load(in, keyPassphrase);
+ }
+ KeyManagerFactory kmf = KeyManagerFactory.getInstance(keystoreAlgorithm);
+ kmf.init(ks, keyPassphrase);
+ keyManagers = kmf.getKeyManagers();
+ }
+ return keyManagers;
+ }
+
+ private static TrustManager[] configureTrustManagers(String truststore, String truststorePassword, String truststoreType, String truststoreAlgorithm)
+ throws KeyStoreException, IOException, NoSuchAlgorithmException, CertificateException {
+ char[] trustPassphrase = null;
+ if (truststorePassword != null) {
+ trustPassphrase = truststorePassword.toCharArray();
+ }
+ TrustManager[] trustManagers = null;
+ if (truststore != null) {
+ KeyStore tks = KeyStore.getInstance(truststoreType);
+ try (InputStream in = loadResource(truststore)) {
+ tks.load(in, trustPassphrase);
+ }
+ TrustManagerFactory tmf = TrustManagerFactory.getInstance(truststoreAlgorithm);
+ tmf.init(tks);
+ trustManagers = tmf.getTrustManagers();
+ }
+ return trustManagers;
+ }
+
+ private static void setUpBasicSsl(ConnectionFactory cf, boolean validateServerCertificate, boolean verifyHostname, String sslAlgorithm) throws KeyManagementException, NoSuchAlgorithmException, KeyStoreException {
+ if (validateServerCertificate) {
+ useDefaultTrustStore(cf, sslAlgorithm, verifyHostname);
+ } else {
+ if (sslAlgorithm == null) {
+ cf.useSslProtocol();
+ } else {
+ cf.useSslProtocol(sslAlgorithm);
+ }
+ }
+ }
+
+ private static void useDefaultTrustStore(ConnectionFactory cf, String sslAlgorithm, boolean verifyHostname) throws NoSuchAlgorithmException, KeyStoreException, KeyManagementException {
+ SSLContext sslContext = SSLContext.getInstance(sslAlgorithm);
+ TrustManagerFactory trustManagerFactory =
+ TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm());
+ trustManagerFactory.init((KeyStore) null);
+ sslContext.init(null, trustManagerFactory.getTrustManagers(), null);
+ cf.useSslProtocol(sslContext);
+ if (verifyHostname) {
+ cf.enableHostnameVerification();
+ }
}
public static void load(ConnectionFactory connectionFactory, String propertyFileLocation) throws IOException {
@@ -256,4 +384,21 @@ public static void load(ConnectionFactory connectionFactory, Properties properti
public static void load(ConnectionFactory connectionFactory, Map properties) {
load(connectionFactory, properties, DEFAULT_PREFIX);
}
+
+ public static String lookUp(String key, Map properties, String prefix) {
+ return lookUp(key, properties, prefix, null);
+ }
+
+ public static String lookUp(String key, Map properties, String prefix, String defaultValue) {
+ String value = properties.get(prefix + key);
+ if (value == null) {
+ value = ALIASES.getOrDefault(key, Collections.emptyList()).stream()
+ .map(alias -> properties.get(prefix + alias))
+ .filter(v -> v != null)
+ .findFirst().orElse(defaultValue);
+ }
+ return value;
+ }
+
+
}
diff --git a/src/test/java/com/rabbitmq/client/test/PropertyFileInitialisationTest.java b/src/test/java/com/rabbitmq/client/test/PropertyFileInitialisationTest.java
index 0c5df5823f..2a140721a5 100644
--- a/src/test/java/com/rabbitmq/client/test/PropertyFileInitialisationTest.java
+++ b/src/test/java/com/rabbitmq/client/test/PropertyFileInitialisationTest.java
@@ -18,66 +18,64 @@
import com.rabbitmq.client.ConnectionFactory;
import com.rabbitmq.client.ConnectionFactoryConfigurator;
import org.junit.Test;
-import org.junit.runner.RunWith;
-import org.junit.runners.Parameterized;
+import javax.net.ssl.SSLContext;
import java.io.FileReader;
import java.io.IOException;
import java.io.Reader;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.Map;
-import java.util.Properties;
+import java.util.*;
+import java.util.concurrent.atomic.AtomicBoolean;
+import java.util.stream.Stream;
import static com.rabbitmq.client.impl.AMQConnection.defaultClientProperties;
import static org.assertj.core.api.Assertions.assertThat;
+import static org.mockito.ArgumentMatchers.any;
+import static org.mockito.Mockito.*;
/**
*
*/
-@RunWith(Parameterized.class)
public class PropertyFileInitialisationTest {
- @Parameterized.Parameters
- public static Object[] data() {
- return new Object[] {
- "./src/test/resources/property-file-initialisation/configuration.properties",
- "classpath:/property-file-initialisation/configuration.properties"
- };
- }
-
- @Parameterized.Parameter
- public String propertyFileLocation;
-
ConnectionFactory cf = new ConnectionFactory();
- @Test public void propertyInitialisationFromFile() throws IOException {
- cf.load(propertyFileLocation);
- checkConnectionFactory();
+ @Test
+ public void propertyInitialisationFromFile() throws IOException {
+ for (String propertyFileLocation : Arrays.asList(
+ "./src/test/resources/property-file-initialisation/configuration.properties",
+ "classpath:/property-file-initialisation/configuration.properties")) {
+ ConnectionFactory connectionFactory = new ConnectionFactory();
+ connectionFactory.load(propertyFileLocation);
+ checkConnectionFactory(connectionFactory);
+ }
}
- @Test public void propertyInitialisationCustomPrefix() throws Exception {
+ @Test
+ public void propertyInitialisationCustomPrefix() throws Exception {
Properties propertiesCustomPrefix = getPropertiesWitPrefix("prefix.");
cf.load(propertiesCustomPrefix, "prefix.");
checkConnectionFactory();
}
- @Test public void propertyInitialisationNoPrefix() throws Exception {
+ @Test
+ public void propertyInitialisationNoPrefix() throws Exception {
Properties propertiesCustomPrefix = getPropertiesWitPrefix("");
cf.load(propertiesCustomPrefix, "");
checkConnectionFactory();
}
- @Test public void propertyInitialisationNullPrefix() throws Exception {
+ @Test
+ public void propertyInitialisationNullPrefix() throws Exception {
Properties propertiesCustomPrefix = getPropertiesWitPrefix("");
cf.load(propertiesCustomPrefix, null);
checkConnectionFactory();
}
- @Test public void propertyInitialisationUri() {
+ @Test
+ public void propertyInitialisationUri() {
cf.load(Collections.singletonMap("rabbitmq.uri", "amqp://foo:bar@127.0.0.1:5673/dummy"));
assertThat(cf.getUsername()).isEqualTo("foo");
@@ -87,12 +85,14 @@ public static Object[] data() {
assertThat(cf.getPort()).isEqualTo(5673);
}
- @Test public void propertyInitialisationIncludeDefaultClientPropertiesByDefault() {
+ @Test
+ public void propertyInitialisationIncludeDefaultClientPropertiesByDefault() {
cf.load(new HashMap<>());
assertThat(cf.getClientProperties().entrySet()).hasSize(defaultClientProperties().size());
}
- @Test public void propertyInitialisationAddCustomClientProperty() {
+ @Test
+ public void propertyInitialisationAddCustomClientProperty() {
cf.load(new HashMap() {{
put("rabbitmq.client.properties.foo", "bar");
}});
@@ -100,7 +100,8 @@ public static Object[] data() {
assertThat(cf.getClientProperties()).extracting("foo").isEqualTo("bar");
}
- @Test public void propertyInitialisationGetRidOfDefaultClientPropertyWithEmptyValue() {
+ @Test
+ public void propertyInitialisationGetRidOfDefaultClientPropertyWithEmptyValue() {
final String key = defaultClientProperties().entrySet().iterator().next().getKey();
cf.load(new HashMap() {{
put("rabbitmq.client.properties." + key, "");
@@ -108,7 +109,8 @@ public static Object[] data() {
assertThat(cf.getClientProperties().entrySet()).hasSize(defaultClientProperties().size() - 1);
}
- @Test public void propertyInitialisationOverrideDefaultClientProperty() {
+ @Test
+ public void propertyInitialisationOverrideDefaultClientProperty() {
final String key = defaultClientProperties().entrySet().iterator().next().getKey();
cf.load(new HashMap() {{
put("rabbitmq.client.properties." + key, "whatever");
@@ -117,7 +119,8 @@ public static Object[] data() {
assertThat(cf.getClientProperties()).extracting(key).isEqualTo("whatever");
}
- @Test public void propertyInitialisationDoNotUseNio() throws Exception {
+ @Test
+ public void propertyInitialisationDoNotUseNio() throws Exception {
cf.load(new HashMap() {{
put("rabbitmq.use.nio", "false");
put("rabbitmq.nio.nb.io.threads", "2");
@@ -125,34 +128,150 @@ public static Object[] data() {
assertThat(cf.getNioParams().getNbIoThreads()).isNotEqualTo(2);
}
+ @Test
+ public void lookUp() {
+ assertThat(ConnectionFactoryConfigurator.lookUp(
+ ConnectionFactoryConfigurator.SSL_KEY_STORE,
+ Collections.singletonMap(ConnectionFactoryConfigurator.SSL_KEY_STORE, "some file"),
+ ""
+ )).as("exact key should be looked up").isEqualTo("some file");
+
+ assertThat(ConnectionFactoryConfigurator.lookUp(
+ ConnectionFactoryConfigurator.SSL_KEY_STORE,
+ Collections.emptyMap(),
+ ""
+ )).as("lookup should return null when no match").isNull();
+
+ assertThat(ConnectionFactoryConfigurator.lookUp(
+ ConnectionFactoryConfigurator.SSL_KEY_STORE,
+ Collections.singletonMap("ssl.key-store", "some file"), // key alias
+ ""
+ )).as("alias key should be used when initial is missing").isEqualTo("some file");
+
+ assertThat(ConnectionFactoryConfigurator.lookUp(
+ ConnectionFactoryConfigurator.SSL_TRUST_STORE_TYPE,
+ Collections.emptyMap(),
+ "",
+ "JKS"
+ )).as("default value should be returned when key is not found").isEqualTo("JKS");
+ }
+
+ @Test
+ public void tlsInitialisationWithKeyManagerAndTrustManagerShouldSucceed() {
+ Stream.of("./src/test/resources/property-file-initialisation/tls/",
+ "classpath:/property-file-initialisation/tls/").forEach(baseDirectory -> {
+ Map configuration = new HashMap<>();
+ configuration.put(ConnectionFactoryConfigurator.SSL_ENABLED, "true");
+ configuration.put(ConnectionFactoryConfigurator.SSL_KEY_STORE, baseDirectory + "keystore.p12");
+ configuration.put(ConnectionFactoryConfigurator.SSL_KEY_STORE_PASSWORD, "bunnies");
+ configuration.put(ConnectionFactoryConfigurator.SSL_KEY_STORE_TYPE, "PKCS12");
+ configuration.put(ConnectionFactoryConfigurator.SSL_KEY_STORE_ALGORITHM, "SunX509");
+
+ configuration.put(ConnectionFactoryConfigurator.SSL_TRUST_STORE, baseDirectory + "truststore.jks");
+ configuration.put(ConnectionFactoryConfigurator.SSL_TRUST_STORE_PASSWORD, "bunnies");
+ configuration.put(ConnectionFactoryConfigurator.SSL_TRUST_STORE_TYPE, "JKS");
+ configuration.put(ConnectionFactoryConfigurator.SSL_TRUST_STORE_ALGORITHM, "SunX509");
+
+ configuration.put(ConnectionFactoryConfigurator.SSL_VERIFY_HOSTNAME, "true");
+
+ ConnectionFactory connectionFactory = mock(ConnectionFactory.class);
+ ConnectionFactoryConfigurator.load(connectionFactory, configuration, "");
+
+ verify(connectionFactory, times(1)).useSslProtocol(any(SSLContext.class));
+ verify(connectionFactory, times(1)).enableHostnameVerification();
+ });
+ }
+
+ @Test
+ public void tlsNotEnabledIfNotConfigured() {
+ ConnectionFactory connectionFactory = mock(ConnectionFactory.class);
+ ConnectionFactoryConfigurator.load(connectionFactory, Collections.emptyMap(), "");
+ verify(connectionFactory, never()).useSslProtocol(any(SSLContext.class));
+ }
+
+ @Test
+ public void tlsNotEnabledIfDisabled() {
+ ConnectionFactory connectionFactory = mock(ConnectionFactory.class);
+ ConnectionFactoryConfigurator.load(
+ connectionFactory,
+ Collections.singletonMap(ConnectionFactoryConfigurator.SSL_ENABLED, "false"),
+ ""
+ );
+ verify(connectionFactory, never()).useSslProtocol(any(SSLContext.class));
+ }
+
+ @Test
+ public void tlsSslContextSetIfTlsEnabled() {
+ AtomicBoolean sslProtocolSet = new AtomicBoolean(false);
+ ConnectionFactory connectionFactory = new ConnectionFactory() {
+ @Override
+ public void useSslProtocol(SSLContext context) {
+ sslProtocolSet.set(true);
+ super.useSslProtocol(context);
+ }
+ };
+ ConnectionFactoryConfigurator.load(
+ connectionFactory,
+ Collections.singletonMap(ConnectionFactoryConfigurator.SSL_ENABLED, "true"),
+ ""
+ );
+ assertThat(sslProtocolSet).isTrue();
+ }
+
+ @Test
+ public void tlsBasicSetupShouldTrustEveryoneWhenServerValidationIsNotEnabled() throws Exception {
+ String algorithm = ConnectionFactory.computeDefaultTlsProtocol(SSLContext.getDefault().getSupportedSSLParameters().getProtocols());
+ Map configuration = new HashMap<>();
+ configuration.put(ConnectionFactoryConfigurator.SSL_ENABLED, "true");
+ configuration.put(ConnectionFactoryConfigurator.SSL_VALIDATE_SERVER_CERTIFICATE, "false");
+ ConnectionFactory connectionFactory = mock(ConnectionFactory.class);
+ ConnectionFactoryConfigurator.load(connectionFactory, configuration, "");
+ verify(connectionFactory, times(1)).useSslProtocol(algorithm);
+ }
+
+ @Test
+ public void tlsBasicSetupShouldSetDefaultTrustManagerWhenServerValidationIsEnabled() throws Exception {
+ Map configuration = new HashMap<>();
+ configuration.put(ConnectionFactoryConfigurator.SSL_ENABLED, "true");
+ configuration.put(ConnectionFactoryConfigurator.SSL_VALIDATE_SERVER_CERTIFICATE, "true");
+ ConnectionFactory connectionFactory = mock(ConnectionFactory.class);
+ ConnectionFactoryConfigurator.load(connectionFactory, configuration, "");
+ verify(connectionFactory, never()).useSslProtocol(anyString());
+ verify(connectionFactory, times(1)).useSslProtocol(any(SSLContext.class));
+ }
+
private void checkConnectionFactory() {
- assertThat(cf.getUsername()).isEqualTo("foo");
- assertThat(cf.getPassword()).isEqualTo("bar");
- assertThat(cf.getVirtualHost()).isEqualTo("dummy");
- assertThat(cf.getHost()).isEqualTo("127.0.0.1");
- assertThat(cf.getPort()).isEqualTo(5673);
+ checkConnectionFactory(this.cf);
+ }
- assertThat(cf.getRequestedChannelMax()).isEqualTo(1);
- assertThat(cf.getRequestedFrameMax()).isEqualTo(2);
- assertThat(cf.getRequestedHeartbeat()).isEqualTo(10);
- assertThat(cf.getConnectionTimeout()).isEqualTo(10000);
- assertThat(cf.getHandshakeTimeout()).isEqualTo(5000);
+ private void checkConnectionFactory(ConnectionFactory connectionFactory) {
+ assertThat(connectionFactory.getUsername()).isEqualTo("foo");
+ assertThat(connectionFactory.getPassword()).isEqualTo("bar");
+ assertThat(connectionFactory.getVirtualHost()).isEqualTo("dummy");
+ assertThat(connectionFactory.getHost()).isEqualTo("127.0.0.1");
+ assertThat(connectionFactory.getPort()).isEqualTo(5673);
- assertThat(cf.getClientProperties().entrySet()).hasSize(defaultClientProperties().size() + 1);
- assertThat(cf.getClientProperties()).extracting("foo").isEqualTo("bar");
+ assertThat(connectionFactory.getRequestedChannelMax()).isEqualTo(1);
+ assertThat(connectionFactory.getRequestedFrameMax()).isEqualTo(2);
+ assertThat(connectionFactory.getRequestedHeartbeat()).isEqualTo(10);
+ assertThat(connectionFactory.getConnectionTimeout()).isEqualTo(10000);
+ assertThat(connectionFactory.getHandshakeTimeout()).isEqualTo(5000);
+
+ assertThat(connectionFactory.getClientProperties().entrySet()).hasSize(defaultClientProperties().size() + 1);
+ assertThat(connectionFactory.getClientProperties()).extracting("foo").isEqualTo("bar");
- assertThat(cf.isAutomaticRecoveryEnabled()).isFalse();
- assertThat(cf.isTopologyRecoveryEnabled()).isFalse();
- assertThat(cf.getNetworkRecoveryInterval()).isEqualTo(10000l);
- assertThat(cf.getChannelRpcTimeout()).isEqualTo(10000);
- assertThat(cf.isChannelShouldCheckRpcResponseType()).isTrue();
+ assertThat(connectionFactory.isAutomaticRecoveryEnabled()).isFalse();
+ assertThat(connectionFactory.isTopologyRecoveryEnabled()).isFalse();
+ assertThat(connectionFactory.getNetworkRecoveryInterval()).isEqualTo(10000l);
+ assertThat(connectionFactory.getChannelRpcTimeout()).isEqualTo(10000);
+ assertThat(connectionFactory.isChannelShouldCheckRpcResponseType()).isTrue();
- assertThat(cf.getNioParams()).isNotNull();
- assertThat(cf.getNioParams().getReadByteBufferSize()).isEqualTo(32000);
- assertThat(cf.getNioParams().getWriteByteBufferSize()).isEqualTo(32000);
- assertThat(cf.getNioParams().getNbIoThreads()).isEqualTo(2);
- assertThat(cf.getNioParams().getWriteEnqueuingTimeoutInMs()).isEqualTo(5000);
- assertThat(cf.getNioParams().getWriteQueueCapacity()).isEqualTo(1000);
+ assertThat(connectionFactory.getNioParams()).isNotNull();
+ assertThat(connectionFactory.getNioParams().getReadByteBufferSize()).isEqualTo(32000);
+ assertThat(connectionFactory.getNioParams().getWriteByteBufferSize()).isEqualTo(32000);
+ assertThat(connectionFactory.getNioParams().getNbIoThreads()).isEqualTo(2);
+ assertThat(connectionFactory.getNioParams().getWriteEnqueuingTimeoutInMs()).isEqualTo(5000);
+ assertThat(connectionFactory.getNioParams().getWriteQueueCapacity()).isEqualTo(1000);
}
private Properties getPropertiesWitPrefix(String prefix) throws IOException {
@@ -168,8 +287,8 @@ private Properties getPropertiesWitPrefix(String prefix) throws IOException {
Properties propertiesCustomPrefix = new Properties();
for (Map.Entry entry : properties.entrySet()) {
propertiesCustomPrefix.put(
- prefix + entry.getKey().toString().substring(ConnectionFactoryConfigurator.DEFAULT_PREFIX.length()),
- entry.getValue()
+ prefix + entry.getKey().toString().substring(ConnectionFactoryConfigurator.DEFAULT_PREFIX.length()),
+ entry.getValue()
);
}
return propertiesCustomPrefix;
diff --git a/src/test/resources/property-file-initialisation/tls/keystore.p12 b/src/test/resources/property-file-initialisation/tls/keystore.p12
new file mode 100644
index 0000000000..a5280a6cbf
Binary files /dev/null and b/src/test/resources/property-file-initialisation/tls/keystore.p12 differ
diff --git a/src/test/resources/property-file-initialisation/tls/truststore.jks b/src/test/resources/property-file-initialisation/tls/truststore.jks
new file mode 100644
index 0000000000..4dd357d17a
Binary files /dev/null and b/src/test/resources/property-file-initialisation/tls/truststore.jks differ