diff --git a/driver/src/main/java/org/neo4j/driver/internal/ClusterView.java b/driver/src/main/java/org/neo4j/driver/internal/ClusterView.java new file mode 100644 index 0000000000..3e378abb08 --- /dev/null +++ b/driver/src/main/java/org/neo4j/driver/internal/ClusterView.java @@ -0,0 +1,175 @@ +/** + * Copyright (c) 2002-2016 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal; + +import java.util.Collections; +import java.util.Comparator; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import org.neo4j.driver.internal.net.BoltServerAddress; +import org.neo4j.driver.internal.util.Clock; +import org.neo4j.driver.internal.util.ConcurrentRoundRobinSet; +import org.neo4j.driver.v1.Logger; + +/** + * Defines a snapshot view of the cluster. + */ +class ClusterView +{ + private final static Comparator COMPARATOR = new Comparator() + { + @Override + public int compare( BoltServerAddress o1, BoltServerAddress o2 ) + { + int compare = o1.host().compareTo( o2.host() ); + if ( compare == 0 ) + { + compare = Integer.compare( o1.port(), o2.port() ); + } + + return compare; + } + }; + + private static final int MIN_ROUTERS = 1; + + private final ConcurrentRoundRobinSet routingServers = + new ConcurrentRoundRobinSet<>( COMPARATOR ); + private final ConcurrentRoundRobinSet readServers = + new ConcurrentRoundRobinSet<>( COMPARATOR ); + private final ConcurrentRoundRobinSet writeServers = + new ConcurrentRoundRobinSet<>( COMPARATOR ); + private final Clock clock; + private final long expires; + private final Logger log; + + public ClusterView( long expires, Clock clock, Logger log ) + { + this.expires = expires; + this.clock = clock; + this.log = log; + } + + public void addRouter( BoltServerAddress router ) + { + this.routingServers.add( router ); + } + + public boolean isStale() + { + return expires < clock.millis() || + routingServers.size() <= MIN_ROUTERS || + readServers.isEmpty() || + writeServers.isEmpty(); + } + + Set all() + { + HashSet all = + new HashSet<>( routingServers.size() + readServers.size() + writeServers.size() ); + all.addAll( routingServers ); + all.addAll( readServers ); + all.addAll( writeServers ); + return all; + } + + + public BoltServerAddress nextRouter() + { + return routingServers.hop(); + } + + public BoltServerAddress nextReader() + { + return readServers.hop(); + } + + public BoltServerAddress nextWriter() + { + return writeServers.hop(); + } + + public void addReaders( List addresses ) + { + readServers.addAll( addresses ); + } + + public void addWriters( List addresses ) + { + writeServers.addAll( addresses ); + } + + public void addRouters( List addresses ) + { + routingServers.addAll( addresses ); + } + + public void remove( BoltServerAddress address ) + { + if ( routingServers.remove( address ) ) + { + log.debug( "Removing %s from routers", address.toString() ); + } + if ( readServers.remove( address ) ) + { + log.debug( "Removing %s from readers", address.toString() ); + } + if ( writeServers.remove( address ) ) + { + log.debug( "Removing %s from writers", address.toString() ); + } + } + + public boolean removeWriter( BoltServerAddress address ) + { + return writeServers.remove( address ); + } + + public int numberOfRouters() + { + return routingServers.size(); + } + + public int numberOfReaders() + { + return readServers.size(); + } + + public int numberOfWriters() + { + return writeServers.size(); + } + + public Set routingServers() + { + return Collections.unmodifiableSet( routingServers ); + } + + public Set readServers() + { + return Collections.unmodifiableSet( readServers ); + } + + public Set writeServers() + { + return Collections.unmodifiableSet( writeServers ); + } +} diff --git a/driver/src/main/java/org/neo4j/driver/internal/RoutingDriver.java b/driver/src/main/java/org/neo4j/driver/internal/RoutingDriver.java index 90ff41f7e7..03cb3f1c2a 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/RoutingDriver.java +++ b/driver/src/main/java/org/neo4j/driver/internal/RoutingDriver.java @@ -18,9 +18,6 @@ */ package org.neo4j.driver.internal; -import java.util.Collections; -import java.util.Comparator; -import java.util.HashSet; import java.util.List; import java.util.Set; @@ -29,9 +26,7 @@ import org.neo4j.driver.internal.spi.Connection; import org.neo4j.driver.internal.spi.ConnectionPool; import org.neo4j.driver.internal.util.Clock; -import org.neo4j.driver.internal.util.ConcurrentRoundRobinSet; import org.neo4j.driver.v1.AccessMode; -import org.neo4j.driver.v1.Logger; import org.neo4j.driver.v1.Logging; import org.neo4j.driver.v1.Record; import org.neo4j.driver.v1.Session; @@ -48,20 +43,7 @@ public class RoutingDriver extends BaseDriver { private static final String GET_SERVERS = "dbms.cluster.routing.getServers"; private static final long MAX_TTL = Long.MAX_VALUE / 1000L; - private final static Comparator COMPARATOR = new Comparator() - { - @Override - public int compare( BoltServerAddress o1, BoltServerAddress o2 ) - { - int compare = o1.host().compareTo( o2.host() ); - if ( compare == 0 ) - { - compare = Integer.compare( o1.port(), o2.port() ); - } - return compare; - } - }; private final ConnectionPool connections; private final Function sessionProvider; private final Clock clock; @@ -197,117 +179,6 @@ List addresses() } } - private static class ClusterView - { - private static final int MIN_ROUTERS = 1; - - private final ConcurrentRoundRobinSet routingServers = - new ConcurrentRoundRobinSet<>( COMPARATOR ); - private final ConcurrentRoundRobinSet readServers = - new ConcurrentRoundRobinSet<>( COMPARATOR ); - private final ConcurrentRoundRobinSet writeServers = - new ConcurrentRoundRobinSet<>( COMPARATOR ); - private final Clock clock; - private final long expires; - private final Logger log; - - private ClusterView( long expires, Clock clock, Logger log ) - { - this.expires = expires; - this.clock = clock; - this.log = log; - } - - public void addRouter( BoltServerAddress router ) - { - this.routingServers.add( router ); - } - - public boolean isStale() - { - return expires < clock.millis() || - routingServers.size() <= MIN_ROUTERS || - readServers.isEmpty() || - writeServers.isEmpty(); - } - - Set all() - { - HashSet all = - new HashSet<>( routingServers.size() + readServers.size() + writeServers.size() ); - all.addAll( routingServers ); - all.addAll( readServers ); - all.addAll( writeServers ); - return all; - } - - public int numberOfRouters() - { - return routingServers.size(); - } - - public BoltServerAddress nextRouter() - { - return routingServers.hop(); - } - - public BoltServerAddress nextReader() - { - return readServers.hop(); - } - - public BoltServerAddress nextWriter() - { - return writeServers.hop(); - } - - public void addReaders( List addresses ) - { - readServers.addAll( addresses ); - } - - public void addWriters( List addresses ) - { - writeServers.addAll( addresses ); - } - - public void addRouters( List addresses ) - { - routingServers.addAll( addresses ); - } - - public void remove( BoltServerAddress address ) - { - if ( routingServers.remove( address ) ) - { - log.debug( "Removing %s from routers", address.toString() ); - } - if ( readServers.remove( address ) ) - { - log.debug( "Removing %s from readers", address.toString() ); - } - if ( writeServers.remove( address ) ) - { - log.debug( "Removing %s from writers", address.toString() ); - } - } - - public boolean removeWriter( BoltServerAddress address ) - { - return writeServers.remove( address ); - } - - public int numberOfReaders() - { - return readServers.size(); - } - - public int numberOfWriters() - { - return writeServers.size(); - } - } - private List servers( Record record ) { return record.get( "servers" ).asList( new Function() @@ -371,7 +242,8 @@ public Session session() @Override public Session session( final AccessMode mode ) { - return new RoutingNetworkSession( mode, acquireConnection( mode ), + Connection connection = acquireConnection( mode ); + return new RoutingNetworkSession( new NetworkSession( connection ), mode, connection.address(), new RoutingErrorHandler() { @Override @@ -458,19 +330,19 @@ public void close() //For testing public Set routingServers() { - return Collections.unmodifiableSet( clusterView.routingServers ); + return clusterView.routingServers(); } //For testing public Set readServers() { - return Collections.unmodifiableSet( clusterView.readServers ); + return clusterView.readServers(); } //For testing public Set writeServers() { - return Collections.unmodifiableSet( clusterView.writeServers ); + return clusterView.writeServers( ); } //For testing diff --git a/driver/src/main/java/org/neo4j/driver/internal/RoutingNetworkSession.java b/driver/src/main/java/org/neo4j/driver/internal/RoutingNetworkSession.java index 6e53194e3b..5046ad9d2e 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/RoutingNetworkSession.java +++ b/driver/src/main/java/org/neo4j/driver/internal/RoutingNetworkSession.java @@ -19,68 +19,150 @@ package org.neo4j.driver.internal; +import java.util.Map; + import org.neo4j.driver.internal.net.BoltServerAddress; -import org.neo4j.driver.internal.spi.Connection; import org.neo4j.driver.v1.AccessMode; +import org.neo4j.driver.v1.Record; +import org.neo4j.driver.v1.Session; import org.neo4j.driver.v1.Statement; import org.neo4j.driver.v1.StatementResult; +import org.neo4j.driver.v1.Transaction; +import org.neo4j.driver.v1.Value; +import org.neo4j.driver.v1.Values; import org.neo4j.driver.v1.exceptions.ClientException; import org.neo4j.driver.v1.exceptions.ConnectionFailureException; import org.neo4j.driver.v1.exceptions.Neo4jException; import org.neo4j.driver.v1.exceptions.SessionExpiredException; +import org.neo4j.driver.v1.types.TypeSystem; import static java.lang.String.format; +import static org.neo4j.driver.v1.Values.value; -public class RoutingNetworkSession extends NetworkSession +/** + * A session that safely handles routing errors. + */ +public class RoutingNetworkSession implements Session { + protected final Session delegate; + private final BoltServerAddress address; private final AccessMode mode; private final RoutingErrorHandler onError; - RoutingNetworkSession( AccessMode mode, Connection connection, + RoutingNetworkSession( Session delegate, AccessMode mode, BoltServerAddress address, RoutingErrorHandler onError ) { - super( connection ); + this.delegate = delegate; this.mode = mode; + this.address = address; this.onError = onError; } + @Override + public StatementResult run( String statementText ) + { + return run( statementText, Values.EmptyMap ); + } + + @Override + public StatementResult run( String statementText, Map statementParameters ) + { + Value params = statementParameters == null ? Values.EmptyMap : value( statementParameters ); + return run( statementText, params ); + } + + @Override + public StatementResult run( String statementTemplate, Record statementParameters ) + { + Value params = statementParameters == null ? Values.EmptyMap : value( statementParameters.asMap() ); + return run( statementTemplate, params ); + } + + @Override + public StatementResult run( String statementText, Value statementParameters ) + { + return run( new Statement( statementText, statementParameters ) ); + } + @Override public StatementResult run( Statement statement ) { try { - return new RoutingStatementResult( super.run( statement ), mode, connection.address(), onError ); + return new RoutingStatementResult( delegate.run( statement ), mode, address, onError ); } catch ( ConnectionFailureException e ) { - throw sessionExpired( e, onError, connection.address() ); + throw sessionExpired( e, onError, address ); } catch ( ClientException e ) { - throw filterFailureToWrite( e, mode, onError, connection.address() ); + throw filterFailureToWrite( e, mode, onError, address ); } } + @Override + public TypeSystem typeSystem() + { + return delegate.typeSystem(); + } + + @Override + public Transaction beginTransaction() + { + return new RoutingTransaction( delegate.beginTransaction(), mode, address, onError); + } + + @Override + public Transaction beginTransaction( String bookmark ) + { + return new RoutingTransaction( delegate.beginTransaction(bookmark), mode, address, onError); + } + + @Override + public String lastBookmark() + { + return delegate.lastBookmark(); + } + + @Override + public void reset() + { + delegate.reset(); + } + + @Override + public boolean isOpen() + { + return delegate.isOpen(); + } + @Override public void close() { try { - super.close(); + delegate.close(); } catch ( ConnectionFailureException e ) { - throw sessionExpired(e, onError, connection.address()); + throw sessionExpired(e, onError, address); } catch ( ClientException e ) { - throw filterFailureToWrite( e, mode, onError, connection.address() ); + throw filterFailureToWrite( e, mode, onError, address ); } } + @Override + public String server() + { + return delegate.server(); + } + public BoltServerAddress address() { - return connection.address(); + return address; } static Neo4jException filterFailureToWrite( ClientException e, AccessMode mode, RoutingErrorHandler onError, diff --git a/driver/src/main/java/org/neo4j/driver/internal/RoutingTransaction.java b/driver/src/main/java/org/neo4j/driver/internal/RoutingTransaction.java new file mode 100644 index 0000000000..289269d919 --- /dev/null +++ b/driver/src/main/java/org/neo4j/driver/internal/RoutingTransaction.java @@ -0,0 +1,143 @@ +/** + * Copyright (c) 2002-2016 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal; + + +import java.util.Map; + +import org.neo4j.driver.internal.net.BoltServerAddress; +import org.neo4j.driver.v1.AccessMode; +import org.neo4j.driver.v1.Record; +import org.neo4j.driver.v1.Statement; +import org.neo4j.driver.v1.StatementResult; +import org.neo4j.driver.v1.Transaction; +import org.neo4j.driver.v1.Value; +import org.neo4j.driver.v1.Values; +import org.neo4j.driver.v1.exceptions.ClientException; +import org.neo4j.driver.v1.exceptions.ConnectionFailureException; +import org.neo4j.driver.v1.types.TypeSystem; + +import static org.neo4j.driver.internal.RoutingNetworkSession.filterFailureToWrite; +import static org.neo4j.driver.internal.RoutingNetworkSession.sessionExpired; +import static org.neo4j.driver.v1.Values.value; + +/** + * A transaction that safely handles routing errors. + */ +public class RoutingTransaction implements Transaction +{ + protected final Transaction delegate; + private final AccessMode mode; + private final BoltServerAddress address; + private final RoutingErrorHandler onError; + + RoutingTransaction( Transaction delegate, AccessMode mode, BoltServerAddress address, + RoutingErrorHandler onError ) + { + this.delegate = delegate; + this.mode = mode; + this.address = address; + this.onError = onError; + } + + @Override + public StatementResult run( String statementText ) + { + return run( statementText, Values.EmptyMap ); + } + + @Override + public StatementResult run( String statementText, Map statementParameters ) + { + Value params = statementParameters == null ? Values.EmptyMap : value( statementParameters ); + return run( statementText, params ); + } + + @Override + public StatementResult run( String statementTemplate, Record statementParameters ) + { + Value params = statementParameters == null ? Values.EmptyMap : value( statementParameters.asMap() ); + return run( statementTemplate, params ); + } + + @Override + public StatementResult run( String statementText, Value statementParameters ) + { + return run( new Statement( statementText, statementParameters ) ); + } + + @Override + public StatementResult run( Statement statement ) + { + try + { + return new RoutingStatementResult( delegate.run( statement ), mode, address, onError ); + } + catch ( ConnectionFailureException e ) + { + throw sessionExpired( e, onError, address ); + } + catch ( ClientException e ) + { + throw filterFailureToWrite( e, mode, onError, address ); + } + } + + @Override + public TypeSystem typeSystem() + { + return delegate.typeSystem(); + } + + + @Override + public void success() + { + delegate.success(); + } + + @Override + public void failure() + { + delegate.failure(); + } + + @Override + public boolean isOpen() + { + return delegate.isOpen(); + } + + @Override + public void close() + { + try + { + delegate.close(); + } + catch ( ConnectionFailureException e ) + { + throw sessionExpired(e, onError, address); + } + catch ( ClientException e ) + { + throw filterFailureToWrite( e, mode, onError, address ); + } + } +} diff --git a/driver/src/test/java/org/neo4j/driver/internal/ClusterViewTest.java b/driver/src/test/java/org/neo4j/driver/internal/ClusterViewTest.java new file mode 100644 index 0000000000..2f9b7fb778 --- /dev/null +++ b/driver/src/test/java/org/neo4j/driver/internal/ClusterViewTest.java @@ -0,0 +1,189 @@ +/** + * Copyright (c) 2002-2016 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal; + + +import org.junit.Test; + +import org.neo4j.driver.internal.net.BoltServerAddress; +import org.neo4j.driver.internal.util.Clock; +import org.neo4j.driver.v1.Logger; + +import static java.util.Arrays.asList; +import static java.util.Collections.singletonList; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class ClusterViewTest +{ + + @Test + public void shouldRoundRobinAmongRoutingServers() + { + // Given + ClusterView clusterView = new ClusterView( 5L, mock( Clock.class ), mock( Logger.class ) ); + + // When + clusterView.addRouters( asList( address("host1"), address( "host2" ), address( "host3" ))); + + // Then + assertThat(clusterView.nextRouter(), equalTo(address( "host1" ))); + assertThat(clusterView.nextRouter(), equalTo(address( "host2" ))); + assertThat(clusterView.nextRouter(), equalTo(address( "host3" ))); + assertThat(clusterView.nextRouter(), equalTo(address( "host1" ))); + } + + @Test + public void shouldRoundRobinAmongReadServers() + { + // Given + ClusterView clusterView = new ClusterView( 5L, mock( Clock.class ), mock( Logger.class ) ); + + // When + clusterView.addReaders( asList( address("host1"), address( "host2" ), address( "host3" ))); + + // Then + assertThat(clusterView.nextReader(), equalTo(address( "host1" ))); + assertThat(clusterView.nextReader(), equalTo(address( "host2" ))); + assertThat(clusterView.nextReader(), equalTo(address( "host3" ))); + assertThat(clusterView.nextReader(), equalTo(address( "host1" ))); + } + + @Test + public void shouldRoundRobinAmongWriteServers() + { + // Given + ClusterView clusterView = new ClusterView( 5L, mock( Clock.class ), mock( Logger.class ) ); + + // When + clusterView.addWriters( asList( address("host1"), address( "host2" ), address( "host3" ))); + + // Then + assertThat(clusterView.nextWriter(), equalTo(address( "host1" ))); + assertThat(clusterView.nextWriter(), equalTo(address( "host2" ))); + assertThat(clusterView.nextWriter(), equalTo(address( "host3" ))); + assertThat(clusterView.nextWriter(), equalTo(address( "host1" ))); + } + + @Test + public void shouldRemoveServer() + { + // Given + ClusterView clusterView = new ClusterView( 5L, mock( Clock.class ), mock( Logger.class ) ); + + clusterView.addRouters( asList( address("host1"), address( "host2" ), address( "host3" ))); + clusterView.addReaders( asList( address("host1"), address( "host2" ), address( "host3" ))); + clusterView.addWriters( asList( address("host2"), address( "host4" ))); + + // When + clusterView.remove( address( "host2" ) ); + + // Then + assertThat(clusterView.routingServers(), containsInAnyOrder(address( "host1" ), address( "host3" ))); + assertThat(clusterView.readServers(), containsInAnyOrder(address( "host1" ), address( "host3" ))); + assertThat(clusterView.writeServers(), containsInAnyOrder(address( "host4" ))); + assertThat(clusterView.all(), containsInAnyOrder( address( "host1" ), address( "host3" ), address( "host4" ) )); + } + + @Test + public void shouldBeStaleIfExpired() + { + // Given + Clock clock = mock( Clock.class ); + when(clock.millis()).thenReturn( 6L ); + ClusterView clusterView = new ClusterView( 5L, clock, mock( Logger.class ) ); + clusterView.addRouters( asList( address("host1"), address( "host2" ), address( "host3" ))); + clusterView.addReaders( asList( address("host1"), address( "host2" ), address( "host3" ))); + clusterView.addWriters( asList( address("host2"), address( "host4" ))); + + // Then + assertTrue(clusterView.isStale()); + } + + @Test + public void shouldNotBeStaleIfNotExpired() + { + // Given + Clock clock = mock( Clock.class ); + when(clock.millis()).thenReturn( 4L ); + ClusterView clusterView = new ClusterView( 5L, clock, mock( Logger.class ) ); + clusterView.addRouters( asList( address("host1"), address( "host2" ), address( "host3" ))); + clusterView.addReaders( asList( address("host1"), address( "host2" ), address( "host3" ))); + clusterView.addWriters( asList( address("host2"), address( "host4" ))); + + // Then + assertFalse(clusterView.isStale()); + } + + @Test + public void shouldBeStaleIfOnlyOneRouter() + { + // Given + Clock clock = mock( Clock.class ); + when(clock.millis()).thenReturn( 4L ); + ClusterView clusterView = new ClusterView( 5L, clock, mock( Logger.class ) ); + clusterView.addRouters( singletonList( address( "host1" ) ) ); + clusterView.addReaders( asList( address("host1"), address( "host2" ), address( "host3" ))); + clusterView.addWriters( asList( address("host2"), address( "host4" ))); + + // When + + // Then + assertTrue(clusterView.isStale()); + } + + @Test + public void shouldBeStaleIfNoReader() + { + // Given + Clock clock = mock( Clock.class ); + when(clock.millis()).thenReturn( 4L ); + ClusterView clusterView = new ClusterView( 5L, clock, mock( Logger.class ) ); + clusterView.addRouters( singletonList( address( "host1" ) ) ); + clusterView.addWriters( asList( address("host2"), address( "host4" ))); + + // Then + assertTrue(clusterView.isStale()); + } + + @Test + public void shouldBeStaleIfNoWriter() + { + // Given + Clock clock = mock( Clock.class ); + when(clock.millis()).thenReturn( 4L ); + ClusterView clusterView = new ClusterView( 5L, clock, mock( Logger.class ) ); + clusterView.addRouters( singletonList( address( "host1" ) ) ); + clusterView.addReaders( asList( address("host1"), address( "host2" ), address( "host3" ))); + + // Then + assertTrue(clusterView.isStale()); + } + + private BoltServerAddress address(String host) + { + return new BoltServerAddress( host ); + } + +} \ No newline at end of file diff --git a/driver/src/test/java/org/neo4j/driver/internal/RoutingDriverStubTest.java b/driver/src/test/java/org/neo4j/driver/internal/RoutingDriverStubTest.java index 773104c9fe..0f1613f57c 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/RoutingDriverStubTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/RoutingDriverStubTest.java @@ -18,6 +18,7 @@ */ package org.neo4j.driver.internal; +import gherkin.lexer.Tr; import org.junit.Ignore; import org.junit.Rule; import org.junit.Test; @@ -43,6 +44,7 @@ import org.neo4j.driver.v1.GraphDatabase; import org.neo4j.driver.v1.Record; import org.neo4j.driver.v1.Session; +import org.neo4j.driver.v1.Transaction; import org.neo4j.driver.v1.exceptions.ConnectionFailureException; import org.neo4j.driver.v1.exceptions.ServiceUnavailableException; import org.neo4j.driver.v1.exceptions.SessionExpiredException; @@ -81,7 +83,7 @@ public void shouldDiscoverServers() throws IOException, InterruptedException, St { // Then Set addresses = driver.routingServers(); - assertThat( addresses, containsInAnyOrder( address(9001), address( 9002 ), address( 9003 ) ) ); + assertThat( addresses, containsInAnyOrder( address( 9001 ), address( 9002 ), address( 9003 ) ) ); } // Finally @@ -101,7 +103,7 @@ public void shouldOnlyPutConnectionInPoolOnce() throws IOException, InterruptedE // Then SocketConnectionPool pool = (SocketConnectionPool) driver.connectionPool(); List pooledConnections = pool.connectionsForAddress( address( 9001 ) ); - assertThat(pooledConnections, hasSize( 1 )); + assertThat( pooledConnections, hasSize( 1 ) ); } // Finally @@ -112,7 +114,7 @@ public void shouldOnlyPutConnectionInPoolOnce() throws IOException, InterruptedE public void shouldDiscoverNewServers() throws IOException, InterruptedException, StubServer.ForceKilled { // Given - StubServer server = StubServer.start( "discover_new_servers.script" , 9001 ); + StubServer server = StubServer.start( "discover_new_servers.script", 9001 ); URI uri = URI.create( "bolt+routing://127.0.0.1:9001" ); BoltServerAddress seed = address( 9001 ); @@ -121,7 +123,7 @@ public void shouldDiscoverNewServers() throws IOException, InterruptedException, { // Then Set addresses = driver.routingServers(); - assertThat( addresses, containsInAnyOrder( address(9002), address( 9003 ), address( 9004 ) ) ); + assertThat( addresses, containsInAnyOrder( address( 9002 ), address( 9003 ), address( 9004 ) ) ); } // Finally @@ -132,7 +134,7 @@ public void shouldDiscoverNewServers() throws IOException, InterruptedException, public void shouldHandleEmptyResponse() throws IOException, InterruptedException, StubServer.ForceKilled { // Given - StubServer server = StubServer.start( "handle_empty_response.script" , 9001 ); + StubServer server = StubServer.start( "handle_empty_response.script", 9001 ); URI uri = URI.create( "bolt+routing://127.0.0.1:9001" ); // When @@ -140,7 +142,8 @@ public void shouldHandleEmptyResponse() throws IOException, InterruptedException { GraphDatabase.driver( uri, config ); fail(); - } catch ( ServiceUnavailableException e ) + } + catch ( ServiceUnavailableException e ) { //ignore } @@ -156,7 +159,7 @@ public void shouldHandleAcquireReadSession() throws IOException, InterruptedExce StubServer server = StubServer.start( "acquire_endpoints.script", 9001 ); //START a read server - StubServer readServer = StubServer.start( "read_server.script" , 9005 ); + StubServer readServer = StubServer.start( "read_server.script", 9005 ); URI uri = URI.create( "bolt+routing://127.0.0.1:9001" ); try ( RoutingDriver driver = (RoutingDriver) GraphDatabase.driver( uri, config ); Session session = driver.session( AccessMode.READ ) ) @@ -178,15 +181,46 @@ public String apply( Record record ) assertThat( readServer.exitStatus(), equalTo( 0 ) ); } + @Test + public void shouldHandleAcquireReadSessionPlusTransaction() + throws IOException, InterruptedException, StubServer.ForceKilled + { + // Given + StubServer server = StubServer.start( "acquire_endpoints.script", 9001 ); + + //START a read server + StubServer readServer = StubServer.start( "read_server.script", 9005 ); + URI uri = URI.create( "bolt+routing://127.0.0.1:9001" ); + try ( RoutingDriver driver = (RoutingDriver) GraphDatabase.driver( uri, config ); + Session session = driver.session( AccessMode.READ ); + Transaction tx = session.beginTransaction() ) + { + List result = tx.run( "MATCH (n) RETURN n.name" ).list( new Function() + { + @Override + public String apply( Record record ) + { + return record.get( "n.name" ).asString(); + } + } ); + + assertThat( result, equalTo( Arrays.asList( "Bob", "Alice", "Tina" ) ) ); + + } + // Finally + assertThat( server.exitStatus(), equalTo( 0 ) ); + assertThat( readServer.exitStatus(), equalTo( 0 ) ); + } + @Test public void shouldRoundRobinReadServers() throws IOException, InterruptedException, StubServer.ForceKilled { // Given - StubServer server = StubServer.start( "acquire_endpoints.script" , 9001 ); + StubServer server = StubServer.start( "acquire_endpoints.script", 9001 ); //START two read servers StubServer readServer1 = StubServer.start( "read_server.script", 9005 ); - StubServer readServer2 = StubServer.start( "read_server.script" , 9006 ); + StubServer readServer2 = StubServer.start( "read_server.script", 9006 ); URI uri = URI.create( "bolt+routing://127.0.0.1:9001" ); try ( RoutingDriver driver = (RoutingDriver) GraphDatabase.driver( uri, config ) ) { @@ -212,6 +246,42 @@ public String apply( Record record ) assertThat( readServer2.exitStatus(), equalTo( 0 ) ); } + @Test + public void shouldRoundRobinReadServersWhenUsingTransaction() + throws IOException, InterruptedException, StubServer.ForceKilled + { + // Given + StubServer server = StubServer.start( "acquire_endpoints.script", 9001 ); + + //START two read servers + StubServer readServer1 = StubServer.start( "read_server.script", 9005 ); + StubServer readServer2 = StubServer.start( "read_server.script", 9006 ); + URI uri = URI.create( "bolt+routing://127.0.0.1:9001" ); + try ( RoutingDriver driver = (RoutingDriver) GraphDatabase.driver( uri, config ) ) + { + // Run twice, one on each read server + for ( int i = 0; i < 2; i++ ) + { + try ( Session session = driver.session( AccessMode.READ ); + Transaction tx = session.beginTransaction() ) + { + assertThat( tx.run( "MATCH (n) RETURN n.name" ).list( new Function() + { + @Override + public String apply( Record record ) + { + return record.get( "n.name" ).asString(); + } + } ), equalTo( Arrays.asList( "Bob", "Alice", "Tina" ) ) ); + } + } + } + // Finally + assertThat( server.exitStatus(), equalTo( 0 ) ); + assertThat( readServer1.exitStatus(), equalTo( 0 ) ); + assertThat( readServer2.exitStatus(), equalTo( 0 ) ); + } + @Test public void shouldThrowSessionExpiredIfReadServerDisappears() throws IOException, InterruptedException, StubServer.ForceKilled @@ -221,10 +291,10 @@ public void shouldThrowSessionExpiredIfReadServerDisappears() exception.expectMessage( "Server at 127.0.0.1:9005 is no longer available" ); // Given - StubServer server = StubServer.start( "acquire_endpoints.script" , 9001 ); + StubServer server = StubServer.start( "acquire_endpoints.script", 9001 ); //START a read server - StubServer.start( "dead_server.script" , 9005 ); + StubServer.start( "dead_server.script", 9005 ); URI uri = URI.create( "bolt+routing://127.0.0.1:9001" ); try ( RoutingDriver driver = (RoutingDriver) GraphDatabase.driver( uri, config ); Session session = driver.session( AccessMode.READ ) ) @@ -235,6 +305,31 @@ public void shouldThrowSessionExpiredIfReadServerDisappears() assertThat( server.exitStatus(), equalTo( 0 ) ); } + @Test + public void shouldThrowSessionExpiredIfReadServerDisappearsWhenUsingTransaction() + throws IOException, InterruptedException, StubServer.ForceKilled + { + //Expect + exception.expect( SessionExpiredException.class ); + exception.expectMessage( "Server at 127.0.0.1:9005 is no longer available" ); + + // Given + StubServer server = StubServer.start( "acquire_endpoints.script", 9001 ); + + //START a read server + StubServer.start( "dead_server.script", 9005 ); + URI uri = URI.create( "bolt+routing://127.0.0.1:9001" ); + try ( RoutingDriver driver = (RoutingDriver) GraphDatabase.driver( uri, config ); + Session session = driver.session( AccessMode.READ ); + Transaction tx = session.beginTransaction() ) + { + tx.run( "MATCH (n) RETURN n.name" ); + tx.success(); + } + // Finally + assertThat( server.exitStatus(), equalTo( 0 ) ); + } + @Test public void shouldThrowSessionExpiredIfWriteServerDisappears() throws IOException, InterruptedException, StubServer.ForceKilled @@ -244,10 +339,10 @@ public void shouldThrowSessionExpiredIfWriteServerDisappears() //exception.expectMessage( "Server at 127.0.0.1:9006 is no longer available" ); // Given - StubServer server = StubServer.start( "acquire_endpoints.script" , 9001 ); + StubServer server = StubServer.start( "acquire_endpoints.script", 9001 ); //START a dead write servers - StubServer.start( "dead_server.script" , 9007 ); + StubServer.start( "dead_server.script", 9007 ); URI uri = URI.create( "bolt+routing://127.0.0.1:9001" ); try ( RoutingDriver driver = (RoutingDriver) GraphDatabase.driver( uri, config ); Session session = driver.session( AccessMode.WRITE ) ) @@ -258,14 +353,40 @@ public void shouldThrowSessionExpiredIfWriteServerDisappears() assertThat( server.exitStatus(), equalTo( 0 ) ); } + @Test + public void shouldThrowSessionExpiredIfWriteServerDisappearsWhenUsingTransaction() + throws IOException, InterruptedException, StubServer.ForceKilled + { + //Expect + exception.expect( SessionExpiredException.class ); + //exception.expectMessage( "Server at 127.0.0.1:9006 is no longer available" ); + + // Given + StubServer server = StubServer.start( "acquire_endpoints.script", 9001 ); + + //START a dead write servers + StubServer.start( "dead_server.script", 9007 ); + URI uri = URI.create( "bolt+routing://127.0.0.1:9001" ); + try ( RoutingDriver driver = (RoutingDriver) GraphDatabase.driver( uri, config ); + Session session = driver.session( AccessMode.WRITE ); + Transaction tx = session.beginTransaction() ) + { + tx.run( "MATCH (n) RETURN n.name" ).consume(); + tx.success(); + } + // Finally + assertThat( server.exitStatus(), equalTo( 0 ) ); + } + + @Test public void shouldHandleAcquireWriteSession() throws IOException, InterruptedException, StubServer.ForceKilled { // Given - StubServer server = StubServer.start( "acquire_endpoints.script" , 9001 ); + StubServer server = StubServer.start( "acquire_endpoints.script", 9001 ); //START a write server - StubServer writeServer = StubServer.start( "write_server.script" , 9007 ); + StubServer writeServer = StubServer.start( "write_server.script", 9007 ); URI uri = URI.create( "bolt+routing://127.0.0.1:9001" ); try ( RoutingDriver driver = (RoutingDriver) GraphDatabase.driver( uri, config ); Session session = driver.session( AccessMode.WRITE ) ) @@ -277,6 +398,28 @@ public void shouldHandleAcquireWriteSession() throws IOException, InterruptedExc assertThat( writeServer.exitStatus(), equalTo( 0 ) ); } + @Test + public void shouldHandleAcquireWriteSessionAndTransaction() + throws IOException, InterruptedException, StubServer.ForceKilled + { + // Given + StubServer server = StubServer.start( "acquire_endpoints.script", 9001 ); + + //START a write server + StubServer writeServer = StubServer.start( "write_server.script", 9007 ); + URI uri = URI.create( "bolt+routing://127.0.0.1:9001" ); + try ( RoutingDriver driver = (RoutingDriver) GraphDatabase.driver( uri, config ); + Session session = driver.session( AccessMode.WRITE ); + Transaction tx = session.beginTransaction() ) + { + tx.run( "CREATE (n {name:'Bob'})" ); + tx.success(); + } + // Finally + assertThat( server.exitStatus(), equalTo( 0 ) ); + assertThat( writeServer.exitStatus(), equalTo( 0 ) ); + } + @Test public void shouldRoundRobinWriteSessions() throws IOException, InterruptedException, StubServer.ForceKilled { @@ -303,6 +446,34 @@ public void shouldRoundRobinWriteSessions() throws IOException, InterruptedExcep assertThat( writeServer2.exitStatus(), equalTo( 0 ) ); } + @Test + public void shouldRoundRobinWriteSessionsInTransaction() throws IOException, InterruptedException, StubServer.ForceKilled + { + // Given + StubServer server = StubServer.start( "acquire_endpoints.script", 9001 ); + + //START a write server + StubServer writeServer1 = StubServer.start( "write_server.script", 9007 ); + StubServer writeServer2 = StubServer.start( "write_server.script", 9008 ); + URI uri = URI.create( "bolt+routing://127.0.0.1:9001" ); + try ( RoutingDriver driver = (RoutingDriver) GraphDatabase.driver( uri, config ) ) + { + for ( int i = 0; i < 2; i++ ) + { + try ( Session session = driver.session(); + Transaction tx = session.beginTransaction()) + { + tx.run( "CREATE (n {name:'Bob'})" ); + tx.success(); + } + } + } + // Finally + assertThat( server.exitStatus(), equalTo( 0 ) ); + assertThat( writeServer1.exitStatus(), equalTo( 0 ) ); + assertThat( writeServer2.exitStatus(), equalTo( 0 ) ); + } + @Test public void shouldRememberEndpoints() throws IOException, InterruptedException, StubServer.ForceKilled { @@ -319,7 +490,8 @@ public void shouldRememberEndpoints() throws IOException, InterruptedException, assertThat( driver.readServers(), containsInAnyOrder( address( 9005 ), address( 9006 ) ) ); assertThat( driver.writeServers(), containsInAnyOrder( address( 9007 ), address( 9008 ) ) ); - assertThat( driver.routingServers(), containsInAnyOrder( address( 9001 ), address( 9002 ), address( 9003 ) ) ); + assertThat( driver.routingServers(), + containsInAnyOrder( address( 9001 ), address( 9002 ), address( 9003 ) ) ); } // Finally assertThat( server.exitStatus(), equalTo( 0 ) ); @@ -358,7 +530,8 @@ public void shouldForgetEndpointsOnFailure() throws IOException, InterruptedExce } @Test - public void shouldForgetEndpointsOnFailedSessionAcquisition() throws IOException, InterruptedException, StubServer.ForceKilled + public void shouldForgetEndpointsOnFailedSessionAcquisition() + throws IOException, InterruptedException, StubServer.ForceKilled { // Given StubServer server = StubServer.start( "acquire_endpoints.script", 9001 ); @@ -406,8 +579,8 @@ public void shouldRediscoverIfNecessaryOnSessionAcquisition() //since we have no write nor read servers we must rediscover Session session = driver.session( AccessMode.READ ); - assertThat( driver.routingServers(), containsInAnyOrder(address( 9002 ), - address( 9003 ), address( 9004 ) ) ); + assertThat( driver.routingServers(), containsInAnyOrder( address( 9002 ), + address( 9003 ), address( 9004 ) ) ); //server told os to forget 9001 assertFalse( driver.connectionPool().hasAddress( address( 9001 ) ) ); session.close(); @@ -502,7 +675,7 @@ public void shouldHandleLeaderSwitchWhenWriting() boolean failed = false; try ( Session session = driver.session( AccessMode.WRITE ) ) { - assertThat( driver.writeServers(), hasItem(address( 9007 ) ) ); + assertThat( driver.writeServers(), hasItem( address( 9007 ) ) ); assertThat( driver.writeServers(), hasItem( address( 9008 ) ) ); session.run( "CREATE ()" ).consume(); } @@ -521,6 +694,103 @@ public void shouldHandleLeaderSwitchWhenWriting() assertThat( server.exitStatus(), equalTo( 0 ) ); } + @Test + public void shouldHandleLeaderSwitchWhenWritingWithoutConsuming() + throws IOException, InterruptedException, StubServer.ForceKilled + { + // Given + StubServer server = StubServer.start( "acquire_endpoints.script", 9001 ); + + //START a write server that doesn't accept writes + StubServer.start( "not_able_to_write_server.script", 9007 ); + URI uri = URI.create( "bolt+routing://127.0.0.1:9001" ); + RoutingDriver driver = (RoutingDriver) GraphDatabase.driver( uri, config ); + boolean failed = false; + try ( Session session = driver.session( AccessMode.WRITE ) ) + { + assertThat( driver.writeServers(), hasItem( address( 9007 ) ) ); + assertThat( driver.writeServers(), hasItem( address( 9008 ) ) ); + session.run( "CREATE ()" ); + } + catch ( SessionExpiredException e ) + { + failed = true; + assertThat( e.getMessage(), equalTo( "Server at 127.0.0.1:9007 no longer accepts writes" ) ); + } + assertTrue( failed ); + assertThat( driver.writeServers(), not( hasItem( address( 9007 ) ) ) ); + assertThat( driver.writeServers(), hasItem( address( 9008 ) ) ); + assertTrue( driver.connectionPool().hasAddress( address( 9007 ) ) ); + + driver.close(); + // Finally + assertThat( server.exitStatus(), equalTo( 0 ) ); + } + + @Test + public void shouldHandleLeaderSwitchWhenWritingInTransaction() + throws IOException, InterruptedException, StubServer.ForceKilled + { + // Given + StubServer server = StubServer.start( "acquire_endpoints.script", 9001 ); + + //START a write server that doesn't accept writes + StubServer.start( "not_able_to_write_server.script", 9007 ); + URI uri = URI.create( "bolt+routing://127.0.0.1:9001" ); + RoutingDriver driver = (RoutingDriver) GraphDatabase.driver( uri, config ); + boolean failed = false; + try ( Session session = driver.session( AccessMode.WRITE ); + Transaction tx = session.beginTransaction() ) + { + tx.run( "CREATE ()" ).consume(); + } + catch ( SessionExpiredException e ) + { + failed = true; + assertThat( e.getMessage(), equalTo( "Server at 127.0.0.1:9007 no longer accepts writes" ) ); + } + assertTrue( failed ); + assertThat( driver.writeServers(), not( hasItem( address( 9007 ) ) ) ); + assertThat( driver.writeServers(), hasItem( address( 9008 ) ) ); + assertTrue( driver.connectionPool().hasAddress( address( 9007 ) ) ); + + driver.close(); + // Finally + assertThat( server.exitStatus(), equalTo( 0 ) ); + } + + @Test + public void shouldHandleLeaderSwitchWhenWritingInTransactionWithoutConsuming() + throws IOException, InterruptedException, StubServer.ForceKilled + { + // Given + StubServer server = StubServer.start( "acquire_endpoints.script", 9001 ); + + //START a write server that doesn't accept writes + StubServer.start( "not_able_to_write_server.script", 9007 ); + URI uri = URI.create( "bolt+routing://127.0.0.1:9001" ); + RoutingDriver driver = (RoutingDriver) GraphDatabase.driver( uri, config ); + boolean failed = false; + try ( Session session = driver.session( AccessMode.WRITE ); + Transaction tx = session.beginTransaction() ) + { + tx.run( "CREATE ()" ); + } + catch ( SessionExpiredException e ) + { + failed = true; + assertThat( e.getMessage(), equalTo( "Server at 127.0.0.1:9007 no longer accepts writes" ) ); + } + assertTrue( failed ); + assertThat( driver.writeServers(), not( hasItem( address( 9007 ) ) ) ); + assertThat( driver.writeServers(), hasItem( address( 9008 ) ) ); + assertTrue( driver.connectionPool().hasAddress( address( 9007 ) ) ); + + driver.close(); + // Finally + assertThat( server.exitStatus(), equalTo( 0 ) ); + } + @Test public void shouldRediscoverOnExpiry() throws IOException, InterruptedException, StubServer.ForceKilled { @@ -531,15 +801,15 @@ public void shouldRediscoverOnExpiry() throws IOException, InterruptedException, StubServer readServer = StubServer.start( "empty.script", 9005 ); URI uri = URI.create( "bolt+routing://127.0.0.1:9001" ); RoutingDriver driver = (RoutingDriver) GraphDatabase.driver( uri, config ); - assertThat(driver.routingServers(), contains(address( 9001 ))); - assertThat(driver.readServers(), contains(address( 9002 ))); - assertThat(driver.writeServers(), contains(address( 9003 ))); + assertThat( driver.routingServers(), contains( address( 9001 ) ) ); + assertThat( driver.readServers(), contains( address( 9002 ) ) ); + assertThat( driver.writeServers(), contains( address( 9003 ) ) ); //On acquisition we should update our view Session session = driver.session( AccessMode.READ ); - assertThat(driver.routingServers(), contains(address( 9004 ))); - assertThat(driver.readServers(), contains(address( 9005 ))); - assertThat(driver.writeServers(), contains(address( 9006 ))); + assertThat( driver.routingServers(), contains( address( 9004 ) ) ); + assertThat( driver.readServers(), contains( address( 9005 ) ) ); + assertThat( driver.writeServers(), contains( address( 9006 ) ) ); session.close(); driver.close(); // Finally @@ -572,28 +842,28 @@ public void shouldNotPutBackPurgedConnection() throws IOException, InterruptedEx writeSession.close(); fail(); } - catch (SessionExpiredException e) + catch ( SessionExpiredException e ) { //ignore } //We now lost all write servers - assertThat(driver.writeServers(), hasSize( 0 )); + assertThat( driver.writeServers(), hasSize( 0 ) ); //reacquiring will trow out the current read server at 9002 writeSession = driver.session( AccessMode.WRITE ); - assertThat(driver.routingServers(), contains(address( 9004 ))); - assertThat(driver.readServers(), contains(address( 9005 ))); - assertThat(driver.writeServers(), contains(address( 9006 ))); - assertFalse(driver.connectionPool().hasAddress(address( 9002 ) )); + assertThat( driver.routingServers(), contains( address( 9004 ) ) ); + assertThat( driver.readServers(), contains( address( 9005 ) ) ); + assertThat( driver.writeServers(), contains( address( 9006 ) ) ); + assertFalse( driver.connectionPool().hasAddress( address( 9002 ) ) ); // now we close the read session and the connection should not be put // back to the pool - Connection connection = ((RoutingNetworkSession) readSession).connection; + Connection connection = ((NetworkSession) ((RoutingNetworkSession) readSession).delegate).connection; assertTrue( connection.isOpen() ); readSession.close(); assertFalse( connection.isOpen() ); - assertFalse(driver.connectionPool().hasAddress(address( 9002 ) )); + assertFalse( driver.connectionPool().hasAddress( address( 9002 ) ) ); writeSession.close(); driver.close(); diff --git a/driver/src/test/java/org/neo4j/driver/internal/RoutingDriverTest.java b/driver/src/test/java/org/neo4j/driver/internal/RoutingDriverTest.java index 2fd896f815..f551a8dbfd 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/RoutingDriverTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/RoutingDriverTest.java @@ -22,6 +22,8 @@ import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; import java.util.Collections; import java.util.HashMap; @@ -47,10 +49,12 @@ import static java.util.Arrays.asList; import static java.util.Collections.singletonList; +import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.CoreMatchers.hasItem; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.containsInAnyOrder; import static org.hamcrest.core.IsNot.not; +import static org.mockito.Matchers.any; import static org.mockito.Matchers.anyString; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; @@ -243,6 +247,57 @@ public void shouldNotRediscoverWheNoTimeout() assertThat( routingDriver.writeServers(), containsInAnyOrder( boltAddress( "localhost", 3333 ) ) ); } + @Test + public void shouldRoundRobinAmongReadServers() + { + // Given + final Session session = mock( Session.class ); + when( session.run( GET_SERVERS ) ).thenReturn( + getServers( asList( "localhost:1111", "localhost:1112" ), + asList( "localhost:2222", "localhost:2223", "localhost:2224" ), + singletonList( "localhost:3333" ) ) ); + + // When + RoutingDriver routingDriver = forSession( session ); + RoutingNetworkSession read1 = (RoutingNetworkSession) routingDriver.session( AccessMode.READ ); + RoutingNetworkSession read2 = (RoutingNetworkSession) routingDriver.session( AccessMode.READ ); + RoutingNetworkSession read3 = (RoutingNetworkSession) routingDriver.session( AccessMode.READ ); + RoutingNetworkSession read4 = (RoutingNetworkSession) routingDriver.session( AccessMode.READ ); + + + // Then + assertThat(read1.address(), equalTo(boltAddress( "localhost", 2222 ))); + assertThat(read2.address(), equalTo(boltAddress( "localhost", 2223 ))); + assertThat(read3.address(), equalTo(boltAddress( "localhost", 2224 ))); + assertThat(read4.address(), equalTo(boltAddress( "localhost", 2222 ))); + + } + + @Test + public void shouldRoundRobinAmongWriteServers() + { + // Given + final Session session = mock( Session.class ); + when( session.run( GET_SERVERS ) ).thenReturn( + getServers( asList( "localhost:1111", "localhost:1112" ), + singletonList( "localhost:3333" ), asList( "localhost:2222", "localhost:2223", "localhost:2224" ) ) ); + + // When + RoutingDriver routingDriver = forSession( session ); + RoutingNetworkSession write1 = (RoutingNetworkSession) routingDriver.session( AccessMode.WRITE ); + RoutingNetworkSession write2 = (RoutingNetworkSession) routingDriver.session( AccessMode.WRITE ); + RoutingNetworkSession write3 = (RoutingNetworkSession) routingDriver.session( AccessMode.WRITE ); + RoutingNetworkSession write4 = (RoutingNetworkSession) routingDriver.session( AccessMode.WRITE ); + + + // Then + assertThat(write1.address(), equalTo(boltAddress( "localhost", 2222 ))); + assertThat(write2.address(), equalTo(boltAddress( "localhost", 2223 ))); + assertThat(write3.address(), equalTo(boltAddress( "localhost", 2224 ))); + assertThat(write4.address(), equalTo(boltAddress( "localhost", 2222 ))); + + } + private RoutingDriver forSession( final Session session ) { return forSession( session, Clock.SYSTEM ); @@ -353,9 +408,22 @@ private Map serverInfo( String role, List addresses ) private ConnectionPool pool() { ConnectionPool pool = mock( ConnectionPool.class ); - Connection connection = mock( Connection.class ); - when( connection.isOpen() ).thenReturn( true ); - when( pool.acquire( SEED ) ).thenReturn( connection ); + + + when( pool.acquire( any(BoltServerAddress.class) ) ).thenAnswer( new Answer() + { + @Override + public Connection answer( InvocationOnMock invocationOnMock ) throws Throwable + { + BoltServerAddress address = (BoltServerAddress) invocationOnMock.getArguments()[0]; + Connection connection = mock( Connection.class ); + when( connection.isOpen() ).thenReturn( true ); + when(connection.address()).thenReturn( address ); + + return connection; + } + } ); + return pool; } diff --git a/driver/src/test/java/org/neo4j/driver/internal/RoutingNetworkSessionTest.java b/driver/src/test/java/org/neo4j/driver/internal/RoutingNetworkSessionTest.java index b988ea7b3f..888f721ff0 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/RoutingNetworkSessionTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/RoutingNetworkSessionTest.java @@ -27,7 +27,7 @@ import org.neo4j.driver.internal.spi.Collector; import org.neo4j.driver.internal.spi.Connection; import org.neo4j.driver.v1.AccessMode; -import org.neo4j.driver.v1.Logger; +import org.neo4j.driver.v1.Session; import org.neo4j.driver.v1.exceptions.ClientException; import org.neo4j.driver.v1.exceptions.ConnectionFailureException; import org.neo4j.driver.v1.exceptions.SessionExpiredException; @@ -68,7 +68,8 @@ public void shouldHandleConnectionFailures() when( connection ).run( anyString(), any( Map.class ), any( Collector.class ) ); RoutingNetworkSession result = - new RoutingNetworkSession( AccessMode.WRITE, connection, onError ); + new RoutingNetworkSession( new NetworkSession( connection ), AccessMode.WRITE, connection.address(), + onError ); // When try @@ -94,7 +95,8 @@ public void shouldHandleWriteFailuresInWriteAccessMode() doThrow( new ClientException( "Neo.ClientError.Cluster.NotALeader", "oh no!" ) ). when( connection ).run( anyString(), any( Map.class ), any( Collector.class ) ); RoutingNetworkSession session = - new RoutingNetworkSession( AccessMode.WRITE, connection, onError ); + new RoutingNetworkSession( new NetworkSession(connection), AccessMode.WRITE, connection.address(), + onError ); // When try @@ -120,7 +122,7 @@ public void shouldHandleWriteFailuresInReadAccessMode() doThrow( new ClientException( "Neo.ClientError.Cluster.NotALeader", "oh no!" ) ). when( connection ).run( anyString(), any( Map.class ), any( Collector.class ) ); RoutingNetworkSession session = - new RoutingNetworkSession( AccessMode.READ, connection, onError ); + new RoutingNetworkSession( new NetworkSession( connection ), AccessMode.READ, connection.address(), onError ); // When try @@ -144,7 +146,7 @@ public void shouldRethrowNonWriteFailures() doThrow( toBeThrown ). when( connection ).run( anyString(), any( Map.class ), any( Collector.class ) ); RoutingNetworkSession session = - new RoutingNetworkSession( AccessMode.WRITE, connection, onError ); + new RoutingNetworkSession( new NetworkSession( connection ), AccessMode.WRITE, connection.address(), onError ); // When try @@ -169,7 +171,8 @@ public void shouldHandleConnectionFailuresOnClose() when( connection ).sync(); RoutingNetworkSession session = - new RoutingNetworkSession( AccessMode.WRITE, connection, onError ); + new RoutingNetworkSession( new NetworkSession( connection ), AccessMode.WRITE, connection.address(), + onError ); // When try @@ -194,7 +197,7 @@ public void shouldHandleWriteFailuresOnClose() doThrow( new ClientException( "Neo.ClientError.Cluster.NotALeader", "oh no!" ) ).when( connection ).sync(); RoutingNetworkSession session = - new RoutingNetworkSession( AccessMode.WRITE, connection, onError ); + new RoutingNetworkSession( new NetworkSession( connection ), AccessMode.WRITE, connection.address(), onError ); // When try @@ -211,4 +214,69 @@ public void shouldHandleWriteFailuresOnClose() verify( onError ).onWriteFailure( LOCALHOST ); verifyNoMoreInteractions( onError ); } + + @Test + public void shouldDelegateLastBookmark() + { + // Given + Session inner = mock( Session.class ); + RoutingNetworkSession session = + new RoutingNetworkSession( inner, AccessMode.WRITE, connection.address(), onError ); + + + // When + session.lastBookmark(); + + // Then + verify( inner ).lastBookmark(); + } + + @Test + public void shouldDelegateReset() + { + // Given + Session inner = mock( Session.class ); + RoutingNetworkSession session = + new RoutingNetworkSession( inner, AccessMode.WRITE, connection.address(), onError ); + + + // When + session.reset(); + + // Then + verify( inner ).reset(); + } + + @Test + public void shouldDelegateIsOpen() + { + // Given + Session inner = mock( Session.class ); + RoutingNetworkSession session = + new RoutingNetworkSession( inner, AccessMode.WRITE, connection.address(), onError ); + + + // When + session.isOpen(); + + // Then + verify( inner ).isOpen(); + } + + @Test + public void shouldDelegateServer() + { + // Given + Session inner = mock( Session.class ); + RoutingNetworkSession session = + new RoutingNetworkSession( inner, AccessMode.WRITE, connection.address(), onError ); + + + // When + session.server(); + + // Then + verify( inner ).server(); + } + } diff --git a/driver/src/test/java/org/neo4j/driver/internal/RoutingTransactionTest.java b/driver/src/test/java/org/neo4j/driver/internal/RoutingTransactionTest.java new file mode 100644 index 0000000000..1b4f9a0e32 --- /dev/null +++ b/driver/src/test/java/org/neo4j/driver/internal/RoutingTransactionTest.java @@ -0,0 +1,312 @@ +/** + * Copyright (c) 2002-2016 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal; + +import org.junit.Before; +import org.junit.Test; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import java.util.Map; + +import org.neo4j.driver.internal.net.BoltServerAddress; +import org.neo4j.driver.internal.spi.Collector; +import org.neo4j.driver.internal.spi.Connection; +import org.neo4j.driver.v1.AccessMode; +import org.neo4j.driver.v1.Transaction; +import org.neo4j.driver.v1.exceptions.ClientException; +import org.neo4j.driver.v1.exceptions.ConnectionFailureException; +import org.neo4j.driver.v1.exceptions.SessionExpiredException; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.fail; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyString; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.verifyZeroInteractions; +import static org.mockito.Mockito.when; + +public class RoutingTransactionTest +{ + private static final BoltServerAddress LOCALHOST = new BoltServerAddress( "localhost", 7687 ); + private Connection connection; + private RoutingErrorHandler onError; + private Runnable cleanup; + + private Answer throwingAnswer( final Throwable throwable ) + { + return new Answer() + { + @Override + public Void answer( InvocationOnMock invocationOnMock ) throws Throwable + { + String statement = (String) invocationOnMock.getArguments()[0]; + if ( statement.equals( "BEGIN" ) ) + { + return null; + } + else + { + throw throwable; + } + } + }; + } + + @Before + public void setUp() + { + connection = mock( Connection.class ); + when( connection.address() ).thenReturn( LOCALHOST ); + when( connection.isOpen() ).thenReturn( true ); + onError = mock( RoutingErrorHandler.class ); + cleanup = mock( Runnable.class ); + } + + @SuppressWarnings( "unchecked" ) + @Test + public void shouldHandleConnectionFailures() + { + // Given + + doAnswer( throwingAnswer( new ConnectionFailureException( "oh no" ) ) ) + .when( connection ).run( anyString(), any( Map.class ), any( Collector.class ) ); + + RoutingTransaction tx = + new RoutingTransaction( new ExplicitTransaction( connection, cleanup ), AccessMode.READ, LOCALHOST, + onError ); + + // When + try + { + tx.run( "CREATE ()" ); + fail(); + } + catch ( SessionExpiredException e ) + { + //ignore + } + + // Then + verify( onError ).onConnectionFailure( LOCALHOST ); + verifyNoMoreInteractions( onError ); + } + + @SuppressWarnings( "unchecked" ) + @Test + public void shouldHandleWriteFailuresInWriteAccessMode() + { + // Given + doAnswer( throwingAnswer( new ClientException( "Neo.ClientError.Cluster.NotALeader", "oh no!" ) ) ) + .when( connection ).run( anyString(), any( Map.class ), any( Collector.class ) ); + + RoutingTransaction tx = + new RoutingTransaction( new ExplicitTransaction( connection, cleanup ), AccessMode.WRITE, + connection.address(), onError ); + + // When + try + { + tx.run( "CREATE ()" ); + fail(); + } + catch ( SessionExpiredException e ) + { + //ignore + } + + // Then + verify( onError ).onWriteFailure( LOCALHOST ); + verifyNoMoreInteractions( onError ); + } + + @SuppressWarnings( "unchecked" ) + @Test + public void shouldHandleWriteFailuresInReadAccessMode() + { + // Given + doAnswer( throwingAnswer( new ClientException( "Neo.ClientError.Cluster.NotALeader", "oh no!" ) ) ) + .when( connection ).run( anyString(), any( Map.class ), any( Collector.class ) ); + RoutingTransaction tx = + new RoutingTransaction( new ExplicitTransaction( connection, cleanup ), AccessMode.READ, + connection.address(), onError ); + + // When + try + { + tx.run( "CREATE ()" ); + fail(); + } + catch ( ClientException e ) + { + //ignore + } + verifyNoMoreInteractions( onError ); + } + + @SuppressWarnings( "unchecked" ) + @Test + public void shouldRethrowNonWriteFailures() + { + // Given + ClientException toBeThrown = new ClientException( "code", "oh no!" ); + doAnswer( throwingAnswer( toBeThrown ) ) + .when( connection ).run( anyString(), any( Map.class ), any( Collector.class ) ); + RoutingTransaction tx = + new RoutingTransaction( new ExplicitTransaction( connection, cleanup ), AccessMode.WRITE, + connection.address(), onError ); + + // When + try + { + tx.run( "CREATE ()" ); + fail(); + } + catch ( ClientException e ) + { + assertThat( e, is( toBeThrown ) ); + } + + // Then + verifyZeroInteractions( onError ); + } + + @Test + public void shouldHandleConnectionFailuresOnClose() + { + // Given + doThrow( new ConnectionFailureException( "oh no" ) ). + when( connection ).sync(); + + RoutingTransaction tx = + new RoutingTransaction( new ExplicitTransaction( connection, cleanup ), AccessMode.WRITE, + connection.address(), onError ); + + // When + try + { + tx.close(); + fail(); + } + catch ( SessionExpiredException e ) + { + //ignore + } + + // Then + verify( onError ).onConnectionFailure( LOCALHOST ); + verifyNoMoreInteractions( onError ); + } + + @Test + public void shouldHandleWriteFailuresOnClose() + { + // Given + doThrow( new ClientException( "Neo.ClientError.Cluster.NotALeader", "oh no!" ) ).when( connection ).sync(); + + RoutingTransaction tx = + new RoutingTransaction( new ExplicitTransaction( connection, cleanup ), AccessMode.WRITE, + connection.address(), onError ); + + // When + try + { + tx.close(); + fail(); + } + catch ( SessionExpiredException e ) + { + //ignore + } + + // Then + verify( onError ).onWriteFailure( LOCALHOST ); + verifyNoMoreInteractions( onError ); + } + + + @Test + public void shouldDelegateSuccess() + { + // Given + Transaction inner = mock( Transaction.class ); + RoutingTransaction tx = + new RoutingTransaction(inner, AccessMode.WRITE, + connection.address(), onError ); + + // When + tx.success(); + + // Then + verify( inner ).success(); + } + + @Test + public void shouldDelegateFailure() + { + // Given + Transaction inner = mock( Transaction.class ); + RoutingTransaction tx = + new RoutingTransaction(inner, AccessMode.WRITE, + connection.address(), onError ); + + // When + tx.failure(); + + // Then + verify( inner ).failure(); + } + + @Test + public void shouldDelegateIsOpen() + { + // Given + Transaction inner = mock( Transaction.class ); + RoutingTransaction tx = + new RoutingTransaction(inner, AccessMode.WRITE, + connection.address(), onError ); + + // When + tx.isOpen(); + + // Then + verify( inner ).isOpen(); + } + + @Test + public void shouldDelegateTypesystem() + { + // Given + Transaction inner = mock( Transaction.class ); + RoutingTransaction tx = + new RoutingTransaction(inner, AccessMode.WRITE, + connection.address(), onError ); + + // When + tx.typeSystem(); + + // Then + verify( inner ).typeSystem(); + } +} \ No newline at end of file diff --git a/driver/src/test/resources/not_able_to_write_server.script b/driver/src/test/resources/not_able_to_write_server.script index 6a97dd41d9..f14c391ee9 100644 --- a/driver/src/test/resources/not_able_to_write_server.script +++ b/driver/src/test/resources/not_able_to_write_server.script @@ -2,6 +2,9 @@ !: AUTO RESET !: AUTO RUN "RETURN 1 // JavaDriver poll to test connection" {} !: AUTO PULL_ALL +!: AUTO RUN "ROLLBACK" {} +!: AUTO RUN "BEGIN" {} +!: AUTO PULL_ALL C: RUN "CREATE ()" {} C: PULL_ALL diff --git a/driver/src/test/resources/read_server.script b/driver/src/test/resources/read_server.script index 17e2d3a22c..2cc3489dde 100644 --- a/driver/src/test/resources/read_server.script +++ b/driver/src/test/resources/read_server.script @@ -2,6 +2,8 @@ !: AUTO RESET !: AUTO RUN "RETURN 1 // JavaDriver poll to test connection" {} !: AUTO PULL_ALL +!: AUTO RUN "ROLLBACK" {} +!: AUTO RUN "BEGIN" {} C: RUN "MATCH (n) RETURN n.name" {} PULL_ALL diff --git a/driver/src/test/resources/write_server.script b/driver/src/test/resources/write_server.script index 27edbefa0c..cf8f1bad68 100644 --- a/driver/src/test/resources/write_server.script +++ b/driver/src/test/resources/write_server.script @@ -2,6 +2,8 @@ !: AUTO RESET !: AUTO RUN "RETURN 1 // JavaDriver poll to test connection" {} !: AUTO PULL_ALL +!: AUTO RUN "ROLLBACK" {} +!: AUTO RUN "BEGIN" {} C: RUN "CREATE (n {name:'Bob'})" {} PULL_ALL