diff --git a/spring-tx/src/main/java/org/springframework/transaction/interceptor/TransactionAspectSupport.java b/spring-tx/src/main/java/org/springframework/transaction/interceptor/TransactionAspectSupport.java index db051b108909..2175a92da189 100644 --- a/spring-tx/src/main/java/org/springframework/transaction/interceptor/TransactionAspectSupport.java +++ b/spring-tx/src/main/java/org/springframework/transaction/interceptor/TransactionAspectSupport.java @@ -834,15 +834,18 @@ public Object invokeWithinTransaction(Method method, @Nullable Class targetCl try { // This is an around advice: Invoke the next interceptor in the chain. // This will normally result in a target object being invoked. - Mono retVal = (Mono) invocation.proceedWithInvocation(); - return retVal - .onErrorResume(ex -> completeTransactionAfterThrowing(it, ex).then(Mono.error(ex))).materialize() - .flatMap(signal -> { - if (signal.isOnComplete() || signal.isOnNext()) { - return commitTransactionAfterReturning(it).thenReturn(signal); - } - return Mono.just(signal); - }).dematerialize(); + // Need re-wrapping of ReactiveTransaction until we get hold of the exception + // through usingWhen. + return Mono.usingWhen(Mono.just(it), s -> { + try { + return (Mono) invocation.proceedWithInvocation(); + } + catch (Throwable throwable) { + return Mono.error(throwable); + } + }, this::commitTransactionAfterReturning, s -> Mono.empty()) + .onErrorResume(ex -> completeTransactionAfterThrowing(it, ex) + .then(Mono.error(ex))); } catch (Throwable ex) { // target invocation exception @@ -860,15 +863,19 @@ public Object invokeWithinTransaction(Method method, @Nullable Class targetCl try { // This is an around advice: Invoke the next interceptor in the chain. // This will normally result in a target object being invoked. - Flux retVal = Flux.from(this.adapter.toPublisher(invocation.proceedWithInvocation())); - return retVal - .onErrorResume(ex -> completeTransactionAfterThrowing(it, ex).then(Mono.error(ex))) - .materialize().flatMap(signal -> { - if (signal.isOnComplete()) { - return commitTransactionAfterReturning(it).materialize(); - } - return Mono.just(signal); - }).dematerialize(); + // Need re-wrapping of ReactiveTransaction until we get hold of the exception + // through usingWhen. + return Flux.usingWhen(Mono.just(it), s -> { + try { + return this.adapter.toPublisher( + invocation.proceedWithInvocation()); + } + catch (Throwable throwable) { + return Mono.error(throwable); + } + }, this::commitTransactionAfterReturning, s -> Mono.empty()) + .onErrorResume(ex -> completeTransactionAfterThrowing(it, ex) + .then(Mono.error(ex))); } catch (Throwable ex) { // target invocation exception diff --git a/spring-tx/src/main/java/org/springframework/transaction/reactive/TransactionContextManager.java b/spring-tx/src/main/java/org/springframework/transaction/reactive/TransactionContextManager.java index 53dd59f23db4..7dc42c5c15e3 100644 --- a/spring-tx/src/main/java/org/springframework/transaction/reactive/TransactionContextManager.java +++ b/spring-tx/src/main/java/org/springframework/transaction/reactive/TransactionContextManager.java @@ -88,7 +88,7 @@ public static Function getOrCreateContext() { return context -> { TransactionContextHolder holder = context.get(TransactionContextHolder.class); if (holder.hasContext()) { - context.put(TransactionContext.class, holder.currentContext()); + return context.put(TransactionContext.class, holder.currentContext()); } return context.put(TransactionContext.class, holder.createContext()); }; diff --git a/spring-tx/src/test/java/org/springframework/transaction/interceptor/AbstractReactiveTransactionAspectTests.java b/spring-tx/src/test/java/org/springframework/transaction/interceptor/AbstractReactiveTransactionAspectTests.java index 4e7213c84b66..507c9d4e941e 100644 --- a/spring-tx/src/test/java/org/springframework/transaction/interceptor/AbstractReactiveTransactionAspectTests.java +++ b/spring-tx/src/test/java/org/springframework/transaction/interceptor/AbstractReactiveTransactionAspectTests.java @@ -33,6 +33,7 @@ import static org.assertj.core.api.Assertions.*; import static org.assertj.core.api.Fail.fail; +import static org.junit.Assert.*; import static org.mockito.Mockito.*; /** @@ -321,6 +322,7 @@ public void cannotCommitTransaction() throws Exception { when(rtm.getReactiveTransaction(txatt)).thenReturn(Mono.just(status)); UnexpectedRollbackException ex = new UnexpectedRollbackException("foobar", null); when(rtm.commit(status)).thenReturn(Mono.error(ex)); + when(rtm.rollback(status)).thenReturn(Mono.empty()); DefaultTestBean tb = new DefaultTestBean(); TestBean itb = (TestBean) advised(tb, rtm, tas); @@ -329,7 +331,10 @@ public void cannotCommitTransaction() throws Exception { Mono.from(itb.setName(name)) .as(StepVerifier::create) - .expectError(UnexpectedRollbackException.class) + .consumeErrorWith(throwable -> { + assertEquals(RuntimeException.class, throwable.getClass()); + assertEquals(ex, throwable.getCause()); + }) .verify(); // Should have invoked target and changed name