diff --git a/driver/src/main/java/org/neo4j/driver/internal/connector/socket/SocketClient.java b/driver/src/main/java/org/neo4j/driver/internal/connector/socket/SocketClient.java index 5931c6f6d5..cabe03543d 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/connector/socket/SocketClient.java +++ b/driver/src/main/java/org/neo4j/driver/internal/connector/socket/SocketClient.java @@ -35,6 +35,8 @@ import org.neo4j.driver.v1.exceptions.ClientException; import static java.nio.ByteOrder.BIG_ENDIAN; +import static org.neo4j.driver.internal.connector.socket.SocketUtils.blockingRead; +import static org.neo4j.driver.internal.connector.socket.SocketUtils.blockingWrite; public class SocketClient { @@ -70,7 +72,6 @@ public void start() { logger.debug( "~~ [CONNECT] %s:%d.", host, port ); channel = ChannelFactory.create( host, port, config, logger ); - protocol = negotiateProtocol(); reader = protocol.reader(); writer = protocol.writer(); @@ -170,7 +171,7 @@ private SocketProtocol negotiateProtocol() throws IOException { logger.debug( "~~ [HANDSHAKE] [0x6060B017, 1, 0, 0, 0]." ); //Propose protocol versions - ByteBuffer buf = ByteBuffer.allocate( 5 * 4 ).order( BIG_ENDIAN ); + ByteBuffer buf = ByteBuffer.allocateDirect( 5 * 4 ).order( BIG_ENDIAN ); buf.putInt( MAGIC_PREAMBLE ); for ( int version : SUPPORTED_VERSIONS ) { @@ -178,13 +179,13 @@ private SocketProtocol negotiateProtocol() throws IOException } buf.flip(); - channel.write( buf ); + //Do a blocking write + blockingWrite(channel, buf); - // Read back the servers choice + // Read (blocking) back the servers choice buf.clear(); buf.limit( 4 ); - - channel.read( buf ); + blockingRead(channel, buf); // Choose protocol, or fail buf.flip(); @@ -223,7 +224,6 @@ public static ByteChannel create( String host, int port, Config config, Logger l SocketChannel soChannel = SocketChannel.open(); soChannel.setOption( StandardSocketOptions.SO_REUSEADDR, true ); soChannel.setOption( StandardSocketOptions.SO_KEEPALIVE, true ); - soChannel.connect( new InetSocketAddress( host, port ) ); ByteChannel channel; diff --git a/driver/src/main/java/org/neo4j/driver/internal/connector/socket/AllOrNothingChannel.java b/driver/src/main/java/org/neo4j/driver/internal/connector/socket/SocketUtils.java similarity index 61% rename from driver/src/main/java/org/neo4j/driver/internal/connector/socket/AllOrNothingChannel.java rename to driver/src/main/java/org/neo4j/driver/internal/connector/socket/SocketUtils.java index 2b1961d7c3..1f94ac9c3b 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/connector/socket/AllOrNothingChannel.java +++ b/driver/src/main/java/org/neo4j/driver/internal/connector/socket/SocketUtils.java @@ -21,33 +21,25 @@ import java.io.IOException; import java.nio.ByteBuffer; import java.nio.channels.ByteChannel; -import java.nio.channels.SocketChannel; import org.neo4j.driver.internal.util.BytePrinter; import org.neo4j.driver.v1.exceptions.ClientException; /** - * Wraps a regular socket channel such that read and write will not return until the full buffers given have been sent - * or received, respectively. + * Utility class for common operations. */ -public class AllOrNothingChannel implements ByteChannel +public final class SocketUtils { - private final SocketChannel channel; - - public AllOrNothingChannel( SocketChannel channel ) throws IOException + private SocketUtils() { - this.channel = channel; - this.channel.configureBlocking( true ); + throw new UnsupportedOperationException( "Do not instantiate" ); } - @Override - public int read( ByteBuffer buf ) throws IOException + public static void blockingRead(ByteChannel channel, ByteBuffer buf) throws IOException { - int toRead = buf.remaining(); - while ( buf.remaining() > 0 ) + while(buf.hasRemaining()) { - int read = channel.read( buf ); - if ( read == -1 ) + if (channel.read( buf ) < 0) { throw new ClientException( String.format( "Connection terminated while receiving data. This can happen due to network " + @@ -55,17 +47,13 @@ public int read( ByteBuffer buf ) throws IOException buf.limit(), BytePrinter.hex( buf ) ) ); } } - return toRead; } - @Override - public int write( ByteBuffer buf ) throws IOException + public static void blockingWrite(ByteChannel channel, ByteBuffer buf) throws IOException { - int toWrite = buf.remaining(); - while( buf.remaining() > 0 ) + while(buf.hasRemaining()) { - int write = channel.write( buf ); - if( write == -1 ) + if (channel.write( buf ) < 0) { throw new ClientException( String.format( "Connection terminated while sending data. This can happen due to network " + @@ -73,18 +61,5 @@ public int write( ByteBuffer buf ) throws IOException buf.limit(), BytePrinter.hex( buf ) ) ); } } - return toWrite; - } - - @Override - public boolean isOpen() - { - return channel.isOpen(); - } - - @Override - public void close() throws IOException - { - channel.close(); } } diff --git a/driver/src/test/java/org/neo4j/driver/internal/connector/socket/SocketUtilsTest.java b/driver/src/test/java/org/neo4j/driver/internal/connector/socket/SocketUtilsTest.java new file mode 100644 index 0000000000..4f8783eb7b --- /dev/null +++ b/driver/src/test/java/org/neo4j/driver/internal/connector/socket/SocketUtilsTest.java @@ -0,0 +1,152 @@ +/** + * Copyright (c) 2002-2016 "Neo Technology," + * Network Engine for Objects in Lund AB [http://neotechnology.com] + * + * This file is part of Neo4j. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.neo4j.driver.internal.connector.socket; + +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.ExpectedException; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.ByteChannel; +import java.util.ArrayList; +import java.util.List; + +import org.neo4j.driver.v1.exceptions.ClientException; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class SocketUtilsTest +{ + @Rule + public ExpectedException exception = ExpectedException.none(); + + @Test + public void shouldReadAllBytes() throws IOException + { + // Given + ByteBuffer buffer = ByteBuffer.allocate( 4 ); + ByteAtATimeChannel channel = new ByteAtATimeChannel( new byte[]{0, 1, 2, 3} ); + + // When + SocketUtils.blockingRead(channel, buffer ); + buffer.flip(); + + // Then + assertThat(buffer.get(), equalTo((byte) 0)); + assertThat(buffer.get(), equalTo((byte) 1)); + assertThat(buffer.get(), equalTo((byte) 2)); + assertThat(buffer.get(), equalTo((byte) 3)); + } + + @Test + public void shouldFailIfConnectionFailsWhileReading() throws IOException + { + // Given + ByteBuffer buffer = ByteBuffer.allocate( 4 ); + ByteChannel channel = mock( ByteChannel.class ); + when(channel.read( buffer )).thenReturn( -1 ); + + //Expect + exception.expect( ClientException.class ); + + // When + SocketUtils.blockingRead(channel, buffer ); + } + + @Test + public void shouldWriteAllBytes() throws IOException + { + // Given + ByteBuffer buffer = ByteBuffer.wrap( new byte[]{0, 1, 2, 3}); + ByteAtATimeChannel channel = new ByteAtATimeChannel( new byte[0] ); + + // When + SocketUtils.blockingWrite(channel, buffer ); + + // Then + assertThat(channel.writtenBytes.get(0), equalTo((byte) 0)); + assertThat(channel.writtenBytes.get(1), equalTo((byte) 1)); + assertThat(channel.writtenBytes.get(2), equalTo((byte) 2)); + assertThat(channel.writtenBytes.get(3), equalTo((byte) 3)); + } + + @Test + public void shouldFailIfConnectionFailsWhileWriting() throws IOException + { + // Given + ByteBuffer buffer = ByteBuffer.allocate( 4 ); + ByteChannel channel = mock( ByteChannel.class ); + when(channel.write( buffer )).thenReturn( -1 ); + + //Expect + exception.expect( ClientException.class ); + + // When + SocketUtils.blockingWrite(channel, buffer ); + } + + private static class ByteAtATimeChannel implements ByteChannel + { + + private final byte[] bytes; + private int index = 0; + private List writtenBytes = new ArrayList<>( ); + + private ByteAtATimeChannel( byte[] bytes ) + { + this.bytes = bytes; + } + + @Override + public int read( ByteBuffer dst ) throws IOException + { + if (index >= bytes.length) + { + return -1; + } + + dst.put( bytes[index++]); + return 1; + } + + @Override + public int write( ByteBuffer src ) throws IOException + { + writtenBytes.add( src.get() ); + return 1; + } + + @Override + public boolean isOpen() + { + return true; + } + + @Override + public void close() throws IOException + { + + } + } + +} \ No newline at end of file