From 449ae97e7286404e415067f9c0fab7e3db90e9bc Mon Sep 17 00:00:00 2001 From: lutovich Date: Wed, 7 Dec 2016 13:16:24 +0100 Subject: [PATCH 1/6] Allow concurrent close of connections It is possible for connection to be closed from another thread while it is being used to do some work (send/receive data from socket, etc.). This could happen when whole driver is closed while there still exist threads doing actual work. Each `SocketConnection` is guarded by `ConcurrencyGuardingConnection` that throws if concurrent access is detected. This wrapper guarded `#close()` call as well so exceptions could be thrown on `Driver#close()`. Close procedure would then not be completed and some connections would not be closed. This commit allows concurrent close in `ConcurrencyGuardingConnection`. --- .../net/ConcurrencyGuardingConnection.java | 12 +--- .../ConcurrencyGuardingConnectionTest.java | 70 +++++++++++++++---- 2 files changed, 58 insertions(+), 24 deletions(-) diff --git a/driver/src/main/java/org/neo4j/driver/internal/net/ConcurrencyGuardingConnection.java b/driver/src/main/java/org/neo4j/driver/internal/net/ConcurrencyGuardingConnection.java index fd5f8dab45..577360019b 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/net/ConcurrencyGuardingConnection.java +++ b/driver/src/main/java/org/neo4j/driver/internal/net/ConcurrencyGuardingConnection.java @@ -174,15 +174,9 @@ public void receiveOne() @Override public void close() { - try - { - markAsInUse(); - delegate.close(); - } - finally - { - markAsAvailable(); - } + // It is fine to call close concurrently with this connection being used somewhere else. + // This could happen when driver is closed while there still exist sessions that do some work. + delegate.close(); } @Override diff --git a/driver/src/test/java/org/neo4j/driver/internal/net/ConcurrencyGuardingConnectionTest.java b/driver/src/test/java/org/neo4j/driver/internal/net/ConcurrencyGuardingConnectionTest.java index 0e136fc98a..319f79de00 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/net/ConcurrencyGuardingConnectionTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/net/ConcurrencyGuardingConnectionTest.java @@ -28,14 +28,16 @@ import java.util.concurrent.atomic.AtomicReference; import org.neo4j.driver.internal.spi.Connection; -import org.neo4j.driver.v1.util.Function; import org.neo4j.driver.v1.exceptions.ClientException; +import org.neo4j.driver.v1.util.Function; import static java.util.Arrays.asList; import static junit.framework.TestCase.fail; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.MatcherAssert.assertThat; +import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; @RunWith( Parameterized.class ) public class ConcurrencyGuardingConnectionTest @@ -44,17 +46,19 @@ public class ConcurrencyGuardingConnectionTest public Function operation; @Parameterized.Parameters - public static List params() + public static List> params() { return asList( - new Object[]{INIT}, - new Object[]{RUN}, - new Object[]{PULL_ALL}, - new Object[]{DISCARD_ALL}, - new Object[]{CLOSE}, - new Object[]{RECIEVE_ONE}, - new Object[]{FLUSH}, - new Object[]{SYNC}); + INIT, + RUN, + PULL_ALL, + DISCARD_ALL, + RECIEVE_ONE, + FLUSH, + SYNC, + RESET, + ACK_FAILURE + ); } @Test @@ -95,6 +99,32 @@ public Object answer( InvocationOnMock invocationOnMock ) throws Throwable "do that is to give each thread its own dedicated session.") ); } + @Test + public void shouldAllowConcurrentClose() + { + // Given + final AtomicReference connection = new AtomicReference<>(); + + Connection delegate = mock( Connection.class, new Answer() + { + @Override + public Void answer( InvocationOnMock invocation ) throws Throwable + { + connection.get().close(); + return null; + } + } ); + doNothing().when( delegate ).close(); + + connection.set( new ConcurrencyGuardingConnection( delegate ) ); + + // When + operation.apply( connection.get() ); + + // Then + verify( delegate ).close(); + } + public static final Function INIT = new Function() { @Override @@ -135,22 +165,32 @@ public Void apply( Connection connection ) } }; - public static final Function RECIEVE_ONE = new Function() + public static final Function RESET = new Function() { @Override public Void apply( Connection connection ) { - connection.receiveOne(); + connection.reset(); return null; } }; - public static final Function CLOSE = new Function() + public static final Function ACK_FAILURE = new Function() { @Override public Void apply( Connection connection ) { - connection.close(); + connection.ackFailure(); + return null; + } + }; + + public static final Function RECIEVE_ONE = new Function() + { + @Override + public Void apply( Connection connection ) + { + connection.receiveOne(); return null; } }; @@ -174,4 +214,4 @@ public Void apply( Connection connection ) return null; } }; -} \ No newline at end of file +} From 6faf667f47f739401d02ea82d2cd8691057be1be Mon Sep 17 00:00:00 2001 From: lutovich Date: Wed, 7 Dec 2016 13:34:17 +0100 Subject: [PATCH 2/6] Removed unused SocketUtils class --- .../connector/socket/SocketUtils.java | 83 ------------------- 1 file changed, 83 deletions(-) delete mode 100644 driver/src/main/java/org/neo4j/driver/internal/connector/socket/SocketUtils.java diff --git a/driver/src/main/java/org/neo4j/driver/internal/connector/socket/SocketUtils.java b/driver/src/main/java/org/neo4j/driver/internal/connector/socket/SocketUtils.java deleted file mode 100644 index 0f6c1e75fe..0000000000 --- a/driver/src/main/java/org/neo4j/driver/internal/connector/socket/SocketUtils.java +++ /dev/null @@ -1,83 +0,0 @@ -/* - * Copyright (c) 2002-2016 "Neo Technology," - * Network Engine for Objects in Lund AB [http://neotechnology.com] - * - * This file is part of Neo4j. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.neo4j.driver.internal.connector.socket; - -import java.io.IOException; -import java.nio.ByteBuffer; -import java.nio.channels.ByteChannel; - -import org.neo4j.driver.internal.util.BytePrinter; -import org.neo4j.driver.v1.exceptions.ClientException; - -/** - * Utility class for common operations. - */ -public final class SocketUtils -{ - private SocketUtils() - { - throw new UnsupportedOperationException( "Do not instantiate" ); - } - - public static void blockingRead(ByteChannel channel, ByteBuffer buf) throws IOException - { - while(buf.hasRemaining()) - { - if (channel.read( buf ) < 0) - { - try - { - channel.close(); - } - catch ( IOException e ) - { - // best effort - } - String bufStr = BytePrinter.hex( buf ).trim(); - throw new ClientException( String.format( - "Connection terminated while receiving data. This can happen due to network " + - "instabilities, or due to restarts of the database. Expected %s bytes, received %s.", - buf.limit(), bufStr.isEmpty() ? "none" : bufStr ) ); - } - } - } - - public static void blockingWrite(ByteChannel channel, ByteBuffer buf) throws IOException - { - while(buf.hasRemaining()) - { - if (channel.write( buf ) < 0) - { - try - { - channel.close(); - } - catch ( IOException e ) - { - // best effort - } - String bufStr = BytePrinter.hex( buf ).trim(); - throw new ClientException( String.format( - "Connection terminated while sending data. This can happen due to network " + - "instabilities, or due to restarts of the database. Expected %s bytes, wrote %s.", - buf.limit(), bufStr.isEmpty() ? "none" :bufStr ) ); - } - } - } -} From 539b90f0b92cde8440ea706b497ad5db92249836 Mon Sep 17 00:00:00 2001 From: lutovich Date: Wed, 7 Dec 2016 14:23:41 +0100 Subject: [PATCH 3/6] Better handling of termination failures in connection queue `BlockingPooledConnectionQueue` disposes both idle and acquired connections when it is terminated. It was previously possible for termination process to stop halfway if disposal of some connection fails. This commit makes pool termination try to dispose all connections regardless of errors. --- .../BlockingPooledConnectionQueue.java | 64 ++++--- .../net/pooling/SocketConnectionPool.java | 2 +- .../BlockingPooledConnectionQueueTest.java | 156 +++++++++++++++++- .../net/pooling/PooledConnectionTest.java | 26 +-- 4 files changed, 209 insertions(+), 39 deletions(-) diff --git a/driver/src/main/java/org/neo4j/driver/internal/net/pooling/BlockingPooledConnectionQueue.java b/driver/src/main/java/org/neo4j/driver/internal/net/pooling/BlockingPooledConnectionQueue.java index a23e0106e1..156e9c9b04 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/net/pooling/BlockingPooledConnectionQueue.java +++ b/driver/src/main/java/org/neo4j/driver/internal/net/pooling/BlockingPooledConnectionQueue.java @@ -27,7 +27,10 @@ import java.util.concurrent.LinkedBlockingQueue; import java.util.concurrent.atomic.AtomicBoolean; +import org.neo4j.driver.internal.net.BoltServerAddress; import org.neo4j.driver.internal.util.Supplier; +import org.neo4j.driver.v1.Logger; +import org.neo4j.driver.v1.Logging; /** * A blocking queue that also keeps track of connections that are acquired in order @@ -37,6 +40,7 @@ public class BlockingPooledConnectionQueue { /** The backing queue, keeps track of connections currently in queue */ private final BlockingQueue queue; + private final Logger logger; private final AtomicBoolean isTerminating = new AtomicBoolean( false ); @@ -44,9 +48,10 @@ public class BlockingPooledConnectionQueue private final Set acquiredConnections = Collections.newSetFromMap(new ConcurrentHashMap()); - public BlockingPooledConnectionQueue( int capacity ) + public BlockingPooledConnectionQueue( BoltServerAddress address, int capacity, Logging logging ) { this.queue = new LinkedBlockingQueue<>( capacity ); + this.logger = createLogger( address, logging ); } /** @@ -64,10 +69,10 @@ public boolean offer( PooledConnection pooledConnection ) pooledConnection.dispose(); } if (isTerminating.get()) { - PooledConnection poll = queue.poll(); - if (poll != null) + PooledConnection connection = queue.poll(); + if (connection != null) { - poll.dispose(); + connection.dispose(); } } return offer; @@ -81,19 +86,19 @@ public boolean offer( PooledConnection pooledConnection ) public PooledConnection acquire( Supplier supplier ) { - PooledConnection poll = queue.poll(); - if ( poll == null ) + PooledConnection connection = queue.poll(); + if ( connection == null ) { - poll = supplier.get(); + connection = supplier.get(); } - acquiredConnections.add( poll ); + acquiredConnections.add( connection ); if (isTerminating.get()) { - acquiredConnections.remove( poll ); - poll.dispose(); + acquiredConnections.remove( connection ); + connection.dispose(); throw new IllegalStateException( "Pool has been closed, cannot acquire new values." ); } - return poll; + return connection; } public List toList() @@ -119,24 +124,43 @@ public boolean contains( PooledConnection pooledConnection ) /** * Terminates all connections, both those that are currently in the queue as well * as those that have been acquired. + *

+ * This method does not throw runtime exceptions. All connection close failures are only logged. */ public void terminate() { - if (isTerminating.compareAndSet( false, true )) + if ( isTerminating.compareAndSet( false, true ) ) { while ( !queue.isEmpty() ) { - PooledConnection conn = queue.poll(); - if ( conn != null ) - { - //close the underlying connection without adding it back to the queue - conn.dispose(); - } + PooledConnection idleConnection = queue.poll(); + disposeSafely( idleConnection ); } - for ( PooledConnection pooledConnection : acquiredConnections ) + for ( PooledConnection acquiredConnection : acquiredConnections ) { - pooledConnection.dispose(); + disposeSafely( acquiredConnection ); } } } + + private void disposeSafely( PooledConnection connection ) + { + try + { + if ( connection != null ) + { + // close the underlying connection without adding it back to the queue + connection.dispose(); + } + } + catch ( Throwable disposeError ) + { + logger.error( "Error disposing connection", disposeError ); + } + } + + private static Logger createLogger( BoltServerAddress address, Logging logging ) + { + return logging.getLog( "connectionQueue[" + address + "]" ); + } } diff --git a/driver/src/main/java/org/neo4j/driver/internal/net/pooling/SocketConnectionPool.java b/driver/src/main/java/org/neo4j/driver/internal/net/pooling/SocketConnectionPool.java index f709e009b4..6f42d73618 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/net/pooling/SocketConnectionPool.java +++ b/driver/src/main/java/org/neo4j/driver/internal/net/pooling/SocketConnectionPool.java @@ -129,7 +129,7 @@ private BlockingPooledConnectionQueue pool( BoltServerAddress address ) BlockingPooledConnectionQueue pool = pools.get( address ); if ( pool == null ) { - pool = new BlockingPooledConnectionQueue( poolSettings.maxIdleConnectionPoolSize() ); + pool = new BlockingPooledConnectionQueue( address, poolSettings.maxIdleConnectionPoolSize(), logging ); if ( pools.putIfAbsent( address, pool ) != null ) { diff --git a/driver/src/test/java/org/neo4j/driver/internal/net/pooling/BlockingPooledConnectionQueueTest.java b/driver/src/test/java/org/neo4j/driver/internal/net/pooling/BlockingPooledConnectionQueueTest.java index 4a0dc81640..3f35a033e2 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/net/pooling/BlockingPooledConnectionQueueTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/net/pooling/BlockingPooledConnectionQueueTest.java @@ -21,16 +21,28 @@ import org.junit.Test; +import org.neo4j.driver.internal.spi.Connection; +import org.neo4j.driver.internal.util.Consumer; import org.neo4j.driver.internal.util.Supplier; +import org.neo4j.driver.v1.Logger; +import org.neo4j.driver.v1.Logging; import static org.hamcrest.CoreMatchers.equalTo; import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; +import static org.mockito.Matchers.anyString; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.RETURNS_MOCKS; +import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.neo4j.driver.internal.net.BoltServerAddress.LOCAL_DEFAULT; +import static org.neo4j.driver.internal.util.Clock.SYSTEM; public class BlockingPooledConnectionQueueTest { @@ -42,7 +54,7 @@ public void shouldCreateNewConnectionWhenEmpty() PooledConnection connection = mock( PooledConnection.class ); Supplier supplier = mock( Supplier.class ); when( supplier.get() ).thenReturn( connection ); - BlockingPooledConnectionQueue queue = new BlockingPooledConnectionQueue( 10 ); + BlockingPooledConnectionQueue queue = newConnectionQueue( 10 ); // When queue.acquire( supplier ); @@ -59,7 +71,7 @@ public void shouldNotCreateNewConnectionWhenNotEmpty() PooledConnection connection = mock( PooledConnection.class ); Supplier supplier = mock( Supplier.class ); when( supplier.get() ).thenReturn( connection ); - BlockingPooledConnectionQueue queue = new BlockingPooledConnectionQueue( 1 ); + BlockingPooledConnectionQueue queue = newConnectionQueue( 1 ); queue.offer( connection ); // When @@ -78,7 +90,7 @@ public void shouldTerminateAllSeenConnections() PooledConnection connection2 = mock( PooledConnection.class ); Supplier supplier = mock( Supplier.class ); when( supplier.get() ).thenReturn( connection1 ); - BlockingPooledConnectionQueue queue = new BlockingPooledConnectionQueue( 2 ); + BlockingPooledConnectionQueue queue = newConnectionQueue( 2 ); queue.offer( connection1 ); queue.offer( connection2 ); assertThat( queue.size(), equalTo( 2 ) ); @@ -99,10 +111,140 @@ public void shouldNotAcceptWhenFull() // Given PooledConnection connection1 = mock( PooledConnection.class ); PooledConnection connection2 = mock( PooledConnection.class ); - BlockingPooledConnectionQueue queue = new BlockingPooledConnectionQueue( 1 ); + BlockingPooledConnectionQueue queue = newConnectionQueue( 1 ); // Then - assertTrue(queue.offer( connection1 )); - assertFalse(queue.offer( connection2 )); + assertTrue( queue.offer( connection1 ) ); + assertFalse( queue.offer( connection2 ) ); } -} \ No newline at end of file + + @Test + public void shouldDisposeAllConnectionsWhenOneOfThemFailsToDispose() + { + BlockingPooledConnectionQueue queue = newConnectionQueue( 5 ); + + PooledConnection connection1 = mock( PooledConnection.class ); + PooledConnection connection2 = mock( PooledConnection.class ); + PooledConnection connection3 = mock( PooledConnection.class ); + + RuntimeException disposeError = new RuntimeException( "Failed to stop socket" ); + doThrow( disposeError ).when( connection2 ).dispose(); + + queue.offer( connection1 ); + queue.offer( connection2 ); + queue.offer( connection3 ); + + queue.terminate(); + + verify( connection1 ).dispose(); + verify( connection2 ).dispose(); + verify( connection3 ).dispose(); + } + + @Test + @SuppressWarnings( "unchecked" ) + public void shouldTryToCloseAllUnderlyingConnections() + { + BlockingPooledConnectionQueue queue = newConnectionQueue( 5 ); + + Connection connection1 = mock( Connection.class ); + Connection connection2 = mock( Connection.class ); + Connection connection3 = mock( Connection.class ); + + RuntimeException closeError1 = new RuntimeException( "Failed to close 1" ); + RuntimeException closeError2 = new RuntimeException( "Failed to close 2" ); + RuntimeException closeError3 = new RuntimeException( "Failed to close 3" ); + + doThrow( closeError1 ).when( connection1 ).close(); + doThrow( closeError2 ).when( connection2 ).close(); + doThrow( closeError3 ).when( connection3 ).close(); + + PooledConnection pooledConnection1 = new PooledConnection( connection1, mock( Consumer.class ), SYSTEM ); + PooledConnection pooledConnection2 = new PooledConnection( connection2, mock( Consumer.class ), SYSTEM ); + PooledConnection pooledConnection3 = new PooledConnection( connection3, mock( Consumer.class ), SYSTEM ); + + queue.offer( pooledConnection1 ); + queue.offer( pooledConnection2 ); + queue.offer( pooledConnection3 ); + + queue.terminate(); + + verify( connection1 ).close(); + verify( connection2 ).close(); + verify( connection3 ).close(); + } + + @Test + @SuppressWarnings( "unchecked" ) + public void shouldLogWhenConnectionDisposeFails() + { + Logging logging = mock( Logging.class ); + Logger logger = mock( Logger.class ); + when( logging.getLog( anyString() ) ).thenReturn( logger ); + + BlockingPooledConnectionQueue queue = newConnectionQueue( 5, logging ); + + Connection connection = mock( Connection.class ); + RuntimeException closeError = new RuntimeException( "Fail" ); + doThrow( closeError ).when( connection ).close(); + PooledConnection pooledConnection = new PooledConnection( connection, mock( Consumer.class ), SYSTEM ); + queue.offer( pooledConnection ); + + queue.terminate(); + + verify( logger ).error( anyString(), eq( closeError ) ); + } + + @Test + public void shouldHaveZeroSizeAfterTermination() + { + BlockingPooledConnectionQueue queue = newConnectionQueue( 5 ); + + queue.offer( mock( PooledConnection.class ) ); + queue.offer( mock( PooledConnection.class ) ); + queue.offer( mock( PooledConnection.class ) ); + + queue.terminate(); + + assertEquals( 0, queue.size() ); + } + + @Test + @SuppressWarnings( "unchecked" ) + public void shouldTerminateBothAcquiredAndIdleConnections() + { + BlockingPooledConnectionQueue queue = newConnectionQueue( 5 ); + + PooledConnection connection1 = mock( PooledConnection.class ); + PooledConnection connection2 = mock( PooledConnection.class ); + PooledConnection connection3 = mock( PooledConnection.class ); + PooledConnection connection4 = mock( PooledConnection.class ); + + queue.offer( connection1 ); + queue.offer( connection2 ); + queue.offer( connection3 ); + queue.offer( connection4 ); + + PooledConnection acquiredConnection1 = queue.acquire( mock( Supplier.class ) ); + PooledConnection acquiredConnection2 = queue.acquire( mock( Supplier.class ) ); + assertSame( connection1, acquiredConnection1 ); + assertSame( connection2, acquiredConnection2 ); + + queue.terminate(); + + verify( connection1 ).dispose(); + verify( connection2 ).dispose(); + verify( connection3 ).dispose(); + verify( connection4 ).dispose(); + } + + private static BlockingPooledConnectionQueue newConnectionQueue( int capacity ) + { + return newConnectionQueue( capacity, mock( Logging.class, RETURNS_MOCKS ) ); + } + + private static BlockingPooledConnectionQueue newConnectionQueue( int capacity, Logging logging ) + { + return new BlockingPooledConnectionQueue( LOCAL_DEFAULT, capacity, logging ); + } +} diff --git a/driver/src/test/java/org/neo4j/driver/internal/net/pooling/PooledConnectionTest.java b/driver/src/test/java/org/neo4j/driver/internal/net/pooling/PooledConnectionTest.java index 9feb17d304..76738dec2f 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/net/pooling/PooledConnectionTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/net/pooling/PooledConnectionTest.java @@ -23,6 +23,7 @@ import org.neo4j.driver.internal.spi.Connection; import org.neo4j.driver.internal.util.Clock; import org.neo4j.driver.internal.util.Supplier; +import org.neo4j.driver.v1.Logging; import org.neo4j.driver.v1.exceptions.ClientException; import org.neo4j.driver.v1.util.Function; @@ -31,11 +32,13 @@ import static org.junit.Assert.assertSame; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; +import static org.mockito.Mockito.RETURNS_MOCKS; import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.times; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.neo4j.driver.internal.net.BoltServerAddress.LOCAL_DEFAULT; public class PooledConnectionTest { @@ -63,8 +66,7 @@ public Boolean apply( PooledConnection pooledConnection ) public void shouldDisposeConnectionIfNotValidConnection() throws Throwable { // Given - final BlockingPooledConnectionQueue - pool = new BlockingPooledConnectionQueue(1); + final BlockingPooledConnectionQueue pool = newConnectionQueue(1); final boolean[] flags = {false}; @@ -93,8 +95,7 @@ public void dispose() public void shouldReturnToThePoolIfIsValidConnectionAndIdlePoolIsNotFull() throws Throwable { // Given - final BlockingPooledConnectionQueue - pool = new BlockingPooledConnectionQueue(1); + final BlockingPooledConnectionQueue pool = newConnectionQueue(1); final boolean[] flags = {false}; @@ -124,8 +125,7 @@ public void dispose() public void shouldDisposeConnectionIfValidConnectionAndIdlePoolIsFull() throws Throwable { // Given - final BlockingPooledConnectionQueue - pool = new BlockingPooledConnectionQueue(1); + final BlockingPooledConnectionQueue pool = newConnectionQueue(1); final boolean[] flags = {false}; @@ -158,7 +158,7 @@ public void shouldDisposeAcquiredConnectionsWhenPoolIsClosed() { PooledConnection connection = mock( PooledConnection.class ); - BlockingPooledConnectionQueue pool = new BlockingPooledConnectionQueue( 5 ); + BlockingPooledConnectionQueue pool = newConnectionQueue( 5 ); Supplier pooledConnectionFactory = mock( Supplier.class ); when( pooledConnectionFactory.get() ).thenReturn( connection ); @@ -178,7 +178,7 @@ public void shouldDisposeAcquiredAndIdleConnectionsWhenPoolIsClosed() PooledConnection connection2 = mock( PooledConnection.class ); PooledConnection connection3 = mock( PooledConnection.class ); - BlockingPooledConnectionQueue pool = new BlockingPooledConnectionQueue( 5 ); + BlockingPooledConnectionQueue pool = newConnectionQueue( 5 ); Supplier pooledConnectionFactory = mock( Supplier.class ); when( pooledConnectionFactory.get() ) @@ -212,7 +212,7 @@ public void shouldDisposeConnectionIfPoolAlreadyClosed() throws Throwable // session.close() -> well, close the connection directly without putting back to the pool // Given - final BlockingPooledConnectionQueue pool = new BlockingPooledConnectionQueue(1); + final BlockingPooledConnectionQueue pool = newConnectionQueue(1); pool.terminate(); final boolean[] flags = {false}; @@ -240,8 +240,7 @@ public void dispose() public void shouldDisposeConnectionIfPoolStoppedAfterPuttingConnectionBackToPool() throws Throwable { // Given - final BlockingPooledConnectionQueue - pool = new BlockingPooledConnectionQueue(1); + final BlockingPooledConnectionQueue pool = newConnectionQueue(1); pool.terminate(); final boolean[] flags = {false}; @@ -362,4 +361,9 @@ public void shouldThrowExceptionIfFailureReceivedForAckFailure() verify( conn, times( 1 ) ).ackFailure(); assertThat( pooledConnection.hasUnrecoverableErrors(), equalTo( true ) ); } + + private static BlockingPooledConnectionQueue newConnectionQueue( int capacity ) + { + return new BlockingPooledConnectionQueue( LOCAL_DEFAULT, capacity, mock( Logging.class, RETURNS_MOCKS ) ); + } } From 6f0cf2d31c7ced9992532e0ef4ba83bcd69ec58f Mon Sep 17 00:00:00 2001 From: lutovich Date: Thu, 8 Dec 2016 14:21:33 +0100 Subject: [PATCH 4/6] Better handle async close in SocketConnectionPool This pool can be used from multiple threads. It is possible for `#close()` to be called concurrently with `#acquire()`. Previously we did not guard against this situation and it was possible to acquire connections from the closed pool. It could also happen that `#close()` does not actually close all connection queues because iteration over them is not atomic. This commit makes `SocketConnectionPool` track if it is closed or not. It also introduces `Connector` abstraction (which is merely a connection factory) so that `SocketConnectionPool` can be easily tested. --- .../driver/internal/net/SocketConnector.java | 83 +++++ .../net/pooling/SocketConnectionPool.java | 101 ++---- .../neo4j/driver/internal/spi/Connector.java | 26 ++ .../org/neo4j/driver/v1/GraphDatabase.java | 24 +- .../internal/net/SocketConnectorTest.java | 121 +++++++ .../net/pooling/SocketConnectionPoolTest.java | 335 ++++++++++++++++++ 6 files changed, 614 insertions(+), 76 deletions(-) create mode 100644 driver/src/main/java/org/neo4j/driver/internal/net/SocketConnector.java create mode 100644 driver/src/main/java/org/neo4j/driver/internal/spi/Connector.java create mode 100644 driver/src/test/java/org/neo4j/driver/internal/net/SocketConnectorTest.java create mode 100644 driver/src/test/java/org/neo4j/driver/internal/net/pooling/SocketConnectionPoolTest.java diff --git a/driver/src/main/java/org/neo4j/driver/internal/net/SocketConnector.java b/driver/src/main/java/org/neo4j/driver/internal/net/SocketConnector.java new file mode 100644 index 0000000000..15158e88fa --- /dev/null +++ b/driver/src/main/java/org/neo4j/driver/internal/net/SocketConnector.java @@ -0,0 +1,83 @@ +/* + * Copyright (c) 2002-2016 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal.net; + +import java.util.Map; + +import org.neo4j.driver.internal.ConnectionSettings; +import org.neo4j.driver.internal.security.InternalAuthToken; +import org.neo4j.driver.internal.security.SecurityPlan; +import org.neo4j.driver.internal.spi.Connection; +import org.neo4j.driver.internal.spi.Connector; +import org.neo4j.driver.v1.AuthToken; +import org.neo4j.driver.v1.AuthTokens; +import org.neo4j.driver.v1.Logging; +import org.neo4j.driver.v1.Value; +import org.neo4j.driver.v1.exceptions.ClientException; + +public class SocketConnector implements Connector +{ + private final ConnectionSettings connectionSettings; + private final SecurityPlan securityPlan; + private final Logging logging; + + public SocketConnector( ConnectionSettings connectionSettings, SecurityPlan securityPlan, Logging logging ) + { + this.connectionSettings = connectionSettings; + this.securityPlan = securityPlan; + this.logging = logging; + } + + @Override + public final Connection connect( BoltServerAddress address ) + { + Connection connection = createConnection( address, securityPlan, logging ); + + // Because SocketConnection is not thread safe, wrap it in this guard + // to ensure concurrent access leads causes application errors + connection = new ConcurrencyGuardingConnection( connection ); + + connection.init( connectionSettings.userAgent(), tokenAsMap( connectionSettings.authToken() ) ); + return connection; + } + + /** + * Create new connection. + *

+ * This method is package-private only for testing + */ + Connection createConnection( BoltServerAddress address, SecurityPlan securityPlan, Logging logging ) + { + return new SocketConnection( address, securityPlan, logging ); + } + + private static Map tokenAsMap( AuthToken token ) + { + if ( token instanceof InternalAuthToken ) + { + return ((InternalAuthToken) token).toMap(); + } + else + { + throw new ClientException( + "Unknown authentication token, `" + token + "`. Please use one of the supported " + + "tokens from `" + AuthTokens.class.getSimpleName() + "`." ); + } + } +} diff --git a/driver/src/main/java/org/neo4j/driver/internal/net/pooling/SocketConnectionPool.java b/driver/src/main/java/org/neo4j/driver/internal/net/pooling/SocketConnectionPool.java index 6f42d73618..a30ef274d0 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/net/pooling/SocketConnectionPool.java +++ b/driver/src/main/java/org/neo4j/driver/internal/net/pooling/SocketConnectionPool.java @@ -18,27 +18,16 @@ */ package org.neo4j.driver.internal.net.pooling; -import java.util.List; -import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicBoolean; -import org.neo4j.driver.internal.ConnectionSettings; import org.neo4j.driver.internal.net.BoltServerAddress; -import org.neo4j.driver.internal.net.ConcurrencyGuardingConnection; -import org.neo4j.driver.internal.net.SocketConnection; -import org.neo4j.driver.internal.security.InternalAuthToken; -import org.neo4j.driver.internal.security.SecurityPlan; import org.neo4j.driver.internal.spi.Connection; import org.neo4j.driver.internal.spi.ConnectionPool; +import org.neo4j.driver.internal.spi.Connector; import org.neo4j.driver.internal.util.Clock; import org.neo4j.driver.internal.util.Supplier; -import org.neo4j.driver.v1.AuthToken; -import org.neo4j.driver.v1.AuthTokens; import org.neo4j.driver.v1.Logging; -import org.neo4j.driver.v1.Value; -import org.neo4j.driver.v1.exceptions.ClientException; - -import static java.util.Collections.emptyList; /** * The pool is designed to buffer certain amount of free sessions into session pool. When closing a session, we first @@ -60,52 +49,26 @@ public class SocketConnectionPool implements ConnectionPool private final ConcurrentHashMap pools = new ConcurrentHashMap<>(); - private final Clock clock = Clock.SYSTEM; + private final AtomicBoolean closed = new AtomicBoolean(); - private final ConnectionSettings connectionSettings; - private final SecurityPlan securityPlan; private final PoolSettings poolSettings; + private final Connector connector; + private final Clock clock; private final Logging logging; - /** Shutdown flag */ - - public SocketConnectionPool( ConnectionSettings connectionSettings, SecurityPlan securityPlan, - PoolSettings poolSettings, Logging logging ) + public SocketConnectionPool( PoolSettings poolSettings, Connector connector, Clock clock, Logging logging ) { - this.connectionSettings = connectionSettings; - this.securityPlan = securityPlan; this.poolSettings = poolSettings; + this.connector = connector; + this.clock = clock; this.logging = logging; } - private Connection connect( BoltServerAddress address ) throws ClientException - { - Connection conn = new SocketConnection( address, securityPlan, logging ); - - // Because SocketConnection is not thread safe, wrap it in this guard - // to ensure concurrent access leads causes application errors - conn = new ConcurrencyGuardingConnection( conn ); - conn.init( connectionSettings.userAgent(), tokenAsMap( connectionSettings.authToken() ) ); - return conn; - } - - private static Map tokenAsMap( AuthToken token ) - { - if ( token instanceof InternalAuthToken ) - { - return ((InternalAuthToken) token).toMap(); - } - else - { - throw new ClientException( - "Unknown authentication token, `" + token + "`. Please use one of the supported " + - "tokens from `" + AuthTokens.class.getSimpleName() + "`." ); - } - } - @Override public Connection acquire( final BoltServerAddress address ) { + assertNotClosed(); + final BlockingPooledConnectionQueue connections = pool( address ); Supplier supplier = new Supplier() { @@ -116,10 +79,17 @@ public PooledConnection get() new PooledConnectionValidator( SocketConnectionPool.this ); PooledConnectionReleaseConsumer releaseConsumer = new PooledConnectionReleaseConsumer( connections, connectionValidator ); - return new PooledConnection( connect( address ), releaseConsumer, clock ); + return new PooledConnection( connector.connect( address ), releaseConsumer, clock ); } }; PooledConnection conn = connections.acquire( supplier ); + + if ( closed.get() ) + { + connections.terminate(); + throw poolClosedException(); + } + conn.updateTimestamp(); return conn; } @@ -144,12 +114,10 @@ private BlockingPooledConnectionQueue pool( BoltServerAddress address ) public void purge( BoltServerAddress address ) { BlockingPooledConnectionQueue connections = pools.remove( address ); - if ( connections == null ) + if ( connections != null ) { - return; + connections.terminate(); } - - connections.terminate(); } @Override @@ -161,28 +129,27 @@ public boolean hasAddress( BoltServerAddress address ) @Override public void close() { - for ( BlockingPooledConnectionQueue pool : pools.values() ) + if ( closed.compareAndSet( false, true ) ) { - pool.terminate(); - } + for ( BlockingPooledConnectionQueue pool : pools.values() ) + { + pool.terminate(); + } - pools.clear(); + pools.clear(); + } } - - //for testing - public List connectionsForAddress( BoltServerAddress address ) + private void assertNotClosed() { - BlockingPooledConnectionQueue pooledConnections = pools.get( address ); - if ( pooledConnections == null ) - { - return emptyList(); - } - else + if ( closed.get() ) { - return pooledConnections.toList(); + throw poolClosedException(); } } - + private static RuntimeException poolClosedException() + { + return new IllegalStateException( "Pool closed" ); + } } diff --git a/driver/src/main/java/org/neo4j/driver/internal/spi/Connector.java b/driver/src/main/java/org/neo4j/driver/internal/spi/Connector.java new file mode 100644 index 0000000000..b512203c56 --- /dev/null +++ b/driver/src/main/java/org/neo4j/driver/internal/spi/Connector.java @@ -0,0 +1,26 @@ +/* + * Copyright (c) 2002-2016 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal.spi; + +import org.neo4j.driver.internal.net.BoltServerAddress; + +public interface Connector +{ + Connection connect( BoltServerAddress address ); +} diff --git a/driver/src/main/java/org/neo4j/driver/v1/GraphDatabase.java b/driver/src/main/java/org/neo4j/driver/v1/GraphDatabase.java index 4c61e9d05a..265908fdb3 100644 --- a/driver/src/main/java/org/neo4j/driver/v1/GraphDatabase.java +++ b/driver/src/main/java/org/neo4j/driver/v1/GraphDatabase.java @@ -27,11 +27,13 @@ import org.neo4j.driver.internal.NetworkSession; import org.neo4j.driver.internal.RoutingDriver; import org.neo4j.driver.internal.net.BoltServerAddress; +import org.neo4j.driver.internal.net.SocketConnector; import org.neo4j.driver.internal.net.pooling.PoolSettings; import org.neo4j.driver.internal.net.pooling.SocketConnectionPool; import org.neo4j.driver.internal.security.SecurityPlan; import org.neo4j.driver.internal.spi.Connection; import org.neo4j.driver.internal.spi.ConnectionPool; +import org.neo4j.driver.internal.spi.Connector; import org.neo4j.driver.internal.util.Clock; import org.neo4j.driver.v1.exceptions.ClientException; import org.neo4j.driver.v1.util.Function; @@ -155,10 +157,6 @@ public static Driver driver( URI uri, AuthToken authToken, Config config ) String scheme = uri.getScheme(); BoltServerAddress address = BoltServerAddress.from( uri ); - // Collate session parameters - ConnectionSettings connectionSettings = - new ConnectionSettings( authToken == null ? AuthTokens.none() : authToken ); - // Make sure we have some configuration to play with if ( config == null ) { @@ -176,12 +174,8 @@ public static Driver driver( URI uri, AuthToken authToken, Config config ) throw new ClientException( "Unable to establish SSL parameters", ex ); } - // Establish pool settings - PoolSettings poolSettings = new PoolSettings( config.maxIdleConnectionPoolSize() ); + ConnectionPool connectionPool = createConnectionPool( authToken, securityPlan, config ); - // And finally, construct the driver proper - ConnectionPool connectionPool = - new SocketConnectionPool( connectionSettings, securityPlan, poolSettings, config.logging() ); switch ( scheme.toLowerCase() ) { case "bolt": @@ -199,6 +193,18 @@ public static Driver driver( URI uri, AuthToken authToken, Config config ) } } + private static ConnectionPool createConnectionPool( AuthToken authToken, SecurityPlan securityPlan, + Config config ) + { + authToken = authToken == null ? AuthTokens.none() : authToken; + + ConnectionSettings connectionSettings = new ConnectionSettings( authToken ); + PoolSettings poolSettings = new PoolSettings( config.maxIdleConnectionPoolSize() ); + Connector connector = new SocketConnector( connectionSettings, securityPlan, config.logging() ); + + return new SocketConnectionPool( poolSettings, connector, Clock.SYSTEM, config.logging() ); + } + /* * Establish a complete SecurityPlan based on the details provided for * driver construction. diff --git a/driver/src/test/java/org/neo4j/driver/internal/net/SocketConnectorTest.java b/driver/src/test/java/org/neo4j/driver/internal/net/SocketConnectorTest.java new file mode 100644 index 0000000000..3a7d335ce0 --- /dev/null +++ b/driver/src/test/java/org/neo4j/driver/internal/net/SocketConnectorTest.java @@ -0,0 +1,121 @@ +/* + * Copyright (c) 2002-2016 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal.net; + +import org.junit.Test; + +import java.util.List; +import java.util.Map; +import java.util.concurrent.CopyOnWriteArrayList; + +import org.neo4j.driver.internal.ConnectionSettings; +import org.neo4j.driver.internal.security.SecurityPlan; +import org.neo4j.driver.internal.spi.Connection; +import org.neo4j.driver.v1.AuthToken; +import org.neo4j.driver.v1.AuthTokens; +import org.neo4j.driver.v1.Logging; +import org.neo4j.driver.v1.exceptions.ClientException; + +import static org.hamcrest.Matchers.instanceOf; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.fail; +import static org.mockito.Matchers.any; +import static org.mockito.Matchers.eq; +import static org.mockito.Mockito.RETURNS_MOCKS; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.neo4j.driver.internal.net.BoltServerAddress.LOCAL_DEFAULT; +import static org.neo4j.driver.internal.security.SecurityPlan.insecure; + +public class SocketConnectorTest +{ + @Test + public void connectCreatesConnection() + { + ConnectionSettings settings = new ConnectionSettings( basicAuthToken() ); + SocketConnector connector = new TestSocketConnector( settings, insecure(), loggingMock() ); + + Connection connection = connector.connect( LOCAL_DEFAULT ); + + assertThat( connection, instanceOf( ConcurrencyGuardingConnection.class ) ); + } + + @Test + @SuppressWarnings( "unchecked" ) + public void connectSendsInit() + { + String userAgent = "agentSmith"; + ConnectionSettings settings = new ConnectionSettings( basicAuthToken(), userAgent ); + TestSocketConnector connector = new TestSocketConnector( settings, insecure(), loggingMock() ); + + connector.connect( LOCAL_DEFAULT ); + + assertEquals( 1, connector.createConnections.size() ); + Connection connection = connector.createConnections.get( 0 ); + verify( connection ).init( eq( userAgent ), any( Map.class ) ); + } + + @Test + public void connectThrowsForUnknownAuthToken() + { + ConnectionSettings settings = new ConnectionSettings( mock( AuthToken.class ) ); + TestSocketConnector connector = new TestSocketConnector( settings, insecure(), loggingMock() ); + + try + { + connector.connect( LOCAL_DEFAULT ); + fail( "Exception expected" ); + } + catch ( Exception e ) + { + assertThat( e, instanceOf( ClientException.class ) ); + } + } + + private static Logging loggingMock() + { + return mock( Logging.class, RETURNS_MOCKS ); + } + + private static AuthToken basicAuthToken() + { + return AuthTokens.basic( "neo4j", "neo4j" ); + } + + private static class TestSocketConnector extends SocketConnector + { + final List createConnections = new CopyOnWriteArrayList<>(); + + TestSocketConnector( ConnectionSettings settings, SecurityPlan securityPlan, Logging logging ) + { + super( settings, securityPlan, logging ); + } + + @Override + Connection createConnection( BoltServerAddress address, SecurityPlan securityPlan, Logging logging ) + { + Connection connection = mock( Connection.class ); + when( connection.boltServerAddress() ).thenReturn( address ); + createConnections.add( connection ); + return connection; + } + } +} diff --git a/driver/src/test/java/org/neo4j/driver/internal/net/pooling/SocketConnectionPoolTest.java b/driver/src/test/java/org/neo4j/driver/internal/net/pooling/SocketConnectionPoolTest.java new file mode 100644 index 0000000000..2215bc6f9a --- /dev/null +++ b/driver/src/test/java/org/neo4j/driver/internal/net/pooling/SocketConnectionPoolTest.java @@ -0,0 +1,335 @@ +/* + * Copyright (c) 2002-2016 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal.net.pooling; + +import org.junit.Test; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import java.util.ArrayList; +import java.util.List; +import java.util.Set; +import java.util.concurrent.Callable; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +import org.neo4j.driver.internal.net.BoltServerAddress; +import org.neo4j.driver.internal.spi.Connection; +import org.neo4j.driver.internal.spi.Connector; +import org.neo4j.driver.v1.Logging; + +import static java.util.Collections.newSetFromMap; +import static org.hamcrest.Matchers.instanceOf; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.RETURNS_MOCKS; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.neo4j.driver.internal.net.BoltServerAddress.DEFAULT_PORT; +import static org.neo4j.driver.internal.net.BoltServerAddress.LOCAL_DEFAULT; +import static org.neo4j.driver.internal.util.Clock.SYSTEM; + +public class SocketConnectionPoolTest +{ + private static final BoltServerAddress ADDRESS_1 = LOCAL_DEFAULT; + private static final BoltServerAddress ADDRESS_2 = new BoltServerAddress( "localhost", DEFAULT_PORT + 42 ); + private static final BoltServerAddress ADDRESS_3 = new BoltServerAddress( "localhost", DEFAULT_PORT + 4242 ); + + @Test + public void acquireCreatesNewConnectionWhenPoolIsEmpty() + { + Connector connector = newMockConnector(); + SocketConnectionPool pool = newPool( connector ); + + Connection connection = pool.acquire( ADDRESS_1 ); + + assertThat( connection, instanceOf( PooledConnection.class ) ); + verify( connector ).connect( ADDRESS_1 ); + } + + @Test + public void acquireUsesExistingConnectionIfPresent() + { + Connection connection = newConnectionMock( ADDRESS_1 ); + Connector connector = newMockConnector( connection ); + + SocketConnectionPool pool = newPool( connector ); + + Connection acquiredConnection1 = pool.acquire( ADDRESS_1 ); + assertThat( acquiredConnection1, instanceOf( PooledConnection.class ) ); + acquiredConnection1.close(); // return connection to the pool + + Connection acquiredConnection2 = pool.acquire( ADDRESS_1 ); + assertThat( acquiredConnection2, instanceOf( PooledConnection.class ) ); + + verify( connector ).connect( ADDRESS_1 ); + } + + @Test + public void purgeDoesNothingForNonExistingAddress() + { + Connection connection = newConnectionMock( ADDRESS_1 ); + SocketConnectionPool pool = newPool( newMockConnector( connection ) ); + + pool.acquire( ADDRESS_1 ).close(); + + assertTrue( pool.hasAddress( ADDRESS_1 ) ); + pool.purge( ADDRESS_2 ); + assertTrue( pool.hasAddress( ADDRESS_1 ) ); + } + + @Test + public void purgeRemovesAddress() + { + Connection connection = newConnectionMock( ADDRESS_1 ); + SocketConnectionPool pool = newPool( newMockConnector( connection ) ); + + pool.acquire( ADDRESS_1 ).close(); + + assertTrue( pool.hasAddress( ADDRESS_1 ) ); + pool.purge( ADDRESS_1 ); + assertFalse( pool.hasAddress( ADDRESS_1 ) ); + } + + @Test + public void purgeTerminatesPoolCorrespondingToTheAddress() + { + Connection connection1 = newConnectionMock( ADDRESS_1 ); + Connection connection2 = newConnectionMock( ADDRESS_1 ); + Connection connection3 = newConnectionMock( ADDRESS_1 ); + SocketConnectionPool pool = newPool( newMockConnector( connection1, connection2, connection3 ) ); + + Connection pooledConnection1 = pool.acquire( ADDRESS_1 ); + Connection pooledConnection2 = pool.acquire( ADDRESS_1 ); + pool.acquire( ADDRESS_1 ); + + // return two connections to the pool + pooledConnection1.close(); + pooledConnection2.close(); + + pool.purge( ADDRESS_1 ); + + verify( connection1 ).close(); + verify( connection2 ).close(); + verify( connection3 ).close(); + } + + @Test + public void hasAddressReturnsFalseWhenPoolIsEmpty() + { + SocketConnectionPool pool = newPool( newMockConnector() ); + + assertFalse( pool.hasAddress( ADDRESS_1 ) ); + assertFalse( pool.hasAddress( ADDRESS_2 ) ); + } + + @Test + public void hasAddressReturnsFalseForUnknownAddress() + { + SocketConnectionPool pool = newPool( newMockConnector() ); + + assertNotNull( pool.acquire( ADDRESS_1 ) ); + + assertFalse( pool.hasAddress( ADDRESS_2 ) ); + } + + @Test + public void hasAddressReturnsTrueForKnownAddress() + { + SocketConnectionPool pool = newPool( newMockConnector() ); + + assertNotNull( pool.acquire( ADDRESS_1 ) ); + + assertTrue( pool.hasAddress( ADDRESS_1 ) ); + } + + @Test + public void closeTerminatesAllPools() + { + Connection connection1 = newConnectionMock( ADDRESS_1 ); + Connection connection2 = newConnectionMock( ADDRESS_1 ); + Connection connection3 = newConnectionMock( ADDRESS_2 ); + Connection connection4 = newConnectionMock( ADDRESS_2 ); + + Connector connector = newMockConnector( connection1, connection2, connection3, connection4 ); + + SocketConnectionPool pool = newPool( connector ); + + assertNotNull( pool.acquire( ADDRESS_1 ) ); + pool.acquire( ADDRESS_1 ).close(); // return to the pool + assertNotNull( pool.acquire( ADDRESS_2 ) ); + pool.acquire( ADDRESS_2 ).close(); // return to the pool + + assertTrue( pool.hasAddress( ADDRESS_1 ) ); + assertTrue( pool.hasAddress( ADDRESS_2 ) ); + + pool.close(); + + verify( connection1 ).close(); + verify( connection2 ).close(); + verify( connection3 ).close(); + verify( connection4 ).close(); + } + + @Test + public void closeRemovesAllPools() + { + Connection connection1 = newConnectionMock( ADDRESS_1 ); + Connection connection2 = newConnectionMock( ADDRESS_2 ); + Connection connection3 = newConnectionMock( ADDRESS_3 ); + + Connector connector = newMockConnector( connection1, connection2, connection3 ); + + SocketConnectionPool pool = newPool( connector ); + + assertNotNull( pool.acquire( ADDRESS_1 ) ); + assertNotNull( pool.acquire( ADDRESS_2 ) ); + assertNotNull( pool.acquire( ADDRESS_3 ) ); + + assertTrue( pool.hasAddress( ADDRESS_1 ) ); + assertTrue( pool.hasAddress( ADDRESS_2 ) ); + assertTrue( pool.hasAddress( ADDRESS_3 ) ); + + pool.close(); + + assertFalse( pool.hasAddress( ADDRESS_1 ) ); + assertFalse( pool.hasAddress( ADDRESS_2 ) ); + assertFalse( pool.hasAddress( ADDRESS_3 ) ); + } + + @Test + public void closeWithConcurrentAcquisitionsEmptiesThePool() throws InterruptedException + { + Connector connector = mock( Connector.class ); + Set createdConnections = newSetFromMap( new ConcurrentHashMap() ); + when( connector.connect( any( BoltServerAddress.class ) ) ) + .then( createConnectionAnswer( createdConnections ) ); + + SocketConnectionPool pool = newPool( connector ); + + ExecutorService executor = Executors.newCachedThreadPool(); + List> results = new ArrayList<>(); + + AtomicInteger port = new AtomicInteger(); + for ( int i = 0; i < 5; i++ ) + { + Future result = executor.submit( acquireConnection( pool, port ) ); + results.add( result ); + } + + Thread.sleep( 500 ); // allow workers to do something + + pool.close(); + + for ( Future result : results ) + { + try + { + result.get( 20, TimeUnit.SECONDS ); + fail( "Exception expected" ); + } + catch ( Exception e ) + { + assertThat( e, instanceOf( ExecutionException.class ) ); + assertThat( e.getCause(), instanceOf( IllegalStateException.class ) ); + } + } + executor.shutdownNow(); + executor.awaitTermination( 10, TimeUnit.SECONDS ); + + for ( int i = 0; i < port.intValue(); i++ ) + { + assertFalse( pool.hasAddress( new BoltServerAddress( "localhost", i ) ) ); + } + for ( Connection connection : createdConnections ) + { + verify( connection ).close(); + } + } + + private static Answer createConnectionAnswer( final Set createdConnections ) + { + return new Answer() + { + @Override + public Connection answer( InvocationOnMock invocation ) + { + BoltServerAddress address = invocation.getArgumentAt( 0, BoltServerAddress.class ); + Connection connection = newConnectionMock( address ); + createdConnections.add( connection ); + return connection; + } + }; + } + + private static Callable acquireConnection( final SocketConnectionPool pool, final AtomicInteger port ) + { + return new Callable() + { + @Override + public Void call() throws Exception + { + while ( true ) + { + pool.acquire( new BoltServerAddress( "localhost", port.incrementAndGet() ) ); + } + } + }; + } + + private static Connector newMockConnector() + { + Connection connection = mock( Connection.class ); + return newMockConnector( connection ); + } + + private static Connector newMockConnector( Connection connection, Connection... otherConnections ) + { + Connector connector = mock( Connector.class ); + when( connector.connect( any( BoltServerAddress.class ) ) ).thenReturn( connection, otherConnections ); + return connector; + } + + private static SocketConnectionPool newPool( Connector connector ) + { + PoolSettings poolSettings = new PoolSettings( 42 ); + Logging logging = mock( Logging.class, RETURNS_MOCKS ); + return new SocketConnectionPool( poolSettings, connector, SYSTEM, logging ); + } + + private static Connection newConnectionMock( BoltServerAddress address ) + { + Connection connection = mock( Connection.class ); + if ( address != null ) + { + when( connection.boltServerAddress() ).thenReturn( address ); + } + return connection; + } +} From 4b0c5aa995263b82392e883165b2dacc59a9cffd Mon Sep 17 00:00:00 2001 From: lutovich Date: Thu, 8 Dec 2016 14:37:12 +0100 Subject: [PATCH 5/6] Close connection if init failed Init message is send every time we establish a new socket connection. This is the very first thing connection does. It is possible for the init to fail (for example when credentials are wrong). Previously socket connection was not closed after such init failure. This commit changes `SocketConnector` to close connection on any init failure. --- .../driver/internal/net/SocketConnector.java | 11 +++- .../internal/net/SocketConnectorTest.java | 56 ++++++++++++++++--- 2 files changed, 59 insertions(+), 8 deletions(-) diff --git a/driver/src/main/java/org/neo4j/driver/internal/net/SocketConnector.java b/driver/src/main/java/org/neo4j/driver/internal/net/SocketConnector.java index 15158e88fa..979b1f1018 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/net/SocketConnector.java +++ b/driver/src/main/java/org/neo4j/driver/internal/net/SocketConnector.java @@ -53,7 +53,16 @@ public final Connection connect( BoltServerAddress address ) // to ensure concurrent access leads causes application errors connection = new ConcurrencyGuardingConnection( connection ); - connection.init( connectionSettings.userAgent(), tokenAsMap( connectionSettings.authToken() ) ); + try + { + connection.init( connectionSettings.userAgent(), tokenAsMap( connectionSettings.authToken() ) ); + } + catch ( Throwable initError ) + { + connection.close(); + throw initError; + } + return connection; } diff --git a/driver/src/test/java/org/neo4j/driver/internal/net/SocketConnectorTest.java b/driver/src/test/java/org/neo4j/driver/internal/net/SocketConnectorTest.java index 3a7d335ce0..b97f6c4c39 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/net/SocketConnectorTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/net/SocketConnectorTest.java @@ -34,16 +34,18 @@ import static org.hamcrest.Matchers.instanceOf; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertSame; import static org.junit.Assert.assertThat; import static org.junit.Assert.fail; import static org.mockito.Matchers.any; +import static org.mockito.Matchers.anyString; import static org.mockito.Matchers.eq; import static org.mockito.Mockito.RETURNS_MOCKS; +import static org.mockito.Mockito.doThrow; import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; import static org.neo4j.driver.internal.net.BoltServerAddress.LOCAL_DEFAULT; -import static org.neo4j.driver.internal.security.SecurityPlan.insecure; public class SocketConnectorTest { @@ -51,7 +53,7 @@ public class SocketConnectorTest public void connectCreatesConnection() { ConnectionSettings settings = new ConnectionSettings( basicAuthToken() ); - SocketConnector connector = new TestSocketConnector( settings, insecure(), loggingMock() ); + SocketConnector connector = new RecordingSocketConnector( settings ); Connection connection = connector.connect( LOCAL_DEFAULT ); @@ -64,7 +66,7 @@ public void connectSendsInit() { String userAgent = "agentSmith"; ConnectionSettings settings = new ConnectionSettings( basicAuthToken(), userAgent ); - TestSocketConnector connector = new TestSocketConnector( settings, insecure(), loggingMock() ); + RecordingSocketConnector connector = new RecordingSocketConnector( settings ); connector.connect( LOCAL_DEFAULT ); @@ -77,7 +79,7 @@ public void connectSendsInit() public void connectThrowsForUnknownAuthToken() { ConnectionSettings settings = new ConnectionSettings( mock( AuthToken.class ) ); - TestSocketConnector connector = new TestSocketConnector( settings, insecure(), loggingMock() ); + RecordingSocketConnector connector = new RecordingSocketConnector( settings ); try { @@ -90,6 +92,29 @@ public void connectThrowsForUnknownAuthToken() } } + @Test + @SuppressWarnings( "unchecked" ) + public void connectClosesOpenedConnectionIfInitThrows() + { + Connection connection = mock( Connection.class ); + RuntimeException initError = new RuntimeException( "Init error" ); + doThrow( initError ).when( connection ).init( anyString(), any( Map.class ) ); + + StubSocketConnector connector = new StubSocketConnector( connection ); + + try + { + connector.connect( LOCAL_DEFAULT ); + fail( "Exception expected" ); + } + catch ( Exception e ) + { + assertSame( initError, e ); + } + + verify( connection ).close(); + } + private static Logging loggingMock() { return mock( Logging.class, RETURNS_MOCKS ); @@ -100,13 +125,13 @@ private static AuthToken basicAuthToken() return AuthTokens.basic( "neo4j", "neo4j" ); } - private static class TestSocketConnector extends SocketConnector + private static class RecordingSocketConnector extends SocketConnector { final List createConnections = new CopyOnWriteArrayList<>(); - TestSocketConnector( ConnectionSettings settings, SecurityPlan securityPlan, Logging logging ) + RecordingSocketConnector( ConnectionSettings settings ) { - super( settings, securityPlan, logging ); + super( settings, SecurityPlan.insecure(), loggingMock() ); } @Override @@ -118,4 +143,21 @@ Connection createConnection( BoltServerAddress address, SecurityPlan securityPla return connection; } } + + private static class StubSocketConnector extends SocketConnector + { + final Connection connection; + + StubSocketConnector( Connection connection ) + { + super( new ConnectionSettings( basicAuthToken() ), SecurityPlan.insecure(), loggingMock() ); + this.connection = connection; + } + + @Override + Connection createConnection( BoltServerAddress address, SecurityPlan securityPlan, Logging logging ) + { + return connection; + } + } } From 668815c49075349338a4c4c2b5b9049dc50f48ed Mon Sep 17 00:00:00 2001 From: lutovich Date: Thu, 8 Dec 2016 17:05:04 +0100 Subject: [PATCH 6/6] Close connection pool if driver creation fails Connection pool (`SocketConnectionPool`) is created before drivers and was never closed if driver creation failed. This was especially possible with `RoutingDriver` which tries to build routing table in constructor. This commit adds closing of the connection pool when creation of driver fails. It also extracts driver creation logic into an internal class `DriverFactory` to make it testable. --- .../neo4j/driver/internal/DriverFactory.java | 182 ++++++++++++++++++ .../org/neo4j/driver/v1/GraphDatabase.java | 129 +------------ .../driver/internal/DriverFactoryTest.java | 142 ++++++++++++++ 3 files changed, 327 insertions(+), 126 deletions(-) create mode 100644 driver/src/main/java/org/neo4j/driver/internal/DriverFactory.java create mode 100644 driver/src/test/java/org/neo4j/driver/internal/DriverFactoryTest.java diff --git a/driver/src/main/java/org/neo4j/driver/internal/DriverFactory.java b/driver/src/main/java/org/neo4j/driver/internal/DriverFactory.java new file mode 100644 index 0000000000..022a456947 --- /dev/null +++ b/driver/src/main/java/org/neo4j/driver/internal/DriverFactory.java @@ -0,0 +1,182 @@ +/* + * Copyright (c) 2002-2016 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal; + +import java.io.IOException; +import java.net.URI; +import java.security.GeneralSecurityException; + +import org.neo4j.driver.internal.cluster.RoutingSettings; +import org.neo4j.driver.internal.net.BoltServerAddress; +import org.neo4j.driver.internal.net.SocketConnector; +import org.neo4j.driver.internal.net.pooling.PoolSettings; +import org.neo4j.driver.internal.net.pooling.SocketConnectionPool; +import org.neo4j.driver.internal.security.SecurityPlan; +import org.neo4j.driver.internal.spi.ConnectionPool; +import org.neo4j.driver.internal.spi.Connector; +import org.neo4j.driver.internal.util.Clock; +import org.neo4j.driver.v1.AuthToken; +import org.neo4j.driver.v1.AuthTokens; +import org.neo4j.driver.v1.Config; +import org.neo4j.driver.v1.Driver; +import org.neo4j.driver.v1.Logger; +import org.neo4j.driver.v1.exceptions.ClientException; + +import static java.lang.String.format; +import static org.neo4j.driver.internal.security.SecurityPlan.insecure; +import static org.neo4j.driver.v1.Config.EncryptionLevel.REQUIRED; + +public class DriverFactory +{ + public final Driver newInstance( URI uri, AuthToken authToken, RoutingSettings routingSettings, Config config ) + { + BoltServerAddress address = BoltServerAddress.from( uri ); + SecurityPlan securityPlan = createSecurityPlan( address, config ); + ConnectionPool connectionPool = createConnectionPool( authToken, securityPlan, config ); + + try + { + return createDriver( address, uri.getScheme(), connectionPool, config, routingSettings, securityPlan ); + } + catch ( Throwable driverError ) + { + // we need to close the connection pool if driver creation threw exception + try + { + connectionPool.close(); + } + catch ( Throwable closeError ) + { + driverError.addSuppressed( closeError ); + } + throw driverError; + } + } + + private Driver createDriver( BoltServerAddress address, String scheme, ConnectionPool connectionPool, + Config config, RoutingSettings routingSettings, SecurityPlan securityPlan ) + { + switch ( scheme.toLowerCase() ) + { + case "bolt": + return createDirectDriver( address, connectionPool, config, securityPlan ); + case "bolt+routing": + return createRoutingDriver( address, connectionPool, config, routingSettings, securityPlan ); + default: + throw new ClientException( format( "Unsupported URI scheme: %s", scheme ) ); + } + } + + /** + * Creates new {@link DirectDriver}. + *

+ * This method is package-private only for testing + */ + DirectDriver createDirectDriver( BoltServerAddress address, ConnectionPool connectionPool, Config config, + SecurityPlan securityPlan ) + { + return new DirectDriver( address, connectionPool, securityPlan, config.logging() ); + } + + /** + * Creates new {@link RoutingDriver}. + *

+ * This method is package-private only for testing + */ + RoutingDriver createRoutingDriver( BoltServerAddress address, ConnectionPool connectionPool, + Config config, RoutingSettings routingSettings, SecurityPlan securityPlan ) + { + return new RoutingDriver( routingSettings, address, connectionPool, securityPlan, Clock.SYSTEM, + config.logging() ); + } + + /** + * Creates new {@link ConnectionPool}. + *

+ * This method is package-private only for testing + */ + ConnectionPool createConnectionPool( AuthToken authToken, SecurityPlan securityPlan, Config config ) + { + authToken = authToken == null ? AuthTokens.none() : authToken; + + ConnectionSettings connectionSettings = new ConnectionSettings( authToken ); + PoolSettings poolSettings = new PoolSettings( config.maxIdleConnectionPoolSize() ); + Connector connector = new SocketConnector( connectionSettings, securityPlan, config.logging() ); + + return new SocketConnectionPool( poolSettings, connector, Clock.SYSTEM, config.logging() ); + } + + private static SecurityPlan createSecurityPlan( BoltServerAddress address, Config config ) + { + try + { + return createSecurityPlanImpl( address, config ); + } + catch ( GeneralSecurityException | IOException ex ) + { + throw new ClientException( "Unable to establish SSL parameters", ex ); + } + } + + /* + * Establish a complete SecurityPlan based on the details provided for + * driver construction. + */ + private static SecurityPlan createSecurityPlanImpl( BoltServerAddress address, Config config ) + throws GeneralSecurityException, IOException + { + Config.EncryptionLevel encryptionLevel = config.encryptionLevel(); + boolean requiresEncryption = encryptionLevel.equals( REQUIRED ); + + if ( requiresEncryption ) + { + Logger logger = config.logging().getLog( "session" ); + switch ( config.trustStrategy().strategy() ) + { + + // DEPRECATED CASES // + case TRUST_ON_FIRST_USE: + logger.warn( + "Option `TRUST_ON_FIRST_USE` has been deprecated and will be removed in a future " + + "version of the driver. Please switch to use `TRUST_ALL_CERTIFICATES` instead." ); + return SecurityPlan.forTrustOnFirstUse( config.trustStrategy().certFile(), address, logger ); + case TRUST_SIGNED_CERTIFICATES: + logger.warn( + "Option `TRUST_SIGNED_CERTIFICATE` has been deprecated and will be removed in a future " + + "version of the driver. Please switch to use `TRUST_CUSTOM_CA_SIGNED_CERTIFICATES` instead." ); + // intentional fallthrough + // END OF DEPRECATED CASES // + + case TRUST_CUSTOM_CA_SIGNED_CERTIFICATES: + return SecurityPlan.forCustomCASignedCertificates( config.trustStrategy().certFile() ); + case TRUST_SYSTEM_CA_SIGNED_CERTIFICATES: + return SecurityPlan.forSystemCASignedCertificates(); + case TRUST_ALL_CERTIFICATES: + return SecurityPlan.forAllCertificates(); + default: + throw new ClientException( + "Unknown TLS authentication strategy: " + config.trustStrategy().strategy().name() ); + } + } + else + { + return insecure(); + } + } +} diff --git a/driver/src/main/java/org/neo4j/driver/v1/GraphDatabase.java b/driver/src/main/java/org/neo4j/driver/v1/GraphDatabase.java index 265908fdb3..3ab9e17b8e 100644 --- a/driver/src/main/java/org/neo4j/driver/v1/GraphDatabase.java +++ b/driver/src/main/java/org/neo4j/driver/v1/GraphDatabase.java @@ -18,29 +18,9 @@ */ package org.neo4j.driver.v1; -import java.io.IOException; import java.net.URI; -import java.security.GeneralSecurityException; -import org.neo4j.driver.internal.ConnectionSettings; -import org.neo4j.driver.internal.DirectDriver; -import org.neo4j.driver.internal.NetworkSession; -import org.neo4j.driver.internal.RoutingDriver; -import org.neo4j.driver.internal.net.BoltServerAddress; -import org.neo4j.driver.internal.net.SocketConnector; -import org.neo4j.driver.internal.net.pooling.PoolSettings; -import org.neo4j.driver.internal.net.pooling.SocketConnectionPool; -import org.neo4j.driver.internal.security.SecurityPlan; -import org.neo4j.driver.internal.spi.Connection; -import org.neo4j.driver.internal.spi.ConnectionPool; -import org.neo4j.driver.internal.spi.Connector; -import org.neo4j.driver.internal.util.Clock; -import org.neo4j.driver.v1.exceptions.ClientException; -import org.neo4j.driver.v1.util.Function; - -import static java.lang.String.format; -import static org.neo4j.driver.internal.security.SecurityPlan.insecure; -import static org.neo4j.driver.v1.Config.EncryptionLevel.REQUIRED; +import org.neo4j.driver.internal.DriverFactory; /** * Creates {@link Driver drivers}, optionally letting you {@link #driver(URI, Config)} to configure them. @@ -49,17 +29,6 @@ */ public class GraphDatabase { - - private static final Function - SESSION_PROVIDER = new Function() - { - @Override - public Session apply( Connection connection ) - { - return new NetworkSession( connection ); - } - }; - /** * Return a driver for a Neo4j instance with the default configuration settings * @@ -153,101 +122,9 @@ public static Driver driver( String uri, AuthToken authToken, Config config ) */ public static Driver driver( URI uri, AuthToken authToken, Config config ) { - // Break down the URI into its constituent parts - String scheme = uri.getScheme(); - BoltServerAddress address = BoltServerAddress.from( uri ); - // Make sure we have some configuration to play with - if ( config == null ) - { - config = Config.defaultConfig(); - } - - // Construct security plan - SecurityPlan securityPlan; - try - { - securityPlan = createSecurityPlan( address, config ); - } - catch ( GeneralSecurityException | IOException ex ) - { - throw new ClientException( "Unable to establish SSL parameters", ex ); - } - - ConnectionPool connectionPool = createConnectionPool( authToken, securityPlan, config ); - - switch ( scheme.toLowerCase() ) - { - case "bolt": - return new DirectDriver( address, connectionPool, securityPlan, config.logging() ); - case "bolt+routing": - return new RoutingDriver( - config.routingSettings(), - address, - connectionPool, - securityPlan, - Clock.SYSTEM, - config.logging() ); - default: - throw new ClientException( format( "Unsupported URI scheme: %s", scheme ) ); - } - } - - private static ConnectionPool createConnectionPool( AuthToken authToken, SecurityPlan securityPlan, - Config config ) - { - authToken = authToken == null ? AuthTokens.none() : authToken; - - ConnectionSettings connectionSettings = new ConnectionSettings( authToken ); - PoolSettings poolSettings = new PoolSettings( config.maxIdleConnectionPoolSize() ); - Connector connector = new SocketConnector( connectionSettings, securityPlan, config.logging() ); - - return new SocketConnectionPool( poolSettings, connector, Clock.SYSTEM, config.logging() ); - } - - /* - * Establish a complete SecurityPlan based on the details provided for - * driver construction. - */ - private static SecurityPlan createSecurityPlan( BoltServerAddress address, Config config ) - throws GeneralSecurityException, IOException - { - Config.EncryptionLevel encryptionLevel = config.encryptionLevel(); - boolean requiresEncryption = encryptionLevel.equals( REQUIRED ); - - if ( requiresEncryption ) - { - Logger logger = config.logging().getLog( "session" ); - switch ( config.trustStrategy().strategy() ) - { - - // DEPRECATED CASES // - case TRUST_ON_FIRST_USE: - logger.warn( - "Option `TRUST_ON_FIRST_USE` has been deprecated and will be removed in a future " + - "version of the driver. Please switch to use `TRUST_ALL_CERTIFICATES` instead." ); - return SecurityPlan.forTrustOnFirstUse( config.trustStrategy().certFile(), address, logger ); - case TRUST_SIGNED_CERTIFICATES: - logger.warn( - "Option `TRUST_SIGNED_CERTIFICATE` has been deprecated and will be removed in a future " + - "version of the driver. Please switch to use `TRUST_CUSTOM_CA_SIGNED_CERTIFICATES` instead." ); - // intentional fallthrough - // END OF DEPRECATED CASES // + config = config == null ? Config.defaultConfig() : config; - case TRUST_CUSTOM_CA_SIGNED_CERTIFICATES: - return SecurityPlan.forCustomCASignedCertificates( config.trustStrategy().certFile() ); - case TRUST_SYSTEM_CA_SIGNED_CERTIFICATES: - return SecurityPlan.forSystemCASignedCertificates(); - case TRUST_ALL_CERTIFICATES: - return SecurityPlan.forAllCertificates(); - default: - throw new ClientException( - "Unknown TLS authentication strategy: " + config.trustStrategy().strategy().name() ); - } - } - else - { - return insecure(); - } + return new DriverFactory().newInstance( uri, authToken, config.routingSettings(), config ); } } diff --git a/driver/src/test/java/org/neo4j/driver/internal/DriverFactoryTest.java b/driver/src/test/java/org/neo4j/driver/internal/DriverFactoryTest.java new file mode 100644 index 0000000000..f96b9762f6 --- /dev/null +++ b/driver/src/test/java/org/neo4j/driver/internal/DriverFactoryTest.java @@ -0,0 +1,142 @@ +/* + * Copyright (c) 2002-2016 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal; + +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; + +import java.net.URI; +import java.util.Arrays; +import java.util.List; + +import org.neo4j.driver.internal.cluster.RoutingSettings; +import org.neo4j.driver.internal.net.BoltServerAddress; +import org.neo4j.driver.internal.security.SecurityPlan; +import org.neo4j.driver.internal.spi.ConnectionPool; +import org.neo4j.driver.v1.AuthToken; +import org.neo4j.driver.v1.AuthTokens; +import org.neo4j.driver.v1.Config; + +import static org.hamcrest.Matchers.instanceOf; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.fail; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.neo4j.driver.v1.Config.defaultConfig; + +@RunWith( Parameterized.class ) +public class DriverFactoryTest +{ + @Parameter + public URI uri; + + @Parameters( name = "{0}" ) + public static List uris() + { + return Arrays.asList( + URI.create( "bolt://localhost:7687" ), + URI.create( "bolt+routing://localhost:7687" ) + ); + } + + @Test + public void connectionPoolClosedWhenDriverCreationFails() throws Exception + { + ConnectionPool connectionPool = mock( ConnectionPool.class ); + DriverFactory factory = new ThrowingDriverFactory( connectionPool ); + + try + { + factory.newInstance( uri, dummyAuthToken(), dummyRoutingSettings(), defaultConfig() ); + fail( "Exception expected" ); + } + catch ( Exception e ) + { + assertThat( e, instanceOf( UnsupportedOperationException.class ) ); + } + verify( connectionPool ).close(); + } + + @Test + public void connectionPoolCloseExceptionIsSupressedWhenDriverCreationFails() throws Exception + { + ConnectionPool connectionPool = mock( ConnectionPool.class ); + RuntimeException poolCloseError = new RuntimeException( "Pool close error" ); + doThrow( poolCloseError ).when( connectionPool ).close(); + + DriverFactory factory = new ThrowingDriverFactory( connectionPool ); + + try + { + factory.newInstance( uri, dummyAuthToken(), dummyRoutingSettings(), defaultConfig() ); + fail( "Exception expected" ); + } + catch ( Exception e ) + { + assertThat( e, instanceOf( UnsupportedOperationException.class ) ); + assertArrayEquals( new Throwable[]{poolCloseError}, e.getSuppressed() ); + } + verify( connectionPool ).close(); + } + + private static AuthToken dummyAuthToken() + { + return AuthTokens.basic( "neo4j", "neo4j" ); + } + + private static RoutingSettings dummyRoutingSettings() + { + return new RoutingSettings( 42, 42 ); + } + + private static class ThrowingDriverFactory extends DriverFactory + { + final ConnectionPool connectionPool; + + ThrowingDriverFactory( ConnectionPool connectionPool ) + { + this.connectionPool = connectionPool; + } + + @Override + DirectDriver createDirectDriver( BoltServerAddress address, ConnectionPool connectionPool, Config config, + SecurityPlan securityPlan ) + { + throw new UnsupportedOperationException( "Can't create direct driver" ); + } + + @Override + RoutingDriver createRoutingDriver( BoltServerAddress address, ConnectionPool connectionPool, Config config, + RoutingSettings routingSettings, SecurityPlan securityPlan ) + { + throw new UnsupportedOperationException( "Can't create routing driver" ); + } + + @Override + ConnectionPool createConnectionPool( AuthToken authToken, SecurityPlan securityPlan, Config config ) + { + return connectionPool; + } + } +}