Skip to content

Commit cfa9327

Browse files
committed
Null-safe error code checks in connection layer
1 parent 2d694ff commit cfa9327

File tree

3 files changed

+49
-24
lines changed

3 files changed

+49
-24
lines changed

driver/src/main/java/org/neo4j/driver/internal/cluster/RoutingPooledConnection.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
package org.neo4j.driver.internal.cluster;
2020

2121
import java.util.Map;
22+
import java.util.Objects;
2223

2324
import org.neo4j.driver.internal.RoutingErrorHandler;
2425
import org.neo4j.driver.internal.net.BoltServerAddress;
@@ -283,7 +284,8 @@ private RuntimeException handledClientException( ClientException e )
283284

284285
private static boolean isFailureToWrite( ClientException e )
285286
{
286-
return e.code().equals( "Neo.ClientError.Cluster.NotALeader" ) ||
287-
e.code().equals( "Neo.ClientError.General.ForbiddenOnReadOnlyDatabase" );
287+
String errorCode = e.code();
288+
return Objects.equals( errorCode, "Neo.ClientError.Cluster.NotALeader" ) ||
289+
Objects.equals( errorCode, "Neo.ClientError.General.ForbiddenOnReadOnlyDatabase" );
288290
}
289291
}

driver/src/main/java/org/neo4j/driver/internal/net/pooling/PooledSocketConnection.java

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -291,18 +291,31 @@ public long lastUsedTimestamp()
291291
return lastUsedTimestamp;
292292
}
293293

294-
private boolean isProtocolViolationError(RuntimeException e )
294+
private boolean isProtocolViolationError( RuntimeException e )
295295
{
296-
return e instanceof Neo4jException
297-
&& ((Neo4jException) e).code().startsWith( "Neo.ClientError.Request" );
296+
if ( e instanceof Neo4jException )
297+
{
298+
String errorCode = ((Neo4jException) e).code();
299+
if ( errorCode != null )
300+
{
301+
return errorCode.startsWith( "Neo.ClientError.Request" );
302+
}
303+
}
304+
return false;
298305
}
299306

300307
private boolean isClientOrTransientError( RuntimeException e )
301308
{
302309
// Eg: DatabaseErrors and unknown (no status code or not neo4j exception) cause session to be discarded
303-
return e instanceof Neo4jException
304-
&& (((Neo4jException) e).code().contains( "ClientError" )
305-
|| ((Neo4jException) e).code().contains( "TransientError" ));
310+
if ( e instanceof Neo4jException )
311+
{
312+
String errorCode = ((Neo4jException) e).code();
313+
if ( errorCode != null )
314+
{
315+
return errorCode.contains( "ClientError" ) || errorCode.contains( "TransientError" );
316+
}
317+
}
318+
return false;
306319
}
307320

308321
private void updateLastUsedTimestamp()

driver/src/test/java/org/neo4j/driver/internal/cluster/RoutingPooledConnectionErrorHandlingTest.java

Lines changed: 26 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -127,20 +127,13 @@ public void shouldHandleFailureToWrite()
127127
@Test
128128
public void shouldPropagateThrowable()
129129
{
130-
RuntimeException error = new RuntimeException( "Random error" );
131-
Connector connector = newConnectorWithThrowingConnections( error );
132-
RoutingTable routingTable = newRoutingTable( ADDRESS1, ADDRESS2, ADDRESS3 );
133-
ConnectionPool connectionPool = newConnectionPool( connector, ADDRESS1, ADDRESS2, ADDRESS3 );
134-
LoadBalancer loadBalancer = newLoadBalancer( routingTable, connectionPool );
135-
136-
Connection readConnection = loadBalancer.acquireReadConnection();
137-
verifyThrowablePropagation( readConnection, method, routingTable, connectionPool );
138-
139-
Connection writeConnection = loadBalancer.acquireWriteConnection();
140-
verifyThrowablePropagation( writeConnection, method, routingTable, connectionPool );
130+
testThrowablePropagation( new RuntimeException( "Random error" ) );
131+
}
141132

142-
assertThat( routingTable, containsRouter( ADDRESS3 ) );
143-
assertTrue( connectionPool.hasAddress( ADDRESS3 ) );
133+
@Test
134+
public void shouldPropagateClientExceptionWithoutErrorCode()
135+
{
136+
testThrowablePropagation( new ClientException( null, "Message" ) );
144137
}
145138

146139
private void testHandleFailureToWriteWithWriteConnection( ClientException error )
@@ -199,6 +192,23 @@ private void testHandleFailureToWrite( ClientException error )
199192
assertTrue( connectionPool.hasAddress( ADDRESS3 ) );
200193
}
201194

195+
private void testThrowablePropagation( Throwable error )
196+
{
197+
Connector connector = newConnectorWithThrowingConnections( error );
198+
RoutingTable routingTable = newRoutingTable( ADDRESS1, ADDRESS2, ADDRESS3 );
199+
ConnectionPool connectionPool = newConnectionPool( connector, ADDRESS1, ADDRESS2, ADDRESS3 );
200+
LoadBalancer loadBalancer = newLoadBalancer( routingTable, connectionPool );
201+
202+
Connection readConnection = loadBalancer.acquireReadConnection();
203+
verifyThrowablePropagation( readConnection, routingTable, connectionPool, error.getClass() );
204+
205+
Connection writeConnection = loadBalancer.acquireWriteConnection();
206+
verifyThrowablePropagation( writeConnection, routingTable, connectionPool, error.getClass() );
207+
208+
assertThat( routingTable, containsRouter( ADDRESS3 ) );
209+
assertTrue( connectionPool.hasAddress( ADDRESS3 ) );
210+
}
211+
202212
private void verifyServiceUnavailableHandling( Connection connection, RoutingTable routingTable,
203213
ConnectionPool connectionPool )
204214
{
@@ -220,8 +230,8 @@ private void verifyServiceUnavailableHandling( Connection connection, RoutingTab
220230
}
221231
}
222232

223-
private void verifyThrowablePropagation( Connection connection, ConnectionMethod method, RoutingTable routingTable,
224-
ConnectionPool connectionPool )
233+
private <T extends Throwable> void verifyThrowablePropagation( Connection connection, RoutingTable routingTable,
234+
ConnectionPool connectionPool, Class<T> expectedClass )
225235
{
226236
try
227237
{
@@ -230,7 +240,7 @@ private void verifyThrowablePropagation( Connection connection, ConnectionMethod
230240
}
231241
catch ( Exception e )
232242
{
233-
assertThat( e, instanceOf( RuntimeException.class ) );
243+
assertThat( e, instanceOf( expectedClass ) );
234244

235245
BoltServerAddress address = connection.boltServerAddress();
236246
assertThat( routingTable, containsRouter( address ) );

0 commit comments

Comments
 (0)