diff --git a/driver/src/main/java/org/neo4j/driver/internal/ClusterView.java b/driver/src/main/java/org/neo4j/driver/internal/ClusterView.java deleted file mode 100644 index 3e378abb08..0000000000 --- a/driver/src/main/java/org/neo4j/driver/internal/ClusterView.java +++ /dev/null @@ -1,175 +0,0 @@ -/** - * Copyright (c) 2002-2016 "Neo Technology," - * Network Engine for Objects in Lund AB [http://neotechnology.com] - * - * This file is part of Neo4j. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.neo4j.driver.internal; - -import java.util.Collections; -import java.util.Comparator; -import java.util.HashSet; -import java.util.List; -import java.util.Set; - -import org.neo4j.driver.internal.net.BoltServerAddress; -import org.neo4j.driver.internal.util.Clock; -import org.neo4j.driver.internal.util.ConcurrentRoundRobinSet; -import org.neo4j.driver.v1.Logger; - -/** - * Defines a snapshot view of the cluster. - */ -class ClusterView -{ - private final static Comparator COMPARATOR = new Comparator() - { - @Override - public int compare( BoltServerAddress o1, BoltServerAddress o2 ) - { - int compare = o1.host().compareTo( o2.host() ); - if ( compare == 0 ) - { - compare = Integer.compare( o1.port(), o2.port() ); - } - - return compare; - } - }; - - private static final int MIN_ROUTERS = 1; - - private final ConcurrentRoundRobinSet routingServers = - new ConcurrentRoundRobinSet<>( COMPARATOR ); - private final ConcurrentRoundRobinSet readServers = - new ConcurrentRoundRobinSet<>( COMPARATOR ); - private final ConcurrentRoundRobinSet writeServers = - new ConcurrentRoundRobinSet<>( COMPARATOR ); - private final Clock clock; - private final long expires; - private final Logger log; - - public ClusterView( long expires, Clock clock, Logger log ) - { - this.expires = expires; - this.clock = clock; - this.log = log; - } - - public void addRouter( BoltServerAddress router ) - { - this.routingServers.add( router ); - } - - public boolean isStale() - { - return expires < clock.millis() || - routingServers.size() <= MIN_ROUTERS || - readServers.isEmpty() || - writeServers.isEmpty(); - } - - Set all() - { - HashSet all = - new HashSet<>( routingServers.size() + readServers.size() + writeServers.size() ); - all.addAll( routingServers ); - all.addAll( readServers ); - all.addAll( writeServers ); - return all; - } - - - public BoltServerAddress nextRouter() - { - return routingServers.hop(); - } - - public BoltServerAddress nextReader() - { - return readServers.hop(); - } - - public BoltServerAddress nextWriter() - { - return writeServers.hop(); - } - - public void addReaders( List addresses ) - { - readServers.addAll( addresses ); - } - - public void addWriters( List addresses ) - { - writeServers.addAll( addresses ); - } - - public void addRouters( List addresses ) - { - routingServers.addAll( addresses ); - } - - public void remove( BoltServerAddress address ) - { - if ( routingServers.remove( address ) ) - { - log.debug( "Removing %s from routers", address.toString() ); - } - if ( readServers.remove( address ) ) - { - log.debug( "Removing %s from readers", address.toString() ); - } - if ( writeServers.remove( address ) ) - { - log.debug( "Removing %s from writers", address.toString() ); - } - } - - public boolean removeWriter( BoltServerAddress address ) - { - return writeServers.remove( address ); - } - - public int numberOfRouters() - { - return routingServers.size(); - } - - public int numberOfReaders() - { - return readServers.size(); - } - - public int numberOfWriters() - { - return writeServers.size(); - } - - public Set routingServers() - { - return Collections.unmodifiableSet( routingServers ); - } - - public Set readServers() - { - return Collections.unmodifiableSet( readServers ); - } - - public Set writeServers() - { - return Collections.unmodifiableSet( writeServers ); - } -} diff --git a/driver/src/main/java/org/neo4j/driver/internal/NetworkSession.java b/driver/src/main/java/org/neo4j/driver/internal/NetworkSession.java index 93cd1c558e..3cf91d468a 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/NetworkSession.java +++ b/driver/src/main/java/org/neo4j/driver/internal/NetworkSession.java @@ -113,6 +113,11 @@ public StatementResult run( String statementText, Value statementParameters ) public StatementResult run( Statement statement ) { ensureConnectionIsValidBeforeRunningSession(); + return run( connection, statement ); + } + + public static StatementResult run( Connection connection, Statement statement ) + { InternalStatementResult cursor = new InternalStatementResult( connection, null, statement ); connection.run( statement.text(), statement.parameters().asMap( Values.ofValue() ), cursor.runResponseCollector() ); diff --git a/driver/src/main/java/org/neo4j/driver/internal/RoutingDriver.java b/driver/src/main/java/org/neo4j/driver/internal/RoutingDriver.java index 38cc595144..8705f08d53 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/RoutingDriver.java +++ b/driver/src/main/java/org/neo4j/driver/internal/RoutingDriver.java @@ -18,9 +18,10 @@ */ package org.neo4j.driver.internal; -import java.util.List; import java.util.Set; +import org.neo4j.driver.internal.cluster.LoadBalancer; +import org.neo4j.driver.internal.cluster.RoutingSettings; import org.neo4j.driver.internal.net.BoltServerAddress; import org.neo4j.driver.internal.security.SecurityPlan; import org.neo4j.driver.internal.spi.Connection; @@ -28,209 +29,25 @@ import org.neo4j.driver.internal.util.Clock; import org.neo4j.driver.v1.AccessMode; import org.neo4j.driver.v1.Logging; -import org.neo4j.driver.v1.Record; import org.neo4j.driver.v1.Session; -import org.neo4j.driver.v1.StatementResult; -import org.neo4j.driver.v1.Value; import org.neo4j.driver.v1.exceptions.ClientException; -import org.neo4j.driver.v1.exceptions.ServiceUnavailableException; -import org.neo4j.driver.v1.exceptions.SessionExpiredException; -import org.neo4j.driver.v1.util.Function; import static java.lang.String.format; public class RoutingDriver extends BaseDriver { - private static final String GET_SERVERS = "dbms.cluster.routing.getServers"; - private static final long MAX_TTL = Long.MAX_VALUE / 1000L; + private final LoadBalancer loadBalancer; - private final ConnectionPool connections; - private final Function sessionProvider; - private final Clock clock; - private ClusterView clusterView; - - - public RoutingDriver( BoltServerAddress seedAddress, + public RoutingDriver( + RoutingSettings settings, + BoltServerAddress seedAddress, ConnectionPool connections, SecurityPlan securityPlan, - Function sessionProvider, Clock clock, Logging logging ) { super( securityPlan, logging ); - this.connections = connections; - this.sessionProvider = sessionProvider; - this.clock = clock; - this.clusterView = new ClusterView( 0L, clock, log ); - this.clusterView.addRouter( seedAddress ); - checkServers(); - } - - private synchronized void checkServers() - { - if ( clusterView.isStale() ) - { - Set oldAddresses = clusterView.all(); - ClusterView newView = newClusterView(); - Set newAddresses = newView.all(); - - oldAddresses.removeAll( newAddresses ); - for ( BoltServerAddress boltServerAddress : oldAddresses ) - { - connections.purge( boltServerAddress ); - } - - this.clusterView = newView; - } - } - - private long calculateNewExpiry( Record record ) - { - long ttl = record.get( "ttl" ).asLong(); - long nextExpiry = clock.millis() + 1000L * ttl; - if ( ttl < 0 || ttl >= MAX_TTL || nextExpiry < 0 ) - { - return Long.MAX_VALUE; - } - else - { - return nextExpiry; - } - } - - private ClusterView newClusterView() - { - BoltServerAddress address = null; - for ( int i = 0; i < clusterView.numberOfRouters(); i++ ) - { - address = clusterView.nextRouter(); - ClusterView newClusterView; - try - { - newClusterView = call( address, GET_SERVERS, new Function() - - { - @Override - public ClusterView apply( Record record ) - { - long expire = calculateNewExpiry( record ); - ClusterView newClusterView = new ClusterView( expire, clock, log ); - List servers = servers( record ); - for ( ServerInfo server : servers ) - { - switch ( server.role() ) - { - case "READ": - newClusterView.addReaders( server.addresses() ); - break; - case "WRITE": - newClusterView.addWriters( server.addresses() ); - break; - case "ROUTE": - newClusterView.addRouters( server.addresses() ); - break; - } - } - return newClusterView; - } - } ); - } - catch ( Throwable t ) - { - forget( address ); - continue; - } - - if ( newClusterView.numberOfRouters() != 0 ) - { - return newClusterView; - } - } - - - //discovery failed, not much to do, stick with what we've got - //this may happen because server is running in standalone mode - this.close(); - throw new ServiceUnavailableException( - String.format( "Server %s couldn't perform discovery", - address == null ? "`UNKNOWN`" : address.toString() ) ); - - } - - private static class ServerInfo - { - private final List addresses; - private final String role; - - public ServerInfo( List addresses, String role ) - { - this.addresses = addresses; - this.role = role; - } - - public String role() - { - return role; - } - - List addresses() - { - return addresses; - } - } - - private List servers( Record record ) - { - return record.get( "servers" ).asList( new Function() - { - @Override - public ServerInfo apply( Value value ) - { - return new ServerInfo( value.get( "addresses" ).asList( new Function() - { - @Override - public BoltServerAddress apply( Value value ) - { - return new BoltServerAddress( value.asString() ); - } - } ), value.get( "role" ).asString() ); - } - } ); - } - - //must be called from a synchronized method - private T call( BoltServerAddress address, String procedureName, Function recorder ) - { - Connection acquire; - Session session = null; - try - { - acquire = connections.acquire( address ); - session = sessionProvider.apply( acquire ); - - StatementResult records = session.run( format( "CALL %s", procedureName ) ); - //got a result but was empty - if ( !records.hasNext() ) - { - forget( address ); - throw new IllegalStateException("Server responded with empty result"); - } - //consume the results - return recorder.apply( records.single() ); - } - finally - { - if ( session != null ) - { - session.close(); - } - } - } - - private synchronized void forget( BoltServerAddress address ) - { - connections.purge( address ); - clusterView.remove(address); + this.loadBalancer = new LoadBalancer( settings, clock, log, connections, seedAddress ); } @Override @@ -243,83 +60,28 @@ public Session session() public Session session( final AccessMode mode ) { Connection connection = acquireConnection( mode ); - return new RoutingNetworkSession( new NetworkSession( connection ), mode, connection.address(), - new RoutingErrorHandler() - { - @Override - public void onConnectionFailure( BoltServerAddress address ) - { - forget( address ); - } - - @Override - public void onWriteFailure( BoltServerAddress address ) - { - clusterView.removeWriter( address ); - } - } ); + return new RoutingNetworkSession( new NetworkSession( connection ), mode, connection.address(), loadBalancer ); } private Connection acquireConnection( AccessMode role ) { - //Potentially rediscover servers if we are not happy with our current knowledge - checkServers(); - switch ( role ) { case READ: - return acquireReadConnection(); + return loadBalancer.acquireReadConnection(); case WRITE: - return acquireWriteConnection(); + return loadBalancer.acquireWriteConnection(); default: throw new ClientException( role + " is not supported for creating new sessions" ); } } - private Connection acquireReadConnection() - { - int numberOfServers = clusterView.numberOfReaders(); - for ( int i = 0; i < numberOfServers; i++ ) - { - BoltServerAddress address = clusterView.nextReader(); - try - { - return connections.acquire( address ); - } - catch ( ServiceUnavailableException e ) - { - forget( address ); - } - } - - throw new SessionExpiredException( "Failed to connect to any read server" ); - } - - private Connection acquireWriteConnection() - { - int numberOfServers = clusterView.numberOfWriters(); - for ( int i = 0; i < numberOfServers; i++ ) - { - BoltServerAddress address = clusterView.nextWriter(); - try - { - return connections.acquire( address ); - } - catch ( ServiceUnavailableException e ) - { - forget( address ); - } - } - - throw new SessionExpiredException( "Failed to connect to any write server" ); - } - @Override public void close() { try { - connections.close(); + loadBalancer.close(); } catch ( Exception ex ) { @@ -327,28 +89,27 @@ public void close() } } - //For testing - public Set routingServers() + Set routingServers() { - return clusterView.routingServers(); + // TODO: the tests that use this should be testing for effect instead + throw new UnsupportedOperationException( "not implemented" ); } - //For testing - public Set readServers() + Set readServers() { - return clusterView.readServers(); + // TODO: the tests that use this should be testing for effect instead + throw new UnsupportedOperationException( "not implemented" ); } - //For testing - public Set writeServers() + Set writeServers() { - return clusterView.writeServers( ); + // TODO: the tests that use this should be testing for effect instead + throw new UnsupportedOperationException( "not implemented" ); } - //For testing - public ConnectionPool connectionPool() + ConnectionPool connectionPool() { - return connections; + // TODO: the tests that use this should be testing for effect instead, perhaps by injecting a pool delegate + throw new UnsupportedOperationException( "not implemented" ); } - } diff --git a/driver/src/main/java/org/neo4j/driver/internal/RoutingErrorHandler.java b/driver/src/main/java/org/neo4j/driver/internal/RoutingErrorHandler.java index ca178b46d1..944e1eddd4 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/RoutingErrorHandler.java +++ b/driver/src/main/java/org/neo4j/driver/internal/RoutingErrorHandler.java @@ -23,7 +23,7 @@ /** * Interface used for tracking errors when connected to a cluster. */ -interface RoutingErrorHandler +public interface RoutingErrorHandler { void onConnectionFailure( BoltServerAddress address ); diff --git a/driver/src/main/java/org/neo4j/driver/internal/cluster/ClusterComposition.java b/driver/src/main/java/org/neo4j/driver/internal/cluster/ClusterComposition.java new file mode 100644 index 0000000000..acfe0069bf --- /dev/null +++ b/driver/src/main/java/org/neo4j/driver/internal/cluster/ClusterComposition.java @@ -0,0 +1,203 @@ +/** + * Copyright (c) 2002-2016 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal.cluster; + +import java.util.HashSet; +import java.util.Set; + +import org.neo4j.driver.internal.NetworkSession; +import org.neo4j.driver.internal.net.BoltServerAddress; +import org.neo4j.driver.internal.spi.Connection; +import org.neo4j.driver.internal.util.Clock; +import org.neo4j.driver.v1.Record; +import org.neo4j.driver.v1.Statement; +import org.neo4j.driver.v1.StatementResult; +import org.neo4j.driver.v1.Value; +import org.neo4j.driver.v1.exceptions.ServiceUnavailableException; +import org.neo4j.driver.v1.exceptions.value.ValueException; +import org.neo4j.driver.v1.util.Function; + +final class ClusterComposition +{ + interface Provider + { + String GET_SERVERS = "CALL dbms.cluster.routing.getServers"; + + ClusterComposition getClusterComposition( Connection connection ) throws ServiceUnavailableException; + + final class Default implements Provider + { + private static final Statement GET_SERVER = new Statement( Provider.GET_SERVERS ); + private final Clock clock; + + Default( Clock clock ) + { + this.clock = clock; + } + + @Override + public ClusterComposition getClusterComposition( Connection connection ) throws ServiceUnavailableException + { + StatementResult cursor = getServers( connection ); + long now = clock.millis(); + try + { + if ( !cursor.hasNext() ) + { + return null; // server returned too few rows, this is a contract violation, treat as incapable + } + Record record = cursor.next(); + if ( cursor.hasNext() ) + { + return null; // server returned too many rows, this is a contract violation, treat as incapable + } + return read( record, now ); + } + finally + { + cursor.consume(); // make sure we exhaust the results + } + } + + private StatementResult getServers( Connection connection ) + { + return NetworkSession.run( connection, GET_SERVER ); + } + } + } + + private static final long MAX_TTL = Long.MAX_VALUE / 1000L; + private static final Function OF_BoltServerAddress = + new Function() + { + @Override + public BoltServerAddress apply( Value value ) + { + return new BoltServerAddress( value.asString() ); + } + }; + private final Set readers, writers, routers; + final long expirationTimestamp; + + private ClusterComposition( long expirationTimestamp ) + { + this.readers = new HashSet<>(); + this.writers = new HashSet<>(); + this.routers = new HashSet<>(); + this.expirationTimestamp = expirationTimestamp; + } + + /** For testing */ + ClusterComposition( + long expirationTimestamp, + Set readers, + Set writers, + Set routers ) + { + this( expirationTimestamp ); + this.readers.addAll( readers ); + this.writers.addAll( writers ); + this.routers.addAll( routers ); + } + + public boolean isValid() + { + return !routers.isEmpty() && !writers.isEmpty(); + } + + public Set readers() + { + return new HashSet<>( readers ); + } + + public Set writers() + { + return new HashSet<>( writers ); + } + + public Set routers() + { + return new HashSet<>( routers ); + } + + @Override + public String toString() + { + return "ClusterComposition{" + + "expirationTimestamp=" + expirationTimestamp + + ", readers=" + readers + + ", writers=" + writers + + ", routers=" + routers + + '}'; + } + + private static ClusterComposition read( Record record, long now ) + { + if ( record == null ) + { + return null; + } + try + { + final ClusterComposition result; + result = new ClusterComposition( expirationTimestamp( now, record ) ); + record.get( "servers" ).asList( new Function() + { + @Override + public Void apply( Value value ) + { + result.servers( value.get( "role" ).asString() ) + .addAll( value.get( "addresses" ).asList( OF_BoltServerAddress ) ); + return null; + } + } ); + return result; + } + catch ( ValueException e ) + { + return null; + } + } + + private static long expirationTimestamp( long now, Record record ) + { + long ttl = record.get( "ttl" ).asLong(); + long expirationTimestamp = now + ttl * 1000; + if ( ttl < 0 || ttl >= MAX_TTL || expirationTimestamp < 0 ) + { + expirationTimestamp = Long.MAX_VALUE; + } + return expirationTimestamp; + } + + private Set servers( String role ) + { + switch ( role ) + { + case "READ": + return readers; + case "WRITE": + return writers; + case "ROUTE": + return routers; + default: + throw new IllegalArgumentException( "invalid server role: " + role ); + } + } +} diff --git a/driver/src/main/java/org/neo4j/driver/internal/cluster/LoadBalancer.java b/driver/src/main/java/org/neo4j/driver/internal/cluster/LoadBalancer.java new file mode 100644 index 0000000000..ac2a83815a --- /dev/null +++ b/driver/src/main/java/org/neo4j/driver/internal/cluster/LoadBalancer.java @@ -0,0 +1,228 @@ +/** + * Copyright (c) 2002-2016 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal.cluster; + +import java.util.HashSet; + +import org.neo4j.driver.internal.RoutingErrorHandler; +import org.neo4j.driver.internal.net.BoltServerAddress; +import org.neo4j.driver.internal.spi.Connection; +import org.neo4j.driver.internal.spi.ConnectionPool; +import org.neo4j.driver.internal.util.Clock; +import org.neo4j.driver.v1.Logger; +import org.neo4j.driver.v1.exceptions.ServiceUnavailableException; + +import static java.util.Arrays.asList; + +public final class LoadBalancer implements RoutingErrorHandler, AutoCloseable +{ + private static final int MIN_ROUTERS = 1; + private static final String NO_ROUTERS_AVAILABLE = "Could not perform discovery. No routing servers available."; + // dependencies + private final RoutingSettings settings; + private final Clock clock; + private final Logger log; + private final ConnectionPool connections; + private final ClusterComposition.Provider provider; + // state + private long expirationTimeout; + private final RoundRobinAddressSet readers, writers, routers; + + public LoadBalancer( + RoutingSettings settings, + Clock clock, + Logger log, + ConnectionPool connections, + BoltServerAddress... routingAddresses ) throws ServiceUnavailableException + { + this( settings, clock, log, connections, new ClusterComposition.Provider.Default( clock ), routingAddresses ); + } + + LoadBalancer( + RoutingSettings settings, + Clock clock, + Logger log, + ConnectionPool connections, + ClusterComposition.Provider provider, + BoltServerAddress... routingAddresses ) throws ServiceUnavailableException + { + this.clock = clock; + this.log = log; + this.connections = connections; + this.expirationTimeout = clock.millis() - 1; + this.provider = provider; + this.settings = settings; + this.readers = new RoundRobinAddressSet(); + this.writers = new RoundRobinAddressSet(); + this.routers = new RoundRobinAddressSet(); + routers.update( new HashSet<>( asList( routingAddresses ) ), new HashSet() ); + // initialize the routing table + ensureRouting(); + } + + public Connection acquireReadConnection() throws ServiceUnavailableException + { + return acquireConnection( readers ); + } + + public Connection acquireWriteConnection() throws ServiceUnavailableException + { + return acquireConnection( writers ); + } + + @Override + public void onConnectionFailure( BoltServerAddress address ) + { + forget( address ); + } + + @Override + public void onWriteFailure( BoltServerAddress address ) + { + writers.remove( address ); + } + + @Override + public void close() throws Exception + { + connections.close(); + } + + private Connection acquireConnection( RoundRobinAddressSet servers ) throws ServiceUnavailableException + { + for ( ; ; ) + { + // refresh the routing table if needed + ensureRouting(); + for ( BoltServerAddress address; (address = servers.next()) != null; ) + { + try + { + return connections.acquire( address ); + } + catch ( ServiceUnavailableException e ) + { + forget( address ); + } + } + // if we get here, we failed to connect to any server, so we will rebuild the routing table + } + } + + private synchronized void ensureRouting() throws ServiceUnavailableException + { + if ( stale() ) + { + try + { + // get a new routing table + ClusterComposition cluster = lookupRoutingTable(); + expirationTimeout = cluster.expirationTimestamp; + HashSet removed = new HashSet<>(); + readers.update( cluster.readers(), removed ); + writers.update( cluster.writers(), removed ); + routers.update( cluster.routers(), removed ); + // purge connections to removed addresses + for ( BoltServerAddress address : removed ) + { + connections.purge( address ); + } + } + catch ( InterruptedException e ) + { + throw new ServiceUnavailableException( "Thread was interrupted while establishing connection.", e ); + } + } + } + + private ClusterComposition lookupRoutingTable() throws InterruptedException, ServiceUnavailableException + { + int size = routers.size(), failures = 0; + if ( size == 0 ) + { + throw new ServiceUnavailableException( NO_ROUTERS_AVAILABLE ); + } + for ( long start = clock.millis(), delay = 0; ; delay = Math.max( settings.retryTimeoutDelay, delay * 2 ) ) + { + long waitTime = start + delay - clock.millis(); + if ( waitTime > 0 ) + { + clock.sleep( waitTime ); + } + start = clock.millis(); + for ( int i = 0; i < size; i++ ) + { + BoltServerAddress address = routers.next(); + if ( address == null ) + { + throw new ServiceUnavailableException( NO_ROUTERS_AVAILABLE ); + } + ClusterComposition cluster; + try ( Connection connection = connections.acquire( address ) ) + { + cluster = provider.getClusterComposition( connection ); + } + catch ( Exception e ) + { + log.error( String.format( "Failed to connect to routing server '%s'.", address ), e ); + continue; + } + if ( cluster == null || !cluster.isValid() ) + { + log.info( + "Server <%s> unable to perform routing capability, dropping from list of routers.", + address ); + routers.remove( address ); + if ( --size == 0 ) + { + throw new ServiceUnavailableException( NO_ROUTERS_AVAILABLE ); + } + } + else + { + return cluster; + } + } + if ( ++failures > settings.maxRoutingFailures ) + { + throw new ServiceUnavailableException( NO_ROUTERS_AVAILABLE ); + } + } + } + + private synchronized void forget( BoltServerAddress address ) + { + // First remove from the load balancer, to prevent concurrent threads from making connections to them. + // Don't remove it from the set of routers, since that might mean we lose our ability to re-discover, + // just remove it from the set of readers and writers, so that we don't use it for actual work without + // performing discovery first. + readers.remove( address ); + writers.remove( address ); + // drop all current connections to the address + connections.purge( address ); + } + + private boolean stale() + { + return expirationTimeout < clock.millis() || // the expiration timeout has been reached + routers.size() <= MIN_ROUTERS || // we need to discover more routing servers + readers.size() == 0 || // we need to discover more read servers + writers.size() == 0; // we need to discover more write servers + } +} diff --git a/driver/src/main/java/org/neo4j/driver/internal/cluster/RoundRobinAddressSet.java b/driver/src/main/java/org/neo4j/driver/internal/cluster/RoundRobinAddressSet.java new file mode 100644 index 0000000000..acae91704f --- /dev/null +++ b/driver/src/main/java/org/neo4j/driver/internal/cluster/RoundRobinAddressSet.java @@ -0,0 +1,136 @@ +/** + * Copyright (c) 2002-2016 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal.cluster; + +import java.util.Set; +import java.util.concurrent.atomic.AtomicInteger; + +import org.neo4j.driver.internal.net.BoltServerAddress; + +class RoundRobinAddressSet +{ + private static final BoltServerAddress[] NONE = {}; + private final AtomicInteger offset = new AtomicInteger(); + private volatile BoltServerAddress[] addresses = NONE; + + public int size() + { + return addresses.length; + } + + public BoltServerAddress next() + { + BoltServerAddress[] addresses = this.addresses; + if ( addresses.length == 0 ) + { + return null; + } + return addresses[next( addresses.length )]; + } + + int next( int divisor ) + { + int index = offset.getAndIncrement(); + for ( ; index == Integer.MAX_VALUE; index = offset.getAndIncrement() ) + { + offset.compareAndSet( Integer.MIN_VALUE, index % divisor ); + } + return index % divisor; + } + + public synchronized void update( Set addresses, Set removed ) + { + BoltServerAddress[] prev = this.addresses; + if ( addresses.isEmpty() ) + { + this.addresses = NONE; + return; + } + if ( prev.length == 0 ) + { + this.addresses = addresses.toArray( NONE ); + return; + } + BoltServerAddress[] copy = null; + if ( addresses.size() != prev.length ) + { + copy = new BoltServerAddress[addresses.size()]; + } + int j = 0; + for ( int i = 0; i < prev.length; i++ ) + { + if ( addresses.remove( prev[i] ) ) + { + if ( copy != null ) + { + copy[j++] = prev[i]; + } + } + else + { + removed.add( prev[i] ); + if ( copy == null ) + { + copy = new BoltServerAddress[prev.length]; + System.arraycopy( prev, 0, copy, 0, i ); + j = i; + } + } + } + if ( copy == null ) + { + return; + } + for ( BoltServerAddress address : addresses ) + { + copy[j++] = address; + } + this.addresses = copy; + } + + public synchronized void remove( BoltServerAddress address ) + { + BoltServerAddress[] addresses = this.addresses; + if ( addresses != null ) + { + for ( int i = 0; i < addresses.length; i++ ) + { + if ( addresses[i].equals( address ) ) + { + if ( addresses.length == 1 ) + { + this.addresses = NONE; + return; + } + BoltServerAddress[] copy = new BoltServerAddress[addresses.length - 1]; + System.arraycopy( addresses, 0, copy, 0, i ); + System.arraycopy( addresses, i + 1, copy, i, addresses.length - i - 1 ); + this.addresses = copy; + return; + } + } + } + } + + /** breaking encapsulation in order to perform white-box testing of boundary case */ + void setOffset( int target ) + { + offset.set( target ); + } +} diff --git a/driver/src/main/java/org/neo4j/driver/internal/cluster/RoutingSettings.java b/driver/src/main/java/org/neo4j/driver/internal/cluster/RoutingSettings.java new file mode 100644 index 0000000000..0b7fe7957e --- /dev/null +++ b/driver/src/main/java/org/neo4j/driver/internal/cluster/RoutingSettings.java @@ -0,0 +1,31 @@ +/** + * Copyright (c) 2002-2016 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal.cluster; + +public class RoutingSettings +{ + final int maxRoutingFailures; + final long retryTimeoutDelay; + + public RoutingSettings( int maxRoutingFailures, long retryTimeoutDelay ) + { + this.maxRoutingFailures = maxRoutingFailures; + this.retryTimeoutDelay = retryTimeoutDelay; + } +} diff --git a/driver/src/main/java/org/neo4j/driver/internal/util/Clock.java b/driver/src/main/java/org/neo4j/driver/internal/util/Clock.java index 21b4b8497e..e5a5a4d932 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/util/Clock.java +++ b/driver/src/main/java/org/neo4j/driver/internal/util/Clock.java @@ -26,6 +26,8 @@ public interface Clock /** Current time, in milliseconds. */ long millis(); + void sleep( long millis ) throws InterruptedException; + Clock SYSTEM = new Clock() { @Override @@ -33,5 +35,11 @@ public long millis() { return System.currentTimeMillis(); } + + @Override + public void sleep( long millis ) throws InterruptedException + { + Thread.sleep( millis ); + } }; } diff --git a/driver/src/main/java/org/neo4j/driver/internal/util/ConcurrentRoundRobinSet.java b/driver/src/main/java/org/neo4j/driver/internal/util/ConcurrentRoundRobinSet.java deleted file mode 100644 index ffd369c248..0000000000 --- a/driver/src/main/java/org/neo4j/driver/internal/util/ConcurrentRoundRobinSet.java +++ /dev/null @@ -1,152 +0,0 @@ -/** - * Copyright (c) 2002-2016 "Neo Technology," - * Network Engine for Objects in Lund AB [http://neotechnology.com] - * - * This file is part of Neo4j. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.neo4j.driver.internal.util; - -import java.util.Collection; -import java.util.Comparator; -import java.util.Iterator; -import java.util.Set; -import java.util.concurrent.ConcurrentSkipListSet; - -/** - * A set that exposes a method {@link #hop()} that cycles through the members of the set. - * @param the type of elements in the set - */ -public class ConcurrentRoundRobinSet implements Set -{ - private final ConcurrentSkipListSet set; - private T current; - - public ConcurrentRoundRobinSet() - { - set = new ConcurrentSkipListSet<>(); - } - - public ConcurrentRoundRobinSet( Comparator comparator ) - { - set = new ConcurrentSkipListSet<>( comparator ); - } - - public ConcurrentRoundRobinSet(ConcurrentRoundRobinSet original) - { - set = new ConcurrentSkipListSet<>( original.set.comparator() ); - set.addAll( original ); - } - - public T hop() - { - if ( current == null ) - { - current = set.first(); - } - else - { - current = set.higher( current ); - //We've gone through all connections, start over - if ( current == null ) - { - current = set.first(); - } - } - - if ( current == null ) - { - throw new IllegalStateException( "nothing in the set" ); - } - - return current; - } - - @Override - public boolean add( T item ) - { - return set.add( item ); - } - - @Override - public boolean containsAll( Collection c ) - { - return set.containsAll( c ); - } - - @Override - public boolean addAll( Collection c ) - { - return set.addAll( c ); - } - - @Override - public boolean retainAll( Collection c ) - { - return set.retainAll( c ); - } - - @Override - public boolean removeAll( Collection c ) - { - return set.retainAll( c ); - } - - @Override - public void clear() - { - set.clear(); - } - - @Override - public boolean remove( Object o ) - { - return set.remove( o ); - } - - public int size() - { - return set.size(); - } - - public boolean isEmpty() - { - return set.isEmpty(); - } - - @Override - public boolean contains( Object o ) - { - return set.contains( o ); - } - - @Override - public Iterator iterator() - { - return set.iterator(); - } - - @Override - public Object[] toArray() - { - return set.toArray(); - } - - @SuppressWarnings( "SuspiciousToArrayCall" ) - @Override - public T1[] toArray( T1[] a ) - { - return set.toArray( a ); - } -} diff --git a/driver/src/main/java/org/neo4j/driver/v1/Config.java b/driver/src/main/java/org/neo4j/driver/v1/Config.java index 41f743b1dc..fc9514295c 100644 --- a/driver/src/main/java/org/neo4j/driver/v1/Config.java +++ b/driver/src/main/java/org/neo4j/driver/v1/Config.java @@ -19,8 +19,10 @@ package org.neo4j.driver.v1; import java.io.File; +import java.util.concurrent.TimeUnit; import java.util.logging.Level; +import org.neo4j.driver.internal.cluster.RoutingSettings; import org.neo4j.driver.internal.logging.JULogging; import org.neo4j.driver.internal.net.pooling.PoolSettings; import org.neo4j.driver.v1.util.Immutable; @@ -60,6 +62,8 @@ public class Config private final TrustStrategy trustStrategy; private final int minServersInCluster; + private final int maxRoutingFailures; + private final long routingRetryDelayMillis; private Config( ConfigBuilder builder) { @@ -71,6 +75,8 @@ private Config( ConfigBuilder builder) this.encryptionLevel = builder.encryptionLevel; this.trustStrategy = builder.trustStrategy; this.minServersInCluster = builder.minServersInCluster; + this.maxRoutingFailures = builder.maxRoutingFailures; + this.routingRetryDelayMillis = builder.routingRetryDelayMillis; } /** @@ -144,6 +150,11 @@ public static Config defaultConfig() return Config.build().toConfig(); } + RoutingSettings routingSettings() + { + return new RoutingSettings( maxRoutingFailures, routingRetryDelayMillis ); + } + /** * Used to build new config instances */ @@ -156,6 +167,8 @@ public static class ConfigBuilder private TrustStrategy trustStrategy = trustOnFirstUse( new File( getProperty( "user.home" ), ".neo4j" + File.separator + "known_hosts" ) ); public int minServersInCluster = 3; + private int maxRoutingFailures = 10; + private long routingRetryDelayMillis = 5_000; private ConfigBuilder() {} @@ -259,6 +272,51 @@ public ConfigBuilder withTrustStrategy( TrustStrategy trustStrategy ) return this; } + /** + * Specify how many times the client should attempt to reconnect to the routing servers before declaring the + * cluster unavailable. + *

+ * The routing servers are tried in order. If connecting any of them fails, they are all retried after + * {@linkplain #withRoutingRetryDelay a delay}. This process of retrying all servers is then repeated for the + * number of times specified here before considering the cluster unavailable. + * + * @param routingRetryLimit + * the number of times to retry each server in the list of routing servers + * @return this builder + */ + public ConfigBuilder withRoutingRetryLimit( int routingRetryLimit ) + { + this.maxRoutingFailures = routingRetryLimit; + return this; + } + + /** + * Specify how long to wait before retrying to connect to a routing server. + *

+ * When connecting to all routing servers fail, connecting will be retried after the delay specified here. + * The delay is measured from when the first attempt to connect was made, so that the delay time specifies a + * retry interval. + *

+ * For each {@linkplain #withRoutingRetryLimit retry attempt} the delay time will be doubled. The time + * specified here is the base time, i.e. the time to wait before the first retry. If that attempt (on all + * servers) also fails, the delay before the next retry will be double the time specified here, and the next + * attempt after that will be double that, et.c. So if, for example, the delay specified here is + * {@code 5 SECONDS}, then after attempting to connect to each server fails reconnecting will be attempted + * 5 seconds after the first connection attempt to the first server. If that attempt also fails to connect to + * all servers, the next attempt will start 10 seconds after the second attempt started. + * + * @param delay + * the amount of time between attempts to reconnect to the same server + * @param unit + * the unit in which the duration is given + * @return this builder + */ + public ConfigBuilder withRoutingRetryDelay( long delay, TimeUnit unit ) + { + this.routingRetryDelayMillis = unit.toMillis( delay ); + return this; + } + /** * Create a config instance from this builder. * @return a {@link Config} instance diff --git a/driver/src/main/java/org/neo4j/driver/v1/GraphDatabase.java b/driver/src/main/java/org/neo4j/driver/v1/GraphDatabase.java index a031815e5e..8e975b9e31 100644 --- a/driver/src/main/java/org/neo4j/driver/v1/GraphDatabase.java +++ b/driver/src/main/java/org/neo4j/driver/v1/GraphDatabase.java @@ -34,7 +34,6 @@ import org.neo4j.driver.internal.spi.ConnectionPool; import org.neo4j.driver.internal.util.Clock; import org.neo4j.driver.v1.exceptions.ClientException; -import org.neo4j.driver.v1.util.BiFunction; import org.neo4j.driver.v1.util.Function; import static java.lang.String.format; @@ -191,7 +190,13 @@ public static Driver driver( URI uri, AuthToken authToken, Config config ) case "bolt": return new DirectDriver( address, connectionPool, securityPlan, config.logging() ); case "bolt+routing": - return new RoutingDriver( address, connectionPool, securityPlan, SESSION_PROVIDER, Clock.SYSTEM, config.logging() ); + return new RoutingDriver( + config.routingSettings(), + address, + connectionPool, + securityPlan, + Clock.SYSTEM, + config.logging() ); default: throw new ClientException( format( "Unsupported URI scheme: %s", scheme ) ); } diff --git a/driver/src/test/java/org/neo4j/driver/internal/ClusterViewTest.java b/driver/src/test/java/org/neo4j/driver/internal/ClusterViewTest.java deleted file mode 100644 index 2f9b7fb778..0000000000 --- a/driver/src/test/java/org/neo4j/driver/internal/ClusterViewTest.java +++ /dev/null @@ -1,189 +0,0 @@ -/** - * Copyright (c) 2002-2016 "Neo Technology," - * Network Engine for Objects in Lund AB [http://neotechnology.com] - * - * This file is part of Neo4j. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.neo4j.driver.internal; - - -import org.junit.Test; - -import org.neo4j.driver.internal.net.BoltServerAddress; -import org.neo4j.driver.internal.util.Clock; -import org.neo4j.driver.v1.Logger; - -import static java.util.Arrays.asList; -import static java.util.Collections.singletonList; -import static org.hamcrest.CoreMatchers.equalTo; -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.containsInAnyOrder; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertTrue; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - -public class ClusterViewTest -{ - - @Test - public void shouldRoundRobinAmongRoutingServers() - { - // Given - ClusterView clusterView = new ClusterView( 5L, mock( Clock.class ), mock( Logger.class ) ); - - // When - clusterView.addRouters( asList( address("host1"), address( "host2" ), address( "host3" ))); - - // Then - assertThat(clusterView.nextRouter(), equalTo(address( "host1" ))); - assertThat(clusterView.nextRouter(), equalTo(address( "host2" ))); - assertThat(clusterView.nextRouter(), equalTo(address( "host3" ))); - assertThat(clusterView.nextRouter(), equalTo(address( "host1" ))); - } - - @Test - public void shouldRoundRobinAmongReadServers() - { - // Given - ClusterView clusterView = new ClusterView( 5L, mock( Clock.class ), mock( Logger.class ) ); - - // When - clusterView.addReaders( asList( address("host1"), address( "host2" ), address( "host3" ))); - - // Then - assertThat(clusterView.nextReader(), equalTo(address( "host1" ))); - assertThat(clusterView.nextReader(), equalTo(address( "host2" ))); - assertThat(clusterView.nextReader(), equalTo(address( "host3" ))); - assertThat(clusterView.nextReader(), equalTo(address( "host1" ))); - } - - @Test - public void shouldRoundRobinAmongWriteServers() - { - // Given - ClusterView clusterView = new ClusterView( 5L, mock( Clock.class ), mock( Logger.class ) ); - - // When - clusterView.addWriters( asList( address("host1"), address( "host2" ), address( "host3" ))); - - // Then - assertThat(clusterView.nextWriter(), equalTo(address( "host1" ))); - assertThat(clusterView.nextWriter(), equalTo(address( "host2" ))); - assertThat(clusterView.nextWriter(), equalTo(address( "host3" ))); - assertThat(clusterView.nextWriter(), equalTo(address( "host1" ))); - } - - @Test - public void shouldRemoveServer() - { - // Given - ClusterView clusterView = new ClusterView( 5L, mock( Clock.class ), mock( Logger.class ) ); - - clusterView.addRouters( asList( address("host1"), address( "host2" ), address( "host3" ))); - clusterView.addReaders( asList( address("host1"), address( "host2" ), address( "host3" ))); - clusterView.addWriters( asList( address("host2"), address( "host4" ))); - - // When - clusterView.remove( address( "host2" ) ); - - // Then - assertThat(clusterView.routingServers(), containsInAnyOrder(address( "host1" ), address( "host3" ))); - assertThat(clusterView.readServers(), containsInAnyOrder(address( "host1" ), address( "host3" ))); - assertThat(clusterView.writeServers(), containsInAnyOrder(address( "host4" ))); - assertThat(clusterView.all(), containsInAnyOrder( address( "host1" ), address( "host3" ), address( "host4" ) )); - } - - @Test - public void shouldBeStaleIfExpired() - { - // Given - Clock clock = mock( Clock.class ); - when(clock.millis()).thenReturn( 6L ); - ClusterView clusterView = new ClusterView( 5L, clock, mock( Logger.class ) ); - clusterView.addRouters( asList( address("host1"), address( "host2" ), address( "host3" ))); - clusterView.addReaders( asList( address("host1"), address( "host2" ), address( "host3" ))); - clusterView.addWriters( asList( address("host2"), address( "host4" ))); - - // Then - assertTrue(clusterView.isStale()); - } - - @Test - public void shouldNotBeStaleIfNotExpired() - { - // Given - Clock clock = mock( Clock.class ); - when(clock.millis()).thenReturn( 4L ); - ClusterView clusterView = new ClusterView( 5L, clock, mock( Logger.class ) ); - clusterView.addRouters( asList( address("host1"), address( "host2" ), address( "host3" ))); - clusterView.addReaders( asList( address("host1"), address( "host2" ), address( "host3" ))); - clusterView.addWriters( asList( address("host2"), address( "host4" ))); - - // Then - assertFalse(clusterView.isStale()); - } - - @Test - public void shouldBeStaleIfOnlyOneRouter() - { - // Given - Clock clock = mock( Clock.class ); - when(clock.millis()).thenReturn( 4L ); - ClusterView clusterView = new ClusterView( 5L, clock, mock( Logger.class ) ); - clusterView.addRouters( singletonList( address( "host1" ) ) ); - clusterView.addReaders( asList( address("host1"), address( "host2" ), address( "host3" ))); - clusterView.addWriters( asList( address("host2"), address( "host4" ))); - - // When - - // Then - assertTrue(clusterView.isStale()); - } - - @Test - public void shouldBeStaleIfNoReader() - { - // Given - Clock clock = mock( Clock.class ); - when(clock.millis()).thenReturn( 4L ); - ClusterView clusterView = new ClusterView( 5L, clock, mock( Logger.class ) ); - clusterView.addRouters( singletonList( address( "host1" ) ) ); - clusterView.addWriters( asList( address("host2"), address( "host4" ))); - - // Then - assertTrue(clusterView.isStale()); - } - - @Test - public void shouldBeStaleIfNoWriter() - { - // Given - Clock clock = mock( Clock.class ); - when(clock.millis()).thenReturn( 4L ); - ClusterView clusterView = new ClusterView( 5L, clock, mock( Logger.class ) ); - clusterView.addRouters( singletonList( address( "host1" ) ) ); - clusterView.addReaders( asList( address("host1"), address( "host2" ), address( "host3" ))); - - // Then - assertTrue(clusterView.isStale()); - } - - private BoltServerAddress address(String host) - { - return new BoltServerAddress( host ); - } - -} \ No newline at end of file diff --git a/driver/src/test/java/org/neo4j/driver/internal/Event.java b/driver/src/test/java/org/neo4j/driver/internal/Event.java new file mode 100644 index 0000000000..bbf8fbaa1f --- /dev/null +++ b/driver/src/test/java/org/neo4j/driver/internal/Event.java @@ -0,0 +1,64 @@ +/** + * Copyright (c) 2002-2016 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal; + +import java.io.PrintWriter; +import java.io.StringWriter; +import java.lang.reflect.Method; + +public abstract class Event +{ + final Class handlerType; + + @SuppressWarnings( "unchecked" ) + public Event() + { + this.handlerType = (Class) handlerType( getClass() ); + } + + @Override + public String toString() + { + StringWriter res = new StringWriter(); + try ( PrintWriter out = new PrintWriter( res ) ) + { + EventHandler.write( this, out ); + } + return res.toString(); + } + + public abstract void dispatch( Handler handler ); + + private static Class handlerType( Class type ) + { + for ( Class c = type; c != Object.class; c = c.getSuperclass() ) + { + for ( Method method : c.getDeclaredMethods() ) + { + if ( method.getName().equals( "dispatch" ) + && method.getParameterTypes().length == 1 + && !method.isSynthetic() ) + { + return method.getParameterTypes()[0]; + } + } + } + throw new Error( "Cannot determine Handler type from dispatch(Handler) method." ); + } +} diff --git a/driver/src/test/java/org/neo4j/driver/internal/EventHandler.java b/driver/src/test/java/org/neo4j/driver/internal/EventHandler.java new file mode 100644 index 0000000000..bb9b87250d --- /dev/null +++ b/driver/src/test/java/org/neo4j/driver/internal/EventHandler.java @@ -0,0 +1,285 @@ +/** + * Copyright (c) 2002-2016 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal; + +import java.io.PrintStream; +import java.io.PrintWriter; +import java.lang.reflect.Constructor; +import java.lang.reflect.InvocationHandler; +import java.lang.reflect.Method; +import java.lang.reflect.Proxy; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Objects; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.CopyOnWriteArrayList; + +import org.hamcrest.Matcher; + +import org.neo4j.driver.internal.util.MatcherFactory; +import org.neo4j.driver.v1.EventLogger; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.any; +import static org.hamcrest.Matchers.not; +import static org.neo4j.driver.internal.util.MatcherFactory.containsAtLeast; +import static org.neo4j.driver.internal.util.MatcherFactory.count; + +public final class EventHandler +{ + private final List events = new ArrayList<>(); + private final ConcurrentMap,List> handlers = new ConcurrentHashMap<>(); + + public void add( Event event ) + { + synchronized ( events ) + { + events.add( event ); + } + List handlers = this.handlers.get( event.handlerType ); + if ( handlers != null ) + { + for ( Object handler : handlers ) + { + try + { + dispatch( event, handler ); + } + catch ( Exception e ) + { + System.err.println( "Failed to dispatch event: " + event + " to handler: " + handler ); + e.printStackTrace( System.err ); + } + } + } + } + + @SuppressWarnings( "unchecked" ) + public void registerHandler( Class type, Handler handler ) + { + List handlers = this.handlers.get( type ); + if ( handlers == null ) + { + List candidate = new CopyOnWriteArrayList<>(); + handlers = this.handlers.putIfAbsent( type, candidate ); + if ( handlers == null ) + { + handlers = candidate; + } + } + handlers.add( handler ); + } + + @SafeVarargs + public final void assertContains( MatcherFactory... matchers ) + { + synchronized ( events ) + { + assertThat( events, containsAtLeast( (MatcherFactory[]) matchers ) ); + } + } + + @SafeVarargs + public final void assertContains( Matcher... matchers ) + { + synchronized ( events ) + { + assertThat( events, containsAtLeast( (Matcher[]) matchers ) ); + } + } + + public final void assertCount( Matcher matcher, Matcher count ) + { + synchronized ( events ) + { + assertThat( events, (Matcher) count( matcher, count ) ); + } + } + + public void assertNone( Matcher matcher ) + { + synchronized ( events ) + { + assertThat( events, not( containsAtLeast( (Matcher) matcher ) ) ); + } + } + + public void printEvents( PrintStream out ) + { + printEvents( any( Event.class ), out ); + } + + public void printEvents( Matcher matcher, PrintStream out ) + { + try ( PrintWriter writer = new PrintWriter( out ) ) + { + printEvents( matcher, writer ); + } + } + + public void printEvents( PrintWriter out ) + { + printEvents( any( Event.class ), out ); + } + + public void printEvents( Matcher matcher, PrintWriter out ) + { + synchronized ( events ) + { + for ( Event event : events ) + { + if ( matcher.matches( event ) ) + { + write( event, out ); + out.println(); + } + } + } + } + + public void forEach( Object handler ) + { + synchronized ( events ) + { + for ( Event event : events ) + { + if ( event.handlerType.isInstance( handler ) ) + { + dispatch( event, handler ); + } + } + } + } + + private static void dispatch( Event event, Object handler ) + { + event.dispatch( event.handlerType.cast( handler ) ); + } + + static void write( Event event, PrintWriter out ) + { + dispatch( event, proxy( event.handlerType, new WriteHandler( out ) ) ); + } + + private static Handler proxy( Class handlerType, InvocationHandler handler ) + { + if ( handlerType.isInstance( handler ) ) + { + return handlerType.cast( handler ); + } + try + { + return handlerType.cast( proxies.get( handlerType ).newInstance( handler ) ); + } + catch ( RuntimeException e ) + { + throw e; + } + catch ( Exception e ) + { + throw new RuntimeException( e ); + } + } + + private static final ClassValue proxies = new ClassValue() + { + @Override + protected Constructor computeValue( Class type ) + { + Class proxy = Proxy.getProxyClass( type.getClassLoader(), type ); + try + { + return proxy.getConstructor( InvocationHandler.class ); + } + catch ( NoSuchMethodException e ) + { + throw new RuntimeException( e ); + } + } + }; + + private static class WriteHandler implements InvocationHandler, EventLogger.Sink + { + private final PrintWriter out; + + WriteHandler( PrintWriter out ) + { + this.out = out; + } + + @Override + public Object invoke( Object proxy, Method method, Object[] args ) throws Throwable + { + out.append( method.getName() ).append( '(' ); + String sep = " "; + for ( Object arg : args ) + { + out.append( sep ); + if ( arg == null || !arg.getClass().isArray() ) + { + out.append( Objects.toString( arg ) ); + } + else if ( arg instanceof Object[] ) + { + out.append( Arrays.toString( (Object[]) arg ) ); + } + else + { + out.append( Arrays.class.getMethod( "toString", arg.getClass() ).invoke( null, arg ).toString() ); + } + sep = ", "; + } + if ( args.length > 0 ) + { + out.append( ' ' ); + } + out.append( ')' ); + return null; + } + + @Override + public void log( String name, EventLogger.Level level, Throwable cause, String message, Object... params ) + { + out.append( level.name() ).append( ": " ); + if ( params != null ) + { + try + { + out.format( message, params ); + } + catch ( Exception e ) + { + out.format( "InvalidFormat(message=\"%s\", params=%s, failure=%s)", + message, Arrays.toString( params ), e ); + } + out.println(); + } + else + { + out.println( message ); + } + if ( cause != null ) + { + cause.printStackTrace( out ); + } + } + } +} diff --git a/driver/src/test/java/org/neo4j/driver/internal/RoutingDriverTest.java b/driver/src/test/java/org/neo4j/driver/internal/RoutingDriverTest.java index f551a8dbfd..61b81a6d4a 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/RoutingDriverTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/RoutingDriverTest.java @@ -18,47 +18,41 @@ */ package org.neo4j.driver.internal; -import org.hamcrest.Matchers; +import java.util.Collections; +import java.util.Map; + import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; +import org.mockito.internal.stubbing.answers.ThrowsException; import org.mockito.invocation.InvocationOnMock; import org.mockito.stubbing.Answer; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - +import org.neo4j.driver.internal.cluster.RoutingSettings; import org.neo4j.driver.internal.net.BoltServerAddress; +import org.neo4j.driver.internal.spi.Collector; import org.neo4j.driver.internal.spi.Connection; import org.neo4j.driver.internal.spi.ConnectionPool; -import org.neo4j.driver.internal.util.Clock; +import org.neo4j.driver.internal.util.FakeClock; import org.neo4j.driver.v1.AccessMode; -import org.neo4j.driver.v1.Logger; +import org.neo4j.driver.v1.EventLogger; import org.neo4j.driver.v1.Logging; -import org.neo4j.driver.v1.Record; -import org.neo4j.driver.v1.Session; -import org.neo4j.driver.v1.StatementResult; import org.neo4j.driver.v1.Value; import org.neo4j.driver.v1.exceptions.ClientException; -import org.neo4j.driver.v1.exceptions.NoSuchRecordException; import org.neo4j.driver.v1.exceptions.ServiceUnavailableException; -import org.neo4j.driver.v1.summary.ResultSummary; -import org.neo4j.driver.v1.util.Function; import static java.util.Arrays.asList; -import static java.util.Collections.singletonList; -import static org.hamcrest.CoreMatchers.equalTo; -import static org.hamcrest.CoreMatchers.hasItem; -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.containsInAnyOrder; -import static org.hamcrest.core.IsNot.not; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; import static org.mockito.Matchers.any; -import static org.mockito.Matchers.anyString; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.doAnswer; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.neo4j.driver.internal.cluster.ClusterCompositionProviderTest.serverInfo; +import static org.neo4j.driver.internal.cluster.ClusterCompositionProviderTest.withKeys; +import static org.neo4j.driver.internal.cluster.ClusterCompositionProviderTest.withServerList; import static org.neo4j.driver.internal.security.SecurityPlan.insecure; import static org.neo4j.driver.v1.Values.value; @@ -66,359 +60,304 @@ public class RoutingDriverTest { @Rule public ExpectedException exception = ExpectedException.none(); - private static final BoltServerAddress SEED = new BoltServerAddress( "localhost", 7687 ); private static final String GET_SERVERS = "CALL dbms.cluster.routing.getServers"; - private static final List NO_ADDRESSES = Collections.emptyList(); - private final ConnectionPool pool = pool(); + private final EventHandler events = new EventHandler(); + private final FakeClock clock = new FakeClock( events, true ); + private final Logging logging = EventLogger.provider( events, EventLogger.Level.TRACE ); @Test public void shouldDoRoutingOnInitialization() { // Given - final Session session = mock( Session.class ); - when( session.run( GET_SERVERS ) ).thenReturn( - getServers( singletonList( "localhost:1111" ), - singletonList( "localhost:2222" ), - singletonList( "localhost:3333" ) ) ); + ConnectionPool pool = poolWithServers( + 10, + serverInfo( "ROUTE", "localhost:1111" ), + serverInfo( "READ", "localhost:2222" ), + serverInfo( "WRITE", "localhost:3333" ) ); // When - RoutingDriver routingDriver = forSession( session ); + driverWithPool( pool ); // Then - assertThat( routingDriver.routingServers(), - containsInAnyOrder( boltAddress( "localhost", 1111 )) ); - assertThat( routingDriver.readServers(), - containsInAnyOrder( boltAddress( "localhost", 2222 ) ) ); - assertThat( routingDriver.writeServers(), - containsInAnyOrder( boltAddress( "localhost", 3333 ) ) ); - + verify( pool ).acquire( SEED ); } @Test public void shouldDoReRoutingOnSessionAcquisitionIfNecessary() { // Given - final Session session = mock( Session.class ); - when( session.run( GET_SERVERS ) ) - .thenReturn( - getServers( singletonList( "localhost:1111" ), NO_ADDRESSES, NO_ADDRESSES ) ) - .thenReturn( - getServers( singletonList( "localhost:1112" ), - singletonList( "localhost:2222" ), - singletonList( "localhost:3333" ) ) ); - - RoutingDriver routingDriver = forSession( session ); - - assertThat( routingDriver.routingServers(), - containsInAnyOrder( boltAddress( "localhost", 1111 )) ); - assertThat( routingDriver.readServers(), Matchers.empty() ); - assertThat( routingDriver.writeServers(), Matchers.empty() ); - + RoutingDriver routingDriver = driverWithPool( pool( + withServers( 10, serverInfo( "ROUTE", "localhost:1111" ), + serverInfo( "READ" ), + serverInfo( "WRITE", "localhost:5555" ) ), + withServers( 10, serverInfo( "ROUTE", "localhost:1112" ), + serverInfo( "READ", "localhost:2222" ), + serverInfo( "WRITE", "localhost:3333" ) ) ) ); // When - routingDriver.session( AccessMode.READ ); + RoutingNetworkSession writing = (RoutingNetworkSession) routingDriver.session( AccessMode.WRITE ); // Then - assertThat( routingDriver.routingServers(), - containsInAnyOrder( boltAddress( "localhost", 1112 ) )); - assertThat( routingDriver.readServers(), - containsInAnyOrder( boltAddress( "localhost", 2222 ) ) ); - assertThat( routingDriver.writeServers(), - containsInAnyOrder( boltAddress( "localhost", 3333 ) ) ); + assertEquals( boltAddress( "localhost", 3333 ), writing.address() ); } @Test public void shouldNotDoReRoutingOnSessionAcquisitionIfNotNecessary() { // Given - final Session session = mock( Session.class ); - when( session.run( GET_SERVERS ) ) - .thenReturn( - getServers( asList( "localhost:1111", "localhost:1112", "localhost:1113" ), - singletonList( "localhost:2222" ), - singletonList( "localhost:3333" ) ) ) - .thenReturn( - getServers( singletonList( "localhost:5555" ), NO_ADDRESSES, NO_ADDRESSES ) ); - - RoutingDriver routingDriver = forSession( session ); + RoutingDriver routingDriver = driverWithPool( pool( + withServers( 10, serverInfo( "ROUTE", "localhost:1111", "localhost:1112", "localhost:1113" ), + serverInfo( "READ", "localhost:2222" ), + serverInfo( "WRITE", "localhost:3333" ) ), + withServers( 10, serverInfo( "ROUTE", "localhost:5555" ), + serverInfo( "READ", "localhost:5555" ), + serverInfo( "WRITE", "localhost:5555" ) ) ) ); // When - routingDriver.session( AccessMode.WRITE ); + RoutingNetworkSession writing = (RoutingNetworkSession) routingDriver.session( AccessMode.WRITE ); + RoutingNetworkSession reading = (RoutingNetworkSession) routingDriver.session( AccessMode.READ ); // Then - assertThat( routingDriver.routingServers(), - not( hasItem( boltAddress( "localhost", 5555 ) ) ) ); + assertEquals( boltAddress( "localhost", 3333 ), writing.address() ); + assertEquals( boltAddress( "localhost", 2222 ), reading.address() ); } @Test public void shouldFailIfNoRouting() { // Given - final Session session = mock( Session.class ); - when( session.run( GET_SERVERS ) ) - .thenThrow( - new ClientException( "Neo.ClientError.Procedure.ProcedureNotFound", "Procedure not found" ) ); + ConnectionPool pool = pool( new ThrowsException( new ClientException( + "Neo.ClientError.Procedure.ProcedureNotFound", "Procedure not found" ) ) ); - // Expect - exception.expect( ServiceUnavailableException.class ); + // When + try + { + driverWithPool( pool ); + } + // Then + catch ( ServiceUnavailableException e ) + { + assertEquals( "Could not perform discovery. No routing servers available.", e.getMessage() ); + } + } + + @Test + public void shouldFailIfNoRoutersProvided() + { + // Given + ConnectionPool pool = poolWithServers( + 10, + serverInfo( "ROUTE" ), + serverInfo( "READ", "localhost:1111" ), + serverInfo( "WRITE", "localhost:1111" ) ); // When - forSession( session ); + try + { + driverWithPool( pool ); + } + // Then + catch ( ServiceUnavailableException e ) + { + assertEquals( "Could not perform discovery. No routing servers available.", e.getMessage() ); + } } @Test - public void shouldForgetAboutServersOnRerouting() + public void shouldFailIfNoWritersProvided() { // Given - final Session session = mock( Session.class ); - when( session.run( GET_SERVERS ) ) - .thenReturn( - getServers( singletonList( "localhost:1111" ), NO_ADDRESSES, NO_ADDRESSES ) ) - .thenReturn( - getServers( singletonList( "localhost:1112" ), - singletonList( "localhost:2222" ), - singletonList( "localhost:3333" ) ) ); + ConnectionPool pool = poolWithServers( + 10, + serverInfo( "ROUTE", "localhost:1111" ), + serverInfo( "READ", "localhost:1111" ), + serverInfo( "WRITE" ) ); - RoutingDriver routingDriver = forSession( session ); + // When + try + { + driverWithPool( pool ); + } + // Then + catch ( ServiceUnavailableException e ) + { + assertEquals( "Could not perform discovery. No routing servers available.", e.getMessage() ); + } + } - assertThat( routingDriver.routingServers(), - containsInAnyOrder( boltAddress( "localhost", 1111 )) ); + @Test + public void shouldForgetAboutServersOnRerouting() + { + // Given + ConnectionPool pool = pool( + withServers( 10, serverInfo( "ROUTE", "localhost:1111" ), + serverInfo( "READ" ), + serverInfo( "WRITE", "localhost:5555" ) ), + withServers( 10, serverInfo( "ROUTE", "localhost:1112" ), + serverInfo( "READ", "localhost:2222" ), + serverInfo( "WRITE", "localhost:3333" ) ) ); + RoutingDriver routingDriver = driverWithPool( pool ); // When - routingDriver.session( AccessMode.READ ); + RoutingNetworkSession write1 = (RoutingNetworkSession) routingDriver.session( AccessMode.WRITE ); + RoutingNetworkSession write2 = (RoutingNetworkSession) routingDriver.session( AccessMode.WRITE ); // Then - assertThat( routingDriver.routingServers(), - containsInAnyOrder( boltAddress( "localhost", 1112 ) )); - verify( pool ).purge( boltAddress( "localhost", 1111 ) ); + assertEquals( boltAddress( "localhost", 3333 ), write1.address() ); + assertEquals( boltAddress( "localhost", 3333 ), write2.address() ); } @Test public void shouldRediscoverOnTimeout() { // Given - final Session session = mock( Session.class ); - Clock clock = mock( Clock.class ); - when(clock.millis()).thenReturn( 0L, 11000L, 22000L ); - when( session.run( GET_SERVERS ) ) - .thenReturn( - getServers( asList( "localhost:1111", "localhost:1112", "localhost:1113" ), - singletonList( "localhost:2222" ), - singletonList( "localhost:3333" ), 10L/*seconds*/ ) ) - .thenReturn( - getServers( singletonList( "localhost:5555" ), singletonList( "localhost:5555" ), singletonList( "localhost:5555" ) ) ); - - RoutingDriver routingDriver = forSession( session, clock ); + RoutingDriver routingDriver = driverWithPool( pool( + withServers( 10, serverInfo( "ROUTE", "localhost:1111", "localhost:1112", "localhost:1113" ), + serverInfo( "READ", "localhost:2222" ), + serverInfo( "WRITE", "localhost:3333" ) ), + withServers( 60, serverInfo( "ROUTE", "localhost:5555", "localhost:6666" ), + serverInfo( "READ", "localhost:7777" ), + serverInfo( "WRITE", "localhost:8888" ) ) ) ); + + clock.progress( 11_000 ); // When - routingDriver.session( AccessMode.WRITE ); + RoutingNetworkSession writing = (RoutingNetworkSession) routingDriver.session( AccessMode.WRITE ); + RoutingNetworkSession reading = (RoutingNetworkSession) routingDriver.session( AccessMode.READ ); // Then - assertThat( routingDriver.routingServers(), containsInAnyOrder( boltAddress( "localhost", 5555 ) ) ); - assertThat( routingDriver.readServers(), containsInAnyOrder( boltAddress( "localhost", 5555 ) ) ); - assertThat( routingDriver.writeServers(), containsInAnyOrder( boltAddress( "localhost", 5555 ) ) ); + assertEquals( boltAddress( "localhost", 8888 ), writing.address() ); + assertEquals( boltAddress( "localhost", 7777 ), reading.address() ); } @Test - public void shouldNotRediscoverWheNoTimeout() + public void shouldNotRediscoverWhenNoTimeout() { // Given - final Session session = mock( Session.class ); - Clock clock = mock( Clock.class ); - when(clock.millis()).thenReturn( 0L, 9900L, 18800L ); - when( session.run( GET_SERVERS ) ) - .thenReturn( - getServers( asList( "localhost:1111", "localhost:1112", "localhost:1113" ), - singletonList( "localhost:2222" ), - singletonList( "localhost:3333" ), 10L/*seconds*/ ) ) - .thenReturn( - getServers( singletonList( "localhost:5555" ), singletonList( "localhost:5555" ), singletonList( "localhost:5555" ) ) ); - - RoutingDriver routingDriver = forSession( session, clock ); + RoutingDriver routingDriver = driverWithPool( pool( + withServers( 10, serverInfo( "ROUTE", "localhost:1111", "localhost:1112", "localhost:1113" ), + serverInfo( "READ", "localhost:2222" ), + serverInfo( "WRITE", "localhost:3333" ) ), + withServers( 10, serverInfo( "ROUTE", "localhost:5555" ), + serverInfo( "READ", "localhost:5555" ), + serverInfo( "WRITE", "localhost:5555" ) ) ) ); + clock.progress( 9900 ); // When - routingDriver.session( AccessMode.WRITE ); + RoutingNetworkSession writer = (RoutingNetworkSession) routingDriver.session( AccessMode.WRITE ); + RoutingNetworkSession reader = (RoutingNetworkSession) routingDriver.session( AccessMode.READ ); // Then - assertThat( routingDriver.routingServers(), containsInAnyOrder( boltAddress( "localhost", 1111 ), boltAddress( "localhost", 1112 ), boltAddress( "localhost", 1113 ) ) ); - assertThat( routingDriver.readServers(), containsInAnyOrder( boltAddress( "localhost", 2222 ) ) ); - assertThat( routingDriver.writeServers(), containsInAnyOrder( boltAddress( "localhost", 3333 ) ) ); + assertEquals( boltAddress( "localhost", 2222 ), reader.address() ); + assertEquals( boltAddress( "localhost", 3333 ), writer.address() ); } @Test public void shouldRoundRobinAmongReadServers() { // Given - final Session session = mock( Session.class ); - when( session.run( GET_SERVERS ) ).thenReturn( - getServers( asList( "localhost:1111", "localhost:1112" ), - asList( "localhost:2222", "localhost:2223", "localhost:2224" ), - singletonList( "localhost:3333" ) ) ); + RoutingDriver routingDriver = driverWithServers( 60, serverInfo( "ROUTE", "localhost:1111", "localhost:1112" ), + serverInfo( "READ", "localhost:2222", "localhost:2223", "localhost:2224" ), + serverInfo( "WRITE", "localhost:3333" ) ); // When - RoutingDriver routingDriver = forSession( session ); RoutingNetworkSession read1 = (RoutingNetworkSession) routingDriver.session( AccessMode.READ ); RoutingNetworkSession read2 = (RoutingNetworkSession) routingDriver.session( AccessMode.READ ); RoutingNetworkSession read3 = (RoutingNetworkSession) routingDriver.session( AccessMode.READ ); RoutingNetworkSession read4 = (RoutingNetworkSession) routingDriver.session( AccessMode.READ ); - + RoutingNetworkSession read5 = (RoutingNetworkSession) routingDriver.session( AccessMode.READ ); + RoutingNetworkSession read6 = (RoutingNetworkSession) routingDriver.session( AccessMode.READ ); // Then - assertThat(read1.address(), equalTo(boltAddress( "localhost", 2222 ))); - assertThat(read2.address(), equalTo(boltAddress( "localhost", 2223 ))); - assertThat(read3.address(), equalTo(boltAddress( "localhost", 2224 ))); - assertThat(read4.address(), equalTo(boltAddress( "localhost", 2222 ))); - + assertEquals( read1.address(), read4.address() ); + assertEquals( read2.address(), read5.address() ); + assertEquals( read3.address(), read6.address() ); + assertNotEquals( read1.address(), read2.address() ); + assertNotEquals( read2.address(), read3.address() ); + assertNotEquals( read3.address(), read1.address() ); } @Test public void shouldRoundRobinAmongWriteServers() { // Given - final Session session = mock( Session.class ); - when( session.run( GET_SERVERS ) ).thenReturn( - getServers( asList( "localhost:1111", "localhost:1112" ), - singletonList( "localhost:3333" ), asList( "localhost:2222", "localhost:2223", "localhost:2224" ) ) ); + RoutingDriver routingDriver = driverWithServers( 60, serverInfo( "ROUTE", "localhost:1111", "localhost:1112" ), + serverInfo( "READ", "localhost:3333" ), + serverInfo( "WRITE", "localhost:2222", "localhost:2223", "localhost:2224" ) ); // When - RoutingDriver routingDriver = forSession( session ); RoutingNetworkSession write1 = (RoutingNetworkSession) routingDriver.session( AccessMode.WRITE ); RoutingNetworkSession write2 = (RoutingNetworkSession) routingDriver.session( AccessMode.WRITE ); RoutingNetworkSession write3 = (RoutingNetworkSession) routingDriver.session( AccessMode.WRITE ); RoutingNetworkSession write4 = (RoutingNetworkSession) routingDriver.session( AccessMode.WRITE ); - + RoutingNetworkSession write5 = (RoutingNetworkSession) routingDriver.session( AccessMode.WRITE ); + RoutingNetworkSession write6 = (RoutingNetworkSession) routingDriver.session( AccessMode.WRITE ); // Then - assertThat(write1.address(), equalTo(boltAddress( "localhost", 2222 ))); - assertThat(write2.address(), equalTo(boltAddress( "localhost", 2223 ))); - assertThat(write3.address(), equalTo(boltAddress( "localhost", 2224 ))); - assertThat(write4.address(), equalTo(boltAddress( "localhost", 2222 ))); - + assertEquals( write1.address(), write4.address() ); + assertEquals( write2.address(), write5.address() ); + assertEquals( write3.address(), write6.address() ); + assertNotEquals( write1.address(), write2.address() ); + assertNotEquals( write2.address(), write3.address() ); + assertNotEquals( write3.address(), write1.address() ); } - private RoutingDriver forSession( final Session session ) + @SafeVarargs + private final RoutingDriver driverWithServers( long ttl, Map... serverInfo ) { - return forSession( session, Clock.SYSTEM ); - } - private RoutingDriver forSession( final Session session, Clock clock ) - { - return new RoutingDriver( SEED, pool, insecure(), - new Function() - { - @Override - public Session apply( Connection connection ) - { - return session; - } - }, clock, logging() ); + return driverWithPool( poolWithServers( ttl, serverInfo ) ); } - private BoltServerAddress boltAddress( String host, int port ) + private RoutingDriver driverWithPool( ConnectionPool pool ) { - return new BoltServerAddress( host, port ); + return new RoutingDriver( new RoutingSettings( 10, 5_000 ), SEED, pool, insecure(), clock, logging ); } - - StatementResult getServers( final List routers, final List readers, - final List writers ) + @SafeVarargs + private final ConnectionPool poolWithServers( long ttl, Map... serverInfo ) { - return getServers( routers,readers, writers, Long.MAX_VALUE ); + return pool( withServers( ttl, serverInfo ) ); } - StatementResult getServers( final List routers, final List readers, - final List writers, final long ttl ) + @SafeVarargs + private static Answer withServers( long ttl, Map... serverInfo ) { - return new StatementResult() - { - private int counter = 0; - - @Override - public List keys() - { - return asList( "ttl", "servers" ); - } - - @Override - public boolean hasNext() - { - return counter < 1; - } - - @Override - public Record next() - { - counter++; - return new InternalRecord( asList( "ttl", "servers" ), - new Value[]{ - value( ttl ), - value( asList( serverInfo( "ROUTE", routers ), serverInfo( "WRITE", writers ), - serverInfo( "READ", readers ) ) ) - } ); - } - - @Override - public Record single() throws NoSuchRecordException - { - return next(); - } - - @Override - public Record peek() - { - return null; - } - - @Override - public List list() - { - return null; - } - - @Override - public List list( Function mapFunction ) - { - return null; - } - - @Override - public ResultSummary consume() - { - return null; - } - - @Override - public void remove() - { - throw new UnsupportedOperationException(); - } - }; + return withServerList( new Value[] {value( ttl ), value( asList( serverInfo ) )} ); } - private Map serverInfo( String role, List addresses ) + private BoltServerAddress boltAddress( String host, int port ) { - Map map = new HashMap<>(); - map.put( "role", role ); - map.put( "addresses", addresses ); - - return map; + return new BoltServerAddress( host, port ); } - private ConnectionPool pool() + private ConnectionPool pool( final Answer toGetServers, final Answer... furtherGetServers ) { ConnectionPool pool = mock( ConnectionPool.class ); - - when( pool.acquire( any(BoltServerAddress.class) ) ).thenAnswer( new Answer() + when( pool.acquire( any( BoltServerAddress.class ) ) ).thenAnswer( new Answer() { + int answer; + @Override public Connection answer( InvocationOnMock invocationOnMock ) throws Throwable { - BoltServerAddress address = (BoltServerAddress) invocationOnMock.getArguments()[0]; + BoltServerAddress address = invocationOnMock.getArgumentAt( 0, BoltServerAddress.class ); Connection connection = mock( Connection.class ); when( connection.isOpen() ).thenReturn( true ); - when(connection.address()).thenReturn( address ); + when( connection.address() ).thenReturn( address ); + doAnswer( withKeys( "ttl", "servers" ) ).when( connection ).run( + eq( GET_SERVERS ), + eq( Collections.emptyMap() ), + any( Collector.class ) ); + if ( answer > furtherGetServers.length ) + { + answer = furtherGetServers.length; + } + int offset = answer++; + doAnswer( offset == 0 ? toGetServers : furtherGetServers[offset - 1] ) + .when( connection ).pullAll( any( Collector.class ) ); return connection; } @@ -426,11 +365,4 @@ public Connection answer( InvocationOnMock invocationOnMock ) throws Throwable return pool; } - - private Logging logging() - { - Logging mock = mock( Logging.class ); - when( mock.getLog( anyString() ) ).thenReturn( mock( Logger.class ) ); - return mock; - } -} \ No newline at end of file +} diff --git a/driver/src/test/java/org/neo4j/driver/internal/cluster/ClusterCompositionProviderTest.java b/driver/src/test/java/org/neo4j/driver/internal/cluster/ClusterCompositionProviderTest.java new file mode 100644 index 0000000000..0e1803521d --- /dev/null +++ b/driver/src/test/java/org/neo4j/driver/internal/cluster/ClusterCompositionProviderTest.java @@ -0,0 +1,248 @@ +/** + * Copyright (c) 2002-2016 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal.cluster; + +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +import org.junit.Test; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; +import org.mockito.stubbing.Stubber; + +import org.neo4j.driver.internal.EventHandler; +import org.neo4j.driver.internal.net.BoltServerAddress; +import org.neo4j.driver.internal.spi.Collector; +import org.neo4j.driver.internal.spi.Connection; +import org.neo4j.driver.internal.util.FakeClock; +import org.neo4j.driver.v1.Value; +import org.neo4j.driver.v1.exceptions.ServiceUnavailableException; + +import static java.util.Arrays.asList; +import static java.util.Collections.singletonList; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.fail; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.neo4j.driver.v1.Values.value; + +public class ClusterCompositionProviderTest +{ + private final FakeClock clock = new FakeClock( (EventHandler) null, true ); + private final Connection connection = mock( Connection.class ); + + @Test + public void shouldParseClusterComposition() throws Exception + { + // given + clock.progress( 16500 ); + keys( "ttl", "servers" ); + values( new Value[] { + value( 100 ), value( asList( + serverInfo( "READ", "one:1337", "two:1337" ), + serverInfo( "WRITE", "one:1337" ), + serverInfo( "ROUTE", "one:1337", "two:1337" ) ) )} ); + + // when + ClusterComposition composition = getClusterComposition(); + + // then + assertNotNull( composition ); + assertEquals( 16500 + 100_000, composition.expirationTimestamp ); + assertEquals( serverSet( "one:1337", "two:1337" ), composition.readers() ); + assertEquals( serverSet( "one:1337" ), composition.writers() ); + assertEquals( serverSet( "one:1337", "two:1337" ), composition.routers() ); + } + + @Test + public void shouldReturnNullIfResultContainsTooManyRows() throws Exception + { + // given + keys( "ttl", "servers" ); + values( + new Value[] { + value( 100 ), value( singletonList( + serverInfo( "READ", "one:1337", "two:1337" ) ) )}, + new Value[] { + value( 100 ), value( singletonList( + serverInfo( "WRITE", "one:1337" ) ) )}, + new Value[] { + value( 100 ), value( singletonList( + serverInfo( "ROUTE", "one:1337", "two:1337" ) ) )} ); + + // then + assertNull( getClusterComposition() ); + } + + @Test + public void shouldReturnNullOnEmptyResult() throws Exception + { + // given + keys( "ttl", "servers" ); + values(); + + // then + assertNull( getClusterComposition() ); + } + + @Test + public void shouldReturnNullOnResultWithWrongFormat() throws Exception + { + // given + clock.progress( 16500 ); + keys( "ttl", "addresses" ); + values( new Value[] { + value( 100 ), value( asList( + serverInfo( "READ", "one:1337", "two:1337" ), + serverInfo( "WRITE", "one:1337" ), + serverInfo( "ROUTE", "one:1337", "two:1337" ) ) )} ); + + // then + assertNull( getClusterComposition() ); + } + + @Test + public void shouldPropagateConnectionFailureExceptions() throws Exception + { + // given + ServiceUnavailableException expected = new ServiceUnavailableException( "spanish inquisition" ); + onGetServers( doThrow( expected ) ); + + // when + try + { + getClusterComposition(); + fail( "Expected exception" ); + } + // then + catch ( ServiceUnavailableException e ) + { + assertSame( expected, e ); + } + } + + private ClusterComposition getClusterComposition() + { + return new ClusterComposition.Provider.Default( clock ).getClusterComposition( connection ); + } + + private void keys( final String... keys ) + { + onGetServers( doAnswer( withKeys( keys ) ) ); + } + + private void values( final Value[]... records ) + { + onPullAll( doAnswer( withServerList( records ) ) ); + } + + private void onGetServers( Stubber stubber ) + { + stubber.when( connection ).run( + eq( ClusterComposition.Provider.GET_SERVERS ), + eq( Collections.emptyMap() ), + any( Collector.class ) ); + } + + private void onPullAll( Stubber stubber ) + { + stubber.when( connection ).pullAll( any( Collector.class ) ); + } + + public static CollectorAnswer withKeys( final String... keys ) + { + return new CollectorAnswer() + { + @Override + void collect( Collector collector ) + { + collector.keys( keys ); + } + }; + } + + public static CollectorAnswer withServerList( final Value[]... records ) + { + return new CollectorAnswer() + { + @Override + void collect( Collector collector ) + { + for ( Value[] fields : records ) + { + collector.record( fields ); + } + } + }; + } + + private static abstract class CollectorAnswer implements Answer + { + abstract void collect( Collector collector ); + + @Override + public final Object answer( InvocationOnMock invocation ) throws Throwable + { + Collector collector = collector( invocation ); + collect( collector ); + collector.done(); + return null; + } + + private Collector collector( InvocationOnMock invocation ) + { + switch ( invocation.getMethod().getName() ) + { + case "pullAll": + return invocation.getArgumentAt( 0, Collector.class ); + case "run": + return invocation.getArgumentAt( 2, Collector.class ); + default: + throw new UnsupportedOperationException( invocation.getMethod().getName() ); + } + } + } + + public static Map serverInfo( String role, String... addresses ) + { + Map map = new HashMap<>(); + map.put( "role", role ); + map.put( "addresses", asList( addresses ) ); + return map; + } + + private static Set serverSet( String... addresses ) + { + Set result = new HashSet<>(); + for ( String address : addresses ) + { + result.add( new BoltServerAddress( address ) ); + } + return result; + } +} diff --git a/driver/src/test/java/org/neo4j/driver/internal/cluster/ClusterTopology.java b/driver/src/test/java/org/neo4j/driver/internal/cluster/ClusterTopology.java new file mode 100644 index 0000000000..ff75414528 --- /dev/null +++ b/driver/src/test/java/org/neo4j/driver/internal/cluster/ClusterTopology.java @@ -0,0 +1,198 @@ +/** + * Copyright (c) 2002-2016 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal.cluster; + +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import org.hamcrest.Description; +import org.hamcrest.Matcher; +import org.hamcrest.TypeSafeMatcher; + +import org.neo4j.driver.internal.Event; +import org.neo4j.driver.internal.EventHandler; +import org.neo4j.driver.internal.net.BoltServerAddress; +import org.neo4j.driver.internal.spi.Connection; +import org.neo4j.driver.internal.util.Clock; + +import static java.util.Arrays.asList; +import static org.neo4j.driver.internal.cluster.ClusterTopology.Role.READ; +import static org.neo4j.driver.internal.cluster.ClusterTopology.Role.ROUTE; +import static org.neo4j.driver.internal.cluster.ClusterTopology.Role.WRITE; + +class ClusterTopology implements ClusterComposition.Provider +{ + public interface EventSink + { + EventSink VOID = new Adapter(); + + void clusterComposition( BoltServerAddress address, ClusterComposition result ); + + class Adapter implements EventSink + { + @Override + public void clusterComposition( BoltServerAddress address, ClusterComposition result ) + { + } + } + } + + private static final List KEYS = Collections.unmodifiableList( asList( "servers", "ttl" ) ); + private final Map views = new HashMap<>(); + private final EventSink events; + private final Clock clock; + + ClusterTopology( final EventHandler events, Clock clock ) + { + this( events == null ? null : new EventSink() + { + @Override + public void clusterComposition( BoltServerAddress address, ClusterComposition result ) + { + events.add( new CompositionRequest( Thread.currentThread(), address, result ) ); + } + }, clock ); + } + + ClusterTopology( EventSink events, Clock clock ) + { + this.events = events == null ? EventSink.VOID : events; + this.clock = clock; + } + + public View on( String host, int port ) + { + View view = new View(); + views.put( new BoltServerAddress( host, port ), view ); + return view; + } + + public enum Role + { + READ, + WRITE, + ROUTE + } + + public static class View + { + private long ttl = 60_000; + private final Set readers = new HashSet<>(), + writers = new HashSet<>(), + routers = new HashSet<>(); + + public View ttlSeconds( long ttl ) + { + this.ttl = ttl * 1000; + return this; + } + + public View provide( String host, int port, Role... roles ) + { + for ( Role role : roles ) + { + servers( role ).add( new BoltServerAddress( host, port ) ); + } + return this; + } + + private Set servers( Role role ) + { + switch ( role ) + { + case READ: + return readers; + case WRITE: + return writers; + case ROUTE: + return routers; + default: + throw new IllegalArgumentException( role.name() ); + } + } + + ClusterComposition composition( long now ) + { + return new ClusterComposition( now + ttl, servers( READ ), servers( WRITE ), servers( ROUTE ) ); + } + } + + @Override + public ClusterComposition getClusterComposition( Connection connection ) + { + BoltServerAddress router = connection.address(); + View view = views.get( router ); + ClusterComposition result = view == null ? null : view.composition( clock.millis() ); + events.clusterComposition( router, result ); + return result; + } + + public static final class CompositionRequest extends Event + { + final Thread thread; + final BoltServerAddress address; + private final ClusterComposition result; + + private CompositionRequest( Thread thread, BoltServerAddress address, ClusterComposition result ) + { + this.thread = thread; + this.address = address; + this.result = result; + } + + @Override + public void dispatch( EventSink sink ) + { + sink.clusterComposition( address, result ); + } + + public static Matcher clusterComposition( + final Matcher thread, + final Matcher address, + final Matcher result ) + { + return new TypeSafeMatcher() + { + @Override + protected boolean matchesSafely( CompositionRequest event ) + { + return thread.matches( event.thread ) + && address.matches( event.address ) + && result.matches( event.result ); + } + + @Override + public void describeTo( Description description ) + { + description.appendText( "a successful cluster composition request on thread <" ) + .appendDescriptionOf( thread ) + .appendText( "> from address <" ) + .appendDescriptionOf( address ) + .appendText( "> returning <" ) + .appendDescriptionOf( result ) + .appendText( ">" ); + } + }; + } + } +} diff --git a/driver/src/test/java/org/neo4j/driver/internal/cluster/LoadBalancerTest.java b/driver/src/test/java/org/neo4j/driver/internal/cluster/LoadBalancerTest.java new file mode 100644 index 0000000000..d0644ecb28 --- /dev/null +++ b/driver/src/test/java/org/neo4j/driver/internal/cluster/LoadBalancerTest.java @@ -0,0 +1,599 @@ +/** + * Copyright (c) 2002-2016 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal.cluster; + +import java.util.ArrayList; +import java.util.List; + +import org.hamcrest.Matcher; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.TestRule; +import org.junit.runner.Description; +import org.junit.runners.model.Statement; + +import org.neo4j.driver.internal.EventHandler; +import org.neo4j.driver.internal.net.BoltServerAddress; +import org.neo4j.driver.internal.spi.Connection; +import org.neo4j.driver.internal.spi.StubConnectionPool; +import org.neo4j.driver.internal.util.FakeClock; +import org.neo4j.driver.internal.util.MatcherFactory; +import org.neo4j.driver.v1.EventLogger; +import org.neo4j.driver.v1.exceptions.ServiceUnavailableException; +import org.neo4j.driver.v1.util.Function; + +import static org.hamcrest.Matchers.any; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.sameInstance; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.fail; +import static org.neo4j.driver.internal.cluster.ClusterTopology.Role.READ; +import static org.neo4j.driver.internal.cluster.ClusterTopology.Role.ROUTE; +import static org.neo4j.driver.internal.cluster.ClusterTopology.Role.WRITE; +import static org.neo4j.driver.internal.spi.StubConnectionPool.Event.acquire; +import static org.neo4j.driver.internal.spi.StubConnectionPool.Event.connectionFailure; +import static org.neo4j.driver.internal.util.FakeClock.Event.sleep; +import static org.neo4j.driver.internal.util.MatcherFactory.inAnyOrder; +import static org.neo4j.driver.internal.util.MatcherFactory.matches; +import static org.neo4j.driver.v1.EventLogger.Entry.message; +import static org.neo4j.driver.v1.EventLogger.Level.INFO; + +public class LoadBalancerTest +{ + private static final long RETRY_TIMEOUT_DELAY = 5_000; + private static final int MAX_ROUTING_FAILURES = 5; + @Rule + public final TestRule printEventsOnFailure = new TestRule() + { + @Override + public Statement apply( final Statement base, Description description ) + { + return new Statement() + { + @Override + public void evaluate() throws Throwable + { + try + { + base.evaluate(); + } + catch ( Throwable e ) + { + events.printEvents( System.err ); + throw e; + } + } + }; + } + }; + private final EventHandler events = new EventHandler(); + private final FakeClock clock = new FakeClock( events, true ); + private final EventLogger log = new EventLogger( events, null, INFO ); + private final StubConnectionPool connections = new StubConnectionPool( clock, events, null ); + private final ClusterTopology cluster = new ClusterTopology( events, clock ); + + private LoadBalancer seedLoadBalancer( String host, int port ) throws Exception + { + return new LoadBalancer( + new RoutingSettings( MAX_ROUTING_FAILURES, RETRY_TIMEOUT_DELAY ), + clock, + log, + connections, + cluster, + new BoltServerAddress( host, port ) ); + } + + @Test + public void shouldConnectToRouter() throws Exception + { + // given + connections.up( "some.host", 1337 ); + cluster.on( "some.host", 1337 ) + .provide( "some.host", 1337, READ, WRITE, ROUTE ) + .provide( "another.host", 1337, ROUTE ); + + // when + Connection connection = seedLoadBalancer( "some.host", 1337 ).acquireReadConnection(); + + // then + events.assertCount( any( ClusterTopology.CompositionRequest.class ), equalTo( 1 ) ); + events.assertContains( acquiredConnection( "some.host", 1337, connection ) ); + } + + @Test + public void shouldConnectToRouterOnInitialization() throws Exception + { + // given + connections.up( "some.host", 1337 ); + cluster.on( "some.host", 1337 ) + .provide( "some.host", 1337, READ, WRITE, ROUTE ) + .provide( "another.host", 1337, ROUTE ); + + // when + seedLoadBalancer( "some.host", 1337 ); + + // then + events.assertCount( any( ClusterTopology.CompositionRequest.class ), equalTo( 1 ) ); + } + + @Test + public void shouldReconnectWithRouterAfterTtlExpires() throws Exception + { + // given + coreClusterOn( 20, "some.host", 1337, "another.host" ); + connections.up( "some.host", 1337 ).up( "another.host", 1337 ); + + LoadBalancer routing = seedLoadBalancer( "some.host", 1337 ); + + // when + clock.progress( 25_000 ); // will cause TTL timeout + Connection connection = routing.acquireWriteConnection(); + + // then + events.assertCount( any( ClusterTopology.CompositionRequest.class ), equalTo( 2 ) ); + events.assertContains( acquiredConnection( "some.host", 1337, connection ) ); + } + + @Test + public void shouldNotReconnectWithRouterWithinTtl() throws Exception + { + // given + coreClusterOn( 20, "some.host", 1337, "another.host" ); + connections.up( "some.host", 1337 ).up( "another.host", 1337 ); + + LoadBalancer routing = seedLoadBalancer( "some.host", 1337 ); + + // when + clock.progress( 15_000 ); // not enough to cause TTL timeout + routing.acquireWriteConnection(); + + // then + events.assertCount( any( ClusterTopology.CompositionRequest.class ), equalTo( 1 ) ); + } + + @Test + public void shouldReconnectWithRouterIfOnlyOneRouterIsFound() throws Exception + { + // given + cluster.on( "here", 1337 ) + .ttlSeconds( 20 ) + .provide( "here", 1337, READ, WRITE, ROUTE ); + connections.up( "here", 1337 ); + + LoadBalancer routing = seedLoadBalancer( "here", 1337 ); + + // when + routing.acquireReadConnection(); + + // then + events.assertCount( any( ClusterTopology.CompositionRequest.class ), equalTo( 2 ) ); + } + + @Test + public void shouldReconnectWithRouterIfNoReadersAreAvailable() throws Exception + { + // given + cluster.on( "one", 1337 ) + .ttlSeconds( 20 ) + .provide( "one", 1337, WRITE, ROUTE ) + .provide( "two", 1337, ROUTE ); + cluster.on( "two", 1337 ) + .ttlSeconds( 20 ) + .provide( "one", 1337, READ, WRITE, ROUTE ) + .provide( "two", 1337, READ, ROUTE ); + connections.up( "one", 1337 ).up( "two", 1337 ); + + LoadBalancer routing = seedLoadBalancer( "one", 1337 ); + + events.assertCount( any( ClusterTopology.CompositionRequest.class ), equalTo( 1 ) ); + + cluster.on( "one", 1337 ) + .ttlSeconds( 20 ) + .provide( "one", 1337, READ, WRITE, ROUTE ) + .provide( "two", 1337, READ, ROUTE ); + + // when + routing.acquireWriteConnection(); // we should require the presence of a READER even though we ask for a WRITER + + // then + events.assertCount( any( ClusterTopology.CompositionRequest.class ), equalTo( 2 ) ); + } + + @Test + public void shouldReconnectWithRouterIfNoWritersAreAvailable() throws Exception + { + // given + cluster.on( "one", 1337 ) + .ttlSeconds( 20 ) + .provide( "one", 1337, READ, ROUTE ) + .provide( "two", 1337, READ, WRITE, ROUTE ); + cluster.on( "two", 1337 ) + .ttlSeconds( 20 ) + .provide( "one", 1337, READ, ROUTE ) + .provide( "two", 1337, READ, WRITE, ROUTE ); + connections.up( "one", 1337 ); + + events.registerHandler( StubConnectionPool.EventSink.class, new StubConnectionPool.EventSink.Adapter() + { + @Override + public void connectionFailure( BoltServerAddress address ) + { + connections.up( "two", 1337 ); + } + } ); + + LoadBalancer routing = seedLoadBalancer( "one", 1337 ); + + events.assertCount( any( ClusterTopology.CompositionRequest.class ), equalTo( 1 ) ); + + cluster.on( "one", 1337 ) + .ttlSeconds( 20 ) + .provide( "one", 1337, READ, WRITE, ROUTE ) + .provide( "two", 1337, READ, ROUTE ); + + // when + routing.acquireWriteConnection(); + + // then + events.assertCount( any( ClusterTopology.CompositionRequest.class ), equalTo( 2 ) ); + } + + @Test + public void shouldDropRouterUnableToPerformRoutingTask() throws Exception + { + // given + connections.up( "some.host", 1337 ) + .up( "other.host", 1337 ) + .up( "another.host", 1337 ); + cluster.on( "some.host", 1337 ) + .ttlSeconds( 20 ) + .provide( "some.host", 1337, READ, WRITE, ROUTE ) + .provide( "other.host", 1337, READ, ROUTE ); + cluster.on( "another.host", 1337 ) + .ttlSeconds( 20 ) + .provide( "some.host", 1337, READ, WRITE, ROUTE ) + .provide( "another.host", 1337, READ, ROUTE ); + events.registerHandler( ClusterTopology.EventSink.class, new ClusterTopology.EventSink.Adapter() + { + @Override + public void clusterComposition( BoltServerAddress address, ClusterComposition result ) + { + if ( result == null ) + { + connections.up( "some.host", 1337 ); + cluster.on( "some.host", 1337 ) + .ttlSeconds( 20 ) + .provide( "some.host", 1337, READ, WRITE, ROUTE ) + .provide( "another.host", 1337, READ, ROUTE ); + } + } + } ); + + LoadBalancer routing = seedLoadBalancer( "some.host", 1337 ); + + // when + connections.down( "some.host", 1337 ); + clock.progress( 25_000 ); // will cause TTL timeout + Connection connection = routing.acquireWriteConnection(); + + // then + events.assertCount( + message( + equalTo( INFO ), + equalTo( "Server unable to perform routing capability, " + + "dropping from list of routers." ) ), + equalTo( 1 ) ); + events.assertContains( acquiredConnection( "some.host", 1337, connection ) ); + } + + @Test + public void shouldConnectToRoutingServersInTimeoutOrder() throws Exception + { + // given + coreClusterOn( 20, "one", 1337, "two", "tre" ); + connections.up( "one", 1337 ); + events.registerHandler( StubConnectionPool.EventSink.class, new StubConnectionPool.EventSink.Adapter() + { + int failed; + + @Override + public void connectionFailure( BoltServerAddress address ) + { + if ( ++failed >= 9 ) // three times per server + { + for ( String host : new String[] {"one", "two", "tre"} ) + { + connections.up( host, 1337 ); + } + } + } + } ); + + LoadBalancer routing = seedLoadBalancer( "one", 1337 ); + + // when + connections.down( "one", 1337 ); + clock.progress( 25_000 ); // will cause TTL timeout + routing.acquireWriteConnection(); + + // then + MatcherFactory failedAttempts = inAnyOrder( + connectionFailure( "one", 1337 ), + connectionFailure( "two", 1337 ), + connectionFailure( "tre", 1337 ) ); + events.assertContains( + failedAttempts, + matches( sleep( RETRY_TIMEOUT_DELAY ) ), + failedAttempts, + matches( sleep( 2 * RETRY_TIMEOUT_DELAY ) ), + failedAttempts, + matches( sleep( 4 * RETRY_TIMEOUT_DELAY ) ), + matches( ClusterTopology.CompositionRequest.clusterComposition( + any( Thread.class ), + any( BoltServerAddress.class ), + any( ClusterComposition.class ) ) ) ); + } + + @Test + public void shouldFailIfEnoughConnectionAttemptsFail() throws Exception + { + // when + try + { + seedLoadBalancer( "one", 1337 ); + fail( "expected failure" ); + } + // then + catch ( ServiceUnavailableException e ) + { + assertEquals( "Could not perform discovery. No routing servers available.", e.getMessage() ); + } + } + + private static final Function READ_SERVERS = new Function() + { + @Override + public Connection apply( LoadBalancer routing ) + { + return routing.acquireReadConnection(); + } + }; + + @Test + public void shouldRoundRobinAmongReadServers() throws Exception + { + shouldRoundRobinAmong( READ_SERVERS ); + } + + private static final Function WRITE_SERVERS = new Function() + { + @Override + public Connection apply( LoadBalancer routing ) + { + return routing.acquireWriteConnection(); + } + }; + + @Test + public void shouldRoundRobinAmongWriteServers() throws Exception + { + shouldRoundRobinAmong( WRITE_SERVERS ); + } + + private void shouldRoundRobinAmong( Function acquire ) throws Exception + { + // given + for ( String host : new String[] {"one", "two", "tre"} ) + { + connections.up( host, 1337 ); + cluster.on( host, 1337 ) + .ttlSeconds( 20 ) + .provide( "one", 1337, READ, WRITE, ROUTE ) + .provide( "two", 1337, READ, WRITE, ROUTE ) + .provide( "tre", 1337, READ, WRITE, ROUTE ); + } + LoadBalancer routing = seedLoadBalancer( "one", 1337 ); + + // when + Connection a = acquire.apply( routing ); + Connection b = acquire.apply( routing ); + Connection c = acquire.apply( routing ); + assertNotEquals( a.address(), b.address() ); + assertNotEquals( b.address(), c.address() ); + assertNotEquals( c.address(), a.address() ); + assertEquals( a.address(), acquire.apply( routing ).address() ); + assertEquals( b.address(), acquire.apply( routing ).address() ); + assertEquals( c.address(), acquire.apply( routing ).address() ); + assertEquals( a.address(), acquire.apply( routing ).address() ); + assertEquals( b.address(), acquire.apply( routing ).address() ); + assertEquals( c.address(), acquire.apply( routing ).address() ); + + // then + MatcherFactory acquireConnections = + inAnyOrder( acquire( "one", 1337 ), acquire( "two", 1337 ), acquire( "tre", 1337 ) ); + events.assertContains( acquireConnections, acquireConnections, acquireConnections ); + events.assertContains( inAnyOrder( acquire( a ), acquire( b ), acquire( c ) ) ); + } + + @Test + public void shouldRoundRobinAmongRouters() throws Exception + { + // given + coreClusterOn( 20, "one", 1337, "two", "tre" ); + connections.up( "one", 1337 ).up( "two", 1337 ).up( "tre", 1337 ); + + // when + LoadBalancer routing = seedLoadBalancer( "one", 1337 ); + for ( int i = 1; i < 9; i++ ) + { + clock.progress( 25_000 ); + routing.acquireReadConnection(); + } + + // then + final List hosts = new ArrayList<>(); + events.forEach( new ClusterTopology.EventSink() + { + @Override + public void clusterComposition( BoltServerAddress address, ClusterComposition result ) + { + hosts.add( address.host() ); + } + } ); + assertEquals( 9, hosts.size() ); + assertEquals( hosts.get( 0 ), hosts.get( 3 ) ); + assertEquals( hosts.get( 1 ), hosts.get( 4 ) ); + assertEquals( hosts.get( 2 ), hosts.get( 5 ) ); + assertEquals( hosts.get( 0 ), hosts.get( 6 ) ); + assertEquals( hosts.get( 1 ), hosts.get( 7 ) ); + assertEquals( hosts.get( 2 ), hosts.get( 8 ) ); + assertNotEquals( hosts.get( 0 ), hosts.get( 1 ) ); + assertNotEquals( hosts.get( 1 ), hosts.get( 2 ) ); + assertNotEquals( hosts.get( 2 ), hosts.get( 0 ) ); + } + + @Test + public void shouldForgetPreviousServersOnRerouting() throws Exception + { + // given + connections.up( "one", 1337 ) + .up( "two", 1337 ); + cluster.on( "one", 1337 ) + .ttlSeconds( 20 ) + .provide( "bad", 1337, READ, WRITE, ROUTE ) + .provide( "one", 1337, READ, ROUTE ); + + LoadBalancer routing = seedLoadBalancer( "one", 1337 ); + + // when + coreClusterOn( 20, "one", 1337, "two" ); + clock.progress( 25_000 ); // will cause TTL timeout + Connection ra = routing.acquireReadConnection(); + Connection rb = routing.acquireReadConnection(); + Connection w = routing.acquireWriteConnection(); + assertNotEquals( ra.address(), rb.address() ); + assertEquals( ra.address(), routing.acquireReadConnection().address() ); + assertEquals( rb.address(), routing.acquireReadConnection().address() ); + assertEquals( w.address(), routing.acquireWriteConnection().address() ); + assertEquals( ra.address(), routing.acquireReadConnection().address() ); + assertEquals( rb.address(), routing.acquireReadConnection().address() ); + assertEquals( w.address(), routing.acquireWriteConnection().address() ); + + // then + events.assertNone( acquire( "bad", 1337 ) ); + } + + @Test + public void shouldFailIfNoRouting() throws Exception + { + // given + connections.up( "one", 1337 ); + cluster.on( "one", 1337 ) + .provide( "one", 1337, READ, WRITE ); + + // when + try + { + seedLoadBalancer( "one", 1337 ); + fail( "expected failure" ); + } + // then + catch ( ServiceUnavailableException e ) + { + assertEquals( "Could not perform discovery. No routing servers available.", e.getMessage() ); + } + } + + @Test + public void shouldFailIfNoWriting() throws Exception + { + // given + connections.up( "one", 1337 ); + cluster.on( "one", 1337 ) + .provide( "one", 1337, READ, ROUTE ); + + // when + try + { + seedLoadBalancer( "one", 1337 ); + fail( "expected failure" ); + } + // then + catch ( ServiceUnavailableException e ) + { + assertEquals( "Could not perform discovery. No routing servers available.", e.getMessage() ); + } + } + + @Test + public void shouldNotForgetAddressForRoutingPurposesWhenUnavailableForOtherUse() throws Exception + { + // given + cluster.on( "one", 1337 ) + .provide( "one", 1337, READ, ROUTE ) + .provide( "two", 1337, WRITE, ROUTE ); + cluster.on( "two", 1337 ) + .provide( "one", 1337, READ, ROUTE ) + .provide( "two", 1337, WRITE, ROUTE ); + connections.up( "one", 1337 ); + + LoadBalancer routing = seedLoadBalancer( "one", 1337 ); + connections.down( "one", 1337 ); + events.registerHandler( FakeClock.EventSink.class, new FakeClock.EventSink.Adapter() + { + @Override + public void sleep( long timestamp, long millis ) + { + connections.up( "two", 1337 ); + } + } ); + + // when + Connection connection = routing.acquireWriteConnection(); + + // then + assertEquals( new BoltServerAddress( "two", 1337 ), connection.address() ); + events.printEvents( System.out ); + } + + private void coreClusterOn( int ttlSeconds, String leader, int port, String... others ) + { + for ( int i = 0; i <= others.length; i++ ) + { + String host = (i == others.length) ? leader : others[i]; + ClusterTopology.View view = cluster.on( host, port ) + .ttlSeconds( ttlSeconds ) + .provide( leader, port, READ, WRITE, ROUTE ); + for ( String other : others ) + { + view.provide( other, port, READ, ROUTE ); + } + } + } + + private Matcher acquiredConnection( + String host, int port, Connection connection ) + { + return acquire( + any( Thread.class ), + equalTo( new BoltServerAddress( host, port ) ), + sameInstance( connection ) ); + } +} diff --git a/driver/src/test/java/org/neo4j/driver/internal/cluster/RoundRobinAddressSetTest.java b/driver/src/test/java/org/neo4j/driver/internal/cluster/RoundRobinAddressSetTest.java new file mode 100644 index 0000000000..4b69ac8f95 --- /dev/null +++ b/driver/src/test/java/org/neo4j/driver/internal/cluster/RoundRobinAddressSetTest.java @@ -0,0 +1,226 @@ +/** + * Copyright (c) 2002-2016 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal.cluster; + +import java.util.ArrayList; +import java.util.HashSet; +import java.util.List; + +import org.junit.Test; + +import org.neo4j.driver.internal.net.BoltServerAddress; + +import static java.util.Arrays.asList; +import static java.util.Collections.singleton; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.fail; + +public class RoundRobinAddressSetTest +{ + @Test + public void shouldReturnNullWhenEmpty() throws Exception + { + // given + RoundRobinAddressSet set = new RoundRobinAddressSet(); + + // then + assertNull( set.next() ); + } + + @Test + public void shouldReturnRoundRobin() throws Exception + { + // given + RoundRobinAddressSet set = new RoundRobinAddressSet(); + set.update( new HashSet<>( asList( + new BoltServerAddress( "one" ), + new BoltServerAddress( "two" ), + new BoltServerAddress( "tre" ) ) ), new HashSet() ); + + // when + BoltServerAddress a = set.next(); + BoltServerAddress b = set.next(); + BoltServerAddress c = set.next(); + + // then + assertEquals( a, set.next() ); + assertEquals( b, set.next() ); + assertEquals( c, set.next() ); + assertEquals( a, set.next() ); + assertEquals( b, set.next() ); + assertEquals( c, set.next() ); + assertNotEquals( a, c ); + assertNotEquals( b, a ); + assertNotEquals( c, b ); + } + + @Test + public void shouldPreserveOrderWhenAdding() throws Exception + { + // given + HashSet servers = new HashSet<>( asList( + new BoltServerAddress( "one" ), + new BoltServerAddress( "two" ), + new BoltServerAddress( "tre" ) ) ); + RoundRobinAddressSet set = new RoundRobinAddressSet(); + set.update( servers, new HashSet() ); + + List order = new ArrayList<>(); + for ( int i = 3 * 4 + 1; i-- > 0; ) + { + BoltServerAddress server = set.next(); + if ( !order.contains( server ) ) + { + order.add( server ); + } + } + assertEquals( 3, order.size() ); + + // when + servers.add( new BoltServerAddress( "fyr" ) ); + set.update( servers, new HashSet() ); + + // then + assertEquals( order.get( 1 ), set.next() ); + assertEquals( order.get( 2 ), set.next() ); + BoltServerAddress next = set.next(); + assertNotEquals( order.get( 0 ), next ); + assertNotEquals( order.get( 1 ), next ); + assertNotEquals( order.get( 2 ), next ); + assertEquals( order.get( 0 ), set.next() ); + // ... and once more + assertEquals( order.get( 1 ), set.next() ); + assertEquals( order.get( 2 ), set.next() ); + assertEquals( next, set.next() ); + assertEquals( order.get( 0 ), set.next() ); + } + + @Test + public void shouldPreserveOrderWhenRemoving() throws Exception + { + // given + HashSet servers = new HashSet<>( asList( + new BoltServerAddress( "one" ), + new BoltServerAddress( "two" ), + new BoltServerAddress( "tre" ) ) ); + RoundRobinAddressSet set = new RoundRobinAddressSet(); + set.update( servers, new HashSet() ); + + List order = new ArrayList<>(); + for ( int i = 3 * 2 + 1; i-- > 0; ) + { + BoltServerAddress server = set.next(); + if ( !order.contains( server ) ) + { + order.add( server ); + } + } + assertEquals( 3, order.size() ); + + // when + set.remove( order.get( 1 ) ); + + // then + assertEquals( order.get( 2 ), set.next() ); + assertEquals( order.get( 0 ), set.next() ); + assertEquals( order.get( 2 ), set.next() ); + assertEquals( order.get( 0 ), set.next() ); + } + + @Test + public void shouldPreserveOrderWhenRemovingThroughUpdate() throws Exception + { + // given + HashSet servers = new HashSet<>( asList( + new BoltServerAddress( "one" ), + new BoltServerAddress( "two" ), + new BoltServerAddress( "tre" ) ) ); + RoundRobinAddressSet set = new RoundRobinAddressSet(); + set.update( servers, new HashSet() ); + + List order = new ArrayList<>(); + for ( int i = 3 * 2 + 1; i-- > 0; ) + { + BoltServerAddress server = set.next(); + if ( !order.contains( server ) ) + { + order.add( server ); + } + } + assertEquals( 3, order.size() ); + + // when + servers.remove( order.get( 1 ) ); + set.update( servers, new HashSet() ); + + // then + assertEquals( order.get( 2 ), set.next() ); + assertEquals( order.get( 0 ), set.next() ); + assertEquals( order.get( 2 ), set.next() ); + assertEquals( order.get( 0 ), set.next() ); + } + + @Test + public void shouldRecordRemovedAddressesWhenUpdating() throws Exception + { + // given + RoundRobinAddressSet set = new RoundRobinAddressSet(); + set.update( + new HashSet<>( asList( + new BoltServerAddress( "one" ), + new BoltServerAddress( "two" ), + new BoltServerAddress( "tre" ) ) ), + new HashSet() ); + + // when + HashSet removed = new HashSet<>(); + set.update( + new HashSet<>( asList( + new BoltServerAddress( "one" ), + new BoltServerAddress( "two" ), + new BoltServerAddress( "fyr" ) ) ), + removed ); + + // then + assertEquals( singleton( new BoltServerAddress( "tre" ) ), removed ); + } + + @Test + public void shouldPreserveOrderEvenWhenIntegerOverflows() throws Exception + { + // given + RoundRobinAddressSet set = new RoundRobinAddressSet(); + + for ( int div = 1; div <= 1024; div++ ) + { + // when - white box testing! + set.setOffset( Integer.MAX_VALUE - 1 ); + int a = set.next( div ); + int b = set.next( div ); + + // then + if ( b != (a + 1) % div ) + { + fail( String.format( "a=%d, b=%d, div=%d, (a+1)%%div=%d", a, b, div, (a + 1) % div ) ); + } + } + } +} diff --git a/driver/src/test/java/org/neo4j/driver/internal/spi/StubConnectionPool.java b/driver/src/test/java/org/neo4j/driver/internal/spi/StubConnectionPool.java new file mode 100644 index 0000000000..387beaf57e --- /dev/null +++ b/driver/src/test/java/org/neo4j/driver/internal/spi/StubConnectionPool.java @@ -0,0 +1,498 @@ +/** + * Copyright (c) 2002-2016 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal.spi; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; + +import org.hamcrest.Description; +import org.hamcrest.Matcher; +import org.hamcrest.TypeSafeMatcher; + +import org.neo4j.driver.internal.EventHandler; +import org.neo4j.driver.internal.net.BoltServerAddress; +import org.neo4j.driver.internal.net.pooling.PooledConnection; +import org.neo4j.driver.internal.util.Clock; +import org.neo4j.driver.internal.util.Consumer; +import org.neo4j.driver.v1.exceptions.ServiceUnavailableException; +import org.neo4j.driver.v1.util.Function; + +import static org.hamcrest.Matchers.any; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.sameInstance; + +public class StubConnectionPool implements ConnectionPool +{ + public interface EventSink + { + EventSink VOID = new Adapter(); + + void acquire( BoltServerAddress address, Connection connection ); + + void release( BoltServerAddress address, Connection connection ); + + void connectionFailure( BoltServerAddress address ); + + void purge( BoltServerAddress address, boolean connected ); + + void close( Collection connected ); + + class Adapter implements EventSink + { + @Override + public void acquire( BoltServerAddress address, Connection connection ) + { + } + + @Override + public void release( BoltServerAddress address, Connection connection ) + { + } + + @Override + public void connectionFailure( BoltServerAddress address ) + { + } + + @Override + public void purge( BoltServerAddress address, boolean connected ) + { + } + + @Override + public void close( Collection connected ) + { + } + } + } + + private final Clock clock; + private final EventSink events; + private final Function factory; + private static final Function NULL_FACTORY = + new Function() + { + @Override + public Connection apply( BoltServerAddress boltServerAddress ) + { + return null; + } + }; + + public StubConnectionPool( Clock clock, final EventHandler events, Function factory ) + { + this( clock, events == null ? null : new EventSink() + { + @Override + public void acquire( BoltServerAddress address, Connection connection ) + { + events.add( new AcquireEvent( Thread.currentThread(), address, connection ) ); + } + + @Override + public void release( BoltServerAddress address, Connection connection ) + { + events.add( new ReleaseEvent( Thread.currentThread(), address, connection ) ); + } + + @Override + public void connectionFailure( BoltServerAddress address ) + { + events.add( new ConnectionFailureEvent( Thread.currentThread(), address ) ); + } + + @Override + public void purge( BoltServerAddress address, boolean connected ) + { + events.add( new PurgeEvent( Thread.currentThread(), address, connected ) ); + } + + @Override + public void close( Collection connected ) + { + events.add( new CloseEvent( Thread.currentThread(), connected ) ); + } + }, factory ); + } + + public StubConnectionPool( Clock clock, EventSink events, Function factory ) + { + this.clock = clock; + this.events = events == null ? EventSink.VOID : events; + this.factory = factory == null ? NULL_FACTORY : factory; + } + + private enum State + { + AVAILABLE, + CONNECTED, + PURGED + } + + private final ConcurrentMap hosts = new ConcurrentHashMap<>(); + + public StubConnectionPool up( String host, int port ) + { + hosts.putIfAbsent( new BoltServerAddress( host, port ), State.AVAILABLE ); + return this; + } + + public StubConnectionPool down( String host, int port ) + { + hosts.remove( new BoltServerAddress( host, port ) ); + return this; + } + + @Override + public Connection acquire( BoltServerAddress address ) + { + if ( hosts.replace( address, State.CONNECTED ) == null ) + { + events.connectionFailure( address ); + throw new ServiceUnavailableException( "Host unavailable: " + address ); + } + Connection connection = new StubConnection( address, factory.apply( address ), events, clock ); + events.acquire( address, connection ); + return connection; + } + + @Override + public void purge( BoltServerAddress address ) + { + State state = hosts.replace( address, State.PURGED ); + events.purge( address, state == State.CONNECTED ); + } + + @Override + public boolean hasAddress( BoltServerAddress address ) + { + return State.CONNECTED == hosts.get( address ); + } + + @Override + public void close() + { + List connected = new ArrayList<>( hosts.size() ); + for ( Map.Entry entry : hosts.entrySet() ) + { + if ( entry.getValue() == State.CONNECTED ) + { + connected.add( entry.getKey() ); + } + } + events.close( connected ); + hosts.clear(); + } + + private static class StubConnection extends PooledConnection + { + private final BoltServerAddress address; + + StubConnection( + final BoltServerAddress address, + final Connection delegate, + final EventSink events, + Clock clock ) + { + super( delegate, new Consumer() + { + @Override + public void accept( PooledConnection self ) + { + events.release( address, self ); + if ( delegate != null ) + { + delegate.close(); + } + } + }, clock ); + this.address = address; + } + + @Override + public String toString() + { + return String.format( "StubConnection{%s}@%s", address, System.identityHashCode( this ) ); + } + + @Override + public BoltServerAddress address() + { + return address; + } + } + + public static abstract class Event extends org.neo4j.driver.internal.Event + { + final Thread thread; + + private Event( Thread thread ) + { + this.thread = thread; + } + + public static Matcher acquire( String host, int port ) + { + return acquire( + any( Thread.class ), + equalTo( new BoltServerAddress( host, port ) ), + any( Connection.class ) ); + } + + public static Matcher acquire( Connection connection ) + { + return acquire( any( Thread.class ), any( BoltServerAddress.class ), sameInstance( connection ) ); + } + + public static Matcher acquire( + final Matcher thread, + final Matcher address, + final Matcher connection ) + { + return new TypeSafeMatcher() + { + @Override + public void describeTo( Description description ) + { + description.appendText( "acquire event on thread <" ) + .appendDescriptionOf( thread ) + .appendText( "> of address <" ) + .appendDescriptionOf( address ) + .appendText( "> resulting in connection <" ) + .appendDescriptionOf( connection ) + .appendText( ">" ); + } + + @Override + protected boolean matchesSafely( AcquireEvent event ) + { + return thread.matches( event.thread ) && + address.matches( event.address ) && + connection.matches( event.connection ); + } + }; + } + + public static Matcher release( + final Matcher thread, + final Matcher address, + final Matcher connection ) + { + return new TypeSafeMatcher() + { + @Override + public void describeTo( Description description ) + { + description.appendText( "release event on thread <" ) + .appendDescriptionOf( thread ) + .appendText( "> of address <" ) + .appendDescriptionOf( address ) + .appendText( "> and connection <" ) + .appendDescriptionOf( connection ) + .appendText( ">" ); + } + + @Override + protected boolean matchesSafely( ReleaseEvent event ) + { + return thread.matches( event.thread ) && + address.matches( event.address ) && + connection.matches( event.connection ); + } + }; + } + + public static Matcher connectionFailure( String host, int port ) + { + return connectionFailure( any( Thread.class ), equalTo( new BoltServerAddress( host, port ) ) ); + } + + public static Matcher connectionFailure( + final Matcher thread, + final Matcher address ) + { + return new TypeSafeMatcher() + { + @Override + public void describeTo( Description description ) + { + description.appendText( "connection failure event on thread <" ) + .appendDescriptionOf( thread ) + .appendText( "> of address <" ) + .appendDescriptionOf( address ) + .appendText( ">" ); + } + + @Override + protected boolean matchesSafely( ConnectionFailureEvent event ) + { + return thread.matches( event.thread ) && + address.matches( event.address ); + } + }; + } + + public static Matcher purge( + final Matcher thread, + final Matcher address, + final Matcher removed ) + { + return new TypeSafeMatcher() + { + @Override + public void describeTo( Description description ) + { + description.appendText( "purge event on thread <" ) + .appendDescriptionOf( thread ) + .appendText( "> of address <" ) + .appendDescriptionOf( address ) + .appendText( "> resulting in actual removal: " ) + .appendDescriptionOf( removed ); + } + + @Override + protected boolean matchesSafely( PurgeEvent event ) + { + return thread.matches( event.thread ) && + address.matches( event.address ) && + removed.matches( event.connected ); + } + }; + } + + public static Matcher close( + final Matcher thread, + final Matcher> addresses ) + { + return new TypeSafeMatcher() + { + @Override + public void describeTo( Description description ) + { + description.appendText( "close event on thread <" ) + .appendDescriptionOf( thread ) + .appendText( "> resulting in closing connections to <" ) + .appendDescriptionOf( addresses ) + .appendText( ">" ); + } + + @Override + protected boolean matchesSafely( CloseEvent event ) + { + return thread.matches( event.thread ) && addresses.matches( event.connected ); + } + }; + } + } + + private static class AcquireEvent extends Event + { + private final BoltServerAddress address; + private final Connection connection; + + AcquireEvent( Thread thread, BoltServerAddress address, Connection connection ) + { + super( thread ); + this.address = address; + this.connection = connection; + } + + @Override + public void dispatch( EventSink sink ) + { + sink.acquire( address, connection ); + } + } + + private static class ReleaseEvent extends Event + { + private final BoltServerAddress address; + private final Connection connection; + + ReleaseEvent( Thread thread, BoltServerAddress address, Connection connection ) + { + super( thread ); + this.address = address; + this.connection = connection; + } + + @Override + public void dispatch( EventSink sink ) + { + sink.release( address, connection ); + } + } + + private static class ConnectionFailureEvent extends Event + { + private final BoltServerAddress address; + + ConnectionFailureEvent( Thread thread, BoltServerAddress address ) + { + super( thread ); + this.address = address; + } + + @Override + public void dispatch( EventSink sink ) + { + sink.connectionFailure( address ); + } + } + + private static class PurgeEvent extends Event + { + private final BoltServerAddress address; + private final boolean connected; + + PurgeEvent( Thread thread, BoltServerAddress address, boolean connected ) + { + super( thread ); + this.address = address; + this.connected = connected; + } + + @Override + public void dispatch( EventSink sink ) + { + sink.purge( address, connected ); + } + } + + private static class CloseEvent extends Event + { + private final Collection connected; + + CloseEvent( Thread thread, Collection connected ) + { + super( thread ); + this.connected = connected; + } + + @Override + public void dispatch( EventSink sink ) + { + sink.close( connected ); + } + } +} diff --git a/driver/src/test/java/org/neo4j/driver/internal/util/ConcurrentRoundRobinSetTest.java b/driver/src/test/java/org/neo4j/driver/internal/util/ConcurrentRoundRobinSetTest.java deleted file mode 100644 index a2c1e32f0b..0000000000 --- a/driver/src/test/java/org/neo4j/driver/internal/util/ConcurrentRoundRobinSetTest.java +++ /dev/null @@ -1,170 +0,0 @@ -/** - * Copyright (c) 2002-2016 "Neo Technology," - * Network Engine for Objects in Lund AB [http://neotechnology.com] - * - * This file is part of Neo4j. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.neo4j.driver.internal.util; - - -import org.junit.Test; - -import java.util.Comparator; -import java.util.HashSet; - -import static java.util.Arrays.asList; -import static junit.framework.TestCase.assertFalse; -import static junit.framework.TestCase.assertTrue; -import static org.hamcrest.CoreMatchers.equalTo; -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.empty; - -public class ConcurrentRoundRobinSetTest -{ - - @Test - public void shouldBeAbleToIterateIndefinitely() - { - // Given - ConcurrentRoundRobinSet integers = new ConcurrentRoundRobinSet<>(); - - // When - integers.addAll( asList( 0, 1, 2, 3, 4 ) ); - - // Then - for ( int i = 0; i < 100; i++ ) - { - assertThat( integers.hop(), equalTo( i % 5 ) ); - } - } - - @Test - public void shouldBeAbleToUseCustomComparator() - { - // Given - ConcurrentRoundRobinSet integers = new ConcurrentRoundRobinSet<>( new Comparator() - { - @Override - public int compare( Integer o1, Integer o2 ) - { - return Integer.compare( o2, o1 ); - } - } ); - - // When - integers.addAll( asList( 0, 1, 2, 3, 4 ) ); - - // Then - assertThat( integers.hop(), equalTo( 4 ) ); - assertThat( integers.hop(), equalTo( 3 ) ); - assertThat( integers.hop(), equalTo( 2 ) ); - assertThat( integers.hop(), equalTo( 1 ) ); - assertThat( integers.hop(), equalTo( 0 ) ); - assertThat( integers.hop(), equalTo( 4 ) ); - assertThat( integers.hop(), equalTo( 3 ) ); - //.... - } - - @Test - public void shouldBeAbleToClearSet() - { - // Given - ConcurrentRoundRobinSet integers = new ConcurrentRoundRobinSet<>(); - - // When - integers.addAll( asList( 0, 1, 2, 3, 4 ) ); - integers.clear(); - - // Then - assertThat( integers, empty() ); - } - - @Test - public void shouldBeAbleToCheckIfContainsElement() - { - // Given - ConcurrentRoundRobinSet integers = new ConcurrentRoundRobinSet<>(); - - // When - integers.addAll( asList( 0, 1, 2, 3, 4 ) ); - - - // Then - assertTrue( integers.contains( 3 ) ); - assertFalse( integers.contains( 7 ) ); - } - - @Test - public void shouldBeAbleToCheckIfContainsMultipleElements() - { - // Given - ConcurrentRoundRobinSet integers = new ConcurrentRoundRobinSet<>(); - - // When - integers.addAll( asList( 0, 1, 2, 3, 4 ) ); - - - // Then - assertTrue( integers.containsAll( asList( 3, 1 ) ) ); - assertFalse( integers.containsAll( asList( 2, 3, 4, 7 ) ) ); - } - - @Test - public void shouldBeAbleToCheckIfEmptyAndSize() - { - // Given - ConcurrentRoundRobinSet integers = new ConcurrentRoundRobinSet<>(); - - // When - integers.addAll( asList( 0, 1, 2, 3, 4 ) ); - - - // Then - assertFalse( integers.isEmpty() ); - assertThat( integers.size(), equalTo( 5 ) ); - integers.clear(); - assertTrue( integers.isEmpty() ); - assertThat( integers.size(), equalTo( 0 ) ); - } - - - @Test - public void shouldBeAbleToCreateArray() - { - // Given - ConcurrentRoundRobinSet integers = new ConcurrentRoundRobinSet<>(); - - // When - integers.addAll( asList( 0, 1, 2, 3, 4 ) ); - Object[] objects = integers.toArray(); - - // Then - assertThat( objects, equalTo( new Object[]{0, 1, 2, 3, 4} ) ); - } - - @Test - public void shouldBeAbleToCreateTypedArray() - { - // Given - ConcurrentRoundRobinSet integers = new ConcurrentRoundRobinSet<>(); - - // When - integers.addAll( asList( 0, 1, 2, 3, 4 ) ); - Integer[] array = integers.toArray( new Integer[5] ); - - // Then - assertThat( array, equalTo( new Integer[]{0, 1, 2, 3, 4} ) ); - } -} \ No newline at end of file diff --git a/driver/src/test/java/org/neo4j/driver/internal/util/FakeClock.java b/driver/src/test/java/org/neo4j/driver/internal/util/FakeClock.java new file mode 100644 index 0000000000..c697eb0fd4 --- /dev/null +++ b/driver/src/test/java/org/neo4j/driver/internal/util/FakeClock.java @@ -0,0 +1,260 @@ +/** + * Copyright (c) 2002-2016 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal.util; + +import java.util.concurrent.PriorityBlockingQueue; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import java.util.concurrent.locks.LockSupport; + +import org.hamcrest.Description; +import org.hamcrest.Matcher; +import org.hamcrest.TypeSafeMatcher; + +import org.neo4j.driver.internal.EventHandler; + +import static java.util.concurrent.atomic.AtomicLongFieldUpdater.newUpdater; +import static org.hamcrest.Matchers.any; +import static org.hamcrest.Matchers.equalTo; + +public class FakeClock implements Clock +{ + public interface EventSink + { + EventSink VOID = new Adapter(); + + void sleep( long timestamp, long millis ); + + void progress( long millis ); + + class Adapter implements EventSink + { + @Override + public void sleep( long timestamp, long millis ) + { + } + + @Override + public void progress( long millis ) + { + } + } + } + + private final EventSink events; + @SuppressWarnings( "unused"/*assigned through AtomicLongFieldUpdater*/ ) + private volatile long timestamp; + private static final AtomicLongFieldUpdater TIMESTAMP = newUpdater( FakeClock.class, "timestamp" ); + private PriorityBlockingQueue threads; + + public FakeClock( final EventHandler events, boolean progressOnSleep ) + { + this( events == null ? null : new EventSink() + { + @Override + public void sleep( long timestamp, long duration ) + { + events.add( new Event.Sleep( Thread.currentThread(), timestamp, duration ) ); + } + + @Override + public void progress( long timestamp ) + { + events.add( new Event.Progress( Thread.currentThread(), timestamp ) ); + } + }, progressOnSleep ); + } + + public FakeClock( EventSink events, boolean progressOnSleep ) + { + this.events = events == null ? EventSink.VOID : events; + this.threads = progressOnSleep ? null : new PriorityBlockingQueue(); + } + + @Override + public long millis() + { + return timestamp; + } + + @Override + public void sleep( long millis ) + { + if ( millis <= 0 ) + { + return; + } + long target = timestamp + millis; + events.sleep( target - millis, millis ); + if ( threads == null ) + { + progress( millis ); + } + else + { + // park until the target time has been reached + WaitingThread token = new WaitingThread( Thread.currentThread(), target ); + threads.add( token ); + for ( ; ; ) + { + if ( timestamp >= target ) + { + threads.remove( token ); + return; + } + // park with a timeout to guarantee that we make progress even if something goes wrong + LockSupport.parkNanos( this, TimeUnit.MILLISECONDS.toNanos( millis ) ); + } + } + } + + public void progress( long millis ) + { + if ( millis < 0 ) + { + throw new IllegalArgumentException( "time can only progress forwards" ); + } + events.progress( TIMESTAMP.addAndGet( this, millis ) ); + if ( threads != null ) + { + // wake up the threads that are sleeping awaiting the current time + for ( WaitingThread thread; (thread = threads.peek()) != null; ) + { + if ( thread.timestamp < timestamp ) + { + threads.remove( thread ); + LockSupport.unpark( thread.thread ); + } + } + } + } + + public static abstract class Event extends org.neo4j.driver.internal.Event + { + final Thread thread; + + private Event( Thread thread ) + { + this.thread = thread; + } + + public static Matcher sleep( long duration ) + { + return sleep( any( Thread.class ), any( Long.class ), equalTo( duration ) ); + } + + public static Matcher sleep( + final Matcher thread, + final Matcher timestamp, + final Matcher duration ) + { + return new TypeSafeMatcher() + { + @Override + public void describeTo( Description description ) + { + description.appendText( "Sleep Event on thread <" ) + .appendDescriptionOf( thread ) + .appendText( "> at timestamp " ) + .appendDescriptionOf( timestamp ) + .appendText( " for duration " ) + .appendDescriptionOf( timestamp ) + .appendText( " (in milliseconds)" ); + } + + @Override + protected boolean matchesSafely( Sleep event ) + { + return thread.matches( event.thread ) + && timestamp.matches( event.timestamp ) + && duration.matches( event.duration ); + } + }; + } + + public static Matcher progress( final Matcher thread, final Matcher timestamp ) + { + return new TypeSafeMatcher() + { + @Override + public void describeTo( Description description ) + { + description.appendText( "Time progresses to timestamp " ) + .appendDescriptionOf( timestamp ) + .appendText( " by thread <" ) + .appendDescriptionOf( thread ) + .appendText( ">" ); + } + + @Override + protected boolean matchesSafely( Progress event ) + { + return thread.matches( event.thread ) && timestamp.matches( event.timestamp ); + } + }; + } + + private static class Sleep extends Event + { + private final long timestamp, duration; + + Sleep( Thread thread, long timestamp, long duration ) + { + super( thread ); + this.timestamp = timestamp; + this.duration = duration; + } + + @Override + public void dispatch( EventSink sink ) + { + sink.sleep( timestamp, duration ); + } + } + + private static class Progress extends Event + { + private final long timestamp; + + Progress( Thread thread, long timestamp ) + { + super( thread ); + this.timestamp = timestamp; + } + + @Override + public void dispatch( EventSink sink ) + { + sink.progress( timestamp ); + } + } + } + + private static class WaitingThread + { + final Thread thread; + final long timestamp; + + private WaitingThread( Thread thread, long timestamp ) + { + this.thread = thread; + this.timestamp = timestamp; + } + } +} diff --git a/driver/src/test/java/org/neo4j/driver/internal/util/MatcherFactory.java b/driver/src/test/java/org/neo4j/driver/internal/util/MatcherFactory.java new file mode 100644 index 0000000000..3d1a8878f6 --- /dev/null +++ b/driver/src/test/java/org/neo4j/driver/internal/util/MatcherFactory.java @@ -0,0 +1,230 @@ +/** + * Copyright (c) 2002-2016 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal.util; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; + +import org.hamcrest.BaseMatcher; +import org.hamcrest.Description; +import org.hamcrest.Matcher; +import org.hamcrest.SelfDescribing; +import org.hamcrest.TypeSafeDiagnosingMatcher; +import org.hamcrest.TypeSafeMatcher; + +public abstract class MatcherFactory implements SelfDescribing +{ + public abstract Matcher createMatcher(); + + /** + * Matches a collection based on the number of elements in the collection that match a given matcher. + * + * @param matcher + * The matcher used for counting matching elements. + * @param count + * The matcher used for evaluating the number of matching elements. + * @param + * The type of elements in the collection. + * @return A matcher for a collection. + */ + public static Matcher> count( final Matcher matcher, final Matcher count ) + { + return new TypeSafeDiagnosingMatcher>() + { + @Override + protected boolean matchesSafely( Iterable collection, Description mismatchDescription ) + { + int matches = 0; + for ( T item : collection ) + { + if ( matcher.matches( item ) ) + { + matches++; + } + } + if ( count.matches( matches ) ) + { + return true; + } + mismatchDescription.appendText( "actual number of matches was " ).appendValue( matches ) + .appendText( " in " ).appendValue( collection ); + return false; + } + + @Override + public void describeTo( Description description ) + { + description.appendText( "collection containing " ) + .appendDescriptionOf( count ) + .appendText( " occurences of " ) + .appendDescriptionOf( matcher ); + } + }; + } + + /** + * Matches a collection that contains elements that match the specified matchers. The elements must be in the same + * order as the given matchers, but the collection may contain other elements in between the matching elements. + * + * @param matchers + * The matchers for the elements of the collection. + * @param + * The type of the elements in the collection. + * @return A matcher for a collection. + */ + @SafeVarargs + public static Matcher> containsAtLeast( final Matcher... matchers ) + { + @SuppressWarnings( "unchecked" ) + MatcherFactory[] factories = new MatcherFactory[matchers.length]; + for ( int i = 0; i < factories.length; i++ ) + { + factories[i] = matches( matchers[i] ); + } + return containsAtLeast( factories ); + } + + @SafeVarargs + public static Matcher> containsAtLeast( final MatcherFactory... matcherFactories ) + { + return new TypeSafeMatcher>() + { + @Override + protected boolean matchesSafely( Iterable collection ) + { + @SuppressWarnings( "unchecked" ) + Matcher[] matchers = new Matcher[matcherFactories.length]; + for ( int i = 0; i < matchers.length; i++ ) + { + matchers[i] = matcherFactories[i].createMatcher(); + } + int i = 0; + for ( T item : collection ) + { + if ( i >= matchers.length ) + { + return true; + } + if ( matchers[i].matches( item ) ) + { + i++; + } + } + return i == matchers.length; + } + + @Override + public void describeTo( Description description ) + { + description.appendText( "collection containing at least " ); + for ( int i = 0; i < matcherFactories.length; i++ ) + { + if ( i != 0 ) + { + if ( i == matcherFactories.length - 1 ) + { + description.appendText( " and " ); + } + else + { + description.appendText( ", " ); + } + } + description.appendDescriptionOf( matcherFactories[i] ); + } + description.appendText( " (in that order) " ); + } + }; + } + + @SafeVarargs + public static MatcherFactory inAnyOrder( final Matcher... matchers ) + { + return new MatcherFactory() + { + @Override + public Matcher createMatcher() + { + final List> remaining = new ArrayList<>( matchers.length ); + Collections.addAll( remaining, matchers ); + return new BaseMatcher() + { + @Override + public boolean matches( Object item ) + { + for ( Iterator> matcher = remaining.iterator(); matcher.hasNext(); ) + { + if ( matcher.next().matches( item ) ) + { + matcher.remove(); + return remaining.isEmpty(); + } + } + return remaining.isEmpty(); + } + + @Override + public void describeTo( Description description ) + { + describe( description ); + } + }; + } + + @Override + public void describeTo( Description description ) + { + describe( description ); + } + + private void describe( Description description ) + { + description.appendText( "in any order" ); + String sep = " {"; + for ( Matcher matcher : matchers ) + { + description.appendText( sep ); + description.appendDescriptionOf( matcher ); + sep = ", "; + } + description.appendText( "}" ); + } + }; + } + + public static MatcherFactory matches( final Matcher matcher ) + { + return new MatcherFactory() + { + @Override + public Matcher createMatcher() + { + return matcher; + } + + @Override + public void describeTo( Description description ) + { + matcher.describeTo( description ); + } + }; + } +} diff --git a/driver/src/test/java/org/neo4j/driver/v1/EventLogger.java b/driver/src/test/java/org/neo4j/driver/v1/EventLogger.java new file mode 100644 index 0000000000..084131a568 --- /dev/null +++ b/driver/src/test/java/org/neo4j/driver/v1/EventLogger.java @@ -0,0 +1,225 @@ +/** + * Copyright (c) 2002-2016 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.v1; + +import org.hamcrest.Description; +import org.hamcrest.Matcher; +import org.hamcrest.TypeSafeMatcher; + +import org.neo4j.driver.internal.Event; +import org.neo4j.driver.internal.EventHandler; + +import static java.util.Objects.requireNonNull; + +public class EventLogger implements Logger +{ + public static Logging provider( EventHandler events, Level level ) + { + return provider( sink( requireNonNull( events, "events" ) ), level ); + } + + public static Logging provider( final Sink events, final Level level ) + { + requireNonNull( events, "events" ); + requireNonNull( level, "level" ); + return new Logging() + { + @Override + public Logger getLog( String name ) + { + return new EventLogger( events, name, level ); + } + }; + } + + public interface Sink + { + void log( String name, Level level, Throwable cause, String message, Object... params ); + } + + private final boolean debug, trace; + private final Sink events; + private final String name; + + public EventLogger( EventHandler events, String name, Level level ) + { + this( sink( requireNonNull( events, "events" ) ), name, level ); + } + + public EventLogger( Sink events, String name, Level level ) + { + this.events = requireNonNull( events, "events" ); + this.name = name; + level = requireNonNull( level, "level" ); + this.debug = Level.DEBUG.compareTo( level ) <= 0; + this.trace = Level.TRACE.compareTo( level ) <= 0; + } + + private static Sink sink( final EventHandler events ) + { + return new Sink() + { + @Override + public void log( String name, Level level, Throwable cause, String message, Object... params ) + { + events.add( new Entry( Thread.currentThread(), name, level, cause, message, params ) ); + } + }; + } + + public enum Level + { + ERROR, + WARN, + INFO, + DEBUG, + TRACE + } + + @Override + public void error( String message, Throwable cause ) + { + events.log( name, Level.ERROR, cause, message ); + } + + @Override + public void info( String message, Object... params ) + { + events.log( name, Level.INFO, null, message, params ); + } + + @Override + public void warn( String message, Object... params ) + { + events.log( name, Level.WARN, null, message, params ); + } + + @Override + public void debug( String message, Object... params ) + { + events.log( name, Level.DEBUG, null, message, params ); + } + + @Override + public void trace( String message, Object... params ) + { + events.log( name, Level.TRACE, null, message, params ); + } + + @Override + public boolean isTraceEnabled() + { + return trace; + } + + @Override + public boolean isDebugEnabled() + { + return debug; + } + + public static final class Entry extends Event + { + private final Thread thread; + private final String name; + private final Level level; + private final Throwable cause; + private final String message; + private final Object[] params; + + private Entry( Thread thread, String name, Level level, Throwable cause, String message, Object... params ) + { + this.thread = thread; + this.name = name; + this.level = requireNonNull( level, "level" ); + this.cause = cause; + this.message = message; + this.params = params; + } + + @Override + public void dispatch( Sink sink ) + { + sink.log( name, level, cause, message, params ); + } + + private String formatted() + { + return params == null ? message : String.format( message, params ); + } + + public static Matcher logEntry( + final Matcher thread, + final Matcher name, + final Matcher level, + final Matcher cause, + final Matcher message, + final Matcher params ) + { + return new TypeSafeMatcher() + { + @Override + protected boolean matchesSafely( Entry entry ) + { + return level.matches( entry.level ) && + thread.matches( entry.thread ) && + name.matches( entry.name ) && + cause.matches( entry.cause ) && + message.matches( entry.message ) && + params.matches( entry.params ); + } + + @Override + public void describeTo( Description description ) + { + description.appendText( "Log entry where level " ) + .appendDescriptionOf( level ) + .appendText( " name <" ) + .appendDescriptionOf( name ) + .appendText( "> cause <" ) + .appendDescriptionOf( cause ) + .appendText( "> message <" ) + .appendDescriptionOf( message ) + .appendText( "> and parameters <" ) + .appendDescriptionOf( params ) + .appendText( ">" ); + } + }; + } + + public static Matcher message( final Matcher level, final Matcher message ) + { + return new TypeSafeMatcher() + { + @Override + protected boolean matchesSafely( Entry entry ) + { + return level.matches( entry.level ) && message.matches( entry.formatted() ); + } + + @Override + public void describeTo( Description description ) + { + description.appendText( "Log entry where level " ).appendDescriptionOf( level ) + .appendText( " and formatted message " ).appendDescriptionOf( message ); + } + }; + } + } +}