Skip to content

Commit 5cee607

Browse files
committed
WebFlux @RequestPart support for List and Flux arguments
The resolver now supports List<T>, Flux<T>, and List<Part>. Issue: SPR-16621
1 parent 15182b2 commit 5cee607

File tree

2 files changed

+298
-16
lines changed

2 files changed

+298
-16
lines changed

spring-webflux/src/main/java/org/springframework/web/reactive/result/method/annotation/RequestPartMethodArgumentResolver.java

Lines changed: 41 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -74,35 +74,49 @@ public Mono<Object> resolveArgument(MethodParameter parameter, BindingContext bi
7474
boolean isRequired = (requestPart == null || requestPart.required());
7575
String name = getPartName(parameter, requestPart);
7676

77-
Flux<Part> values = exchange.getMultipartData()
77+
Flux<Part> parts = exchange.getMultipartData()
7878
.flatMapMany(map -> {
79-
List<Part> parts = map.get(name);
80-
if (CollectionUtils.isEmpty(parts)) {
79+
List<Part> list = map.get(name);
80+
if (CollectionUtils.isEmpty(list)) {
8181
return isRequired ?
8282
Flux.error(getMissingPartException(name, parameter)) :
8383
Flux.empty();
8484
}
85-
return Flux.fromIterable(parts);
85+
return Flux.fromIterable(list);
8686
});
8787

88-
ReactiveAdapter adapter = getAdapterRegistry().getAdapter(parameter.getParameterType());
89-
MethodParameter elementType = adapter != null ? parameter.nested() : parameter;
88+
if (Part.class.isAssignableFrom(parameter.getParameterType())) {
89+
return parts.next().cast(Object.class);
90+
}
9091

91-
if (Part.class.isAssignableFrom(elementType.getNestedParameterType())) {
92-
if (adapter != null) {
93-
values = adapter.isMultiValue() ? values : values.take(1);
94-
return Mono.just(adapter.fromPublisher(values));
92+
if (List.class.isAssignableFrom(parameter.getParameterType())) {
93+
MethodParameter elementType = parameter.nested();
94+
if (Part.class.isAssignableFrom(elementType.getNestedParameterType())) {
95+
return parts.collectList().cast(Object.class);
9596
}
9697
else {
97-
return values.next().cast(Object.class);
98+
return decodePartValues(parts, elementType, bindingContext, exchange, isRequired)
99+
.collectList().cast(Object.class);
98100
}
99101
}
100102

101-
return values.next().flatMap(part -> {
102-
ServerHttpRequest partRequest = new PartServerHttpRequest(exchange.getRequest(), part);
103-
ServerWebExchange partExchange = exchange.mutate().request(partRequest).build();
104-
return readBody(parameter, isRequired, bindingContext, partExchange);
105-
});
103+
ReactiveAdapter adapter = getAdapterRegistry().getAdapter(parameter.getParameterType());
104+
if (adapter != null) {
105+
// Mono<Part> or Flux<Part>
106+
MethodParameter elementType = parameter.nested();
107+
if (Part.class.isAssignableFrom(elementType.getNestedParameterType())) {
108+
parts = adapter.isMultiValue() ? parts : parts.take(1);
109+
return Mono.just(adapter.fromPublisher(parts));
110+
}
111+
// We have to decode the content for each part, one at a time
112+
if (adapter.isMultiValue()) {
113+
return Mono.just(decodePartValues(parts, elementType, bindingContext, exchange, isRequired));
114+
}
115+
}
116+
117+
// <T> or Mono<T>
118+
return decodePartValues(parts, parameter, bindingContext, exchange, isRequired)
119+
.next().cast(Object.class);
106120
}
107121

108122
private String getPartName(MethodParameter methodParam, @Nullable RequestPart requestPart) {
@@ -124,6 +138,17 @@ private ServerWebInputException getMissingPartException(String name, MethodParam
124138
}
125139

126140

141+
private Flux<?> decodePartValues(Flux<Part> parts, MethodParameter elementType, BindingContext bindingContext,
142+
ServerWebExchange exchange, boolean isRequired) {
143+
144+
return parts.flatMap(part -> {
145+
ServerHttpRequest partRequest = new PartServerHttpRequest(exchange.getRequest(), part);
146+
ServerWebExchange partExchange = exchange.mutate().request(partRequest).build();
147+
return readBody(elementType, isRequired, bindingContext, partExchange);
148+
});
149+
}
150+
151+
127152
private static class PartServerHttpRequest extends ServerHttpRequestDecorator {
128153

129154
private final Part part;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
/*
2+
* Copyright 2002-2018 the original author or authors.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
package org.springframework.web.reactive.result.method.annotation;
18+
19+
import java.nio.charset.StandardCharsets;
20+
import java.time.Duration;
21+
import java.util.Collections;
22+
import java.util.List;
23+
24+
import com.fasterxml.jackson.annotation.JsonCreator;
25+
import com.fasterxml.jackson.annotation.JsonProperty;
26+
import org.junit.Before;
27+
import org.junit.Test;
28+
import reactor.core.publisher.Flux;
29+
import reactor.core.publisher.Mono;
30+
31+
import org.springframework.core.MethodParameter;
32+
import org.springframework.core.ReactiveAdapterRegistry;
33+
import org.springframework.core.io.buffer.DataBuffer;
34+
import org.springframework.core.io.buffer.DataBufferUtils;
35+
import org.springframework.core.io.buffer.support.DataBufferTestUtils;
36+
import org.springframework.http.HttpMethod;
37+
import org.springframework.http.MediaType;
38+
import org.springframework.http.client.MultipartBodyBuilder;
39+
import org.springframework.http.codec.ClientCodecConfigurer;
40+
import org.springframework.http.codec.HttpMessageReader;
41+
import org.springframework.http.codec.HttpMessageWriter;
42+
import org.springframework.http.codec.ServerCodecConfigurer;
43+
import org.springframework.http.codec.multipart.MultipartHttpMessageWriter;
44+
import org.springframework.http.codec.multipart.Part;
45+
import org.springframework.mock.http.client.reactive.test.MockClientHttpRequest;
46+
import org.springframework.mock.http.server.reactive.test.MockServerHttpRequest;
47+
import org.springframework.mock.web.test.server.MockServerWebExchange;
48+
import org.springframework.util.MultiValueMap;
49+
import org.springframework.web.bind.annotation.RequestPart;
50+
import org.springframework.web.method.ResolvableMethod;
51+
import org.springframework.web.reactive.BindingContext;
52+
import org.springframework.web.server.ServerWebExchange;
53+
54+
import static org.junit.Assert.*;
55+
import static org.springframework.core.ResolvableType.*;
56+
import static org.springframework.web.method.MvcAnnotationPredicates.*;
57+
58+
/**
59+
* Unit tests for {@link RequestPartMethodArgumentResolver}.
60+
* @author Rossen Stoyanchev
61+
*/
62+
public class RequestPartMethodArgumentResolverTests {
63+
64+
private RequestPartMethodArgumentResolver resolver;
65+
66+
private ResolvableMethod testMethod = ResolvableMethod.on(getClass()).named("handle").build();
67+
68+
private MultipartHttpMessageWriter writer;
69+
70+
71+
@Before
72+
public void setup() throws Exception {
73+
List<HttpMessageReader<?>> readers = ServerCodecConfigurer.create().getReaders();
74+
ReactiveAdapterRegistry registry = ReactiveAdapterRegistry.getSharedInstance();
75+
this.resolver = new RequestPartMethodArgumentResolver(readers, registry);
76+
77+
List<HttpMessageWriter<?>> writers = ClientCodecConfigurer.create().getWriters();
78+
this.writer = new MultipartHttpMessageWriter(writers);
79+
}
80+
81+
82+
@Test
83+
public void supportsParameter() {
84+
85+
MethodParameter param;
86+
87+
param = this.testMethod.annot(requestPart().name("name")).arg(Person.class);
88+
assertTrue(this.resolver.supportsParameter(param));
89+
90+
param = this.testMethod.annot(requestPart().name("name")).arg(Mono.class, Person.class);
91+
assertTrue(this.resolver.supportsParameter(param));
92+
93+
param = this.testMethod.annot(requestPart().name("name")).arg(Flux.class, Person.class);
94+
assertTrue(this.resolver.supportsParameter(param));
95+
96+
param = this.testMethod.annot(requestPart().name("name")).arg(Part.class);
97+
assertTrue(this.resolver.supportsParameter(param));
98+
99+
param = this.testMethod.annot(requestPart().name("name")).arg(Mono.class, Part.class);
100+
assertTrue(this.resolver.supportsParameter(param));
101+
102+
param = this.testMethod.annot(requestPart().name("name")).arg(Flux.class, Part.class);
103+
assertTrue(this.resolver.supportsParameter(param));
104+
}
105+
106+
107+
@Test
108+
public void person() {
109+
MethodParameter param = this.testMethod.annot(requestPart().name("name")).arg(Person.class);
110+
MultipartBodyBuilder bodyBuilder = new MultipartBodyBuilder();
111+
bodyBuilder.part("name", new Person("Jones"));
112+
Person actual = resolveArgument(param, bodyBuilder);
113+
114+
assertEquals("Jones", actual.getName());
115+
}
116+
117+
@Test
118+
public void listPerson() {
119+
MethodParameter param = this.testMethod.annot(requestPart().name("name")).arg(List.class, Person.class);
120+
MultipartBodyBuilder bodyBuilder = new MultipartBodyBuilder();
121+
bodyBuilder.part("name", new Person("Jones"));
122+
bodyBuilder.part("name", new Person("James"));
123+
List<Person> actual = resolveArgument(param, bodyBuilder);
124+
125+
assertEquals("Jones", actual.get(0).getName());
126+
assertEquals("James", actual.get(1).getName());
127+
}
128+
129+
@Test
130+
public void monoPerson() {
131+
MethodParameter param = this.testMethod.annot(requestPart().name("name")).arg(Mono.class, Person.class);
132+
MultipartBodyBuilder bodyBuilder = new MultipartBodyBuilder();
133+
bodyBuilder.part("name", new Person("Jones"));
134+
Mono<Person> actual = resolveArgument(param, bodyBuilder);
135+
136+
assertEquals("Jones", actual.block().getName());
137+
}
138+
139+
@Test
140+
public void fluxPerson() {
141+
MethodParameter param = this.testMethod.annot(requestPart().name("name")).arg(Flux.class, Person.class);
142+
MultipartBodyBuilder bodyBuilder = new MultipartBodyBuilder();
143+
bodyBuilder.part("name", new Person("Jones"));
144+
bodyBuilder.part("name", new Person("James"));
145+
Flux<Person> actual = resolveArgument(param, bodyBuilder);
146+
147+
List<Person> persons = actual.collectList().block();
148+
assertEquals("Jones", persons.get(0).getName());
149+
assertEquals("James", persons.get(1).getName());
150+
}
151+
152+
@Test
153+
public void part() {
154+
MethodParameter param = this.testMethod.annot(requestPart().name("name")).arg(Part.class);
155+
MultipartBodyBuilder bodyBuilder = new MultipartBodyBuilder();
156+
bodyBuilder.part("name", new Person("Jones"));
157+
Part actual = resolveArgument(param, bodyBuilder);
158+
159+
DataBuffer buffer = DataBufferUtils.join(actual.content()).block();
160+
assertEquals("{\"name\":\"Jones\"}", DataBufferTestUtils.dumpString(buffer, StandardCharsets.UTF_8));
161+
}
162+
163+
@Test
164+
public void listPart() {
165+
MethodParameter param = this.testMethod.annot(requestPart().name("name")).arg(List.class, Part.class);
166+
MultipartBodyBuilder bodyBuilder = new MultipartBodyBuilder();
167+
bodyBuilder.part("name", new Person("Jones"));
168+
bodyBuilder.part("name", new Person("James"));
169+
List<Part> actual = resolveArgument(param, bodyBuilder);
170+
171+
assertEquals("{\"name\":\"Jones\"}", partToUtf8String(actual.get(0)));
172+
assertEquals("{\"name\":\"James\"}", partToUtf8String(actual.get(1)));
173+
}
174+
175+
@Test
176+
public void monoPart() {
177+
MethodParameter param = this.testMethod.annot(requestPart().name("name")).arg(Mono.class, Part.class);
178+
MultipartBodyBuilder bodyBuilder = new MultipartBodyBuilder();
179+
bodyBuilder.part("name", new Person("Jones"));
180+
Mono<Part> actual = resolveArgument(param, bodyBuilder);
181+
182+
Part part = actual.block();
183+
assertEquals("{\"name\":\"Jones\"}", partToUtf8String(part));
184+
}
185+
186+
@Test
187+
public void fluxPart() {
188+
MethodParameter param = this.testMethod.annot(requestPart().name("name")).arg(Flux.class, Part.class);
189+
MultipartBodyBuilder bodyBuilder = new MultipartBodyBuilder();
190+
bodyBuilder.part("name", new Person("Jones"));
191+
bodyBuilder.part("name", new Person("James"));
192+
Flux<Part> actual = resolveArgument(param, bodyBuilder);
193+
194+
List<Part> parts = actual.collectList().block();
195+
assertEquals("{\"name\":\"Jones\"}", partToUtf8String(parts.get(0)));
196+
assertEquals("{\"name\":\"James\"}", partToUtf8String(parts.get(1)));
197+
}
198+
199+
@SuppressWarnings("unchecked")
200+
private <T> T resolveArgument(MethodParameter param, MultipartBodyBuilder builder) {
201+
ServerWebExchange exchange = createExchange(builder);
202+
Mono<Object> result = this.resolver.resolveArgument(param, new BindingContext(), exchange);
203+
Object value = result.block(Duration.ofSeconds(5));
204+
205+
assertNotNull(value);
206+
assertTrue(param.getParameterType().isAssignableFrom(value.getClass()));
207+
return (T) value;
208+
}
209+
210+
@SuppressWarnings("ConstantConditions")
211+
private ServerWebExchange createExchange(MultipartBodyBuilder builder) {
212+
213+
MockClientHttpRequest clientRequest = new MockClientHttpRequest(HttpMethod.POST, "/");
214+
this.writer.write(Mono.just(builder.build()), forClass(MultiValueMap.class),
215+
MediaType.MULTIPART_FORM_DATA, clientRequest, Collections.emptyMap()).block();
216+
217+
MockServerHttpRequest serverRequest = MockServerHttpRequest.post("/")
218+
.contentType(clientRequest.getHeaders().getContentType())
219+
.body(clientRequest.getBody());
220+
221+
return MockServerWebExchange.from(serverRequest);
222+
}
223+
224+
private String partToUtf8String(Part part) {
225+
DataBuffer buffer = DataBufferUtils.join(part.content()).block();
226+
return DataBufferTestUtils.dumpString(buffer, StandardCharsets.UTF_8);
227+
}
228+
229+
230+
@SuppressWarnings("unused")
231+
void handle(
232+
@RequestPart("name") Person person,
233+
@RequestPart("name") Mono<Person> personMono,
234+
@RequestPart("name") Flux<Person> personFlux,
235+
@RequestPart("name") List<Person> personList,
236+
@RequestPart("name") Part part,
237+
@RequestPart("name") Mono<Part> partMono,
238+
@RequestPart("name") Flux<Part> partFlux,
239+
@RequestPart("name") List<Part> partList,
240+
String notAnnotated) {}
241+
242+
243+
private static class Person {
244+
245+
private String name;
246+
247+
@JsonCreator
248+
public Person(@JsonProperty("name") String name) {
249+
this.name = name;
250+
}
251+
252+
public String getName() {
253+
return name;
254+
}
255+
}
256+
257+
}

0 commit comments

Comments
 (0)