Skip to content

Commit ba6b2f4

Browse files
committed
update: keep SSL handler adapters in the pipeline
1 parent 6de187d commit ba6b2f4

File tree

3 files changed

+34
-33
lines changed

3 files changed

+34
-33
lines changed

src/main/java/io/r2dbc/postgresql/client/ReactorNettyClient.java

+2-6
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@
2525
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
2626
import io.netty.handler.logging.LogLevel;
2727
import io.netty.handler.logging.LoggingHandler;
28-
import io.netty.util.AttributeKey;
2928
import io.netty.util.ReferenceCountUtil;
3029
import io.netty.util.internal.logging.InternalLogger;
3130
import io.netty.util.internal.logging.InternalLoggerFactory;
@@ -102,8 +101,6 @@ public final class ReactorNettyClient implements Client {
102101

103102
private static final Supplier<PostgresConnectionClosedException> EXPECTED = () -> new PostgresConnectionClosedException("Connection closed");
104103

105-
private static final AttributeKey<Mono<Void>> SSL_HANDSHAKE_KEY = AttributeKey.valueOf("ssl-handshake");
106-
107104
private final ByteBufAllocator byteBufAllocator;
108105

109106
private final ConnectionSettings settings;
@@ -422,16 +419,15 @@ private static void registerSslHandler(SSLConfig sslConfig, Channel channel) {
422419
}
423420

424421
channel.pipeline().addFirst(sslAdapter);
425-
channel.attr(SSL_HANDSHAKE_KEY).set(sslAdapter.getHandshake());
426422
}
427423
} catch (Throwable e) {
428424
throw new RuntimeException(e);
429425
}
430426
}
431427

432428
private static Mono<Void> getSslHandshake(Channel channel) {
433-
Mono<Void> sslHandshake = channel.attr(SSL_HANDSHAKE_KEY).getAndSet(null);
434-
return (sslHandshake == null) ? Mono.empty() : sslHandshake;
429+
AbstractPostgresSSLHandlerAdapter sslAdapter = channel.pipeline().get(AbstractPostgresSSLHandlerAdapter.class);
430+
return (sslAdapter != null) ? sslAdapter.getHandshake() : Mono.empty();
435431
}
436432

437433
@Override

src/main/java/io/r2dbc/postgresql/client/SSLSessionHandlerAdapter.java

+31-24
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,8 @@ final class SSLSessionHandlerAdapter extends AbstractPostgresSSLHandlerAdapter {
3333

3434
private final SSLConfig sslConfig;
3535

36+
private boolean negotiating = true;
37+
3638
SSLSessionHandlerAdapter(ByteBufAllocator alloc, SSLConfig sslConfig) {
3739
super(alloc, sslConfig);
3840
this.alloc = alloc;
@@ -41,36 +43,44 @@ final class SSLSessionHandlerAdapter extends AbstractPostgresSSLHandlerAdapter {
4143

4244
@Override
4345
public void channelActive(ChannelHandlerContext ctx) throws Exception {
44-
Mono.from(SSLRequest.INSTANCE.encode(this.alloc)).subscribe(ctx::writeAndFlush);
46+
if (negotiating) {
47+
Mono.from(SSLRequest.INSTANCE.encode(this.alloc)).subscribe(ctx::writeAndFlush);
48+
}
4549
super.channelActive(ctx);
4650
}
4751

4852
@Override
4953
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
50-
// If we receive channel inactive before removing this handler, then the inbound has closed early.
51-
PostgresqlSslException e = new PostgresqlSslException("Connection closed during SSL negotiation");
52-
completeHandshakeExceptionally(e);
54+
if (negotiating) {
55+
// If we receive channel inactive before negotiated, then the inbound has closed early.
56+
PostgresqlSslException e = new PostgresqlSslException("Connection closed during SSL negotiation");
57+
completeHandshakeExceptionally(e);
58+
}
5359
super.channelInactive(ctx);
5460
}
5561

5662
@Override
57-
public void channelRead(ChannelHandlerContext ctx, Object msg) {
58-
ByteBuf buf = (ByteBuf) msg;
59-
char response = (char) buf.readByte();
60-
try {
61-
switch (response) {
62-
case 'S':
63-
processSslEnabled(ctx, buf);
64-
break;
65-
case 'N':
66-
processSslDisabled(ctx);
67-
break;
68-
default:
69-
buf.release();
70-
throw new IllegalStateException("Unknown SSLResponse from server: '" + response + "'");
63+
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
64+
if (negotiating) {
65+
ByteBuf buf = (ByteBuf) msg;
66+
char response = (char) buf.readByte();
67+
try {
68+
switch (response) {
69+
case 'S':
70+
processSslEnabled(ctx, buf);
71+
break;
72+
case 'N':
73+
processSslDisabled(ctx);
74+
break;
75+
default:
76+
throw new IllegalStateException("Unknown SSLResponse from server: '" + response + "'");
77+
}
78+
} finally {
79+
buf.release();
80+
negotiating = false;
7181
}
72-
} finally {
73-
buf.release();
82+
} else {
83+
super.channelRead(ctx, msg);
7484
}
7585
}
7686

@@ -81,7 +91,6 @@ private void processSslDisabled(ChannelHandlerContext ctx) {
8191
completeHandshakeExceptionally(e);
8292
} else {
8393
completeHandshake();
84-
ctx.channel().pipeline().remove(this);
8594
}
8695
}
8796

@@ -92,9 +101,7 @@ private void processSslEnabled(ChannelHandlerContext ctx, ByteBuf msg) {
92101
completeHandshakeExceptionally(e);
93102
return;
94103
}
95-
ctx.channel().pipeline()
96-
.addFirst(this.getSslHandler())
97-
.remove(this);
104+
ctx.channel().pipeline().addFirst(this.getSslHandler());
98105
ctx.fireChannelRead(msg.retain());
99106
}
100107

src/main/java/io/r2dbc/postgresql/client/SSLTunnelHandlerAdapter.java

+1-3
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,7 @@ public void handlerAdded(ChannelHandlerContext ctx) {
4040
completeHandshakeExceptionally(e);
4141
return;
4242
}
43-
ctx.channel().pipeline()
44-
.addFirst(this.getSslHandler())
45-
.remove(this);
43+
ctx.channel().pipeline().addFirst(this.getSslHandler());
4644
}
4745

4846
}

0 commit comments

Comments
 (0)