Skip to content

Commit 3ca6bb0

Browse files
committed
Expand in-place iteration specialization to Flatten, FlatMap and ArrayChunks
1 parent 2a1af89 commit 3ca6bb0

25 files changed

+391
-63
lines changed

library/alloc/src/collections/binary_heap/mod.rs

+9-2
Original file line numberDiff line numberDiff line change
@@ -145,7 +145,7 @@
145145

146146
use core::alloc::Allocator;
147147
use core::fmt;
148-
use core::iter::{FusedIterator, InPlaceIterable, SourceIter, TrustedLen};
148+
use core::iter::{FusedIterator, InPlaceIterable, SourceIter, TrustedFused, TrustedLen};
149149
use core::mem::{self, swap, ManuallyDrop};
150150
use core::num::NonZeroUsize;
151151
use core::ops::{Deref, DerefMut};
@@ -1540,6 +1540,10 @@ impl<T, A: Allocator> ExactSizeIterator for IntoIter<T, A> {
15401540
#[stable(feature = "fused", since = "1.26.0")]
15411541
impl<T, A: Allocator> FusedIterator for IntoIter<T, A> {}
15421542

1543+
#[doc(hidden)]
1544+
#[unstable(issue = "none", feature = "trusted_fused")]
1545+
unsafe impl<T, A: Allocator> TrustedFused for IntoIter<T, A> {}
1546+
15431547
#[stable(feature = "default_iters", since = "1.70.0")]
15441548
impl<T> Default for IntoIter<T> {
15451549
/// Creates an empty `binary_heap::IntoIter`.
@@ -1569,7 +1573,10 @@ unsafe impl<T, A: Allocator> SourceIter for IntoIter<T, A> {
15691573

15701574
#[unstable(issue = "none", feature = "inplace_iteration")]
15711575
#[doc(hidden)]
1572-
unsafe impl<I, A: Allocator> InPlaceIterable for IntoIter<I, A> {}
1576+
unsafe impl<I, A: Allocator> InPlaceIterable for IntoIter<I, A> {
1577+
const EXPAND_BY: Option<NonZeroUsize> = NonZeroUsize::new(1);
1578+
const MERGE_BY: Option<NonZeroUsize> = NonZeroUsize::new(1);
1579+
}
15731580

15741581
unsafe impl<I> AsVecIntoIter for IntoIter<I> {
15751582
type Item = I;

library/alloc/src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,7 @@
155155
#![feature(std_internals)]
156156
#![feature(str_internals)]
157157
#![feature(strict_provenance)]
158+
#![feature(trusted_fused)]
158159
#![feature(trusted_len)]
159160
#![feature(trusted_random_access)]
160161
#![feature(try_trait_v2)]

library/alloc/src/vec/in_place_collect.rs

+66-17
Original file line numberDiff line numberDiff line change
@@ -137,44 +137,73 @@
137137
//! }
138138
//! vec.truncate(write_idx);
139139
//! ```
140+
use crate::alloc::{handle_alloc_error, Global};
141+
use core::alloc::Allocator;
142+
use core::alloc::Layout;
140143
use core::iter::{InPlaceIterable, SourceIter, TrustedRandomAccessNoCoerce};
141144
use core::mem::{self, ManuallyDrop, SizedTypeProperties};
142-
use core::ptr::{self};
145+
use core::num::NonZeroUsize;
146+
use core::ptr::{self, NonNull};
143147

144148
use super::{InPlaceDrop, InPlaceDstBufDrop, SpecFromIter, SpecFromIterNested, Vec};
145149

146-
/// Specialization marker for collecting an iterator pipeline into a Vec while reusing the
147-
/// source allocation, i.e. executing the pipeline in place.
148-
#[rustc_unsafe_specialization_marker]
149-
pub(super) trait InPlaceIterableMarker {}
150+
const fn in_place_collectible<DEST, SRC>(
151+
step_merge: Option<NonZeroUsize>,
152+
step_expand: Option<NonZeroUsize>,
153+
) -> bool {
154+
if DEST::IS_ZST || mem::align_of::<SRC>() != mem::align_of::<DEST>() {
155+
return false;
156+
}
157+
158+
match (step_merge, step_expand) {
159+
(Some(step_merge), Some(step_expand)) => {
160+
// At least N merged source items -> at most M expanded destination items
161+
// e.g.
162+
// - 1 x [u8; 4] -> 4x u8, via flatten
163+
// - 4 x u8 -> 1x [u8; 4], via array_chunks
164+
mem::size_of::<SRC>() * step_merge.get() == mem::size_of::<DEST>() * step_expand.get()
165+
}
166+
// Fall back to other from_iter impls if an overflow occured in the step merge/expansion
167+
// tracking.
168+
_ => false,
169+
}
170+
}
150171

151-
impl<T> InPlaceIterableMarker for T where T: InPlaceIterable {}
172+
/// This provides a shorthand for the source type since local type aliases aren't a thing.
173+
#[rustc_specialization_trait]
174+
trait InPlaceCollect: SourceIter<Source: AsVecIntoIter> + InPlaceIterable {
175+
type Src;
176+
}
177+
178+
impl<T> InPlaceCollect for T
179+
where
180+
T: SourceIter<Source: AsVecIntoIter> + InPlaceIterable,
181+
{
182+
type Src = <<T as SourceIter>::Source as AsVecIntoIter>::Item;
183+
}
152184

153185
impl<T, I> SpecFromIter<T, I> for Vec<T>
154186
where
155-
I: Iterator<Item = T> + SourceIter<Source: AsVecIntoIter> + InPlaceIterableMarker,
187+
I: Iterator<Item = T> + InPlaceCollect,
188+
<I as SourceIter>::Source: AsVecIntoIter,
156189
{
157190
default fn from_iter(mut iterator: I) -> Self {
158191
// See "Layout constraints" section in the module documentation. We rely on const
159192
// optimization here since these conditions currently cannot be expressed as trait bounds
160-
if T::IS_ZST
161-
|| mem::size_of::<T>()
162-
!= mem::size_of::<<<I as SourceIter>::Source as AsVecIntoIter>::Item>()
163-
|| mem::align_of::<T>()
164-
!= mem::align_of::<<<I as SourceIter>::Source as AsVecIntoIter>::Item>()
165-
{
193+
if const { !in_place_collectible::<T, I::Src>(I::MERGE_BY, I::EXPAND_BY) } {
166194
// fallback to more generic implementations
167195
return SpecFromIterNested::from_iter(iterator);
168196
}
169197

170-
let (src_buf, src_ptr, dst_buf, dst_end, cap) = unsafe {
198+
let (src_buf, src_ptr, src_cap, mut dst_buf, dst_end, dst_cap) = unsafe {
171199
let inner = iterator.as_inner().as_into_iter();
172200
(
173201
inner.buf.as_ptr(),
174202
inner.ptr,
203+
inner.cap,
175204
inner.buf.as_ptr() as *mut T,
176205
inner.end as *const T,
177-
inner.cap,
206+
inner.cap * mem::size_of::<I::Src>() / mem::size_of::<T>(),
178207
)
179208
};
180209

@@ -203,11 +232,31 @@ where
203232
// Note: This access to the source wouldn't be allowed by the TrustedRandomIteratorNoCoerce
204233
// contract (used by SpecInPlaceCollect below). But see the "O(1) collect" section in the
205234
// module documentation why this is ok anyway.
206-
let dst_guard = InPlaceDstBufDrop { ptr: dst_buf, len, cap };
235+
let dst_guard = InPlaceDstBufDrop { ptr: dst_buf, len, cap: dst_cap };
207236
src.forget_allocation_drop_remaining();
208237
mem::forget(dst_guard);
209238

210-
let vec = unsafe { Vec::from_raw_parts(dst_buf, len, cap) };
239+
// Adjust the allocation size if the source had a capacity in bytes that wasn't a multiple
240+
// of the destination type size.
241+
// Since the discrepancy should generally be small this should only result in some
242+
// bookkeeping updates and no memmove.
243+
if const { mem::size_of::<T>() > mem::size_of::<I::Src>() }
244+
&& src_cap * mem::size_of::<I::Src>() != dst_cap * mem::size_of::<T>()
245+
{
246+
let alloc = Global;
247+
unsafe {
248+
let new_layout = Layout::array::<T>(dst_cap).unwrap();
249+
let result = alloc.shrink(
250+
NonNull::new_unchecked(dst_buf as *mut u8),
251+
Layout::array::<I::Src>(src_cap).unwrap(),
252+
new_layout,
253+
);
254+
let Ok(reallocated) = result else { handle_alloc_error(new_layout) };
255+
dst_buf = reallocated.as_ptr() as *mut T;
256+
}
257+
}
258+
259+
let vec = unsafe { Vec::from_raw_parts(dst_buf, len, dst_cap) };
211260

212261
vec
213262
}

library/alloc/src/vec/into_iter.rs

+10-2
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@ use crate::raw_vec::RawVec;
77
use core::array;
88
use core::fmt;
99
use core::iter::{
10-
FusedIterator, InPlaceIterable, SourceIter, TrustedLen, TrustedRandomAccessNoCoerce,
10+
FusedIterator, InPlaceIterable, SourceIter, TrustedFused, TrustedLen,
11+
TrustedRandomAccessNoCoerce,
1112
};
1213
use core::marker::PhantomData;
1314
use core::mem::{self, ManuallyDrop, MaybeUninit, SizedTypeProperties};
@@ -339,6 +340,10 @@ impl<T, A: Allocator> ExactSizeIterator for IntoIter<T, A> {
339340
#[stable(feature = "fused", since = "1.26.0")]
340341
impl<T, A: Allocator> FusedIterator for IntoIter<T, A> {}
341342

343+
#[doc(hidden)]
344+
#[unstable(issue = "none", feature = "trusted_fused")]
345+
unsafe impl<T, A: Allocator> TrustedFused for IntoIter<T, A> {}
346+
342347
#[unstable(feature = "trusted_len", issue = "37572")]
343348
unsafe impl<T, A: Allocator> TrustedLen for IntoIter<T, A> {}
344349

@@ -423,7 +428,10 @@ unsafe impl<#[may_dangle] T, A: Allocator> Drop for IntoIter<T, A> {
423428
// also refer to the vec::in_place_collect module documentation to get an overview
424429
#[unstable(issue = "none", feature = "inplace_iteration")]
425430
#[doc(hidden)]
426-
unsafe impl<T, A: Allocator> InPlaceIterable for IntoIter<T, A> {}
431+
unsafe impl<T, A: Allocator> InPlaceIterable for IntoIter<T, A> {
432+
const EXPAND_BY: Option<NonZeroUsize> = NonZeroUsize::new(1);
433+
const MERGE_BY: Option<NonZeroUsize> = NonZeroUsize::new(1);
434+
}
427435

428436
#[unstable(issue = "none", feature = "inplace_iteration")]
429437
#[doc(hidden)]

library/alloc/tests/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
#![feature(allocator_api)]
22
#![feature(alloc_layout_extra)]
3+
#![feature(iter_array_chunks)]
34
#![feature(assert_matches)]
45
#![feature(btree_extract_if)]
56
#![feature(cow_is_borrowed)]

library/alloc/tests/vec.rs

+32-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
use alloc::vec::Vec;
12
use core::alloc::{Allocator, Layout};
2-
use core::assert_eq;
3-
use core::iter::IntoIterator;
3+
use core::{assert_eq, assert_ne};
4+
use core::iter::{IntoIterator, Iterator};
45
use core::num::NonZeroUsize;
56
use core::ptr::NonNull;
67
use std::alloc::System;
@@ -1184,6 +1185,35 @@ fn test_from_iter_specialization_with_iterator_adapters() {
11841185
assert_eq!(srcptr, sinkptr as *const usize);
11851186
}
11861187

1188+
#[test]
1189+
fn test_in_place_specialization_step_up_down() {
1190+
fn assert_in_place_trait<T: InPlaceIterable>(_: &T) {}
1191+
let src = vec![[0u8; 4]; 256];
1192+
let srcptr = src.as_ptr();
1193+
let src_cap = src.capacity();
1194+
let iter = src.into_iter().flatten();
1195+
assert_in_place_trait(&iter);
1196+
let sink = iter.collect::<Vec<_>>();
1197+
let sinkptr = sink.as_ptr();
1198+
assert_eq!(srcptr as *const u8, sinkptr);
1199+
assert_eq!(src_cap * 4, sink.capacity());
1200+
1201+
let iter = sink.into_iter().array_chunks::<4>();
1202+
assert_in_place_trait(&iter);
1203+
let sink = iter.collect::<Vec<_>>();
1204+
let sinkptr = sink.as_ptr();
1205+
assert_eq!(srcptr, sinkptr);
1206+
assert_eq!(src_cap, sink.capacity());
1207+
1208+
let mut src: Vec<u8> = Vec::with_capacity(17);
1209+
let src_bytes = src.capacity();
1210+
src.resize(8, 0u8);
1211+
let sink: Vec<[u8; 4]> = src.into_iter().array_chunks::<4>().collect();
1212+
let sink_bytes = sink.capacity() * 4;
1213+
assert_ne!(src_bytes, sink_bytes);
1214+
assert_eq!(sink.len(), 2);
1215+
}
1216+
11871217
#[test]
11881218
fn test_from_iter_specialization_head_tail_drop() {
11891219
let drop_count: Vec<_> = (0..=2).map(|_| Rc::new(())).collect();

library/core/src/iter/adapters/array_chunks.rs

+33-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
use crate::array;
2-
use crate::iter::{ByRefSized, FusedIterator, Iterator, TrustedRandomAccessNoCoerce};
2+
use crate::iter::adapters::SourceIter;
3+
use crate::iter::{
4+
ByRefSized, FusedIterator, InPlaceIterable, Iterator, TrustedFused, TrustedRandomAccessNoCoerce,
5+
};
6+
use crate::num::NonZeroUsize;
37
use crate::ops::{ControlFlow, NeverShortCircuit, Try};
48

59
/// An iterator over `N` elements of the iterator at a time.
@@ -159,6 +163,9 @@ where
159163
#[unstable(feature = "iter_array_chunks", reason = "recently added", issue = "100450")]
160164
impl<I, const N: usize> FusedIterator for ArrayChunks<I, N> where I: FusedIterator {}
161165

166+
#[unstable(issue = "none", feature = "trusted_fused")]
167+
unsafe impl<I, const N: usize> TrustedFused for ArrayChunks<I, N> where I: TrustedFused + Iterator {}
168+
162169
#[unstable(feature = "iter_array_chunks", reason = "recently added", issue = "100450")]
163170
impl<I, const N: usize> ExactSizeIterator for ArrayChunks<I, N>
164171
where
@@ -229,3 +236,28 @@ where
229236
accum
230237
}
231238
}
239+
240+
#[unstable(issue = "none", feature = "inplace_iteration")]
241+
unsafe impl<I, const N: usize> SourceIter for ArrayChunks<I, N>
242+
where
243+
I: SourceIter + Iterator,
244+
{
245+
type Source = I::Source;
246+
247+
#[inline]
248+
unsafe fn as_inner(&mut self) -> &mut I::Source {
249+
// SAFETY: unsafe function forwarding to unsafe function with the same requirements
250+
unsafe { SourceIter::as_inner(&mut self.iter) }
251+
}
252+
}
253+
254+
#[unstable(issue = "none", feature = "inplace_iteration")]
255+
unsafe impl<I: InPlaceIterable + Iterator, const N: usize> InPlaceIterable for ArrayChunks<I, N> {
256+
const EXPAND_BY: Option<NonZeroUsize> = I::EXPAND_BY;
257+
const MERGE_BY: Option<NonZeroUsize> = const {
258+
match (I::MERGE_BY, NonZeroUsize::new(N)) {
259+
(Some(m), Some(n)) => m.checked_mul(n),
260+
_ => None,
261+
}
262+
};
263+
}

library/core/src/iter/adapters/enumerate.rs

+8-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use crate::iter::adapters::{
22
zip::try_get_unchecked, SourceIter, TrustedRandomAccess, TrustedRandomAccessNoCoerce,
33
};
4-
use crate::iter::{FusedIterator, InPlaceIterable, TrustedLen};
4+
use crate::iter::{FusedIterator, InPlaceIterable, TrustedFused, TrustedLen};
55
use crate::num::NonZeroUsize;
66
use crate::ops::Try;
77

@@ -243,6 +243,9 @@ where
243243
#[stable(feature = "fused", since = "1.26.0")]
244244
impl<I> FusedIterator for Enumerate<I> where I: FusedIterator {}
245245

246+
#[unstable(issue = "none", feature = "trusted_fused")]
247+
unsafe impl<I: TrustedFused> TrustedFused for Enumerate<I> {}
248+
246249
#[unstable(feature = "trusted_len", issue = "37572")]
247250
unsafe impl<I> TrustedLen for Enumerate<I> where I: TrustedLen {}
248251

@@ -261,7 +264,10 @@ where
261264
}
262265

263266
#[unstable(issue = "none", feature = "inplace_iteration")]
264-
unsafe impl<I: InPlaceIterable> InPlaceIterable for Enumerate<I> {}
267+
unsafe impl<I: InPlaceIterable> InPlaceIterable for Enumerate<I> {
268+
const EXPAND_BY: Option<NonZeroUsize> = I::EXPAND_BY;
269+
const MERGE_BY: Option<NonZeroUsize> = I::MERGE_BY;
270+
}
265271

266272
#[stable(feature = "default_iters", since = "1.70.0")]
267273
impl<I: Default> Default for Enumerate<I> {

library/core/src/iter/adapters/filter.rs

+9-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
use crate::fmt;
2-
use crate::iter::{adapters::SourceIter, FusedIterator, InPlaceIterable};
2+
use crate::iter::{adapters::SourceIter, FusedIterator, InPlaceIterable, TrustedFused};
3+
use crate::num::NonZeroUsize;
34
use crate::ops::Try;
45
use core::array;
56
use core::mem::{ManuallyDrop, MaybeUninit};
@@ -189,6 +190,9 @@ where
189190
#[stable(feature = "fused", since = "1.26.0")]
190191
impl<I: FusedIterator, P> FusedIterator for Filter<I, P> where P: FnMut(&I::Item) -> bool {}
191192

193+
#[unstable(issue = "none", feature = "trusted_fused")]
194+
unsafe impl<I: TrustedFused, F> TrustedFused for Filter<I, F> {}
195+
192196
#[unstable(issue = "none", feature = "inplace_iteration")]
193197
unsafe impl<P, I> SourceIter for Filter<I, P>
194198
where
@@ -204,4 +208,7 @@ where
204208
}
205209

206210
#[unstable(issue = "none", feature = "inplace_iteration")]
207-
unsafe impl<I: InPlaceIterable, P> InPlaceIterable for Filter<I, P> where P: FnMut(&I::Item) -> bool {}
211+
unsafe impl<I: InPlaceIterable, P> InPlaceIterable for Filter<I, P> {
212+
const EXPAND_BY: Option<NonZeroUsize> = I::EXPAND_BY;
213+
const MERGE_BY: Option<NonZeroUsize> = I::MERGE_BY;
214+
}

library/core/src/iter/adapters/filter_map.rs

+8-4
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
1-
use crate::iter::{adapters::SourceIter, FusedIterator, InPlaceIterable};
1+
use crate::iter::{adapters::SourceIter, FusedIterator, InPlaceIterable, TrustedFused};
22
use crate::mem::{ManuallyDrop, MaybeUninit};
3+
use crate::num::NonZeroUsize;
34
use crate::ops::{ControlFlow, Try};
45
use crate::{array, fmt};
56

@@ -188,6 +189,9 @@ where
188189
#[stable(feature = "fused", since = "1.26.0")]
189190
impl<B, I: FusedIterator, F> FusedIterator for FilterMap<I, F> where F: FnMut(I::Item) -> Option<B> {}
190191

192+
#[unstable(issue = "none", feature = "trusted_fused")]
193+
unsafe impl<I: TrustedFused, F> TrustedFused for FilterMap<I, F> {}
194+
191195
#[unstable(issue = "none", feature = "inplace_iteration")]
192196
unsafe impl<I, F> SourceIter for FilterMap<I, F>
193197
where
@@ -203,7 +207,7 @@ where
203207
}
204208

205209
#[unstable(issue = "none", feature = "inplace_iteration")]
206-
unsafe impl<B, I: InPlaceIterable, F> InPlaceIterable for FilterMap<I, F> where
207-
F: FnMut(I::Item) -> Option<B>
208-
{
210+
unsafe impl<I: InPlaceIterable, F> InPlaceIterable for FilterMap<I, F> {
211+
const EXPAND_BY: Option<NonZeroUsize> = I::EXPAND_BY;
212+
const MERGE_BY: Option<NonZeroUsize> = I::MERGE_BY;
209213
}

0 commit comments

Comments
 (0)