diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/pool/ConnectionPoolImpl.java b/driver/src/main/java/org/neo4j/driver/internal/async/pool/ConnectionPoolImpl.java index 3438516ecf..c6f5de94bd 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/pool/ConnectionPoolImpl.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/pool/ConnectionPoolImpl.java @@ -22,15 +22,20 @@ import io.netty.channel.Channel; import io.netty.channel.EventLoopGroup; +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map; import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; import java.util.concurrent.CompletionStage; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.ConcurrentMap; import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReadWriteLock; +import java.util.concurrent.locks.ReentrantReadWriteLock; +import java.util.function.Supplier; import org.neo4j.driver.Logger; import org.neo4j.driver.Logging; @@ -61,7 +66,8 @@ public class ConnectionPoolImpl implements ConnectionPool private final MetricsListener metricsListener; private final boolean ownsEventLoopGroup; - private final ConcurrentMap pools = new ConcurrentHashMap<>(); + private final ReadWriteLock addressToPoolLock = new ReentrantReadWriteLock(); + private final Map addressToPool = new HashMap<>(); private final AtomicBoolean closed = new AtomicBoolean(); private final CompletableFuture closeFuture = new CompletableFuture<>(); private final ConnectionFactory connectionFactory; @@ -124,25 +130,32 @@ public CompletionStage acquire( BoltServerAddress address ) @Override public void retainAll( Set addressesToRetain ) { - for ( BoltServerAddress address : pools.keySet() ) + executeWithLock( addressToPoolLock.writeLock(), () -> { - if ( !addressesToRetain.contains( address ) ) + Iterator> entryIterator = addressToPool.entrySet().iterator(); + while ( entryIterator.hasNext() ) { - int activeChannels = nettyChannelTracker.inUseChannelCount( address ); - if ( activeChannels == 0 ) + Map.Entry entry = entryIterator.next(); + BoltServerAddress address = entry.getKey(); + if ( !addressesToRetain.contains( address ) ) { - // address is not present in updated routing table and has no active connections - // it's now safe to terminate corresponding connection pool and forget about it - ExtendedChannelPool pool = pools.remove( address ); - if ( pool != null ) + int activeChannels = nettyChannelTracker.inUseChannelCount( address ); + if ( activeChannels == 0 ) { - log.info( "Closing connection pool towards %s, it has no active connections " + - "and is not in the routing table registry.", address ); - closePoolInBackground( address, pool ); + // address is not present in updated routing table and has no active connections + // it's now safe to terminate corresponding connection pool and forget about it + ExtendedChannelPool pool = entry.getValue(); + entryIterator.remove(); + if ( pool != null ) + { + log.info( "Closing connection pool towards %s, it has no active connections " + + "and is not in the routing table registry.", address ); + closePoolInBackground( address, pool ); + } } } } - } + } ); } @Override @@ -163,21 +176,26 @@ public CompletionStage close() if ( closed.compareAndSet( false, true ) ) { nettyChannelTracker.prepareToCloseChannels(); - CompletableFuture allPoolClosedFuture = closeAllPools(); - // We can only shutdown event loop group when all netty pools are fully closed, - // otherwise the netty pools might missing threads (from event loop group) to execute clean ups. - allPoolClosedFuture.whenComplete( ( ignored, pollCloseError ) -> { - pools.clear(); - if ( !ownsEventLoopGroup ) - { - completeWithNullIfNoError( closeFuture, pollCloseError ); - } - else - { - shutdownEventLoopGroup( pollCloseError ); - } - } ); + executeWithLockAsync( addressToPoolLock.writeLock(), + () -> + { + // We can only shutdown event loop group when all netty pools are fully closed, + // otherwise the netty pools might missing threads (from event loop group) to execute clean ups. + return closeAllPools().whenComplete( + ( ignored, pollCloseError ) -> + { + addressToPool.clear(); + if ( !ownsEventLoopGroup ) + { + completeWithNullIfNoError( closeFuture, pollCloseError ); + } + else + { + shutdownEventLoopGroup( pollCloseError ); + } + } ); + } ); } return closeFuture; } @@ -185,13 +203,13 @@ public CompletionStage close() @Override public boolean isOpen( BoltServerAddress address ) { - return pools.containsKey( address ); + return executeWithLock( addressToPoolLock.readLock(), () -> addressToPool.containsKey( address ) ); } @Override public String toString() { - return "ConnectionPoolImpl{" + "pools=" + pools + '}'; + return executeWithLock( addressToPoolLock.readLock(), () -> "ConnectionPoolImpl{" + "pools=" + addressToPool + '}' ); } private void processAcquisitionError( ExtendedChannelPool pool, BoltServerAddress serverAddress, Throwable error ) @@ -237,7 +255,7 @@ private void assertNotClosed( BoltServerAddress address, Channel channel, Extend { pool.release( channel ); closePoolInBackground( address, pool ); - pools.remove( address ); + executeWithLock( addressToPoolLock.writeLock(), () -> addressToPool.remove( address ) ); assertNotClosed(); } } @@ -245,7 +263,7 @@ private void assertNotClosed( BoltServerAddress address, Channel channel, Extend // for testing only ExtendedChannelPool getPool( BoltServerAddress address ) { - return pools.get( address ); + return executeWithLock( addressToPoolLock.readLock(), () -> addressToPool.get( address ) ); } ExtendedChannelPool newPool( BoltServerAddress address ) @@ -256,12 +274,22 @@ ExtendedChannelPool newPool( BoltServerAddress address ) private ExtendedChannelPool getOrCreatePool( BoltServerAddress address ) { - return pools.computeIfAbsent( address, ignored -> { - ExtendedChannelPool pool = newPool( address ); - // before the connection pool is added I can add the metrics for the pool. - metricsListener.putPoolMetrics( pool.id(), address, this ); - return pool; - } ); + ExtendedChannelPool existingPool = executeWithLock( addressToPoolLock.readLock(), () -> addressToPool.get( address ) ); + return existingPool != null + ? existingPool + : executeWithLock( addressToPoolLock.writeLock(), + () -> + { + ExtendedChannelPool pool = addressToPool.get( address ); + if ( pool == null ) + { + pool = newPool( address ); + // before the connection pool is added I can add the metrics for the pool. + metricsListener.putPoolMetrics( pool.id(), address, this ); + addressToPool.put( address, pool ); + } + return pool; + } ); } private CompletionStage closePool( ExtendedChannelPool pool ) @@ -303,12 +331,45 @@ private void shutdownEventLoopGroup( Throwable pollCloseError ) private CompletableFuture closeAllPools() { return CompletableFuture.allOf( - pools.entrySet().stream().map( entry -> { - BoltServerAddress address = entry.getKey(); - ExtendedChannelPool pool = entry.getValue(); - log.info( "Closing connection pool towards %s", address ); - // Wait for all pools to be closed. - return closePool( pool ).toCompletableFuture(); - } ).toArray( CompletableFuture[]::new ) ); + addressToPool.entrySet().stream() + .map( entry -> + { + BoltServerAddress address = entry.getKey(); + ExtendedChannelPool pool = entry.getValue(); + log.info( "Closing connection pool towards %s", address ); + // Wait for all pools to be closed. + return closePool( pool ).toCompletableFuture(); + } ) + .toArray( CompletableFuture[]::new ) ); + } + + private void executeWithLock( Lock lock, Runnable runnable ) + { + executeWithLock( lock, () -> + { + runnable.run(); + return null; + } ); + } + + private T executeWithLock( Lock lock, Supplier supplier ) + { + lock.lock(); + try + { + return supplier.get(); + } + finally + { + lock.unlock(); + } + } + + private void executeWithLockAsync( Lock lock, Supplier> stageSupplier ) + { + lock.lock(); + CompletableFuture.completedFuture( lock ) + .thenCompose( ignored -> stageSupplier.get() ) + .whenComplete( ( ignored, throwable ) -> lock.unlock() ); } }