Skip to content

Commit ddafcc0

Browse files
committed
Auto merge of #79650 - the8472:fix-take, r=dtolnay
Fix incorrect io::Take's limit resulting from io::copy specialization The specialization introduced in #75272 fails to update `io::Take` wrappers after performing the copy syscalls which bypass those wrappers. The buffer flushing before the copy does update them correctly, but the bytes copied after the initial flush weren't subtracted. The fix is to subtract the bytes copied from each `Take` in the chain of wrappers, even when an error occurs during the syscall loop. To do so the `CopyResult` enum now has to carry the bytes copied so far in the error case.
2 parents bb0d481 + a9b1381 commit ddafcc0

File tree

3 files changed

+54
-16
lines changed

3 files changed

+54
-16
lines changed

Diff for: library/std/src/sys/unix/fs.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -1211,7 +1211,8 @@ pub fn copy(from: &Path, to: &Path) -> io::Result<u64> {
12111211
use super::kernel_copy::{copy_regular_files, CopyResult};
12121212

12131213
match copy_regular_files(reader.as_raw_fd(), writer.as_raw_fd(), max_len) {
1214-
CopyResult::Ended(result) => result,
1214+
CopyResult::Ended(bytes) => Ok(bytes),
1215+
CopyResult::Error(e, _) => Err(e),
12151216
CopyResult::Fallback(written) => match io::copy::generic_copy(&mut reader, &mut writer) {
12161217
Ok(bytes) => Ok(bytes + written),
12171218
Err(e) => Err(e),

Diff for: library/std/src/sys/unix/kernel_copy.rs

+42-12
Original file line numberDiff line numberDiff line change
@@ -167,10 +167,11 @@ impl<R: CopyRead, W: CopyWrite> SpecCopy for Copier<'_, '_, R, W> {
167167

168168
if input_meta.copy_file_range_candidate() && output_meta.copy_file_range_candidate() {
169169
let result = copy_regular_files(readfd, writefd, max_write);
170+
result.update_take(reader);
170171

171172
match result {
172-
CopyResult::Ended(Ok(bytes_copied)) => return Ok(bytes_copied + written),
173-
CopyResult::Ended(err) => return err,
173+
CopyResult::Ended(bytes_copied) => return Ok(bytes_copied + written),
174+
CopyResult::Error(e, _) => return Err(e),
174175
CopyResult::Fallback(bytes) => written += bytes,
175176
}
176177
}
@@ -182,20 +183,22 @@ impl<R: CopyRead, W: CopyWrite> SpecCopy for Copier<'_, '_, R, W> {
182183
// fall back to the generic copy loop.
183184
if input_meta.potential_sendfile_source() {
184185
let result = sendfile_splice(SpliceMode::Sendfile, readfd, writefd, max_write);
186+
result.update_take(reader);
185187

186188
match result {
187-
CopyResult::Ended(Ok(bytes_copied)) => return Ok(bytes_copied + written),
188-
CopyResult::Ended(err) => return err,
189+
CopyResult::Ended(bytes_copied) => return Ok(bytes_copied + written),
190+
CopyResult::Error(e, _) => return Err(e),
189191
CopyResult::Fallback(bytes) => written += bytes,
190192
}
191193
}
192194

193195
if input_meta.maybe_fifo() || output_meta.maybe_fifo() {
194196
let result = sendfile_splice(SpliceMode::Splice, readfd, writefd, max_write);
197+
result.update_take(reader);
195198

196199
match result {
197-
CopyResult::Ended(Ok(bytes_copied)) => return Ok(bytes_copied + written),
198-
CopyResult::Ended(err) => return err,
200+
CopyResult::Ended(bytes_copied) => return Ok(bytes_copied + written),
201+
CopyResult::Error(e, _) => return Err(e),
199202
CopyResult::Fallback(0) => { /* use the fallback below */ }
200203
CopyResult::Fallback(_) => {
201204
unreachable!("splice should not return > 0 bytes on the fallback path")
@@ -225,6 +228,9 @@ trait CopyRead: Read {
225228
Ok(0)
226229
}
227230

231+
/// Updates `Take` wrappers to remove the number of bytes copied.
232+
fn taken(&mut self, _bytes: u64) {}
233+
228234
/// The minimum of the limit of all `Take<_>` wrappers, `u64::MAX` otherwise.
229235
/// This method does not account for data `BufReader` buffers and would underreport
230236
/// the limit of a `Take<BufReader<Take<_>>>` type. Thus its result is only valid
@@ -251,6 +257,10 @@ where
251257
(**self).drain_to(writer, limit)
252258
}
253259

260+
fn taken(&mut self, bytes: u64) {
261+
(**self).taken(bytes);
262+
}
263+
254264
fn min_limit(&self) -> u64 {
255265
(**self).min_limit()
256266
}
@@ -407,6 +417,11 @@ impl<T: CopyRead> CopyRead for Take<T> {
407417
Ok(bytes_drained)
408418
}
409419

420+
fn taken(&mut self, bytes: u64) {
421+
self.set_limit(self.limit() - bytes);
422+
self.get_mut().taken(bytes);
423+
}
424+
410425
fn min_limit(&self) -> u64 {
411426
min(Take::limit(self), self.get_ref().min_limit())
412427
}
@@ -432,6 +447,10 @@ impl<T: CopyRead> CopyRead for BufReader<T> {
432447
Ok(bytes as u64 + inner_bytes)
433448
}
434449

450+
fn taken(&mut self, bytes: u64) {
451+
self.get_mut().taken(bytes);
452+
}
453+
435454
fn min_limit(&self) -> u64 {
436455
self.get_ref().min_limit()
437456
}
@@ -457,10 +476,21 @@ fn fd_to_meta<T: AsRawFd>(fd: &T) -> FdMeta {
457476
}
458477

459478
pub(super) enum CopyResult {
460-
Ended(Result<u64>),
479+
Ended(u64),
480+
Error(Error, u64),
461481
Fallback(u64),
462482
}
463483

484+
impl CopyResult {
485+
fn update_take(&self, reader: &mut impl CopyRead) {
486+
match *self {
487+
CopyResult::Fallback(bytes)
488+
| CopyResult::Ended(bytes)
489+
| CopyResult::Error(_, bytes) => reader.taken(bytes),
490+
}
491+
}
492+
}
493+
464494
/// linux-specific implementation that will attempt to use copy_file_range for copy offloading
465495
/// as the name says, it only works on regular files
466496
///
@@ -527,7 +557,7 @@ pub(super) fn copy_regular_files(reader: RawFd, writer: RawFd, max_len: u64) ->
527557
// - copying from an overlay filesystem in docker. reported to occur on fedora 32.
528558
return CopyResult::Fallback(0);
529559
}
530-
Ok(0) => return CopyResult::Ended(Ok(written)), // reached EOF
560+
Ok(0) => return CopyResult::Ended(written), // reached EOF
531561
Ok(ret) => written += ret as u64,
532562
Err(err) => {
533563
return match err.raw_os_error() {
@@ -545,12 +575,12 @@ pub(super) fn copy_regular_files(reader: RawFd, writer: RawFd, max_len: u64) ->
545575
assert_eq!(written, 0);
546576
CopyResult::Fallback(0)
547577
}
548-
_ => CopyResult::Ended(Err(err)),
578+
_ => CopyResult::Error(err, written),
549579
};
550580
}
551581
}
552582
}
553-
CopyResult::Ended(Ok(written))
583+
CopyResult::Ended(written)
554584
}
555585

556586
#[derive(PartialEq)]
@@ -623,10 +653,10 @@ fn sendfile_splice(mode: SpliceMode, reader: RawFd, writer: RawFd, len: u64) ->
623653
Some(os_err) if mode == SpliceMode::Sendfile && os_err == libc::EOVERFLOW => {
624654
CopyResult::Fallback(written)
625655
}
626-
_ => CopyResult::Ended(Err(err)),
656+
_ => CopyResult::Error(err, written),
627657
};
628658
}
629659
}
630660
}
631-
CopyResult::Ended(Ok(written))
661+
CopyResult::Ended(written)
632662
}

Diff for: library/std/src/sys/unix/kernel_copy/tests.rs

+10-3
Original file line numberDiff line numberDiff line change
@@ -42,8 +42,15 @@ fn copy_specialization() -> Result<()> {
4242
assert_eq!(sink.buffer(), b"wxyz");
4343

4444
let copied = crate::io::copy(&mut source, &mut sink)?;
45-
assert_eq!(copied, 10);
46-
assert_eq!(sink.buffer().len(), 0);
45+
assert_eq!(copied, 10, "copy obeyed limit imposed by Take");
46+
assert_eq!(sink.buffer().len(), 0, "sink buffer was flushed");
47+
assert_eq!(source.limit(), 0, "outer Take was exhausted");
48+
assert_eq!(source.get_ref().buffer().len(), 0, "source buffer should be drained");
49+
assert_eq!(
50+
source.get_ref().get_ref().limit(),
51+
1,
52+
"inner Take allowed reading beyond end of file, some bytes should be left"
53+
);
4754

4855
let mut sink = sink.into_inner()?;
4956
sink.seek(SeekFrom::Start(0))?;
@@ -210,7 +217,7 @@ fn bench_socket_pipe_socket_copy(b: &mut test::Bencher) {
210217
);
211218

212219
match probe {
213-
CopyResult::Ended(Ok(1)) => {
220+
CopyResult::Ended(1) => {
214221
// splice works
215222
}
216223
_ => {

0 commit comments

Comments
 (0)