From 57618c7eb3034f57f7b71c90d4a474e0bff21ee4 Mon Sep 17 00:00:00 2001 From: Anton Duyun Date: Tue, 26 Nov 2019 14:13:55 +0300 Subject: [PATCH 1/5] Test extension with primary/standby servers --- .../client/HighAvailabilityClusterTest.java | 18 +++ ...resqlHighAvailabilityClusterExtension.java | 124 ++++++++++++++++++ src/test/resources/setup-primary.sh | 17 +++ src/test/resources/setup-standby.sh | 22 ++++ 4 files changed, 181 insertions(+) create mode 100644 src/test/java/io/r2dbc/postgresql/client/HighAvailabilityClusterTest.java create mode 100644 src/test/java/io/r2dbc/postgresql/util/PostgresqlHighAvailabilityClusterExtension.java create mode 100644 src/test/resources/setup-primary.sh create mode 100644 src/test/resources/setup-standby.sh diff --git a/src/test/java/io/r2dbc/postgresql/client/HighAvailabilityClusterTest.java b/src/test/java/io/r2dbc/postgresql/client/HighAvailabilityClusterTest.java new file mode 100644 index 00000000..c7a7fc4b --- /dev/null +++ b/src/test/java/io/r2dbc/postgresql/client/HighAvailabilityClusterTest.java @@ -0,0 +1,18 @@ +package io.r2dbc.postgresql.client; + +import io.r2dbc.postgresql.util.PostgresqlHighAvailabilityClusterExtension; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.RegisterExtension; + +public class HighAvailabilityClusterTest { + + @RegisterExtension + static final PostgresqlHighAvailabilityClusterExtension SERVERS = new PostgresqlHighAvailabilityClusterExtension(); + + @Test + void testPrimaryAndStandbyStartup() { + Assertions.assertFalse(SERVERS.getPrimaryJdbc().queryForObject("show transaction_read_only", Boolean.class)); + Assertions.assertTrue(SERVERS.getStandbyJdbc().queryForObject("show transaction_read_only", Boolean.class)); + } +} diff --git a/src/test/java/io/r2dbc/postgresql/util/PostgresqlHighAvailabilityClusterExtension.java b/src/test/java/io/r2dbc/postgresql/util/PostgresqlHighAvailabilityClusterExtension.java new file mode 100644 index 00000000..3698cdf7 --- /dev/null +++ b/src/test/java/io/r2dbc/postgresql/util/PostgresqlHighAvailabilityClusterExtension.java @@ -0,0 +1,124 @@ +package io.r2dbc.postgresql.util; + +import com.zaxxer.hikari.HikariConfig; +import com.zaxxer.hikari.HikariDataSource; +import org.junit.jupiter.api.extension.AfterAllCallback; +import org.junit.jupiter.api.extension.BeforeAllCallback; +import org.junit.jupiter.api.extension.ExtensionContext; +import org.springframework.jdbc.core.JdbcTemplate; +import org.testcontainers.containers.Network; +import org.testcontainers.containers.PostgreSQLContainer; +import org.testcontainers.containers.wait.strategy.LogMessageWaitStrategy; +import org.testcontainers.utility.MountableFile; + +import java.net.URISyntaxException; +import java.net.URL; +import java.nio.file.Path; +import java.nio.file.Paths; +import java.time.Duration; +import java.time.temporal.ChronoUnit; + +import static org.testcontainers.utility.MountableFile.forHostPath; + +public class PostgresqlHighAvailabilityClusterExtension implements BeforeAllCallback, AfterAllCallback { + + private PostgreSQLContainer primary; + + private HikariDataSource primaryDataSource; + + private PostgreSQLContainer standby; + + private HikariDataSource standbyDataSource; + + @Override + public void afterAll(ExtensionContext extensionContext) { + if (this.standbyDataSource != null) { + this.standbyDataSource.close(); + } + if (this.standby != null) { + this.standby.stop(); + } + if (this.primaryDataSource != null) { + this.primaryDataSource.close(); + } + if (this.primary != null) { + this.primary.stop(); + } + } + + @Override + public void beforeAll(ExtensionContext extensionContext) { + Network network = Network.newNetwork(); + this.startPrimary(network); + this.startStandby(network); + } + + public PostgreSQLContainer getPrimary() { + return this.primary; + } + + public JdbcTemplate getPrimaryJdbc() { + return new JdbcTemplate(this.primaryDataSource); + } + + public PostgreSQLContainer getStandby() { + return standby; + } + + public JdbcTemplate getStandbyJdbc() { + return new JdbcTemplate(this.standbyDataSource); + } + + private static MountableFile getHostPath(String name, int mode) { + return forHostPath(getResourcePath(name), mode); + } + + private static Path getResourcePath(String name) { + URL resource = PostgresqlHighAvailabilityClusterExtension.class.getClassLoader().getResource(name); + if (resource == null) { + throw new IllegalStateException("Resource not found: " + name); + } + + try { + return Paths.get(resource.toURI()); + } catch (URISyntaxException e) { + throw new IllegalStateException("Cannot convert to path for: " + name, e); + } + } + + private void startPrimary(Network network) { + this.primary = new PostgreSQLContainer<>("postgres:latest") + .withNetwork(network) + .withNetworkAliases("postgres-primary") + .withCopyFileToContainer(getHostPath("setup-primary.sh", 0755), "/docker-entrypoint-initdb.d/setup-primary.sh") + .withEnv("PG_REP_USER", "replication") + .withEnv("PG_REP_PASSWORD", "replication_password"); + this.primary.start(); + HikariConfig primaryConfig = new HikariConfig(); + primaryConfig.setJdbcUrl(this.primary.getJdbcUrl()); + primaryConfig.setUsername(this.primary.getUsername()); + primaryConfig.setPassword(this.primary.getPassword()); + this.primaryDataSource = new HikariDataSource(primaryConfig); + } + + private void startStandby(Network network) { + this.standby = new PostgreSQLContainer<>("postgres:latest") + .withNetwork(network) + .withCopyFileToContainer(getHostPath("setup-standby.sh", 0755), "/setup-standby.sh") + .withCommand("/setup-standby.sh") + .withEnv("PG_REP_USER", "replication") + .withEnv("PG_REP_PASSWORD", "replication_password") + .withEnv("PG_MASTER_HOST", "postgres-primary") + .withEnv("PG_MASTER_PORT", "5432"); + this.standby.setWaitStrategy(new LogMessageWaitStrategy() + .withRegEx(".*database system is ready to accept read only connections.*\\s") + .withTimes(1) + .withStartupTimeout(Duration.of(60L, ChronoUnit.SECONDS))); + this.standby.start(); + HikariConfig standbyConfig = new HikariConfig(); + standbyConfig.setJdbcUrl(this.standby.getJdbcUrl()); + standbyConfig.setUsername(this.standby.getUsername()); + standbyConfig.setPassword(this.standby.getPassword()); + this.standbyDataSource = new HikariDataSource(standbyConfig); + } +} diff --git a/src/test/resources/setup-primary.sh b/src/test/resources/setup-primary.sh new file mode 100644 index 00000000..68fc820a --- /dev/null +++ b/src/test/resources/setup-primary.sh @@ -0,0 +1,17 @@ +#!/bin/sh + +echo "host replication all 0.0.0.0/0 md5" >> "$PGDATA/pg_hba.conf" + +set -e +psql -v ON_ERROR_STOP=1 --username "$POSTGRES_USER" --dbname "$POSTGRES_DB" <<-EOSQL + CREATE USER $PG_REP_USER REPLICATION LOGIN CONNECTION LIMIT 100 ENCRYPTED PASSWORD '$PG_REP_PASSWORD'; +EOSQL + +cat >> ${PGDATA}/postgresql.conf < ~/.pgpass + chmod 0600 ~/.pgpass + until pg_basebackup -h "${PG_MASTER_HOST}" -p "${PG_MASTER_PORT}" -D "${PGDATA}" -U "${PG_REP_USER}" -vP -W + do + echo "Waiting for primary server to connect..." + sleep 1s + done + echo "host replication all 0.0.0.0/0 md5" >> "$PGDATA/pg_hba.conf" + set -e + cat > "${PGDATA}"/standby.signal < Date: Tue, 26 Nov 2019 14:16:01 +0300 Subject: [PATCH 2/5] primary/standby servers support implementation --- .../io/r2dbc/postgresql/ClientFactory.java | 266 ++++++++++++++++++ .../PostgresqlConnectionConfiguration.java | 128 +++++++-- .../PostgresqlConnectionFactory.java | 123 ++------ .../io/r2dbc/postgresql/TargetServerType.java | 53 ++++ .../PostgresqlConnectionFactoryTest.java | 8 +- .../client/HighAvailabilityClusterTest.java | 181 ++++++++++++ 6 files changed, 635 insertions(+), 124 deletions(-) create mode 100644 src/main/java/io/r2dbc/postgresql/ClientFactory.java create mode 100644 src/main/java/io/r2dbc/postgresql/TargetServerType.java diff --git a/src/main/java/io/r2dbc/postgresql/ClientFactory.java b/src/main/java/io/r2dbc/postgresql/ClientFactory.java new file mode 100644 index 00000000..7055d31c --- /dev/null +++ b/src/main/java/io/r2dbc/postgresql/ClientFactory.java @@ -0,0 +1,266 @@ +package io.r2dbc.postgresql; + +import io.netty.channel.unix.DomainSocketAddress; +import io.r2dbc.postgresql.authentication.AuthenticationHandler; +import io.r2dbc.postgresql.authentication.PasswordAuthenticationHandler; +import io.r2dbc.postgresql.authentication.SASLAuthenticationHandler; +import io.r2dbc.postgresql.client.Client; +import io.r2dbc.postgresql.client.SSLConfig; +import io.r2dbc.postgresql.client.SSLMode; +import io.r2dbc.postgresql.client.StartupMessageFlow; +import io.r2dbc.postgresql.codec.DefaultCodecs; +import io.r2dbc.postgresql.message.backend.AuthenticationMessage; +import io.r2dbc.postgresql.util.Assert; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import javax.annotation.Nullable; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.time.Duration; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicReference; +import java.util.function.Function; +import java.util.function.Predicate; + +import static io.r2dbc.postgresql.TargetServerType.ANY; +import static io.r2dbc.postgresql.TargetServerType.MASTER; +import static io.r2dbc.postgresql.TargetServerType.PREFER_SECONDARY; +import static io.r2dbc.postgresql.TargetServerType.SECONDARY; + +class ClientFactory implements Function, Mono> { + + private final List addresses; + + private final PostgresqlConnectionConfiguration configuration; + + private final Map statusMap = new ConcurrentHashMap<>(); + + private final ConnectionSupplier connectionSupplier; + + public ClientFactory(PostgresqlConnectionConfiguration configuration, ConnectionSupplier connectionSupplier) { + this.configuration = configuration; + this.addresses = ClientFactory.createSocketAddress(this.configuration); + this.connectionSupplier = connectionSupplier; + } + + @Override + public Mono apply(Map options) { + AtomicReference exceptionRef = new AtomicReference<>(); + TargetServerType targetServerType = this.configuration.getTargetServerType(); + return this.tryConnect(targetServerType, options) + .onErrorResume(e -> this.addresses.size() > 1, e -> { + if (!exceptionRef.compareAndSet(null, e)) { + exceptionRef.get().addSuppressed(e); + } + return Mono.empty(); + }) + .switchIfEmpty(Mono.defer(() -> targetServerType == PREFER_SECONDARY + ? this.tryConnect(MASTER, options) + : Mono.empty())) + .switchIfEmpty(Mono.error(() -> { + Throwable error = exceptionRef.get(); + if (error == null) { + return new PostgresqlConnectionFactory.PostgresConnectionException(String.format("No server matches target type %s", targetServerType.getValue()), null); + } else { + return error; + } + })); + } + + public Mono tryConnect(TargetServerType targetServerType, @Nullable Map options) { + AtomicReference exceptionRef = new AtomicReference<>(); + return this.getCandidates(targetServerType).concatMap(candidate -> this.tryConnectToCandidate(targetServerType, candidate, options) + .onErrorResume(e -> this.addresses.size() > 1, e -> { + if (!exceptionRef.compareAndSet(null, e)) { + exceptionRef.get().addSuppressed(e); + } + this.statusMap.put(candidate, HostSpecStatus.fail(candidate)); + return Mono.empty(); + })) + .next() + .switchIfEmpty(Mono.defer(() -> exceptionRef.get() != null + ? Mono.error(exceptionRef.get()) + : Mono.empty())); + } + + private static List createSocketAddress(PostgresqlConnectionConfiguration configuration) { + if (!configuration.isUseSocket()) { + if (configuration.getTmpHosts() != null) { + String[] hosts = configuration.getTmpHosts(); + int[] ports = configuration.getTmpPorts(); + List addressList = new ArrayList<>(hosts.length); + for (int i = 0; i < hosts.length; i++) { + String host = hosts[i]; + int port = ports[i]; + addressList.add(InetSocketAddress.createUnresolved(host, port)); + } + return addressList; + } else { + return Collections.singletonList(InetSocketAddress.createUnresolved(configuration.getRequiredHost(), configuration.getPort())); + } + } + + if (configuration.isUseSocket()) { + return Collections.singletonList(new DomainSocketAddress(configuration.getRequiredSocket())); + } + + throw new IllegalArgumentException("Cannot create SocketAddress for " + configuration); + } + + private static HostSpecStatus evaluateStatus(SocketAddress candidate, @Nullable HostSpecStatus oldStatus) { + return oldStatus == null || oldStatus.hostStatus == HostStatus.CONNECT_FAIL + ? HostSpecStatus.ok(candidate) + : oldStatus; + } + + private static Mono isPrimaryServer(Client client) { + return new SimpleQueryPostgresqlStatement(client, new DefaultCodecs(client.getByteBufAllocator()), "show transaction_read_only") + .execute() + .flatMap(result -> result.map((row, rowMetadata) -> row.get(0, String.class))) + .map(s -> s.equalsIgnoreCase("off")) + .next(); + } + + private AuthenticationHandler getAuthenticationHandler(AuthenticationMessage message) { + if (PasswordAuthenticationHandler.supports(message)) { + CharSequence password = Assert.requireNonNull(this.configuration.getPassword(), "Password must not be null"); + return new PasswordAuthenticationHandler(password, this.configuration.getUsername()); + } else if (SASLAuthenticationHandler.supports(message)) { + CharSequence password = Assert.requireNonNull(this.configuration.getPassword(), "Password must not be null"); + return new SASLAuthenticationHandler(password, this.configuration.getUsername()); + } else { + throw new IllegalStateException(String.format("Unable to provide AuthenticationHandler capable of handling %s", message)); + } + } + + private Flux getCandidates(TargetServerType targetServerType) { + return Flux.create(sink -> { + if (this.addresses.size() == 1) { + sink.next(this.addresses.get(0)); + sink.complete(); + return; + } + long now = System.currentTimeMillis(); + List addresses = new ArrayList<>(this.addresses); + if (this.configuration.isLoadBalance()) { + Collections.shuffle(addresses); + } + for (SocketAddress address : addresses) { + HostSpecStatus currentStatus = this.statusMap.get(address); + if (currentStatus == null || now > currentStatus.updated + this.configuration.getHostRecheckTime()) { + sink.next(address); + } else if (targetServerType.allowStatus(currentStatus.hostStatus)) { + sink.next(address); + } + } + sink.complete(); + }); + } + + private Mono tryConnectWithConfig(SSLConfig sslConfig, SocketAddress endpoint, @Nullable Map options) { + return this.connectionSupplier.connect(endpoint, this.configuration.getConnectTimeout(), sslConfig) + .delayUntil(client -> StartupMessageFlow + .exchange(this.configuration.getApplicationName(), this::getAuthenticationHandler, client, this.configuration.getDatabase(), this.configuration.getUsername(), options) + .handle(ExceptionFactory.INSTANCE::handleErrorResponse)); + } + + private Mono tryConnectToCandidate(TargetServerType targetServerType, SocketAddress candidate, @Nullable Map options) { + return Mono.create(sink -> this.tryConnectToEndpoint(candidate, options).subscribe(client -> { + this.statusMap.compute(candidate, (a, oldStatus) -> ClientFactory.evaluateStatus(candidate, oldStatus)); + if (targetServerType == ANY) { + sink.success(client); + return; + } + ClientFactory.isPrimaryServer(client).subscribe( + isPrimary -> { + if (isPrimary) { + this.statusMap.put(candidate, HostSpecStatus.primary(candidate)); + } else { + this.statusMap.put(candidate, HostSpecStatus.standby(candidate)); + } + if (isPrimary && targetServerType == MASTER) { + sink.success(client); + } else if (!isPrimary && (targetServerType == SECONDARY || targetServerType == PREFER_SECONDARY)) { + sink.success(client); + } else { + client.close().subscribe(v -> sink.success(), sink::error, sink::success, sink.currentContext()); + } + }, + sink::error, + () -> { + }, + sink.currentContext() + ); + }, sink::error, () -> { + }, sink.currentContext())); + } + + private Mono tryConnectToEndpoint(SocketAddress endpoint, @Nullable Map options) { + SSLConfig sslConfig = this.configuration.getSslConfig(); + Predicate isAuthSpecificationError = e -> e instanceof ExceptionFactory.PostgresqlAuthenticationFailure; + return this.tryConnectWithConfig(sslConfig, endpoint, options) + .onErrorResume( + isAuthSpecificationError.and(e -> sslConfig.getSslMode() == SSLMode.ALLOW), + e -> this.tryConnectWithConfig(sslConfig.mutateMode(SSLMode.REQUIRE), endpoint, options) + .onErrorResume(sslAuthError -> { + e.addSuppressed(sslAuthError); + return Mono.error(e); + }) + ) + .onErrorResume( + isAuthSpecificationError.and(e -> sslConfig.getSslMode() == SSLMode.PREFER), + e -> this.tryConnectWithConfig(sslConfig.mutateMode(SSLMode.DISABLE), endpoint, options) + .onErrorResume(sslAuthError -> { + e.addSuppressed(sslAuthError); + return Mono.error(e); + }) + ); + } + + public interface ConnectionSupplier { + + Mono connect(SocketAddress endpoint, @Nullable Duration connectTimeout, SSLConfig sslConfig); + } + + enum HostStatus { + CONNECT_FAIL, + CONNECT_OK, + PRIMARY, + STANDBY + } + + private static class HostSpecStatus { + + public final SocketAddress address; + + public final HostStatus hostStatus; + + public final long updated = System.currentTimeMillis(); + + private HostSpecStatus(SocketAddress address, HostStatus hostStatus) { + this.address = address; + this.hostStatus = hostStatus; + } + + public static HostSpecStatus fail(SocketAddress host) { + return new HostSpecStatus(host, HostStatus.CONNECT_FAIL); + } + + public static HostSpecStatus ok(SocketAddress host) { + return new HostSpecStatus(host, HostStatus.CONNECT_OK); + } + + public static HostSpecStatus primary(SocketAddress host) { + return new HostSpecStatus(host, HostStatus.PRIMARY); + } + + public static HostSpecStatus standby(SocketAddress host) { + return new HostSpecStatus(host, HostStatus.STANDBY); + } + } +} diff --git a/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionConfiguration.java b/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionConfiguration.java index 294c45ff..8883e976 100644 --- a/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionConfiguration.java +++ b/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionConfiguration.java @@ -79,10 +79,20 @@ public final class PostgresqlConnectionConfiguration { private final SSLConfig sslConfig; + private final int hostRecheckTime; + + private final boolean loadBalance; + + private final TargetServerType targetServerType; + + private final String[] tmpHosts; + + private final int[] tmpPorts; + private PostgresqlConnectionConfiguration(String applicationName, boolean autodetectExtensions, @Nullable Duration connectTimeout, @Nullable String database, List extensions, boolean forceBinary, @Nullable String host, @Nullable Map options, @Nullable CharSequence password, int port, @Nullable String schema, @Nullable String socket, String username, - SSLConfig sslConfig) { + SSLConfig sslConfig, TargetServerType targetServerType, int hostRecheckTime, boolean loadBalance, @Nullable String[] tmpHosts, @Nullable int[] tmpPorts) { this.applicationName = Assert.requireNonNull(applicationName, "applicationName must not be null"); this.autodetectExtensions = autodetectExtensions; this.connectTimeout = connectTimeout; @@ -97,6 +107,11 @@ private PostgresqlConnectionConfiguration(String applicationName, boolean autode this.socket = socket; this.username = Assert.requireNonNull(username, "username must not be null"); this.sslConfig = sslConfig; + this.targetServerType = targetServerType; + this.hostRecheckTime = hostRecheckTime; + this.loadBalance = loadBalance; + this.tmpHosts = tmpHosts; + this.tmpPorts = tmpPorts; } /** @@ -226,6 +241,28 @@ SSLConfig getSslConfig() { return this.sslConfig; } + @Nullable + public String[] getTmpHosts() { + return tmpHosts; + } + + @Nullable + public int[] getTmpPorts() { + return tmpPorts; + } + + int getHostRecheckTime() { + return this.hostRecheckTime; + } + + TargetServerType getTargetServerType() { + return this.targetServerType; + } + + boolean isLoadBalance() { + return this.loadBalance; + } + /** * A builder for {@link PostgresqlConnectionConfiguration} instances. *

@@ -284,6 +321,16 @@ public static final class Builder { @Nullable private String username; + private int hostRecheckTime = 10000; + + private boolean loadBalance = false; + + private TargetServerType targetServerType = TargetServerType.ANY; + + private String[] tmpHosts; + + private int[] tmpPorts; + private Builder() { } @@ -317,7 +364,7 @@ public Builder autodetectExtensions(boolean autodetectExtensions) { */ public PostgresqlConnectionConfiguration build() { - if (this.host == null && this.socket == null) { + if (this.host == null && this.socket == null && this.tmpHosts == null) { throw new IllegalArgumentException("host or socket must not be null"); } @@ -330,7 +377,8 @@ public PostgresqlConnectionConfiguration build() { } return new PostgresqlConnectionConfiguration(this.applicationName, this.autodetectExtensions, this.connectTimeout, this.database, this.extensions, this.forceBinary, this.host, - this.options, this.password, this.port, this.schema, this.socket, this.username, this.createSslConfig()); + this.options, this.password, this.port, this.schema, this.socket, this.username, this.createSslConfig(), this.targetServerType, this.hostRecheckTime, this.loadBalance, tmpHosts, + tmpPorts); } /** @@ -546,29 +594,9 @@ public Builder sslPassword(@Nullable CharSequence sslPassword) { return this; } - @Override - public String toString() { - return "Builder{" + - "applicationName='" + this.applicationName + '\'' + - ", autodetectExtensions='" + this.autodetectExtensions + '\'' + - ", connectTimeout='" + this.connectTimeout + '\'' + - ", database='" + this.database + '\'' + - ", extensions='" + this.extensions + '\'' + - ", forceBinary='" + this.forceBinary + '\'' + - ", host='" + this.host + '\'' + - ", parameters='" + this.options + '\'' + - ", password='" + repeat(this.password != null ? this.password.length() : 0, "*") + '\'' + - ", port=" + this.port + - ", schema='" + this.schema + '\'' + - ", username='" + this.username + '\'' + - ", socket='" + this.socket + '\'' + - ", sslContextBuilderCustomizer='" + this.sslContextBuilderCustomizer + '\'' + - ", sslMode='" + this.sslMode + '\'' + - ", sslRootCert='" + this.sslRootCert + '\'' + - ", sslCert='" + this.sslCert + '\'' + - ", sslKey='" + this.sslKey + '\'' + - ", sslHostnameVerifier='" + this.sslHostnameVerifier + '\'' + - '}'; + public Builder hostRecheckTime(int hostRecheckTime) { + this.hostRecheckTime = hostRecheckTime; + return this; } /** @@ -594,6 +622,54 @@ public Builder username(String username) { return this; } + public Builder loadBalance(boolean loadBalance) { + this.loadBalance = loadBalance; + return this; + } + + public Builder targetServerType(TargetServerType targetServerType) { + this.targetServerType = targetServerType; + return this; + } + + public Builder tmpHosts(String[] tmpHosts) { + this.tmpHosts = tmpHosts; + return this; + } + + public Builder tmpPorts(int[] tmpPorts) { + this.tmpPorts = tmpPorts; + return this; + } + + @Override + public String toString() { + return "Builder{" + + "applicationName='" + this.applicationName + '\'' + + ", autodetectExtensions='" + this.autodetectExtensions + '\'' + + ", connectTimeout='" + this.connectTimeout + '\'' + + ", database='" + this.database + '\'' + + ", extensions='" + this.extensions + '\'' + + ", forceBinary='" + this.forceBinary + '\'' + + ", host='" + this.host + '\'' + + ", parameters='" + this.options + '\'' + + ", password='" + repeat(this.password != null ? this.password.length() : 0, "*") + '\'' + + ", port=" + this.port + + ", schema='" + this.schema + '\'' + + ", username='" + this.username + '\'' + + ", socket='" + this.socket + '\'' + + ", sslContextBuilderCustomizer='" + this.sslContextBuilderCustomizer + '\'' + + ", sslMode='" + this.sslMode + '\'' + + ", sslRootCert='" + this.sslRootCert + '\'' + + ", sslCert='" + this.sslCert + '\'' + + ", sslKey='" + this.sslKey + '\'' + + ", sslHostnameVerifier='" + this.sslHostnameVerifier + '\'' + + ", targetServerType='" + this.targetServerType + '\'' + + ", hostRecheckTime='" + this.hostRecheckTime + '\'' + + ", loadBalance='" + this.loadBalance + '\'' + + '}'; + } + private SSLConfig createSslConfig() { if (this.socket != null || this.sslMode == SSLMode.DISABLE) { return SSLConfig.disabled(); diff --git a/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionFactory.java b/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionFactory.java index 9d4da04a..609a3263 100644 --- a/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionFactory.java +++ b/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionFactory.java @@ -17,18 +17,10 @@ package io.r2dbc.postgresql; import io.netty.buffer.ByteBufAllocator; -import io.netty.channel.unix.DomainSocketAddress; -import io.r2dbc.postgresql.authentication.AuthenticationHandler; -import io.r2dbc.postgresql.authentication.PasswordAuthenticationHandler; -import io.r2dbc.postgresql.authentication.SASLAuthenticationHandler; import io.r2dbc.postgresql.client.Client; import io.r2dbc.postgresql.client.ReactorNettyClient; -import io.r2dbc.postgresql.client.SSLConfig; -import io.r2dbc.postgresql.client.SSLMode; -import io.r2dbc.postgresql.client.StartupMessageFlow; import io.r2dbc.postgresql.codec.DefaultCodecs; import io.r2dbc.postgresql.extension.CodecRegistrar; -import io.r2dbc.postgresql.message.backend.AuthenticationMessage; import io.r2dbc.postgresql.util.Assert; import io.r2dbc.spi.ConnectionFactory; import io.r2dbc.spi.ConnectionFactoryMetadata; @@ -41,15 +33,12 @@ import reactor.netty.resources.ConnectionProvider; import reactor.util.annotation.Nullable; -import java.net.InetSocketAddress; -import java.net.SocketAddress; import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Locale; import java.util.Map; import java.util.function.Function; -import java.util.function.Predicate; /** * An implementation of {@link ConnectionFactory} for creating connections to a PostgreSQL database. @@ -60,12 +49,10 @@ public final class PostgresqlConnectionFactory implements ConnectionFactory { private static final String REPLICATION_DATABASE = "database"; - private final Function> clientFactory; + private final Function, Mono> clientFactory; private final PostgresqlConnectionConfiguration configuration; - private final SocketAddress endpoint; - private final Extensions extensions; /** @@ -76,31 +63,19 @@ public final class PostgresqlConnectionFactory implements ConnectionFactory { */ public PostgresqlConnectionFactory(PostgresqlConnectionConfiguration configuration) { this.configuration = Assert.requireNonNull(configuration, "configuration must not be null"); - this.endpoint = createSocketAddress(configuration); - this.clientFactory = sslConfig -> ReactorNettyClient.connect(ConnectionProvider.newConnection(), this.endpoint, configuration.getConnectTimeout(), sslConfig).cast(Client.class); + this.clientFactory = new ClientFactory(configuration, + (endpoint, connectTimeout, sslConfig) -> ReactorNettyClient.connect(ConnectionProvider.newConnection(), endpoint, connectTimeout, sslConfig) + .cast(Client.class)); this.extensions = getExtensions(configuration); } - PostgresqlConnectionFactory(Function> clientFactory, PostgresqlConnectionConfiguration configuration) { + PostgresqlConnectionFactory(ClientFactory.ConnectionSupplier connectionSupplier, PostgresqlConnectionConfiguration configuration) { this.configuration = Assert.requireNonNull(configuration, "configuration must not be null"); - this.endpoint = createSocketAddress(configuration); - this.clientFactory = Assert.requireNonNull(clientFactory, "clientFactory must not be null"); + Assert.requireNonNull(connectionSupplier, "connectionSupplier must not be null"); + this.clientFactory = new ClientFactory(configuration, connectionSupplier); this.extensions = getExtensions(configuration); } - private static SocketAddress createSocketAddress(PostgresqlConnectionConfiguration configuration) { - - if (!configuration.isUseSocket()) { - return InetSocketAddress.createUnresolved(configuration.getRequiredHost(), configuration.getPort()); - } - - if (configuration.isUseSocket()) { - return new DomainSocketAddress(configuration.getRequiredSocket()); - } - - throw new IllegalArgumentException("Cannot create SocketAddress for " + configuration); - } - private static Extensions getExtensions(PostgresqlConnectionConfiguration configuration) { Extensions extensions = Extensions.from(configuration.getExtensions()); @@ -140,42 +115,15 @@ public Mono replication return doCreateConnection(true, options).map(DefaultPostgresqlReplicationConnection::new); } - private Mono doCreateConnection(boolean forReplication, @Nullable Map options) { - - SSLConfig sslConfig = this.configuration.getSslConfig(); - Predicate isAuthSpecificationError = e -> e instanceof ExceptionFactory.PostgresqlAuthenticationFailure; - return this.tryConnectWithConfig(sslConfig, options) - .onErrorResume( - isAuthSpecificationError.and(e -> sslConfig.getSslMode() == SSLMode.ALLOW), - e -> this.tryConnectWithConfig(sslConfig.mutateMode(SSLMode.REQUIRE), options) - .onErrorResume(sslAuthError -> { - e.addSuppressed(sslAuthError); - return Mono.error(e); - }) - ) - .onErrorResume( - isAuthSpecificationError.and(e -> sslConfig.getSslMode() == SSLMode.PREFER), - e -> this.tryConnectWithConfig(sslConfig.mutateMode(SSLMode.DISABLE), options) - .onErrorResume(sslAuthError -> { - e.addSuppressed(sslAuthError); - return Mono.error(e); - }) - ) - .flatMap(client -> { - DefaultCodecs codecs = new DefaultCodecs(client.getByteBufAllocator()); + private Throwable cannotConnect(Throwable throwable) { - Mono isolationLevelMono = Mono.just(IsolationLevel.READ_COMMITTED); - if (!forReplication) { - isolationLevelMono = getIsolationLevel(client, codecs); - } + if (throwable instanceof R2dbcException) { + return throwable; + } - return isolationLevelMono - .map(it -> new PostgresqlConnection(client, codecs, DefaultPortalNameSupplier.INSTANCE, new IndefiniteStatementCache(client), it, this.configuration.isForceBinary())) - .delayUntil(connection -> { - return prepareConnection(connection, client.getByteBufAllocator(), codecs); - }) - .onErrorResume(throwable -> this.closeWithError(client, throwable)); - }).onErrorMap(this::cannotConnect); + return new PostgresConnectionException( + String.format("Cannot connect to %s", "TODO"), throwable // TODO + ); } private boolean isReplicationConnection() { @@ -183,15 +131,6 @@ private boolean isReplicationConnection() { return options != null && REPLICATION_DATABASE.equalsIgnoreCase(options.get(REPLICATION_OPTION)); } - private Mono tryConnectWithConfig(SSLConfig sslConfig, @Nullable Map options) { - return this.clientFactory.apply(sslConfig) - .delayUntil(client -> StartupMessageFlow - .exchange(this.configuration.getApplicationName(), this::getAuthenticationHandler, client, this.configuration.getDatabase(), this.configuration.getUsername(), - options) - .handle(ExceptionFactory.INSTANCE::handleErrorResponse)) - .cast(Client.class); - } - private Publisher prepareConnection(PostgresqlConnection connection, ByteBufAllocator byteBufAllocator, DefaultCodecs codecs) { List> publishers = new ArrayList<>(); @@ -208,15 +147,23 @@ private Mono closeWithError(Client client, Throwable throw return client.close().then(Mono.error(throwable)); } - private Throwable cannotConnect(Throwable throwable) { + private Mono doCreateConnection(boolean forReplication, @Nullable Map options) { + return this.clientFactory.apply(options) + .flatMap(client -> { + DefaultCodecs codecs = new DefaultCodecs(client.getByteBufAllocator()); - if (throwable instanceof R2dbcException) { - return throwable; - } + Mono isolationLevelMono = Mono.just(IsolationLevel.READ_COMMITTED); + if (!forReplication) { + isolationLevelMono = getIsolationLevel(client, codecs); + } - return new PostgresConnectionException( - String.format("Cannot connect to %s", this.endpoint), throwable - ); + return isolationLevelMono + .map(it -> new PostgresqlConnection(client, codecs, DefaultPortalNameSupplier.INSTANCE, new IndefiniteStatementCache(client), it, this.configuration.isForceBinary())) + .delayUntil(connection -> { + return prepareConnection(connection, client.getByteBufAllocator(), codecs); + }) + .onErrorResume(throwable -> this.closeWithError(client, throwable)); + }).onErrorMap(this::cannotConnect); } @Override @@ -237,18 +184,6 @@ public String toString() { '}'; } - private AuthenticationHandler getAuthenticationHandler(AuthenticationMessage message) { - if (PasswordAuthenticationHandler.supports(message)) { - CharSequence password = Assert.requireNonNull(this.configuration.getPassword(), "Password must not be null"); - return new PasswordAuthenticationHandler(password, this.configuration.getUsername()); - } else if (SASLAuthenticationHandler.supports(message)) { - CharSequence password = Assert.requireNonNull(this.configuration.getPassword(), "Password must not be null"); - return new SASLAuthenticationHandler(password, this.configuration.getUsername()); - } else { - throw new IllegalStateException(String.format("Unable to provide AuthenticationHandler capable of handling %s", message)); - } - } - private Mono getIsolationLevel(Client client, DefaultCodecs codecs) { return new SimpleQueryPostgresqlStatement(client, codecs, "SHOW TRANSACTION ISOLATION LEVEL") .execute() diff --git a/src/main/java/io/r2dbc/postgresql/TargetServerType.java b/src/main/java/io/r2dbc/postgresql/TargetServerType.java new file mode 100644 index 00000000..0ac08656 --- /dev/null +++ b/src/main/java/io/r2dbc/postgresql/TargetServerType.java @@ -0,0 +1,53 @@ +package io.r2dbc.postgresql; + +import javax.annotation.Nullable; + +public enum TargetServerType { + ANY("any") { + @Override + public boolean allowStatus(ClientFactory.HostStatus hostStatus) { + return hostStatus != ClientFactory.HostStatus.CONNECT_FAIL; + } + }, + MASTER("master") { + @Override + public boolean allowStatus(ClientFactory.HostStatus hostStatus) { + return hostStatus == ClientFactory.HostStatus.PRIMARY || hostStatus == ClientFactory.HostStatus.CONNECT_OK; + } + }, + SECONDARY("secondary") { + @Override + public boolean allowStatus(ClientFactory.HostStatus hostStatus) { + return hostStatus == ClientFactory.HostStatus.STANDBY || hostStatus == ClientFactory.HostStatus.CONNECT_OK; + } + }, + PREFER_SECONDARY("preferSecondary") { + @Override + public boolean allowStatus(ClientFactory.HostStatus hostStatus) { + return hostStatus == ClientFactory.HostStatus.STANDBY || hostStatus == ClientFactory.HostStatus.CONNECT_OK; + } + }; + + private final String value; + + TargetServerType(String value) { + this.value = value; + } + + @Nullable + public static TargetServerType fromValue(String value) { + String fixedValue = value.replace("lave", "econdary"); + for (TargetServerType type : values()) { + if (type.value.equals(fixedValue)) { + return type; + } + } + return null; + } + + public String getValue() { + return value; + } + + public abstract boolean allowStatus(ClientFactory.HostStatus hostStatus); +} diff --git a/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryTest.java b/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryTest.java index 7ab805b3..f2e7777b 100644 --- a/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryTest.java +++ b/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryTest.java @@ -49,7 +49,7 @@ void constructorNoClientFactory() { .password("test-password") .username("test-username") .build())) - .withMessage("clientFactory must not be null"); + .withMessage("connectionSupplier must not be null"); } @Test @@ -77,7 +77,7 @@ void createAuthenticationMD5Password() { .password("test-password") .build(); - new PostgresqlConnectionFactory(c -> Mono.just(client), configuration) + new PostgresqlConnectionFactory((a, b, c) -> Mono.just(client), configuration) .create() .as(StepVerifier::create) .expectNextCount(1) @@ -128,7 +128,7 @@ void createError() { .password("test-password") .build(); - new PostgresqlConnectionFactory(c -> Mono.just(client), configuration).create() + new PostgresqlConnectionFactory((a, b, c) -> Mono.just(client), configuration).create() .as(StepVerifier::create) .verifyErrorMatches(R2dbcNonTransientResourceException.class::isInstance); } @@ -152,7 +152,7 @@ void getMetadata() { .password("test-password") .build(); - assertThat(new PostgresqlConnectionFactory(c -> Mono.just(client), configuration).getMetadata()).isNotNull(); + assertThat(new PostgresqlConnectionFactory((a, b, c) -> Mono.just(client), configuration).getMetadata()).isNotNull(); } } diff --git a/src/test/java/io/r2dbc/postgresql/client/HighAvailabilityClusterTest.java b/src/test/java/io/r2dbc/postgresql/client/HighAvailabilityClusterTest.java index c7a7fc4b..9514d718 100644 --- a/src/test/java/io/r2dbc/postgresql/client/HighAvailabilityClusterTest.java +++ b/src/test/java/io/r2dbc/postgresql/client/HighAvailabilityClusterTest.java @@ -1,9 +1,17 @@ package io.r2dbc.postgresql.client; +import io.r2dbc.postgresql.PostgresqlConnectionConfiguration; +import io.r2dbc.postgresql.PostgresqlConnectionFactory; +import io.r2dbc.postgresql.TargetServerType; +import io.r2dbc.postgresql.api.PostgresqlConnection; import io.r2dbc.postgresql.util.PostgresqlHighAvailabilityClusterExtension; +import io.r2dbc.spi.R2dbcException; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; +import org.testcontainers.containers.PostgreSQLContainer; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; public class HighAvailabilityClusterTest { @@ -15,4 +23,177 @@ void testPrimaryAndStandbyStartup() { Assertions.assertFalse(SERVERS.getPrimaryJdbc().queryForObject("show transaction_read_only", Boolean.class)); Assertions.assertTrue(SERVERS.getStandbyJdbc().queryForObject("show transaction_read_only", Boolean.class)); } + + @Test + void testMultipleCallsOnSameFactory() { + PostgresqlConnectionFactory connectionFactory = this.multipleHostsConnectionFactory(TargetServerType.PREFER_SECONDARY, SERVERS.getPrimary(), SERVERS.getStandby()); + + connectionFactory + .create() + .flatMapMany(connection -> this.isPrimary(connection) + .concatWith(connection.close().then(Mono.empty()))) + .next() + .as(StepVerifier::create) + .expectNext(false) + .verifyComplete(); + + connectionFactory + .create() + .flatMapMany(connection -> this.isPrimary(connection) + .concatWith(connection.close().then(Mono.empty()))) + .next() + .as(StepVerifier::create) + .expectNext(false) + .verifyComplete(); + } + + @Test + void testTargetAnyChooseFirst() { + isConnectedToPrimary(TargetServerType.ANY, SERVERS.getPrimary(), SERVERS.getStandby()) + .as(StepVerifier::create) + .expectNext(true) + .verifyComplete(); + + isConnectedToPrimary(TargetServerType.ANY, SERVERS.getStandby(), SERVERS.getPrimary()) + .as(StepVerifier::create) + .expectNext(false) + .verifyComplete(); + } + + @Test + void testTargetAnyConnectedToPrimary() { + isConnectedToPrimary(TargetServerType.ANY, SERVERS.getPrimary()) + .as(StepVerifier::create) + .expectNext(true) + .verifyComplete(); + } + + @Test + void testTargetAnyConnectedToStandby() { + isConnectedToPrimary(TargetServerType.ANY, SERVERS.getStandby()) + .as(StepVerifier::create) + .expectNext(false) + .verifyComplete(); + } + + @Test + void testTargetPreferSecondaryChooseStandby() { + isConnectedToPrimary(TargetServerType.PREFER_SECONDARY, SERVERS.getStandby(), SERVERS.getPrimary()) + .as(StepVerifier::create) + .expectNext(false) + .verifyComplete(); + + isConnectedToPrimary(TargetServerType.PREFER_SECONDARY, SERVERS.getPrimary(), SERVERS.getStandby()) + .as(StepVerifier::create) + .expectNext(false) + .verifyComplete(); + } + + @Test + void testTargetPreferSecondaryConnectedToPrimary() { + isConnectedToPrimary(TargetServerType.PREFER_SECONDARY, SERVERS.getPrimary()) + .as(StepVerifier::create) + .expectNext(true) + .verifyComplete(); + } + + @Test + void testTargetPreferSecondaryConnectedToStandby() { + isConnectedToPrimary(TargetServerType.PREFER_SECONDARY, SERVERS.getStandby()) + .as(StepVerifier::create) + .expectNext(false) + .verifyComplete(); + } + + @Test + void testTargetPrimaryChoosePrimary() { + isConnectedToPrimary(TargetServerType.MASTER, SERVERS.getPrimary(), SERVERS.getStandby()) + .as(StepVerifier::create) + .expectNext(true) + .verifyComplete(); + + isConnectedToPrimary(TargetServerType.MASTER, SERVERS.getStandby(), SERVERS.getPrimary()) + .as(StepVerifier::create) + .expectNext(true) + .verifyComplete(); + } + + @Test + void testTargetPrimaryConnectedOnPrimary() { + isConnectedToPrimary(TargetServerType.MASTER, SERVERS.getPrimary()) + .as(StepVerifier::create) + .expectNext(true) + .verifyComplete(); + } + + @Test + void testTargetPrimaryFailedOnStandby() { + isConnectedToPrimary(TargetServerType.MASTER, SERVERS.getStandby()) + .as(StepVerifier::create) + .verifyError(R2dbcException.class); + } + + @Test + void testTargetSecondaryChooseStandby() { + isConnectedToPrimary(TargetServerType.SECONDARY, SERVERS.getStandby(), SERVERS.getPrimary()) + .as(StepVerifier::create) + .expectNext(false) + .verifyComplete(); + + isConnectedToPrimary(TargetServerType.SECONDARY, SERVERS.getPrimary(), SERVERS.getStandby()) + .as(StepVerifier::create) + .expectNext(false) + .verifyComplete(); + } + + @Test + void testTargetSecondaryConnectedOnStandby() { + isConnectedToPrimary(TargetServerType.SECONDARY, SERVERS.getStandby()) + .as(StepVerifier::create) + .expectNext(false) + .verifyComplete(); + } + + @Test + void testTargetSecondaryFailedOnPrimary() { + isConnectedToPrimary(TargetServerType.SECONDARY, SERVERS.getPrimary()) + .as(StepVerifier::create) + .verifyError(R2dbcException.class); + } + + private Mono isConnectedToPrimary(TargetServerType targetServerType, PostgreSQLContainer... servers) { + PostgresqlConnectionFactory connectionFactory = this.multipleHostsConnectionFactory(targetServerType, servers); + + return connectionFactory + .create() + .flatMapMany(connection -> this.isPrimary(connection) + .concatWith(connection.close().then(Mono.empty()))) + .next(); + } + + private Mono isPrimary(PostgresqlConnection connection) { + return connection.createStatement("show transaction_read_only") + .execute() + .flatMap(result -> result.map((row, meta) -> row.get(0, String.class))) + .map(str -> str.equalsIgnoreCase("off")) + .next(); + } + + private PostgresqlConnectionFactory multipleHostsConnectionFactory(TargetServerType targetServerType, PostgreSQLContainer... servers) { + PostgreSQLContainer server = servers[0]; + String[] hosts = new String[servers.length]; + int[] ports = new int[servers.length]; + for (int i = 0; i < servers.length; i++) { + hosts[i] = servers[i].getContainerIpAddress(); + ports[i] = servers[i].getMappedPort(5432); + } + PostgresqlConnectionConfiguration configuration = PostgresqlConnectionConfiguration.builder() + .tmpHosts(hosts) + .tmpPorts(ports) + .username(server.getUsername()) + .password(server.getPassword()) + .targetServerType(targetServerType) + .build(); + return new PostgresqlConnectionFactory(configuration); + } } From 4bcc6970c081a7cb79c4116851c78ed17e27a881 Mon Sep 17 00:00:00 2001 From: Anton Duyun Date: Tue, 26 Nov 2019 21:09:45 +0300 Subject: [PATCH 3/5] split configuration for single/multiple hosts --- .../io/r2dbc/postgresql/ClientFactory.java | 262 +----------------- .../r2dbc/postgresql/ClientFactoryBase.java | 71 +++++ .../io/r2dbc/postgresql/ClientSupplier.java | 14 + .../MultipleHostsClientFactory.java | 191 +++++++++++++ .../PostgresqlConnectionConfiguration.java | 215 +++++--------- .../PostgresqlConnectionFactory.java | 18 +- .../postgresql/SingleHostClientFactory.java | 40 +++ .../io/r2dbc/postgresql/TargetServerType.java | 18 +- .../client/MultipleHostsConfiguration.java | 59 ++++ .../client/SingleHostConfiguration.java | 61 ++++ ...stgresqlConnectionFactoryProviderTest.java | 4 +- .../PostgresqlConnectionFactoryTest.java | 11 +- .../client/HighAvailabilityClusterTest.java | 26 +- 13 files changed, 553 insertions(+), 437 deletions(-) create mode 100644 src/main/java/io/r2dbc/postgresql/ClientFactoryBase.java create mode 100644 src/main/java/io/r2dbc/postgresql/ClientSupplier.java create mode 100644 src/main/java/io/r2dbc/postgresql/MultipleHostsClientFactory.java create mode 100644 src/main/java/io/r2dbc/postgresql/SingleHostClientFactory.java create mode 100644 src/main/java/io/r2dbc/postgresql/client/MultipleHostsConfiguration.java create mode 100644 src/main/java/io/r2dbc/postgresql/client/SingleHostConfiguration.java diff --git a/src/main/java/io/r2dbc/postgresql/ClientFactory.java b/src/main/java/io/r2dbc/postgresql/ClientFactory.java index 7055d31c..7929dad4 100644 --- a/src/main/java/io/r2dbc/postgresql/ClientFactory.java +++ b/src/main/java/io/r2dbc/postgresql/ClientFactory.java @@ -1,266 +1,22 @@ package io.r2dbc.postgresql; -import io.netty.channel.unix.DomainSocketAddress; -import io.r2dbc.postgresql.authentication.AuthenticationHandler; -import io.r2dbc.postgresql.authentication.PasswordAuthenticationHandler; -import io.r2dbc.postgresql.authentication.SASLAuthenticationHandler; import io.r2dbc.postgresql.client.Client; -import io.r2dbc.postgresql.client.SSLConfig; -import io.r2dbc.postgresql.client.SSLMode; -import io.r2dbc.postgresql.client.StartupMessageFlow; -import io.r2dbc.postgresql.codec.DefaultCodecs; -import io.r2dbc.postgresql.message.backend.AuthenticationMessage; -import io.r2dbc.postgresql.util.Assert; -import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import reactor.util.annotation.Nullable; -import javax.annotation.Nullable; -import java.net.InetSocketAddress; -import java.net.SocketAddress; -import java.time.Duration; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.atomic.AtomicReference; -import java.util.function.Function; -import java.util.function.Predicate; -import static io.r2dbc.postgresql.TargetServerType.ANY; -import static io.r2dbc.postgresql.TargetServerType.MASTER; -import static io.r2dbc.postgresql.TargetServerType.PREFER_SECONDARY; -import static io.r2dbc.postgresql.TargetServerType.SECONDARY; +public interface ClientFactory { -class ClientFactory implements Function, Mono> { - - private final List addresses; - - private final PostgresqlConnectionConfiguration configuration; - - private final Map statusMap = new ConcurrentHashMap<>(); - - private final ConnectionSupplier connectionSupplier; - - public ClientFactory(PostgresqlConnectionConfiguration configuration, ConnectionSupplier connectionSupplier) { - this.configuration = configuration; - this.addresses = ClientFactory.createSocketAddress(this.configuration); - this.connectionSupplier = connectionSupplier; - } - - @Override - public Mono apply(Map options) { - AtomicReference exceptionRef = new AtomicReference<>(); - TargetServerType targetServerType = this.configuration.getTargetServerType(); - return this.tryConnect(targetServerType, options) - .onErrorResume(e -> this.addresses.size() > 1, e -> { - if (!exceptionRef.compareAndSet(null, e)) { - exceptionRef.get().addSuppressed(e); - } - return Mono.empty(); - }) - .switchIfEmpty(Mono.defer(() -> targetServerType == PREFER_SECONDARY - ? this.tryConnect(MASTER, options) - : Mono.empty())) - .switchIfEmpty(Mono.error(() -> { - Throwable error = exceptionRef.get(); - if (error == null) { - return new PostgresqlConnectionFactory.PostgresConnectionException(String.format("No server matches target type %s", targetServerType.getValue()), null); - } else { - return error; - } - })); - } - - public Mono tryConnect(TargetServerType targetServerType, @Nullable Map options) { - AtomicReference exceptionRef = new AtomicReference<>(); - return this.getCandidates(targetServerType).concatMap(candidate -> this.tryConnectToCandidate(targetServerType, candidate, options) - .onErrorResume(e -> this.addresses.size() > 1, e -> { - if (!exceptionRef.compareAndSet(null, e)) { - exceptionRef.get().addSuppressed(e); - } - this.statusMap.put(candidate, HostSpecStatus.fail(candidate)); - return Mono.empty(); - })) - .next() - .switchIfEmpty(Mono.defer(() -> exceptionRef.get() != null - ? Mono.error(exceptionRef.get()) - : Mono.empty())); - } - - private static List createSocketAddress(PostgresqlConnectionConfiguration configuration) { - if (!configuration.isUseSocket()) { - if (configuration.getTmpHosts() != null) { - String[] hosts = configuration.getTmpHosts(); - int[] ports = configuration.getTmpPorts(); - List addressList = new ArrayList<>(hosts.length); - for (int i = 0; i < hosts.length; i++) { - String host = hosts[i]; - int port = ports[i]; - addressList.add(InetSocketAddress.createUnresolved(host, port)); - } - return addressList; - } else { - return Collections.singletonList(InetSocketAddress.createUnresolved(configuration.getRequiredHost(), configuration.getPort())); - } - } - - if (configuration.isUseSocket()) { - return Collections.singletonList(new DomainSocketAddress(configuration.getRequiredSocket())); + static ClientFactory getFactory(PostgresqlConnectionConfiguration configuration, ClientSupplier clientSupplier) { + if (configuration.getSingleHostConfiguration() != null) { + return new SingleHostClientFactory(configuration, clientSupplier); } - - throw new IllegalArgumentException("Cannot create SocketAddress for " + configuration); - } - - private static HostSpecStatus evaluateStatus(SocketAddress candidate, @Nullable HostSpecStatus oldStatus) { - return oldStatus == null || oldStatus.hostStatus == HostStatus.CONNECT_FAIL - ? HostSpecStatus.ok(candidate) - : oldStatus; - } - - private static Mono isPrimaryServer(Client client) { - return new SimpleQueryPostgresqlStatement(client, new DefaultCodecs(client.getByteBufAllocator()), "show transaction_read_only") - .execute() - .flatMap(result -> result.map((row, rowMetadata) -> row.get(0, String.class))) - .map(s -> s.equalsIgnoreCase("off")) - .next(); - } - - private AuthenticationHandler getAuthenticationHandler(AuthenticationMessage message) { - if (PasswordAuthenticationHandler.supports(message)) { - CharSequence password = Assert.requireNonNull(this.configuration.getPassword(), "Password must not be null"); - return new PasswordAuthenticationHandler(password, this.configuration.getUsername()); - } else if (SASLAuthenticationHandler.supports(message)) { - CharSequence password = Assert.requireNonNull(this.configuration.getPassword(), "Password must not be null"); - return new SASLAuthenticationHandler(password, this.configuration.getUsername()); - } else { - throw new IllegalStateException(String.format("Unable to provide AuthenticationHandler capable of handling %s", message)); + if (configuration.getMultipleHostsConfiguration() != null) { + return new MultipleHostsClientFactory(configuration, clientSupplier); } + throw new IllegalArgumentException("Can't build client factory based on configuration " + configuration); } - private Flux getCandidates(TargetServerType targetServerType) { - return Flux.create(sink -> { - if (this.addresses.size() == 1) { - sink.next(this.addresses.get(0)); - sink.complete(); - return; - } - long now = System.currentTimeMillis(); - List addresses = new ArrayList<>(this.addresses); - if (this.configuration.isLoadBalance()) { - Collections.shuffle(addresses); - } - for (SocketAddress address : addresses) { - HostSpecStatus currentStatus = this.statusMap.get(address); - if (currentStatus == null || now > currentStatus.updated + this.configuration.getHostRecheckTime()) { - sink.next(address); - } else if (targetServerType.allowStatus(currentStatus.hostStatus)) { - sink.next(address); - } - } - sink.complete(); - }); - } - - private Mono tryConnectWithConfig(SSLConfig sslConfig, SocketAddress endpoint, @Nullable Map options) { - return this.connectionSupplier.connect(endpoint, this.configuration.getConnectTimeout(), sslConfig) - .delayUntil(client -> StartupMessageFlow - .exchange(this.configuration.getApplicationName(), this::getAuthenticationHandler, client, this.configuration.getDatabase(), this.configuration.getUsername(), options) - .handle(ExceptionFactory.INSTANCE::handleErrorResponse)); - } - - private Mono tryConnectToCandidate(TargetServerType targetServerType, SocketAddress candidate, @Nullable Map options) { - return Mono.create(sink -> this.tryConnectToEndpoint(candidate, options).subscribe(client -> { - this.statusMap.compute(candidate, (a, oldStatus) -> ClientFactory.evaluateStatus(candidate, oldStatus)); - if (targetServerType == ANY) { - sink.success(client); - return; - } - ClientFactory.isPrimaryServer(client).subscribe( - isPrimary -> { - if (isPrimary) { - this.statusMap.put(candidate, HostSpecStatus.primary(candidate)); - } else { - this.statusMap.put(candidate, HostSpecStatus.standby(candidate)); - } - if (isPrimary && targetServerType == MASTER) { - sink.success(client); - } else if (!isPrimary && (targetServerType == SECONDARY || targetServerType == PREFER_SECONDARY)) { - sink.success(client); - } else { - client.close().subscribe(v -> sink.success(), sink::error, sink::success, sink.currentContext()); - } - }, - sink::error, - () -> { - }, - sink.currentContext() - ); - }, sink::error, () -> { - }, sink.currentContext())); - } - - private Mono tryConnectToEndpoint(SocketAddress endpoint, @Nullable Map options) { - SSLConfig sslConfig = this.configuration.getSslConfig(); - Predicate isAuthSpecificationError = e -> e instanceof ExceptionFactory.PostgresqlAuthenticationFailure; - return this.tryConnectWithConfig(sslConfig, endpoint, options) - .onErrorResume( - isAuthSpecificationError.and(e -> sslConfig.getSslMode() == SSLMode.ALLOW), - e -> this.tryConnectWithConfig(sslConfig.mutateMode(SSLMode.REQUIRE), endpoint, options) - .onErrorResume(sslAuthError -> { - e.addSuppressed(sslAuthError); - return Mono.error(e); - }) - ) - .onErrorResume( - isAuthSpecificationError.and(e -> sslConfig.getSslMode() == SSLMode.PREFER), - e -> this.tryConnectWithConfig(sslConfig.mutateMode(SSLMode.DISABLE), endpoint, options) - .onErrorResume(sslAuthError -> { - e.addSuppressed(sslAuthError); - return Mono.error(e); - }) - ); - } - - public interface ConnectionSupplier { - - Mono connect(SocketAddress endpoint, @Nullable Duration connectTimeout, SSLConfig sslConfig); - } - - enum HostStatus { - CONNECT_FAIL, - CONNECT_OK, - PRIMARY, - STANDBY - } - - private static class HostSpecStatus { - - public final SocketAddress address; - - public final HostStatus hostStatus; - - public final long updated = System.currentTimeMillis(); - - private HostSpecStatus(SocketAddress address, HostStatus hostStatus) { - this.address = address; - this.hostStatus = hostStatus; - } - - public static HostSpecStatus fail(SocketAddress host) { - return new HostSpecStatus(host, HostStatus.CONNECT_FAIL); - } - - public static HostSpecStatus ok(SocketAddress host) { - return new HostSpecStatus(host, HostStatus.CONNECT_OK); - } - - public static HostSpecStatus primary(SocketAddress host) { - return new HostSpecStatus(host, HostStatus.PRIMARY); - } - - public static HostSpecStatus standby(SocketAddress host) { - return new HostSpecStatus(host, HostStatus.STANDBY); - } - } + Mono create(@Nullable Map options); } diff --git a/src/main/java/io/r2dbc/postgresql/ClientFactoryBase.java b/src/main/java/io/r2dbc/postgresql/ClientFactoryBase.java new file mode 100644 index 00000000..7fbeec7d --- /dev/null +++ b/src/main/java/io/r2dbc/postgresql/ClientFactoryBase.java @@ -0,0 +1,71 @@ +package io.r2dbc.postgresql; + +import io.r2dbc.postgresql.authentication.AuthenticationHandler; +import io.r2dbc.postgresql.authentication.PasswordAuthenticationHandler; +import io.r2dbc.postgresql.authentication.SASLAuthenticationHandler; +import io.r2dbc.postgresql.client.Client; +import io.r2dbc.postgresql.client.SSLConfig; +import io.r2dbc.postgresql.client.SSLMode; +import io.r2dbc.postgresql.client.StartupMessageFlow; +import io.r2dbc.postgresql.message.backend.AuthenticationMessage; +import io.r2dbc.postgresql.util.Assert; +import reactor.core.publisher.Mono; + +import javax.annotation.Nullable; +import java.net.SocketAddress; +import java.util.Map; +import java.util.function.Predicate; + +public abstract class ClientFactoryBase implements ClientFactory { + + private final ClientSupplier clientSupplier; + + private final PostgresqlConnectionConfiguration configuration; + + protected ClientFactoryBase(PostgresqlConnectionConfiguration configuration, ClientSupplier clientSupplier) { + this.configuration = configuration; + this.clientSupplier = clientSupplier; + } + + protected AuthenticationHandler getAuthenticationHandler(AuthenticationMessage message) { + if (PasswordAuthenticationHandler.supports(message)) { + CharSequence password = Assert.requireNonNull(this.configuration.getPassword(), "Password must not be null"); + return new PasswordAuthenticationHandler(password, this.configuration.getUsername()); + } else if (SASLAuthenticationHandler.supports(message)) { + CharSequence password = Assert.requireNonNull(this.configuration.getPassword(), "Password must not be null"); + return new SASLAuthenticationHandler(password, this.configuration.getUsername()); + } else { + throw new IllegalStateException(String.format("Unable to provide AuthenticationHandler capable of handling %s", message)); + } + } + + protected Mono tryConnectToEndpoint(SocketAddress endpoint, @Nullable Map options) { + SSLConfig sslConfig = this.configuration.getSslConfig(); + Predicate isAuthSpecificationError = e -> e instanceof ExceptionFactory.PostgresqlAuthenticationFailure; + return this.tryConnectWithConfig(sslConfig, endpoint, options) + .onErrorResume( + isAuthSpecificationError.and(e -> sslConfig.getSslMode() == SSLMode.ALLOW), + e -> this.tryConnectWithConfig(sslConfig.mutateMode(SSLMode.REQUIRE), endpoint, options) + .onErrorResume(sslAuthError -> { + e.addSuppressed(sslAuthError); + return Mono.error(e); + }) + ) + .onErrorResume( + isAuthSpecificationError.and(e -> sslConfig.getSslMode() == SSLMode.PREFER), + e -> this.tryConnectWithConfig(sslConfig.mutateMode(SSLMode.DISABLE), endpoint, options) + .onErrorResume(sslAuthError -> { + e.addSuppressed(sslAuthError); + return Mono.error(e); + }) + ); + } + + protected Mono tryConnectWithConfig(SSLConfig sslConfig, SocketAddress endpoint, @Nullable Map options) { + return this.clientSupplier.connect(endpoint, this.configuration.getConnectTimeout(), sslConfig) + .delayUntil(client -> StartupMessageFlow + .exchange(this.configuration.getApplicationName(), this::getAuthenticationHandler, client, this.configuration.getDatabase(), this.configuration.getUsername(), options) + .handle(ExceptionFactory.INSTANCE::handleErrorResponse)); + } + +} diff --git a/src/main/java/io/r2dbc/postgresql/ClientSupplier.java b/src/main/java/io/r2dbc/postgresql/ClientSupplier.java new file mode 100644 index 00000000..21f1a5d8 --- /dev/null +++ b/src/main/java/io/r2dbc/postgresql/ClientSupplier.java @@ -0,0 +1,14 @@ +package io.r2dbc.postgresql; + +import io.r2dbc.postgresql.client.Client; +import io.r2dbc.postgresql.client.SSLConfig; +import reactor.core.publisher.Mono; + +import javax.annotation.Nullable; +import java.net.SocketAddress; +import java.time.Duration; + +public interface ClientSupplier { + + Mono connect(SocketAddress endpoint, @Nullable Duration connectTimeout, SSLConfig sslConfig); +} diff --git a/src/main/java/io/r2dbc/postgresql/MultipleHostsClientFactory.java b/src/main/java/io/r2dbc/postgresql/MultipleHostsClientFactory.java new file mode 100644 index 00000000..645c168c --- /dev/null +++ b/src/main/java/io/r2dbc/postgresql/MultipleHostsClientFactory.java @@ -0,0 +1,191 @@ +package io.r2dbc.postgresql; + +import io.r2dbc.postgresql.client.Client; +import io.r2dbc.postgresql.client.MultipleHostsConfiguration; +import io.r2dbc.postgresql.codec.DefaultCodecs; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import javax.annotation.Nullable; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicReference; + +import static io.r2dbc.postgresql.TargetServerType.ANY; +import static io.r2dbc.postgresql.TargetServerType.MASTER; +import static io.r2dbc.postgresql.TargetServerType.PREFER_SECONDARY; +import static io.r2dbc.postgresql.TargetServerType.SECONDARY; + +class MultipleHostsClientFactory extends ClientFactoryBase { + + private final List addresses; + + private final MultipleHostsConfiguration configuration; + + private final Map statusMap = new ConcurrentHashMap<>(); + + public MultipleHostsClientFactory(PostgresqlConnectionConfiguration configuration, ClientSupplier clientSupplier) { + super(configuration, clientSupplier); + this.configuration = configuration.getMultipleHostsConfiguration(); + this.addresses = MultipleHostsClientFactory.createSocketAddress(this.configuration); + } + + @Override + public Mono create(@Nullable Map options) { + AtomicReference exceptionRef = new AtomicReference<>(); + TargetServerType targetServerType = this.configuration.getTargetServerType(); + return this.tryConnect(targetServerType, options) + .onErrorResume(e -> { + if (!exceptionRef.compareAndSet(null, e)) { + exceptionRef.get().addSuppressed(e); + } + return Mono.empty(); + }) + .switchIfEmpty(Mono.defer(() -> targetServerType == PREFER_SECONDARY + ? this.tryConnect(MASTER, options) + : Mono.empty())) + .switchIfEmpty(Mono.error(() -> { + Throwable error = exceptionRef.get(); + if (error == null) { + return new PostgresqlConnectionFactory.PostgresConnectionException(String.format("No server matches target type %s", targetServerType.getValue()), null); + } else { + return error; + } + })); + } + + public Mono tryConnect(TargetServerType targetServerType, @Nullable Map options) { + AtomicReference exceptionRef = new AtomicReference<>(); + return this.getCandidates(targetServerType).concatMap(candidate -> this.tryConnectToCandidate(targetServerType, candidate, options) + .onErrorResume(e -> { + if (!exceptionRef.compareAndSet(null, e)) { + exceptionRef.get().addSuppressed(e); + } + this.statusMap.put(candidate, HostSpecStatus.fail(candidate)); + return Mono.empty(); + })) + .next() + .switchIfEmpty(Mono.defer(() -> exceptionRef.get() != null + ? Mono.error(exceptionRef.get()) + : Mono.empty())); + } + + private static List createSocketAddress(MultipleHostsConfiguration configuration) { + List addressList = new ArrayList<>(configuration.getHosts().size()); + for (MultipleHostsConfiguration.ServerHost host : configuration.getHosts()) { + addressList.add(InetSocketAddress.createUnresolved(host.getHost(), host.getPort())); + } + return addressList; + } + + private static HostSpecStatus evaluateStatus(SocketAddress candidate, @Nullable HostSpecStatus oldStatus) { + return oldStatus == null || oldStatus.hostStatus == HostStatus.CONNECT_FAIL + ? HostSpecStatus.ok(candidate) + : oldStatus; + } + + private static Mono isPrimaryServer(Client client) { + return new SimpleQueryPostgresqlStatement(client, new DefaultCodecs(client.getByteBufAllocator()), "show transaction_read_only") + .execute() + .flatMap(result -> result.map((row, rowMetadata) -> row.get(0, String.class))) + .map(s -> s.equalsIgnoreCase("off")) + .next(); + } + + private Flux getCandidates(TargetServerType targetServerType) { + return Flux.create(sink -> { + if (this.addresses.size() == 1) { + sink.next(this.addresses.get(0)); + sink.complete(); + return; + } + long now = System.currentTimeMillis(); + List addresses = new ArrayList<>(this.addresses); + if (this.configuration.isLoadBalance()) { + Collections.shuffle(addresses); + } + for (SocketAddress address : addresses) { + HostSpecStatus currentStatus = this.statusMap.get(address); + if (currentStatus == null || now > currentStatus.updated + this.configuration.getHostRecheckTime()) { + sink.next(address); + } else if (targetServerType.allowStatus(currentStatus.hostStatus)) { + sink.next(address); + } + } + sink.complete(); + }); + } + + private Mono tryConnectToCandidate(TargetServerType targetServerType, SocketAddress candidate, @Nullable Map options) { + return Mono.create(sink -> this.tryConnectToEndpoint(candidate, options).subscribe(client -> { + this.statusMap.compute(candidate, (a, oldStatus) -> MultipleHostsClientFactory.evaluateStatus(candidate, oldStatus)); + if (targetServerType == ANY) { + sink.success(client); + return; + } + MultipleHostsClientFactory.isPrimaryServer(client).subscribe( + isPrimary -> { + if (isPrimary) { + this.statusMap.put(candidate, HostSpecStatus.primary(candidate)); + } else { + this.statusMap.put(candidate, HostSpecStatus.standby(candidate)); + } + if (isPrimary && targetServerType == MASTER) { + sink.success(client); + } else if (!isPrimary && (targetServerType == SECONDARY || targetServerType == PREFER_SECONDARY)) { + sink.success(client); + } else { + client.close().subscribe(v -> sink.success(), sink::error, sink::success, sink.currentContext()); + } + }, + sink::error, + () -> { + }, + sink.currentContext() + ); + }, sink::error, () -> { + }, sink.currentContext())); + } + + enum HostStatus { + CONNECT_FAIL, + CONNECT_OK, + PRIMARY, + STANDBY + } + + private static class HostSpecStatus { + + public final SocketAddress address; + + public final HostStatus hostStatus; + + public final long updated = System.currentTimeMillis(); + + private HostSpecStatus(SocketAddress address, HostStatus hostStatus) { + this.address = address; + this.hostStatus = hostStatus; + } + + public static HostSpecStatus fail(SocketAddress host) { + return new HostSpecStatus(host, HostStatus.CONNECT_FAIL); + } + + public static HostSpecStatus ok(SocketAddress host) { + return new HostSpecStatus(host, HostStatus.CONNECT_OK); + } + + public static HostSpecStatus primary(SocketAddress host) { + return new HostSpecStatus(host, HostStatus.PRIMARY); + } + + public static HostSpecStatus standby(SocketAddress host) { + return new HostSpecStatus(host, HostStatus.STANDBY); + } + } +} diff --git a/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionConfiguration.java b/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionConfiguration.java index 8883e976..1fed357a 100644 --- a/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionConfiguration.java +++ b/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionConfiguration.java @@ -20,8 +20,10 @@ import io.netty.handler.ssl.SslContextBuilder; import io.netty.handler.ssl.util.InsecureTrustManagerFactory; import io.r2dbc.postgresql.client.DefaultHostnameVerifier; +import io.r2dbc.postgresql.client.MultipleHostsConfiguration; import io.r2dbc.postgresql.client.SSLConfig; import io.r2dbc.postgresql.client.SSLMode; +import io.r2dbc.postgresql.client.SingleHostConfiguration; import io.r2dbc.postgresql.codec.Codec; import io.r2dbc.postgresql.extension.CodecRegistrar; import io.r2dbc.postgresql.extension.Extension; @@ -63,55 +65,40 @@ public final class PostgresqlConnectionConfiguration { private final boolean forceBinary; - private final String host; - private final Map options; private final CharSequence password; - private final int port; - private final String schema; - private final String socket; private final String username; private final SSLConfig sslConfig; - private final int hostRecheckTime; - - private final boolean loadBalance; + private final MultipleHostsConfiguration multipleHostsConfiguration; - private final TargetServerType targetServerType; - - private final String[] tmpHosts; - - private final int[] tmpPorts; + private final SingleHostConfiguration singleHostConfiguration; private PostgresqlConnectionConfiguration(String applicationName, boolean autodetectExtensions, - @Nullable Duration connectTimeout, @Nullable String database, List extensions, boolean forceBinary, @Nullable String host, - @Nullable Map options, @Nullable CharSequence password, int port, @Nullable String schema, @Nullable String socket, String username, - SSLConfig sslConfig, TargetServerType targetServerType, int hostRecheckTime, boolean loadBalance, @Nullable String[] tmpHosts, @Nullable int[] tmpPorts) { + @Nullable Duration connectTimeout, @Nullable String database, List extensions, boolean forceBinary, + @Nullable Map options, @Nullable CharSequence password, @Nullable String schema, String username, + SSLConfig sslConfig, + @Nullable SingleHostConfiguration singleHostConfiguration, + @Nullable MultipleHostsConfiguration multipleHostsConfiguration) { this.applicationName = Assert.requireNonNull(applicationName, "applicationName must not be null"); this.autodetectExtensions = autodetectExtensions; this.connectTimeout = connectTimeout; this.extensions = Assert.requireNonNull(extensions, "extensions must not be null"); this.database = database; this.forceBinary = forceBinary; - this.host = host; this.options = options; this.password = password; - this.port = port; this.schema = schema; - this.socket = socket; this.username = Assert.requireNonNull(username, "username must not be null"); this.sslConfig = sslConfig; - this.targetServerType = targetServerType; - this.hostRecheckTime = hostRecheckTime; - this.loadBalance = loadBalance; - this.tmpHosts = tmpHosts; - this.tmpPorts = tmpPorts; + this.singleHostConfiguration = singleHostConfiguration; + this.multipleHostsConfiguration = multipleHostsConfiguration; } /** @@ -134,19 +121,29 @@ private static String repeat(int length, String character) { return builder.toString(); } + @Nullable + public MultipleHostsConfiguration getMultipleHostsConfiguration() { + return multipleHostsConfiguration; + } + + @Nullable + public SingleHostConfiguration getSingleHostConfiguration() { + return singleHostConfiguration; + } + @Override public String toString() { return "PostgresqlConnectionConfiguration{" + "applicationName='" + this.applicationName + '\'' + + ", singleHostConfiguration='" + this.singleHostConfiguration + '\'' + + ", multipleHostsConfiguration='" + this.multipleHostsConfiguration + '\'' + ", autodetectExtensions='" + this.autodetectExtensions + '\'' + ", connectTimeout=" + this.connectTimeout + ", database='" + this.database + '\'' + ", extensions=" + this.extensions + ", forceBinary='" + this.forceBinary + '\'' + - ", host='" + this.host + '\'' + ", options='" + this.options + '\'' + ", password='" + repeat(this.password != null ? this.password.length() : 0, "*") + '\'' + - ", port=" + this.port + ", schema='" + this.schema + '\'' + ", username='" + this.username + '\'' + '}'; @@ -170,22 +167,6 @@ List getExtensions() { return this.extensions; } - @Nullable - String getHost() { - return this.host; - } - - String getRequiredHost() { - - String host = getHost(); - - if (host == null || host.isEmpty()) { - throw new IllegalStateException("Connection is configured for socket connections and not for host usage"); - } - - return host; - } - @Nullable Map getOptions() { return this.options; @@ -196,30 +177,11 @@ CharSequence getPassword() { return this.password; } - int getPort() { - return this.port; - } - @Nullable String getSchema() { return this.schema; } - @Nullable - String getSocket() { - return this.socket; - } - - String getRequiredSocket() { - - String socket = getSocket(); - - if (socket == null || socket.isEmpty()) { - throw new IllegalStateException("Connection is configured to use host and port connections and not for socket usage"); - } - - return socket; - } String getUsername() { return this.username; @@ -233,36 +195,10 @@ boolean isForceBinary() { return this.forceBinary; } - boolean isUseSocket() { - return getSocket() != null; - } - SSLConfig getSslConfig() { return this.sslConfig; } - @Nullable - public String[] getTmpHosts() { - return tmpHosts; - } - - @Nullable - public int[] getTmpPorts() { - return tmpPorts; - } - - int getHostRecheckTime() { - return this.hostRecheckTime; - } - - TargetServerType getTargetServerType() { - return this.targetServerType; - } - - boolean isLoadBalance() { - return this.loadBalance; - } - /** * A builder for {@link PostgresqlConnectionConfiguration} instances. *

@@ -285,21 +221,19 @@ public static final class Builder { private boolean forceBinary = false; @Nullable - private String host; + private MultipleHostsConfiguration multipleHostsConfiguration; + + @Nullable + private SingleHostConfiguration singleHostConfiguration; private Map options; @Nullable private CharSequence password; - private int port = DEFAULT_PORT; - @Nullable private String schema; - @Nullable - private String socket; - @Nullable private String sslCert = null; @@ -321,16 +255,6 @@ public static final class Builder { @Nullable private String username; - private int hostRecheckTime = 10000; - - private boolean loadBalance = false; - - private TargetServerType targetServerType = TargetServerType.ANY; - - private String[] tmpHosts; - - private int[] tmpPorts; - private Builder() { } @@ -364,11 +288,11 @@ public Builder autodetectExtensions(boolean autodetectExtensions) { */ public PostgresqlConnectionConfiguration build() { - if (this.host == null && this.socket == null && this.tmpHosts == null) { + if (this.singleHostConfiguration != null && this.singleHostConfiguration.getHost() == null && this.singleHostConfiguration.getSocket() == null) { throw new IllegalArgumentException("host or socket must not be null"); } - if (this.host != null && this.socket != null) { + if (this.singleHostConfiguration != null && this.singleHostConfiguration.getHost() != null && this.singleHostConfiguration.getSocket() != null) { throw new IllegalArgumentException("Connection must be configured for either host/port or socket usage but not both"); } @@ -376,9 +300,8 @@ public PostgresqlConnectionConfiguration build() { throw new IllegalArgumentException("username must not be null"); } - return new PostgresqlConnectionConfiguration(this.applicationName, this.autodetectExtensions, this.connectTimeout, this.database, this.extensions, this.forceBinary, this.host, - this.options, this.password, this.port, this.schema, this.socket, this.username, this.createSslConfig(), this.targetServerType, this.hostRecheckTime, this.loadBalance, tmpHosts, - tmpPorts); + return new PostgresqlConnectionConfiguration(this.applicationName, this.autodetectExtensions, this.connectTimeout, this.database, this.extensions, this.forceBinary, + this.options, this.password, this.schema, this.username, this.createSslConfig(), this.singleHostConfiguration, this.multipleHostsConfiguration); } /** @@ -452,7 +375,12 @@ public Builder forceBinary(boolean forceBinary) { * @throws IllegalArgumentException if {@code host} is {@code null} */ public Builder host(String host) { - this.host = Assert.requireNonNull(host, "host must not be null"); + Assert.requireNonNull(host, "host must not be null"); + if (this.singleHostConfiguration == null) { + this.singleHostConfiguration = new SingleHostConfiguration(host, DEFAULT_PORT, null); + } else { + this.singleHostConfiguration = new SingleHostConfiguration(host, this.singleHostConfiguration.getPort(), this.singleHostConfiguration.getSocket()); + } return this; } @@ -490,14 +418,8 @@ public Builder password(@Nullable CharSequence password) { return this; } - /** - * Configure the port. Defaults to {@code 5432}. - * - * @param port the port - * @return this {@link Builder} - */ - public Builder port(int port) { - this.port = port; + public Builder multipleHostsConfiguration(MultipleHostsConfiguration multipleHostsConfiguration) { + this.multipleHostsConfiguration = Assert.requireNonNull(multipleHostsConfiguration, "multipleHostsConfiguration must not be null"); return this; } @@ -513,15 +435,17 @@ public Builder schema(@Nullable String schema) { } /** - * Configure the unix domain socket to connect to. + * Configure the port. Defaults to {@code 5432}. * - * @param socket the socket path + * @param port the port * @return this {@link Builder} - * @throws IllegalArgumentException if {@code socket} is {@code null} */ - public Builder socket(String socket) { - this.socket = Assert.requireNonNull(socket, "host must not be null"); - sslMode(SSLMode.DISABLE); + public Builder port(int port) { + if (this.singleHostConfiguration == null) { + this.singleHostConfiguration = new SingleHostConfiguration(null, port, null); + } else { + this.singleHostConfiguration = new SingleHostConfiguration(this.singleHostConfiguration.getHost(), port, this.singleHostConfiguration.getSocket()); + } return this; } @@ -594,11 +518,6 @@ public Builder sslPassword(@Nullable CharSequence sslPassword) { return this; } - public Builder hostRecheckTime(int hostRecheckTime) { - this.hostRecheckTime = hostRecheckTime; - return this; - } - /** * Configure ssl root cert for server certificate validation. * @@ -622,23 +541,23 @@ public Builder username(String username) { return this; } - public Builder loadBalance(boolean loadBalance) { - this.loadBalance = loadBalance; - return this; - } - - public Builder targetServerType(TargetServerType targetServerType) { - this.targetServerType = targetServerType; - return this; - } + /** + * Configure the unix domain socket to connect to. + * + * @param socket the socket path + * @return this {@link Builder} + * @throws IllegalArgumentException if {@code socket} is {@code null} + */ + public Builder socket(String socket) { + Assert.requireNonNull(socket, "host must not be null"); - public Builder tmpHosts(String[] tmpHosts) { - this.tmpHosts = tmpHosts; - return this; - } + if (this.singleHostConfiguration == null) { + this.singleHostConfiguration = new SingleHostConfiguration(null, DEFAULT_PORT, socket); + } else { + this.singleHostConfiguration = new SingleHostConfiguration(this.singleHostConfiguration.getHost(), this.singleHostConfiguration.getPort(), socket); + } - public Builder tmpPorts(int[] tmpPorts) { - this.tmpPorts = tmpPorts; + sslMode(SSLMode.DISABLE); return this; } @@ -646,32 +565,28 @@ public Builder tmpPorts(int[] tmpPorts) { public String toString() { return "Builder{" + "applicationName='" + this.applicationName + '\'' + + ", singleHostConfiguration='" + this.singleHostConfiguration + '\'' + + ", multipleHostsConfiguration='" + this.multipleHostsConfiguration + '\'' + ", autodetectExtensions='" + this.autodetectExtensions + '\'' + ", connectTimeout='" + this.connectTimeout + '\'' + ", database='" + this.database + '\'' + ", extensions='" + this.extensions + '\'' + ", forceBinary='" + this.forceBinary + '\'' + - ", host='" + this.host + '\'' + ", parameters='" + this.options + '\'' + ", password='" + repeat(this.password != null ? this.password.length() : 0, "*") + '\'' + - ", port=" + this.port + ", schema='" + this.schema + '\'' + ", username='" + this.username + '\'' + - ", socket='" + this.socket + '\'' + ", sslContextBuilderCustomizer='" + this.sslContextBuilderCustomizer + '\'' + ", sslMode='" + this.sslMode + '\'' + ", sslRootCert='" + this.sslRootCert + '\'' + ", sslCert='" + this.sslCert + '\'' + ", sslKey='" + this.sslKey + '\'' + ", sslHostnameVerifier='" + this.sslHostnameVerifier + '\'' + - ", targetServerType='" + this.targetServerType + '\'' + - ", hostRecheckTime='" + this.hostRecheckTime + '\'' + - ", loadBalance='" + this.loadBalance + '\'' + '}'; } private SSLConfig createSslConfig() { - if (this.socket != null || this.sslMode == SSLMode.DISABLE) { + if (this.singleHostConfiguration != null && this.singleHostConfiguration.getSocket() != null || this.sslMode == SSLMode.DISABLE) { return SSLConfig.disabled(); } diff --git a/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionFactory.java b/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionFactory.java index 609a3263..5040601d 100644 --- a/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionFactory.java +++ b/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionFactory.java @@ -38,7 +38,6 @@ import java.util.List; import java.util.Locale; import java.util.Map; -import java.util.function.Function; /** * An implementation of {@link ConnectionFactory} for creating connections to a PostgreSQL database. @@ -49,7 +48,11 @@ public final class PostgresqlConnectionFactory implements ConnectionFactory { private static final String REPLICATION_DATABASE = "database"; - private final Function, Mono> clientFactory; + private static final ClientSupplier DEFAULT_CLIENT_SUPPLIER = (endpoint, connectTimeout, sslConfig) -> ReactorNettyClient. + connect(ConnectionProvider.newConnection(), endpoint, connectTimeout, sslConfig) + .cast(Client.class); + + private final ClientFactory clientFactory; private final PostgresqlConnectionConfiguration configuration; @@ -63,16 +66,13 @@ public final class PostgresqlConnectionFactory implements ConnectionFactory { */ public PostgresqlConnectionFactory(PostgresqlConnectionConfiguration configuration) { this.configuration = Assert.requireNonNull(configuration, "configuration must not be null"); - this.clientFactory = new ClientFactory(configuration, - (endpoint, connectTimeout, sslConfig) -> ReactorNettyClient.connect(ConnectionProvider.newConnection(), endpoint, connectTimeout, sslConfig) - .cast(Client.class)); + this.clientFactory = ClientFactory.getFactory(configuration, DEFAULT_CLIENT_SUPPLIER); this.extensions = getExtensions(configuration); } - PostgresqlConnectionFactory(ClientFactory.ConnectionSupplier connectionSupplier, PostgresqlConnectionConfiguration configuration) { + PostgresqlConnectionFactory(ClientFactory clientFactory, PostgresqlConnectionConfiguration configuration) { this.configuration = Assert.requireNonNull(configuration, "configuration must not be null"); - Assert.requireNonNull(connectionSupplier, "connectionSupplier must not be null"); - this.clientFactory = new ClientFactory(configuration, connectionSupplier); + this.clientFactory = Assert.requireNonNull(clientFactory, "clientFactory must not be null"); this.extensions = getExtensions(configuration); } @@ -148,7 +148,7 @@ private Mono closeWithError(Client client, Throwable throw } private Mono doCreateConnection(boolean forReplication, @Nullable Map options) { - return this.clientFactory.apply(options) + return this.clientFactory.create(options) .flatMap(client -> { DefaultCodecs codecs = new DefaultCodecs(client.getByteBufAllocator()); diff --git a/src/main/java/io/r2dbc/postgresql/SingleHostClientFactory.java b/src/main/java/io/r2dbc/postgresql/SingleHostClientFactory.java new file mode 100644 index 00000000..6f2d0723 --- /dev/null +++ b/src/main/java/io/r2dbc/postgresql/SingleHostClientFactory.java @@ -0,0 +1,40 @@ +package io.r2dbc.postgresql; + +import io.netty.channel.unix.DomainSocketAddress; +import io.r2dbc.postgresql.client.Client; +import io.r2dbc.postgresql.client.SingleHostConfiguration; +import reactor.core.publisher.Mono; +import reactor.util.annotation.Nullable; + +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.util.Map; + +class SingleHostClientFactory extends ClientFactoryBase { + + private final SocketAddress endpoint; + + protected SingleHostClientFactory(PostgresqlConnectionConfiguration configuration, ClientSupplier clientSupplier) { + super(configuration, clientSupplier); + this.endpoint = SingleHostClientFactory.createSocketAddress(configuration.getSingleHostConfiguration()); + } + + @Override + public Mono create(@Nullable Map options) { + return this.tryConnectToEndpoint(this.endpoint, options); + } + + + protected static SocketAddress createSocketAddress(SingleHostConfiguration configuration) { + if (!configuration.isUseSocket()) { + return InetSocketAddress.createUnresolved(configuration.getRequiredHost(), configuration.getPort()); + } + + if (configuration.isUseSocket()) { + return new DomainSocketAddress(configuration.getRequiredSocket()); + } + + throw new IllegalArgumentException("Cannot create SocketAddress for " + configuration); + } + +} diff --git a/src/main/java/io/r2dbc/postgresql/TargetServerType.java b/src/main/java/io/r2dbc/postgresql/TargetServerType.java index 0ac08656..9c1dadbb 100644 --- a/src/main/java/io/r2dbc/postgresql/TargetServerType.java +++ b/src/main/java/io/r2dbc/postgresql/TargetServerType.java @@ -5,26 +5,26 @@ public enum TargetServerType { ANY("any") { @Override - public boolean allowStatus(ClientFactory.HostStatus hostStatus) { - return hostStatus != ClientFactory.HostStatus.CONNECT_FAIL; + public boolean allowStatus(MultipleHostsClientFactory.HostStatus hostStatus) { + return hostStatus != MultipleHostsClientFactory.HostStatus.CONNECT_FAIL; } }, MASTER("master") { @Override - public boolean allowStatus(ClientFactory.HostStatus hostStatus) { - return hostStatus == ClientFactory.HostStatus.PRIMARY || hostStatus == ClientFactory.HostStatus.CONNECT_OK; + public boolean allowStatus(MultipleHostsClientFactory.HostStatus hostStatus) { + return hostStatus == MultipleHostsClientFactory.HostStatus.PRIMARY || hostStatus == MultipleHostsClientFactory.HostStatus.CONNECT_OK; } }, SECONDARY("secondary") { @Override - public boolean allowStatus(ClientFactory.HostStatus hostStatus) { - return hostStatus == ClientFactory.HostStatus.STANDBY || hostStatus == ClientFactory.HostStatus.CONNECT_OK; + public boolean allowStatus(MultipleHostsClientFactory.HostStatus hostStatus) { + return hostStatus == MultipleHostsClientFactory.HostStatus.STANDBY || hostStatus == MultipleHostsClientFactory.HostStatus.CONNECT_OK; } }, PREFER_SECONDARY("preferSecondary") { @Override - public boolean allowStatus(ClientFactory.HostStatus hostStatus) { - return hostStatus == ClientFactory.HostStatus.STANDBY || hostStatus == ClientFactory.HostStatus.CONNECT_OK; + public boolean allowStatus(MultipleHostsClientFactory.HostStatus hostStatus) { + return hostStatus == MultipleHostsClientFactory.HostStatus.STANDBY || hostStatus == MultipleHostsClientFactory.HostStatus.CONNECT_OK; } }; @@ -49,5 +49,5 @@ public String getValue() { return value; } - public abstract boolean allowStatus(ClientFactory.HostStatus hostStatus); + public abstract boolean allowStatus(MultipleHostsClientFactory.HostStatus hostStatus); } diff --git a/src/main/java/io/r2dbc/postgresql/client/MultipleHostsConfiguration.java b/src/main/java/io/r2dbc/postgresql/client/MultipleHostsConfiguration.java new file mode 100644 index 00000000..9ca47e13 --- /dev/null +++ b/src/main/java/io/r2dbc/postgresql/client/MultipleHostsConfiguration.java @@ -0,0 +1,59 @@ +package io.r2dbc.postgresql.client; + +import io.r2dbc.postgresql.TargetServerType; + +import java.util.List; + +public class MultipleHostsConfiguration { + + private final int hostRecheckTime; + + private final List hosts; + + private final boolean loadBalance; + + private final TargetServerType targetServerType; + + public MultipleHostsConfiguration(List hosts, int hostRecheckTime, boolean loadBalance, TargetServerType targetServerType) { + this.hosts = hosts; + this.hostRecheckTime = hostRecheckTime; + this.loadBalance = loadBalance; + this.targetServerType = targetServerType; + } + + public int getHostRecheckTime() { + return hostRecheckTime; + } + + public List getHosts() { + return hosts; + } + + public TargetServerType getTargetServerType() { + return targetServerType; + } + + public boolean isLoadBalance() { + return loadBalance; + } + + public static class ServerHost { + + private final String host; + + private final int port; + + public ServerHost(String host, int port) { + this.host = host; + this.port = port; + } + + public String getHost() { + return host; + } + + public int getPort() { + return port; + } + } +} diff --git a/src/main/java/io/r2dbc/postgresql/client/SingleHostConfiguration.java b/src/main/java/io/r2dbc/postgresql/client/SingleHostConfiguration.java new file mode 100644 index 00000000..5a984f08 --- /dev/null +++ b/src/main/java/io/r2dbc/postgresql/client/SingleHostConfiguration.java @@ -0,0 +1,61 @@ +package io.r2dbc.postgresql.client; + +import reactor.util.annotation.Nullable; + +public class SingleHostConfiguration { + + @Nullable + private final String host; + + private final int port; + + @Nullable + private final String socket; + + public SingleHostConfiguration(@Nullable String host, int port, @Nullable String socket) { + this.host = host; + this.port = port; + this.socket = socket; + } + + @Nullable + public String getHost() { + return host; + } + + public int getPort() { + return port; + } + + public String getRequiredHost() { + + String host = getHost(); + + if (host == null || host.isEmpty()) { + throw new IllegalStateException("Connection is configured for socket connections and not for host usage"); + } + + return host; + } + + public String getRequiredSocket() { + + String socket = getSocket(); + + if (socket == null || socket.isEmpty()) { + throw new IllegalStateException("Connection is configured to use host and port connections and not for socket usage"); + } + + return socket; + } + + @Nullable + public String getSocket() { + return socket; + } + + public boolean isUseSocket() { + return getSocket() != null; + } + +} diff --git a/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryProviderTest.java b/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryProviderTest.java index 2ab6ba61..0e53caae 100644 --- a/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryProviderTest.java +++ b/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryProviderTest.java @@ -231,8 +231,8 @@ void shouldConnectUsingUnixDomainSocket() { .option(USER, "postgres") .build()); - assertThat(factory.getConfiguration().isUseSocket()).isTrue(); - assertThat(factory.getConfiguration().getRequiredSocket()).isEqualTo("/tmp/.s.PGSQL.5432"); + assertThat(factory.getConfiguration().getSingleHostConfiguration().isUseSocket()).isTrue(); + assertThat(factory.getConfiguration().getSingleHostConfiguration().getRequiredSocket()).isEqualTo("/tmp/.s.PGSQL.5432"); } } diff --git a/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryTest.java b/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryTest.java index f2e7777b..eb202e67 100644 --- a/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryTest.java +++ b/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryTest.java @@ -77,7 +77,7 @@ void createAuthenticationMD5Password() { .password("test-password") .build(); - new PostgresqlConnectionFactory((a, b, c) -> Mono.just(client), configuration) + new PostgresqlConnectionFactory(testClientFactory(client, configuration), configuration) .create() .as(StepVerifier::create) .expectNextCount(1) @@ -128,7 +128,7 @@ void createError() { .password("test-password") .build(); - new PostgresqlConnectionFactory((a, b, c) -> Mono.just(client), configuration).create() + new PostgresqlConnectionFactory(testClientFactory(client, configuration), configuration).create() .as(StepVerifier::create) .verifyErrorMatches(R2dbcNonTransientResourceException.class::isInstance); } @@ -152,7 +152,12 @@ void getMetadata() { .password("test-password") .build(); - assertThat(new PostgresqlConnectionFactory((a, b, c) -> Mono.just(client), configuration).getMetadata()).isNotNull(); + assertThat(new PostgresqlConnectionFactory(testClientFactory(client, configuration), configuration).getMetadata()).isNotNull(); + } + + + private ClientFactory testClientFactory(Client client, PostgresqlConnectionConfiguration configuration) { + return new SingleHostClientFactory(configuration, (endpoint, timeout, ssl) -> Mono.just(client)); } } diff --git a/src/test/java/io/r2dbc/postgresql/client/HighAvailabilityClusterTest.java b/src/test/java/io/r2dbc/postgresql/client/HighAvailabilityClusterTest.java index 9514d718..10b1369b 100644 --- a/src/test/java/io/r2dbc/postgresql/client/HighAvailabilityClusterTest.java +++ b/src/test/java/io/r2dbc/postgresql/client/HighAvailabilityClusterTest.java @@ -13,6 +13,9 @@ import reactor.core.publisher.Mono; import reactor.test.StepVerifier; +import java.util.ArrayList; +import java.util.List; + public class HighAvailabilityClusterTest { @RegisterExtension @@ -180,19 +183,20 @@ private Mono isPrimary(PostgresqlConnection connection) { } private PostgresqlConnectionFactory multipleHostsConnectionFactory(TargetServerType targetServerType, PostgreSQLContainer... servers) { - PostgreSQLContainer server = servers[0]; - String[] hosts = new String[servers.length]; - int[] ports = new int[servers.length]; - for (int i = 0; i < servers.length; i++) { - hosts[i] = servers[i].getContainerIpAddress(); - ports[i] = servers[i].getMappedPort(5432); + PostgreSQLContainer firstServer = servers[0]; + List hosts = new ArrayList<>(servers.length); + for (PostgreSQLContainer server : servers) { + hosts.add(new MultipleHostsConfiguration.ServerHost( + server.getContainerIpAddress(), + server.getMappedPort(5432) + )); } PostgresqlConnectionConfiguration configuration = PostgresqlConnectionConfiguration.builder() - .tmpHosts(hosts) - .tmpPorts(ports) - .username(server.getUsername()) - .password(server.getPassword()) - .targetServerType(targetServerType) + .multipleHostsConfiguration(new MultipleHostsConfiguration( + hosts, 10000, false, targetServerType + )) + .username(firstServer.getUsername()) + .password(firstServer.getPassword()) .build(); return new PostgresqlConnectionFactory(configuration); } From 20be3d3565993a30dffbf20cad77fbeb823c73b9 Mon Sep 17 00:00:00 2001 From: Anton Duyun Date: Wed, 27 Nov 2019 16:13:36 +0300 Subject: [PATCH 4/5] single/multiple hosts configuration polishing --- .../MultipleHostsClientFactory.java | 23 ++- .../PostgresqlConnectionConfiguration.java | 123 +++++++++++--- .../PostgresqlConnectionFactoryProvider.java | 81 ++++++++-- .../client/MultipleHostsConfiguration.java | 152 ++++++++++++++++-- .../client/SingleHostConfiguration.java | 103 ++++++++++++ ...PostgresqlConnectionConfigurationTest.java | 12 +- ...stgresqlConnectionFactoryProviderTest.java | 29 ++++ .../PostgresqlConnectionFactoryTest.java | 2 +- .../client/HighAvailabilityClusterTest.java | 16 +- 9 files changed, 467 insertions(+), 74 deletions(-) diff --git a/src/main/java/io/r2dbc/postgresql/MultipleHostsClientFactory.java b/src/main/java/io/r2dbc/postgresql/MultipleHostsClientFactory.java index 645c168c..edc95456 100644 --- a/src/main/java/io/r2dbc/postgresql/MultipleHostsClientFactory.java +++ b/src/main/java/io/r2dbc/postgresql/MultipleHostsClientFactory.java @@ -3,6 +3,7 @@ import io.r2dbc.postgresql.client.Client; import io.r2dbc.postgresql.client.MultipleHostsConfiguration; import io.r2dbc.postgresql.codec.DefaultCodecs; +import io.r2dbc.postgresql.util.Assert; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; @@ -31,7 +32,7 @@ class MultipleHostsClientFactory extends ClientFactoryBase { public MultipleHostsClientFactory(PostgresqlConnectionConfiguration configuration, ClientSupplier clientSupplier) { super(configuration, clientSupplier); - this.configuration = configuration.getMultipleHostsConfiguration(); + this.configuration = Assert.requireNonNull(configuration.getMultipleHostsConfiguration(), "MultipleHostsConfiguration must not be null"); this.addresses = MultipleHostsClientFactory.createSocketAddress(this.configuration); } @@ -99,22 +100,30 @@ private static Mono isPrimaryServer(Client client) { private Flux getCandidates(TargetServerType targetServerType) { return Flux.create(sink -> { - if (this.addresses.size() == 1) { - sink.next(this.addresses.get(0)); - sink.complete(); - return; - } long now = System.currentTimeMillis(); List addresses = new ArrayList<>(this.addresses); - if (this.configuration.isLoadBalance()) { + if (this.configuration.isLoadBalanceHosts()) { Collections.shuffle(addresses); } + int counter = 0; for (SocketAddress address : addresses) { HostSpecStatus currentStatus = this.statusMap.get(address); if (currentStatus == null || now > currentStatus.updated + this.configuration.getHostRecheckTime()) { sink.next(address); + counter++; } else if (targetServerType.allowStatus(currentStatus.hostStatus)) { sink.next(address); + counter++; + } + } + if (counter == 0) { + // if no candidate match the requirement or all of them are in unavailable status try all the hosts + addresses = new ArrayList<>(this.addresses); + if (this.configuration.isLoadBalanceHosts()) { + Collections.shuffle(addresses); + } + for (SocketAddress address : addresses) { + sink.next(address); } } sink.complete(); diff --git a/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionConfiguration.java b/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionConfiguration.java index 1fed357a..2e046caf 100644 --- a/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionConfiguration.java +++ b/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionConfiguration.java @@ -71,7 +71,6 @@ public final class PostgresqlConnectionConfiguration { private final String schema; - private final String username; private final SSLConfig sslConfig; @@ -221,10 +220,10 @@ public static final class Builder { private boolean forceBinary = false; @Nullable - private MultipleHostsConfiguration multipleHostsConfiguration; + private MultipleHostsConfiguration.Builder multipleHostsConfiguration; @Nullable - private SingleHostConfiguration singleHostConfiguration; + private SingleHostConfiguration.Builder singleHostConfiguration; private Map options; @@ -287,13 +286,17 @@ public Builder autodetectExtensions(boolean autodetectExtensions) { * @return a configured {@link PostgresqlConnectionConfiguration} */ public PostgresqlConnectionConfiguration build() { - - if (this.singleHostConfiguration != null && this.singleHostConfiguration.getHost() == null && this.singleHostConfiguration.getSocket() == null) { - throw new IllegalArgumentException("host or socket must not be null"); + SingleHostConfiguration singleHostConfiguration = this.singleHostConfiguration != null + ? this.singleHostConfiguration.build() + : null; + MultipleHostsConfiguration multipleHostsConfiguration = this.multipleHostsConfiguration != null + ? this.multipleHostsConfiguration.build() + : null; + if (singleHostConfiguration == null && multipleHostsConfiguration == null) { + throw new IllegalArgumentException("Either multiple hosts configuration or single host configuration should be provided"); } - - if (this.singleHostConfiguration != null && this.singleHostConfiguration.getHost() != null && this.singleHostConfiguration.getSocket() != null) { - throw new IllegalArgumentException("Connection must be configured for either host/port or socket usage but not both"); + if (singleHostConfiguration != null && multipleHostsConfiguration != null) { + throw new IllegalArgumentException("Either multiple hosts configuration or single host configuration should be provided"); } if (this.username == null) { @@ -301,7 +304,7 @@ public PostgresqlConnectionConfiguration build() { } return new PostgresqlConnectionConfiguration(this.applicationName, this.autodetectExtensions, this.connectTimeout, this.database, this.extensions, this.forceBinary, - this.options, this.password, this.schema, this.username, this.createSslConfig(), this.singleHostConfiguration, this.multipleHostsConfiguration); + this.options, this.password, this.schema, this.username, this.createSslConfig(), singleHostConfiguration, multipleHostsConfiguration); } /** @@ -377,10 +380,9 @@ public Builder forceBinary(boolean forceBinary) { public Builder host(String host) { Assert.requireNonNull(host, "host must not be null"); if (this.singleHostConfiguration == null) { - this.singleHostConfiguration = new SingleHostConfiguration(host, DEFAULT_PORT, null); - } else { - this.singleHostConfiguration = new SingleHostConfiguration(host, this.singleHostConfiguration.getPort(), this.singleHostConfiguration.getSocket()); + this.singleHostConfiguration = SingleHostConfiguration.builder(); } + this.singleHostConfiguration.host(host); return this; } @@ -418,11 +420,6 @@ public Builder password(@Nullable CharSequence password) { return this; } - public Builder multipleHostsConfiguration(MultipleHostsConfiguration multipleHostsConfiguration) { - this.multipleHostsConfiguration = Assert.requireNonNull(multipleHostsConfiguration, "multipleHostsConfiguration must not be null"); - return this; - } - /** * Configure the schema. * @@ -442,10 +439,9 @@ public Builder schema(@Nullable String schema) { */ public Builder port(int port) { if (this.singleHostConfiguration == null) { - this.singleHostConfiguration = new SingleHostConfiguration(null, port, null); - } else { - this.singleHostConfiguration = new SingleHostConfiguration(this.singleHostConfiguration.getHost(), port, this.singleHostConfiguration.getSocket()); + this.singleHostConfiguration = SingleHostConfiguration.builder(); } + this.singleHostConfiguration.port(port); return this; } @@ -550,17 +546,94 @@ public Builder username(String username) { */ public Builder socket(String socket) { Assert.requireNonNull(socket, "host must not be null"); - if (this.singleHostConfiguration == null) { - this.singleHostConfiguration = new SingleHostConfiguration(null, DEFAULT_PORT, socket); - } else { - this.singleHostConfiguration = new SingleHostConfiguration(this.singleHostConfiguration.getHost(), this.singleHostConfiguration.getPort(), socket); + this.singleHostConfiguration = SingleHostConfiguration.builder(); } + this.singleHostConfiguration.socket(socket); sslMode(SSLMode.DISABLE); return this; } + /** + * Allows opening connections to only servers with required state, the allowed values are any, master, slave, secondary, preferSlave and preferSecondary. + * The master/secondary distinction is currently done by observing if the server allows writes. + * The value preferSecondary tries to connect to secondary if any are available, otherwise allows falls back to connecting also to master. + * Default value is any. + * + * @param targetServerType target server type + * @return this {@link Builder} + * @throws IllegalArgumentException if {@code targetServerType} is {@code null} + */ + public Builder targetServerType(TargetServerType targetServerType) { + if (this.multipleHostsConfiguration == null) { + this.multipleHostsConfiguration = MultipleHostsConfiguration.builder(); + } + this.multipleHostsConfiguration.targetServerType(targetServerType); + return this; + } + + /** + * Controls how long in seconds the knowledge about a host state is cached connection factory. The default value is 10000 milliseconds. + * + * @param hostRecheckTime host recheck time in milliseconds + * @return this {@link Builder} + */ + public Builder hostRecheckTime(int hostRecheckTime) { + if (this.multipleHostsConfiguration == null) { + this.multipleHostsConfiguration = MultipleHostsConfiguration.builder(); + } + this.multipleHostsConfiguration.hostRecheckTime(hostRecheckTime); + return this; + } + + /** + * In default mode (disabled) hosts are connected in the given order. If enabled hosts are chosen randomly from the set of suitable candidates. + * + * @param loadBalanceHosts is load balance mode enabled + * @return this {@link Builder} + */ + public Builder loadBalanceHosts(boolean loadBalanceHosts) { + if (this.multipleHostsConfiguration == null) { + this.multipleHostsConfiguration = MultipleHostsConfiguration.builder(); + } + this.multipleHostsConfiguration.loadBalanceHosts(loadBalanceHosts); + return this; + } + + /** + * Add host with default port to hosts list. + * + * @param host the host + * @return this {@link Builder} + * @throws IllegalArgumentException if {@code host} is {@code null} + */ + public Builder addHost(String host) { + Assert.requireNonNull(host, "host must not be null"); + if (this.multipleHostsConfiguration == null) { + this.multipleHostsConfiguration = MultipleHostsConfiguration.builder(); + } + this.multipleHostsConfiguration.addHost(host); + return this; + } + + /** + * Add host to hosts list. + * + * @param host the host + * @param port the port + * @return this {@link Builder} + * @throws IllegalArgumentException if {@code host} is {@code null} + */ + public Builder addHost(String host, int port) { + Assert.requireNonNull(host, "host must not be null"); + if (this.multipleHostsConfiguration == null) { + this.multipleHostsConfiguration = MultipleHostsConfiguration.builder(); + } + this.multipleHostsConfiguration.addHost(host, port); + return this; + } + @Override public String toString() { return "Builder{" + diff --git a/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryProvider.java b/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryProvider.java index ade9b75a..2c9ef4ca 100644 --- a/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryProvider.java +++ b/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryProvider.java @@ -33,6 +33,7 @@ import static io.r2dbc.spi.ConnectionFactoryOptions.HOST; import static io.r2dbc.spi.ConnectionFactoryOptions.PASSWORD; import static io.r2dbc.spi.ConnectionFactoryOptions.PORT; +import static io.r2dbc.spi.ConnectionFactoryOptions.PROTOCOL; import static io.r2dbc.spi.ConnectionFactoryOptions.SSL; import static io.r2dbc.spi.ConnectionFactoryOptions.USER; @@ -56,11 +57,31 @@ public final class PostgresqlConnectionFactoryProvider implements ConnectionFact */ public static final Option FORCE_BINARY = Option.valueOf("forceBinary"); + /** + * Load balance hosts. + */ + public static final Option LOAD_BALANCE_HOSTS = Option.valueOf("loadBalanceHosts"); + + /** + * Host status recheck time im ms. + */ + public static final Option HOST_RECHECK_TIME = Option.valueOf("hostRecheckTime"); + + /** + * Target server type. Allowed values: any, master, secondary, preferSecondary. + */ + public static final Option TARGET_SERVER_TYPE = Option.valueOf("targetServerType"); + /** * Driver option value. */ public static final String POSTGRESQL_DRIVER = "postgresql"; + /** + * Failover driver protocol. + */ + public static final String FAILOVER_PROTOCOL = "failover"; + /** * Legacy driver option value. */ @@ -144,21 +165,57 @@ static PostgresqlConnectionConfiguration createConfiguration(ConnectionFactoryOp builder.connectTimeout(connectionFactoryOptions.getValue(CONNECT_TIMEOUT)); builder.database(connectionFactoryOptions.getValue(DATABASE)); - if (connectionFactoryOptions.hasOption(SOCKET)) { - tcp = false; - builder.socket(connectionFactoryOptions.getRequiredValue(SOCKET)); - } else { + if (FAILOVER_PROTOCOL.equals(connectionFactoryOptions.getValue(PROTOCOL))) { + if (connectionFactoryOptions.hasOption(HOST_RECHECK_TIME)) { + builder.hostRecheckTime(connectionFactoryOptions.getRequiredValue(HOST_RECHECK_TIME)); + } + if (connectionFactoryOptions.hasOption(LOAD_BALANCE_HOSTS)) { + Object loadBalanceHosts = connectionFactoryOptions.getRequiredValue(LOAD_BALANCE_HOSTS); + if (loadBalanceHosts instanceof Boolean) { + builder.loadBalanceHosts((Boolean) loadBalanceHosts); + } else { + builder.loadBalanceHosts(Boolean.parseBoolean(loadBalanceHosts.toString())); + } + } + if (connectionFactoryOptions.hasOption(TARGET_SERVER_TYPE)) { + Object targetServerType = connectionFactoryOptions.getRequiredValue(TARGET_SERVER_TYPE); + if (targetServerType instanceof TargetServerType) { + builder.targetServerType((TargetServerType) targetServerType); + } else { + builder.targetServerType(TargetServerType.fromValue(targetServerType.toString())); + } + } + String hosts = connectionFactoryOptions.getRequiredValue(HOST); + String[] hostsArray = hosts.split(","); + for (String host : hostsArray) { + String[] hostParts = host.split(":"); + if (hostParts.length == 1) { + builder.addHost(hostParts[0]); + } else { + int port = Integer.parseInt(hostParts[1]); + builder.addHost(hostParts[0], port); + } + } tcp = true; - builder.host(connectionFactoryOptions.getRequiredValue(HOST)); + } else { + if (connectionFactoryOptions.hasOption(SOCKET)) { + tcp = false; + builder.socket(connectionFactoryOptions.getRequiredValue(SOCKET)); + } else { + tcp = true; + builder.host(connectionFactoryOptions.getRequiredValue(HOST)); + } + Integer port = connectionFactoryOptions.getValue(PORT); + if (port != null) { + builder.port(port); + } } + + builder.password(connectionFactoryOptions.getValue(PASSWORD)); builder.schema(connectionFactoryOptions.getValue(SCHEMA)); builder.username(connectionFactoryOptions.getRequiredValue(USER)); - Integer port = connectionFactoryOptions.getValue(PORT); - if (port != null) { - builder.port(port); - } Boolean forceBinary = connectionFactoryOptions.getValue(FORCE_BINARY); @@ -242,10 +299,6 @@ public boolean supports(ConnectionFactoryOptions connectionFactoryOptions) { Assert.requireNonNull(connectionFactoryOptions, "connectionFactoryOptions must not be null"); String driver = connectionFactoryOptions.getValue(DRIVER); - if (driver == null || !(driver.equals(POSTGRESQL_DRIVER) || driver.equals(LEGACY_POSTGRESQL_DRIVER))) { - return false; - } - - return true; + return driver != null && (driver.equals(POSTGRESQL_DRIVER) || driver.equals(LEGACY_POSTGRESQL_DRIVER)); } } diff --git a/src/main/java/io/r2dbc/postgresql/client/MultipleHostsConfiguration.java b/src/main/java/io/r2dbc/postgresql/client/MultipleHostsConfiguration.java index 9ca47e13..c2d6074e 100644 --- a/src/main/java/io/r2dbc/postgresql/client/MultipleHostsConfiguration.java +++ b/src/main/java/io/r2dbc/postgresql/client/MultipleHostsConfiguration.java @@ -1,23 +1,27 @@ package io.r2dbc.postgresql.client; import io.r2dbc.postgresql.TargetServerType; +import io.r2dbc.postgresql.util.Assert; +import java.util.ArrayList; import java.util.List; -public class MultipleHostsConfiguration { +import static io.r2dbc.postgresql.PostgresqlConnectionConfiguration.DEFAULT_PORT; - private final int hostRecheckTime; +public class MultipleHostsConfiguration { private final List hosts; - private final boolean loadBalance; + private final int hostRecheckTime; + + private final boolean loadBalanceHosts; private final TargetServerType targetServerType; - public MultipleHostsConfiguration(List hosts, int hostRecheckTime, boolean loadBalance, TargetServerType targetServerType) { + public MultipleHostsConfiguration(List hosts, int hostRecheckTime, boolean loadBalanceHosts, TargetServerType targetServerType) { this.hosts = hosts; this.hostRecheckTime = hostRecheckTime; - this.loadBalance = loadBalance; + this.loadBalanceHosts = loadBalanceHosts; this.targetServerType = targetServerType; } @@ -33,8 +37,18 @@ public TargetServerType getTargetServerType() { return targetServerType; } - public boolean isLoadBalance() { - return loadBalance; + public boolean isLoadBalanceHosts() { + return loadBalanceHosts; + } + + @Override + public String toString() { + return "MultipleHostsConfiguration{" + + "hosts=" + this.hosts + + ", hostRecheckTime=" + this.hostRecheckTime + + ", loadBalanceHosts=" + this.loadBalanceHosts + + ", targetServerType=" + this.targetServerType + + '}'; } public static class ServerHost { @@ -49,11 +63,131 @@ public ServerHost(String host, int port) { } public String getHost() { - return host; + return this.host; } public int getPort() { - return port; + return this.port; + } + + @Override + public String toString() { + return "ServerHost{" + + "host='" + this.host + '\'' + + ", port=" + this.port + + '}'; + } + } + + /** + * Returns a new {@link Builder}. + * + * @return a new {@link Builder} + */ + public static Builder builder() { + return new Builder(); + } + + /** + * A builder for {@link MultipleHostsConfiguration} instances. + *

+ * This class is not threadsafe + */ + public static class Builder { + + private int hostRecheckTime = 10000; + + private List hosts = new ArrayList<>(); + + private boolean loadBalanceHosts = false; + + private TargetServerType targetServerType = TargetServerType.ANY; + + /** + * Allows opening connections to only servers with required state, the allowed values are any, master, slave, secondary, preferSlave and preferSecondary. + * The master/secondary distinction is currently done by observing if the server allows writes. + * The value preferSecondary tries to connect to secondary if any are available, otherwise allows falls back to connecting also to master. + * Default value is any. + * + * @param targetServerType target server type + * @return this {@link Builder} + * @throws IllegalArgumentException if {@code targetServerType} is {@code null} + */ + public Builder targetServerType(TargetServerType targetServerType) { + this.targetServerType = Assert.requireNonNull(targetServerType, "targetServerType must not be null"); + return this; + } + + /** + * Controls how long in seconds the knowledge about a host state is cached connection factory. The default value is 10000 milliseconds. + * + * @param hostRecheckTime host recheck time in milliseconds + * @return this {@link Builder} + */ + public Builder hostRecheckTime(int hostRecheckTime) { + this.hostRecheckTime = hostRecheckTime; + return this; + } + + /** + * In default mode (disabled) hosts are connected in the given order. If enabled hosts are chosen randomly from the set of suitable candidates. + * + * @param loadBalanceHosts is load balance mode enabled + * @return this {@link Builder} + */ + public Builder loadBalanceHosts(boolean loadBalanceHosts) { + this.loadBalanceHosts = loadBalanceHosts; + return this; + } + + /** + * Add host with default port to hosts list. + * + * @param host the host + * @return this {@link Builder} + * @throws IllegalArgumentException if {@code host} is {@code null} + */ + public Builder addHost(String host) { + Assert.requireNonNull(host, "host must not be null"); + this.hosts.add(new ServerHost(host, DEFAULT_PORT)); + return this; + } + + /** + * Add host to hosts list. + * + * @param host the host + * @param port the port + * @return this {@link Builder} + * @throws IllegalArgumentException if {@code host} is {@code null} + */ + public Builder addHost(String host, int port) { + Assert.requireNonNull(host, "host must not be null"); + this.hosts.add(new ServerHost(host, port)); + return this; + } + + /** + * Returns a configured {@link MultipleHostsConfiguration}. + * + * @return a configured {@link MultipleHostsConfiguration} + */ + public MultipleHostsConfiguration build() { + if (this.hosts.isEmpty()) { + throw new IllegalArgumentException("At least one host should be provided"); + } + + return new MultipleHostsConfiguration(this.hosts, this.hostRecheckTime, this.loadBalanceHosts, this.targetServerType); + } + + @Override + public String toString() { + return "Builder{" + + "hostRecheckTime=" + this.hostRecheckTime + + ", hosts=" + this.hosts + + ", loadBalanceHosts=" + this.loadBalanceHosts + + ", targetServerType=" + this.targetServerType + + '}'; } } } diff --git a/src/main/java/io/r2dbc/postgresql/client/SingleHostConfiguration.java b/src/main/java/io/r2dbc/postgresql/client/SingleHostConfiguration.java index 5a984f08..2578ce6d 100644 --- a/src/main/java/io/r2dbc/postgresql/client/SingleHostConfiguration.java +++ b/src/main/java/io/r2dbc/postgresql/client/SingleHostConfiguration.java @@ -1,7 +1,10 @@ package io.r2dbc.postgresql.client; +import io.r2dbc.postgresql.util.Assert; import reactor.util.annotation.Nullable; +import static io.r2dbc.postgresql.PostgresqlConnectionConfiguration.DEFAULT_PORT; + public class SingleHostConfiguration { @Nullable @@ -58,4 +61,104 @@ public boolean isUseSocket() { return getSocket() != null; } + @Override + public String toString() { + return "SingleHostConfiguration{" + + "host='" + this.host + '\'' + + ", port=" + this.port + + ", socket='" + this.socket + '\'' + + '}'; + } + + /** + * Returns a new {@link Builder}. + * + * @return a new {@link Builder} + */ + public static Builder builder() { + return new Builder(); + } + + /** + * A builder for {@link SingleHostConfiguration} instances. + *

+ * This class is not threadsafe + */ + public static class Builder { + + @Nullable + private String host; + + private int port = DEFAULT_PORT; + + @Nullable + private String socket; + + /** + * Configure the host. + * + * @param host the host + * @return this {@link Builder} + * @throws IllegalArgumentException if {@code host} is {@code null} + */ + public Builder host(String host) { + this.host = Assert.requireNonNull(host, "host must not be null"); + return this; + } + + /** + * Configure the port. Defaults to {@code 5432}. + * + * @param port the port + * @return this {@link Builder} + */ + public Builder port(int port) { + this.port = port; + return this; + } + + /** + * Configure the unix domain socket to connect to. + * + * @param socket the socket path + * @return this {@link Builder} + * @throws IllegalArgumentException if {@code socket} is {@code null} + */ + public Builder socket(String socket) { + this.socket = Assert.requireNonNull(socket, "host must not be null"); + return this; + } + + /** + * Returns a configured {@link SingleHostConfiguration}. + * + * @return a configured {@link SingleHostConfiguration} + */ + public SingleHostConfiguration build() { + if (this.host == null && this.socket == null) { + throw new IllegalArgumentException("host or socket must not be null"); + } + if (this.host != null && this.socket != null) { + throw new IllegalArgumentException("Connection must be configured for either host/port or socket usage but not both"); + } + + return new SingleHostConfiguration(this.host, this.port, this.socket); + } + + @Nullable + public String getSocket() { + return socket; + } + + @Override + public String toString() { + return "Builder{" + + "host='" + this.host + '\'' + + ", port=" + this.port + + ", socket='" + this.socket + '\'' + + '}'; + } + } + + } diff --git a/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionConfigurationTest.java b/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionConfigurationTest.java index a13112d7..046b1660 100644 --- a/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionConfigurationTest.java +++ b/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionConfigurationTest.java @@ -35,9 +35,9 @@ void builderNoApplicationName() { } @Test - void builderNoHostAndSocket() { + void builderNoHostConfiguration() { assertThatIllegalArgumentException().isThrownBy(() -> PostgresqlConnectionConfiguration.builder().build()) - .withMessage("host or socket must not be null"); + .withMessage("Either multiple hosts configuration or single host configuration should be provided"); } @Test @@ -74,10 +74,10 @@ void configuration() { .hasFieldOrPropertyWithValue("applicationName", "test-application-name") .hasFieldOrPropertyWithValue("connectTimeout", Duration.ofMillis(1000)) .hasFieldOrPropertyWithValue("database", "test-database") - .hasFieldOrPropertyWithValue("host", "test-host") + .hasFieldOrPropertyWithValue("singleHostConfiguration.host", "test-host") + .hasFieldOrPropertyWithValue("singleHostConfiguration.port", 100) .hasFieldOrPropertyWithValue("options", options) .hasFieldOrPropertyWithValue("password", null) - .hasFieldOrPropertyWithValue("port", 100) .hasFieldOrPropertyWithValue("schema", "test-schema") .hasFieldOrPropertyWithValue("username", "test-username") .hasFieldOrProperty("sslConfig"); @@ -96,9 +96,9 @@ void configurationDefaults() { assertThat(configuration) .hasFieldOrPropertyWithValue("applicationName", "r2dbc-postgresql") .hasFieldOrPropertyWithValue("database", "test-database") - .hasFieldOrPropertyWithValue("host", "test-host") + .hasFieldOrPropertyWithValue("singleHostConfiguration.host", "test-host") + .hasFieldOrPropertyWithValue("singleHostConfiguration.port", 5432) .hasFieldOrPropertyWithValue("password", "test-password") - .hasFieldOrPropertyWithValue("port", 5432) .hasFieldOrPropertyWithValue("schema", "test-schema") .hasFieldOrPropertyWithValue("username", "test-username") .hasFieldOrProperty("sslConfig"); diff --git a/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryProviderTest.java b/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryProviderTest.java index 0e53caae..989f5ef2 100644 --- a/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryProviderTest.java +++ b/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryProviderTest.java @@ -16,25 +16,32 @@ package io.r2dbc.postgresql; +import io.r2dbc.postgresql.client.MultipleHostsConfiguration; import io.r2dbc.postgresql.client.SSLConfig; import io.r2dbc.postgresql.client.SSLMode; import io.r2dbc.spi.ConnectionFactoryOptions; import org.junit.jupiter.api.Test; import java.util.HashMap; +import java.util.List; import java.util.Map; import static io.r2dbc.postgresql.PostgresqlConnectionFactoryProvider.AUTODETECT_EXTENSIONS; +import static io.r2dbc.postgresql.PostgresqlConnectionFactoryProvider.FAILOVER_PROTOCOL; import static io.r2dbc.postgresql.PostgresqlConnectionFactoryProvider.FORCE_BINARY; +import static io.r2dbc.postgresql.PostgresqlConnectionFactoryProvider.HOST_RECHECK_TIME; import static io.r2dbc.postgresql.PostgresqlConnectionFactoryProvider.LEGACY_POSTGRESQL_DRIVER; +import static io.r2dbc.postgresql.PostgresqlConnectionFactoryProvider.LOAD_BALANCE_HOSTS; import static io.r2dbc.postgresql.PostgresqlConnectionFactoryProvider.OPTIONS; import static io.r2dbc.postgresql.PostgresqlConnectionFactoryProvider.POSTGRESQL_DRIVER; import static io.r2dbc.postgresql.PostgresqlConnectionFactoryProvider.SOCKET; import static io.r2dbc.postgresql.PostgresqlConnectionFactoryProvider.SSL_CONTEXT_BUILDER_CUSTOMIZER; import static io.r2dbc.postgresql.PostgresqlConnectionFactoryProvider.SSL_MODE; +import static io.r2dbc.postgresql.PostgresqlConnectionFactoryProvider.TARGET_SERVER_TYPE; import static io.r2dbc.spi.ConnectionFactoryOptions.DRIVER; import static io.r2dbc.spi.ConnectionFactoryOptions.HOST; import static io.r2dbc.spi.ConnectionFactoryOptions.PASSWORD; +import static io.r2dbc.spi.ConnectionFactoryOptions.PROTOCOL; import static io.r2dbc.spi.ConnectionFactoryOptions.SSL; import static io.r2dbc.spi.ConnectionFactoryOptions.USER; import static io.r2dbc.spi.ConnectionFactoryOptions.builder; @@ -235,4 +242,26 @@ void shouldConnectUsingUnixDomainSocket() { assertThat(factory.getConfiguration().getSingleHostConfiguration().getRequiredSocket()).isEqualTo("/tmp/.s.PGSQL.5432"); } + @Test + void testFailoverConfiguration() { + PostgresqlConnectionFactory factory = this.provider.create(builder() + .option(DRIVER, POSTGRESQL_DRIVER) + .option(PROTOCOL, FAILOVER_PROTOCOL) + .option(HOST, "host1:5433,host2:5432,host3") + .option(USER, "postgres") + .option(LOAD_BALANCE_HOSTS, true) + .option(HOST_RECHECK_TIME, 20000) + .option(TARGET_SERVER_TYPE, TargetServerType.SECONDARY) + .build()); + + assertThat(factory.getConfiguration().getSingleHostConfiguration()).isNull(); + assertThat(factory.getConfiguration().getMultipleHostsConfiguration().isLoadBalanceHosts()).isEqualTo(true); + assertThat(factory.getConfiguration().getMultipleHostsConfiguration().getHostRecheckTime()).isEqualTo(20000); + assertThat(factory.getConfiguration().getMultipleHostsConfiguration().getTargetServerType()).isEqualTo(TargetServerType.SECONDARY); + List hosts = factory.getConfiguration().getMultipleHostsConfiguration().getHosts(); + assertThat(hosts).hasSize(3); + assertThat(hosts.get(0)).isEqualToComparingFieldByField(new MultipleHostsConfiguration.ServerHost("host1", 5433)); + assertThat(hosts.get(1)).isEqualToComparingFieldByField(new MultipleHostsConfiguration.ServerHost("host2", 5432)); + assertThat(hosts.get(2)).isEqualToComparingFieldByField(new MultipleHostsConfiguration.ServerHost("host3", 5432)); + } } diff --git a/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryTest.java b/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryTest.java index eb202e67..272513a9 100644 --- a/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryTest.java +++ b/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryTest.java @@ -49,7 +49,7 @@ void constructorNoClientFactory() { .password("test-password") .username("test-username") .build())) - .withMessage("connectionSupplier must not be null"); + .withMessage("clientFactory must not be null"); } @Test diff --git a/src/test/java/io/r2dbc/postgresql/client/HighAvailabilityClusterTest.java b/src/test/java/io/r2dbc/postgresql/client/HighAvailabilityClusterTest.java index 10b1369b..4b2b34de 100644 --- a/src/test/java/io/r2dbc/postgresql/client/HighAvailabilityClusterTest.java +++ b/src/test/java/io/r2dbc/postgresql/client/HighAvailabilityClusterTest.java @@ -13,9 +13,6 @@ import reactor.core.publisher.Mono; import reactor.test.StepVerifier; -import java.util.ArrayList; -import java.util.List; - public class HighAvailabilityClusterTest { @RegisterExtension @@ -184,17 +181,12 @@ private Mono isPrimary(PostgresqlConnection connection) { private PostgresqlConnectionFactory multipleHostsConnectionFactory(TargetServerType targetServerType, PostgreSQLContainer... servers) { PostgreSQLContainer firstServer = servers[0]; - List hosts = new ArrayList<>(servers.length); + PostgresqlConnectionConfiguration.Builder builder = PostgresqlConnectionConfiguration.builder(); for (PostgreSQLContainer server : servers) { - hosts.add(new MultipleHostsConfiguration.ServerHost( - server.getContainerIpAddress(), - server.getMappedPort(5432) - )); + builder.addHost(server.getContainerIpAddress(), server.getMappedPort(5432)); } - PostgresqlConnectionConfiguration configuration = PostgresqlConnectionConfiguration.builder() - .multipleHostsConfiguration(new MultipleHostsConfiguration( - hosts, 10000, false, targetServerType - )) + PostgresqlConnectionConfiguration configuration = builder + .targetServerType(targetServerType) .username(firstServer.getUsername()) .password(firstServer.getPassword()) .build(); From 4bbdbf144ff3b15183535c231802d07f81dc27f6 Mon Sep 17 00:00:00 2001 From: Viktors Baltauss Date: Fri, 14 Feb 2020 12:56:19 +0200 Subject: [PATCH 5/5] Duration PR comment fixed, polishing --- .../MultipleHostsClientFactory.java | 5 ++- .../PostgresqlConnectionConfiguration.java | 4 +- .../PostgresqlConnectionFactoryProvider.java | 41 +++++++++---------- .../client/MultipleHostsConfiguration.java | 11 ++--- ...stgresqlConnectionFactoryProviderTest.java | 3 +- 5 files changed, 33 insertions(+), 31 deletions(-) diff --git a/src/main/java/io/r2dbc/postgresql/MultipleHostsClientFactory.java b/src/main/java/io/r2dbc/postgresql/MultipleHostsClientFactory.java index e0c4f118..b8864f55 100644 --- a/src/main/java/io/r2dbc/postgresql/MultipleHostsClientFactory.java +++ b/src/main/java/io/r2dbc/postgresql/MultipleHostsClientFactory.java @@ -11,6 +11,7 @@ import javax.annotation.Nullable; import java.net.InetSocketAddress; import java.net.SocketAddress; +import java.time.Duration; import java.util.ArrayList; import java.util.Collections; import java.util.List; @@ -115,7 +116,9 @@ private Flux getCandidates(TargetServerType targetServerType) { int counter = 0; for (SocketAddress address : addresses) { HostSpecStatus currentStatus = this.statusMap.get(address); - if (currentStatus == null || now > currentStatus.updated + this.configuration.getHostRecheckTime()) { + Duration hostRecheckDuration = this.configuration.getHostRecheckTime(); + boolean recheck = currentStatus == null || hostRecheckDuration.plusMillis(currentStatus.updated).toMillis() < now; + if (recheck) { sink.next(address); counter++; } else if (targetServerType.allowStatus(currentStatus.hostStatus)) { diff --git a/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionConfiguration.java b/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionConfiguration.java index b02d0c7f..4be0395c 100644 --- a/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionConfiguration.java +++ b/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionConfiguration.java @@ -598,12 +598,12 @@ public Builder targetServerType(TargetServerType targetServerType) { } /** - * Controls how long in seconds the knowledge about a host state is cached connection factory. The default value is 10000 milliseconds. + * Controls how long in milliseconds the knowledge about a host state is cached connection factory. The default value is 10000 milliseconds. * * @param hostRecheckTime host recheck time in milliseconds * @return this {@link Builder} */ - public Builder hostRecheckTime(int hostRecheckTime) { + public Builder hostRecheckTime(@Nullable Duration hostRecheckTime) { if (this.multipleHostsConfiguration == null) { this.multipleHostsConfiguration = MultipleHostsConfiguration.builder(); } diff --git a/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryProvider.java b/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryProvider.java index 7470706c..70a02d01 100644 --- a/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryProvider.java +++ b/src/main/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryProvider.java @@ -25,6 +25,7 @@ import io.r2dbc.spi.Option; import javax.net.ssl.HostnameVerifier; +import java.time.Duration; import java.util.Map; import java.util.function.Function; @@ -232,6 +233,7 @@ private static PostgresqlConnectionConfiguration.Builder fromConnectionFactoryOp if (isUsingTcp(connectionFactoryOptions)) { setupSsl(builder, connectionFactoryOptions); + setupFailover(builder, connectionFactoryOptions); } else { builder.socket(connectionFactoryOptions.getRequiredValue(SOCKET)); } @@ -263,27 +265,11 @@ private static void setSslHostnameVerifier(PostgresqlConnectionConfiguration.Bui } } - private static void setupSsl(PostgresqlConnectionConfiguration.Builder builder, ConnectionFactoryOptions connectionFactoryOptions) { - Boolean ssl = connectionFactoryOptions.getValue(SSL); - if (ssl != null && ssl) { - builder.enableSsl(); - } - - Object sslMode = connectionFactoryOptions.getValue(SSL_MODE); - if (sslMode != null) { - if (sslMode instanceof String) { - builder.sslMode(SSLMode.fromValue(sslMode.toString())); - } else { - builder.sslMode((SSLMode) sslMode); - } - } - - builder.connectTimeout(connectionFactoryOptions.getValue(CONNECT_TIMEOUT)); - builder.database(connectionFactoryOptions.getValue(DATABASE)); - + private static void setupFailover(PostgresqlConnectionConfiguration.Builder builder, ConnectionFactoryOptions connectionFactoryOptions) { if (FAILOVER_PROTOCOL.equals(connectionFactoryOptions.getValue(PROTOCOL))) { if (connectionFactoryOptions.hasOption(HOST_RECHECK_TIME)) { - builder.hostRecheckTime(connectionFactoryOptions.getRequiredValue(HOST_RECHECK_TIME)); + Duration hostRecheckTime = Duration.ofMillis(connectionFactoryOptions.getRequiredValue(HOST_RECHECK_TIME)); + builder.hostRecheckTime(hostRecheckTime); } if (connectionFactoryOptions.hasOption(LOAD_BALANCE_HOSTS)) { Object loadBalanceHosts = connectionFactoryOptions.getRequiredValue(LOAD_BALANCE_HOSTS); @@ -323,11 +309,22 @@ private static void setupSsl(PostgresqlConnectionConfiguration.Builder builder, builder.port(port); } } + } + private static void setupSsl(PostgresqlConnectionConfiguration.Builder builder, ConnectionFactoryOptions connectionFactoryOptions) { + Boolean ssl = connectionFactoryOptions.getValue(SSL); + if (ssl != null && ssl) { + builder.enableSsl(); + } - builder.password(connectionFactoryOptions.getValue(PASSWORD)); - builder.schema(connectionFactoryOptions.getValue(SCHEMA)); - builder.username(connectionFactoryOptions.getRequiredValue(USER)); + Object sslMode = connectionFactoryOptions.getValue(SSL_MODE); + if (sslMode != null) { + if (sslMode instanceof String) { + builder.sslMode(SSLMode.fromValue(sslMode.toString())); + } else { + builder.sslMode((SSLMode) sslMode); + } + } String sslRootCert = connectionFactoryOptions.getValue(SSL_ROOT_CERT); if (sslRootCert != null) { diff --git a/src/main/java/io/r2dbc/postgresql/client/MultipleHostsConfiguration.java b/src/main/java/io/r2dbc/postgresql/client/MultipleHostsConfiguration.java index c2d6074e..82226e0c 100644 --- a/src/main/java/io/r2dbc/postgresql/client/MultipleHostsConfiguration.java +++ b/src/main/java/io/r2dbc/postgresql/client/MultipleHostsConfiguration.java @@ -3,6 +3,7 @@ import io.r2dbc.postgresql.TargetServerType; import io.r2dbc.postgresql.util.Assert; +import java.time.Duration; import java.util.ArrayList; import java.util.List; @@ -12,20 +13,20 @@ public class MultipleHostsConfiguration { private final List hosts; - private final int hostRecheckTime; + private final Duration hostRecheckTime; private final boolean loadBalanceHosts; private final TargetServerType targetServerType; - public MultipleHostsConfiguration(List hosts, int hostRecheckTime, boolean loadBalanceHosts, TargetServerType targetServerType) { + public MultipleHostsConfiguration(List hosts, Duration hostRecheckTime, boolean loadBalanceHosts, TargetServerType targetServerType) { this.hosts = hosts; this.hostRecheckTime = hostRecheckTime; this.loadBalanceHosts = loadBalanceHosts; this.targetServerType = targetServerType; } - public int getHostRecheckTime() { + public Duration getHostRecheckTime() { return hostRecheckTime; } @@ -95,7 +96,7 @@ public static Builder builder() { */ public static class Builder { - private int hostRecheckTime = 10000; + private Duration hostRecheckTime = Duration.ofMillis(10000); private List hosts = new ArrayList<>(); @@ -124,7 +125,7 @@ public Builder targetServerType(TargetServerType targetServerType) { * @param hostRecheckTime host recheck time in milliseconds * @return this {@link Builder} */ - public Builder hostRecheckTime(int hostRecheckTime) { + public Builder hostRecheckTime(Duration hostRecheckTime) { this.hostRecheckTime = hostRecheckTime; return this; } diff --git a/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryProviderTest.java b/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryProviderTest.java index bf8d6503..acad3cb9 100644 --- a/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryProviderTest.java +++ b/src/test/java/io/r2dbc/postgresql/PostgresqlConnectionFactoryProviderTest.java @@ -23,6 +23,7 @@ import io.r2dbc.spi.Option; import org.junit.jupiter.api.Test; +import java.time.Duration; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -297,7 +298,7 @@ void testFailoverConfiguration() { assertThat(factory.getConfiguration().getSingleHostConfiguration()).isNull(); assertThat(factory.getConfiguration().getMultipleHostsConfiguration().isLoadBalanceHosts()).isEqualTo(true); - assertThat(factory.getConfiguration().getMultipleHostsConfiguration().getHostRecheckTime()).isEqualTo(20000); + assertThat(factory.getConfiguration().getMultipleHostsConfiguration().getHostRecheckTime()).isEqualTo(Duration.ofMillis(20000)); assertThat(factory.getConfiguration().getMultipleHostsConfiguration().getTargetServerType()).isEqualTo(TargetServerType.SECONDARY); List hosts = factory.getConfiguration().getMultipleHostsConfiguration().getHosts(); assertThat(hosts).hasSize(3);