diff --git a/driver/src/main/java/org/neo4j/driver/internal/DriverFactory.java b/driver/src/main/java/org/neo4j/driver/internal/DriverFactory.java new file mode 100644 index 0000000000..022a456947 --- /dev/null +++ b/driver/src/main/java/org/neo4j/driver/internal/DriverFactory.java @@ -0,0 +1,182 @@ +/* + * Copyright (c) 2002-2016 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal; + +import java.io.IOException; +import java.net.URI; +import java.security.GeneralSecurityException; + +import org.neo4j.driver.internal.cluster.RoutingSettings; +import org.neo4j.driver.internal.net.BoltServerAddress; +import org.neo4j.driver.internal.net.SocketConnector; +import org.neo4j.driver.internal.net.pooling.PoolSettings; +import org.neo4j.driver.internal.net.pooling.SocketConnectionPool; +import org.neo4j.driver.internal.security.SecurityPlan; +import org.neo4j.driver.internal.spi.ConnectionPool; +import org.neo4j.driver.internal.spi.Connector; +import org.neo4j.driver.internal.util.Clock; +import org.neo4j.driver.v1.AuthToken; +import org.neo4j.driver.v1.AuthTokens; +import org.neo4j.driver.v1.Config; +import org.neo4j.driver.v1.Driver; +import org.neo4j.driver.v1.Logger; +import org.neo4j.driver.v1.exceptions.ClientException; + +import static java.lang.String.format; +import static org.neo4j.driver.internal.security.SecurityPlan.insecure; +import static org.neo4j.driver.v1.Config.EncryptionLevel.REQUIRED; + +public class DriverFactory +{ + public final Driver newInstance( URI uri, AuthToken authToken, RoutingSettings routingSettings, Config config ) + { + BoltServerAddress address = BoltServerAddress.from( uri ); + SecurityPlan securityPlan = createSecurityPlan( address, config ); + ConnectionPool connectionPool = createConnectionPool( authToken, securityPlan, config ); + + try + { + return createDriver( address, uri.getScheme(), connectionPool, config, routingSettings, securityPlan ); + } + catch ( Throwable driverError ) + { + // we need to close the connection pool if driver creation threw exception + try + { + connectionPool.close(); + } + catch ( Throwable closeError ) + { + driverError.addSuppressed( closeError ); + } + throw driverError; + } + } + + private Driver createDriver( BoltServerAddress address, String scheme, ConnectionPool connectionPool, + Config config, RoutingSettings routingSettings, SecurityPlan securityPlan ) + { + switch ( scheme.toLowerCase() ) + { + case "bolt": + return createDirectDriver( address, connectionPool, config, securityPlan ); + case "bolt+routing": + return createRoutingDriver( address, connectionPool, config, routingSettings, securityPlan ); + default: + throw new ClientException( format( "Unsupported URI scheme: %s", scheme ) ); + } + } + + /** + * Creates new {@link DirectDriver}. + *

+ * This method is package-private only for testing + */ + DirectDriver createDirectDriver( BoltServerAddress address, ConnectionPool connectionPool, Config config, + SecurityPlan securityPlan ) + { + return new DirectDriver( address, connectionPool, securityPlan, config.logging() ); + } + + /** + * Creates new {@link RoutingDriver}. + *

+ * This method is package-private only for testing + */ + RoutingDriver createRoutingDriver( BoltServerAddress address, ConnectionPool connectionPool, + Config config, RoutingSettings routingSettings, SecurityPlan securityPlan ) + { + return new RoutingDriver( routingSettings, address, connectionPool, securityPlan, Clock.SYSTEM, + config.logging() ); + } + + /** + * Creates new {@link ConnectionPool}. + *

+ * This method is package-private only for testing + */ + ConnectionPool createConnectionPool( AuthToken authToken, SecurityPlan securityPlan, Config config ) + { + authToken = authToken == null ? AuthTokens.none() : authToken; + + ConnectionSettings connectionSettings = new ConnectionSettings( authToken ); + PoolSettings poolSettings = new PoolSettings( config.maxIdleConnectionPoolSize() ); + Connector connector = new SocketConnector( connectionSettings, securityPlan, config.logging() ); + + return new SocketConnectionPool( poolSettings, connector, Clock.SYSTEM, config.logging() ); + } + + private static SecurityPlan createSecurityPlan( BoltServerAddress address, Config config ) + { + try + { + return createSecurityPlanImpl( address, config ); + } + catch ( GeneralSecurityException | IOException ex ) + { + throw new ClientException( "Unable to establish SSL parameters", ex ); + } + } + + /* + * Establish a complete SecurityPlan based on the details provided for + * driver construction. + */ + private static SecurityPlan createSecurityPlanImpl( BoltServerAddress address, Config config ) + throws GeneralSecurityException, IOException + { + Config.EncryptionLevel encryptionLevel = config.encryptionLevel(); + boolean requiresEncryption = encryptionLevel.equals( REQUIRED ); + + if ( requiresEncryption ) + { + Logger logger = config.logging().getLog( "session" ); + switch ( config.trustStrategy().strategy() ) + { + + // DEPRECATED CASES // + case TRUST_ON_FIRST_USE: + logger.warn( + "Option `TRUST_ON_FIRST_USE` has been deprecated and will be removed in a future " + + "version of the driver. Please switch to use `TRUST_ALL_CERTIFICATES` instead." ); + return SecurityPlan.forTrustOnFirstUse( config.trustStrategy().certFile(), address, logger ); + case TRUST_SIGNED_CERTIFICATES: + logger.warn( + "Option `TRUST_SIGNED_CERTIFICATE` has been deprecated and will be removed in a future " + + "version of the driver. Please switch to use `TRUST_CUSTOM_CA_SIGNED_CERTIFICATES` instead." ); + // intentional fallthrough + // END OF DEPRECATED CASES // + + case TRUST_CUSTOM_CA_SIGNED_CERTIFICATES: + return SecurityPlan.forCustomCASignedCertificates( config.trustStrategy().certFile() ); + case TRUST_SYSTEM_CA_SIGNED_CERTIFICATES: + return SecurityPlan.forSystemCASignedCertificates(); + case TRUST_ALL_CERTIFICATES: + return SecurityPlan.forAllCertificates(); + default: + throw new ClientException( + "Unknown TLS authentication strategy: " + config.trustStrategy().strategy().name() ); + } + } + else + { + return insecure(); + } + } +} diff --git a/driver/src/main/java/org/neo4j/driver/internal/connector/socket/SocketUtils.java b/driver/src/main/java/org/neo4j/driver/internal/connector/socket/SocketUtils.java deleted file mode 100644 index 0f6c1e75fe..0000000000 --- a/driver/src/main/java/org/neo4j/driver/internal/connector/socket/SocketUtils.java +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Copyright (c) 2002-2016 "Neo Technology," - * Network Engine for Objects in Lund AB [http://neotechnology.com] - * - * This file is part of Neo4j. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.neo4j.driver.internal.connector.socket; - -import java.io.IOException; -import java.nio.ByteBuffer; -import java.nio.channels.ByteChannel; - -import org.neo4j.driver.internal.util.BytePrinter; -import org.neo4j.driver.v1.exceptions.ClientException; - -/** - * Utility class for common operations. - */ -public final class SocketUtils -{ - private SocketUtils() - { - throw new UnsupportedOperationException( "Do not instantiate" ); - } - - public static void blockingRead(ByteChannel channel, ByteBuffer buf) throws IOException - { - while(buf.hasRemaining()) - { - if (channel.read( buf ) < 0) - { - try - { - channel.close(); - } - catch ( IOException e ) - { - // best effort - } - String bufStr = BytePrinter.hex( buf ).trim(); - throw new ClientException( String.format( - "Connection terminated while receiving data. This can happen due to network " + - "instabilities, or due to restarts of the database. Expected %s bytes, received %s.", - buf.limit(), bufStr.isEmpty() ? "none" : bufStr ) ); - } - } - } - - public static void blockingWrite(ByteChannel channel, ByteBuffer buf) throws IOException - { - while(buf.hasRemaining()) - { - if (channel.write( buf ) < 0) - { - try - { - channel.close(); - } - catch ( IOException e ) - { - // best effort - } - String bufStr = BytePrinter.hex( buf ).trim(); - throw new ClientException( String.format( - "Connection terminated while sending data. This can happen due to network " + - "instabilities, or due to restarts of the database. Expected %s bytes, wrote %s.", - buf.limit(), bufStr.isEmpty() ? "none" :bufStr ) ); - } - } - } -} diff --git a/driver/src/main/java/org/neo4j/driver/internal/net/ConcurrencyGuardingConnection.java b/driver/src/main/java/org/neo4j/driver/internal/net/ConcurrencyGuardingConnection.java index fd5f8dab45..577360019b 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/net/ConcurrencyGuardingConnection.java +++ b/driver/src/main/java/org/neo4j/driver/internal/net/ConcurrencyGuardingConnection.java @@ -174,15 +174,9 @@ public void receiveOne() @Override public void close() { - try - { - markAsInUse(); - delegate.close(); - } - finally - { - markAsAvailable(); - } + // It is fine to call close concurrently with this connection being used somewhere else. + // This could happen when driver is closed while there still exist sessions that do some work. + delegate.close(); } @Override diff --git a/driver/src/main/java/org/neo4j/driver/internal/net/SocketConnector.java b/driver/src/main/java/org/neo4j/driver/internal/net/SocketConnector.java new file mode 100644 index 0000000000..979b1f1018 --- /dev/null +++ b/driver/src/main/java/org/neo4j/driver/internal/net/SocketConnector.java @@ -0,0 +1,92 @@ +/* + * Copyright (c) 2002-2016 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal.net; + +import java.util.Map; + +import org.neo4j.driver.internal.ConnectionSettings; +import org.neo4j.driver.internal.security.InternalAuthToken; +import org.neo4j.driver.internal.security.SecurityPlan; +import org.neo4j.driver.internal.spi.Connection; +import org.neo4j.driver.internal.spi.Connector; +import org.neo4j.driver.v1.AuthToken; +import org.neo4j.driver.v1.AuthTokens; +import org.neo4j.driver.v1.Logging; +import org.neo4j.driver.v1.Value; +import org.neo4j.driver.v1.exceptions.ClientException; + +public class SocketConnector implements Connector +{ + private final ConnectionSettings connectionSettings; + private final SecurityPlan securityPlan; + private final Logging logging; + + public SocketConnector( ConnectionSettings connectionSettings, SecurityPlan securityPlan, Logging logging ) + { + this.connectionSettings = connectionSettings; + this.securityPlan = securityPlan; + this.logging = logging; + } + + @Override + public final Connection connect( BoltServerAddress address ) + { + Connection connection = createConnection( address, securityPlan, logging ); + + // Because SocketConnection is not thread safe, wrap it in this guard + // to ensure concurrent access leads causes application errors + connection = new ConcurrencyGuardingConnection( connection ); + + try + { + connection.init( connectionSettings.userAgent(), tokenAsMap( connectionSettings.authToken() ) ); + } + catch ( Throwable initError ) + { + connection.close(); + throw initError; + } + + return connection; + } + + /** + * Create new connection. + *

+ * This method is package-private only for testing + */ + Connection createConnection( BoltServerAddress address, SecurityPlan securityPlan, Logging logging ) + { + return new SocketConnection( address, securityPlan, logging ); + } + + private static Map tokenAsMap( AuthToken token ) + { + if ( token instanceof InternalAuthToken ) + { + return ((InternalAuthToken) token).toMap(); + } + else + { + throw new ClientException( + "Unknown authentication token, `" + token + "`. Please use one of the supported " + + "tokens from `" + AuthTokens.class.getSimpleName() + "`." ); + } + } +} diff --git a/driver/src/main/java/org/neo4j/driver/internal/net/pooling/BlockingPooledConnectionQueue.java b/driver/src/main/java/org/neo4j/driver/internal/net/pooling/BlockingPooledConnectionQueue.java index a23e0106e1..156e9c9b04 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/net/pooling/BlockingPooledConnectionQueue.java +++ b/driver/src/main/java/org/neo4j/driver/internal/net/pooling/BlockingPooledConnectionQueue.java @@ -27,7 +27,10 @@ import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.atomic.AtomicBoolean; +import org.neo4j.driver.internal.net.BoltServerAddress; import org.neo4j.driver.internal.util.Supplier; +import org.neo4j.driver.v1.Logger; +import org.neo4j.driver.v1.Logging; /** * A blocking queue that also keeps track of connections that are acquired in order @@ -37,6 +40,7 @@ public class BlockingPooledConnectionQueue { /** The backing queue, keeps track of connections currently in queue */ private final BlockingQueue queue; + private final Logger logger; private final AtomicBoolean isTerminating = new AtomicBoolean( false ); @@ -44,9 +48,10 @@ public class BlockingPooledConnectionQueue private final Set acquiredConnections = Collections.newSetFromMap(new ConcurrentHashMap()); - public BlockingPooledConnectionQueue( int capacity ) + public BlockingPooledConnectionQueue( BoltServerAddress address, int capacity, Logging logging ) { this.queue = new LinkedBlockingQueue<>( capacity ); + this.logger = createLogger( address, logging ); } /** @@ -64,10 +69,10 @@ public boolean offer( PooledConnection pooledConnection ) pooledConnection.dispose(); } if (isTerminating.get()) { - PooledConnection poll = queue.poll(); - if (poll != null) + PooledConnection connection = queue.poll(); + if (connection != null) { - poll.dispose(); + connection.dispose(); } } return offer; @@ -81,19 +86,19 @@ public boolean offer( PooledConnection pooledConnection ) public PooledConnection acquire( Supplier supplier ) { - PooledConnection poll = queue.poll(); - if ( poll == null ) + PooledConnection connection = queue.poll(); + if ( connection == null ) { - poll = supplier.get(); + connection = supplier.get(); } - acquiredConnections.add( poll ); + acquiredConnections.add( connection ); if (isTerminating.get()) { - acquiredConnections.remove( poll ); - poll.dispose(); + acquiredConnections.remove( connection ); + connection.dispose(); throw new IllegalStateException( "Pool has been closed, cannot acquire new values." ); } - return poll; + return connection; } public List toList() @@ -119,24 +124,43 @@ public boolean contains( PooledConnection pooledConnection ) /** * Terminates all connections, both those that are currently in the queue as well * as those that have been acquired. + *

+ * This method does not throw runtime exceptions. All connection close failures are only logged. */ public void terminate() { - if (isTerminating.compareAndSet( false, true )) + if ( isTerminating.compareAndSet( false, true ) ) { while ( !queue.isEmpty() ) { - PooledConnection conn = queue.poll(); - if ( conn != null ) - { - //close the underlying connection without adding it back to the queue - conn.dispose(); - } + PooledConnection idleConnection = queue.poll(); + disposeSafely( idleConnection ); } - for ( PooledConnection pooledConnection : acquiredConnections ) + for ( PooledConnection acquiredConnection : acquiredConnections ) { - pooledConnection.dispose(); + disposeSafely( acquiredConnection ); } } } + + private void disposeSafely( PooledConnection connection ) + { + try + { + if ( connection != null ) + { + // close the underlying connection without adding it back to the queue + connection.dispose(); + } + } + catch ( Throwable disposeError ) + { + logger.error( "Error disposing connection", disposeError ); + } + } + + private static Logger createLogger( BoltServerAddress address, Logging logging ) + { + return logging.getLog( "connectionQueue[" + address + "]" ); + } } diff --git a/driver/src/main/java/org/neo4j/driver/internal/net/pooling/SocketConnectionPool.java b/driver/src/main/java/org/neo4j/driver/internal/net/pooling/SocketConnectionPool.java index f709e009b4..a30ef274d0 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/net/pooling/SocketConnectionPool.java +++ b/driver/src/main/java/org/neo4j/driver/internal/net/pooling/SocketConnectionPool.java @@ -18,27 +18,16 @@ */ package org.neo4j.driver.internal.net.pooling; -import java.util.List; -import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicBoolean; -import org.neo4j.driver.internal.ConnectionSettings; import org.neo4j.driver.internal.net.BoltServerAddress; -import org.neo4j.driver.internal.net.ConcurrencyGuardingConnection; -import org.neo4j.driver.internal.net.SocketConnection; -import org.neo4j.driver.internal.security.InternalAuthToken; -import org.neo4j.driver.internal.security.SecurityPlan; import org.neo4j.driver.internal.spi.Connection; import org.neo4j.driver.internal.spi.ConnectionPool; +import org.neo4j.driver.internal.spi.Connector; import org.neo4j.driver.internal.util.Clock; import org.neo4j.driver.internal.util.Supplier; -import org.neo4j.driver.v1.AuthToken; -import org.neo4j.driver.v1.AuthTokens; import org.neo4j.driver.v1.Logging; -import org.neo4j.driver.v1.Value; -import org.neo4j.driver.v1.exceptions.ClientException; - -import static java.util.Collections.emptyList; /** * The pool is designed to buffer certain amount of free sessions into session pool. When closing a session, we first @@ -60,52 +49,26 @@ public class SocketConnectionPool implements ConnectionPool private final ConcurrentHashMap pools = new ConcurrentHashMap<>(); - private final Clock clock = Clock.SYSTEM; + private final AtomicBoolean closed = new AtomicBoolean(); - private final ConnectionSettings connectionSettings; - private final SecurityPlan securityPlan; private final PoolSettings poolSettings; + private final Connector connector; + private final Clock clock; private final Logging logging; - /** Shutdown flag */ - - public SocketConnectionPool( ConnectionSettings connectionSettings, SecurityPlan securityPlan, - PoolSettings poolSettings, Logging logging ) + public SocketConnectionPool( PoolSettings poolSettings, Connector connector, Clock clock, Logging logging ) { - this.connectionSettings = connectionSettings; - this.securityPlan = securityPlan; this.poolSettings = poolSettings; + this.connector = connector; + this.clock = clock; this.logging = logging; } - private Connection connect( BoltServerAddress address ) throws ClientException - { - Connection conn = new SocketConnection( address, securityPlan, logging ); - - // Because SocketConnection is not thread safe, wrap it in this guard - // to ensure concurrent access leads causes application errors - conn = new ConcurrencyGuardingConnection( conn ); - conn.init( connectionSettings.userAgent(), tokenAsMap( connectionSettings.authToken() ) ); - return conn; - } - - private static Map tokenAsMap( AuthToken token ) - { - if ( token instanceof InternalAuthToken ) - { - return ((InternalAuthToken) token).toMap(); - } - else - { - throw new ClientException( - "Unknown authentication token, `" + token + "`. Please use one of the supported " + - "tokens from `" + AuthTokens.class.getSimpleName() + "`." ); - } - } - @Override public Connection acquire( final BoltServerAddress address ) { + assertNotClosed(); + final BlockingPooledConnectionQueue connections = pool( address ); Supplier supplier = new Supplier() { @@ -116,10 +79,17 @@ public PooledConnection get() new PooledConnectionValidator( SocketConnectionPool.this ); PooledConnectionReleaseConsumer releaseConsumer = new PooledConnectionReleaseConsumer( connections, connectionValidator ); - return new PooledConnection( connect( address ), releaseConsumer, clock ); + return new PooledConnection( connector.connect( address ), releaseConsumer, clock ); } }; PooledConnection conn = connections.acquire( supplier ); + + if ( closed.get() ) + { + connections.terminate(); + throw poolClosedException(); + } + conn.updateTimestamp(); return conn; } @@ -129,7 +99,7 @@ private BlockingPooledConnectionQueue pool( BoltServerAddress address ) BlockingPooledConnectionQueue pool = pools.get( address ); if ( pool == null ) { - pool = new BlockingPooledConnectionQueue( poolSettings.maxIdleConnectionPoolSize() ); + pool = new BlockingPooledConnectionQueue( address, poolSettings.maxIdleConnectionPoolSize(), logging ); if ( pools.putIfAbsent( address, pool ) != null ) { @@ -144,12 +114,10 @@ private BlockingPooledConnectionQueue pool( BoltServerAddress address ) public void purge( BoltServerAddress address ) { BlockingPooledConnectionQueue connections = pools.remove( address ); - if ( connections == null ) + if ( connections != null ) { - return; + connections.terminate(); } - - connections.terminate(); } @Override @@ -161,28 +129,27 @@ public boolean hasAddress( BoltServerAddress address ) @Override public void close() { - for ( BlockingPooledConnectionQueue pool : pools.values() ) + if ( closed.compareAndSet( false, true ) ) { - pool.terminate(); - } + for ( BlockingPooledConnectionQueue pool : pools.values() ) + { + pool.terminate(); + } - pools.clear(); + pools.clear(); + } } - - //for testing - public List connectionsForAddress( BoltServerAddress address ) + private void assertNotClosed() { - BlockingPooledConnectionQueue pooledConnections = pools.get( address ); - if ( pooledConnections == null ) - { - return emptyList(); - } - else + if ( closed.get() ) { - return pooledConnections.toList(); + throw poolClosedException(); } } - + private static RuntimeException poolClosedException() + { + return new IllegalStateException( "Pool closed" ); + } } diff --git a/driver/src/main/java/org/neo4j/driver/internal/spi/Connector.java b/driver/src/main/java/org/neo4j/driver/internal/spi/Connector.java new file mode 100644 index 0000000000..b512203c56 --- /dev/null +++ b/driver/src/main/java/org/neo4j/driver/internal/spi/Connector.java @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2002-2016 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal.spi; + +import org.neo4j.driver.internal.net.BoltServerAddress; + +public interface Connector +{ + Connection connect( BoltServerAddress address ); +} diff --git a/driver/src/main/java/org/neo4j/driver/v1/GraphDatabase.java b/driver/src/main/java/org/neo4j/driver/v1/GraphDatabase.java index 4c61e9d05a..3ab9e17b8e 100644 --- a/driver/src/main/java/org/neo4j/driver/v1/GraphDatabase.java +++ b/driver/src/main/java/org/neo4j/driver/v1/GraphDatabase.java @@ -18,27 +18,9 @@ */ package org.neo4j.driver.v1; -import java.io.IOException; import java.net.URI; -import java.security.GeneralSecurityException; -import org.neo4j.driver.internal.ConnectionSettings; -import org.neo4j.driver.internal.DirectDriver; -import org.neo4j.driver.internal.NetworkSession; -import org.neo4j.driver.internal.RoutingDriver; -import org.neo4j.driver.internal.net.BoltServerAddress; -import org.neo4j.driver.internal.net.pooling.PoolSettings; -import org.neo4j.driver.internal.net.pooling.SocketConnectionPool; -import org.neo4j.driver.internal.security.SecurityPlan; -import org.neo4j.driver.internal.spi.Connection; -import org.neo4j.driver.internal.spi.ConnectionPool; -import org.neo4j.driver.internal.util.Clock; -import org.neo4j.driver.v1.exceptions.ClientException; -import org.neo4j.driver.v1.util.Function; - -import static java.lang.String.format; -import static org.neo4j.driver.internal.security.SecurityPlan.insecure; -import static org.neo4j.driver.v1.Config.EncryptionLevel.REQUIRED; +import org.neo4j.driver.internal.DriverFactory; /** * Creates {@link Driver drivers}, optionally letting you {@link #driver(URI, Config)} to configure them. @@ -47,17 +29,6 @@ */ public class GraphDatabase { - - private static final Function - SESSION_PROVIDER = new Function() - { - @Override - public Session apply( Connection connection ) - { - return new NetworkSession( connection ); - } - }; - /** * Return a driver for a Neo4j instance with the default configuration settings * @@ -151,97 +122,9 @@ public static Driver driver( String uri, AuthToken authToken, Config config ) */ public static Driver driver( URI uri, AuthToken authToken, Config config ) { - // Break down the URI into its constituent parts - String scheme = uri.getScheme(); - BoltServerAddress address = BoltServerAddress.from( uri ); - - // Collate session parameters - ConnectionSettings connectionSettings = - new ConnectionSettings( authToken == null ? AuthTokens.none() : authToken ); - // Make sure we have some configuration to play with - if ( config == null ) - { - config = Config.defaultConfig(); - } - - // Construct security plan - SecurityPlan securityPlan; - try - { - securityPlan = createSecurityPlan( address, config ); - } - catch ( GeneralSecurityException | IOException ex ) - { - throw new ClientException( "Unable to establish SSL parameters", ex ); - } - - // Establish pool settings - PoolSettings poolSettings = new PoolSettings( config.maxIdleConnectionPoolSize() ); - - // And finally, construct the driver proper - ConnectionPool connectionPool = - new SocketConnectionPool( connectionSettings, securityPlan, poolSettings, config.logging() ); - switch ( scheme.toLowerCase() ) - { - case "bolt": - return new DirectDriver( address, connectionPool, securityPlan, config.logging() ); - case "bolt+routing": - return new RoutingDriver( - config.routingSettings(), - address, - connectionPool, - securityPlan, - Clock.SYSTEM, - config.logging() ); - default: - throw new ClientException( format( "Unsupported URI scheme: %s", scheme ) ); - } - } - - /* - * Establish a complete SecurityPlan based on the details provided for - * driver construction. - */ - private static SecurityPlan createSecurityPlan( BoltServerAddress address, Config config ) - throws GeneralSecurityException, IOException - { - Config.EncryptionLevel encryptionLevel = config.encryptionLevel(); - boolean requiresEncryption = encryptionLevel.equals( REQUIRED ); - - if ( requiresEncryption ) - { - Logger logger = config.logging().getLog( "session" ); - switch ( config.trustStrategy().strategy() ) - { - - // DEPRECATED CASES // - case TRUST_ON_FIRST_USE: - logger.warn( - "Option `TRUST_ON_FIRST_USE` has been deprecated and will be removed in a future " + - "version of the driver. Please switch to use `TRUST_ALL_CERTIFICATES` instead." ); - return SecurityPlan.forTrustOnFirstUse( config.trustStrategy().certFile(), address, logger ); - case TRUST_SIGNED_CERTIFICATES: - logger.warn( - "Option `TRUST_SIGNED_CERTIFICATE` has been deprecated and will be removed in a future " + - "version of the driver. Please switch to use `TRUST_CUSTOM_CA_SIGNED_CERTIFICATES` instead." ); - // intentional fallthrough - // END OF DEPRECATED CASES // + config = config == null ? Config.defaultConfig() : config; - case TRUST_CUSTOM_CA_SIGNED_CERTIFICATES: - return SecurityPlan.forCustomCASignedCertificates( config.trustStrategy().certFile() ); - case TRUST_SYSTEM_CA_SIGNED_CERTIFICATES: - return SecurityPlan.forSystemCASignedCertificates(); - case TRUST_ALL_CERTIFICATES: - return SecurityPlan.forAllCertificates(); - default: - throw new ClientException( - "Unknown TLS authentication strategy: " + config.trustStrategy().strategy().name() ); - } - } - else - { - return insecure(); - } + return new DriverFactory().newInstance( uri, authToken, config.routingSettings(), config ); } } diff --git a/driver/src/test/java/org/neo4j/driver/internal/DriverFactoryTest.java b/driver/src/test/java/org/neo4j/driver/internal/DriverFactoryTest.java new file mode 100644 index 0000000000..f96b9762f6 --- /dev/null +++ b/driver/src/test/java/org/neo4j/driver/internal/DriverFactoryTest.java @@ -0,0 +1,142 @@ +/* + * Copyright (c) 2002-2016 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; + +import java.net.URI; +import java.util.Arrays; +import java.util.List; + +import org.neo4j.driver.internal.cluster.RoutingSettings; +import org.neo4j.driver.internal.net.BoltServerAddress; +import org.neo4j.driver.internal.security.SecurityPlan; +import org.neo4j.driver.internal.spi.ConnectionPool; +import org.neo4j.driver.v1.AuthToken; +import org.neo4j.driver.v1.AuthTokens; +import org.neo4j.driver.v1.Config; + +import static org.hamcrest.Matchers.instanceOf; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.neo4j.driver.v1.Config.defaultConfig; + +@RunWith( Parameterized.class ) +public class DriverFactoryTest +{ + @Parameter + public URI uri; + + @Parameters( name = "{0}" ) + public static List uris() + { + return Arrays.asList( + URI.create( "bolt://localhost:7687" ), + URI.create( "bolt+routing://localhost:7687" ) + ); + } + + @Test + public void connectionPoolClosedWhenDriverCreationFails() throws Exception + { + ConnectionPool connectionPool = mock( ConnectionPool.class ); + DriverFactory factory = new ThrowingDriverFactory( connectionPool ); + + try + { + factory.newInstance( uri, dummyAuthToken(), dummyRoutingSettings(), defaultConfig() ); + fail( "Exception expected" ); + } + catch ( Exception e ) + { + assertThat( e, instanceOf( UnsupportedOperationException.class ) ); + } + verify( connectionPool ).close(); + } + + @Test + public void connectionPoolCloseExceptionIsSupressedWhenDriverCreationFails() throws Exception + { + ConnectionPool connectionPool = mock( ConnectionPool.class ); + RuntimeException poolCloseError = new RuntimeException( "Pool close error" ); + doThrow( poolCloseError ).when( connectionPool ).close(); + + DriverFactory factory = new ThrowingDriverFactory( connectionPool ); + + try + { + factory.newInstance( uri, dummyAuthToken(), dummyRoutingSettings(), defaultConfig() ); + fail( "Exception expected" ); + } + catch ( Exception e ) + { + assertThat( e, instanceOf( UnsupportedOperationException.class ) ); + assertArrayEquals( new Throwable[]{poolCloseError}, e.getSuppressed() ); + } + verify( connectionPool ).close(); + } + + private static AuthToken dummyAuthToken() + { + return AuthTokens.basic( "neo4j", "neo4j" ); + } + + private static RoutingSettings dummyRoutingSettings() + { + return new RoutingSettings( 42, 42 ); + } + + private static class ThrowingDriverFactory extends DriverFactory + { + final ConnectionPool connectionPool; + + ThrowingDriverFactory( ConnectionPool connectionPool ) + { + this.connectionPool = connectionPool; + } + + @Override + DirectDriver createDirectDriver( BoltServerAddress address, ConnectionPool connectionPool, Config config, + SecurityPlan securityPlan ) + { + throw new UnsupportedOperationException( "Can't create direct driver" ); + } + + @Override + RoutingDriver createRoutingDriver( BoltServerAddress address, ConnectionPool connectionPool, Config config, + RoutingSettings routingSettings, SecurityPlan securityPlan ) + { + throw new UnsupportedOperationException( "Can't create routing driver" ); + } + + @Override + ConnectionPool createConnectionPool( AuthToken authToken, SecurityPlan securityPlan, Config config ) + { + return connectionPool; + } + } +} diff --git a/driver/src/test/java/org/neo4j/driver/internal/net/ConcurrencyGuardingConnectionTest.java b/driver/src/test/java/org/neo4j/driver/internal/net/ConcurrencyGuardingConnectionTest.java index 0e136fc98a..319f79de00 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/net/ConcurrencyGuardingConnectionTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/net/ConcurrencyGuardingConnectionTest.java @@ -28,14 +28,16 @@ import java.util.concurrent.atomic.AtomicReference; import org.neo4j.driver.internal.spi.Connection; -import org.neo4j.driver.v1.util.Function; import org.neo4j.driver.v1.exceptions.ClientException; +import org.neo4j.driver.v1.util.Function; import static java.util.Arrays.asList; import static junit.framework.TestCase.fail; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.MatcherAssert.assertThat; +import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; @RunWith( Parameterized.class ) public class ConcurrencyGuardingConnectionTest @@ -44,17 +46,19 @@ public class ConcurrencyGuardingConnectionTest public Function operation; @Parameterized.Parameters - public static List params() + public static List> params() { return asList( - new Object[]{INIT}, - new Object[]{RUN}, - new Object[]{PULL_ALL}, - new Object[]{DISCARD_ALL}, - new Object[]{CLOSE}, - new Object[]{RECIEVE_ONE}, - new Object[]{FLUSH}, - new Object[]{SYNC}); + INIT, + RUN, + PULL_ALL, + DISCARD_ALL, + RECIEVE_ONE, + FLUSH, + SYNC, + RESET, + ACK_FAILURE + ); } @Test @@ -95,6 +99,32 @@ public Object answer( InvocationOnMock invocationOnMock ) throws Throwable "do that is to give each thread its own dedicated session.") ); } + @Test + public void shouldAllowConcurrentClose() + { + // Given + final AtomicReference connection = new AtomicReference<>(); + + Connection delegate = mock( Connection.class, new Answer() + { + @Override + public Void answer( InvocationOnMock invocation ) throws Throwable + { + connection.get().close(); + return null; + } + } ); + doNothing().when( delegate ).close(); + + connection.set( new ConcurrencyGuardingConnection( delegate ) ); + + // When + operation.apply( connection.get() ); + + // Then + verify( delegate ).close(); + } + public static final Function INIT = new Function() { @Override @@ -135,22 +165,32 @@ public Void apply( Connection connection ) } }; - public static final Function RECIEVE_ONE = new Function() + public static final Function RESET = new Function() { @Override public Void apply( Connection connection ) { - connection.receiveOne(); + connection.reset(); return null; } }; - public static final Function CLOSE = new Function() + public static final Function ACK_FAILURE = new Function() { @Override public Void apply( Connection connection ) { - connection.close(); + connection.ackFailure(); + return null; + } + }; + + public static final Function RECIEVE_ONE = new Function() + { + @Override + public Void apply( Connection connection ) + { + connection.receiveOne(); return null; } }; @@ -174,4 +214,4 @@ public Void apply( Connection connection ) return null; } }; -} \ No newline at end of file +} diff --git a/driver/src/test/java/org/neo4j/driver/internal/net/SocketConnectorTest.java b/driver/src/test/java/org/neo4j/driver/internal/net/SocketConnectorTest.java new file mode 100644 index 0000000000..b97f6c4c39 --- /dev/null +++ b/driver/src/test/java/org/neo4j/driver/internal/net/SocketConnectorTest.java @@ -0,0 +1,163 @@ +/* + * Copyright (c) 2002-2016 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal.net; + +import org.junit.Test; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.CopyOnWriteArrayList; + +import org.neo4j.driver.internal.ConnectionSettings; +import org.neo4j.driver.internal.security.SecurityPlan; +import org.neo4j.driver.internal.spi.Connection; +import org.neo4j.driver.v1.AuthToken; +import org.neo4j.driver.v1.AuthTokens; +import org.neo4j.driver.v1.Logging; +import org.neo4j.driver.v1.exceptions.ClientException; + +import static org.hamcrest.Matchers.instanceOf; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertSame; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.fail; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyString; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.RETURNS_MOCKS; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.neo4j.driver.internal.net.BoltServerAddress.LOCAL_DEFAULT; + +public class SocketConnectorTest +{ + @Test + public void connectCreatesConnection() + { + ConnectionSettings settings = new ConnectionSettings( basicAuthToken() ); + SocketConnector connector = new RecordingSocketConnector( settings ); + + Connection connection = connector.connect( LOCAL_DEFAULT ); + + assertThat( connection, instanceOf( ConcurrencyGuardingConnection.class ) ); + } + + @Test + @SuppressWarnings( "unchecked" ) + public void connectSendsInit() + { + String userAgent = "agentSmith"; + ConnectionSettings settings = new ConnectionSettings( basicAuthToken(), userAgent ); + RecordingSocketConnector connector = new RecordingSocketConnector( settings ); + + connector.connect( LOCAL_DEFAULT ); + + assertEquals( 1, connector.createConnections.size() ); + Connection connection = connector.createConnections.get( 0 ); + verify( connection ).init( eq( userAgent ), any( Map.class ) ); + } + + @Test + public void connectThrowsForUnknownAuthToken() + { + ConnectionSettings settings = new ConnectionSettings( mock( AuthToken.class ) ); + RecordingSocketConnector connector = new RecordingSocketConnector( settings ); + + try + { + connector.connect( LOCAL_DEFAULT ); + fail( "Exception expected" ); + } + catch ( Exception e ) + { + assertThat( e, instanceOf( ClientException.class ) ); + } + } + + @Test + @SuppressWarnings( "unchecked" ) + public void connectClosesOpenedConnectionIfInitThrows() + { + Connection connection = mock( Connection.class ); + RuntimeException initError = new RuntimeException( "Init error" ); + doThrow( initError ).when( connection ).init( anyString(), any( Map.class ) ); + + StubSocketConnector connector = new StubSocketConnector( connection ); + + try + { + connector.connect( LOCAL_DEFAULT ); + fail( "Exception expected" ); + } + catch ( Exception e ) + { + assertSame( initError, e ); + } + + verify( connection ).close(); + } + + private static Logging loggingMock() + { + return mock( Logging.class, RETURNS_MOCKS ); + } + + private static AuthToken basicAuthToken() + { + return AuthTokens.basic( "neo4j", "neo4j" ); + } + + private static class RecordingSocketConnector extends SocketConnector + { + final List createConnections = new CopyOnWriteArrayList<>(); + + RecordingSocketConnector( ConnectionSettings settings ) + { + super( settings, SecurityPlan.insecure(), loggingMock() ); + } + + @Override + Connection createConnection( BoltServerAddress address, SecurityPlan securityPlan, Logging logging ) + { + Connection connection = mock( Connection.class ); + when( connection.boltServerAddress() ).thenReturn( address ); + createConnections.add( connection ); + return connection; + } + } + + private static class StubSocketConnector extends SocketConnector + { + final Connection connection; + + StubSocketConnector( Connection connection ) + { + super( new ConnectionSettings( basicAuthToken() ), SecurityPlan.insecure(), loggingMock() ); + this.connection = connection; + } + + @Override + Connection createConnection( BoltServerAddress address, SecurityPlan securityPlan, Logging logging ) + { + return connection; + } + } +} diff --git a/driver/src/test/java/org/neo4j/driver/internal/net/pooling/BlockingPooledConnectionQueueTest.java b/driver/src/test/java/org/neo4j/driver/internal/net/pooling/BlockingPooledConnectionQueueTest.java index 4a0dc81640..3f35a033e2 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/net/pooling/BlockingPooledConnectionQueueTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/net/pooling/BlockingPooledConnectionQueueTest.java @@ -21,16 +21,28 @@ import org.junit.Test; +import org.neo4j.driver.internal.spi.Connection; +import org.neo4j.driver.internal.util.Consumer; import org.neo4j.driver.internal.util.Supplier; +import org.neo4j.driver.v1.Logger; +import org.neo4j.driver.v1.Logging; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; +import static org.mockito.Matchers.anyString; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.RETURNS_MOCKS; +import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.neo4j.driver.internal.net.BoltServerAddress.LOCAL_DEFAULT; +import static org.neo4j.driver.internal.util.Clock.SYSTEM; public class BlockingPooledConnectionQueueTest { @@ -42,7 +54,7 @@ public void shouldCreateNewConnectionWhenEmpty() PooledConnection connection = mock( PooledConnection.class ); Supplier supplier = mock( Supplier.class ); when( supplier.get() ).thenReturn( connection ); - BlockingPooledConnectionQueue queue = new BlockingPooledConnectionQueue( 10 ); + BlockingPooledConnectionQueue queue = newConnectionQueue( 10 ); // When queue.acquire( supplier ); @@ -59,7 +71,7 @@ public void shouldNotCreateNewConnectionWhenNotEmpty() PooledConnection connection = mock( PooledConnection.class ); Supplier supplier = mock( Supplier.class ); when( supplier.get() ).thenReturn( connection ); - BlockingPooledConnectionQueue queue = new BlockingPooledConnectionQueue( 1 ); + BlockingPooledConnectionQueue queue = newConnectionQueue( 1 ); queue.offer( connection ); // When @@ -78,7 +90,7 @@ public void shouldTerminateAllSeenConnections() PooledConnection connection2 = mock( PooledConnection.class ); Supplier supplier = mock( Supplier.class ); when( supplier.get() ).thenReturn( connection1 ); - BlockingPooledConnectionQueue queue = new BlockingPooledConnectionQueue( 2 ); + BlockingPooledConnectionQueue queue = newConnectionQueue( 2 ); queue.offer( connection1 ); queue.offer( connection2 ); assertThat( queue.size(), equalTo( 2 ) ); @@ -99,10 +111,140 @@ public void shouldNotAcceptWhenFull() // Given PooledConnection connection1 = mock( PooledConnection.class ); PooledConnection connection2 = mock( PooledConnection.class ); - BlockingPooledConnectionQueue queue = new BlockingPooledConnectionQueue( 1 ); + BlockingPooledConnectionQueue queue = newConnectionQueue( 1 ); // Then - assertTrue(queue.offer( connection1 )); - assertFalse(queue.offer( connection2 )); + assertTrue( queue.offer( connection1 ) ); + assertFalse( queue.offer( connection2 ) ); } -} \ No newline at end of file + + @Test + public void shouldDisposeAllConnectionsWhenOneOfThemFailsToDispose() + { + BlockingPooledConnectionQueue queue = newConnectionQueue( 5 ); + + PooledConnection connection1 = mock( PooledConnection.class ); + PooledConnection connection2 = mock( PooledConnection.class ); + PooledConnection connection3 = mock( PooledConnection.class ); + + RuntimeException disposeError = new RuntimeException( "Failed to stop socket" ); + doThrow( disposeError ).when( connection2 ).dispose(); + + queue.offer( connection1 ); + queue.offer( connection2 ); + queue.offer( connection3 ); + + queue.terminate(); + + verify( connection1 ).dispose(); + verify( connection2 ).dispose(); + verify( connection3 ).dispose(); + } + + @Test + @SuppressWarnings( "unchecked" ) + public void shouldTryToCloseAllUnderlyingConnections() + { + BlockingPooledConnectionQueue queue = newConnectionQueue( 5 ); + + Connection connection1 = mock( Connection.class ); + Connection connection2 = mock( Connection.class ); + Connection connection3 = mock( Connection.class ); + + RuntimeException closeError1 = new RuntimeException( "Failed to close 1" ); + RuntimeException closeError2 = new RuntimeException( "Failed to close 2" ); + RuntimeException closeError3 = new RuntimeException( "Failed to close 3" ); + + doThrow( closeError1 ).when( connection1 ).close(); + doThrow( closeError2 ).when( connection2 ).close(); + doThrow( closeError3 ).when( connection3 ).close(); + + PooledConnection pooledConnection1 = new PooledConnection( connection1, mock( Consumer.class ), SYSTEM ); + PooledConnection pooledConnection2 = new PooledConnection( connection2, mock( Consumer.class ), SYSTEM ); + PooledConnection pooledConnection3 = new PooledConnection( connection3, mock( Consumer.class ), SYSTEM ); + + queue.offer( pooledConnection1 ); + queue.offer( pooledConnection2 ); + queue.offer( pooledConnection3 ); + + queue.terminate(); + + verify( connection1 ).close(); + verify( connection2 ).close(); + verify( connection3 ).close(); + } + + @Test + @SuppressWarnings( "unchecked" ) + public void shouldLogWhenConnectionDisposeFails() + { + Logging logging = mock( Logging.class ); + Logger logger = mock( Logger.class ); + when( logging.getLog( anyString() ) ).thenReturn( logger ); + + BlockingPooledConnectionQueue queue = newConnectionQueue( 5, logging ); + + Connection connection = mock( Connection.class ); + RuntimeException closeError = new RuntimeException( "Fail" ); + doThrow( closeError ).when( connection ).close(); + PooledConnection pooledConnection = new PooledConnection( connection, mock( Consumer.class ), SYSTEM ); + queue.offer( pooledConnection ); + + queue.terminate(); + + verify( logger ).error( anyString(), eq( closeError ) ); + } + + @Test + public void shouldHaveZeroSizeAfterTermination() + { + BlockingPooledConnectionQueue queue = newConnectionQueue( 5 ); + + queue.offer( mock( PooledConnection.class ) ); + queue.offer( mock( PooledConnection.class ) ); + queue.offer( mock( PooledConnection.class ) ); + + queue.terminate(); + + assertEquals( 0, queue.size() ); + } + + @Test + @SuppressWarnings( "unchecked" ) + public void shouldTerminateBothAcquiredAndIdleConnections() + { + BlockingPooledConnectionQueue queue = newConnectionQueue( 5 ); + + PooledConnection connection1 = mock( PooledConnection.class ); + PooledConnection connection2 = mock( PooledConnection.class ); + PooledConnection connection3 = mock( PooledConnection.class ); + PooledConnection connection4 = mock( PooledConnection.class ); + + queue.offer( connection1 ); + queue.offer( connection2 ); + queue.offer( connection3 ); + queue.offer( connection4 ); + + PooledConnection acquiredConnection1 = queue.acquire( mock( Supplier.class ) ); + PooledConnection acquiredConnection2 = queue.acquire( mock( Supplier.class ) ); + assertSame( connection1, acquiredConnection1 ); + assertSame( connection2, acquiredConnection2 ); + + queue.terminate(); + + verify( connection1 ).dispose(); + verify( connection2 ).dispose(); + verify( connection3 ).dispose(); + verify( connection4 ).dispose(); + } + + private static BlockingPooledConnectionQueue newConnectionQueue( int capacity ) + { + return newConnectionQueue( capacity, mock( Logging.class, RETURNS_MOCKS ) ); + } + + private static BlockingPooledConnectionQueue newConnectionQueue( int capacity, Logging logging ) + { + return new BlockingPooledConnectionQueue( LOCAL_DEFAULT, capacity, logging ); + } +} diff --git a/driver/src/test/java/org/neo4j/driver/internal/net/pooling/PooledConnectionTest.java b/driver/src/test/java/org/neo4j/driver/internal/net/pooling/PooledConnectionTest.java index 9feb17d304..76738dec2f 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/net/pooling/PooledConnectionTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/net/pooling/PooledConnectionTest.java @@ -23,6 +23,7 @@ import org.neo4j.driver.internal.spi.Connection; import org.neo4j.driver.internal.util.Clock; import org.neo4j.driver.internal.util.Supplier; +import org.neo4j.driver.v1.Logging; import org.neo4j.driver.v1.exceptions.ClientException; import org.neo4j.driver.v1.util.Function; @@ -31,11 +32,13 @@ import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; +import static org.mockito.Mockito.RETURNS_MOCKS; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.neo4j.driver.internal.net.BoltServerAddress.LOCAL_DEFAULT; public class PooledConnectionTest { @@ -63,8 +66,7 @@ public Boolean apply( PooledConnection pooledConnection ) public void shouldDisposeConnectionIfNotValidConnection() throws Throwable { // Given - final BlockingPooledConnectionQueue - pool = new BlockingPooledConnectionQueue(1); + final BlockingPooledConnectionQueue pool = newConnectionQueue(1); final boolean[] flags = {false}; @@ -93,8 +95,7 @@ public void dispose() public void shouldReturnToThePoolIfIsValidConnectionAndIdlePoolIsNotFull() throws Throwable { // Given - final BlockingPooledConnectionQueue - pool = new BlockingPooledConnectionQueue(1); + final BlockingPooledConnectionQueue pool = newConnectionQueue(1); final boolean[] flags = {false}; @@ -124,8 +125,7 @@ public void dispose() public void shouldDisposeConnectionIfValidConnectionAndIdlePoolIsFull() throws Throwable { // Given - final BlockingPooledConnectionQueue - pool = new BlockingPooledConnectionQueue(1); + final BlockingPooledConnectionQueue pool = newConnectionQueue(1); final boolean[] flags = {false}; @@ -158,7 +158,7 @@ public void shouldDisposeAcquiredConnectionsWhenPoolIsClosed() { PooledConnection connection = mock( PooledConnection.class ); - BlockingPooledConnectionQueue pool = new BlockingPooledConnectionQueue( 5 ); + BlockingPooledConnectionQueue pool = newConnectionQueue( 5 ); Supplier pooledConnectionFactory = mock( Supplier.class ); when( pooledConnectionFactory.get() ).thenReturn( connection ); @@ -178,7 +178,7 @@ public void shouldDisposeAcquiredAndIdleConnectionsWhenPoolIsClosed() PooledConnection connection2 = mock( PooledConnection.class ); PooledConnection connection3 = mock( PooledConnection.class ); - BlockingPooledConnectionQueue pool = new BlockingPooledConnectionQueue( 5 ); + BlockingPooledConnectionQueue pool = newConnectionQueue( 5 ); Supplier pooledConnectionFactory = mock( Supplier.class ); when( pooledConnectionFactory.get() ) @@ -212,7 +212,7 @@ public void shouldDisposeConnectionIfPoolAlreadyClosed() throws Throwable // session.close() -> well, close the connection directly without putting back to the pool // Given - final BlockingPooledConnectionQueue pool = new BlockingPooledConnectionQueue(1); + final BlockingPooledConnectionQueue pool = newConnectionQueue(1); pool.terminate(); final boolean[] flags = {false}; @@ -240,8 +240,7 @@ public void dispose() public void shouldDisposeConnectionIfPoolStoppedAfterPuttingConnectionBackToPool() throws Throwable { // Given - final BlockingPooledConnectionQueue - pool = new BlockingPooledConnectionQueue(1); + final BlockingPooledConnectionQueue pool = newConnectionQueue(1); pool.terminate(); final boolean[] flags = {false}; @@ -362,4 +361,9 @@ public void shouldThrowExceptionIfFailureReceivedForAckFailure() verify( conn, times( 1 ) ).ackFailure(); assertThat( pooledConnection.hasUnrecoverableErrors(), equalTo( true ) ); } + + private static BlockingPooledConnectionQueue newConnectionQueue( int capacity ) + { + return new BlockingPooledConnectionQueue( LOCAL_DEFAULT, capacity, mock( Logging.class, RETURNS_MOCKS ) ); + } } diff --git a/driver/src/test/java/org/neo4j/driver/internal/net/pooling/SocketConnectionPoolTest.java b/driver/src/test/java/org/neo4j/driver/internal/net/pooling/SocketConnectionPoolTest.java new file mode 100644 index 0000000000..2215bc6f9a --- /dev/null +++ b/driver/src/test/java/org/neo4j/driver/internal/net/pooling/SocketConnectionPoolTest.java @@ -0,0 +1,335 @@ +/* + * Copyright (c) 2002-2016 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal.net.pooling; + +import org.junit.Test; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import java.util.ArrayList; +import java.util.List; +import java.util.Set; +import java.util.concurrent.Callable; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +import org.neo4j.driver.internal.net.BoltServerAddress; +import org.neo4j.driver.internal.spi.Connection; +import org.neo4j.driver.internal.spi.Connector; +import org.neo4j.driver.v1.Logging; + +import static java.util.Collections.newSetFromMap; +import static org.hamcrest.Matchers.instanceOf; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.RETURNS_MOCKS; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.neo4j.driver.internal.net.BoltServerAddress.DEFAULT_PORT; +import static org.neo4j.driver.internal.net.BoltServerAddress.LOCAL_DEFAULT; +import static org.neo4j.driver.internal.util.Clock.SYSTEM; + +public class SocketConnectionPoolTest +{ + private static final BoltServerAddress ADDRESS_1 = LOCAL_DEFAULT; + private static final BoltServerAddress ADDRESS_2 = new BoltServerAddress( "localhost", DEFAULT_PORT + 42 ); + private static final BoltServerAddress ADDRESS_3 = new BoltServerAddress( "localhost", DEFAULT_PORT + 4242 ); + + @Test + public void acquireCreatesNewConnectionWhenPoolIsEmpty() + { + Connector connector = newMockConnector(); + SocketConnectionPool pool = newPool( connector ); + + Connection connection = pool.acquire( ADDRESS_1 ); + + assertThat( connection, instanceOf( PooledConnection.class ) ); + verify( connector ).connect( ADDRESS_1 ); + } + + @Test + public void acquireUsesExistingConnectionIfPresent() + { + Connection connection = newConnectionMock( ADDRESS_1 ); + Connector connector = newMockConnector( connection ); + + SocketConnectionPool pool = newPool( connector ); + + Connection acquiredConnection1 = pool.acquire( ADDRESS_1 ); + assertThat( acquiredConnection1, instanceOf( PooledConnection.class ) ); + acquiredConnection1.close(); // return connection to the pool + + Connection acquiredConnection2 = pool.acquire( ADDRESS_1 ); + assertThat( acquiredConnection2, instanceOf( PooledConnection.class ) ); + + verify( connector ).connect( ADDRESS_1 ); + } + + @Test + public void purgeDoesNothingForNonExistingAddress() + { + Connection connection = newConnectionMock( ADDRESS_1 ); + SocketConnectionPool pool = newPool( newMockConnector( connection ) ); + + pool.acquire( ADDRESS_1 ).close(); + + assertTrue( pool.hasAddress( ADDRESS_1 ) ); + pool.purge( ADDRESS_2 ); + assertTrue( pool.hasAddress( ADDRESS_1 ) ); + } + + @Test + public void purgeRemovesAddress() + { + Connection connection = newConnectionMock( ADDRESS_1 ); + SocketConnectionPool pool = newPool( newMockConnector( connection ) ); + + pool.acquire( ADDRESS_1 ).close(); + + assertTrue( pool.hasAddress( ADDRESS_1 ) ); + pool.purge( ADDRESS_1 ); + assertFalse( pool.hasAddress( ADDRESS_1 ) ); + } + + @Test + public void purgeTerminatesPoolCorrespondingToTheAddress() + { + Connection connection1 = newConnectionMock( ADDRESS_1 ); + Connection connection2 = newConnectionMock( ADDRESS_1 ); + Connection connection3 = newConnectionMock( ADDRESS_1 ); + SocketConnectionPool pool = newPool( newMockConnector( connection1, connection2, connection3 ) ); + + Connection pooledConnection1 = pool.acquire( ADDRESS_1 ); + Connection pooledConnection2 = pool.acquire( ADDRESS_1 ); + pool.acquire( ADDRESS_1 ); + + // return two connections to the pool + pooledConnection1.close(); + pooledConnection2.close(); + + pool.purge( ADDRESS_1 ); + + verify( connection1 ).close(); + verify( connection2 ).close(); + verify( connection3 ).close(); + } + + @Test + public void hasAddressReturnsFalseWhenPoolIsEmpty() + { + SocketConnectionPool pool = newPool( newMockConnector() ); + + assertFalse( pool.hasAddress( ADDRESS_1 ) ); + assertFalse( pool.hasAddress( ADDRESS_2 ) ); + } + + @Test + public void hasAddressReturnsFalseForUnknownAddress() + { + SocketConnectionPool pool = newPool( newMockConnector() ); + + assertNotNull( pool.acquire( ADDRESS_1 ) ); + + assertFalse( pool.hasAddress( ADDRESS_2 ) ); + } + + @Test + public void hasAddressReturnsTrueForKnownAddress() + { + SocketConnectionPool pool = newPool( newMockConnector() ); + + assertNotNull( pool.acquire( ADDRESS_1 ) ); + + assertTrue( pool.hasAddress( ADDRESS_1 ) ); + } + + @Test + public void closeTerminatesAllPools() + { + Connection connection1 = newConnectionMock( ADDRESS_1 ); + Connection connection2 = newConnectionMock( ADDRESS_1 ); + Connection connection3 = newConnectionMock( ADDRESS_2 ); + Connection connection4 = newConnectionMock( ADDRESS_2 ); + + Connector connector = newMockConnector( connection1, connection2, connection3, connection4 ); + + SocketConnectionPool pool = newPool( connector ); + + assertNotNull( pool.acquire( ADDRESS_1 ) ); + pool.acquire( ADDRESS_1 ).close(); // return to the pool + assertNotNull( pool.acquire( ADDRESS_2 ) ); + pool.acquire( ADDRESS_2 ).close(); // return to the pool + + assertTrue( pool.hasAddress( ADDRESS_1 ) ); + assertTrue( pool.hasAddress( ADDRESS_2 ) ); + + pool.close(); + + verify( connection1 ).close(); + verify( connection2 ).close(); + verify( connection3 ).close(); + verify( connection4 ).close(); + } + + @Test + public void closeRemovesAllPools() + { + Connection connection1 = newConnectionMock( ADDRESS_1 ); + Connection connection2 = newConnectionMock( ADDRESS_2 ); + Connection connection3 = newConnectionMock( ADDRESS_3 ); + + Connector connector = newMockConnector( connection1, connection2, connection3 ); + + SocketConnectionPool pool = newPool( connector ); + + assertNotNull( pool.acquire( ADDRESS_1 ) ); + assertNotNull( pool.acquire( ADDRESS_2 ) ); + assertNotNull( pool.acquire( ADDRESS_3 ) ); + + assertTrue( pool.hasAddress( ADDRESS_1 ) ); + assertTrue( pool.hasAddress( ADDRESS_2 ) ); + assertTrue( pool.hasAddress( ADDRESS_3 ) ); + + pool.close(); + + assertFalse( pool.hasAddress( ADDRESS_1 ) ); + assertFalse( pool.hasAddress( ADDRESS_2 ) ); + assertFalse( pool.hasAddress( ADDRESS_3 ) ); + } + + @Test + public void closeWithConcurrentAcquisitionsEmptiesThePool() throws InterruptedException + { + Connector connector = mock( Connector.class ); + Set createdConnections = newSetFromMap( new ConcurrentHashMap() ); + when( connector.connect( any( BoltServerAddress.class ) ) ) + .then( createConnectionAnswer( createdConnections ) ); + + SocketConnectionPool pool = newPool( connector ); + + ExecutorService executor = Executors.newCachedThreadPool(); + List> results = new ArrayList<>(); + + AtomicInteger port = new AtomicInteger(); + for ( int i = 0; i < 5; i++ ) + { + Future result = executor.submit( acquireConnection( pool, port ) ); + results.add( result ); + } + + Thread.sleep( 500 ); // allow workers to do something + + pool.close(); + + for ( Future result : results ) + { + try + { + result.get( 20, TimeUnit.SECONDS ); + fail( "Exception expected" ); + } + catch ( Exception e ) + { + assertThat( e, instanceOf( ExecutionException.class ) ); + assertThat( e.getCause(), instanceOf( IllegalStateException.class ) ); + } + } + executor.shutdownNow(); + executor.awaitTermination( 10, TimeUnit.SECONDS ); + + for ( int i = 0; i < port.intValue(); i++ ) + { + assertFalse( pool.hasAddress( new BoltServerAddress( "localhost", i ) ) ); + } + for ( Connection connection : createdConnections ) + { + verify( connection ).close(); + } + } + + private static Answer createConnectionAnswer( final Set createdConnections ) + { + return new Answer() + { + @Override + public Connection answer( InvocationOnMock invocation ) + { + BoltServerAddress address = invocation.getArgumentAt( 0, BoltServerAddress.class ); + Connection connection = newConnectionMock( address ); + createdConnections.add( connection ); + return connection; + } + }; + } + + private static Callable acquireConnection( final SocketConnectionPool pool, final AtomicInteger port ) + { + return new Callable() + { + @Override + public Void call() throws Exception + { + while ( true ) + { + pool.acquire( new BoltServerAddress( "localhost", port.incrementAndGet() ) ); + } + } + }; + } + + private static Connector newMockConnector() + { + Connection connection = mock( Connection.class ); + return newMockConnector( connection ); + } + + private static Connector newMockConnector( Connection connection, Connection... otherConnections ) + { + Connector connector = mock( Connector.class ); + when( connector.connect( any( BoltServerAddress.class ) ) ).thenReturn( connection, otherConnections ); + return connector; + } + + private static SocketConnectionPool newPool( Connector connector ) + { + PoolSettings poolSettings = new PoolSettings( 42 ); + Logging logging = mock( Logging.class, RETURNS_MOCKS ); + return new SocketConnectionPool( poolSettings, connector, SYSTEM, logging ); + } + + private static Connection newConnectionMock( BoltServerAddress address ) + { + Connection connection = mock( Connection.class ); + if ( address != null ) + { + when( connection.boltServerAddress() ).thenReturn( address ); + } + return connection; + } +}