|
3 | 3 | import com.ongres.scram.client.ScramClient;
|
4 | 4 | import com.ongres.scram.common.StringPreparation;
|
5 | 5 | import com.ongres.scram.common.exception.ScramException;
|
6 |
| - |
| 6 | +import com.ongres.scram.common.util.TlsServerEndpoint; |
| 7 | +import io.r2dbc.postgresql.client.ConnectionContext; |
7 | 8 | import io.r2dbc.postgresql.message.backend.AuthenticationMessage;
|
8 | 9 | import io.r2dbc.postgresql.message.backend.AuthenticationSASL;
|
9 | 10 | import io.r2dbc.postgresql.message.backend.AuthenticationSASLContinue;
|
|
14 | 15 | import io.r2dbc.postgresql.util.Assert;
|
15 | 16 | import io.r2dbc.postgresql.util.ByteBufferUtils;
|
16 | 17 | import reactor.core.Exceptions;
|
| 18 | +import reactor.util.Logger; |
| 19 | +import reactor.util.Loggers; |
17 | 20 | import reactor.util.annotation.Nullable;
|
18 | 21 |
|
| 22 | +import javax.net.ssl.SSLException; |
| 23 | +import javax.net.ssl.SSLSession; |
| 24 | +import java.security.cert.Certificate; |
| 25 | +import java.security.cert.CertificateException; |
| 26 | +import java.security.cert.X509Certificate; |
| 27 | + |
19 | 28 | public class SASLAuthenticationHandler implements AuthenticationHandler {
|
20 | 29 |
|
| 30 | + private static final Logger LOG = Loggers.getLogger(SASLAuthenticationHandler.class); |
| 31 | + |
21 | 32 | private final CharSequence password;
|
22 | 33 |
|
23 | 34 | private final String username;
|
24 | 35 |
|
| 36 | + private final ConnectionContext context; |
| 37 | + |
25 | 38 | private ScramClient scramClient;
|
26 | 39 |
|
27 | 40 | /**
|
28 | 41 | * Create a new handler.
|
29 | 42 | *
|
30 | 43 | * @param password the password to use for authentication
|
31 | 44 | * @param username the username to use for authentication
|
| 45 | + * @param context the connection context |
32 | 46 | * @throws IllegalArgumentException if {@code password} or {@code user} is {@code null}
|
33 | 47 | */
|
34 |
| - public SASLAuthenticationHandler(CharSequence password, String username) { |
| 48 | + public SASLAuthenticationHandler(CharSequence password, String username, ConnectionContext context) { |
35 | 49 | this.password = Assert.requireNonNull(password, "password must not be null");
|
36 | 50 | this.username = Assert.requireNonNull(username, "username must not be null");
|
| 51 | + this.context = Assert.requireNonNull(context, "context must not be null"); |
37 | 52 | }
|
38 | 53 |
|
39 | 54 | /**
|
@@ -67,14 +82,44 @@ public FrontendMessage handle(AuthenticationMessage message) {
|
67 | 82 | }
|
68 | 83 |
|
69 | 84 | private FrontendMessage handleAuthenticationSASL(AuthenticationSASL message) {
|
70 |
| - this.scramClient = ScramClient.builder() |
| 85 | + |
| 86 | + char[] password = new char[this.password.length()]; |
| 87 | + for (int i = 0; i < password.length; i++) { |
| 88 | + password[i] = this.password.charAt(i); |
| 89 | + } |
| 90 | + |
| 91 | + ScramClient.FinalBuildStage builder = ScramClient.builder() |
71 | 92 | .advertisedMechanisms(message.getAuthenticationMechanisms())
|
72 |
| - .username(username) // ignored by the server, use startup message |
73 |
| - .password(password.toString().toCharArray()) |
74 |
| - .stringPreparation(StringPreparation.POSTGRESQL_PREPARATION) |
75 |
| - .build(); |
| 93 | + .username(this.username) // ignored by the server, use startup message |
| 94 | + .password(password) |
| 95 | + .stringPreparation(StringPreparation.POSTGRESQL_PREPARATION); |
| 96 | + |
| 97 | + SSLSession sslSession = this.context.getSslSession(); |
76 | 98 |
|
77 |
| - return new SASLInitialResponse(ByteBufferUtils.encode(this.scramClient.clientFirstMessage().toString()), scramClient.getScramMechanism().getName()); |
| 99 | + if (sslSession != null && sslSession.isValid()) { |
| 100 | + builder.channelBinding(TlsServerEndpoint.TLS_SERVER_END_POINT, extractSslEndpoint(sslSession)); |
| 101 | + } |
| 102 | + |
| 103 | + this.scramClient = builder.build(); |
| 104 | + |
| 105 | + return new SASLInitialResponse(ByteBufferUtils.encode(this.scramClient.clientFirstMessage().toString()), this.scramClient.getScramMechanism().getName()); |
| 106 | + } |
| 107 | + |
| 108 | + private static byte[] extractSslEndpoint(SSLSession sslSession) { |
| 109 | + try { |
| 110 | + Certificate[] certificates = sslSession.getPeerCertificates(); |
| 111 | + if (certificates != null && certificates.length > 0) { |
| 112 | + Certificate peerCert = certificates[0]; // First certificate is the peer's certificate |
| 113 | + if (peerCert instanceof X509Certificate) { |
| 114 | + X509Certificate cert = (X509Certificate) peerCert; |
| 115 | + return TlsServerEndpoint.getChannelBindingData(cert); |
| 116 | + |
| 117 | + } |
| 118 | + } |
| 119 | + } catch (CertificateException | SSLException e) { |
| 120 | + LOG.debug("Cannot extract X509Certificate from SSL session", e); |
| 121 | + } |
| 122 | + return new byte[0]; |
78 | 123 | }
|
79 | 124 |
|
80 | 125 | private FrontendMessage handleAuthenticationSASLContinue(AuthenticationSASLContinue message) {
|
|
0 commit comments