From 28a3c3f7b14a140e85c75d0bde5603531fd56495 Mon Sep 17 00:00:00 2001 From: Tobias Lindaaker Date: Mon, 24 Oct 2016 13:38:20 +0200 Subject: [PATCH] Refactor routing logic to better handle failures Separate out different parts of the logic into different classes: * When to update the view of the cluster, and the maintenance of the set of servers in the cluster: - LoadBalancer * The ability to retrieve a list of servers with roles (ClusterComposition) from a server: - ClusterComposition.Provider * Round Robin cycling through a set of addresses (maintaining the order/position on addition/removal of members in the set): - RoundRobinAddressSet The new logic distinguishes being unable to connect to the routing servers and maintains the server in the list in this case, under the assumption that the server is only temporarily unavailable. The attempt to contact routing servers only fail after a number of failed attempts with an exponential back off in between. These changes introduce event based testing for some of the components, this makes it easier to test for the behaviour of units, rather than asserting on their state, while the explicitness of it makes it easier to follow than using mocks. --- .../neo4j/driver/internal/ClusterView.java | 175 ----- .../neo4j/driver/internal/NetworkSession.java | 5 + .../neo4j/driver/internal/RoutingDriver.java | 285 +-------- .../driver/internal/RoutingErrorHandler.java | 2 +- .../internal/cluster/ClusterComposition.java | 203 ++++++ .../driver/internal/cluster/LoadBalancer.java | 228 +++++++ .../cluster/RoundRobinAddressSet.java | 136 ++++ .../internal/cluster/RoutingSettings.java | 31 + .../org/neo4j/driver/internal/util/Clock.java | 8 + .../util/ConcurrentRoundRobinSet.java | 152 ----- .../main/java/org/neo4j/driver/v1/Config.java | 58 ++ .../org/neo4j/driver/v1/GraphDatabase.java | 9 +- .../driver/internal/ClusterViewTest.java | 189 ------ .../java/org/neo4j/driver/internal/Event.java | 64 ++ .../neo4j/driver/internal/EventHandler.java | 285 +++++++++ .../driver/internal/RoutingDriverTest.java | 446 ++++++------- .../ClusterCompositionProviderTest.java | 248 ++++++++ .../internal/cluster/ClusterTopology.java | 198 ++++++ .../internal/cluster/LoadBalancerTest.java | 599 ++++++++++++++++++ .../cluster/RoundRobinAddressSetTest.java | 226 +++++++ .../internal/spi/StubConnectionPool.java | 498 +++++++++++++++ .../util/ConcurrentRoundRobinSetTest.java | 170 ----- .../neo4j/driver/internal/util/FakeClock.java | 260 ++++++++ .../driver/internal/util/MatcherFactory.java | 230 +++++++ .../java/org/neo4j/driver/v1/EventLogger.java | 225 +++++++ 25 files changed, 3722 insertions(+), 1208 deletions(-) delete mode 100644 driver/src/main/java/org/neo4j/driver/internal/ClusterView.java create mode 100644 driver/src/main/java/org/neo4j/driver/internal/cluster/ClusterComposition.java create mode 100644 driver/src/main/java/org/neo4j/driver/internal/cluster/LoadBalancer.java create mode 100644 driver/src/main/java/org/neo4j/driver/internal/cluster/RoundRobinAddressSet.java create mode 100644 driver/src/main/java/org/neo4j/driver/internal/cluster/RoutingSettings.java delete mode 100644 driver/src/main/java/org/neo4j/driver/internal/util/ConcurrentRoundRobinSet.java delete mode 100644 driver/src/test/java/org/neo4j/driver/internal/ClusterViewTest.java create mode 100644 driver/src/test/java/org/neo4j/driver/internal/Event.java create mode 100644 driver/src/test/java/org/neo4j/driver/internal/EventHandler.java create mode 100644 driver/src/test/java/org/neo4j/driver/internal/cluster/ClusterCompositionProviderTest.java create mode 100644 driver/src/test/java/org/neo4j/driver/internal/cluster/ClusterTopology.java create mode 100644 driver/src/test/java/org/neo4j/driver/internal/cluster/LoadBalancerTest.java create mode 100644 driver/src/test/java/org/neo4j/driver/internal/cluster/RoundRobinAddressSetTest.java create mode 100644 driver/src/test/java/org/neo4j/driver/internal/spi/StubConnectionPool.java delete mode 100644 driver/src/test/java/org/neo4j/driver/internal/util/ConcurrentRoundRobinSetTest.java create mode 100644 driver/src/test/java/org/neo4j/driver/internal/util/FakeClock.java create mode 100644 driver/src/test/java/org/neo4j/driver/internal/util/MatcherFactory.java create mode 100644 driver/src/test/java/org/neo4j/driver/v1/EventLogger.java diff --git a/driver/src/main/java/org/neo4j/driver/internal/ClusterView.java b/driver/src/main/java/org/neo4j/driver/internal/ClusterView.java deleted file mode 100644 index 3e378abb08..0000000000 --- a/driver/src/main/java/org/neo4j/driver/internal/ClusterView.java +++ /dev/null @@ -1,175 +0,0 @@ -/** - * Copyright (c) 2002-2016 "Neo Technology," - * Network Engine for Objects in Lund AB [http://neotechnology.com] - * - * This file is part of Neo4j. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.neo4j.driver.internal; - -import java.util.Collections; -import java.util.Comparator; -import java.util.HashSet; -import java.util.List; -import java.util.Set; - -import org.neo4j.driver.internal.net.BoltServerAddress; -import org.neo4j.driver.internal.util.Clock; -import org.neo4j.driver.internal.util.ConcurrentRoundRobinSet; -import org.neo4j.driver.v1.Logger; - -/** - * Defines a snapshot view of the cluster. - */ -class ClusterView -{ - private final static Comparator COMPARATOR = new Comparator() - { - @Override - public int compare( BoltServerAddress o1, BoltServerAddress o2 ) - { - int compare = o1.host().compareTo( o2.host() ); - if ( compare == 0 ) - { - compare = Integer.compare( o1.port(), o2.port() ); - } - - return compare; - } - }; - - private static final int MIN_ROUTERS = 1; - - private final ConcurrentRoundRobinSet routingServers = - new ConcurrentRoundRobinSet<>( COMPARATOR ); - private final ConcurrentRoundRobinSet readServers = - new ConcurrentRoundRobinSet<>( COMPARATOR ); - private final ConcurrentRoundRobinSet writeServers = - new ConcurrentRoundRobinSet<>( COMPARATOR ); - private final Clock clock; - private final long expires; - private final Logger log; - - public ClusterView( long expires, Clock clock, Logger log ) - { - this.expires = expires; - this.clock = clock; - this.log = log; - } - - public void addRouter( BoltServerAddress router ) - { - this.routingServers.add( router ); - } - - public boolean isStale() - { - return expires < clock.millis() || - routingServers.size() <= MIN_ROUTERS || - readServers.isEmpty() || - writeServers.isEmpty(); - } - - Set all() - { - HashSet all = - new HashSet<>( routingServers.size() + readServers.size() + writeServers.size() ); - all.addAll( routingServers ); - all.addAll( readServers ); - all.addAll( writeServers ); - return all; - } - - - public BoltServerAddress nextRouter() - { - return routingServers.hop(); - } - - public BoltServerAddress nextReader() - { - return readServers.hop(); - } - - public BoltServerAddress nextWriter() - { - return writeServers.hop(); - } - - public void addReaders( List addresses ) - { - readServers.addAll( addresses ); - } - - public void addWriters( List addresses ) - { - writeServers.addAll( addresses ); - } - - public void addRouters( List addresses ) - { - routingServers.addAll( addresses ); - } - - public void remove( BoltServerAddress address ) - { - if ( routingServers.remove( address ) ) - { - log.debug( "Removing %s from routers", address.toString() ); - } - if ( readServers.remove( address ) ) - { - log.debug( "Removing %s from readers", address.toString() ); - } - if ( writeServers.remove( address ) ) - { - log.debug( "Removing %s from writers", address.toString() ); - } - } - - public boolean removeWriter( BoltServerAddress address ) - { - return writeServers.remove( address ); - } - - public int numberOfRouters() - { - return routingServers.size(); - } - - public int numberOfReaders() - { - return readServers.size(); - } - - public int numberOfWriters() - { - return writeServers.size(); - } - - public Set routingServers() - { - return Collections.unmodifiableSet( routingServers ); - } - - public Set readServers() - { - return Collections.unmodifiableSet( readServers ); - } - - public Set writeServers() - { - return Collections.unmodifiableSet( writeServers ); - } -} diff --git a/driver/src/main/java/org/neo4j/driver/internal/NetworkSession.java b/driver/src/main/java/org/neo4j/driver/internal/NetworkSession.java index 93cd1c558e..3cf91d468a 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/NetworkSession.java +++ b/driver/src/main/java/org/neo4j/driver/internal/NetworkSession.java @@ -113,6 +113,11 @@ public StatementResult run( String statementText, Value statementParameters ) public StatementResult run( Statement statement ) { ensureConnectionIsValidBeforeRunningSession(); + return run( connection, statement ); + } + + public static StatementResult run( Connection connection, Statement statement ) + { InternalStatementResult cursor = new InternalStatementResult( connection, null, statement ); connection.run( statement.text(), statement.parameters().asMap( Values.ofValue() ), cursor.runResponseCollector() ); diff --git a/driver/src/main/java/org/neo4j/driver/internal/RoutingDriver.java b/driver/src/main/java/org/neo4j/driver/internal/RoutingDriver.java index 38cc595144..8705f08d53 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/RoutingDriver.java +++ b/driver/src/main/java/org/neo4j/driver/internal/RoutingDriver.java @@ -18,9 +18,10 @@ */ package org.neo4j.driver.internal; -import java.util.List; import java.util.Set; +import org.neo4j.driver.internal.cluster.LoadBalancer; +import org.neo4j.driver.internal.cluster.RoutingSettings; import org.neo4j.driver.internal.net.BoltServerAddress; import org.neo4j.driver.internal.security.SecurityPlan; import org.neo4j.driver.internal.spi.Connection; @@ -28,209 +29,25 @@ import org.neo4j.driver.internal.util.Clock; import org.neo4j.driver.v1.AccessMode; import org.neo4j.driver.v1.Logging; -import org.neo4j.driver.v1.Record; import org.neo4j.driver.v1.Session; -import org.neo4j.driver.v1.StatementResult; -import org.neo4j.driver.v1.Value; import org.neo4j.driver.v1.exceptions.ClientException; -import org.neo4j.driver.v1.exceptions.ServiceUnavailableException; -import org.neo4j.driver.v1.exceptions.SessionExpiredException; -import org.neo4j.driver.v1.util.Function; import static java.lang.String.format; public class RoutingDriver extends BaseDriver { - private static final String GET_SERVERS = "dbms.cluster.routing.getServers"; - private static final long MAX_TTL = Long.MAX_VALUE / 1000L; + private final LoadBalancer loadBalancer; - private final ConnectionPool connections; - private final Function sessionProvider; - private final Clock clock; - private ClusterView clusterView; - - - public RoutingDriver( BoltServerAddress seedAddress, + public RoutingDriver( + RoutingSettings settings, + BoltServerAddress seedAddress, ConnectionPool connections, SecurityPlan securityPlan, - Function sessionProvider, Clock clock, Logging logging ) { super( securityPlan, logging ); - this.connections = connections; - this.sessionProvider = sessionProvider; - this.clock = clock; - this.clusterView = new ClusterView( 0L, clock, log ); - this.clusterView.addRouter( seedAddress ); - checkServers(); - } - - private synchronized void checkServers() - { - if ( clusterView.isStale() ) - { - Set oldAddresses = clusterView.all(); - ClusterView newView = newClusterView(); - Set newAddresses = newView.all(); - - oldAddresses.removeAll( newAddresses ); - for ( BoltServerAddress boltServerAddress : oldAddresses ) - { - connections.purge( boltServerAddress ); - } - - this.clusterView = newView; - } - } - - private long calculateNewExpiry( Record record ) - { - long ttl = record.get( "ttl" ).asLong(); - long nextExpiry = clock.millis() + 1000L * ttl; - if ( ttl < 0 || ttl >= MAX_TTL || nextExpiry < 0 ) - { - return Long.MAX_VALUE; - } - else - { - return nextExpiry; - } - } - - private ClusterView newClusterView() - { - BoltServerAddress address = null; - for ( int i = 0; i < clusterView.numberOfRouters(); i++ ) - { - address = clusterView.nextRouter(); - ClusterView newClusterView; - try - { - newClusterView = call( address, GET_SERVERS, new Function() - - { - @Override - public ClusterView apply( Record record ) - { - long expire = calculateNewExpiry( record ); - ClusterView newClusterView = new ClusterView( expire, clock, log ); - List servers = servers( record ); - for ( ServerInfo server : servers ) - { - switch ( server.role() ) - { - case "READ": - newClusterView.addReaders( server.addresses() ); - break; - case "WRITE": - newClusterView.addWriters( server.addresses() ); - break; - case "ROUTE": - newClusterView.addRouters( server.addresses() ); - break; - } - } - return newClusterView; - } - } ); - } - catch ( Throwable t ) - { - forget( address ); - continue; - } - - if ( newClusterView.numberOfRouters() != 0 ) - { - return newClusterView; - } - } - - - //discovery failed, not much to do, stick with what we've got - //this may happen because server is running in standalone mode - this.close(); - throw new ServiceUnavailableException( - String.format( "Server %s couldn't perform discovery", - address == null ? "`UNKNOWN`" : address.toString() ) ); - - } - - private static class ServerInfo - { - private final List addresses; - private final String role; - - public ServerInfo( List addresses, String role ) - { - this.addresses = addresses; - this.role = role; - } - - public String role() - { - return role; - } - - List addresses() - { - return addresses; - } - } - - private List servers( Record record ) - { - return record.get( "servers" ).asList( new Function() - { - @Override - public ServerInfo apply( Value value ) - { - return new ServerInfo( value.get( "addresses" ).asList( new Function() - { - @Override - public BoltServerAddress apply( Value value ) - { - return new BoltServerAddress( value.asString() ); - } - } ), value.get( "role" ).asString() ); - } - } ); - } - - //must be called from a synchronized method - private T call( BoltServerAddress address, String procedureName, Function recorder ) - { - Connection acquire; - Session session = null; - try - { - acquire = connections.acquire( address ); - session = sessionProvider.apply( acquire ); - - StatementResult records = session.run( format( "CALL %s", procedureName ) ); - //got a result but was empty - if ( !records.hasNext() ) - { - forget( address ); - throw new IllegalStateException("Server responded with empty result"); - } - //consume the results - return recorder.apply( records.single() ); - } - finally - { - if ( session != null ) - { - session.close(); - } - } - } - - private synchronized void forget( BoltServerAddress address ) - { - connections.purge( address ); - clusterView.remove(address); + this.loadBalancer = new LoadBalancer( settings, clock, log, connections, seedAddress ); } @Override @@ -243,83 +60,28 @@ public Session session() public Session session( final AccessMode mode ) { Connection connection = acquireConnection( mode ); - return new RoutingNetworkSession( new NetworkSession( connection ), mode, connection.address(), - new RoutingErrorHandler() - { - @Override - public void onConnectionFailure( BoltServerAddress address ) - { - forget( address ); - } - - @Override - public void onWriteFailure( BoltServerAddress address ) - { - clusterView.removeWriter( address ); - } - } ); + return new RoutingNetworkSession( new NetworkSession( connection ), mode, connection.address(), loadBalancer ); } private Connection acquireConnection( AccessMode role ) { - //Potentially rediscover servers if we are not happy with our current knowledge - checkServers(); - switch ( role ) { case READ: - return acquireReadConnection(); + return loadBalancer.acquireReadConnection(); case WRITE: - return acquireWriteConnection(); + return loadBalancer.acquireWriteConnection(); default: throw new ClientException( role + " is not supported for creating new sessions" ); } } - private Connection acquireReadConnection() - { - int numberOfServers = clusterView.numberOfReaders(); - for ( int i = 0; i < numberOfServers; i++ ) - { - BoltServerAddress address = clusterView.nextReader(); - try - { - return connections.acquire( address ); - } - catch ( ServiceUnavailableException e ) - { - forget( address ); - } - } - - throw new SessionExpiredException( "Failed to connect to any read server" ); - } - - private Connection acquireWriteConnection() - { - int numberOfServers = clusterView.numberOfWriters(); - for ( int i = 0; i < numberOfServers; i++ ) - { - BoltServerAddress address = clusterView.nextWriter(); - try - { - return connections.acquire( address ); - } - catch ( ServiceUnavailableException e ) - { - forget( address ); - } - } - - throw new SessionExpiredException( "Failed to connect to any write server" ); - } - @Override public void close() { try { - connections.close(); + loadBalancer.close(); } catch ( Exception ex ) { @@ -327,28 +89,27 @@ public void close() } } - //For testing - public Set routingServers() + Set routingServers() { - return clusterView.routingServers(); + // TODO: the tests that use this should be testing for effect instead + throw new UnsupportedOperationException( "not implemented" ); } - //For testing - public Set readServers() + Set readServers() { - return clusterView.readServers(); + // TODO: the tests that use this should be testing for effect instead + throw new UnsupportedOperationException( "not implemented" ); } - //For testing - public Set writeServers() + Set writeServers() { - return clusterView.writeServers( ); + // TODO: the tests that use this should be testing for effect instead + throw new UnsupportedOperationException( "not implemented" ); } - //For testing - public ConnectionPool connectionPool() + ConnectionPool connectionPool() { - return connections; + // TODO: the tests that use this should be testing for effect instead, perhaps by injecting a pool delegate + throw new UnsupportedOperationException( "not implemented" ); } - } diff --git a/driver/src/main/java/org/neo4j/driver/internal/RoutingErrorHandler.java b/driver/src/main/java/org/neo4j/driver/internal/RoutingErrorHandler.java index ca178b46d1..944e1eddd4 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/RoutingErrorHandler.java +++ b/driver/src/main/java/org/neo4j/driver/internal/RoutingErrorHandler.java @@ -23,7 +23,7 @@ /** * Interface used for tracking errors when connected to a cluster. */ -interface RoutingErrorHandler +public interface RoutingErrorHandler { void onConnectionFailure( BoltServerAddress address ); diff --git a/driver/src/main/java/org/neo4j/driver/internal/cluster/ClusterComposition.java b/driver/src/main/java/org/neo4j/driver/internal/cluster/ClusterComposition.java new file mode 100644 index 0000000000..acfe0069bf --- /dev/null +++ b/driver/src/main/java/org/neo4j/driver/internal/cluster/ClusterComposition.java @@ -0,0 +1,203 @@ +/** + * Copyright (c) 2002-2016 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal.cluster; + +import java.util.HashSet; +import java.util.Set; + +import org.neo4j.driver.internal.NetworkSession; +import org.neo4j.driver.internal.net.BoltServerAddress; +import org.neo4j.driver.internal.spi.Connection; +import org.neo4j.driver.internal.util.Clock; +import org.neo4j.driver.v1.Record; +import org.neo4j.driver.v1.Statement; +import org.neo4j.driver.v1.StatementResult; +import org.neo4j.driver.v1.Value; +import org.neo4j.driver.v1.exceptions.ServiceUnavailableException; +import org.neo4j.driver.v1.exceptions.value.ValueException; +import org.neo4j.driver.v1.util.Function; + +final class ClusterComposition +{ + interface Provider + { + String GET_SERVERS = "CALL dbms.cluster.routing.getServers"; + + ClusterComposition getClusterComposition( Connection connection ) throws ServiceUnavailableException; + + final class Default implements Provider + { + private static final Statement GET_SERVER = new Statement( Provider.GET_SERVERS ); + private final Clock clock; + + Default( Clock clock ) + { + this.clock = clock; + } + + @Override + public ClusterComposition getClusterComposition( Connection connection ) throws ServiceUnavailableException + { + StatementResult cursor = getServers( connection ); + long now = clock.millis(); + try + { + if ( !cursor.hasNext() ) + { + return null; // server returned too few rows, this is a contract violation, treat as incapable + } + Record record = cursor.next(); + if ( cursor.hasNext() ) + { + return null; // server returned too many rows, this is a contract violation, treat as incapable + } + return read( record, now ); + } + finally + { + cursor.consume(); // make sure we exhaust the results + } + } + + private StatementResult getServers( Connection connection ) + { + return NetworkSession.run( connection, GET_SERVER ); + } + } + } + + private static final long MAX_TTL = Long.MAX_VALUE / 1000L; + private static final Function OF_BoltServerAddress = + new Function() + { + @Override + public BoltServerAddress apply( Value value ) + { + return new BoltServerAddress( value.asString() ); + } + }; + private final Set readers, writers, routers; + final long expirationTimestamp; + + private ClusterComposition( long expirationTimestamp ) + { + this.readers = new HashSet<>(); + this.writers = new HashSet<>(); + this.routers = new HashSet<>(); + this.expirationTimestamp = expirationTimestamp; + } + + /** For testing */ + ClusterComposition( + long expirationTimestamp, + Set readers, + Set writers, + Set routers ) + { + this( expirationTimestamp ); + this.readers.addAll( readers ); + this.writers.addAll( writers ); + this.routers.addAll( routers ); + } + + public boolean isValid() + { + return !routers.isEmpty() && !writers.isEmpty(); + } + + public Set readers() + { + return new HashSet<>( readers ); + } + + public Set writers() + { + return new HashSet<>( writers ); + } + + public Set routers() + { + return new HashSet<>( routers ); + } + + @Override + public String toString() + { + return "ClusterComposition{" + + "expirationTimestamp=" + expirationTimestamp + + ", readers=" + readers + + ", writers=" + writers + + ", routers=" + routers + + '}'; + } + + private static ClusterComposition read( Record record, long now ) + { + if ( record == null ) + { + return null; + } + try + { + final ClusterComposition result; + result = new ClusterComposition( expirationTimestamp( now, record ) ); + record.get( "servers" ).asList( new Function() + { + @Override + public Void apply( Value value ) + { + result.servers( value.get( "role" ).asString() ) + .addAll( value.get( "addresses" ).asList( OF_BoltServerAddress ) ); + return null; + } + } ); + return result; + } + catch ( ValueException e ) + { + return null; + } + } + + private static long expirationTimestamp( long now, Record record ) + { + long ttl = record.get( "ttl" ).asLong(); + long expirationTimestamp = now + ttl * 1000; + if ( ttl < 0 || ttl >= MAX_TTL || expirationTimestamp < 0 ) + { + expirationTimestamp = Long.MAX_VALUE; + } + return expirationTimestamp; + } + + private Set servers( String role ) + { + switch ( role ) + { + case "READ": + return readers; + case "WRITE": + return writers; + case "ROUTE": + return routers; + default: + throw new IllegalArgumentException( "invalid server role: " + role ); + } + } +} diff --git a/driver/src/main/java/org/neo4j/driver/internal/cluster/LoadBalancer.java b/driver/src/main/java/org/neo4j/driver/internal/cluster/LoadBalancer.java new file mode 100644 index 0000000000..ac2a83815a --- /dev/null +++ b/driver/src/main/java/org/neo4j/driver/internal/cluster/LoadBalancer.java @@ -0,0 +1,228 @@ +/** + * Copyright (c) 2002-2016 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal.cluster; + +import java.util.HashSet; + +import org.neo4j.driver.internal.RoutingErrorHandler; +import org.neo4j.driver.internal.net.BoltServerAddress; +import org.neo4j.driver.internal.spi.Connection; +import org.neo4j.driver.internal.spi.ConnectionPool; +import org.neo4j.driver.internal.util.Clock; +import org.neo4j.driver.v1.Logger; +import org.neo4j.driver.v1.exceptions.ServiceUnavailableException; + +import static java.util.Arrays.asList; + +public final class LoadBalancer implements RoutingErrorHandler, AutoCloseable +{ + private static final int MIN_ROUTERS = 1; + private static final String NO_ROUTERS_AVAILABLE = "Could not perform discovery. No routing servers available."; + // dependencies + private final RoutingSettings settings; + private final Clock clock; + private final Logger log; + private final ConnectionPool connections; + private final ClusterComposition.Provider provider; + // state + private long expirationTimeout; + private final RoundRobinAddressSet readers, writers, routers; + + public LoadBalancer( + RoutingSettings settings, + Clock clock, + Logger log, + ConnectionPool connections, + BoltServerAddress... routingAddresses ) throws ServiceUnavailableException + { + this( settings, clock, log, connections, new ClusterComposition.Provider.Default( clock ), routingAddresses ); + } + + LoadBalancer( + RoutingSettings settings, + Clock clock, + Logger log, + ConnectionPool connections, + ClusterComposition.Provider provider, + BoltServerAddress... routingAddresses ) throws ServiceUnavailableException + { + this.clock = clock; + this.log = log; + this.connections = connections; + this.expirationTimeout = clock.millis() - 1; + this.provider = provider; + this.settings = settings; + this.readers = new RoundRobinAddressSet(); + this.writers = new RoundRobinAddressSet(); + this.routers = new RoundRobinAddressSet(); + routers.update( new HashSet<>( asList( routingAddresses ) ), new HashSet() ); + // initialize the routing table + ensureRouting(); + } + + public Connection acquireReadConnection() throws ServiceUnavailableException + { + return acquireConnection( readers ); + } + + public Connection acquireWriteConnection() throws ServiceUnavailableException + { + return acquireConnection( writers ); + } + + @Override + public void onConnectionFailure( BoltServerAddress address ) + { + forget( address ); + } + + @Override + public void onWriteFailure( BoltServerAddress address ) + { + writers.remove( address ); + } + + @Override + public void close() throws Exception + { + connections.close(); + } + + private Connection acquireConnection( RoundRobinAddressSet servers ) throws ServiceUnavailableException + { + for ( ; ; ) + { + // refresh the routing table if needed + ensureRouting(); + for ( BoltServerAddress address; (address = servers.next()) != null; ) + { + try + { + return connections.acquire( address ); + } + catch ( ServiceUnavailableException e ) + { + forget( address ); + } + } + // if we get here, we failed to connect to any server, so we will rebuild the routing table + } + } + + private synchronized void ensureRouting() throws ServiceUnavailableException + { + if ( stale() ) + { + try + { + // get a new routing table + ClusterComposition cluster = lookupRoutingTable(); + expirationTimeout = cluster.expirationTimestamp; + HashSet removed = new HashSet<>(); + readers.update( cluster.readers(), removed ); + writers.update( cluster.writers(), removed ); + routers.update( cluster.routers(), removed ); + // purge connections to removed addresses + for ( BoltServerAddress address : removed ) + { + connections.purge( address ); + } + } + catch ( InterruptedException e ) + { + throw new ServiceUnavailableException( "Thread was interrupted while establishing connection.", e ); + } + } + } + + private ClusterComposition lookupRoutingTable() throws InterruptedException, ServiceUnavailableException + { + int size = routers.size(), failures = 0; + if ( size == 0 ) + { + throw new ServiceUnavailableException( NO_ROUTERS_AVAILABLE ); + } + for ( long start = clock.millis(), delay = 0; ; delay = Math.max( settings.retryTimeoutDelay, delay * 2 ) ) + { + long waitTime = start + delay - clock.millis(); + if ( waitTime > 0 ) + { + clock.sleep( waitTime ); + } + start = clock.millis(); + for ( int i = 0; i < size; i++ ) + { + BoltServerAddress address = routers.next(); + if ( address == null ) + { + throw new ServiceUnavailableException( NO_ROUTERS_AVAILABLE ); + } + ClusterComposition cluster; + try ( Connection connection = connections.acquire( address ) ) + { + cluster = provider.getClusterComposition( connection ); + } + catch ( Exception e ) + { + log.error( String.format( "Failed to connect to routing server '%s'.", address ), e ); + continue; + } + if ( cluster == null || !cluster.isValid() ) + { + log.info( + "Server <%s> unable to perform routing capability, dropping from list of routers.", + address ); + routers.remove( address ); + if ( --size == 0 ) + { + throw new ServiceUnavailableException( NO_ROUTERS_AVAILABLE ); + } + } + else + { + return cluster; + } + } + if ( ++failures > settings.maxRoutingFailures ) + { + throw new ServiceUnavailableException( NO_ROUTERS_AVAILABLE ); + } + } + } + + private synchronized void forget( BoltServerAddress address ) + { + // First remove from the load balancer, to prevent concurrent threads from making connections to them. + // Don't remove it from the set of routers, since that might mean we lose our ability to re-discover, + // just remove it from the set of readers and writers, so that we don't use it for actual work without + // performing discovery first. + readers.remove( address ); + writers.remove( address ); + // drop all current connections to the address + connections.purge( address ); + } + + private boolean stale() + { + return expirationTimeout < clock.millis() || // the expiration timeout has been reached + routers.size() <= MIN_ROUTERS || // we need to discover more routing servers + readers.size() == 0 || // we need to discover more read servers + writers.size() == 0; // we need to discover more write servers + } +} diff --git a/driver/src/main/java/org/neo4j/driver/internal/cluster/RoundRobinAddressSet.java b/driver/src/main/java/org/neo4j/driver/internal/cluster/RoundRobinAddressSet.java new file mode 100644 index 0000000000..acae91704f --- /dev/null +++ b/driver/src/main/java/org/neo4j/driver/internal/cluster/RoundRobinAddressSet.java @@ -0,0 +1,136 @@ +/** + * Copyright (c) 2002-2016 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal.cluster; + +import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; + +import org.neo4j.driver.internal.net.BoltServerAddress; + +class RoundRobinAddressSet +{ + private static final BoltServerAddress[] NONE = {}; + private final AtomicInteger offset = new AtomicInteger(); + private volatile BoltServerAddress[] addresses = NONE; + + public int size() + { + return addresses.length; + } + + public BoltServerAddress next() + { + BoltServerAddress[] addresses = this.addresses; + if ( addresses.length == 0 ) + { + return null; + } + return addresses[next( addresses.length )]; + } + + int next( int divisor ) + { + int index = offset.getAndIncrement(); + for ( ; index == Integer.MAX_VALUE; index = offset.getAndIncrement() ) + { + offset.compareAndSet( Integer.MIN_VALUE, index % divisor ); + } + return index % divisor; + } + + public synchronized void update( Set addresses, Set removed ) + { + BoltServerAddress[] prev = this.addresses; + if ( addresses.isEmpty() ) + { + this.addresses = NONE; + return; + } + if ( prev.length == 0 ) + { + this.addresses = addresses.toArray( NONE ); + return; + } + BoltServerAddress[] copy = null; + if ( addresses.size() != prev.length ) + { + copy = new BoltServerAddress[addresses.size()]; + } + int j = 0; + for ( int i = 0; i < prev.length; i++ ) + { + if ( addresses.remove( prev[i] ) ) + { + if ( copy != null ) + { + copy[j++] = prev[i]; + } + } + else + { + removed.add( prev[i] ); + if ( copy == null ) + { + copy = new BoltServerAddress[prev.length]; + System.arraycopy( prev, 0, copy, 0, i ); + j = i; + } + } + } + if ( copy == null ) + { + return; + } + for ( BoltServerAddress address : addresses ) + { + copy[j++] = address; + } + this.addresses = copy; + } + + public synchronized void remove( BoltServerAddress address ) + { + BoltServerAddress[] addresses = this.addresses; + if ( addresses != null ) + { + for ( int i = 0; i < addresses.length; i++ ) + { + if ( addresses[i].equals( address ) ) + { + if ( addresses.length == 1 ) + { + this.addresses = NONE; + return; + } + BoltServerAddress[] copy = new BoltServerAddress[addresses.length - 1]; + System.arraycopy( addresses, 0, copy, 0, i ); + System.arraycopy( addresses, i + 1, copy, i, addresses.length - i - 1 ); + this.addresses = copy; + return; + } + } + } + } + + /** breaking encapsulation in order to perform white-box testing of boundary case */ + void setOffset( int target ) + { + offset.set( target ); + } +} diff --git a/driver/src/main/java/org/neo4j/driver/internal/cluster/RoutingSettings.java b/driver/src/main/java/org/neo4j/driver/internal/cluster/RoutingSettings.java new file mode 100644 index 0000000000..0b7fe7957e --- /dev/null +++ b/driver/src/main/java/org/neo4j/driver/internal/cluster/RoutingSettings.java @@ -0,0 +1,31 @@ +/** + * Copyright (c) 2002-2016 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal.cluster; + +public class RoutingSettings +{ + final int maxRoutingFailures; + final long retryTimeoutDelay; + + public RoutingSettings( int maxRoutingFailures, long retryTimeoutDelay ) + { + this.maxRoutingFailures = maxRoutingFailures; + this.retryTimeoutDelay = retryTimeoutDelay; + } +} diff --git a/driver/src/main/java/org/neo4j/driver/internal/util/Clock.java b/driver/src/main/java/org/neo4j/driver/internal/util/Clock.java index 21b4b8497e..e5a5a4d932 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/util/Clock.java +++ b/driver/src/main/java/org/neo4j/driver/internal/util/Clock.java @@ -26,6 +26,8 @@ public interface Clock /** Current time, in milliseconds. */ long millis(); + void sleep( long millis ) throws InterruptedException; + Clock SYSTEM = new Clock() { @Override @@ -33,5 +35,11 @@ public long millis() { return System.currentTimeMillis(); } + + @Override + public void sleep( long millis ) throws InterruptedException + { + Thread.sleep( millis ); + } }; } diff --git a/driver/src/main/java/org/neo4j/driver/internal/util/ConcurrentRoundRobinSet.java b/driver/src/main/java/org/neo4j/driver/internal/util/ConcurrentRoundRobinSet.java deleted file mode 100644 index ffd369c248..0000000000 --- a/driver/src/main/java/org/neo4j/driver/internal/util/ConcurrentRoundRobinSet.java +++ /dev/null @@ -1,152 +0,0 @@ -/** - * Copyright (c) 2002-2016 "Neo Technology," - * Network Engine for Objects in Lund AB [http://neotechnology.com] - * - * This file is part of Neo4j. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.neo4j.driver.internal.util; - -import java.util.Collection; -import java.util.Comparator; -import java.util.Iterator; -import java.util.Set; -import java.util.concurrent.ConcurrentSkipListSet; - -/** - * A set that exposes a method {@link #hop()} that cycles through the members of the set. - * @param the type of elements in the set - */ -public class ConcurrentRoundRobinSet implements Set -{ - private final ConcurrentSkipListSet set; - private T current; - - public ConcurrentRoundRobinSet() - { - set = new ConcurrentSkipListSet<>(); - } - - public ConcurrentRoundRobinSet( Comparator comparator ) - { - set = new ConcurrentSkipListSet<>( comparator ); - } - - public ConcurrentRoundRobinSet(ConcurrentRoundRobinSet original) - { - set = new ConcurrentSkipListSet<>( original.set.comparator() ); - set.addAll( original ); - } - - public T hop() - { - if ( current == null ) - { - current = set.first(); - } - else - { - current = set.higher( current ); - //We've gone through all connections, start over - if ( current == null ) - { - current = set.first(); - } - } - - if ( current == null ) - { - throw new IllegalStateException( "nothing in the set" ); - } - - return current; - } - - @Override - public boolean add( T item ) - { - return set.add( item ); - } - - @Override - public boolean containsAll( Collection c ) - { - return set.containsAll( c ); - } - - @Override - public boolean addAll( Collection c ) - { - return set.addAll( c ); - } - - @Override - public boolean retainAll( Collection c ) - { - return set.retainAll( c ); - } - - @Override - public boolean removeAll( Collection c ) - { - return set.retainAll( c ); - } - - @Override - public void clear() - { - set.clear(); - } - - @Override - public boolean remove( Object o ) - { - return set.remove( o ); - } - - public int size() - { - return set.size(); - } - - public boolean isEmpty() - { - return set.isEmpty(); - } - - @Override - public boolean contains( Object o ) - { - return set.contains( o ); - } - - @Override - public Iterator iterator() - { - return set.iterator(); - } - - @Override - public Object[] toArray() - { - return set.toArray(); - } - - @SuppressWarnings( "SuspiciousToArrayCall" ) - @Override - public T1[] toArray( T1[] a ) - { - return set.toArray( a ); - } -} diff --git a/driver/src/main/java/org/neo4j/driver/v1/Config.java b/driver/src/main/java/org/neo4j/driver/v1/Config.java index 41f743b1dc..fc9514295c 100644 --- a/driver/src/main/java/org/neo4j/driver/v1/Config.java +++ b/driver/src/main/java/org/neo4j/driver/v1/Config.java @@ -19,8 +19,10 @@ package org.neo4j.driver.v1; import java.io.File; +import java.util.concurrent.TimeUnit; import java.util.logging.Level; +import org.neo4j.driver.internal.cluster.RoutingSettings; import org.neo4j.driver.internal.logging.JULogging; import org.neo4j.driver.internal.net.pooling.PoolSettings; import org.neo4j.driver.v1.util.Immutable; @@ -60,6 +62,8 @@ public class Config private final TrustStrategy trustStrategy; private final int minServersInCluster; + private final int maxRoutingFailures; + private final long routingRetryDelayMillis; private Config( ConfigBuilder builder) { @@ -71,6 +75,8 @@ private Config( ConfigBuilder builder) this.encryptionLevel = builder.encryptionLevel; this.trustStrategy = builder.trustStrategy; this.minServersInCluster = builder.minServersInCluster; + this.maxRoutingFailures = builder.maxRoutingFailures; + this.routingRetryDelayMillis = builder.routingRetryDelayMillis; } /** @@ -144,6 +150,11 @@ public static Config defaultConfig() return Config.build().toConfig(); } + RoutingSettings routingSettings() + { + return new RoutingSettings( maxRoutingFailures, routingRetryDelayMillis ); + } + /** * Used to build new config instances */ @@ -156,6 +167,8 @@ public static class ConfigBuilder private TrustStrategy trustStrategy = trustOnFirstUse( new File( getProperty( "user.home" ), ".neo4j" + File.separator + "known_hosts" ) ); public int minServersInCluster = 3; + private int maxRoutingFailures = 10; + private long routingRetryDelayMillis = 5_000; private ConfigBuilder() {} @@ -259,6 +272,51 @@ public ConfigBuilder withTrustStrategy( TrustStrategy trustStrategy ) return this; } + /** + * Specify how many times the client should attempt to reconnect to the routing servers before declaring the + * cluster unavailable. + *

+ * The routing servers are tried in order. If connecting any of them fails, they are all retried after + * {@linkplain #withRoutingRetryDelay a delay}. This process of retrying all servers is then repeated for the + * number of times specified here before considering the cluster unavailable. + * + * @param routingRetryLimit + * the number of times to retry each server in the list of routing servers + * @return this builder + */ + public ConfigBuilder withRoutingRetryLimit( int routingRetryLimit ) + { + this.maxRoutingFailures = routingRetryLimit; + return this; + } + + /** + * Specify how long to wait before retrying to connect to a routing server. + *

+ * When connecting to all routing servers fail, connecting will be retried after the delay specified here. + * The delay is measured from when the first attempt to connect was made, so that the delay time specifies a + * retry interval. + *

+ * For each {@linkplain #withRoutingRetryLimit retry attempt} the delay time will be doubled. The time + * specified here is the base time, i.e. the time to wait before the first retry. If that attempt (on all + * servers) also fails, the delay before the next retry will be double the time specified here, and the next + * attempt after that will be double that, et.c. So if, for example, the delay specified here is + * {@code 5 SECONDS}, then after attempting to connect to each server fails reconnecting will be attempted + * 5 seconds after the first connection attempt to the first server. If that attempt also fails to connect to + * all servers, the next attempt will start 10 seconds after the second attempt started. + * + * @param delay + * the amount of time between attempts to reconnect to the same server + * @param unit + * the unit in which the duration is given + * @return this builder + */ + public ConfigBuilder withRoutingRetryDelay( long delay, TimeUnit unit ) + { + this.routingRetryDelayMillis = unit.toMillis( delay ); + return this; + } + /** * Create a config instance from this builder. * @return a {@link Config} instance diff --git a/driver/src/main/java/org/neo4j/driver/v1/GraphDatabase.java b/driver/src/main/java/org/neo4j/driver/v1/GraphDatabase.java index a031815e5e..8e975b9e31 100644 --- a/driver/src/main/java/org/neo4j/driver/v1/GraphDatabase.java +++ b/driver/src/main/java/org/neo4j/driver/v1/GraphDatabase.java @@ -34,7 +34,6 @@ import org.neo4j.driver.internal.spi.ConnectionPool; import org.neo4j.driver.internal.util.Clock; import org.neo4j.driver.v1.exceptions.ClientException; -import org.neo4j.driver.v1.util.BiFunction; import org.neo4j.driver.v1.util.Function; import static java.lang.String.format; @@ -191,7 +190,13 @@ public static Driver driver( URI uri, AuthToken authToken, Config config ) case "bolt": return new DirectDriver( address, connectionPool, securityPlan, config.logging() ); case "bolt+routing": - return new RoutingDriver( address, connectionPool, securityPlan, SESSION_PROVIDER, Clock.SYSTEM, config.logging() ); + return new RoutingDriver( + config.routingSettings(), + address, + connectionPool, + securityPlan, + Clock.SYSTEM, + config.logging() ); default: throw new ClientException( format( "Unsupported URI scheme: %s", scheme ) ); } diff --git a/driver/src/test/java/org/neo4j/driver/internal/ClusterViewTest.java b/driver/src/test/java/org/neo4j/driver/internal/ClusterViewTest.java deleted file mode 100644 index 2f9b7fb778..0000000000 --- a/driver/src/test/java/org/neo4j/driver/internal/ClusterViewTest.java +++ /dev/null @@ -1,189 +0,0 @@ -/** - * Copyright (c) 2002-2016 "Neo Technology," - * Network Engine for Objects in Lund AB [http://neotechnology.com] - * - * This file is part of Neo4j. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.neo4j.driver.internal; - - -import org.junit.Test; - -import org.neo4j.driver.internal.net.BoltServerAddress; -import org.neo4j.driver.internal.util.Clock; -import org.neo4j.driver.v1.Logger; - -import static java.util.Arrays.asList; -import static java.util.Collections.singletonList; -import static org.hamcrest.CoreMatchers.equalTo; -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.containsInAnyOrder; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - -public class ClusterViewTest -{ - - @Test - public void shouldRoundRobinAmongRoutingServers() - { - // Given - ClusterView clusterView = new ClusterView( 5L, mock( Clock.class ), mock( Logger.class ) ); - - // When - clusterView.addRouters( asList( address("host1"), address( "host2" ), address( "host3" ))); - - // Then - assertThat(clusterView.nextRouter(), equalTo(address( "host1" ))); - assertThat(clusterView.nextRouter(), equalTo(address( "host2" ))); - assertThat(clusterView.nextRouter(), equalTo(address( "host3" ))); - assertThat(clusterView.nextRouter(), equalTo(address( "host1" ))); - } - - @Test - public void shouldRoundRobinAmongReadServers() - { - // Given - ClusterView clusterView = new ClusterView( 5L, mock( Clock.class ), mock( Logger.class ) ); - - // When - clusterView.addReaders( asList( address("host1"), address( "host2" ), address( "host3" ))); - - // Then - assertThat(clusterView.nextReader(), equalTo(address( "host1" ))); - assertThat(clusterView.nextReader(), equalTo(address( "host2" ))); - assertThat(clusterView.nextReader(), equalTo(address( "host3" ))); - assertThat(clusterView.nextReader(), equalTo(address( "host1" ))); - } - - @Test - public void shouldRoundRobinAmongWriteServers() - { - // Given - ClusterView clusterView = new ClusterView( 5L, mock( Clock.class ), mock( Logger.class ) ); - - // When - clusterView.addWriters( asList( address("host1"), address( "host2" ), address( "host3" ))); - - // Then - assertThat(clusterView.nextWriter(), equalTo(address( "host1" ))); - assertThat(clusterView.nextWriter(), equalTo(address( "host2" ))); - assertThat(clusterView.nextWriter(), equalTo(address( "host3" ))); - assertThat(clusterView.nextWriter(), equalTo(address( "host1" ))); - } - - @Test - public void shouldRemoveServer() - { - // Given - ClusterView clusterView = new ClusterView( 5L, mock( Clock.class ), mock( Logger.class ) ); - - clusterView.addRouters( asList( address("host1"), address( "host2" ), address( "host3" ))); - clusterView.addReaders( asList( address("host1"), address( "host2" ), address( "host3" ))); - clusterView.addWriters( asList( address("host2"), address( "host4" ))); - - // When - clusterView.remove( address( "host2" ) ); - - // Then - assertThat(clusterView.routingServers(), containsInAnyOrder(address( "host1" ), address( "host3" ))); - assertThat(clusterView.readServers(), containsInAnyOrder(address( "host1" ), address( "host3" ))); - assertThat(clusterView.writeServers(), containsInAnyOrder(address( "host4" ))); - assertThat(clusterView.all(), containsInAnyOrder( address( "host1" ), address( "host3" ), address( "host4" ) )); - } - - @Test - public void shouldBeStaleIfExpired() - { - // Given - Clock clock = mock( Clock.class ); - when(clock.millis()).thenReturn( 6L ); - ClusterView clusterView = new ClusterView( 5L, clock, mock( Logger.class ) ); - clusterView.addRouters( asList( address("host1"), address( "host2" ), address( "host3" ))); - clusterView.addReaders( asList( address("host1"), address( "host2" ), address( "host3" ))); - clusterView.addWriters( asList( address("host2"), address( "host4" ))); - - // Then - assertTrue(clusterView.isStale()); - } - - @Test - public void shouldNotBeStaleIfNotExpired() - { - // Given - Clock clock = mock( Clock.class ); - when(clock.millis()).thenReturn( 4L ); - ClusterView clusterView = new ClusterView( 5L, clock, mock( Logger.class ) ); - clusterView.addRouters( asList( address("host1"), address( "host2" ), address( "host3" ))); - clusterView.addReaders( asList( address("host1"), address( "host2" ), address( "host3" ))); - clusterView.addWriters( asList( address("host2"), address( "host4" ))); - - // Then - assertFalse(clusterView.isStale()); - } - - @Test - public void shouldBeStaleIfOnlyOneRouter() - { - // Given - Clock clock = mock( Clock.class ); - when(clock.millis()).thenReturn( 4L ); - ClusterView clusterView = new ClusterView( 5L, clock, mock( Logger.class ) ); - clusterView.addRouters( singletonList( address( "host1" ) ) ); - clusterView.addReaders( asList( address("host1"), address( "host2" ), address( "host3" ))); - clusterView.addWriters( asList( address("host2"), address( "host4" ))); - - // When - - // Then - assertTrue(clusterView.isStale()); - } - - @Test - public void shouldBeStaleIfNoReader() - { - // Given - Clock clock = mock( Clock.class ); - when(clock.millis()).thenReturn( 4L ); - ClusterView clusterView = new ClusterView( 5L, clock, mock( Logger.class ) ); - clusterView.addRouters( singletonList( address( "host1" ) ) ); - clusterView.addWriters( asList( address("host2"), address( "host4" ))); - - // Then - assertTrue(clusterView.isStale()); - } - - @Test - public void shouldBeStaleIfNoWriter() - { - // Given - Clock clock = mock( Clock.class ); - when(clock.millis()).thenReturn( 4L ); - ClusterView clusterView = new ClusterView( 5L, clock, mock( Logger.class ) ); - clusterView.addRouters( singletonList( address( "host1" ) ) ); - clusterView.addReaders( asList( address("host1"), address( "host2" ), address( "host3" ))); - - // Then - assertTrue(clusterView.isStale()); - } - - private BoltServerAddress address(String host) - { - return new BoltServerAddress( host ); - } - -} \ No newline at end of file diff --git a/driver/src/test/java/org/neo4j/driver/internal/Event.java b/driver/src/test/java/org/neo4j/driver/internal/Event.java new file mode 100644 index 0000000000..bbf8fbaa1f --- /dev/null +++ b/driver/src/test/java/org/neo4j/driver/internal/Event.java @@ -0,0 +1,64 @@ +/** + * Copyright (c) 2002-2016 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal; + +import java.io.PrintWriter; +import java.io.StringWriter; +import java.lang.reflect.Method; + +public abstract class Event +{ + final Class handlerType; + + @SuppressWarnings( "unchecked" ) + public Event() + { + this.handlerType = (Class) handlerType( getClass() ); + } + + @Override + public String toString() + { + StringWriter res = new StringWriter(); + try ( PrintWriter out = new PrintWriter( res ) ) + { + EventHandler.write( this, out ); + } + return res.toString(); + } + + public abstract void dispatch( Handler handler ); + + private static Class handlerType( Class type ) + { + for ( Class c = type; c != Object.class; c = c.getSuperclass() ) + { + for ( Method method : c.getDeclaredMethods() ) + { + if ( method.getName().equals( "dispatch" ) + && method.getParameterTypes().length == 1 + && !method.isSynthetic() ) + { + return method.getParameterTypes()[0]; + } + } + } + throw new Error( "Cannot determine Handler type from dispatch(Handler) method." ); + } +} diff --git a/driver/src/test/java/org/neo4j/driver/internal/EventHandler.java b/driver/src/test/java/org/neo4j/driver/internal/EventHandler.java new file mode 100644 index 0000000000..bb9b87250d --- /dev/null +++ b/driver/src/test/java/org/neo4j/driver/internal/EventHandler.java @@ -0,0 +1,285 @@ +/** + * Copyright (c) 2002-2016 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal; + +import java.io.PrintStream; +import java.io.PrintWriter; +import java.lang.reflect.Constructor; +import java.lang.reflect.InvocationHandler; +import java.lang.reflect.Method; +import java.lang.reflect.Proxy; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.CopyOnWriteArrayList; + +import org.hamcrest.Matcher; + +import org.neo4j.driver.internal.util.MatcherFactory; +import org.neo4j.driver.v1.EventLogger; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.any; +import static org.hamcrest.Matchers.not; +import static org.neo4j.driver.internal.util.MatcherFactory.containsAtLeast; +import static org.neo4j.driver.internal.util.MatcherFactory.count; + +public final class EventHandler +{ + private final List events = new ArrayList<>(); + private final ConcurrentMap,List> handlers = new ConcurrentHashMap<>(); + + public void add( Event event ) + { + synchronized ( events ) + { + events.add( event ); + } + List handlers = this.handlers.get( event.handlerType ); + if ( handlers != null ) + { + for ( Object handler : handlers ) + { + try + { + dispatch( event, handler ); + } + catch ( Exception e ) + { + System.err.println( "Failed to dispatch event: " + event + " to handler: " + handler ); + e.printStackTrace( System.err ); + } + } + } + } + + @SuppressWarnings( "unchecked" ) + public void registerHandler( Class type, Handler handler ) + { + List handlers = this.handlers.get( type ); + if ( handlers == null ) + { + List candidate = new CopyOnWriteArrayList<>(); + handlers = this.handlers.putIfAbsent( type, candidate ); + if ( handlers == null ) + { + handlers = candidate; + } + } + handlers.add( handler ); + } + + @SafeVarargs + public final void assertContains( MatcherFactory... matchers ) + { + synchronized ( events ) + { + assertThat( events, containsAtLeast( (MatcherFactory[]) matchers ) ); + } + } + + @SafeVarargs + public final void assertContains( Matcher... matchers ) + { + synchronized ( events ) + { + assertThat( events, containsAtLeast( (Matcher[]) matchers ) ); + } + } + + public final void assertCount( Matcher matcher, Matcher count ) + { + synchronized ( events ) + { + assertThat( events, (Matcher) count( matcher, count ) ); + } + } + + public void assertNone( Matcher matcher ) + { + synchronized ( events ) + { + assertThat( events, not( containsAtLeast( (Matcher) matcher ) ) ); + } + } + + public void printEvents( PrintStream out ) + { + printEvents( any( Event.class ), out ); + } + + public void printEvents( Matcher matcher, PrintStream out ) + { + try ( PrintWriter writer = new PrintWriter( out ) ) + { + printEvents( matcher, writer ); + } + } + + public void printEvents( PrintWriter out ) + { + printEvents( any( Event.class ), out ); + } + + public void printEvents( Matcher matcher, PrintWriter out ) + { + synchronized ( events ) + { + for ( Event event : events ) + { + if ( matcher.matches( event ) ) + { + write( event, out ); + out.println(); + } + } + } + } + + public void forEach( Object handler ) + { + synchronized ( events ) + { + for ( Event event : events ) + { + if ( event.handlerType.isInstance( handler ) ) + { + dispatch( event, handler ); + } + } + } + } + + private static void dispatch( Event event, Object handler ) + { + event.dispatch( event.handlerType.cast( handler ) ); + } + + static void write( Event event, PrintWriter out ) + { + dispatch( event, proxy( event.handlerType, new WriteHandler( out ) ) ); + } + + private static Handler proxy( Class handlerType, InvocationHandler handler ) + { + if ( handlerType.isInstance( handler ) ) + { + return handlerType.cast( handler ); + } + try + { + return handlerType.cast( proxies.get( handlerType ).newInstance( handler ) ); + } + catch ( RuntimeException e ) + { + throw e; + } + catch ( Exception e ) + { + throw new RuntimeException( e ); + } + } + + private static final ClassValue proxies = new ClassValue() + { + @Override + protected Constructor computeValue( Class type ) + { + Class proxy = Proxy.getProxyClass( type.getClassLoader(), type ); + try + { + return proxy.getConstructor( InvocationHandler.class ); + } + catch ( NoSuchMethodException e ) + { + throw new RuntimeException( e ); + } + } + }; + + private static class WriteHandler implements InvocationHandler, EventLogger.Sink + { + private final PrintWriter out; + + WriteHandler( PrintWriter out ) + { + this.out = out; + } + + @Override + public Object invoke( Object proxy, Method method, Object[] args ) throws Throwable + { + out.append( method.getName() ).append( '(' ); + String sep = " "; + for ( Object arg : args ) + { + out.append( sep ); + if ( arg == null || !arg.getClass().isArray() ) + { + out.append( Objects.toString( arg ) ); + } + else if ( arg instanceof Object[] ) + { + out.append( Arrays.toString( (Object[]) arg ) ); + } + else + { + out.append( Arrays.class.getMethod( "toString", arg.getClass() ).invoke( null, arg ).toString() ); + } + sep = ", "; + } + if ( args.length > 0 ) + { + out.append( ' ' ); + } + out.append( ')' ); + return null; + } + + @Override + public void log( String name, EventLogger.Level level, Throwable cause, String message, Object... params ) + { + out.append( level.name() ).append( ": " ); + if ( params != null ) + { + try + { + out.format( message, params ); + } + catch ( Exception e ) + { + out.format( "InvalidFormat(message=\"%s\", params=%s, failure=%s)", + message, Arrays.toString( params ), e ); + } + out.println(); + } + else + { + out.println( message ); + } + if ( cause != null ) + { + cause.printStackTrace( out ); + } + } + } +} diff --git a/driver/src/test/java/org/neo4j/driver/internal/RoutingDriverTest.java b/driver/src/test/java/org/neo4j/driver/internal/RoutingDriverTest.java index f551a8dbfd..61b81a6d4a 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/RoutingDriverTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/RoutingDriverTest.java @@ -18,47 +18,41 @@ */ package org.neo4j.driver.internal; -import org.hamcrest.Matchers; +import java.util.Collections; +import java.util.Map; + import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; +import org.mockito.internal.stubbing.answers.ThrowsException; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - +import org.neo4j.driver.internal.cluster.RoutingSettings; import org.neo4j.driver.internal.net.BoltServerAddress; +import org.neo4j.driver.internal.spi.Collector; import org.neo4j.driver.internal.spi.Connection; import org.neo4j.driver.internal.spi.ConnectionPool; -import org.neo4j.driver.internal.util.Clock; +import org.neo4j.driver.internal.util.FakeClock; import org.neo4j.driver.v1.AccessMode; -import org.neo4j.driver.v1.Logger; +import org.neo4j.driver.v1.EventLogger; import org.neo4j.driver.v1.Logging; -import org.neo4j.driver.v1.Record; -import org.neo4j.driver.v1.Session; -import org.neo4j.driver.v1.StatementResult; import org.neo4j.driver.v1.Value; import org.neo4j.driver.v1.exceptions.ClientException; -import org.neo4j.driver.v1.exceptions.NoSuchRecordException; import org.neo4j.driver.v1.exceptions.ServiceUnavailableException; -import org.neo4j.driver.v1.summary.ResultSummary; -import org.neo4j.driver.v1.util.Function; import static java.util.Arrays.asList; -import static java.util.Collections.singletonList; -import static org.hamcrest.CoreMatchers.equalTo; -import static org.hamcrest.CoreMatchers.hasItem; -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.containsInAnyOrder; -import static org.hamcrest.core.IsNot.not; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; import static org.mockito.Matchers.any; -import static org.mockito.Matchers.anyString; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.neo4j.driver.internal.cluster.ClusterCompositionProviderTest.serverInfo; +import static org.neo4j.driver.internal.cluster.ClusterCompositionProviderTest.withKeys; +import static org.neo4j.driver.internal.cluster.ClusterCompositionProviderTest.withServerList; import static org.neo4j.driver.internal.security.SecurityPlan.insecure; import static org.neo4j.driver.v1.Values.value; @@ -66,359 +60,304 @@ public class RoutingDriverTest { @Rule public ExpectedException exception = ExpectedException.none(); - private static final BoltServerAddress SEED = new BoltServerAddress( "localhost", 7687 ); private static final String GET_SERVERS = "CALL dbms.cluster.routing.getServers"; - private static final List NO_ADDRESSES = Collections.emptyList(); - private final ConnectionPool pool = pool(); + private final EventHandler events = new EventHandler(); + private final FakeClock clock = new FakeClock( events, true ); + private final Logging logging = EventLogger.provider( events, EventLogger.Level.TRACE ); @Test public void shouldDoRoutingOnInitialization() { // Given - final Session session = mock( Session.class ); - when( session.run( GET_SERVERS ) ).thenReturn( - getServers( singletonList( "localhost:1111" ), - singletonList( "localhost:2222" ), - singletonList( "localhost:3333" ) ) ); + ConnectionPool pool = poolWithServers( + 10, + serverInfo( "ROUTE", "localhost:1111" ), + serverInfo( "READ", "localhost:2222" ), + serverInfo( "WRITE", "localhost:3333" ) ); // When - RoutingDriver routingDriver = forSession( session ); + driverWithPool( pool ); // Then - assertThat( routingDriver.routingServers(), - containsInAnyOrder( boltAddress( "localhost", 1111 )) ); - assertThat( routingDriver.readServers(), - containsInAnyOrder( boltAddress( "localhost", 2222 ) ) ); - assertThat( routingDriver.writeServers(), - containsInAnyOrder( boltAddress( "localhost", 3333 ) ) ); - + verify( pool ).acquire( SEED ); } @Test public void shouldDoReRoutingOnSessionAcquisitionIfNecessary() { // Given - final Session session = mock( Session.class ); - when( session.run( GET_SERVERS ) ) - .thenReturn( - getServers( singletonList( "localhost:1111" ), NO_ADDRESSES, NO_ADDRESSES ) ) - .thenReturn( - getServers( singletonList( "localhost:1112" ), - singletonList( "localhost:2222" ), - singletonList( "localhost:3333" ) ) ); - - RoutingDriver routingDriver = forSession( session ); - - assertThat( routingDriver.routingServers(), - containsInAnyOrder( boltAddress( "localhost", 1111 )) ); - assertThat( routingDriver.readServers(), Matchers.empty() ); - assertThat( routingDriver.writeServers(), Matchers.empty() ); - + RoutingDriver routingDriver = driverWithPool( pool( + withServers( 10, serverInfo( "ROUTE", "localhost:1111" ), + serverInfo( "READ" ), + serverInfo( "WRITE", "localhost:5555" ) ), + withServers( 10, serverInfo( "ROUTE", "localhost:1112" ), + serverInfo( "READ", "localhost:2222" ), + serverInfo( "WRITE", "localhost:3333" ) ) ) ); // When - routingDriver.session( AccessMode.READ ); + RoutingNetworkSession writing = (RoutingNetworkSession) routingDriver.session( AccessMode.WRITE ); // Then - assertThat( routingDriver.routingServers(), - containsInAnyOrder( boltAddress( "localhost", 1112 ) )); - assertThat( routingDriver.readServers(), - containsInAnyOrder( boltAddress( "localhost", 2222 ) ) ); - assertThat( routingDriver.writeServers(), - containsInAnyOrder( boltAddress( "localhost", 3333 ) ) ); + assertEquals( boltAddress( "localhost", 3333 ), writing.address() ); } @Test public void shouldNotDoReRoutingOnSessionAcquisitionIfNotNecessary() { // Given - final Session session = mock( Session.class ); - when( session.run( GET_SERVERS ) ) - .thenReturn( - getServers( asList( "localhost:1111", "localhost:1112", "localhost:1113" ), - singletonList( "localhost:2222" ), - singletonList( "localhost:3333" ) ) ) - .thenReturn( - getServers( singletonList( "localhost:5555" ), NO_ADDRESSES, NO_ADDRESSES ) ); - - RoutingDriver routingDriver = forSession( session ); + RoutingDriver routingDriver = driverWithPool( pool( + withServers( 10, serverInfo( "ROUTE", "localhost:1111", "localhost:1112", "localhost:1113" ), + serverInfo( "READ", "localhost:2222" ), + serverInfo( "WRITE", "localhost:3333" ) ), + withServers( 10, serverInfo( "ROUTE", "localhost:5555" ), + serverInfo( "READ", "localhost:5555" ), + serverInfo( "WRITE", "localhost:5555" ) ) ) ); // When - routingDriver.session( AccessMode.WRITE ); + RoutingNetworkSession writing = (RoutingNetworkSession) routingDriver.session( AccessMode.WRITE ); + RoutingNetworkSession reading = (RoutingNetworkSession) routingDriver.session( AccessMode.READ ); // Then - assertThat( routingDriver.routingServers(), - not( hasItem( boltAddress( "localhost", 5555 ) ) ) ); + assertEquals( boltAddress( "localhost", 3333 ), writing.address() ); + assertEquals( boltAddress( "localhost", 2222 ), reading.address() ); } @Test public void shouldFailIfNoRouting() { // Given - final Session session = mock( Session.class ); - when( session.run( GET_SERVERS ) ) - .thenThrow( - new ClientException( "Neo.ClientError.Procedure.ProcedureNotFound", "Procedure not found" ) ); + ConnectionPool pool = pool( new ThrowsException( new ClientException( + "Neo.ClientError.Procedure.ProcedureNotFound", "Procedure not found" ) ) ); - // Expect - exception.expect( ServiceUnavailableException.class ); + // When + try + { + driverWithPool( pool ); + } + // Then + catch ( ServiceUnavailableException e ) + { + assertEquals( "Could not perform discovery. No routing servers available.", e.getMessage() ); + } + } + + @Test + public void shouldFailIfNoRoutersProvided() + { + // Given + ConnectionPool pool = poolWithServers( + 10, + serverInfo( "ROUTE" ), + serverInfo( "READ", "localhost:1111" ), + serverInfo( "WRITE", "localhost:1111" ) ); // When - forSession( session ); + try + { + driverWithPool( pool ); + } + // Then + catch ( ServiceUnavailableException e ) + { + assertEquals( "Could not perform discovery. No routing servers available.", e.getMessage() ); + } } @Test - public void shouldForgetAboutServersOnRerouting() + public void shouldFailIfNoWritersProvided() { // Given - final Session session = mock( Session.class ); - when( session.run( GET_SERVERS ) ) - .thenReturn( - getServers( singletonList( "localhost:1111" ), NO_ADDRESSES, NO_ADDRESSES ) ) - .thenReturn( - getServers( singletonList( "localhost:1112" ), - singletonList( "localhost:2222" ), - singletonList( "localhost:3333" ) ) ); + ConnectionPool pool = poolWithServers( + 10, + serverInfo( "ROUTE", "localhost:1111" ), + serverInfo( "READ", "localhost:1111" ), + serverInfo( "WRITE" ) ); - RoutingDriver routingDriver = forSession( session ); + // When + try + { + driverWithPool( pool ); + } + // Then + catch ( ServiceUnavailableException e ) + { + assertEquals( "Could not perform discovery. No routing servers available.", e.getMessage() ); + } + } - assertThat( routingDriver.routingServers(), - containsInAnyOrder( boltAddress( "localhost", 1111 )) ); + @Test + public void shouldForgetAboutServersOnRerouting() + { + // Given + ConnectionPool pool = pool( + withServers( 10, serverInfo( "ROUTE", "localhost:1111" ), + serverInfo( "READ" ), + serverInfo( "WRITE", "localhost:5555" ) ), + withServers( 10, serverInfo( "ROUTE", "localhost:1112" ), + serverInfo( "READ", "localhost:2222" ), + serverInfo( "WRITE", "localhost:3333" ) ) ); + RoutingDriver routingDriver = driverWithPool( pool ); // When - routingDriver.session( AccessMode.READ ); + RoutingNetworkSession write1 = (RoutingNetworkSession) routingDriver.session( AccessMode.WRITE ); + RoutingNetworkSession write2 = (RoutingNetworkSession) routingDriver.session( AccessMode.WRITE ); // Then - assertThat( routingDriver.routingServers(), - containsInAnyOrder( boltAddress( "localhost", 1112 ) )); - verify( pool ).purge( boltAddress( "localhost", 1111 ) ); + assertEquals( boltAddress( "localhost", 3333 ), write1.address() ); + assertEquals( boltAddress( "localhost", 3333 ), write2.address() ); } @Test public void shouldRediscoverOnTimeout() { // Given - final Session session = mock( Session.class ); - Clock clock = mock( Clock.class ); - when(clock.millis()).thenReturn( 0L, 11000L, 22000L ); - when( session.run( GET_SERVERS ) ) - .thenReturn( - getServers( asList( "localhost:1111", "localhost:1112", "localhost:1113" ), - singletonList( "localhost:2222" ), - singletonList( "localhost:3333" ), 10L/*seconds*/ ) ) - .thenReturn( - getServers( singletonList( "localhost:5555" ), singletonList( "localhost:5555" ), singletonList( "localhost:5555" ) ) ); - - RoutingDriver routingDriver = forSession( session, clock ); + RoutingDriver routingDriver = driverWithPool( pool( + withServers( 10, serverInfo( "ROUTE", "localhost:1111", "localhost:1112", "localhost:1113" ), + serverInfo( "READ", "localhost:2222" ), + serverInfo( "WRITE", "localhost:3333" ) ), + withServers( 60, serverInfo( "ROUTE", "localhost:5555", "localhost:6666" ), + serverInfo( "READ", "localhost:7777" ), + serverInfo( "WRITE", "localhost:8888" ) ) ) ); + + clock.progress( 11_000 ); // When - routingDriver.session( AccessMode.WRITE ); + RoutingNetworkSession writing = (RoutingNetworkSession) routingDriver.session( AccessMode.WRITE ); + RoutingNetworkSession reading = (RoutingNetworkSession) routingDriver.session( AccessMode.READ ); // Then - assertThat( routingDriver.routingServers(), containsInAnyOrder( boltAddress( "localhost", 5555 ) ) ); - assertThat( routingDriver.readServers(), containsInAnyOrder( boltAddress( "localhost", 5555 ) ) ); - assertThat( routingDriver.writeServers(), containsInAnyOrder( boltAddress( "localhost", 5555 ) ) ); + assertEquals( boltAddress( "localhost", 8888 ), writing.address() ); + assertEquals( boltAddress( "localhost", 7777 ), reading.address() ); } @Test - public void shouldNotRediscoverWheNoTimeout() + public void shouldNotRediscoverWhenNoTimeout() { // Given - final Session session = mock( Session.class ); - Clock clock = mock( Clock.class ); - when(clock.millis()).thenReturn( 0L, 9900L, 18800L ); - when( session.run( GET_SERVERS ) ) - .thenReturn( - getServers( asList( "localhost:1111", "localhost:1112", "localhost:1113" ), - singletonList( "localhost:2222" ), - singletonList( "localhost:3333" ), 10L/*seconds*/ ) ) - .thenReturn( - getServers( singletonList( "localhost:5555" ), singletonList( "localhost:5555" ), singletonList( "localhost:5555" ) ) ); - - RoutingDriver routingDriver = forSession( session, clock ); + RoutingDriver routingDriver = driverWithPool( pool( + withServers( 10, serverInfo( "ROUTE", "localhost:1111", "localhost:1112", "localhost:1113" ), + serverInfo( "READ", "localhost:2222" ), + serverInfo( "WRITE", "localhost:3333" ) ), + withServers( 10, serverInfo( "ROUTE", "localhost:5555" ), + serverInfo( "READ", "localhost:5555" ), + serverInfo( "WRITE", "localhost:5555" ) ) ) ); + clock.progress( 9900 ); // When - routingDriver.session( AccessMode.WRITE ); + RoutingNetworkSession writer = (RoutingNetworkSession) routingDriver.session( AccessMode.WRITE ); + RoutingNetworkSession reader = (RoutingNetworkSession) routingDriver.session( AccessMode.READ ); // Then - assertThat( routingDriver.routingServers(), containsInAnyOrder( boltAddress( "localhost", 1111 ), boltAddress( "localhost", 1112 ), boltAddress( "localhost", 1113 ) ) ); - assertThat( routingDriver.readServers(), containsInAnyOrder( boltAddress( "localhost", 2222 ) ) ); - assertThat( routingDriver.writeServers(), containsInAnyOrder( boltAddress( "localhost", 3333 ) ) ); + assertEquals( boltAddress( "localhost", 2222 ), reader.address() ); + assertEquals( boltAddress( "localhost", 3333 ), writer.address() ); } @Test public void shouldRoundRobinAmongReadServers() { // Given - final Session session = mock( Session.class ); - when( session.run( GET_SERVERS ) ).thenReturn( - getServers( asList( "localhost:1111", "localhost:1112" ), - asList( "localhost:2222", "localhost:2223", "localhost:2224" ), - singletonList( "localhost:3333" ) ) ); + RoutingDriver routingDriver = driverWithServers( 60, serverInfo( "ROUTE", "localhost:1111", "localhost:1112" ), + serverInfo( "READ", "localhost:2222", "localhost:2223", "localhost:2224" ), + serverInfo( "WRITE", "localhost:3333" ) ); // When - RoutingDriver routingDriver = forSession( session ); RoutingNetworkSession read1 = (RoutingNetworkSession) routingDriver.session( AccessMode.READ ); RoutingNetworkSession read2 = (RoutingNetworkSession) routingDriver.session( AccessMode.READ ); RoutingNetworkSession read3 = (RoutingNetworkSession) routingDriver.session( AccessMode.READ ); RoutingNetworkSession read4 = (RoutingNetworkSession) routingDriver.session( AccessMode.READ ); - + RoutingNetworkSession read5 = (RoutingNetworkSession) routingDriver.session( AccessMode.READ ); + RoutingNetworkSession read6 = (RoutingNetworkSession) routingDriver.session( AccessMode.READ ); // Then - assertThat(read1.address(), equalTo(boltAddress( "localhost", 2222 ))); - assertThat(read2.address(), equalTo(boltAddress( "localhost", 2223 ))); - assertThat(read3.address(), equalTo(boltAddress( "localhost", 2224 ))); - assertThat(read4.address(), equalTo(boltAddress( "localhost", 2222 ))); - + assertEquals( read1.address(), read4.address() ); + assertEquals( read2.address(), read5.address() ); + assertEquals( read3.address(), read6.address() ); + assertNotEquals( read1.address(), read2.address() ); + assertNotEquals( read2.address(), read3.address() ); + assertNotEquals( read3.address(), read1.address() ); } @Test public void shouldRoundRobinAmongWriteServers() { // Given - final Session session = mock( Session.class ); - when( session.run( GET_SERVERS ) ).thenReturn( - getServers( asList( "localhost:1111", "localhost:1112" ), - singletonList( "localhost:3333" ), asList( "localhost:2222", "localhost:2223", "localhost:2224" ) ) ); + RoutingDriver routingDriver = driverWithServers( 60, serverInfo( "ROUTE", "localhost:1111", "localhost:1112" ), + serverInfo( "READ", "localhost:3333" ), + serverInfo( "WRITE", "localhost:2222", "localhost:2223", "localhost:2224" ) ); // When - RoutingDriver routingDriver = forSession( session ); RoutingNetworkSession write1 = (RoutingNetworkSession) routingDriver.session( AccessMode.WRITE ); RoutingNetworkSession write2 = (RoutingNetworkSession) routingDriver.session( AccessMode.WRITE ); RoutingNetworkSession write3 = (RoutingNetworkSession) routingDriver.session( AccessMode.WRITE ); RoutingNetworkSession write4 = (RoutingNetworkSession) routingDriver.session( AccessMode.WRITE ); - + RoutingNetworkSession write5 = (RoutingNetworkSession) routingDriver.session( AccessMode.WRITE ); + RoutingNetworkSession write6 = (RoutingNetworkSession) routingDriver.session( AccessMode.WRITE ); // Then - assertThat(write1.address(), equalTo(boltAddress( "localhost", 2222 ))); - assertThat(write2.address(), equalTo(boltAddress( "localhost", 2223 ))); - assertThat(write3.address(), equalTo(boltAddress( "localhost", 2224 ))); - assertThat(write4.address(), equalTo(boltAddress( "localhost", 2222 ))); - + assertEquals( write1.address(), write4.address() ); + assertEquals( write2.address(), write5.address() ); + assertEquals( write3.address(), write6.address() ); + assertNotEquals( write1.address(), write2.address() ); + assertNotEquals( write2.address(), write3.address() ); + assertNotEquals( write3.address(), write1.address() ); } - private RoutingDriver forSession( final Session session ) + @SafeVarargs + private final RoutingDriver driverWithServers( long ttl, Map... serverInfo ) { - return forSession( session, Clock.SYSTEM ); - } - private RoutingDriver forSession( final Session session, Clock clock ) - { - return new RoutingDriver( SEED, pool, insecure(), - new Function() - { - @Override - public Session apply( Connection connection ) - { - return session; - } - }, clock, logging() ); + return driverWithPool( poolWithServers( ttl, serverInfo ) ); } - private BoltServerAddress boltAddress( String host, int port ) + private RoutingDriver driverWithPool( ConnectionPool pool ) { - return new BoltServerAddress( host, port ); + return new RoutingDriver( new RoutingSettings( 10, 5_000 ), SEED, pool, insecure(), clock, logging ); } - - StatementResult getServers( final List routers, final List readers, - final List writers ) + @SafeVarargs + private final ConnectionPool poolWithServers( long ttl, Map... serverInfo ) { - return getServers( routers,readers, writers, Long.MAX_VALUE ); + return pool( withServers( ttl, serverInfo ) ); } - StatementResult getServers( final List routers, final List readers, - final List writers, final long ttl ) + @SafeVarargs + private static Answer withServers( long ttl, Map... serverInfo ) { - return new StatementResult() - { - private int counter = 0; - - @Override - public List keys() - { - return asList( "ttl", "servers" ); - } - - @Override - public boolean hasNext() - { - return counter < 1; - } - - @Override - public Record next() - { - counter++; - return new InternalRecord( asList( "ttl", "servers" ), - new Value[]{ - value( ttl ), - value( asList( serverInfo( "ROUTE", routers ), serverInfo( "WRITE", writers ), - serverInfo( "READ", readers ) ) ) - } ); - } - - @Override - public Record single() throws NoSuchRecordException - { - return next(); - } - - @Override - public Record peek() - { - return null; - } - - @Override - public List list() - { - return null; - } - - @Override - public List list( Function mapFunction ) - { - return null; - } - - @Override - public ResultSummary consume() - { - return null; - } - - @Override - public void remove() - { - throw new UnsupportedOperationException(); - } - }; + return withServerList( new Value[] {value( ttl ), value( asList( serverInfo ) )} ); } - private Map serverInfo( String role, List addresses ) + private BoltServerAddress boltAddress( String host, int port ) { - Map map = new HashMap<>(); - map.put( "role", role ); - map.put( "addresses", addresses ); - - return map; + return new BoltServerAddress( host, port ); } - private ConnectionPool pool() + private ConnectionPool pool( final Answer toGetServers, final Answer... furtherGetServers ) { ConnectionPool pool = mock( ConnectionPool.class ); - - when( pool.acquire( any(BoltServerAddress.class) ) ).thenAnswer( new Answer() + when( pool.acquire( any( BoltServerAddress.class ) ) ).thenAnswer( new Answer() { + int answer; + @Override public Connection answer( InvocationOnMock invocationOnMock ) throws Throwable { - BoltServerAddress address = (BoltServerAddress) invocationOnMock.getArguments()[0]; + BoltServerAddress address = invocationOnMock.getArgumentAt( 0, BoltServerAddress.class ); Connection connection = mock( Connection.class ); when( connection.isOpen() ).thenReturn( true ); - when(connection.address()).thenReturn( address ); + when( connection.address() ).thenReturn( address ); + doAnswer( withKeys( "ttl", "servers" ) ).when( connection ).run( + eq( GET_SERVERS ), + eq( Collections.emptyMap() ), + any( Collector.class ) ); + if ( answer > furtherGetServers.length ) + { + answer = furtherGetServers.length; + } + int offset = answer++; + doAnswer( offset == 0 ? toGetServers : furtherGetServers[offset - 1] ) + .when( connection ).pullAll( any( Collector.class ) ); return connection; } @@ -426,11 +365,4 @@ public Connection answer( InvocationOnMock invocationOnMock ) throws Throwable return pool; } - - private Logging logging() - { - Logging mock = mock( Logging.class ); - when( mock.getLog( anyString() ) ).thenReturn( mock( Logger.class ) ); - return mock; - } -} \ No newline at end of file +} diff --git a/driver/src/test/java/org/neo4j/driver/internal/cluster/ClusterCompositionProviderTest.java b/driver/src/test/java/org/neo4j/driver/internal/cluster/ClusterCompositionProviderTest.java new file mode 100644 index 0000000000..0e1803521d --- /dev/null +++ b/driver/src/test/java/org/neo4j/driver/internal/cluster/ClusterCompositionProviderTest.java @@ -0,0 +1,248 @@ +/** + * Copyright (c) 2002-2016 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal.cluster; + +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +import org.junit.Test; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; +import org.mockito.stubbing.Stubber; + +import org.neo4j.driver.internal.EventHandler; +import org.neo4j.driver.internal.net.BoltServerAddress; +import org.neo4j.driver.internal.spi.Collector; +import org.neo4j.driver.internal.spi.Connection; +import org.neo4j.driver.internal.util.FakeClock; +import org.neo4j.driver.v1.Value; +import org.neo4j.driver.v1.exceptions.ServiceUnavailableException; + +import static java.util.Arrays.asList; +import static java.util.Collections.singletonList; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.fail; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.neo4j.driver.v1.Values.value; + +public class ClusterCompositionProviderTest +{ + private final FakeClock clock = new FakeClock( (EventHandler) null, true ); + private final Connection connection = mock( Connection.class ); + + @Test + public void shouldParseClusterComposition() throws Exception + { + // given + clock.progress( 16500 ); + keys( "ttl", "servers" ); + values( new Value[] { + value( 100 ), value( asList( + serverInfo( "READ", "one:1337", "two:1337" ), + serverInfo( "WRITE", "one:1337" ), + serverInfo( "ROUTE", "one:1337", "two:1337" ) ) )} ); + + // when + ClusterComposition composition = getClusterComposition(); + + // then + assertNotNull( composition ); + assertEquals( 16500 + 100_000, composition.expirationTimestamp ); + assertEquals( serverSet( "one:1337", "two:1337" ), composition.readers() ); + assertEquals( serverSet( "one:1337" ), composition.writers() ); + assertEquals( serverSet( "one:1337", "two:1337" ), composition.routers() ); + } + + @Test + public void shouldReturnNullIfResultContainsTooManyRows() throws Exception + { + // given + keys( "ttl", "servers" ); + values( + new Value[] { + value( 100 ), value( singletonList( + serverInfo( "READ", "one:1337", "two:1337" ) ) )}, + new Value[] { + value( 100 ), value( singletonList( + serverInfo( "WRITE", "one:1337" ) ) )}, + new Value[] { + value( 100 ), value( singletonList( + serverInfo( "ROUTE", "one:1337", "two:1337" ) ) )} ); + + // then + assertNull( getClusterComposition() ); + } + + @Test + public void shouldReturnNullOnEmptyResult() throws Exception + { + // given + keys( "ttl", "servers" ); + values(); + + // then + assertNull( getClusterComposition() ); + } + + @Test + public void shouldReturnNullOnResultWithWrongFormat() throws Exception + { + // given + clock.progress( 16500 ); + keys( "ttl", "addresses" ); + values( new Value[] { + value( 100 ), value( asList( + serverInfo( "READ", "one:1337", "two:1337" ), + serverInfo( "WRITE", "one:1337" ), + serverInfo( "ROUTE", "one:1337", "two:1337" ) ) )} ); + + // then + assertNull( getClusterComposition() ); + } + + @Test + public void shouldPropagateConnectionFailureExceptions() throws Exception + { + // given + ServiceUnavailableException expected = new ServiceUnavailableException( "spanish inquisition" ); + onGetServers( doThrow( expected ) ); + + // when + try + { + getClusterComposition(); + fail( "Expected exception" ); + } + // then + catch ( ServiceUnavailableException e ) + { + assertSame( expected, e ); + } + } + + private ClusterComposition getClusterComposition() + { + return new ClusterComposition.Provider.Default( clock ).getClusterComposition( connection ); + } + + private void keys( final String... keys ) + { + onGetServers( doAnswer( withKeys( keys ) ) ); + } + + private void values( final Value[]... records ) + { + onPullAll( doAnswer( withServerList( records ) ) ); + } + + private void onGetServers( Stubber stubber ) + { + stubber.when( connection ).run( + eq( ClusterComposition.Provider.GET_SERVERS ), + eq( Collections.emptyMap() ), + any( Collector.class ) ); + } + + private void onPullAll( Stubber stubber ) + { + stubber.when( connection ).pullAll( any( Collector.class ) ); + } + + public static CollectorAnswer withKeys( final String... keys ) + { + return new CollectorAnswer() + { + @Override + void collect( Collector collector ) + { + collector.keys( keys ); + } + }; + } + + public static CollectorAnswer withServerList( final Value[]... records ) + { + return new CollectorAnswer() + { + @Override + void collect( Collector collector ) + { + for ( Value[] fields : records ) + { + collector.record( fields ); + } + } + }; + } + + private static abstract class CollectorAnswer implements Answer + { + abstract void collect( Collector collector ); + + @Override + public final Object answer( InvocationOnMock invocation ) throws Throwable + { + Collector collector = collector( invocation ); + collect( collector ); + collector.done(); + return null; + } + + private Collector collector( InvocationOnMock invocation ) + { + switch ( invocation.getMethod().getName() ) + { + case "pullAll": + return invocation.getArgumentAt( 0, Collector.class ); + case "run": + return invocation.getArgumentAt( 2, Collector.class ); + default: + throw new UnsupportedOperationException( invocation.getMethod().getName() ); + } + } + } + + public static Map serverInfo( String role, String... addresses ) + { + Map map = new HashMap<>(); + map.put( "role", role ); + map.put( "addresses", asList( addresses ) ); + return map; + } + + private static Set serverSet( String... addresses ) + { + Set result = new HashSet<>(); + for ( String address : addresses ) + { + result.add( new BoltServerAddress( address ) ); + } + return result; + } +} diff --git a/driver/src/test/java/org/neo4j/driver/internal/cluster/ClusterTopology.java b/driver/src/test/java/org/neo4j/driver/internal/cluster/ClusterTopology.java new file mode 100644 index 0000000000..ff75414528 --- /dev/null +++ b/driver/src/test/java/org/neo4j/driver/internal/cluster/ClusterTopology.java @@ -0,0 +1,198 @@ +/** + * Copyright (c) 2002-2016 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal.cluster; + +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import org.hamcrest.Description; +import org.hamcrest.Matcher; +import org.hamcrest.TypeSafeMatcher; + +import org.neo4j.driver.internal.Event; +import org.neo4j.driver.internal.EventHandler; +import org.neo4j.driver.internal.net.BoltServerAddress; +import org.neo4j.driver.internal.spi.Connection; +import org.neo4j.driver.internal.util.Clock; + +import static java.util.Arrays.asList; +import static org.neo4j.driver.internal.cluster.ClusterTopology.Role.READ; +import static org.neo4j.driver.internal.cluster.ClusterTopology.Role.ROUTE; +import static org.neo4j.driver.internal.cluster.ClusterTopology.Role.WRITE; + +class ClusterTopology implements ClusterComposition.Provider +{ + public interface EventSink + { + EventSink VOID = new Adapter(); + + void clusterComposition( BoltServerAddress address, ClusterComposition result ); + + class Adapter implements EventSink + { + @Override + public void clusterComposition( BoltServerAddress address, ClusterComposition result ) + { + } + } + } + + private static final List KEYS = Collections.unmodifiableList( asList( "servers", "ttl" ) ); + private final Map views = new HashMap<>(); + private final EventSink events; + private final Clock clock; + + ClusterTopology( final EventHandler events, Clock clock ) + { + this( events == null ? null : new EventSink() + { + @Override + public void clusterComposition( BoltServerAddress address, ClusterComposition result ) + { + events.add( new CompositionRequest( Thread.currentThread(), address, result ) ); + } + }, clock ); + } + + ClusterTopology( EventSink events, Clock clock ) + { + this.events = events == null ? EventSink.VOID : events; + this.clock = clock; + } + + public View on( String host, int port ) + { + View view = new View(); + views.put( new BoltServerAddress( host, port ), view ); + return view; + } + + public enum Role + { + READ, + WRITE, + ROUTE + } + + public static class View + { + private long ttl = 60_000; + private final Set readers = new HashSet<>(), + writers = new HashSet<>(), + routers = new HashSet<>(); + + public View ttlSeconds( long ttl ) + { + this.ttl = ttl * 1000; + return this; + } + + public View provide( String host, int port, Role... roles ) + { + for ( Role role : roles ) + { + servers( role ).add( new BoltServerAddress( host, port ) ); + } + return this; + } + + private Set servers( Role role ) + { + switch ( role ) + { + case READ: + return readers; + case WRITE: + return writers; + case ROUTE: + return routers; + default: + throw new IllegalArgumentException( role.name() ); + } + } + + ClusterComposition composition( long now ) + { + return new ClusterComposition( now + ttl, servers( READ ), servers( WRITE ), servers( ROUTE ) ); + } + } + + @Override + public ClusterComposition getClusterComposition( Connection connection ) + { + BoltServerAddress router = connection.address(); + View view = views.get( router ); + ClusterComposition result = view == null ? null : view.composition( clock.millis() ); + events.clusterComposition( router, result ); + return result; + } + + public static final class CompositionRequest extends Event + { + final Thread thread; + final BoltServerAddress address; + private final ClusterComposition result; + + private CompositionRequest( Thread thread, BoltServerAddress address, ClusterComposition result ) + { + this.thread = thread; + this.address = address; + this.result = result; + } + + @Override + public void dispatch( EventSink sink ) + { + sink.clusterComposition( address, result ); + } + + public static Matcher clusterComposition( + final Matcher thread, + final Matcher address, + final Matcher result ) + { + return new TypeSafeMatcher() + { + @Override + protected boolean matchesSafely( CompositionRequest event ) + { + return thread.matches( event.thread ) + && address.matches( event.address ) + && result.matches( event.result ); + } + + @Override + public void describeTo( Description description ) + { + description.appendText( "a successful cluster composition request on thread <" ) + .appendDescriptionOf( thread ) + .appendText( "> from address <" ) + .appendDescriptionOf( address ) + .appendText( "> returning <" ) + .appendDescriptionOf( result ) + .appendText( ">" ); + } + }; + } + } +} diff --git a/driver/src/test/java/org/neo4j/driver/internal/cluster/LoadBalancerTest.java b/driver/src/test/java/org/neo4j/driver/internal/cluster/LoadBalancerTest.java new file mode 100644 index 0000000000..d0644ecb28 --- /dev/null +++ b/driver/src/test/java/org/neo4j/driver/internal/cluster/LoadBalancerTest.java @@ -0,0 +1,599 @@ +/** + * Copyright (c) 2002-2016 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal.cluster; + +import java.util.ArrayList; +import java.util.List; + +import org.hamcrest.Matcher; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestRule; +import org.junit.runner.Description; +import org.junit.runners.model.Statement; + +import org.neo4j.driver.internal.EventHandler; +import org.neo4j.driver.internal.net.BoltServerAddress; +import org.neo4j.driver.internal.spi.Connection; +import org.neo4j.driver.internal.spi.StubConnectionPool; +import org.neo4j.driver.internal.util.FakeClock; +import org.neo4j.driver.internal.util.MatcherFactory; +import org.neo4j.driver.v1.EventLogger; +import org.neo4j.driver.v1.exceptions.ServiceUnavailableException; +import org.neo4j.driver.v1.util.Function; + +import static org.hamcrest.Matchers.any; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.sameInstance; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.fail; +import static org.neo4j.driver.internal.cluster.ClusterTopology.Role.READ; +import static org.neo4j.driver.internal.cluster.ClusterTopology.Role.ROUTE; +import static org.neo4j.driver.internal.cluster.ClusterTopology.Role.WRITE; +import static org.neo4j.driver.internal.spi.StubConnectionPool.Event.acquire; +import static org.neo4j.driver.internal.spi.StubConnectionPool.Event.connectionFailure; +import static org.neo4j.driver.internal.util.FakeClock.Event.sleep; +import static org.neo4j.driver.internal.util.MatcherFactory.inAnyOrder; +import static org.neo4j.driver.internal.util.MatcherFactory.matches; +import static org.neo4j.driver.v1.EventLogger.Entry.message; +import static org.neo4j.driver.v1.EventLogger.Level.INFO; + +public class LoadBalancerTest +{ + private static final long RETRY_TIMEOUT_DELAY = 5_000; + private static final int MAX_ROUTING_FAILURES = 5; + @Rule + public final TestRule printEventsOnFailure = new TestRule() + { + @Override + public Statement apply( final Statement base, Description description ) + { + return new Statement() + { + @Override + public void evaluate() throws Throwable + { + try + { + base.evaluate(); + } + catch ( Throwable e ) + { + events.printEvents( System.err ); + throw e; + } + } + }; + } + }; + private final EventHandler events = new EventHandler(); + private final FakeClock clock = new FakeClock( events, true ); + private final EventLogger log = new EventLogger( events, null, INFO ); + private final StubConnectionPool connections = new StubConnectionPool( clock, events, null ); + private final ClusterTopology cluster = new ClusterTopology( events, clock ); + + private LoadBalancer seedLoadBalancer( String host, int port ) throws Exception + { + return new LoadBalancer( + new RoutingSettings( MAX_ROUTING_FAILURES, RETRY_TIMEOUT_DELAY ), + clock, + log, + connections, + cluster, + new BoltServerAddress( host, port ) ); + } + + @Test + public void shouldConnectToRouter() throws Exception + { + // given + connections.up( "some.host", 1337 ); + cluster.on( "some.host", 1337 ) + .provide( "some.host", 1337, READ, WRITE, ROUTE ) + .provide( "another.host", 1337, ROUTE ); + + // when + Connection connection = seedLoadBalancer( "some.host", 1337 ).acquireReadConnection(); + + // then + events.assertCount( any( ClusterTopology.CompositionRequest.class ), equalTo( 1 ) ); + events.assertContains( acquiredConnection( "some.host", 1337, connection ) ); + } + + @Test + public void shouldConnectToRouterOnInitialization() throws Exception + { + // given + connections.up( "some.host", 1337 ); + cluster.on( "some.host", 1337 ) + .provide( "some.host", 1337, READ, WRITE, ROUTE ) + .provide( "another.host", 1337, ROUTE ); + + // when + seedLoadBalancer( "some.host", 1337 ); + + // then + events.assertCount( any( ClusterTopology.CompositionRequest.class ), equalTo( 1 ) ); + } + + @Test + public void shouldReconnectWithRouterAfterTtlExpires() throws Exception + { + // given + coreClusterOn( 20, "some.host", 1337, "another.host" ); + connections.up( "some.host", 1337 ).up( "another.host", 1337 ); + + LoadBalancer routing = seedLoadBalancer( "some.host", 1337 ); + + // when + clock.progress( 25_000 ); // will cause TTL timeout + Connection connection = routing.acquireWriteConnection(); + + // then + events.assertCount( any( ClusterTopology.CompositionRequest.class ), equalTo( 2 ) ); + events.assertContains( acquiredConnection( "some.host", 1337, connection ) ); + } + + @Test + public void shouldNotReconnectWithRouterWithinTtl() throws Exception + { + // given + coreClusterOn( 20, "some.host", 1337, "another.host" ); + connections.up( "some.host", 1337 ).up( "another.host", 1337 ); + + LoadBalancer routing = seedLoadBalancer( "some.host", 1337 ); + + // when + clock.progress( 15_000 ); // not enough to cause TTL timeout + routing.acquireWriteConnection(); + + // then + events.assertCount( any( ClusterTopology.CompositionRequest.class ), equalTo( 1 ) ); + } + + @Test + public void shouldReconnectWithRouterIfOnlyOneRouterIsFound() throws Exception + { + // given + cluster.on( "here", 1337 ) + .ttlSeconds( 20 ) + .provide( "here", 1337, READ, WRITE, ROUTE ); + connections.up( "here", 1337 ); + + LoadBalancer routing = seedLoadBalancer( "here", 1337 ); + + // when + routing.acquireReadConnection(); + + // then + events.assertCount( any( ClusterTopology.CompositionRequest.class ), equalTo( 2 ) ); + } + + @Test + public void shouldReconnectWithRouterIfNoReadersAreAvailable() throws Exception + { + // given + cluster.on( "one", 1337 ) + .ttlSeconds( 20 ) + .provide( "one", 1337, WRITE, ROUTE ) + .provide( "two", 1337, ROUTE ); + cluster.on( "two", 1337 ) + .ttlSeconds( 20 ) + .provide( "one", 1337, READ, WRITE, ROUTE ) + .provide( "two", 1337, READ, ROUTE ); + connections.up( "one", 1337 ).up( "two", 1337 ); + + LoadBalancer routing = seedLoadBalancer( "one", 1337 ); + + events.assertCount( any( ClusterTopology.CompositionRequest.class ), equalTo( 1 ) ); + + cluster.on( "one", 1337 ) + .ttlSeconds( 20 ) + .provide( "one", 1337, READ, WRITE, ROUTE ) + .provide( "two", 1337, READ, ROUTE ); + + // when + routing.acquireWriteConnection(); // we should require the presence of a READER even though we ask for a WRITER + + // then + events.assertCount( any( ClusterTopology.CompositionRequest.class ), equalTo( 2 ) ); + } + + @Test + public void shouldReconnectWithRouterIfNoWritersAreAvailable() throws Exception + { + // given + cluster.on( "one", 1337 ) + .ttlSeconds( 20 ) + .provide( "one", 1337, READ, ROUTE ) + .provide( "two", 1337, READ, WRITE, ROUTE ); + cluster.on( "two", 1337 ) + .ttlSeconds( 20 ) + .provide( "one", 1337, READ, ROUTE ) + .provide( "two", 1337, READ, WRITE, ROUTE ); + connections.up( "one", 1337 ); + + events.registerHandler( StubConnectionPool.EventSink.class, new StubConnectionPool.EventSink.Adapter() + { + @Override + public void connectionFailure( BoltServerAddress address ) + { + connections.up( "two", 1337 ); + } + } ); + + LoadBalancer routing = seedLoadBalancer( "one", 1337 ); + + events.assertCount( any( ClusterTopology.CompositionRequest.class ), equalTo( 1 ) ); + + cluster.on( "one", 1337 ) + .ttlSeconds( 20 ) + .provide( "one", 1337, READ, WRITE, ROUTE ) + .provide( "two", 1337, READ, ROUTE ); + + // when + routing.acquireWriteConnection(); + + // then + events.assertCount( any( ClusterTopology.CompositionRequest.class ), equalTo( 2 ) ); + } + + @Test + public void shouldDropRouterUnableToPerformRoutingTask() throws Exception + { + // given + connections.up( "some.host", 1337 ) + .up( "other.host", 1337 ) + .up( "another.host", 1337 ); + cluster.on( "some.host", 1337 ) + .ttlSeconds( 20 ) + .provide( "some.host", 1337, READ, WRITE, ROUTE ) + .provide( "other.host", 1337, READ, ROUTE ); + cluster.on( "another.host", 1337 ) + .ttlSeconds( 20 ) + .provide( "some.host", 1337, READ, WRITE, ROUTE ) + .provide( "another.host", 1337, READ, ROUTE ); + events.registerHandler( ClusterTopology.EventSink.class, new ClusterTopology.EventSink.Adapter() + { + @Override + public void clusterComposition( BoltServerAddress address, ClusterComposition result ) + { + if ( result == null ) + { + connections.up( "some.host", 1337 ); + cluster.on( "some.host", 1337 ) + .ttlSeconds( 20 ) + .provide( "some.host", 1337, READ, WRITE, ROUTE ) + .provide( "another.host", 1337, READ, ROUTE ); + } + } + } ); + + LoadBalancer routing = seedLoadBalancer( "some.host", 1337 ); + + // when + connections.down( "some.host", 1337 ); + clock.progress( 25_000 ); // will cause TTL timeout + Connection connection = routing.acquireWriteConnection(); + + // then + events.assertCount( + message( + equalTo( INFO ), + equalTo( "Server unable to perform routing capability, " + + "dropping from list of routers." ) ), + equalTo( 1 ) ); + events.assertContains( acquiredConnection( "some.host", 1337, connection ) ); + } + + @Test + public void shouldConnectToRoutingServersInTimeoutOrder() throws Exception + { + // given + coreClusterOn( 20, "one", 1337, "two", "tre" ); + connections.up( "one", 1337 ); + events.registerHandler( StubConnectionPool.EventSink.class, new StubConnectionPool.EventSink.Adapter() + { + int failed; + + @Override + public void connectionFailure( BoltServerAddress address ) + { + if ( ++failed >= 9 ) // three times per server + { + for ( String host : new String[] {"one", "two", "tre"} ) + { + connections.up( host, 1337 ); + } + } + } + } ); + + LoadBalancer routing = seedLoadBalancer( "one", 1337 ); + + // when + connections.down( "one", 1337 ); + clock.progress( 25_000 ); // will cause TTL timeout + routing.acquireWriteConnection(); + + // then + MatcherFactory failedAttempts = inAnyOrder( + connectionFailure( "one", 1337 ), + connectionFailure( "two", 1337 ), + connectionFailure( "tre", 1337 ) ); + events.assertContains( + failedAttempts, + matches( sleep( RETRY_TIMEOUT_DELAY ) ), + failedAttempts, + matches( sleep( 2 * RETRY_TIMEOUT_DELAY ) ), + failedAttempts, + matches( sleep( 4 * RETRY_TIMEOUT_DELAY ) ), + matches( ClusterTopology.CompositionRequest.clusterComposition( + any( Thread.class ), + any( BoltServerAddress.class ), + any( ClusterComposition.class ) ) ) ); + } + + @Test + public void shouldFailIfEnoughConnectionAttemptsFail() throws Exception + { + // when + try + { + seedLoadBalancer( "one", 1337 ); + fail( "expected failure" ); + } + // then + catch ( ServiceUnavailableException e ) + { + assertEquals( "Could not perform discovery. No routing servers available.", e.getMessage() ); + } + } + + private static final Function READ_SERVERS = new Function() + { + @Override + public Connection apply( LoadBalancer routing ) + { + return routing.acquireReadConnection(); + } + }; + + @Test + public void shouldRoundRobinAmongReadServers() throws Exception + { + shouldRoundRobinAmong( READ_SERVERS ); + } + + private static final Function WRITE_SERVERS = new Function() + { + @Override + public Connection apply( LoadBalancer routing ) + { + return routing.acquireWriteConnection(); + } + }; + + @Test + public void shouldRoundRobinAmongWriteServers() throws Exception + { + shouldRoundRobinAmong( WRITE_SERVERS ); + } + + private void shouldRoundRobinAmong( Function acquire ) throws Exception + { + // given + for ( String host : new String[] {"one", "two", "tre"} ) + { + connections.up( host, 1337 ); + cluster.on( host, 1337 ) + .ttlSeconds( 20 ) + .provide( "one", 1337, READ, WRITE, ROUTE ) + .provide( "two", 1337, READ, WRITE, ROUTE ) + .provide( "tre", 1337, READ, WRITE, ROUTE ); + } + LoadBalancer routing = seedLoadBalancer( "one", 1337 ); + + // when + Connection a = acquire.apply( routing ); + Connection b = acquire.apply( routing ); + Connection c = acquire.apply( routing ); + assertNotEquals( a.address(), b.address() ); + assertNotEquals( b.address(), c.address() ); + assertNotEquals( c.address(), a.address() ); + assertEquals( a.address(), acquire.apply( routing ).address() ); + assertEquals( b.address(), acquire.apply( routing ).address() ); + assertEquals( c.address(), acquire.apply( routing ).address() ); + assertEquals( a.address(), acquire.apply( routing ).address() ); + assertEquals( b.address(), acquire.apply( routing ).address() ); + assertEquals( c.address(), acquire.apply( routing ).address() ); + + // then + MatcherFactory acquireConnections = + inAnyOrder( acquire( "one", 1337 ), acquire( "two", 1337 ), acquire( "tre", 1337 ) ); + events.assertContains( acquireConnections, acquireConnections, acquireConnections ); + events.assertContains( inAnyOrder( acquire( a ), acquire( b ), acquire( c ) ) ); + } + + @Test + public void shouldRoundRobinAmongRouters() throws Exception + { + // given + coreClusterOn( 20, "one", 1337, "two", "tre" ); + connections.up( "one", 1337 ).up( "two", 1337 ).up( "tre", 1337 ); + + // when + LoadBalancer routing = seedLoadBalancer( "one", 1337 ); + for ( int i = 1; i < 9; i++ ) + { + clock.progress( 25_000 ); + routing.acquireReadConnection(); + } + + // then + final List hosts = new ArrayList<>(); + events.forEach( new ClusterTopology.EventSink() + { + @Override + public void clusterComposition( BoltServerAddress address, ClusterComposition result ) + { + hosts.add( address.host() ); + } + } ); + assertEquals( 9, hosts.size() ); + assertEquals( hosts.get( 0 ), hosts.get( 3 ) ); + assertEquals( hosts.get( 1 ), hosts.get( 4 ) ); + assertEquals( hosts.get( 2 ), hosts.get( 5 ) ); + assertEquals( hosts.get( 0 ), hosts.get( 6 ) ); + assertEquals( hosts.get( 1 ), hosts.get( 7 ) ); + assertEquals( hosts.get( 2 ), hosts.get( 8 ) ); + assertNotEquals( hosts.get( 0 ), hosts.get( 1 ) ); + assertNotEquals( hosts.get( 1 ), hosts.get( 2 ) ); + assertNotEquals( hosts.get( 2 ), hosts.get( 0 ) ); + } + + @Test + public void shouldForgetPreviousServersOnRerouting() throws Exception + { + // given + connections.up( "one", 1337 ) + .up( "two", 1337 ); + cluster.on( "one", 1337 ) + .ttlSeconds( 20 ) + .provide( "bad", 1337, READ, WRITE, ROUTE ) + .provide( "one", 1337, READ, ROUTE ); + + LoadBalancer routing = seedLoadBalancer( "one", 1337 ); + + // when + coreClusterOn( 20, "one", 1337, "two" ); + clock.progress( 25_000 ); // will cause TTL timeout + Connection ra = routing.acquireReadConnection(); + Connection rb = routing.acquireReadConnection(); + Connection w = routing.acquireWriteConnection(); + assertNotEquals( ra.address(), rb.address() ); + assertEquals( ra.address(), routing.acquireReadConnection().address() ); + assertEquals( rb.address(), routing.acquireReadConnection().address() ); + assertEquals( w.address(), routing.acquireWriteConnection().address() ); + assertEquals( ra.address(), routing.acquireReadConnection().address() ); + assertEquals( rb.address(), routing.acquireReadConnection().address() ); + assertEquals( w.address(), routing.acquireWriteConnection().address() ); + + // then + events.assertNone( acquire( "bad", 1337 ) ); + } + + @Test + public void shouldFailIfNoRouting() throws Exception + { + // given + connections.up( "one", 1337 ); + cluster.on( "one", 1337 ) + .provide( "one", 1337, READ, WRITE ); + + // when + try + { + seedLoadBalancer( "one", 1337 ); + fail( "expected failure" ); + } + // then + catch ( ServiceUnavailableException e ) + { + assertEquals( "Could not perform discovery. No routing servers available.", e.getMessage() ); + } + } + + @Test + public void shouldFailIfNoWriting() throws Exception + { + // given + connections.up( "one", 1337 ); + cluster.on( "one", 1337 ) + .provide( "one", 1337, READ, ROUTE ); + + // when + try + { + seedLoadBalancer( "one", 1337 ); + fail( "expected failure" ); + } + // then + catch ( ServiceUnavailableException e ) + { + assertEquals( "Could not perform discovery. No routing servers available.", e.getMessage() ); + } + } + + @Test + public void shouldNotForgetAddressForRoutingPurposesWhenUnavailableForOtherUse() throws Exception + { + // given + cluster.on( "one", 1337 ) + .provide( "one", 1337, READ, ROUTE ) + .provide( "two", 1337, WRITE, ROUTE ); + cluster.on( "two", 1337 ) + .provide( "one", 1337, READ, ROUTE ) + .provide( "two", 1337, WRITE, ROUTE ); + connections.up( "one", 1337 ); + + LoadBalancer routing = seedLoadBalancer( "one", 1337 ); + connections.down( "one", 1337 ); + events.registerHandler( FakeClock.EventSink.class, new FakeClock.EventSink.Adapter() + { + @Override + public void sleep( long timestamp, long millis ) + { + connections.up( "two", 1337 ); + } + } ); + + // when + Connection connection = routing.acquireWriteConnection(); + + // then + assertEquals( new BoltServerAddress( "two", 1337 ), connection.address() ); + events.printEvents( System.out ); + } + + private void coreClusterOn( int ttlSeconds, String leader, int port, String... others ) + { + for ( int i = 0; i <= others.length; i++ ) + { + String host = (i == others.length) ? leader : others[i]; + ClusterTopology.View view = cluster.on( host, port ) + .ttlSeconds( ttlSeconds ) + .provide( leader, port, READ, WRITE, ROUTE ); + for ( String other : others ) + { + view.provide( other, port, READ, ROUTE ); + } + } + } + + private Matcher acquiredConnection( + String host, int port, Connection connection ) + { + return acquire( + any( Thread.class ), + equalTo( new BoltServerAddress( host, port ) ), + sameInstance( connection ) ); + } +} diff --git a/driver/src/test/java/org/neo4j/driver/internal/cluster/RoundRobinAddressSetTest.java b/driver/src/test/java/org/neo4j/driver/internal/cluster/RoundRobinAddressSetTest.java new file mode 100644 index 0000000000..4b69ac8f95 --- /dev/null +++ b/driver/src/test/java/org/neo4j/driver/internal/cluster/RoundRobinAddressSetTest.java @@ -0,0 +1,226 @@ +/** + * Copyright (c) 2002-2016 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal.cluster; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; + +import org.junit.Test; + +import org.neo4j.driver.internal.net.BoltServerAddress; + +import static java.util.Arrays.asList; +import static java.util.Collections.singleton; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.fail; + +public class RoundRobinAddressSetTest +{ + @Test + public void shouldReturnNullWhenEmpty() throws Exception + { + // given + RoundRobinAddressSet set = new RoundRobinAddressSet(); + + // then + assertNull( set.next() ); + } + + @Test + public void shouldReturnRoundRobin() throws Exception + { + // given + RoundRobinAddressSet set = new RoundRobinAddressSet(); + set.update( new HashSet<>( asList( + new BoltServerAddress( "one" ), + new BoltServerAddress( "two" ), + new BoltServerAddress( "tre" ) ) ), new HashSet() ); + + // when + BoltServerAddress a = set.next(); + BoltServerAddress b = set.next(); + BoltServerAddress c = set.next(); + + // then + assertEquals( a, set.next() ); + assertEquals( b, set.next() ); + assertEquals( c, set.next() ); + assertEquals( a, set.next() ); + assertEquals( b, set.next() ); + assertEquals( c, set.next() ); + assertNotEquals( a, c ); + assertNotEquals( b, a ); + assertNotEquals( c, b ); + } + + @Test + public void shouldPreserveOrderWhenAdding() throws Exception + { + // given + HashSet servers = new HashSet<>( asList( + new BoltServerAddress( "one" ), + new BoltServerAddress( "two" ), + new BoltServerAddress( "tre" ) ) ); + RoundRobinAddressSet set = new RoundRobinAddressSet(); + set.update( servers, new HashSet() ); + + List order = new ArrayList<>(); + for ( int i = 3 * 4 + 1; i-- > 0; ) + { + BoltServerAddress server = set.next(); + if ( !order.contains( server ) ) + { + order.add( server ); + } + } + assertEquals( 3, order.size() ); + + // when + servers.add( new BoltServerAddress( "fyr" ) ); + set.update( servers, new HashSet() ); + + // then + assertEquals( order.get( 1 ), set.next() ); + assertEquals( order.get( 2 ), set.next() ); + BoltServerAddress next = set.next(); + assertNotEquals( order.get( 0 ), next ); + assertNotEquals( order.get( 1 ), next ); + assertNotEquals( order.get( 2 ), next ); + assertEquals( order.get( 0 ), set.next() ); + // ... and once more + assertEquals( order.get( 1 ), set.next() ); + assertEquals( order.get( 2 ), set.next() ); + assertEquals( next, set.next() ); + assertEquals( order.get( 0 ), set.next() ); + } + + @Test + public void shouldPreserveOrderWhenRemoving() throws Exception + { + // given + HashSet servers = new HashSet<>( asList( + new BoltServerAddress( "one" ), + new BoltServerAddress( "two" ), + new BoltServerAddress( "tre" ) ) ); + RoundRobinAddressSet set = new RoundRobinAddressSet(); + set.update( servers, new HashSet() ); + + List order = new ArrayList<>(); + for ( int i = 3 * 2 + 1; i-- > 0; ) + { + BoltServerAddress server = set.next(); + if ( !order.contains( server ) ) + { + order.add( server ); + } + } + assertEquals( 3, order.size() ); + + // when + set.remove( order.get( 1 ) ); + + // then + assertEquals( order.get( 2 ), set.next() ); + assertEquals( order.get( 0 ), set.next() ); + assertEquals( order.get( 2 ), set.next() ); + assertEquals( order.get( 0 ), set.next() ); + } + + @Test + public void shouldPreserveOrderWhenRemovingThroughUpdate() throws Exception + { + // given + HashSet servers = new HashSet<>( asList( + new BoltServerAddress( "one" ), + new BoltServerAddress( "two" ), + new BoltServerAddress( "tre" ) ) ); + RoundRobinAddressSet set = new RoundRobinAddressSet(); + set.update( servers, new HashSet() ); + + List order = new ArrayList<>(); + for ( int i = 3 * 2 + 1; i-- > 0; ) + { + BoltServerAddress server = set.next(); + if ( !order.contains( server ) ) + { + order.add( server ); + } + } + assertEquals( 3, order.size() ); + + // when + servers.remove( order.get( 1 ) ); + set.update( servers, new HashSet() ); + + // then + assertEquals( order.get( 2 ), set.next() ); + assertEquals( order.get( 0 ), set.next() ); + assertEquals( order.get( 2 ), set.next() ); + assertEquals( order.get( 0 ), set.next() ); + } + + @Test + public void shouldRecordRemovedAddressesWhenUpdating() throws Exception + { + // given + RoundRobinAddressSet set = new RoundRobinAddressSet(); + set.update( + new HashSet<>( asList( + new BoltServerAddress( "one" ), + new BoltServerAddress( "two" ), + new BoltServerAddress( "tre" ) ) ), + new HashSet() ); + + // when + HashSet removed = new HashSet<>(); + set.update( + new HashSet<>( asList( + new BoltServerAddress( "one" ), + new BoltServerAddress( "two" ), + new BoltServerAddress( "fyr" ) ) ), + removed ); + + // then + assertEquals( singleton( new BoltServerAddress( "tre" ) ), removed ); + } + + @Test + public void shouldPreserveOrderEvenWhenIntegerOverflows() throws Exception + { + // given + RoundRobinAddressSet set = new RoundRobinAddressSet(); + + for ( int div = 1; div <= 1024; div++ ) + { + // when - white box testing! + set.setOffset( Integer.MAX_VALUE - 1 ); + int a = set.next( div ); + int b = set.next( div ); + + // then + if ( b != (a + 1) % div ) + { + fail( String.format( "a=%d, b=%d, div=%d, (a+1)%%div=%d", a, b, div, (a + 1) % div ) ); + } + } + } +} diff --git a/driver/src/test/java/org/neo4j/driver/internal/spi/StubConnectionPool.java b/driver/src/test/java/org/neo4j/driver/internal/spi/StubConnectionPool.java new file mode 100644 index 0000000000..387beaf57e --- /dev/null +++ b/driver/src/test/java/org/neo4j/driver/internal/spi/StubConnectionPool.java @@ -0,0 +1,498 @@ +/** + * Copyright (c) 2002-2016 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal.spi; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; + +import org.hamcrest.Description; +import org.hamcrest.Matcher; +import org.hamcrest.TypeSafeMatcher; + +import org.neo4j.driver.internal.EventHandler; +import org.neo4j.driver.internal.net.BoltServerAddress; +import org.neo4j.driver.internal.net.pooling.PooledConnection; +import org.neo4j.driver.internal.util.Clock; +import org.neo4j.driver.internal.util.Consumer; +import org.neo4j.driver.v1.exceptions.ServiceUnavailableException; +import org.neo4j.driver.v1.util.Function; + +import static org.hamcrest.Matchers.any; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.sameInstance; + +public class StubConnectionPool implements ConnectionPool +{ + public interface EventSink + { + EventSink VOID = new Adapter(); + + void acquire( BoltServerAddress address, Connection connection ); + + void release( BoltServerAddress address, Connection connection ); + + void connectionFailure( BoltServerAddress address ); + + void purge( BoltServerAddress address, boolean connected ); + + void close( Collection connected ); + + class Adapter implements EventSink + { + @Override + public void acquire( BoltServerAddress address, Connection connection ) + { + } + + @Override + public void release( BoltServerAddress address, Connection connection ) + { + } + + @Override + public void connectionFailure( BoltServerAddress address ) + { + } + + @Override + public void purge( BoltServerAddress address, boolean connected ) + { + } + + @Override + public void close( Collection connected ) + { + } + } + } + + private final Clock clock; + private final EventSink events; + private final Function factory; + private static final Function NULL_FACTORY = + new Function() + { + @Override + public Connection apply( BoltServerAddress boltServerAddress ) + { + return null; + } + }; + + public StubConnectionPool( Clock clock, final EventHandler events, Function factory ) + { + this( clock, events == null ? null : new EventSink() + { + @Override + public void acquire( BoltServerAddress address, Connection connection ) + { + events.add( new AcquireEvent( Thread.currentThread(), address, connection ) ); + } + + @Override + public void release( BoltServerAddress address, Connection connection ) + { + events.add( new ReleaseEvent( Thread.currentThread(), address, connection ) ); + } + + @Override + public void connectionFailure( BoltServerAddress address ) + { + events.add( new ConnectionFailureEvent( Thread.currentThread(), address ) ); + } + + @Override + public void purge( BoltServerAddress address, boolean connected ) + { + events.add( new PurgeEvent( Thread.currentThread(), address, connected ) ); + } + + @Override + public void close( Collection connected ) + { + events.add( new CloseEvent( Thread.currentThread(), connected ) ); + } + }, factory ); + } + + public StubConnectionPool( Clock clock, EventSink events, Function factory ) + { + this.clock = clock; + this.events = events == null ? EventSink.VOID : events; + this.factory = factory == null ? NULL_FACTORY : factory; + } + + private enum State + { + AVAILABLE, + CONNECTED, + PURGED + } + + private final ConcurrentMap hosts = new ConcurrentHashMap<>(); + + public StubConnectionPool up( String host, int port ) + { + hosts.putIfAbsent( new BoltServerAddress( host, port ), State.AVAILABLE ); + return this; + } + + public StubConnectionPool down( String host, int port ) + { + hosts.remove( new BoltServerAddress( host, port ) ); + return this; + } + + @Override + public Connection acquire( BoltServerAddress address ) + { + if ( hosts.replace( address, State.CONNECTED ) == null ) + { + events.connectionFailure( address ); + throw new ServiceUnavailableException( "Host unavailable: " + address ); + } + Connection connection = new StubConnection( address, factory.apply( address ), events, clock ); + events.acquire( address, connection ); + return connection; + } + + @Override + public void purge( BoltServerAddress address ) + { + State state = hosts.replace( address, State.PURGED ); + events.purge( address, state == State.CONNECTED ); + } + + @Override + public boolean hasAddress( BoltServerAddress address ) + { + return State.CONNECTED == hosts.get( address ); + } + + @Override + public void close() + { + List connected = new ArrayList<>( hosts.size() ); + for ( Map.Entry entry : hosts.entrySet() ) + { + if ( entry.getValue() == State.CONNECTED ) + { + connected.add( entry.getKey() ); + } + } + events.close( connected ); + hosts.clear(); + } + + private static class StubConnection extends PooledConnection + { + private final BoltServerAddress address; + + StubConnection( + final BoltServerAddress address, + final Connection delegate, + final EventSink events, + Clock clock ) + { + super( delegate, new Consumer() + { + @Override + public void accept( PooledConnection self ) + { + events.release( address, self ); + if ( delegate != null ) + { + delegate.close(); + } + } + }, clock ); + this.address = address; + } + + @Override + public String toString() + { + return String.format( "StubConnection{%s}@%s", address, System.identityHashCode( this ) ); + } + + @Override + public BoltServerAddress address() + { + return address; + } + } + + public static abstract class Event extends org.neo4j.driver.internal.Event + { + final Thread thread; + + private Event( Thread thread ) + { + this.thread = thread; + } + + public static Matcher acquire( String host, int port ) + { + return acquire( + any( Thread.class ), + equalTo( new BoltServerAddress( host, port ) ), + any( Connection.class ) ); + } + + public static Matcher acquire( Connection connection ) + { + return acquire( any( Thread.class ), any( BoltServerAddress.class ), sameInstance( connection ) ); + } + + public static Matcher acquire( + final Matcher thread, + final Matcher address, + final Matcher connection ) + { + return new TypeSafeMatcher() + { + @Override + public void describeTo( Description description ) + { + description.appendText( "acquire event on thread <" ) + .appendDescriptionOf( thread ) + .appendText( "> of address <" ) + .appendDescriptionOf( address ) + .appendText( "> resulting in connection <" ) + .appendDescriptionOf( connection ) + .appendText( ">" ); + } + + @Override + protected boolean matchesSafely( AcquireEvent event ) + { + return thread.matches( event.thread ) && + address.matches( event.address ) && + connection.matches( event.connection ); + } + }; + } + + public static Matcher release( + final Matcher thread, + final Matcher address, + final Matcher connection ) + { + return new TypeSafeMatcher() + { + @Override + public void describeTo( Description description ) + { + description.appendText( "release event on thread <" ) + .appendDescriptionOf( thread ) + .appendText( "> of address <" ) + .appendDescriptionOf( address ) + .appendText( "> and connection <" ) + .appendDescriptionOf( connection ) + .appendText( ">" ); + } + + @Override + protected boolean matchesSafely( ReleaseEvent event ) + { + return thread.matches( event.thread ) && + address.matches( event.address ) && + connection.matches( event.connection ); + } + }; + } + + public static Matcher connectionFailure( String host, int port ) + { + return connectionFailure( any( Thread.class ), equalTo( new BoltServerAddress( host, port ) ) ); + } + + public static Matcher connectionFailure( + final Matcher thread, + final Matcher address ) + { + return new TypeSafeMatcher() + { + @Override + public void describeTo( Description description ) + { + description.appendText( "connection failure event on thread <" ) + .appendDescriptionOf( thread ) + .appendText( "> of address <" ) + .appendDescriptionOf( address ) + .appendText( ">" ); + } + + @Override + protected boolean matchesSafely( ConnectionFailureEvent event ) + { + return thread.matches( event.thread ) && + address.matches( event.address ); + } + }; + } + + public static Matcher purge( + final Matcher thread, + final Matcher address, + final Matcher removed ) + { + return new TypeSafeMatcher() + { + @Override + public void describeTo( Description description ) + { + description.appendText( "purge event on thread <" ) + .appendDescriptionOf( thread ) + .appendText( "> of address <" ) + .appendDescriptionOf( address ) + .appendText( "> resulting in actual removal: " ) + .appendDescriptionOf( removed ); + } + + @Override + protected boolean matchesSafely( PurgeEvent event ) + { + return thread.matches( event.thread ) && + address.matches( event.address ) && + removed.matches( event.connected ); + } + }; + } + + public static Matcher close( + final Matcher thread, + final Matcher> addresses ) + { + return new TypeSafeMatcher() + { + @Override + public void describeTo( Description description ) + { + description.appendText( "close event on thread <" ) + .appendDescriptionOf( thread ) + .appendText( "> resulting in closing connections to <" ) + .appendDescriptionOf( addresses ) + .appendText( ">" ); + } + + @Override + protected boolean matchesSafely( CloseEvent event ) + { + return thread.matches( event.thread ) && addresses.matches( event.connected ); + } + }; + } + } + + private static class AcquireEvent extends Event + { + private final BoltServerAddress address; + private final Connection connection; + + AcquireEvent( Thread thread, BoltServerAddress address, Connection connection ) + { + super( thread ); + this.address = address; + this.connection = connection; + } + + @Override + public void dispatch( EventSink sink ) + { + sink.acquire( address, connection ); + } + } + + private static class ReleaseEvent extends Event + { + private final BoltServerAddress address; + private final Connection connection; + + ReleaseEvent( Thread thread, BoltServerAddress address, Connection connection ) + { + super( thread ); + this.address = address; + this.connection = connection; + } + + @Override + public void dispatch( EventSink sink ) + { + sink.release( address, connection ); + } + } + + private static class ConnectionFailureEvent extends Event + { + private final BoltServerAddress address; + + ConnectionFailureEvent( Thread thread, BoltServerAddress address ) + { + super( thread ); + this.address = address; + } + + @Override + public void dispatch( EventSink sink ) + { + sink.connectionFailure( address ); + } + } + + private static class PurgeEvent extends Event + { + private final BoltServerAddress address; + private final boolean connected; + + PurgeEvent( Thread thread, BoltServerAddress address, boolean connected ) + { + super( thread ); + this.address = address; + this.connected = connected; + } + + @Override + public void dispatch( EventSink sink ) + { + sink.purge( address, connected ); + } + } + + private static class CloseEvent extends Event + { + private final Collection connected; + + CloseEvent( Thread thread, Collection connected ) + { + super( thread ); + this.connected = connected; + } + + @Override + public void dispatch( EventSink sink ) + { + sink.close( connected ); + } + } +} diff --git a/driver/src/test/java/org/neo4j/driver/internal/util/ConcurrentRoundRobinSetTest.java b/driver/src/test/java/org/neo4j/driver/internal/util/ConcurrentRoundRobinSetTest.java deleted file mode 100644 index a2c1e32f0b..0000000000 --- a/driver/src/test/java/org/neo4j/driver/internal/util/ConcurrentRoundRobinSetTest.java +++ /dev/null @@ -1,170 +0,0 @@ -/** - * Copyright (c) 2002-2016 "Neo Technology," - * Network Engine for Objects in Lund AB [http://neotechnology.com] - * - * This file is part of Neo4j. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.neo4j.driver.internal.util; - - -import org.junit.Test; - -import java.util.Comparator; -import java.util.HashSet; - -import static java.util.Arrays.asList; -import static junit.framework.TestCase.assertFalse; -import static junit.framework.TestCase.assertTrue; -import static org.hamcrest.CoreMatchers.equalTo; -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.empty; - -public class ConcurrentRoundRobinSetTest -{ - - @Test - public void shouldBeAbleToIterateIndefinitely() - { - // Given - ConcurrentRoundRobinSet integers = new ConcurrentRoundRobinSet<>(); - - // When - integers.addAll( asList( 0, 1, 2, 3, 4 ) ); - - // Then - for ( int i = 0; i < 100; i++ ) - { - assertThat( integers.hop(), equalTo( i % 5 ) ); - } - } - - @Test - public void shouldBeAbleToUseCustomComparator() - { - // Given - ConcurrentRoundRobinSet integers = new ConcurrentRoundRobinSet<>( new Comparator() - { - @Override - public int compare( Integer o1, Integer o2 ) - { - return Integer.compare( o2, o1 ); - } - } ); - - // When - integers.addAll( asList( 0, 1, 2, 3, 4 ) ); - - // Then - assertThat( integers.hop(), equalTo( 4 ) ); - assertThat( integers.hop(), equalTo( 3 ) ); - assertThat( integers.hop(), equalTo( 2 ) ); - assertThat( integers.hop(), equalTo( 1 ) ); - assertThat( integers.hop(), equalTo( 0 ) ); - assertThat( integers.hop(), equalTo( 4 ) ); - assertThat( integers.hop(), equalTo( 3 ) ); - //.... - } - - @Test - public void shouldBeAbleToClearSet() - { - // Given - ConcurrentRoundRobinSet integers = new ConcurrentRoundRobinSet<>(); - - // When - integers.addAll( asList( 0, 1, 2, 3, 4 ) ); - integers.clear(); - - // Then - assertThat( integers, empty() ); - } - - @Test - public void shouldBeAbleToCheckIfContainsElement() - { - // Given - ConcurrentRoundRobinSet integers = new ConcurrentRoundRobinSet<>(); - - // When - integers.addAll( asList( 0, 1, 2, 3, 4 ) ); - - - // Then - assertTrue( integers.contains( 3 ) ); - assertFalse( integers.contains( 7 ) ); - } - - @Test - public void shouldBeAbleToCheckIfContainsMultipleElements() - { - // Given - ConcurrentRoundRobinSet integers = new ConcurrentRoundRobinSet<>(); - - // When - integers.addAll( asList( 0, 1, 2, 3, 4 ) ); - - - // Then - assertTrue( integers.containsAll( asList( 3, 1 ) ) ); - assertFalse( integers.containsAll( asList( 2, 3, 4, 7 ) ) ); - } - - @Test - public void shouldBeAbleToCheckIfEmptyAndSize() - { - // Given - ConcurrentRoundRobinSet integers = new ConcurrentRoundRobinSet<>(); - - // When - integers.addAll( asList( 0, 1, 2, 3, 4 ) ); - - - // Then - assertFalse( integers.isEmpty() ); - assertThat( integers.size(), equalTo( 5 ) ); - integers.clear(); - assertTrue( integers.isEmpty() ); - assertThat( integers.size(), equalTo( 0 ) ); - } - - - @Test - public void shouldBeAbleToCreateArray() - { - // Given - ConcurrentRoundRobinSet integers = new ConcurrentRoundRobinSet<>(); - - // When - integers.addAll( asList( 0, 1, 2, 3, 4 ) ); - Object[] objects = integers.toArray(); - - // Then - assertThat( objects, equalTo( new Object[]{0, 1, 2, 3, 4} ) ); - } - - @Test - public void shouldBeAbleToCreateTypedArray() - { - // Given - ConcurrentRoundRobinSet integers = new ConcurrentRoundRobinSet<>(); - - // When - integers.addAll( asList( 0, 1, 2, 3, 4 ) ); - Integer[] array = integers.toArray( new Integer[5] ); - - // Then - assertThat( array, equalTo( new Integer[]{0, 1, 2, 3, 4} ) ); - } -} \ No newline at end of file diff --git a/driver/src/test/java/org/neo4j/driver/internal/util/FakeClock.java b/driver/src/test/java/org/neo4j/driver/internal/util/FakeClock.java new file mode 100644 index 0000000000..c697eb0fd4 --- /dev/null +++ b/driver/src/test/java/org/neo4j/driver/internal/util/FakeClock.java @@ -0,0 +1,260 @@ +/** + * Copyright (c) 2002-2016 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal.util; + +import java.util.concurrent.PriorityBlockingQueue; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import java.util.concurrent.locks.LockSupport; + +import org.hamcrest.Description; +import org.hamcrest.Matcher; +import org.hamcrest.TypeSafeMatcher; + +import org.neo4j.driver.internal.EventHandler; + +import static java.util.concurrent.atomic.AtomicLongFieldUpdater.newUpdater; +import static org.hamcrest.Matchers.any; +import static org.hamcrest.Matchers.equalTo; + +public class FakeClock implements Clock +{ + public interface EventSink + { + EventSink VOID = new Adapter(); + + void sleep( long timestamp, long millis ); + + void progress( long millis ); + + class Adapter implements EventSink + { + @Override + public void sleep( long timestamp, long millis ) + { + } + + @Override + public void progress( long millis ) + { + } + } + } + + private final EventSink events; + @SuppressWarnings( "unused"/*assigned through AtomicLongFieldUpdater*/ ) + private volatile long timestamp; + private static final AtomicLongFieldUpdater TIMESTAMP = newUpdater( FakeClock.class, "timestamp" ); + private PriorityBlockingQueue threads; + + public FakeClock( final EventHandler events, boolean progressOnSleep ) + { + this( events == null ? null : new EventSink() + { + @Override + public void sleep( long timestamp, long duration ) + { + events.add( new Event.Sleep( Thread.currentThread(), timestamp, duration ) ); + } + + @Override + public void progress( long timestamp ) + { + events.add( new Event.Progress( Thread.currentThread(), timestamp ) ); + } + }, progressOnSleep ); + } + + public FakeClock( EventSink events, boolean progressOnSleep ) + { + this.events = events == null ? EventSink.VOID : events; + this.threads = progressOnSleep ? null : new PriorityBlockingQueue(); + } + + @Override + public long millis() + { + return timestamp; + } + + @Override + public void sleep( long millis ) + { + if ( millis <= 0 ) + { + return; + } + long target = timestamp + millis; + events.sleep( target - millis, millis ); + if ( threads == null ) + { + progress( millis ); + } + else + { + // park until the target time has been reached + WaitingThread token = new WaitingThread( Thread.currentThread(), target ); + threads.add( token ); + for ( ; ; ) + { + if ( timestamp >= target ) + { + threads.remove( token ); + return; + } + // park with a timeout to guarantee that we make progress even if something goes wrong + LockSupport.parkNanos( this, TimeUnit.MILLISECONDS.toNanos( millis ) ); + } + } + } + + public void progress( long millis ) + { + if ( millis < 0 ) + { + throw new IllegalArgumentException( "time can only progress forwards" ); + } + events.progress( TIMESTAMP.addAndGet( this, millis ) ); + if ( threads != null ) + { + // wake up the threads that are sleeping awaiting the current time + for ( WaitingThread thread; (thread = threads.peek()) != null; ) + { + if ( thread.timestamp < timestamp ) + { + threads.remove( thread ); + LockSupport.unpark( thread.thread ); + } + } + } + } + + public static abstract class Event extends org.neo4j.driver.internal.Event + { + final Thread thread; + + private Event( Thread thread ) + { + this.thread = thread; + } + + public static Matcher sleep( long duration ) + { + return sleep( any( Thread.class ), any( Long.class ), equalTo( duration ) ); + } + + public static Matcher sleep( + final Matcher thread, + final Matcher timestamp, + final Matcher duration ) + { + return new TypeSafeMatcher() + { + @Override + public void describeTo( Description description ) + { + description.appendText( "Sleep Event on thread <" ) + .appendDescriptionOf( thread ) + .appendText( "> at timestamp " ) + .appendDescriptionOf( timestamp ) + .appendText( " for duration " ) + .appendDescriptionOf( timestamp ) + .appendText( " (in milliseconds)" ); + } + + @Override + protected boolean matchesSafely( Sleep event ) + { + return thread.matches( event.thread ) + && timestamp.matches( event.timestamp ) + && duration.matches( event.duration ); + } + }; + } + + public static Matcher progress( final Matcher thread, final Matcher timestamp ) + { + return new TypeSafeMatcher() + { + @Override + public void describeTo( Description description ) + { + description.appendText( "Time progresses to timestamp " ) + .appendDescriptionOf( timestamp ) + .appendText( " by thread <" ) + .appendDescriptionOf( thread ) + .appendText( ">" ); + } + + @Override + protected boolean matchesSafely( Progress event ) + { + return thread.matches( event.thread ) && timestamp.matches( event.timestamp ); + } + }; + } + + private static class Sleep extends Event + { + private final long timestamp, duration; + + Sleep( Thread thread, long timestamp, long duration ) + { + super( thread ); + this.timestamp = timestamp; + this.duration = duration; + } + + @Override + public void dispatch( EventSink sink ) + { + sink.sleep( timestamp, duration ); + } + } + + private static class Progress extends Event + { + private final long timestamp; + + Progress( Thread thread, long timestamp ) + { + super( thread ); + this.timestamp = timestamp; + } + + @Override + public void dispatch( EventSink sink ) + { + sink.progress( timestamp ); + } + } + } + + private static class WaitingThread + { + final Thread thread; + final long timestamp; + + private WaitingThread( Thread thread, long timestamp ) + { + this.thread = thread; + this.timestamp = timestamp; + } + } +} diff --git a/driver/src/test/java/org/neo4j/driver/internal/util/MatcherFactory.java b/driver/src/test/java/org/neo4j/driver/internal/util/MatcherFactory.java new file mode 100644 index 0000000000..3d1a8878f6 --- /dev/null +++ b/driver/src/test/java/org/neo4j/driver/internal/util/MatcherFactory.java @@ -0,0 +1,230 @@ +/** + * Copyright (c) 2002-2016 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal.util; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; + +import org.hamcrest.BaseMatcher; +import org.hamcrest.Description; +import org.hamcrest.Matcher; +import org.hamcrest.SelfDescribing; +import org.hamcrest.TypeSafeDiagnosingMatcher; +import org.hamcrest.TypeSafeMatcher; + +public abstract class MatcherFactory implements SelfDescribing +{ + public abstract Matcher createMatcher(); + + /** + * Matches a collection based on the number of elements in the collection that match a given matcher. + * + * @param matcher + * The matcher used for counting matching elements. + * @param count + * The matcher used for evaluating the number of matching elements. + * @param + * The type of elements in the collection. + * @return A matcher for a collection. + */ + public static Matcher> count( final Matcher matcher, final Matcher count ) + { + return new TypeSafeDiagnosingMatcher>() + { + @Override + protected boolean matchesSafely( Iterable collection, Description mismatchDescription ) + { + int matches = 0; + for ( T item : collection ) + { + if ( matcher.matches( item ) ) + { + matches++; + } + } + if ( count.matches( matches ) ) + { + return true; + } + mismatchDescription.appendText( "actual number of matches was " ).appendValue( matches ) + .appendText( " in " ).appendValue( collection ); + return false; + } + + @Override + public void describeTo( Description description ) + { + description.appendText( "collection containing " ) + .appendDescriptionOf( count ) + .appendText( " occurences of " ) + .appendDescriptionOf( matcher ); + } + }; + } + + /** + * Matches a collection that contains elements that match the specified matchers. The elements must be in the same + * order as the given matchers, but the collection may contain other elements in between the matching elements. + * + * @param matchers + * The matchers for the elements of the collection. + * @param + * The type of the elements in the collection. + * @return A matcher for a collection. + */ + @SafeVarargs + public static Matcher> containsAtLeast( final Matcher... matchers ) + { + @SuppressWarnings( "unchecked" ) + MatcherFactory[] factories = new MatcherFactory[matchers.length]; + for ( int i = 0; i < factories.length; i++ ) + { + factories[i] = matches( matchers[i] ); + } + return containsAtLeast( factories ); + } + + @SafeVarargs + public static Matcher> containsAtLeast( final MatcherFactory... matcherFactories ) + { + return new TypeSafeMatcher>() + { + @Override + protected boolean matchesSafely( Iterable collection ) + { + @SuppressWarnings( "unchecked" ) + Matcher[] matchers = new Matcher[matcherFactories.length]; + for ( int i = 0; i < matchers.length; i++ ) + { + matchers[i] = matcherFactories[i].createMatcher(); + } + int i = 0; + for ( T item : collection ) + { + if ( i >= matchers.length ) + { + return true; + } + if ( matchers[i].matches( item ) ) + { + i++; + } + } + return i == matchers.length; + } + + @Override + public void describeTo( Description description ) + { + description.appendText( "collection containing at least " ); + for ( int i = 0; i < matcherFactories.length; i++ ) + { + if ( i != 0 ) + { + if ( i == matcherFactories.length - 1 ) + { + description.appendText( " and " ); + } + else + { + description.appendText( ", " ); + } + } + description.appendDescriptionOf( matcherFactories[i] ); + } + description.appendText( " (in that order) " ); + } + }; + } + + @SafeVarargs + public static MatcherFactory inAnyOrder( final Matcher... matchers ) + { + return new MatcherFactory() + { + @Override + public Matcher createMatcher() + { + final List> remaining = new ArrayList<>( matchers.length ); + Collections.addAll( remaining, matchers ); + return new BaseMatcher() + { + @Override + public boolean matches( Object item ) + { + for ( Iterator> matcher = remaining.iterator(); matcher.hasNext(); ) + { + if ( matcher.next().matches( item ) ) + { + matcher.remove(); + return remaining.isEmpty(); + } + } + return remaining.isEmpty(); + } + + @Override + public void describeTo( Description description ) + { + describe( description ); + } + }; + } + + @Override + public void describeTo( Description description ) + { + describe( description ); + } + + private void describe( Description description ) + { + description.appendText( "in any order" ); + String sep = " {"; + for ( Matcher matcher : matchers ) + { + description.appendText( sep ); + description.appendDescriptionOf( matcher ); + sep = ", "; + } + description.appendText( "}" ); + } + }; + } + + public static MatcherFactory matches( final Matcher matcher ) + { + return new MatcherFactory() + { + @Override + public Matcher createMatcher() + { + return matcher; + } + + @Override + public void describeTo( Description description ) + { + matcher.describeTo( description ); + } + }; + } +} diff --git a/driver/src/test/java/org/neo4j/driver/v1/EventLogger.java b/driver/src/test/java/org/neo4j/driver/v1/EventLogger.java new file mode 100644 index 0000000000..084131a568 --- /dev/null +++ b/driver/src/test/java/org/neo4j/driver/v1/EventLogger.java @@ -0,0 +1,225 @@ +/** + * Copyright (c) 2002-2016 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.v1; + +import org.hamcrest.Description; +import org.hamcrest.Matcher; +import org.hamcrest.TypeSafeMatcher; + +import org.neo4j.driver.internal.Event; +import org.neo4j.driver.internal.EventHandler; + +import static java.util.Objects.requireNonNull; + +public class EventLogger implements Logger +{ + public static Logging provider( EventHandler events, Level level ) + { + return provider( sink( requireNonNull( events, "events" ) ), level ); + } + + public static Logging provider( final Sink events, final Level level ) + { + requireNonNull( events, "events" ); + requireNonNull( level, "level" ); + return new Logging() + { + @Override + public Logger getLog( String name ) + { + return new EventLogger( events, name, level ); + } + }; + } + + public interface Sink + { + void log( String name, Level level, Throwable cause, String message, Object... params ); + } + + private final boolean debug, trace; + private final Sink events; + private final String name; + + public EventLogger( EventHandler events, String name, Level level ) + { + this( sink( requireNonNull( events, "events" ) ), name, level ); + } + + public EventLogger( Sink events, String name, Level level ) + { + this.events = requireNonNull( events, "events" ); + this.name = name; + level = requireNonNull( level, "level" ); + this.debug = Level.DEBUG.compareTo( level ) <= 0; + this.trace = Level.TRACE.compareTo( level ) <= 0; + } + + private static Sink sink( final EventHandler events ) + { + return new Sink() + { + @Override + public void log( String name, Level level, Throwable cause, String message, Object... params ) + { + events.add( new Entry( Thread.currentThread(), name, level, cause, message, params ) ); + } + }; + } + + public enum Level + { + ERROR, + WARN, + INFO, + DEBUG, + TRACE + } + + @Override + public void error( String message, Throwable cause ) + { + events.log( name, Level.ERROR, cause, message ); + } + + @Override + public void info( String message, Object... params ) + { + events.log( name, Level.INFO, null, message, params ); + } + + @Override + public void warn( String message, Object... params ) + { + events.log( name, Level.WARN, null, message, params ); + } + + @Override + public void debug( String message, Object... params ) + { + events.log( name, Level.DEBUG, null, message, params ); + } + + @Override + public void trace( String message, Object... params ) + { + events.log( name, Level.TRACE, null, message, params ); + } + + @Override + public boolean isTraceEnabled() + { + return trace; + } + + @Override + public boolean isDebugEnabled() + { + return debug; + } + + public static final class Entry extends Event + { + private final Thread thread; + private final String name; + private final Level level; + private final Throwable cause; + private final String message; + private final Object[] params; + + private Entry( Thread thread, String name, Level level, Throwable cause, String message, Object... params ) + { + this.thread = thread; + this.name = name; + this.level = requireNonNull( level, "level" ); + this.cause = cause; + this.message = message; + this.params = params; + } + + @Override + public void dispatch( Sink sink ) + { + sink.log( name, level, cause, message, params ); + } + + private String formatted() + { + return params == null ? message : String.format( message, params ); + } + + public static Matcher logEntry( + final Matcher thread, + final Matcher name, + final Matcher level, + final Matcher cause, + final Matcher message, + final Matcher params ) + { + return new TypeSafeMatcher() + { + @Override + protected boolean matchesSafely( Entry entry ) + { + return level.matches( entry.level ) && + thread.matches( entry.thread ) && + name.matches( entry.name ) && + cause.matches( entry.cause ) && + message.matches( entry.message ) && + params.matches( entry.params ); + } + + @Override + public void describeTo( Description description ) + { + description.appendText( "Log entry where level " ) + .appendDescriptionOf( level ) + .appendText( " name <" ) + .appendDescriptionOf( name ) + .appendText( "> cause <" ) + .appendDescriptionOf( cause ) + .appendText( "> message <" ) + .appendDescriptionOf( message ) + .appendText( "> and parameters <" ) + .appendDescriptionOf( params ) + .appendText( ">" ); + } + }; + } + + public static Matcher message( final Matcher level, final Matcher message ) + { + return new TypeSafeMatcher() + { + @Override + protected boolean matchesSafely( Entry entry ) + { + return level.matches( entry.level ) && message.matches( entry.formatted() ); + } + + @Override + public void describeTo( Description description ) + { + description.appendText( "Log entry where level " ).appendDescriptionOf( level ) + .appendText( " and formatted message " ).appendDescriptionOf( message ); + } + }; + } + } +}