Skip to content

Commit d36c0f5

Browse files
authored
perf: fix decoder streams to make pooled connections reusable (#2484)
When a response body is being decompressed, and the length wasn't known, but was using chunked transfer-encoding, the remaining `0\r\n\r\n` was not consumed. That would leave the connection in a state that could be not be reused, and so the pool had to discard it. This fix makes sure the remaining end chunk is consumed, improving the amount of pooled connections that can be reused. Closes #2381
1 parent 4367d30 commit d36c0f5

File tree

6 files changed

+1033
-28
lines changed

6 files changed

+1033
-28
lines changed

src/async_impl/decoder.rs

Lines changed: 88 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,14 @@ use std::future::Future;
99
use std::pin::Pin;
1010
use std::task::{Context, Poll};
1111

12+
#[cfg(any(
13+
feature = "gzip",
14+
feature = "zstd",
15+
feature = "brotli",
16+
feature = "deflate"
17+
))]
18+
use futures_util::stream::Fuse;
19+
1220
#[cfg(feature = "gzip")]
1321
use async_compression::tokio::bufread::GzipDecoder;
1422

@@ -108,19 +116,19 @@ enum Inner {
108116

109117
/// A `Gzip` decoder will uncompress the gzipped response content before returning it.
110118
#[cfg(feature = "gzip")]
111-
Gzip(Pin<Box<FramedRead<GzipDecoder<PeekableIoStreamReader>, BytesCodec>>>),
119+
Gzip(Pin<Box<Fuse<FramedRead<GzipDecoder<PeekableIoStreamReader>, BytesCodec>>>>),
112120

113121
/// A `Brotli` decoder will uncompress the brotlied response content before returning it.
114122
#[cfg(feature = "brotli")]
115-
Brotli(Pin<Box<FramedRead<BrotliDecoder<PeekableIoStreamReader>, BytesCodec>>>),
123+
Brotli(Pin<Box<Fuse<FramedRead<BrotliDecoder<PeekableIoStreamReader>, BytesCodec>>>>),
116124

117125
/// A `Zstd` decoder will uncompress the zstd compressed response content before returning it.
118126
#[cfg(feature = "zstd")]
119-
Zstd(Pin<Box<FramedRead<ZstdDecoder<PeekableIoStreamReader>, BytesCodec>>>),
127+
Zstd(Pin<Box<Fuse<FramedRead<ZstdDecoder<PeekableIoStreamReader>, BytesCodec>>>>),
120128

121129
/// A `Deflate` decoder will uncompress the deflated response content before returning it.
122130
#[cfg(feature = "deflate")]
123-
Deflate(Pin<Box<FramedRead<ZlibDecoder<PeekableIoStreamReader>, BytesCodec>>>),
131+
Deflate(Pin<Box<Fuse<FramedRead<ZlibDecoder<PeekableIoStreamReader>, BytesCodec>>>>),
124132

125133
/// A decoder that doesn't have a value yet.
126134
#[cfg(any(
@@ -365,34 +373,74 @@ impl HttpBody for Decoder {
365373
}
366374
#[cfg(feature = "gzip")]
367375
Inner::Gzip(ref mut decoder) => {
368-
match futures_core::ready!(Pin::new(decoder).poll_next(cx)) {
376+
match futures_core::ready!(Pin::new(&mut *decoder).poll_next(cx)) {
369377
Some(Ok(bytes)) => Poll::Ready(Some(Ok(Frame::data(bytes.freeze())))),
370378
Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))),
371-
None => Poll::Ready(None),
379+
None => {
380+
// poll inner connection until EOF after gzip stream is finished
381+
let inner_stream = decoder.get_mut().get_mut().get_mut().get_mut();
382+
match futures_core::ready!(Pin::new(inner_stream).poll_next(cx)) {
383+
Some(Ok(_)) => Poll::Ready(Some(Err(crate::error::decode(
384+
"there are extra bytes after body has been decompressed",
385+
)))),
386+
Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))),
387+
None => Poll::Ready(None),
388+
}
389+
}
372390
}
373391
}
374392
#[cfg(feature = "brotli")]
375393
Inner::Brotli(ref mut decoder) => {
376-
match futures_core::ready!(Pin::new(decoder).poll_next(cx)) {
394+
match futures_core::ready!(Pin::new(&mut *decoder).poll_next(cx)) {
377395
Some(Ok(bytes)) => Poll::Ready(Some(Ok(Frame::data(bytes.freeze())))),
378396
Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))),
379-
None => Poll::Ready(None),
397+
None => {
398+
// poll inner connection until EOF after brotli stream is finished
399+
let inner_stream = decoder.get_mut().get_mut().get_mut().get_mut();
400+
match futures_core::ready!(Pin::new(inner_stream).poll_next(cx)) {
401+
Some(Ok(_)) => Poll::Ready(Some(Err(crate::error::decode(
402+
"there are extra bytes after body has been decompressed",
403+
)))),
404+
Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))),
405+
None => Poll::Ready(None),
406+
}
407+
}
380408
}
381409
}
382410
#[cfg(feature = "zstd")]
383411
Inner::Zstd(ref mut decoder) => {
384-
match futures_core::ready!(Pin::new(decoder).poll_next(cx)) {
412+
match futures_core::ready!(Pin::new(&mut *decoder).poll_next(cx)) {
385413
Some(Ok(bytes)) => Poll::Ready(Some(Ok(Frame::data(bytes.freeze())))),
386414
Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))),
387-
None => Poll::Ready(None),
415+
None => {
416+
// poll inner connection until EOF after zstd stream is finished
417+
let inner_stream = decoder.get_mut().get_mut().get_mut().get_mut();
418+
match futures_core::ready!(Pin::new(inner_stream).poll_next(cx)) {
419+
Some(Ok(_)) => Poll::Ready(Some(Err(crate::error::decode(
420+
"there are extra bytes after body has been decompressed",
421+
)))),
422+
Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))),
423+
None => Poll::Ready(None),
424+
}
425+
}
388426
}
389427
}
390428
#[cfg(feature = "deflate")]
391429
Inner::Deflate(ref mut decoder) => {
392-
match futures_core::ready!(Pin::new(decoder).poll_next(cx)) {
430+
match futures_core::ready!(Pin::new(&mut *decoder).poll_next(cx)) {
393431
Some(Ok(bytes)) => Poll::Ready(Some(Ok(Frame::data(bytes.freeze())))),
394432
Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))),
395-
None => Poll::Ready(None),
433+
None => {
434+
// poll inner connection until EOF after deflate stream is finished
435+
let inner_stream = decoder.get_mut().get_mut().get_mut().get_mut();
436+
match futures_core::ready!(Pin::new(inner_stream).poll_next(cx)) {
437+
Some(Ok(_)) => Poll::Ready(Some(Err(crate::error::decode(
438+
"there are extra bytes after body has been decompressed",
439+
)))),
440+
Some(Err(err)) => Poll::Ready(Some(Err(crate::error::decode_io(err)))),
441+
None => Poll::Ready(None),
442+
}
443+
}
396444
}
397445
}
398446
}
@@ -456,25 +504,37 @@ impl Future for Pending {
456504

457505
match self.1 {
458506
#[cfg(feature = "brotli")]
459-
DecoderType::Brotli => Poll::Ready(Ok(Inner::Brotli(Box::pin(FramedRead::new(
460-
BrotliDecoder::new(StreamReader::new(_body)),
461-
BytesCodec::new(),
462-
))))),
507+
DecoderType::Brotli => Poll::Ready(Ok(Inner::Brotli(Box::pin(
508+
FramedRead::new(
509+
BrotliDecoder::new(StreamReader::new(_body)),
510+
BytesCodec::new(),
511+
)
512+
.fuse(),
513+
)))),
463514
#[cfg(feature = "zstd")]
464-
DecoderType::Zstd => Poll::Ready(Ok(Inner::Zstd(Box::pin(FramedRead::new(
465-
ZstdDecoder::new(StreamReader::new(_body)),
466-
BytesCodec::new(),
467-
))))),
515+
DecoderType::Zstd => Poll::Ready(Ok(Inner::Zstd(Box::pin(
516+
FramedRead::new(
517+
ZstdDecoder::new(StreamReader::new(_body)),
518+
BytesCodec::new(),
519+
)
520+
.fuse(),
521+
)))),
468522
#[cfg(feature = "gzip")]
469-
DecoderType::Gzip => Poll::Ready(Ok(Inner::Gzip(Box::pin(FramedRead::new(
470-
GzipDecoder::new(StreamReader::new(_body)),
471-
BytesCodec::new(),
472-
))))),
523+
DecoderType::Gzip => Poll::Ready(Ok(Inner::Gzip(Box::pin(
524+
FramedRead::new(
525+
GzipDecoder::new(StreamReader::new(_body)),
526+
BytesCodec::new(),
527+
)
528+
.fuse(),
529+
)))),
473530
#[cfg(feature = "deflate")]
474-
DecoderType::Deflate => Poll::Ready(Ok(Inner::Deflate(Box::pin(FramedRead::new(
475-
ZlibDecoder::new(StreamReader::new(_body)),
476-
BytesCodec::new(),
477-
))))),
531+
DecoderType::Deflate => Poll::Ready(Ok(Inner::Deflate(Box::pin(
532+
FramedRead::new(
533+
ZlibDecoder::new(StreamReader::new(_body)),
534+
BytesCodec::new(),
535+
)
536+
.fuse(),
537+
)))),
478538
}
479539
}
480540
}

tests/brotli.rs

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
mod support;
22
use std::io::Read;
33
use support::server;
4+
use tokio::io::AsyncWriteExt;
45

56
#[tokio::test]
67
async fn brotli_response() {
@@ -145,3 +146,212 @@ async fn brotli_case(response_size: usize, chunk_size: usize) {
145146
let body = res.text().await.expect("text");
146147
assert_eq!(body, content);
147148
}
149+
150+
const COMPRESSED_RESPONSE_HEADERS: &[u8] = b"HTTP/1.1 200 OK\x0d\x0a\
151+
Content-Type: text/plain\x0d\x0a\
152+
Connection: keep-alive\x0d\x0a\
153+
Content-Encoding: br\x0d\x0a";
154+
155+
const RESPONSE_CONTENT: &str = "some message here";
156+
157+
fn brotli_compress(input: &[u8]) -> Vec<u8> {
158+
let mut encoder = brotli_crate::CompressorReader::new(input, 4096, 5, 20);
159+
let mut brotlied_content = Vec::new();
160+
encoder.read_to_end(&mut brotlied_content).unwrap();
161+
brotlied_content
162+
}
163+
164+
#[tokio::test]
165+
async fn test_non_chunked_non_fragmented_response() {
166+
let server = server::low_level_with_response(|_raw_request, client_socket| {
167+
Box::new(async move {
168+
let brotlied_content = brotli_compress(RESPONSE_CONTENT.as_bytes());
169+
let content_length_header =
170+
format!("Content-Length: {}\r\n\r\n", brotlied_content.len()).into_bytes();
171+
let response = [
172+
COMPRESSED_RESPONSE_HEADERS,
173+
&content_length_header,
174+
&brotlied_content,
175+
]
176+
.concat();
177+
178+
client_socket
179+
.write_all(response.as_slice())
180+
.await
181+
.expect("response write_all failed");
182+
client_socket.flush().await.expect("response flush failed");
183+
})
184+
});
185+
186+
let res = reqwest::Client::new()
187+
.get(&format!("http://{}/", server.addr()))
188+
.send()
189+
.await
190+
.expect("response");
191+
192+
assert_eq!(res.text().await.expect("text"), RESPONSE_CONTENT);
193+
}
194+
195+
#[tokio::test]
196+
async fn test_chunked_fragmented_response_1() {
197+
const DELAY_BETWEEN_RESPONSE_PARTS: tokio::time::Duration =
198+
tokio::time::Duration::from_millis(1000);
199+
const DELAY_MARGIN: tokio::time::Duration = tokio::time::Duration::from_millis(50);
200+
201+
let server = server::low_level_with_response(|_raw_request, client_socket| {
202+
Box::new(async move {
203+
let brotlied_content = brotli_compress(RESPONSE_CONTENT.as_bytes());
204+
let response_first_part = [
205+
COMPRESSED_RESPONSE_HEADERS,
206+
format!(
207+
"Transfer-Encoding: chunked\r\n\r\n{:x}\r\n",
208+
brotlied_content.len()
209+
)
210+
.as_bytes(),
211+
&brotlied_content,
212+
]
213+
.concat();
214+
let response_second_part = b"\r\n0\r\n\r\n";
215+
216+
client_socket
217+
.write_all(response_first_part.as_slice())
218+
.await
219+
.expect("response_first_part write_all failed");
220+
client_socket
221+
.flush()
222+
.await
223+
.expect("response_first_part flush failed");
224+
225+
tokio::time::sleep(DELAY_BETWEEN_RESPONSE_PARTS).await;
226+
227+
client_socket
228+
.write_all(response_second_part)
229+
.await
230+
.expect("response_second_part write_all failed");
231+
client_socket
232+
.flush()
233+
.await
234+
.expect("response_second_part flush failed");
235+
})
236+
});
237+
238+
let start = tokio::time::Instant::now();
239+
let res = reqwest::Client::new()
240+
.get(&format!("http://{}/", server.addr()))
241+
.send()
242+
.await
243+
.expect("response");
244+
245+
assert_eq!(res.text().await.expect("text"), RESPONSE_CONTENT);
246+
assert!(start.elapsed() >= DELAY_BETWEEN_RESPONSE_PARTS - DELAY_MARGIN);
247+
}
248+
249+
#[tokio::test]
250+
async fn test_chunked_fragmented_response_2() {
251+
const DELAY_BETWEEN_RESPONSE_PARTS: tokio::time::Duration =
252+
tokio::time::Duration::from_millis(1000);
253+
const DELAY_MARGIN: tokio::time::Duration = tokio::time::Duration::from_millis(50);
254+
255+
let server = server::low_level_with_response(|_raw_request, client_socket| {
256+
Box::new(async move {
257+
let brotlied_content = brotli_compress(RESPONSE_CONTENT.as_bytes());
258+
let response_first_part = [
259+
COMPRESSED_RESPONSE_HEADERS,
260+
format!(
261+
"Transfer-Encoding: chunked\r\n\r\n{:x}\r\n",
262+
brotlied_content.len()
263+
)
264+
.as_bytes(),
265+
&brotlied_content,
266+
b"\r\n",
267+
]
268+
.concat();
269+
let response_second_part = b"0\r\n\r\n";
270+
271+
client_socket
272+
.write_all(response_first_part.as_slice())
273+
.await
274+
.expect("response_first_part write_all failed");
275+
client_socket
276+
.flush()
277+
.await
278+
.expect("response_first_part flush failed");
279+
280+
tokio::time::sleep(DELAY_BETWEEN_RESPONSE_PARTS).await;
281+
282+
client_socket
283+
.write_all(response_second_part)
284+
.await
285+
.expect("response_second_part write_all failed");
286+
client_socket
287+
.flush()
288+
.await
289+
.expect("response_second_part flush failed");
290+
})
291+
});
292+
293+
let start = tokio::time::Instant::now();
294+
let res = reqwest::Client::new()
295+
.get(&format!("http://{}/", server.addr()))
296+
.send()
297+
.await
298+
.expect("response");
299+
300+
assert_eq!(res.text().await.expect("text"), RESPONSE_CONTENT);
301+
assert!(start.elapsed() >= DELAY_BETWEEN_RESPONSE_PARTS - DELAY_MARGIN);
302+
}
303+
304+
#[tokio::test]
305+
async fn test_chunked_fragmented_response_with_extra_bytes() {
306+
const DELAY_BETWEEN_RESPONSE_PARTS: tokio::time::Duration =
307+
tokio::time::Duration::from_millis(1000);
308+
const DELAY_MARGIN: tokio::time::Duration = tokio::time::Duration::from_millis(50);
309+
310+
let server = server::low_level_with_response(|_raw_request, client_socket| {
311+
Box::new(async move {
312+
let brotlied_content = brotli_compress(RESPONSE_CONTENT.as_bytes());
313+
let response_first_part = [
314+
COMPRESSED_RESPONSE_HEADERS,
315+
format!(
316+
"Transfer-Encoding: chunked\r\n\r\n{:x}\r\n",
317+
brotlied_content.len()
318+
)
319+
.as_bytes(),
320+
&brotlied_content,
321+
]
322+
.concat();
323+
let response_second_part = b"\r\n2ab\r\n0\r\n\r\n";
324+
325+
client_socket
326+
.write_all(response_first_part.as_slice())
327+
.await
328+
.expect("response_first_part write_all failed");
329+
client_socket
330+
.flush()
331+
.await
332+
.expect("response_first_part flush failed");
333+
334+
tokio::time::sleep(DELAY_BETWEEN_RESPONSE_PARTS).await;
335+
336+
client_socket
337+
.write_all(response_second_part)
338+
.await
339+
.expect("response_second_part write_all failed");
340+
client_socket
341+
.flush()
342+
.await
343+
.expect("response_second_part flush failed");
344+
})
345+
});
346+
347+
let start = tokio::time::Instant::now();
348+
let res = reqwest::Client::new()
349+
.get(&format!("http://{}/", server.addr()))
350+
.send()
351+
.await
352+
.expect("response");
353+
354+
let err = res.text().await.expect_err("there must be an error");
355+
assert!(err.is_decode());
356+
assert!(start.elapsed() >= DELAY_BETWEEN_RESPONSE_PARTS - DELAY_MARGIN);
357+
}

0 commit comments

Comments
 (0)