diff --git a/src/main/java/io/r2dbc/postgresql/ClientSupplier.java b/src/main/java/io/r2dbc/postgresql/ClientSupplier.java new file mode 100644 index 000000000..5b49740fe --- /dev/null +++ b/src/main/java/io/r2dbc/postgresql/ClientSupplier.java @@ -0,0 +1,13 @@ +package io.r2dbc.postgresql; + +import io.r2dbc.postgresql.client.Client; +import io.r2dbc.postgresql.client.ConnectionSettings; +import reactor.core.publisher.Mono; + +import java.net.SocketAddress; + +public interface ClientSupplier { + + Mono connect(SocketAddress endpoint, ConnectionSettings settings); + +} diff --git a/src/main/java/io/r2dbc/postgresql/ConnectionStrategy.java b/src/main/java/io/r2dbc/postgresql/ConnectionStrategy.java new file mode 100644 index 000000000..d6a881a11 --- /dev/null +++ b/src/main/java/io/r2dbc/postgresql/ConnectionStrategy.java @@ -0,0 +1,31 @@ +package io.r2dbc.postgresql; + +import io.r2dbc.postgresql.client.Client; +import io.r2dbc.postgresql.client.ConnectionSettings; +import reactor.core.publisher.Mono; + +import java.net.SocketAddress; +import java.util.Map; +import java.util.function.Function; + +public interface ConnectionStrategy { + + Mono connect(); + + ConnectionStrategy withOptions(Map options); + + interface ComposableConnectionStrategy extends ConnectionStrategy { + + default T chainIf(boolean guard, Function nextStrategyProvider, Class klass) { + return guard ? nextStrategyProvider.apply(this) : klass.cast(this); + } + + ComposableConnectionStrategy withAddress(SocketAddress address); + + ComposableConnectionStrategy withConnectionSettings(ConnectionSettings connectionSettings); + + ComposableConnectionStrategy withOptions(Map options); + + } + +} diff --git a/src/main/java/io/r2dbc/postgresql/ConnectionStrategyFactory.java b/src/main/java/io/r2dbc/postgresql/ConnectionStrategyFactory.java new file mode 100644 index 000000000..3fd41b417 --- /dev/null +++ b/src/main/java/io/r2dbc/postgresql/ConnectionStrategyFactory.java @@ -0,0 +1,47 @@ +package io.r2dbc.postgresql; + +import io.netty.channel.unix.DomainSocketAddress; +import io.r2dbc.postgresql.client.MultiHostConfiguration; +import io.r2dbc.postgresql.client.SSLConfig; +import io.r2dbc.postgresql.client.SSLMode; +import io.r2dbc.postgresql.client.SingleHostConfiguration; + +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.util.ArrayList; +import java.util.List; + +public class ConnectionStrategyFactory { + + public static ConnectionStrategy getConnectionStrategy(ClientSupplier clientSupplier, PostgresqlConnectionConfiguration configuration) { + SingleHostConfiguration singleHostConfiguration = configuration.getSingleHostConfiguration(); + MultiHostConfiguration multiHostConfiguration = configuration.getMultiHostConfiguration(); + SSLConfig sslConfig = configuration.getSslConfig(); + SocketAddress address = singleHostConfiguration != null ? createSocketAddress(singleHostConfiguration) : null; + return new DefaultConnectionStrategy(address, clientSupplier, configuration, configuration.getConnectionSettings(), configuration.getOptions()) + .chainIf(!SSLMode.DISABLE.equals(sslConfig.getSslMode()), strategy -> new SslFallbackConnectionStrategy(configuration, strategy), ConnectionStrategy.ComposableConnectionStrategy.class) + .chainIf(multiHostConfiguration != null, strategy -> new MultiHostConnectionStrategy(createSocketAddress(multiHostConfiguration), configuration, strategy), ConnectionStrategy.class); + } + + private static SocketAddress createSocketAddress(SingleHostConfiguration configuration) { + if (!configuration.isUseSocket()) { + return InetSocketAddress.createUnresolved(configuration.getRequiredHost(), configuration.getPort()); + } + return DomainSocketFactory.getDomainSocketAddress(configuration); + } + + static class DomainSocketFactory { + private static SocketAddress getDomainSocketAddress(SingleHostConfiguration configuration) { + return new DomainSocketAddress(configuration.getRequiredSocket()); + } + } + + private static List createSocketAddress(MultiHostConfiguration configuration) { + List addressList = new ArrayList<>(configuration.getHosts().size()); + for (MultiHostConfiguration.ServerHost host : configuration.getHosts()) { + addressList.add(InetSocketAddress.createUnresolved(host.getHost(), host.getPort())); + } + return addressList; + } + +} diff --git a/src/main/java/io/r2dbc/postgresql/DefaultConnectionStrategy.java b/src/main/java/io/r2dbc/postgresql/DefaultConnectionStrategy.java new file mode 100644 index 000000000..a1b6af87d --- /dev/null +++ b/src/main/java/io/r2dbc/postgresql/DefaultConnectionStrategy.java @@ -0,0 +1,79 @@ +package io.r2dbc.postgresql; + +import io.r2dbc.postgresql.authentication.AuthenticationHandler; +import io.r2dbc.postgresql.authentication.PasswordAuthenticationHandler; +import io.r2dbc.postgresql.authentication.SASLAuthenticationHandler; +import io.r2dbc.postgresql.client.Client; +import io.r2dbc.postgresql.client.ConnectionSettings; +import io.r2dbc.postgresql.client.StartupMessageFlow; +import io.r2dbc.postgresql.message.backend.AuthenticationMessage; +import io.r2dbc.postgresql.util.Assert; +import reactor.core.publisher.Mono; + +import javax.annotation.Nullable; +import java.net.SocketAddress; +import java.util.Map; + +public class DefaultConnectionStrategy implements ConnectionStrategy.ComposableConnectionStrategy { + + private final SocketAddress address; + + private final ClientSupplier clientSupplier; + + private final PostgresqlConnectionConfiguration configuration; + + private final ConnectionSettings connectionSettings; + + private final Map options; + + DefaultConnectionStrategy( + @Nullable SocketAddress address, + ClientSupplier clientSupplier, + PostgresqlConnectionConfiguration configuration, + ConnectionSettings connectionSettings, + @Nullable Map options + ) { + this.address = address; + this.clientSupplier = clientSupplier; + this.configuration = configuration; + this.connectionSettings = connectionSettings; + this.options = options; + } + + @Override + public Mono connect() { + Assert.requireNonNull(this.address, "address must not be null"); + return this.clientSupplier.connect(this.address, this.connectionSettings) + .delayUntil(client -> StartupMessageFlow + .exchange(this.configuration.getApplicationName(), this::getAuthenticationHandler, client, this.configuration.getDatabase(), this.configuration.getUsername(), this.options) + .handle(ExceptionFactory.INSTANCE::handleErrorResponse)); + } + + @Override + public ComposableConnectionStrategy withAddress(SocketAddress address) { + return new DefaultConnectionStrategy(address, this.clientSupplier, this.configuration, this.connectionSettings, this.options); + } + + @Override + public ComposableConnectionStrategy withConnectionSettings(ConnectionSettings connectionSettings) { + return new DefaultConnectionStrategy(this.address, this.clientSupplier, this.configuration, connectionSettings, this.options); + } + + @Override + public ComposableConnectionStrategy withOptions(Map options) { + return new DefaultConnectionStrategy(this.address, this.clientSupplier, this.configuration, this.connectionSettings, options); + } + + protected AuthenticationHandler getAuthenticationHandler(AuthenticationMessage message) { + if (PasswordAuthenticationHandler.supports(message)) { + CharSequence password = Assert.requireNonNull(this.configuration.getPassword(), "Password must not be null"); + return new PasswordAuthenticationHandler(password, this.configuration.getUsername()); + } else if (SASLAuthenticationHandler.supports(message)) { + CharSequence password = Assert.requireNonNull(this.configuration.getPassword(), "Password must not be null"); + return new SASLAuthenticationHandler(password, this.configuration.getUsername()); + } else { + throw new IllegalStateException(String.format("Unable to provide AuthenticationHandler capable of handling %s", message)); + } + } + +} diff --git a/src/main/java/io/r2dbc/postgresql/MultiHostConnectionStrategy.java b/src/main/java/io/r2dbc/postgresql/MultiHostConnectionStrategy.java new file mode 100644 index 000000000..95bee4b8f --- /dev/null +++ b/src/main/java/io/r2dbc/postgresql/MultiHostConnectionStrategy.java @@ -0,0 +1,195 @@ +package io.r2dbc.postgresql; + +import io.r2dbc.postgresql.client.Client; +import io.r2dbc.postgresql.client.MultiHostConfiguration; +import io.r2dbc.postgresql.codec.DefaultCodecs; +import io.r2dbc.spi.IsolationLevel; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.util.context.Context; + +import javax.annotation.Nullable; +import java.net.SocketAddress; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Predicate; + +import static io.r2dbc.postgresql.TargetServerType.ANY; +import static io.r2dbc.postgresql.TargetServerType.MASTER; +import static io.r2dbc.postgresql.TargetServerType.PREFER_SECONDARY; +import static io.r2dbc.postgresql.TargetServerType.SECONDARY; + +public class MultiHostConnectionStrategy implements ConnectionStrategy { + + private final List addresses; + + private final PostgresqlConnectionConfiguration configuration; + + private final ComposableConnectionStrategy connectionStrategy; + + private final MultiHostConfiguration multiHostConfiguration; + + private final Map statusMap; + + MultiHostConnectionStrategy(List addresses, PostgresqlConnectionConfiguration configuration, ComposableConnectionStrategy connectionStrategy) { + this.addresses = addresses; + this.configuration = configuration; + this.connectionStrategy = connectionStrategy; + this.multiHostConfiguration = this.configuration.getMultiHostConfiguration(); + this.statusMap = new ConcurrentHashMap<>(); + } + + @Override + public Mono connect() { + AtomicReference exceptionRef = new AtomicReference<>(); + TargetServerType targetServerType = this.multiHostConfiguration.getTargetServerType(); + return this.tryConnect(targetServerType) + .onErrorResume(e -> { + if (!exceptionRef.compareAndSet(null, e)) { + exceptionRef.get().addSuppressed(e); + } + return Mono.empty(); + }) + .switchIfEmpty(Mono.defer(() -> targetServerType == PREFER_SECONDARY + ? this.tryConnect(MASTER) + : Mono.empty())) + .switchIfEmpty(Mono.error(() -> { + Throwable error = exceptionRef.get(); + if (error == null) { + return new PostgresqlConnectionFactory.PostgresConnectionException(String.format("No server matches target type %s", targetServerType.getValue()), null); + } else { + return error; + } + })); + } + + @Override + public ConnectionStrategy withOptions(Map options) { + return new MultiHostConnectionStrategy(this.addresses, this.configuration, this.connectionStrategy.withOptions(options)); + } + + private Mono tryConnect(TargetServerType targetServerType) { + AtomicReference exceptionRef = new AtomicReference<>(); + return this.getCandidates(targetServerType).concatMap(candidate -> this.tryConnectToCandidate(targetServerType, candidate) + .onErrorResume(e -> { + if (!exceptionRef.compareAndSet(null, e)) { + exceptionRef.get().addSuppressed(e); + } + this.statusMap.put(candidate, HostSpecStatus.fail(candidate)); + return Mono.empty(); + })) + .next() + .switchIfEmpty(Mono.defer(() -> exceptionRef.get() != null + ? Mono.error(exceptionRef.get()) + : Mono.empty())); + } + + private static HostSpecStatus evaluateStatus(SocketAddress candidate, @Nullable HostSpecStatus oldStatus) { + return oldStatus == null || oldStatus.hostStatus == HostStatus.CONNECT_FAIL + ? HostSpecStatus.ok(candidate) + : oldStatus; + } + + private static Mono isPrimaryServer(Client client, PostgresqlConnectionConfiguration configuration) { + PostgresqlConnection connection = new PostgresqlConnection(client, new DefaultCodecs(client.getByteBufAllocator()), DefaultPortalNameSupplier.INSTANCE, + StatementCache.fromPreparedStatementCacheQueries(client, configuration.getPreparedStatementCacheQueries()), IsolationLevel.READ_UNCOMMITTED, configuration); + return connection.createStatement("show transaction_read_only") + .execute() + .flatMap(result -> result.map((row, rowMetadata) -> row.get(0, String.class))) + .map(s -> s.equalsIgnoreCase("off")) + .next(); + } + + private Flux getCandidates(TargetServerType targetServerType) { + return Flux.create(sink -> { + Predicate needsRecheck = updated -> System.currentTimeMillis() > updated + this.multiHostConfiguration.getHostRecheckTime().toMillis(); + List addresses = new ArrayList<>(this.addresses); + if (this.multiHostConfiguration.isLoadBalanceHosts()) { + Collections.shuffle(addresses); + } + boolean addressEmitted = false; + for (SocketAddress address : addresses) { + HostSpecStatus currentStatus = this.statusMap.get(address); + if (currentStatus == null || needsRecheck.test(currentStatus.updated) || targetServerType.allowStatus(currentStatus.hostStatus)) { + sink.next(address); + addressEmitted = true; + } + } + if (!addressEmitted) { + // if no candidate matches the requirement or all of them are in unavailable status, try all the hosts + for (SocketAddress address : addresses) { + sink.next(address); + } + } + sink.complete(); + }); + } + + private Mono tryConnectToCandidate(TargetServerType targetServerType, SocketAddress candidate) { + return Mono.create(sink -> this.connectionStrategy.withAddress(candidate).connect().subscribe(client -> { + this.statusMap.compute(candidate, (a, oldStatus) -> evaluateStatus(candidate, oldStatus)); + if (targetServerType == ANY) { + sink.success(client); + return; + } + isPrimaryServer(client, this.configuration).subscribe( + isPrimary -> { + if (isPrimary) { + this.statusMap.put(candidate, HostSpecStatus.primary(candidate)); + } else { + this.statusMap.put(candidate, HostSpecStatus.standby(candidate)); + } + if (isPrimary && targetServerType == MASTER) { + sink.success(client); + } else if (!isPrimary && (targetServerType == SECONDARY || targetServerType == PREFER_SECONDARY)) { + sink.success(client); + } else { + client.close().subscribe(v -> sink.success(), sink::error, sink::success, Context.of(sink.contextView())); + } + }, + sink::error, () -> {}, Context.of(sink.contextView())); + }, sink::error, () -> {}, Context.of(sink.contextView()))); + } + + enum HostStatus { + CONNECT_FAIL, + CONNECT_OK, + PRIMARY, + STANDBY + } + + private static class HostSpecStatus { + + public final SocketAddress address; + + public final HostStatus hostStatus; + + public final long updated; + + private HostSpecStatus(SocketAddress address, HostStatus hostStatus) { + this.address = address; + this.hostStatus = hostStatus; + this.updated = System.currentTimeMillis(); + } + + public static HostSpecStatus fail(SocketAddress host) { + return new HostSpecStatus(host, HostStatus.CONNECT_FAIL); + } + + public static HostSpecStatus ok(SocketAddress host) { + return new HostSpecStatus(host, HostStatus.CONNECT_OK); + } + + public static HostSpecStatus primary(SocketAddress host) { + return new HostSpecStatus(host, HostStatus.PRIMARY); + } + + public static HostSpecStatus standby(SocketAddress host) { + return new HostSpecStatus(host, HostStatus.STANDBY); + } + } +} diff --git a/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionConfiguration.java b/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionConfiguration.java index 8caba3bbf..f05495d9a 100644 --- a/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionConfiguration.java +++ b/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionConfiguration.java @@ -23,8 +23,10 @@ import io.netty.handler.ssl.util.InsecureTrustManagerFactory; import io.r2dbc.postgresql.client.ConnectionSettings; import io.r2dbc.postgresql.client.DefaultHostnameVerifier; +import io.r2dbc.postgresql.client.MultiHostConfiguration; import io.r2dbc.postgresql.client.SSLConfig; import io.r2dbc.postgresql.client.SSLMode; +import io.r2dbc.postgresql.client.SingleHostConfiguration; import io.r2dbc.postgresql.codec.Codec; import io.r2dbc.postgresql.codec.Codecs; import io.r2dbc.postgresql.codec.Json; @@ -86,30 +88,30 @@ public final class PostgresqlConnectionConfiguration { private final boolean forceBinary; - private final String host; - @Nullable private final Duration lockWaitTimeout; @Nullable private final LoopResources loopResources; + @Nullable + private final MultiHostConfiguration multiHostConfiguration; + private final LogLevel noticeLogLevel; private final Map options; private final CharSequence password; - private final int port; - private final boolean preferAttachedBuffers; private final int preparedStatementCacheQueries; @Nullable - private final Duration statementTimeout; + private final SingleHostConfiguration singleHostConfiguration; - private final String socket; + @Nullable + private final Duration statementTimeout; private final SSLConfig sslConfig; @@ -121,10 +123,12 @@ public final class PostgresqlConnectionConfiguration { private PostgresqlConnectionConfiguration(String applicationName, boolean autodetectExtensions, @Nullable boolean compatibilityMode, Duration connectTimeout, @Nullable String database, LogLevel errorResponseLogLevel, - List extensions, ToIntFunction fetchSize, boolean forceBinary, @Nullable String host, @Nullable Duration lockWaitTimeout, + List extensions, ToIntFunction fetchSize, boolean forceBinary, @Nullable Duration lockWaitTimeout, @Nullable LoopResources loopResources, - LogLevel noticeLogLevel, @Nullable Map options, @Nullable CharSequence password, int port, boolean preferAttachedBuffers, - int preparedStatementCacheQueries, @Nullable String schema, @Nullable String socket, SSLConfig sslConfig, @Nullable Duration statementTimeout, + @Nullable MultiHostConfiguration multiHostConfiguration, + LogLevel noticeLogLevel, @Nullable Map options, @Nullable CharSequence password, boolean preferAttachedBuffers, + int preparedStatementCacheQueries, @Nullable String schema, + @Nullable SingleHostConfiguration singleHostConfiguration, SSLConfig sslConfig, @Nullable Duration statementTimeout, boolean tcpKeepAlive, boolean tcpNoDelay, String username) { this.applicationName = Assert.requireNonNull(applicationName, "applicationName must not be null"); @@ -136,8 +140,8 @@ private PostgresqlConnectionConfiguration(String applicationName, boolean autode this.database = database; this.fetchSize = fetchSize; this.forceBinary = forceBinary; - this.host = host; this.loopResources = loopResources; + this.multiHostConfiguration = multiHostConfiguration; this.noticeLogLevel = noticeLogLevel; this.options = options == null ? new LinkedHashMap<>() : new LinkedHashMap<>(options); this.statementTimeout = statementTimeout; @@ -156,10 +160,9 @@ private PostgresqlConnectionConfiguration(String applicationName, boolean autode } this.password = password; - this.port = port; this.preferAttachedBuffers = preferAttachedBuffers; this.preparedStatementCacheQueries = preparedStatementCacheQueries; - this.socket = socket; + this.singleHostConfiguration = singleHostConfiguration; this.sslConfig = sslConfig; this.tcpKeepAlive = tcpKeepAlive; this.tcpNoDelay = tcpNoDelay; @@ -187,15 +190,14 @@ public String toString() { ", extensions=" + this.extensions + ", fetchSize=" + this.fetchSize + ", forceBinary='" + this.forceBinary + '\'' + - ", host='" + this.host + '\'' + ", lockWaitTimeout='" + this.lockWaitTimeout + ", loopResources='" + this.loopResources + '\'' + + ", multiHostConfiguration='" + this.multiHostConfiguration + '\'' + ", noticeLogLevel='" + this.noticeLogLevel + '\'' + ", options='" + this.options + '\'' + ", password='" + obfuscate(this.password != null ? this.password.length() : 0) + '\'' + - ", port=" + this.port + ", preferAttachedBuffers=" + this.preferAttachedBuffers + - ", socket=" + this.socket + + ", singleHostConfiguration=" + this.singleHostConfiguration + ", statementTimeout=" + this.statementTimeout + ", tcpKeepAlive=" + this.tcpKeepAlive + ", tcpNoDelay=" + this.tcpNoDelay + @@ -234,19 +236,8 @@ int getFetchSize(String sql) { } @Nullable - String getHost() { - return this.host; - } - - String getRequiredHost() { - - String host = getHost(); - - if (host == null || host.isEmpty()) { - throw new IllegalStateException("Connection is configured for socket connections and not for host usage"); - } - - return host; + MultiHostConfiguration getMultiHostConfiguration() { + return this.multiHostConfiguration; } Map getOptions() { @@ -258,10 +249,6 @@ CharSequence getPassword() { return this.password; } - int getPort() { - return this.port; - } - boolean isPreferAttachedBuffers() { return this.preferAttachedBuffers; } @@ -271,19 +258,8 @@ int getPreparedStatementCacheQueries() { } @Nullable - String getSocket() { - return this.socket; - } - - String getRequiredSocket() { - - String socket = getSocket(); - - if (socket == null || socket.isEmpty()) { - throw new IllegalStateException("Connection is configured to use host and port connections and not for socket usage"); - } - - return socket; + SingleHostConfiguration getSingleHostConfiguration() { + return this.singleHostConfiguration; } String getUsername() { @@ -306,10 +282,6 @@ boolean isTcpNoDelay() { return this.tcpNoDelay; } - boolean isUseSocket() { - return getSocket() != null; - } - SSLConfig getSslConfig() { return this.sslConfig; } @@ -365,10 +337,10 @@ public static final class Builder { private boolean forceBinary = false; @Nullable - private String host; + private Duration lockWaitTimeout; @Nullable - private Duration lockWaitTimeout; + private MultiHostConfiguration.Builder multiHostConfiguration; private LogLevel noticeLogLevel = LogLevel.DEBUG; @@ -377,8 +349,6 @@ public static final class Builder { @Nullable private CharSequence password; - private int port = DEFAULT_PORT; - private boolean preferAttachedBuffers = false; private int preparedStatementCacheQueries = -1; @@ -387,7 +357,7 @@ public static final class Builder { private String schema; @Nullable - private String socket; + private SingleHostConfiguration.Builder singleHostConfiguration; @Nullable private URL sslCert = null; @@ -453,12 +423,16 @@ public Builder autodetectExtensions(boolean autodetectExtensions) { */ public PostgresqlConnectionConfiguration build() { - if (this.host == null && this.socket == null) { - throw new IllegalArgumentException("host or socket must not be null"); - } + SingleHostConfiguration singleHostConfiguration = this.singleHostConfiguration != null + ? this.singleHostConfiguration.build() + : null; + + MultiHostConfiguration multiHostConfiguration = this.multiHostConfiguration != null + ? this.multiHostConfiguration.build() + : null; - if (this.host != null && this.socket != null) { - throw new IllegalArgumentException("Connection must be configured for either host/port or socket usage but not both"); + if (!(singleHostConfiguration == null ^ multiHostConfiguration == null)) { + throw new IllegalArgumentException("either multiHostConfiguration or singleHostConfiguration must not be null"); } if (this.username == null) { @@ -467,9 +441,12 @@ public PostgresqlConnectionConfiguration build() { return new PostgresqlConnectionConfiguration(this.applicationName, this.autodetectExtensions, this.compatibilityMode, this.connectTimeout, this.database, this.errorResponseLogLevel, this.extensions, - this.fetchSize - , this.forceBinary, this.host, this.lockWaitTimeout, this.loopResources, this.noticeLogLevel, this.options, this.password, this.port, this.preferAttachedBuffers, - this.preparedStatementCacheQueries, this.schema, this.socket, this.createSslConfig(), this.statementTimeout, this.tcpKeepAlive, this.tcpNoDelay, this.username); + this.fetchSize, this.forceBinary, this.lockWaitTimeout, this.loopResources, + multiHostConfiguration, + this.noticeLogLevel, this.options, this.password, this.preferAttachedBuffers, + this.preparedStatementCacheQueries, this.schema, + singleHostConfiguration, + this.createSslConfig(), this.statementTimeout, this.tcpKeepAlive, this.tcpNoDelay, this.username); } /** @@ -597,7 +574,72 @@ public Builder forceBinary(boolean forceBinary) { * @throws IllegalArgumentException if {@code host} is {@code null} */ public Builder host(String host) { - this.host = Assert.requireNonNull(host, "host must not be null"); + Assert.requireNonNull(host, "host must not be null"); + if (this.singleHostConfiguration == null) { + this.singleHostConfiguration = SingleHostConfiguration.builder(); + } + this.singleHostConfiguration.host(host); + return this; + } + + /** + * Add host with default port to hosts list. + * + * @param host the host + * @return this {@link Builder} + * @throws IllegalArgumentException if {@code host} is {@code null} + */ + public Builder addHost(String host) { + Assert.requireNonNull(host, "host must not be null"); + if (this.multiHostConfiguration == null) { + this.multiHostConfiguration = MultiHostConfiguration.builder(); + } + this.multiHostConfiguration.addHost(host); + return this; + } + + /** + * Add host to hosts list. + * + * @param host the host + * @param port the port + * @return this {@link Builder} + * @throws IllegalArgumentException if {@code host} is {@code null} + */ + public Builder addHost(String host, int port) { + Assert.requireNonNull(host, "host must not be null"); + if (this.multiHostConfiguration == null) { + this.multiHostConfiguration = MultiHostConfiguration.builder(); + } + this.multiHostConfiguration.addHost(host, port); + return this; + } + + /** + * Controls how long the knowledge about a host state is cached connection factory. The default value is 10000 milliseconds. + * + * @param hostRecheckTime host recheck time + * @return this {@link Builder} + */ + public Builder hostRecheckTime(Duration hostRecheckTime) { + if (this.multiHostConfiguration == null) { + this.multiHostConfiguration = MultiHostConfiguration.builder(); + } + this.multiHostConfiguration.hostRecheckTime(hostRecheckTime); + return this; + } + + /** + * In default mode (disabled) hosts are connected in the given order. If enabled hosts are chosen randomly from the set of suitable candidates. + * + * @param loadBalanceHosts is load balance mode enabled + * @return this {@link Builder} + */ + public Builder loadBalanceHosts(boolean loadBalanceHosts) { + if (this.multiHostConfiguration == null) { + this.multiHostConfiguration = MultiHostConfiguration.builder(); + } + this.multiHostConfiguration.loadBalanceHosts(loadBalanceHosts); return this; } @@ -682,7 +724,10 @@ public Builder password(@Nullable CharSequence password) { * @return this {@link Builder} */ public Builder port(int port) { - this.port = port; + if (this.singleHostConfiguration == null) { + this.singleHostConfiguration = SingleHostConfiguration.builder(); + } + this.singleHostConfiguration.port(port); return this; } @@ -732,7 +777,12 @@ public Builder schema(@Nullable String schema) { * @throws IllegalArgumentException if {@code socket} is {@code null} */ public Builder socket(String socket) { - this.socket = Assert.requireNonNull(socket, "host must not be null"); + Assert.requireNonNull(socket, "host must not be null"); + if (this.singleHostConfiguration == null) { + this.singleHostConfiguration = SingleHostConfiguration.builder(); + } + this.singleHostConfiguration.socket(socket); + sslMode(SSLMode.DISABLE); return this; } @@ -866,6 +916,24 @@ public Builder statementTimeout(Duration statementTimeout) { return this; } + /** + * Allows opening connections to only servers with required state, the allowed values are any, master, secondary and preferSecondary. + * The master/secondary distinction is currently done by observing if the server allows writes. + * The value preferSecondary tries to connect to secondary if any are available, otherwise allows falls back to connecting also to master. + * Default value is any. + * + * @param targetServerType target server type + * @return this {@link Builder} + * @throws IllegalArgumentException if {@code targetServerType} is {@code null} + */ + public Builder targetServerType(TargetServerType targetServerType) { + if (this.multiHostConfiguration == null) { + this.multiHostConfiguration = MultiHostConfiguration.builder(); + } + this.multiHostConfiguration.targetServerType(targetServerType); + return this; + } + /** * Configure TCP KeepAlive. * @@ -916,16 +984,15 @@ public String toString() { ", errorResponseLogLevel='" + this.errorResponseLogLevel + '\'' + ", fetchSize='" + this.fetchSize + '\'' + ", forceBinary='" + this.forceBinary + '\'' + - ", host='" + this.host + '\'' + ", lockWaitTimeout='" + this.lockWaitTimeout + '\'' + ", loopResources='" + this.loopResources + '\'' + + ", multiHostConfiguration='" + this.multiHostConfiguration + '\'' + ", noticeLogLevel='" + this.noticeLogLevel + '\'' + ", parameters='" + this.options + '\'' + ", password='" + obfuscate(this.password != null ? this.password.length() : 0) + '\'' + - ", port=" + this.port + ", preparedStatementCacheQueries='" + this.preparedStatementCacheQueries + '\'' + ", schema='" + this.schema + '\'' + - ", socket='" + this.socket + '\'' + + ", singleHostConfiguration='" + this.singleHostConfiguration + '\'' + ", sslContextBuilderCustomizer='" + this.sslContextBuilderCustomizer + '\'' + ", sslMode='" + this.sslMode + '\'' + ", sslRootCert='" + this.sslRootCert + '\'' + @@ -940,7 +1007,7 @@ public String toString() { } private SSLConfig createSslConfig() { - if (this.socket != null || this.sslMode == SSLMode.DISABLE) { + if (this.singleHostConfiguration != null && this.singleHostConfiguration.getSocket() != null || this.sslMode == SSLMode.DISABLE) { return SSLConfig.disabled(); } diff --git a/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionFactory.java b/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionFactory.java index 44e4ed1fb..1d7cd333a 100644 --- a/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionFactory.java +++ b/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionFactory.java @@ -17,19 +17,10 @@ package io.r2dbc.postgresql; import io.netty.buffer.ByteBufAllocator; -import io.netty.channel.unix.DomainSocketAddress; -import io.r2dbc.postgresql.authentication.AuthenticationHandler; -import io.r2dbc.postgresql.authentication.PasswordAuthenticationHandler; -import io.r2dbc.postgresql.authentication.SASLAuthenticationHandler; import io.r2dbc.postgresql.client.Client; -import io.r2dbc.postgresql.client.ConnectionSettings; import io.r2dbc.postgresql.client.ReactorNettyClient; -import io.r2dbc.postgresql.client.SSLConfig; -import io.r2dbc.postgresql.client.SSLMode; -import io.r2dbc.postgresql.client.StartupMessageFlow; import io.r2dbc.postgresql.codec.DefaultCodecs; import io.r2dbc.postgresql.extension.CodecRegistrar; -import io.r2dbc.postgresql.message.backend.AuthenticationMessage; import io.r2dbc.postgresql.util.Assert; import io.r2dbc.spi.ConnectionFactory; import io.r2dbc.spi.ConnectionFactoryMetadata; @@ -41,31 +32,28 @@ import reactor.core.publisher.Mono; import reactor.util.annotation.Nullable; -import java.net.InetSocketAddress; -import java.net.SocketAddress; import java.util.ArrayList; import java.util.LinkedHashMap; import java.util.List; import java.util.Locale; import java.util.Map; -import java.util.function.Function; -import java.util.function.Predicate; /** * An implementation of {@link ConnectionFactory} for creating connections to a PostgreSQL database. */ public final class PostgresqlConnectionFactory implements ConnectionFactory { + private static final ClientSupplier DEFAULT_CLIENT_SUPPLIER = (endpoint, settings) -> + ReactorNettyClient.connect(endpoint, settings).cast(Client.class); + private static final String REPLICATION_OPTION = "replication"; private static final String REPLICATION_DATABASE = "database"; - private final Function> clientFactory; + private final ConnectionStrategy connectionStrategy; private final PostgresqlConnectionConfiguration configuration; - private final SocketAddress endpoint; - private final Extensions extensions; /** @@ -76,27 +64,16 @@ public final class PostgresqlConnectionFactory implements ConnectionFactory { */ public PostgresqlConnectionFactory(PostgresqlConnectionConfiguration configuration) { this.configuration = Assert.requireNonNull(configuration, "configuration must not be null"); - this.endpoint = createSocketAddress(configuration); - this.clientFactory = settings -> ReactorNettyClient.connect(this.endpoint, settings).cast(Client.class); + this.connectionStrategy = ConnectionStrategyFactory.getConnectionStrategy(DEFAULT_CLIENT_SUPPLIER, configuration); this.extensions = getExtensions(configuration); } - PostgresqlConnectionFactory(Function> clientFactory, PostgresqlConnectionConfiguration configuration) { + PostgresqlConnectionFactory(ConnectionStrategy connectionStrategy, PostgresqlConnectionConfiguration configuration) { this.configuration = Assert.requireNonNull(configuration, "configuration must not be null"); - this.endpoint = createSocketAddress(configuration); - this.clientFactory = Assert.requireNonNull(clientFactory, "clientFactory must not be null"); + this.connectionStrategy = Assert.requireNonNull(connectionStrategy, "clientFactory must not be null"); this.extensions = getExtensions(configuration); } - private static SocketAddress createSocketAddress(PostgresqlConnectionConfiguration configuration) { - - if (!configuration.isUseSocket()) { - return InetSocketAddress.createUnresolved(configuration.getRequiredHost(), configuration.getPort()); - } - - return DomainSocketFactory.getDomainSocketAddress(configuration); - } - private static Extensions getExtensions(PostgresqlConnectionConfiguration configuration) { Extensions extensions = Extensions.from(configuration.getExtensions()); @@ -114,7 +91,7 @@ public Mono create() { throw new UnsupportedOperationException("Cannot create replication connection through create(). Use replication() method instead."); } - return doCreateConnection(false, this.configuration.getOptions()).cast(io.r2dbc.postgresql.api.PostgresqlConnection.class); + return doCreateConnection(false, this.connectionStrategy).cast(io.r2dbc.postgresql.api.PostgresqlConnection.class); } /** @@ -127,31 +104,12 @@ public Mono replication Map options = new LinkedHashMap<>(this.configuration.getOptions()); options.put(REPLICATION_OPTION, REPLICATION_DATABASE); - return doCreateConnection(true, options).map(DefaultPostgresqlReplicationConnection::new); + return doCreateConnection(true, this.connectionStrategy.withOptions(options)).map(DefaultPostgresqlReplicationConnection::new); } - private Mono doCreateConnection(boolean forReplication, @Nullable Map options) { - - SSLConfig sslConfig = this.configuration.getSslConfig(); - ConnectionSettings connectionSettings = this.configuration.getConnectionSettings(); - Predicate isAuthSpecificationError = e -> e instanceof ExceptionFactory.PostgresqlAuthenticationFailure; - return this.tryConnectWithConfig(connectionSettings, options) - .onErrorResume( - isAuthSpecificationError.and(e -> sslConfig.getSslMode() == SSLMode.ALLOW), - e -> this.tryConnectWithConfig(connectionSettings.mutate(builder -> builder.sslConfig(sslConfig.mutateMode(SSLMode.REQUIRE))), options) - .onErrorResume(sslAuthError -> { - e.addSuppressed(sslAuthError); - return Mono.error(e); - }) - ) - .onErrorResume( - isAuthSpecificationError.and(e -> sslConfig.getSslMode() == SSLMode.PREFER), - e -> this.tryConnectWithConfig(connectionSettings.mutate(builder -> builder.sslConfig(sslConfig.mutateMode(SSLMode.DISABLE))), options) - .onErrorResume(sslAuthError -> { - e.addSuppressed(sslAuthError); - return Mono.error(e); - }) - ) + private Mono doCreateConnection(boolean forReplication, ConnectionStrategy connectionStrategy) { + + return connectionStrategy.connect() .flatMap(client -> { DefaultCodecs codecs = new DefaultCodecs(client.getByteBufAllocator(), this.configuration.isPreferAttachedBuffers()); @@ -180,15 +138,6 @@ private boolean isReplicationConnection() { return REPLICATION_DATABASE.equalsIgnoreCase(options.get(REPLICATION_OPTION)); } - private Mono tryConnectWithConfig(ConnectionSettings settings, @Nullable Map options) { - return this.clientFactory.apply(settings) - .delayUntil(client -> StartupMessageFlow - .exchange(this.configuration.getApplicationName(), this::getAuthenticationHandler, client, this.configuration.getDatabase(), this.configuration.getUsername(), - options) - .handle(ExceptionFactory.INSTANCE::handleErrorResponse)) - .cast(Client.class); - } - private Publisher prepareConnection(PostgresqlConnection connection, ByteBufAllocator byteBufAllocator, DefaultCodecs codecs, boolean forReplication) { List> publishers = new ArrayList<>(); @@ -213,7 +162,7 @@ private Throwable cannotConnect(Throwable throwable) { } return new PostgresConnectionException( - String.format("Cannot connect to %s", this.endpoint), throwable + String.format("Cannot connect to %s", "TODO"), throwable // TODO ); } @@ -229,24 +178,12 @@ PostgresqlConnectionConfiguration getConfiguration() { @Override public String toString() { return "PostgresqlConnectionFactory{" + - "clientFactory=" + this.clientFactory + + "connectionStrategy=" + this.connectionStrategy + ", configuration=" + this.configuration + ", extensions=" + this.extensions + '}'; } - private AuthenticationHandler getAuthenticationHandler(AuthenticationMessage message) { - if (PasswordAuthenticationHandler.supports(message)) { - CharSequence password = Assert.requireNonNull(this.configuration.getPassword(), "Password must not be null"); - return new PasswordAuthenticationHandler(password, this.configuration.getUsername()); - } else if (SASLAuthenticationHandler.supports(message)) { - CharSequence password = Assert.requireNonNull(this.configuration.getPassword(), "Password must not be null"); - return new SASLAuthenticationHandler(password, this.configuration.getUsername()); - } else { - throw new IllegalStateException(String.format("Unable to provide AuthenticationHandler capable of handling %s", message)); - } - } - private Mono getIsolationLevel(io.r2dbc.postgresql.api.PostgresqlConnection connection) { return connection.createStatement("SHOW TRANSACTION ISOLATION LEVEL") .fetchSize(0) @@ -270,12 +207,4 @@ public PostgresConnectionException(String msg, @Nullable Throwable cause) { } - static class DomainSocketFactory { - - private static SocketAddress getDomainSocketAddress(PostgresqlConnectionConfiguration configuration) { - return new DomainSocketAddress(configuration.getRequiredSocket()); - } - - } - } diff --git a/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryProvider.java b/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryProvider.java index 633002d95..4b787309d 100644 --- a/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryProvider.java +++ b/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryProvider.java @@ -42,6 +42,7 @@ import static io.r2dbc.spi.ConnectionFactoryOptions.HOST; import static io.r2dbc.spi.ConnectionFactoryOptions.PASSWORD; import static io.r2dbc.spi.ConnectionFactoryOptions.PORT; +import static io.r2dbc.spi.ConnectionFactoryOptions.PROTOCOL; import static io.r2dbc.spi.ConnectionFactoryOptions.SSL; import static io.r2dbc.spi.ConnectionFactoryOptions.USER; @@ -89,6 +90,16 @@ public final class PostgresqlConnectionFactoryProvider implements ConnectionFact */ public static final Option FORCE_BINARY = Option.valueOf("forceBinary"); + /** + * Host status recheck time. + */ + public static final Option HOST_RECHECK_TIME = Option.valueOf("hostRecheckTime"); + + /** + * Load balance hosts. + */ + public static final Option LOAD_BALANCE_HOSTS = Option.valueOf("loadBalanceHosts"); + /** * Lock timeout. * @@ -123,6 +134,11 @@ public final class PostgresqlConnectionFactoryProvider implements ConnectionFact */ public static final String LEGACY_POSTGRESQL_DRIVER = "postgres"; + /** + * Failover driver protocol. + */ + public static final String FAILOVER_PROTOCOL = "failover"; + /** * Configure whether {@link Codecs codecs} should prefer attached data buffers. The default is {@code false}, meaning that codecs will copy data from the input buffer into a {@code byte[]} * or similar data structure that is enabled for garbage collection. Using attached buffers is more efficient but comes with the requirement that decoded values (such as {@link Json}) must @@ -204,6 +220,11 @@ public final class PostgresqlConnectionFactoryProvider implements ConnectionFact */ public static final Option STATEMENT_TIMEOUT = ConnectionFactoryOptions.STATEMENT_TIMEOUT; + /** + * Target server type. Allowed values: any, master, secondary, preferSecondary. + */ + public static final Option TARGET_SERVER_TYPE = Option.valueOf("targetServerType"); + /** * Enable TCP KeepAlive. * @@ -262,6 +283,28 @@ private static PostgresqlConnectionConfiguration.Builder fromConnectionFactoryOp OptionMapper mapper = OptionMapper.create(options); + String protocol = (String) options.getValue(PROTOCOL); + if (protocol != null && FAILOVER_PROTOCOL.equals(protocol)) { + mapper.from(HOST_RECHECK_TIME).map(OptionMapper::toDuration).to(builder::hostRecheckTime); + mapper.from(LOAD_BALANCE_HOSTS).map(OptionMapper::toBoolean).to(builder::loadBalanceHosts); + mapper.from(TARGET_SERVER_TYPE).map(value -> OptionMapper.toEnum(value, TargetServerType.class)).to(builder::targetServerType); + String hosts = "" + options.getRequiredValue(HOST); + for (String host : hosts.split(",")) { + String[] hostParts = host.split(":"); + if (hostParts.length == 1) { + builder.addHost(host); + } else { + builder.addHost(hostParts[0], OptionMapper.toInteger(hostParts[1])); + } + } + setupSsl(builder, mapper); + } else { + mapper.fromTyped(SOCKET).to(builder::socket).otherwise(() -> { + builder.host("" + options.getRequiredValue(HOST)); + setupSsl(builder, mapper); + }); + } + mapper.fromTyped(APPLICATION_NAME).to(builder::applicationName); mapper.from(AUTODETECT_EXTENSIONS).map(OptionMapper::toBoolean).to(builder::autodetectExtensions); mapper.from(COMPATIBILITY_MODE).map(OptionMapper::toBoolean).to(builder::compatibilityMode); @@ -280,10 +323,6 @@ private static PostgresqlConnectionConfiguration.Builder fromConnectionFactoryOp mapper.from(PORT).map(OptionMapper::toInteger).to(builder::port); mapper.from(PREFER_ATTACHED_BUFFERS).map(OptionMapper::toBoolean).to(builder::preferAttachedBuffers); mapper.from(PREPARED_STATEMENT_CACHE_QUERIES).map(OptionMapper::toInteger).to(builder::preparedStatementCacheQueries); - mapper.fromTyped(SOCKET).to(builder::socket).otherwise(() -> { - builder.host("" + options.getRequiredValue(HOST)); - setupSsl(builder, mapper); - }); mapper.from(STATEMENT_TIMEOUT).map(OptionMapper::toDuration).to(builder::statementTimeout); mapper.from(TCP_KEEPALIVE).map(OptionMapper::toBoolean).to(builder::tcpKeepAlive); mapper.from(TCP_NODELAY).map(OptionMapper::toBoolean).to(builder::tcpNoDelay); diff --git a/src/main/java/io/r2dbc/postgresql/SslFallbackConnectionStrategy.java b/src/main/java/io/r2dbc/postgresql/SslFallbackConnectionStrategy.java new file mode 100644 index 000000000..c51847a6d --- /dev/null +++ b/src/main/java/io/r2dbc/postgresql/SslFallbackConnectionStrategy.java @@ -0,0 +1,60 @@ +package io.r2dbc.postgresql; + +import io.r2dbc.postgresql.client.Client; +import io.r2dbc.postgresql.client.ConnectionSettings; +import io.r2dbc.postgresql.client.SSLConfig; +import io.r2dbc.postgresql.client.SSLMode; +import reactor.core.publisher.Mono; + +import java.net.SocketAddress; +import java.util.Map; +import java.util.function.Function; +import java.util.function.Predicate; + +public class SslFallbackConnectionStrategy implements ConnectionStrategy.ComposableConnectionStrategy { + + private final PostgresqlConnectionConfiguration configuration; + + private final ComposableConnectionStrategy connectionStrategy; + + SslFallbackConnectionStrategy(PostgresqlConnectionConfiguration configuration, ComposableConnectionStrategy connectionStrategy) { + this.configuration = configuration; + this.connectionStrategy = connectionStrategy; + } + + @Override + public Mono connect() { + SSLConfig sslConfig = this.configuration.getSslConfig(); + Predicate isAuthSpecificationError = e -> e instanceof ExceptionFactory.PostgresqlAuthenticationFailure; + return this.connectionStrategy.connect() + .onErrorResume(isAuthSpecificationError.and(e -> sslConfig.getSslMode() == SSLMode.ALLOW), fallback(SSLMode.REQUIRE)) + .onErrorResume(isAuthSpecificationError.and(e -> sslConfig.getSslMode() == SSLMode.PREFER), fallback(SSLMode.DISABLE)); + } + + private Function> fallback(SSLMode sslMode) { + ConnectionSettings connectionSettings = this.configuration.getConnectionSettings(); + SSLConfig sslConfig = this.configuration.getSslConfig(); + return e -> this.connectionStrategy.withConnectionSettings(connectionSettings.mutate(builder -> builder.sslConfig(sslConfig.mutateMode(sslMode)))) + .connect() + .onErrorResume(sslAuthError -> { + e.addSuppressed(sslAuthError); + return Mono.error(e); + }); + } + + @Override + public ComposableConnectionStrategy withAddress(SocketAddress address) { + return new SslFallbackConnectionStrategy(this.configuration, this.connectionStrategy.withAddress(address)); + } + + @Override + public ComposableConnectionStrategy withConnectionSettings(ConnectionSettings connectionSettings) { + return new SslFallbackConnectionStrategy(this.configuration, this.connectionStrategy.withConnectionSettings(connectionSettings)); + } + + @Override + public ComposableConnectionStrategy withOptions(Map options) { + return new SslFallbackConnectionStrategy(this.configuration, this.connectionStrategy.withOptions(options)); + } + +} diff --git a/src/main/java/io/r2dbc/postgresql/TargetServerType.java b/src/main/java/io/r2dbc/postgresql/TargetServerType.java new file mode 100644 index 000000000..4dd538c8f --- /dev/null +++ b/src/main/java/io/r2dbc/postgresql/TargetServerType.java @@ -0,0 +1,53 @@ +package io.r2dbc.postgresql; + +import javax.annotation.Nullable; + +public enum TargetServerType { + ANY("any") { + @Override + public boolean allowStatus(MultiHostConnectionStrategy.HostStatus hostStatus) { + return hostStatus != MultiHostConnectionStrategy.HostStatus.CONNECT_FAIL; + } + }, + MASTER("master") { + @Override + public boolean allowStatus(MultiHostConnectionStrategy.HostStatus hostStatus) { + return hostStatus == MultiHostConnectionStrategy.HostStatus.PRIMARY || hostStatus == MultiHostConnectionStrategy.HostStatus.CONNECT_OK; + } + }, + SECONDARY("secondary") { + @Override + public boolean allowStatus(MultiHostConnectionStrategy.HostStatus hostStatus) { + return hostStatus == MultiHostConnectionStrategy.HostStatus.STANDBY || hostStatus == MultiHostConnectionStrategy.HostStatus.CONNECT_OK; + } + }, + PREFER_SECONDARY("preferSecondary") { + @Override + public boolean allowStatus(MultiHostConnectionStrategy.HostStatus hostStatus) { + return hostStatus == MultiHostConnectionStrategy.HostStatus.STANDBY || hostStatus == MultiHostConnectionStrategy.HostStatus.CONNECT_OK; + } + }; + + private final String value; + + TargetServerType(String value) { + this.value = value; + } + + @Nullable + public static TargetServerType fromValue(String value) { + for (TargetServerType type : values()) { + if (type.value.equals(value)) { + return type; + } + } + return null; + } + + public String getValue() { + return value; + } + + public abstract boolean allowStatus(MultiHostConnectionStrategy.HostStatus hostStatus); + +} diff --git a/src/main/java/io/r2dbc/postgresql/client/MultiHostConfiguration.java b/src/main/java/io/r2dbc/postgresql/client/MultiHostConfiguration.java new file mode 100644 index 000000000..799d99c3e --- /dev/null +++ b/src/main/java/io/r2dbc/postgresql/client/MultiHostConfiguration.java @@ -0,0 +1,194 @@ +package io.r2dbc.postgresql.client; + +import io.r2dbc.postgresql.TargetServerType; +import io.r2dbc.postgresql.util.Assert; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; + +import static io.r2dbc.postgresql.PostgresqlConnectionConfiguration.DEFAULT_PORT; + +public class MultiHostConfiguration { + + private final List hosts; + + private final Duration hostRecheckTime; + + private final boolean loadBalanceHosts; + + private final TargetServerType targetServerType; + + public MultiHostConfiguration(List hosts, Duration hostRecheckTime, boolean loadBalanceHosts, TargetServerType targetServerType) { + this.hosts = hosts; + this.hostRecheckTime = hostRecheckTime; + this.loadBalanceHosts = loadBalanceHosts; + this.targetServerType = targetServerType; + } + + public Duration getHostRecheckTime() { + return hostRecheckTime; + } + + public List getHosts() { + return hosts; + } + + public TargetServerType getTargetServerType() { + return targetServerType; + } + + public boolean isLoadBalanceHosts() { + return loadBalanceHosts; + } + + @Override + public String toString() { + return "MultiHostConfiguration{" + + "hosts=" + this.hosts + + ", hostRecheckTime=" + this.hostRecheckTime + + ", loadBalanceHosts=" + this.loadBalanceHosts + + ", targetServerType=" + this.targetServerType + + '}'; + } + + public static class ServerHost { + + private final String host; + + private final int port; + + public ServerHost(String host, int port) { + this.host = host; + this.port = port; + } + + public String getHost() { + return this.host; + } + + public int getPort() { + return this.port; + } + + @Override + public String toString() { + return "ServerHost{" + + "host='" + this.host + '\'' + + ", port=" + this.port + + '}'; + } + } + + /** + * Returns a new {@link Builder}. + * + * @return a new {@link Builder} + */ + public static Builder builder() { + return new Builder(); + } + + /** + * A builder for {@link MultiHostConfiguration} instances. + *

+ * This class is not threadsafe + */ + public static class Builder { + + private Duration hostRecheckTime = Duration.ofMillis(10000); + + private List hosts = new ArrayList<>(); + + private boolean loadBalanceHosts = false; + + private TargetServerType targetServerType = TargetServerType.ANY; + + /** + * Allows opening connections to only servers with required state, the allowed values are any, master, secondary and preferSecondary. + * The master/secondary distinction is currently done by observing if the server allows writes. + * The value preferSecondary tries to connect to secondary if any are available, otherwise allows falls back to connecting also to master. + * Default value is any. + * + * @param targetServerType target server type + * @return this {@link Builder} + * @throws IllegalArgumentException if {@code targetServerType} is {@code null} + */ + public Builder targetServerType(TargetServerType targetServerType) { + this.targetServerType = Assert.requireNonNull(targetServerType, "targetServerType must not be null"); + return this; + } + + /** + * Controls how long in seconds the knowledge about a host state is cached connection factory. The default value is 10000 milliseconds. + * + * @param hostRecheckTime host recheck time in milliseconds + * @return this {@link Builder} + */ + public Builder hostRecheckTime(Duration hostRecheckTime) { + this.hostRecheckTime = hostRecheckTime; + return this; + } + + /** + * In default mode (disabled) hosts are connected in the given order. If enabled hosts are chosen randomly from the set of suitable candidates. + * + * @param loadBalanceHosts is load balance mode enabled + * @return this {@link Builder} + */ + public Builder loadBalanceHosts(boolean loadBalanceHosts) { + this.loadBalanceHosts = loadBalanceHosts; + return this; + } + + /** + * Add host with default port to hosts list. + * + * @param host the host + * @return this {@link Builder} + * @throws IllegalArgumentException if {@code host} is {@code null} + */ + public Builder addHost(String host) { + Assert.requireNonNull(host, "host must not be null"); + this.hosts.add(new ServerHost(host, DEFAULT_PORT)); + return this; + } + + /** + * Add host to hosts list. + * + * @param host the host + * @param port the port + * @return this {@link Builder} + * @throws IllegalArgumentException if {@code host} is {@code null} + */ + public Builder addHost(String host, int port) { + Assert.requireNonNull(host, "host must not be null"); + this.hosts.add(new ServerHost(host, port)); + return this; + } + + /** + * Returns a configured {@link MultiHostConfiguration}. + * + * @return a configured {@link MultiHostConfiguration} + */ + public MultiHostConfiguration build() { + if (this.hosts.isEmpty()) { + throw new IllegalArgumentException("At least one host should be provided"); + } + + return new MultiHostConfiguration(this.hosts, this.hostRecheckTime, this.loadBalanceHosts, this.targetServerType); + } + + @Override + public String toString() { + return "Builder{" + + "hostRecheckTime=" + this.hostRecheckTime + + ", hosts=" + this.hosts + + ", loadBalanceHosts=" + this.loadBalanceHosts + + ", targetServerType=" + this.targetServerType + + '}'; + } + } +} diff --git a/src/main/java/io/r2dbc/postgresql/client/SingleHostConfiguration.java b/src/main/java/io/r2dbc/postgresql/client/SingleHostConfiguration.java new file mode 100644 index 000000000..cfa3db7ef --- /dev/null +++ b/src/main/java/io/r2dbc/postgresql/client/SingleHostConfiguration.java @@ -0,0 +1,164 @@ +package io.r2dbc.postgresql.client; + +import io.r2dbc.postgresql.util.Assert; +import reactor.util.annotation.Nullable; + +import static io.r2dbc.postgresql.PostgresqlConnectionConfiguration.DEFAULT_PORT; + +public class SingleHostConfiguration { + + @Nullable + private final String host; + + private final int port; + + @Nullable + private final String socket; + + public SingleHostConfiguration(@Nullable String host, int port, @Nullable String socket) { + this.host = host; + this.port = port; + this.socket = socket; + } + + @Nullable + public String getHost() { + return host; + } + + public int getPort() { + return port; + } + + public String getRequiredHost() { + + String host = getHost(); + + if (host == null || host.isEmpty()) { + throw new IllegalStateException("Connection is configured for socket connections and not for host usage"); + } + + return host; + } + + public String getRequiredSocket() { + + String socket = getSocket(); + + if (socket == null || socket.isEmpty()) { + throw new IllegalStateException("Connection is configured to use host and port connections and not for socket usage"); + } + + return socket; + } + + @Nullable + public String getSocket() { + return socket; + } + + public boolean isUseSocket() { + return getSocket() != null; + } + + @Override + public String toString() { + return "SingleHostConfiguration{" + + "host='" + this.host + '\'' + + ", port=" + this.port + + ", socket='" + this.socket + '\'' + + '}'; + } + + /** + * Returns a new {@link Builder}. + * + * @return a new {@link Builder} + */ + public static Builder builder() { + return new Builder(); + } + + /** + * A builder for {@link SingleHostConfiguration} instances. + *

+ * This class is not threadsafe + */ + public static class Builder { + + @Nullable + private String host; + + private int port = DEFAULT_PORT; + + @Nullable + private String socket; + + /** + * Configure the host. + * + * @param host the host + * @return this {@link Builder} + * @throws IllegalArgumentException if {@code host} is {@code null} + */ + public Builder host(String host) { + this.host = Assert.requireNonNull(host, "host must not be null"); + return this; + } + + /** + * Configure the port. Defaults to {@code 5432}. + * + * @param port the port + * @return this {@link Builder} + */ + public Builder port(int port) { + this.port = port; + return this; + } + + /** + * Configure the unix domain socket to connect to. + * + * @param socket the socket path + * @return this {@link Builder} + * @throws IllegalArgumentException if {@code socket} is {@code null} + */ + public Builder socket(String socket) { + this.socket = Assert.requireNonNull(socket, "host must not be null"); + return this; + } + + /** + * Returns a configured {@link SingleHostConfiguration}. + * + * @return a configured {@link SingleHostConfiguration} + */ + public SingleHostConfiguration build() { + if (this.host == null && this.socket == null) { + throw new IllegalArgumentException("host or socket must not be null"); + } + if (this.host != null && this.socket != null) { + throw new IllegalArgumentException("Connection must be configured for either host/port or socket usage but not both"); + } + + return new SingleHostConfiguration(this.host, this.port, this.socket); + } + + @Nullable + public String getSocket() { + return socket; + } + + @Override + public String toString() { + return "Builder{" + + "host='" + this.host + '\'' + + ", port=" + this.port + + ", socket='" + this.socket + '\'' + + '}'; + } + } + + +} diff --git a/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionConfigurationUnitTests.java b/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionConfigurationUnitTests.java index d9edc3aa9..b941799cf 100644 --- a/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionConfigurationUnitTests.java +++ b/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionConfigurationUnitTests.java @@ -40,9 +40,9 @@ void builderNoApplicationName() { } @Test - void builderNoHostAndSocket() { + void builderNoHostConfiguration() { assertThatIllegalArgumentException().isThrownBy(() -> PostgresqlConnectionConfiguration.builder().build()) - .withMessage("host or socket must not be null"); + .withMessage("either multiHostConfiguration or singleHostConfiguration must not be null"); } @Test @@ -82,10 +82,10 @@ void configuration() { .hasFieldOrPropertyWithValue("applicationName", "test-application-name") .hasFieldOrPropertyWithValue("connectTimeout", Duration.ofMillis(1000)) .hasFieldOrPropertyWithValue("database", "test-database") - .hasFieldOrPropertyWithValue("host", "test-host") + .hasFieldOrPropertyWithValue("singleHostConfiguration.host", "test-host") .hasFieldOrProperty("options") .hasFieldOrPropertyWithValue("password", null) - .hasFieldOrPropertyWithValue("port", 100) + .hasFieldOrPropertyWithValue("singleHostConfiguration.port", 100) .hasFieldOrPropertyWithValue("username", "test-username") .hasFieldOrProperty("sslConfig") .hasFieldOrPropertyWithValue("tcpKeepAlive", true) @@ -114,10 +114,10 @@ void configureStatementAndLockTimeouts() { .hasFieldOrPropertyWithValue("applicationName", "test-application-name") .hasFieldOrPropertyWithValue("connectTimeout", Duration.ofMillis(1000)) .hasFieldOrPropertyWithValue("database", "test-database") - .hasFieldOrPropertyWithValue("host", "test-host") + .hasFieldOrPropertyWithValue("singleHostConfiguration.host", "test-host") .hasFieldOrProperty("options") .hasFieldOrPropertyWithValue("password", null) - .hasFieldOrPropertyWithValue("port", 100) + .hasFieldOrPropertyWithValue("singleHostConfiguration.port", 100) .hasFieldOrPropertyWithValue("username", "test-username") .hasFieldOrProperty("sslConfig") .hasFieldOrPropertyWithValue("tcpKeepAlive", true) @@ -159,9 +159,9 @@ void configurationDefaults() { assertThat(configuration) .hasFieldOrPropertyWithValue("applicationName", "r2dbc-postgresql") .hasFieldOrPropertyWithValue("database", "test-database") - .hasFieldOrPropertyWithValue("host", "test-host") + .hasFieldOrPropertyWithValue("singleHostConfiguration.host", "test-host") .hasFieldOrPropertyWithValue("password", "test-password") - .hasFieldOrPropertyWithValue("port", 5432) + .hasFieldOrPropertyWithValue("singleHostConfiguration.port", 5432) .hasFieldOrProperty("options") .hasFieldOrPropertyWithValue("username", "test-username") .hasFieldOrProperty("sslConfig") diff --git a/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryProviderUnitTests.java b/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryProviderUnitTests.java index 3ccaeb910..f43d3cac5 100644 --- a/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryProviderUnitTests.java +++ b/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryProviderUnitTests.java @@ -16,6 +16,7 @@ package io.r2dbc.postgresql; +import io.r2dbc.postgresql.client.MultiHostConfiguration; import io.r2dbc.postgresql.client.SSLConfig; import io.r2dbc.postgresql.client.SSLMode; import io.r2dbc.postgresql.extension.Extension; @@ -27,6 +28,7 @@ import java.time.Duration; import java.util.Arrays; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.Objects; @@ -34,9 +36,12 @@ import static io.r2dbc.postgresql.PostgresqlConnectionFactoryProvider.COMPATIBILITY_MODE; import static io.r2dbc.postgresql.PostgresqlConnectionFactoryProvider.ERROR_RESPONSE_LOG_LEVEL; import static io.r2dbc.postgresql.PostgresqlConnectionFactoryProvider.EXTENSIONS; +import static io.r2dbc.postgresql.PostgresqlConnectionFactoryProvider.FAILOVER_PROTOCOL; import static io.r2dbc.postgresql.PostgresqlConnectionFactoryProvider.FETCH_SIZE; import static io.r2dbc.postgresql.PostgresqlConnectionFactoryProvider.FORCE_BINARY; +import static io.r2dbc.postgresql.PostgresqlConnectionFactoryProvider.HOST_RECHECK_TIME; import static io.r2dbc.postgresql.PostgresqlConnectionFactoryProvider.LEGACY_POSTGRESQL_DRIVER; +import static io.r2dbc.postgresql.PostgresqlConnectionFactoryProvider.LOAD_BALANCE_HOSTS; import static io.r2dbc.postgresql.PostgresqlConnectionFactoryProvider.LOCK_WAIT_TIMEOUT; import static io.r2dbc.postgresql.PostgresqlConnectionFactoryProvider.OPTIONS; import static io.r2dbc.postgresql.PostgresqlConnectionFactoryProvider.POSTGRESQL_DRIVER; @@ -49,11 +54,13 @@ import static io.r2dbc.postgresql.PostgresqlConnectionFactoryProvider.SSL_MODE; import static io.r2dbc.postgresql.PostgresqlConnectionFactoryProvider.SSL_ROOT_CERT; import static io.r2dbc.postgresql.PostgresqlConnectionFactoryProvider.STATEMENT_TIMEOUT; +import static io.r2dbc.postgresql.PostgresqlConnectionFactoryProvider.TARGET_SERVER_TYPE; import static io.r2dbc.postgresql.PostgresqlConnectionFactoryProvider.TCP_KEEPALIVE; import static io.r2dbc.postgresql.PostgresqlConnectionFactoryProvider.TCP_NODELAY; import static io.r2dbc.spi.ConnectionFactoryOptions.DRIVER; import static io.r2dbc.spi.ConnectionFactoryOptions.HOST; import static io.r2dbc.spi.ConnectionFactoryOptions.PASSWORD; +import static io.r2dbc.spi.ConnectionFactoryOptions.PROTOCOL; import static io.r2dbc.spi.ConnectionFactoryOptions.SSL; import static io.r2dbc.spi.ConnectionFactoryOptions.USER; import static io.r2dbc.spi.ConnectionFactoryOptions.builder; @@ -472,8 +479,31 @@ void shouldConnectUsingUnixDomainSocket() { .option(USER, "postgres") .build()); - assertThat(factory.getConfiguration().isUseSocket()).isTrue(); - assertThat(factory.getConfiguration().getRequiredSocket()).isEqualTo("/tmp/.s.PGSQL.5432"); + assertThat(factory.getConfiguration().getSingleHostConfiguration().isUseSocket()).isTrue(); + assertThat(factory.getConfiguration().getSingleHostConfiguration().getRequiredSocket()).isEqualTo("/tmp/.s.PGSQL.5432"); + } + + @Test + void shouldConnectUsingMultiHostConfiguration() { + PostgresqlConnectionFactory factory = this.provider.create(builder() + .option(DRIVER, POSTGRESQL_DRIVER) + .option(PROTOCOL, FAILOVER_PROTOCOL) + .option(HOST, "host1:5433,host2:5432,host3") + .option(USER, "postgres") + .option(LOAD_BALANCE_HOSTS, true) + .option(HOST_RECHECK_TIME, Duration.ofMillis(20000)) + .option(TARGET_SERVER_TYPE, TargetServerType.SECONDARY) + .build()); + + assertThat(factory.getConfiguration().getSingleHostConfiguration()).isNull(); + assertThat(factory.getConfiguration().getMultiHostConfiguration().isLoadBalanceHosts()).isEqualTo(true); + assertThat(factory.getConfiguration().getMultiHostConfiguration().getHostRecheckTime()).isEqualTo(Duration.ofMillis(20000)); + assertThat(factory.getConfiguration().getMultiHostConfiguration().getTargetServerType()).isEqualTo(TargetServerType.SECONDARY); + List hosts = factory.getConfiguration().getMultiHostConfiguration().getHosts(); + assertThat(hosts).hasSize(3); + assertThat(hosts.get(0)).usingRecursiveComparison().isEqualTo(new MultiHostConfiguration.ServerHost("host1", 5433)); + assertThat(hosts.get(1)).usingRecursiveComparison().isEqualTo(new MultiHostConfiguration.ServerHost("host2", 5432)); + assertThat(hosts.get(2)).usingRecursiveComparison().isEqualTo(new MultiHostConfiguration.ServerHost("host3", 5432)); } @Test diff --git a/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryUnitTests.java b/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryUnitTests.java index 7fdae1318..28f112505 100644 --- a/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryUnitTests.java +++ b/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryUnitTests.java @@ -17,6 +17,8 @@ package io.r2dbc.postgresql; import com.ongres.scram.client.ScramClient; + +import io.netty.channel.unix.DomainSocketAddress; import io.r2dbc.postgresql.client.Client; import io.r2dbc.postgresql.client.TestClient; import io.r2dbc.postgresql.message.backend.AuthenticationMD5Password; @@ -82,7 +84,7 @@ void createAuthenticationMD5Password() { .password("test-password") .build(); - new PostgresqlConnectionFactory(c -> Mono.just(client), configuration) + new PostgresqlConnectionFactory(testClientFactory(client, configuration), configuration) .create() .as(StepVerifier::create) .expectNextCount(1) @@ -133,7 +135,7 @@ void createError() { .password("test-password") .build(); - new PostgresqlConnectionFactory(c -> Mono.just(client), configuration).create() + new PostgresqlConnectionFactory(testClientFactory(client, configuration), configuration).create() .as(StepVerifier::create) .verifyErrorMatches(R2dbcNonTransientResourceException.class::isInstance); } @@ -157,7 +159,11 @@ void getMetadata() { .password("test-password") .build(); - assertThat(new PostgresqlConnectionFactory(c -> Mono.just(client), configuration).getMetadata()).isNotNull(); + assertThat(new PostgresqlConnectionFactory(testClientFactory(client, configuration), configuration).getMetadata()).isNotNull(); + } + + private ConnectionStrategy testClientFactory(Client client, PostgresqlConnectionConfiguration configuration) { + return new DefaultConnectionStrategy(new DomainSocketAddress(""), (endpoint, settings) -> Mono.just(client), configuration, null, Collections.emptyMap()); } } diff --git a/src/test/java/io/r2dbc/postgresql/client/HighAvailabilityClusterIntegrationTests.java b/src/test/java/io/r2dbc/postgresql/client/HighAvailabilityClusterIntegrationTests.java new file mode 100644 index 000000000..6775d0c35 --- /dev/null +++ b/src/test/java/io/r2dbc/postgresql/client/HighAvailabilityClusterIntegrationTests.java @@ -0,0 +1,195 @@ +package io.r2dbc.postgresql.client; + +import io.r2dbc.postgresql.PostgresqlConnectionConfiguration; +import io.r2dbc.postgresql.PostgresqlConnectionFactory; +import io.r2dbc.postgresql.TargetServerType; +import io.r2dbc.postgresql.api.PostgresqlConnection; +import io.r2dbc.postgresql.util.PostgresqlHighAvailabilityClusterExtension; +import io.r2dbc.spi.R2dbcException; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.testcontainers.containers.PostgreSQLContainer; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +public class HighAvailabilityClusterIntegrationTests { + + @RegisterExtension + static final PostgresqlHighAvailabilityClusterExtension SERVERS = new PostgresqlHighAvailabilityClusterExtension(); + + @Test + void testPrimaryAndStandbyStartup() { + Assertions.assertFalse(SERVERS.getPrimaryJdbc().queryForObject("show transaction_read_only", Boolean.class)); + Assertions.assertTrue(SERVERS.getStandbyJdbc().queryForObject("show transaction_read_only", Boolean.class)); + } + + @Test + void testMultipleCallsOnSameFactory() { + PostgresqlConnectionFactory connectionFactory = this.multiHostConnectionFactory(TargetServerType.PREFER_SECONDARY, SERVERS.getPrimary(), SERVERS.getStandby()); + + connectionFactory + .create() + .flatMapMany(connection -> this.isPrimary(connection) + .concatWith(connection.close().then(Mono.empty()))) + .next() + .as(StepVerifier::create) + .expectNext(false) + .verifyComplete(); + + connectionFactory + .create() + .flatMapMany(connection -> this.isPrimary(connection) + .concatWith(connection.close().then(Mono.empty()))) + .next() + .as(StepVerifier::create) + .expectNext(false) + .verifyComplete(); + } + + @Test + void testTargetAnyChooseFirst() { + isConnectedToPrimary(TargetServerType.ANY, SERVERS.getPrimary(), SERVERS.getStandby()) + .as(StepVerifier::create) + .expectNext(true) + .verifyComplete(); + + isConnectedToPrimary(TargetServerType.ANY, SERVERS.getStandby(), SERVERS.getPrimary()) + .as(StepVerifier::create) + .expectNext(false) + .verifyComplete(); + } + + @Test + void testTargetAnyConnectedToPrimary() { + isConnectedToPrimary(TargetServerType.ANY, SERVERS.getPrimary()) + .as(StepVerifier::create) + .expectNext(true) + .verifyComplete(); + } + + @Test + void testTargetAnyConnectedToStandby() { + isConnectedToPrimary(TargetServerType.ANY, SERVERS.getStandby()) + .as(StepVerifier::create) + .expectNext(false) + .verifyComplete(); + } + + @Test + void testTargetPreferSecondaryChooseStandby() { + isConnectedToPrimary(TargetServerType.PREFER_SECONDARY, SERVERS.getStandby(), SERVERS.getPrimary()) + .as(StepVerifier::create) + .expectNext(false) + .verifyComplete(); + + isConnectedToPrimary(TargetServerType.PREFER_SECONDARY, SERVERS.getPrimary(), SERVERS.getStandby()) + .as(StepVerifier::create) + .expectNext(false) + .verifyComplete(); + } + + @Test + void testTargetPreferSecondaryConnectedToPrimary() { + isConnectedToPrimary(TargetServerType.PREFER_SECONDARY, SERVERS.getPrimary()) + .as(StepVerifier::create) + .expectNext(true) + .verifyComplete(); + } + + @Test + void testTargetPreferSecondaryConnectedToStandby() { + isConnectedToPrimary(TargetServerType.PREFER_SECONDARY, SERVERS.getStandby()) + .as(StepVerifier::create) + .expectNext(false) + .verifyComplete(); + } + + @Test + void testTargetPrimaryChoosePrimary() { + isConnectedToPrimary(TargetServerType.MASTER, SERVERS.getPrimary(), SERVERS.getStandby()) + .as(StepVerifier::create) + .expectNext(true) + .verifyComplete(); + + isConnectedToPrimary(TargetServerType.MASTER, SERVERS.getStandby(), SERVERS.getPrimary()) + .as(StepVerifier::create) + .expectNext(true) + .verifyComplete(); + } + + @Test + void testTargetPrimaryConnectedOnPrimary() { + isConnectedToPrimary(TargetServerType.MASTER, SERVERS.getPrimary()) + .as(StepVerifier::create) + .expectNext(true) + .verifyComplete(); + } + + @Test + void testTargetPrimaryFailedOnStandby() { + isConnectedToPrimary(TargetServerType.MASTER, SERVERS.getStandby()) + .as(StepVerifier::create) + .verifyError(R2dbcException.class); + } + + @Test + void testTargetSecondaryChooseStandby() { + isConnectedToPrimary(TargetServerType.SECONDARY, SERVERS.getStandby(), SERVERS.getPrimary()) + .as(StepVerifier::create) + .expectNext(false) + .verifyComplete(); + + isConnectedToPrimary(TargetServerType.SECONDARY, SERVERS.getPrimary(), SERVERS.getStandby()) + .as(StepVerifier::create) + .expectNext(false) + .verifyComplete(); + } + + @Test + void testTargetSecondaryConnectedOnStandby() { + isConnectedToPrimary(TargetServerType.SECONDARY, SERVERS.getStandby()) + .as(StepVerifier::create) + .expectNext(false) + .verifyComplete(); + } + + @Test + void testTargetSecondaryFailedOnPrimary() { + isConnectedToPrimary(TargetServerType.SECONDARY, SERVERS.getPrimary()) + .as(StepVerifier::create) + .verifyError(R2dbcException.class); + } + + private Mono isConnectedToPrimary(TargetServerType targetServerType, PostgreSQLContainer... servers) { + PostgresqlConnectionFactory connectionFactory = this.multiHostConnectionFactory(targetServerType, servers); + + return connectionFactory + .create() + .flatMapMany(connection -> this.isPrimary(connection) + .concatWith(connection.close().then(Mono.empty()))) + .next(); + } + + private Mono isPrimary(PostgresqlConnection connection) { + return connection.createStatement("show transaction_read_only") + .execute() + .flatMap(result -> result.map((row, meta) -> row.get(0, String.class))) + .map(str -> str.equalsIgnoreCase("off")) + .next(); + } + + private PostgresqlConnectionFactory multiHostConnectionFactory(TargetServerType targetServerType, PostgreSQLContainer... servers) { + PostgreSQLContainer firstServer = servers[0]; + PostgresqlConnectionConfiguration.Builder builder = PostgresqlConnectionConfiguration.builder(); + for (PostgreSQLContainer server : servers) { + builder.addHost(server.getContainerIpAddress(), server.getMappedPort(5432)); + } + PostgresqlConnectionConfiguration configuration = builder + .targetServerType(targetServerType) + .username(firstServer.getUsername()) + .password(firstServer.getPassword()) + .build(); + return new PostgresqlConnectionFactory(configuration); + } +} diff --git a/src/test/java/io/r2dbc/postgresql/client/SingleHostConfigurationUnitTests.java b/src/test/java/io/r2dbc/postgresql/client/SingleHostConfigurationUnitTests.java new file mode 100644 index 000000000..137033365 --- /dev/null +++ b/src/test/java/io/r2dbc/postgresql/client/SingleHostConfigurationUnitTests.java @@ -0,0 +1,20 @@ +package io.r2dbc.postgresql.client; + +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; + +public class SingleHostConfigurationUnitTests { + + @Test + void builderNoHostAndSocket() { + assertThatIllegalArgumentException().isThrownBy(() -> SingleHostConfiguration.builder().build()) + .withMessage("host or socket must not be null"); + } + + @Test + void builderHostAndSocket() { + assertThatIllegalArgumentException().isThrownBy(() -> SingleHostConfiguration.builder().host("host").socket("socket").build()) + .withMessageContaining("either host/port or socket"); + } +} diff --git a/src/test/java/io/r2dbc/postgresql/util/PostgresqlHighAvailabilityClusterExtension.java b/src/test/java/io/r2dbc/postgresql/util/PostgresqlHighAvailabilityClusterExtension.java new file mode 100644 index 000000000..15dacb547 --- /dev/null +++ b/src/test/java/io/r2dbc/postgresql/util/PostgresqlHighAvailabilityClusterExtension.java @@ -0,0 +1,124 @@ +package io.r2dbc.postgresql.util; + +import com.zaxxer.hikari.HikariConfig; +import com.zaxxer.hikari.HikariDataSource; +import org.junit.jupiter.api.extension.AfterAllCallback; +import org.junit.jupiter.api.extension.BeforeAllCallback; +import org.junit.jupiter.api.extension.ExtensionContext; +import org.springframework.jdbc.core.JdbcTemplate; +import org.testcontainers.containers.Network; +import org.testcontainers.containers.PostgreSQLContainer; +import org.testcontainers.containers.wait.strategy.LogMessageWaitStrategy; +import org.testcontainers.utility.MountableFile; + +import java.net.URISyntaxException; +import java.net.URL; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.time.Duration; +import java.time.temporal.ChronoUnit; + +import static org.testcontainers.utility.MountableFile.forHostPath; + +public class PostgresqlHighAvailabilityClusterExtension implements BeforeAllCallback, AfterAllCallback { + + private PostgreSQLContainer primary; + + private HikariDataSource primaryDataSource; + + private PostgreSQLContainer standby; + + private HikariDataSource standbyDataSource; + + @Override + public void afterAll(ExtensionContext extensionContext) { + if (this.standbyDataSource != null) { + this.standbyDataSource.close(); + } + if (this.standby != null) { + this.standby.stop(); + } + if (this.primaryDataSource != null) { + this.primaryDataSource.close(); + } + if (this.primary != null) { + this.primary.stop(); + } + } + + @Override + public void beforeAll(ExtensionContext extensionContext) { + Network network = Network.newNetwork(); + this.startPrimary(network); + this.startStandby(network); + } + + public PostgreSQLContainer getPrimary() { + return this.primary; + } + + public JdbcTemplate getPrimaryJdbc() { + return new JdbcTemplate(this.primaryDataSource); + } + + public PostgreSQLContainer getStandby() { + return standby; + } + + public JdbcTemplate getStandbyJdbc() { + return new JdbcTemplate(this.standbyDataSource); + } + + private static MountableFile getHostPath(String name, int mode) { + return forHostPath(getResourcePath(name), mode); + } + + private static Path getResourcePath(String name) { + URL resource = PostgresqlHighAvailabilityClusterExtension.class.getClassLoader().getResource(name); + if (resource == null) { + throw new IllegalStateException("Resource not found: " + name); + } + + try { + return Paths.get(resource.toURI()); + } catch (URISyntaxException e) { + throw new IllegalStateException("Cannot convert to path for: " + name, e); + } + } + + private void startPrimary(Network network) { + this.primary = new PostgreSQLContainer<>("postgres:latest") + .withNetwork(network) + .withNetworkAliases("postgres-primary") + .withCopyFileToContainer(getHostPath("setup-primary.sh", 0755), "/docker-entrypoint-initdb.d/setup-primary.sh") + .withEnv("PG_REP_USER", "replication") + .withEnv("PG_REP_PASSWORD", "replication_password"); + this.primary.start(); + HikariConfig primaryConfig = new HikariConfig(); + primaryConfig.setJdbcUrl(this.primary.getJdbcUrl()); + primaryConfig.setUsername(this.primary.getUsername()); + primaryConfig.setPassword(this.primary.getPassword()); + this.primaryDataSource = new HikariDataSource(primaryConfig); + } + + private void startStandby(Network network) { + this.standby = new PostgreSQLContainer<>("postgres:latest") + .withNetwork(network) + .withCopyFileToContainer(getHostPath("setup-standby.sh", 0755), "/setup-standby.sh") + .withCommand("/setup-standby.sh") + .withEnv("PG_REP_USER", "replication") + .withEnv("PG_REP_PASSWORD", "replication_password") + .withEnv("PG_MASTER_HOST", "postgres-primary") + .withEnv("PG_MASTER_PORT", "5432"); + this.standby.setWaitStrategy(new LogMessageWaitStrategy() + .withRegEx(".*database system is ready to accept read-only connections.*\\s") + .withTimes(1) + .withStartupTimeout(Duration.of(60L, ChronoUnit.SECONDS))); + this.standby.start(); + HikariConfig standbyConfig = new HikariConfig(); + standbyConfig.setJdbcUrl(this.standby.getJdbcUrl()); + standbyConfig.setUsername(this.standby.getUsername()); + standbyConfig.setPassword(this.standby.getPassword()); + this.standbyDataSource = new HikariDataSource(standbyConfig); + } +} diff --git a/src/test/resources/setup-primary.sh b/src/test/resources/setup-primary.sh new file mode 100644 index 000000000..2ee3ba81f --- /dev/null +++ b/src/test/resources/setup-primary.sh @@ -0,0 +1,17 @@ +#!/bin/sh + +echo "host replication all 0.0.0.0/0 md5" >> "$PGDATA/pg_hba.conf" + +set -e +psql -v ON_ERROR_STOP=1 --username "$POSTGRES_USER" --dbname "$POSTGRES_DB" <<-EOSQL + CREATE USER $PG_REP_USER REPLICATION LOGIN CONNECTION LIMIT 100 ENCRYPTED PASSWORD '$PG_REP_PASSWORD'; +EOSQL + +cat >> ${PGDATA}/postgresql.conf < ~/.pgpass + chmod 0600 ~/.pgpass + until pg_basebackup -h "${PG_MASTER_HOST}" -p "${PG_MASTER_PORT}" -D "${PGDATA}" -U "${PG_REP_USER}" -vP -W + do + echo "Waiting for primary server to connect..." + sleep 1s + done + echo "host replication all 0.0.0.0/0 md5" >> "$PGDATA/pg_hba.conf" + set -e + cat > "${PGDATA}"/standby.signal <