Skip to content

Commit 52f7b00

Browse files
committed
Polishing.
Add reactive type translation to Coroutine methods to return the expected type for AOP processing. See #2926
1 parent 110756a commit 52f7b00

File tree

2 files changed

+72
-9
lines changed

2 files changed

+72
-9
lines changed

src/main/java/org/springframework/data/repository/core/support/RepositoryMethodInvoker.java

+42-9
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@
1515
*/
1616
package org.springframework.data.repository.core.support;
1717

18+
import kotlin.Unit;
1819
import kotlin.reflect.KFunction;
20+
import kotlinx.coroutines.flow.Flow;
1921
import reactor.core.publisher.Flux;
2022
import reactor.core.publisher.Mono;
2123

@@ -54,25 +56,59 @@ abstract class RepositoryMethodInvoker {
5456
private final Class<?> returnedType;
5557
private final Invokable invokable;
5658
private final boolean suspendedDeclaredMethod;
57-
private final boolean returnsReactiveType;
5859

5960
protected RepositoryMethodInvoker(Method method, Invokable invokable) {
6061

6162
this.method = method;
62-
this.invokable = invokable;
6363

6464
if (KotlinDetector.isKotlinReflectPresent()) {
6565

6666
this.suspendedDeclaredMethod = KotlinReflectionUtils.isSuspend(method);
6767
this.returnedType = this.suspendedDeclaredMethod ? KotlinReflectionUtils.getReturnType(method)
6868
: method.getReturnType();
69+
70+
// special case for most query methods: These can return Flux but we don't want to fail later on if the method
71+
// is void.
72+
if (suspendedDeclaredMethod) {
73+
74+
this.invokable = args -> {
75+
76+
Object result = invokable.invoke(args);
77+
78+
if (returnedType == Unit.class) {
79+
80+
if (result instanceof Mono<?> m) {
81+
return m.then();
82+
}
83+
84+
if (result instanceof Flux<?> f) {
85+
return f.then();
86+
}
87+
88+
return ReactiveWrapperConverters.toWrapper(result, Flux.class).then();
89+
}
90+
91+
if (returnedType != Flow.class) {
92+
93+
if (result instanceof Mono<?> m) {
94+
return m;
95+
}
96+
97+
return ReactiveWrapperConverters.toWrapper(result, Flux.class).singleOrEmpty();
98+
}
99+
100+
return result;
101+
};
102+
} else {
103+
this.invokable = invokable;
104+
}
105+
69106
} else {
70107

71108
this.suspendedDeclaredMethod = false;
72109
this.returnedType = method.getReturnType();
110+
this.invokable = invokable;
73111
}
74-
75-
this.returnsReactiveType = ReactiveWrappers.supports(returnedType);
76112
}
77113

78114
static RepositoryQueryMethodInvoker forRepositoryQuery(Method declaredMethod, RepositoryQuery query) {
@@ -154,7 +190,7 @@ private RepositoryMethodInvocation computeInvocationResult(RepositoryMethodInvoc
154190
interface Invokable {
155191

156192
@Nullable
157-
Object invoke(Object[] args) throws ReflectiveOperationException;
193+
Object invoke(Object[] args) throws Exception;
158194
}
159195

160196
/**
@@ -214,8 +250,6 @@ Publisher<Object> decorate(Class<?> repositoryInterface, RepositoryInvocationMul
214250
*/
215251
private static class RepositoryFragmentMethodInvoker extends RepositoryMethodInvoker {
216252

217-
private final CoroutineAdapterInformation adapterInformation;
218-
219253
public RepositoryFragmentMethodInvoker(Method declaredMethod, Object instance, Method baseClassMethod) {
220254
this(CoroutineAdapterInformation.create(declaredMethod, baseClassMethod), declaredMethod, instance,
221255
baseClassMethod);
@@ -236,13 +270,12 @@ public RepositoryFragmentMethodInvoker(CoroutineAdapterInformation adapterInform
236270
return AopUtils.invokeJoinpointUsingReflection(instance, baseClassMethod, invocationArguments);
237271
}
238272
return AopUtils.invokeJoinpointUsingReflection(instance, baseClassMethod, args);
239-
} catch (RuntimeException e) {
273+
} catch (Exception e) {
240274
throw e;
241275
} catch (Throwable e) {
242276
throw new RuntimeException(e);
243277
}
244278
});
245-
this.adapterInformation = adapterInformation;
246279
}
247280

248281
/**

src/test/kotlin/org/springframework/data/repository/kotlin/CoroutineCrudRepositoryUnitTests.kt

+30
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,22 @@ class CoroutineCrudRepositoryUnitTests {
210210
Mockito.verify(invocationListener).afterInvocation(captor.capture())
211211
}
212212

213+
@Test // DATACMNS-1508, DATACMNS-1764
214+
fun shouldBridgeFluxQueryMethod() {
215+
216+
val sample = User()
217+
218+
Mockito.`when`(factory.queryOne.execute(arrayOf("foo", null, any()))).thenReturn(Flux.just(sample))
219+
220+
val result = runBlocking {
221+
coRepository.findOne("foo")
222+
}
223+
224+
assertThat(result).isNotNull().isEqualTo(sample)
225+
val captor = ArgumentCaptor.forClass(RepositoryMethodInvocationListener.RepositoryMethodInvocation::class.java)
226+
Mockito.verify(invocationListener).afterInvocation(captor.capture())
227+
}
228+
213229
@Test // DATACMNS-1508
214230
fun shouldBridgeRxJavaQueryMethod() {
215231

@@ -298,6 +314,18 @@ class CoroutineCrudRepositoryUnitTests {
298314
assertThat(emptyResult).isNull()
299315
}
300316

317+
@Test // DATACMNS-1802
318+
fun shouldDiscardResult() {
319+
320+
Mockito.`when`(factory.queryOne.execute(any())).thenReturn(Flux.empty<User>())
321+
322+
val result = runBlocking {
323+
coRepository.someDelete("foo")
324+
}
325+
326+
assertThat(result).isInstanceOf(Unit.javaClass)
327+
}
328+
301329
interface MyCoRepository : CoroutineCrudRepository<User, String> {
302330

303331
suspend fun findOne(id: String): User
@@ -307,5 +335,7 @@ class CoroutineCrudRepositoryUnitTests {
307335
suspend fun findSuspendedMultiple(id: String): Flow<User>
308336

309337
suspend fun findSuspendedAsList(id: String): List<User>
338+
339+
suspend fun someDelete(id: String)
310340
}
311341
}

0 commit comments

Comments
 (0)