Skip to content

Commit a9b1381

Browse files
committed
fix copy specialization not updating Take wrappers
1 parent 9b390e7 commit a9b1381

File tree

3 files changed

+45
-14
lines changed

3 files changed

+45
-14
lines changed

Diff for: library/std/src/sys/unix/fs.rs

+2-1
Original file line numberDiff line numberDiff line change
@@ -1211,7 +1211,8 @@ pub fn copy(from: &Path, to: &Path) -> io::Result<u64> {
12111211
use super::kernel_copy::{copy_regular_files, CopyResult};
12121212

12131213
match copy_regular_files(reader.as_raw_fd(), writer.as_raw_fd(), max_len) {
1214-
CopyResult::Ended(result) => result,
1214+
CopyResult::Ended(bytes) => Ok(bytes),
1215+
CopyResult::Error(e, _) => Err(e),
12151216
CopyResult::Fallback(written) => match io::copy::generic_copy(&mut reader, &mut writer) {
12161217
Ok(bytes) => Ok(bytes + written),
12171218
Err(e) => Err(e),

Diff for: library/std/src/sys/unix/kernel_copy.rs

+42-12
Original file line numberDiff line numberDiff line change
@@ -167,10 +167,11 @@ impl<R: CopyRead, W: CopyWrite> SpecCopy for Copier<'_, '_, R, W> {
167167

168168
if input_meta.copy_file_range_candidate() && output_meta.copy_file_range_candidate() {
169169
let result = copy_regular_files(readfd, writefd, max_write);
170+
result.update_take(reader);
170171

171172
match result {
172-
CopyResult::Ended(Ok(bytes_copied)) => return Ok(bytes_copied + written),
173-
CopyResult::Ended(err) => return err,
173+
CopyResult::Ended(bytes_copied) => return Ok(bytes_copied + written),
174+
CopyResult::Error(e, _) => return Err(e),
174175
CopyResult::Fallback(bytes) => written += bytes,
175176
}
176177
}
@@ -182,20 +183,22 @@ impl<R: CopyRead, W: CopyWrite> SpecCopy for Copier<'_, '_, R, W> {
182183
// fall back to the generic copy loop.
183184
if input_meta.potential_sendfile_source() {
184185
let result = sendfile_splice(SpliceMode::Sendfile, readfd, writefd, max_write);
186+
result.update_take(reader);
185187

186188
match result {
187-
CopyResult::Ended(Ok(bytes_copied)) => return Ok(bytes_copied + written),
188-
CopyResult::Ended(err) => return err,
189+
CopyResult::Ended(bytes_copied) => return Ok(bytes_copied + written),
190+
CopyResult::Error(e, _) => return Err(e),
189191
CopyResult::Fallback(bytes) => written += bytes,
190192
}
191193
}
192194

193195
if input_meta.maybe_fifo() || output_meta.maybe_fifo() {
194196
let result = sendfile_splice(SpliceMode::Splice, readfd, writefd, max_write);
197+
result.update_take(reader);
195198

196199
match result {
197-
CopyResult::Ended(Ok(bytes_copied)) => return Ok(bytes_copied + written),
198-
CopyResult::Ended(err) => return err,
200+
CopyResult::Ended(bytes_copied) => return Ok(bytes_copied + written),
201+
CopyResult::Error(e, _) => return Err(e),
199202
CopyResult::Fallback(0) => { /* use the fallback below */ }
200203
CopyResult::Fallback(_) => {
201204
unreachable!("splice should not return > 0 bytes on the fallback path")
@@ -225,6 +228,9 @@ trait CopyRead: Read {
225228
Ok(0)
226229
}
227230

231+
/// Updates `Take` wrappers to remove the number of bytes copied.
232+
fn taken(&mut self, _bytes: u64) {}
233+
228234
/// The minimum of the limit of all `Take<_>` wrappers, `u64::MAX` otherwise.
229235
/// This method does not account for data `BufReader` buffers and would underreport
230236
/// the limit of a `Take<BufReader<Take<_>>>` type. Thus its result is only valid
@@ -251,6 +257,10 @@ where
251257
(**self).drain_to(writer, limit)
252258
}
253259

260+
fn taken(&mut self, bytes: u64) {
261+
(**self).taken(bytes);
262+
}
263+
254264
fn min_limit(&self) -> u64 {
255265
(**self).min_limit()
256266
}
@@ -407,6 +417,11 @@ impl<T: CopyRead> CopyRead for Take<T> {
407417
Ok(bytes_drained)
408418
}
409419

420+
fn taken(&mut self, bytes: u64) {
421+
self.set_limit(self.limit() - bytes);
422+
self.get_mut().taken(bytes);
423+
}
424+
410425
fn min_limit(&self) -> u64 {
411426
min(Take::limit(self), self.get_ref().min_limit())
412427
}
@@ -432,6 +447,10 @@ impl<T: CopyRead> CopyRead for BufReader<T> {
432447
Ok(bytes as u64 + inner_bytes)
433448
}
434449

450+
fn taken(&mut self, bytes: u64) {
451+
self.get_mut().taken(bytes);
452+
}
453+
435454
fn min_limit(&self) -> u64 {
436455
self.get_ref().min_limit()
437456
}
@@ -457,10 +476,21 @@ fn fd_to_meta<T: AsRawFd>(fd: &T) -> FdMeta {
457476
}
458477

459478
pub(super) enum CopyResult {
460-
Ended(Result<u64>),
479+
Ended(u64),
480+
Error(Error, u64),
461481
Fallback(u64),
462482
}
463483

484+
impl CopyResult {
485+
fn update_take(&self, reader: &mut impl CopyRead) {
486+
match *self {
487+
CopyResult::Fallback(bytes)
488+
| CopyResult::Ended(bytes)
489+
| CopyResult::Error(_, bytes) => reader.taken(bytes),
490+
}
491+
}
492+
}
493+
464494
/// linux-specific implementation that will attempt to use copy_file_range for copy offloading
465495
/// as the name says, it only works on regular files
466496
///
@@ -527,7 +557,7 @@ pub(super) fn copy_regular_files(reader: RawFd, writer: RawFd, max_len: u64) ->
527557
// - copying from an overlay filesystem in docker. reported to occur on fedora 32.
528558
return CopyResult::Fallback(0);
529559
}
530-
Ok(0) => return CopyResult::Ended(Ok(written)), // reached EOF
560+
Ok(0) => return CopyResult::Ended(written), // reached EOF
531561
Ok(ret) => written += ret as u64,
532562
Err(err) => {
533563
return match err.raw_os_error() {
@@ -545,12 +575,12 @@ pub(super) fn copy_regular_files(reader: RawFd, writer: RawFd, max_len: u64) ->
545575
assert_eq!(written, 0);
546576
CopyResult::Fallback(0)
547577
}
548-
_ => CopyResult::Ended(Err(err)),
578+
_ => CopyResult::Error(err, written),
549579
};
550580
}
551581
}
552582
}
553-
CopyResult::Ended(Ok(written))
583+
CopyResult::Ended(written)
554584
}
555585

556586
#[derive(PartialEq)]
@@ -623,10 +653,10 @@ fn sendfile_splice(mode: SpliceMode, reader: RawFd, writer: RawFd, len: u64) ->
623653
Some(os_err) if mode == SpliceMode::Sendfile && os_err == libc::EOVERFLOW => {
624654
CopyResult::Fallback(written)
625655
}
626-
_ => CopyResult::Ended(Err(err)),
656+
_ => CopyResult::Error(err, written),
627657
};
628658
}
629659
}
630660
}
631-
CopyResult::Ended(Ok(written))
661+
CopyResult::Ended(written)
632662
}

Diff for: library/std/src/sys/unix/kernel_copy/tests.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ fn bench_socket_pipe_socket_copy(b: &mut test::Bencher) {
217217
);
218218

219219
match probe {
220-
CopyResult::Ended(Ok(1)) => {
220+
CopyResult::Ended(1) => {
221221
// splice works
222222
}
223223
_ => {

0 commit comments

Comments
 (0)