Skip to content

Commit c4f2577

Browse files
committed
Auto merge of rust-lang#115273 - the8472:take-fold, r=cuviper
Optimize Take::{fold, for_each} when wrapping TrustedRandomAccess iterators
2 parents 9229b1e + f93e125 commit c4f2577

File tree

2 files changed

+97
-19
lines changed

2 files changed

+97
-19
lines changed

library/core/src/iter/adapters/take.rs

+82-19
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
use crate::cmp;
2-
use crate::iter::{adapters::SourceIter, FusedIterator, InPlaceIterable, TrustedLen};
2+
use crate::iter::{
3+
adapters::SourceIter, FusedIterator, InPlaceIterable, TrustedLen, TrustedRandomAccess,
4+
};
35
use crate::num::NonZeroUsize;
46
use crate::ops::{ControlFlow, Try};
57

@@ -98,26 +100,18 @@ where
98100
}
99101
}
100102

101-
impl_fold_via_try_fold! { fold -> try_fold }
102-
103103
#[inline]
104-
fn for_each<F: FnMut(Self::Item)>(mut self, f: F) {
105-
// The default implementation would use a unit accumulator, so we can
106-
// avoid a stateful closure by folding over the remaining number
107-
// of items we wish to return instead.
108-
fn check<'a, Item>(
109-
mut action: impl FnMut(Item) + 'a,
110-
) -> impl FnMut(usize, Item) -> Option<usize> + 'a {
111-
move |more, x| {
112-
action(x);
113-
more.checked_sub(1)
114-
}
115-
}
104+
fn fold<B, F>(self, init: B, f: F) -> B
105+
where
106+
Self: Sized,
107+
F: FnMut(B, Self::Item) -> B,
108+
{
109+
Self::spec_fold(self, init, f)
110+
}
116111

117-
let remaining = self.n;
118-
if remaining > 0 {
119-
self.iter.try_fold(remaining - 1, check(f));
120-
}
112+
#[inline]
113+
fn for_each<F: FnMut(Self::Item)>(self, f: F) {
114+
Self::spec_for_each(self, f)
121115
}
122116

123117
#[inline]
@@ -249,3 +243,72 @@ impl<I> FusedIterator for Take<I> where I: FusedIterator {}
249243

250244
#[unstable(feature = "trusted_len", issue = "37572")]
251245
unsafe impl<I: TrustedLen> TrustedLen for Take<I> {}
246+
247+
trait SpecTake: Iterator {
248+
fn spec_fold<B, F>(self, init: B, f: F) -> B
249+
where
250+
Self: Sized,
251+
F: FnMut(B, Self::Item) -> B;
252+
253+
fn spec_for_each<F: FnMut(Self::Item)>(self, f: F);
254+
}
255+
256+
impl<I: Iterator> SpecTake for Take<I> {
257+
#[inline]
258+
default fn spec_fold<B, F>(mut self, init: B, f: F) -> B
259+
where
260+
Self: Sized,
261+
F: FnMut(B, Self::Item) -> B,
262+
{
263+
use crate::ops::NeverShortCircuit;
264+
self.try_fold(init, NeverShortCircuit::wrap_mut_2(f)).0
265+
}
266+
267+
#[inline]
268+
default fn spec_for_each<F: FnMut(Self::Item)>(mut self, f: F) {
269+
// The default implementation would use a unit accumulator, so we can
270+
// avoid a stateful closure by folding over the remaining number
271+
// of items we wish to return instead.
272+
fn check<'a, Item>(
273+
mut action: impl FnMut(Item) + 'a,
274+
) -> impl FnMut(usize, Item) -> Option<usize> + 'a {
275+
move |more, x| {
276+
action(x);
277+
more.checked_sub(1)
278+
}
279+
}
280+
281+
let remaining = self.n;
282+
if remaining > 0 {
283+
self.iter.try_fold(remaining - 1, check(f));
284+
}
285+
}
286+
}
287+
288+
impl<I: Iterator + TrustedRandomAccess> SpecTake for Take<I> {
289+
#[inline]
290+
fn spec_fold<B, F>(mut self, init: B, mut f: F) -> B
291+
where
292+
Self: Sized,
293+
F: FnMut(B, Self::Item) -> B,
294+
{
295+
let mut acc = init;
296+
let end = self.n.min(self.iter.size());
297+
for i in 0..end {
298+
// SAFETY: i < end <= self.iter.size() and we discard the iterator at the end
299+
let val = unsafe { self.iter.__iterator_get_unchecked(i) };
300+
acc = f(acc, val);
301+
}
302+
acc
303+
}
304+
305+
#[inline]
306+
fn spec_for_each<F: FnMut(Self::Item)>(mut self, mut f: F) {
307+
let end = self.n.min(self.iter.size());
308+
for i in 0..end {
309+
// SAFETY: i < end <= self.iter.size() and we discard the iterator at the end
310+
let val = unsafe { self.iter.__iterator_get_unchecked(i) };
311+
f(val);
312+
}
313+
}
314+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
// ignore-debug: the debug assertions get in the way
2+
// compile-flags: -O
3+
// only-x86_64 (vectorization varies between architectures)
4+
#![crate_type = "lib"]
5+
6+
7+
// Ensure that slice + take + sum gets vectorized.
8+
// Currently this relies on the slice::Iter::try_fold implementation
9+
// CHECK-LABEL: @slice_take_sum
10+
#[no_mangle]
11+
pub fn slice_take_sum(s: &[u64], l: usize) -> u64 {
12+
// CHECK: vector.body:
13+
// CHECK: ret
14+
s.iter().take(l).sum()
15+
}

0 commit comments

Comments
 (0)