diff --git a/src/main/java/io/r2dbc/postgresql/ClientFactory.java b/src/main/java/io/r2dbc/postgresql/ClientFactory.java new file mode 100644 index 00000000..5f7a6f8d --- /dev/null +++ b/src/main/java/io/r2dbc/postgresql/ClientFactory.java @@ -0,0 +1,22 @@ +package io.r2dbc.postgresql; + +import io.r2dbc.postgresql.client.Client; +import reactor.core.publisher.Mono; +import reactor.util.annotation.Nullable; + +import java.util.Map; + +public interface ClientFactory { + + static ClientFactory getFactory(PostgresqlConnectionConfiguration configuration, ClientSupplier clientSupplier) { + if (configuration.getSingleHostConfiguration() != null) { + return new SingleHostClientFactory(configuration, clientSupplier); + } + if (configuration.getMultipleHostsConfiguration() != null) { + return new MultipleHostsClientFactory(configuration, clientSupplier); + } + throw new IllegalArgumentException("Can't build client factory based on configuration " + configuration); + } + + Mono create(@Nullable Map options); +} diff --git a/src/main/java/io/r2dbc/postgresql/ClientFactoryBase.java b/src/main/java/io/r2dbc/postgresql/ClientFactoryBase.java new file mode 100644 index 00000000..042bebb7 --- /dev/null +++ b/src/main/java/io/r2dbc/postgresql/ClientFactoryBase.java @@ -0,0 +1,73 @@ +package io.r2dbc.postgresql; + +import io.r2dbc.postgresql.authentication.AuthenticationHandler; +import io.r2dbc.postgresql.authentication.PasswordAuthenticationHandler; +import io.r2dbc.postgresql.authentication.SASLAuthenticationHandler; +import io.r2dbc.postgresql.client.Client; +import io.r2dbc.postgresql.client.SSLConfig; +import io.r2dbc.postgresql.client.SSLMode; +import io.r2dbc.postgresql.client.StartupMessageFlow; +import io.r2dbc.postgresql.message.backend.AuthenticationMessage; +import io.r2dbc.postgresql.util.Assert; +import reactor.core.publisher.Mono; + +import javax.annotation.Nullable; +import java.net.SocketAddress; +import java.util.Map; +import java.util.function.Predicate; + +public abstract class ClientFactoryBase implements ClientFactory { + + private final ClientSupplier clientSupplier; + + private final PostgresqlConnectionConfiguration configuration; + + protected ClientFactoryBase(PostgresqlConnectionConfiguration configuration, ClientSupplier clientSupplier) { + this.configuration = configuration; + this.clientSupplier = clientSupplier; + } + + protected AuthenticationHandler getAuthenticationHandler(AuthenticationMessage message) { + if (PasswordAuthenticationHandler.supports(message)) { + CharSequence password = Assert.requireNonNull(this.configuration.getPassword(), "Password must not be null"); + return new PasswordAuthenticationHandler(password, this.configuration.getUsername()); + } else if (SASLAuthenticationHandler.supports(message)) { + CharSequence password = Assert.requireNonNull(this.configuration.getPassword(), "Password must not be null"); + return new SASLAuthenticationHandler(password, this.configuration.getUsername()); + } else { + throw new IllegalStateException(String.format("Unable to provide AuthenticationHandler capable of handling %s", message)); + } + } + + protected Mono tryConnectToEndpoint(SocketAddress endpoint, @Nullable Map options) { + SSLConfig sslConfig = this.configuration.getSslConfig(); + Predicate isAuthSpecificationError = e -> e instanceof ExceptionFactory.PostgresqlAuthenticationFailure; + return this.tryConnectWithConfig(sslConfig, endpoint, options) + .onErrorResume( + isAuthSpecificationError.and(e -> sslConfig.getSslMode() == SSLMode.ALLOW), + e -> this.tryConnectWithConfig(sslConfig.mutateMode(SSLMode.REQUIRE), endpoint, options) + .onErrorResume(sslAuthError -> { + e.addSuppressed(sslAuthError); + return Mono.error(e); + }) + ) + .onErrorResume( + isAuthSpecificationError.and(e -> sslConfig.getSslMode() == SSLMode.PREFER), + e -> this.tryConnectWithConfig(sslConfig.mutateMode(SSLMode.DISABLE), endpoint, options) + .onErrorResume(sslAuthError -> { + e.addSuppressed(sslAuthError); + return Mono.error(e); + }) + ); + } + + protected Mono tryConnectWithConfig(SSLConfig sslConfig, SocketAddress endpoint, @Nullable Map options) { + return this.clientSupplier.connect(endpoint, this.configuration.getConnectTimeout(), sslConfig) + .delayUntil(client -> StartupMessageFlow + .exchange(this.configuration.getApplicationName(), this::getAuthenticationHandler, client, this.configuration + .getDatabase(), this.configuration.getUsername(), options) + .handle(ExceptionFactory.INSTANCE::handleErrorResponse)) + .cast(Client.class); + } + +} diff --git a/src/main/java/io/r2dbc/postgresql/ClientSupplier.java b/src/main/java/io/r2dbc/postgresql/ClientSupplier.java new file mode 100644 index 00000000..21f1a5d8 --- /dev/null +++ b/src/main/java/io/r2dbc/postgresql/ClientSupplier.java @@ -0,0 +1,14 @@ +package io.r2dbc.postgresql; + +import io.r2dbc.postgresql.client.Client; +import io.r2dbc.postgresql.client.SSLConfig; +import reactor.core.publisher.Mono; + +import javax.annotation.Nullable; +import java.net.SocketAddress; +import java.time.Duration; + +public interface ClientSupplier { + + Mono connect(SocketAddress endpoint, @Nullable Duration connectTimeout, SSLConfig sslConfig); +} diff --git a/src/main/java/io/r2dbc/postgresql/MultipleHostsClientFactory.java b/src/main/java/io/r2dbc/postgresql/MultipleHostsClientFactory.java new file mode 100644 index 00000000..b8864f55 --- /dev/null +++ b/src/main/java/io/r2dbc/postgresql/MultipleHostsClientFactory.java @@ -0,0 +1,210 @@ +package io.r2dbc.postgresql; + +import io.r2dbc.postgresql.client.Client; +import io.r2dbc.postgresql.client.MultipleHostsConfiguration; +import io.r2dbc.postgresql.codec.DefaultCodecs; +import io.r2dbc.postgresql.util.Assert; +import io.r2dbc.spi.IsolationLevel; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import javax.annotation.Nullable; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicReference; + +import static io.r2dbc.postgresql.TargetServerType.ANY; +import static io.r2dbc.postgresql.TargetServerType.MASTER; +import static io.r2dbc.postgresql.TargetServerType.PREFER_SECONDARY; +import static io.r2dbc.postgresql.TargetServerType.SECONDARY; + +class MultipleHostsClientFactory extends ClientFactoryBase { + + private final List addresses; + + private final MultipleHostsConfiguration configuration; + + private final Map statusMap = new ConcurrentHashMap<>(); + + public MultipleHostsClientFactory(PostgresqlConnectionConfiguration configuration, ClientSupplier clientSupplier) { + super(configuration, clientSupplier); + this.configuration = Assert.requireNonNull(configuration.getMultipleHostsConfiguration(), "MultipleHostsConfiguration must not be null"); + this.addresses = MultipleHostsClientFactory.createSocketAddress(this.configuration); + } + + @Override + public Mono create(@Nullable Map options) { + AtomicReference exceptionRef = new AtomicReference<>(); + TargetServerType targetServerType = this.configuration.getTargetServerType(); + return this.tryConnect(targetServerType, options) + .onErrorResume(e -> { + if (!exceptionRef.compareAndSet(null, e)) { + exceptionRef.get().addSuppressed(e); + } + return Mono.empty(); + }) + .switchIfEmpty(Mono.defer(() -> targetServerType == PREFER_SECONDARY + ? this.tryConnect(MASTER, options) + : Mono.empty())) + .switchIfEmpty(Mono.error(() -> { + Throwable error = exceptionRef.get(); + if (error == null) { + return new PostgresqlConnectionFactory.PostgresConnectionException(String.format("No server matches target type %s", targetServerType + .getValue()), null); + } else { + return error; + } + })); + } + + public Mono tryConnect(TargetServerType targetServerType, @Nullable Map options) { + AtomicReference exceptionRef = new AtomicReference<>(); + return this.getCandidates(targetServerType).concatMap(candidate -> this.tryConnectToCandidate(targetServerType, candidate, options) + .onErrorResume(e -> { + if (!exceptionRef.compareAndSet(null, e)) { + exceptionRef.get().addSuppressed(e); + } + this.statusMap.put(candidate, HostSpecStatus.fail(candidate)); + return Mono.empty(); + })) + .next() + .switchIfEmpty(Mono.defer(() -> exceptionRef.get() != null + ? Mono.error(exceptionRef.get()) + : Mono.empty())); + } + + private static List createSocketAddress(MultipleHostsConfiguration configuration) { + List addressList = new ArrayList<>(configuration.getHosts().size()); + for (MultipleHostsConfiguration.ServerHost host : configuration.getHosts()) { + addressList.add(InetSocketAddress.createUnresolved(host.getHost(), host.getPort())); + } + return addressList; + } + + private static HostSpecStatus evaluateStatus(SocketAddress candidate, @Nullable HostSpecStatus oldStatus) { + return oldStatus == null || oldStatus.hostStatus == HostStatus.CONNECT_FAIL + ? HostSpecStatus.ok(candidate) + : oldStatus; + } + + private static Mono isPrimaryServer(Client client) { + DefaultCodecs codecs = new DefaultCodecs(client.getByteBufAllocator()); + StatementCache disabledStatementCache = StatementCache.fromPreparedStatementCacheQueries(client, 0); + PostgresqlConnection connection = new PostgresqlConnection(client, codecs, DefaultPortalNameSupplier.INSTANCE, disabledStatementCache, + IsolationLevel.READ_COMMITTED, false); + ConnectionContext context = new ConnectionContext(client, codecs, connection); + return new SimpleQueryPostgresqlStatement(context, "show transaction_read_only") + .execute() + .flatMap(result -> result.map((row, rowMetadata) -> row.get(0, String.class))) + .map(s -> s.equalsIgnoreCase("off")) + .next(); + } + + private Flux getCandidates(TargetServerType targetServerType) { + return Flux.create(sink -> { + long now = System.currentTimeMillis(); + List addresses = new ArrayList<>(this.addresses); + if (this.configuration.isLoadBalanceHosts()) { + Collections.shuffle(addresses); + } + int counter = 0; + for (SocketAddress address : addresses) { + HostSpecStatus currentStatus = this.statusMap.get(address); + Duration hostRecheckDuration = this.configuration.getHostRecheckTime(); + boolean recheck = currentStatus == null || hostRecheckDuration.plusMillis(currentStatus.updated).toMillis() < now; + if (recheck) { + sink.next(address); + counter++; + } else if (targetServerType.allowStatus(currentStatus.hostStatus)) { + sink.next(address); + counter++; + } + } + if (counter == 0) { + // if no candidate match the requirement or all of them are in unavailable status try all the hosts + addresses = new ArrayList<>(this.addresses); + if (this.configuration.isLoadBalanceHosts()) { + Collections.shuffle(addresses); + } + for (SocketAddress address : addresses) { + sink.next(address); + } + } + sink.complete(); + }); + } + + private Mono tryConnectToCandidate(TargetServerType targetServerType, SocketAddress candidate, @Nullable Map options) { + return Mono.create(sink -> this.tryConnectToEndpoint(candidate, options).subscribe(client -> { + this.statusMap.compute(candidate, (a, oldStatus) -> MultipleHostsClientFactory.evaluateStatus(candidate, oldStatus)); + if (targetServerType == ANY) { + sink.success(client); + return; + } + MultipleHostsClientFactory.isPrimaryServer(client).subscribe( + isPrimary -> { + if (isPrimary) { + this.statusMap.put(candidate, HostSpecStatus.primary(candidate)); + } else { + this.statusMap.put(candidate, HostSpecStatus.standby(candidate)); + } + if (isPrimary && targetServerType == MASTER) { + sink.success(client); + } else if (!isPrimary && (targetServerType == SECONDARY || targetServerType == PREFER_SECONDARY)) { + sink.success(client); + } else { + client.close().subscribe(v -> sink.success(), sink::error, sink::success, sink.currentContext()); + } + }, + sink::error, + () -> { + }, + sink.currentContext() + ); + }, sink::error, () -> { + }, sink.currentContext())); + } + + enum HostStatus { + CONNECT_FAIL, + CONNECT_OK, + PRIMARY, + STANDBY + } + + private static class HostSpecStatus { + + public final SocketAddress address; + + public final HostStatus hostStatus; + + public final long updated = System.currentTimeMillis(); + + private HostSpecStatus(SocketAddress address, HostStatus hostStatus) { + this.address = address; + this.hostStatus = hostStatus; + } + + public static HostSpecStatus fail(SocketAddress host) { + return new HostSpecStatus(host, HostStatus.CONNECT_FAIL); + } + + public static HostSpecStatus ok(SocketAddress host) { + return new HostSpecStatus(host, HostStatus.CONNECT_OK); + } + + public static HostSpecStatus primary(SocketAddress host) { + return new HostSpecStatus(host, HostStatus.PRIMARY); + } + + public static HostSpecStatus standby(SocketAddress host) { + return new HostSpecStatus(host, HostStatus.STANDBY); + } + } +} diff --git a/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionConfiguration.java b/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionConfiguration.java index 65e63942..4be0395c 100644 --- a/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionConfiguration.java +++ b/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionConfiguration.java @@ -19,8 +19,10 @@ import io.netty.handler.ssl.SslContextBuilder; import io.netty.handler.ssl.util.InsecureTrustManagerFactory; import io.r2dbc.postgresql.client.DefaultHostnameVerifier; +import io.r2dbc.postgresql.client.MultipleHostsConfiguration; import io.r2dbc.postgresql.client.SSLConfig; import io.r2dbc.postgresql.client.SSLMode; +import io.r2dbc.postgresql.client.SingleHostConfiguration; import io.r2dbc.postgresql.codec.Codec; import io.r2dbc.postgresql.extension.CodecRegistrar; import io.r2dbc.postgresql.extension.Extension; @@ -62,42 +64,43 @@ public final class PostgresqlConnectionConfiguration { private final boolean forceBinary; - private final String host; - private final Map options; private final CharSequence password; - private final int port; - private final String schema; - private final String socket; - private final String username; private final SSLConfig sslConfig; + private final MultipleHostsConfiguration multipleHostsConfiguration; + + private final SingleHostConfiguration singleHostConfiguration; + private final int preparedStatementCacheQueries; private PostgresqlConnectionConfiguration(String applicationName, boolean autodetectExtensions, - @Nullable Duration connectTimeout, @Nullable String database, List extensions, boolean forceBinary, @Nullable String host, - @Nullable Map options, @Nullable CharSequence password, int port, @Nullable String schema, @Nullable String socket, String username, - SSLConfig sslConfig, int preparedStatementCacheQueries) { + @Nullable Duration connectTimeout, @Nullable String database, List extensions, boolean forceBinary, + @Nullable Map options, @Nullable CharSequence password, @Nullable String schema, String username, + SSLConfig sslConfig, + @Nullable SingleHostConfiguration singleHostConfiguration, + @Nullable MultipleHostsConfiguration multipleHostsConfiguration, + int preparedStatementCacheQueries) { + this.applicationName = Assert.requireNonNull(applicationName, "applicationName must not be null"); this.autodetectExtensions = autodetectExtensions; this.connectTimeout = connectTimeout; this.extensions = Assert.requireNonNull(extensions, "extensions must not be null"); this.database = database; this.forceBinary = forceBinary; - this.host = host; this.options = options; this.password = password; - this.port = port; this.schema = schema; - this.socket = socket; this.username = Assert.requireNonNull(username, "username must not be null"); this.sslConfig = sslConfig; + this.singleHostConfiguration = singleHostConfiguration; + this.multipleHostsConfiguration = multipleHostsConfiguration; this.preparedStatementCacheQueries = preparedStatementCacheQueries; } @@ -111,22 +114,32 @@ public static Builder builder() { } + @Nullable + public MultipleHostsConfiguration getMultipleHostsConfiguration() { + return multipleHostsConfiguration; + } + + @Nullable + public SingleHostConfiguration getSingleHostConfiguration() { + return singleHostConfiguration; + } + @Override public String toString() { return "PostgresqlConnectionConfiguration{" + - "applicationName='" + this.applicationName + '\'' + - ", autodetectExtensions='" + this.autodetectExtensions + '\'' + - ", connectTimeout=" + this.connectTimeout + - ", database='" + this.database + '\'' + - ", extensions=" + this.extensions + - ", forceBinary='" + this.forceBinary + '\'' + - ", host='" + this.host + '\'' + - ", options='" + this.options + '\'' + - ", password='" + obfuscate(this.password != null ? this.password.length() : 0) + '\'' + - ", port=" + this.port + - ", schema='" + this.schema + '\'' + - ", username='" + this.username + '\'' + - '}'; + "applicationName='" + this.applicationName + '\'' + + ", singleHostConfiguration='" + this.singleHostConfiguration + '\'' + + ", multipleHostsConfiguration='" + this.multipleHostsConfiguration + '\'' + + ", autodetectExtensions='" + this.autodetectExtensions + '\'' + + ", connectTimeout=" + this.connectTimeout + + ", database='" + this.database + '\'' + + ", extensions=" + this.extensions + + ", forceBinary='" + this.forceBinary + '\'' + + ", options='" + this.options + '\'' + + ", password='" + obfuscate(this.password != null ? this.password.length() : 0) + '\'' + + ", schema='" + this.schema + '\'' + + ", username='" + this.username + '\'' + + '}'; } String getApplicationName() { @@ -147,22 +160,6 @@ List getExtensions() { return this.extensions; } - @Nullable - String getHost() { - return this.host; - } - - String getRequiredHost() { - - String host = getHost(); - - if (host == null || host.isEmpty()) { - throw new IllegalStateException("Connection is configured for socket connections and not for host usage"); - } - - return host; - } - @Nullable Map getOptions() { return this.options; @@ -173,30 +170,11 @@ CharSequence getPassword() { return this.password; } - int getPort() { - return this.port; - } - @Nullable String getSchema() { return this.schema; } - @Nullable - String getSocket() { - return this.socket; - } - - String getRequiredSocket() { - - String socket = getSocket(); - - if (socket == null || socket.isEmpty()) { - throw new IllegalStateException("Connection is configured to use host and port connections and not for socket usage"); - } - - return socket; - } String getUsername() { return this.username; @@ -210,10 +188,6 @@ boolean isForceBinary() { return this.forceBinary; } - boolean isUseSocket() { - return getSocket() != null; - } - SSLConfig getSslConfig() { return this.sslConfig; } @@ -255,21 +229,19 @@ public static final class Builder { private boolean forceBinary = false; @Nullable - private String host; + private MultipleHostsConfiguration.Builder multipleHostsConfiguration; + + @Nullable + private SingleHostConfiguration.Builder singleHostConfiguration; private Map options; @Nullable private CharSequence password; - private int port = DEFAULT_PORT; - @Nullable private String schema; - @Nullable - private String socket; - @Nullable private String sslCert = null; @@ -325,21 +297,25 @@ public Builder autodetectExtensions(boolean autodetectExtensions) { * @return a configured {@link PostgresqlConnectionConfiguration} */ public PostgresqlConnectionConfiguration build() { - - if (this.host == null && this.socket == null) { - throw new IllegalArgumentException("host or socket must not be null"); + SingleHostConfiguration singleHostConfiguration = this.singleHostConfiguration != null + ? this.singleHostConfiguration.build() + : null; + MultipleHostsConfiguration multipleHostsConfiguration = this.multipleHostsConfiguration != null + ? this.multipleHostsConfiguration.build() + : null; + if (singleHostConfiguration == null && multipleHostsConfiguration == null) { + throw new IllegalArgumentException("Either multiple hosts configuration or single host configuration should be provided"); } - - if (this.host != null && this.socket != null) { - throw new IllegalArgumentException("Connection must be configured for either host/port or socket usage but not both"); + if (singleHostConfiguration != null && multipleHostsConfiguration != null) { + throw new IllegalArgumentException("Either multiple hosts configuration or single host configuration should be provided"); } if (this.username == null) { throw new IllegalArgumentException("username must not be null"); } - return new PostgresqlConnectionConfiguration(this.applicationName, this.autodetectExtensions, this.connectTimeout, this.database, this.extensions, this.forceBinary, this.host, - this.options, this.password, this.port, this.schema, this.socket, this.username, this.createSslConfig(), this.preparedStatementCacheQueries); + return new PostgresqlConnectionConfiguration(this.applicationName, this.autodetectExtensions, this.connectTimeout, this.database, this.extensions, this.forceBinary, this.options, + this.password, this.schema, this.username, this.createSslConfig(), singleHostConfiguration, multipleHostsConfiguration, this.preparedStatementCacheQueries); } /** @@ -413,7 +389,11 @@ public Builder forceBinary(boolean forceBinary) { * @throws IllegalArgumentException if {@code host} is {@code null} */ public Builder host(String host) { - this.host = Assert.requireNonNull(host, "host must not be null"); + Assert.requireNonNull(host, "host must not be null"); + if (this.singleHostConfiguration == null) { + this.singleHostConfiguration = SingleHostConfiguration.builder(); + } + this.singleHostConfiguration.host(host); return this; } @@ -451,17 +431,6 @@ public Builder password(@Nullable CharSequence password) { return this; } - /** - * Configure the port. Defaults to {@code 5432}. - * - * @param port the port - * @return this {@link Builder} - */ - public Builder port(int port) { - this.port = port; - return this; - } - /** * Configure the schema. * @@ -474,15 +443,16 @@ public Builder schema(@Nullable String schema) { } /** - * Configure the unix domain socket to connect to. + * Configure the port. Defaults to {@code 5432}. * - * @param socket the socket path + * @param port the port * @return this {@link Builder} - * @throws IllegalArgumentException if {@code socket} is {@code null} */ - public Builder socket(String socket) { - this.socket = Assert.requireNonNull(socket, "host must not be null"); - sslMode(SSLMode.DISABLE); + public Builder port(int port) { + if (this.singleHostConfiguration == null) { + this.singleHostConfiguration = SingleHostConfiguration.builder(); + } + this.singleHostConfiguration.port(port); return this; } @@ -591,34 +561,130 @@ public Builder preparedStatementCacheQueries(int preparedStatementCacheQueries) return this; } + /** + * Configure the unix domain socket to connect to. + * + * @param socket the socket path + * @return this {@link Builder} + * @throws IllegalArgumentException if {@code socket} is {@code null} + */ + public Builder socket(String socket) { + Assert.requireNonNull(socket, "host must not be null"); + if (this.singleHostConfiguration == null) { + this.singleHostConfiguration = SingleHostConfiguration.builder(); + } + this.singleHostConfiguration.socket(socket); + + sslMode(SSLMode.DISABLE); + return this; + } + + /** + * Allows opening connections to only servers with required state, the allowed values are any, master, slave, secondary, preferSlave and preferSecondary. + * The master/secondary distinction is currently done by observing if the server allows writes. + * The value preferSecondary tries to connect to secondary if any are available, otherwise allows falls back to connecting also to master. + * Default value is any. + * + * @param targetServerType target server type + * @return this {@link Builder} + * @throws IllegalArgumentException if {@code targetServerType} is {@code null} + */ + public Builder targetServerType(TargetServerType targetServerType) { + if (this.multipleHostsConfiguration == null) { + this.multipleHostsConfiguration = MultipleHostsConfiguration.builder(); + } + this.multipleHostsConfiguration.targetServerType(targetServerType); + return this; + } + + /** + * Controls how long in milliseconds the knowledge about a host state is cached connection factory. The default value is 10000 milliseconds. + * + * @param hostRecheckTime host recheck time in milliseconds + * @return this {@link Builder} + */ + public Builder hostRecheckTime(@Nullable Duration hostRecheckTime) { + if (this.multipleHostsConfiguration == null) { + this.multipleHostsConfiguration = MultipleHostsConfiguration.builder(); + } + this.multipleHostsConfiguration.hostRecheckTime(hostRecheckTime); + return this; + } + + /** + * In default mode (disabled) hosts are connected in the given order. If enabled hosts are chosen randomly from the set of suitable candidates. + * + * @param loadBalanceHosts is load balance mode enabled + * @return this {@link Builder} + */ + public Builder loadBalanceHosts(boolean loadBalanceHosts) { + if (this.multipleHostsConfiguration == null) { + this.multipleHostsConfiguration = MultipleHostsConfiguration.builder(); + } + this.multipleHostsConfiguration.loadBalanceHosts(loadBalanceHosts); + return this; + } + + /** + * Add host with default port to hosts list. + * + * @param host the host + * @return this {@link Builder} + * @throws IllegalArgumentException if {@code host} is {@code null} + */ + public Builder addHost(String host) { + Assert.requireNonNull(host, "host must not be null"); + if (this.multipleHostsConfiguration == null) { + this.multipleHostsConfiguration = MultipleHostsConfiguration.builder(); + } + this.multipleHostsConfiguration.addHost(host); + return this; + } + + /** + * Add host to hosts list. + * + * @param host the host + * @param port the port + * @return this {@link Builder} + * @throws IllegalArgumentException if {@code host} is {@code null} + */ + public Builder addHost(String host, int port) { + Assert.requireNonNull(host, "host must not be null"); + if (this.multipleHostsConfiguration == null) { + this.multipleHostsConfiguration = MultipleHostsConfiguration.builder(); + } + this.multipleHostsConfiguration.addHost(host, port); + return this; + } + @Override public String toString() { return "Builder{" + - "applicationName='" + this.applicationName + '\'' + - ", autodetectExtensions='" + this.autodetectExtensions + '\'' + - ", connectTimeout='" + this.connectTimeout + '\'' + - ", database='" + this.database + '\'' + - ", extensions='" + this.extensions + '\'' + - ", forceBinary='" + this.forceBinary + '\'' + - ", host='" + this.host + '\'' + - ", parameters='" + this.options + '\'' + - ", password='" + obfuscate(this.password != null ? this.password.length() : 0) + '\'' + - ", port=" + this.port + - ", schema='" + this.schema + '\'' + - ", username='" + this.username + '\'' + - ", socket='" + this.socket + '\'' + - ", sslContextBuilderCustomizer='" + this.sslContextBuilderCustomizer + '\'' + - ", sslMode='" + this.sslMode + '\'' + - ", sslRootCert='" + this.sslRootCert + '\'' + - ", sslCert='" + this.sslCert + '\'' + - ", sslKey='" + this.sslKey + '\'' + - ", sslHostnameVerifier='" + this.sslHostnameVerifier + '\'' + - ", preparedStatementCacheQueries='" + this.preparedStatementCacheQueries + '\'' + - '}'; + "applicationName='" + this.applicationName + '\'' + + ", singleHostConfiguration='" + this.singleHostConfiguration + '\'' + + ", multipleHostsConfiguration='" + this.multipleHostsConfiguration + '\'' + + ", autodetectExtensions='" + this.autodetectExtensions + '\'' + + ", connectTimeout='" + this.connectTimeout + '\'' + + ", database='" + this.database + '\'' + + ", extensions='" + this.extensions + '\'' + + ", forceBinary='" + this.forceBinary + '\'' + + ", parameters='" + this.options + '\'' + + ", password='" + obfuscate(this.password != null ? this.password.length() : 0) + '\'' + + ", schema='" + this.schema + '\'' + + ", username='" + this.username + '\'' + + ", sslContextBuilderCustomizer='" + this.sslContextBuilderCustomizer + '\'' + + ", sslMode='" + this.sslMode + '\'' + + ", sslRootCert='" + this.sslRootCert + '\'' + + ", sslCert='" + this.sslCert + '\'' + + ", sslKey='" + this.sslKey + '\'' + + ", sslHostnameVerifier='" + this.sslHostnameVerifier + '\'' + + ", preparedStatementCacheQueries='" + this.preparedStatementCacheQueries + '\'' + + '}'; } private SSLConfig createSslConfig() { - if (this.socket != null || this.sslMode == SSLMode.DISABLE) { + if (this.singleHostConfiguration != null && this.singleHostConfiguration.getSocket() != null || this.sslMode == SSLMode.DISABLE) { return SSLConfig.disabled(); } @@ -669,9 +735,9 @@ private Supplier createSslProvider() { } return () -> SslProvider.builder() - .sslContext(this.sslContextBuilderCustomizer.apply(sslContextBuilder)) - .defaultConfiguration(TCP) - .build(); + .sslContext(this.sslContextBuilderCustomizer.apply(sslContextBuilder)) + .defaultConfiguration(TCP) + .build(); } } } diff --git a/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionFactory.java b/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionFactory.java index 0009ef81..31b80b18 100644 --- a/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionFactory.java +++ b/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionFactory.java @@ -17,18 +17,10 @@ package io.r2dbc.postgresql; import io.netty.buffer.ByteBufAllocator; -import io.netty.channel.unix.DomainSocketAddress; -import io.r2dbc.postgresql.authentication.AuthenticationHandler; -import io.r2dbc.postgresql.authentication.PasswordAuthenticationHandler; -import io.r2dbc.postgresql.authentication.SASLAuthenticationHandler; import io.r2dbc.postgresql.client.Client; import io.r2dbc.postgresql.client.ReactorNettyClient; -import io.r2dbc.postgresql.client.SSLConfig; -import io.r2dbc.postgresql.client.SSLMode; -import io.r2dbc.postgresql.client.StartupMessageFlow; import io.r2dbc.postgresql.codec.DefaultCodecs; import io.r2dbc.postgresql.extension.CodecRegistrar; -import io.r2dbc.postgresql.message.backend.AuthenticationMessage; import io.r2dbc.postgresql.util.Assert; import io.r2dbc.spi.ConnectionFactory; import io.r2dbc.spi.ConnectionFactoryMetadata; @@ -41,30 +33,28 @@ import reactor.netty.resources.ConnectionProvider; import reactor.util.annotation.Nullable; -import java.net.InetSocketAddress; -import java.net.SocketAddress; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Locale; import java.util.Map; -import java.util.function.Function; -import java.util.function.Predicate; /** * An implementation of {@link ConnectionFactory} for creating connections to a PostgreSQL database. */ public final class PostgresqlConnectionFactory implements ConnectionFactory { - private static final String REPLICATION_OPTION = "replication"; + private static final ClientSupplier DEFAULT_CLIENT_SUPPLIER = (endpoint, connectTimeout, sslConfig) -> + ReactorNettyClient.connect(ConnectionProvider.newConnection(), endpoint, connectTimeout, sslConfig) + .cast(Client.class); private static final String REPLICATION_DATABASE = "database"; - private final Function> clientFactory; + private static final String REPLICATION_OPTION = "replication"; - private final PostgresqlConnectionConfiguration configuration; + private final ClientFactory clientFactory; - private final SocketAddress endpoint; + private final PostgresqlConnectionConfiguration configuration; private final Extensions extensions; @@ -76,41 +66,16 @@ public final class PostgresqlConnectionFactory implements ConnectionFactory { */ public PostgresqlConnectionFactory(PostgresqlConnectionConfiguration configuration) { this.configuration = Assert.requireNonNull(configuration, "configuration must not be null"); - this.endpoint = createSocketAddress(configuration); - this.clientFactory = sslConfig -> ReactorNettyClient.connect(ConnectionProvider.newConnection(), this.endpoint, configuration.getConnectTimeout(), sslConfig).cast(Client.class); + this.clientFactory = ClientFactory.getFactory(configuration, DEFAULT_CLIENT_SUPPLIER); this.extensions = getExtensions(configuration); } - PostgresqlConnectionFactory(Function> clientFactory, PostgresqlConnectionConfiguration configuration) { + PostgresqlConnectionFactory(ClientFactory clientFactory, PostgresqlConnectionConfiguration configuration) { this.configuration = Assert.requireNonNull(configuration, "configuration must not be null"); - this.endpoint = createSocketAddress(configuration); this.clientFactory = Assert.requireNonNull(clientFactory, "clientFactory must not be null"); this.extensions = getExtensions(configuration); } - private static SocketAddress createSocketAddress(PostgresqlConnectionConfiguration configuration) { - - if (!configuration.isUseSocket()) { - return InetSocketAddress.createUnresolved(configuration.getRequiredHost(), configuration.getPort()); - } - - if (configuration.isUseSocket()) { - return new DomainSocketAddress(configuration.getRequiredSocket()); - } - - throw new IllegalArgumentException("Cannot create SocketAddress for " + configuration); - } - - private static Extensions getExtensions(PostgresqlConnectionConfiguration configuration) { - Extensions extensions = Extensions.from(configuration.getExtensions()); - - if (configuration.isAutodetectExtensions()) { - extensions = extensions.mergeWith(Extensions.autodetect()); - } - - return extensions; - } - @Override public Mono create() { @@ -121,6 +86,11 @@ public Mono create() { return doCreateConnection(false, this.configuration.getOptions()).cast(io.r2dbc.postgresql.api.PostgresqlConnection.class); } + @Override + public ConnectionFactoryMetadata getMetadata() { + return PostgresqlConnectionFactoryMetadata.INSTANCE; + } + /** * Creates a new {@link io.r2dbc.postgresql.api.PostgresqlReplicationConnection} for interaction with replication streams. * @@ -140,119 +110,65 @@ public Mono replication return doCreateConnection(true, options).map(DefaultPostgresqlReplicationConnection::new); } - private Mono doCreateConnection(boolean forReplication, @Nullable Map options) { - - SSLConfig sslConfig = this.configuration.getSslConfig(); - Predicate isAuthSpecificationError = e -> e instanceof ExceptionFactory.PostgresqlAuthenticationFailure; - return this.tryConnectWithConfig(sslConfig, options) - .onErrorResume( - isAuthSpecificationError.and(e -> sslConfig.getSslMode() == SSLMode.ALLOW), - e -> this.tryConnectWithConfig(sslConfig.mutateMode(SSLMode.REQUIRE), options) - .onErrorResume(sslAuthError -> { - e.addSuppressed(sslAuthError); - return Mono.error(e); - }) - ) - .onErrorResume( - isAuthSpecificationError.and(e -> sslConfig.getSslMode() == SSLMode.PREFER), - e -> this.tryConnectWithConfig(sslConfig.mutateMode(SSLMode.DISABLE), options) - .onErrorResume(sslAuthError -> { - e.addSuppressed(sslAuthError); - return Mono.error(e); - }) - ) - .flatMap(client -> { - - DefaultCodecs codecs = new DefaultCodecs(client.getByteBufAllocator()); - StatementCache statementCache = StatementCache.fromPreparedStatementCacheQueries(client, this.configuration.getPreparedStatementCacheQueries()); - - // early connection object to retrieve initialization details - PostgresqlConnection earlyConnection = new PostgresqlConnection(client, codecs, DefaultPortalNameSupplier.INSTANCE, statementCache, IsolationLevel.READ_COMMITTED, - this.configuration.isForceBinary()); - - Mono isolationLevelMono = Mono.just(IsolationLevel.READ_COMMITTED); - if (!forReplication) { - isolationLevelMono = getIsolationLevel(earlyConnection); - } - return isolationLevelMono - // actual connection to be used - .map(isolationLevel -> new PostgresqlConnection(client, codecs, DefaultPortalNameSupplier.INSTANCE, statementCache, isolationLevel, this.configuration.isForceBinary())) - .delayUntil(connection -> { - return prepareConnection(connection, client.getByteBufAllocator(), codecs); - }) - .onErrorResume(throwable -> this.closeWithError(client, throwable)); - }).onErrorMap(this::cannotConnect); - } - - private boolean isReplicationConnection() { - Map options = this.configuration.getOptions(); - return options != null && REPLICATION_DATABASE.equalsIgnoreCase(options.get(REPLICATION_OPTION)); + @Override + public String toString() { + return "PostgresqlConnectionFactory{" + + "clientFactory=" + this.clientFactory + + ", configuration=" + this.configuration + + ", extensions=" + this.extensions + + '}'; } - private Mono tryConnectWithConfig(SSLConfig sslConfig, @Nullable Map options) { - return this.clientFactory.apply(sslConfig) - .delayUntil(client -> StartupMessageFlow - .exchange(this.configuration.getApplicationName(), this::getAuthenticationHandler, client, this.configuration.getDatabase(), this.configuration.getUsername(), - options) - .handle(ExceptionFactory.INSTANCE::handleErrorResponse)) - .cast(Client.class); + PostgresqlConnectionConfiguration getConfiguration() { + return this.configuration; } - private Publisher prepareConnection(PostgresqlConnection connection, ByteBufAllocator byteBufAllocator, DefaultCodecs codecs) { - - List> publishers = new ArrayList<>(); - publishers.add(setSchema(connection)); - - this.extensions.forEach(CodecRegistrar.class, it -> { - publishers.add(it.register(connection, byteBufAllocator, codecs)); - }); + private static Extensions getExtensions(PostgresqlConnectionConfiguration configuration) { + Extensions extensions = Extensions.from(configuration.getExtensions()); - return Flux.concat(publishers).then(); - } + if (configuration.isAutodetectExtensions()) { + extensions = extensions.mergeWith(Extensions.autodetect()); + } - private Mono closeWithError(Client client, Throwable throwable) { - return client.close().then(Mono.error(throwable)); + return extensions; } private Throwable cannotConnect(Throwable throwable) { - if (throwable instanceof R2dbcException) { return throwable; } return new PostgresConnectionException( - String.format("Cannot connect to %s", this.endpoint), throwable + String.format("Cannot connect to %s", "TODO"), throwable // TODO ); } - @Override - public ConnectionFactoryMetadata getMetadata() { - return PostgresqlConnectionFactoryMetadata.INSTANCE; + private Mono closeWithError(Client client, Throwable throwable) { + return client.close().then(Mono.error(throwable)); } - PostgresqlConnectionConfiguration getConfiguration() { - return this.configuration; - } + private Mono doCreateConnection(boolean forReplication, @Nullable Map options) { + return clientFactory.create(options) + .flatMap(client -> { + DefaultCodecs codecs = new DefaultCodecs(client.getByteBufAllocator()); + StatementCache statementCache = StatementCache.fromPreparedStatementCacheQueries(client, this.configuration + .getPreparedStatementCacheQueries()); - @Override - public String toString() { - return "PostgresqlConnectionFactory{" + - "clientFactory=" + this.clientFactory + - ", configuration=" + this.configuration + - ", extensions=" + this.extensions + - '}'; - } + // early connection object to retrieve initialization details + PostgresqlConnection earlyConnection = new PostgresqlConnection(client, codecs, DefaultPortalNameSupplier.INSTANCE, statementCache, IsolationLevel.READ_COMMITTED, + this.configuration.isForceBinary()); - private AuthenticationHandler getAuthenticationHandler(AuthenticationMessage message) { - if (PasswordAuthenticationHandler.supports(message)) { - CharSequence password = Assert.requireNonNull(this.configuration.getPassword(), "Password must not be null"); - return new PasswordAuthenticationHandler(password, this.configuration.getUsername()); - } else if (SASLAuthenticationHandler.supports(message)) { - CharSequence password = Assert.requireNonNull(this.configuration.getPassword(), "Password must not be null"); - return new SASLAuthenticationHandler(password, this.configuration.getUsername()); - } else { - throw new IllegalStateException(String.format("Unable to provide AuthenticationHandler capable of handling %s", message)); - } + Mono isolationLevelMono = Mono.just(IsolationLevel.READ_COMMITTED); + if (!forReplication) { + isolationLevelMono = getIsolationLevel(earlyConnection); + } + return isolationLevelMono + // actual connection to be used + .map(isolationLevel -> new PostgresqlConnection(client, codecs, DefaultPortalNameSupplier.INSTANCE, statementCache, isolationLevel, this.configuration + .isForceBinary())) + .delayUntil(connection -> prepareConnection(connection, client.getByteBufAllocator(), codecs)) + .onErrorResume(throwable -> this.closeWithError(client, throwable)); + }).onErrorMap(this::cannotConnect); } private Mono getIsolationLevel(io.r2dbc.postgresql.api.PostgresqlConnection connection) { @@ -269,6 +185,22 @@ private Mono getIsolationLevel(io.r2dbc.postgresql.api.Postgresq })).defaultIfEmpty(IsolationLevel.READ_COMMITTED).last(); } + private boolean isReplicationConnection() { + Map options = this.configuration.getOptions(); + return options != null && REPLICATION_DATABASE.equalsIgnoreCase(options.get(REPLICATION_OPTION)); + } + + private Publisher prepareConnection(PostgresqlConnection connection, ByteBufAllocator byteBufAllocator, DefaultCodecs codecs) { + List> publishers = new ArrayList<>(); + publishers.add(setSchema(connection)); + + this.extensions.forEach(CodecRegistrar.class, it -> { + publishers.add(it.register(connection, byteBufAllocator, codecs)); + }); + + return Flux.concat(publishers).then(); + } + private Mono setSchema(PostgresqlConnection connection) { if (this.configuration.getSchema() == null) { return Mono.empty(); diff --git a/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryProvider.java b/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryProvider.java index cad330c5..70a02d01 100644 --- a/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryProvider.java +++ b/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryProvider.java @@ -25,6 +25,7 @@ import io.r2dbc.spi.Option; import javax.net.ssl.HostnameVerifier; +import java.time.Duration; import java.util.Map; import java.util.function.Function; @@ -34,6 +35,7 @@ import static io.r2dbc.spi.ConnectionFactoryOptions.HOST; import static io.r2dbc.spi.ConnectionFactoryOptions.PASSWORD; import static io.r2dbc.spi.ConnectionFactoryOptions.PORT; +import static io.r2dbc.spi.ConnectionFactoryOptions.PROTOCOL; import static io.r2dbc.spi.ConnectionFactoryOptions.SSL; import static io.r2dbc.spi.ConnectionFactoryOptions.USER; @@ -52,21 +54,47 @@ public final class PostgresqlConnectionFactoryProvider implements ConnectionFact */ public static final Option AUTODETECT_EXTENSIONS = Option.valueOf("autodetectExtensions"); + /** + * Failover driver protocol. + */ + public static final String FAILOVER_PROTOCOL = "failover"; + /** * Force binary transfer. */ public static final Option FORCE_BINARY = Option.valueOf("forceBinary"); /** - * Driver option value. + * Host status recheck time im ms. */ - public static final String POSTGRESQL_DRIVER = "postgresql"; + public static final Option HOST_RECHECK_TIME = Option.valueOf("hostRecheckTime"); /** * Legacy driver option value. */ public static final String LEGACY_POSTGRESQL_DRIVER = "postgres"; + /** + * Load balance hosts. + */ + public static final Option LOAD_BALANCE_HOSTS = Option.valueOf("loadBalanceHosts"); + + /** + * Connection options which are applied once after the connection has been created. + */ + public static final Option> OPTIONS = Option.valueOf("options"); + + /** + * Driver option value. + */ + public static final String POSTGRESQL_DRIVER = "postgresql"; + + /** + * Determine the number of queries that are cached in each connection. + * The default is {@code -1}, meaning there's no limit. The value of {@code 0} disables the cache. Any other value specifies the cache size. + */ + public static final Option PREPARED_STATEMENT_CACHE_QUERIES = Option.valueOf("preparedStatementCacheQueries"); + /** * Schema. */ @@ -78,14 +106,14 @@ public final class PostgresqlConnectionFactoryProvider implements ConnectionFact public static final Option SOCKET = Option.valueOf("socket"); /** - * Customizer {@link Function} for {@link SslContextBuilder}. + * Full path for the certificate file. */ - public static final Option> SSL_CONTEXT_BUILDER_CUSTOMIZER = Option.valueOf("sslContextBuilderCustomizer"); + public static final Option SSL_CERT = Option.valueOf("sslCert"); /** - * Full path for the certificate file. + * Customizer {@link Function} for {@link SslContextBuilder}. */ - public static final Option SSL_CERT = Option.valueOf("sslCert"); + public static final Option> SSL_CONTEXT_BUILDER_CUSTOMIZER = Option.valueOf("sslContextBuilderCustomizer"); /** * Class name of hostname verifier. Defaults to {@link DefaultHostnameVerifier}. @@ -113,15 +141,9 @@ public final class PostgresqlConnectionFactoryProvider implements ConnectionFact public static final Option SSL_ROOT_CERT = Option.valueOf("sslRootCert"); /** - * Determine the number of queries that are cached in each connection. - * The default is {@code -1}, meaning there's no limit. The value of {@code 0} disables the cache. Any other value specifies the cache size. - */ - public static final Option PREPARED_STATEMENT_CACHE_QUERIES = Option.valueOf("preparedStatementCacheQueries"); - - /** - * Connection options which are applied once after the connection has been created. + * Target server type. Allowed values: any, master, secondary, preferSecondary. */ - public static final Option> OPTIONS = Option.valueOf("options"); + public static final Option TARGET_SERVER_TYPE = Option.valueOf("targetServerType"); /** * Returns a new {@link PostgresqlConnectionConfiguration.Builder} configured with the given {@link ConnectionFactoryOptions}. @@ -152,80 +174,14 @@ public boolean supports(ConnectionFactoryOptions connectionFactoryOptions) { return driver != null && (driver.equals(POSTGRESQL_DRIVER) || driver.equals(LEGACY_POSTGRESQL_DRIVER)); } - private static int convertToInt(Object value) { - return value instanceof Integer ? (int) value : Integer.parseInt(value.toString()); - } - - private static void setupSsl(PostgresqlConnectionConfiguration.Builder builder, ConnectionFactoryOptions connectionFactoryOptions) { - Boolean ssl = connectionFactoryOptions.getValue(SSL); - if (ssl != null && ssl) { - builder.enableSsl(); - } - - Object sslMode = connectionFactoryOptions.getValue(SSL_MODE); - if (sslMode != null) { - if (sslMode instanceof String) { - builder.sslMode(SSLMode.fromValue(sslMode.toString())); - } else { - builder.sslMode((SSLMode) sslMode); - } - } - - String sslRootCert = connectionFactoryOptions.getValue(SSL_ROOT_CERT); - if (sslRootCert != null) { - builder.sslRootCert(sslRootCert); - } - - String sslCert = connectionFactoryOptions.getValue(SSL_CERT); - if (sslCert != null) { - builder.sslCert(sslCert); - } - - String sslKey = connectionFactoryOptions.getValue(SSL_KEY); - if (sslKey != null) { - builder.sslKey(sslKey); - } - - String sslPassword = connectionFactoryOptions.getValue(SSL_PASSWORD); - if (sslPassword != null) { - builder.sslPassword(sslPassword); - } - - if (connectionFactoryOptions.hasOption(SSL_CONTEXT_BUILDER_CUSTOMIZER)) { - builder.sslContextBuilderCustomizer(connectionFactoryOptions.getRequiredValue(SSL_CONTEXT_BUILDER_CUSTOMIZER)); - } - - setSslHostnameVerifier(builder, connectionFactoryOptions); - } - - private static void setSslHostnameVerifier(PostgresqlConnectionConfiguration.Builder builder, ConnectionFactoryOptions connectionFactoryOptions) { - Object sslHostnameVerifier = connectionFactoryOptions.getValue(SSL_HOSTNAME_VERIFIER); - if (sslHostnameVerifier != null) { - - if (sslHostnameVerifier instanceof String) { - - try { - Class verifierClass = Class.forName((String) sslHostnameVerifier); - Object verifier = verifierClass.getConstructor().newInstance(); - - builder.sslHostnameVerifier((HostnameVerifier) verifier); - } catch (ReflectiveOperationException e) { - throw new IllegalStateException("Cannot instantiate " + sslHostnameVerifier, e); - } - } else { - builder.sslHostnameVerifier((HostnameVerifier) sslHostnameVerifier); - } - } - } - - private static boolean isUsingTcp(ConnectionFactoryOptions connectionFactoryOptions) { - return !connectionFactoryOptions.hasOption(SOCKET); - } - private static boolean convertToBoolean(Object value) { return value instanceof Boolean ? (boolean) value : Boolean.parseBoolean(value.toString()); } + private static int convertToInt(Object value) { + return value instanceof Integer ? (int) value : Integer.parseInt(value.toString()); + } + /** * Configure the builder with the given {@link ConnectionFactoryOptions}. * @@ -276,12 +232,124 @@ private static PostgresqlConnectionConfiguration.Builder fromConnectionFactoryOp } if (isUsingTcp(connectionFactoryOptions)) { - builder.host(connectionFactoryOptions.getRequiredValue(HOST)); setupSsl(builder, connectionFactoryOptions); + setupFailover(builder, connectionFactoryOptions); } else { builder.socket(connectionFactoryOptions.getRequiredValue(SOCKET)); } return builder; } + + private static boolean isUsingTcp(ConnectionFactoryOptions connectionFactoryOptions) { + return !connectionFactoryOptions.hasOption(SOCKET); + } + + private static void setSslHostnameVerifier(PostgresqlConnectionConfiguration.Builder builder, ConnectionFactoryOptions connectionFactoryOptions) { + Object sslHostnameVerifier = connectionFactoryOptions.getValue(SSL_HOSTNAME_VERIFIER); + if (sslHostnameVerifier != null) { + + if (sslHostnameVerifier instanceof String) { + + try { + Class verifierClass = Class.forName((String) sslHostnameVerifier); + Object verifier = verifierClass.getConstructor().newInstance(); + + builder.sslHostnameVerifier((HostnameVerifier) verifier); + } catch (ReflectiveOperationException e) { + throw new IllegalStateException("Cannot instantiate " + sslHostnameVerifier, e); + } + } else { + builder.sslHostnameVerifier((HostnameVerifier) sslHostnameVerifier); + } + } + } + + private static void setupFailover(PostgresqlConnectionConfiguration.Builder builder, ConnectionFactoryOptions connectionFactoryOptions) { + if (FAILOVER_PROTOCOL.equals(connectionFactoryOptions.getValue(PROTOCOL))) { + if (connectionFactoryOptions.hasOption(HOST_RECHECK_TIME)) { + Duration hostRecheckTime = Duration.ofMillis(connectionFactoryOptions.getRequiredValue(HOST_RECHECK_TIME)); + builder.hostRecheckTime(hostRecheckTime); + } + if (connectionFactoryOptions.hasOption(LOAD_BALANCE_HOSTS)) { + Object loadBalanceHosts = connectionFactoryOptions.getRequiredValue(LOAD_BALANCE_HOSTS); + if (loadBalanceHosts instanceof Boolean) { + builder.loadBalanceHosts((Boolean) loadBalanceHosts); + } else { + builder.loadBalanceHosts(Boolean.parseBoolean(loadBalanceHosts.toString())); + } + } + if (connectionFactoryOptions.hasOption(TARGET_SERVER_TYPE)) { + Object targetServerType = connectionFactoryOptions.getRequiredValue(TARGET_SERVER_TYPE); + if (targetServerType instanceof TargetServerType) { + builder.targetServerType((TargetServerType) targetServerType); + } else { + builder.targetServerType(TargetServerType.fromValue(targetServerType.toString())); + } + } + String hosts = connectionFactoryOptions.getRequiredValue(HOST); + String[] hostsArray = hosts.split(","); + for (String host : hostsArray) { + String[] hostParts = host.split(":"); + if (hostParts.length == 1) { + builder.addHost(hostParts[0]); + } else { + int port = Integer.parseInt(hostParts[1]); + builder.addHost(hostParts[0], port); + } + } + } else { + if (connectionFactoryOptions.hasOption(SOCKET)) { + builder.socket(connectionFactoryOptions.getRequiredValue(SOCKET)); + } else { + builder.host(connectionFactoryOptions.getRequiredValue(HOST)); + } + Integer port = connectionFactoryOptions.getValue(PORT); + if (port != null) { + builder.port(port); + } + } + } + + private static void setupSsl(PostgresqlConnectionConfiguration.Builder builder, ConnectionFactoryOptions connectionFactoryOptions) { + Boolean ssl = connectionFactoryOptions.getValue(SSL); + if (ssl != null && ssl) { + builder.enableSsl(); + } + + Object sslMode = connectionFactoryOptions.getValue(SSL_MODE); + if (sslMode != null) { + if (sslMode instanceof String) { + builder.sslMode(SSLMode.fromValue(sslMode.toString())); + } else { + builder.sslMode((SSLMode) sslMode); + } + } + + String sslRootCert = connectionFactoryOptions.getValue(SSL_ROOT_CERT); + if (sslRootCert != null) { + builder.sslRootCert(sslRootCert); + } + + String sslCert = connectionFactoryOptions.getValue(SSL_CERT); + if (sslCert != null) { + builder.sslCert(sslCert); + } + + String sslKey = connectionFactoryOptions.getValue(SSL_KEY); + if (sslKey != null) { + builder.sslKey(sslKey); + } + + String sslPassword = connectionFactoryOptions.getValue(SSL_PASSWORD); + if (sslPassword != null) { + builder.sslPassword(sslPassword); + } + + if (connectionFactoryOptions.hasOption(SSL_CONTEXT_BUILDER_CUSTOMIZER)) { + builder.sslContextBuilderCustomizer(connectionFactoryOptions.getRequiredValue(SSL_CONTEXT_BUILDER_CUSTOMIZER)); + } + + setSslHostnameVerifier(builder, connectionFactoryOptions); + } } diff --git a/src/main/java/io/r2dbc/postgresql/SingleHostClientFactory.java b/src/main/java/io/r2dbc/postgresql/SingleHostClientFactory.java new file mode 100644 index 00000000..e0a7f288 --- /dev/null +++ b/src/main/java/io/r2dbc/postgresql/SingleHostClientFactory.java @@ -0,0 +1,40 @@ +package io.r2dbc.postgresql; + +import io.netty.channel.unix.DomainSocketAddress; +import io.r2dbc.postgresql.client.Client; +import io.r2dbc.postgresql.client.SingleHostConfiguration; +import reactor.core.publisher.Mono; +import reactor.util.annotation.Nullable; + +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.util.Map; + +class SingleHostClientFactory extends ClientFactoryBase { + + private final SocketAddress endpoint; + + protected SingleHostClientFactory(PostgresqlConnectionConfiguration configuration, ClientSupplier clientSupplier) { + super(configuration, clientSupplier); + this.endpoint = SingleHostClientFactory.createSocketAddress(configuration.getSingleHostConfiguration()); + } + + @Override + public Mono create(@Nullable Map options) { + return this.tryConnectToEndpoint(this.endpoint, options); + } + + + protected static SocketAddress createSocketAddress(SingleHostConfiguration configuration) { + if (!configuration.isUseSocket()) { + return InetSocketAddress.createUnresolved(configuration.getRequiredHost(), configuration.getPort()); + } + + if (configuration.isUseSocket()) { + return new DomainSocketAddress(configuration.getRequiredSocket()); + } + + throw new IllegalArgumentException("Cannot create SocketAddress for " + configuration); + } + +} diff --git a/src/main/java/io/r2dbc/postgresql/TargetServerType.java b/src/main/java/io/r2dbc/postgresql/TargetServerType.java new file mode 100644 index 00000000..27b3cda2 --- /dev/null +++ b/src/main/java/io/r2dbc/postgresql/TargetServerType.java @@ -0,0 +1,53 @@ +package io.r2dbc.postgresql; + +import javax.annotation.Nullable; + +public enum TargetServerType { + ANY("any") { + @Override + public boolean allowStatus(MultipleHostsClientFactory.HostStatus hostStatus) { + return hostStatus != MultipleHostsClientFactory.HostStatus.CONNECT_FAIL; + } + }, + MASTER("master") { + @Override + public boolean allowStatus(MultipleHostsClientFactory.HostStatus hostStatus) { + return hostStatus == MultipleHostsClientFactory.HostStatus.PRIMARY || hostStatus == MultipleHostsClientFactory.HostStatus.CONNECT_OK; + } + }, + SECONDARY("secondary") { + @Override + public boolean allowStatus(MultipleHostsClientFactory.HostStatus hostStatus) { + return hostStatus == MultipleHostsClientFactory.HostStatus.STANDBY || hostStatus == MultipleHostsClientFactory.HostStatus.CONNECT_OK; + } + }, + PREFER_SECONDARY("preferSecondary") { + @Override + public boolean allowStatus(MultipleHostsClientFactory.HostStatus hostStatus) { + return hostStatus == MultipleHostsClientFactory.HostStatus.STANDBY || hostStatus == MultipleHostsClientFactory.HostStatus.CONNECT_OK; + } + }; + + private final String value; + + TargetServerType(String value) { + this.value = value; + } + + @Nullable + public static TargetServerType fromValue(String value) { + String fixedValue = value.replace("lave", "econdary"); + for (TargetServerType type : values()) { + if (type.value.equals(fixedValue)) { + return type; + } + } + return null; + } + + public abstract boolean allowStatus(MultipleHostsClientFactory.HostStatus hostStatus); + + public String getValue() { + return value; + } +} diff --git a/src/main/java/io/r2dbc/postgresql/client/MultipleHostsConfiguration.java b/src/main/java/io/r2dbc/postgresql/client/MultipleHostsConfiguration.java new file mode 100644 index 00000000..82226e0c --- /dev/null +++ b/src/main/java/io/r2dbc/postgresql/client/MultipleHostsConfiguration.java @@ -0,0 +1,194 @@ +package io.r2dbc.postgresql.client; + +import io.r2dbc.postgresql.TargetServerType; +import io.r2dbc.postgresql.util.Assert; + +import java.time.Duration; +import java.util.ArrayList; +import java.util.List; + +import static io.r2dbc.postgresql.PostgresqlConnectionConfiguration.DEFAULT_PORT; + +public class MultipleHostsConfiguration { + + private final List hosts; + + private final Duration hostRecheckTime; + + private final boolean loadBalanceHosts; + + private final TargetServerType targetServerType; + + public MultipleHostsConfiguration(List hosts, Duration hostRecheckTime, boolean loadBalanceHosts, TargetServerType targetServerType) { + this.hosts = hosts; + this.hostRecheckTime = hostRecheckTime; + this.loadBalanceHosts = loadBalanceHosts; + this.targetServerType = targetServerType; + } + + public Duration getHostRecheckTime() { + return hostRecheckTime; + } + + public List getHosts() { + return hosts; + } + + public TargetServerType getTargetServerType() { + return targetServerType; + } + + public boolean isLoadBalanceHosts() { + return loadBalanceHosts; + } + + @Override + public String toString() { + return "MultipleHostsConfiguration{" + + "hosts=" + this.hosts + + ", hostRecheckTime=" + this.hostRecheckTime + + ", loadBalanceHosts=" + this.loadBalanceHosts + + ", targetServerType=" + this.targetServerType + + '}'; + } + + public static class ServerHost { + + private final String host; + + private final int port; + + public ServerHost(String host, int port) { + this.host = host; + this.port = port; + } + + public String getHost() { + return this.host; + } + + public int getPort() { + return this.port; + } + + @Override + public String toString() { + return "ServerHost{" + + "host='" + this.host + '\'' + + ", port=" + this.port + + '}'; + } + } + + /** + * Returns a new {@link Builder}. + * + * @return a new {@link Builder} + */ + public static Builder builder() { + return new Builder(); + } + + /** + * A builder for {@link MultipleHostsConfiguration} instances. + *

+ * This class is not threadsafe + */ + public static class Builder { + + private Duration hostRecheckTime = Duration.ofMillis(10000); + + private List hosts = new ArrayList<>(); + + private boolean loadBalanceHosts = false; + + private TargetServerType targetServerType = TargetServerType.ANY; + + /** + * Allows opening connections to only servers with required state, the allowed values are any, master, slave, secondary, preferSlave and preferSecondary. + * The master/secondary distinction is currently done by observing if the server allows writes. + * The value preferSecondary tries to connect to secondary if any are available, otherwise allows falls back to connecting also to master. + * Default value is any. + * + * @param targetServerType target server type + * @return this {@link Builder} + * @throws IllegalArgumentException if {@code targetServerType} is {@code null} + */ + public Builder targetServerType(TargetServerType targetServerType) { + this.targetServerType = Assert.requireNonNull(targetServerType, "targetServerType must not be null"); + return this; + } + + /** + * Controls how long in seconds the knowledge about a host state is cached connection factory. The default value is 10000 milliseconds. + * + * @param hostRecheckTime host recheck time in milliseconds + * @return this {@link Builder} + */ + public Builder hostRecheckTime(Duration hostRecheckTime) { + this.hostRecheckTime = hostRecheckTime; + return this; + } + + /** + * In default mode (disabled) hosts are connected in the given order. If enabled hosts are chosen randomly from the set of suitable candidates. + * + * @param loadBalanceHosts is load balance mode enabled + * @return this {@link Builder} + */ + public Builder loadBalanceHosts(boolean loadBalanceHosts) { + this.loadBalanceHosts = loadBalanceHosts; + return this; + } + + /** + * Add host with default port to hosts list. + * + * @param host the host + * @return this {@link Builder} + * @throws IllegalArgumentException if {@code host} is {@code null} + */ + public Builder addHost(String host) { + Assert.requireNonNull(host, "host must not be null"); + this.hosts.add(new ServerHost(host, DEFAULT_PORT)); + return this; + } + + /** + * Add host to hosts list. + * + * @param host the host + * @param port the port + * @return this {@link Builder} + * @throws IllegalArgumentException if {@code host} is {@code null} + */ + public Builder addHost(String host, int port) { + Assert.requireNonNull(host, "host must not be null"); + this.hosts.add(new ServerHost(host, port)); + return this; + } + + /** + * Returns a configured {@link MultipleHostsConfiguration}. + * + * @return a configured {@link MultipleHostsConfiguration} + */ + public MultipleHostsConfiguration build() { + if (this.hosts.isEmpty()) { + throw new IllegalArgumentException("At least one host should be provided"); + } + + return new MultipleHostsConfiguration(this.hosts, this.hostRecheckTime, this.loadBalanceHosts, this.targetServerType); + } + + @Override + public String toString() { + return "Builder{" + + "hostRecheckTime=" + this.hostRecheckTime + + ", hosts=" + this.hosts + + ", loadBalanceHosts=" + this.loadBalanceHosts + + ", targetServerType=" + this.targetServerType + + '}'; + } + } +} diff --git a/src/main/java/io/r2dbc/postgresql/client/SingleHostConfiguration.java b/src/main/java/io/r2dbc/postgresql/client/SingleHostConfiguration.java new file mode 100644 index 00000000..2578ce6d --- /dev/null +++ b/src/main/java/io/r2dbc/postgresql/client/SingleHostConfiguration.java @@ -0,0 +1,164 @@ +package io.r2dbc.postgresql.client; + +import io.r2dbc.postgresql.util.Assert; +import reactor.util.annotation.Nullable; + +import static io.r2dbc.postgresql.PostgresqlConnectionConfiguration.DEFAULT_PORT; + +public class SingleHostConfiguration { + + @Nullable + private final String host; + + private final int port; + + @Nullable + private final String socket; + + public SingleHostConfiguration(@Nullable String host, int port, @Nullable String socket) { + this.host = host; + this.port = port; + this.socket = socket; + } + + @Nullable + public String getHost() { + return host; + } + + public int getPort() { + return port; + } + + public String getRequiredHost() { + + String host = getHost(); + + if (host == null || host.isEmpty()) { + throw new IllegalStateException("Connection is configured for socket connections and not for host usage"); + } + + return host; + } + + public String getRequiredSocket() { + + String socket = getSocket(); + + if (socket == null || socket.isEmpty()) { + throw new IllegalStateException("Connection is configured to use host and port connections and not for socket usage"); + } + + return socket; + } + + @Nullable + public String getSocket() { + return socket; + } + + public boolean isUseSocket() { + return getSocket() != null; + } + + @Override + public String toString() { + return "SingleHostConfiguration{" + + "host='" + this.host + '\'' + + ", port=" + this.port + + ", socket='" + this.socket + '\'' + + '}'; + } + + /** + * Returns a new {@link Builder}. + * + * @return a new {@link Builder} + */ + public static Builder builder() { + return new Builder(); + } + + /** + * A builder for {@link SingleHostConfiguration} instances. + *

+ * This class is not threadsafe + */ + public static class Builder { + + @Nullable + private String host; + + private int port = DEFAULT_PORT; + + @Nullable + private String socket; + + /** + * Configure the host. + * + * @param host the host + * @return this {@link Builder} + * @throws IllegalArgumentException if {@code host} is {@code null} + */ + public Builder host(String host) { + this.host = Assert.requireNonNull(host, "host must not be null"); + return this; + } + + /** + * Configure the port. Defaults to {@code 5432}. + * + * @param port the port + * @return this {@link Builder} + */ + public Builder port(int port) { + this.port = port; + return this; + } + + /** + * Configure the unix domain socket to connect to. + * + * @param socket the socket path + * @return this {@link Builder} + * @throws IllegalArgumentException if {@code socket} is {@code null} + */ + public Builder socket(String socket) { + this.socket = Assert.requireNonNull(socket, "host must not be null"); + return this; + } + + /** + * Returns a configured {@link SingleHostConfiguration}. + * + * @return a configured {@link SingleHostConfiguration} + */ + public SingleHostConfiguration build() { + if (this.host == null && this.socket == null) { + throw new IllegalArgumentException("host or socket must not be null"); + } + if (this.host != null && this.socket != null) { + throw new IllegalArgumentException("Connection must be configured for either host/port or socket usage but not both"); + } + + return new SingleHostConfiguration(this.host, this.port, this.socket); + } + + @Nullable + public String getSocket() { + return socket; + } + + @Override + public String toString() { + return "Builder{" + + "host='" + this.host + '\'' + + ", port=" + this.port + + ", socket='" + this.socket + '\'' + + '}'; + } + } + + +} diff --git a/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionConfigurationTest.java b/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionConfigurationTest.java index 157acd36..6f0b7dc6 100644 --- a/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionConfigurationTest.java +++ b/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionConfigurationTest.java @@ -35,9 +35,9 @@ void builderNoApplicationName() { } @Test - void builderNoHostAndSocket() { + void builderNoHostConfiguration() { assertThatIllegalArgumentException().isThrownBy(() -> PostgresqlConnectionConfiguration.builder().build()) - .withMessage("host or socket must not be null"); + .withMessage("Either multiple hosts configuration or single host configuration should be provided"); } @Test @@ -74,10 +74,10 @@ void configuration() { .hasFieldOrPropertyWithValue("applicationName", "test-application-name") .hasFieldOrPropertyWithValue("connectTimeout", Duration.ofMillis(1000)) .hasFieldOrPropertyWithValue("database", "test-database") - .hasFieldOrPropertyWithValue("host", "test-host") + .hasFieldOrPropertyWithValue("singleHostConfiguration.host", "test-host") + .hasFieldOrPropertyWithValue("singleHostConfiguration.port", 100) .hasFieldOrPropertyWithValue("options", options) .hasFieldOrPropertyWithValue("password", null) - .hasFieldOrPropertyWithValue("port", 100) .hasFieldOrPropertyWithValue("schema", "test-schema") .hasFieldOrPropertyWithValue("username", "test-username") .hasFieldOrProperty("sslConfig"); @@ -96,9 +96,9 @@ void configurationDefaults() { assertThat(configuration) .hasFieldOrPropertyWithValue("applicationName", "r2dbc-postgresql") .hasFieldOrPropertyWithValue("database", "test-database") - .hasFieldOrPropertyWithValue("host", "test-host") + .hasFieldOrPropertyWithValue("singleHostConfiguration.host", "test-host") + .hasFieldOrPropertyWithValue("singleHostConfiguration.port", 5432) .hasFieldOrPropertyWithValue("password", "test-password") - .hasFieldOrPropertyWithValue("port", 5432) .hasFieldOrPropertyWithValue("schema", "test-schema") .hasFieldOrPropertyWithValue("username", "test-username") .hasFieldOrProperty("sslConfig"); diff --git a/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryProviderTest.java b/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryProviderTest.java index 7c3da972..acad3cb9 100644 --- a/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryProviderTest.java +++ b/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryProviderTest.java @@ -16,27 +16,35 @@ package io.r2dbc.postgresql; +import io.r2dbc.postgresql.client.MultipleHostsConfiguration; import io.r2dbc.postgresql.client.SSLConfig; import io.r2dbc.postgresql.client.SSLMode; import io.r2dbc.spi.ConnectionFactoryOptions; import io.r2dbc.spi.Option; import org.junit.jupiter.api.Test; +import java.time.Duration; import java.util.HashMap; +import java.util.List; import java.util.Map; import static io.r2dbc.postgresql.PostgresqlConnectionFactoryProvider.AUTODETECT_EXTENSIONS; +import static io.r2dbc.postgresql.PostgresqlConnectionFactoryProvider.FAILOVER_PROTOCOL; import static io.r2dbc.postgresql.PostgresqlConnectionFactoryProvider.FORCE_BINARY; +import static io.r2dbc.postgresql.PostgresqlConnectionFactoryProvider.HOST_RECHECK_TIME; import static io.r2dbc.postgresql.PostgresqlConnectionFactoryProvider.LEGACY_POSTGRESQL_DRIVER; +import static io.r2dbc.postgresql.PostgresqlConnectionFactoryProvider.LOAD_BALANCE_HOSTS; import static io.r2dbc.postgresql.PostgresqlConnectionFactoryProvider.OPTIONS; import static io.r2dbc.postgresql.PostgresqlConnectionFactoryProvider.POSTGRESQL_DRIVER; import static io.r2dbc.postgresql.PostgresqlConnectionFactoryProvider.PREPARED_STATEMENT_CACHE_QUERIES; import static io.r2dbc.postgresql.PostgresqlConnectionFactoryProvider.SOCKET; import static io.r2dbc.postgresql.PostgresqlConnectionFactoryProvider.SSL_CONTEXT_BUILDER_CUSTOMIZER; import static io.r2dbc.postgresql.PostgresqlConnectionFactoryProvider.SSL_MODE; +import static io.r2dbc.postgresql.PostgresqlConnectionFactoryProvider.TARGET_SERVER_TYPE; import static io.r2dbc.spi.ConnectionFactoryOptions.DRIVER; import static io.r2dbc.spi.ConnectionFactoryOptions.HOST; import static io.r2dbc.spi.ConnectionFactoryOptions.PASSWORD; +import static io.r2dbc.spi.ConnectionFactoryOptions.PROTOCOL; import static io.r2dbc.spi.ConnectionFactoryOptions.SSL; import static io.r2dbc.spi.ConnectionFactoryOptions.USER; import static io.r2dbc.spi.ConnectionFactoryOptions.builder; @@ -272,8 +280,30 @@ void shouldConnectUsingUnixDomainSocket() { .option(USER, "postgres") .build()); - assertThat(factory.getConfiguration().isUseSocket()).isTrue(); - assertThat(factory.getConfiguration().getRequiredSocket()).isEqualTo("/tmp/.s.PGSQL.5432"); + assertThat(factory.getConfiguration().getSingleHostConfiguration().isUseSocket()).isTrue(); + assertThat(factory.getConfiguration().getSingleHostConfiguration().getRequiredSocket()).isEqualTo("/tmp/.s.PGSQL.5432"); } + @Test + void testFailoverConfiguration() { + PostgresqlConnectionFactory factory = this.provider.create(builder() + .option(DRIVER, POSTGRESQL_DRIVER) + .option(PROTOCOL, FAILOVER_PROTOCOL) + .option(HOST, "host1:5433,host2:5432,host3") + .option(USER, "postgres") + .option(LOAD_BALANCE_HOSTS, true) + .option(HOST_RECHECK_TIME, 20000) + .option(TARGET_SERVER_TYPE, TargetServerType.SECONDARY) + .build()); + + assertThat(factory.getConfiguration().getSingleHostConfiguration()).isNull(); + assertThat(factory.getConfiguration().getMultipleHostsConfiguration().isLoadBalanceHosts()).isEqualTo(true); + assertThat(factory.getConfiguration().getMultipleHostsConfiguration().getHostRecheckTime()).isEqualTo(Duration.ofMillis(20000)); + assertThat(factory.getConfiguration().getMultipleHostsConfiguration().getTargetServerType()).isEqualTo(TargetServerType.SECONDARY); + List hosts = factory.getConfiguration().getMultipleHostsConfiguration().getHosts(); + assertThat(hosts).hasSize(3); + assertThat(hosts.get(0)).isEqualToComparingFieldByField(new MultipleHostsConfiguration.ServerHost("host1", 5433)); + assertThat(hosts.get(1)).isEqualToComparingFieldByField(new MultipleHostsConfiguration.ServerHost("host2", 5432)); + assertThat(hosts.get(2)).isEqualToComparingFieldByField(new MultipleHostsConfiguration.ServerHost("host3", 5432)); + } } diff --git a/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryTest.java b/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryTest.java index 215a3ff3..77771da9 100644 --- a/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryTest.java +++ b/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryTest.java @@ -77,7 +77,7 @@ void createAuthenticationMD5Password() { .password("test-password") .build(); - new PostgresqlConnectionFactory(c -> Mono.just(client), configuration) + new PostgresqlConnectionFactory(testClientFactory(client, configuration), configuration) .create() .as(StepVerifier::create) .expectNextCount(1) @@ -128,7 +128,7 @@ void createError() { .password("test-password") .build(); - new PostgresqlConnectionFactory(c -> Mono.just(client), configuration).create() + new PostgresqlConnectionFactory(testClientFactory(client, configuration), configuration).create() .as(StepVerifier::create) .verifyErrorMatches(R2dbcNonTransientResourceException.class::isInstance); } @@ -152,7 +152,12 @@ void getMetadata() { .password("test-password") .build(); - assertThat(new PostgresqlConnectionFactory(c -> Mono.just(client), configuration).getMetadata()).isNotNull(); + assertThat(new PostgresqlConnectionFactory(testClientFactory(client, configuration), configuration).getMetadata()).isNotNull(); + } + + + private ClientFactory testClientFactory(Client client, PostgresqlConnectionConfiguration configuration) { + return new SingleHostClientFactory(configuration, (endpoint, timeout, ssl) -> Mono.just(client)); } } diff --git a/src/test/java/io/r2dbc/postgresql/client/HighAvailabilityClusterTest.java b/src/test/java/io/r2dbc/postgresql/client/HighAvailabilityClusterTest.java new file mode 100644 index 00000000..4b2b34de --- /dev/null +++ b/src/test/java/io/r2dbc/postgresql/client/HighAvailabilityClusterTest.java @@ -0,0 +1,195 @@ +package io.r2dbc.postgresql.client; + +import io.r2dbc.postgresql.PostgresqlConnectionConfiguration; +import io.r2dbc.postgresql.PostgresqlConnectionFactory; +import io.r2dbc.postgresql.TargetServerType; +import io.r2dbc.postgresql.api.PostgresqlConnection; +import io.r2dbc.postgresql.util.PostgresqlHighAvailabilityClusterExtension; +import io.r2dbc.spi.R2dbcException; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; +import org.testcontainers.containers.PostgreSQLContainer; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + +public class HighAvailabilityClusterTest { + + @RegisterExtension + static final PostgresqlHighAvailabilityClusterExtension SERVERS = new PostgresqlHighAvailabilityClusterExtension(); + + @Test + void testPrimaryAndStandbyStartup() { + Assertions.assertFalse(SERVERS.getPrimaryJdbc().queryForObject("show transaction_read_only", Boolean.class)); + Assertions.assertTrue(SERVERS.getStandbyJdbc().queryForObject("show transaction_read_only", Boolean.class)); + } + + @Test + void testMultipleCallsOnSameFactory() { + PostgresqlConnectionFactory connectionFactory = this.multipleHostsConnectionFactory(TargetServerType.PREFER_SECONDARY, SERVERS.getPrimary(), SERVERS.getStandby()); + + connectionFactory + .create() + .flatMapMany(connection -> this.isPrimary(connection) + .concatWith(connection.close().then(Mono.empty()))) + .next() + .as(StepVerifier::create) + .expectNext(false) + .verifyComplete(); + + connectionFactory + .create() + .flatMapMany(connection -> this.isPrimary(connection) + .concatWith(connection.close().then(Mono.empty()))) + .next() + .as(StepVerifier::create) + .expectNext(false) + .verifyComplete(); + } + + @Test + void testTargetAnyChooseFirst() { + isConnectedToPrimary(TargetServerType.ANY, SERVERS.getPrimary(), SERVERS.getStandby()) + .as(StepVerifier::create) + .expectNext(true) + .verifyComplete(); + + isConnectedToPrimary(TargetServerType.ANY, SERVERS.getStandby(), SERVERS.getPrimary()) + .as(StepVerifier::create) + .expectNext(false) + .verifyComplete(); + } + + @Test + void testTargetAnyConnectedToPrimary() { + isConnectedToPrimary(TargetServerType.ANY, SERVERS.getPrimary()) + .as(StepVerifier::create) + .expectNext(true) + .verifyComplete(); + } + + @Test + void testTargetAnyConnectedToStandby() { + isConnectedToPrimary(TargetServerType.ANY, SERVERS.getStandby()) + .as(StepVerifier::create) + .expectNext(false) + .verifyComplete(); + } + + @Test + void testTargetPreferSecondaryChooseStandby() { + isConnectedToPrimary(TargetServerType.PREFER_SECONDARY, SERVERS.getStandby(), SERVERS.getPrimary()) + .as(StepVerifier::create) + .expectNext(false) + .verifyComplete(); + + isConnectedToPrimary(TargetServerType.PREFER_SECONDARY, SERVERS.getPrimary(), SERVERS.getStandby()) + .as(StepVerifier::create) + .expectNext(false) + .verifyComplete(); + } + + @Test + void testTargetPreferSecondaryConnectedToPrimary() { + isConnectedToPrimary(TargetServerType.PREFER_SECONDARY, SERVERS.getPrimary()) + .as(StepVerifier::create) + .expectNext(true) + .verifyComplete(); + } + + @Test + void testTargetPreferSecondaryConnectedToStandby() { + isConnectedToPrimary(TargetServerType.PREFER_SECONDARY, SERVERS.getStandby()) + .as(StepVerifier::create) + .expectNext(false) + .verifyComplete(); + } + + @Test + void testTargetPrimaryChoosePrimary() { + isConnectedToPrimary(TargetServerType.MASTER, SERVERS.getPrimary(), SERVERS.getStandby()) + .as(StepVerifier::create) + .expectNext(true) + .verifyComplete(); + + isConnectedToPrimary(TargetServerType.MASTER, SERVERS.getStandby(), SERVERS.getPrimary()) + .as(StepVerifier::create) + .expectNext(true) + .verifyComplete(); + } + + @Test + void testTargetPrimaryConnectedOnPrimary() { + isConnectedToPrimary(TargetServerType.MASTER, SERVERS.getPrimary()) + .as(StepVerifier::create) + .expectNext(true) + .verifyComplete(); + } + + @Test + void testTargetPrimaryFailedOnStandby() { + isConnectedToPrimary(TargetServerType.MASTER, SERVERS.getStandby()) + .as(StepVerifier::create) + .verifyError(R2dbcException.class); + } + + @Test + void testTargetSecondaryChooseStandby() { + isConnectedToPrimary(TargetServerType.SECONDARY, SERVERS.getStandby(), SERVERS.getPrimary()) + .as(StepVerifier::create) + .expectNext(false) + .verifyComplete(); + + isConnectedToPrimary(TargetServerType.SECONDARY, SERVERS.getPrimary(), SERVERS.getStandby()) + .as(StepVerifier::create) + .expectNext(false) + .verifyComplete(); + } + + @Test + void testTargetSecondaryConnectedOnStandby() { + isConnectedToPrimary(TargetServerType.SECONDARY, SERVERS.getStandby()) + .as(StepVerifier::create) + .expectNext(false) + .verifyComplete(); + } + + @Test + void testTargetSecondaryFailedOnPrimary() { + isConnectedToPrimary(TargetServerType.SECONDARY, SERVERS.getPrimary()) + .as(StepVerifier::create) + .verifyError(R2dbcException.class); + } + + private Mono isConnectedToPrimary(TargetServerType targetServerType, PostgreSQLContainer... servers) { + PostgresqlConnectionFactory connectionFactory = this.multipleHostsConnectionFactory(targetServerType, servers); + + return connectionFactory + .create() + .flatMapMany(connection -> this.isPrimary(connection) + .concatWith(connection.close().then(Mono.empty()))) + .next(); + } + + private Mono isPrimary(PostgresqlConnection connection) { + return connection.createStatement("show transaction_read_only") + .execute() + .flatMap(result -> result.map((row, meta) -> row.get(0, String.class))) + .map(str -> str.equalsIgnoreCase("off")) + .next(); + } + + private PostgresqlConnectionFactory multipleHostsConnectionFactory(TargetServerType targetServerType, PostgreSQLContainer... servers) { + PostgreSQLContainer firstServer = servers[0]; + PostgresqlConnectionConfiguration.Builder builder = PostgresqlConnectionConfiguration.builder(); + for (PostgreSQLContainer server : servers) { + builder.addHost(server.getContainerIpAddress(), server.getMappedPort(5432)); + } + PostgresqlConnectionConfiguration configuration = builder + .targetServerType(targetServerType) + .username(firstServer.getUsername()) + .password(firstServer.getPassword()) + .build(); + return new PostgresqlConnectionFactory(configuration); + } +} diff --git a/src/test/java/io/r2dbc/postgresql/util/PostgresqlHighAvailabilityClusterExtension.java b/src/test/java/io/r2dbc/postgresql/util/PostgresqlHighAvailabilityClusterExtension.java new file mode 100644 index 00000000..3698cdf7 --- /dev/null +++ b/src/test/java/io/r2dbc/postgresql/util/PostgresqlHighAvailabilityClusterExtension.java @@ -0,0 +1,124 @@ +package io.r2dbc.postgresql.util; + +import com.zaxxer.hikari.HikariConfig; +import com.zaxxer.hikari.HikariDataSource; +import org.junit.jupiter.api.extension.AfterAllCallback; +import org.junit.jupiter.api.extension.BeforeAllCallback; +import org.junit.jupiter.api.extension.ExtensionContext; +import org.springframework.jdbc.core.JdbcTemplate; +import org.testcontainers.containers.Network; +import org.testcontainers.containers.PostgreSQLContainer; +import org.testcontainers.containers.wait.strategy.LogMessageWaitStrategy; +import org.testcontainers.utility.MountableFile; + +import java.net.URISyntaxException; +import java.net.URL; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.time.Duration; +import java.time.temporal.ChronoUnit; + +import static org.testcontainers.utility.MountableFile.forHostPath; + +public class PostgresqlHighAvailabilityClusterExtension implements BeforeAllCallback, AfterAllCallback { + + private PostgreSQLContainer primary; + + private HikariDataSource primaryDataSource; + + private PostgreSQLContainer standby; + + private HikariDataSource standbyDataSource; + + @Override + public void afterAll(ExtensionContext extensionContext) { + if (this.standbyDataSource != null) { + this.standbyDataSource.close(); + } + if (this.standby != null) { + this.standby.stop(); + } + if (this.primaryDataSource != null) { + this.primaryDataSource.close(); + } + if (this.primary != null) { + this.primary.stop(); + } + } + + @Override + public void beforeAll(ExtensionContext extensionContext) { + Network network = Network.newNetwork(); + this.startPrimary(network); + this.startStandby(network); + } + + public PostgreSQLContainer getPrimary() { + return this.primary; + } + + public JdbcTemplate getPrimaryJdbc() { + return new JdbcTemplate(this.primaryDataSource); + } + + public PostgreSQLContainer getStandby() { + return standby; + } + + public JdbcTemplate getStandbyJdbc() { + return new JdbcTemplate(this.standbyDataSource); + } + + private static MountableFile getHostPath(String name, int mode) { + return forHostPath(getResourcePath(name), mode); + } + + private static Path getResourcePath(String name) { + URL resource = PostgresqlHighAvailabilityClusterExtension.class.getClassLoader().getResource(name); + if (resource == null) { + throw new IllegalStateException("Resource not found: " + name); + } + + try { + return Paths.get(resource.toURI()); + } catch (URISyntaxException e) { + throw new IllegalStateException("Cannot convert to path for: " + name, e); + } + } + + private void startPrimary(Network network) { + this.primary = new PostgreSQLContainer<>("postgres:latest") + .withNetwork(network) + .withNetworkAliases("postgres-primary") + .withCopyFileToContainer(getHostPath("setup-primary.sh", 0755), "/docker-entrypoint-initdb.d/setup-primary.sh") + .withEnv("PG_REP_USER", "replication") + .withEnv("PG_REP_PASSWORD", "replication_password"); + this.primary.start(); + HikariConfig primaryConfig = new HikariConfig(); + primaryConfig.setJdbcUrl(this.primary.getJdbcUrl()); + primaryConfig.setUsername(this.primary.getUsername()); + primaryConfig.setPassword(this.primary.getPassword()); + this.primaryDataSource = new HikariDataSource(primaryConfig); + } + + private void startStandby(Network network) { + this.standby = new PostgreSQLContainer<>("postgres:latest") + .withNetwork(network) + .withCopyFileToContainer(getHostPath("setup-standby.sh", 0755), "/setup-standby.sh") + .withCommand("/setup-standby.sh") + .withEnv("PG_REP_USER", "replication") + .withEnv("PG_REP_PASSWORD", "replication_password") + .withEnv("PG_MASTER_HOST", "postgres-primary") + .withEnv("PG_MASTER_PORT", "5432"); + this.standby.setWaitStrategy(new LogMessageWaitStrategy() + .withRegEx(".*database system is ready to accept read only connections.*\\s") + .withTimes(1) + .withStartupTimeout(Duration.of(60L, ChronoUnit.SECONDS))); + this.standby.start(); + HikariConfig standbyConfig = new HikariConfig(); + standbyConfig.setJdbcUrl(this.standby.getJdbcUrl()); + standbyConfig.setUsername(this.standby.getUsername()); + standbyConfig.setPassword(this.standby.getPassword()); + this.standbyDataSource = new HikariDataSource(standbyConfig); + } +} diff --git a/src/test/resources/setup-primary.sh b/src/test/resources/setup-primary.sh new file mode 100644 index 00000000..68fc820a --- /dev/null +++ b/src/test/resources/setup-primary.sh @@ -0,0 +1,17 @@ +#!/bin/sh + +echo "host replication all 0.0.0.0/0 md5" >> "$PGDATA/pg_hba.conf" + +set -e +psql -v ON_ERROR_STOP=1 --username "$POSTGRES_USER" --dbname "$POSTGRES_DB" <<-EOSQL + CREATE USER $PG_REP_USER REPLICATION LOGIN CONNECTION LIMIT 100 ENCRYPTED PASSWORD '$PG_REP_PASSWORD'; +EOSQL + +cat >> ${PGDATA}/postgresql.conf < ~/.pgpass + chmod 0600 ~/.pgpass + until pg_basebackup -h "${PG_MASTER_HOST}" -p "${PG_MASTER_PORT}" -D "${PGDATA}" -U "${PG_REP_USER}" -vP -W + do + echo "Waiting for primary server to connect..." + sleep 1s + done + echo "host replication all 0.0.0.0/0 md5" >> "$PGDATA/pg_hba.conf" + set -e + cat > "${PGDATA}"/standby.signal <