Skip to content

Commit de4013a

Browse files
committed
Consider declaring class when evaluating method return type for query method post-processing.
We now consider the declaring class to properly resolve type variable references for the result post-processing of a query method result. Previously, we attempted to resolve the return type without considering the actual repository class resolving always Object instead of the type parameter. Closes #3125
1 parent 75d0992 commit de4013a

File tree

3 files changed

+71
-20
lines changed

3 files changed

+71
-20
lines changed

src/main/java/org/springframework/data/repository/core/support/QueryExecutionResultHandler.java

+39-14
Original file line numberDiff line numberDiff line change
@@ -15,21 +15,23 @@
1515
*/
1616
package org.springframework.data.repository.core.support;
1717

18-
import java.lang.reflect.Method;
1918
import java.util.Collection;
2019
import java.util.Collections;
2120
import java.util.HashMap;
2221
import java.util.Map;
2322
import java.util.Optional;
2423

2524
import org.springframework.core.CollectionFactory;
25+
import org.springframework.core.KotlinDetector;
2626
import org.springframework.core.MethodParameter;
2727
import org.springframework.core.convert.ConversionService;
2828
import org.springframework.core.convert.TypeDescriptor;
2929
import org.springframework.core.convert.support.GenericConversionService;
30+
import org.springframework.data.repository.util.ClassUtils;
3031
import org.springframework.data.repository.util.QueryExecutionConverters;
3132
import org.springframework.data.repository.util.ReactiveWrapperConverters;
3233
import org.springframework.data.util.NullableWrapper;
34+
import org.springframework.data.util.ReactiveWrappers;
3335
import org.springframework.data.util.Streamable;
3436
import org.springframework.lang.Nullable;
3537

@@ -44,12 +46,14 @@ class QueryExecutionResultHandler {
4446

4547
private static final TypeDescriptor WRAPPER_TYPE = TypeDescriptor.valueOf(NullableWrapper.class);
4648

49+
private static final Class<?> FLOW_TYPE = loadIfPresent("kotlinx.coroutines.flow.Flow");
50+
4751
private final GenericConversionService conversionService;
4852

4953
private final Object mutex = new Object();
5054

5155
// concurrent access guarded by mutex.
52-
private Map<Method, ReturnTypeDescriptor> descriptorCache = Collections.emptyMap();
56+
private Map<MethodParameter, ReturnTypeDescriptor> descriptorCache = Collections.emptyMap();
5357

5458
/**
5559
* Creates a new {@link QueryExecutionResultHandler}.
@@ -58,6 +62,17 @@ class QueryExecutionResultHandler {
5862
this.conversionService = conversionService;
5963
}
6064

65+
@Nullable
66+
@SuppressWarnings("unchecked")
67+
public static <T> Class<T> loadIfPresent(String type) {
68+
69+
try {
70+
return (Class<T>) org.springframework.util.ClassUtils.forName(type, ClassUtils.class.getClassLoader());
71+
} catch (ClassNotFoundException | LinkageError e) {
72+
return null;
73+
}
74+
}
75+
6176
/**
6277
* Post-processes the given result of a query invocation to match the return type of the given method.
6378
*
@@ -66,9 +81,9 @@ class QueryExecutionResultHandler {
6681
* @return
6782
*/
6883
@Nullable
69-
Object postProcessInvocationResult(@Nullable Object result, Method method) {
84+
Object postProcessInvocationResult(@Nullable Object result, MethodParameter method) {
7085

71-
if (!processingRequired(result, method.getReturnType())) {
86+
if (!processingRequired(result, method)) {
7287
return result;
7388
}
7489

@@ -77,24 +92,23 @@ Object postProcessInvocationResult(@Nullable Object result, Method method) {
7792
return postProcessInvocationResult(result, 0, descriptor);
7893
}
7994

80-
private ReturnTypeDescriptor getOrCreateReturnTypeDescriptor(Method method) {
95+
private ReturnTypeDescriptor getOrCreateReturnTypeDescriptor(MethodParameter method) {
8196

82-
Map<Method, ReturnTypeDescriptor> descriptorCache = this.descriptorCache;
97+
Map<MethodParameter, ReturnTypeDescriptor> descriptorCache = this.descriptorCache;
8398
ReturnTypeDescriptor descriptor = descriptorCache.get(method);
8499

85100
if (descriptor == null) {
86101

87102
descriptor = ReturnTypeDescriptor.of(method);
88103

89-
Map<Method, ReturnTypeDescriptor> updatedDescriptorCache;
104+
Map<MethodParameter, ReturnTypeDescriptor> updatedDescriptorCache;
90105

91106
if (descriptorCache.isEmpty()) {
92107
updatedDescriptorCache = Collections.singletonMap(method, descriptor);
93108
} else {
94109
updatedDescriptorCache = new HashMap<>(descriptorCache.size() + 1, 1);
95110
updatedDescriptorCache.putAll(descriptorCache);
96111
updatedDescriptorCache.put(method, descriptor);
97-
98112
}
99113

100114
synchronized (mutex) {
@@ -234,10 +248,21 @@ private static Object unwrapOptional(@Nullable Object source) {
234248
* Returns whether we have to process the given source object in the first place.
235249
*
236250
* @param source can be {@literal null}.
237-
* @param targetType must not be {@literal null}.
251+
* @param methodParameter must not be {@literal null}.
238252
* @return
239253
*/
240-
private static boolean processingRequired(@Nullable Object source, Class<?> targetType) {
254+
private static boolean processingRequired(@Nullable Object source, MethodParameter methodParameter) {
255+
256+
Class<?> targetType = methodParameter.getParameterType();
257+
258+
if (source != null && ReactiveWrappers.KOTLIN_COROUTINES_PRESENT
259+
&& KotlinDetector.isSuspendingFunction(methodParameter.getMethod())) {
260+
261+
// Spring's AOP invoker handles Publisher to Flow conversion, so we have to exempt these from post-processing.
262+
if (FLOW_TYPE != null && FLOW_TYPE.isAssignableFrom(targetType)) {
263+
return false;
264+
}
265+
}
241266

242267
return !targetType.isInstance(source) //
243268
|| source == null //
@@ -253,19 +278,19 @@ static class ReturnTypeDescriptor {
253278
private final TypeDescriptor typeDescriptor;
254279
private final @Nullable TypeDescriptor nestedTypeDescriptor;
255280

256-
private ReturnTypeDescriptor(Method method) {
257-
this.methodParameter = new MethodParameter(method, -1);
281+
private ReturnTypeDescriptor(MethodParameter methodParameter) {
282+
this.methodParameter = methodParameter;
258283
this.typeDescriptor = TypeDescriptor.nested(this.methodParameter, 0);
259284
this.nestedTypeDescriptor = TypeDescriptor.nested(this.methodParameter, 1);
260285
}
261286

262287
/**
263-
* Create a {@link ReturnTypeDescriptor} from a {@link Method}.
288+
* Create a {@link ReturnTypeDescriptor} from a {@link MethodParameter}.
264289
*
265290
* @param method
266291
* @return
267292
*/
268-
public static ReturnTypeDescriptor of(Method method) {
293+
public static ReturnTypeDescriptor of(MethodParameter method) {
269294
return new ReturnTypeDescriptor(method);
270295
}
271296

src/main/java/org/springframework/data/repository/core/support/QueryExecutorMethodInterceptor.java

+8-3
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,12 @@
2121
import java.util.List;
2222
import java.util.Map;
2323
import java.util.Optional;
24+
import java.util.concurrent.ConcurrentHashMap;
2425

2526
import org.aopalliance.intercept.MethodInterceptor;
2627
import org.aopalliance.intercept.MethodInvocation;
28+
29+
import org.springframework.core.MethodParameter;
2730
import org.springframework.core.ResolvableType;
2831
import org.springframework.data.projection.ProjectionFactory;
2932
import org.springframework.data.repository.core.NamedQueries;
@@ -55,6 +58,7 @@ class QueryExecutorMethodInterceptor implements MethodInterceptor {
5558
private final RepositoryInformation repositoryInformation;
5659
private final Map<Method, RepositoryQuery> queries;
5760
private final Map<Method, RepositoryMethodInvoker> invocationMetadataCache = new ConcurrentReferenceHashMap<>();
61+
private final Map<Method, MethodParameter> returnTypeMap = new ConcurrentHashMap<>();
5862
private final QueryExecutionResultHandler resultHandler;
5963
private final NamedQueries namedQueries;
6064
private final List<QueryCreationListener<?>> queryPostProcessors;
@@ -135,16 +139,17 @@ private void invokeListeners(RepositoryQuery query) {
135139
public Object invoke(@SuppressWarnings("null") MethodInvocation invocation) throws Throwable {
136140

137141
Method method = invocation.getMethod();
142+
MethodParameter returnType = returnTypeMap.computeIfAbsent(method, it -> new MethodParameter(it, -1));
138143

139144
QueryExecutionConverters.ExecutionAdapter executionAdapter = QueryExecutionConverters //
140-
.getExecutionAdapter(method.getReturnType());
145+
.getExecutionAdapter(returnType.getParameterType());
141146

142147
if (executionAdapter == null) {
143-
return resultHandler.postProcessInvocationResult(doInvoke(invocation), method);
148+
return resultHandler.postProcessInvocationResult(doInvoke(invocation), returnType);
144149
}
145150

146151
return executionAdapter //
147-
.apply(() -> resultHandler.postProcessInvocationResult(doInvoke(invocation), method));
152+
.apply(() -> resultHandler.postProcessInvocationResult(doInvoke(invocation), returnType));
148153
}
149154

150155
@Nullable

src/test/java/org/springframework/data/repository/core/support/QueryExecutionResultHandlerUnitTests.java

+24-3
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,6 @@
2626
import reactor.core.publisher.Flux;
2727
import reactor.core.publisher.Mono;
2828

29-
import java.lang.reflect.Method;
3029
import java.math.BigDecimal;
3130
import java.util.Arrays;
3231
import java.util.Collections;
@@ -40,6 +39,8 @@
4039
import org.assertj.core.api.SoftAssertions;
4140
import org.junit.jupiter.api.Test;
4241
import org.reactivestreams.Publisher;
42+
43+
import org.springframework.core.MethodParameter;
4344
import org.springframework.dao.InvalidDataAccessApiUsageException;
4445
import org.springframework.data.repository.Repository;
4546
import org.springframework.data.util.Streamable;
@@ -404,6 +405,17 @@ void nestedConversion() throws Exception {
404405
});
405406
}
406407

408+
@Test // GH-3125
409+
void considersTypeBoundsFromBaseInterface() throws NoSuchMethodException {
410+
411+
var method = CustomizedRepository.class.getMethod("findById", Object.class);
412+
413+
var result = handler.postProcessInvocationResult(Optional.of(new Entity()),
414+
new MethodParameter(method, -1).withContainingClass(CustomizedRepository.class));
415+
416+
assertThat(result).isInstanceOf(Entity.class);
417+
}
418+
407419
@Test // DATACMNS-1552
408420
void keepsVavrOptionType() throws Exception {
409421

@@ -412,8 +424,17 @@ void keepsVavrOptionType() throws Exception {
412424
assertThat(handler.postProcessInvocationResult(source, getMethod("option"))).isSameAs(source);
413425
}
414426

415-
private static Method getMethod(String methodName) throws Exception {
416-
return Sample.class.getMethod(methodName);
427+
private static MethodParameter getMethod(String methodName) throws Exception {
428+
return new MethodParameter(Sample.class.getMethod(methodName), -1);
429+
}
430+
431+
interface BaseRepository<T, ID> extends Repository<T, ID> {
432+
433+
T findById(ID id);
434+
}
435+
436+
interface CustomizedRepository extends BaseRepository<Entity, Long> {
437+
417438
}
418439

419440
static interface Sample extends Repository<Entity, Long> {

0 commit comments

Comments
 (0)