|
25 | 25 | import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
|
26 | 26 | import io.netty.handler.logging.LogLevel;
|
27 | 27 | import io.netty.handler.logging.LoggingHandler;
|
| 28 | +import io.netty.util.AttributeKey; |
28 | 29 | import io.netty.util.ReferenceCountUtil;
|
29 | 30 | import io.netty.util.internal.logging.InternalLogger;
|
30 | 31 | import io.netty.util.internal.logging.InternalLoggerFactory;
|
@@ -101,6 +102,8 @@ public final class ReactorNettyClient implements Client {
|
101 | 102 |
|
102 | 103 | private static final Supplier<PostgresConnectionClosedException> EXPECTED = () -> new PostgresConnectionClosedException("Connection closed");
|
103 | 104 |
|
| 105 | + private static final AttributeKey<Mono<Void>> SSL_HANDSHAKE_KEY = AttributeKey.valueOf("ssl-handshake"); |
| 106 | + |
104 | 107 | private final ByteBufAllocator byteBufAllocator;
|
105 | 108 |
|
106 | 109 | private final ConnectionSettings settings;
|
@@ -144,7 +147,7 @@ private ReactorNettyClient(Connection connection, ConnectionSettings settings) {
|
144 | 147 | Assert.requireNonNull(connection, "Connection must not be null");
|
145 | 148 | this.settings = Assert.requireNonNull(settings, "ConnectionSettings must not be null");
|
146 | 149 |
|
147 |
| - connection.addHandlerFirst(new EnsureSubscribersCompleteChannelHandler(this.requestSink)); |
| 150 | + connection.addHandlerLast(new EnsureSubscribersCompleteChannelHandler(this.requestSink)); |
148 | 151 | connection.addHandlerLast(new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE - 5, 1, 4, -4, 0));
|
149 | 152 | this.connection = connection;
|
150 | 153 | this.byteBufAllocator = connection.outbound().alloc();
|
@@ -392,43 +395,43 @@ public static Mono<ReactorNettyClient> connect(SocketAddress socketAddress, Conn
|
392 | 395 | tcpClient = tcpClient.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, settings.getConnectTimeoutMs());
|
393 | 396 | }
|
394 | 397 |
|
395 |
| - return tcpClient.connect().flatMap(it -> { |
396 |
| - |
397 |
| - ChannelPipeline pipeline = it.channel().pipeline(); |
| 398 | + return tcpClient.doOnChannelInit((observer, channel, remoteAddress) -> { |
| 399 | + ChannelPipeline pipeline = channel.pipeline(); |
398 | 400 |
|
399 | 401 | InternalLogger logger = InternalLoggerFactory.getInstance(ReactorNettyClient.class);
|
400 | 402 | if (logger.isTraceEnabled()) {
|
401 | 403 | pipeline.addFirst(LoggingHandler.class.getSimpleName(),
|
402 | 404 | new LoggingHandler(ReactorNettyClient.class, LogLevel.TRACE));
|
403 | 405 | }
|
404 | 406 |
|
405 |
| - return registerSslHandler(settings.getSslConfig(), it).thenReturn(new ReactorNettyClient(it, settings)); |
406 |
| - }); |
| 407 | + registerSslHandler(settings.getSslConfig(), channel); |
| 408 | + }).connect().flatMap(it -> |
| 409 | + getSslHandshake(it.channel()).thenReturn(new ReactorNettyClient(it, settings)) |
| 410 | + ); |
407 | 411 | }
|
408 | 412 |
|
409 |
| - private static Mono<? extends Void> registerSslHandler(SSLConfig sslConfig, Connection it) { |
410 |
| - |
| 413 | + private static void registerSslHandler(SSLConfig sslConfig, Channel channel) { |
411 | 414 | try {
|
412 | 415 | if (sslConfig.getSslMode().startSsl()) {
|
413 | 416 |
|
414 |
| - return Mono.defer(() -> { |
415 |
| - AbstractPostgresSSLHandlerAdapter sslAdapter; |
416 |
| - if (sslConfig.getSslMode() == SSLMode.TUNNEL) { |
417 |
| - sslAdapter = new SSLTunnelHandlerAdapter(it.outbound().alloc(), sslConfig); |
418 |
| - } else { |
419 |
| - sslAdapter = new SSLSessionHandlerAdapter(it.outbound().alloc(), sslConfig); |
420 |
| - } |
421 |
| - |
422 |
| - it.addHandlerFirst(sslAdapter); |
423 |
| - return sslAdapter.getHandshake(); |
| 417 | + AbstractPostgresSSLHandlerAdapter sslAdapter; |
| 418 | + if (sslConfig.getSslMode() == SSLMode.TUNNEL) { |
| 419 | + sslAdapter = new SSLTunnelHandlerAdapter(channel.alloc(), sslConfig); |
| 420 | + } else { |
| 421 | + sslAdapter = new SSLSessionHandlerAdapter(channel.alloc(), sslConfig); |
| 422 | + } |
424 | 423 |
|
425 |
| - }).subscribeOn(Schedulers.boundedElastic()); |
| 424 | + channel.pipeline().addFirst(sslAdapter); |
| 425 | + channel.attr(SSL_HANDSHAKE_KEY).set(sslAdapter.getHandshake()); |
426 | 426 | }
|
427 | 427 | } catch (Throwable e) {
|
428 | 428 | throw new RuntimeException(e);
|
429 | 429 | }
|
| 430 | + } |
430 | 431 |
|
431 |
| - return Mono.empty(); |
| 432 | + private static Mono<Void> getSslHandshake(Channel channel) { |
| 433 | + Mono<Void> sslHandshake = channel.attr(SSL_HANDSHAKE_KEY).getAndSet(null); |
| 434 | + return (sslHandshake == null) ? Mono.empty() : sslHandshake; |
432 | 435 | }
|
433 | 436 |
|
434 | 437 | @Override
|
|
0 commit comments