|
1 | 1 | /*
|
2 |
| - * Copyright 2002-2022 the original author or authors. |
| 2 | + * Copyright 2002-2023 the original author or authors. |
3 | 3 | *
|
4 | 4 | * Licensed under the Apache License, Version 2.0 (the "License");
|
5 | 5 | * you may not use this file except in compliance with the License.
|
|
21 | 21 | import java.util.Collections;
|
22 | 22 | import java.util.List;
|
23 | 23 | import java.util.Map;
|
| 24 | +import java.util.concurrent.atomic.AtomicInteger; |
| 25 | +import java.util.concurrent.atomic.AtomicLong; |
24 | 26 |
|
25 | 27 | import org.reactivestreams.Publisher;
|
26 | 28 | import reactor.core.publisher.Flux;
|
@@ -53,6 +55,10 @@ public class PartEventHttpMessageReader extends LoggingCodecSupport implements H
|
53 | 55 |
|
54 | 56 | private int maxHeadersSize = 10 * 1024;
|
55 | 57 |
|
| 58 | + private int maxParts = -1; |
| 59 | + |
| 60 | + private long maxPartSize = -1; |
| 61 | + |
56 | 62 | private Charset headersCharset = StandardCharsets.UTF_8;
|
57 | 63 |
|
58 | 64 |
|
@@ -85,6 +91,24 @@ public void setMaxHeadersSize(int byteCount) {
|
85 | 91 | this.maxHeadersSize = byteCount;
|
86 | 92 | }
|
87 | 93 |
|
| 94 | + /** |
| 95 | + * Specify the maximum number of parts allowed in a given multipart request. |
| 96 | + * <p>By default this is set to -1, meaning that there is no maximum. |
| 97 | + * @since 6.1 |
| 98 | + */ |
| 99 | + public void setMaxParts(int maxParts) { |
| 100 | + this.maxParts = maxParts; |
| 101 | + } |
| 102 | + |
| 103 | + /** |
| 104 | + * Configure the maximum size allowed for any part. |
| 105 | + * <p>By default this is set to -1, meaning that there is no maximum. |
| 106 | + * @since 6.1 |
| 107 | + */ |
| 108 | + public void setMaxPartSize(long maxPartSize) { |
| 109 | + this.maxPartSize = maxPartSize; |
| 110 | + } |
| 111 | + |
88 | 112 | /**
|
89 | 113 | * Set the character set used to decode headers.
|
90 | 114 | * <p>Defaults to UTF-8 as per RFC 7578.
|
@@ -125,50 +149,81 @@ public Flux<PartEvent> read(ResolvableType elementType, ReactiveHttpInputMessage
|
125 | 149 | return Flux.error(new DecodingException("No multipart boundary found in Content-Type: \"" +
|
126 | 150 | message.getHeaders().getContentType() + "\""));
|
127 | 151 | }
|
128 |
| - return MultipartParser.parse(message.getBody(), boundary, this.maxHeadersSize, this.headersCharset) |
| 152 | + Flux<MultipartParser.Token> allPartsTokens = MultipartParser.parse(message.getBody(), boundary, |
| 153 | + this.maxHeadersSize, this.headersCharset); |
| 154 | + |
| 155 | + AtomicInteger partCount = new AtomicInteger(); |
| 156 | + return allPartsTokens |
129 | 157 | .windowUntil(t -> t instanceof MultipartParser.HeadersToken, true)
|
130 |
| - .concatMap(tokens -> tokens.switchOnFirst((signal, flux) -> { |
131 |
| - if (signal.hasValue()) { |
132 |
| - MultipartParser.HeadersToken headersToken = (MultipartParser.HeadersToken) signal.get(); |
133 |
| - Assert.state(headersToken != null, "Signal should be headers token"); |
134 |
| - |
135 |
| - HttpHeaders headers = headersToken.headers(); |
136 |
| - Flux<MultipartParser.BodyToken> bodyTokens = |
137 |
| - flux.filter(t -> t instanceof MultipartParser.BodyToken) |
138 |
| - .cast(MultipartParser.BodyToken.class); |
139 |
| - return createEvents(headers, bodyTokens); |
| 158 | + .concatMap(partTokens -> { |
| 159 | + if (tooManyParts(partCount)) { |
| 160 | + return Mono.error(new DecodingException("Too many parts (" + partCount.get() + "/" + |
| 161 | + this.maxParts + " allowed)")); |
140 | 162 | }
|
141 | 163 | else {
|
142 |
| - // complete or error signal |
143 |
| - return flux.cast(PartEvent.class); |
| 164 | + return partTokens.switchOnFirst((signal, flux) -> { |
| 165 | + if (signal.hasValue()) { |
| 166 | + MultipartParser.HeadersToken headersToken = (MultipartParser.HeadersToken) signal.get(); |
| 167 | + Assert.state(headersToken != null, "Signal should be headers token"); |
| 168 | + |
| 169 | + HttpHeaders headers = headersToken.headers(); |
| 170 | + Flux<MultipartParser.BodyToken> bodyTokens = |
| 171 | + flux.filter(t -> t instanceof MultipartParser.BodyToken) |
| 172 | + .cast(MultipartParser.BodyToken.class); |
| 173 | + return createEvents(headers, bodyTokens); |
| 174 | + } |
| 175 | + else { |
| 176 | + // complete or error signal |
| 177 | + return flux.cast(PartEvent.class); |
| 178 | + } |
| 179 | + }); |
144 | 180 | }
|
145 |
| - })); |
| 181 | + }); |
146 | 182 | });
|
147 | 183 | }
|
148 | 184 |
|
| 185 | + private boolean tooManyParts(AtomicInteger partCount) { |
| 186 | + int count = partCount.incrementAndGet(); |
| 187 | + return this.maxParts > 0 && count > this.maxParts; |
| 188 | + } |
| 189 | + |
| 190 | + |
149 | 191 | private Publisher<? extends PartEvent> createEvents(HttpHeaders headers, Flux<MultipartParser.BodyToken> bodyTokens) {
|
150 | 192 | if (MultipartUtils.isFormField(headers)) {
|
151 | 193 | Flux<DataBuffer> contents = bodyTokens.map(MultipartParser.BodyToken::buffer);
|
152 |
| - return DataBufferUtils.join(contents, this.maxInMemorySize) |
| 194 | + int maxSize = (int) Math.min(this.maxInMemorySize, this.maxPartSize); |
| 195 | + return DataBufferUtils.join(contents, maxSize) |
153 | 196 | .map(content -> {
|
154 | 197 | String value = content.toString(MultipartUtils.charset(headers));
|
155 | 198 | DataBufferUtils.release(content);
|
156 | 199 | return DefaultPartEvents.form(headers, value);
|
157 | 200 | })
|
158 | 201 | .switchIfEmpty(Mono.fromCallable(() -> DefaultPartEvents.form(headers)));
|
159 | 202 | }
|
160 |
| - else if (headers.getContentDisposition().getFilename() != null) { |
161 |
| - return bodyTokens |
162 |
| - .map(body -> DefaultPartEvents.file(headers, body.buffer(), body.isLast())) |
163 |
| - .switchIfEmpty(Mono.fromCallable(() -> DefaultPartEvents.file(headers))); |
164 |
| - } |
165 | 203 | else {
|
| 204 | + boolean isFilePart = headers.getContentDisposition().getFilename() != null; |
| 205 | + AtomicLong partSize = new AtomicLong(); |
166 | 206 | return bodyTokens
|
167 |
| - .map(body -> DefaultPartEvents.create(headers, body.buffer(), body.isLast())) |
168 |
| - .switchIfEmpty(Mono.fromCallable(() -> DefaultPartEvents.create(headers))); // empty body |
| 207 | + .concatMap(body -> { |
| 208 | + DataBuffer buffer = body.buffer(); |
| 209 | + if (tooLarge(partSize, buffer)) { |
| 210 | + DataBufferUtils.release(buffer); |
| 211 | + return Mono.error(new DataBufferLimitException("Part exceeded the limit of " + |
| 212 | + this.maxPartSize + " bytes")); |
| 213 | + } |
| 214 | + else { |
| 215 | + return isFilePart ? Mono.just(DefaultPartEvents.file(headers, buffer, body.isLast())) |
| 216 | + : Mono.just(DefaultPartEvents.create(headers, body.buffer(), body.isLast())); |
| 217 | + } |
| 218 | + }) |
| 219 | + .switchIfEmpty(Mono.fromCallable(() -> |
| 220 | + isFilePart ? DefaultPartEvents.file(headers) : DefaultPartEvents.create(headers))); |
169 | 221 | }
|
| 222 | + } |
170 | 223 |
|
171 |
| - |
| 224 | + private boolean tooLarge(AtomicLong partSize, DataBuffer buffer) { |
| 225 | + long size = partSize.addAndGet(buffer.readableByteCount()); |
| 226 | + return this.maxPartSize > 0 && size > this.maxPartSize; |
172 | 227 | }
|
173 | 228 |
|
174 | 229 | }
|
0 commit comments