diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/NetworkConnection.java b/driver/src/main/java/org/neo4j/driver/internal/async/NetworkConnection.java index 797a60ad6f..19bf6d5664 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/NetworkConnection.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/NetworkConnection.java @@ -19,7 +19,6 @@ package org.neo4j.driver.internal.async; import io.netty.channel.Channel; -import io.netty.channel.pool.ChannelPool; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; @@ -28,6 +27,7 @@ import org.neo4j.driver.internal.BoltServerAddress; import org.neo4j.driver.internal.async.connection.ChannelAttributes; import org.neo4j.driver.internal.async.inbound.InboundMessageDispatcher; +import org.neo4j.driver.internal.async.pool.ExtendedChannelPool; import org.neo4j.driver.internal.handlers.ChannelReleasingResetResponseHandler; import org.neo4j.driver.internal.handlers.ResetResponseHandler; import org.neo4j.driver.internal.messaging.BoltProtocol; @@ -57,7 +57,7 @@ public class NetworkConnection implements Connection private final BoltServerAddress serverAddress; private final ServerVersion serverVersion; private final BoltProtocol protocol; - private final ChannelPool channelPool; + private final ExtendedChannelPool channelPool; private final CompletableFuture releaseFuture; private final Clock clock; @@ -65,7 +65,7 @@ public class NetworkConnection implements Connection private final MetricsListener metricsListener; private final ListenerEvent inUseEvent; - public NetworkConnection( Channel channel, ChannelPool channelPool, Clock clock, MetricsListener metricsListener ) + public NetworkConnection( Channel channel, ExtendedChannelPool channelPool, Clock clock, MetricsListener metricsListener ) { this.channel = channel; this.messageDispatcher = ChannelAttributes.messageDispatcher( channel ); diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/pool/ConnectionPoolImpl.java b/driver/src/main/java/org/neo4j/driver/internal/async/pool/ConnectionPoolImpl.java index fe52756936..072bff4987 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/pool/ConnectionPoolImpl.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/pool/ConnectionPoolImpl.java @@ -21,11 +21,9 @@ import io.netty.bootstrap.Bootstrap; import io.netty.channel.Channel; import io.netty.channel.EventLoopGroup; -import io.netty.channel.pool.ChannelPool; -import io.netty.util.concurrent.Future; -import java.util.Map; import java.util.Set; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionException; import java.util.concurrent.CompletionStage; import java.util.concurrent.ConcurrentHashMap; @@ -48,6 +46,8 @@ import org.neo4j.driver.internal.util.Futures; import static java.lang.String.format; +import static org.neo4j.driver.internal.util.Futures.combineErrors; +import static org.neo4j.driver.internal.util.Futures.completeWithNullIfNoError; public class ConnectionPoolImpl implements ConnectionPool { @@ -62,6 +62,7 @@ public class ConnectionPoolImpl implements ConnectionPool private final ConcurrentMap pools = new ConcurrentHashMap<>(); private final AtomicBoolean closed = new AtomicBoolean(); + private final CompletableFuture closeFuture = new CompletableFuture<>(); private final ConnectionFactory connectionFactory; public ConnectionPoolImpl( ChannelConnector connector, Bootstrap bootstrap, PoolSettings settings, MetricsListener metricsListener, Logging logging, @@ -95,9 +96,9 @@ public CompletionStage acquire( BoltServerAddress address ) ListenerEvent acquireEvent = metricsListener.createListenerEvent(); metricsListener.beforeAcquiringOrCreating( pool.id(), acquireEvent ); - Future connectionFuture = pool.acquire(); + CompletionStage channelFuture = pool.acquire(); - return Futures.asCompletionStage( connectionFuture ).handle( ( channel, error ) -> + return channelFuture.handle( ( channel, error ) -> { try { @@ -131,8 +132,8 @@ public void retainAll( Set addressesToRetain ) if ( pool != null ) { log.info( "Closing connection pool towards %s, it has no active connections " + - "and is not in the routing table", address ); - closePool( pool ); + "and is not in the routing table registry.", address ); + closePoolInBackground( address, pool ); } } } @@ -156,37 +157,24 @@ public CompletionStage close() { if ( closed.compareAndSet( false, true ) ) { - try - { - nettyChannelTracker.prepareToCloseChannels(); - for ( Map.Entry entry : pools.entrySet() ) - { - BoltServerAddress address = entry.getKey(); - ExtendedChannelPool pool = entry.getValue(); - log.info( "Closing connection pool towards %s", address ); - closePool( pool ); - } + nettyChannelTracker.prepareToCloseChannels(); + CompletableFuture allPoolClosedFuture = closeAllPools(); + // We can only shutdown event loop group when all netty pools are fully closed, + // otherwise the netty pools might missing threads (from event loop group) to execute clean ups. + allPoolClosedFuture.whenComplete( ( ignored, pollCloseError ) -> { pools.clear(); - } - finally - { - - if (ownsEventLoopGroup) { - // This is an attempt to speed up the shut down procedure of the driver - // Feel free return this back to shutdownGracefully() method with default values - // if this proves troublesome!!! - eventLoopGroup().shutdownGracefully(200, 15_000, TimeUnit.MILLISECONDS); + if ( !ownsEventLoopGroup ) + { + completeWithNullIfNoError( closeFuture, pollCloseError ); } - } - } - if (!ownsEventLoopGroup) - { - return Futures.completedWithNull(); + else + { + shutdownEventLoopGroup( pollCloseError ); + } + } ); } - - return Futures.asCompletionStage( eventLoopGroup().terminationFuture() ) - .thenApply( ignore -> null ); + return closeFuture; } @Override @@ -195,31 +183,10 @@ public boolean isOpen( BoltServerAddress address ) return pools.containsKey( address ); } - private ExtendedChannelPool getOrCreatePool( BoltServerAddress address ) - { - return pools.computeIfAbsent( address, this::newPool ); - } - - private void closePool( ExtendedChannelPool pool ) - { - pool.close(); - // after the connection pool is removed/close, I can remove its metrics. - metricsListener.removePoolMetrics( pool.id() ); - } - - ExtendedChannelPool newPool( BoltServerAddress address ) - { - NettyChannelPool pool = - new NettyChannelPool( address, connector, bootstrap, nettyChannelTracker, channelHealthChecker, settings.connectionAcquisitionTimeout(), - settings.maxConnectionPoolSize() ); - // before the connection pool is added I can add the metrics for the pool. - metricsListener.putPoolMetrics( pool.id(), address, this ); - return pool; - } - - private EventLoopGroup eventLoopGroup() + @Override + public String toString() { - return bootstrap.config().group(); + return "ConnectionPoolImpl{" + "pools=" + pools + '}'; } private void processAcquisitionError( ExtendedChannelPool pool, BoltServerAddress serverAddress, Throwable error ) @@ -259,26 +226,84 @@ private void assertNotClosed() } } - private void assertNotClosed( BoltServerAddress address, Channel channel, ChannelPool pool ) + private void assertNotClosed( BoltServerAddress address, Channel channel, ExtendedChannelPool pool ) { if ( closed.get() ) { pool.release( channel ); - pool.close(); + closePoolInBackground( address, pool ); pools.remove( address ); assertNotClosed(); } } - @Override - public String toString() - { - return "ConnectionPoolImpl{" + "pools=" + pools + '}'; - } - // for testing only ExtendedChannelPool getPool( BoltServerAddress address ) { return pools.get( address ); } + + ExtendedChannelPool newPool( BoltServerAddress address ) + { + return new NettyChannelPool( address, connector, bootstrap, nettyChannelTracker, channelHealthChecker, settings.connectionAcquisitionTimeout(), + settings.maxConnectionPoolSize() ); + } + + private ExtendedChannelPool getOrCreatePool( BoltServerAddress address ) + { + return pools.computeIfAbsent( address, ignored -> { + ExtendedChannelPool pool = newPool( address ); + // before the connection pool is added I can add the metrics for the pool. + metricsListener.putPoolMetrics( pool.id(), address, this ); + return pool; + } ); + } + + private CompletionStage closePool( ExtendedChannelPool pool ) + { + return pool.close().whenComplete( ( ignored, error ) -> + // after the connection pool is removed/close, I can remove its metrics. + metricsListener.removePoolMetrics( pool.id() ) ); + } + + private void closePoolInBackground( BoltServerAddress address, ExtendedChannelPool pool ) + { + // Close in the background + closePool( pool ).whenComplete( ( ignored, error ) -> { + if ( error != null ) + { + log.warn( format( "An error occurred while closing connection pool towards %s.", address ), error ); + } + } ); + } + + private EventLoopGroup eventLoopGroup() + { + return bootstrap.config().group(); + } + + private void shutdownEventLoopGroup( Throwable pollCloseError ) + { + // This is an attempt to speed up the shut down procedure of the driver + // This timeout is needed for `closePoolInBackground` to finish background job, especially for races between `acquire` and `close`. + eventLoopGroup().shutdownGracefully( 200, 15_000, TimeUnit.MILLISECONDS ); + + Futures.asCompletionStage( eventLoopGroup().terminationFuture() ) + .whenComplete( ( ignore, eventLoopGroupTerminationError ) -> { + CompletionException combinedErrors = combineErrors( pollCloseError, eventLoopGroupTerminationError ); + completeWithNullIfNoError( closeFuture, combinedErrors ); + } ); + } + + private CompletableFuture closeAllPools() + { + return CompletableFuture.allOf( + pools.entrySet().stream().map( entry -> { + BoltServerAddress address = entry.getKey(); + ExtendedChannelPool pool = entry.getValue(); + log.info( "Closing connection pool towards %s", address ); + // Wait for all pools to be closed. + return closePool( pool ).toCompletableFuture(); + } ).toArray( CompletableFuture[]::new ) ); + } } diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/pool/ExtendedChannelPool.java b/driver/src/main/java/org/neo4j/driver/internal/async/pool/ExtendedChannelPool.java index b1e74564b4..dfa2374d92 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/pool/ExtendedChannelPool.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/pool/ExtendedChannelPool.java @@ -18,11 +18,19 @@ */ package org.neo4j.driver.internal.async.pool; -import io.netty.channel.pool.ChannelPool; +import io.netty.channel.Channel; -public interface ExtendedChannelPool extends ChannelPool +import java.util.concurrent.CompletionStage; + +public interface ExtendedChannelPool { + CompletionStage acquire(); + + CompletionStage release( Channel channel ); + boolean isClosed(); String id(); + + CompletionStage close(); } diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/pool/NettyChannelPool.java b/driver/src/main/java/org/neo4j/driver/internal/async/pool/NettyChannelPool.java index 7be99983da..be90da07ab 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/pool/NettyChannelPool.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/pool/NettyChannelPool.java @@ -24,6 +24,8 @@ import io.netty.channel.pool.ChannelHealthChecker; import io.netty.channel.pool.FixedChannelPool; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; import java.util.concurrent.atomic.AtomicBoolean; import org.neo4j.driver.internal.BoltServerAddress; @@ -32,8 +34,9 @@ import static java.util.Objects.requireNonNull; import static org.neo4j.driver.internal.async.connection.ChannelAttributes.setPoolId; +import static org.neo4j.driver.internal.util.Futures.asCompletionStage; -public class NettyChannelPool extends FixedChannelPool implements ExtendedChannelPool +public class NettyChannelPool implements ExtendedChannelPool { /** * Unlimited amount of parties are allowed to request channels from the pool. @@ -44,60 +47,73 @@ public class NettyChannelPool extends FixedChannelPool implements ExtendedChanne */ private static final boolean RELEASE_HEALTH_CHECK = false; - private final BoltServerAddress address; - private final ChannelConnector connector; - private final NettyChannelTracker handler; + private final FixedChannelPool delegate; private final AtomicBoolean closed = new AtomicBoolean( false ); private final String id; + private final CompletableFuture closeFuture = new CompletableFuture<>(); - public NettyChannelPool( BoltServerAddress address, ChannelConnector connector, Bootstrap bootstrap, NettyChannelTracker handler, + NettyChannelPool( BoltServerAddress address, ChannelConnector connector, Bootstrap bootstrap, NettyChannelTracker handler, ChannelHealthChecker healthCheck, long acquireTimeoutMillis, int maxConnections ) { - super( bootstrap, handler, healthCheck, AcquireTimeoutAction.FAIL, acquireTimeoutMillis, maxConnections, - MAX_PENDING_ACQUIRES, RELEASE_HEALTH_CHECK ); - - this.address = requireNonNull( address ); - this.connector = requireNonNull( connector ); - this.handler = requireNonNull( handler ); + requireNonNull( address ); + requireNonNull( connector ); + requireNonNull( handler ); this.id = poolId( address ); - } - - @Override - protected ChannelFuture connectChannel( Bootstrap bootstrap ) - { - ListenerEvent creatingEvent = handler.channelCreating( this.id ); - ChannelFuture channelFuture = connector.connect( address, bootstrap ); - channelFuture.addListener( future -> + this.delegate = new FixedChannelPool( bootstrap, handler, healthCheck, FixedChannelPool.AcquireTimeoutAction.FAIL, acquireTimeoutMillis, maxConnections, + MAX_PENDING_ACQUIRES, RELEASE_HEALTH_CHECK ) { - if ( future.isSuccess() ) + @Override + protected ChannelFuture connectChannel( Bootstrap bootstrap ) { - // notify pool handler about a successful connection - Channel channel = channelFuture.channel(); - setPoolId( channel, this.id ); - handler.channelCreated( channel, creatingEvent ); + ListenerEvent creatingEvent = handler.channelCreating( id ); + ChannelFuture channelFuture = connector.connect( address, bootstrap ); + channelFuture.addListener( future -> { + if ( future.isSuccess() ) + { + // notify pool handler about a successful connection + Channel channel = channelFuture.channel(); + setPoolId( channel, id ); + handler.channelCreated( channel, creatingEvent ); + } + else + { + handler.channelFailedToCreate( id ); + } + } ); + return channelFuture; } - else - { - handler.channelFailedToCreate( this.id ); - } - } ); - return channelFuture; + }; } @Override - public void close() + public CompletionStage close() { if ( closed.compareAndSet( false, true ) ) { - super.close(); + asCompletionStage( delegate.closeAsync(), closeFuture ); } + return closeFuture; } + @Override + public CompletionStage acquire() + { + return asCompletionStage( delegate.acquire() ); + } + + @Override + public CompletionStage release( Channel channel ) + { + return asCompletionStage( delegate.release( channel ) ); + } + + @Override public boolean isClosed() { return closed.get(); } + @Override public String id() { return this.id; diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/pool/NettyChannelTracker.java b/driver/src/main/java/org/neo4j/driver/internal/async/pool/NettyChannelTracker.java index a7734ded1b..39e62c6345 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/pool/NettyChannelTracker.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/pool/NettyChannelTracker.java @@ -91,9 +91,8 @@ public void channelCreated( Channel channel, ListenerEvent creatingEvent ) log.debug( "Channel [0x%s] created. Local address: %s, remote address: %s", channel.id(), channel.localAddress(), channel.remoteAddress() ); - incrementInUse( channel ); + incrementIdle( channel ); // when it is created, we count it as idle as it has not been acquired out of the pool metricsListener.afterCreated( poolId( channel ), creatingEvent ); - allChannels.add( channel ); } diff --git a/driver/src/main/java/org/neo4j/driver/internal/handlers/ChannelReleasingResetResponseHandler.java b/driver/src/main/java/org/neo4j/driver/internal/handlers/ChannelReleasingResetResponseHandler.java index d96dfc4bf3..e97c84cc93 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/handlers/ChannelReleasingResetResponseHandler.java +++ b/driver/src/main/java/org/neo4j/driver/internal/handlers/ChannelReleasingResetResponseHandler.java @@ -19,12 +19,12 @@ package org.neo4j.driver.internal.handlers; import io.netty.channel.Channel; -import io.netty.channel.pool.ChannelPool; -import io.netty.util.concurrent.Future; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; import org.neo4j.driver.internal.async.inbound.InboundMessageDispatcher; +import org.neo4j.driver.internal.async.pool.ExtendedChannelPool; import org.neo4j.driver.internal.util.Clock; import static org.neo4j.driver.internal.async.connection.ChannelAttributes.setLastUsedTimestamp; @@ -32,10 +32,10 @@ public class ChannelReleasingResetResponseHandler extends ResetResponseHandler { private final Channel channel; - private final ChannelPool pool; + private final ExtendedChannelPool pool; private final Clock clock; - public ChannelReleasingResetResponseHandler( Channel channel, ChannelPool pool, + public ChannelReleasingResetResponseHandler( Channel channel, ExtendedChannelPool pool, InboundMessageDispatcher messageDispatcher, Clock clock, CompletableFuture releaseFuture ) { super( messageDispatcher, releaseFuture ); @@ -58,7 +58,7 @@ protected void resetCompleted( CompletableFuture completionFuture, boolean channel.close(); } - Future released = pool.release( channel ); - released.addListener( ignore -> completionFuture.complete( null ) ); + CompletionStage released = pool.release( channel ); + released.whenComplete( ( ignore, error ) -> completionFuture.complete( null ) ); } } diff --git a/driver/src/main/java/org/neo4j/driver/internal/util/Futures.java b/driver/src/main/java/org/neo4j/driver/internal/util/Futures.java index e2a96d3f06..5cbf96aa8c 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/util/Futures.java +++ b/driver/src/main/java/org/neo4j/driver/internal/util/Futures.java @@ -45,9 +45,27 @@ public static CompletableFuture completedWithNull() return (CompletableFuture) COMPLETED_WITH_NULL; } + public static CompletableFuture completeWithNullIfNoError( CompletableFuture future, Throwable error ) + { + if ( error != null ) + { + future.completeExceptionally( error ); + } + else + { + future.complete( null ); + } + return future; + } + public static CompletionStage asCompletionStage( io.netty.util.concurrent.Future future ) { CompletableFuture result = new CompletableFuture<>(); + return asCompletionStage( future, result ); + } + + public static CompletionStage asCompletionStage( io.netty.util.concurrent.Future future, CompletableFuture result ) + { if ( future.isCancelled() ) { result.cancel( true ); diff --git a/driver/src/test/java/org/neo4j/driver/internal/async/NetworkConnectionTest.java b/driver/src/test/java/org/neo4j/driver/internal/async/NetworkConnectionTest.java index 6df0a4dfde..6df017ac0a 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/async/NetworkConnectionTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/async/NetworkConnectionTest.java @@ -22,7 +22,6 @@ import io.netty.channel.DefaultEventLoop; import io.netty.channel.EventLoop; import io.netty.channel.embedded.EmbeddedChannel; -import io.netty.channel.pool.ChannelPool; import io.netty.util.internal.ConcurrentSet; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.Test; @@ -39,6 +38,7 @@ import org.neo4j.driver.internal.BoltServerAddress; import org.neo4j.driver.internal.async.connection.ChannelAttributes; import org.neo4j.driver.internal.async.inbound.InboundMessageDispatcher; +import org.neo4j.driver.internal.async.pool.ExtendedChannelPool; import org.neo4j.driver.internal.handlers.NoOpResponseHandler; import org.neo4j.driver.internal.messaging.request.RunMessage; import org.neo4j.driver.internal.spi.ResponseHandler; @@ -447,7 +447,7 @@ void shouldCloseChannelWhenTerminated() void shouldReleaseChannelWhenTerminated() { EmbeddedChannel channel = newChannel(); - ChannelPool pool = mock( ChannelPool.class ); + ExtendedChannelPool pool = mock( ExtendedChannelPool.class ); NetworkConnection connection = newConnection( channel, pool ); verify( pool, never() ).release( any() ); @@ -460,7 +460,7 @@ void shouldReleaseChannelWhenTerminated() void shouldNotReleaseChannelMultipleTimesWhenTerminatedMultipleTimes() { EmbeddedChannel channel = newChannel(); - ChannelPool pool = mock( ChannelPool.class ); + ExtendedChannelPool pool = mock( ExtendedChannelPool.class ); NetworkConnection connection = newConnection( channel, pool ); verify( pool, never() ).release( any() ); @@ -478,7 +478,7 @@ void shouldNotReleaseChannelMultipleTimesWhenTerminatedMultipleTimes() void shouldNotReleaseAfterTermination() { EmbeddedChannel channel = newChannel(); - ChannelPool pool = mock( ChannelPool.class ); + ExtendedChannelPool pool = mock( ExtendedChannelPool.class ); NetworkConnection connection = newConnection( channel, pool ); verify( pool, never() ).release( any() ); @@ -611,10 +611,10 @@ private static EmbeddedChannel newChannel() private static NetworkConnection newConnection( Channel channel ) { - return newConnection( channel, mock( ChannelPool.class ) ); + return newConnection( channel, mock( ExtendedChannelPool.class ) ); } - private static NetworkConnection newConnection( Channel channel, ChannelPool pool ) + private static NetworkConnection newConnection( Channel channel, ExtendedChannelPool pool ) { return new NetworkConnection( channel, pool, new FakeClock(), DEV_NULL_METRICS ); } diff --git a/driver/src/test/java/org/neo4j/driver/internal/async/pool/ConnectionPoolImplIT.java b/driver/src/test/java/org/neo4j/driver/internal/async/pool/ConnectionPoolImplIT.java index 306bd94af8..eab6e68194 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/async/pool/ConnectionPoolImplIT.java +++ b/driver/src/test/java/org/neo4j/driver/internal/async/pool/ConnectionPoolImplIT.java @@ -24,6 +24,9 @@ import org.junit.jupiter.api.Test; import org.junit.jupiter.api.extension.RegisterExtension; +import java.util.Collections; +import java.util.concurrent.CompletionStage; + import org.neo4j.driver.exceptions.ServiceUnavailableException; import org.neo4j.driver.internal.BoltServerAddress; import org.neo4j.driver.internal.ConnectionSettings; @@ -86,6 +89,18 @@ void shouldAcquireIdleConnection() assertNotNull( connection2 ); } + @Test + void shouldBeAbleToClosePoolInIOWorkerThread() throws Throwable + { + // In the IO worker thread of a channel obtained from a pool, we shall be able to close the pool. + CompletionStage future = pool.acquire( neo4j.address() ).thenCompose( Connection::release ) + // This shall close all pools + .whenComplete( ( ignored, error ) -> pool.retainAll( Collections.emptySet() ) ); + + // We should be able to come to this line. + await( future ); + } + @Test void shouldFailToAcquireConnectionToWrongAddress() { @@ -118,12 +133,12 @@ void shouldFailToAcquireConnectionWhenPoolIsClosed() { await( pool.acquire( neo4j.address() ) ); ExtendedChannelPool channelPool = this.pool.getPool( neo4j.address() ); - channelPool.close(); + await( channelPool.close() ); ServiceUnavailableException error = assertThrows( ServiceUnavailableException.class, () -> await( pool.acquire( neo4j.address() ) ) ); assertThat( error.getMessage(), containsString( "closed while acquiring a connection" ) ); assertThat( error.getCause(), instanceOf( IllegalStateException.class ) ); - assertThat( error.getCause().getMessage(), containsString( "FixedChannelPooled was closed" ) ); + assertThat( error.getCause().getMessage(), containsString( "FixedChannelPool was closed" ) ); } private ConnectionPoolImpl newPool() throws Exception diff --git a/driver/src/test/java/org/neo4j/driver/internal/async/pool/ConnectionPoolImplTest.java b/driver/src/test/java/org/neo4j/driver/internal/async/pool/ConnectionPoolImplTest.java index 505043e442..f8eecb8876 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/async/pool/ConnectionPoolImplTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/async/pool/ConnectionPoolImplTest.java @@ -19,26 +19,18 @@ package org.neo4j.driver.internal.async.pool; import io.netty.bootstrap.Bootstrap; -import io.netty.channel.Channel; -import io.netty.channel.pool.ChannelPool; -import io.netty.util.concurrent.ImmediateEventExecutor; import org.junit.jupiter.api.Test; -import java.util.HashMap; import java.util.HashSet; -import java.util.Map; import org.neo4j.driver.internal.BoltServerAddress; -import org.neo4j.driver.internal.async.connection.ChannelConnector; import org.neo4j.driver.internal.util.FakeClock; import static java.util.Arrays.asList; import static java.util.Collections.singleton; -import static org.junit.jupiter.api.Assertions.assertNotNull; -import static org.mockito.Mockito.doReturn; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.never; -import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verifyZeroInteractions; import static org.mockito.Mockito.when; import static org.neo4j.driver.internal.BoltServerAddress.LOCAL_DEFAULT; @@ -55,7 +47,7 @@ class ConnectionPoolImplTest void shouldDoNothingWhenRetainOnEmptyPool() { NettyChannelTracker nettyChannelTracker = mock( NettyChannelTracker.class ); - TestConnectionPool pool = new TestConnectionPool( nettyChannelTracker ); + TestConnectionPool pool = newConnectionPool( nettyChannelTracker ); pool.retainAll( singleton( LOCAL_DEFAULT ) ); @@ -66,16 +58,16 @@ void shouldDoNothingWhenRetainOnEmptyPool() void shouldRetainSpecifiedAddresses() { NettyChannelTracker nettyChannelTracker = mock( NettyChannelTracker.class ); - TestConnectionPool pool = new TestConnectionPool( nettyChannelTracker ); + TestConnectionPool pool = newConnectionPool( nettyChannelTracker ); pool.acquire( ADDRESS_1 ); pool.acquire( ADDRESS_2 ); pool.acquire( ADDRESS_3 ); pool.retainAll( new HashSet<>( asList( ADDRESS_1, ADDRESS_2, ADDRESS_3 ) ) ); - for ( ChannelPool channelPool : pool.channelPoolsByAddress.values() ) + for ( ExtendedChannelPool channelPool : pool.channelPoolsByAddress.values() ) { - verify( channelPool, never() ).close(); + assertFalse( channelPool.isClosed() ); } } @@ -83,7 +75,7 @@ void shouldRetainSpecifiedAddresses() void shouldClosePoolsWhenRetaining() { NettyChannelTracker nettyChannelTracker = mock( NettyChannelTracker.class ); - TestConnectionPool pool = new TestConnectionPool( nettyChannelTracker ); + TestConnectionPool pool = newConnectionPool( nettyChannelTracker ); pool.acquire( ADDRESS_1 ); pool.acquire( ADDRESS_2 ); @@ -94,16 +86,16 @@ void shouldClosePoolsWhenRetaining() when( nettyChannelTracker.inUseChannelCount( ADDRESS_3 ) ).thenReturn( 3 ); pool.retainAll( new HashSet<>( asList( ADDRESS_1, ADDRESS_3 ) ) ); - verify( pool.getPool( ADDRESS_1 ), never() ).close(); - verify( pool.getPool( ADDRESS_2 ) ).close(); - verify( pool.getPool( ADDRESS_3 ), never() ).close(); + assertFalse( pool.getPool( ADDRESS_1 ).isClosed() ); + assertTrue( pool.getPool( ADDRESS_2 ).isClosed() ); + assertFalse( pool.getPool( ADDRESS_3 ).isClosed() ); } @Test void shouldNotClosePoolsWithActiveConnectionsWhenRetaining() { NettyChannelTracker nettyChannelTracker = mock( NettyChannelTracker.class ); - TestConnectionPool pool = new TestConnectionPool( nettyChannelTracker ); + TestConnectionPool pool = newConnectionPool( nettyChannelTracker ); pool.acquire( ADDRESS_1 ); pool.acquire( ADDRESS_2 ); @@ -114,9 +106,9 @@ void shouldNotClosePoolsWithActiveConnectionsWhenRetaining() when( nettyChannelTracker.inUseChannelCount( ADDRESS_3 ) ).thenReturn( 0 ); pool.retainAll( singleton( ADDRESS_2 ) ); - verify( pool.getPool( ADDRESS_1 ), never() ).close(); - verify( pool.getPool( ADDRESS_2 ), never() ).close(); - verify( pool.getPool( ADDRESS_3 ) ).close(); + assertFalse( pool.getPool( ADDRESS_1 ).isClosed() ); + assertFalse( pool.getPool( ADDRESS_2 ).isClosed() ); + assertTrue( pool.getPool( ADDRESS_3 ).isClosed() ); } private static PoolSettings newSettings() @@ -124,31 +116,9 @@ private static PoolSettings newSettings() return new PoolSettings( 10, 5000, -1, -1 ); } - private static class TestConnectionPool extends ConnectionPoolImpl + private static TestConnectionPool newConnectionPool( NettyChannelTracker nettyChannelTracker ) { - final Map channelPoolsByAddress = new HashMap<>(); - - TestConnectionPool( NettyChannelTracker nettyChannelTracker ) - { - super( mock( ChannelConnector.class ), mock( Bootstrap.class ), nettyChannelTracker, newSettings(), DEV_NULL_METRICS, DEV_NULL_LOGGING, - new FakeClock(), true, mock( ConnectionFactory.class ) ); - } - - ExtendedChannelPool getPool( BoltServerAddress address ) - { - ExtendedChannelPool pool = channelPoolsByAddress.get( address ); - assertNotNull( pool ); - return pool; - } - - @Override - ExtendedChannelPool newPool( BoltServerAddress address ) - { - ExtendedChannelPool channelPool = mock( ExtendedChannelPool.class ); - Channel channel = mock( Channel.class ); - doReturn( ImmediateEventExecutor.INSTANCE.newSucceededFuture( channel ) ).when( channelPool ).acquire(); - channelPoolsByAddress.put( address, channelPool ); - return channelPool; - } + return new TestConnectionPool( mock( Bootstrap.class ), nettyChannelTracker, newSettings(), DEV_NULL_METRICS, DEV_NULL_LOGGING, + new FakeClock(), true ); } } diff --git a/driver/src/test/java/org/neo4j/driver/internal/async/pool/NettyChannelPoolIT.java b/driver/src/test/java/org/neo4j/driver/internal/async/pool/NettyChannelPoolIT.java index af97f883d0..99a631997a 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/async/pool/NettyChannelPoolIT.java +++ b/driver/src/test/java/org/neo4j/driver/internal/async/pool/NettyChannelPoolIT.java @@ -21,7 +21,6 @@ import io.netty.bootstrap.Bootstrap; import io.netty.channel.Channel; import io.netty.channel.pool.ChannelHealthChecker; -import io.netty.util.concurrent.Future; import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -29,10 +28,12 @@ import java.util.HashMap; import java.util.Map; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.TimeUnit; import java.util.concurrent.TimeoutException; +import org.neo4j.driver.AuthToken; +import org.neo4j.driver.AuthTokens; +import org.neo4j.driver.Value; +import org.neo4j.driver.exceptions.AuthenticationException; import org.neo4j.driver.internal.ConnectionSettings; import org.neo4j.driver.internal.async.connection.BootstrapFactory; import org.neo4j.driver.internal.async.connection.ChannelConnectorImpl; @@ -40,28 +41,22 @@ import org.neo4j.driver.internal.security.SecurityPlan; import org.neo4j.driver.internal.util.FakeClock; import org.neo4j.driver.internal.util.ImmediateSchedulingEventExecutor; -import org.neo4j.driver.AuthToken; -import org.neo4j.driver.AuthTokens; -import org.neo4j.driver.Value; -import org.neo4j.driver.exceptions.AuthenticationException; import org.neo4j.driver.util.DatabaseExtension; import org.neo4j.driver.util.Neo4jRunner; import org.neo4j.driver.util.ParallelizableIT; -import static org.hamcrest.Matchers.instanceOf; -import static org.hamcrest.junit.MatcherAssert.assertThat; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertNotNull; import static org.junit.jupiter.api.Assertions.assertThrows; -import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; +import static org.neo4j.driver.Values.value; import static org.neo4j.driver.internal.logging.DevNullLogging.DEV_NULL_LOGGING; import static org.neo4j.driver.internal.metrics.InternalAbstractMetrics.DEV_NULL_METRICS; -import static org.neo4j.driver.Values.value; +import static org.neo4j.driver.util.TestUtil.await; @ParallelizableIT class NettyChannelPoolIT @@ -98,19 +93,12 @@ void shouldAcquireAndReleaseWithCorrectCredentials() throws Exception { pool = newPool( neo4j.authToken() ); - Future acquireFuture = pool.acquire(); - acquireFuture.await( 5, TimeUnit.SECONDS ); - - assertTrue( acquireFuture.isSuccess() ); - Channel channel = acquireFuture.getNow(); + Channel channel = await( pool.acquire() ); assertNotNull( channel ); verify( poolHandler ).channelCreated( eq( channel ), any() ); verify( poolHandler, never() ).channelReleased( channel ); - Future releaseFuture = pool.release( channel ); - releaseFuture.await( 5, TimeUnit.SECONDS ); - - assertTrue( releaseFuture.isSuccess() ); + await( pool.release( channel ) ); verify( poolHandler ).channelReleased( channel ); } @@ -119,12 +107,7 @@ void shouldFailToAcquireWithWrongCredentials() throws Exception { pool = newPool( AuthTokens.basic( "wrong", "wrong" ) ); - Future future = pool.acquire(); - future.await( 5, TimeUnit.DAYS ); - - assertTrue( future.isDone() ); - assertNotNull( future.cause() ); - assertThat( future.cause(), instanceOf( AuthenticationException.class ) ); + assertThrows( AuthenticationException.class, () -> await( pool.acquire() ) ); verify( poolHandler, never() ).channelCreated( any() ); verify( poolHandler, never() ).channelReleased( any() ); @@ -145,8 +128,7 @@ void shouldAllowAcquireAfterFailures() throws Exception for ( int i = 0; i < maxConnections; i++ ) { - ExecutionException e = assertThrows( ExecutionException.class, () -> acquire( pool ) ); - assertThat( e.getCause(), instanceOf( AuthenticationException.class ) ); + AuthenticationException e = assertThrows( AuthenticationException.class, () -> acquire( pool ) ); } authTokenMap.put( "credentials", value( Neo4jRunner.PASSWORD ) ); @@ -165,9 +147,8 @@ void shouldLimitNumberOfConcurrentConnections() throws Exception assertNotNull( acquire( pool ) ); } - ExecutionException e = assertThrows( ExecutionException.class, () -> acquire( pool ) ); - assertThat( e.getCause(), instanceOf( TimeoutException.class ) ); - assertEquals( e.getCause().getMessage(), "Acquire operation took longer then configured maximum time" ); + TimeoutException e = assertThrows( TimeoutException.class, () -> acquire( pool ) ); + assertEquals( e.getMessage(), "Acquire operation took longer then configured maximum time" ); } @Test @@ -209,11 +190,11 @@ private NettyChannelPool newPool( AuthToken authToken, int maxConnections ) private static Channel acquire( NettyChannelPool pool ) throws Exception { - return pool.acquire().get( 5, TimeUnit.SECONDS ); + return await( pool.acquire() ); } private void release( Channel channel ) throws Exception { - pool.release( channel ).get( 5, TimeUnit.SECONDS ); + await( pool.release( channel ) ); } } diff --git a/driver/src/test/java/org/neo4j/driver/internal/async/pool/NettyChannelTrackerTest.java b/driver/src/test/java/org/neo4j/driver/internal/async/pool/NettyChannelTrackerTest.java index e326836a0c..3f73c6bcb7 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/async/pool/NettyChannelTrackerTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/async/pool/NettyChannelTrackerTest.java @@ -49,15 +49,15 @@ class NettyChannelTrackerTest private final NettyChannelTracker tracker = new NettyChannelTracker( DEV_NULL_METRICS, mock( ChannelGroup.class ), DEV_NULL_LOGGING ); @Test - void shouldIncrementInUseCountWhenChannelCreated() + void shouldIncrementIdleCountWhenChannelCreated() { Channel channel = newChannel(); assertEquals( 0, tracker.inUseChannelCount( address ) ); assertEquals( 0, tracker.idleChannelCount( address ) ); tracker.channelCreated( channel, null ); - assertEquals( 1, tracker.inUseChannelCount( address ) ); - assertEquals( 0, tracker.idleChannelCount( address ) ); + assertEquals( 0, tracker.inUseChannelCount( address ) ); + assertEquals( 1, tracker.idleChannelCount( address ) ); } @Test @@ -68,33 +68,45 @@ void shouldIncrementInUseCountWhenChannelAcquired() assertEquals( 0, tracker.idleChannelCount( address ) ); tracker.channelCreated( channel, null ); + assertEquals( 0, tracker.inUseChannelCount( address ) ); + assertEquals( 1, tracker.idleChannelCount( address ) ); + + tracker.channelAcquired( channel ); assertEquals( 1, tracker.inUseChannelCount( address ) ); assertEquals( 0, tracker.idleChannelCount( address ) ); + } - tracker.channelReleased( channel ); + @Test + void shouldIncrementIdleCountWhenChannelReleased() + { + Channel channel = newChannel(); assertEquals( 0, tracker.inUseChannelCount( address ) ); - assertEquals( 1, tracker.idleChannelCount( address ) ); + assertEquals( 0, tracker.idleChannelCount( address ) ); - tracker.channelAcquired( channel ); + channelCreatedAndAcquired( channel ); assertEquals( 1, tracker.inUseChannelCount( address ) ); assertEquals( 0, tracker.idleChannelCount( address ) ); + + tracker.channelReleased( channel ); + assertEquals( 0, tracker.inUseChannelCount( address ) ); + assertEquals( 1, tracker.idleChannelCount( address ) ); } @Test - void shouldIncrementInuseCountForAddress() + void shouldIncrementIdleCountForAddress() { Channel channel1 = newChannel(); Channel channel2 = newChannel(); Channel channel3 = newChannel(); - assertEquals( 0, tracker.inUseChannelCount( address ) ); + assertEquals( 0, tracker.idleChannelCount( address ) ); tracker.channelCreated( channel1, null ); - assertEquals( 1, tracker.inUseChannelCount( address ) ); + assertEquals( 1, tracker.idleChannelCount( address ) ); tracker.channelCreated( channel2, null ); - assertEquals( 2, tracker.inUseChannelCount( address ) ); + assertEquals( 2, tracker.idleChannelCount( address ) ); tracker.channelCreated( channel3, null ); - assertEquals( 3, tracker.inUseChannelCount( address ) ); - assertEquals( 0, tracker.idleChannelCount( address ) ); + assertEquals( 3, tracker.idleChannelCount( address ) ); + assertEquals( 0, tracker.inUseChannelCount( address ) ); } @Test @@ -104,9 +116,9 @@ void shouldDecrementCountForAddress() Channel channel2 = newChannel(); Channel channel3 = newChannel(); - tracker.channelCreated( channel1, null ); - tracker.channelCreated( channel2, null ); - tracker.channelCreated( channel3, null ); + channelCreatedAndAcquired( channel1 ); + channelCreatedAndAcquired( channel2 ); + channelCreatedAndAcquired( channel3 ); assertEquals( 3, tracker.inUseChannelCount( address ) ); assertEquals( 0, tracker.idleChannelCount( address ) ); @@ -126,7 +138,7 @@ void shouldDecreaseIdleWhenClosedOutsidePool() throws Throwable { // Given Channel channel = newChannel(); - tracker.channelCreated( channel, null ); + channelCreatedAndAcquired( channel ); assertEquals( 1, tracker.inUseChannelCount( address ) ); assertEquals( 0, tracker.idleChannelCount( address ) ); @@ -147,7 +159,7 @@ void shouldDecreaseIdleWhenClosedInsidePool() throws Throwable { // Given Channel channel = newChannel(); - tracker.channelCreated( channel, null ); + channelCreatedAndAcquired( channel ); assertEquals( 1, tracker.inUseChannelCount( address ) ); assertEquals( 0, tracker.idleChannelCount( address ) ); @@ -160,7 +172,6 @@ void shouldDecreaseIdleWhenClosedInsidePool() throws Throwable // Then assertEquals( 0, tracker.inUseChannelCount( address ) ); assertEquals( 0, tracker.idleChannelCount( address ) ); - } @Test @@ -226,4 +237,10 @@ private EmbeddedChannel newChannelWithProtocolV3() setMessageDispatcher( channel, mock( InboundMessageDispatcher.class ) ); return channel; } + + private void channelCreatedAndAcquired( Channel channel ) + { + tracker.channelCreated( channel, null ); + tracker.channelAcquired( channel ); + } } diff --git a/driver/src/test/java/org/neo4j/driver/internal/async/pool/TestConnectionPool.java b/driver/src/test/java/org/neo4j/driver/internal/async/pool/TestConnectionPool.java new file mode 100644 index 0000000000..5259be32a7 --- /dev/null +++ b/driver/src/test/java/org/neo4j/driver/internal/async/pool/TestConnectionPool.java @@ -0,0 +1,122 @@ +/* + * Copyright (c) 2002-2019 "Neo4j," + * Neo4j Sweden AB [http://neo4j.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal.async.pool; + +import io.netty.bootstrap.Bootstrap; +import io.netty.channel.Channel; +import io.netty.channel.embedded.EmbeddedChannel; + +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.CompletionStage; +import java.util.concurrent.atomic.AtomicBoolean; + +import org.neo4j.driver.Logging; +import org.neo4j.driver.internal.BoltServerAddress; +import org.neo4j.driver.internal.async.connection.ChannelConnector; +import org.neo4j.driver.internal.metrics.ListenerEvent; +import org.neo4j.driver.internal.metrics.MetricsListener; +import org.neo4j.driver.internal.spi.Connection; +import org.neo4j.driver.internal.util.Clock; + +import static java.util.concurrent.CompletableFuture.completedFuture; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.neo4j.driver.internal.async.connection.ChannelAttributes.setPoolId; +import static org.neo4j.driver.internal.async.connection.ChannelAttributes.setServerAddress; +import static org.neo4j.driver.internal.util.Futures.completedWithNull; + +public class TestConnectionPool extends ConnectionPoolImpl +{ + final Map channelPoolsByAddress = new HashMap<>(); + private final NettyChannelTracker nettyChannelTracker; + + public TestConnectionPool( Bootstrap bootstrap, NettyChannelTracker nettyChannelTracker, PoolSettings settings, + MetricsListener metricsListener, Logging logging, Clock clock, boolean ownsEventLoopGroup ) + { + super( mock( ChannelConnector.class ), bootstrap, nettyChannelTracker, settings, metricsListener, logging, clock, ownsEventLoopGroup, + newConnectionFactory() ); + this.nettyChannelTracker = nettyChannelTracker; + } + + ExtendedChannelPool getPool( BoltServerAddress address ) + { + return channelPoolsByAddress.get( address ); + } + + @Override + ExtendedChannelPool newPool( BoltServerAddress address ) + { + ExtendedChannelPool channelPool = new ExtendedChannelPool() + { + private final AtomicBoolean isClosed = new AtomicBoolean( false ); + @Override + public CompletionStage acquire() + { + EmbeddedChannel channel = new EmbeddedChannel(); + setServerAddress( channel, address ); + setPoolId( channel, id() ); + + ListenerEvent event = nettyChannelTracker.channelCreating( id() ); + nettyChannelTracker.channelCreated( channel, event ); + nettyChannelTracker.channelAcquired( channel ); + + return completedFuture( channel ); + } + + @Override + public CompletionStage release( Channel channel ) + { + nettyChannelTracker.channelReleased( channel ); + nettyChannelTracker.channelClosed( channel ); + return completedWithNull(); + } + + @Override + public boolean isClosed() + { + return isClosed.get(); + } + + @Override + public String id() + { + return "Pool-" + this.hashCode(); + } + + @Override + public CompletionStage close() + { + isClosed.set( true ); + return completedWithNull(); + } + }; + channelPoolsByAddress.put( address, channelPool ); + return channelPool; + } + + private static ConnectionFactory newConnectionFactory() + { + return ( channel, pool ) -> { + Connection conn = mock( Connection.class ); + when( conn.release() ).thenAnswer( invocation -> pool.release( channel ) ); + return conn; + }; + } +} diff --git a/driver/src/test/java/org/neo4j/driver/internal/cluster/loadbalancing/RoutingTableAndConnectionPoolTest.java b/driver/src/test/java/org/neo4j/driver/internal/cluster/loadbalancing/RoutingTableAndConnectionPoolTest.java index 0d3adaa193..28b28d7016 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/cluster/loadbalancing/RoutingTableAndConnectionPoolTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/cluster/loadbalancing/RoutingTableAndConnectionPoolTest.java @@ -19,11 +19,6 @@ package org.neo4j.driver.internal.cluster.loadbalancing; import io.netty.bootstrap.Bootstrap; -import io.netty.channel.Channel; -import io.netty.channel.ChannelFuture; -import io.netty.channel.ChannelPromise; -import io.netty.channel.embedded.EmbeddedChannel; -import io.netty.util.AttributeKey; import io.netty.util.concurrent.GlobalEventExecutor; import org.junit.jupiter.api.Test; @@ -49,26 +44,19 @@ import org.neo4j.driver.internal.BoltServerAddress; import org.neo4j.driver.internal.InternalBookmark; import org.neo4j.driver.internal.async.connection.BootstrapFactory; -import org.neo4j.driver.internal.async.connection.ChannelConnector; -import org.neo4j.driver.internal.async.pool.ConnectionFactory; -import org.neo4j.driver.internal.async.pool.ConnectionPoolImpl; -import org.neo4j.driver.internal.async.pool.ExtendedChannelPool; import org.neo4j.driver.internal.async.pool.NettyChannelTracker; import org.neo4j.driver.internal.async.pool.PoolSettings; +import org.neo4j.driver.internal.async.pool.TestConnectionPool; import org.neo4j.driver.internal.cluster.ClusterComposition; import org.neo4j.driver.internal.cluster.Rediscovery; import org.neo4j.driver.internal.cluster.RoutingTable; import org.neo4j.driver.internal.cluster.RoutingTableRegistry; import org.neo4j.driver.internal.cluster.RoutingTableRegistryImpl; -import org.neo4j.driver.internal.messaging.BoltProtocol; -import org.neo4j.driver.internal.messaging.Message; import org.neo4j.driver.internal.metrics.InternalAbstractMetrics; import org.neo4j.driver.internal.spi.Connection; import org.neo4j.driver.internal.spi.ConnectionPool; -import org.neo4j.driver.internal.spi.ResponseHandler; import org.neo4j.driver.internal.util.Clock; import org.neo4j.driver.internal.util.Futures; -import org.neo4j.driver.internal.util.ServerVersion; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.junit.MatcherAssert.assertThat; @@ -80,7 +68,6 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; import static org.neo4j.driver.Logging.none; -import static org.neo4j.driver.internal.async.connection.ChannelAttributes.setServerAddress; import static org.neo4j.driver.internal.cluster.RediscoveryUtils.contextWithDatabase; import static org.neo4j.driver.internal.cluster.RoutingSettings.STALE_ROUTING_TABLE_PURGE_DELAY_MS; import static org.neo4j.driver.internal.messaging.request.MultiDatabaseUtil.ABSENT_DB_NAME; @@ -320,27 +307,14 @@ private void acquireAndReleaseConnections( LoadBalancer loadBalancer ) throws In assertThat( errors.size(), equalTo( 0 ) ); } - private ChannelFuture newChannelFuture( BoltServerAddress address ) - { - EmbeddedChannel channel = new EmbeddedChannel(); - ChannelPromise channelPromise = channel.newPromise(); - channelPromise.setSuccess(); - setServerAddress( channel, address ); - return channelPromise; - } - private ConnectionPool newConnectionPool() { InternalAbstractMetrics metrics = DEV_NULL_METRICS; PoolSettings poolSettings = new PoolSettings( 10, 5000, -1, -1 ); - - ChannelConnector connector = ( address, bootstrap ) -> newChannelFuture( address ); Bootstrap bootstrap = BootstrapFactory.newBootstrap(); - NettyChannelTracker channelTracker = new NettyChannelTracker( metrics, bootstrap.config().group().next(), logging ); - ConnectionFactory connectionFactory = PooledConnection::new; - return new ConnectionPoolImpl( connector, bootstrap, channelTracker, poolSettings, metrics, logging, clock, true, connectionFactory ); + return new TestConnectionPool( bootstrap, channelTracker, poolSettings, metrics, logging, clock, true ); } private RoutingTableRegistryImpl newRoutingTables( ConnectionPool connectionPool, Rediscovery rediscovery ) @@ -395,105 +369,4 @@ public CompletionStage lookupClusterComposition( RoutingTabl return CompletableFuture.completedFuture( composition ); } } - - // This connection can be acquired from a connection pool and/or released back to it. - private static class PooledConnection implements Connection - { - private final Channel channel; - private final ExtendedChannelPool pool; - - PooledConnection( Channel channel, ExtendedChannelPool pool ) - { - this.channel = channel; - this.pool = pool; - - this.channel.attr( AttributeKey.valueOf( "channelPool" ) ).setIfAbsent( pool ); - } - - @Override - public boolean isOpen() - { - return false; - } - - @Override - public void enableAutoRead() - { - - } - - @Override - public void disableAutoRead() - { - - } - - @Override - public void write( Message message, ResponseHandler handler ) - { - - } - - @Override - public void write( Message message1, ResponseHandler handler1, Message message2, ResponseHandler handler2 ) - { - - } - - @Override - public void writeAndFlush( Message message, ResponseHandler handler ) - { - - } - - @Override - public void writeAndFlush( Message message1, ResponseHandler handler1, Message message2, ResponseHandler handler2 ) - { - - } - - @Override - public CompletionStage reset() - { - return Futures.completedWithNull(); - } - - @Override - public CompletionStage release() - { - CompletableFuture releaseFuture = new CompletableFuture<>(); - pool.release( channel ).addListener( ignore -> releaseFuture.complete( null ) ); - return releaseFuture; - } - - @Override - public void terminateAndRelease( String reason ) - { - - } - - @Override - public BoltServerAddress serverAddress() - { - return null; - } - - @Override - public ServerVersion serverVersion() - { - return null; - } - - @Override - public BoltProtocol protocol() - { - return null; - } - - @Override - public void flush() - { - - } - } } diff --git a/driver/src/test/java/org/neo4j/driver/internal/handlers/ChannelReleasingResetResponseHandlerTest.java b/driver/src/test/java/org/neo4j/driver/internal/handlers/ChannelReleasingResetResponseHandlerTest.java index 601df886cb..a0762b5373 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/handlers/ChannelReleasingResetResponseHandlerTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/handlers/ChannelReleasingResetResponseHandlerTest.java @@ -19,7 +19,6 @@ package org.neo4j.driver.internal.handlers; import io.netty.channel.embedded.EmbeddedChannel; -import io.netty.channel.pool.ChannelPool; import io.netty.util.concurrent.Future; import io.netty.util.concurrent.ImmediateEventExecutor; import org.junit.jupiter.api.AfterEach; @@ -28,6 +27,7 @@ import java.util.concurrent.CompletableFuture; import org.neo4j.driver.internal.async.inbound.InboundMessageDispatcher; +import org.neo4j.driver.internal.async.pool.ExtendedChannelPool; import org.neo4j.driver.internal.util.Clock; import org.neo4j.driver.internal.util.FakeClock; @@ -41,6 +41,7 @@ import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.neo4j.driver.internal.async.connection.ChannelAttributes.lastUsedTimestamp; +import static org.neo4j.driver.internal.util.Futures.completedWithNull; class ChannelReleasingResetResponseHandlerTest { @@ -56,7 +57,7 @@ void tearDown() @Test void shouldReleaseChannelOnSuccess() { - ChannelPool pool = newChannelPoolMock(); + ExtendedChannelPool pool = newChannelPoolMock(); FakeClock clock = new FakeClock(); clock.progress( 5 ); CompletableFuture releaseFuture = new CompletableFuture<>(); @@ -73,7 +74,7 @@ void shouldReleaseChannelOnSuccess() @Test void shouldCloseAndReleaseChannelOnFailure() { - ChannelPool pool = newChannelPoolMock(); + ExtendedChannelPool pool = newChannelPoolMock(); FakeClock clock = new FakeClock(); clock.progress( 100 ); CompletableFuture releaseFuture = new CompletableFuture<>(); @@ -92,17 +93,16 @@ private void verifyLastUsedTimestamp( int expectedValue ) assertEquals( expectedValue, lastUsedTimestamp( channel ).intValue() ); } - private ChannelReleasingResetResponseHandler newHandler( ChannelPool pool, Clock clock, + private ChannelReleasingResetResponseHandler newHandler( ExtendedChannelPool pool, Clock clock, CompletableFuture releaseFuture ) { return new ChannelReleasingResetResponseHandler( channel, pool, messageDispatcher, clock, releaseFuture ); } - private static ChannelPool newChannelPoolMock() + private static ExtendedChannelPool newChannelPoolMock() { - ChannelPool pool = mock( ChannelPool.class ); - Future releasedFuture = ImmediateEventExecutor.INSTANCE.newSucceededFuture( null ); - when( pool.release( any() ) ).thenReturn( releasedFuture ); + ExtendedChannelPool pool = mock( ExtendedChannelPool.class ); + when( pool.release( any() ) ).thenReturn( completedWithNull() ); return pool; } } diff --git a/pom.xml b/pom.xml index 57e91813c9..e253af1498 100644 --- a/pom.xml +++ b/pom.xml @@ -64,7 +64,7 @@ io.netty netty-handler - 4.1.22.Final + 4.1.41.Final io.projectreactor