Skip to content

Commit 6ebe590

Browse files
authored
Rollup merge of rust-lang#135847 - edwloef:slice_ptr_rotate_opt, r=scottmcm
optimize slice::ptr_rotate for small rotates r? `@scottmcm` This swaps the positions and numberings of algorithms 1 and 2 in `slice::ptr_rotate`, and pulls the entire outer loop into algorithm 3 since it was redundant for the first two. Effectively, `ptr_rotate` now always does the `memcpy`+`memmove`+`memcpy` sequence if the shifts fit into the stack buffer. With this change, an `IndexMap`-style `move_index` function is optimized correctly. Assembly comparisons: - `move_index`, before: https://godbolt.org/z/Kr616KnYM - `move_index`, after: https://godbolt.org/z/1aoov6j8h - the code from `rust-lang#89714`, before: https://godbolt.org/z/Y4zaPxEG6 - the code from `rust-lang#89714`, after: https://godbolt.org/z/1dPx83axc related to rust-lang#89714 some relevant discussion in https://internals.rust-lang.org/t/idea-shift-move-to-efficiently-move-elements-in-a-vec/22184 Behavior tests pass locally. I can't get any consistent microbenchmark results on my machine, but the assembly diffs look promising.
2 parents 4a5f1cc + fb3d1d0 commit 6ebe590

File tree

2 files changed

+212
-152
lines changed

2 files changed

+212
-152
lines changed

Diff for: library/core/src/slice/rotate.rs

+182-152
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,91 @@
11
use crate::mem::{self, MaybeUninit, SizedTypeProperties};
22
use crate::{cmp, ptr};
33

4+
type BufType = [usize; 32];
5+
46
/// Rotates the range `[mid-left, mid+right)` such that the element at `mid` becomes the first
57
/// element. Equivalently, rotates the range `left` elements to the left or `right` elements to the
68
/// right.
79
///
810
/// # Safety
911
///
1012
/// The specified range must be valid for reading and writing.
13+
#[inline]
14+
pub(super) unsafe fn ptr_rotate<T>(left: usize, mid: *mut T, right: usize) {
15+
if T::IS_ZST {
16+
return;
17+
}
18+
// abort early if the rotate is a no-op
19+
if (left == 0) || (right == 0) {
20+
return;
21+
}
22+
// `T` is not a zero-sized type, so it's okay to divide by its size.
23+
if !cfg!(feature = "optimize_for_size")
24+
&& cmp::min(left, right) <= mem::size_of::<BufType>() / mem::size_of::<T>()
25+
{
26+
// SAFETY: guaranteed by the caller
27+
unsafe { ptr_rotate_memmove(left, mid, right) };
28+
} else if !cfg!(feature = "optimize_for_size")
29+
&& ((left + right < 24) || (mem::size_of::<T>() > mem::size_of::<[usize; 4]>()))
30+
{
31+
// SAFETY: guaranteed by the caller
32+
unsafe { ptr_rotate_gcd(left, mid, right) }
33+
} else {
34+
// SAFETY: guaranteed by the caller
35+
unsafe { ptr_rotate_swap(left, mid, right) }
36+
}
37+
}
38+
39+
/// Algorithm 1 is used if `min(left, right)` is small enough to fit onto a stack buffer. The
40+
/// `min(left, right)` elements are copied onto the buffer, `memmove` is applied to the others, and
41+
/// the ones on the buffer are moved back into the hole on the opposite side of where they
42+
/// originated.
1143
///
12-
/// # Algorithm
44+
/// # Safety
1345
///
14-
/// Algorithm 1 is used for small values of `left + right` or for large `T`. The elements are moved
15-
/// into their final positions one at a time starting at `mid - left` and advancing by `right` steps
16-
/// modulo `left + right`, such that only one temporary is needed. Eventually, we arrive back at
17-
/// `mid - left`. However, if `gcd(left + right, right)` is not 1, the above steps skipped over
18-
/// elements. For example:
46+
/// The specified range must be valid for reading and writing.
47+
#[inline]
48+
unsafe fn ptr_rotate_memmove<T>(left: usize, mid: *mut T, right: usize) {
49+
// The `[T; 0]` here is to ensure this is appropriately aligned for T
50+
let mut rawarray = MaybeUninit::<(BufType, [T; 0])>::uninit();
51+
let buf = rawarray.as_mut_ptr() as *mut T;
52+
// SAFETY: `mid-left <= mid-left+right < mid+right`
53+
let dim = unsafe { mid.sub(left).add(right) };
54+
if left <= right {
55+
// SAFETY:
56+
//
57+
// 1) The `if` condition about the sizes ensures `[mid-left; left]` will fit in
58+
// `buf` without overflow and `buf` was created just above and so cannot be
59+
// overlapped with any value of `[mid-left; left]`
60+
// 2) [mid-left, mid+right) are all valid for reading and writing and we don't care
61+
// about overlaps here.
62+
// 3) The `if` condition about `left <= right` ensures writing `left` elements to
63+
// `dim = mid-left+right` is valid because:
64+
// - `buf` is valid and `left` elements were written in it in 1)
65+
// - `dim+left = mid-left+right+left = mid+right` and we write `[dim, dim+left)`
66+
unsafe {
67+
// 1)
68+
ptr::copy_nonoverlapping(mid.sub(left), buf, left);
69+
// 2)
70+
ptr::copy(mid, mid.sub(left), right);
71+
// 3)
72+
ptr::copy_nonoverlapping(buf, dim, left);
73+
}
74+
} else {
75+
// SAFETY: same reasoning as above but with `left` and `right` reversed
76+
unsafe {
77+
ptr::copy_nonoverlapping(mid, buf, right);
78+
ptr::copy(mid.sub(left), dim, left);
79+
ptr::copy_nonoverlapping(buf, mid.sub(left), right);
80+
}
81+
}
82+
}
83+
84+
/// Algorithm 2 is used for small values of `left + right` or for large `T`. The elements
85+
/// are moved into their final positions one at a time starting at `mid - left` and advancing by
86+
/// `right` steps modulo `left + right`, such that only one temporary is needed. Eventually, we
87+
/// arrive back at `mid - left`. However, if `gcd(left + right, right)` is not 1, the above steps
88+
/// skipped over elements. For example:
1989
/// ```text
2090
/// left = 10, right = 6
2191
/// the `^` indicates an element in its final place
@@ -39,17 +109,104 @@ use crate::{cmp, ptr};
39109
/// `gcd(left + right, right)` value). The end result is that all elements are finalized once and
40110
/// only once.
41111
///
42-
/// Algorithm 2 is used if `left + right` is large but `min(left, right)` is small enough to
43-
/// fit onto a stack buffer. The `min(left, right)` elements are copied onto the buffer, `memmove`
44-
/// is applied to the others, and the ones on the buffer are moved back into the hole on the
45-
/// opposite side of where they originated.
46-
///
47-
/// Algorithms that can be vectorized outperform the above once `left + right` becomes large enough.
48-
/// Algorithm 1 can be vectorized by chunking and performing many rounds at once, but there are too
112+
/// Algorithm 2 can be vectorized by chunking and performing many rounds at once, but there are too
49113
/// few rounds on average until `left + right` is enormous, and the worst case of a single
50-
/// round is always there. Instead, algorithm 3 utilizes repeated swapping of
51-
/// `min(left, right)` elements until a smaller rotate problem is left.
114+
/// round is always there.
115+
///
116+
/// # Safety
117+
///
118+
/// The specified range must be valid for reading and writing.
119+
#[inline]
120+
unsafe fn ptr_rotate_gcd<T>(left: usize, mid: *mut T, right: usize) {
121+
// Algorithm 2
122+
// Microbenchmarks indicate that the average performance for random shifts is better all
123+
// the way until about `left + right == 32`, but the worst case performance breaks even
124+
// around 16. 24 was chosen as middle ground. If the size of `T` is larger than 4
125+
// `usize`s, this algorithm also outperforms other algorithms.
126+
// SAFETY: callers must ensure `mid - left` is valid for reading and writing.
127+
let x = unsafe { mid.sub(left) };
128+
// beginning of first round
129+
// SAFETY: see previous comment.
130+
let mut tmp: T = unsafe { x.read() };
131+
let mut i = right;
132+
// `gcd` can be found before hand by calculating `gcd(left + right, right)`,
133+
// but it is faster to do one loop which calculates the gcd as a side effect, then
134+
// doing the rest of the chunk
135+
let mut gcd = right;
136+
// benchmarks reveal that it is faster to swap temporaries all the way through instead
137+
// of reading one temporary once, copying backwards, and then writing that temporary at
138+
// the very end. This is possibly due to the fact that swapping or replacing temporaries
139+
// uses only one memory address in the loop instead of needing to manage two.
140+
loop {
141+
// [long-safety-expl]
142+
// SAFETY: callers must ensure `[left, left+mid+right)` are all valid for reading and
143+
// writing.
144+
//
145+
// - `i` start with `right` so `mid-left <= x+i = x+right = mid-left+right < mid+right`
146+
// - `i <= left+right-1` is always true
147+
// - if `i < left`, `right` is added so `i < left+right` and on the next
148+
// iteration `left` is removed from `i` so it doesn't go further
149+
// - if `i >= left`, `left` is removed immediately and so it doesn't go further.
150+
// - overflows cannot happen for `i` since the function's safety contract ask for
151+
// `mid+right-1 = x+left+right` to be valid for writing
152+
// - underflows cannot happen because `i` must be bigger or equal to `left` for
153+
// a subtraction of `left` to happen.
154+
//
155+
// So `x+i` is valid for reading and writing if the caller respected the contract
156+
tmp = unsafe { x.add(i).replace(tmp) };
157+
// instead of incrementing `i` and then checking if it is outside the bounds, we
158+
// check if `i` will go outside the bounds on the next increment. This prevents
159+
// any wrapping of pointers or `usize`.
160+
if i >= left {
161+
i -= left;
162+
if i == 0 {
163+
// end of first round
164+
// SAFETY: tmp has been read from a valid source and x is valid for writing
165+
// according to the caller.
166+
unsafe { x.write(tmp) };
167+
break;
168+
}
169+
// this conditional must be here if `left + right >= 15`
170+
if i < gcd {
171+
gcd = i;
172+
}
173+
} else {
174+
i += right;
175+
}
176+
}
177+
// finish the chunk with more rounds
178+
for start in 1..gcd {
179+
// SAFETY: `gcd` is at most equal to `right` so all values in `1..gcd` are valid for
180+
// reading and writing as per the function's safety contract, see [long-safety-expl]
181+
// above
182+
tmp = unsafe { x.add(start).read() };
183+
// [safety-expl-addition]
184+
//
185+
// Here `start < gcd` so `start < right` so `i < right+right`: `right` being the
186+
// greatest common divisor of `(left+right, right)` means that `left = right` so
187+
// `i < left+right` so `x+i = mid-left+i` is always valid for reading and writing
188+
// according to the function's safety contract.
189+
i = start + right;
190+
loop {
191+
// SAFETY: see [long-safety-expl] and [safety-expl-addition]
192+
tmp = unsafe { x.add(i).replace(tmp) };
193+
if i >= left {
194+
i -= left;
195+
if i == start {
196+
// SAFETY: see [long-safety-expl] and [safety-expl-addition]
197+
unsafe { x.add(start).write(tmp) };
198+
break;
199+
}
200+
} else {
201+
i += right;
202+
}
203+
}
204+
}
205+
}
206+
207+
/// Algorithm 3 utilizes repeated swapping of `min(left, right)` elements.
52208
///
209+
/// ///
53210
/// ```text
54211
/// left = 11, right = 4
55212
/// [4 5 6 7 8 9 10 11 12 13 14 . 0 1 2 3]
@@ -60,144 +217,14 @@ use crate::{cmp, ptr};
60217
/// we cannot swap any more, but a smaller rotation problem is left to solve
61218
/// ```
62219
/// when `left < right` the swapping happens from the left instead.
63-
pub(super) unsafe fn ptr_rotate<T>(mut left: usize, mut mid: *mut T, mut right: usize) {
64-
type BufType = [usize; 32];
65-
if T::IS_ZST {
66-
return;
67-
}
220+
///
221+
/// # Safety
222+
///
223+
/// The specified range must be valid for reading and writing.
224+
#[inline]
225+
unsafe fn ptr_rotate_swap<T>(mut left: usize, mut mid: *mut T, mut right: usize) {
68226
loop {
69-
// N.B. the below algorithms can fail if these cases are not checked
70-
if (right == 0) || (left == 0) {
71-
return;
72-
}
73-
if !cfg!(feature = "optimize_for_size")
74-
&& ((left + right < 24) || (mem::size_of::<T>() > mem::size_of::<[usize; 4]>()))
75-
{
76-
// Algorithm 1
77-
// Microbenchmarks indicate that the average performance for random shifts is better all
78-
// the way until about `left + right == 32`, but the worst case performance breaks even
79-
// around 16. 24 was chosen as middle ground. If the size of `T` is larger than 4
80-
// `usize`s, this algorithm also outperforms other algorithms.
81-
// SAFETY: callers must ensure `mid - left` is valid for reading and writing.
82-
let x = unsafe { mid.sub(left) };
83-
// beginning of first round
84-
// SAFETY: see previous comment.
85-
let mut tmp: T = unsafe { x.read() };
86-
let mut i = right;
87-
// `gcd` can be found before hand by calculating `gcd(left + right, right)`,
88-
// but it is faster to do one loop which calculates the gcd as a side effect, then
89-
// doing the rest of the chunk
90-
let mut gcd = right;
91-
// benchmarks reveal that it is faster to swap temporaries all the way through instead
92-
// of reading one temporary once, copying backwards, and then writing that temporary at
93-
// the very end. This is possibly due to the fact that swapping or replacing temporaries
94-
// uses only one memory address in the loop instead of needing to manage two.
95-
loop {
96-
// [long-safety-expl]
97-
// SAFETY: callers must ensure `[left, left+mid+right)` are all valid for reading and
98-
// writing.
99-
//
100-
// - `i` start with `right` so `mid-left <= x+i = x+right = mid-left+right < mid+right`
101-
// - `i <= left+right-1` is always true
102-
// - if `i < left`, `right` is added so `i < left+right` and on the next
103-
// iteration `left` is removed from `i` so it doesn't go further
104-
// - if `i >= left`, `left` is removed immediately and so it doesn't go further.
105-
// - overflows cannot happen for `i` since the function's safety contract ask for
106-
// `mid+right-1 = x+left+right` to be valid for writing
107-
// - underflows cannot happen because `i` must be bigger or equal to `left` for
108-
// a subtraction of `left` to happen.
109-
//
110-
// So `x+i` is valid for reading and writing if the caller respected the contract
111-
tmp = unsafe { x.add(i).replace(tmp) };
112-
// instead of incrementing `i` and then checking if it is outside the bounds, we
113-
// check if `i` will go outside the bounds on the next increment. This prevents
114-
// any wrapping of pointers or `usize`.
115-
if i >= left {
116-
i -= left;
117-
if i == 0 {
118-
// end of first round
119-
// SAFETY: tmp has been read from a valid source and x is valid for writing
120-
// according to the caller.
121-
unsafe { x.write(tmp) };
122-
break;
123-
}
124-
// this conditional must be here if `left + right >= 15`
125-
if i < gcd {
126-
gcd = i;
127-
}
128-
} else {
129-
i += right;
130-
}
131-
}
132-
// finish the chunk with more rounds
133-
for start in 1..gcd {
134-
// SAFETY: `gcd` is at most equal to `right` so all values in `1..gcd` are valid for
135-
// reading and writing as per the function's safety contract, see [long-safety-expl]
136-
// above
137-
tmp = unsafe { x.add(start).read() };
138-
// [safety-expl-addition]
139-
//
140-
// Here `start < gcd` so `start < right` so `i < right+right`: `right` being the
141-
// greatest common divisor of `(left+right, right)` means that `left = right` so
142-
// `i < left+right` so `x+i = mid-left+i` is always valid for reading and writing
143-
// according to the function's safety contract.
144-
i = start + right;
145-
loop {
146-
// SAFETY: see [long-safety-expl] and [safety-expl-addition]
147-
tmp = unsafe { x.add(i).replace(tmp) };
148-
if i >= left {
149-
i -= left;
150-
if i == start {
151-
// SAFETY: see [long-safety-expl] and [safety-expl-addition]
152-
unsafe { x.add(start).write(tmp) };
153-
break;
154-
}
155-
} else {
156-
i += right;
157-
}
158-
}
159-
}
160-
return;
161-
// `T` is not a zero-sized type, so it's okay to divide by its size.
162-
} else if !cfg!(feature = "optimize_for_size")
163-
&& cmp::min(left, right) <= mem::size_of::<BufType>() / mem::size_of::<T>()
164-
{
165-
// Algorithm 2
166-
// The `[T; 0]` here is to ensure this is appropriately aligned for T
167-
let mut rawarray = MaybeUninit::<(BufType, [T; 0])>::uninit();
168-
let buf = rawarray.as_mut_ptr() as *mut T;
169-
// SAFETY: `mid-left <= mid-left+right < mid+right`
170-
let dim = unsafe { mid.sub(left).add(right) };
171-
if left <= right {
172-
// SAFETY:
173-
//
174-
// 1) The `else if` condition about the sizes ensures `[mid-left; left]` will fit in
175-
// `buf` without overflow and `buf` was created just above and so cannot be
176-
// overlapped with any value of `[mid-left; left]`
177-
// 2) [mid-left, mid+right) are all valid for reading and writing and we don't care
178-
// about overlaps here.
179-
// 3) The `if` condition about `left <= right` ensures writing `left` elements to
180-
// `dim = mid-left+right` is valid because:
181-
// - `buf` is valid and `left` elements were written in it in 1)
182-
// - `dim+left = mid-left+right+left = mid+right` and we write `[dim, dim+left)`
183-
unsafe {
184-
// 1)
185-
ptr::copy_nonoverlapping(mid.sub(left), buf, left);
186-
// 2)
187-
ptr::copy(mid, mid.sub(left), right);
188-
// 3)
189-
ptr::copy_nonoverlapping(buf, dim, left);
190-
}
191-
} else {
192-
// SAFETY: same reasoning as above but with `left` and `right` reversed
193-
unsafe {
194-
ptr::copy_nonoverlapping(mid, buf, right);
195-
ptr::copy(mid.sub(left), dim, left);
196-
ptr::copy_nonoverlapping(buf, mid.sub(left), right);
197-
}
198-
}
199-
return;
200-
} else if left >= right {
227+
if left >= right {
201228
// Algorithm 3
202229
// There is an alternate way of swapping that involves finding where the last swap
203230
// of this algorithm would be, and swapping using that last chunk instead of swapping
@@ -233,5 +260,8 @@ pub(super) unsafe fn ptr_rotate<T>(mut left: usize, mut mid: *mut T, mut right:
233260
}
234261
}
235262
}
263+
if (right == 0) || (left == 0) {
264+
return;
265+
}
236266
}
237267
}

0 commit comments

Comments
 (0)