diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/NettyConnection.java b/driver/src/main/java/org/neo4j/driver/internal/async/NettyConnection.java index 7ae902ac80..c77f7a0e65 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/NettyConnection.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/NettyConnection.java @@ -20,9 +20,9 @@ import io.netty.channel.Channel; import io.netty.channel.pool.ChannelPool; -import io.netty.util.concurrent.Promise; import java.util.Map; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; import java.util.concurrent.atomic.AtomicBoolean; @@ -39,9 +39,6 @@ import org.neo4j.driver.internal.util.ServerVersion; import org.neo4j.driver.v1.Value; -import static java.util.concurrent.CompletableFuture.completedFuture; -import static org.neo4j.driver.internal.util.Futures.asCompletionStage; - public class NettyConnection implements Connection { private final Channel channel; @@ -49,6 +46,7 @@ public class NettyConnection implements Connection private final BoltServerAddress serverAddress; private final ServerVersion serverVersion; private final ChannelPool channelPool; + private final CompletableFuture releaseFuture; private final Clock clock; private final AtomicBoolean open = new AtomicBoolean( true ); @@ -61,6 +59,7 @@ public NettyConnection( Channel channel, ChannelPool channelPool, Clock clock ) this.serverAddress = ChannelAttributes.serverAddress( channel ); this.serverVersion = ChannelAttributes.serverVersion( channel ); this.channelPool = channelPool; + this.releaseFuture = new CompletableFuture<>(); this.clock = clock; } @@ -111,14 +110,9 @@ public CompletionStage release() { if ( open.compareAndSet( true, false ) ) { - Promise releasePromise = channel.eventLoop().newPromise(); - reset( new ResetResponseHandler( channel, channelPool, messageDispatcher, clock, releasePromise ) ); - return asCompletionStage( releasePromise ); - } - else - { - return completedFuture( null ); + reset( new ResetResponseHandler( channel, channelPool, messageDispatcher, clock, releaseFuture ) ); } + return releaseFuture; } @Override diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/pool/ConnectionPoolImpl.java b/driver/src/main/java/org/neo4j/driver/internal/async/pool/ConnectionPoolImpl.java index 66299e9b37..52d99ae0b3 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/pool/ConnectionPoolImpl.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/pool/ConnectionPoolImpl.java @@ -25,6 +25,7 @@ import io.netty.util.concurrent.Future; import java.util.Map; +import java.util.Set; import java.util.concurrent.CompletionException; import java.util.concurrent.CompletionStage; import java.util.concurrent.ConcurrentHashMap; @@ -58,10 +59,16 @@ public class ConnectionPoolImpl implements ConnectionPool public ConnectionPoolImpl( ChannelConnector connector, Bootstrap bootstrap, PoolSettings settings, Logging logging, Clock clock ) + { + this( connector, bootstrap, new ActiveChannelTracker( logging ), settings, logging, clock ); + } + + ConnectionPoolImpl( ChannelConnector connector, Bootstrap bootstrap, ActiveChannelTracker activeChannelTracker, + PoolSettings settings, Logging logging, Clock clock ) { this.connector = connector; this.bootstrap = bootstrap; - this.activeChannelTracker = new ActiveChannelTracker( logging ); + this.activeChannelTracker = activeChannelTracker; this.channelHealthChecker = new NettyChannelHealthChecker( settings, clock, logging ); this.settings = settings; this.clock = clock; @@ -86,27 +93,30 @@ public CompletionStage acquire( BoltServerAddress address ) } @Override - public void purge( BoltServerAddress address ) + public void retainAll( Set addressesToRetain ) { - log.info( "Purging connections towards %s", address ); - - // purge active connections - activeChannelTracker.purge( address ); - - // purge idle connections in the pool and pool itself - ChannelPool pool = pools.remove( address ); - if ( pool != null ) + for ( BoltServerAddress address : pools.keySet() ) { - pool.close(); + if ( !addressesToRetain.contains( address ) ) + { + int activeChannels = activeChannelTracker.activeChannelCount( address ); + if ( activeChannels == 0 ) + { + // address is not present in updated routing table and has no active connections + // it's now safe to terminate corresponding connection pool and forget about it + + ChannelPool pool = pools.remove( address ); + if ( pool != null ) + { + log.info( "Closing connection pool towards %s, it has no active connections " + + "and is not in the routing table", address ); + pool.close(); + } + } + } } } - @Override - public boolean hasAddress( BoltServerAddress address ) - { - return pools.containsKey( address ); - } - @Override public int activeConnections( BoltServerAddress address ) { @@ -157,7 +167,7 @@ private ChannelPool getOrCreatePool( BoltServerAddress address ) return pool; } - private NettyChannelPool newPool( BoltServerAddress address ) + ChannelPool newPool( BoltServerAddress address ) { return new NettyChannelPool( address, connector, bootstrap, activeChannelTracker, channelHealthChecker, settings.connectionAcquisitionTimeout(), settings.maxConnectionPoolSize() ); diff --git a/driver/src/main/java/org/neo4j/driver/internal/cluster/AddressSet.java b/driver/src/main/java/org/neo4j/driver/internal/cluster/AddressSet.java index 5f6fc0bfc2..db98e1d449 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/cluster/AddressSet.java +++ b/driver/src/main/java/org/neo4j/driver/internal/cluster/AddressSet.java @@ -39,54 +39,9 @@ public int size() return addresses.length; } - public synchronized void update( Set addresses, Set removed ) + public synchronized void update( Set addresses ) { - BoltServerAddress[] prev = this.addresses; - if ( addresses.isEmpty() ) - { - this.addresses = NONE; - return; - } - if ( prev.length == 0 ) - { - this.addresses = addresses.toArray( NONE ); - return; - } - BoltServerAddress[] copy = null; - if ( addresses.size() != prev.length ) - { - copy = new BoltServerAddress[addresses.size()]; - } - int j = 0; - for ( int i = 0; i < prev.length; i++ ) - { - if ( addresses.remove( prev[i] ) ) - { - if ( copy != null ) - { - copy[j++] = prev[i]; - } - } - else - { - removed.add( prev[i] ); - if ( copy == null ) - { - copy = new BoltServerAddress[prev.length]; - System.arraycopy( prev, 0, copy, 0, i ); - j = i; - } - } - } - if ( copy == null ) - { - return; - } - for ( BoltServerAddress address : addresses ) - { - copy[j++] = address; - } - this.addresses = copy; + this.addresses = addresses.toArray( NONE ); } public synchronized void remove( BoltServerAddress address ) diff --git a/driver/src/main/java/org/neo4j/driver/internal/cluster/ClusterRoutingTable.java b/driver/src/main/java/org/neo4j/driver/internal/cluster/ClusterRoutingTable.java index 83e9b4b7f3..a97423b679 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/cluster/ClusterRoutingTable.java +++ b/driver/src/main/java/org/neo4j/driver/internal/cluster/ClusterRoutingTable.java @@ -19,6 +19,7 @@ package org.neo4j.driver.internal.cluster; +import java.util.Collections; import java.util.HashSet; import java.util.LinkedHashSet; import java.util.Set; @@ -43,7 +44,7 @@ public class ClusterRoutingTable implements RoutingTable public ClusterRoutingTable( Clock clock, BoltServerAddress... routingAddresses ) { this( clock ); - routers.update( new LinkedHashSet<>( asList( routingAddresses ) ), new HashSet() ); + routers.update( new LinkedHashSet<>( asList( routingAddresses ) ) ); } private ClusterRoutingTable( Clock clock ) @@ -66,14 +67,12 @@ public boolean isStaleFor( AccessMode mode ) } @Override - public synchronized Set update( ClusterComposition cluster ) + public synchronized void update( ClusterComposition cluster ) { expirationTimeout = cluster.expirationTimestamp(); - Set removed = new HashSet<>(); - readers.update( cluster.readers(), removed ); - writers.update( cluster.writers(), removed ); - routers.update( cluster.routers(), removed ); - return removed; + readers.update( cluster.readers() ); + writers.update( cluster.writers() ); + routers.update( cluster.routers() ); } @Override @@ -102,6 +101,16 @@ public AddressSet routers() return routers; } + @Override + public Set servers() + { + Set servers = new HashSet<>(); + Collections.addAll( servers, readers.toArray() ); + Collections.addAll( servers, writers.toArray() ); + Collections.addAll( servers, routers.toArray() ); + return servers; + } + @Override public void removeWriter( BoltServerAddress toRemove ) { diff --git a/driver/src/main/java/org/neo4j/driver/internal/cluster/RoutingProcedureRunner.java b/driver/src/main/java/org/neo4j/driver/internal/cluster/RoutingProcedureRunner.java index 2902a5c775..97f00244c9 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/cluster/RoutingProcedureRunner.java +++ b/driver/src/main/java/org/neo4j/driver/internal/cluster/RoutingProcedureRunner.java @@ -53,18 +53,9 @@ public CompletionStage run( CompletionStage { Statement procedure = procedureStatement( connection.serverVersion() ); - return runProcedure( connection, procedure ).handle( ( records, error ) -> - { - Throwable cause = Futures.completionErrorCause( error ); - if ( cause != null ) - { - return handleError( procedure, cause ); - } - else - { - return new RoutingProcedureResponse( procedure, records ); - } - } ); + return runProcedure( connection, procedure ) + .thenCompose( records -> releaseConnection( connection, records ) ) + .handle( ( records, error ) -> processProcedureResponse( procedure, records, error ) ); } ); } @@ -87,6 +78,30 @@ private Statement procedureStatement( ServerVersion serverVersion ) } } + private CompletionStage> releaseConnection( Connection connection, List records ) + { + // It is not strictly required to release connection after routing procedure invocation because it'll + // be released by the PULL_ALL response handler after result is fully fetched. Such release will happen + // in background. However, releasing it early as part of whole chain makes it easier to reason about + // rediscovery in stub server tests. Some of them assume connections to instances not present in new + // routing table will be closed immediately. + return connection.release().thenApply( ignore -> records ); + } + + private RoutingProcedureResponse processProcedureResponse( Statement procedure, List records, + Throwable error ) + { + Throwable cause = Futures.completionErrorCause( error ); + if ( cause != null ) + { + return handleError( procedure, cause ); + } + else + { + return new RoutingProcedureResponse( procedure, records ); + } + } + private RoutingProcedureResponse handleError( Statement procedure, Throwable error ) { if ( error instanceof ClientException ) diff --git a/driver/src/main/java/org/neo4j/driver/internal/cluster/RoutingTable.java b/driver/src/main/java/org/neo4j/driver/internal/cluster/RoutingTable.java index 859920f071..5b9802c492 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/cluster/RoutingTable.java +++ b/driver/src/main/java/org/neo4j/driver/internal/cluster/RoutingTable.java @@ -27,7 +27,7 @@ public interface RoutingTable { boolean isStaleFor( AccessMode mode ); - Set update( ClusterComposition cluster ); + void update( ClusterComposition cluster ); void forget( BoltServerAddress address ); @@ -37,5 +37,7 @@ public interface RoutingTable AddressSet routers(); + Set servers(); + void removeWriter( BoltServerAddress toRemove ); } diff --git a/driver/src/main/java/org/neo4j/driver/internal/cluster/loadbalancing/LoadBalancer.java b/driver/src/main/java/org/neo4j/driver/internal/cluster/loadbalancing/LoadBalancer.java index a89b9bd8aa..5b4a50f4e4 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/cluster/loadbalancing/LoadBalancer.java +++ b/driver/src/main/java/org/neo4j/driver/internal/cluster/loadbalancing/LoadBalancer.java @@ -20,7 +20,6 @@ import io.netty.util.concurrent.EventExecutorGroup; -import java.util.Set; import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; @@ -125,10 +124,8 @@ public CompletionStage close() private synchronized void forget( BoltServerAddress address ) { - // First remove from the load balancer, to prevent concurrent threads from making connections to them. + // remove from the routing table, to prevent concurrent threads from making connections to this address routingTable.forget( address ); - // drop all current connections to the address - connectionPool.purge( address ); } private synchronized CompletionStage freshRoutingTable( AccessMode mode ) @@ -171,18 +168,21 @@ else if ( routingTable.isStaleFor( mode ) ) private synchronized void freshClusterCompositionFetched( ClusterComposition composition ) { - Set removed = routingTable.update( composition ); - - for ( BoltServerAddress address : removed ) + try { - connectionPool.purge( address ); - } + routingTable.update( composition ); + connectionPool.retainAll( routingTable.servers() ); - log.info( "Refreshed routing information. %s", routingTable ); + log.info( "Refreshed routing information. %s", routingTable ); - CompletableFuture routingTableFuture = refreshRoutingTableFuture; - refreshRoutingTableFuture = null; - routingTableFuture.complete( routingTable ); + CompletableFuture routingTableFuture = refreshRoutingTableFuture; + refreshRoutingTableFuture = null; + routingTableFuture.complete( routingTable ); + } + catch ( Throwable error ) + { + clusterCompositionLookupFailed( error ); + } } private synchronized void clusterCompositionLookupFailed( Throwable error ) diff --git a/driver/src/main/java/org/neo4j/driver/internal/handlers/ResetResponseHandler.java b/driver/src/main/java/org/neo4j/driver/internal/handlers/ResetResponseHandler.java index 12bfd06e9d..c09b5ced85 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/handlers/ResetResponseHandler.java +++ b/driver/src/main/java/org/neo4j/driver/internal/handlers/ResetResponseHandler.java @@ -20,9 +20,10 @@ import io.netty.channel.Channel; import io.netty.channel.pool.ChannelPool; -import io.netty.util.concurrent.Promise; +import io.netty.util.concurrent.Future; import java.util.Map; +import java.util.concurrent.CompletableFuture; import org.neo4j.driver.internal.async.inbound.InboundMessageDispatcher; import org.neo4j.driver.internal.spi.ResponseHandler; @@ -37,16 +38,16 @@ public class ResetResponseHandler implements ResponseHandler private final ChannelPool pool; private final InboundMessageDispatcher messageDispatcher; private final Clock clock; - private final Promise releasePromise; + private final CompletableFuture releaseFuture; public ResetResponseHandler( Channel channel, ChannelPool pool, InboundMessageDispatcher messageDispatcher, - Clock clock, Promise releasePromise ) + Clock clock, CompletableFuture releaseFuture ) { this.channel = channel; this.pool = pool; this.messageDispatcher = messageDispatcher; this.clock = clock; - this.releasePromise = releasePromise; + this.releaseFuture = releaseFuture; } @Override @@ -72,13 +73,7 @@ private void releaseChannel() messageDispatcher.unMuteAckFailure(); setLastUsedTimestamp( channel, clock.millis() ); - if ( releasePromise == null ) - { - pool.release( channel ); - } - else - { - pool.release( channel, releasePromise ); - } + Future released = pool.release( channel ); + released.addListener( ignore -> releaseFuture.complete( null ) ); } } diff --git a/driver/src/main/java/org/neo4j/driver/internal/spi/ConnectionPool.java b/driver/src/main/java/org/neo4j/driver/internal/spi/ConnectionPool.java index e6a215756c..e322eaf45e 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/spi/ConnectionPool.java +++ b/driver/src/main/java/org/neo4j/driver/internal/spi/ConnectionPool.java @@ -18,6 +18,7 @@ */ package org.neo4j.driver.internal.spi; +import java.util.Set; import java.util.concurrent.CompletionStage; import org.neo4j.driver.internal.BoltServerAddress; @@ -26,9 +27,7 @@ public interface ConnectionPool { CompletionStage acquire( BoltServerAddress address ); - void purge( BoltServerAddress address ); - - boolean hasAddress( BoltServerAddress address ); + void retainAll( Set addressesToRetain ); int activeConnections( BoltServerAddress address ); diff --git a/driver/src/main/java/org/neo4j/driver/internal/util/Futures.java b/driver/src/main/java/org/neo4j/driver/internal/util/Futures.java index 01441fdcae..ec56b5235d 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/util/Futures.java +++ b/driver/src/main/java/org/neo4j/driver/internal/util/Futures.java @@ -45,6 +45,10 @@ else if ( future.isSuccess() ) { result.complete( future.getNow() ); } + else if ( future.cause() != null ) + { + result.completeExceptionally( future.cause() ); + } else { future.addListener( ignore -> diff --git a/driver/src/test/java/org/neo4j/driver/internal/async/NettyConnectionTest.java b/driver/src/test/java/org/neo4j/driver/internal/async/NettyConnectionTest.java index 4b45a4e6e4..d2ab115426 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/async/NettyConnectionTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/async/NettyConnectionTest.java @@ -28,6 +28,7 @@ import org.junit.Test; import java.util.Set; +import java.util.concurrent.CompletionStage; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; import java.util.concurrent.TimeUnit; @@ -227,6 +228,30 @@ public void shouldReturnServerVersionWhenReleased() assertEquals( version, connection.serverVersion() ); } + @Test + public void shouldReturnSameCompletionStageFromRelease() + { + EmbeddedChannel channel = new EmbeddedChannel(); + InboundMessageDispatcher messageDispatcher = new InboundMessageDispatcher( channel, DEV_NULL_LOGGING ); + ChannelAttributes.setMessageDispatcher( channel, messageDispatcher ); + + NettyConnection connection = newConnection( channel ); + + CompletionStage releaseStage1 = connection.release(); + CompletionStage releaseStage2 = connection.release(); + CompletionStage releaseStage3 = connection.release(); + + channel.runPendingTasks(); + + // RESET should be send only once + assertEquals( 1, channel.outboundMessages().size() ); + assertEquals( RESET, channel.outboundMessages().poll() ); + + // all returned stages should be the same + assertEquals( releaseStage1, releaseStage2 ); + assertEquals( releaseStage2, releaseStage3 ); + } + private void testWriteInEventLoop( String threadName, Consumer action ) throws Exception { EmbeddedChannel channel = spy( new EmbeddedChannel() ); diff --git a/driver/src/test/java/org/neo4j/driver/internal/async/pool/ConnectionPoolImplTest.java b/driver/src/test/java/org/neo4j/driver/internal/async/pool/ConnectionPoolImplTest.java index ef74dc18dd..719e07ef1c 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/async/pool/ConnectionPoolImplTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/async/pool/ConnectionPoolImplTest.java @@ -19,14 +19,22 @@ package org.neo4j.driver.internal.async.pool; import io.netty.bootstrap.Bootstrap; +import io.netty.channel.Channel; +import io.netty.channel.pool.ChannelPool; +import io.netty.util.concurrent.ImmediateEventExecutor; import org.junit.After; import org.junit.Before; import org.junit.Rule; import org.junit.Test; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; + import org.neo4j.driver.internal.BoltServerAddress; import org.neo4j.driver.internal.ConnectionSettings; import org.neo4j.driver.internal.async.BootstrapFactory; +import org.neo4j.driver.internal.async.ChannelConnector; import org.neo4j.driver.internal.async.ChannelConnectorImpl; import org.neo4j.driver.internal.security.SecurityPlan; import org.neo4j.driver.internal.spi.Connection; @@ -34,20 +42,31 @@ import org.neo4j.driver.v1.exceptions.ServiceUnavailableException; import org.neo4j.driver.v1.util.TestNeo4j; +import static java.util.Arrays.asList; +import static java.util.Collections.singleton; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.startsWith; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertThat; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; +import static org.mockito.Mockito.doReturn; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyZeroInteractions; +import static org.mockito.Mockito.when; +import static org.neo4j.driver.internal.BoltServerAddress.LOCAL_DEFAULT; import static org.neo4j.driver.internal.logging.DevNullLogging.DEV_NULL_LOGGING; import static org.neo4j.driver.v1.util.TestUtil.await; public class ConnectionPoolImplTest { + private static final BoltServerAddress ADDRESS_1 = new BoltServerAddress( "server:1" ); + private static final BoltServerAddress ADDRESS_2 = new BoltServerAddress( "server:2" ); + private static final BoltServerAddress ADDRESS_3 = new BoltServerAddress( "server:3" ); + @Rule public final TestNeo4j neo4j = new TestNeo4j(); @@ -60,13 +79,13 @@ public void setUp() throws Exception } @After - public void tearDown() throws Exception + public void tearDown() { pool.close(); } @Test - public void shouldAcquireConnectionWhenPoolIsEmpty() throws Exception + public void shouldAcquireConnectionWhenPoolIsEmpty() { Connection connection = await( pool.acquire( neo4j.address() ) ); @@ -74,7 +93,7 @@ public void shouldAcquireConnectionWhenPoolIsEmpty() throws Exception } @Test - public void shouldAcquireIdleConnection() throws Exception + public void shouldAcquireIdleConnection() { Connection connection1 = await( pool.acquire( neo4j.address() ) ); await( connection1.release() ); @@ -84,7 +103,7 @@ public void shouldAcquireIdleConnection() throws Exception } @Test - public void shouldFailToAcquireConnectionToWrongAddress() throws Exception + public void shouldFailToAcquireConnectionToWrongAddress() { try { @@ -99,7 +118,7 @@ public void shouldFailToAcquireConnectionToWrongAddress() throws Exception } @Test - public void shouldFailToAcquireWhenPoolClosed() throws Exception + public void shouldFailToAcquireWhenPoolClosed() { Connection connection = await( pool.acquire( neo4j.address() ) ); await( connection.release() ); @@ -118,59 +137,121 @@ public void shouldFailToAcquireWhenPoolClosed() throws Exception } @Test - public void shouldPurgeAddressWithConnections() + public void shouldNotCloseWhenClosed() { - Connection connection1 = await( pool.acquire( neo4j.address() ) ); - Connection connection2 = await( pool.acquire( neo4j.address() ) ); - Connection connection3 = await( pool.acquire( neo4j.address() ) ); - - assertNotNull( connection1 ); - assertNotNull( connection2 ); - assertNotNull( connection3 ); + assertNull( await( pool.close() ) ); + assertTrue( pool.close().toCompletableFuture().isDone() ); + } - assertEquals( 3, pool.activeConnections( neo4j.address() ) ); + @Test + public void shouldDoNothingWhenRetainOnEmptyPool() + { + ActiveChannelTracker activeChannelTracker = mock( ActiveChannelTracker.class ); + TestConnectionPool pool = new TestConnectionPool( activeChannelTracker ); - pool.purge( neo4j.address() ); + pool.retainAll( singleton( LOCAL_DEFAULT ) ); - assertEquals( 0, pool.activeConnections( neo4j.address() ) ); + verifyZeroInteractions( activeChannelTracker ); } @Test - public void shouldPurgeAddressWithoutConnections() + public void shouldRetainSpecifiedAddresses() { - assertEquals( 0, pool.activeConnections( neo4j.address() ) ); + ActiveChannelTracker activeChannelTracker = mock( ActiveChannelTracker.class ); + TestConnectionPool pool = new TestConnectionPool( activeChannelTracker ); - pool.purge( neo4j.address() ); + pool.acquire( ADDRESS_1 ); + pool.acquire( ADDRESS_2 ); + pool.acquire( ADDRESS_3 ); - assertEquals( 0, pool.activeConnections( neo4j.address() ) ); + pool.retainAll( new HashSet<>( asList( ADDRESS_1, ADDRESS_2, ADDRESS_3 ) ) ); + for ( ChannelPool channelPool : pool.channelPoolsByAddress.values() ) + { + verify( channelPool, never() ).close(); + } } @Test - public void shouldCheckIfPoolHasAddress() + public void shouldClosePoolsWhenRetaining() { - assertFalse( pool.hasAddress( neo4j.address() ) ); + ActiveChannelTracker activeChannelTracker = mock( ActiveChannelTracker.class ); + TestConnectionPool pool = new TestConnectionPool( activeChannelTracker ); - await( pool.acquire( neo4j.address() ) ); + pool.acquire( ADDRESS_1 ); + pool.acquire( ADDRESS_2 ); + pool.acquire( ADDRESS_3 ); - assertTrue( pool.hasAddress( neo4j.address() ) ); + when( activeChannelTracker.activeChannelCount( ADDRESS_1 ) ).thenReturn( 2 ); + when( activeChannelTracker.activeChannelCount( ADDRESS_2 ) ).thenReturn( 0 ); + when( activeChannelTracker.activeChannelCount( ADDRESS_3 ) ).thenReturn( 3 ); + + pool.retainAll( new HashSet<>( asList( ADDRESS_1, ADDRESS_3 ) ) ); + verify( pool.getPool( ADDRESS_1 ), never() ).close(); + verify( pool.getPool( ADDRESS_2 ) ).close(); + verify( pool.getPool( ADDRESS_3 ), never() ).close(); } @Test - public void shouldNotCloseWhenClosed() + public void shouldNotClosePoolsWithActiveConnectionsWhenRetaining() { - assertNull( await( pool.close() ) ); - assertTrue( pool.close().toCompletableFuture().isDone() ); + ActiveChannelTracker activeChannelTracker = mock( ActiveChannelTracker.class ); + TestConnectionPool pool = new TestConnectionPool( activeChannelTracker ); + + pool.acquire( ADDRESS_1 ); + pool.acquire( ADDRESS_2 ); + pool.acquire( ADDRESS_3 ); + + when( activeChannelTracker.activeChannelCount( ADDRESS_1 ) ).thenReturn( 1 ); + when( activeChannelTracker.activeChannelCount( ADDRESS_2 ) ).thenReturn( 42 ); + when( activeChannelTracker.activeChannelCount( ADDRESS_3 ) ).thenReturn( 0 ); + + pool.retainAll( singleton( ADDRESS_2 ) ); + verify( pool.getPool( ADDRESS_1 ), never() ).close(); + verify( pool.getPool( ADDRESS_2 ), never() ).close(); + verify( pool.getPool( ADDRESS_3 ) ).close(); } private ConnectionPoolImpl newPool() throws Exception { FakeClock clock = new FakeClock(); ConnectionSettings connectionSettings = new ConnectionSettings( neo4j.authToken(), 5000 ); - ChannelConnectorImpl connector = - new ChannelConnectorImpl( connectionSettings, SecurityPlan.forAllCertificates(), + ChannelConnector connector = new ChannelConnectorImpl( connectionSettings, SecurityPlan.forAllCertificates(), DEV_NULL_LOGGING, clock ); - PoolSettings poolSettings = new PoolSettings( 10, 5000, -1, -1 ); + PoolSettings poolSettings = newSettings(); Bootstrap bootstrap = BootstrapFactory.newBootstrap( 1 ); return new ConnectionPoolImpl( connector, bootstrap, poolSettings, DEV_NULL_LOGGING, clock ); } + + private static PoolSettings newSettings() + { + return new PoolSettings( 10, 5000, -1, -1 ); + } + + private static class TestConnectionPool extends ConnectionPoolImpl + { + final Map channelPoolsByAddress = new HashMap<>(); + + TestConnectionPool( ActiveChannelTracker activeChannelTracker ) + { + super( mock( ChannelConnector.class ), mock( Bootstrap.class ), activeChannelTracker, newSettings(), + DEV_NULL_LOGGING, new FakeClock() ); + } + + ChannelPool getPool( BoltServerAddress address ) + { + ChannelPool pool = channelPoolsByAddress.get( address ); + assertNotNull( pool ); + return pool; + } + + @Override + ChannelPool newPool( BoltServerAddress address ) + { + ChannelPool channelPool = mock( ChannelPool.class ); + Channel channel = mock( Channel.class ); + doReturn( ImmediateEventExecutor.INSTANCE.newSucceededFuture( channel ) ).when( channelPool ).acquire(); + channelPoolsByAddress.put( address, channelPool ); + return channelPool; + } + } } diff --git a/driver/src/test/java/org/neo4j/driver/internal/cluster/AddressSetTest.java b/driver/src/test/java/org/neo4j/driver/internal/cluster/AddressSetTest.java index ce92997dba..25eebcfc6a 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/cluster/AddressSetTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/cluster/AddressSetTest.java @@ -20,13 +20,11 @@ import org.junit.Test; -import java.util.HashSet; import java.util.LinkedHashSet; import java.util.Set; import org.neo4j.driver.internal.BoltServerAddress; -import static java.util.Collections.singleton; import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; @@ -39,7 +37,7 @@ public void shouldPreserveOrderWhenAdding() throws Exception Set servers = addresses( "one", "two", "tre" ); AddressSet set = new AddressSet(); - set.update( servers, new HashSet() ); + set.update( servers ); assertArrayEquals( new BoltServerAddress[]{ new BoltServerAddress( "one" ), @@ -48,7 +46,7 @@ public void shouldPreserveOrderWhenAdding() throws Exception // when servers.add( new BoltServerAddress( "fyr" ) ); - set.update( servers, new HashSet() ); + set.update( servers ); // then assertArrayEquals( new BoltServerAddress[]{ @@ -64,7 +62,7 @@ public void shouldPreserveOrderWhenRemoving() throws Exception // given Set servers = addresses( "one", "two", "tre" ); AddressSet set = new AddressSet(); - set.update( servers, new HashSet() ); + set.update( servers ); assertArrayEquals( new BoltServerAddress[]{ new BoltServerAddress( "one" ), @@ -86,7 +84,7 @@ public void shouldPreserveOrderWhenRemovingThroughUpdate() throws Exception // given Set servers = addresses( "one", "two", "tre" ); AddressSet set = new AddressSet(); - set.update( servers, new HashSet() ); + set.update( servers ); assertArrayEquals( new BoltServerAddress[]{ new BoltServerAddress( "one" ), @@ -95,7 +93,7 @@ public void shouldPreserveOrderWhenRemovingThroughUpdate() throws Exception // when servers.remove( new BoltServerAddress( "one" ) ); - set.update( servers, new HashSet() ); + set.update( servers ); // then assertArrayEquals( new BoltServerAddress[]{ @@ -103,21 +101,6 @@ public void shouldPreserveOrderWhenRemovingThroughUpdate() throws Exception new BoltServerAddress( "tre" )}, set.toArray() ); } - @Test - public void shouldRecordRemovedAddressesWhenUpdating() throws Exception - { - // given - AddressSet set = new AddressSet(); - set.update( addresses( "one", "two", "tre" ), new HashSet() ); - - // when - HashSet removed = new HashSet<>(); - set.update( addresses( "one", "two", "fyr" ), removed ); - - // then - assertEquals( singleton( new BoltServerAddress( "tre" ) ), removed ); - } - @Test public void shouldExposeEmptyArrayWhenEmpty() { @@ -132,7 +115,7 @@ public void shouldExposeEmptyArrayWhenEmpty() public void shouldExposeCorrectArray() { AddressSet addressSet = new AddressSet(); - addressSet.update( addresses( "one", "two", "tre" ), new HashSet() ); + addressSet.update( addresses( "one", "two", "tre" ) ); BoltServerAddress[] addresses = addressSet.toArray(); @@ -154,7 +137,7 @@ public void shouldHaveSizeZeroWhenEmpty() public void shouldHaveCorrectSize() { AddressSet addressSet = new AddressSet(); - addressSet.update( addresses( "one", "two" ), new HashSet() ); + addressSet.update( addresses( "one", "two" ) ); assertEquals( 2, addressSet.size() ); } diff --git a/driver/src/test/java/org/neo4j/driver/internal/cluster/RediscoveryTest.java b/driver/src/test/java/org/neo4j/driver/internal/cluster/RediscoveryTest.java index 172d31de6a..163225c11b 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/cluster/RediscoveryTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/cluster/RediscoveryTest.java @@ -23,7 +23,6 @@ import java.io.IOException; import java.util.HashMap; -import java.util.HashSet; import java.util.Map; import java.util.concurrent.CompletionStage; @@ -398,7 +397,7 @@ private static RoutingTable routingTableMock( BoltServerAddress... routers ) { RoutingTable routingTable = mock( RoutingTable.class ); AddressSet addressSet = new AddressSet(); - addressSet.update( asOrderedSet( routers ), new HashSet<>() ); + addressSet.update( asOrderedSet( routers ) ); when( routingTable.routers() ).thenReturn( addressSet ); return routingTable; } diff --git a/driver/src/test/java/org/neo4j/driver/internal/cluster/RoutingProcedureRunnerTest.java b/driver/src/test/java/org/neo4j/driver/internal/cluster/RoutingProcedureRunnerTest.java index ea610c9311..10bc68459a 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/cluster/RoutingProcedureRunnerTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/cluster/RoutingProcedureRunnerTest.java @@ -40,6 +40,7 @@ import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.neo4j.driver.internal.cluster.RoutingProcedureRunner.GET_ROUTING_TABLE; import static org.neo4j.driver.internal.cluster.RoutingProcedureRunner.GET_ROUTING_TABLE_PARAM; @@ -144,11 +145,54 @@ public void shouldPropagateErrorFromConnectionStage() } } + @Test + public void shouldReleaseConnectionOnSuccess() + { + RoutingProcedureRunner runner = new TestRoutingProcedureRunner( RoutingContext.EMPTY, + completedFuture( singletonList( mock( Record.class ) ) ) ); + + CompletionStage connectionStage = connectionStage( "Neo4j/3.2.2" ); + Connection connection = await( connectionStage ); + RoutingProcedureResponse response = await( runner.run( connectionStage ) ); + + assertTrue( response.isSuccess() ); + verify( connection ).release(); + } + + @Test + public void shouldPropagateReleaseError() + { + RoutingProcedureRunner runner = new TestRoutingProcedureRunner( RoutingContext.EMPTY, + completedFuture( singletonList( mock( Record.class ) ) ) ); + + RuntimeException releaseError = new RuntimeException( "Release failed" ); + CompletionStage connectionStage = connectionStage( "Neo4j/3.3.3", failedFuture( releaseError ) ); + Connection connection = await( connectionStage ); + + try + { + await( runner.run( connectionStage ) ); + fail( "Exception expected" ); + } + catch ( RuntimeException e ) + { + assertEquals( releaseError, e ); + } + verify( connection ).release(); + } + private static CompletionStage connectionStage( String serverVersion ) + { + return connectionStage( serverVersion, completedFuture( null ) ); + } + + private static CompletionStage connectionStage( String serverVersion, + CompletionStage releaseStage ) { Connection connection = mock( Connection.class ); when( connection.serverAddress() ).thenReturn( new BoltServerAddress( "123:45" ) ); when( connection.serverVersion() ).thenReturn( version( serverVersion ) ); + when( connection.release() ).thenReturn( releaseStage ); return completedFuture( connection ); } diff --git a/driver/src/test/java/org/neo4j/driver/internal/cluster/loadbalancing/LoadBalancerTest.java b/driver/src/test/java/org/neo4j/driver/internal/cluster/loadbalancing/LoadBalancerTest.java index 1571c99cee..b50229d4a9 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/cluster/loadbalancing/LoadBalancerTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/cluster/loadbalancing/LoadBalancerTest.java @@ -21,7 +21,6 @@ import io.netty.util.concurrent.GlobalEventExecutor; import org.junit.Test; -import java.util.Arrays; import java.util.Collections; import java.util.HashSet; import java.util.LinkedHashSet; @@ -41,6 +40,7 @@ import org.neo4j.driver.v1.exceptions.ServiceUnavailableException; import org.neo4j.driver.v1.exceptions.SessionExpiredException; +import static java.util.Arrays.asList; import static java.util.Collections.emptySet; import static java.util.Collections.singletonList; import static java.util.concurrent.CompletableFuture.completedFuture; @@ -61,6 +61,9 @@ import static org.neo4j.driver.internal.cluster.ClusterCompositionUtil.A; import static org.neo4j.driver.internal.cluster.ClusterCompositionUtil.B; import static org.neo4j.driver.internal.cluster.ClusterCompositionUtil.C; +import static org.neo4j.driver.internal.cluster.ClusterCompositionUtil.D; +import static org.neo4j.driver.internal.cluster.ClusterCompositionUtil.E; +import static org.neo4j.driver.internal.cluster.ClusterCompositionUtil.F; import static org.neo4j.driver.internal.logging.DevNullLogging.DEV_NULL_LOGGING; import static org.neo4j.driver.v1.AccessMode.READ; import static org.neo4j.driver.v1.AccessMode.WRITE; @@ -78,10 +81,10 @@ public void acquireShouldUpdateRoutingTableWhenKnownRoutingTableIsStale() BoltServerAddress writer1 = new BoltServerAddress( "writer-1", 4 ); BoltServerAddress router1 = new BoltServerAddress( "router-1", 5 ); - ConnectionPool connectionPool = newAsyncConnectionPoolMock(); + ConnectionPool connectionPool = newConnectionPoolMock(); ClusterRoutingTable routingTable = new ClusterRoutingTable( new FakeClock(), initialRouter ); - Set readers = new LinkedHashSet<>( Arrays.asList( reader1, reader2 ) ); + Set readers = new LinkedHashSet<>( asList( reader1, reader2 ) ); Set writers = new LinkedHashSet<>( singletonList( writer1 ) ); Set routers = new LinkedHashSet<>( singletonList( router1 ) ); ClusterComposition clusterComposition = new ClusterComposition( 42, readers, writers, routers ); @@ -100,36 +103,6 @@ public void acquireShouldUpdateRoutingTableWhenKnownRoutingTableIsStale() assertArrayEquals( new BoltServerAddress[]{router1}, routingTable.routers().toArray() ); } - @Test - public void acquireShouldPurgeConnectionsWhenKnownRoutingTableIsStale() - { - BoltServerAddress initialRouter1 = new BoltServerAddress( "initialRouter-1", 1 ); - BoltServerAddress initialRouter2 = new BoltServerAddress( "initialRouter-2", 1 ); - BoltServerAddress reader = new BoltServerAddress( "reader", 2 ); - BoltServerAddress writer = new BoltServerAddress( "writer", 3 ); - BoltServerAddress router = new BoltServerAddress( "router", 4 ); - - ConnectionPool connectionPool = newAsyncConnectionPoolMock(); - ClusterRoutingTable routingTable = new ClusterRoutingTable( new FakeClock(), initialRouter1, initialRouter2 ); - - Set readers = new HashSet<>( singletonList( reader ) ); - Set writers = new HashSet<>( singletonList( writer ) ); - Set routers = new HashSet<>( singletonList( router ) ); - ClusterComposition clusterComposition = new ClusterComposition( 42, readers, writers, routers ); - Rediscovery rediscovery = mock( Rediscovery.class ); - when( rediscovery.lookupClusterComposition( routingTable, connectionPool ) ) - .thenReturn( completedFuture( clusterComposition ) ); - - LoadBalancer loadBalancer = new LoadBalancer( connectionPool, routingTable, rediscovery, - GlobalEventExecutor.INSTANCE, DEV_NULL_LOGGING ); - - assertNotNull( await( loadBalancer.acquireConnection( READ ) ) ); - - verify( rediscovery ).lookupClusterComposition( routingTable, connectionPool ); - verify( connectionPool ).purge( initialRouter1 ); - verify( connectionPool ).purge( initialRouter2 ); - } - @Test public void shouldRediscoverOnReadWhenRoutingTableIsStaleForReads() { @@ -157,7 +130,7 @@ public void shouldNotRediscoverOnWriteWhenRoutingTableIsStaleForReadsButNotWrite @Test public void shouldThrowWhenRediscoveryReturnsNoSuitableServers() { - ConnectionPool connectionPool = newAsyncConnectionPoolMock(); + ConnectionPool connectionPool = newConnectionPoolMock(); RoutingTable routingTable = mock( RoutingTable.class ); when( routingTable.isStaleFor( any( AccessMode.class ) ) ).thenReturn( true ); Rediscovery rediscovery = mock( Rediscovery.class ); @@ -196,7 +169,7 @@ public void shouldThrowWhenRediscoveryReturnsNoSuitableServers() @Test public void shouldSelectLeastConnectedAddress() { - ConnectionPool connectionPool = newAsyncConnectionPoolMock(); + ConnectionPool connectionPool = newConnectionPoolMock(); when( connectionPool.activeConnections( A ) ).thenReturn( 0 ); when( connectionPool.activeConnections( B ) ).thenReturn( 20 ); @@ -221,13 +194,13 @@ public void shouldSelectLeastConnectedAddress() // server B should never be selected because it has many active connections assertEquals( 2, seenAddresses.size() ); - assertTrue( seenAddresses.containsAll( Arrays.asList( A, C ) ) ); + assertTrue( seenAddresses.containsAll( asList( A, C ) ) ); } @Test public void shouldRoundRobinWhenNoActiveConnections() { - ConnectionPool connectionPool = newAsyncConnectionPoolMock(); + ConnectionPool connectionPool = newConnectionPoolMock(); RoutingTable routingTable = mock( RoutingTable.class ); AddressSet readerAddresses = mock( AddressSet.class ); @@ -247,7 +220,7 @@ public void shouldRoundRobinWhenNoActiveConnections() } assertEquals( 3, seenAddresses.size() ); - assertTrue( seenAddresses.containsAll( Arrays.asList( A, B, C ) ) ); + assertTrue( seenAddresses.containsAll( asList( A, B, C ) ) ); } @Test @@ -274,6 +247,51 @@ public void shouldTryMultipleServersAfterRediscovery() assertArrayEquals( new BoltServerAddress[]{B}, routingTable.readers().toArray() ); } + @Test + public void shouldRemoveAddressFromRoutingTableOnConnectionFailure() + { + RoutingTable routingTable = new ClusterRoutingTable( new FakeClock() ); + routingTable.update( new ClusterComposition( + 42, asOrderedSet( A, B, C ), asOrderedSet( A, C, E ), asOrderedSet( B, D, F ) ) ); + + LoadBalancer loadBalancer = new LoadBalancer( newConnectionPoolMock(), routingTable, newRediscoveryMock(), + GlobalEventExecutor.INSTANCE, DEV_NULL_LOGGING ); + + loadBalancer.onConnectionFailure( B ); + + assertArrayEquals( new BoltServerAddress[]{A, C}, routingTable.readers().toArray() ); + assertArrayEquals( new BoltServerAddress[]{A, C, E}, routingTable.writers().toArray() ); + assertArrayEquals( new BoltServerAddress[]{D, F}, routingTable.routers().toArray() ); + + loadBalancer.onConnectionFailure( A ); + + assertArrayEquals( new BoltServerAddress[]{C}, routingTable.readers().toArray() ); + assertArrayEquals( new BoltServerAddress[]{C, E}, routingTable.writers().toArray() ); + assertArrayEquals( new BoltServerAddress[]{D, F}, routingTable.routers().toArray() ); + } + + @Test + public void shouldRetainAllFetchedAddressesInConnectionPoolAfterFetchingOfRoutingTable() + { + RoutingTable routingTable = new ClusterRoutingTable( new FakeClock() ); + routingTable.update( new ClusterComposition( + 42, asOrderedSet(), asOrderedSet( B, C ), asOrderedSet( D, E ) ) ); + + ConnectionPool connectionPool = newConnectionPoolMock(); + + Rediscovery rediscovery = newRediscoveryMock(); + when( rediscovery.lookupClusterComposition( any(), any() ) ).thenReturn( completedFuture( + new ClusterComposition( 42, asOrderedSet( A, B ), asOrderedSet( B, C ), asOrderedSet( A, C ) ) ) ); + + LoadBalancer loadBalancer = new LoadBalancer( connectionPool, routingTable, rediscovery, + GlobalEventExecutor.INSTANCE, DEV_NULL_LOGGING ); + + Connection connection = await( loadBalancer.acquireConnection( READ ) ); + assertNotNull( connection ); + + verify( connectionPool ).retainAll( new HashSet<>( asList( A, B, C ) ) ); + } + private void testRediscoveryWhenStale( AccessMode mode ) { ConnectionPool connectionPool = mock( ConnectionPool.class ); @@ -313,10 +331,9 @@ private static RoutingTable newStaleRoutingTableMock( AccessMode mode ) { RoutingTable routingTable = mock( RoutingTable.class ); when( routingTable.isStaleFor( mode ) ).thenReturn( true ); - when( routingTable.update( any( ClusterComposition.class ) ) ).thenReturn( new HashSet<>() ); AddressSet addresses = new AddressSet(); - addresses.update( new HashSet<>( singletonList( LOCAL_DEFAULT ) ), new HashSet<>() ); + addresses.update( new HashSet<>( singletonList( LOCAL_DEFAULT ) ) ); when( routingTable.readers() ).thenReturn( addresses ); when( routingTable.writers() ).thenReturn( addresses ); @@ -333,7 +350,7 @@ private static Rediscovery newRediscoveryMock() return rediscovery; } - private static ConnectionPool newAsyncConnectionPoolMock() + private static ConnectionPool newConnectionPoolMock() { return newConnectionPoolMockWithFailures( emptySet() ); } diff --git a/driver/src/test/java/org/neo4j/driver/internal/handlers/ResetResponseHandlerTest.java b/driver/src/test/java/org/neo4j/driver/internal/handlers/ResetResponseHandlerTest.java index d97818f219..e28930c3a2 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/handlers/ResetResponseHandlerTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/handlers/ResetResponseHandlerTest.java @@ -20,20 +20,26 @@ import io.netty.channel.embedded.EmbeddedChannel; import io.netty.channel.pool.ChannelPool; -import io.netty.util.concurrent.Promise; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.ImmediateEventExecutor; import org.junit.After; import org.junit.Test; +import java.util.concurrent.CompletableFuture; + import org.neo4j.driver.internal.async.inbound.InboundMessageDispatcher; import org.neo4j.driver.internal.util.Clock; import org.neo4j.driver.internal.util.FakeClock; import static java.util.Collections.emptyMap; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; import static org.mockito.Matchers.any; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; import static org.neo4j.driver.internal.async.ChannelAttributes.lastUsedTimestamp; public class ResetResponseHandlerTest @@ -50,66 +56,42 @@ public void tearDown() @Test public void shouldReleaseChannelOnSuccess() { - ChannelPool pool = mock( ChannelPool.class ); + ChannelPool pool = newChannelPoolMock(); FakeClock clock = new FakeClock(); clock.progress( 5 ); - ResetResponseHandler handler = newHandler( pool, clock ); + CompletableFuture releaseFuture = new CompletableFuture<>(); + ResetResponseHandler handler = newHandler( pool, clock, releaseFuture ); handler.onSuccess( emptyMap() ); verifyLastUsedTimestamp( 5 ); - verify( pool ).release( eq( channel ), any() ); - } - - @Test - public void shouldReleaseChannelWithPromiseOnSuccess() - { - ChannelPool pool = mock( ChannelPool.class ); - FakeClock clock = new FakeClock(); - clock.progress( 42 ); - Promise promise = channel.newPromise(); - ResetResponseHandler handler = newHandler( pool, clock, promise ); - - handler.onSuccess( emptyMap() ); - - verifyLastUsedTimestamp( 42 ); - verify( pool ).release( channel, promise ); + verify( pool ).release( eq( channel ) ); + assertTrue( releaseFuture.isDone() ); + assertFalse( releaseFuture.isCompletedExceptionally() ); } @Test public void shouldReleaseChannelOnFailure() { - ChannelPool pool = mock( ChannelPool.class ); + ChannelPool pool = newChannelPoolMock(); FakeClock clock = new FakeClock(); clock.progress( 100 ); - ResetResponseHandler handler = newHandler( pool, clock ); + CompletableFuture releaseFuture = new CompletableFuture<>(); + ResetResponseHandler handler = newHandler( pool, clock, releaseFuture ); handler.onFailure( new RuntimeException() ); verifyLastUsedTimestamp( 100 ); - verify( pool ).release( eq( channel ), any() ); - } - - @Test - public void shouldReleaseChannelWithPromiseOnFailure() - { - ChannelPool pool = mock( ChannelPool.class ); - FakeClock clock = new FakeClock(); - clock.progress( 99 ); - Promise promise = channel.newPromise(); - ResetResponseHandler handler = newHandler( pool, clock, promise ); - - handler.onFailure( new RuntimeException() ); - - verifyLastUsedTimestamp( 99 ); - verify( pool ).release( channel, promise ); + verify( pool ).release( eq( channel ) ); + assertTrue( releaseFuture.isDone() ); + assertFalse( releaseFuture.isCompletedExceptionally() ); } @Test public void shouldUnMuteAckFailureOnSuccess() { - ChannelPool pool = mock( ChannelPool.class ); - ResetResponseHandler handler = newHandler( pool, new FakeClock() ); + ChannelPool pool = newChannelPoolMock(); + ResetResponseHandler handler = newHandler( pool, new FakeClock(), new CompletableFuture<>() ); handler.onSuccess( emptyMap() ); @@ -119,8 +101,8 @@ public void shouldUnMuteAckFailureOnSuccess() @Test public void shouldUnMuteAckFailureOnFailure() { - ChannelPool pool = mock( ChannelPool.class ); - ResetResponseHandler handler = newHandler( pool, new FakeClock() ); + ChannelPool pool = newChannelPoolMock(); + ResetResponseHandler handler = newHandler( pool, new FakeClock(), new CompletableFuture<>() ); handler.onFailure( new RuntimeException() ); @@ -132,13 +114,16 @@ private void verifyLastUsedTimestamp( int expectedValue ) assertEquals( expectedValue, lastUsedTimestamp( channel ).intValue() ); } - private ResetResponseHandler newHandler( ChannelPool pool, Clock clock ) + private ResetResponseHandler newHandler( ChannelPool pool, Clock clock, CompletableFuture releaseFuture ) { - return newHandler( pool, clock, channel.newPromise() ); + return new ResetResponseHandler( channel, pool, messageDispatcher, clock, releaseFuture ); } - private ResetResponseHandler newHandler( ChannelPool pool, Clock clock, Promise promise ) + private static ChannelPool newChannelPoolMock() { - return new ResetResponseHandler( channel, pool, messageDispatcher, clock, promise ); + ChannelPool pool = mock( ChannelPool.class ); + Future releasedFuture = ImmediateEventExecutor.INSTANCE.newSucceededFuture( null ); + when( pool.release( any() ) ).thenReturn( releasedFuture ); + return pool; } } diff --git a/driver/src/test/java/org/neo4j/driver/internal/util/FailingConnectionDriverFactory.java b/driver/src/test/java/org/neo4j/driver/internal/util/FailingConnectionDriverFactory.java new file mode 100644 index 0000000000..d314fc6c46 --- /dev/null +++ b/driver/src/test/java/org/neo4j/driver/internal/util/FailingConnectionDriverFactory.java @@ -0,0 +1,174 @@ +/* + * Copyright (c) 2002-2017 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal.util; + +import io.netty.bootstrap.Bootstrap; + +import java.util.Map; +import java.util.Set; +import java.util.concurrent.CompletionStage; +import java.util.concurrent.atomic.AtomicReference; + +import org.neo4j.driver.internal.BoltServerAddress; +import org.neo4j.driver.internal.DriverFactory; +import org.neo4j.driver.internal.security.SecurityPlan; +import org.neo4j.driver.internal.spi.Connection; +import org.neo4j.driver.internal.spi.ConnectionPool; +import org.neo4j.driver.internal.spi.ResponseHandler; +import org.neo4j.driver.v1.AuthToken; +import org.neo4j.driver.v1.Config; +import org.neo4j.driver.v1.Value; + +public class FailingConnectionDriverFactory extends DriverFactory +{ + private final AtomicReference nextRunFailure = new AtomicReference<>(); + + @Override + protected ConnectionPool createConnectionPool( AuthToken authToken, SecurityPlan securityPlan, Bootstrap bootstrap, + Config config ) + { + ConnectionPool pool = super.createConnectionPool( authToken, securityPlan, bootstrap, config ); + return new ConnectionPoolWithFailingConnections( pool, nextRunFailure ); + } + + public void setNextRunFailure( Throwable failure ) + { + nextRunFailure.set( failure ); + } + + private static class ConnectionPoolWithFailingConnections implements ConnectionPool + { + final ConnectionPool delegate; + final AtomicReference nextRunFailure; + + ConnectionPoolWithFailingConnections( ConnectionPool delegate, AtomicReference nextRunFailure ) + { + this.delegate = delegate; + this.nextRunFailure = nextRunFailure; + } + + @Override + public CompletionStage acquire( BoltServerAddress address ) + { + return delegate.acquire( address ) + .thenApply( connection -> new FailingConnection( connection, nextRunFailure ) ); + } + + @Override + public void retainAll( Set addressesToRetain ) + { + delegate.retainAll( addressesToRetain ); + } + + @Override + public int activeConnections( BoltServerAddress address ) + { + return delegate.activeConnections( address ); + } + + @Override + public CompletionStage close() + { + return delegate.close(); + } + } + + private static class FailingConnection implements Connection + { + final Connection delegate; + final AtomicReference nextRunFailure; + + FailingConnection( Connection delegate, AtomicReference nextRunFailure ) + { + this.delegate = delegate; + this.nextRunFailure = nextRunFailure; + } + + @Override + public boolean isOpen() + { + return delegate.isOpen(); + } + + @Override + public void enableAutoRead() + { + delegate.enableAutoRead(); + } + + @Override + public void disableAutoRead() + { + delegate.disableAutoRead(); + } + + @Override + public void run( String statement, Map parameters, ResponseHandler runHandler, + ResponseHandler pullAllHandler ) + { + if ( tryFail( runHandler, pullAllHandler ) ) + { + return; + } + delegate.run( statement, parameters, runHandler, pullAllHandler ); + } + + @Override + public void runAndFlush( String statement, Map parameters, ResponseHandler runHandler, + ResponseHandler pullAllHandler ) + { + if ( tryFail( runHandler, pullAllHandler ) ) + { + return; + } + delegate.runAndFlush( statement, parameters, runHandler, pullAllHandler ); + } + + @Override + public CompletionStage release() + { + return delegate.release(); + } + + @Override + public BoltServerAddress serverAddress() + { + return delegate.serverAddress(); + } + + @Override + public ServerVersion serverVersion() + { + return delegate.serverVersion(); + } + + private boolean tryFail( ResponseHandler runHandler, ResponseHandler pullAllHandler ) + { + Throwable failure = nextRunFailure.getAndSet( null ); + if ( failure != null ) + { + runHandler.onFailure( failure ); + pullAllHandler.onFailure( failure ); + return true; + } + return false; + } + } +} + diff --git a/driver/src/test/java/org/neo4j/driver/v1/integration/CausalClusteringIT.java b/driver/src/test/java/org/neo4j/driver/v1/integration/CausalClusteringIT.java index 3597339bb5..6c603129d7 100644 --- a/driver/src/test/java/org/neo4j/driver/v1/integration/CausalClusteringIT.java +++ b/driver/src/test/java/org/neo4j/driver/v1/integration/CausalClusteringIT.java @@ -38,6 +38,7 @@ import org.neo4j.driver.internal.cluster.RoutingSettings; import org.neo4j.driver.internal.retry.RetrySettings; import org.neo4j.driver.internal.util.ChannelTrackingDriverFactory; +import org.neo4j.driver.internal.util.FailingConnectionDriverFactory; import org.neo4j.driver.internal.util.FakeClock; import org.neo4j.driver.v1.AccessMode; import org.neo4j.driver.v1.AuthToken; @@ -579,6 +580,70 @@ public void shouldRespectMaxConnectionPoolSizePerClusterMember() } } + @Test + public void shouldAllowExistingTransactionToCompleteAfterDifferentConnectionBreaks() + { + Cluster cluster = clusterRule.getCluster(); + ClusterMember leader = cluster.leader(); + + FailingConnectionDriverFactory driverFactory = new FailingConnectionDriverFactory(); + RoutingSettings routingSettings = new RoutingSettings( 1, SECONDS.toMillis( 5 ), null ); + Config config = Config.build().toConfig(); + + try ( Driver driver = driverFactory.newInstance( leader.getRoutingUri(), clusterRule.getDefaultAuthToken(), + routingSettings, RetrySettings.DEFAULT, config ) ) + { + Session session1 = driver.session(); + Transaction tx1 = session1.beginTransaction(); + tx1.run( "CREATE (n:Node1 {name: 'Node1'})" ).consume(); + + Session session2 = driver.session(); + Transaction tx2 = session2.beginTransaction(); + tx2.run( "CREATE (n:Node2 {name: 'Node2'})" ).consume(); + + ServiceUnavailableException error = new ServiceUnavailableException( "Connection broke!" ); + driverFactory.setNextRunFailure( error ); + assertUnableToRunMoreStatementsInTx( tx2, error ); + + closeTx( tx2 ); + closeTx( tx1 ); + + try ( Session session3 = driver.session( session1.lastBookmark() ) ) + { + // tx1 should not be terminated and should commit successfully + assertEquals( 1, countNodes( session3, "Node1", "name", "Node1" ) ); + // tx2 should not commit because of a connection failure + assertEquals( 0, countNodes( session3, "Node2", "name", "Node2" ) ); + } + + // rediscovery should happen for the new write query + String session4Bookmark = createNodeAndGetBookmark( driver.session(), "Node3", "name", "Node3" ); + try ( Session session5 = driver.session( session4Bookmark ) ) + { + assertEquals( 1, countNodes( session5, "Node3", "name", "Node3" ) ); + } + } + } + + private static void closeTx( Transaction tx ) + { + tx.success(); + tx.close(); + } + + private static void assertUnableToRunMoreStatementsInTx( Transaction tx, ServiceUnavailableException cause ) + { + try + { + tx.run( "CREATE (n:Node3 {name: 'Node3'})" ).consume(); + fail( "Exception expected" ); + } + catch ( SessionExpiredException e ) + { + assertEquals( cause, e.getCause() ); + } + } + private CompletionStage> combineCursors( StatementResultCursor cursor1, StatementResultCursor cursor2 ) { @@ -814,44 +879,33 @@ private static void closeAndExpectException( AutoCloseable closeable, Class() + return session.readTransaction( tx -> { - @Override - public Integer execute( Transaction tx ) - { - StatementResult result = tx.run( "MATCH (n:" + label + " {" + property + ": $value}) RETURN count(n)", - parameters( "value", value ) ); - return result.single().get( 0 ).asInt(); - } + String query = "MATCH (n:" + label + " {" + property + ": $value}) RETURN count(n)"; + StatementResult result = tx.run( query, parameters( "value", value ) ); + return result.single().get( 0 ).asInt(); } ); } - private static Callable createNodeAndGetBookmark( final Driver driver, final String label, - final String property, final String value ) + private static Callable createNodeAndGetBookmark( Driver driver, String label, String property, + String value ) { - return new Callable() + return () -> createNodeAndGetBookmark( driver.session(), label, property, value ); + } + + private static String createNodeAndGetBookmark( Session session, String label, String property, String value ) + { + try ( Session localSession = session ) { - @Override - public String call() + localSession.writeTransaction( tx -> { - try ( Session session = driver.session() ) - { - session.writeTransaction( new TransactionWork() - { - @Override - public Void execute( Transaction tx ) - { - tx.run( "CREATE (n:" + label + ") SET n." + property + " = $value", - parameters( "value", value ) ); - return null; - } - } ); - return session.lastBookmark(); - } - } - }; + tx.run( "CREATE (n:" + label + ") SET n." + property + " = $value", parameters( "value", value ) ); + return null; + } ); + return localSession.lastBookmark(); + } } private static class RecordAndSummary diff --git a/driver/src/test/java/org/neo4j/driver/v1/util/TestUtil.java b/driver/src/test/java/org/neo4j/driver/v1/util/TestUtil.java index 2abd5e6506..be641e508a 100644 --- a/driver/src/test/java/org/neo4j/driver/v1/util/TestUtil.java +++ b/driver/src/test/java/org/neo4j/driver/v1/util/TestUtil.java @@ -52,6 +52,12 @@ private TestUtil() { } + @SafeVarargs + public static List awaitAll( CompletionStage... stages ) + { + return awaitAll( Arrays.asList( stages ) ); + } + public static List awaitAll( List> stages ) { List result = new ArrayList<>();