Skip to content

Commit ce3a04c

Browse files
Merge pull request #273 from rabbitmq/rabbitmq-java-client-241
Make SslContext creation more flexible
2 parents 0e97491 + c8ea835 commit ce3a04c

11 files changed

+292
-33
lines changed

src/main/java/com/rabbitmq/client/ConnectionFactory.java

+47-16
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,11 @@
3636
import static java.util.concurrent.TimeUnit.*;
3737

3838
/**
39-
* Convenience "factory" class to facilitate opening a {@link Connection} to an AMQP broker.
39+
* Convenience factory class to facilitate opening a {@link Connection} to a RabbitMQ node.
40+
*
41+
* Most connection and socket settings are configured using this factory.
42+
* Some settings that apply to connections can also be configured here
43+
* and will apply to all connections produced by this factory.
4044
*/
4145

4246
public class ConnectionFactory implements Cloneable {
@@ -94,7 +98,7 @@ public class ConnectionFactory implements Cloneable {
9498
private int handshakeTimeout = DEFAULT_HANDSHAKE_TIMEOUT;
9599
private int shutdownTimeout = DEFAULT_SHUTDOWN_TIMEOUT;
96100
private Map<String, Object> _clientProperties = AMQConnection.defaultClientProperties();
97-
private SocketFactory factory = SocketFactory.getDefault();
101+
private SocketFactory socketFactory = null;
98102
private SaslConfig saslConfig = DefaultSaslConfig.PLAIN;
99103
private ExecutorService sharedExecutor;
100104
private ThreadFactory threadFactory = Executors.defaultThreadFactory();
@@ -119,7 +123,7 @@ public class ConnectionFactory implements Cloneable {
119123
private FrameHandlerFactory frameHandlerFactory;
120124
private NioParams nioParams = new NioParams();
121125

122-
private SSLContext sslContext;
126+
private SslContextFactory sslContextFactory;
123127

124128
/**
125129
* Continuation timeout on RPC calls.
@@ -442,18 +446,19 @@ public void setSaslConfig(SaslConfig saslConfig) {
442446
* Retrieve the socket factory used to make connections with.
443447
*/
444448
public SocketFactory getSocketFactory() {
445-
return this.factory;
449+
return this.socketFactory;
446450
}
447451

448452
/**
449-
* Set the socket factory used to make connections with. Can be
450-
* used to enable SSL connections by passing in a
453+
* Set the socket factory used to create sockets for new connections. Can be
454+
* used to customize TLS-related settings by passing in a
451455
* javax.net.ssl.SSLSocketFactory instance.
452-
*
456+
* Note this applies only to blocking IO, not to
457+
* NIO, as the NIO API doesn't use the SocketFactory API.
453458
* @see #useSslProtocol
454459
*/
455460
public void setSocketFactory(SocketFactory factory) {
456-
this.factory = factory;
461+
this.socketFactory = factory;
457462
}
458463

459464
/**
@@ -556,7 +561,7 @@ public void setExceptionHandler(ExceptionHandler exceptionHandler) {
556561
}
557562

558563
public boolean isSSL(){
559-
return getSocketFactory() instanceof SSLSocketFactory;
564+
return getSocketFactory() instanceof SSLSocketFactory || sslContextFactory != null;
560565
}
561566

562567
/**
@@ -572,6 +577,10 @@ public void useSslProtocol()
572577
/**
573578
* Convenience method for setting up a SSL socket factory/engine, using
574579
* the supplied protocol and a very trusting TrustManager.
580+
* The produced {@link SSLContext} instance will be shared by all
581+
* the connections created by this connection factory. Use
582+
* {@link #setSslContextFactory(SslContextFactory)} for more flexibility.
583+
* @see #setSslContextFactory(SslContextFactory)
575584
*/
576585
public void useSslProtocol(String protocol)
577586
throws NoSuchAlgorithmException, KeyManagementException
@@ -582,8 +591,11 @@ public void useSslProtocol(String protocol)
582591
/**
583592
* Convenience method for setting up an SSL socket factory/engine.
584593
* Pass in the SSL protocol to use, e.g. "TLSv1" or "TLSv1.2".
585-
*
594+
* The produced {@link SSLContext} instance will be shared with all
595+
* the connections created by this connection factory. Use
596+
* {@link #setSslContextFactory(SslContextFactory)} for more flexibility.
586597
* @param protocol SSL protocol to use.
598+
* @see #setSslContextFactory(SslContextFactory)
587599
*/
588600
public void useSslProtocol(String protocol, TrustManager trustManager)
589601
throws NoSuchAlgorithmException, KeyManagementException
@@ -594,14 +606,17 @@ public void useSslProtocol(String protocol, TrustManager trustManager)
594606
}
595607

596608
/**
597-
* Convenience method for setting up an SSL socket factory/engine.
609+
* Convenience method for setting up an SSL socket socketFactory/engine.
598610
* Pass in an initialized SSLContext.
599-
*
611+
* The {@link SSLContext} instance will be shared with all
612+
* the connections created by this connection factory. Use
613+
* {@link #setSslContextFactory(SslContextFactory)} for more flexibility.
600614
* @param context An initialized SSLContext
615+
* @see #setSslContextFactory(SslContextFactory)
601616
*/
602617
public void useSslProtocol(SSLContext context) {
618+
this.sslContextFactory = name -> context;
603619
setSocketFactory(context.getSocketFactory());
604-
this.sslContext = context;
605620
}
606621

607622
public static String computeDefaultTlsProcotol(String[] supportedProtocols) {
@@ -667,11 +682,11 @@ protected synchronized FrameHandlerFactory createFrameHandlerFactory() throws IO
667682
if(this.nioParams.getNioExecutor() == null && this.nioParams.getThreadFactory() == null) {
668683
this.nioParams.setThreadFactory(getThreadFactory());
669684
}
670-
this.frameHandlerFactory = new SocketChannelFrameHandlerFactory(connectionTimeout, nioParams, isSSL(), sslContext);
685+
this.frameHandlerFactory = new SocketChannelFrameHandlerFactory(connectionTimeout, nioParams, isSSL(), sslContextFactory);
671686
}
672687
return this.frameHandlerFactory;
673688
} else {
674-
return new SocketFrameHandlerFactory(connectionTimeout, factory, socketConf, isSSL(), this.shutdownExecutor);
689+
return new SocketFrameHandlerFactory(connectionTimeout, socketFactory, socketConf, isSSL(), this.shutdownExecutor, sslContextFactory);
675690
}
676691

677692
}
@@ -915,7 +930,7 @@ public Connection newConnection(ExecutorService executor, AddressResolver addres
915930
Exception lastException = null;
916931
for (Address addr : addrs) {
917932
try {
918-
FrameHandler handler = fhFactory.create(addr);
933+
FrameHandler handler = fhFactory.create(addr, clientProvidedName);
919934
AMQConnection conn = createConnection(params, handler, metricsCollector);
920935
conn.start();
921936
this.metricsCollector.newConnection(conn);
@@ -1124,4 +1139,20 @@ public void setChannelRpcTimeout(int channelRpcTimeout) {
11241139
public int getChannelRpcTimeout() {
11251140
return channelRpcTimeout;
11261141
}
1142+
1143+
/**
1144+
* The factory to create SSL contexts.
1145+
* This provides more flexibility to create {@link SSLContext}s
1146+
* for different connections than sharing the {@link SSLContext}
1147+
* with all the connections produced by the connection factory
1148+
* (which is the case with the {@link #useSslProtocol()} methods).
1149+
* This way, different connections with a different certificate
1150+
* for each of them is a possible scenario.
1151+
* @param sslContextFactory
1152+
* @see #useSslProtocol(SSLContext)
1153+
* @since 5.0.0
1154+
*/
1155+
public void setSslContextFactory(SslContextFactory sslContextFactory) {
1156+
this.sslContextFactory = sslContextFactory;
1157+
}
11271158
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// Copyright (c) 2017-Present Pivotal Software, Inc. All rights reserved.
2+
//
3+
// This software, the RabbitMQ Java client library, is triple-licensed under the
4+
// Mozilla Public License 1.1 ("MPL"), the GNU General Public License version 2
5+
// ("GPL") and the Apache License version 2 ("ASL"). For the MPL, please see
6+
// LICENSE-MPL-RabbitMQ. For the GPL, please see LICENSE-GPL2. For the ASL,
7+
// please see LICENSE-APACHE2.
8+
//
9+
// This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY KIND,
10+
// either express or implied. See the LICENSE file for specific language governing
11+
// rights and limitations of this software.
12+
//
13+
// If you have any questions regarding licensing, please contact us at
14+
15+
16+
package com.rabbitmq.client;
17+
18+
import javax.net.ssl.SSLContext;
19+
20+
/**
21+
* A factory to create {@link SSLContext}s.
22+
*
23+
* @see ConnectionFactory#setSslContextFactory(SslContextFactory)
24+
* @since 5.0.0
25+
*/
26+
public interface SslContextFactory {
27+
28+
/**
29+
* Create a {@link SSLContext} for a given name.
30+
* The name is typically the name of the connection.
31+
* @param name name of the connection the SSLContext is used for
32+
* @return the SSLContext for this name
33+
* @see ConnectionFactory#newConnection(String)
34+
*/
35+
SSLContext create(String name);
36+
37+
}

src/main/java/com/rabbitmq/client/impl/FrameHandlerFactory.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,6 @@
99
*/
1010
public interface FrameHandlerFactory {
1111

12-
FrameHandler create(Address addr) throws IOException;
12+
FrameHandler create(Address addr, String connectionName) throws IOException;
1313

1414
}

src/main/java/com/rabbitmq/client/impl/SocketFrameHandlerFactory.java

+30-7
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import com.rabbitmq.client.Address;
1919
import com.rabbitmq.client.ConnectionFactory;
2020
import com.rabbitmq.client.SocketConfigurator;
21+
import com.rabbitmq.client.SslContextFactory;
2122

2223
import javax.net.SocketFactory;
2324
import java.io.IOException;
@@ -27,25 +28,34 @@
2728

2829
public class SocketFrameHandlerFactory extends AbstractFrameHandlerFactory {
2930

30-
private final SocketFactory factory;
31+
private final SocketFactory socketFactory;
3132
private final ExecutorService shutdownExecutor;
33+
private final SslContextFactory sslContextFactory;
3234

33-
public SocketFrameHandlerFactory(int connectionTimeout, SocketFactory factory, SocketConfigurator configurator, boolean ssl) {
34-
this(connectionTimeout, factory, configurator, ssl, null);
35+
public SocketFrameHandlerFactory(int connectionTimeout, SocketFactory socketFactory, SocketConfigurator configurator,
36+
boolean ssl) {
37+
this(connectionTimeout, socketFactory, configurator, ssl, null);
3538
}
3639

37-
public SocketFrameHandlerFactory(int connectionTimeout, SocketFactory factory, SocketConfigurator configurator, boolean ssl, ExecutorService shutdownExecutor) {
40+
public SocketFrameHandlerFactory(int connectionTimeout, SocketFactory socketFactory, SocketConfigurator configurator,
41+
boolean ssl, ExecutorService shutdownExecutor) {
42+
this(connectionTimeout, socketFactory, configurator, ssl, shutdownExecutor, null);
43+
}
44+
45+
public SocketFrameHandlerFactory(int connectionTimeout, SocketFactory socketFactory, SocketConfigurator configurator,
46+
boolean ssl, ExecutorService shutdownExecutor, SslContextFactory sslContextFactory) {
3847
super(connectionTimeout, configurator, ssl);
39-
this.factory = factory;
48+
this.socketFactory = socketFactory;
4049
this.shutdownExecutor = shutdownExecutor;
50+
this.sslContextFactory = sslContextFactory;
4151
}
4252

43-
public FrameHandler create(Address addr) throws IOException {
53+
public FrameHandler create(Address addr, String connectionName) throws IOException {
4454
String hostName = addr.getHost();
4555
int portNumber = ConnectionFactory.portOrDefault(addr.getPort(), ssl);
4656
Socket socket = null;
4757
try {
48-
socket = factory.createSocket();
58+
socket = createSocket(connectionName);
4959
configurator.configure(socket);
5060
socket.connect(new InetSocketAddress(hostName, portNumber),
5161
connectionTimeout);
@@ -56,6 +66,19 @@ public FrameHandler create(Address addr) throws IOException {
5666
}
5767
}
5868

69+
protected Socket createSocket(String connectionName) throws IOException {
70+
// SocketFactory takes precedence if specified
71+
if (socketFactory != null) {
72+
return socketFactory.createSocket();
73+
} else {
74+
if (ssl) {
75+
return sslContextFactory.create(connectionName).getSocketFactory().createSocket();
76+
} else {
77+
return SocketFactory.getDefault().createSocket();
78+
}
79+
}
80+
}
81+
5982
public FrameHandler create(Socket sock) throws IOException
6083
{
6184
return new SocketFrameHandler(sock, this.shutdownExecutor);

src/main/java/com/rabbitmq/client/impl/nio/SocketChannelFrameHandlerFactory.java

+6-4
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import com.rabbitmq.client.Address;
1919
import com.rabbitmq.client.ConnectionFactory;
20+
import com.rabbitmq.client.SslContextFactory;
2021
import com.rabbitmq.client.impl.AbstractFrameHandlerFactory;
2122
import com.rabbitmq.client.impl.FrameHandler;
2223

@@ -40,34 +41,35 @@ public class SocketChannelFrameHandlerFactory extends AbstractFrameHandlerFactor
4041

4142
final NioParams nioParams;
4243

43-
private final SSLContext sslContext;
44+
private final SslContextFactory sslContextFactory;
4445

4546
private final Lock stateLock = new ReentrantLock();
4647

4748
private final AtomicLong globalConnectionCount = new AtomicLong();
4849

4950
private final List<NioLoopContext> nioLoopContexts;
5051

51-
public SocketChannelFrameHandlerFactory(int connectionTimeout, NioParams nioParams, boolean ssl, SSLContext sslContext)
52+
public SocketChannelFrameHandlerFactory(int connectionTimeout, NioParams nioParams, boolean ssl, SslContextFactory sslContextFactory)
5253
throws IOException {
5354
super(connectionTimeout, null, ssl);
5455
this.nioParams = new NioParams(nioParams);
55-
this.sslContext = sslContext;
56+
this.sslContextFactory = sslContextFactory;
5657
this.nioLoopContexts = new ArrayList<NioLoopContext>(this.nioParams.getNbIoThreads());
5758
for (int i = 0; i < this.nioParams.getNbIoThreads(); i++) {
5859
this.nioLoopContexts.add(new NioLoopContext(this, this.nioParams));
5960
}
6061
}
6162

6263
@Override
63-
public FrameHandler create(Address addr) throws IOException {
64+
public FrameHandler create(Address addr, String connectionName) throws IOException {
6465
int portNumber = ConnectionFactory.portOrDefault(addr.getPort(), ssl);
6566

6667
SSLEngine sslEngine = null;
6768
SocketChannel channel = null;
6869

6970
try {
7071
if (ssl) {
72+
SSLContext sslContext = sslContextFactory.create(connectionName);
7173
sslEngine = sslContext.createSSLEngine(addr.getHost(), portNumber);
7274
sslEngine.setUseClientMode(true);
7375
}

src/main/java/com/rabbitmq/client/impl/recovery/RecoveryAwareAMQConnectionFactory.java

+12-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import java.util.ArrayList;
2525
import java.util.Collections;
2626
import java.util.List;
27+
import java.util.Map;
2728
import java.util.concurrent.TimeoutException;
2829

2930
public class RecoveryAwareAMQConnectionFactory {
@@ -58,7 +59,7 @@ public RecoveryAwareAMQConnection newConnection() throws IOException, TimeoutExc
5859

5960
for (Address addr : shuffled) {
6061
try {
61-
FrameHandler frameHandler = factory.create(addr);
62+
FrameHandler frameHandler = factory.create(addr, connectionName());
6263
RecoveryAwareAMQConnection conn = createConnection(params, frameHandler, metricsCollector);
6364
conn.start();
6465
metricsCollector.newConnection(conn);
@@ -89,4 +90,14 @@ private static List<Address> shuffle(List<Address> addrs) {
8990
protected RecoveryAwareAMQConnection createConnection(ConnectionParams params, FrameHandler handler, MetricsCollector metricsCollector) {
9091
return new RecoveryAwareAMQConnection(params, handler, metricsCollector);
9192
}
93+
94+
private String connectionName() {
95+
Map<String, Object> clientProperties = params.getClientProperties();
96+
if (clientProperties == null) {
97+
return null;
98+
} else {
99+
Object connectionName = clientProperties.get("connection_name");
100+
return connectionName == null ? null : connectionName.toString();
101+
}
102+
}
92103
}

src/test/java/com/rabbitmq/client/test/ChannelRpcTimeoutIntegrationTest.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ public void tearDown() throws Exception {
8787
private FrameHandler createFrameHandler() throws IOException {
8888
SocketFrameHandlerFactory socketFrameHandlerFactory = new SocketFrameHandlerFactory(ConnectionFactory.DEFAULT_CONNECTION_TIMEOUT,
8989
SocketFactory.getDefault(), new DefaultSocketConfigurator(), false, null);
90-
return socketFrameHandlerFactory.create(new Address("localhost"));
90+
return socketFrameHandlerFactory.create(new Address("localhost"), null);
9191
}
9292

9393
static class WaitingChannel extends ChannelN {

src/test/java/com/rabbitmq/client/test/ClientTests.java

+1
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
ConnectionFactoryTest.class,
5050
RecoveryAwareAMQConnectionFactoryTest.class,
5151
RpcTest.class,
52+
SslContextFactoryTest.class,
5253
LambdaCallbackTest.class
5354
})
5455
public class ClientTests {

0 commit comments

Comments
 (0)