Skip to content

Commit c8852fa

Browse files
committed
Make SslContext creation more flexible
Introduce the SslContextFactory interface, in use in the ConnectionFactory to create SslContext instance based on the connection name. The existing ConnectionFactory#useSslProtocol() methods still work the same way (they end using a SslContextFactory that returns the same SslContext, whatever the name of the connection is). This introduces a breaking change in the FrameHandlerFactory by adding a new connectionName parameter. It should impact many users, as this is more an internal API. Fixes #241
1 parent 0e97491 commit c8852fa

10 files changed

+270
-20
lines changed

src/main/java/com/rabbitmq/client/ConnectionFactory.java

+37-10
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ public class ConnectionFactory implements Cloneable {
9494
private int handshakeTimeout = DEFAULT_HANDSHAKE_TIMEOUT;
9595
private int shutdownTimeout = DEFAULT_SHUTDOWN_TIMEOUT;
9696
private Map<String, Object> _clientProperties = AMQConnection.defaultClientProperties();
97-
private SocketFactory factory = SocketFactory.getDefault();
97+
private SocketFactory factory = null;
9898
private SaslConfig saslConfig = DefaultSaslConfig.PLAIN;
9999
private ExecutorService sharedExecutor;
100100
private ThreadFactory threadFactory = Executors.defaultThreadFactory();
@@ -119,7 +119,7 @@ public class ConnectionFactory implements Cloneable {
119119
private FrameHandlerFactory frameHandlerFactory;
120120
private NioParams nioParams = new NioParams();
121121

122-
private SSLContext sslContext;
122+
private SslContextFactory sslContextFactory;
123123

124124
/**
125125
* Continuation timeout on RPC calls.
@@ -449,7 +449,8 @@ public SocketFactory getSocketFactory() {
449449
* Set the socket factory used to make connections with. Can be
450450
* used to enable SSL connections by passing in a
451451
* javax.net.ssl.SSLSocketFactory instance.
452-
*
452+
* Note this applies only to blocking IO, not to
453+
* NIO, as the NIO API doesn't use the SocketFactory API.
453454
* @see #useSslProtocol
454455
*/
455456
public void setSocketFactory(SocketFactory factory) {
@@ -556,7 +557,7 @@ public void setExceptionHandler(ExceptionHandler exceptionHandler) {
556557
}
557558

558559
public boolean isSSL(){
559-
return getSocketFactory() instanceof SSLSocketFactory;
560+
return getSocketFactory() instanceof SSLSocketFactory || sslContextFactory != null;
560561
}
561562

562563
/**
@@ -572,6 +573,10 @@ public void useSslProtocol()
572573
/**
573574
* Convenience method for setting up a SSL socket factory/engine, using
574575
* the supplied protocol and a very trusting TrustManager.
576+
* The produced {@link SSLContext} instance will be shared by all
577+
* the connections created by this connection factory. Use
578+
* {@link #setSslContextFactory(SslContextFactory)} for more flexibility.
579+
* @see #setSslContextFactory(SslContextFactory)
575580
*/
576581
public void useSslProtocol(String protocol)
577582
throws NoSuchAlgorithmException, KeyManagementException
@@ -582,8 +587,11 @@ public void useSslProtocol(String protocol)
582587
/**
583588
* Convenience method for setting up an SSL socket factory/engine.
584589
* Pass in the SSL protocol to use, e.g. "TLSv1" or "TLSv1.2".
585-
*
590+
* The produced {@link SSLContext} instance will be shared with all
591+
* the connections created by this connection factory. Use
592+
* {@link #setSslContextFactory(SslContextFactory)} for more flexibility.
586593
* @param protocol SSL protocol to use.
594+
* @see #setSslContextFactory(SslContextFactory)
587595
*/
588596
public void useSslProtocol(String protocol, TrustManager trustManager)
589597
throws NoSuchAlgorithmException, KeyManagementException
@@ -596,12 +604,15 @@ public void useSslProtocol(String protocol, TrustManager trustManager)
596604
/**
597605
* Convenience method for setting up an SSL socket factory/engine.
598606
* Pass in an initialized SSLContext.
599-
*
607+
* The {@link SSLContext} instance will be shared with all
608+
* the connections created by this connection factory. Use
609+
* {@link #setSslContextFactory(SslContextFactory)} for more flexibility.
600610
* @param context An initialized SSLContext
611+
* @see #setSslContextFactory(SslContextFactory)
601612
*/
602613
public void useSslProtocol(SSLContext context) {
614+
this.sslContextFactory = name -> context;
603615
setSocketFactory(context.getSocketFactory());
604-
this.sslContext = context;
605616
}
606617

607618
public static String computeDefaultTlsProcotol(String[] supportedProtocols) {
@@ -667,11 +678,11 @@ protected synchronized FrameHandlerFactory createFrameHandlerFactory() throws IO
667678
if(this.nioParams.getNioExecutor() == null && this.nioParams.getThreadFactory() == null) {
668679
this.nioParams.setThreadFactory(getThreadFactory());
669680
}
670-
this.frameHandlerFactory = new SocketChannelFrameHandlerFactory(connectionTimeout, nioParams, isSSL(), sslContext);
681+
this.frameHandlerFactory = new SocketChannelFrameHandlerFactory(connectionTimeout, nioParams, isSSL(), sslContextFactory);
671682
}
672683
return this.frameHandlerFactory;
673684
} else {
674-
return new SocketFrameHandlerFactory(connectionTimeout, factory, socketConf, isSSL(), this.shutdownExecutor);
685+
return new SocketFrameHandlerFactory(connectionTimeout, factory, socketConf, isSSL(), this.shutdownExecutor, sslContextFactory);
675686
}
676687

677688
}
@@ -915,7 +926,7 @@ public Connection newConnection(ExecutorService executor, AddressResolver addres
915926
Exception lastException = null;
916927
for (Address addr : addrs) {
917928
try {
918-
FrameHandler handler = fhFactory.create(addr);
929+
FrameHandler handler = fhFactory.create(addr, clientProvidedName);
919930
AMQConnection conn = createConnection(params, handler, metricsCollector);
920931
conn.start();
921932
this.metricsCollector.newConnection(conn);
@@ -1124,4 +1135,20 @@ public void setChannelRpcTimeout(int channelRpcTimeout) {
11241135
public int getChannelRpcTimeout() {
11251136
return channelRpcTimeout;
11261137
}
1138+
1139+
/**
1140+
* The factory to create SSL contexts.
1141+
* This provides more flexibility to create {@link SSLContext}s
1142+
* for different connections than sharing the {@link SSLContext}
1143+
* with all the connections produced by the connection factory
1144+
* (which is the case with the {@link #useSslProtocol()} methods).
1145+
* This way, different connections with a different certificate
1146+
* for each of them is a possible scenario.
1147+
* @param sslContextFactory
1148+
* @see #useSslProtocol(SSLContext)
1149+
* @since 5.0.0
1150+
*/
1151+
public void setSslContextFactory(SslContextFactory sslContextFactory) {
1152+
this.sslContextFactory = sslContextFactory;
1153+
}
11271154
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
// Copyright (c) 2017-Present Pivotal Software, Inc. All rights reserved.
2+
//
3+
// This software, the RabbitMQ Java client library, is triple-licensed under the
4+
// Mozilla Public License 1.1 ("MPL"), the GNU General Public License version 2
5+
// ("GPL") and the Apache License version 2 ("ASL"). For the MPL, please see
6+
// LICENSE-MPL-RabbitMQ. For the GPL, please see LICENSE-GPL2. For the ASL,
7+
// please see LICENSE-APACHE2.
8+
//
9+
// This software is distributed on an "AS IS" basis, WITHOUT WARRANTY OF ANY KIND,
10+
// either express or implied. See the LICENSE file for specific language governing
11+
// rights and limitations of this software.
12+
//
13+
// If you have any questions regarding licensing, please contact us at
14+
15+
16+
package com.rabbitmq.client;
17+
18+
import javax.net.ssl.SSLContext;
19+
20+
/**
21+
* A factory to create {@link SSLContext}s.
22+
*
23+
* @see ConnectionFactory#setSslContextFactory(SslContextFactory)
24+
* @since 5.0.0
25+
*/
26+
public interface SslContextFactory {
27+
28+
/**
29+
* Create a {@link SSLContext} for a given name.
30+
* The name is typically the name of the connection.
31+
* @param name name of the connection the SSLContext is used for
32+
* @return the SSLContext for this name
33+
* @see ConnectionFactory#newConnection(String)
34+
*/
35+
SSLContext create(String name);
36+
37+
}

src/main/java/com/rabbitmq/client/impl/FrameHandlerFactory.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,6 @@
99
*/
1010
public interface FrameHandlerFactory {
1111

12-
FrameHandler create(Address addr) throws IOException;
12+
FrameHandler create(Address addr, String connectionName) throws IOException;
1313

1414
}

src/main/java/com/rabbitmq/client/impl/SocketFrameHandlerFactory.java

+22-2
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import com.rabbitmq.client.Address;
1919
import com.rabbitmq.client.ConnectionFactory;
2020
import com.rabbitmq.client.SocketConfigurator;
21+
import com.rabbitmq.client.SslContextFactory;
2122

2223
import javax.net.SocketFactory;
2324
import java.io.IOException;
@@ -29,23 +30,29 @@ public class SocketFrameHandlerFactory extends AbstractFrameHandlerFactory {
2930

3031
private final SocketFactory factory;
3132
private final ExecutorService shutdownExecutor;
33+
private final SslContextFactory sslContextFactory;
3234

3335
public SocketFrameHandlerFactory(int connectionTimeout, SocketFactory factory, SocketConfigurator configurator, boolean ssl) {
3436
this(connectionTimeout, factory, configurator, ssl, null);
3537
}
3638

3739
public SocketFrameHandlerFactory(int connectionTimeout, SocketFactory factory, SocketConfigurator configurator, boolean ssl, ExecutorService shutdownExecutor) {
40+
this(connectionTimeout, factory, configurator, ssl, shutdownExecutor, null);
41+
}
42+
43+
public SocketFrameHandlerFactory(int connectionTimeout, SocketFactory factory, SocketConfigurator configurator, boolean ssl, ExecutorService shutdownExecutor, SslContextFactory sslContextFactory) {
3844
super(connectionTimeout, configurator, ssl);
3945
this.factory = factory;
4046
this.shutdownExecutor = shutdownExecutor;
47+
this.sslContextFactory = sslContextFactory;
4148
}
4249

43-
public FrameHandler create(Address addr) throws IOException {
50+
public FrameHandler create(Address addr, String connectionName) throws IOException {
4451
String hostName = addr.getHost();
4552
int portNumber = ConnectionFactory.portOrDefault(addr.getPort(), ssl);
4653
Socket socket = null;
4754
try {
48-
socket = factory.createSocket();
55+
socket = createSocket(connectionName);
4956
configurator.configure(socket);
5057
socket.connect(new InetSocketAddress(hostName, portNumber),
5158
connectionTimeout);
@@ -56,6 +63,19 @@ public FrameHandler create(Address addr) throws IOException {
5663
}
5764
}
5865

66+
protected Socket createSocket(String connectionName) throws IOException {
67+
// SocketFactory takes precedence if specified
68+
if (factory != null) {
69+
return factory.createSocket();
70+
} else {
71+
if (ssl) {
72+
return sslContextFactory.create(connectionName).getSocketFactory().createSocket();
73+
} else {
74+
return SocketFactory.getDefault().createSocket();
75+
}
76+
}
77+
}
78+
5979
public FrameHandler create(Socket sock) throws IOException
6080
{
6181
return new SocketFrameHandler(sock, this.shutdownExecutor);

src/main/java/com/rabbitmq/client/impl/nio/SocketChannelFrameHandlerFactory.java

+6-4
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
import com.rabbitmq.client.Address;
1919
import com.rabbitmq.client.ConnectionFactory;
20+
import com.rabbitmq.client.SslContextFactory;
2021
import com.rabbitmq.client.impl.AbstractFrameHandlerFactory;
2122
import com.rabbitmq.client.impl.FrameHandler;
2223

@@ -40,34 +41,35 @@ public class SocketChannelFrameHandlerFactory extends AbstractFrameHandlerFactor
4041

4142
final NioParams nioParams;
4243

43-
private final SSLContext sslContext;
44+
private final SslContextFactory sslContextFactory;
4445

4546
private final Lock stateLock = new ReentrantLock();
4647

4748
private final AtomicLong globalConnectionCount = new AtomicLong();
4849

4950
private final List<NioLoopContext> nioLoopContexts;
5051

51-
public SocketChannelFrameHandlerFactory(int connectionTimeout, NioParams nioParams, boolean ssl, SSLContext sslContext)
52+
public SocketChannelFrameHandlerFactory(int connectionTimeout, NioParams nioParams, boolean ssl, SslContextFactory sslContextFactory)
5253
throws IOException {
5354
super(connectionTimeout, null, ssl);
5455
this.nioParams = new NioParams(nioParams);
55-
this.sslContext = sslContext;
56+
this.sslContextFactory = sslContextFactory;
5657
this.nioLoopContexts = new ArrayList<NioLoopContext>(this.nioParams.getNbIoThreads());
5758
for (int i = 0; i < this.nioParams.getNbIoThreads(); i++) {
5859
this.nioLoopContexts.add(new NioLoopContext(this, this.nioParams));
5960
}
6061
}
6162

6263
@Override
63-
public FrameHandler create(Address addr) throws IOException {
64+
public FrameHandler create(Address addr, String connectionName) throws IOException {
6465
int portNumber = ConnectionFactory.portOrDefault(addr.getPort(), ssl);
6566

6667
SSLEngine sslEngine = null;
6768
SocketChannel channel = null;
6869

6970
try {
7071
if (ssl) {
72+
SSLContext sslContext = sslContextFactory.create(connectionName);
7173
sslEngine = sslContext.createSSLEngine(addr.getHost(), portNumber);
7274
sslEngine.setUseClientMode(true);
7375
}

src/main/java/com/rabbitmq/client/impl/recovery/RecoveryAwareAMQConnectionFactory.java

+12-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import java.util.ArrayList;
2525
import java.util.Collections;
2626
import java.util.List;
27+
import java.util.Map;
2728
import java.util.concurrent.TimeoutException;
2829

2930
public class RecoveryAwareAMQConnectionFactory {
@@ -58,7 +59,7 @@ public RecoveryAwareAMQConnection newConnection() throws IOException, TimeoutExc
5859

5960
for (Address addr : shuffled) {
6061
try {
61-
FrameHandler frameHandler = factory.create(addr);
62+
FrameHandler frameHandler = factory.create(addr, connectionName());
6263
RecoveryAwareAMQConnection conn = createConnection(params, frameHandler, metricsCollector);
6364
conn.start();
6465
metricsCollector.newConnection(conn);
@@ -89,4 +90,14 @@ private static List<Address> shuffle(List<Address> addrs) {
8990
protected RecoveryAwareAMQConnection createConnection(ConnectionParams params, FrameHandler handler, MetricsCollector metricsCollector) {
9091
return new RecoveryAwareAMQConnection(params, handler, metricsCollector);
9192
}
93+
94+
private String connectionName() {
95+
Map<String, Object> clientProperties = params.getClientProperties();
96+
if (clientProperties == null) {
97+
return null;
98+
} else {
99+
Object connectionName = clientProperties.get("connection_name");
100+
return connectionName == null ? null : connectionName.toString();
101+
}
102+
}
92103
}

src/test/java/com/rabbitmq/client/test/ChannelRpcTimeoutIntegrationTest.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ public void tearDown() throws Exception {
8787
private FrameHandler createFrameHandler() throws IOException {
8888
SocketFrameHandlerFactory socketFrameHandlerFactory = new SocketFrameHandlerFactory(ConnectionFactory.DEFAULT_CONNECTION_TIMEOUT,
8989
SocketFactory.getDefault(), new DefaultSocketConfigurator(), false, null);
90-
return socketFrameHandlerFactory.create(new Address("localhost"));
90+
return socketFrameHandlerFactory.create(new Address("localhost"), null);
9191
}
9292

9393
static class WaitingChannel extends ChannelN {

src/test/java/com/rabbitmq/client/test/ClientTests.java

+1
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@
4949
ConnectionFactoryTest.class,
5050
RecoveryAwareAMQConnectionFactoryTest.class,
5151
RpcTest.class,
52+
SslContextFactoryTest.class,
5253
LambdaCallbackTest.class
5354
})
5455
public class ClientTests {

0 commit comments

Comments
 (0)