Skip to content

Commit 83d39c8

Browse files
committed
Propagate SSLSession in ConnectionContext to enable SASL/SCRAM channel binding.
[#645]
1 parent 26761e8 commit 83d39c8

File tree

8 files changed

+124
-19
lines changed

8 files changed

+124
-19
lines changed

pom.xml

+5
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,11 @@
130130
<artifactId>scram-client</artifactId>
131131
<version>${scram-client.version}</version>
132132
</dependency>
133+
<dependency>
134+
<groupId>com.ongres.scram</groupId>
135+
<artifactId>scram-common</artifactId>
136+
<version>${scram-client.version}</version>
137+
</dependency>
133138
<dependency>
134139
<groupId>io.projectreactor</groupId>
135140
<artifactId>reactor-core</artifactId>

src/main/java/io/r2dbc/postgresql/SingleHostConnectionFunction.java

+4-3
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
import io.r2dbc.postgresql.authentication.PasswordAuthenticationHandler;
2121
import io.r2dbc.postgresql.authentication.SASLAuthenticationHandler;
2222
import io.r2dbc.postgresql.client.Client;
23+
import io.r2dbc.postgresql.client.ConnectionContext;
2324
import io.r2dbc.postgresql.client.ConnectionSettings;
2425
import io.r2dbc.postgresql.client.PostgresStartupParameterProvider;
2526
import io.r2dbc.postgresql.client.StartupMessageFlow;
@@ -46,7 +47,7 @@ public Mono<Client> connect(SocketAddress endpoint, ConnectionSettings settings)
4647

4748
return this.upstreamFunction.connect(endpoint, settings)
4849
.delayUntil(client -> getCredentials().flatMapMany(credentials -> StartupMessageFlow
49-
.exchange(auth -> getAuthenticationHandler(auth, credentials), client, this.configuration.getDatabase(), credentials.getUsername(),
50+
.exchange(auth -> getAuthenticationHandler(auth, credentials, client.getContext()), client, this.configuration.getDatabase(), credentials.getUsername(),
5051
getParameterProvider(this.configuration, settings)))
5152
.handle(ExceptionFactory.INSTANCE::handleErrorResponse));
5253
}
@@ -55,13 +56,13 @@ private static PostgresStartupParameterProvider getParameterProvider(PostgresqlC
5556
return new PostgresStartupParameterProvider(configuration.getApplicationName(), configuration.getTimeZone(), settings);
5657
}
5758

58-
protected AuthenticationHandler getAuthenticationHandler(AuthenticationMessage message, UsernameAndPassword usernameAndPassword) {
59+
protected AuthenticationHandler getAuthenticationHandler(AuthenticationMessage message, UsernameAndPassword usernameAndPassword, ConnectionContext context) {
5960
if (PasswordAuthenticationHandler.supports(message)) {
6061
CharSequence password = Assert.requireNonNull(usernameAndPassword.getPassword(), "Password must not be null");
6162
return new PasswordAuthenticationHandler(password, usernameAndPassword.getUsername());
6263
} else if (SASLAuthenticationHandler.supports(message)) {
6364
CharSequence password = Assert.requireNonNull(usernameAndPassword.getPassword(), "Password must not be null");
64-
return new SASLAuthenticationHandler(password, usernameAndPassword.getUsername());
65+
return new SASLAuthenticationHandler(password, usernameAndPassword.getUsername(), context);
6566
} else {
6667
throw new IllegalStateException(String.format("Unable to provide AuthenticationHandler capable of handling %s", message));
6768
}

src/main/java/io/r2dbc/postgresql/authentication/SASLAuthenticationHandler.java

+53-8
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import com.ongres.scram.client.ScramClient;
44
import com.ongres.scram.common.StringPreparation;
55
import com.ongres.scram.common.exception.ScramException;
6-
6+
import com.ongres.scram.common.util.TlsServerEndpoint;
7+
import io.r2dbc.postgresql.client.ConnectionContext;
78
import io.r2dbc.postgresql.message.backend.AuthenticationMessage;
89
import io.r2dbc.postgresql.message.backend.AuthenticationSASL;
910
import io.r2dbc.postgresql.message.backend.AuthenticationSASLContinue;
@@ -14,26 +15,40 @@
1415
import io.r2dbc.postgresql.util.Assert;
1516
import io.r2dbc.postgresql.util.ByteBufferUtils;
1617
import reactor.core.Exceptions;
18+
import reactor.util.Logger;
19+
import reactor.util.Loggers;
1720
import reactor.util.annotation.Nullable;
1821

22+
import javax.net.ssl.SSLException;
23+
import javax.net.ssl.SSLSession;
24+
import java.security.cert.Certificate;
25+
import java.security.cert.CertificateException;
26+
import java.security.cert.X509Certificate;
27+
1928
public class SASLAuthenticationHandler implements AuthenticationHandler {
2029

30+
private static final Logger LOG = Loggers.getLogger(SASLAuthenticationHandler.class);
31+
2132
private final CharSequence password;
2233

2334
private final String username;
2435

36+
private final ConnectionContext context;
37+
2538
private ScramClient scramClient;
2639

2740
/**
2841
* Create a new handler.
2942
*
3043
* @param password the password to use for authentication
3144
* @param username the username to use for authentication
45+
* @param context the connection context
3246
* @throws IllegalArgumentException if {@code password} or {@code user} is {@code null}
3347
*/
34-
public SASLAuthenticationHandler(CharSequence password, String username) {
48+
public SASLAuthenticationHandler(CharSequence password, String username, ConnectionContext context) {
3549
this.password = Assert.requireNonNull(password, "password must not be null");
3650
this.username = Assert.requireNonNull(username, "username must not be null");
51+
this.context = Assert.requireNonNull(context, "context must not be null");
3752
}
3853

3954
/**
@@ -67,14 +82,44 @@ public FrontendMessage handle(AuthenticationMessage message) {
6782
}
6883

6984
private FrontendMessage handleAuthenticationSASL(AuthenticationSASL message) {
70-
this.scramClient = ScramClient.builder()
85+
86+
char[] password = new char[this.password.length()];
87+
for (int i = 0; i < password.length; i++) {
88+
password[i] = this.password.charAt(i);
89+
}
90+
91+
ScramClient.FinalBuildStage builder = ScramClient.builder()
7192
.advertisedMechanisms(message.getAuthenticationMechanisms())
72-
.username(username) // ignored by the server, use startup message
73-
.password(password.toString().toCharArray())
74-
.stringPreparation(StringPreparation.POSTGRESQL_PREPARATION)
75-
.build();
93+
.username(this.username) // ignored by the server, use startup message
94+
.password(password)
95+
.stringPreparation(StringPreparation.POSTGRESQL_PREPARATION);
96+
97+
SSLSession sslSession = this.context.getSslSession();
7698

77-
return new SASLInitialResponse(ByteBufferUtils.encode(this.scramClient.clientFirstMessage().toString()), scramClient.getScramMechanism().getName());
99+
if (sslSession != null && sslSession.isValid()) {
100+
builder.channelBinding(TlsServerEndpoint.TLS_SERVER_END_POINT, extractSslEndpoint(sslSession));
101+
}
102+
103+
this.scramClient = builder.build();
104+
105+
return new SASLInitialResponse(ByteBufferUtils.encode(this.scramClient.clientFirstMessage().toString()), this.scramClient.getScramMechanism().getName());
106+
}
107+
108+
private static byte[] extractSslEndpoint(SSLSession sslSession) {
109+
try {
110+
Certificate[] certificates = sslSession.getPeerCertificates();
111+
if (certificates != null && certificates.length > 0) {
112+
Certificate peerCert = certificates[0]; // First certificate is the peer's certificate
113+
if (peerCert instanceof X509Certificate) {
114+
X509Certificate cert = (X509Certificate) peerCert;
115+
return TlsServerEndpoint.getChannelBindingData(cert);
116+
117+
}
118+
}
119+
} catch (CertificateException | SSLException e) {
120+
LOG.debug("Cannot extract X509Certificate from SSL session", e);
121+
}
122+
return new byte[0];
78123
}
79124

80125
private FrontendMessage handleAuthenticationSASLContinue(AuthenticationSASLContinue message) {

src/main/java/io/r2dbc/postgresql/client/ConnectionContext.java

+24-3
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,9 @@
2020
import reactor.util.Loggers;
2121

2222
import javax.annotation.Nullable;
23+
import javax.net.ssl.SSLSession;
2324
import java.util.concurrent.atomic.AtomicLong;
25+
import java.util.function.Supplier;
2426

2527
/**
2628
* Value object capturing diagnostic connection context. Allows for log-message post-processing with {@link #getMessage(String) if the logger category for
@@ -50,6 +52,8 @@ public final class ConnectionContext {
5052

5153
private final String connectionIdPrefix;
5254

55+
private final Supplier<SSLSession> sslSession;
56+
5357
/**
5458
* Create a new {@link ConnectionContext} with a unique connection Id.
5559
*/
@@ -58,13 +62,15 @@ public ConnectionContext() {
5862
this.connectionCounter = incrementConnectionCounter();
5963
this.connectionIdPrefix = getConnectionIdPrefix();
6064
this.channelId = null;
65+
this.sslSession = () -> null;
6166
}
6267

63-
private ConnectionContext(@Nullable Integer processId, @Nullable String channelId, String connectionCounter) {
68+
private ConnectionContext(@Nullable Integer processId, @Nullable String channelId, String connectionCounter, Supplier<SSLSession> sslSession) {
6469
this.processId = processId;
6570
this.channelId = channelId;
6671
this.connectionCounter = connectionCounter;
6772
this.connectionIdPrefix = getConnectionIdPrefix();
73+
this.sslSession = sslSession;
6874
}
6975

7076
private String incrementConnectionCounter() {
@@ -101,14 +107,29 @@ public String getMessage(String original) {
101107
return original;
102108
}
103109

110+
@Nullable
111+
public SSLSession getSslSession() {
112+
return this.sslSession.get();
113+
}
114+
104115
/**
105116
* Create a new {@link ConnectionContext} by associating the {@code channelId}.
106117
*
107118
* @param channelId the channel identifier.
108119
* @return a new {@link ConnectionContext} with all previously set values and the associated {@code channelId}.
109120
*/
110121
public ConnectionContext withChannelId(String channelId) {
111-
return new ConnectionContext(this.processId, channelId, this.connectionCounter);
122+
return new ConnectionContext(this.processId, channelId, this.connectionCounter, this.sslSession);
123+
}
124+
125+
/**
126+
* Create a new {@link ConnectionContext} by associating the {@code sslSession}.
127+
*
128+
* @param sslSession the SSL session supplier.
129+
* @return a new {@link ConnectionContext} with all previously set values and the associated {@code sslSession}.
130+
*/
131+
public ConnectionContext withSslSession(Supplier<SSLSession> sslSession) {
132+
return new ConnectionContext(this.processId, this.channelId, this.connectionCounter, sslSession);
112133
}
113134

114135
/**
@@ -118,7 +139,7 @@ public ConnectionContext withChannelId(String channelId) {
118139
* @return a new {@link ConnectionContext} with all previously set values and the associated {@code processId}.
119140
*/
120141
public ConnectionContext withProcessId(int processId) {
121-
return new ConnectionContext(processId, this.channelId, this.connectionCounter);
142+
return new ConnectionContext(processId, this.channelId, this.connectionCounter, this.sslSession);
122143
}
123144

124145
}

src/main/java/io/r2dbc/postgresql/client/ReactorNettyClient.java

+18-1
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import io.netty.handler.codec.LengthFieldBasedFrameDecoder;
2626
import io.netty.handler.logging.LogLevel;
2727
import io.netty.handler.logging.LoggingHandler;
28+
import io.netty.handler.ssl.SslHandler;
2829
import io.netty.util.ReferenceCountUtil;
2930
import io.netty.util.internal.logging.InternalLogger;
3031
import io.netty.util.internal.logging.InternalLoggerFactory;
@@ -148,7 +149,23 @@ private ReactorNettyClient(Connection connection, ConnectionSettings settings) {
148149
connection.addHandlerLast(new LengthFieldBasedFrameDecoder(Integer.MAX_VALUE - 5, 1, 4, -4, 0));
149150
this.connection = connection;
150151
this.byteBufAllocator = connection.outbound().alloc();
151-
this.context = new ConnectionContext().withChannelId(connection.channel().toString());
152+
153+
ConnectionContext connectionContext = new ConnectionContext().withChannelId(connection.channel().toString());
154+
SslHandler sslHandler = this.connection.channel().pipeline().get(SslHandler.class);
155+
156+
if (sslHandler == null) {
157+
SSLSessionHandlerAdapter handlerAdapter = this.connection.channel().pipeline().get(SSLSessionHandlerAdapter.class);
158+
if (handlerAdapter != null) {
159+
sslHandler = handlerAdapter.getSslHandler();
160+
}
161+
}
162+
163+
if (sslHandler != null) {
164+
SslHandler toUse = sslHandler;
165+
connectionContext = connectionContext.withSslSession(() -> toUse.engine().getSession());
166+
}
167+
168+
this.context = connectionContext;
152169

153170
AtomicReference<Throwable> receiveError = new AtomicReference<>();
154171

src/main/java/io/r2dbc/postgresql/client/SSLSessionHandlerAdapter.java

+4-4
Original file line numberDiff line numberDiff line change
@@ -45,15 +45,15 @@ final class SSLSessionHandlerAdapter extends AbstractPostgresSSLHandlerAdapter {
4545

4646
@Override
4747
public void channelActive(ChannelHandlerContext ctx) throws Exception {
48-
if (negotiating) {
48+
if (this.negotiating) {
4949
Mono.from(SSLRequest.INSTANCE.encode(this.alloc)).subscribe(ctx::writeAndFlush);
5050
}
5151
super.channelActive(ctx);
5252
}
5353

5454
@Override
5555
public void channelInactive(ChannelHandlerContext ctx) throws Exception {
56-
if (negotiating) {
56+
if (this.negotiating) {
5757
// If we receive channel inactive before negotiated, then the inbound has closed early.
5858
PostgresqlSslException e = new PostgresqlSslException("Connection closed during SSL negotiation");
5959
completeHandshakeExceptionally(e);
@@ -63,7 +63,7 @@ public void channelInactive(ChannelHandlerContext ctx) throws Exception {
6363

6464
@Override
6565
public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
66-
if (negotiating) {
66+
if (this.negotiating) {
6767
ByteBuf buf = (ByteBuf) msg;
6868
char response = (char) buf.readByte();
6969
try {
@@ -79,7 +79,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception
7979
}
8080
} finally {
8181
buf.release();
82-
negotiating = false;
82+
this.negotiating = false;
8383
}
8484
} else {
8585
super.channelRead(ctx, msg);

src/test/java/io/r2dbc/postgresql/client/ReactorNettyClientIntegrationTests.java

+15
Original file line numberDiff line numberDiff line change
@@ -451,6 +451,21 @@ void exchangeSslWithClientCertNoCert() {
451451
.expectError(R2dbcPermissionDeniedException.class));
452452
}
453453

454+
@Test
455+
void exchangeSslWitScram() {
456+
client(
457+
c -> c
458+
.sslRootCert(SERVER.getServerCrt())
459+
.username("test-ssl-scram")
460+
.password("test-ssl-scram"),
461+
c -> c.map(client -> client.createStatement("SELECT 10")
462+
.execute()
463+
.flatMap(r -> r.map((row, meta) -> row.get(0, Integer.class)))
464+
.as(StepVerifier::create)
465+
.expectNext(10)
466+
.verifyComplete()));
467+
}
468+
454469
@Test
455470
void exchangeSslWithPassword() {
456471
client(

src/test/resources/pg_hba.conf

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
hostnossl all test all md5
22
hostnossl all test-scram all scram-sha-256
33
hostssl all test-ssl all password
4+
hostssl all test-ssl-scram all scram-sha-256
45
hostssl all test-ssl-with-cert all cert
56
local all all md5

0 commit comments

Comments
 (0)