Skip to content

Commit d5cd46c

Browse files
committed
Polishing.
Add reactive type translation to Coroutine methods to return the expected type for AOP processing. See spring-projects#2926
1 parent 110756a commit d5cd46c

File tree

2 files changed

+87
-12
lines changed

2 files changed

+87
-12
lines changed

src/main/java/org/springframework/data/repository/core/support/RepositoryMethodInvoker.java

+47-9
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,15 @@
1515
*/
1616
package org.springframework.data.repository.core.support;
1717

18+
import kotlin.Unit;
1819
import kotlin.reflect.KFunction;
20+
import kotlinx.coroutines.flow.Flow;
1921
import reactor.core.publisher.Flux;
2022
import reactor.core.publisher.Mono;
2123

2224
import java.lang.reflect.InvocationTargetException;
2325
import java.lang.reflect.Method;
26+
import java.util.Collection;
2427
import java.util.stream.Stream;
2528

2629
import org.reactivestreams.Publisher;
@@ -54,25 +57,63 @@ abstract class RepositoryMethodInvoker {
5457
private final Class<?> returnedType;
5558
private final Invokable invokable;
5659
private final boolean suspendedDeclaredMethod;
57-
private final boolean returnsReactiveType;
5860

61+
@SuppressWarnings("ReactiveStreamsUnusedPublisher")
5962
protected RepositoryMethodInvoker(Method method, Invokable invokable) {
6063

6164
this.method = method;
62-
this.invokable = invokable;
6365

6466
if (KotlinDetector.isKotlinReflectPresent()) {
6567

6668
this.suspendedDeclaredMethod = KotlinReflectionUtils.isSuspend(method);
6769
this.returnedType = this.suspendedDeclaredMethod ? KotlinReflectionUtils.getReturnType(method)
6870
: method.getReturnType();
71+
72+
// special case for most query methods: These can return Flux but we don't want to fail later on if the method
73+
// is void.
74+
if (suspendedDeclaredMethod) {
75+
76+
this.invokable = args -> {
77+
78+
Object result = invokable.invoke(args);
79+
80+
if (returnedType == Unit.class) {
81+
82+
if (result instanceof Mono<?> m) {
83+
return m.then();
84+
}
85+
86+
Flux<?> flux = result instanceof Flux<?> f ? f : ReactiveWrapperConverters.toWrapper(result, Flux.class);
87+
return flux.then();
88+
}
89+
90+
if (returnedType != Flow.class) {
91+
92+
if (result instanceof Mono<?> m) {
93+
return m;
94+
}
95+
96+
Flux<?> flux = result instanceof Flux<?> f ? f : ReactiveWrapperConverters.toWrapper(result, Flux.class);
97+
98+
if (Collection.class.isAssignableFrom(returnedType)) {
99+
return flux.collectList();
100+
}
101+
102+
return flux.singleOrEmpty();
103+
}
104+
105+
return result;
106+
};
107+
} else {
108+
this.invokable = invokable;
109+
}
110+
69111
} else {
70112

71113
this.suspendedDeclaredMethod = false;
72114
this.returnedType = method.getReturnType();
115+
this.invokable = invokable;
73116
}
74-
75-
this.returnsReactiveType = ReactiveWrappers.supports(returnedType);
76117
}
77118

78119
static RepositoryQueryMethodInvoker forRepositoryQuery(Method declaredMethod, RepositoryQuery query) {
@@ -154,7 +195,7 @@ private RepositoryMethodInvocation computeInvocationResult(RepositoryMethodInvoc
154195
interface Invokable {
155196

156197
@Nullable
157-
Object invoke(Object[] args) throws ReflectiveOperationException;
198+
Object invoke(Object[] args) throws Exception;
158199
}
159200

160201
/**
@@ -214,8 +255,6 @@ Publisher<Object> decorate(Class<?> repositoryInterface, RepositoryInvocationMul
214255
*/
215256
private static class RepositoryFragmentMethodInvoker extends RepositoryMethodInvoker {
216257

217-
private final CoroutineAdapterInformation adapterInformation;
218-
219258
public RepositoryFragmentMethodInvoker(Method declaredMethod, Object instance, Method baseClassMethod) {
220259
this(CoroutineAdapterInformation.create(declaredMethod, baseClassMethod), declaredMethod, instance,
221260
baseClassMethod);
@@ -236,13 +275,12 @@ public RepositoryFragmentMethodInvoker(CoroutineAdapterInformation adapterInform
236275
return AopUtils.invokeJoinpointUsingReflection(instance, baseClassMethod, invocationArguments);
237276
}
238277
return AopUtils.invokeJoinpointUsingReflection(instance, baseClassMethod, args);
239-
} catch (RuntimeException e) {
278+
} catch (Exception e) {
240279
throw e;
241280
} catch (Throwable e) {
242281
throw new RuntimeException(e);
243282
}
244283
});
245-
this.adapterInformation = adapterInformation;
246284
}
247285

248286
/**

src/test/kotlin/org/springframework/data/repository/kotlin/CoroutineCrudRepositoryUnitTests.kt

+40-3
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,22 @@ class CoroutineCrudRepositoryUnitTests {
210210
Mockito.verify(invocationListener).afterInvocation(captor.capture())
211211
}
212212

213+
@Test // DATACMNS-1508, DATACMNS-1764
214+
fun shouldBridgeFluxQueryMethod() {
215+
216+
val sample = User()
217+
218+
Mockito.`when`(factory.queryOne.execute(arrayOf("foo", null, any()))).thenReturn(Flux.just(sample))
219+
220+
val result = runBlocking {
221+
coRepository.findOne("foo")
222+
}
223+
224+
assertThat(result).isNotNull().isEqualTo(sample)
225+
val captor = ArgumentCaptor.forClass(RepositoryMethodInvocationListener.RepositoryMethodInvocation::class.java)
226+
Mockito.verify(invocationListener).afterInvocation(captor.capture())
227+
}
228+
213229
@Test // DATACMNS-1508
214230
fun shouldBridgeRxJavaQueryMethod() {
215231

@@ -283,13 +299,20 @@ class CoroutineCrudRepositoryUnitTests {
283299

284300
val sample = User()
285301

286-
Mockito.`when`(factory.queryOne.execute(arrayOf("foo", null, any()))).thenReturn(Mono.just(listOf(sample)), Mono.empty<User>())
302+
Mockito.`when`(factory.queryOne.execute(arrayOf("foo", null, any())))
303+
.thenReturn(Mono.just(listOf(sample)), Flux.just(sample, sample), Mono.empty<User>())
287304

288-
val result = runBlocking {
305+
val result1 = runBlocking {
289306
coRepository.findSuspendedAsList("foo")
290307
}
291308

292-
assertThat(result).hasSize(1).containsOnly(sample)
309+
assertThat(result1).hasSize(1).containsOnly(sample)
310+
311+
val result2 = runBlocking {
312+
coRepository.findSuspendedAsList("foo")
313+
}
314+
315+
assertThat(result2).hasSize(2).contains(sample)
293316

294317
val emptyResult = runBlocking {
295318
coRepository.findSuspendedAsList("foo")
@@ -298,6 +321,18 @@ class CoroutineCrudRepositoryUnitTests {
298321
assertThat(emptyResult).isNull()
299322
}
300323

324+
@Test // DATACMNS-1802
325+
fun shouldDiscardResult() {
326+
327+
Mockito.`when`(factory.queryOne.execute(any())).thenReturn(Flux.empty<User>())
328+
329+
val result = runBlocking {
330+
coRepository.someDelete("foo")
331+
}
332+
333+
assertThat(result).isInstanceOf(Unit.javaClass)
334+
}
335+
301336
interface MyCoRepository : CoroutineCrudRepository<User, String> {
302337

303338
suspend fun findOne(id: String): User
@@ -307,5 +342,7 @@ class CoroutineCrudRepositoryUnitTests {
307342
suspend fun findSuspendedMultiple(id: String): Flow<User>
308343

309344
suspend fun findSuspendedAsList(id: String): List<User>
345+
346+
suspend fun someDelete(id: String)
310347
}
311348
}

0 commit comments

Comments
 (0)