Skip to content

Commit 98e89d8

Browse files
committed
Leverage KType in Kotlin Serialization WebFlux support
In order to take in account properly Kotlin null-safety with the annotation programming model. Closes gh-33016
1 parent 23dccc5 commit 98e89d8

File tree

3 files changed

+75
-3
lines changed

3 files changed

+75
-3
lines changed

spring-web/src/main/java/org/springframework/http/codec/KotlinSerializationSupport.java

+39-3
Original file line numberDiff line numberDiff line change
@@ -16,21 +16,29 @@
1616

1717
package org.springframework.http.codec;
1818

19+
import java.lang.reflect.Method;
1920
import java.lang.reflect.Type;
2021
import java.util.Arrays;
2122
import java.util.HashSet;
2223
import java.util.List;
2324
import java.util.Map;
2425
import java.util.Set;
2526

27+
import kotlin.reflect.KFunction;
28+
import kotlin.reflect.KType;
29+
import kotlin.reflect.full.KCallables;
30+
import kotlin.reflect.jvm.ReflectJvmMapping;
2631
import kotlinx.serialization.KSerializer;
2732
import kotlinx.serialization.SerialFormat;
2833
import kotlinx.serialization.SerializersKt;
2934
import kotlinx.serialization.descriptors.PolymorphicKind;
3035
import kotlinx.serialization.descriptors.SerialDescriptor;
3136

37+
import org.springframework.core.KotlinDetector;
38+
import org.springframework.core.MethodParameter;
3239
import org.springframework.core.ResolvableType;
3340
import org.springframework.lang.Nullable;
41+
import org.springframework.util.Assert;
3442
import org.springframework.util.ConcurrentReferenceHashMap;
3543
import org.springframework.util.MimeType;
3644

@@ -46,7 +54,10 @@
4654
*/
4755
public abstract class KotlinSerializationSupport<T extends SerialFormat> {
4856

49-
private final Map<Type, KSerializer<Object>> serializerCache = new ConcurrentReferenceHashMap<>();
57+
private final Map<Type, KSerializer<Object>> typeSerializerCache = new ConcurrentReferenceHashMap<>();
58+
59+
private final Map<KType, KSerializer<Object>> kTypeSerializerCache = new ConcurrentReferenceHashMap<>();
60+
5061

5162
private final T format;
5263

@@ -117,8 +128,33 @@ private boolean supports(@Nullable MimeType mimeType) {
117128
*/
118129
@Nullable
119130
protected final KSerializer<Object> serializer(ResolvableType resolvableType) {
131+
if (resolvableType.getSource() instanceof MethodParameter parameter) {
132+
Method method = parameter.getMethod();
133+
Assert.notNull(method, "Method must not be null");
134+
if (KotlinDetector.isKotlinType(method.getDeclaringClass())) {
135+
KFunction<?> function = ReflectJvmMapping.getKotlinFunction(method);
136+
Assert.notNull(function, "Kotlin function must not be null");
137+
KType type = (parameter.getParameterIndex() == -1 ? function.getReturnType() :
138+
KCallables.getValueParameters(function).get(parameter.getParameterIndex()).getType());
139+
KSerializer<Object> serializer = this.kTypeSerializerCache.get(type);
140+
if (serializer == null) {
141+
try {
142+
serializer = SerializersKt.serializerOrNull(this.format.getSerializersModule(), type);
143+
}
144+
catch (IllegalArgumentException ignored) {
145+
}
146+
if (serializer != null) {
147+
if (hasPolymorphism(serializer.getDescriptor(), new HashSet<>())) {
148+
return null;
149+
}
150+
this.kTypeSerializerCache.put(type, serializer);
151+
}
152+
}
153+
return serializer;
154+
}
155+
}
120156
Type type = resolvableType.getType();
121-
KSerializer<Object> serializer = this.serializerCache.get(type);
157+
KSerializer<Object> serializer = this.typeSerializerCache.get(type);
122158
if (serializer == null) {
123159
try {
124160
serializer = SerializersKt.serializerOrNull(this.format.getSerializersModule(), type);
@@ -129,7 +165,7 @@ protected final KSerializer<Object> serializer(ResolvableType resolvableType) {
129165
if (hasPolymorphism(serializer.getDescriptor(), new HashSet<>())) {
130166
return null;
131167
}
132-
this.serializerCache.put(type, serializer);
168+
this.typeSerializerCache.put(type, serializer);
133169
}
134170
}
135171
return serializer;

spring-web/src/test/kotlin/org/springframework/http/codec/json/KotlinSerializationJsonDecoderTests.kt

+21
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,11 @@ package org.springframework.http.codec.json
1919
import kotlinx.serialization.Serializable
2020
import org.assertj.core.api.Assertions.assertThat
2121
import org.junit.jupiter.api.Test
22+
import org.springframework.core.MethodParameter
2223
import org.springframework.core.Ordered
2324
import org.springframework.core.ResolvableType
2425
import org.springframework.core.io.buffer.DataBuffer
26+
import org.springframework.core.io.buffer.DataBufferUtils
2527
import org.springframework.core.testfixture.codec.AbstractDecoderTests
2628
import org.springframework.http.MediaType
2729
import reactor.core.publisher.Flux
@@ -32,6 +34,7 @@ import java.lang.UnsupportedOperationException
3234
import java.math.BigDecimal
3335
import java.nio.charset.Charset
3436
import java.nio.charset.StandardCharsets
37+
import kotlin.reflect.jvm.javaMethod
3538

3639
/**
3740
* Tests for the JSON decoding using kotlinx.serialization.
@@ -128,6 +131,22 @@ class KotlinSerializationJsonDecoderTests : AbstractDecoderTests<KotlinSerializa
128131
}, null, null)
129132
}
130133

134+
@Test
135+
fun decodeToMonoWithNullableWithNull() {
136+
val input = Flux.concat(
137+
stringBuffer("{\"value\":null}\n"),
138+
)
139+
140+
val methodParameter = MethodParameter.forExecutable(::handleMapWithNullable::javaMethod.get()!!, -1)
141+
val elementType = ResolvableType.forMethodParameter(methodParameter)
142+
143+
testDecodeToMonoAll(input, elementType, {
144+
it.expectNext(mapOf("value" to null))
145+
.expectComplete()
146+
.verify()
147+
}, null, null)
148+
}
149+
131150
private fun stringBuffer(value: String): Mono<DataBuffer> {
132151
return stringBuffer(value, StandardCharsets.UTF_8)
133152
}
@@ -145,4 +164,6 @@ class KotlinSerializationJsonDecoderTests : AbstractDecoderTests<KotlinSerializa
145164
@Serializable
146165
data class Pojo(val foo: String, val bar: String, val pojo: Pojo? = null)
147166

167+
fun handleMapWithNullable(map: Map<String, String?>) = map
168+
148169
}

spring-web/src/test/kotlin/org/springframework/http/codec/json/KotlinSerializationJsonEncoderTests.kt

+15
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ package org.springframework.http.codec.json
1919
import kotlinx.serialization.Serializable
2020
import org.assertj.core.api.Assertions.assertThat
2121
import org.junit.jupiter.api.Test
22+
import org.springframework.core.MethodParameter
2223
import org.springframework.core.Ordered
2324
import org.springframework.core.ResolvableType
2425
import org.springframework.core.io.buffer.DataBuffer
@@ -31,6 +32,7 @@ import reactor.core.publisher.Mono
3132
import reactor.test.StepVerifier.FirstStep
3233
import java.math.BigDecimal
3334
import java.nio.charset.StandardCharsets
35+
import kotlin.reflect.jvm.javaMethod
3436

3537
/**
3638
* Tests for the JSON encoding using kotlinx.serialization.
@@ -109,6 +111,17 @@ class KotlinSerializationJsonEncoderTests : AbstractEncoderTests<KotlinSerializa
109111
}
110112
}
111113

114+
@Test
115+
fun encodeMonoWithNullableWithNull() {
116+
val input = Mono.just(mapOf("value" to null))
117+
val methodParameter = MethodParameter.forExecutable(::handleMapWithNullable::javaMethod.get()!!, -1)
118+
testEncode(input, ResolvableType.forMethodParameter(methodParameter), null, null) {
119+
it.consumeNextWith(expectString("{\"value\":null}")
120+
.andThen { dataBuffer: DataBuffer? -> DataBufferUtils.release(dataBuffer) })
121+
.verifyComplete()
122+
}
123+
}
124+
112125
@Test
113126
fun canNotEncode() {
114127
assertThat(encoder.canEncode(ResolvableType.forClass(String::class.java), null)).isFalse()
@@ -123,4 +136,6 @@ class KotlinSerializationJsonEncoderTests : AbstractEncoderTests<KotlinSerializa
123136
@Serializable
124137
data class Pojo(val foo: String, val bar: String, val pojo: Pojo? = null)
125138

139+
fun handleMapWithNullable(map: Map<String, String?>) = map
140+
126141
}

0 commit comments

Comments
 (0)