diff --git a/driver/src/main/java/org/neo4j/driver/internal/DriverFactory.java b/driver/src/main/java/org/neo4j/driver/internal/DriverFactory.java index e1eb995d26..833890a899 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/DriverFactory.java +++ b/driver/src/main/java/org/neo4j/driver/internal/DriverFactory.java @@ -19,7 +19,6 @@ package org.neo4j.driver.internal; import io.netty.bootstrap.Bootstrap; -import io.netty.channel.EventLoopGroup; import io.netty.util.concurrent.EventExecutorGroup; import java.io.IOException; @@ -76,17 +75,17 @@ public final Driver newInstance( URI uri, AuthToken authToken, RoutingSettings r SecurityPlan securityPlan = createSecurityPlan( address, config ); ConnectionPool connectionPool = createConnectionPool( authToken, securityPlan, config ); - Bootstrap bootstrap = BootstrapFactory.newBootstrap(); - EventLoopGroup eventLoopGroup = bootstrap.config().group(); - RetryLogic retryLogic = createRetryLogic( retrySettings, eventLoopGroup, config.logging() ); + Bootstrap bootstrap = createBootstrap(); + EventExecutorGroup eventExecutorGroup = bootstrap.config().group(); + RetryLogic retryLogic = createRetryLogic( retrySettings, eventExecutorGroup, config.logging() ); AsyncConnectionPool asyncConnectionPool = createAsyncConnectionPool( authToken, securityPlan, bootstrap, config ); try { - return createDriver( uri, address, connectionPool, config, newRoutingSettings, securityPlan, retryLogic, - asyncConnectionPool ); + return createDriver( uri, address, connectionPool, asyncConnectionPool, config, newRoutingSettings, + eventExecutorGroup, securityPlan, retryLogic ); } catch ( Throwable driverError ) { @@ -121,8 +120,8 @@ private AsyncConnectionPool createAsyncConnectionPool( AuthToken authToken, Secu } private Driver createDriver( URI uri, BoltServerAddress address, ConnectionPool connectionPool, - Config config, RoutingSettings routingSettings, SecurityPlan securityPlan, RetryLogic retryLogic, - AsyncConnectionPool asyncConnectionPool ) + AsyncConnectionPool asyncConnectionPool, Config config, RoutingSettings routingSettings, + EventExecutorGroup eventExecutorGroup, SecurityPlan securityPlan, RetryLogic retryLogic ) { String scheme = uri.getScheme().toLowerCase(); switch ( scheme ) @@ -131,7 +130,8 @@ private Driver createDriver( URI uri, BoltServerAddress address, ConnectionPool assertNoRoutingContext( uri, routingSettings ); return createDirectDriver( address, connectionPool, config, securityPlan, retryLogic, asyncConnectionPool ); case BOLT_ROUTING_URI_SCHEME: - return createRoutingDriver( address, connectionPool, config, routingSettings, securityPlan, retryLogic ); + return createRoutingDriver( address, connectionPool, asyncConnectionPool, config, routingSettings, + securityPlan, retryLogic, eventExecutorGroup ); default: throw new ClientException( format( "Unsupported URI scheme: %s", scheme ) ); } @@ -158,13 +158,15 @@ protected Driver createDirectDriver( BoltServerAddress address, ConnectionPool c * This method is protected only for testing */ protected Driver createRoutingDriver( BoltServerAddress address, ConnectionPool connectionPool, - Config config, RoutingSettings routingSettings, SecurityPlan securityPlan, RetryLogic retryLogic ) + AsyncConnectionPool asyncConnectionPool, Config config, RoutingSettings routingSettings, + SecurityPlan securityPlan, RetryLogic retryLogic, EventExecutorGroup eventExecutorGroup ) { if ( !securityPlan.isRoutingCompatible() ) { throw new IllegalArgumentException( "The chosen security plan is not compatible with a routing driver" ); } - ConnectionProvider connectionProvider = createLoadBalancer( address, connectionPool, config, routingSettings ); + ConnectionProvider connectionProvider = createLoadBalancer( address, connectionPool, asyncConnectionPool, + eventExecutorGroup, config, routingSettings ); SessionFactory sessionFactory = createSessionFactory( connectionProvider, retryLogic, config ); return createDriver( config, securityPlan, sessionFactory ); } @@ -184,21 +186,25 @@ protected InternalDriver createDriver( Config config, SecurityPlan securityPlan, *

* This method is protected only for testing */ - protected LoadBalancer createLoadBalancer( BoltServerAddress address, ConnectionPool connectionPool, Config config, - RoutingSettings routingSettings ) + protected LoadBalancer createLoadBalancer( BoltServerAddress address, ConnectionPool connectionPool, + AsyncConnectionPool asyncConnectionPool, EventExecutorGroup eventExecutorGroup, + Config config, RoutingSettings routingSettings ) { - return new LoadBalancer( address, routingSettings, connectionPool, createClock(), config.logging(), - createLoadBalancingStrategy( config, connectionPool ) ); + LoadBalancingStrategy loadBalancingStrategy = + createLoadBalancingStrategy( config, connectionPool, asyncConnectionPool ); + return new LoadBalancer( address, routingSettings, connectionPool, asyncConnectionPool, eventExecutorGroup, + createClock(), config.logging(), loadBalancingStrategy ); } - private static LoadBalancingStrategy createLoadBalancingStrategy( Config config, ConnectionPool connectionPool ) + private static LoadBalancingStrategy createLoadBalancingStrategy( Config config, ConnectionPool connectionPool, + AsyncConnectionPool asyncConnectionPool ) { switch ( config.loadBalancingStrategy() ) { case ROUND_ROBIN: return new RoundRobinLoadBalancingStrategy( config.logging() ); case LEAST_CONNECTED: - return new LeastConnectedLoadBalancingStrategy( connectionPool, config.logging() ); + return new LeastConnectedLoadBalancingStrategy( connectionPool, asyncConnectionPool, config.logging() ); default: throw new IllegalArgumentException( "Unknown load balancing strategy: " + config.loadBalancingStrategy() ); } @@ -253,7 +259,7 @@ protected SessionFactory createSessionFactory( ConnectionProvider connectionProv } /** - * Creates new {@link RetryLogic >}. + * Creates new {@link RetryLogic}. *

* This method is protected only for testing */ @@ -263,6 +269,16 @@ protected RetryLogic createRetryLogic( RetrySettings settings, EventExecutorGrou return new ExponentialBackoffRetryLogic( settings, eventExecutorGroup, createClock(), logging ); } + /** + * Creates new {@link Bootstrap}. + *

+ * This method is protected only for testing + */ + protected Bootstrap createBootstrap() + { + return BootstrapFactory.newBootstrap(); + } + private static SecurityPlan createSecurityPlan( BoltServerAddress address, Config config ) { try diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/AsyncConnection.java b/driver/src/main/java/org/neo4j/driver/internal/async/AsyncConnection.java index 8ac425a16a..66ef5aef52 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/AsyncConnection.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/AsyncConnection.java @@ -21,9 +21,10 @@ import java.util.Map; import java.util.concurrent.CompletionStage; +import org.neo4j.driver.internal.net.BoltServerAddress; import org.neo4j.driver.internal.spi.ResponseHandler; +import org.neo4j.driver.internal.util.ServerVersion; import org.neo4j.driver.v1.Value; -import org.neo4j.driver.v1.summary.ServerInfo; public interface AsyncConnection { @@ -43,5 +44,7 @@ void runAndFlush( String statement, Map parameters, ResponseHandle CompletionStage forceRelease(); - ServerInfo serverInfo(); + BoltServerAddress serverAddress(); + + ServerVersion serverVersion(); } diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/ChannelAttributes.java b/driver/src/main/java/org/neo4j/driver/internal/async/ChannelAttributes.java index 9654695118..26dc606726 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/ChannelAttributes.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/ChannelAttributes.java @@ -23,31 +23,42 @@ import org.neo4j.driver.internal.async.inbound.InboundMessageDispatcher; import org.neo4j.driver.internal.net.BoltServerAddress; +import org.neo4j.driver.internal.util.ServerVersion; import static io.netty.util.AttributeKey.newInstance; public final class ChannelAttributes { - private static final AttributeKey ADDRESS = newInstance( "address" ); + private static final AttributeKey ADDRESS = newInstance( "serverAddress" ); + private static final AttributeKey SERVER_VERSION = newInstance( "serverVersion" ); private static final AttributeKey CREATION_TIMESTAMP = newInstance( "creationTimestamp" ); private static final AttributeKey LAST_USED_TIMESTAMP = newInstance( "lastUsedTimestamp" ); private static final AttributeKey MESSAGE_DISPATCHER = newInstance( "messageDispatcher" ); - private static final AttributeKey SERVER_VERSION = newInstance( "serverVersion" ); private ChannelAttributes() { } - public static BoltServerAddress address( Channel channel ) + public static BoltServerAddress serverAddress( Channel channel ) { return get( channel, ADDRESS ); } - public static void setAddress( Channel channel, BoltServerAddress address ) + public static void setServerAddress( Channel channel, BoltServerAddress address ) { setOnce( channel, ADDRESS, address ); } + public static ServerVersion serverVersion( Channel channel ) + { + return get( channel, SERVER_VERSION ); + } + + public static void setServerVersion( Channel channel, ServerVersion version ) + { + setOnce( channel, SERVER_VERSION, version ); + } + public static long creationTimestamp( Channel channel ) { return get( channel, CREATION_TIMESTAMP ); @@ -78,16 +89,6 @@ public static void setMessageDispatcher( Channel channel, InboundMessageDispatch setOnce( channel, MESSAGE_DISPATCHER, messageDispatcher ); } - public static String serverVersion( Channel channel ) - { - return get( channel, SERVER_VERSION ); - } - - public static void setServerVersion( Channel channel, String serverVersion ) - { - setOnce( channel, SERVER_VERSION, serverVersion ); - } - private static T get( Channel channel, AttributeKey key ) { return channel.attr( key ).get(); diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/Main.java b/driver/src/main/java/org/neo4j/driver/internal/async/Main.java index 1006f5267e..cfe5f94a1f 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/Main.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/Main.java @@ -18,14 +18,18 @@ */ package org.neo4j.driver.internal.async; +import io.netty.util.internal.ConcurrentSet; + import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Set; import java.util.concurrent.CompletionStage; import java.util.concurrent.TimeUnit; +import org.neo4j.driver.v1.AccessMode; import org.neo4j.driver.v1.AuthToken; import org.neo4j.driver.v1.AuthTokens; import org.neo4j.driver.v1.Config; @@ -37,13 +41,15 @@ import org.neo4j.driver.v1.StatementResultCursor; import org.neo4j.driver.v1.Transaction; +// todo: remove this class public class Main { - private static final int ITERATIONS = 100; + private static final int ITERATIONS = 200; - private static final String QUERY1 = "MATCH (n:ActiveItem) RETURN n LIMIT 10000"; + private static final String QUERY1 = "RETURN 1"; + private static final String QUERY2 = "MATCH (n:ActiveItem) RETURN n LIMIT 50000"; - private static final String QUERY = + private static final String QUERY3 = "MATCH (s:Sku{sku_no: {skuNo}})-[:HAS_ITEM_SOURCE]->(i:ItemSource{itemsource: {itemSource}})\n" + "//Get master sku for auxiliary item\n" + "OPTIONAL MATCH (s)-[:AUXILIARY_FOR]->(master_sku:Sku) WHERE NOT s.display_auxiliary_content\n" + @@ -97,24 +103,30 @@ public class Main "\thasThirdPartyContent: sku.has_third_party_content\n" + "}) AS overview;\n"; - private static final Map PARAMS_OBJ = new HashMap<>(); + private static final String QUERY = QUERY2; + + private static final Map PARAMS1 = new HashMap<>(); + private static final Map PARAMS2 = new HashMap<>(); + + private static final Map PARAMS = PARAMS1; + private static final String SCHEME = "bolt+routing"; private static final String USER = "neo4j"; private static final String PASSWORD = "test"; - private static final String HOST = "ec2-54-73-57-164.eu-west-1.compute.amazonaws.com"; - private static final int PORT = 7687; - private static final String URI = "bolt://" + HOST + ":" + PORT; + private static final String HOST = "ec2-34-249-23-195.eu-west-1.compute.amazonaws.com"; + private static final int PORT = 26000; + private static final String URI = SCHEME + "://" + HOST + ":" + PORT; static { - PARAMS_OBJ.put( "skuNo", 366421 ); - PARAMS_OBJ.put( "itemSource", "REG" ); - PARAMS_OBJ.put( "catalogId", 2 ); - PARAMS_OBJ.put( "locale", "en" ); + PARAMS1.put( "skuNo", 366421 ); + PARAMS1.put( "itemSource", "REG" ); + PARAMS1.put( "catalogId", 2 ); + PARAMS1.put( "locale", "en" ); Map tmpObj = new HashMap<>(); tmpObj.put( "skuNo", 366421 ); tmpObj.put( "itemSource", "REG" ); - PARAMS_OBJ.put( "itemList", Collections.singletonList( tmpObj ) ); + PARAMS1.put( "itemList", Collections.singletonList( tmpObj ) ); } public static void main( String[] args ) throws Throwable @@ -131,17 +143,18 @@ private static void testSessionRun() throws Throwable test( "Session#run()", new Action() { @Override - public void apply( Driver driver, MutableInt recordsRead ) + public void apply( Driver driver, MutableInt recordsRead, Set serversUsed ) { - try ( Session session = driver.session() ) + try ( Session session = driver.session( AccessMode.READ ) ) { - StatementResult result = session.run( QUERY, PARAMS_OBJ ); + StatementResult result = session.run( QUERY, PARAMS ); while ( result.hasNext() ) { Record record = result.next(); useRecord( record ); recordsRead.increment(); } + serversUsed.add( result.summary().server().address() ); } } } ); @@ -152,10 +165,10 @@ private static void testSessionRunAsync() throws Throwable test( "Session#runAsync()", new Action() { @Override - public void apply( Driver driver, MutableInt recordsRead ) + public void apply( Driver driver, MutableInt recordsRead, Set serversUsed ) { - Session session = driver.session(); - CompletionStage cursorFuture = session.runAsync( QUERY, PARAMS_OBJ ); + Session session = driver.session( AccessMode.READ ); + CompletionStage cursorFuture = session.runAsync( QUERY, PARAMS ); StatementResultCursor cursor = await( cursorFuture ); Record record; while ( (record = await( cursor.nextAsync() )) != null ) @@ -163,6 +176,7 @@ public void apply( Driver driver, MutableInt recordsRead ) useRecord( record ); recordsRead.increment(); } + serversUsed.add( await( cursor.summaryAsync() ).server().address() ); await( session.closeAsync() ); } } ); @@ -173,12 +187,12 @@ private static void testTxRun() throws Throwable test( "Transaction#run()", new Action() { @Override - public void apply( Driver driver, MutableInt recordsRead ) + public void apply( Driver driver, MutableInt recordsRead, Set serversUsed ) { - try ( Session session = driver.session(); + try ( Session session = driver.session( AccessMode.READ ); Transaction tx = session.beginTransaction() ) { - StatementResult result = tx.run( QUERY, PARAMS_OBJ ); + StatementResult result = tx.run( QUERY, PARAMS ); while ( result.hasNext() ) { Record record = result.next(); @@ -186,6 +200,7 @@ public void apply( Driver driver, MutableInt recordsRead ) recordsRead.increment(); } tx.success(); + serversUsed.add( result.summary().server().address() ); } } } ); @@ -196,17 +211,18 @@ private static void testTxRunAsync() throws Throwable test( "Transaction#runAsync()", new Action() { @Override - public void apply( Driver driver, MutableInt recordsRead ) + public void apply( Driver driver, MutableInt recordsRead, Set serversUsed ) { - Session session = driver.session(); + Session session = driver.session( AccessMode.READ ); Transaction tx = await( session.beginTransactionAsync() ); - StatementResultCursor cursor = await( tx.runAsync( QUERY, PARAMS_OBJ ) ); + StatementResultCursor cursor = await( tx.runAsync( QUERY, PARAMS ) ); Record record; while ( (record = await( cursor.nextAsync() )) != null ) { useRecord( record ); recordsRead.increment(); } + serversUsed.add( await( cursor.summaryAsync() ).server().address() ); await( tx.commitAsync() ); await( session.closeAsync() ); } @@ -220,6 +236,7 @@ private static void test( String actionName, Action action ) throws Throwable List timings = new ArrayList<>(); MutableInt recordsRead = new MutableInt(); + ConcurrentSet serversUsed = new ConcurrentSet<>(); try ( Driver driver = GraphDatabase.driver( URI, authToken, config ) ) { @@ -227,7 +244,7 @@ private static void test( String actionName, Action action ) throws Throwable { long start = System.nanoTime(); - action.apply( driver, recordsRead ); + action.apply( driver, recordsRead, serversUsed ); long end = System.nanoTime(); timings.add( TimeUnit.NANOSECONDS.toMillis( end - start ) ); @@ -240,6 +257,7 @@ private static void test( String actionName, Action action ) throws Throwable System.out.println( actionName + ": mean --> " + mean( timings ) + "ms, stdDev --> " + stdDev( timings ) ); System.out.println( actionName + ": timings --> " + timings ); System.out.println( actionName + ": recordsRead --> " + recordsRead ); + System.out.println( actionName + ": serversUsed --> " + serversUsed ); System.out.println( "============================================================" ); } @@ -306,7 +324,7 @@ private static void useRecord( Record record ) private interface Action { - void apply( Driver driver, MutableInt recordsRead ); + void apply( Driver driver, MutableInt recordsRead, Set serversUsed ); } private static class MutableInt diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/NettyChannelInitializer.java b/driver/src/main/java/org/neo4j/driver/internal/async/NettyChannelInitializer.java index 2d62851e56..4498e4ade2 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/NettyChannelInitializer.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/NettyChannelInitializer.java @@ -31,9 +31,9 @@ import org.neo4j.driver.internal.security.SecurityPlan; import org.neo4j.driver.internal.util.Clock; -import static org.neo4j.driver.internal.async.ChannelAttributes.setAddress; import static org.neo4j.driver.internal.async.ChannelAttributes.setCreationTimestamp; import static org.neo4j.driver.internal.async.ChannelAttributes.setMessageDispatcher; +import static org.neo4j.driver.internal.async.ChannelAttributes.setServerAddress; import static org.neo4j.driver.internal.logging.DevNullLogging.DEV_NULL_LOGGING; public class NettyChannelInitializer extends ChannelInitializer @@ -82,7 +82,7 @@ private SSLEngine createSslEngine() private void updateChannelAttributes( Channel channel ) { - setAddress( channel, address ); + setServerAddress( channel, address ); setCreationTimestamp( channel, clock.millis() ); setMessageDispatcher( channel, new InboundMessageDispatcher( channel, DEV_NULL_LOGGING ) ); } diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/NettyConnection.java b/driver/src/main/java/org/neo4j/driver/internal/async/NettyConnection.java index 2328a3d3c6..dc1fc5e90e 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/NettyConnection.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/NettyConnection.java @@ -32,16 +32,14 @@ import org.neo4j.driver.internal.messaging.PullAllMessage; import org.neo4j.driver.internal.messaging.ResetMessage; import org.neo4j.driver.internal.messaging.RunMessage; +import org.neo4j.driver.internal.net.BoltServerAddress; import org.neo4j.driver.internal.spi.ResponseHandler; -import org.neo4j.driver.internal.summary.InternalServerInfo; import org.neo4j.driver.internal.util.Clock; +import org.neo4j.driver.internal.util.ServerVersion; import org.neo4j.driver.v1.Value; -import org.neo4j.driver.v1.summary.ServerInfo; import static java.util.concurrent.CompletableFuture.completedFuture; -import static org.neo4j.driver.internal.async.ChannelAttributes.address; import static org.neo4j.driver.internal.async.ChannelAttributes.messageDispatcher; -import static org.neo4j.driver.internal.async.ChannelAttributes.serverVersion; import static org.neo4j.driver.internal.async.Futures.asCompletionStage; // todo: keep state flags to prohibit interaction with released connections @@ -127,9 +125,15 @@ public CompletionStage forceRelease() } @Override - public ServerInfo serverInfo() + public BoltServerAddress serverAddress() { - return new InternalServerInfo( address( channel ), serverVersion( channel ) ); + return ChannelAttributes.serverAddress( channel ); + } + + @Override + public ServerVersion serverVersion() + { + return ChannelAttributes.serverVersion( channel ); } private void run( String statement, Map parameters, ResponseHandler runHandler, diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/RoutingAsyncConnection.java b/driver/src/main/java/org/neo4j/driver/internal/async/RoutingAsyncConnection.java new file mode 100644 index 0000000000..bfa4a14e15 --- /dev/null +++ b/driver/src/main/java/org/neo4j/driver/internal/async/RoutingAsyncConnection.java @@ -0,0 +1,106 @@ +/* + * Copyright (c) 2002-2017 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal.async; + +import java.util.Map; +import java.util.concurrent.CompletionStage; + +import org.neo4j.driver.internal.RoutingErrorHandler; +import org.neo4j.driver.internal.net.BoltServerAddress; +import org.neo4j.driver.internal.spi.ResponseHandler; +import org.neo4j.driver.internal.util.ServerVersion; +import org.neo4j.driver.v1.AccessMode; +import org.neo4j.driver.v1.Value; + +public class RoutingAsyncConnection implements AsyncConnection +{ + private final AsyncConnection delegate; + private final AccessMode accessMode; + private final RoutingErrorHandler errorHandler; + + public RoutingAsyncConnection( AsyncConnection delegate, AccessMode accessMode, RoutingErrorHandler errorHandler ) + { + this.delegate = delegate; + this.accessMode = accessMode; + this.errorHandler = errorHandler; + } + + @Override + public boolean tryMarkInUse() + { + return delegate.tryMarkInUse(); + } + + @Override + public void enableAutoRead() + { + delegate.enableAutoRead(); + } + + @Override + public void disableAutoRead() + { + delegate.disableAutoRead(); + } + + @Override + public void run( String statement, Map parameters, ResponseHandler runHandler, + ResponseHandler pullAllHandler ) + { + delegate.run( statement, parameters, newRoutingResponseHandler( runHandler ), + newRoutingResponseHandler( pullAllHandler ) ); + } + + @Override + public void runAndFlush( String statement, Map parameters, ResponseHandler runHandler, + ResponseHandler pullAllHandler ) + { + delegate.runAndFlush( statement, parameters, newRoutingResponseHandler( runHandler ), + newRoutingResponseHandler( pullAllHandler ) ); + } + + @Override + public void release() + { + delegate.release(); + } + + @Override + public CompletionStage forceRelease() + { + return delegate.forceRelease(); + } + + @Override + public BoltServerAddress serverAddress() + { + return delegate.serverAddress(); + } + + @Override + public ServerVersion serverVersion() + { + return delegate.serverVersion(); + } + + private RoutingResponseHandler newRoutingResponseHandler( ResponseHandler handler ) + { + return new RoutingResponseHandler( handler, serverAddress(), accessMode, errorHandler ); + } +} diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/RoutingResponseHandler.java b/driver/src/main/java/org/neo4j/driver/internal/async/RoutingResponseHandler.java new file mode 100644 index 0000000000..183eb35168 --- /dev/null +++ b/driver/src/main/java/org/neo4j/driver/internal/async/RoutingResponseHandler.java @@ -0,0 +1,139 @@ +/* + * Copyright (c) 2002-2017 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal.async; + +import java.util.Map; +import java.util.Objects; +import java.util.concurrent.CompletionException; + +import org.neo4j.driver.internal.RoutingErrorHandler; +import org.neo4j.driver.internal.net.BoltServerAddress; +import org.neo4j.driver.internal.spi.ResponseHandler; +import org.neo4j.driver.v1.AccessMode; +import org.neo4j.driver.v1.Value; +import org.neo4j.driver.v1.exceptions.ClientException; +import org.neo4j.driver.v1.exceptions.ServiceUnavailableException; +import org.neo4j.driver.v1.exceptions.SessionExpiredException; +import org.neo4j.driver.v1.exceptions.TransientException; + +import static java.lang.String.format; + +public class RoutingResponseHandler implements ResponseHandler +{ + private final ResponseHandler delegate; + private final BoltServerAddress address; + private final AccessMode accessMode; + private final RoutingErrorHandler errorHandler; + + public RoutingResponseHandler( ResponseHandler delegate, BoltServerAddress address, AccessMode accessMode, + RoutingErrorHandler errorHandler ) + { + this.delegate = delegate; + this.address = address; + this.accessMode = accessMode; + this.errorHandler = errorHandler; + } + + @Override + public void onSuccess( Map metadata ) + { + delegate.onSuccess( metadata ); + } + + @Override + public void onFailure( Throwable error ) + { + Throwable newError = handledError( error ); + delegate.onFailure( newError ); + } + + @Override + public void onRecord( Value[] fields ) + { + delegate.onRecord( fields ); + } + + private Throwable handledError( Throwable error ) + { + if ( error instanceof CompletionException ) + { + error = error.getCause(); + } + + if ( error instanceof ServiceUnavailableException ) + { + return handledServiceUnavailableException( ((ServiceUnavailableException) error) ); + } + else if ( error instanceof ClientException ) + { + return handledClientException( ((ClientException) error) ); + } + else if ( error instanceof TransientException ) + { + return handledTransientException( ((TransientException) error) ); + } + else + { + return error; + } + } + + private Throwable handledServiceUnavailableException( ServiceUnavailableException e ) + { + errorHandler.onConnectionFailure( address ); + return new SessionExpiredException( format( "Server at %s is no longer available", address ), e ); + } + + private Throwable handledTransientException( TransientException e ) + { + String errorCode = e.code(); + if ( Objects.equals( errorCode, "Neo.TransientError.General.DatabaseUnavailable" ) ) + { + errorHandler.onConnectionFailure( address ); + } + return e; + } + + private Throwable handledClientException( ClientException e ) + { + if ( isFailureToWrite( e ) ) + { + // The server is unaware of the session mode, so we have to implement this logic in the driver. + // In the future, we might be able to move this logic to the server. + switch ( accessMode ) + { + case READ: + return new ClientException( "Write queries cannot be performed in READ access mode." ); + case WRITE: + errorHandler.onWriteFailure( address ); + return new SessionExpiredException( format( "Server at %s no longer accepts writes", address ) ); + default: + throw new IllegalArgumentException( accessMode + " not supported." ); + } + } + return e; + } + + private static boolean isFailureToWrite( ClientException e ) + { + String errorCode = e.code(); + return Objects.equals( errorCode, "Neo.ClientError.Cluster.NotALeader" ) || + Objects.equals( errorCode, "Neo.ClientError.General.ForbiddenOnReadOnlyDatabase" ); + } +} diff --git a/driver/src/main/java/org/neo4j/driver/internal/async/pool/ActiveChannelTracker.java b/driver/src/main/java/org/neo4j/driver/internal/async/pool/ActiveChannelTracker.java index 2402cf0cc8..c1e0fdd429 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/async/pool/ActiveChannelTracker.java +++ b/driver/src/main/java/org/neo4j/driver/internal/async/pool/ActiveChannelTracker.java @@ -29,7 +29,7 @@ import org.neo4j.driver.v1.Logger; import org.neo4j.driver.v1.Logging; -import static org.neo4j.driver.internal.async.ChannelAttributes.address; +import static org.neo4j.driver.internal.async.ChannelAttributes.serverAddress; public class ActiveChannelTracker implements ChannelPoolHandler { @@ -83,7 +83,7 @@ public void purge( BoltServerAddress address ) private void channelActive( Channel channel ) { - BoltServerAddress address = address( channel ); + BoltServerAddress address = serverAddress( channel ); ConcurrentSet activeChannels = addressToActiveChannelCount.get( address ); if ( activeChannels == null ) { @@ -105,7 +105,7 @@ private void channelActive( Channel channel ) private void channelInactive( Channel channel ) { - BoltServerAddress address = address( channel ); + BoltServerAddress address = serverAddress( channel ); ConcurrentSet activeChannels = addressToActiveChannelCount.get( address ); if ( activeChannels == null ) { diff --git a/driver/src/main/java/org/neo4j/driver/internal/cluster/ClusterComposition.java b/driver/src/main/java/org/neo4j/driver/internal/cluster/ClusterComposition.java index 7c9eceeaa7..7376cc0548 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/cluster/ClusterComposition.java +++ b/driver/src/main/java/org/neo4j/driver/internal/cluster/ClusterComposition.java @@ -19,6 +19,7 @@ package org.neo4j.driver.internal.cluster; import java.util.LinkedHashSet; +import java.util.Objects; import java.util.Set; import org.neo4j.driver.internal.net.BoltServerAddress; @@ -94,15 +95,39 @@ public long expirationTimestamp() { return this.expirationTimestamp; } + @Override + public boolean equals( Object o ) + { + if ( this == o ) + { + return true; + } + if ( o == null || getClass() != o.getClass() ) + { + return false; + } + ClusterComposition that = (ClusterComposition) o; + return expirationTimestamp == that.expirationTimestamp && + Objects.equals( readers, that.readers ) && + Objects.equals( writers, that.writers ) && + Objects.equals( routers, that.routers ); + } + + @Override + public int hashCode() + { + return Objects.hash( readers, writers, routers, expirationTimestamp ); + } + @Override public String toString() { return "ClusterComposition{" + - "expirationTimestamp=" + expirationTimestamp + - ", readers=" + readers + - ", writers=" + writers + - ", routers=" + routers + - '}'; + "readers=" + readers + + ", writers=" + writers + + ", routers=" + routers + + ", expirationTimestamp=" + expirationTimestamp + + '}'; } public static ClusterComposition parse( Record record, long now ) diff --git a/driver/src/main/java/org/neo4j/driver/internal/cluster/ClusterCompositionProvider.java b/driver/src/main/java/org/neo4j/driver/internal/cluster/ClusterCompositionProvider.java index 70e0fd901f..5192611f0d 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/cluster/ClusterCompositionProvider.java +++ b/driver/src/main/java/org/neo4j/driver/internal/cluster/ClusterCompositionProvider.java @@ -18,9 +18,15 @@ */ package org.neo4j.driver.internal.cluster; +import java.util.concurrent.CompletionStage; + +import org.neo4j.driver.internal.async.AsyncConnection; import org.neo4j.driver.internal.spi.Connection; public interface ClusterCompositionProvider { ClusterCompositionResponse getClusterComposition( Connection connection ); + + CompletionStage getClusterComposition( + CompletionStage connectionStage ); } diff --git a/driver/src/main/java/org/neo4j/driver/internal/cluster/Rediscovery.java b/driver/src/main/java/org/neo4j/driver/internal/cluster/Rediscovery.java index de4db16771..21d295e4ee 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/cluster/Rediscovery.java +++ b/driver/src/main/java/org/neo4j/driver/internal/cluster/Rediscovery.java @@ -18,10 +18,18 @@ */ package org.neo4j.driver.internal.cluster; +import io.netty.util.concurrent.EventExecutorGroup; + import java.util.Collections; import java.util.HashSet; import java.util.Set; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionException; +import java.util.concurrent.CompletionStage; +import java.util.concurrent.TimeUnit; +import org.neo4j.driver.internal.async.AsyncConnection; +import org.neo4j.driver.internal.async.pool.AsyncConnectionPool; import org.neo4j.driver.internal.net.BoltServerAddress; import org.neo4j.driver.internal.spi.Connection; import org.neo4j.driver.internal.spi.ConnectionPool; @@ -31,6 +39,7 @@ import org.neo4j.driver.v1.exceptions.ServiceUnavailableException; import static java.lang.String.format; +import static java.util.concurrent.CompletableFuture.completedFuture; public class Rediscovery { @@ -42,11 +51,21 @@ public class Rediscovery private final Logger logger; private final ClusterCompositionProvider provider; private final HostNameResolver hostNameResolver; + private final EventExecutorGroup eventExecutorGroup; + + private volatile boolean useInitialRouter; - private boolean useInitialRouter; + public Rediscovery( BoltServerAddress initialRouter, RoutingSettings settings, ClusterCompositionProvider provider, + EventExecutorGroup eventExecutorGroup, HostNameResolver hostNameResolver, Clock clock, Logger logger ) + { + // todo: set useInitialRouter to true when driver only does async + this( initialRouter, settings, provider, hostNameResolver, eventExecutorGroup, clock, logger, false ); + } - public Rediscovery( BoltServerAddress initialRouter, RoutingSettings settings, Clock clock, Logger logger, - ClusterCompositionProvider provider, HostNameResolver hostNameResolver ) + // Test-only constructor + public Rediscovery( BoltServerAddress initialRouter, RoutingSettings settings, ClusterCompositionProvider provider, + HostNameResolver hostNameResolver, EventExecutorGroup eventExecutorGroup, Clock clock, Logger logger, + boolean useInitialRouter ) { this.initialRouter = initialRouter; this.settings = settings; @@ -54,6 +73,8 @@ public Rediscovery( BoltServerAddress initialRouter, RoutingSettings settings, C this.logger = logger; this.provider = provider; this.hostNameResolver = hostNameResolver; + this.eventExecutorGroup = eventExecutorGroup; + this.useInitialRouter = useInitialRouter; } /** @@ -87,6 +108,45 @@ public ClusterComposition lookupClusterComposition( RoutingTable routingTable, C } } + public CompletionStage lookupClusterCompositionAsync( RoutingTable routingTable, + AsyncConnectionPool connectionPool ) + { + CompletableFuture result = new CompletableFuture<>(); + lookupClusterComposition( routingTable, connectionPool, 0, 0, result ); + return result; + } + + private void lookupClusterComposition( RoutingTable routingTable, AsyncConnectionPool pool, + int failures, long previousDelay, CompletableFuture result ) + { + if ( failures >= settings.maxRoutingFailures() ) + { + result.completeExceptionally( new ServiceUnavailableException( NO_ROUTERS_AVAILABLE ) ); + return; + } + + lookupAsync( routingTable, pool ).whenComplete( ( composition, error ) -> + { + if ( error != null ) + { + result.completeExceptionally( error ); + } + else if ( composition != null ) + { + result.complete( composition ); + } + else + { + long nextDelay = Math.max( settings.retryTimeoutDelay(), previousDelay * 2 ); + logger.info( "Unable to fetch new routing table, will try again in " + nextDelay + "ms" ); + eventExecutorGroup.next().schedule( + () -> lookupClusterComposition( routingTable, pool, failures + 1, nextDelay, result ), + nextDelay, TimeUnit.MILLISECONDS + ); + } + } ); + } + private ClusterComposition lookup( RoutingTable routingTable, ConnectionPool connections ) { ClusterComposition composition; @@ -109,6 +169,30 @@ private ClusterComposition lookup( RoutingTable routingTable, ConnectionPool con return composition; } + private CompletionStage lookupAsync( RoutingTable routingTable, + AsyncConnectionPool connectionPool ) + { + CompletionStage compositionStage; + + if ( useInitialRouter ) + { + compositionStage = lookupOnInitialRouterThenOnKnownRoutersAsync( routingTable, connectionPool ); + useInitialRouter = false; + } + else + { + compositionStage = lookupOnKnownRoutersThenOnInitialRouterAsync( routingTable, connectionPool ); + } + + return compositionStage.whenComplete( ( composition, error ) -> + { + if ( composition != null && !composition.hasWriters() ) + { + useInitialRouter = true; + } + } ); + } + private ClusterComposition lookupOnKnownRoutersThenOnInitialRouter( RoutingTable routingTable, ConnectionPool connections ) { @@ -121,6 +205,20 @@ private ClusterComposition lookupOnKnownRoutersThenOnInitialRouter( RoutingTable return composition; } + private CompletionStage lookupOnKnownRoutersThenOnInitialRouterAsync( RoutingTable routingTable, + AsyncConnectionPool connectionPool ) + { + Set seenServers = new HashSet<>(); + return lookupOnKnownRoutersAsync( routingTable, connectionPool, seenServers ).thenCompose( composition -> + { + if ( composition != null ) + { + return completedFuture( composition ); + } + return lookupOnInitialRouterAsync( routingTable, connectionPool, seenServers ); + } ); + } + private ClusterComposition lookupOnInitialRouterThenOnKnownRouters( RoutingTable routingTable, ConnectionPool connections ) { @@ -133,6 +231,20 @@ private ClusterComposition lookupOnInitialRouterThenOnKnownRouters( RoutingTable return composition; } + private CompletionStage lookupOnInitialRouterThenOnKnownRoutersAsync( RoutingTable routingTable, + AsyncConnectionPool connectionPool ) + { + Set seenServers = Collections.emptySet(); + return lookupOnInitialRouterAsync( routingTable, connectionPool, seenServers ).thenCompose( composition -> + { + if ( composition != null ) + { + return completedFuture( composition ); + } + return lookupOnKnownRoutersAsync( routingTable, connectionPool, new HashSet<>() ); + } ); + } + private ClusterComposition lookupOnKnownRouters( RoutingTable routingTable, ConnectionPool connections, Set seenServers ) { @@ -154,11 +266,35 @@ private ClusterComposition lookupOnKnownRouters( RoutingTable routingTable, Conn return null; } + private CompletionStage lookupOnKnownRoutersAsync( RoutingTable routingTable, + AsyncConnectionPool connectionPool, Set seenServers ) + { + BoltServerAddress[] addresses = routingTable.routers().toArray(); + + CompletableFuture result = completedFuture( null ); + for ( BoltServerAddress address : addresses ) + { + result = result.thenCompose( composition -> + { + if ( composition != null ) + { + return completedFuture( composition ); + } + else + { + return lookupOnRouterAsync( address, routingTable, connectionPool ) + .whenComplete( ( ignore, error ) -> seenServers.add( address ) ); + } + } ); + } + return result; + } + private ClusterComposition lookupOnInitialRouter( RoutingTable routingTable, - ConnectionPool connections, Set triedServers ) + ConnectionPool connections, Set seenServers ) { Set ips = hostNameResolver.resolve( initialRouter ); - ips.removeAll( triedServers ); + ips.removeAll( seenServers ); for ( BoltServerAddress address : ips ) { ClusterComposition composition = lookupOnRouter( address, routingTable, connections ); @@ -171,6 +307,27 @@ private ClusterComposition lookupOnInitialRouter( RoutingTable routingTable, return null; } + private CompletionStage lookupOnInitialRouterAsync( RoutingTable routingTable, + AsyncConnectionPool connectionPool, Set seenServers ) + { + Set addresses = hostNameResolver.resolve( initialRouter ); + addresses.removeAll( seenServers ); + + CompletableFuture result = completedFuture( null ); + for ( BoltServerAddress address : addresses ) + { + result = result.thenCompose( composition -> + { + if ( composition != null ) + { + return completedFuture( composition ); + } + return lookupOnRouterAsync( address, routingTable, connectionPool ); + } ); + } + return result; + } + private ClusterComposition lookupOnRouter( BoltServerAddress routerAddress, RoutingTable routingTable, ConnectionPool connections ) { @@ -197,6 +354,43 @@ private ClusterComposition lookupOnRouter( BoltServerAddress routerAddress, Rout return cluster; } + private CompletionStage lookupOnRouterAsync( BoltServerAddress routerAddress, + RoutingTable routingTable, AsyncConnectionPool connectionPool ) + { + CompletionStage connectionStage = connectionPool.acquire( routerAddress ); + + return provider.getClusterComposition( connectionStage ).handle( ( response, error ) -> + { + if ( error != null ) + { + return handleRoutingProcedureError( error, routingTable, routerAddress ); + } + else + { + ClusterComposition cluster = response.clusterComposition(); + logger.info( "Got cluster composition %s", cluster ); + return cluster; + } + } ); + } + + private ClusterComposition handleRoutingProcedureError( Throwable error, RoutingTable routingTable, + BoltServerAddress routerAddress ) + { + if ( error instanceof SecurityException ) + { + // auth error happened, terminate the discovery procedure immediately + throw new CompletionException( error ); + } + else + { + // connection turned out to be broken + logger.error( format( "Failed to connect to routing server '%s'.", routerAddress ), error ); + routingTable.forget( routerAddress ); + return null; + } + } + private void sleep( long millis ) { if ( millis > 0 ) diff --git a/driver/src/main/java/org/neo4j/driver/internal/cluster/RoutingProcedureClusterCompositionProvider.java b/driver/src/main/java/org/neo4j/driver/internal/cluster/RoutingProcedureClusterCompositionProvider.java index c6b9f5f732..d8e8a093a1 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/cluster/RoutingProcedureClusterCompositionProvider.java +++ b/driver/src/main/java/org/neo4j/driver/internal/cluster/RoutingProcedureClusterCompositionProvider.java @@ -19,13 +19,14 @@ package org.neo4j.driver.internal.cluster; import java.util.List; +import java.util.concurrent.CompletionStage; +import org.neo4j.driver.internal.async.AsyncConnection; import org.neo4j.driver.internal.spi.Connection; import org.neo4j.driver.internal.util.Clock; import org.neo4j.driver.v1.Logger; import org.neo4j.driver.v1.Record; import org.neo4j.driver.v1.Statement; -import org.neo4j.driver.v1.exceptions.ClientException; import org.neo4j.driver.v1.exceptions.ProtocolException; import org.neo4j.driver.v1.exceptions.ServiceUnavailableException; import org.neo4j.driver.v1.exceptions.value.ValueException; @@ -38,39 +39,48 @@ public class RoutingProcedureClusterCompositionProvider implements ClusterCompos private final Clock clock; private final Logger log; - private final RoutingProcedureRunner getServersRunner; + private final RoutingProcedureRunner routingProcedureRunner; public RoutingProcedureClusterCompositionProvider( Clock clock, Logger log, RoutingSettings settings ) { this( clock, log, new RoutingProcedureRunner( settings.routingContext() ) ); } - RoutingProcedureClusterCompositionProvider( Clock clock, Logger log, RoutingProcedureRunner getServersRunner ) + RoutingProcedureClusterCompositionProvider( Clock clock, Logger log, RoutingProcedureRunner routingProcedureRunner ) { this.clock = clock; this.log = log; - this.getServersRunner = getServersRunner; + this.routingProcedureRunner = routingProcedureRunner; } @Override public ClusterCompositionResponse getClusterComposition( Connection connection ) { - List records; + RoutingProcedureResponse response = routingProcedureRunner.run( connection ); + return processRoutingResponse( response ); + } - // failed to invoke procedure - try - { - records = getServersRunner.run( connection ); - } - catch ( ClientException e ) + @Override + public CompletionStage getClusterComposition( + CompletionStage connectionStage ) + { + return routingProcedureRunner.run( connectionStage ) + .thenApply( this::processRoutingResponse ); + } + + private ClusterCompositionResponse processRoutingResponse( RoutingProcedureResponse response ) + { + if ( !response.isSuccess() ) { return new ClusterCompositionResponse.Failure( new ServiceUnavailableException( format( "Failed to run '%s' on server. " + "Please make sure that there is a Neo4j 3.1+ causal cluster up running.", - invokedProcedureString() ), e + invokedProcedureString( response ) ), response.error() ) ); } + List records = response.records(); + log.info( "Got getServers response: %s", records ); long now = clock.millis(); @@ -79,7 +89,7 @@ public ClusterCompositionResponse getClusterComposition( Connection connection ) { return new ClusterCompositionResponse.Failure( new ProtocolException( format( PROTOCOL_ERROR_MESSAGE + "records received '%s' is too few or too many.", - invokedProcedureString(), records.size() ) ) ); + invokedProcedureString( response ), records.size() ) ) ); } // failed to parse the record @@ -92,7 +102,7 @@ public ClusterCompositionResponse getClusterComposition( Connection connection ) { return new ClusterCompositionResponse.Failure( new ProtocolException( format( PROTOCOL_ERROR_MESSAGE + "unparsable record received.", - invokedProcedureString() ), e ) ); + invokedProcedureString( response ) ), e ) ); } // the cluster result is not a legal reply @@ -100,16 +110,16 @@ public ClusterCompositionResponse getClusterComposition( Connection connection ) { return new ClusterCompositionResponse.Failure( new ProtocolException( format( PROTOCOL_ERROR_MESSAGE + "no router or reader found in response.", - invokedProcedureString() ) ) ); + invokedProcedureString( response ) ) ) ); } // all good return new ClusterCompositionResponse.Success( cluster ); } - private String invokedProcedureString() + private static String invokedProcedureString( RoutingProcedureResponse response ) { - Statement statement = getServersRunner.invokedProcedure(); + Statement statement = response.procedure(); return statement.text() + " " + statement.parameters(); } } diff --git a/driver/src/main/java/org/neo4j/driver/internal/cluster/RoutingProcedureResponse.java b/driver/src/main/java/org/neo4j/driver/internal/cluster/RoutingProcedureResponse.java new file mode 100644 index 0000000000..b1f5a2a1ef --- /dev/null +++ b/driver/src/main/java/org/neo4j/driver/internal/cluster/RoutingProcedureResponse.java @@ -0,0 +1,76 @@ +/* + * Copyright (c) 2002-2017 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal.cluster; + +import java.util.List; + +import org.neo4j.driver.v1.Record; +import org.neo4j.driver.v1.Statement; + +public class RoutingProcedureResponse +{ + private final Statement procedure; + private final List records; + private final Throwable error; + + public RoutingProcedureResponse( Statement procedure, List records ) + { + this( procedure, records, null ); + } + + public RoutingProcedureResponse( Statement procedure, Throwable error ) + { + this( procedure, null, error ); + } + + private RoutingProcedureResponse( Statement procedure, List records, Throwable error ) + { + this.procedure = procedure; + this.records = records; + this.error = error; + } + + public boolean isSuccess() + { + return records != null; + } + + public Statement procedure() + { + return procedure; + } + + public List records() + { + if ( !isSuccess() ) + { + throw new IllegalStateException( "Can't access records of a failed result", error ); + } + return records; + } + + public Throwable error() + { + if ( isSuccess() ) + { + throw new IllegalStateException( "Can't access error of a succeeded result " + records ); + } + return error; + } +} diff --git a/driver/src/main/java/org/neo4j/driver/internal/cluster/RoutingProcedureRunner.java b/driver/src/main/java/org/neo4j/driver/internal/cluster/RoutingProcedureRunner.java index d604a06d83..3fb9d365b3 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/cluster/RoutingProcedureRunner.java +++ b/driver/src/main/java/org/neo4j/driver/internal/cluster/RoutingProcedureRunner.java @@ -20,15 +20,21 @@ package org.neo4j.driver.internal.cluster; import java.util.List; +import java.util.concurrent.CompletionException; +import java.util.concurrent.CompletionStage; import org.neo4j.driver.ResultResourcesHandler; import org.neo4j.driver.internal.NetworkSession; +import org.neo4j.driver.internal.async.AsyncConnection; +import org.neo4j.driver.internal.async.QueryRunner; import org.neo4j.driver.internal.spi.Connection; +import org.neo4j.driver.internal.util.ServerVersion; import org.neo4j.driver.v1.Record; import org.neo4j.driver.v1.Statement; +import org.neo4j.driver.v1.StatementResultCursor; +import org.neo4j.driver.v1.exceptions.ClientException; import static org.neo4j.driver.internal.util.ServerVersion.v3_2_0; -import static org.neo4j.driver.internal.util.ServerVersion.version; import static org.neo4j.driver.v1.Values.parameters; public class RoutingProcedureRunner @@ -38,26 +44,43 @@ public class RoutingProcedureRunner static final String GET_ROUTING_TABLE = "dbms.cluster.routing.getRoutingTable({" + GET_ROUTING_TABLE_PARAM + "})"; private final RoutingContext context; - private Statement invokedProcedure; public RoutingProcedureRunner( RoutingContext context ) { this.context = context; } - public List run( Connection connection ) + public RoutingProcedureResponse run( Connection connection ) { - if( version( connection.server().version() ).greaterThanOrEqual( v3_2_0 ) ) + Statement procedure = procedureStatement( ServerVersion.version( connection.server().version() ) ); + + try { - invokedProcedure = new Statement( "CALL " + GET_ROUTING_TABLE, - parameters( GET_ROUTING_TABLE_PARAM, context.asMap() ) ); + return new RoutingProcedureResponse( procedure, runProcedure( connection, procedure ) ); } - else + catch ( ClientException error ) { - invokedProcedure = new Statement( "CALL " + GET_SERVERS ); + return new RoutingProcedureResponse( procedure, error ); } + } - return runProcedure( connection, invokedProcedure ); + public CompletionStage run( CompletionStage connectionStage ) + { + return connectionStage.thenCompose( connection -> + { + Statement procedure = procedureStatement( connection.serverVersion() ); + return runProcedure( connection, procedure ).handle( ( records, error ) -> + { + if ( error != null ) + { + return handleError( procedure, error ); + } + else + { + return new RoutingProcedureResponse( procedure, records ); + } + } ); + } ); } List runProcedure( Connection connection, Statement procedure ) @@ -65,8 +88,34 @@ List runProcedure( Connection connection, Statement procedure ) return NetworkSession.run( connection, procedure, ResultResourcesHandler.NO_OP ).list(); } - Statement invokedProcedure() + CompletionStage> runProcedure( AsyncConnection connection, Statement procedure ) { - return invokedProcedure; + return QueryRunner.runAsync( connection, procedure ) + .thenCompose( StatementResultCursor::listAsync ); + } + + private Statement procedureStatement( ServerVersion serverVersion ) + { + if ( serverVersion.greaterThanOrEqual( v3_2_0 ) ) + { + return new Statement( "CALL " + GET_ROUTING_TABLE, + parameters( GET_ROUTING_TABLE_PARAM, context.asMap() ) ); + } + else + { + return new Statement( "CALL " + GET_SERVERS ); + } + } + + private RoutingProcedureResponse handleError( Statement procedure, Throwable error ) + { + if ( error instanceof ClientException ) + { + return new RoutingProcedureResponse( procedure, error ); + } + else + { + throw new CompletionException( error ); + } } } diff --git a/driver/src/main/java/org/neo4j/driver/internal/cluster/loadbalancing/LeastConnectedLoadBalancingStrategy.java b/driver/src/main/java/org/neo4j/driver/internal/cluster/loadbalancing/LeastConnectedLoadBalancingStrategy.java index d9157d6f21..8d07f5ae71 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/cluster/loadbalancing/LeastConnectedLoadBalancingStrategy.java +++ b/driver/src/main/java/org/neo4j/driver/internal/cluster/loadbalancing/LeastConnectedLoadBalancingStrategy.java @@ -18,6 +18,9 @@ */ package org.neo4j.driver.internal.cluster.loadbalancing; +import java.util.function.Function; + +import org.neo4j.driver.internal.async.pool.AsyncConnectionPool; import org.neo4j.driver.internal.net.BoltServerAddress; import org.neo4j.driver.internal.spi.ConnectionPool; import org.neo4j.driver.v1.Logger; @@ -36,28 +39,44 @@ public class LeastConnectedLoadBalancingStrategy implements LoadBalancingStrateg private final RoundRobinArrayIndex writersIndex = new RoundRobinArrayIndex(); private final ConnectionPool connectionPool; + private final AsyncConnectionPool asyncConnectionPool; private final Logger log; - public LeastConnectedLoadBalancingStrategy( ConnectionPool connectionPool, Logging logging ) + public LeastConnectedLoadBalancingStrategy( ConnectionPool connectionPool, AsyncConnectionPool asyncConnectionPool, + Logging logging ) { this.connectionPool = connectionPool; + this.asyncConnectionPool = asyncConnectionPool; this.log = logging.getLog( LOGGER_NAME ); } @Override public BoltServerAddress selectReader( BoltServerAddress[] knownReaders ) { - return select( knownReaders, readersIndex, "reader" ); + return select( knownReaders, readersIndex, "reader", connectionPool::activeConnections ); + } + + @Override + public BoltServerAddress selectReaderAsync( BoltServerAddress[] knownReaders ) + { + return select( knownReaders, readersIndex, "reader", asyncConnectionPool::activeConnections ); } @Override public BoltServerAddress selectWriter( BoltServerAddress[] knownWriters ) { - return select( knownWriters, writersIndex, "writer" ); + return select( knownWriters, writersIndex, "writer", connectionPool::activeConnections ); + } + + @Override + public BoltServerAddress selectWriterAsync( BoltServerAddress[] knownWriters ) + { + return select( knownWriters, writersIndex, "writer", asyncConnectionPool::activeConnections ); } + // todo: remove Function from params when only async is supported private BoltServerAddress select( BoltServerAddress[] addresses, RoundRobinArrayIndex addressesIndex, - String addressType ) + String addressType, Function activeConnectionFunction ) { int size = addresses.length; if ( size == 0 ) @@ -77,7 +96,7 @@ private BoltServerAddress select( BoltServerAddress[] addresses, RoundRobinArray do { BoltServerAddress address = addresses[index]; - int activeConnections = connectionPool.activeConnections( address ); + int activeConnections = activeConnectionFunction.apply( address ); if ( activeConnections < leastActiveConnections ) { diff --git a/driver/src/main/java/org/neo4j/driver/internal/cluster/loadbalancing/LoadBalancer.java b/driver/src/main/java/org/neo4j/driver/internal/cluster/loadbalancing/LoadBalancer.java index 2cc0c2de91..4c913298ee 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/cluster/loadbalancing/LoadBalancer.java +++ b/driver/src/main/java/org/neo4j/driver/internal/cluster/loadbalancing/LoadBalancer.java @@ -18,11 +18,17 @@ */ package org.neo4j.driver.internal.cluster.loadbalancing; +import io.netty.util.concurrent.EventExecutorGroup; + import java.util.Set; +import java.util.concurrent.CompletableFuture; import java.util.concurrent.CompletionStage; import org.neo4j.driver.internal.RoutingErrorHandler; import org.neo4j.driver.internal.async.AsyncConnection; +import org.neo4j.driver.internal.async.Futures; +import org.neo4j.driver.internal.async.RoutingAsyncConnection; +import org.neo4j.driver.internal.async.pool.AsyncConnectionPool; import org.neo4j.driver.internal.cluster.AddressSet; import org.neo4j.driver.internal.cluster.ClusterComposition; import org.neo4j.driver.internal.cluster.ClusterCompositionProvider; @@ -44,48 +50,62 @@ import org.neo4j.driver.v1.exceptions.ServiceUnavailableException; import org.neo4j.driver.v1.exceptions.SessionExpiredException; +import static java.util.concurrent.CompletableFuture.completedFuture; + public class LoadBalancer implements ConnectionProvider, RoutingErrorHandler, AutoCloseable { private static final String LOAD_BALANCER_LOG_NAME = "LoadBalancer"; private final ConnectionPool connections; + private final AsyncConnectionPool asyncConnectionPool; private final RoutingTable routingTable; private final Rediscovery rediscovery; private final LoadBalancingStrategy loadBalancingStrategy; + private final EventExecutorGroup eventExecutorGroup; private final Logger log; + private CompletableFuture refreshRoutingTableFuture; + public LoadBalancer( BoltServerAddress initialRouter, RoutingSettings settings, ConnectionPool connections, - Clock clock, Logging logging, LoadBalancingStrategy loadBalancingStrategy ) + AsyncConnectionPool asyncConnectionPool, EventExecutorGroup eventExecutorGroup, Clock clock, + Logging logging, LoadBalancingStrategy loadBalancingStrategy ) { - this( connections, new ClusterRoutingTable( clock, initialRouter ), - createRediscovery( initialRouter, settings, clock, logging ), loadBalancerLogger( logging ), - loadBalancingStrategy ); + this( connections, asyncConnectionPool, new ClusterRoutingTable( clock, initialRouter ), + createRediscovery( initialRouter, settings, eventExecutorGroup, clock, logging ), + loadBalancerLogger( logging ), loadBalancingStrategy, eventExecutorGroup ); } // Used only in testing - public LoadBalancer( ConnectionPool connections, RoutingTable routingTable, Rediscovery rediscovery, - Logging logging ) + public LoadBalancer( ConnectionPool connections, AsyncConnectionPool asyncConnectionPool, + RoutingTable routingTable, Rediscovery rediscovery, EventExecutorGroup eventExecutorGroup, Logging logging ) { - this( connections, routingTable, rediscovery, loadBalancerLogger( logging ), - new LeastConnectedLoadBalancingStrategy( connections, logging ) ); + this( connections, asyncConnectionPool, routingTable, rediscovery, loadBalancerLogger( logging ), + new LeastConnectedLoadBalancingStrategy( connections, asyncConnectionPool, logging ), + eventExecutorGroup ); } - private LoadBalancer( ConnectionPool connections, RoutingTable routingTable, Rediscovery rediscovery, Logger log, - LoadBalancingStrategy loadBalancingStrategy ) + private LoadBalancer( ConnectionPool connections, AsyncConnectionPool asyncConnectionPool, + RoutingTable routingTable, Rediscovery rediscovery, Logger log, + LoadBalancingStrategy loadBalancingStrategy, EventExecutorGroup eventExecutorGroup ) { this.connections = connections; + this.asyncConnectionPool = asyncConnectionPool; this.routingTable = routingTable; this.rediscovery = rediscovery; this.loadBalancingStrategy = loadBalancingStrategy; + this.eventExecutorGroup = eventExecutorGroup; this.log = log; - refreshRoutingTable(); + if ( connections != null ) + { + refreshRoutingTable(); + } } @Override public PooledConnection acquireConnection( AccessMode mode ) { - AddressSet addressSet = addressSetFor( mode ); + AddressSet addressSet = addressSet( mode, routingTable ); PooledConnection connection = acquireConnection( mode, addressSet ); return new RoutingPooledConnection( connection, this, mode ); } @@ -93,7 +113,9 @@ public PooledConnection acquireConnection( AccessMode mode ) @Override public CompletionStage acquireAsyncConnection( AccessMode mode ) { - throw new UnsupportedOperationException(); + return freshRoutingTable( mode ) + .thenCompose( routingTable -> acquireAsync( mode, routingTable ) ) + .thenApply( connection -> new RoutingAsyncConnection( connection, mode, this ) ); } @Override @@ -112,6 +134,7 @@ public void onWriteFailure( BoltServerAddress address ) public void close() throws Exception { connections.close(); + Futures.getBlocking( asyncConnectionPool.closeAsync() ); } private PooledConnection acquireConnection( AccessMode mode, AddressSet servers ) @@ -138,7 +161,11 @@ private synchronized void forget( BoltServerAddress address ) // First remove from the load balancer, to prevent concurrent threads from making connections to them. routingTable.forget( address ); // drop all current connections to the address - connections.purge( address ); + if ( connections != null ) + { + connections.purge( address ); + } + asyncConnectionPool.purge( address ); } synchronized void ensureRouting( AccessMode mode ) @@ -165,7 +192,109 @@ synchronized void refreshRoutingTable() log.info( "Refreshed routing information. %s", routingTable ); } - private AddressSet addressSetFor( AccessMode mode ) + private synchronized CompletionStage freshRoutingTable( AccessMode mode ) + { + if ( refreshRoutingTableFuture != null ) + { + // refresh is already happening concurrently, just use it's result + return refreshRoutingTableFuture; + } + else if ( routingTable.isStaleFor( mode ) ) + { + // existing routing table is not fresh and should be updated + log.info( "Routing information is stale. %s", routingTable ); + + CompletableFuture resultFuture = new CompletableFuture<>(); + refreshRoutingTableFuture = resultFuture; + + rediscovery.lookupClusterCompositionAsync( routingTable, asyncConnectionPool ) + .whenComplete( ( composition, error ) -> + { + if ( error != null ) + { + clusterCompositionLookupFailed( error ); + } + else + { + freshClusterCompositionFetched( composition ); + } + } ); + + return resultFuture; + } + else + { + // existing routing table is fresh, use it + return completedFuture( routingTable ); + } + } + + private synchronized void freshClusterCompositionFetched( ClusterComposition composition ) + { + Set removed = routingTable.update( composition ); + + for ( BoltServerAddress address : removed ) + { + asyncConnectionPool.purge( address ); + } + + log.info( "Refreshed routing information. %s", routingTable ); + + CompletableFuture routingTableFuture = refreshRoutingTableFuture; + refreshRoutingTableFuture = null; + routingTableFuture.complete( routingTable ); + } + + private synchronized void clusterCompositionLookupFailed( Throwable error ) + { + CompletableFuture routingTableFuture = refreshRoutingTableFuture; + refreshRoutingTableFuture = null; + routingTableFuture.completeExceptionally( error ); + } + + private CompletionStage acquireAsync( AccessMode mode, RoutingTable routingTable ) + { + AddressSet addresses = addressSet( mode, routingTable ); + CompletableFuture result = new CompletableFuture<>(); + acquireAsync( mode, addresses, result ); + return result; + } + + private void acquireAsync( AccessMode mode, AddressSet addresses, CompletableFuture result ) + { + BoltServerAddress address = selectAddressAsync( mode, addresses ); + + if ( address == null ) + { + result.completeExceptionally( new SessionExpiredException( + "Failed to obtain connection towards " + mode + " server. " + + "Known routing table is: " + routingTable ) ); + return; + } + + asyncConnectionPool.acquire( address ).whenComplete( ( connection, error ) -> + { + if ( error != null ) + { + if ( error instanceof ServiceUnavailableException ) + { + log.error( "Failed to obtain a connection towards address " + address, error ); + forget( address ); + eventExecutorGroup.next().execute( () -> acquireAsync( mode, addresses, result ) ); + } + else + { + result.completeExceptionally( error ); + } + } + else + { + result.complete( connection ); + } + } ); + } + + private static AddressSet addressSet( AccessMode mode, RoutingTable routingTable ) { switch ( mode ) { @@ -193,13 +322,29 @@ private BoltServerAddress selectAddress( AccessMode mode, AddressSet servers ) } } + private BoltServerAddress selectAddressAsync( AccessMode mode, AddressSet servers ) + { + BoltServerAddress[] addresses = servers.toArray(); + + switch ( mode ) + { + case READ: + return loadBalancingStrategy.selectReaderAsync( addresses ); + case WRITE: + return loadBalancingStrategy.selectWriterAsync( addresses ); + default: + throw unknownMode( mode ); + } + } + private static Rediscovery createRediscovery( BoltServerAddress initialRouter, RoutingSettings settings, - Clock clock, Logging logging ) + EventExecutorGroup eventExecutorGroup, Clock clock, Logging logging ) { Logger log = loadBalancerLogger( logging ); - ClusterCompositionProvider clusterComposition = + ClusterCompositionProvider clusterCompositionProvider = new RoutingProcedureClusterCompositionProvider( clock, log, settings ); - return new Rediscovery( initialRouter, settings, clock, log, clusterComposition, new DnsResolver( log ) ); + return new Rediscovery( initialRouter, settings, clusterCompositionProvider, eventExecutorGroup, + new DnsResolver( log ), clock, log ); } private static Logger loadBalancerLogger( Logging logging ) diff --git a/driver/src/main/java/org/neo4j/driver/internal/cluster/loadbalancing/LoadBalancingStrategy.java b/driver/src/main/java/org/neo4j/driver/internal/cluster/loadbalancing/LoadBalancingStrategy.java index 0b3ee3c8e1..a8e238154c 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/cluster/loadbalancing/LoadBalancingStrategy.java +++ b/driver/src/main/java/org/neo4j/driver/internal/cluster/loadbalancing/LoadBalancingStrategy.java @@ -33,6 +33,8 @@ public interface LoadBalancingStrategy */ BoltServerAddress selectReader( BoltServerAddress[] knownReaders ); + BoltServerAddress selectReaderAsync( BoltServerAddress[] knownReaders ); + /** * Select most appropriate write address from the given array of addresses. * @@ -40,4 +42,6 @@ public interface LoadBalancingStrategy * @return most appropriate writer or {@code null} if it can't be selected. */ BoltServerAddress selectWriter( BoltServerAddress[] knownWriters ); + + BoltServerAddress selectWriterAsync( BoltServerAddress[] knownWriters ); } diff --git a/driver/src/main/java/org/neo4j/driver/internal/cluster/loadbalancing/RoundRobinLoadBalancingStrategy.java b/driver/src/main/java/org/neo4j/driver/internal/cluster/loadbalancing/RoundRobinLoadBalancingStrategy.java index 1a977cd65b..0a7174cb20 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/cluster/loadbalancing/RoundRobinLoadBalancingStrategy.java +++ b/driver/src/main/java/org/neo4j/driver/internal/cluster/loadbalancing/RoundRobinLoadBalancingStrategy.java @@ -46,12 +46,24 @@ public BoltServerAddress selectReader( BoltServerAddress[] knownReaders ) return select( knownReaders, readersIndex, "reader" ); } + @Override + public BoltServerAddress selectReaderAsync( BoltServerAddress[] knownReaders ) + { + return selectReader( knownReaders ); + } + @Override public BoltServerAddress selectWriter( BoltServerAddress[] knownWriters ) { return select( knownWriters, writersIndex, "writer" ); } + @Override + public BoltServerAddress selectWriterAsync( BoltServerAddress[] knownWriters ) + { + return selectWriter( knownWriters ); + } + private BoltServerAddress select( BoltServerAddress[] addresses, RoundRobinArrayIndex roundRobinIndex, String addressType ) { diff --git a/driver/src/main/java/org/neo4j/driver/internal/handlers/AsyncInitResponseHandler.java b/driver/src/main/java/org/neo4j/driver/internal/handlers/AsyncInitResponseHandler.java index 2d8310d926..b2fcbc0652 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/handlers/AsyncInitResponseHandler.java +++ b/driver/src/main/java/org/neo4j/driver/internal/handlers/AsyncInitResponseHandler.java @@ -32,6 +32,7 @@ import org.neo4j.driver.v1.Value; import static org.neo4j.driver.internal.async.ChannelAttributes.setServerVersion; +import static org.neo4j.driver.internal.util.ServerVersion.version; public class AsyncInitResponseHandler implements ResponseHandler { @@ -47,14 +48,23 @@ public AsyncInitResponseHandler( ChannelPromise connectionInitializedPromise ) @Override public void onSuccess( Map metadata ) { - Value versionValue = metadata.get( "server" ); - if ( versionValue != null ) + try { - String serverVersion = versionValue.asString(); - setServerVersion( channel, serverVersion ); - updatePipelineIfNeeded( serverVersion, channel.pipeline() ); + Value versionValue = metadata.get( "server" ); + if ( versionValue != null ) + { + String versionString = versionValue.asString(); + ServerVersion version = version( versionString ); + setServerVersion( channel, version ); + updatePipelineIfNeeded( version, channel.pipeline() ); + } + connectionInitializedPromise.setSuccess(); + } + catch ( Throwable error ) + { + connectionInitializedPromise.setFailure( error ); + throw error; } - connectionInitializedPromise.setSuccess(); } @Override @@ -76,9 +86,8 @@ public void onRecord( Value[] fields ) throw new UnsupportedOperationException(); } - private static void updatePipelineIfNeeded( String serverVersionString, ChannelPipeline pipeline ) + private static void updatePipelineIfNeeded( ServerVersion serverVersion, ChannelPipeline pipeline ) { - ServerVersion serverVersion = ServerVersion.version( serverVersionString ); if ( serverVersion.lessThan( ServerVersion.v3_2_0 ) ) { OutboundMessageHandler outboundHandler = pipeline.get( OutboundMessageHandler.class ); diff --git a/driver/src/main/java/org/neo4j/driver/internal/handlers/PullAllResponseHandler.java b/driver/src/main/java/org/neo4j/driver/internal/handlers/PullAllResponseHandler.java index a8b59e6574..79bd00ab5c 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/handlers/PullAllResponseHandler.java +++ b/driver/src/main/java/org/neo4j/driver/internal/handlers/PullAllResponseHandler.java @@ -33,6 +33,7 @@ import org.neo4j.driver.internal.summary.InternalPlan; import org.neo4j.driver.internal.summary.InternalProfiledPlan; import org.neo4j.driver.internal.summary.InternalResultSummary; +import org.neo4j.driver.internal.summary.InternalServerInfo; import org.neo4j.driver.internal.summary.InternalSummaryCounters; import org.neo4j.driver.v1.Record; import org.neo4j.driver.v1.Statement; @@ -204,7 +205,9 @@ private void failRecordFuture( Throwable error ) private ResultSummary extractResultSummary( Map metadata ) { - return new InternalResultSummary( statement, connection.serverInfo(), extractStatementType( metadata ), + InternalServerInfo serverInfo = new InternalServerInfo( connection.serverAddress(), + connection.serverVersion() ); + return new InternalResultSummary( statement, serverInfo, extractStatementType( metadata ), extractCounters( metadata ), extractPlan( metadata ), extractProfiledPlan( metadata ), extractNotifications( metadata ), runResponseHandler.resultAvailableAfter(), extractResultConsumedAfter( metadata ) ); diff --git a/driver/src/main/java/org/neo4j/driver/internal/summary/InternalServerInfo.java b/driver/src/main/java/org/neo4j/driver/internal/summary/InternalServerInfo.java index a03e5efac9..968e3c6d2a 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/summary/InternalServerInfo.java +++ b/driver/src/main/java/org/neo4j/driver/internal/summary/InternalServerInfo.java @@ -20,6 +20,7 @@ package org.neo4j.driver.internal.summary; import org.neo4j.driver.internal.net.BoltServerAddress; +import org.neo4j.driver.internal.util.ServerVersion; import org.neo4j.driver.v1.summary.ServerInfo; public class InternalServerInfo implements ServerInfo @@ -27,6 +28,11 @@ public class InternalServerInfo implements ServerInfo private final BoltServerAddress address; private final String version; + public InternalServerInfo( BoltServerAddress address, ServerVersion version ) + { + this( address, version.toString() ); + } + public InternalServerInfo( BoltServerAddress address, String version ) { this.address = address; diff --git a/driver/src/main/java/org/neo4j/driver/internal/util/ServerVersion.java b/driver/src/main/java/org/neo4j/driver/internal/util/ServerVersion.java index 0af4b21e85..d3f713b8b5 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/util/ServerVersion.java +++ b/driver/src/main/java/org/neo4j/driver/internal/util/ServerVersion.java @@ -152,6 +152,8 @@ private int compareTo( ServerVersion o ) @Override public String toString() { - return String.format( "%s.%s.%s", major, minor, patch ); + return this == vInDev + ? NEO4J_IN_DEV_VERSION_STRING + : String.format( "Neo4j/%s.%s.%s", major, minor, patch ); } } diff --git a/driver/src/test/java/org/neo4j/driver/internal/DriverFactoryTest.java b/driver/src/test/java/org/neo4j/driver/internal/DriverFactoryTest.java index dd7be8391b..faf0549666 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/DriverFactoryTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/DriverFactoryTest.java @@ -18,6 +18,7 @@ */ package org.neo4j.driver.internal; +import io.netty.util.concurrent.EventExecutorGroup; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; @@ -28,6 +29,7 @@ import java.util.Arrays; import java.util.List; +import org.neo4j.driver.internal.async.pool.AsyncConnectionPool; import org.neo4j.driver.internal.cluster.RoutingSettings; import org.neo4j.driver.internal.cluster.loadbalancing.LoadBalancer; import org.neo4j.driver.internal.net.BoltServerAddress; @@ -168,8 +170,9 @@ protected InternalDriver createDriver( Config config, SecurityPlan securityPlan, } @Override - protected Driver createRoutingDriver( BoltServerAddress address, ConnectionPool connectionPool, Config config, - RoutingSettings routingSettings, SecurityPlan securityPlan, RetryLogic retryLogic ) + protected Driver createRoutingDriver( BoltServerAddress address, ConnectionPool connectionPool, + AsyncConnectionPool asyncConnectionPool, Config config, RoutingSettings routingSettings, + SecurityPlan securityPlan, RetryLogic retryLogic, EventExecutorGroup eventExecutorGroup ) { throw new UnsupportedOperationException( "Can't create routing driver" ); } @@ -193,6 +196,7 @@ protected InternalDriver createDriver( Config config, SecurityPlan securityPlan, @Override protected LoadBalancer createLoadBalancer( BoltServerAddress address, ConnectionPool connectionPool, + AsyncConnectionPool asyncConnectionPool, EventExecutorGroup eventExecutorGroup, Config config, RoutingSettings routingSettings ) { return null; diff --git a/driver/src/test/java/org/neo4j/driver/internal/RoutingDriverTest.java b/driver/src/test/java/org/neo4j/driver/internal/RoutingDriverTest.java index 123af37f74..3f2273aab6 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/RoutingDriverTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/RoutingDriverTest.java @@ -18,6 +18,7 @@ */ package org.neo4j.driver.internal; +import io.netty.util.concurrent.GlobalEventExecutor; import org.junit.After; import org.junit.Rule; import org.junit.Test; @@ -31,9 +32,11 @@ import java.util.Collections; import java.util.Map; +import org.neo4j.driver.internal.async.pool.AsyncConnectionPool; import org.neo4j.driver.internal.cluster.RoutingSettings; import org.neo4j.driver.internal.cluster.loadbalancing.LeastConnectedLoadBalancingStrategy; import org.neo4j.driver.internal.cluster.loadbalancing.LoadBalancer; +import org.neo4j.driver.internal.cluster.loadbalancing.LoadBalancingStrategy; import org.neo4j.driver.internal.net.BoltServerAddress; import org.neo4j.driver.internal.retry.FixedRetryLogic; import org.neo4j.driver.internal.retry.RetryLogic; @@ -65,7 +68,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; -import static org.neo4j.driver.internal.cluster.ClusterCompositionProviderTest.serverInfo; +import static org.neo4j.driver.internal.cluster.RoutingProcedureClusterCompositionProviderTest.serverInfo; import static org.neo4j.driver.internal.logging.DevNullLogging.DEV_NULL_LOGGING; import static org.neo4j.driver.internal.security.SecurityPlan.insecure; import static org.neo4j.driver.v1.Values.value; @@ -360,8 +363,11 @@ private Driver driverWithPool( ConnectionPool pool ) { Logging logging = DEV_NULL_LOGGING; RoutingSettings settings = new RoutingSettings( 10, 5_000, null ); - ConnectionProvider connectionProvider = new LoadBalancer( SEED, settings, pool, clock, logging, - new LeastConnectedLoadBalancingStrategy( pool, logging ) ); + AsyncConnectionPool asyncConnectionPool = mock( AsyncConnectionPool.class ); + LoadBalancingStrategy loadBalancingStrategy = new LeastConnectedLoadBalancingStrategy( pool, + asyncConnectionPool, logging ); + ConnectionProvider connectionProvider = new LoadBalancer( SEED, settings, pool, asyncConnectionPool, + GlobalEventExecutor.INSTANCE, clock, logging, loadBalancingStrategy ); Config config = Config.build().withLogging( logging ).toConfig(); SessionFactory sessionFactory = new NetworkSessionWithAddressFactory( connectionProvider, config ); return new InternalDriver( insecure(), sessionFactory, logging ); diff --git a/driver/src/test/java/org/neo4j/driver/internal/async/ChannelAttributesTest.java b/driver/src/test/java/org/neo4j/driver/internal/async/ChannelAttributesTest.java index 2c9f6073c9..af67380706 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/async/ChannelAttributesTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/async/ChannelAttributesTest.java @@ -24,6 +24,7 @@ import org.neo4j.driver.internal.async.inbound.InboundMessageDispatcher; import org.neo4j.driver.internal.net.BoltServerAddress; +import org.neo4j.driver.internal.util.ServerVersion; import static org.hamcrest.Matchers.instanceOf; import static org.junit.Assert.assertEquals; @@ -31,16 +32,17 @@ import static org.junit.Assert.assertThat; import static org.junit.Assert.fail; import static org.mockito.Mockito.mock; -import static org.neo4j.driver.internal.async.ChannelAttributes.address; import static org.neo4j.driver.internal.async.ChannelAttributes.creationTimestamp; import static org.neo4j.driver.internal.async.ChannelAttributes.lastUsedTimestamp; import static org.neo4j.driver.internal.async.ChannelAttributes.messageDispatcher; +import static org.neo4j.driver.internal.async.ChannelAttributes.serverAddress; import static org.neo4j.driver.internal.async.ChannelAttributes.serverVersion; -import static org.neo4j.driver.internal.async.ChannelAttributes.setAddress; import static org.neo4j.driver.internal.async.ChannelAttributes.setCreationTimestamp; import static org.neo4j.driver.internal.async.ChannelAttributes.setLastUsedTimestamp; import static org.neo4j.driver.internal.async.ChannelAttributes.setMessageDispatcher; +import static org.neo4j.driver.internal.async.ChannelAttributes.setServerAddress; import static org.neo4j.driver.internal.async.ChannelAttributes.setServerVersion; +import static org.neo4j.driver.internal.util.ServerVersion.version; public class ChannelAttributesTest { @@ -56,18 +58,18 @@ public void tearDown() throws Exception public void shouldSetAndGetAddress() { BoltServerAddress address = new BoltServerAddress( "local:42" ); - setAddress( channel, address ); - assertEquals( address, address( channel ) ); + setServerAddress( channel, address ); + assertEquals( address, serverAddress( channel ) ); } @Test public void shouldFailToSetAddressTwice() { - setAddress( channel, BoltServerAddress.LOCAL_DEFAULT ); + setServerAddress( channel, BoltServerAddress.LOCAL_DEFAULT ); try { - setAddress( channel, BoltServerAddress.LOCAL_DEFAULT ); + setServerAddress( channel, BoltServerAddress.LOCAL_DEFAULT ); fail( "Exception expected" ); } catch ( Exception e ) @@ -144,18 +146,19 @@ public void shouldFailToSetMessageDispatcherTwice() @Test public void shouldSetAndGetServerVersion() { - setServerVersion( channel, "3.2.1" ); - assertEquals( "3.2.1", serverVersion( channel ) ); + ServerVersion version = version( "3.2.1" ); + setServerVersion( channel, version ); + assertEquals( version, serverVersion( channel ) ); } @Test public void shouldFailToSetServerVersionTwice() { - setServerVersion( channel, "3.2.2" ); + setServerVersion( channel, version( "3.2.2" ) ); try { - setServerVersion( channel, "3.2.3" ); + setServerVersion( channel, version( "3.2.3" ) ); fail( "Exception expected" ); } catch ( Exception e ) diff --git a/driver/src/test/java/org/neo4j/driver/internal/async/RoutingAsyncConnectionTest.java b/driver/src/test/java/org/neo4j/driver/internal/async/RoutingAsyncConnectionTest.java new file mode 100644 index 0000000000..15b0752820 --- /dev/null +++ b/driver/src/test/java/org/neo4j/driver/internal/async/RoutingAsyncConnectionTest.java @@ -0,0 +1,83 @@ +/* + * Copyright (c) 2002-2017 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal.async; + +import org.junit.Test; +import org.mockito.ArgumentCaptor; + +import org.neo4j.driver.internal.RoutingErrorHandler; +import org.neo4j.driver.internal.spi.ResponseHandler; + +import static java.util.Collections.emptyMap; +import static org.hamcrest.Matchers.instanceOf; +import static org.junit.Assert.assertThat; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.neo4j.driver.v1.AccessMode.READ; + +public class RoutingAsyncConnectionTest +{ + @Test + public void shouldWrapGivenHandlersInRun() + { + testHandlersWrapping( false ); + } + + @Test + public void shouldWrapGivenHandlersInRunAndFlush() + { + testHandlersWrapping( true ); + } + + private static void testHandlersWrapping( boolean flush ) + { + AsyncConnection connection = mock( AsyncConnection.class ); + RoutingErrorHandler errorHandler = mock( RoutingErrorHandler.class ); + RoutingAsyncConnection routingConnection = new RoutingAsyncConnection( connection, READ, errorHandler ); + + if ( flush ) + { + routingConnection.runAndFlush( "RETURN 1", emptyMap(), mock( ResponseHandler.class ), + mock( ResponseHandler.class ) ); + } + else + { + routingConnection.run( "RETURN 1", emptyMap(), mock( ResponseHandler.class ), + mock( ResponseHandler.class ) ); + } + + ArgumentCaptor runHandlerCaptor = ArgumentCaptor.forClass( ResponseHandler.class ); + ArgumentCaptor pullAllHandlerCaptor = ArgumentCaptor.forClass( ResponseHandler.class ); + + if ( flush ) + { + verify( connection ).runAndFlush( eq( "RETURN 1" ), eq( emptyMap() ), runHandlerCaptor.capture(), + pullAllHandlerCaptor.capture() ); + } + else + { + verify( connection ).run( eq( "RETURN 1" ), eq( emptyMap() ), runHandlerCaptor.capture(), + pullAllHandlerCaptor.capture() ); + } + + assertThat( runHandlerCaptor.getValue(), instanceOf( RoutingResponseHandler.class ) ); + assertThat( pullAllHandlerCaptor.getValue(), instanceOf( RoutingResponseHandler.class ) ); + } +} diff --git a/driver/src/test/java/org/neo4j/driver/internal/async/RoutingResponseHandlerTest.java b/driver/src/test/java/org/neo4j/driver/internal/async/RoutingResponseHandlerTest.java new file mode 100644 index 0000000000..a5d525d155 --- /dev/null +++ b/driver/src/test/java/org/neo4j/driver/internal/async/RoutingResponseHandlerTest.java @@ -0,0 +1,169 @@ +/* + * Copyright (c) 2002-2017 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal.async; + +import org.junit.Test; +import org.mockito.ArgumentCaptor; + +import java.util.concurrent.CompletionException; + +import org.neo4j.driver.internal.RoutingErrorHandler; +import org.neo4j.driver.internal.spi.ResponseHandler; +import org.neo4j.driver.v1.AccessMode; +import org.neo4j.driver.v1.exceptions.ClientException; +import org.neo4j.driver.v1.exceptions.ServiceUnavailableException; +import org.neo4j.driver.v1.exceptions.SessionExpiredException; +import org.neo4j.driver.v1.exceptions.TransientException; + +import static org.hamcrest.Matchers.instanceOf; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyZeroInteractions; +import static org.neo4j.driver.internal.net.BoltServerAddress.LOCAL_DEFAULT; + +public class RoutingResponseHandlerTest +{ + @Test + public void shouldUnwrapCompletionException() + { + RuntimeException error = new RuntimeException( "Hi" ); + RoutingErrorHandler errorHandler = mock( RoutingErrorHandler.class ); + + Throwable handledError = handle( new CompletionException( error ), errorHandler ); + + assertEquals( error, handledError ); + verifyZeroInteractions( errorHandler ); + } + + @Test + public void shouldHandleServiceUnavailableException() + { + ServiceUnavailableException error = new ServiceUnavailableException( "Hi" ); + RoutingErrorHandler errorHandler = mock( RoutingErrorHandler.class ); + + Throwable handledError = handle( error, errorHandler ); + + assertThat( handledError, instanceOf( SessionExpiredException.class ) ); + verify( errorHandler ).onConnectionFailure( LOCAL_DEFAULT ); + } + + @Test + public void shouldHandleDatabaseUnavailableError() + { + TransientException error = new TransientException( "Neo.TransientError.General.DatabaseUnavailable", "Hi" ); + RoutingErrorHandler errorHandler = mock( RoutingErrorHandler.class ); + + Throwable handledError = handle( error, errorHandler ); + + assertEquals( error, handledError ); + verify( errorHandler ).onConnectionFailure( LOCAL_DEFAULT ); + } + + @Test + public void shouldHandleTransientException() + { + TransientException error = new TransientException( "Neo.TransientError.Transaction.DeadlockDetected", "Hi" ); + RoutingErrorHandler errorHandler = mock( RoutingErrorHandler.class ); + + Throwable handledError = handle( error, errorHandler ); + + assertEquals( error, handledError ); + verifyZeroInteractions( errorHandler ); + } + + @Test + public void shouldHandleNotALeaderErrorWithReadAccessMode() + { + testWriteFailureWithReadAccessMode( "Neo.ClientError.Cluster.NotALeader" ); + } + + @Test + public void shouldHandleNotALeaderErrorWithWriteAccessMode() + { + testWriteFailureWithWriteAccessMode( "Neo.ClientError.Cluster.NotALeader" ); + } + + @Test + public void shouldHandleForbiddenOnReadOnlyDatabaseErrorWithReadAccessMode() + { + testWriteFailureWithReadAccessMode( "Neo.ClientError.General.ForbiddenOnReadOnlyDatabase" ); + } + + @Test + public void shouldHandleForbiddenOnReadOnlyDatabaseErrorWithWriteAccessMode() + { + testWriteFailureWithWriteAccessMode( "Neo.ClientError.General.ForbiddenOnReadOnlyDatabase" ); + } + + @Test + public void shouldHandleClientException() + { + ClientException error = new ClientException( "Neo.ClientError.Request.Invalid", "Hi" ); + RoutingErrorHandler errorHandler = mock( RoutingErrorHandler.class ); + + Throwable handledError = handle( error, errorHandler, AccessMode.READ ); + + assertEquals( error, handledError ); + verifyZeroInteractions( errorHandler ); + } + + private void testWriteFailureWithReadAccessMode( String code ) + { + ClientException error = new ClientException( code, "Hi" ); + RoutingErrorHandler errorHandler = mock( RoutingErrorHandler.class ); + + Throwable handledError = handle( error, errorHandler, AccessMode.READ ); + + assertThat( handledError, instanceOf( ClientException.class ) ); + assertEquals( "Write queries cannot be performed in READ access mode.", handledError.getMessage() ); + verifyZeroInteractions( errorHandler ); + } + + private void testWriteFailureWithWriteAccessMode( String code ) + { + ClientException error = new ClientException( code, "Hi" ); + RoutingErrorHandler errorHandler = mock( RoutingErrorHandler.class ); + + Throwable handledError = handle( error, errorHandler, AccessMode.WRITE ); + + assertThat( handledError, instanceOf( SessionExpiredException.class ) ); + assertEquals( "Server at " + LOCAL_DEFAULT + " no longer accepts writes", handledError.getMessage() ); + verify( errorHandler ).onWriteFailure( LOCAL_DEFAULT ); + } + + private static Throwable handle( Throwable error, RoutingErrorHandler errorHandler ) + { + return handle( error, errorHandler, AccessMode.READ ); + } + + private static Throwable handle( Throwable error, RoutingErrorHandler errorHandler, AccessMode accessMode ) + { + ResponseHandler responseHandler = mock( ResponseHandler.class ); + RoutingResponseHandler routingResponseHandler = + new RoutingResponseHandler( responseHandler, LOCAL_DEFAULT, accessMode, errorHandler ); + + routingResponseHandler.onFailure( error ); + + ArgumentCaptor handledErrorCaptor = ArgumentCaptor.forClass( Throwable.class ); + verify( responseHandler ).onFailure( handledErrorCaptor.capture() ); + return handledErrorCaptor.getValue(); + } +} diff --git a/driver/src/test/java/org/neo4j/driver/internal/async/pool/ActiveChannelTrackerTest.java b/driver/src/test/java/org/neo4j/driver/internal/async/pool/ActiveChannelTrackerTest.java index 57c0fe39de..8705aee247 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/async/pool/ActiveChannelTrackerTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/async/pool/ActiveChannelTrackerTest.java @@ -29,7 +29,7 @@ import static org.junit.Assert.assertNull; import static org.junit.Assert.assertThat; import static org.junit.Assert.fail; -import static org.neo4j.driver.internal.async.ChannelAttributes.setAddress; +import static org.neo4j.driver.internal.async.ChannelAttributes.setServerAddress; import static org.neo4j.driver.internal.logging.DevNullLogging.DEV_NULL_LOGGING; import static org.neo4j.driver.v1.util.TestUtil.await; @@ -138,7 +138,7 @@ public void shouldPruneForExistingAddress() private Channel newChannel() { EmbeddedChannel channel = new EmbeddedChannel(); - setAddress( channel, address ); + setServerAddress( channel, address ); return channel; } } diff --git a/driver/src/test/java/org/neo4j/driver/internal/cluster/RediscoveryAsyncTest.java b/driver/src/test/java/org/neo4j/driver/internal/cluster/RediscoveryAsyncTest.java new file mode 100644 index 0000000000..3980b74233 --- /dev/null +++ b/driver/src/test/java/org/neo4j/driver/internal/cluster/RediscoveryAsyncTest.java @@ -0,0 +1,406 @@ +/* + * Copyright (c) 2002-2017 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal.cluster; + +import io.netty.util.concurrent.GlobalEventExecutor; +import org.junit.Test; + +import java.io.IOException; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.concurrent.CompletionStage; + +import org.neo4j.driver.internal.async.AsyncConnection; +import org.neo4j.driver.internal.async.pool.AsyncConnectionPool; +import org.neo4j.driver.internal.cluster.ClusterCompositionResponse.Failure; +import org.neo4j.driver.internal.cluster.ClusterCompositionResponse.Success; +import org.neo4j.driver.internal.net.BoltServerAddress; +import org.neo4j.driver.internal.util.FakeClock; +import org.neo4j.driver.internal.util.TrackingEventExecutor; +import org.neo4j.driver.v1.exceptions.AuthenticationException; +import org.neo4j.driver.v1.exceptions.ProtocolException; +import org.neo4j.driver.v1.exceptions.ServiceUnavailableException; +import org.neo4j.driver.v1.exceptions.SessionExpiredException; + +import static java.util.Arrays.asList; +import static java.util.Collections.emptySet; +import static java.util.concurrent.CompletableFuture.completedFuture; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.fail; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.neo4j.driver.internal.async.Futures.failedFuture; +import static org.neo4j.driver.internal.async.Futures.getBlocking; +import static org.neo4j.driver.internal.cluster.ClusterCompositionUtil.A; +import static org.neo4j.driver.internal.cluster.ClusterCompositionUtil.B; +import static org.neo4j.driver.internal.cluster.ClusterCompositionUtil.C; +import static org.neo4j.driver.internal.cluster.ClusterCompositionUtil.D; +import static org.neo4j.driver.internal.cluster.ClusterCompositionUtil.E; +import static org.neo4j.driver.internal.logging.DevNullLogger.DEV_NULL_LOGGER; +import static org.neo4j.driver.v1.util.TestUtil.asOrderedSet; + +public class RediscoveryAsyncTest +{ + private final AsyncConnectionPool pool = asyncConnectionPoolMock(); + + @Test + public void shouldUseFirstRouterInTable() + { + ClusterComposition expectedComposition = new ClusterComposition( 42, + asOrderedSet( B, C ), asOrderedSet( C, D ), asOrderedSet( B ) ); + + Map responsesByAddress = new HashMap<>(); + responsesByAddress.put( B, new Success( expectedComposition ) ); // first -> valid cluster composition + + ClusterCompositionProvider compositionProvider = compositionProviderMock( responsesByAddress ); + Rediscovery rediscovery = newRediscovery( A, compositionProvider, mock( HostNameResolver.class ) ); + RoutingTable table = routingTableMock( B ); + + ClusterComposition actualComposition = getBlocking( rediscovery.lookupClusterCompositionAsync( table, pool ) ); + + assertEquals( expectedComposition, actualComposition ); + verify( table, never() ).forget( B ); + } + + @Test + public void shouldSkipFailingRouters() + { + ClusterComposition expectedComposition = new ClusterComposition( 42, + asOrderedSet( A, B, C ), asOrderedSet( B, C, D ), asOrderedSet( A, B ) ); + + Map responsesByAddress = new HashMap<>(); + responsesByAddress.put( A, new RuntimeException( "Hi!" ) ); // first -> non-fatal failure + responsesByAddress.put( B, new ServiceUnavailableException( "Hi!" ) ); // second -> non-fatal failure + responsesByAddress.put( C, new Success( expectedComposition ) ); // third -> valid cluster composition + + ClusterCompositionProvider compositionProvider = compositionProviderMock( responsesByAddress ); + Rediscovery rediscovery = newRediscovery( A, compositionProvider, mock( HostNameResolver.class ) ); + RoutingTable table = routingTableMock( A, B, C ); + + ClusterComposition actualComposition = getBlocking( rediscovery.lookupClusterCompositionAsync( table, pool ) ); + + assertEquals( expectedComposition, actualComposition ); + verify( table ).forget( A ); + verify( table ).forget( B ); + verify( table, never() ).forget( C ); + } + + @Test + public void shouldFailImmediatelyOnAuthError() + { + AuthenticationException authError = new AuthenticationException( "Neo.ClientError.Security.Unauthorized", + "Wrong password" ); + + Map responsesByAddress = new HashMap<>(); + responsesByAddress.put( A, new RuntimeException( "Hi!" ) ); // first router -> non-fatal failure + responsesByAddress.put( B, authError ); // second router -> fatal auth error + + ClusterCompositionProvider compositionProvider = compositionProviderMock( responsesByAddress ); + Rediscovery rediscovery = newRediscovery( A, compositionProvider, mock( HostNameResolver.class ) ); + RoutingTable table = routingTableMock( A, B, C ); + + try + { + getBlocking( rediscovery.lookupClusterCompositionAsync( table, pool ) ); + fail( "Exception expected" ); + } + catch ( AuthenticationException e ) + { + assertEquals( authError, e ); + verify( table ).forget( A ); + } + } + + @Test + public void shouldFallbackToInitialRouterWhenKnownRoutersFail() + { + BoltServerAddress initialRouter = A; + ClusterComposition expectedComposition = new ClusterComposition( 42, + asOrderedSet( C, B, A ), asOrderedSet( A, B ), asOrderedSet( D, E ) ); + + Map responsesByAddress = new HashMap<>(); + responsesByAddress.put( B, new ServiceUnavailableException( "Hi!" ) ); // first -> non-fatal failure + responsesByAddress.put( C, new ServiceUnavailableException( "Hi!" ) ); // second -> non-fatal failure + responsesByAddress.put( initialRouter, new Success( expectedComposition ) ); // initial -> valid response + + ClusterCompositionProvider compositionProvider = compositionProviderMock( responsesByAddress ); + HostNameResolver resolver = hostNameResolverMock( initialRouter, initialRouter ); + Rediscovery rediscovery = newRediscovery( initialRouter, compositionProvider, resolver ); + RoutingTable table = routingTableMock( B, C ); + + ClusterComposition actualComposition = getBlocking( rediscovery.lookupClusterCompositionAsync( table, pool ) ); + + assertEquals( expectedComposition, actualComposition ); + verify( table ).forget( B ); + verify( table ).forget( C ); + } + + @Test + public void shouldFailImmediatelyWhenClusterCompositionProviderReturnsFailure() + { + ClusterComposition validComposition = new ClusterComposition( 42, + asOrderedSet( A ), asOrderedSet( B ), asOrderedSet( C ) ); + ProtocolException protocolError = new ProtocolException( "Wrong record!" ); + + Map responsesByAddress = new HashMap<>(); + responsesByAddress.put( B, new Failure( protocolError ) ); // first -> fatal failure + responsesByAddress.put( C, new Success( validComposition ) ); // second -> valid cluster composition + + ClusterCompositionProvider compositionProvider = compositionProviderMock( responsesByAddress ); + Rediscovery rediscovery = newRediscovery( A, compositionProvider, mock( HostNameResolver.class ) ); + RoutingTable table = routingTableMock( B, C ); + + try + { + getBlocking( rediscovery.lookupClusterCompositionAsync( table, pool ) ); + fail( "Exception expected" ); + } + catch ( ProtocolException e ) + { + assertEquals( protocolError, e ); + } + } + + @Test + public void shouldResolveInitialRouterAddress() + { + BoltServerAddress initialRouter = A; + ClusterComposition expectedComposition = new ClusterComposition( 42, + asOrderedSet( A, B ), asOrderedSet( A, B ), asOrderedSet( A, B ) ); + + Map responsesByAddress = new HashMap<>(); + responsesByAddress.put( B, new ServiceUnavailableException( "Hi!" ) ); // first -> non-fatal failure + responsesByAddress.put( C, new ServiceUnavailableException( "Hi!" ) ); // second -> non-fatal failure + responsesByAddress.put( D, new IOException( "Hi!" ) ); // resolved first -> non-fatal failure + responsesByAddress.put( E, new Success( expectedComposition ) ); // resolved second -> valid response + + ClusterCompositionProvider compositionProvider = compositionProviderMock( responsesByAddress ); + // initial router resolved to two other addresses + HostNameResolver resolver = hostNameResolverMock( initialRouter, D, E ); + Rediscovery rediscovery = newRediscovery( initialRouter, compositionProvider, resolver ); + RoutingTable table = routingTableMock( B, C ); + + ClusterComposition actualComposition = getBlocking( rediscovery.lookupClusterCompositionAsync( table, pool ) ); + + assertEquals( expectedComposition, actualComposition ); + verify( table ).forget( B ); + verify( table ).forget( C ); + verify( table ).forget( D ); + } + + @Test + public void shouldFailWhenNoRoutersRespond() + { + Map responsesByAddress = new HashMap<>(); + responsesByAddress.put( A, new ServiceUnavailableException( "Hi!" ) ); // first -> non-fatal failure + responsesByAddress.put( B, new SessionExpiredException( "Hi!" ) ); // second -> non-fatal failure + responsesByAddress.put( C, new IOException( "Hi!" ) ); // third -> non-fatal failure + + ClusterCompositionProvider compositionProvider = compositionProviderMock( responsesByAddress ); + Rediscovery rediscovery = newRediscovery( A, compositionProvider, mock( HostNameResolver.class ) ); + RoutingTable table = routingTableMock( A, B, C ); + + try + { + getBlocking( rediscovery.lookupClusterCompositionAsync( table, pool ) ); + fail( "Exception expected" ); + } + catch ( ServiceUnavailableException e ) + { + assertEquals( "Could not perform discovery. No routing servers available.", e.getMessage() ); + } + } + + @Test + public void shouldUseInitialRouterAfterDiscoveryReturnsNoWriters() + { + BoltServerAddress initialRouter = A; + ClusterComposition noWritersComposition = new ClusterComposition( 42, + asOrderedSet( D, E ), emptySet(), asOrderedSet( D, E ) ); + ClusterComposition validComposition = new ClusterComposition( 42, + asOrderedSet( B, A ), asOrderedSet( B, A ), asOrderedSet( B, A ) ); + + Map responsesByAddress = new HashMap<>(); + responsesByAddress.put( B, new Success( noWritersComposition ) ); // first -> valid cluster composition + responsesByAddress.put( initialRouter, new Success( validComposition ) ); // initial -> valid composition + + ClusterCompositionProvider compositionProvider = compositionProviderMock( responsesByAddress ); + HostNameResolver resolver = hostNameResolverMock( initialRouter, initialRouter ); + Rediscovery rediscovery = newRediscovery( initialRouter, compositionProvider, resolver ); + RoutingTable table = routingTableMock( B ); + + ClusterComposition composition1 = getBlocking( rediscovery.lookupClusterCompositionAsync( table, pool ) ); + assertEquals( noWritersComposition, composition1 ); + + ClusterComposition composition2 = getBlocking( rediscovery.lookupClusterCompositionAsync( table, pool ) ); + assertEquals( validComposition, composition2 ); + } + + @Test + public void shouldUseInitialRouterToStartWith() + { + BoltServerAddress initialRouter = A; + ClusterComposition validComposition = new ClusterComposition( 42, + asOrderedSet( A ), asOrderedSet( A ), asOrderedSet( A ) ); + + Map responsesByAddress = new HashMap<>(); + responsesByAddress.put( initialRouter, new Success( validComposition ) ); // initial -> valid composition + + ClusterCompositionProvider compositionProvider = compositionProviderMock( responsesByAddress ); + HostNameResolver resolver = hostNameResolverMock( initialRouter, initialRouter ); + Rediscovery rediscovery = newRediscovery( initialRouter, compositionProvider, resolver, true ); + RoutingTable table = routingTableMock( B, C, D ); + + ClusterComposition composition = getBlocking( rediscovery.lookupClusterCompositionAsync( table, pool ) ); + assertEquals( validComposition, composition ); + } + + @Test + public void shouldUseKnownRoutersWhenInitialRouterFails() + { + BoltServerAddress initialRouter = A; + ClusterComposition validComposition = new ClusterComposition( 42, + asOrderedSet( D, E ), asOrderedSet( E, D ), asOrderedSet( A, B ) ); + + Map responsesByAddress = new HashMap<>(); + responsesByAddress.put( initialRouter, new ServiceUnavailableException( "Hi" ) ); // initial -> non-fatal error + responsesByAddress.put( D, new IOException( "Hi" ) ); // first known -> non-fatal failure + responsesByAddress.put( E, new Success( validComposition ) ); // second known -> valid composition + + ClusterCompositionProvider compositionProvider = compositionProviderMock( responsesByAddress ); + HostNameResolver resolver = hostNameResolverMock( initialRouter, initialRouter ); + Rediscovery rediscovery = newRediscovery( initialRouter, compositionProvider, resolver, true ); + RoutingTable table = routingTableMock( D, E ); + + ClusterComposition composition = getBlocking( rediscovery.lookupClusterCompositionAsync( table, pool ) ); + assertEquals( validComposition, composition ); + verify( table ).forget( initialRouter ); + verify( table ).forget( D ); + } + + @Test + public void shouldRetryConfiguredNumberOfTimesWithDelay() + { + int maxRoutingFailures = 3; + long retryTimeoutDelay = 15; + ClusterComposition expectedComposition = new ClusterComposition( 42, + asOrderedSet( A, C ), asOrderedSet( B, D ), asOrderedSet( A, E ) ); + + Map responsesByAddress = new HashMap<>(); + responsesByAddress.put( A, new ServiceUnavailableException( "Hi!" ) ); + responsesByAddress.put( B, new ServiceUnavailableException( "Hi!" ) ); + responsesByAddress.put( E, new Success( expectedComposition ) ); + + ClusterCompositionProvider compositionProvider = compositionProviderMock( responsesByAddress ); + HostNameResolver resolver = mock( HostNameResolver.class ); + when( resolver.resolve( A ) ).thenReturn( asOrderedSet( A ) ) + .thenReturn( asOrderedSet( A ) ) + .thenReturn( asOrderedSet( E ) ); + + TrackingEventExecutor eventExecutor = new TrackingEventExecutor(); + RoutingSettings settings = new RoutingSettings( maxRoutingFailures, retryTimeoutDelay ); + Rediscovery rediscovery = new Rediscovery( A, settings, compositionProvider, resolver, eventExecutor, + new FakeClock(), DEV_NULL_LOGGER, false ); + RoutingTable table = routingTableMock( A, B ); + + ClusterComposition actualComposition = getBlocking( rediscovery.lookupClusterCompositionAsync( table, pool ) ); + + assertEquals( expectedComposition, actualComposition ); + verify( table, times( maxRoutingFailures ) ).forget( A ); + verify( table, times( maxRoutingFailures ) ).forget( B ); + assertEquals( asList( retryTimeoutDelay, retryTimeoutDelay * 2 ), eventExecutor.scheduleDelays() ); + } + + private Rediscovery newRediscovery( BoltServerAddress initialRouter, ClusterCompositionProvider compositionProvider, + HostNameResolver hostNameResolver ) + { + return newRediscovery( initialRouter, compositionProvider, hostNameResolver, false ); + } + + private Rediscovery newRediscovery( BoltServerAddress initialRouter, ClusterCompositionProvider compositionProvider, + HostNameResolver hostNameResolver, boolean useInitialRouter ) + { + RoutingSettings settings = new RoutingSettings( 1, 0 ); + return new Rediscovery( initialRouter, settings, compositionProvider, hostNameResolver, + GlobalEventExecutor.INSTANCE, new FakeClock(), DEV_NULL_LOGGER, useInitialRouter ); + } + + @SuppressWarnings( "unchecked" ) + private static ClusterCompositionProvider compositionProviderMock( + Map responsesByAddress ) + { + ClusterCompositionProvider provider = mock( ClusterCompositionProvider.class ); + when( provider.getClusterComposition( any( CompletionStage.class ) ) ).then( invocation -> + { + CompletionStage connectionStage = invocation.getArgumentAt( 0, CompletionStage.class ); + BoltServerAddress address = getBlocking( connectionStage ).serverAddress(); + Object response = responsesByAddress.get( address ); + assertNotNull( response ); + if ( response instanceof Throwable ) + { + return failedFuture( (Throwable) response ); + } + else + { + return completedFuture( response ); + } + } ); + return provider; + } + + private static HostNameResolver hostNameResolverMock( BoltServerAddress address, BoltServerAddress... resolved ) + { + HostNameResolver resolver = mock( HostNameResolver.class ); + when( resolver.resolve( address ) ).thenReturn( asOrderedSet( resolved ) ); + return resolver; + } + + private static AsyncConnectionPool asyncConnectionPoolMock() + { + AsyncConnectionPool pool = mock( AsyncConnectionPool.class ); + when( pool.acquire( any() ) ).then( invocation -> + { + BoltServerAddress address = invocation.getArgumentAt( 0, BoltServerAddress.class ); + return completedFuture( asyncConnectionMock( address ) ); + } ); + return pool; + } + + private static AsyncConnection asyncConnectionMock( BoltServerAddress address ) + { + AsyncConnection connection = mock( AsyncConnection.class ); + when( connection.serverAddress() ).thenReturn( address ); + return connection; + } + + private static RoutingTable routingTableMock( BoltServerAddress... routers ) + { + RoutingTable routingTable = mock( RoutingTable.class ); + AddressSet addressSet = new AddressSet(); + addressSet.update( asOrderedSet( routers ), new HashSet<>() ); + when( routingTable.routers() ).thenReturn( addressSet ); + return routingTable; + } +} diff --git a/driver/src/test/java/org/neo4j/driver/internal/cluster/RediscoveryTest.java b/driver/src/test/java/org/neo4j/driver/internal/cluster/RediscoveryTest.java index 33966304db..e978bae5d6 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/cluster/RediscoveryTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/cluster/RediscoveryTest.java @@ -18,6 +18,7 @@ */ package org.neo4j.driver.internal.cluster; +import io.netty.util.concurrent.GlobalEventExecutor; import org.junit.Test; import org.junit.experimental.runners.Enclosed; import org.junit.runner.RunWith; @@ -106,7 +107,8 @@ public void shouldTryConfiguredMaxRoutingFailures() throws Exception ClusterCompositionProvider mockedProvider = mock( ClusterCompositionProvider.class ); when( mockedProvider.getClusterComposition( any( Connection.class ) ) ).thenThrow( new RuntimeException() ); - Rediscovery rediscovery = new Rediscovery( A, settings, clock, DEV_NULL_LOGGER, mockedProvider, directMapProvider ); + Rediscovery rediscovery = new Rediscovery( A, settings, mockedProvider, GlobalEventExecutor.INSTANCE, + directMapProvider, clock, DEV_NULL_LOGGER ); // when try @@ -245,8 +247,9 @@ public void shouldUseInitialRouterWhenRediscoveringAfterNoWriters() throws Throw when( mockedProvider.getClusterComposition( initialRouterConn ) ) .thenReturn( success( VALID_CLUSTER_COMPOSITION ) ); - Rediscovery rediscovery = new Rediscovery( F, new RoutingSettings( 1, 0 ), new FakeClock(), - DEV_NULL_LOGGER, mockedProvider, directMapProvider ); + Rediscovery rediscovery = new Rediscovery( F, new RoutingSettings( 1, 0 ), mockedProvider, + GlobalEventExecutor.INSTANCE, directMapProvider, new FakeClock(), + DEV_NULL_LOGGER ); // first rediscovery should accept table with no writers ClusterComposition composition1 = rediscovery.lookupClusterComposition( routingTable, mockedConnections ); @@ -481,8 +484,9 @@ public void shouldProbeAllKnownRoutersInOrder() Clock mockedClock = mock( Clock.class ); Logger mockedLogger = mock( Logger.class ); - Rediscovery rediscovery = new Rediscovery( A, settings, mockedClock, mockedLogger, clusterComposition, - directMapProvider ); + Rediscovery rediscovery = new Rediscovery( A, settings, clusterComposition, GlobalEventExecutor.INSTANCE, + directMapProvider, mockedClock, mockedLogger + ); ClusterComposition composition1 = rediscovery.lookupClusterComposition( routingTable, connections ); assertEquals( VALID_CLUSTER_COMPOSITION, composition1 ); @@ -512,8 +516,9 @@ private static ClusterComposition rediscover( BoltServerAddress initialRouter, C Clock mockedClock = mock( Clock.class ); Logger mockedLogger = mock( Logger.class ); - Rediscovery rediscovery = new Rediscovery( initialRouter, settings, mockedClock, mockedLogger, provider, - directMapProvider ); + Rediscovery rediscovery = new Rediscovery( initialRouter, settings, provider, GlobalEventExecutor.INSTANCE, + directMapProvider, mockedClock, mockedLogger + ); return rediscovery.lookupClusterComposition( routingTable, connections ); } diff --git a/driver/src/test/java/org/neo4j/driver/internal/cluster/RoutingPooledConnectionErrorHandlingTest.java b/driver/src/test/java/org/neo4j/driver/internal/cluster/RoutingPooledConnectionErrorHandlingTest.java index b8ad0242ea..9a4fe89741 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/cluster/RoutingPooledConnectionErrorHandlingTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/cluster/RoutingPooledConnectionErrorHandlingTest.java @@ -18,6 +18,7 @@ */ package org.neo4j.driver.internal.cluster; +import io.netty.util.concurrent.GlobalEventExecutor; import org.junit.Test; import org.junit.runner.RunWith; import org.junit.runners.Parameterized; @@ -30,6 +31,7 @@ import java.util.HashSet; import java.util.List; +import org.neo4j.driver.internal.async.pool.AsyncConnectionPool; import org.neo4j.driver.internal.cluster.loadbalancing.LoadBalancer; import org.neo4j.driver.internal.handlers.NoOpResponseHandler; import org.neo4j.driver.internal.net.BoltServerAddress; @@ -379,7 +381,9 @@ private static LoadBalancer newLoadBalancer( ClusterComposition clusterCompositi { Rediscovery rediscovery = mock( Rediscovery.class ); when( rediscovery.lookupClusterComposition( routingTable, connectionPool ) ).thenReturn( clusterComposition ); - return new LoadBalancer( connectionPool, routingTable, rediscovery, DEV_NULL_LOGGING ); + AsyncConnectionPool asyncConnectionPool = mock( AsyncConnectionPool.class ); + return new LoadBalancer( connectionPool, asyncConnectionPool, routingTable, rediscovery, + GlobalEventExecutor.INSTANCE, DEV_NULL_LOGGING ); } private interface ConnectionMethod diff --git a/driver/src/test/java/org/neo4j/driver/internal/cluster/ClusterCompositionProviderTest.java b/driver/src/test/java/org/neo4j/driver/internal/cluster/RoutingProcedureClusterCompositionProviderTest.java similarity index 63% rename from driver/src/test/java/org/neo4j/driver/internal/cluster/ClusterCompositionProviderTest.java rename to driver/src/test/java/org/neo4j/driver/internal/cluster/RoutingProcedureClusterCompositionProviderTest.java index b1fece48c7..decb07357b 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/cluster/ClusterCompositionProviderTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/cluster/RoutingProcedureClusterCompositionProviderTest.java @@ -17,17 +17,18 @@ * limitations under the License. */ package org.neo4j.driver.internal.cluster; + import org.junit.Test; -import java.util.ArrayList; import java.util.HashMap; import java.util.HashSet; import java.util.Map; import java.util.Set; +import java.util.concurrent.CompletionStage; import org.neo4j.driver.internal.InternalRecord; +import org.neo4j.driver.internal.async.AsyncConnection; import org.neo4j.driver.internal.net.BoltServerAddress; -import org.neo4j.driver.internal.spi.PooledConnection; import org.neo4j.driver.internal.util.Clock; import org.neo4j.driver.internal.value.StringValue; import org.neo4j.driver.v1.Record; @@ -37,33 +38,36 @@ import org.neo4j.driver.v1.exceptions.ServiceUnavailableException; import static java.util.Arrays.asList; +import static java.util.concurrent.CompletableFuture.completedFuture; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.instanceOf; import static org.junit.Assert.assertEquals; import static org.junit.Assert.fail; -import static org.mockito.Mockito.doThrow; +import static org.mockito.Matchers.any; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import static org.neo4j.driver.internal.async.Futures.failedFuture; +import static org.neo4j.driver.internal.async.Futures.getBlocking; import static org.neo4j.driver.internal.logging.DevNullLogger.DEV_NULL_LOGGER; import static org.neo4j.driver.v1.Values.value; -public class ClusterCompositionProviderTest +public class RoutingProcedureClusterCompositionProviderTest { @Test - public void shouldProtocolErrorWhenNoRecord() throws Throwable + public void shouldProtocolErrorWhenNoRecord() { // Given RoutingProcedureRunner mockedRunner = newProcedureRunnerMock(); ClusterCompositionProvider provider = new RoutingProcedureClusterCompositionProvider( mock( Clock.class ), DEV_NULL_LOGGER, mockedRunner ); - PooledConnection mockedConn = mock( PooledConnection.class ); - ArrayList emptyRecord = new ArrayList<>(); - when( mockedRunner.run( mockedConn ) ).thenReturn( emptyRecord ); + CompletionStage connectionStage = completedFuture( mock( AsyncConnection.class ) ); + RoutingProcedureResponse noRecordsResponse = newRoutingResponse(); + when( mockedRunner.run( connectionStage ) ).thenReturn( completedFuture( noRecordsResponse ) ); // When - ClusterCompositionResponse response = provider.getClusterComposition( mockedConn ); + ClusterCompositionResponse response = getBlocking( provider.getClusterComposition( connectionStage ) ); // Then assertThat( response, instanceOf( ClusterCompositionResponse.Failure.class ) ); @@ -80,19 +84,20 @@ public void shouldProtocolErrorWhenNoRecord() throws Throwable } @Test - public void shouldProtocolErrorWhenMoreThanOneRecord() throws Throwable + public void shouldProtocolErrorWhenMoreThanOneRecord() { // Given RoutingProcedureRunner mockedRunner = newProcedureRunnerMock(); ClusterCompositionProvider provider = new RoutingProcedureClusterCompositionProvider( mock( Clock.class ), DEV_NULL_LOGGER, mockedRunner ); - PooledConnection mockedConn = mock( PooledConnection.class ); + CompletionStage connectionStage = completedFuture( mock( AsyncConnection.class ) ); Record aRecord = new InternalRecord( asList( "key1", "key2" ), new Value[]{ new StringValue( "a value" ) } ); - when( mockedRunner.run( mockedConn ) ).thenReturn( asList( aRecord, aRecord ) ); + RoutingProcedureResponse routingResponse = newRoutingResponse( aRecord, aRecord ); + when( mockedRunner.run( connectionStage ) ).thenReturn( completedFuture( routingResponse ) ); // When - ClusterCompositionResponse response = provider.getClusterComposition( mockedConn ); + ClusterCompositionResponse response = getBlocking( provider.getClusterComposition( connectionStage ) ); // Then assertThat( response, instanceOf( ClusterCompositionResponse.Failure.class ) ); @@ -109,19 +114,20 @@ public void shouldProtocolErrorWhenMoreThanOneRecord() throws Throwable } @Test - public void shouldProtocolErrorWhenUnparsableRecord() throws Throwable + public void shouldProtocolErrorWhenUnparsableRecord() { // Given RoutingProcedureRunner mockedRunner = newProcedureRunnerMock(); ClusterCompositionProvider provider = new RoutingProcedureClusterCompositionProvider( mock( Clock.class ), DEV_NULL_LOGGER, mockedRunner ); - PooledConnection mockedConn = mock( PooledConnection.class ); + CompletionStage connectionStage = completedFuture( mock( AsyncConnection.class ) ); Record aRecord = new InternalRecord( asList( "key1", "key2" ), new Value[]{ new StringValue( "a value" ) } ); - when( mockedRunner.run( mockedConn ) ).thenReturn( asList( aRecord ) ); + RoutingProcedureResponse routingResponse = newRoutingResponse( aRecord ); + when( mockedRunner.run( connectionStage ) ).thenReturn( completedFuture( routingResponse ) ); // When - ClusterCompositionResponse response = provider.getClusterComposition( mockedConn ); + ClusterCompositionResponse response = getBlocking( provider.getClusterComposition( connectionStage ) ); // Then assertThat( response, instanceOf( ClusterCompositionResponse.Failure.class ) ); @@ -138,7 +144,7 @@ public void shouldProtocolErrorWhenUnparsableRecord() throws Throwable } @Test - public void shouldProtocolErrorWhenNoRouters() throws Throwable + public void shouldProtocolErrorWhenNoRouters() { // Given RoutingProcedureRunner mockedRunner = newProcedureRunnerMock(); @@ -146,17 +152,18 @@ public void shouldProtocolErrorWhenNoRouters() throws Throwable ClusterCompositionProvider provider = new RoutingProcedureClusterCompositionProvider( mockedClock, DEV_NULL_LOGGER, mockedRunner ); - PooledConnection mockedConn = mock( PooledConnection.class ); + CompletionStage connectionStage = completedFuture( mock( AsyncConnection.class ) ); Record record = new InternalRecord( asList( "ttl", "servers" ), new Value[]{ value( 100 ), value( asList( serverInfo( "READ", "one:1337", "two:1337" ), serverInfo( "WRITE", "one:1337" ) ) ) } ); - when( mockedRunner.run( mockedConn ) ).thenReturn( asList( record ) ); + RoutingProcedureResponse routingResponse = newRoutingResponse( record ); + when( mockedRunner.run( connectionStage ) ).thenReturn( completedFuture( routingResponse ) ); when( mockedClock.millis() ).thenReturn( 12345L ); // When - ClusterCompositionResponse response = provider.getClusterComposition( mockedConn ); + ClusterCompositionResponse response = getBlocking( provider.getClusterComposition( connectionStage ) ); // Then assertThat( response, instanceOf( ClusterCompositionResponse.Failure.class ) ); @@ -173,7 +180,7 @@ public void shouldProtocolErrorWhenNoRouters() throws Throwable } @Test - public void shouldProtocolErrorWhenNoReaders() throws Throwable + public void shouldProtocolErrorWhenNoReaders() { // Given RoutingProcedureRunner mockedRunner = newProcedureRunnerMock(); @@ -181,17 +188,18 @@ public void shouldProtocolErrorWhenNoReaders() throws Throwable ClusterCompositionProvider provider = new RoutingProcedureClusterCompositionProvider( mockedClock, DEV_NULL_LOGGER, mockedRunner ); - PooledConnection mockedConn = mock( PooledConnection.class ); + CompletionStage connectionStage = completedFuture( mock( AsyncConnection.class ) ); Record record = new InternalRecord( asList( "ttl", "servers" ), new Value[]{ value( 100 ), value( asList( serverInfo( "WRITE", "one:1337" ), serverInfo( "ROUTE", "one:1337", "two:1337" ) ) ) } ); - when( mockedRunner.run( mockedConn ) ).thenReturn( asList( record ) ); + RoutingProcedureResponse routingResponse = newRoutingResponse( record ); + when( mockedRunner.run( connectionStage ) ).thenReturn( completedFuture( routingResponse ) ); when( mockedClock.millis() ).thenReturn( 12345L ); // When - ClusterCompositionResponse response = provider.getClusterComposition( mockedConn ); + ClusterCompositionResponse response = getBlocking( provider.getClusterComposition( connectionStage ) ); // Then assertThat( response, instanceOf( ClusterCompositionResponse.Failure.class ) ); @@ -209,26 +217,21 @@ public void shouldProtocolErrorWhenNoReaders() throws Throwable @Test - public void shouldPropagateConnectionFailureExceptions() throws Exception + public void shouldPropagateConnectionFailureExceptions() { // Given RoutingProcedureRunner mockedRunner = newProcedureRunnerMock(); ClusterCompositionProvider provider = new RoutingProcedureClusterCompositionProvider( mock( Clock.class ), DEV_NULL_LOGGER, mockedRunner ); - PooledConnection mockedConn = mock( PooledConnection.class ); - Record record = new InternalRecord( asList( "ttl", "servers" ), new Value[]{ - value( 100 ), value( asList( - serverInfo( "WRITE", "one:1337" ), - serverInfo( "ROUTE", "one:1337", "two:1337" ) ) ) - } ); - doThrow( new ServiceUnavailableException( "Connection breaks during cypher execution" ) ) - .when( mockedRunner ).run( mockedConn ); + CompletionStage connectionStage = completedFuture( mock( AsyncConnection.class ) ); + when( mockedRunner.run( connectionStage ) ).thenReturn( failedFuture( + new ServiceUnavailableException( "Connection breaks during cypher execution" ) ) ); // When & Then try { - provider.getClusterComposition( mockedConn ); + getBlocking( provider.getClusterComposition( connectionStage ) ); fail( "Expecting a failure but not triggered." ); } catch( Exception e ) @@ -239,7 +242,7 @@ public void shouldPropagateConnectionFailureExceptions() throws Exception } @Test - public void shouldReturnSuccessResultWhenNoError() throws Throwable + public void shouldReturnSuccessResultWhenNoError() { // Given Clock mockedClock = mock( Clock.class ); @@ -247,18 +250,19 @@ public void shouldReturnSuccessResultWhenNoError() throws Throwable ClusterCompositionProvider provider = new RoutingProcedureClusterCompositionProvider( mockedClock, DEV_NULL_LOGGER, mockedRunner ); - PooledConnection mockedConn = mock( PooledConnection.class ); + CompletionStage connectionStage = completedFuture( mock( AsyncConnection.class ) ); Record record = new InternalRecord( asList( "ttl", "servers" ), new Value[]{ value( 100 ), value( asList( serverInfo( "READ", "one:1337", "two:1337" ), serverInfo( "WRITE", "one:1337" ), serverInfo( "ROUTE", "one:1337", "two:1337" ) ) ) } ); - when( mockedRunner.run( mockedConn ) ).thenReturn( asList( record ) ); + RoutingProcedureResponse routingResponse = newRoutingResponse( record ); + when( mockedRunner.run( connectionStage ) ).thenReturn( completedFuture( routingResponse ) ); when( mockedClock.millis() ).thenReturn( 12345L ); // When - ClusterCompositionResponse response = provider.getClusterComposition( mockedConn ); + ClusterCompositionResponse response = getBlocking( provider.getClusterComposition( connectionStage ) ); // Then assertThat( response, instanceOf( ClusterCompositionResponse.Success.class ) ); @@ -269,6 +273,32 @@ public void shouldReturnSuccessResultWhenNoError() throws Throwable assertEquals( serverSet( "one:1337", "two:1337" ), cluster.routers() ); } + @Test + @SuppressWarnings( "unchecked" ) + public void shouldReturnFailureWhenProcedureRunnerFails() + { + RoutingProcedureRunner procedureRunner = newProcedureRunnerMock(); + RuntimeException error = new RuntimeException( "hi" ); + when( procedureRunner.run( any( CompletionStage.class ) ) ) + .thenReturn( completedFuture( newRoutingResponse( error ) ) ); + + RoutingProcedureClusterCompositionProvider provider = new RoutingProcedureClusterCompositionProvider( + mock( Clock.class ), DEV_NULL_LOGGER, procedureRunner ); + + CompletionStage connectionStage = completedFuture( mock( AsyncConnection.class ) ); + ClusterCompositionResponse response = getBlocking( provider.getClusterComposition( connectionStage ) ); + + try + { + response.clusterComposition(); + fail( "Exception expected" ); + } + catch ( ServiceUnavailableException e ) + { + assertEquals( error, e.getCause() ); + } + } + public static Map serverInfo( String role, String... addresses ) { Map map = new HashMap<>(); @@ -289,8 +319,16 @@ private static Set serverSet( String... addresses ) private static RoutingProcedureRunner newProcedureRunnerMock() { - RoutingProcedureRunner mock = mock( RoutingProcedureRunner.class ); - when( mock.invokedProcedure() ).thenReturn( new Statement( "procedure" ) ); - return mock; + return mock( RoutingProcedureRunner.class ); + } + + private static RoutingProcedureResponse newRoutingResponse( Record... records ) + { + return new RoutingProcedureResponse( new Statement( "procedure" ), asList( records ) ); + } + + private static RoutingProcedureResponse newRoutingResponse( Throwable error ) + { + return new RoutingProcedureResponse( new Statement( "procedure" ), error ); } } diff --git a/driver/src/test/java/org/neo4j/driver/internal/cluster/RoutingProcedureResponseTest.java b/driver/src/test/java/org/neo4j/driver/internal/cluster/RoutingProcedureResponseTest.java new file mode 100644 index 0000000000..fc21fc1252 --- /dev/null +++ b/driver/src/test/java/org/neo4j/driver/internal/cluster/RoutingProcedureResponseTest.java @@ -0,0 +1,115 @@ +/* + * Copyright (c) 2002-2017 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal.cluster; + +import org.junit.Test; + +import org.neo4j.driver.internal.InternalRecord; +import org.neo4j.driver.internal.value.StringValue; +import org.neo4j.driver.v1.Record; +import org.neo4j.driver.v1.Statement; +import org.neo4j.driver.v1.Value; + +import static java.util.Arrays.asList; +import static org.hamcrest.Matchers.instanceOf; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +public class RoutingProcedureResponseTest +{ + private static final Statement PROCEDURE = new Statement( "procedure" ); + + private static final Record RECORD_1 = new InternalRecord( asList( "a", "b" ), + new Value[]{new StringValue( "a" ), new StringValue( "b" )} ); + private static final Record RECORD_2 = new InternalRecord( asList( "a", "b" ), + new Value[]{new StringValue( "aa" ), new StringValue( "bb" )} ); + + @Test + public void shouldBeSuccessfulWithRecords() + { + RoutingProcedureResponse response = new RoutingProcedureResponse( PROCEDURE, asList( RECORD_1, RECORD_2 ) ); + assertTrue( response.isSuccess() ); + } + + @Test + public void shouldNotBeSuccessfulWithError() + { + RoutingProcedureResponse response = new RoutingProcedureResponse( PROCEDURE, new RuntimeException() ); + assertFalse( response.isSuccess() ); + } + + @Test + public void shouldThrowWhenFailedAndAskedForRecords() + { + RuntimeException error = new RuntimeException(); + RoutingProcedureResponse response = new RoutingProcedureResponse( PROCEDURE, error ); + + try + { + response.records(); + fail( "Exception expected" ); + } + catch ( Exception e ) + { + assertThat( e, instanceOf( IllegalStateException.class ) ); + assertEquals( e.getCause(), error ); + } + } + + @Test + public void shouldThrowWhenSuccessfulAndAskedForError() + { + RoutingProcedureResponse response = new RoutingProcedureResponse( PROCEDURE, asList( RECORD_1, RECORD_2 ) ); + + try + { + response.error(); + fail( "Exception expected" ); + } + catch ( Exception e ) + { + assertThat( e, instanceOf( IllegalStateException.class ) ); + } + } + + @Test + public void shouldHaveErrorWhenFailed() + { + RuntimeException error = new RuntimeException( "Hi!" ); + RoutingProcedureResponse response = new RoutingProcedureResponse( PROCEDURE, error ); + assertEquals( error, response.error() ); + } + + @Test + public void shouldHaveRecordsWhenSuccessful() + { + RoutingProcedureResponse response = new RoutingProcedureResponse( PROCEDURE, asList( RECORD_1, RECORD_2 ) ); + assertEquals( asList( RECORD_1, RECORD_2 ), response.records() ); + } + + @Test + public void shouldHaveProcedure() + { + RoutingProcedureResponse response = new RoutingProcedureResponse( PROCEDURE, asList( RECORD_1, RECORD_2 ) ); + assertEquals( PROCEDURE, response.procedure() ); + } +} diff --git a/driver/src/test/java/org/neo4j/driver/internal/cluster/RoutingProcedureRunnerTest.java b/driver/src/test/java/org/neo4j/driver/internal/cluster/RoutingProcedureRunnerTest.java index c30735d4f9..3c81591163 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/cluster/RoutingProcedureRunnerTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/cluster/RoutingProcedureRunnerTest.java @@ -22,22 +22,35 @@ import java.net.URI; import java.util.List; +import java.util.concurrent.CompletionStage; +import org.neo4j.driver.internal.async.AsyncConnection; import org.neo4j.driver.internal.net.BoltServerAddress; import org.neo4j.driver.internal.spi.Connection; import org.neo4j.driver.internal.summary.InternalServerInfo; import org.neo4j.driver.v1.Record; import org.neo4j.driver.v1.Statement; import org.neo4j.driver.v1.Value; +import org.neo4j.driver.v1.exceptions.ClientException; +import static java.util.Arrays.asList; import static java.util.Collections.EMPTY_MAP; -import static org.hamcrest.MatcherAssert.assertThat; +import static java.util.Collections.singletonList; +import static java.util.concurrent.CompletableFuture.completedFuture; import static org.hamcrest.core.IsEqual.equalTo; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import static org.neo4j.driver.internal.async.Futures.failedFuture; +import static org.neo4j.driver.internal.async.Futures.getBlocking; import static org.neo4j.driver.internal.cluster.RoutingProcedureRunner.GET_ROUTING_TABLE; import static org.neo4j.driver.internal.cluster.RoutingProcedureRunner.GET_ROUTING_TABLE_PARAM; import static org.neo4j.driver.internal.cluster.RoutingProcedureRunner.GET_SERVERS; +import static org.neo4j.driver.internal.util.ServerVersion.version; import static org.neo4j.driver.v1.Values.parameters; public class RoutingProcedureRunnerTest @@ -51,13 +64,27 @@ public void shouldCallGetRoutingTableWithEmptyMap() throws Throwable when( mock.server() ).thenReturn( new InternalServerInfo( new BoltServerAddress( "123:45" ), "Neo4j/3.2.1" ) ); // When - runner.run( mock ); + RoutingProcedureResponse response = runner.run( mock ); // Then - assertThat( runner.invokedProcedure(), equalTo( + assertThat( response.procedure(), equalTo( new Statement( "CALL " + GET_ROUTING_TABLE, parameters( GET_ROUTING_TABLE_PARAM, EMPTY_MAP ) ) ) ); } + @Test + public void shouldCallGetRoutingTableWithEmptyMapAsync() + { + RoutingProcedureRunner runner = new TestRoutingProcedureRunner( RoutingContext.EMPTY, + completedFuture( asList( mock( Record.class ), mock( Record.class ) ) ) ); + + RoutingProcedureResponse response = getBlocking( runner.run( connectionStage( "Neo4j/3.2.1" ) ) ); + + assertTrue( response.isSuccess() ); + assertEquals( 2, response.records().size() ); + assertEquals( new Statement( "CALL " + GET_ROUTING_TABLE, parameters( GET_ROUTING_TABLE_PARAM, EMPTY_MAP ) ), + response.procedure() ); + } + @Test public void shouldCallGetRoutingTableWithParam() throws Throwable { @@ -69,14 +96,31 @@ public void shouldCallGetRoutingTableWithParam() throws Throwable when( mock.server() ).thenReturn( new InternalServerInfo( new BoltServerAddress( "123:45" ), "Neo4j/3.2.1" ) ); // When - runner.run( mock ); + RoutingProcedureResponse response = runner.run( mock ); // Then Value expectedParams = parameters( GET_ROUTING_TABLE_PARAM, context.asMap() ); - assertThat( runner.invokedProcedure(), equalTo( + assertThat( response.procedure(), equalTo( new Statement( "CALL " + GET_ROUTING_TABLE, expectedParams ) ) ); } + @Test + public void shouldCallGetRoutingTableWithParamAsync() + { + URI uri = URI.create( "bolt+routing://localhost/?key1=value1&key2=value2" ); + RoutingContext context = new RoutingContext( uri ); + + RoutingProcedureRunner runner = new TestRoutingProcedureRunner( context, + completedFuture( singletonList( mock( Record.class ) ) ) ); + + RoutingProcedureResponse response = getBlocking( runner.run( connectionStage( "Neo4j/3.2.1" ) ) ); + + assertTrue( response.isSuccess() ); + assertEquals( 1, response.records().size() ); + Value expectedParams = parameters( GET_ROUTING_TABLE_PARAM, context.asMap() ); + assertEquals( new Statement( "CALL " + GET_ROUTING_TABLE, expectedParams ), response.procedure() ); + } + @Test public void shouldCallGetServers() throws Throwable { @@ -88,18 +132,96 @@ public void shouldCallGetServers() throws Throwable when( mock.server() ).thenReturn( new InternalServerInfo( new BoltServerAddress( "123:45" ), "Neo4j/3.1.8" ) ); // When - runner.run( mock ); + RoutingProcedureResponse response = runner.run( mock ); // Then - assertThat( runner.invokedProcedure(), equalTo( + assertThat( response.procedure(), equalTo( new Statement( "CALL " + GET_SERVERS ) ) ); } + @Test + public void shouldCallGetServersAsync() + { + URI uri = URI.create( "bolt+routing://localhost/?key1=value1&key2=value2" ); + RoutingContext context = new RoutingContext( uri ); + + RoutingProcedureRunner runner = new TestRoutingProcedureRunner( context, + completedFuture( asList( mock( Record.class ), mock( Record.class ) ) ) ); + + RoutingProcedureResponse response = getBlocking( runner.run( connectionStage( "Neo4j/3.1.8" ) ) ); + + assertTrue( response.isSuccess() ); + assertEquals( 2, response.records().size() ); + assertEquals( new Statement( "CALL " + GET_SERVERS ), response.procedure() ); + } + + @Test + public void shouldReturnFailedResponseOnClientException() + { + ClientException error = new ClientException( "Hi" ); + RoutingProcedureRunner runner = new TestRoutingProcedureRunner( RoutingContext.EMPTY, failedFuture( error ) ); + + RoutingProcedureResponse response = getBlocking( runner.run( connectionStage( "Neo4j/3.2.2" ) ) ); + + assertFalse( response.isSuccess() ); + assertEquals( error, response.error() ); + } + + @Test + public void shouldReturnFailedStageOnError() + { + Exception error = new Exception( "Hi" ); + RoutingProcedureRunner runner = new TestRoutingProcedureRunner( RoutingContext.EMPTY, failedFuture( error ) ); + + try + { + getBlocking( runner.run( connectionStage( "Neo4j/3.2.2" ) ) ); + fail( "Exception expected" ); + } + catch ( Exception e ) + { + assertEquals( error, e ); + } + } + + @Test + public void shouldPropagateErrorFromConnectionStage() + { + RuntimeException error = new RuntimeException( "Hi" ); + RoutingProcedureRunner runner = new TestRoutingProcedureRunner( RoutingContext.EMPTY ); + + try + { + getBlocking( runner.run( failedFuture( error ) ) ); + fail( "Exception expected" ); + } + catch ( RuntimeException e ) + { + assertEquals( error, e ); + } + } + + private static CompletionStage connectionStage( String serverVersion ) + { + AsyncConnection connection = mock( AsyncConnection.class ); + when( connection.serverAddress() ).thenReturn( new BoltServerAddress( "123:45" ) ); + when( connection.serverVersion() ).thenReturn( version( serverVersion ) ); + return completedFuture( connection ); + } + private static class TestRoutingProcedureRunner extends RoutingProcedureRunner { + final CompletionStage> runProcedureResult; + TestRoutingProcedureRunner( RoutingContext context ) + { + this( context, null ); + } + + TestRoutingProcedureRunner( RoutingContext context, CompletionStage> runProcedureResult ) { super( context ); + this.runProcedureResult = runProcedureResult; } @Override @@ -108,6 +230,12 @@ List runProcedure( Connection connection, Statement procedure ) // I do not want any network traffic return null; } + + @Override + CompletionStage> runProcedure( AsyncConnection connection, Statement procedure ) + { + return runProcedureResult; + } } } diff --git a/driver/src/test/java/org/neo4j/driver/internal/cluster/loadbalancing/LeastConnectedLoadBalancingStrategyTest.java b/driver/src/test/java/org/neo4j/driver/internal/cluster/loadbalancing/LeastConnectedLoadBalancingStrategyTest.java index 722efc0bbd..e1af9cf9ed 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/cluster/loadbalancing/LeastConnectedLoadBalancingStrategyTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/cluster/loadbalancing/LeastConnectedLoadBalancingStrategyTest.java @@ -22,6 +22,7 @@ import org.junit.Test; import org.mockito.Mock; +import org.neo4j.driver.internal.async.pool.AsyncConnectionPool; import org.neo4j.driver.internal.net.BoltServerAddress; import org.neo4j.driver.internal.spi.ConnectionPool; import org.neo4j.driver.v1.Logger; @@ -44,13 +45,15 @@ public class LeastConnectedLoadBalancingStrategyTest { @Mock private ConnectionPool connectionPool; + @Mock + private AsyncConnectionPool asyncConnectionPool; private LeastConnectedLoadBalancingStrategy strategy; @Before public void setUp() throws Exception { initMocks( this ); - strategy = new LeastConnectedLoadBalancingStrategy( connectionPool, DEV_NULL_LOGGING ); + strategy = new LeastConnectedLoadBalancingStrategy( connectionPool, asyncConnectionPool, DEV_NULL_LOGGING ); } @Test @@ -166,7 +169,8 @@ public void shouldTraceLogWhenNoAddressSelected() Logger logger = mock( Logger.class ); when( logging.getLog( anyString() ) ).thenReturn( logger ); - LoadBalancingStrategy strategy = new LeastConnectedLoadBalancingStrategy( connectionPool, logging ); + LoadBalancingStrategy strategy = new LeastConnectedLoadBalancingStrategy( connectionPool, asyncConnectionPool, + logging ); strategy.selectReader( new BoltServerAddress[0] ); strategy.selectWriter( new BoltServerAddress[0] ); @@ -184,7 +188,8 @@ public void shouldTraceLogSelectedAddress() when( connectionPool.activeConnections( any( BoltServerAddress.class ) ) ).thenReturn( 42 ); - LoadBalancingStrategy strategy = new LeastConnectedLoadBalancingStrategy( connectionPool, logging ); + LoadBalancingStrategy strategy = new LeastConnectedLoadBalancingStrategy( connectionPool, asyncConnectionPool, + logging ); strategy.selectReader( new BoltServerAddress[]{A} ); strategy.selectWriter( new BoltServerAddress[]{A} ); diff --git a/driver/src/test/java/org/neo4j/driver/internal/cluster/loadbalancing/LoadBalancerTest.java b/driver/src/test/java/org/neo4j/driver/internal/cluster/loadbalancing/LoadBalancerTest.java index 4f00ecbc6c..0378ac7c4a 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/cluster/loadbalancing/LoadBalancerTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/cluster/loadbalancing/LoadBalancerTest.java @@ -27,14 +27,19 @@ import java.util.Arrays; import java.util.Collections; import java.util.HashSet; +import java.util.LinkedHashSet; import java.util.Set; import java.util.concurrent.atomic.AtomicInteger; import org.neo4j.driver.internal.ExplicitTransaction; import org.neo4j.driver.internal.NetworkSession; import org.neo4j.driver.internal.SessionResourcesHandler; +import org.neo4j.driver.internal.async.AsyncConnection; +import org.neo4j.driver.internal.async.Futures; +import org.neo4j.driver.internal.async.pool.AsyncConnectionPool; import org.neo4j.driver.internal.cluster.AddressSet; import org.neo4j.driver.internal.cluster.ClusterComposition; +import org.neo4j.driver.internal.cluster.ClusterRoutingTable; import org.neo4j.driver.internal.cluster.Rediscovery; import org.neo4j.driver.internal.cluster.RoutingPooledConnection; import org.neo4j.driver.internal.cluster.RoutingTable; @@ -45,6 +50,7 @@ import org.neo4j.driver.internal.spi.Connection; import org.neo4j.driver.internal.spi.ConnectionPool; import org.neo4j.driver.internal.spi.PooledConnection; +import org.neo4j.driver.internal.util.FakeClock; import org.neo4j.driver.internal.util.SleeplessClock; import org.neo4j.driver.v1.AccessMode; import org.neo4j.driver.v1.Session; @@ -53,11 +59,14 @@ import org.neo4j.driver.v1.exceptions.ServiceUnavailableException; import org.neo4j.driver.v1.exceptions.SessionExpiredException; +import static java.util.Collections.emptySet; import static java.util.Collections.singleton; import static java.util.Collections.singletonList; +import static java.util.concurrent.CompletableFuture.completedFuture; import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.startsWith; import static org.hamcrest.core.IsEqual.equalTo; +import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertThat; @@ -68,10 +77,12 @@ import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; import static org.mockito.Mockito.spy; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.neo4j.driver.internal.async.Futures.getBlocking; import static org.neo4j.driver.internal.cluster.ClusterCompositionUtil.A; import static org.neo4j.driver.internal.cluster.ClusterCompositionUtil.B; import static org.neo4j.driver.internal.cluster.ClusterCompositionUtil.C; @@ -79,6 +90,7 @@ import static org.neo4j.driver.internal.net.BoltServerAddress.LOCAL_DEFAULT; import static org.neo4j.driver.v1.AccessMode.READ; import static org.neo4j.driver.v1.AccessMode.WRITE; +import static org.neo4j.driver.v1.util.TestUtil.asOrderedSet; public class LoadBalancerTest { @@ -93,7 +105,8 @@ public void ensureRoutingShouldUpdateRoutingTableAndPurgeConnectionPoolWhenStale when( routingTable.update( any( ClusterComposition.class ) ) ).thenReturn( set ); // when - LoadBalancer balancer = new LoadBalancer( conns, routingTable, rediscovery, DEV_NULL_LOGGING ); + LoadBalancer balancer = new LoadBalancer( conns, null, routingTable, rediscovery, GlobalEventExecutor.INSTANCE, + DEV_NULL_LOGGING ); // then assertNotNull( balancer ); @@ -103,13 +116,76 @@ public void ensureRoutingShouldUpdateRoutingTableAndPurgeConnectionPoolWhenStale inOrder.verify( conns ).purge( new BoltServerAddress( "abc", 12 ) ); } + @Test + public void acquireShouldUpdateRoutingTableWhenKnownRoutingTableIsStale() + { + BoltServerAddress initialRouter = new BoltServerAddress( "initialRouter", 1 ); + BoltServerAddress reader1 = new BoltServerAddress( "reader-1", 2 ); + BoltServerAddress reader2 = new BoltServerAddress( "reader-1", 3 ); + BoltServerAddress writer1 = new BoltServerAddress( "writer-1", 4 ); + BoltServerAddress router1 = new BoltServerAddress( "router-1", 5 ); + + AsyncConnectionPool asyncConnectionPool = newAsyncConnectionPoolMock(); + ClusterRoutingTable routingTable = new ClusterRoutingTable( new FakeClock(), initialRouter ); + + Set readers = new LinkedHashSet<>( Arrays.asList( reader1, reader2 ) ); + Set writers = new LinkedHashSet<>( singletonList( writer1 ) ); + Set routers = new LinkedHashSet<>( singletonList( router1 ) ); + ClusterComposition clusterComposition = new ClusterComposition( 42, readers, writers, routers ); + Rediscovery rediscovery = mock( Rediscovery.class ); + when( rediscovery.lookupClusterCompositionAsync( routingTable, asyncConnectionPool ) ) + .thenReturn( completedFuture( clusterComposition ) ); + + LoadBalancer loadBalancer = + new LoadBalancer( null, asyncConnectionPool, routingTable, rediscovery, GlobalEventExecutor.INSTANCE, + DEV_NULL_LOGGING ); + + assertNotNull( getBlocking( loadBalancer.acquireAsyncConnection( READ ) ) ); + + verify( rediscovery ).lookupClusterCompositionAsync( routingTable, asyncConnectionPool ); + assertArrayEquals( new BoltServerAddress[]{reader1, reader2}, routingTable.readers().toArray() ); + assertArrayEquals( new BoltServerAddress[]{writer1}, routingTable.writers().toArray() ); + assertArrayEquals( new BoltServerAddress[]{router1}, routingTable.routers().toArray() ); + } + + @Test + public void acquireShouldPurgeConnectionsWhenKnownRoutingTableIsStale() + { + BoltServerAddress initialRouter1 = new BoltServerAddress( "initialRouter-1", 1 ); + BoltServerAddress initialRouter2 = new BoltServerAddress( "initialRouter-2", 1 ); + BoltServerAddress reader = new BoltServerAddress( "reader", 2 ); + BoltServerAddress writer = new BoltServerAddress( "writer", 3 ); + BoltServerAddress router = new BoltServerAddress( "router", 4 ); + + AsyncConnectionPool asyncConnectionPool = newAsyncConnectionPoolMock(); + ClusterRoutingTable routingTable = new ClusterRoutingTable( new FakeClock(), initialRouter1, initialRouter2 ); + + Set readers = new HashSet<>( singletonList( reader ) ); + Set writers = new HashSet<>( singletonList( writer ) ); + Set routers = new HashSet<>( singletonList( router ) ); + ClusterComposition clusterComposition = new ClusterComposition( 42, readers, writers, routers ); + Rediscovery rediscovery = mock( Rediscovery.class ); + when( rediscovery.lookupClusterCompositionAsync( routingTable, asyncConnectionPool ) ) + .thenReturn( completedFuture( clusterComposition ) ); + + LoadBalancer loadBalancer = + new LoadBalancer( null, asyncConnectionPool, routingTable, rediscovery, GlobalEventExecutor.INSTANCE, + DEV_NULL_LOGGING ); + + assertNotNull( getBlocking( loadBalancer.acquireAsyncConnection( READ ) ) ); + + verify( rediscovery ).lookupClusterCompositionAsync( routingTable, asyncConnectionPool ); + verify( asyncConnectionPool ).purge( initialRouter1 ); + verify( asyncConnectionPool ).purge( initialRouter2 ); + } + @Test public void shouldRefreshRoutingTableOnInitialization() throws Exception { // given & when final AtomicInteger refreshRoutingTableCounter = new AtomicInteger( 0 ); - LoadBalancer balancer = new LoadBalancer( mock( ConnectionPool.class ), mock( RoutingTable.class ), - mock( Rediscovery.class ), DEV_NULL_LOGGING ) + LoadBalancer balancer = new LoadBalancer( mock( ConnectionPool.class ), null, + mock( RoutingTable.class ), mock( Rediscovery.class ), GlobalEventExecutor.INSTANCE, DEV_NULL_LOGGING ) { @Override synchronized void refreshRoutingTable() @@ -162,8 +238,11 @@ public void shouldForgetAddressAndItsConnectionsOnServiceUnavailableWhileClosing { RoutingTable routingTable = mock( RoutingTable.class ); ConnectionPool connectionPool = mock( ConnectionPool.class ); + AsyncConnectionPool asyncConnectionPool = mock( AsyncConnectionPool.class ); + Rediscovery rediscovery = mock( Rediscovery.class ); - LoadBalancer loadBalancer = new LoadBalancer( connectionPool, routingTable, rediscovery, DEV_NULL_LOGGING ); + LoadBalancer loadBalancer = new LoadBalancer( connectionPool, asyncConnectionPool, routingTable, rediscovery, + GlobalEventExecutor.INSTANCE, DEV_NULL_LOGGING ); BoltServerAddress address = new BoltServerAddress( "host", 42 ); PooledConnection connection = newConnectionWithFailingSync( address ); @@ -194,10 +273,12 @@ public void shouldForgetAddressAndItsConnectionsOnServiceUnavailableWhileClosing when( addressSet.toArray() ).thenReturn( new BoltServerAddress[]{address} ); when( routingTable.writers() ).thenReturn( addressSet ); ConnectionPool connectionPool = mock( ConnectionPool.class ); + AsyncConnectionPool asyncConnectionPool = mock( AsyncConnectionPool.class ); PooledConnection connectionWithFailingSync = newConnectionWithFailingSync( address ); when( connectionPool.acquire( any( BoltServerAddress.class ) ) ).thenReturn( connectionWithFailingSync ); Rediscovery rediscovery = mock( Rediscovery.class ); - LoadBalancer loadBalancer = new LoadBalancer( connectionPool, routingTable, rediscovery, DEV_NULL_LOGGING ); + LoadBalancer loadBalancer = new LoadBalancer( connectionPool, asyncConnectionPool, routingTable, rediscovery, + GlobalEventExecutor.INSTANCE, DEV_NULL_LOGGING ); Session session = newSession( loadBalancer ); // begin transaction to make session obtain a connection @@ -221,6 +302,18 @@ public void shouldRediscoverOnWriteWhenRoutingTableIsStaleForWrites() testRediscoveryWhenStale( WRITE ); } + @Test + public void shouldRediscoverOnReadWhenRoutingTableIsStaleForReadsAsync() + { + testRediscoveryWhenStaleAsync( READ ); + } + + @Test + public void shouldRediscoverOnWriteWhenRoutingTableIsStaleForWritesAsync() + { + testRediscoveryWhenStaleAsync( WRITE ); + } + @Test public void shouldNotRediscoverOnReadWhenRoutingTableIsStaleForWritesButNotReads() { @@ -233,17 +326,31 @@ public void shouldNotRediscoverOnWriteWhenRoutingTableIsStaleForReadsButNotWrite testNoRediscoveryWhenNotStale( READ, WRITE ); } + @Test + public void shouldNotRediscoverOnReadWhenRoutingTableIsStaleForWritesButNotReadsAsync() + { + testNoRediscoveryWhenNotStaleAsync( WRITE, READ ); + } + + @Test + public void shouldNotRediscoverOnWriteWhenRoutingTableIsStaleForReadsButNotWritesAsync() + { + testNoRediscoveryWhenNotStaleAsync( READ, WRITE ); + } + @Test public void shouldThrowWhenRediscoveryReturnsNoSuitableServers() { ConnectionPool connections = mock( ConnectionPool.class ); + AsyncConnectionPool asyncConnectionPool = mock( AsyncConnectionPool.class ); RoutingTable routingTable = mock( RoutingTable.class ); when( routingTable.isStaleFor( any( AccessMode.class ) ) ).thenReturn( true ); Rediscovery rediscovery = mock( Rediscovery.class ); when( routingTable.readers() ).thenReturn( new AddressSet() ); when( routingTable.writers() ).thenReturn( new AddressSet() ); - LoadBalancer loadBalancer = new LoadBalancer( connections, routingTable, rediscovery, DEV_NULL_LOGGING ); + LoadBalancer loadBalancer = new LoadBalancer( connections, asyncConnectionPool, routingTable, rediscovery, + GlobalEventExecutor.INSTANCE, DEV_NULL_LOGGING ); try { @@ -268,10 +375,50 @@ public void shouldThrowWhenRediscoveryReturnsNoSuitableServers() } } + @Test + public void shouldThrowWhenRediscoveryReturnsNoSuitableServersAsync() + { + AsyncConnectionPool asyncConnectionPool = newAsyncConnectionPoolMock(); + RoutingTable routingTable = mock( RoutingTable.class ); + when( routingTable.isStaleFor( any( AccessMode.class ) ) ).thenReturn( true ); + Rediscovery rediscovery = mock( Rediscovery.class ); + ClusterComposition emptyClusterComposition = new ClusterComposition( 42, emptySet(), emptySet(), emptySet() ); + when( rediscovery.lookupClusterCompositionAsync( routingTable, asyncConnectionPool ) ) + .thenReturn( completedFuture( emptyClusterComposition ) ); + when( routingTable.readers() ).thenReturn( new AddressSet() ); + when( routingTable.writers() ).thenReturn( new AddressSet() ); + + LoadBalancer loadBalancer = new LoadBalancer( null, asyncConnectionPool, routingTable, rediscovery, + GlobalEventExecutor.INSTANCE, DEV_NULL_LOGGING ); + + try + { + getBlocking( loadBalancer.acquireAsyncConnection( READ ) ); + fail( "Exception expected" ); + } + catch ( Exception e ) + { + assertThat( e, instanceOf( SessionExpiredException.class ) ); + assertThat( e.getMessage(), startsWith( "Failed to obtain connection towards READ server" ) ); + } + + try + { + getBlocking( loadBalancer.acquireAsyncConnection( WRITE ) ); + fail( "Exception expected" ); + } + catch ( Exception e ) + { + assertThat( e, instanceOf( SessionExpiredException.class ) ); + assertThat( e.getMessage(), startsWith( "Failed to obtain connection towards WRITE server" ) ); + } + } + @Test public void shouldSelectLeastConnectedAddress() { ConnectionPool connectionPool = newConnectionPoolMock(); + AsyncConnectionPool asyncConnectionPool = newAsyncConnectionPoolMock(); when( connectionPool.activeConnections( A ) ).thenReturn( 0 ); when( connectionPool.activeConnections( B ) ).thenReturn( 20 ); when( connectionPool.activeConnections( C ) ).thenReturn( 0 ); @@ -283,7 +430,8 @@ public void shouldSelectLeastConnectedAddress() Rediscovery rediscovery = mock( Rediscovery.class ); - LoadBalancer loadBalancer = new LoadBalancer( connectionPool, routingTable, rediscovery, DEV_NULL_LOGGING ); + LoadBalancer loadBalancer = new LoadBalancer( connectionPool, asyncConnectionPool, routingTable, rediscovery, + GlobalEventExecutor.INSTANCE, DEV_NULL_LOGGING ); Set seenAddresses = new HashSet<>(); for ( int i = 0; i < 10; i++ ) @@ -297,10 +445,43 @@ public void shouldSelectLeastConnectedAddress() assertTrue( seenAddresses.containsAll( Arrays.asList( A, C ) ) ); } + @Test + public void shouldSelectLeastConnectedAddressAsync() + { + AsyncConnectionPool asyncConnectionPool = newAsyncConnectionPoolMock(); + + when( asyncConnectionPool.activeConnections( A ) ).thenReturn( 0 ); + when( asyncConnectionPool.activeConnections( B ) ).thenReturn( 20 ); + when( asyncConnectionPool.activeConnections( C ) ).thenReturn( 0 ); + + RoutingTable routingTable = mock( RoutingTable.class ); + AddressSet readerAddresses = mock( AddressSet.class ); + when( readerAddresses.toArray() ).thenReturn( new BoltServerAddress[]{A, B, C} ); + when( routingTable.readers() ).thenReturn( readerAddresses ); + + Rediscovery rediscovery = mock( Rediscovery.class ); + + LoadBalancer loadBalancer = + new LoadBalancer( null, asyncConnectionPool, routingTable, rediscovery, GlobalEventExecutor.INSTANCE, + DEV_NULL_LOGGING ); + + Set seenAddresses = new HashSet<>(); + for ( int i = 0; i < 10; i++ ) + { + AsyncConnection connection = getBlocking( loadBalancer.acquireAsyncConnection( READ ) ); + seenAddresses.add( connection.serverAddress() ); + } + + // server B should never be selected because it has many active connections + assertEquals( 2, seenAddresses.size() ); + assertTrue( seenAddresses.containsAll( Arrays.asList( A, C ) ) ); + } + @Test public void shouldRoundRobinWhenNoActiveConnections() { ConnectionPool connectionPool = newConnectionPoolMock(); + AsyncConnectionPool asyncConnectionPool = newAsyncConnectionPoolMock(); RoutingTable routingTable = mock( RoutingTable.class ); AddressSet readerAddresses = mock( AddressSet.class ); @@ -309,7 +490,8 @@ public void shouldRoundRobinWhenNoActiveConnections() Rediscovery rediscovery = mock( Rediscovery.class ); - LoadBalancer loadBalancer = new LoadBalancer( connectionPool, routingTable, rediscovery, DEV_NULL_LOGGING ); + LoadBalancer loadBalancer = new LoadBalancer( connectionPool, asyncConnectionPool, routingTable, rediscovery, + GlobalEventExecutor.INSTANCE, DEV_NULL_LOGGING ); Set seenAddresses = new HashSet<>(); for ( int i = 0; i < 10; i++ ) @@ -322,6 +504,57 @@ public void shouldRoundRobinWhenNoActiveConnections() assertTrue( seenAddresses.containsAll( Arrays.asList( A, B, C ) ) ); } + @Test + public void shouldRoundRobinWhenNoActiveConnectionsAsync() + { + ConnectionPool connectionPool = newConnectionPoolMock(); + AsyncConnectionPool asyncConnectionPool = newAsyncConnectionPoolMock(); + + RoutingTable routingTable = mock( RoutingTable.class ); + AddressSet readerAddresses = mock( AddressSet.class ); + when( readerAddresses.toArray() ).thenReturn( new BoltServerAddress[]{A, B, C} ); + when( routingTable.readers() ).thenReturn( readerAddresses ); + + Rediscovery rediscovery = mock( Rediscovery.class ); + + LoadBalancer loadBalancer = new LoadBalancer( connectionPool, asyncConnectionPool, routingTable, rediscovery, + GlobalEventExecutor.INSTANCE, DEV_NULL_LOGGING ); + + Set seenAddresses = new HashSet<>(); + for ( int i = 0; i < 10; i++ ) + { + AsyncConnection connection = getBlocking( loadBalancer.acquireAsyncConnection( READ ) ); + seenAddresses.add( connection.serverAddress() ); + } + + assertEquals( 3, seenAddresses.size() ); + assertTrue( seenAddresses.containsAll( Arrays.asList( A, B, C ) ) ); + } + + @Test + public void shouldTryMultipleServersAfterRediscovery() + { + Set unavailableAddresses = asOrderedSet( A ); + AsyncConnectionPool asyncConnectionPool = newAsyncConnectionPoolMockWithFailures( unavailableAddresses ); + + ClusterRoutingTable routingTable = new ClusterRoutingTable( new FakeClock(), A ); + Rediscovery rediscovery = mock( Rediscovery.class ); + ClusterComposition clusterComposition = new ClusterComposition( 42, + asOrderedSet( A, B ), asOrderedSet( A, B ), asOrderedSet( A, B ) ); + when( rediscovery.lookupClusterCompositionAsync( any(), any() ) ) + .thenReturn( completedFuture( clusterComposition ) ); + + LoadBalancer loadBalancer = new LoadBalancer( null, asyncConnectionPool, routingTable, rediscovery, + GlobalEventExecutor.INSTANCE, DEV_NULL_LOGGING ); + + AsyncConnection connection = getBlocking( loadBalancer.acquireAsyncConnection( READ ) ); + + assertNotNull( connection ); + assertEquals( B, connection.serverAddress() ); + // routing table should've forgotten A + assertArrayEquals( new BoltServerAddress[]{B}, routingTable.readers().toArray() ); + } + private void testRediscoveryWhenStale( AccessMode mode ) { ConnectionPool connections = mock( ConnectionPool.class ); @@ -330,7 +563,8 @@ private void testRediscoveryWhenStale( AccessMode mode ) RoutingTable routingTable = newStaleRoutingTableMock( mode ); Rediscovery rediscovery = newRediscoveryMock(); - LoadBalancer loadBalancer = new LoadBalancer( connections, routingTable, rediscovery, DEV_NULL_LOGGING ); + LoadBalancer loadBalancer = new LoadBalancer( connections, null, routingTable, rediscovery, + GlobalEventExecutor.INSTANCE, DEV_NULL_LOGGING ); verify( rediscovery ).lookupClusterComposition( routingTable, connections ); assertNotNull( loadBalancer.acquireConnection( mode ) ); @@ -338,6 +572,25 @@ private void testRediscoveryWhenStale( AccessMode mode ) verify( rediscovery, times( 2 ) ).lookupClusterComposition( routingTable, connections ); } + private void testRediscoveryWhenStaleAsync( AccessMode mode ) + { + AsyncConnectionPool asyncConnectionPool = mock( AsyncConnectionPool.class ); + when( asyncConnectionPool.acquire( LOCAL_DEFAULT ) ) + .thenReturn( completedFuture( mock( AsyncConnection.class ) ) ); + + RoutingTable routingTable = newStaleRoutingTableMock( mode ); + Rediscovery rediscovery = newRediscoveryMock(); + + LoadBalancer loadBalancer = + new LoadBalancer( null, asyncConnectionPool, routingTable, rediscovery, GlobalEventExecutor.INSTANCE, + DEV_NULL_LOGGING ); + AsyncConnection connection = getBlocking( loadBalancer.acquireAsyncConnection( mode ) ); + assertNotNull( connection ); + + verify( routingTable ).isStaleFor( mode ); + verify( rediscovery ).lookupClusterCompositionAsync( routingTable, asyncConnectionPool ); + } + private void testNoRediscoveryWhenNotStale( AccessMode staleMode, AccessMode notStaleMode ) { ConnectionPool connections = mock( ConnectionPool.class ); @@ -346,7 +599,8 @@ private void testNoRediscoveryWhenNotStale( AccessMode staleMode, AccessMode not RoutingTable routingTable = newStaleRoutingTableMock( staleMode ); Rediscovery rediscovery = newRediscoveryMock(); - LoadBalancer loadBalancer = new LoadBalancer( connections, routingTable, rediscovery, DEV_NULL_LOGGING ); + LoadBalancer loadBalancer = new LoadBalancer( connections, null, routingTable, rediscovery, + GlobalEventExecutor.INSTANCE, DEV_NULL_LOGGING ); verify( rediscovery ).lookupClusterComposition( routingTable, connections ); assertNotNull( loadBalancer.acquireConnection( notStaleMode ) ); @@ -354,13 +608,24 @@ private void testNoRediscoveryWhenNotStale( AccessMode staleMode, AccessMode not verify( rediscovery ).lookupClusterComposition( routingTable, connections ); } - private LoadBalancer setupLoadBalancer( PooledConnection writerConn, PooledConnection readConn ) + private void testNoRediscoveryWhenNotStaleAsync( AccessMode staleMode, AccessMode notStaleMode ) { - return setupLoadBalancer( writerConn, readConn, mock( Rediscovery.class ) ); + AsyncConnectionPool asyncConnectionPool = mock( AsyncConnectionPool.class ); + when( asyncConnectionPool.acquire( LOCAL_DEFAULT ) ) + .thenReturn( completedFuture( mock( AsyncConnection.class ) ) ); + + RoutingTable routingTable = newStaleRoutingTableMock( staleMode ); + Rediscovery rediscovery = newRediscoveryMock(); + + LoadBalancer loadBalancer = new LoadBalancer( null, asyncConnectionPool, routingTable, rediscovery, + GlobalEventExecutor.INSTANCE, DEV_NULL_LOGGING ); + + assertNotNull( getBlocking( loadBalancer.acquireAsyncConnection( notStaleMode ) ) ); + verify( routingTable ).isStaleFor( notStaleMode ); + verify( rediscovery, never() ).lookupClusterCompositionAsync( routingTable, asyncConnectionPool ); } - private LoadBalancer setupLoadBalancer( PooledConnection writerConn, PooledConnection readConn, - Rediscovery rediscovery ) + private LoadBalancer setupLoadBalancer( PooledConnection writerConn, PooledConnection readConn ) { BoltServerAddress writer = mock( BoltServerAddress.class ); BoltServerAddress reader = mock( BoltServerAddress.class ); @@ -369,6 +634,8 @@ private LoadBalancer setupLoadBalancer( PooledConnection writerConn, PooledConne when( connPool.acquire( writer ) ).thenReturn( writerConn ); when( connPool.acquire( reader ) ).thenReturn( readConn ); + AsyncConnectionPool asyncConnectionPool = mock( AsyncConnectionPool.class ); + AddressSet writerAddrs = mock( AddressSet.class ); when( writerAddrs.toArray() ).thenReturn( new BoltServerAddress[]{writer} ); @@ -379,7 +646,10 @@ private LoadBalancer setupLoadBalancer( PooledConnection writerConn, PooledConne when( routingTable.readers() ).thenReturn( readerAddrs ); when( routingTable.writers() ).thenReturn( writerAddrs ); - return new LoadBalancer( connPool, routingTable, rediscovery, DEV_NULL_LOGGING ); + Rediscovery rediscovery = mock( Rediscovery.class ); + + return new LoadBalancer( connPool, asyncConnectionPool, routingTable, rediscovery, GlobalEventExecutor.INSTANCE, + DEV_NULL_LOGGING ); } private static Session newSession( LoadBalancer loadBalancer ) @@ -421,6 +691,8 @@ private static Rediscovery newRediscoveryMock() ClusterComposition clusterComposition = new ClusterComposition( 1, noServers, noServers, noServers ); when( rediscovery.lookupClusterComposition( any( RoutingTable.class ), any( ConnectionPool.class ) ) ) .thenReturn( clusterComposition ); + when( rediscovery.lookupClusterCompositionAsync( any( RoutingTable.class ), any( AsyncConnectionPool.class ) ) ) + .thenReturn( completedFuture( clusterComposition ) ); return rediscovery; } @@ -440,4 +712,27 @@ public PooledConnection answer( InvocationOnMock invocation ) throws Throwable } ); return connectionPool; } + + private static AsyncConnectionPool newAsyncConnectionPoolMock() + { + return newAsyncConnectionPoolMockWithFailures( emptySet() ); + } + + private static AsyncConnectionPool newAsyncConnectionPoolMockWithFailures( + Set unavailableAddresses ) + { + AsyncConnectionPool pool = mock( AsyncConnectionPool.class ); + when( pool.acquire( any( BoltServerAddress.class ) ) ).then( invocation -> + { + BoltServerAddress requestedAddress = invocation.getArgumentAt( 0, BoltServerAddress.class ); + if ( unavailableAddresses.contains( requestedAddress ) ) + { + return Futures.failedFuture( new ServiceUnavailableException( requestedAddress + " is unavailable!" ) ); + } + AsyncConnection connection = mock( AsyncConnection.class ); + when( connection.serverAddress() ).thenReturn( requestedAddress ); + return completedFuture( connection ); + } ); + return pool; + } } diff --git a/driver/src/test/java/org/neo4j/driver/internal/security/TrustOnFirstUseTrustManagerTest.java b/driver/src/test/java/org/neo4j/driver/internal/security/TrustOnFirstUseTrustManagerTest.java index 4957c96eb0..f664ec5e41 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/security/TrustOnFirstUseTrustManagerTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/security/TrustOnFirstUseTrustManagerTest.java @@ -54,7 +54,7 @@ public class TrustOnFirstUseTrustManagerTest private String knownServer; @Rule - public TemporaryFolder testDir = new TemporaryFolder(); + public TemporaryFolder testDir = new TemporaryFolder( new File( "target" ) ); private X509Certificate knownCertificate; @Before diff --git a/driver/src/test/java/org/neo4j/driver/internal/util/DriverFactoryWithOneEventLoopThread.java b/driver/src/test/java/org/neo4j/driver/internal/util/DriverFactoryWithOneEventLoopThread.java new file mode 100644 index 0000000000..ecbbf4adc0 --- /dev/null +++ b/driver/src/test/java/org/neo4j/driver/internal/util/DriverFactoryWithOneEventLoopThread.java @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2002-2017 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal.util; + +import io.netty.bootstrap.Bootstrap; + +import java.net.URI; + +import org.neo4j.driver.internal.DriverFactory; +import org.neo4j.driver.internal.async.BootstrapFactory; +import org.neo4j.driver.internal.cluster.RoutingSettings; +import org.neo4j.driver.internal.retry.RetrySettings; +import org.neo4j.driver.v1.AuthToken; +import org.neo4j.driver.v1.Config; +import org.neo4j.driver.v1.Driver; + +public class DriverFactoryWithOneEventLoopThread extends DriverFactory +{ + public Driver newInstance( URI uri, AuthToken authToken, Config config ) + { + return newInstance( uri, authToken, new RoutingSettings( 1, 0 ), RetrySettings.DEFAULT, config ); + } + + @Override + protected Bootstrap createBootstrap() + { + return BootstrapFactory.newBootstrap( 1 ); + } +} diff --git a/driver/src/test/java/org/neo4j/driver/v1/integration/CredentialsIT.java b/driver/src/test/java/org/neo4j/driver/v1/integration/CredentialsIT.java index 020a031187..24f73f7b42 100644 --- a/driver/src/test/java/org/neo4j/driver/v1/integration/CredentialsIT.java +++ b/driver/src/test/java/org/neo4j/driver/v1/integration/CredentialsIT.java @@ -23,6 +23,7 @@ import org.junit.Test; import org.junit.rules.TemporaryFolder; +import java.io.File; import java.util.HashMap; import org.neo4j.driver.internal.security.InternalAuthToken; @@ -52,7 +53,7 @@ public class CredentialsIT { @ClassRule - public static TemporaryFolder tempDir = new TemporaryFolder(); + public static TemporaryFolder tempDir = new TemporaryFolder( new File( "target" ) ); @ClassRule public static TestNeo4j neo4j = new TestNeo4j(); diff --git a/driver/src/test/java/org/neo4j/driver/v1/integration/TLSSocketChannelIT.java b/driver/src/test/java/org/neo4j/driver/v1/integration/TLSSocketChannelIT.java index 4087f8706c..b17b31f38a 100644 --- a/driver/src/test/java/org/neo4j/driver/v1/integration/TLSSocketChannelIT.java +++ b/driver/src/test/java/org/neo4j/driver/v1/integration/TLSSocketChannelIT.java @@ -66,7 +66,7 @@ public class TLSSocketChannelIT public TestNeo4j neo4j = new TestNeo4j(); @Rule - public TemporaryFolder folder = new TemporaryFolder(); + public TemporaryFolder folder = new TemporaryFolder( new File( "target" ) ); @BeforeClass public static void setup() throws IOException, InterruptedException diff --git a/driver/src/test/java/org/neo4j/driver/v1/stress/AbstractStressTestBase.java b/driver/src/test/java/org/neo4j/driver/v1/stress/AbstractStressTestBase.java index 1b3e8e364a..8e952494ae 100644 --- a/driver/src/test/java/org/neo4j/driver/v1/stress/AbstractStressTestBase.java +++ b/driver/src/test/java/org/neo4j/driver/v1/stress/AbstractStressTestBase.java @@ -60,7 +60,6 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNull; import static org.junit.Assert.assertThat; -import static org.junit.Assume.assumeTrue; public abstract class AbstractStressTestBase { @@ -110,9 +109,6 @@ public void blockingApiStressTest() throws Throwable @Test public void asyncApiStressTest() throws Throwable { - // todo: re-enable when async is supported in routing driver - assumeTrue( "bolt".equalsIgnoreCase( databaseUri().getScheme() ) ); - runStressTest( this::launchAsyncWorkerThreads ); } diff --git a/driver/src/test/java/org/neo4j/driver/v1/tck/DriverComplianceIT.java b/driver/src/test/java/org/neo4j/driver/v1/tck/DriverComplianceIT.java index d158960e93..dee8835918 100644 --- a/driver/src/test/java/org/neo4j/driver/v1/tck/DriverComplianceIT.java +++ b/driver/src/test/java/org/neo4j/driver/v1/tck/DriverComplianceIT.java @@ -24,6 +24,7 @@ import org.junit.rules.TemporaryFolder; import org.junit.runner.RunWith; +import java.io.File; import java.io.IOException; import org.neo4j.driver.v1.util.TestNeo4j; @@ -37,7 +38,7 @@ public class DriverComplianceIT { @Rule - TemporaryFolder folder = new TemporaryFolder(); + TemporaryFolder folder = new TemporaryFolder( new File( "target" ) ); @ClassRule public static TestNeo4j neo4j = new TestNeo4j(); diff --git a/driver/src/test/java/org/neo4j/driver/v1/util/Neo4jRunner.java b/driver/src/test/java/org/neo4j/driver/v1/util/Neo4jRunner.java index 4ffda3be4b..8898d22c0e 100644 --- a/driver/src/test/java/org/neo4j/driver/v1/util/Neo4jRunner.java +++ b/driver/src/test/java/org/neo4j/driver/v1/util/Neo4jRunner.java @@ -51,7 +51,7 @@ public class Neo4jRunner private static final boolean debug = true; - private static final String DEFAULT_NEOCTRL_ARGS = "-e 3.2.0"; + private static final String DEFAULT_NEOCTRL_ARGS = "-e 3.2.5"; public static final String NEOCTRL_ARGS = System.getProperty( "neoctrl.args", DEFAULT_NEOCTRL_ARGS ); public static final URI DEFAULT_URI = URI.create( "bolt://localhost:7687" ); public static final BoltServerAddress DEFAULT_ADDRESS = new BoltServerAddress( DEFAULT_URI ); diff --git a/driver/src/test/java/org/neo4j/driver/v1/util/TestUtil.java b/driver/src/test/java/org/neo4j/driver/v1/util/TestUtil.java index 1b15faf947..2f81f3e90e 100644 --- a/driver/src/test/java/org/neo4j/driver/v1/util/TestUtil.java +++ b/driver/src/test/java/org/neo4j/driver/v1/util/TestUtil.java @@ -19,9 +19,13 @@ package org.neo4j.driver.v1.util; import io.netty.buffer.ByteBuf; +import io.netty.util.internal.PlatformDependent; import java.util.ArrayList; +import java.util.Arrays; +import java.util.LinkedHashSet; import java.util.List; +import java.util.Set; import java.util.concurrent.CompletionStage; import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; @@ -66,7 +70,7 @@ public static > T await( U future ) } catch ( ExecutionException e ) { - throwException( e.getCause() ); + PlatformDependent.throwException( e.getCause() ); return null; } catch ( TimeoutException e ) @@ -112,15 +116,10 @@ public static void assertByteBufEquals( ByteBuf expected, ByteBuf actual ) } } - private static void throwException( Throwable t ) + @SafeVarargs + public static Set asOrderedSet( T... elements ) { - TestUtil.doThrowException( t ); - } - - @SuppressWarnings( "unchecked" ) - private static void doThrowException( Throwable t ) throws E - { - throw (E) t; + return new LinkedHashSet<>( Arrays.asList( elements ) ); } private static Number read( ByteBuf buf, Class type ) diff --git a/driver/src/test/java/org/neo4j/driver/v1/util/cc/Cluster.java b/driver/src/test/java/org/neo4j/driver/v1/util/cc/Cluster.java index 4ab07fe00c..8a95f7f82f 100644 --- a/driver/src/test/java/org/neo4j/driver/v1/util/cc/Cluster.java +++ b/driver/src/test/java/org/neo4j/driver/v1/util/cc/Cluster.java @@ -30,16 +30,17 @@ import org.neo4j.driver.internal.net.BoltServerAddress; import org.neo4j.driver.internal.util.Consumer; +import org.neo4j.driver.internal.util.DriverFactoryWithOneEventLoopThread; import org.neo4j.driver.v1.AccessMode; import org.neo4j.driver.v1.AuthTokens; import org.neo4j.driver.v1.Config; import org.neo4j.driver.v1.Driver; -import org.neo4j.driver.v1.GraphDatabase; import org.neo4j.driver.v1.Record; import org.neo4j.driver.v1.Session; import org.neo4j.driver.v1.StatementResult; import static java.util.Collections.unmodifiableSet; +import static org.neo4j.driver.internal.logging.DevNullLogging.DEV_NULL_LOGGING; import static org.neo4j.driver.internal.util.Iterables.single; import static org.neo4j.driver.v1.Config.TrustStrategy.trustAllCertificates; @@ -318,13 +319,13 @@ private static Driver createDriver( Set members, String password private static List findClusterOverview( Session session ) { - StatementResult result = session.run( "call dbms.cluster.overview" ); + StatementResult result = session.run( "CALL dbms.cluster.overview()" ); return result.list(); } private static boolean isCoreMember( Session session ) { - Record record = single( session.run( "call dbms.cluster.role" ).list() ); + Record record = single( session.run( "call dbms.cluster.role()" ).list() ); ClusterMemberRole role = extractRole( record ); return role != ClusterMemberRole.READ_REPLICA; } @@ -412,15 +413,18 @@ private static ClusterMember findByBoltAddress( BoltServerAddress boltAddress, S private static Driver createDriver( URI boltUri, String password ) { - return GraphDatabase.driver( boltUri, AuthTokens.basic( ADMIN_USER, password ), driverConfig() ); + DriverFactoryWithOneEventLoopThread factory = new DriverFactoryWithOneEventLoopThread(); + return factory.newInstance( boltUri, AuthTokens.basic( ADMIN_USER, password ), driverConfig() ); } private static Config driverConfig() { // try to build config for a very lightweight driver return Config.build() + .withLogging( DEV_NULL_LOGGING ) .withTrustStrategy( trustAllCertificates() ) .withEncryption() + .withMaxConnectionPoolSize( 1 ) .withMaxIdleConnections( 1 ) .withConnectionLivenessCheckTimeout( 1, TimeUnit.HOURS ) .toConfig(); diff --git a/driver/src/test/java/org/neo4j/driver/v1/util/cc/LocalOrRemoteClusterRule.java b/driver/src/test/java/org/neo4j/driver/v1/util/cc/LocalOrRemoteClusterRule.java index be28cc49b5..1f85ed4e78 100644 --- a/driver/src/test/java/org/neo4j/driver/v1/util/cc/LocalOrRemoteClusterRule.java +++ b/driver/src/test/java/org/neo4j/driver/v1/util/cc/LocalOrRemoteClusterRule.java @@ -92,7 +92,7 @@ private static void assertValidSystemPropertiesDefined() } if ( uri != null && !BOLT_ROUTING_URI_SCHEME.equals( uri.getScheme() ) ) { - throw new IllegalStateException( "CLuster uri should have bolt+routing scheme: '" + uri + "'" ); + throw new IllegalStateException( "Cluster uri should have bolt+routing scheme: '" + uri + "'" ); } }