Skip to content

Commit 5a7afbb

Browse files
authored
Update RoutedBoltConnectionProvider (#1582)
1 parent c5b2868 commit 5a7afbb

File tree

2 files changed

+47
-44
lines changed

2 files changed

+47
-44
lines changed

driver/src/main/java/org/neo4j/driver/internal/bolt/routedimpl/RoutedBoltConnection.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -267,7 +267,7 @@ public CompletionStage<Void> forceClose(String reason) {
267267

268268
@Override
269269
public CompletionStage<Void> close() {
270-
provider.decreaseCount(serverAddress());
270+
provider.decrementInUseCount(serverAddress());
271271
return delegate.close();
272272
}
273273

driver/src/main/java/org/neo4j/driver/internal/bolt/routedimpl/RoutedBoltConnectionProvider.java

Lines changed: 46 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
package org.neo4j.driver.internal.bolt.routedimpl;
1818

1919
import static java.lang.String.format;
20-
import static org.neo4j.driver.internal.bolt.routedimpl.util.LockUtil.executeWithLock;
2120

2221
import java.time.Clock;
2322
import java.util.ArrayList;
@@ -30,7 +29,6 @@
3029
import java.util.concurrent.CompletableFuture;
3130
import java.util.concurrent.CompletionStage;
3231
import java.util.concurrent.atomic.AtomicReference;
33-
import java.util.concurrent.locks.ReentrantLock;
3432
import java.util.function.Consumer;
3533
import java.util.function.Function;
3634
import java.util.function.Supplier;
@@ -71,7 +69,6 @@ public class RoutedBoltConnectionProvider implements BoltConnectionProvider {
7169
"Failed to obtain a connection towards address %s, will try other addresses if available. Complete failure is reported separately from this entry.";
7270
private final LoggingProvider logging;
7371
private final System.Logger log;
74-
private final ReentrantLock lock = new ReentrantLock();
7572
private final Supplier<BoltConnectionProvider> boltConnectionProviderSupplier;
7673

7774
private final Map<BoltServerAddress, BoltConnectionProvider> addressToProvider = new HashMap<>();
@@ -85,8 +82,6 @@ public class RoutedBoltConnectionProvider implements BoltConnectionProvider {
8582
private Rediscovery rediscovery;
8683
private RoutingTableRegistry registry;
8784

88-
private BoltServerAddress address;
89-
9085
private RoutingContext routingContext;
9186
private BoltAgent boltAgent;
9287
private String userAgent;
@@ -107,28 +102,21 @@ public RoutedBoltConnectionProvider(
107102
this.resolver = Objects.requireNonNull(resolver);
108103
this.logging = Objects.requireNonNull(logging);
109104
this.log = logging.getLog(getClass());
110-
this.loadBalancingStrategy = new LeastConnectedLoadBalancingStrategy(
111-
(addr) -> {
112-
synchronized (this) {
113-
return addressToInUseCount.getOrDefault(address, 0);
114-
}
115-
},
116-
logging);
105+
this.loadBalancingStrategy = new LeastConnectedLoadBalancingStrategy(this::getInUseCount, logging);
117106
this.domainNameResolver = Objects.requireNonNull(domainNameResolver);
118107
this.routingTablePurgeDelayMs = routingTablePurgeDelayMs;
119108
this.rediscovery = rediscovery;
120109
this.clock = Objects.requireNonNull(clock);
121110
}
122111

123112
@Override
124-
public CompletionStage<Void> init(
113+
public synchronized CompletionStage<Void> init(
125114
BoltServerAddress address,
126115
RoutingContext routingContext,
127116
BoltAgent boltAgent,
128117
String userAgent,
129118
int connectTimeoutMillis,
130119
MetricsListener metricsListener) {
131-
this.address = address;
132120
this.routingContext = routingContext;
133121
this.boltAgent = boltAgent;
134122
this.userAgent = userAgent;
@@ -154,10 +142,12 @@ public CompletionStage<BoltConnection> connect(
154142
BoltProtocolVersion minVersion,
155143
NotificationConfig notificationConfig,
156144
Consumer<DatabaseName> databaseNameConsumer) {
145+
RoutingTableRegistry registry;
157146
synchronized (this) {
158147
if (closeFuture != null) {
159148
return CompletableFuture.failedFuture(new IllegalStateException("Connection provider is closed."));
160149
}
150+
registry = this.registry;
161151
}
162152

163153
var handlerRef = new AtomicReference<RoutingTableHandler>();
@@ -196,6 +186,10 @@ public CompletionStage<BoltConnection> connect(
196186

197187
@Override
198188
public CompletionStage<Void> verifyConnectivity(SecurityPlan securityPlan, Map<String, Value> authMap) {
189+
RoutingTableRegistry registry;
190+
synchronized (this) {
191+
registry = this.registry;
192+
}
199193
return supportsMultiDb(securityPlan, authMap)
200194
.thenCompose(supports -> registry.ensureRoutingTable(
201195
securityPlan,
@@ -244,7 +238,7 @@ private synchronized void shutdownUnusedProviders(Set<BoltServerAddress> address
244238
while (iterator.hasNext()) {
245239
var entry = iterator.next();
246240
var address = entry.getKey();
247-
if (!addressesToRetain.contains(address) && addressToInUseCount.getOrDefault(address, 0) == 0) {
241+
if (!addressesToRetain.contains(address) && getInUseCount(address) == 0) {
248242
entry.getValue().close();
249243
iterator.remove();
250244
}
@@ -256,8 +250,12 @@ private CompletionStage<Boolean> detectFeature(
256250
Map<String, Value> authMap,
257251
String baseErrorMessagePrefix,
258252
Function<BoltConnection, Boolean> featureDetectionFunction) {
259-
List<BoltServerAddress> addresses;
253+
Rediscovery rediscovery;
254+
synchronized (this) {
255+
rediscovery = this.rediscovery;
256+
}
260257

258+
List<BoltServerAddress> addresses;
261259
try {
262260
addresses = rediscovery.resolve();
263261
} catch (Throwable error) {
@@ -390,11 +388,7 @@ private void acquire(
390388
result.completeExceptionally(error);
391389
}
392390
} else {
393-
synchronized (this) {
394-
var inUse = addressToInUseCount.getOrDefault(address, 0);
395-
inUse++;
396-
addressToInUseCount.put(address, inUse);
397-
}
391+
incrementInUseCount(address);
398392
result.complete(connection);
399393
}
400394
});
@@ -414,43 +408,52 @@ private static List<BoltServerAddress> getAddressesByMode(AccessMode mode, Routi
414408
};
415409
}
416410

417-
synchronized void decreaseCount(BoltServerAddress address) {
418-
var inUse = addressToInUseCount.get(address);
419-
if (inUse != null) {
420-
inUse--;
421-
if (inUse <= 0) {
422-
addressToInUseCount.remove(address);
411+
private synchronized int getInUseCount(BoltServerAddress address) {
412+
return addressToInUseCount.getOrDefault(address, 0);
413+
}
414+
415+
private synchronized void incrementInUseCount(BoltServerAddress address) {
416+
addressToInUseCount.merge(address, 1, Integer::sum);
417+
}
418+
419+
synchronized void decrementInUseCount(BoltServerAddress address) {
420+
addressToInUseCount.compute(address, (ignored, value) -> {
421+
if (value == null) {
422+
return null;
423423
} else {
424-
addressToInUseCount.put(address, inUse);
424+
value--;
425+
return value > 0 ? value : null;
425426
}
426-
}
427+
});
427428
}
428429

429430
@Override
430431
public CompletionStage<Void> close() {
431432
CompletableFuture<Void> closeFuture;
432433
synchronized (this) {
433434
if (this.closeFuture == null) {
434-
var futures = executeWithLock(lock, () -> addressToProvider.values().stream()
435-
.map(BoltConnectionProvider::close)
436-
.map(CompletionStage::toCompletableFuture)
437-
.toArray(CompletableFuture[]::new));
435+
@SuppressWarnings({"rawtypes", "RedundantSuppression"})
436+
var futures = new CompletableFuture[addressToProvider.size()];
437+
var iterator = addressToProvider.values().iterator();
438+
var index = 0;
439+
while (iterator.hasNext()) {
440+
futures[index++] = iterator.next().close().toCompletableFuture();
441+
iterator.remove();
442+
}
438443
this.closeFuture = CompletableFuture.allOf(futures);
439444
}
440445
closeFuture = this.closeFuture;
441446
}
442447
return closeFuture;
443448
}
444449

445-
private BoltConnectionProvider get(BoltServerAddress address) {
446-
return executeWithLock(lock, () -> {
447-
var provider = addressToProvider.get(address);
448-
if (provider == null) {
449-
provider = boltConnectionProviderSupplier.get();
450-
provider.init(address, routingContext, boltAgent, userAgent, connectTimeoutMillis, metricsListener);
451-
addressToProvider.put(address, provider);
452-
}
453-
return provider;
454-
});
450+
private synchronized BoltConnectionProvider get(BoltServerAddress address) {
451+
var provider = addressToProvider.get(address);
452+
if (provider == null) {
453+
provider = boltConnectionProviderSupplier.get();
454+
provider.init(address, routingContext, boltAgent, userAgent, connectTimeoutMillis, metricsListener);
455+
addressToProvider.put(address, provider);
456+
}
457+
return provider;
455458
}
456459
}

0 commit comments

Comments
 (0)