|
1 | 1 | package io.r2dbc.postgresql.authentication;
|
2 | 2 |
|
3 | 3 | import com.ongres.scram.client.ScramClient;
|
4 |
| -import com.ongres.scram.client.ScramSession; |
5 |
| -import com.ongres.scram.common.exception.ScramInvalidServerSignatureException; |
6 |
| -import com.ongres.scram.common.exception.ScramParseException; |
7 |
| -import com.ongres.scram.common.exception.ScramServerErrorException; |
| 4 | +import com.ongres.scram.common.StringPreparation; |
| 5 | +import com.ongres.scram.common.exception.ScramException; |
| 6 | + |
8 | 7 | import io.r2dbc.postgresql.message.backend.AuthenticationMessage;
|
9 | 8 | import io.r2dbc.postgresql.message.backend.AuthenticationSASL;
|
10 | 9 | import io.r2dbc.postgresql.message.backend.AuthenticationSASLContinue;
|
|
17 | 16 | import reactor.core.Exceptions;
|
18 | 17 | import reactor.util.annotation.Nullable;
|
19 | 18 |
|
20 |
| -import static com.ongres.scram.client.ScramClient.ChannelBinding.NO; |
21 |
| -import static com.ongres.scram.common.stringprep.StringPreparations.NO_PREPARATION; |
22 |
| - |
23 | 19 | public class SASLAuthenticationHandler implements AuthenticationHandler {
|
24 | 20 |
|
25 | 21 | private final CharSequence password;
|
26 | 22 |
|
27 | 23 | private final String username;
|
28 | 24 |
|
29 |
| - private ScramSession.ClientFinalProcessor clientFinalProcessor; |
30 |
| - |
31 |
| - private ScramSession scramSession; |
| 25 | + private ScramClient scramClient; |
32 | 26 |
|
33 | 27 | /**
|
34 | 28 | * Create a new handler.
|
@@ -73,35 +67,32 @@ public FrontendMessage handle(AuthenticationMessage message) {
|
73 | 67 | }
|
74 | 68 |
|
75 | 69 | private FrontendMessage handleAuthenticationSASL(AuthenticationSASL message) {
|
76 |
| - ScramClient scramClient = ScramClient |
77 |
| - .channelBinding(NO) |
78 |
| - .stringPreparation(NO_PREPARATION) |
79 |
| - .selectMechanismBasedOnServerAdvertised(message.getAuthenticationMechanisms().toArray(new String[0])) |
80 |
| - .setup(); |
81 |
| - |
82 |
| - this.scramSession = scramClient.scramSession(this.username); |
83 |
| - |
84 |
| - return new SASLInitialResponse(ByteBufferUtils.encode(this.scramSession.clientFirstMessage()), scramClient.getScramMechanism().getName()); |
| 70 | + this.scramClient = ScramClient.builder() |
| 71 | + .advertisedMechanisms(message.getAuthenticationMechanisms()) |
| 72 | + .username(username) // ignored by the server, use startup message |
| 73 | + .password(password.toString().toCharArray()) |
| 74 | + .stringPreparation(StringPreparation.POSTGRESQL_PREPARATION) |
| 75 | + .build(); |
| 76 | + |
| 77 | + return new SASLInitialResponse(ByteBufferUtils.encode(this.scramClient.clientFirstMessage().toString()), scramClient.getScramMechanism().getName()); |
85 | 78 | }
|
86 | 79 |
|
87 | 80 | private FrontendMessage handleAuthenticationSASLContinue(AuthenticationSASLContinue message) {
|
88 | 81 | try {
|
89 |
| - this.clientFinalProcessor = this.scramSession |
90 |
| - .receiveServerFirstMessage(ByteBufferUtils.decode(message.getData())) |
91 |
| - .clientFinalProcessor(this.password.toString()); |
| 82 | + this.scramClient.serverFirstMessage(ByteBufferUtils.decode(message.getData())); |
92 | 83 |
|
93 |
| - return new SASLResponse(ByteBufferUtils.encode(clientFinalProcessor.clientFinalMessage())); |
94 |
| - } catch (ScramParseException e) { |
| 84 | + return new SASLResponse(ByteBufferUtils.encode(this.scramClient.clientFinalMessage().toString())); |
| 85 | + } catch (ScramException e) { |
95 | 86 | throw Exceptions.propagate(e);
|
96 | 87 | }
|
97 | 88 | }
|
98 | 89 |
|
99 | 90 | @Nullable
|
100 | 91 | private FrontendMessage handleAuthenticationSASLFinal(AuthenticationSASLFinal message) {
|
101 | 92 | try {
|
102 |
| - this.clientFinalProcessor.receiveServerFinalMessage(ByteBufferUtils.decode(message.getAdditionalData())); |
| 93 | + this.scramClient.serverFinalMessage(ByteBufferUtils.decode(message.getAdditionalData())); |
103 | 94 | return null;
|
104 |
| - } catch (ScramParseException | ScramInvalidServerSignatureException | ScramServerErrorException e) { |
| 95 | + } catch (ScramException e) { |
105 | 96 | throw Exceptions.propagate(e);
|
106 | 97 | }
|
107 | 98 | }
|
|
0 commit comments