Skip to content

Adapt for Spring Framework Coroutines AOP support #2926

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,16 @@
*/
package org.springframework.data.repository.core.support;

import kotlin.coroutines.Continuation;
import kotlin.reflect.KFunction;
import kotlinx.coroutines.reactive.AwaitKt;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.Collection;
import java.util.stream.Stream;

import kotlin.reflect.KFunction;
import org.reactivestreams.Publisher;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;

import org.springframework.aop.support.AopUtils;
import org.springframework.core.KotlinDetector;
import org.springframework.data.repository.core.support.RepositoryMethodInvocationListener.RepositoryMethodInvocation;
import org.springframework.data.repository.core.support.RepositoryMethodInvocationListener.RepositoryMethodInvocationResult;
Expand Down Expand Up @@ -116,12 +114,7 @@ public static boolean canInvoke(Method declaredMethod, Method baseClassMethod) {
@Nullable
public Object invoke(Class<?> repositoryInterface, RepositoryInvocationMulticaster multicaster, Object[] args)
throws Exception {
return shouldAdaptReactiveToSuspended() ? doInvokeReactiveToSuspended(repositoryInterface, multicaster, args)
: doInvoke(repositoryInterface, multicaster, args);
}

protected boolean shouldAdaptReactiveToSuspended() {
return suspendedDeclaredMethod;
return doInvoke(repositoryInterface, multicaster, args);
}

@Nullable
Expand Down Expand Up @@ -153,41 +146,6 @@ private Object doInvoke(Class<?> repositoryInterface, RepositoryInvocationMultic
}
}

@Nullable
@SuppressWarnings({ "unchecked", "ConstantConditions" })
private Object doInvokeReactiveToSuspended(Class<?> repositoryInterface, RepositoryInvocationMulticaster multicaster,
Object[] args) throws Exception {

/*
* Kotlin suspended functions are invoked with a synthetic Continuation parameter that keeps track of the Coroutine context.
* We're invoking a method without Continuation as we expect the method to return any sort of reactive type,
* therefore we need to strip the Continuation parameter.
*/
Continuation<Object> continuation = (Continuation) args[args.length - 1];
args[args.length - 1] = null;

RepositoryMethodInvocationCaptor invocationResultCaptor = RepositoryMethodInvocationCaptor
.captureInvocationOn(repositoryInterface);
try {

Publisher<?> result = new ReactiveInvocationListenerDecorator().decorate(repositoryInterface, multicaster, args,
invokable.invoke(args));

if (returnsReactiveType) {
return ReactiveWrapperConverters.toWrapper(result, returnedType);
}

if (Collection.class.isAssignableFrom(returnedType)) {
result = (Publisher<?>) collectToList(result);
}

return AwaitKt.awaitFirstOrNull(result, continuation);
} catch (Exception e) {
multicaster.notifyListeners(method, args, computeInvocationResult(invocationResultCaptor.error(e)));
throw e;
}
}

// to avoid NoClassDefFoundError: org/reactivestreams/Publisher when loading this class ¯\_(ツ)_/¯
private static Object collectToList(Object result) {
return Flux.from((Publisher<?>) result).collectList();
Expand Down Expand Up @@ -271,30 +229,26 @@ public RepositoryFragmentMethodInvoker(Method declaredMethod, Object instance, M
public RepositoryFragmentMethodInvoker(CoroutineAdapterInformation adapterInformation, Method declaredMethod,
Object instance, Method baseClassMethod) {
super(declaredMethod, args -> {

if (adapterInformation.isAdapterMethod()) {

/*
* Kotlin suspended functions are invoked with a synthetic Continuation parameter that keeps track of the Coroutine context.
* We're invoking a method without Continuation as we expect the method to return any sort of reactive type,
* therefore we need to strip the Continuation parameter.
*/
Object[] invocationArguments = new Object[args.length - 1];
System.arraycopy(args, 0, invocationArguments, 0, invocationArguments.length);

return baseClassMethod.invoke(instance, invocationArguments);
try {
if(adapterInformation.shouldAdaptReactiveToSuspended()) {
/*
* Kotlin suspended functions are invoked with a synthetic Continuation parameter that keeps track of the Coroutine context.
* We're invoking a method without Continuation as we expect the method to return any sort of reactive type,
* therefore we need to strip the Continuation parameter.
*/
Object[] invocationArguments = new Object[args.length - 1];
System.arraycopy(args, 0, invocationArguments, 0, invocationArguments.length);
return AopUtils.invokeJoinpointUsingReflection(instance, baseClassMethod, invocationArguments);
}
return AopUtils.invokeJoinpointUsingReflection(instance, baseClassMethod, args);
}
catch (Throwable e) {
throw new RuntimeException(e);
}

return baseClassMethod.invoke(instance, args);
});
this.adapterInformation = adapterInformation;
}

@Override
protected boolean shouldAdaptReactiveToSuspended() {
return adapterInformation.shouldAdaptReactiveToSuspended();
}

/**
* Value object capturing whether a suspended Kotlin method (Coroutine method) can be bridged with a native or
* reactive fragment method.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,6 @@
*/
package org.springframework.data.repository.core.support;

import static org.assertj.core.api.Assertions.*;
import static org.mockito.Mockito.*;

import kotlin.coroutines.Continuation;
import kotlin.coroutines.CoroutineContext;
import kotlinx.coroutines.flow.Flow;
import kotlinx.coroutines.flow.FlowKt;
import kotlinx.coroutines.reactor.ReactorContext;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;

import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Iterator;
Expand All @@ -38,6 +26,8 @@
import java.util.function.Consumer;
import java.util.stream.Stream;

import kotlin.coroutines.Continuation;
import kotlinx.coroutines.reactive.ReactiveFlowKt;
import org.assertj.core.api.Assertions;
import org.assertj.core.data.Percentage;
import org.jetbrains.annotations.NotNull;
Expand All @@ -49,6 +39,10 @@
import org.mockito.internal.stubbing.answers.Returns;
import org.mockito.junit.jupiter.MockitoExtension;
import org.reactivestreams.Subscription;
import reactor.core.publisher.Flux;
import reactor.core.publisher.Mono;
import reactor.test.StepVerifier;

import org.springframework.data.repository.CrudRepository;
import org.springframework.data.repository.core.support.CoroutineRepositoryMetadataUnitTests.MyCoroutineRepository;
import org.springframework.data.repository.core.support.RepositoryMethodInvocationListener.RepositoryMethodInvocation;
Expand All @@ -59,6 +53,12 @@
import org.springframework.util.CollectionUtils;
import org.springframework.util.ReflectionUtils;

import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalStateException;
import static org.mockito.Mockito.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;

/**
* @author Christoph Strobl
* @author Johannes Englmeier
Expand Down Expand Up @@ -244,29 +244,12 @@ void capturesReactiveCancellationCorrectly() throws Exception {
@Test // DATACMNS-1764
void capturesKotlinSuspendFunctionsCorrectly() throws Exception {

var result = Flux.just(new TestDummy());
var result = ReactiveFlowKt.asFlow(Flux.just(new TestDummy()));
when(query.execute(any())).thenReturn(result);

Flow<TestDummy> flow = new RepositoryMethodInvokerStub(MyCoroutineRepository.class, multicaster,
Flux<TestDummy> flux = new RepositoryMethodInvokerStub(MyCoroutineRepository.class, multicaster,
"suspendedQueryMethod", query::execute).invoke(mock(Continuation.class));

assertThat(multicaster).isEmpty();

FlowKt.toCollection(flow, new ArrayList<>(), new Continuation<ArrayList<? extends Object>>() {

ReactorContext ctx = new ReactorContext(reactor.util.context.Context.empty());

@NotNull
@Override
public CoroutineContext getContext() {
return ctx;
}

@Override
public void resumeWith(@NotNull Object o) {

}
});
flux.subscribe();

assertThat(multicaster.first().getResult().getState()).isEqualTo(State.SUCCESS);
assertThat(multicaster.first().getResult().getError()).isNull();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package org.springframework.data.repository.kotlin

import io.mockk.mockk
import kotlinx.coroutines.delay
import kotlinx.coroutines.runBlocking
import org.assertj.core.api.Assertions.assertThat
import org.junit.jupiter.api.BeforeEach
Expand Down Expand Up @@ -45,7 +46,7 @@ class CoroutineCrudRepositoryCustomImplementationUnitTests {
}

@Test // DATACMNS-1508
fun shouldInvokeFindAll() {
fun shouldInvokeFindOne() {

val result = runBlocking {
coRepository.findOne("foo")
Expand All @@ -71,6 +72,7 @@ class CoroutineCrudRepositoryCustomImplementationUnitTests {
class MyCustomCoRepositoryImpl : MyCustomCoRepository {

override suspend fun findOne(id: String): User {
delay(1)
return User()
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ import io.mockk.every
import io.mockk.mockk
import io.mockk.verify
import io.reactivex.rxjava3.core.Observable
import io.reactivex.rxjava3.core.Single
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.flowOf
import kotlinx.coroutines.flow.toList
Expand All @@ -28,6 +27,7 @@ import org.assertj.core.api.Assertions.assertThat
import org.junit.jupiter.api.BeforeEach
import org.junit.jupiter.api.Test
import org.mockito.ArgumentCaptor
import org.mockito.ArgumentMatchers.any
import org.mockito.Mockito
import org.reactivestreams.Publisher
import org.springframework.data.repository.core.support.DummyReactiveRepositoryFactory
Expand Down Expand Up @@ -199,7 +199,7 @@ class CoroutineCrudRepositoryUnitTests {

val sample = User()

Mockito.`when`(factory.queryOne.execute(arrayOf("foo", null))).thenReturn(Mono.just(sample))
Mockito.`when`(factory.queryOne.execute(arrayOf("foo", null, any()))).thenReturn(Mono.just(sample))

val result = runBlocking {
coRepository.findOne("foo")
Expand All @@ -215,7 +215,7 @@ class CoroutineCrudRepositoryUnitTests {

val sample = User()

Mockito.`when`(factory.queryOne.execute(arrayOf("foo", null))).thenReturn(Single.just(sample))
Mockito.`when`(factory.queryOne.execute(arrayOf("foo", null, any()))).thenReturn(Mono.just(sample))

val result = runBlocking {
coRepository.findOne("foo")
Expand Down Expand Up @@ -263,7 +263,7 @@ class CoroutineCrudRepositoryUnitTests {

val sample = User()

Mockito.`when`(factory.queryOne.execute(arrayOf("foo", null))).thenReturn(Flux.just(sample), Flux.empty<User>())
Mockito.`when`(factory.queryOne.execute(arrayOf("foo", null, any()))).thenReturn(Flux.just(sample), Flux.empty<User>())

val result = runBlocking {
coRepository.findSuspendedMultiple("foo").toList()
Expand All @@ -283,7 +283,7 @@ class CoroutineCrudRepositoryUnitTests {

val sample = User()

Mockito.`when`(factory.queryOne.execute(arrayOf("foo", null))).thenReturn(Flux.just(sample), Flux.empty<User>())
Mockito.`when`(factory.queryOne.execute(arrayOf("foo", null, any()))).thenReturn(Mono.just(listOf(sample)), Mono.empty<User>())

val result = runBlocking {
coRepository.findSuspendedAsList("foo")
Expand All @@ -295,7 +295,7 @@ class CoroutineCrudRepositoryUnitTests {
coRepository.findSuspendedAsList("foo")
}

assertThat(emptyResult).isEmpty()
assertThat(emptyResult).isNull()
}

interface MyCoRepository : CoroutineCrudRepository<User, String> {
Expand Down