Skip to content

Commit 66aac7e

Browse files
committed
Add maxParts and maxPartSize to PartEventHttpMessageReader
This commit introduces the maxParts and maxPartSize properties to PartEventHttpMessageReader, which can be used to limit the amount of parts, and maximum part size respectively. Closes gh-31343
1 parent 93528b6 commit 66aac7e

File tree

2 files changed

+113
-25
lines changed

2 files changed

+113
-25
lines changed

spring-web/src/main/java/org/springframework/http/codec/multipart/PartEventHttpMessageReader.java

Lines changed: 79 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2022 the original author or authors.
2+
* Copyright 2002-2023 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -21,6 +21,8 @@
2121
import java.util.Collections;
2222
import java.util.List;
2323
import java.util.Map;
24+
import java.util.concurrent.atomic.AtomicInteger;
25+
import java.util.concurrent.atomic.AtomicLong;
2426

2527
import org.reactivestreams.Publisher;
2628
import reactor.core.publisher.Flux;
@@ -53,6 +55,10 @@ public class PartEventHttpMessageReader extends LoggingCodecSupport implements H
5355

5456
private int maxHeadersSize = 10 * 1024;
5557

58+
private int maxParts = -1;
59+
60+
private long maxPartSize = -1;
61+
5662
private Charset headersCharset = StandardCharsets.UTF_8;
5763

5864

@@ -85,6 +91,24 @@ public void setMaxHeadersSize(int byteCount) {
8591
this.maxHeadersSize = byteCount;
8692
}
8793

94+
/**
95+
* Specify the maximum number of parts allowed in a given multipart request.
96+
* <p>By default this is set to -1, meaning that there is no maximum.
97+
* @since 6.1
98+
*/
99+
public void setMaxParts(int maxParts) {
100+
this.maxParts = maxParts;
101+
}
102+
103+
/**
104+
* Configure the maximum size allowed for any part.
105+
* <p>By default this is set to -1, meaning that there is no maximum.
106+
* @since 6.1
107+
*/
108+
public void setMaxPartSize(long maxPartSize) {
109+
this.maxPartSize = maxPartSize;
110+
}
111+
88112
/**
89113
* Set the character set used to decode headers.
90114
* <p>Defaults to UTF-8 as per RFC 7578.
@@ -125,50 +149,81 @@ public Flux<PartEvent> read(ResolvableType elementType, ReactiveHttpInputMessage
125149
return Flux.error(new DecodingException("No multipart boundary found in Content-Type: \"" +
126150
message.getHeaders().getContentType() + "\""));
127151
}
128-
return MultipartParser.parse(message.getBody(), boundary, this.maxHeadersSize, this.headersCharset)
152+
Flux<MultipartParser.Token> allPartsTokens = MultipartParser.parse(message.getBody(), boundary,
153+
this.maxHeadersSize, this.headersCharset);
154+
155+
AtomicInteger partCount = new AtomicInteger();
156+
return allPartsTokens
129157
.windowUntil(t -> t instanceof MultipartParser.HeadersToken, true)
130-
.concatMap(tokens -> tokens.switchOnFirst((signal, flux) -> {
131-
if (signal.hasValue()) {
132-
MultipartParser.HeadersToken headersToken = (MultipartParser.HeadersToken) signal.get();
133-
Assert.state(headersToken != null, "Signal should be headers token");
134-
135-
HttpHeaders headers = headersToken.headers();
136-
Flux<MultipartParser.BodyToken> bodyTokens =
137-
flux.filter(t -> t instanceof MultipartParser.BodyToken)
138-
.cast(MultipartParser.BodyToken.class);
139-
return createEvents(headers, bodyTokens);
158+
.concatMap(partTokens -> {
159+
if (tooManyParts(partCount)) {
160+
return Mono.error(new DecodingException("Too many parts (" + partCount.get() + "/" +
161+
this.maxParts + " allowed)"));
140162
}
141163
else {
142-
// complete or error signal
143-
return flux.cast(PartEvent.class);
164+
return partTokens.switchOnFirst((signal, flux) -> {
165+
if (signal.hasValue()) {
166+
MultipartParser.HeadersToken headersToken = (MultipartParser.HeadersToken) signal.get();
167+
Assert.state(headersToken != null, "Signal should be headers token");
168+
169+
HttpHeaders headers = headersToken.headers();
170+
Flux<MultipartParser.BodyToken> bodyTokens =
171+
flux.filter(t -> t instanceof MultipartParser.BodyToken)
172+
.cast(MultipartParser.BodyToken.class);
173+
return createEvents(headers, bodyTokens);
174+
}
175+
else {
176+
// complete or error signal
177+
return flux.cast(PartEvent.class);
178+
}
179+
});
144180
}
145-
}));
181+
});
146182
});
147183
}
148184

185+
private boolean tooManyParts(AtomicInteger partCount) {
186+
int count = partCount.incrementAndGet();
187+
return this.maxParts > 0 && count > this.maxParts;
188+
}
189+
190+
149191
private Publisher<? extends PartEvent> createEvents(HttpHeaders headers, Flux<MultipartParser.BodyToken> bodyTokens) {
150192
if (MultipartUtils.isFormField(headers)) {
151193
Flux<DataBuffer> contents = bodyTokens.map(MultipartParser.BodyToken::buffer);
152-
return DataBufferUtils.join(contents, this.maxInMemorySize)
194+
int maxSize = (int) Math.min(this.maxInMemorySize, this.maxPartSize);
195+
return DataBufferUtils.join(contents, maxSize)
153196
.map(content -> {
154197
String value = content.toString(MultipartUtils.charset(headers));
155198
DataBufferUtils.release(content);
156199
return DefaultPartEvents.form(headers, value);
157200
})
158201
.switchIfEmpty(Mono.fromCallable(() -> DefaultPartEvents.form(headers)));
159202
}
160-
else if (headers.getContentDisposition().getFilename() != null) {
161-
return bodyTokens
162-
.map(body -> DefaultPartEvents.file(headers, body.buffer(), body.isLast()))
163-
.switchIfEmpty(Mono.fromCallable(() -> DefaultPartEvents.file(headers)));
164-
}
165203
else {
204+
boolean isFilePart = headers.getContentDisposition().getFilename() != null;
205+
AtomicLong partSize = new AtomicLong();
166206
return bodyTokens
167-
.map(body -> DefaultPartEvents.create(headers, body.buffer(), body.isLast()))
168-
.switchIfEmpty(Mono.fromCallable(() -> DefaultPartEvents.create(headers))); // empty body
207+
.concatMap(body -> {
208+
DataBuffer buffer = body.buffer();
209+
if (tooLarge(partSize, buffer)) {
210+
DataBufferUtils.release(buffer);
211+
return Mono.error(new DataBufferLimitException("Part exceeded the limit of " +
212+
this.maxPartSize + " bytes"));
213+
}
214+
else {
215+
return isFilePart ? Mono.just(DefaultPartEvents.file(headers, buffer, body.isLast()))
216+
: Mono.just(DefaultPartEvents.create(headers, body.buffer(), body.isLast()));
217+
}
218+
})
219+
.switchIfEmpty(Mono.fromCallable(() ->
220+
isFilePart ? DefaultPartEvents.file(headers) : DefaultPartEvents.create(headers)));
169221
}
222+
}
170223

171-
224+
private boolean tooLarge(AtomicLong partSize, DataBuffer buffer) {
225+
long size = partSize.addAndGet(buffer.readableByteCount());
226+
return this.maxPartSize > 0 && size > this.maxPartSize;
172227
}
173228

174229
}

spring-web/src/test/java/org/springframework/http/codec/multipart/PartEventHttpMessageReaderTests.java

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright 2002-2022 the original author or authors.
2+
* Copyright 2002-2023 the original author or authors.
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -30,6 +30,7 @@
3030
import org.springframework.core.io.Resource;
3131
import org.springframework.core.io.buffer.DataBuffer;
3232
import org.springframework.core.io.buffer.DataBufferFactory;
33+
import org.springframework.core.io.buffer.DataBufferLimitException;
3334
import org.springframework.core.io.buffer.DataBufferUtils;
3435
import org.springframework.core.io.buffer.NettyDataBufferFactory;
3536
import org.springframework.http.ContentDisposition;
@@ -226,6 +227,38 @@ public void safari() {
226227
.verifyComplete();
227228
}
228229

230+
@Test
231+
void tooManyParts() {
232+
MockServerHttpRequest request = createRequest(
233+
new ClassPathResource("simple.multipart", getClass()), "simple-boundary");
234+
235+
PartEventHttpMessageReader reader = new PartEventHttpMessageReader();
236+
reader.setMaxParts(1);
237+
238+
Flux<PartEvent> result = reader.read(forClass(PartEvent.class), request, emptyMap());
239+
240+
StepVerifier.create(result)
241+
.expectError(DecodingException.class)
242+
.verify();
243+
}
244+
245+
@Test
246+
void partSizeTooLarge() {
247+
MockServerHttpRequest request = createRequest(new ClassPathResource("safari.multipart", getClass()),
248+
"----WebKitFormBoundaryG8fJ50opQOML0oGD");
249+
250+
PartEventHttpMessageReader reader = new PartEventHttpMessageReader();
251+
reader.setMaxPartSize(60);
252+
253+
Flux<PartEvent> result = reader.read(forClass(PartEvent.class), request, emptyMap());
254+
255+
StepVerifier.create(result)
256+
.assertNext(data(headersFormField("text1"), bodyText("a"), true))
257+
.assertNext(data(headersFormField("text2"), bodyText("b"), true))
258+
.expectError(DataBufferLimitException.class)
259+
.verify();
260+
261+
}
229262

230263
@Test
231264
public void utf8Headers() {

0 commit comments

Comments
 (0)