Skip to content

Commit 110756a

Browse files
sdeleuzemp911de
authored andcommitted
Adapt for Spring Framework Coroutines AOP support.
This commit adapts Spring Data RepositoryMethodInvoker and related tests in order to remove most of the Coroutines specific code and rely on Spring Framework Coroutines AOP support. Closes #2926
1 parent ccfa93a commit 110756a

File tree

4 files changed

+45
-106
lines changed

4 files changed

+45
-106
lines changed

src/main/java/org/springframework/data/repository/core/support/RepositoryMethodInvoker.java

+18-68
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,16 @@
1515
*/
1616
package org.springframework.data.repository.core.support;
1717

18-
import kotlin.coroutines.Continuation;
1918
import kotlin.reflect.KFunction;
20-
import kotlinx.coroutines.reactive.AwaitKt;
2119
import reactor.core.publisher.Flux;
2220
import reactor.core.publisher.Mono;
2321

2422
import java.lang.reflect.InvocationTargetException;
2523
import java.lang.reflect.Method;
26-
import java.util.Collection;
2724
import java.util.stream.Stream;
2825

2926
import org.reactivestreams.Publisher;
27+
import org.springframework.aop.support.AopUtils;
3028
import org.springframework.core.KotlinDetector;
3129
import org.springframework.data.repository.core.support.RepositoryMethodInvocationListener.RepositoryMethodInvocation;
3230
import org.springframework.data.repository.core.support.RepositoryMethodInvocationListener.RepositoryMethodInvocationResult;
@@ -116,12 +114,7 @@ public static boolean canInvoke(Method declaredMethod, Method baseClassMethod) {
116114
@Nullable
117115
public Object invoke(Class<?> repositoryInterface, RepositoryInvocationMulticaster multicaster, Object[] args)
118116
throws Exception {
119-
return shouldAdaptReactiveToSuspended() ? doInvokeReactiveToSuspended(repositoryInterface, multicaster, args)
120-
: doInvoke(repositoryInterface, multicaster, args);
121-
}
122-
123-
protected boolean shouldAdaptReactiveToSuspended() {
124-
return suspendedDeclaredMethod;
117+
return doInvoke(repositoryInterface, multicaster, args);
125118
}
126119

127120
@Nullable
@@ -153,46 +146,6 @@ private Object doInvoke(Class<?> repositoryInterface, RepositoryInvocationMultic
153146
}
154147
}
155148

156-
@Nullable
157-
@SuppressWarnings({ "unchecked", "ConstantConditions" })
158-
private Object doInvokeReactiveToSuspended(Class<?> repositoryInterface, RepositoryInvocationMulticaster multicaster,
159-
Object[] args) throws Exception {
160-
161-
/*
162-
* Kotlin suspended functions are invoked with a synthetic Continuation parameter that keeps track of the Coroutine context.
163-
* We're invoking a method without Continuation as we expect the method to return any sort of reactive type,
164-
* therefore we need to strip the Continuation parameter.
165-
*/
166-
Continuation<Object> continuation = (Continuation) args[args.length - 1];
167-
args[args.length - 1] = null;
168-
169-
RepositoryMethodInvocationCaptor invocationResultCaptor = RepositoryMethodInvocationCaptor
170-
.captureInvocationOn(repositoryInterface);
171-
try {
172-
173-
Publisher<?> result = new ReactiveInvocationListenerDecorator().decorate(repositoryInterface, multicaster, args,
174-
invokable.invoke(args));
175-
176-
if (returnsReactiveType) {
177-
return ReactiveWrapperConverters.toWrapper(result, returnedType);
178-
}
179-
180-
if (Collection.class.isAssignableFrom(returnedType)) {
181-
result = (Publisher<?>) collectToList(result);
182-
}
183-
184-
return AwaitKt.awaitFirstOrNull(result, continuation);
185-
} catch (Exception e) {
186-
multicaster.notifyListeners(method, args, computeInvocationResult(invocationResultCaptor.error(e)));
187-
throw e;
188-
}
189-
}
190-
191-
// to avoid NoClassDefFoundError: org/reactivestreams/Publisher when loading this class ¯\_(ツ)_/¯
192-
private static Object collectToList(Object result) {
193-
return Flux.from((Publisher<?>) result).collectList();
194-
}
195-
196149
private RepositoryMethodInvocation computeInvocationResult(RepositoryMethodInvocationCaptor captured) {
197150
return new RepositoryMethodInvocation(captured.getRepositoryInterface(), method, captured.getCapturedResult(),
198151
captured.getDuration());
@@ -271,30 +224,27 @@ public RepositoryFragmentMethodInvoker(Method declaredMethod, Object instance, M
271224
public RepositoryFragmentMethodInvoker(CoroutineAdapterInformation adapterInformation, Method declaredMethod,
272225
Object instance, Method baseClassMethod) {
273226
super(declaredMethod, args -> {
274-
275-
if (adapterInformation.isAdapterMethod()) {
276-
277-
/*
278-
* Kotlin suspended functions are invoked with a synthetic Continuation parameter that keeps track of the Coroutine context.
279-
* We're invoking a method without Continuation as we expect the method to return any sort of reactive type,
280-
* therefore we need to strip the Continuation parameter.
281-
*/
282-
Object[] invocationArguments = new Object[args.length - 1];
283-
System.arraycopy(args, 0, invocationArguments, 0, invocationArguments.length);
284-
285-
return baseClassMethod.invoke(instance, invocationArguments);
227+
try {
228+
if (adapterInformation.shouldAdaptReactiveToSuspended()) {
229+
/*
230+
* Kotlin suspended functions are invoked with a synthetic Continuation parameter that keeps track of the Coroutine context.
231+
* We're invoking a method without Continuation as we expect the method to return any sort of reactive type,
232+
* therefore we need to strip the Continuation parameter.
233+
*/
234+
Object[] invocationArguments = new Object[args.length - 1];
235+
System.arraycopy(args, 0, invocationArguments, 0, invocationArguments.length);
236+
return AopUtils.invokeJoinpointUsingReflection(instance, baseClassMethod, invocationArguments);
237+
}
238+
return AopUtils.invokeJoinpointUsingReflection(instance, baseClassMethod, args);
239+
} catch (RuntimeException e) {
240+
throw e;
241+
} catch (Throwable e) {
242+
throw new RuntimeException(e);
286243
}
287-
288-
return baseClassMethod.invoke(instance, args);
289244
});
290245
this.adapterInformation = adapterInformation;
291246
}
292247

293-
@Override
294-
protected boolean shouldAdaptReactiveToSuspended() {
295-
return adapterInformation.shouldAdaptReactiveToSuspended();
296-
}
297-
298248
/**
299249
* Value object capturing whether a suspended Kotlin method (Coroutine method) can be bridged with a native or
300250
* reactive fragment method.

src/test/java/org/springframework/data/repository/core/support/RepositoryMethodInvokerUnitTests.java

+15-32
Original file line numberDiff line numberDiff line change
@@ -15,18 +15,6 @@
1515
*/
1616
package org.springframework.data.repository.core.support;
1717

18-
import static org.assertj.core.api.Assertions.*;
19-
import static org.mockito.Mockito.*;
20-
21-
import kotlin.coroutines.Continuation;
22-
import kotlin.coroutines.CoroutineContext;
23-
import kotlinx.coroutines.flow.Flow;
24-
import kotlinx.coroutines.flow.FlowKt;
25-
import kotlinx.coroutines.reactor.ReactorContext;
26-
import reactor.core.publisher.Flux;
27-
import reactor.core.publisher.Mono;
28-
import reactor.test.StepVerifier;
29-
3018
import java.lang.reflect.Method;
3119
import java.util.ArrayList;
3220
import java.util.Iterator;
@@ -38,6 +26,8 @@
3826
import java.util.function.Consumer;
3927
import java.util.stream.Stream;
4028

29+
import kotlin.coroutines.Continuation;
30+
import kotlinx.coroutines.reactive.ReactiveFlowKt;
4131
import org.assertj.core.api.Assertions;
4232
import org.assertj.core.data.Percentage;
4333
import org.jetbrains.annotations.NotNull;
@@ -49,6 +39,10 @@
4939
import org.mockito.internal.stubbing.answers.Returns;
5040
import org.mockito.junit.jupiter.MockitoExtension;
5141
import org.reactivestreams.Subscription;
42+
import reactor.core.publisher.Flux;
43+
import reactor.core.publisher.Mono;
44+
import reactor.test.StepVerifier;
45+
5246
import org.springframework.data.repository.CrudRepository;
5347
import org.springframework.data.repository.core.support.CoroutineRepositoryMetadataUnitTests.MyCoroutineRepository;
5448
import org.springframework.data.repository.core.support.RepositoryMethodInvocationListener.RepositoryMethodInvocation;
@@ -59,6 +53,12 @@
5953
import org.springframework.util.CollectionUtils;
6054
import org.springframework.util.ReflectionUtils;
6155

56+
import static org.assertj.core.api.Assertions.assertThat;
57+
import static org.assertj.core.api.Assertions.assertThatIllegalStateException;
58+
import static org.mockito.Mockito.any;
59+
import static org.mockito.Mockito.mock;
60+
import static org.mockito.Mockito.when;
61+
6262
/**
6363
* @author Christoph Strobl
6464
* @author Johannes Englmeier
@@ -244,29 +244,12 @@ void capturesReactiveCancellationCorrectly() throws Exception {
244244
@Test // DATACMNS-1764
245245
void capturesKotlinSuspendFunctionsCorrectly() throws Exception {
246246

247-
var result = Flux.just(new TestDummy());
247+
var result = ReactiveFlowKt.asFlow(Flux.just(new TestDummy()));
248248
when(query.execute(any())).thenReturn(result);
249249

250-
Flow<TestDummy> flow = new RepositoryMethodInvokerStub(MyCoroutineRepository.class, multicaster,
250+
Flux<TestDummy> flux = new RepositoryMethodInvokerStub(MyCoroutineRepository.class, multicaster,
251251
"suspendedQueryMethod", query::execute).invoke(mock(Continuation.class));
252-
253-
assertThat(multicaster).isEmpty();
254-
255-
FlowKt.toCollection(flow, new ArrayList<>(), new Continuation<ArrayList<? extends Object>>() {
256-
257-
ReactorContext ctx = new ReactorContext(reactor.util.context.Context.empty());
258-
259-
@NotNull
260-
@Override
261-
public CoroutineContext getContext() {
262-
return ctx;
263-
}
264-
265-
@Override
266-
public void resumeWith(@NotNull Object o) {
267-
268-
}
269-
});
252+
flux.subscribe();
270253

271254
assertThat(multicaster.first().getResult().getState()).isEqualTo(State.SUCCESS);
272255
assertThat(multicaster.first().getResult().getError()).isNull();

src/test/kotlin/org/springframework/data/repository/kotlin/CoroutineCrudRepositoryCustomImplementationUnitTests.kt

+6
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ import org.springframework.data.repository.core.RepositoryMetadata
2525
import org.springframework.data.repository.core.support.DummyReactiveRepositoryFactory
2626
import org.springframework.data.repository.core.support.RepositoryComposition
2727
import org.springframework.data.repository.core.support.RepositoryFragment
28+
import org.springframework.data.repository.core.support.RepositoryMethodInvocationListener
2829
import org.springframework.data.repository.reactive.ReactiveCrudRepository
2930
import org.springframework.data.repository.sample.User
3031

@@ -42,7 +43,12 @@ class CoroutineCrudRepositoryCustomImplementationUnitTests {
4243
@BeforeEach
4344
fun before() {
4445
factory = CustomDummyReactiveRepositoryFactory(backingRepository)
46+
factory.addInvocationListener(RepositoryMethodInvocationListener {
47+
repositoryMethodInvocation ->
48+
println(repositoryMethodInvocation)
49+
})
4550
coRepository = factory.getRepository(MyCoRepository::class.java)
51+
4652
}
4753

4854
@Test // DATACMNS-1508

src/test/kotlin/org/springframework/data/repository/kotlin/CoroutineCrudRepositoryUnitTests.kt

+6-6
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,6 @@ import io.mockk.every
1919
import io.mockk.mockk
2020
import io.mockk.verify
2121
import io.reactivex.rxjava3.core.Observable
22-
import io.reactivex.rxjava3.core.Single
2322
import kotlinx.coroutines.flow.Flow
2423
import kotlinx.coroutines.flow.flowOf
2524
import kotlinx.coroutines.flow.toList
@@ -28,6 +27,7 @@ import org.assertj.core.api.Assertions.assertThat
2827
import org.junit.jupiter.api.BeforeEach
2928
import org.junit.jupiter.api.Test
3029
import org.mockito.ArgumentCaptor
30+
import org.mockito.ArgumentMatchers.any
3131
import org.mockito.Mockito
3232
import org.reactivestreams.Publisher
3333
import org.springframework.data.repository.core.support.DummyReactiveRepositoryFactory
@@ -199,7 +199,7 @@ class CoroutineCrudRepositoryUnitTests {
199199

200200
val sample = User()
201201

202-
Mockito.`when`(factory.queryOne.execute(arrayOf("foo", null))).thenReturn(Mono.just(sample))
202+
Mockito.`when`(factory.queryOne.execute(arrayOf("foo", null, any()))).thenReturn(Mono.just(sample))
203203

204204
val result = runBlocking {
205205
coRepository.findOne("foo")
@@ -215,7 +215,7 @@ class CoroutineCrudRepositoryUnitTests {
215215

216216
val sample = User()
217217

218-
Mockito.`when`(factory.queryOne.execute(arrayOf("foo", null))).thenReturn(Single.just(sample))
218+
Mockito.`when`(factory.queryOne.execute(arrayOf("foo", null, any()))).thenReturn(Mono.just(sample))
219219

220220
val result = runBlocking {
221221
coRepository.findOne("foo")
@@ -263,7 +263,7 @@ class CoroutineCrudRepositoryUnitTests {
263263

264264
val sample = User()
265265

266-
Mockito.`when`(factory.queryOne.execute(arrayOf("foo", null))).thenReturn(Flux.just(sample), Flux.empty<User>())
266+
Mockito.`when`(factory.queryOne.execute(arrayOf("foo", null, any()))).thenReturn(Flux.just(sample), Flux.empty<User>())
267267

268268
val result = runBlocking {
269269
coRepository.findSuspendedMultiple("foo").toList()
@@ -283,7 +283,7 @@ class CoroutineCrudRepositoryUnitTests {
283283

284284
val sample = User()
285285

286-
Mockito.`when`(factory.queryOne.execute(arrayOf("foo", null))).thenReturn(Flux.just(sample), Flux.empty<User>())
286+
Mockito.`when`(factory.queryOne.execute(arrayOf("foo", null, any()))).thenReturn(Mono.just(listOf(sample)), Mono.empty<User>())
287287

288288
val result = runBlocking {
289289
coRepository.findSuspendedAsList("foo")
@@ -295,7 +295,7 @@ class CoroutineCrudRepositoryUnitTests {
295295
coRepository.findSuspendedAsList("foo")
296296
}
297297

298-
assertThat(emptyResult).isEmpty()
298+
assertThat(emptyResult).isNull()
299299
}
300300

301301
interface MyCoRepository : CoroutineCrudRepository<User, String> {

0 commit comments

Comments
 (0)