Skip to content

Update RoutedBoltConnectionProvider #1582

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ public CompletionStage<Void> forceClose(String reason) {

@Override
public CompletionStage<Void> close() {
provider.decreaseCount(serverAddress());
provider.decrementInUseCount(serverAddress());
return delegate.close();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package org.neo4j.driver.internal.bolt.routedimpl;

import static java.lang.String.format;
import static org.neo4j.driver.internal.bolt.routedimpl.util.LockUtil.executeWithLock;

import java.time.Clock;
import java.util.ArrayList;
Expand All @@ -30,7 +29,6 @@
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.atomic.AtomicReference;
import java.util.concurrent.locks.ReentrantLock;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
Expand Down Expand Up @@ -71,7 +69,6 @@ public class RoutedBoltConnectionProvider implements BoltConnectionProvider {
"Failed to obtain a connection towards address %s, will try other addresses if available. Complete failure is reported separately from this entry.";
private final LoggingProvider logging;
private final System.Logger log;
private final ReentrantLock lock = new ReentrantLock();
private final Supplier<BoltConnectionProvider> boltConnectionProviderSupplier;

private final Map<BoltServerAddress, BoltConnectionProvider> addressToProvider = new HashMap<>();
Expand All @@ -85,8 +82,6 @@ public class RoutedBoltConnectionProvider implements BoltConnectionProvider {
private Rediscovery rediscovery;
private RoutingTableRegistry registry;

private BoltServerAddress address;

private RoutingContext routingContext;
private BoltAgent boltAgent;
private String userAgent;
Expand All @@ -107,28 +102,21 @@ public RoutedBoltConnectionProvider(
this.resolver = Objects.requireNonNull(resolver);
this.logging = Objects.requireNonNull(logging);
this.log = logging.getLog(getClass());
this.loadBalancingStrategy = new LeastConnectedLoadBalancingStrategy(
(addr) -> {
synchronized (this) {
return addressToInUseCount.getOrDefault(address, 0);
}
},
logging);
this.loadBalancingStrategy = new LeastConnectedLoadBalancingStrategy(this::getInUseCount, logging);
this.domainNameResolver = Objects.requireNonNull(domainNameResolver);
this.routingTablePurgeDelayMs = routingTablePurgeDelayMs;
this.rediscovery = rediscovery;
this.clock = Objects.requireNonNull(clock);
}

@Override
public CompletionStage<Void> init(
public synchronized CompletionStage<Void> init(
BoltServerAddress address,
RoutingContext routingContext,
BoltAgent boltAgent,
String userAgent,
int connectTimeoutMillis,
MetricsListener metricsListener) {
this.address = address;
this.routingContext = routingContext;
this.boltAgent = boltAgent;
this.userAgent = userAgent;
Expand All @@ -154,10 +142,12 @@ public CompletionStage<BoltConnection> connect(
BoltProtocolVersion minVersion,
NotificationConfig notificationConfig,
Consumer<DatabaseName> databaseNameConsumer) {
RoutingTableRegistry registry;
synchronized (this) {
if (closeFuture != null) {
return CompletableFuture.failedFuture(new IllegalStateException("Connection provider is closed."));
}
registry = this.registry;
}

var handlerRef = new AtomicReference<RoutingTableHandler>();
Expand Down Expand Up @@ -196,6 +186,10 @@ public CompletionStage<BoltConnection> connect(

@Override
public CompletionStage<Void> verifyConnectivity(SecurityPlan securityPlan, Map<String, Value> authMap) {
RoutingTableRegistry registry;
synchronized (this) {
registry = this.registry;
}
return supportsMultiDb(securityPlan, authMap)
.thenCompose(supports -> registry.ensureRoutingTable(
securityPlan,
Expand Down Expand Up @@ -244,7 +238,7 @@ private synchronized void shutdownUnusedProviders(Set<BoltServerAddress> address
while (iterator.hasNext()) {
var entry = iterator.next();
var address = entry.getKey();
if (!addressesToRetain.contains(address) && addressToInUseCount.getOrDefault(address, 0) == 0) {
if (!addressesToRetain.contains(address) && getInUseCount(address) == 0) {
entry.getValue().close();
iterator.remove();
}
Expand All @@ -256,8 +250,12 @@ private CompletionStage<Boolean> detectFeature(
Map<String, Value> authMap,
String baseErrorMessagePrefix,
Function<BoltConnection, Boolean> featureDetectionFunction) {
List<BoltServerAddress> addresses;
Rediscovery rediscovery;
synchronized (this) {
rediscovery = this.rediscovery;
}

List<BoltServerAddress> addresses;
try {
addresses = rediscovery.resolve();
} catch (Throwable error) {
Expand Down Expand Up @@ -390,11 +388,7 @@ private void acquire(
result.completeExceptionally(error);
}
} else {
synchronized (this) {
var inUse = addressToInUseCount.getOrDefault(address, 0);
inUse++;
addressToInUseCount.put(address, inUse);
}
incrementInUseCount(address);
result.complete(connection);
}
});
Expand All @@ -414,43 +408,52 @@ private static List<BoltServerAddress> getAddressesByMode(AccessMode mode, Routi
};
}

synchronized void decreaseCount(BoltServerAddress address) {
var inUse = addressToInUseCount.get(address);
if (inUse != null) {
inUse--;
if (inUse <= 0) {
addressToInUseCount.remove(address);
private synchronized int getInUseCount(BoltServerAddress address) {
return addressToInUseCount.getOrDefault(address, 0);
}

private synchronized void incrementInUseCount(BoltServerAddress address) {
addressToInUseCount.merge(address, 1, Integer::sum);
}

synchronized void decrementInUseCount(BoltServerAddress address) {
addressToInUseCount.compute(address, (ignored, value) -> {
if (value == null) {
return null;
} else {
addressToInUseCount.put(address, inUse);
value--;
return value > 0 ? value : null;
}
}
});
}

@Override
public CompletionStage<Void> close() {
CompletableFuture<Void> closeFuture;
synchronized (this) {
if (this.closeFuture == null) {
var futures = executeWithLock(lock, () -> addressToProvider.values().stream()
.map(BoltConnectionProvider::close)
.map(CompletionStage::toCompletableFuture)
.toArray(CompletableFuture[]::new));
@SuppressWarnings({"rawtypes", "RedundantSuppression"})
var futures = new CompletableFuture[addressToProvider.size()];
var iterator = addressToProvider.values().iterator();
var index = 0;
while (iterator.hasNext()) {
futures[index++] = iterator.next().close().toCompletableFuture();
iterator.remove();
}
this.closeFuture = CompletableFuture.allOf(futures);
}
closeFuture = this.closeFuture;
}
return closeFuture;
}

private BoltConnectionProvider get(BoltServerAddress address) {
return executeWithLock(lock, () -> {
var provider = addressToProvider.get(address);
if (provider == null) {
provider = boltConnectionProviderSupplier.get();
provider.init(address, routingContext, boltAgent, userAgent, connectTimeoutMillis, metricsListener);
addressToProvider.put(address, provider);
}
return provider;
});
private synchronized BoltConnectionProvider get(BoltServerAddress address) {
var provider = addressToProvider.get(address);
if (provider == null) {
provider = boltConnectionProviderSupplier.get();
provider.init(address, routingContext, boltAgent, userAgent, connectTimeoutMillis, metricsListener);
addressToProvider.put(address, provider);
}
return provider;
}
}