Skip to content

Commit 23dccc5

Browse files
committed
Leverage KType in Kotlin Serialization WebMVC support
In order to take in account properly Kotlin null-safety with the annotation programming model. See gh-33016
1 parent 4555384 commit 23dccc5

File tree

4 files changed

+157
-100
lines changed

4 files changed

+157
-100
lines changed

spring-web/src/main/java/org/springframework/http/converter/AbstractKotlinSerializationHttpMessageConverter.java

+57-34
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2023 the original author or authors.
2+
* Copyright 2002-2024 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -17,22 +17,30 @@
1717
package org.springframework.http.converter;
1818

1919
import java.io.IOException;
20+
import java.lang.reflect.Method;
2021
import java.lang.reflect.Type;
2122
import java.util.HashSet;
2223
import java.util.Map;
2324
import java.util.Set;
2425

26+
import kotlin.reflect.KFunction;
27+
import kotlin.reflect.KType;
28+
import kotlin.reflect.full.KCallables;
29+
import kotlin.reflect.jvm.ReflectJvmMapping;
2530
import kotlinx.serialization.KSerializer;
2631
import kotlinx.serialization.SerialFormat;
2732
import kotlinx.serialization.SerializersKt;
2833
import kotlinx.serialization.descriptors.PolymorphicKind;
2934
import kotlinx.serialization.descriptors.SerialDescriptor;
3035

31-
import org.springframework.core.GenericTypeResolver;
36+
import org.springframework.core.KotlinDetector;
37+
import org.springframework.core.MethodParameter;
38+
import org.springframework.core.ResolvableType;
3239
import org.springframework.http.HttpInputMessage;
3340
import org.springframework.http.HttpOutputMessage;
3441
import org.springframework.http.MediaType;
3542
import org.springframework.lang.Nullable;
43+
import org.springframework.util.Assert;
3644
import org.springframework.util.ConcurrentReferenceHashMap;
3745

3846

@@ -48,9 +56,11 @@
4856
* @since 6.0
4957
* @param <T> the type of {@link SerialFormat}
5058
*/
51-
public abstract class AbstractKotlinSerializationHttpMessageConverter<T extends SerialFormat> extends AbstractGenericHttpMessageConverter<Object> {
59+
public abstract class AbstractKotlinSerializationHttpMessageConverter<T extends SerialFormat> extends AbstractSmartHttpMessageConverter<Object> {
5260

53-
private final Map<Type, KSerializer<Object>> serializerCache = new ConcurrentReferenceHashMap<>();
61+
private final Map<KType, KSerializer<Object>> kTypeSerializerCache = new ConcurrentReferenceHashMap<>();
62+
63+
private final Map<Type, KSerializer<Object>> typeSerializerCache = new ConcurrentReferenceHashMap<>();
5464

5565
private final T format;
5666

@@ -66,15 +76,14 @@ protected AbstractKotlinSerializationHttpMessageConverter(T format, MediaType...
6676
this.format = format;
6777
}
6878

69-
7079
@Override
7180
protected boolean supports(Class<?> clazz) {
72-
return serializer(clazz) != null;
81+
return serializer(ResolvableType.forClass(clazz)) != null;
7382
}
7483

7584
@Override
76-
public boolean canRead(Type type, @Nullable Class<?> contextClass, @Nullable MediaType mediaType) {
77-
if (serializer(GenericTypeResolver.resolveType(type, contextClass)) != null) {
85+
public boolean canRead(ResolvableType type, @Nullable MediaType mediaType) {
86+
if (!ResolvableType.NONE.equals(type) && serializer(type) != null) {
7887
return canRead(mediaType);
7988
}
8089
else {
@@ -83,8 +92,8 @@ public boolean canRead(Type type, @Nullable Class<?> contextClass, @Nullable Med
8392
}
8493

8594
@Override
86-
public boolean canWrite(@Nullable Type type, Class<?> clazz, @Nullable MediaType mediaType) {
87-
if (serializer(type != null ? GenericTypeResolver.resolveType(type, clazz) : clazz) != null) {
95+
public boolean canWrite(ResolvableType type, Class<?> clazz, @Nullable MediaType mediaType) {
96+
if (!ResolvableType.NONE.equals(type) && serializer(type) != null) {
8897
return canWrite(mediaType);
8998
}
9099
else {
@@ -93,24 +102,12 @@ public boolean canWrite(@Nullable Type type, Class<?> clazz, @Nullable MediaType
93102
}
94103

95104
@Override
96-
public final Object read(Type type, @Nullable Class<?> contextClass, HttpInputMessage inputMessage)
105+
public final Object read(ResolvableType type, HttpInputMessage inputMessage, @Nullable Map<String, Object> hints)
97106
throws IOException, HttpMessageNotReadableException {
98107

99-
Type resolvedType = GenericTypeResolver.resolveType(type, contextClass);
100-
KSerializer<Object> serializer = serializer(resolvedType);
108+
KSerializer<Object> serializer = serializer(type);
101109
if (serializer == null) {
102-
throw new HttpMessageNotReadableException("Could not find KSerializer for " + resolvedType, inputMessage);
103-
}
104-
return readInternal(serializer, this.format, inputMessage);
105-
}
106-
107-
@Override
108-
protected final Object readInternal(Class<?> clazz, HttpInputMessage inputMessage)
109-
throws IOException, HttpMessageNotReadableException {
110-
111-
KSerializer<Object> serializer = serializer(clazz);
112-
if (serializer == null) {
113-
throw new HttpMessageNotReadableException("Could not find KSerializer for " + clazz, inputMessage);
110+
throw new HttpMessageNotReadableException("Could not find KSerializer for " + type, inputMessage);
114111
}
115112
return readInternal(serializer, this.format, inputMessage);
116113
}
@@ -122,13 +119,13 @@ protected abstract Object readInternal(KSerializer<Object> serializer, T format,
122119
throws IOException, HttpMessageNotReadableException;
123120

124121
@Override
125-
protected final void writeInternal(Object object, @Nullable Type type, HttpOutputMessage outputMessage)
126-
throws IOException, HttpMessageNotWritableException {
122+
protected final void writeInternal(Object object, ResolvableType type, HttpOutputMessage outputMessage,
123+
@Nullable Map<String, Object> hints) throws IOException, HttpMessageNotWritableException {
127124

128-
Type resolvedType = type != null ? type : object.getClass();
129-
KSerializer<Object> serializer = serializer(resolvedType);
125+
ResolvableType resolvableType = (ResolvableType.NONE.equals(type) ? ResolvableType.forInstance(object) : type);
126+
KSerializer<Object> serializer = serializer(resolvableType);
130127
if (serializer == null) {
131-
throw new HttpMessageNotWritableException("Could not find KSerializer for " + resolvedType);
128+
throw new HttpMessageNotWritableException("Could not find KSerializer for " + resolvableType);
132129
}
133130
writeInternal(object, serializer, this.format, outputMessage);
134131
}
@@ -143,12 +140,38 @@ protected abstract void writeInternal(Object object, KSerializer<Object> seriali
143140
* Tries to find a serializer that can marshall or unmarshall instances of the given type
144141
* using kotlinx.serialization. If no serializer can be found, {@code null} is returned.
145142
* <p>Resolved serializers are cached and cached results are returned on successive calls.
146-
* @param type the type to find a serializer for
143+
* @param resolvableType the type to find a serializer for
147144
* @return a resolved serializer for the given type, or {@code null}
148145
*/
149146
@Nullable
150-
private KSerializer<Object> serializer(Type type) {
151-
KSerializer<Object> serializer = this.serializerCache.get(type);
147+
private KSerializer<Object> serializer(ResolvableType resolvableType) {
148+
if (resolvableType.getSource() instanceof MethodParameter parameter) {
149+
Method method = parameter.getMethod();
150+
Assert.notNull(method, "Method must not be null");
151+
if (KotlinDetector.isKotlinType(method.getDeclaringClass())) {
152+
KFunction<?> function = ReflectJvmMapping.getKotlinFunction(method);
153+
Assert.notNull(function, "Kotlin function must not be null");
154+
KType type = (parameter.getParameterIndex() == -1 ? function.getReturnType() :
155+
KCallables.getValueParameters(function).get(parameter.getParameterIndex()).getType());
156+
KSerializer<Object> serializer = this.kTypeSerializerCache.get(type);
157+
if (serializer == null) {
158+
try {
159+
serializer = SerializersKt.serializerOrNull(this.format.getSerializersModule(), type);
160+
}
161+
catch (IllegalArgumentException ignored) {
162+
}
163+
if (serializer != null) {
164+
if (hasPolymorphism(serializer.getDescriptor(), new HashSet<>())) {
165+
return null;
166+
}
167+
this.kTypeSerializerCache.put(type, serializer);
168+
}
169+
}
170+
return serializer;
171+
}
172+
}
173+
Type type = resolvableType.getType();
174+
KSerializer<Object> serializer = this.typeSerializerCache.get(type);
152175
if (serializer == null) {
153176
try {
154177
serializer = SerializersKt.serializerOrNull(this.format.getSerializersModule(), type);
@@ -159,7 +182,7 @@ private KSerializer<Object> serializer(Type type) {
159182
if (hasPolymorphism(serializer.getDescriptor(), new HashSet<>())) {
160183
return null;
161184
}
162-
this.serializerCache.put(type, serializer);
185+
this.typeSerializerCache.put(type, serializer);
163186
}
164187
}
165188
return serializer;

spring-web/src/test/kotlin/org/springframework/http/converter/cbor/KotlinSerializationCborHttpMessageConverterTests.kt

+21-21
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2021 the original author or authors.
2+
* Copyright 2002-2024 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -17,7 +17,6 @@
1717
package org.springframework.http.converter.cbor
1818

1919
import java.lang.reflect.ParameterizedType
20-
import java.lang.reflect.Type
2120
import java.nio.charset.StandardCharsets
2221

2322
import kotlin.reflect.javaType
@@ -31,6 +30,7 @@ import org.assertj.core.api.Assertions.assertThatExceptionOfType
3130
import org.junit.jupiter.api.Test
3231

3332
import org.springframework.core.Ordered
33+
import org.springframework.core.ResolvableType
3434
import org.springframework.http.MediaType
3535
import org.springframework.http.converter.HttpMessageNotReadableException
3636
import org.springframework.web.testfixture.http.MockHttpInputMessage
@@ -67,18 +67,18 @@ class KotlinSerializationCborHttpMessageConverterTests {
6767
assertThat(converter.canRead(NotSerializableBean::class.java, MediaType.APPLICATION_CBOR)).isFalse()
6868

6969
assertThat(converter.canRead(Map::class.java, MediaType.APPLICATION_CBOR)).isFalse()
70-
assertThat(converter.canRead(typeTokenOf<Map<String, SerializableBean>>(), Map::class.java, MediaType.APPLICATION_CBOR)).isTrue()
70+
assertThat(converter.canRead(resolvableTypeOf<Map<String, SerializableBean>>(), MediaType.APPLICATION_CBOR)).isTrue()
7171
assertThat(converter.canRead(List::class.java, MediaType.APPLICATION_CBOR)).isFalse()
72-
assertThat(converter.canRead(typeTokenOf<List<SerializableBean>>(), List::class.java, MediaType.APPLICATION_CBOR)).isTrue()
72+
assertThat(converter.canRead(resolvableTypeOf<List<SerializableBean>>(), MediaType.APPLICATION_CBOR)).isTrue()
7373
assertThat(converter.canRead(Set::class.java, MediaType.APPLICATION_CBOR)).isFalse()
74-
assertThat(converter.canRead(typeTokenOf<Set<SerializableBean>>(), Set::class.java, MediaType.APPLICATION_CBOR)).isTrue()
74+
assertThat(converter.canRead(resolvableTypeOf<Set<SerializableBean>>(), MediaType.APPLICATION_CBOR)).isTrue()
7575

76-
assertThat(converter.canRead(typeTokenOf<List<Int>>(), List::class.java, MediaType.APPLICATION_CBOR)).isTrue()
77-
assertThat(converter.canRead(typeTokenOf<ArrayList<Int>>(), List::class.java, MediaType.APPLICATION_CBOR)).isTrue()
78-
assertThat(converter.canRead(typeTokenOf<List<Int>>(), List::class.java, MediaType.APPLICATION_JSON)).isFalse()
76+
assertThat(converter.canRead(resolvableTypeOf<List<Int>>(), MediaType.APPLICATION_CBOR)).isTrue()
77+
assertThat(converter.canRead(resolvableTypeOf<ArrayList<Int>>(), MediaType.APPLICATION_CBOR)).isTrue()
78+
assertThat(converter.canRead(resolvableTypeOf<List<Int>>(), MediaType.APPLICATION_JSON)).isFalse()
7979

80-
assertThat(converter.canRead(typeTokenOf<Ordered>(), Ordered::class.java, MediaType.APPLICATION_CBOR)).isFalse()
81-
assertThat(converter.canRead(typeTokenOf<List<Ordered>>(), List::class.java, MediaType.APPLICATION_CBOR)).isFalse()
80+
assertThat(converter.canRead(resolvableTypeOf<Ordered>(), MediaType.APPLICATION_CBOR)).isFalse()
81+
assertThat(converter.canRead(resolvableTypeOf<List<Ordered>>(), MediaType.APPLICATION_CBOR)).isFalse()
8282
}
8383

8484
@Test
@@ -89,17 +89,17 @@ class KotlinSerializationCborHttpMessageConverterTests {
8989
assertThat(converter.canWrite(NotSerializableBean::class.java, MediaType.APPLICATION_CBOR)).isFalse()
9090

9191
assertThat(converter.canWrite(Map::class.java, MediaType.APPLICATION_CBOR)).isFalse()
92-
assertThat(converter.canWrite(typeTokenOf<Map<String, SerializableBean>>(), Map::class.java, MediaType.APPLICATION_CBOR)).isTrue()
92+
assertThat(converter.canWrite(resolvableTypeOf<Map<String, SerializableBean>>(), Map::class.java, MediaType.APPLICATION_CBOR)).isTrue()
9393
assertThat(converter.canWrite(List::class.java, MediaType.APPLICATION_CBOR)).isFalse()
94-
assertThat(converter.canWrite(typeTokenOf<List<SerializableBean>>(), List::class.java, MediaType.APPLICATION_CBOR)).isTrue()
94+
assertThat(converter.canWrite(resolvableTypeOf<List<SerializableBean>>(), List::class.java, MediaType.APPLICATION_CBOR)).isTrue()
9595
assertThat(converter.canWrite(Set::class.java, MediaType.APPLICATION_CBOR)).isFalse()
96-
assertThat(converter.canWrite(typeTokenOf<Set<SerializableBean>>(), Set::class.java, MediaType.APPLICATION_CBOR)).isTrue()
96+
assertThat(converter.canWrite(resolvableTypeOf<Set<SerializableBean>>(), Set::class.java, MediaType.APPLICATION_CBOR)).isTrue()
9797

98-
assertThat(converter.canWrite(typeTokenOf<List<Int>>(), List::class.java, MediaType.APPLICATION_CBOR)).isTrue()
99-
assertThat(converter.canWrite(typeTokenOf<ArrayList<Int>>(), List::class.java, MediaType.APPLICATION_CBOR)).isTrue()
100-
assertThat(converter.canWrite(typeTokenOf<List<Int>>(), List::class.java, MediaType.APPLICATION_JSON)).isFalse()
98+
assertThat(converter.canWrite(resolvableTypeOf<List<Int>>(), List::class.java, MediaType.APPLICATION_CBOR)).isTrue()
99+
assertThat(converter.canWrite(resolvableTypeOf<ArrayList<Int>>(), List::class.java, MediaType.APPLICATION_CBOR)).isTrue()
100+
assertThat(converter.canWrite(resolvableTypeOf<List<Int>>(), List::class.java, MediaType.APPLICATION_JSON)).isFalse()
101101

102-
assertThat(converter.canWrite(typeTokenOf<Ordered>(), Ordered::class.java, MediaType.APPLICATION_CBOR)).isFalse()
102+
assertThat(converter.canWrite(resolvableTypeOf<Ordered>(), Ordered::class.java, MediaType.APPLICATION_CBOR)).isFalse()
103103
}
104104

105105
@Test
@@ -139,7 +139,7 @@ class KotlinSerializationCborHttpMessageConverterTests {
139139
fun readGenericCollection() {
140140
val inputMessage = MockHttpInputMessage(serializableBeanArrayBody)
141141
inputMessage.headers.contentType = MediaType.APPLICATION_CBOR
142-
val result = converter.read(typeOf<List<SerializableBean>>().javaType, null, inputMessage)
142+
val result = converter.read(ResolvableType.forType(typeOf<List<SerializableBean>>().javaType), inputMessage, null)
143143
as List<SerializableBean>
144144

145145
assertThat(result).hasSize(1)
@@ -200,7 +200,7 @@ class KotlinSerializationCborHttpMessageConverterTests {
200200
fun writeGenericCollection() {
201201
val outputMessage = MockHttpOutputMessage()
202202

203-
this.converter.write(listOf(serializableBean), typeOf<List<SerializableBean>>().javaType, null, outputMessage)
203+
this.converter.write(listOf(serializableBean), ResolvableType.forType(typeOf<List<SerializableBean>>().javaType), null, outputMessage, null)
204204

205205
assertThat(outputMessage.headers).containsEntry("Content-Type", listOf("application/cbor"))
206206
assertThat(outputMessage.bodyAsBytes.isNotEmpty()).isTrue()
@@ -222,10 +222,10 @@ class KotlinSerializationCborHttpMessageConverterTests {
222222

223223
open class TypeBase<T>
224224

225-
inline fun <reified T> typeTokenOf(): Type {
225+
private inline fun <reified T> resolvableTypeOf(): ResolvableType {
226226
val base = object : TypeBase<T>() {}
227227
val superType = base::class.java.genericSuperclass!!
228-
return (superType as ParameterizedType).actualTypeArguments.first()!!
228+
return ResolvableType.forType((superType as ParameterizedType).actualTypeArguments.first()!!)
229229
}
230230

231231
}

0 commit comments

Comments
 (0)