Skip to content

Session acquisition update #232

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Sep 22, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 123 additions & 30 deletions driver/src/main/java/org/neo4j/driver/internal/ClusterDriver.java
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,16 @@

import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.concurrent.atomic.AtomicLong;

import org.neo4j.driver.internal.net.BoltServerAddress;
import org.neo4j.driver.internal.security.SecurityPlan;
import org.neo4j.driver.internal.spi.Connection;
import org.neo4j.driver.internal.spi.ConnectionPool;
import org.neo4j.driver.internal.util.Clock;
import org.neo4j.driver.internal.util.ConcurrentRoundRobinSet;
import org.neo4j.driver.internal.util.Consumer;
import org.neo4j.driver.v1.AccessMode;
Expand All @@ -34,91 +38,133 @@
import org.neo4j.driver.v1.Record;
import org.neo4j.driver.v1.Session;
import org.neo4j.driver.v1.StatementResult;
import org.neo4j.driver.v1.Value;
import org.neo4j.driver.v1.exceptions.ClientException;
import org.neo4j.driver.v1.exceptions.ConnectionFailureException;
import org.neo4j.driver.v1.exceptions.ServiceUnavailableException;
import org.neo4j.driver.v1.util.BiFunction;
import org.neo4j.driver.v1.util.Function;

import static java.lang.String.format;

public class ClusterDriver extends BaseDriver
{
private static final String GET_SERVERS = "dbms.cluster.routing.getServers";
private static final long MAX_TTL = Long.MAX_VALUE / 1000L;
private final static Comparator<BoltServerAddress> COMPARATOR = new Comparator<BoltServerAddress>()
{
@Override
public int compare( BoltServerAddress o1, BoltServerAddress o2 )
{
int compare = o1.host().compareTo( o2.host() );
if (compare == 0)
if ( compare == 0 )
{
compare = Integer.compare( o1.port(), o2.port() );
}

return compare;
}
};
private static final int MIN_SERVERS = 2;
private static final int MIN_SERVERS = 1;
private final ConnectionPool connections;
private final BiFunction<Connection,Logger, Session> sessionProvider;

private final ConcurrentRoundRobinSet<BoltServerAddress> routingServers = new ConcurrentRoundRobinSet<>(COMPARATOR);
private final ConcurrentRoundRobinSet<BoltServerAddress> readServers = new ConcurrentRoundRobinSet<>(COMPARATOR);
private final ConcurrentRoundRobinSet<BoltServerAddress> writeServers = new ConcurrentRoundRobinSet<>(COMPARATOR);
private final BiFunction<Connection,Logger,Session> sessionProvider;
private final Clock clock;
private final ConcurrentRoundRobinSet<BoltServerAddress> routingServers =
new ConcurrentRoundRobinSet<>( COMPARATOR );
private final ConcurrentRoundRobinSet<BoltServerAddress> readServers = new ConcurrentRoundRobinSet<>( COMPARATOR );
private final ConcurrentRoundRobinSet<BoltServerAddress> writeServers = new ConcurrentRoundRobinSet<>( COMPARATOR );
private final AtomicLong expires = new AtomicLong( 0L );

public ClusterDriver( BoltServerAddress seedAddress,
ConnectionPool connections,
SecurityPlan securityPlan,
BiFunction<Connection,Logger, Session> sessionProvider,
BiFunction<Connection,Logger,Session> sessionProvider,
Clock clock,
Logging logging )
{
super( securityPlan, logging );
routingServers.add( seedAddress );
this.connections = connections;
this.sessionProvider = sessionProvider;
this.clock = clock;
checkServers();
}

private void checkServers()
{
synchronized ( routingServers )
{
if ( routingServers.size() < MIN_SERVERS ||
if ( expires.get() < clock.millis() ||
routingServers.size() < MIN_SERVERS ||
readServers.isEmpty() ||
writeServers.isEmpty())
writeServers.isEmpty() )
{
getServers();
}
}
}

private Set<BoltServerAddress> forgetAllServers()
{
final Set<BoltServerAddress> seen = new HashSet<>();
seen.addAll( routingServers );
seen.addAll( readServers );
seen.addAll( writeServers );
routingServers.clear();
readServers.clear();
writeServers.clear();
return seen;
}

private long calculateNewExpiry( Record record )
{
long ttl = record.get( "ttl" ).asLong();
long nextExpiry = clock.millis() + 1000L * ttl;
if ( ttl < 0 || ttl >= MAX_TTL || nextExpiry < 0 )
{
return Long.MAX_VALUE;
}
else
{
return nextExpiry;
}
}

//must be called from a synchronized block
private void getServers()
{
BoltServerAddress address = null;
try
{
boolean success = false;
while ( !routingServers.isEmpty() && !success )

ConcurrentRoundRobinSet<BoltServerAddress> routers = new ConcurrentRoundRobinSet<>( routingServers );
final Set<BoltServerAddress> seen = forgetAllServers();
while ( !routers.isEmpty() && !success )
{
address = routingServers.hop();
address = routers.hop();
success = call( address, GET_SERVERS, new Consumer<Record>()
{
@Override
public void accept( Record record )
{
BoltServerAddress newAddress = new BoltServerAddress( record.get( "address" ).asString() );
switch ( record.get( "mode" ).asString().toUpperCase() )
expires.set( calculateNewExpiry( record ) );
List<ServerInfo> servers = servers( record );
for ( ServerInfo server : servers )
{
case "READ":
readServers.add( newAddress );
break;
case "WRITE":
writeServers.add( newAddress );
break;
case "ROUTE":
routingServers.add( newAddress );
break;
seen.removeAll( server.addresses() );
switch ( server.role() )
{
case "READ":
readServers.addAll( server.addresses() );
break;
case "WRITE":
writeServers.addAll( server.addresses() );
break;
case "ROUTE":
routingServers.addAll( server.addresses() );
break;
}
}
}
} );
Expand All @@ -127,6 +173,12 @@ public void accept( Record record )
{
throw new ServiceUnavailableException( "Run out of servers" );
}

//the server no longer think we should care about these
for ( BoltServerAddress remove : seen )
{
connections.purge( remove );
}
}
catch ( ClientException ex )
{
Expand All @@ -137,7 +189,7 @@ public void accept( Record record )
this.close();
throw new ServiceUnavailableException(
String.format( "Server %s couldn't perform discovery",
address == null ? "`UNKNOWN`" : address.toString()), ex );
address == null ? "`UNKNOWN`" : address.toString() ), ex );
}
else
{
Expand All @@ -146,14 +198,55 @@ public void accept( Record record )
}
}

private static class ServerInfo
{
private final List<BoltServerAddress> addresses;
private final String role;

public ServerInfo( List<BoltServerAddress> addresses, String role )
{
this.addresses = addresses;
this.role = role;
}

public String role()
{
return role;
}

List<BoltServerAddress> addresses()
{
return addresses;
}
}

private List<ServerInfo> servers( Record record )
{
return record.get( "servers" ).asList( new Function<Value,ServerInfo>()
{
@Override
public ServerInfo apply( Value value )
{
return new ServerInfo( value.get( "addresses" ).asList( new Function<Value,BoltServerAddress>()
{
@Override
public BoltServerAddress apply( Value value )
{
return new BoltServerAddress( value.asString() );
}
} ), value.get( "role" ).asString() );
}
} );
}

//must be called from a synchronized method
private boolean call( BoltServerAddress address, String procedureName, Consumer<Record> recorder )
{
Connection acquire = null;
Session session = null;
try
{
acquire = connections.acquire(address);
acquire = connections.acquire( address );
session = sessionProvider.apply( acquire, log );

StatementResult records = session.run( format( "CALL %s", procedureName ) );
Expand Down Expand Up @@ -217,19 +310,19 @@ public void onWriteFailure( BoltServerAddress address )
log );
}

private Connection acquireConnection( AccessMode mode )
private Connection acquireConnection( AccessMode role )
{
//Potentially rediscover servers if we are not happy with our current knowledge
checkServers();

switch ( mode )
switch ( role )
{
case READ:
return connections.acquire( readServers.hop() );
case WRITE:
return connections.acquire( writeServers.hop() );
default:
throw new ClientException( mode + " is not supported for creating new sessions" );
throw new ClientException( role + " is not supported for creating new sessions" );
}
}

Expand All @@ -255,13 +348,13 @@ Set<BoltServerAddress> routingServers()
//For testing
Set<BoltServerAddress> readServers()
{
return Collections.unmodifiableSet(readServers);
return Collections.unmodifiableSet( readServers );
}

//For testing
Set<BoltServerAddress> writeServers()
{
return Collections.unmodifiableSet( writeServers);
return Collections.unmodifiableSet( writeServers );
}

//For testing
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public StatementResult run( Statement statement )
}
catch ( ClientException e )
{
if ( e.code().equals( "Neo.ClientError.General.ForbiddenOnFollower" ) )
if ( e.code().equals( "Neo.ClientError.Cluster.NotALeader" ) )
{
onError.onWriteFailure( connection.address() );
throw new SessionExpiredException(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,6 @@ private SessionExpiredException failedWrite()

private boolean isFailedToWrite( ClientException e )
{
return e.code().equals( "Neo.ClientError.General.ForbiddenOnFollower" );
return e.code().equals( "Neo.ClientError.Cluster.NotALeader" );
}
}
Loading