diff --git a/driver/src/main/java/org/neo4j/driver/internal/reactive/InternalRxSession.java b/driver/src/main/java/org/neo4j/driver/internal/reactive/InternalRxSession.java index 1e5aeb7d91..70f57c96d0 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/reactive/InternalRxSession.java +++ b/driver/src/main/java/org/neo4j/driver/internal/reactive/InternalRxSession.java @@ -37,6 +37,7 @@ import org.neo4j.driver.reactive.RxTransactionWork; import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; public class InternalRxSession extends AbstractRxQueryRunner implements RxSession { private final NetworkSession session; @@ -69,7 +70,8 @@ public Publisher beginTransaction(TransactionConfig config) { return txFuture; }, () -> new IllegalStateException( - "Unexpected condition, begin transaction call has completed successfully with transaction being null")); + "Unexpected condition, begin transaction call has completed successfully with transaction being null"), + (tx) -> Mono.fromDirect(tx.close()).subscribe()); } private Publisher beginTransaction(AccessMode mode, TransactionConfig config) { @@ -86,7 +88,8 @@ private Publisher beginTransaction(AccessMode mode, Trans return txFuture; }, () -> new IllegalStateException( - "Unexpected condition, begin transaction call has completed successfully with transaction being null")); + "Unexpected condition, begin transaction call has completed successfully with transaction being null"), + (tx) -> Mono.fromDirect(tx.close()).subscribe()); } @Override diff --git a/driver/src/main/java/org/neo4j/driver/internal/reactive/RxUtils.java b/driver/src/main/java/org/neo4j/driver/internal/reactive/RxUtils.java index d0048bc198..aa86457bf6 100644 --- a/driver/src/main/java/org/neo4j/driver/internal/reactive/RxUtils.java +++ b/driver/src/main/java/org/neo4j/driver/internal/reactive/RxUtils.java @@ -18,8 +18,11 @@ */ package org.neo4j.driver.internal.reactive; +import static java.util.Objects.requireNonNull; + import java.util.Optional; import java.util.concurrent.CompletionStage; +import java.util.function.Consumer; import java.util.function.Supplier; import org.neo4j.driver.internal.util.Futures; import org.reactivestreams.Publisher; @@ -28,6 +31,7 @@ public class RxUtils { /** * The publisher created by this method will either succeed without publishing anything or fail with an error. + * * @param supplier supplies a {@link CompletionStage}. * @return A publisher that publishes nothing on completion or fails with an error. */ @@ -48,23 +52,79 @@ public static Publisher createEmptyPublisher(Supplier} that MUST produce a non-null result when completed successfully. * @param nullResultThrowableSupplier supplies a {@link Throwable} that is used as an error when the supplied completion stage completes successfully with * null. + * @param cancellationHandler handles cancellation, may be used to release associated resources * @param the type of the item to publish. * @return A publisher that succeeds exactly one item or fails with an error. */ public static Publisher createSingleItemPublisher( - Supplier> supplier, Supplier nullResultThrowableSupplier) { - return Mono.create(sink -> supplier.get().whenComplete((item, completionError) -> { - if (completionError == null) { - if (item != null) { - sink.success(item); - } else { - sink.error(nullResultThrowableSupplier.get()); + Supplier> supplier, + Supplier nullResultThrowableSupplier, + Consumer cancellationHandler) { + requireNonNull(supplier, "supplier must not be null"); + requireNonNull(nullResultThrowableSupplier, "nullResultThrowableSupplier must not be null"); + requireNonNull(cancellationHandler, "cancellationHandler must not be null"); + return Mono.create(sink -> { + SinkState state = new SinkState(); + sink.onRequest(ignored -> { + CompletionStage stage; + synchronized (state) { + if (state.isCancelled()) { + return; + } + if (state.getStage() != null) { + return; + } + stage = supplier.get(); + state.setStage(stage); } - } else { - Throwable error = Optional.ofNullable(Futures.completionExceptionCause(completionError)) - .orElse(completionError); - sink.error(error); - } - })); + stage.whenComplete((item, completionError) -> { + if (completionError == null) { + if (item != null) { + sink.success(item); + } else { + sink.error(nullResultThrowableSupplier.get()); + } + } else { + Throwable error = Optional.ofNullable(Futures.completionExceptionCause(completionError)) + .orElse(completionError); + sink.error(error); + } + }); + }); + sink.onCancel(() -> { + CompletionStage stage; + synchronized (state) { + if (state.isCancelled()) { + return; + } + state.setCancelled(true); + stage = state.getStage(); + } + if (stage != null) { + stage.whenComplete((value, ignored) -> cancellationHandler.accept(value)); + } + }); + }); + } + + private static class SinkState { + private CompletionStage stage; + private boolean cancelled; + + public CompletionStage getStage() { + return stage; + } + + public void setStage(CompletionStage stage) { + this.stage = stage; + } + + public boolean isCancelled() { + return cancelled; + } + + public void setCancelled(boolean cancelled) { + this.cancelled = cancelled; + } } } diff --git a/driver/src/test/java/org/neo4j/driver/internal/reactive/RxUtilsTest.java b/driver/src/test/java/org/neo4j/driver/internal/reactive/RxUtilsTest.java index b0bf9ceae8..69556a7e8b 100644 --- a/driver/src/test/java/org/neo4j/driver/internal/reactive/RxUtilsTest.java +++ b/driver/src/test/java/org/neo4j/driver/internal/reactive/RxUtilsTest.java @@ -18,16 +18,22 @@ */ package org.neo4j.driver.internal.reactive; +import static org.mockito.BDDMockito.then; import static org.mockito.Mockito.mock; import static org.neo4j.driver.internal.reactive.RxUtils.createEmptyPublisher; import static org.neo4j.driver.internal.reactive.RxUtils.createSingleItemPublisher; import static org.neo4j.driver.internal.util.Futures.failedFuture; import java.util.concurrent.CompletableFuture; +import java.util.concurrent.CompletionStage; +import java.util.function.Consumer; import java.util.function.Predicate; +import java.util.function.Supplier; import org.junit.jupiter.api.Test; import org.neo4j.driver.internal.util.Futures; import org.reactivestreams.Publisher; +import org.reactivestreams.Subscription; +import reactor.core.publisher.BaseSubscriber; import reactor.test.StepVerifier; class RxUtilsTest { @@ -47,15 +53,16 @@ void emptyPublisherShouldErrorWhenSupplierErrors() { @Test void singleItemPublisherShouldCompleteWithValue() { - Publisher publisher = - createSingleItemPublisher(() -> CompletableFuture.completedFuture("One"), () -> mock(Throwable.class)); + Publisher publisher = createSingleItemPublisher( + () -> CompletableFuture.completedFuture("One"), () -> mock(Throwable.class), (ignored) -> {}); StepVerifier.create(publisher).expectNext("One").verifyComplete(); } @Test void singleItemPublisherShouldErrorWhenFutureCompletesWithNull() { Throwable error = mock(Throwable.class); - Publisher publisher = createSingleItemPublisher(Futures::completedWithNull, () -> error); + Publisher publisher = + createSingleItemPublisher(Futures::completedWithNull, () -> error, (ignored) -> {}); StepVerifier.create(publisher).verifyErrorMatches(actualError -> error == actualError); } @@ -63,8 +70,41 @@ void singleItemPublisherShouldErrorWhenFutureCompletesWithNull() { @Test void singleItemPublisherShouldErrorWhenSupplierErrors() { RuntimeException error = mock(RuntimeException.class); - Publisher publisher = createSingleItemPublisher(() -> failedFuture(error), () -> mock(Throwable.class)); + Publisher publisher = + createSingleItemPublisher(() -> failedFuture(error), () -> mock(Throwable.class), (ignored) -> {}); StepVerifier.create(publisher).verifyErrorMatches(actualError -> error == actualError); } + + @Test + void singleItemPublisherShouldHandleCancellationAfterRequestProcessingBegins() { + // GIVEN + String value = "value"; + CompletableFuture valueFuture = new CompletableFuture<>(); + CompletableFuture supplierInvokedFuture = new CompletableFuture<>(); + Supplier> valueFutureSupplier = () -> { + supplierInvokedFuture.complete(null); + return valueFuture; + }; + @SuppressWarnings("unchecked") + Consumer cancellationHandler = mock(Consumer.class); + Publisher publisher = + createSingleItemPublisher(valueFutureSupplier, () -> mock(Throwable.class), cancellationHandler); + + // WHEN + publisher.subscribe(new BaseSubscriber() { + @Override + protected void hookOnSubscribe(Subscription subscription) { + subscription.request(1); + supplierInvokedFuture.thenAccept(ignored -> { + subscription.cancel(); + valueFuture.complete(value); + }); + } + }); + + // THEN + valueFuture.join(); + then(cancellationHandler).should().accept(value); + } }