Skip to content

Block when negotiating versions #186

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jun 9, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@
import org.neo4j.driver.v1.exceptions.ClientException;

import static java.nio.ByteOrder.BIG_ENDIAN;
import static org.neo4j.driver.internal.connector.socket.SocketUtils.blockingRead;
import static org.neo4j.driver.internal.connector.socket.SocketUtils.blockingWrite;

public class SocketClient
{
Expand Down Expand Up @@ -70,7 +72,6 @@ public void start()
{
logger.debug( "~~ [CONNECT] %s:%d.", host, port );
channel = ChannelFactory.create( host, port, config, logger );

protocol = negotiateProtocol();
reader = protocol.reader();
writer = protocol.writer();
Expand Down Expand Up @@ -170,21 +171,21 @@ private SocketProtocol negotiateProtocol() throws IOException
{
logger.debug( "~~ [HANDSHAKE] [0x6060B017, 1, 0, 0, 0]." );
//Propose protocol versions
ByteBuffer buf = ByteBuffer.allocate( 5 * 4 ).order( BIG_ENDIAN );
ByteBuffer buf = ByteBuffer.allocateDirect( 5 * 4 ).order( BIG_ENDIAN );
buf.putInt( MAGIC_PREAMBLE );
for ( int version : SUPPORTED_VERSIONS )
{
buf.putInt( version );
}
buf.flip();

channel.write( buf );
//Do a blocking write
blockingWrite(channel, buf);

// Read back the servers choice
// Read (blocking) back the servers choice
buf.clear();
buf.limit( 4 );

channel.read( buf );
blockingRead(channel, buf);

// Choose protocol, or fail
buf.flip();
Expand Down Expand Up @@ -223,7 +224,6 @@ public static ByteChannel create( String host, int port, Config config, Logger l
SocketChannel soChannel = SocketChannel.open();
soChannel.setOption( StandardSocketOptions.SO_REUSEADDR, true );
soChannel.setOption( StandardSocketOptions.SO_KEEPALIVE, true );

soChannel.connect( new InetSocketAddress( host, port ) );

ByteChannel channel;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,70 +21,45 @@
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.ByteChannel;
import java.nio.channels.SocketChannel;

import org.neo4j.driver.internal.util.BytePrinter;
import org.neo4j.driver.v1.exceptions.ClientException;

/**
* Wraps a regular socket channel such that read and write will not return until the full buffers given have been sent
* or received, respectively.
* Utility class for common operations.
*/
public class AllOrNothingChannel implements ByteChannel
public final class SocketUtils
{
private final SocketChannel channel;

public AllOrNothingChannel( SocketChannel channel ) throws IOException
private SocketUtils()
{
this.channel = channel;
this.channel.configureBlocking( true );
throw new UnsupportedOperationException( "Do not instantiate" );
}

@Override
public int read( ByteBuffer buf ) throws IOException
public static void blockingRead(ByteChannel channel, ByteBuffer buf) throws IOException
{
int toRead = buf.remaining();
while ( buf.remaining() > 0 )
while(buf.hasRemaining())
{
int read = channel.read( buf );
if ( read == -1 )
if (channel.read( buf ) < 0)
{
throw new ClientException( String.format(
"Connection terminated while receiving data. This can happen due to network " +
"instabilities, or due to restarts of the database. Expected %s bytes, received %s.",
buf.limit(), BytePrinter.hex( buf ) ) );
}
}
return toRead;
}

@Override
public int write( ByteBuffer buf ) throws IOException
public static void blockingWrite(ByteChannel channel, ByteBuffer buf) throws IOException
{
int toWrite = buf.remaining();
while( buf.remaining() > 0 )
while(buf.hasRemaining())
{
int write = channel.write( buf );
if( write == -1 )
if (channel.write( buf ) < 0)
{
throw new ClientException( String.format(
"Connection terminated while sending data. This can happen due to network " +
"instabilities, or due to restarts of the database. Expected %s bytes, wrote %s.",
buf.limit(), BytePrinter.hex( buf ) ) );
}
}
return toWrite;
}

@Override
public boolean isOpen()
{
return channel.isOpen();
}

@Override
public void close() throws IOException
{
channel.close();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,152 @@
/**
* Copyright (c) 2002-2016 "Neo Technology,"
* Network Engine for Objects in Lund AB [http://neotechnology.com]
*
* This file is part of Neo4j.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.neo4j.driver.internal.connector.socket;

import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.ByteChannel;
import java.util.ArrayList;
import java.util.List;

import org.neo4j.driver.v1.exceptions.ClientException;

import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

public class SocketUtilsTest
{
@Rule
public ExpectedException exception = ExpectedException.none();

@Test
public void shouldReadAllBytes() throws IOException
{
// Given
ByteBuffer buffer = ByteBuffer.allocate( 4 );
ByteAtATimeChannel channel = new ByteAtATimeChannel( new byte[]{0, 1, 2, 3} );

// When
SocketUtils.blockingRead(channel, buffer );
buffer.flip();

// Then
assertThat(buffer.get(), equalTo((byte) 0));
assertThat(buffer.get(), equalTo((byte) 1));
assertThat(buffer.get(), equalTo((byte) 2));
assertThat(buffer.get(), equalTo((byte) 3));
}

@Test
public void shouldFailIfConnectionFailsWhileReading() throws IOException
{
// Given
ByteBuffer buffer = ByteBuffer.allocate( 4 );
ByteChannel channel = mock( ByteChannel.class );
when(channel.read( buffer )).thenReturn( -1 );

//Expect
exception.expect( ClientException.class );

// When
SocketUtils.blockingRead(channel, buffer );
}

@Test
public void shouldWriteAllBytes() throws IOException
{
// Given
ByteBuffer buffer = ByteBuffer.wrap( new byte[]{0, 1, 2, 3});
ByteAtATimeChannel channel = new ByteAtATimeChannel( new byte[0] );

// When
SocketUtils.blockingWrite(channel, buffer );

// Then
assertThat(channel.writtenBytes.get(0), equalTo((byte) 0));
assertThat(channel.writtenBytes.get(1), equalTo((byte) 1));
assertThat(channel.writtenBytes.get(2), equalTo((byte) 2));
assertThat(channel.writtenBytes.get(3), equalTo((byte) 3));
}

@Test
public void shouldFailIfConnectionFailsWhileWriting() throws IOException
{
// Given
ByteBuffer buffer = ByteBuffer.allocate( 4 );
ByteChannel channel = mock( ByteChannel.class );
when(channel.write( buffer )).thenReturn( -1 );

//Expect
exception.expect( ClientException.class );

// When
SocketUtils.blockingWrite(channel, buffer );
}

private static class ByteAtATimeChannel implements ByteChannel
{

private final byte[] bytes;
private int index = 0;
private List<Byte> writtenBytes = new ArrayList<>( );

private ByteAtATimeChannel( byte[] bytes )
{
this.bytes = bytes;
}

@Override
public int read( ByteBuffer dst ) throws IOException
{
if (index >= bytes.length)
{
return -1;
}

dst.put( bytes[index++]);
return 1;
}

@Override
public int write( ByteBuffer src ) throws IOException
{
writtenBytes.add( src.get() );
return 1;
}

@Override
public boolean isOpen()
{
return true;
}

@Override
public void close() throws IOException
{

}
}

}