Skip to content

Add support for cancellation on reactive session run #1457

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 18, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,13 @@ public CompletionStage<ResultCursor> runAsync(Query query, TransactionConfig con
.thenApply(cursor -> cursor); // convert the return type
}

public CompletionStage<RxResultCursor> runRx(Query query, TransactionConfig config) {
public CompletionStage<RxResultCursor> runRx(
Query query, TransactionConfig config, CompletionStage<RxResultCursor> cursorPublishStage) {
var newResultCursorStage = buildResultCursorFactory(query, config).thenCompose(ResultCursorFactory::rxResult);

resultCursorStage = newResultCursorStage.exceptionally(error -> null);
resultCursorStage = newResultCursorStage
.thenCompose(cursor -> cursor == null ? CompletableFuture.completedFuture(null) : cursorPublishStage)
.exceptionally(throwable -> null);
return newResultCursorStage;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ public CompletionStage<AsyncResultCursor> asyncResult() {
@Override
public CompletionStage<RxResultCursor> rxResult() {
connection.writeAndFlush(runMessage, runHandler);
return runFuture.handle((ignored, error) -> new RxResultCursorImpl(error, runHandler, pullHandler));
return runFuture.handle(
(ignored, error) -> new RxResultCursorImpl(error, runHandler, pullHandler, connection::release));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,13 @@ public interface RxResultCursor extends Subscription, FailableCursor {
boolean isDone();

Throwable getRunError();

/**
* Rolls back this instance by releasing connection with RESET.
* <p>
* This must never be called on a published instance.
* @return reset completion stage
* @since 5.11
*/
CompletionStage<Void> rollback();
}
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.function.BiConsumer;
import java.util.function.Supplier;
import org.neo4j.driver.Record;
import org.neo4j.driver.exceptions.TransactionNestingException;
import org.neo4j.driver.internal.handlers.RunResponseHandler;
Expand All @@ -41,23 +42,30 @@ public class RxResultCursorImpl implements RxResultCursor {
private final RunResponseHandler runHandler;
private final PullResponseHandler pullHandler;
private final Throwable runResponseError;
private final Supplier<CompletionStage<Void>> connectionReleaseSupplier;
private boolean runErrorSurfaced;
private final CompletableFuture<ResultSummary> summaryFuture = new CompletableFuture<>();
private boolean summaryFutureExposed;
private boolean resultConsumed;
private RecordConsumerStatus consumerStatus = NOT_INSTALLED;

// for testing only
public RxResultCursorImpl(RunResponseHandler runHandler, PullResponseHandler pullHandler) {
this(null, runHandler, pullHandler);
this(null, runHandler, pullHandler, () -> CompletableFuture.completedFuture(null));
}

public RxResultCursorImpl(Throwable runError, RunResponseHandler runHandler, PullResponseHandler pullHandler) {
public RxResultCursorImpl(
Throwable runError,
RunResponseHandler runHandler,
PullResponseHandler pullHandler,
Supplier<CompletionStage<Void>> connectionReleaseSupplier) {
Objects.requireNonNull(runHandler);
Objects.requireNonNull(pullHandler);

this.runResponseError = runError;
this.runHandler = runHandler;
this.pullHandler = pullHandler;
this.connectionReleaseSupplier = connectionReleaseSupplier;
installSummaryConsumer();
}

Expand Down Expand Up @@ -130,6 +138,12 @@ public Throwable getRunError() {
return runResponseError;
}

@Override
public CompletionStage<Void> rollback() {
summaryFuture.complete(null);
return connectionReleaseSupplier.get();
}

public CompletionStage<ResultSummary> summaryStage() {
if (!isDone() && !resultConsumed) // the summary is called before record streaming
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,19 @@

import java.util.Set;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import org.neo4j.driver.AccessMode;
import org.neo4j.driver.Bookmark;
import org.neo4j.driver.Query;
import org.neo4j.driver.TransactionConfig;
import org.neo4j.driver.exceptions.ClientException;
import org.neo4j.driver.exceptions.TransactionNestingException;
import org.neo4j.driver.internal.async.NetworkSession;
import org.neo4j.driver.internal.async.UnmanagedTransaction;
import org.neo4j.driver.internal.cursor.RxResultCursor;
import org.neo4j.driver.internal.util.Futures;
import org.neo4j.driver.reactive.RxResult;
import org.neo4j.driver.reactivestreams.ReactiveResult;
Expand Down Expand Up @@ -142,6 +147,73 @@ public Set<Bookmark> lastBookmarks() {
return session.lastBookmarks();
}

protected <T> Publisher<T> run(Query query, TransactionConfig config, Function<RxResultCursor, T> cursorToResult) {
var cursorPublishFuture = new CompletableFuture<RxResultCursor>();
var cursorReference = new AtomicReference<RxResultCursor>();

return createSingleItemPublisher(
() -> runAsStage(query, config, cursorPublishFuture)
.thenApply(cursor -> {
cursorReference.set(cursor);
return cursor;
})
.thenApply(cursorToResult),
() -> new IllegalStateException(
"Unexpected condition, run call has completed successfully with result being null"),
value -> {
if (value != null) {
cursorReference.get().rollback().whenComplete((unused, throwable) -> {
if (throwable != null) {
cursorPublishFuture.completeExceptionally(throwable);
} else {
cursorPublishFuture.complete(null);
}
});
}
})
.doOnNext(value -> cursorPublishFuture.complete(cursorReference.get()))
.doOnError(cursorPublishFuture::completeExceptionally);
}

private CompletionStage<RxResultCursor> runAsStage(
Query query, TransactionConfig config, CompletionStage<RxResultCursor> finalStage) {
CompletionStage<RxResultCursor> cursorStage;
try {
cursorStage = session.runRx(query, config, finalStage);
} catch (Throwable t) {
cursorStage = Futures.failedFuture(t);
}

return cursorStage
.handle((cursor, throwable) -> {
if (throwable != null) {
return this.<RxResultCursor>releaseConnectionAndRethrow(throwable);
} else {
var runError = cursor.getRunError();
if (runError != null) {
return this.<RxResultCursor>releaseConnectionAndRethrow(runError);
} else {
return CompletableFuture.completedFuture(cursor);
}
}
})
.thenCompose(stage -> stage);
}

private <T> CompletionStage<T> releaseConnectionAndRethrow(Throwable throwable) {
return session.releaseConnectionAsync().handle((ignored, releaseThrowable) -> {
if (releaseThrowable != null) {
throw Futures.combineErrors(throwable, releaseThrowable);
} else {
if (throwable instanceof RuntimeException e) {
throw e;
} else {
throw new CompletionException(throwable);
}
}
});
}

protected <T> Publisher<T> doClose() {
return createEmptyPublisher(session::closeAsync);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,21 +23,17 @@

import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.Flow.Publisher;
import org.neo4j.driver.AccessMode;
import org.neo4j.driver.Bookmark;
import org.neo4j.driver.Query;
import org.neo4j.driver.TransactionConfig;
import org.neo4j.driver.internal.async.NetworkSession;
import org.neo4j.driver.internal.async.UnmanagedTransaction;
import org.neo4j.driver.internal.cursor.RxResultCursor;
import org.neo4j.driver.internal.util.Futures;
import org.neo4j.driver.reactive.ReactiveResult;
import org.neo4j.driver.reactive.ReactiveSession;
import org.neo4j.driver.reactive.ReactiveTransaction;
import org.neo4j.driver.reactive.ReactiveTransactionCallback;
import reactor.core.publisher.Mono;

public class InternalReactiveSession extends AbstractReactiveSession<ReactiveTransaction>
implements ReactiveSession, BaseReactiveQueryRunner {
Expand Down Expand Up @@ -89,30 +85,7 @@ public Publisher<ReactiveResult> run(Query query) {

@Override
public Publisher<ReactiveResult> run(Query query, TransactionConfig config) {
CompletionStage<RxResultCursor> cursorStage;
try {
cursorStage = session.runRx(query, config);
} catch (Throwable t) {
cursorStage = Futures.failedFuture(t);
}

return publisherToFlowPublisher(Mono.fromCompletionStage(cursorStage)
.onErrorResume(error -> Mono.fromCompletionStage(session.releaseConnectionAsync())
.onErrorMap(releaseError -> Futures.combineErrors(error, releaseError))
.then(Mono.error(error)))
.flatMap(cursor -> {
Mono<RxResultCursor> publisher;
var runError = cursor.getRunError();
if (runError != null) {
publisher = Mono.fromCompletionStage(session.releaseConnectionAsync())
.onErrorMap(releaseError -> Futures.combineErrors(runError, releaseError))
.then(Mono.error(runError));
} else {
publisher = Mono.just(cursor);
}
return publisher;
})
.map(InternalReactiveResult::new));
return publisherToFlowPublisher(run(query, config, InternalReactiveResult::new));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ public RxResult run(Query query) {
public RxResult run(Query query, TransactionConfig config) {
return new InternalRxResult(() -> {
var resultCursorFuture = new CompletableFuture<RxResultCursor>();
session.runRx(query, config).whenComplete((cursor, completionError) -> {
session.runRx(query, config, resultCursorFuture).whenComplete((cursor, completionError) -> {
if (cursor != null) {
resultCursorFuture.complete(cursor);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ public static <T> Publisher<T> createEmptyPublisher(Supplier<CompletionStage<Voi
* @param <T> the type of the item to publish.
* @return A publisher that succeeds exactly one item or fails with an error.
*/
public static <T> Publisher<T> createSingleItemPublisher(
public static <T> Mono<T> createSingleItemPublisher(
Supplier<CompletionStage<T>> supplier,
Supplier<Throwable> nullResultThrowableSupplier,
Consumer<T> cancellationHandler) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,18 @@

import java.util.HashSet;
import java.util.Set;
import java.util.concurrent.CompletionStage;
import org.neo4j.driver.AccessMode;
import org.neo4j.driver.Bookmark;
import org.neo4j.driver.Query;
import org.neo4j.driver.TransactionConfig;
import org.neo4j.driver.internal.async.NetworkSession;
import org.neo4j.driver.internal.async.UnmanagedTransaction;
import org.neo4j.driver.internal.cursor.RxResultCursor;
import org.neo4j.driver.internal.reactive.AbstractReactiveSession;
import org.neo4j.driver.internal.util.Futures;
import org.neo4j.driver.reactivestreams.ReactiveResult;
import org.neo4j.driver.reactivestreams.ReactiveSession;
import org.neo4j.driver.reactivestreams.ReactiveTransaction;
import org.neo4j.driver.reactivestreams.ReactiveTransactionCallback;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Mono;

public class InternalReactiveSession extends AbstractReactiveSession<ReactiveTransaction>
implements ReactiveSession, BaseReactiveQueryRunner {
Expand Down Expand Up @@ -83,30 +79,7 @@ public Publisher<ReactiveResult> run(Query query) {

@Override
public Publisher<ReactiveResult> run(Query query, TransactionConfig config) {
CompletionStage<RxResultCursor> cursorStage;
try {
cursorStage = session.runRx(query, config);
} catch (Throwable t) {
cursorStage = Futures.failedFuture(t);
}

return Mono.fromCompletionStage(cursorStage)
.onErrorResume(error -> Mono.fromCompletionStage(session.releaseConnectionAsync())
.onErrorMap(releaseError -> Futures.combineErrors(error, releaseError))
.then(Mono.error(error)))
.flatMap(cursor -> {
Mono<RxResultCursor> publisher;
var runError = cursor.getRunError();
if (runError != null) {
publisher = Mono.fromCompletionStage(session.releaseConnectionAsync())
.onErrorMap(releaseError -> Futures.combineErrors(runError, releaseError))
.then(Mono.error(runError));
} else {
publisher = Mono.just(cursor);
}
return publisher;
})
.map(InternalReactiveResult::new);
return run(query, config, InternalReactiveResult::new);
}

@Override
Expand Down
Loading