Skip to content

Commit 3f70049

Browse files
committed
Move fold logic to iter_fold method and reuse it in count and last
1 parent cbc5f62 commit 3f70049

File tree

2 files changed

+118
-16
lines changed

2 files changed

+118
-16
lines changed

library/core/src/iter/adapters/flatten.rs

+76-16
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,16 @@ where
7878
fn advance_by(&mut self, n: usize) -> Result<(), usize> {
7979
self.inner.advance_by(n)
8080
}
81+
82+
#[inline]
83+
fn count(self) -> usize {
84+
self.inner.count()
85+
}
86+
87+
#[inline]
88+
fn last(self) -> Option<Self::Item> {
89+
self.inner.last()
90+
}
8191
}
8292

8393
#[stable(feature = "rust1", since = "1.0.0")]
@@ -229,6 +239,16 @@ where
229239
fn advance_by(&mut self, n: usize) -> Result<(), usize> {
230240
self.inner.advance_by(n)
231241
}
242+
243+
#[inline]
244+
fn count(self) -> usize {
245+
self.inner.count()
246+
}
247+
248+
#[inline]
249+
fn last(self) -> Option<Self::Item> {
250+
self.inner.last()
251+
}
232252
}
233253

234254
#[stable(feature = "iterator_flatten", since = "1.29.0")]
@@ -304,6 +324,35 @@ impl<I, U> FlattenCompat<I, U>
304324
where
305325
I: Iterator<Item: IntoIterator<IntoIter = U>>,
306326
{
327+
/// Folds the inner iterators into an accumulator by applying an operation.
328+
///
329+
/// Folds over the inner iterators, not over their elements. Is used by the `fold`, `count`,
330+
/// and `last` methods.
331+
#[inline]
332+
fn iter_fold<Acc, Fold>(self, mut acc: Acc, mut fold: Fold) -> Acc
333+
where
334+
Fold: FnMut(Acc, U) -> Acc,
335+
{
336+
#[inline]
337+
fn flatten<T: IntoIterator, Acc>(
338+
fold: &mut impl FnMut(Acc, T::IntoIter) -> Acc,
339+
) -> impl FnMut(Acc, T) -> Acc + '_ {
340+
move |acc, iter| fold(acc, iter.into_iter())
341+
}
342+
343+
if let Some(iter) = self.frontiter {
344+
acc = fold(acc, iter);
345+
}
346+
347+
acc = self.iter.fold(acc, flatten(&mut fold));
348+
349+
if let Some(iter) = self.backiter {
350+
acc = fold(acc, iter);
351+
}
352+
353+
acc
354+
}
355+
307356
/// Folds over the inner iterators as long as the given function returns successfully,
308357
/// always storing the most recent inner iterator in `self.frontiter`.
309358
///
@@ -440,28 +489,18 @@ where
440489
}
441490

442491
#[inline]
443-
fn fold<Acc, Fold>(self, mut init: Acc, mut fold: Fold) -> Acc
492+
fn fold<Acc, Fold>(self, init: Acc, fold: Fold) -> Acc
444493
where
445494
Fold: FnMut(Acc, Self::Item) -> Acc,
446495
{
447496
#[inline]
448-
fn flatten<T: IntoIterator, Acc>(
449-
fold: &mut impl FnMut(Acc, T::Item) -> Acc,
450-
) -> impl FnMut(Acc, T) -> Acc + '_ {
451-
move |acc, x| x.into_iter().fold(acc, &mut *fold)
497+
fn flatten<U: Iterator, Acc>(
498+
mut fold: impl FnMut(Acc, U::Item) -> Acc,
499+
) -> impl FnMut(Acc, U) -> Acc {
500+
move |acc, iter| iter.fold(acc, &mut fold)
452501
}
453502

454-
if let Some(front) = self.frontiter {
455-
init = front.fold(init, &mut fold);
456-
}
457-
458-
init = self.iter.fold(init, flatten(&mut fold));
459-
460-
if let Some(back) = self.backiter {
461-
init = back.fold(init, &mut fold);
462-
}
463-
464-
init
503+
self.iter_fold(init, flatten(fold))
465504
}
466505

467506
#[inline]
@@ -481,6 +520,27 @@ where
481520
_ => Ok(()),
482521
}
483522
}
523+
524+
#[inline]
525+
fn count(self) -> usize {
526+
#[inline]
527+
#[rustc_inherit_overflow_checks]
528+
fn count<U: Iterator>(acc: usize, iter: U) -> usize {
529+
acc + iter.count()
530+
}
531+
532+
self.iter_fold(0, count)
533+
}
534+
535+
#[inline]
536+
fn last(self) -> Option<Self::Item> {
537+
#[inline]
538+
fn last<U: Iterator>(last: Option<U::Item>, iter: U) -> Option<U::Item> {
539+
iter.last().or(last)
540+
}
541+
542+
self.iter_fold(None, last)
543+
}
484544
}
485545

486546
impl<I, U> DoubleEndedIterator for FlattenCompat<I, U>

library/core/tests/iter/adapters/flatten.rs

+42
Original file line numberDiff line numberDiff line change
@@ -168,3 +168,45 @@ fn test_trusted_len_flatten() {
168168
assert_trusted_len(&iter);
169169
assert_eq!(iter.size_hint(), (20, Some(20)));
170170
}
171+
172+
#[test]
173+
fn test_flatten_count() {
174+
let mut it = once(0..10).chain(once(10..30)).chain(once(30..40)).flatten();
175+
176+
assert_eq!(it.clone().count(), 40);
177+
it.advance_by(5).unwrap();
178+
assert_eq!(it.clone().count(), 35);
179+
it.advance_back_by(5).unwrap();
180+
assert_eq!(it.clone().count(), 30);
181+
it.advance_by(10).unwrap();
182+
assert_eq!(it.clone().count(), 20);
183+
it.advance_back_by(8).unwrap();
184+
assert_eq!(it.clone().count(), 12);
185+
it.advance_by(4).unwrap();
186+
assert_eq!(it.clone().count(), 8);
187+
it.advance_back_by(5).unwrap();
188+
assert_eq!(it.clone().count(), 3);
189+
it.advance_by(3).unwrap();
190+
assert_eq!(it.clone().count(), 0);
191+
}
192+
193+
#[test]
194+
fn test_flatten_last() {
195+
let mut it = once(0..10).chain(once(10..30)).chain(once(30..40)).flatten();
196+
197+
assert_eq!(it.clone().last(), Some(39));
198+
it.advance_by(5).unwrap(); // 5..40
199+
assert_eq!(it.clone().last(), Some(39));
200+
it.advance_back_by(5).unwrap(); // 5..35
201+
assert_eq!(it.clone().last(), Some(34));
202+
it.advance_by(10).unwrap(); // 15..35
203+
assert_eq!(it.clone().last(), Some(34));
204+
it.advance_back_by(8).unwrap(); // 15..27
205+
assert_eq!(it.clone().last(), Some(26));
206+
it.advance_by(4).unwrap(); // 19..27
207+
assert_eq!(it.clone().last(), Some(26));
208+
it.advance_back_by(5).unwrap(); // 19..22
209+
assert_eq!(it.clone().last(), Some(21));
210+
it.advance_by(3).unwrap(); // 22..22
211+
assert_eq!(it.clone().last(), None);
212+
}

0 commit comments

Comments
 (0)