28
28
import java .io .IOException ;
29
29
import java .util .ArrayList ;
30
30
import java .util .Arrays ;
31
- import java .util .Collections ;
32
31
import java .util .Iterator ;
33
32
import java .util .List ;
34
33
import java .util .concurrent .Future ;
46
45
import org .neo4j .driver .v1 .TransactionWork ;
47
46
import org .neo4j .driver .v1 .Value ;
48
47
import org .neo4j .driver .v1 .exceptions .ClientException ;
48
+ import org .neo4j .driver .v1 .exceptions .DatabaseException ;
49
49
import org .neo4j .driver .v1 .exceptions .ServiceUnavailableException ;
50
50
import org .neo4j .driver .v1 .exceptions .SessionExpiredException ;
51
51
import org .neo4j .driver .v1 .exceptions .TransientException ;
54
54
import org .neo4j .driver .v1 .types .Node ;
55
55
import org .neo4j .driver .v1 .util .TestNeo4j ;
56
56
57
+ import static java .util .Collections .emptyIterator ;
57
58
import static org .hamcrest .Matchers .containsString ;
58
59
import static org .hamcrest .Matchers .greaterThanOrEqualTo ;
59
60
import static org .hamcrest .Matchers .instanceOf ;
@@ -371,11 +372,13 @@ public void shouldRunAsyncTransactionWithoutRetries()
371
372
}
372
373
373
374
@ Test
374
- public void shouldRunAsyncTransactionWithRetries ()
375
+ public void shouldRunAsyncTransactionWithRetriesOnAsyncFailures ()
375
376
{
376
- List <Throwable > failures = Arrays .<Throwable >asList ( new ServiceUnavailableException ( "Oh!" ),
377
- new SessionExpiredException ( "Ah!" ), new TransientException ( "Code" , "Message" ) );
378
- InvocationTrackingWork work = new InvocationTrackingWork ( "CREATE (:Node) RETURN 24" , failures );
377
+ InvocationTrackingWork work = new InvocationTrackingWork ( "CREATE (:Node) RETURN 24" ).withAsyncFailures (
378
+ new ServiceUnavailableException ( "Oh!" ),
379
+ new SessionExpiredException ( "Ah!" ),
380
+ new TransientException ( "Code" , "Message" ) );
381
+
379
382
Response <Record > txResponse = session .writeTransactionAsync ( work );
380
383
381
384
Record record = await ( txResponse );
@@ -386,6 +389,23 @@ public void shouldRunAsyncTransactionWithRetries()
386
389
assertEquals ( 1 , countNodesByLabel ( "Node" ) );
387
390
}
388
391
392
+ @ Test
393
+ public void shouldRunAsyncTransactionWithRetriesOnSyncFailures ()
394
+ {
395
+ InvocationTrackingWork work = new InvocationTrackingWork ( "CREATE (:Test) RETURN 12" ).withSyncFailures (
396
+ new TransientException ( "Oh!" , "Deadlock!" ),
397
+ new ServiceUnavailableException ( "Oh! Network Failure" ) );
398
+
399
+ Response <Record > txResponse = session .writeTransactionAsync ( work );
400
+
401
+ Record record = await ( txResponse );
402
+ assertNotNull ( record );
403
+ assertEquals ( 12L , record .get ( 0 ).asLong () );
404
+
405
+ assertEquals ( 3 , work .invocationCount () );
406
+ assertEquals ( 1 , countNodesByLabel ( "Test" ) );
407
+ }
408
+
389
409
@ Test
390
410
public void shouldRunAsyncTransactionThatCanNotBeRetried ()
391
411
{
@@ -406,6 +426,32 @@ public void shouldRunAsyncTransactionThatCanNotBeRetried()
406
426
assertEquals ( 0 , countNodesByLabel ( "Hi" ) );
407
427
}
408
428
429
+ @ Test
430
+ public void shouldRunAsyncTransactionThatCanNotBeRetriedAfterATransientFailure ()
431
+ {
432
+ // first throw TransientException directly from work, retry can happen afterwards
433
+ // then return a future failed with DatabaseException, retry can't happen afterwards
434
+ InvocationTrackingWork work = new InvocationTrackingWork ( "CREATE (:Person) RETURN 1" )
435
+ .withSyncFailures ( new TransientException ( "Oh!" , "Deadlock!" ) )
436
+ .withAsyncFailures ( new DatabaseException ( "Oh!" , "OutOfMemory!" ) );
437
+ Response <Record > txResponse = session .writeTransactionAsync ( work );
438
+
439
+ try
440
+ {
441
+ await ( txResponse );
442
+ fail ( "Exception expected" );
443
+ }
444
+ catch ( Exception e )
445
+ {
446
+ assertThat ( e , instanceOf ( DatabaseException .class ) );
447
+ assertEquals ( 1 , e .getSuppressed ().length );
448
+ assertThat ( e .getSuppressed ()[0 ], instanceOf ( TransientException .class ) );
449
+ }
450
+
451
+ assertEquals ( 2 , work .invocationCount () );
452
+ assertEquals ( 0 , countNodesByLabel ( "Person" ) );
453
+ }
454
+
409
455
private Future <List <Future <Boolean >>> runNestedQueries ( StatementResultCursor inputCursor )
410
456
{
411
457
Promise <List <Future <Boolean >>> resultPromise = GlobalEventExecutor .INSTANCE .newPromise ();
@@ -524,19 +570,27 @@ void killDb()
524
570
private static class InvocationTrackingWork implements TransactionWork <Response <Record >>
525
571
{
526
572
final String query ;
527
- final Iterator <Throwable > failures ;
528
573
final AtomicInteger invocationCount ;
529
574
575
+ Iterator <RuntimeException > asyncFailures = emptyIterator ();
576
+ Iterator <RuntimeException > syncFailures = emptyIterator ();
577
+
530
578
InvocationTrackingWork ( String query )
531
579
{
532
- this ( query , Collections .<Throwable >emptyList () );
580
+ this .query = query ;
581
+ this .invocationCount = new AtomicInteger ();
582
+ }
583
+
584
+ InvocationTrackingWork withAsyncFailures ( RuntimeException ... failures )
585
+ {
586
+ asyncFailures = Arrays .asList ( failures ).iterator ();
587
+ return this ;
533
588
}
534
589
535
- InvocationTrackingWork ( String query , List < Throwable > failures )
590
+ InvocationTrackingWork withSyncFailures ( RuntimeException ... failures )
536
591
{
537
- this .query = query ;
538
- this .failures = failures .iterator ();
539
- this .invocationCount = new AtomicInteger ();
592
+ syncFailures = Arrays .asList ( failures ).iterator ();
593
+ return this ;
540
594
}
541
595
542
596
int invocationCount ()
@@ -549,6 +603,11 @@ public Response<Record> execute( Transaction tx )
549
603
{
550
604
invocationCount .incrementAndGet ();
551
605
606
+ if ( syncFailures .hasNext () )
607
+ {
608
+ throw syncFailures .next ();
609
+ }
610
+
552
611
final InternalPromise <Record > resultPromise = new InternalPromise <>( GlobalEventExecutor .INSTANCE );
553
612
554
613
tx .runAsync ( query ).addListener ( new ResponseListener <StatementResultCursor >()
@@ -597,9 +656,9 @@ private void processFetchResult( Boolean recordAvailable, Throwable error,
597
656
return ;
598
657
}
599
658
600
- if ( failures .hasNext () )
659
+ if ( asyncFailures .hasNext () )
601
660
{
602
- resultPromise .setFailure ( failures .next () );
661
+ resultPromise .setFailure ( asyncFailures .next () );
603
662
}
604
663
else
605
664
{
0 commit comments