diff --git a/src/main/java/org/springframework/data/repository/core/support/RepositoryMethodInvoker.java b/src/main/java/org/springframework/data/repository/core/support/RepositoryMethodInvoker.java index 8e4f283f4d..a76b498935 100644 --- a/src/main/java/org/springframework/data/repository/core/support/RepositoryMethodInvoker.java +++ b/src/main/java/org/springframework/data/repository/core/support/RepositoryMethodInvoker.java @@ -15,18 +15,16 @@ */ package org.springframework.data.repository.core.support; -import kotlin.coroutines.Continuation; -import kotlin.reflect.KFunction; -import kotlinx.coroutines.reactive.AwaitKt; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Method; -import java.util.Collection; import java.util.stream.Stream; +import kotlin.reflect.KFunction; import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; + +import org.springframework.aop.support.AopUtils; import org.springframework.core.KotlinDetector; import org.springframework.data.repository.core.support.RepositoryMethodInvocationListener.RepositoryMethodInvocation; import org.springframework.data.repository.core.support.RepositoryMethodInvocationListener.RepositoryMethodInvocationResult; @@ -116,12 +114,7 @@ public static boolean canInvoke(Method declaredMethod, Method baseClassMethod) { @Nullable public Object invoke(Class repositoryInterface, RepositoryInvocationMulticaster multicaster, Object[] args) throws Exception { - return shouldAdaptReactiveToSuspended() ? doInvokeReactiveToSuspended(repositoryInterface, multicaster, args) - : doInvoke(repositoryInterface, multicaster, args); - } - - protected boolean shouldAdaptReactiveToSuspended() { - return suspendedDeclaredMethod; + return doInvoke(repositoryInterface, multicaster, args); } @Nullable @@ -153,41 +146,6 @@ private Object doInvoke(Class repositoryInterface, RepositoryInvocationMultic } } - @Nullable - @SuppressWarnings({ "unchecked", "ConstantConditions" }) - private Object doInvokeReactiveToSuspended(Class repositoryInterface, RepositoryInvocationMulticaster multicaster, - Object[] args) throws Exception { - - /* - * Kotlin suspended functions are invoked with a synthetic Continuation parameter that keeps track of the Coroutine context. - * We're invoking a method without Continuation as we expect the method to return any sort of reactive type, - * therefore we need to strip the Continuation parameter. - */ - Continuation continuation = (Continuation) args[args.length - 1]; - args[args.length - 1] = null; - - RepositoryMethodInvocationCaptor invocationResultCaptor = RepositoryMethodInvocationCaptor - .captureInvocationOn(repositoryInterface); - try { - - Publisher result = new ReactiveInvocationListenerDecorator().decorate(repositoryInterface, multicaster, args, - invokable.invoke(args)); - - if (returnsReactiveType) { - return ReactiveWrapperConverters.toWrapper(result, returnedType); - } - - if (Collection.class.isAssignableFrom(returnedType)) { - result = (Publisher) collectToList(result); - } - - return AwaitKt.awaitFirstOrNull(result, continuation); - } catch (Exception e) { - multicaster.notifyListeners(method, args, computeInvocationResult(invocationResultCaptor.error(e))); - throw e; - } - } - // to avoid NoClassDefFoundError: org/reactivestreams/Publisher when loading this class ¯\_(ツ)_/¯ private static Object collectToList(Object result) { return Flux.from((Publisher) result).collectList(); @@ -271,30 +229,26 @@ public RepositoryFragmentMethodInvoker(Method declaredMethod, Object instance, M public RepositoryFragmentMethodInvoker(CoroutineAdapterInformation adapterInformation, Method declaredMethod, Object instance, Method baseClassMethod) { super(declaredMethod, args -> { - - if (adapterInformation.isAdapterMethod()) { - - /* - * Kotlin suspended functions are invoked with a synthetic Continuation parameter that keeps track of the Coroutine context. - * We're invoking a method without Continuation as we expect the method to return any sort of reactive type, - * therefore we need to strip the Continuation parameter. - */ - Object[] invocationArguments = new Object[args.length - 1]; - System.arraycopy(args, 0, invocationArguments, 0, invocationArguments.length); - - return baseClassMethod.invoke(instance, invocationArguments); + try { + if(adapterInformation.shouldAdaptReactiveToSuspended()) { + /* + * Kotlin suspended functions are invoked with a synthetic Continuation parameter that keeps track of the Coroutine context. + * We're invoking a method without Continuation as we expect the method to return any sort of reactive type, + * therefore we need to strip the Continuation parameter. + */ + Object[] invocationArguments = new Object[args.length - 1]; + System.arraycopy(args, 0, invocationArguments, 0, invocationArguments.length); + return AopUtils.invokeJoinpointUsingReflection(instance, baseClassMethod, invocationArguments); + } + return AopUtils.invokeJoinpointUsingReflection(instance, baseClassMethod, args); + } + catch (Throwable e) { + throw new RuntimeException(e); } - - return baseClassMethod.invoke(instance, args); }); this.adapterInformation = adapterInformation; } - @Override - protected boolean shouldAdaptReactiveToSuspended() { - return adapterInformation.shouldAdaptReactiveToSuspended(); - } - /** * Value object capturing whether a suspended Kotlin method (Coroutine method) can be bridged with a native or * reactive fragment method. diff --git a/src/test/java/org/springframework/data/repository/core/support/RepositoryMethodInvokerUnitTests.java b/src/test/java/org/springframework/data/repository/core/support/RepositoryMethodInvokerUnitTests.java index 169de59bc2..5a185737f9 100644 --- a/src/test/java/org/springframework/data/repository/core/support/RepositoryMethodInvokerUnitTests.java +++ b/src/test/java/org/springframework/data/repository/core/support/RepositoryMethodInvokerUnitTests.java @@ -15,18 +15,6 @@ */ package org.springframework.data.repository.core.support; -import static org.assertj.core.api.Assertions.*; -import static org.mockito.Mockito.*; - -import kotlin.coroutines.Continuation; -import kotlin.coroutines.CoroutineContext; -import kotlinx.coroutines.flow.Flow; -import kotlinx.coroutines.flow.FlowKt; -import kotlinx.coroutines.reactor.ReactorContext; -import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; -import reactor.test.StepVerifier; - import java.lang.reflect.Method; import java.util.ArrayList; import java.util.Iterator; @@ -38,6 +26,8 @@ import java.util.function.Consumer; import java.util.stream.Stream; +import kotlin.coroutines.Continuation; +import kotlinx.coroutines.reactive.ReactiveFlowKt; import org.assertj.core.api.Assertions; import org.assertj.core.data.Percentage; import org.jetbrains.annotations.NotNull; @@ -49,6 +39,10 @@ import org.mockito.internal.stubbing.answers.Returns; import org.mockito.junit.jupiter.MockitoExtension; import org.reactivestreams.Subscription; +import reactor.core.publisher.Flux; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; + import org.springframework.data.repository.CrudRepository; import org.springframework.data.repository.core.support.CoroutineRepositoryMetadataUnitTests.MyCoroutineRepository; import org.springframework.data.repository.core.support.RepositoryMethodInvocationListener.RepositoryMethodInvocation; @@ -59,6 +53,12 @@ import org.springframework.util.CollectionUtils; import org.springframework.util.ReflectionUtils; +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatIllegalStateException; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + /** * @author Christoph Strobl * @author Johannes Englmeier @@ -244,29 +244,12 @@ void capturesReactiveCancellationCorrectly() throws Exception { @Test // DATACMNS-1764 void capturesKotlinSuspendFunctionsCorrectly() throws Exception { - var result = Flux.just(new TestDummy()); + var result = ReactiveFlowKt.asFlow(Flux.just(new TestDummy())); when(query.execute(any())).thenReturn(result); - Flow flow = new RepositoryMethodInvokerStub(MyCoroutineRepository.class, multicaster, + Flux flux = new RepositoryMethodInvokerStub(MyCoroutineRepository.class, multicaster, "suspendedQueryMethod", query::execute).invoke(mock(Continuation.class)); - - assertThat(multicaster).isEmpty(); - - FlowKt.toCollection(flow, new ArrayList<>(), new Continuation>() { - - ReactorContext ctx = new ReactorContext(reactor.util.context.Context.empty()); - - @NotNull - @Override - public CoroutineContext getContext() { - return ctx; - } - - @Override - public void resumeWith(@NotNull Object o) { - - } - }); + flux.subscribe(); assertThat(multicaster.first().getResult().getState()).isEqualTo(State.SUCCESS); assertThat(multicaster.first().getResult().getError()).isNull(); diff --git a/src/test/kotlin/org/springframework/data/repository/kotlin/CoroutineCrudRepositoryCustomImplementationUnitTests.kt b/src/test/kotlin/org/springframework/data/repository/kotlin/CoroutineCrudRepositoryCustomImplementationUnitTests.kt index e0758f026c..fe756e5b0f 100644 --- a/src/test/kotlin/org/springframework/data/repository/kotlin/CoroutineCrudRepositoryCustomImplementationUnitTests.kt +++ b/src/test/kotlin/org/springframework/data/repository/kotlin/CoroutineCrudRepositoryCustomImplementationUnitTests.kt @@ -16,6 +16,7 @@ package org.springframework.data.repository.kotlin import io.mockk.mockk +import kotlinx.coroutines.delay import kotlinx.coroutines.runBlocking import org.assertj.core.api.Assertions.assertThat import org.junit.jupiter.api.BeforeEach @@ -45,7 +46,7 @@ class CoroutineCrudRepositoryCustomImplementationUnitTests { } @Test // DATACMNS-1508 - fun shouldInvokeFindAll() { + fun shouldInvokeFindOne() { val result = runBlocking { coRepository.findOne("foo") @@ -71,6 +72,7 @@ class CoroutineCrudRepositoryCustomImplementationUnitTests { class MyCustomCoRepositoryImpl : MyCustomCoRepository { override suspend fun findOne(id: String): User { + delay(1) return User() } } diff --git a/src/test/kotlin/org/springframework/data/repository/kotlin/CoroutineCrudRepositoryUnitTests.kt b/src/test/kotlin/org/springframework/data/repository/kotlin/CoroutineCrudRepositoryUnitTests.kt index af0d70628d..87bbb62cce 100644 --- a/src/test/kotlin/org/springframework/data/repository/kotlin/CoroutineCrudRepositoryUnitTests.kt +++ b/src/test/kotlin/org/springframework/data/repository/kotlin/CoroutineCrudRepositoryUnitTests.kt @@ -19,7 +19,6 @@ import io.mockk.every import io.mockk.mockk import io.mockk.verify import io.reactivex.rxjava3.core.Observable -import io.reactivex.rxjava3.core.Single import kotlinx.coroutines.flow.Flow import kotlinx.coroutines.flow.flowOf import kotlinx.coroutines.flow.toList @@ -28,6 +27,7 @@ import org.assertj.core.api.Assertions.assertThat import org.junit.jupiter.api.BeforeEach import org.junit.jupiter.api.Test import org.mockito.ArgumentCaptor +import org.mockito.ArgumentMatchers.any import org.mockito.Mockito import org.reactivestreams.Publisher import org.springframework.data.repository.core.support.DummyReactiveRepositoryFactory @@ -199,7 +199,7 @@ class CoroutineCrudRepositoryUnitTests { val sample = User() - Mockito.`when`(factory.queryOne.execute(arrayOf("foo", null))).thenReturn(Mono.just(sample)) + Mockito.`when`(factory.queryOne.execute(arrayOf("foo", null, any()))).thenReturn(Mono.just(sample)) val result = runBlocking { coRepository.findOne("foo") @@ -215,7 +215,7 @@ class CoroutineCrudRepositoryUnitTests { val sample = User() - Mockito.`when`(factory.queryOne.execute(arrayOf("foo", null))).thenReturn(Single.just(sample)) + Mockito.`when`(factory.queryOne.execute(arrayOf("foo", null, any()))).thenReturn(Mono.just(sample)) val result = runBlocking { coRepository.findOne("foo") @@ -263,7 +263,7 @@ class CoroutineCrudRepositoryUnitTests { val sample = User() - Mockito.`when`(factory.queryOne.execute(arrayOf("foo", null))).thenReturn(Flux.just(sample), Flux.empty()) + Mockito.`when`(factory.queryOne.execute(arrayOf("foo", null, any()))).thenReturn(Flux.just(sample), Flux.empty()) val result = runBlocking { coRepository.findSuspendedMultiple("foo").toList() @@ -283,7 +283,7 @@ class CoroutineCrudRepositoryUnitTests { val sample = User() - Mockito.`when`(factory.queryOne.execute(arrayOf("foo", null))).thenReturn(Flux.just(sample), Flux.empty()) + Mockito.`when`(factory.queryOne.execute(arrayOf("foo", null, any()))).thenReturn(Mono.just(listOf(sample)), Mono.empty()) val result = runBlocking { coRepository.findSuspendedAsList("foo") @@ -295,7 +295,7 @@ class CoroutineCrudRepositoryUnitTests { coRepository.findSuspendedAsList("foo") } - assertThat(emptyResult).isEmpty() + assertThat(emptyResult).isNull() } interface MyCoRepository : CoroutineCrudRepository {