Skip to content

Commit d05f93e

Browse files
committed
Handle early disconnects before SSL handshake
1 parent 2005e32 commit d05f93e

File tree

3 files changed

+94
-23
lines changed

3 files changed

+94
-23
lines changed

src/main/java/io/r2dbc/postgresql/client/ReactorNettyClient.java

Lines changed: 23 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
2626
import io.netty.handler.logging.LogLevel;
2727
import io.netty.handler.logging.LoggingHandler;
28+
import io.netty.util.AttributeKey;
2829
import io.netty.util.ReferenceCountUtil;
2930
import io.netty.util.internal.logging.InternalLogger;
3031
import io.netty.util.internal.logging.InternalLoggerFactory;
@@ -101,6 +102,8 @@ public final class ReactorNettyClient implements Client {
101102

102103
private static final Supplier<PostgresConnectionClosedException> EXPECTED = () -> new PostgresConnectionClosedException("Connection closed");
103104

105+
private static final AttributeKey<Mono<Void>> SSL_HANDSHAKE_KEY = AttributeKey.valueOf("ssl-handshake");
106+
104107
private final ByteBufAllocator byteBufAllocator;
105108

106109
private final ConnectionSettings settings;
@@ -144,7 +147,7 @@ private ReactorNettyClient(Connection connection, ConnectionSettings settings) {
144147
Assert.requireNonNull(connection, "Connection must not be null");
145148
this.settings = Assert.requireNonNull(settings, "ConnectionSettings must not be null");
146149

147-
connection.addHandlerFirst(new EnsureSubscribersCompleteChannelHandler(this.requestSink));
150+
connection.addHandlerLast(new EnsureSubscribersCompleteChannelHandler(this.requestSink));
148151
connection.addHandlerLast(new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE - 5, 1, 4, -4, 0));
149152
this.connection = connection;
150153
this.byteBufAllocator = connection.outbound().alloc();
@@ -392,43 +395,43 @@ public static Mono<ReactorNettyClient> connect(SocketAddress socketAddress, Conn
392395
tcpClient = tcpClient.option(ChannelOption.CONNECT_TIMEOUT_MILLIS, settings.getConnectTimeoutMs());
393396
}
394397

395-
return tcpClient.connect().flatMap(it -> {
396-
397-
ChannelPipeline pipeline = it.channel().pipeline();
398+
return tcpClient.doOnChannelInit((observer, channel, remoteAddress) -> {
399+
ChannelPipeline pipeline = channel.pipeline();
398400

399401
InternalLogger logger = InternalLoggerFactory.getInstance(ReactorNettyClient.class);
400402
if (logger.isTraceEnabled()) {
401403
pipeline.addFirst(LoggingHandler.class.getSimpleName(),
402404
new LoggingHandler(ReactorNettyClient.class, LogLevel.TRACE));
403405
}
404406

405-
return registerSslHandler(settings.getSslConfig(), it).thenReturn(new ReactorNettyClient(it, settings));
406-
});
407+
registerSslHandler(settings.getSslConfig(), channel);
408+
}).connect().flatMap(it ->
409+
getSslHandshake(it.channel()).thenReturn(new ReactorNettyClient(it, settings))
410+
);
407411
}
408412

409-
private static Mono<? extends Void> registerSslHandler(SSLConfig sslConfig, Connection it) {
410-
413+
private static void registerSslHandler(SSLConfig sslConfig, Channel channel) {
411414
try {
412415
if (sslConfig.getSslMode().startSsl()) {
413416

414-
return Mono.defer(() -> {
415-
AbstractPostgresSSLHandlerAdapter sslAdapter;
416-
if (sslConfig.getSslMode() == SSLMode.TUNNEL) {
417-
sslAdapter = new SSLTunnelHandlerAdapter(it.outbound().alloc(), sslConfig);
418-
} else {
419-
sslAdapter = new SSLSessionHandlerAdapter(it.outbound().alloc(), sslConfig);
420-
}
421-
422-
it.addHandlerFirst(sslAdapter);
423-
return sslAdapter.getHandshake();
417+
AbstractPostgresSSLHandlerAdapter sslAdapter;
418+
if (sslConfig.getSslMode() == SSLMode.TUNNEL) {
419+
sslAdapter = new SSLTunnelHandlerAdapter(channel.alloc(), sslConfig);
420+
} else {
421+
sslAdapter = new SSLSessionHandlerAdapter(channel.alloc(), sslConfig);
422+
}
424423

425-
}).subscribeOn(Schedulers.boundedElastic());
424+
channel.pipeline().addFirst(sslAdapter);
425+
channel.attr(SSL_HANDSHAKE_KEY).set(sslAdapter.getHandshake());
426426
}
427427
} catch (Throwable e) {
428428
throw new RuntimeException(e);
429429
}
430+
}
430431

431-
return Mono.empty();
432+
private static Mono<Void> getSslHandshake(Channel channel) {
433+
Mono<Void> sslHandshake = channel.attr(SSL_HANDSHAKE_KEY).getAndSet(null);
434+
return (sslHandshake == null) ? Mono.empty() : sslHandshake;
432435
}
433436

434437
@Override

src/main/java/io/r2dbc/postgresql/client/SSLSessionHandlerAdapter.java

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,17 @@ final class SSLSessionHandlerAdapter extends AbstractPostgresSSLHandlerAdapter {
4040
}
4141

4242
@Override
43-
public void handlerAdded(ChannelHandlerContext ctx) {
43+
public void channelActive(ChannelHandlerContext ctx) {
4444
Mono.from(SSLRequest.INSTANCE.encode(this.alloc)).subscribe(ctx::writeAndFlush);
45+
ctx.fireChannelActive();
46+
}
47+
48+
@Override
49+
public void channelInactive(ChannelHandlerContext ctx) {
50+
// If we receive channel inactive before removing this handler, then the inbound has closed early.
51+
PostgresqlSslException e = new PostgresqlSslException("Connection closed during SSL negotiation");
52+
completeHandshakeExceptionally(e);
53+
ctx.fireChannelInactive();
4554
}
4655

4756
@Override
@@ -54,7 +63,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) {
5463
processSslEnabled(ctx, buf);
5564
break;
5665
case 'N':
57-
processSslDisabled();
66+
processSslDisabled(ctx);
5867
break;
5968
default:
6069
buf.release();
@@ -65,13 +74,14 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) {
6574
}
6675
}
6776

68-
private void processSslDisabled() {
77+
private void processSslDisabled(ChannelHandlerContext ctx) {
6978
if (this.sslConfig.getSslMode().requireSsl()) {
7079
PostgresqlSslException e =
7180
new PostgresqlSslException("Server support for SSL connection is disabled, but client was configured with SSL mode " + this.sslConfig.getSslMode());
7281
completeHandshakeExceptionally(e);
7382
} else {
7483
completeHandshake();
84+
ctx.channel().pipeline().remove(this);
7585
}
7686
}
7787

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
/*
2+
* Copyright 2022 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* https://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package io.r2dbc.postgresql.client;
18+
19+
import io.r2dbc.postgresql.PostgresqlConnectionConfiguration;
20+
import io.r2dbc.postgresql.PostgresqlConnectionFactory;
21+
import org.junit.jupiter.api.Test;
22+
import reactor.netty.DisposableChannel;
23+
import reactor.netty.DisposableServer;
24+
import reactor.netty.tcp.TcpServer;
25+
import reactor.test.StepVerifier;
26+
27+
import static org.assertj.core.api.Assertions.assertThat;
28+
29+
public class DowntimeIntegrationTests {
30+
31+
@Test
32+
void failSslHandshakeIfInboundClosed() {
33+
// Simulate server downtime, where connections are accepted and then closed immediately
34+
DisposableServer server =
35+
TcpServer.create()
36+
.doOnConnection(DisposableChannel::dispose)
37+
.bindNow();
38+
39+
PostgresqlConnectionFactory connectionFactory =
40+
new PostgresqlConnectionFactory(
41+
PostgresqlConnectionConfiguration.builder()
42+
.host(server.host())
43+
.port(server.port())
44+
.username("test")
45+
.sslMode(SSLMode.REQUIRE)
46+
.build());
47+
48+
connectionFactory.create()
49+
.as(StepVerifier::create)
50+
.verifyErrorSatisfies(error ->
51+
assertThat(error)
52+
.isInstanceOf(AbstractPostgresSSLHandlerAdapter.PostgresqlSslException.class)
53+
.hasMessage("Connection closed during SSL negotiation"));
54+
55+
server.disposeNow();
56+
}
57+
58+
}

0 commit comments

Comments
 (0)