Skip to content

Commit c44af61

Browse files
committed
Speed up checked_isqrt and isqrt methods
* Use a lookup table for 8-bit integers and the Karatsuba square root algorithm for larger integers. * Include optimization hints that give the compiler the exact numeric range of results.
1 parent 2139651 commit c44af61

File tree

5 files changed

+371
-35
lines changed

5 files changed

+371
-35
lines changed

Diff for: core/src/num/int_macros.rs

+29-7
Original file line numberDiff line numberDiff line change
@@ -1641,7 +1641,33 @@ macro_rules! int_impl {
16411641
if self < 0 {
16421642
None
16431643
} else {
1644-
Some((self as $UnsignedT).isqrt() as Self)
1644+
// SAFETY: Input is nonnegative in this `else` branch.
1645+
let result = unsafe {
1646+
crate::num::int_sqrt::$ActualT(self as $ActualT) as $SelfT
1647+
};
1648+
1649+
// Inform the optimizer what the range of outputs is. If
1650+
// testing `core` crashes with no panic message and a
1651+
// `num::int_sqrt::i*` test failed, it's because your edits
1652+
// caused these assertions to become false.
1653+
//
1654+
// SAFETY: Integer square root is a monotonically nondecreasing
1655+
// function, which means that increasing the input will never
1656+
// cause the output to decrease. Thus, since the input for
1657+
// nonnegative signed integers is bounded by
1658+
// `[0, <$ActualT>::MAX]`, sqrt(n) will be bounded by
1659+
// `[sqrt(0), sqrt(<$ActualT>::MAX)]`.
1660+
unsafe {
1661+
// SAFETY: `<$ActualT>::MAX` is nonnegative.
1662+
const MAX_RESULT: $SelfT = unsafe {
1663+
crate::num::int_sqrt::$ActualT(<$ActualT>::MAX) as $SelfT
1664+
};
1665+
1666+
crate::hint::assert_unchecked(result >= 0);
1667+
crate::hint::assert_unchecked(result <= MAX_RESULT);
1668+
}
1669+
1670+
Some(result)
16451671
}
16461672
}
16471673

@@ -2862,15 +2888,11 @@ macro_rules! int_impl {
28622888
#[must_use = "this returns the result of the operation, \
28632889
without modifying the original"]
28642890
#[inline]
2891+
#[track_caller]
28652892
pub const fn isqrt(self) -> Self {
2866-
// I would like to implement it as
2867-
// ```
2868-
// self.checked_isqrt().expect("argument of integer square root must be non-negative")
2869-
// ```
2870-
// but `expect` is not yet stable as a `const fn`.
28712893
match self.checked_isqrt() {
28722894
Some(sqrt) => sqrt,
2873-
None => panic!("argument of integer square root must be non-negative"),
2895+
None => crate::num::int_sqrt::panic_for_negative_argument(),
28742896
}
28752897
}
28762898

Diff for: core/src/num/int_sqrt.rs

+316
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,316 @@
1+
//! These functions use the [Karatsuba square root algorithm][1] to compute the
2+
//! [integer square root](https://en.wikipedia.org/wiki/Integer_square_root)
3+
//! for the primitive integer types.
4+
//!
5+
//! The signed integer functions can only handle **nonnegative** inputs, so
6+
//! that must be checked before calling those.
7+
//!
8+
//! [1]: <https://web.archive.org/web/20230511212802/https://inria.hal.science/inria-00072854v1/file/RR-3805.pdf>
9+
//! "Paul Zimmermann. Karatsuba Square Root. \[Research Report\] RR-3805,
10+
//! INRIA. 1999, pp.8. (inria-00072854)"
11+
12+
/// This array stores the [integer square roots](
13+
/// https://en.wikipedia.org/wiki/Integer_square_root) and remainders of each
14+
/// [`u8`](prim@u8) value. For example, `U8_ISQRT_WITH_REMAINDER[17]` will be
15+
/// `(4, 1)` because the integer square root of 17 is 4 and because 17 is 1
16+
/// higher than 4 squared.
17+
const U8_ISQRT_WITH_REMAINDER: [(u8, u8); 256] = {
18+
let mut result = [(0, 0); 256];
19+
20+
let mut n: usize = 0;
21+
let mut isqrt_n: usize = 0;
22+
while n < result.len() {
23+
result[n] = (isqrt_n as u8, (n - isqrt_n.pow(2)) as u8);
24+
25+
n += 1;
26+
if n == (isqrt_n + 1).pow(2) {
27+
isqrt_n += 1;
28+
}
29+
}
30+
31+
result
32+
};
33+
34+
/// Returns the [integer square root](
35+
/// https://en.wikipedia.org/wiki/Integer_square_root) of any [`u8`](prim@u8)
36+
/// input.
37+
#[must_use = "this returns the result of the operation, \
38+
without modifying the original"]
39+
#[inline]
40+
pub const fn u8(n: u8) -> u8 {
41+
U8_ISQRT_WITH_REMAINDER[n as usize].0
42+
}
43+
44+
/// Generates an `i*` function that returns the [integer square root](
45+
/// https://en.wikipedia.org/wiki/Integer_square_root) of any **nonnegative**
46+
/// input of a specific signed integer type.
47+
macro_rules! signed_fn {
48+
($SignedT:ident, $UnsignedT:ident) => {
49+
/// Returns the [integer square root](
50+
/// https://en.wikipedia.org/wiki/Integer_square_root) of any
51+
/// **nonnegative**
52+
#[doc = concat!("[`", stringify!($SignedT), "`](prim@", stringify!($SignedT), ")")]
53+
/// input.
54+
///
55+
/// # Safety
56+
///
57+
/// This results in undefined behavior when the input is negative.
58+
#[must_use = "this returns the result of the operation, \
59+
without modifying the original"]
60+
#[inline]
61+
pub const unsafe fn $SignedT(n: $SignedT) -> $SignedT {
62+
debug_assert!(n >= 0, "Negative input inside `isqrt`.");
63+
$UnsignedT(n as $UnsignedT) as $SignedT
64+
}
65+
};
66+
}
67+
68+
signed_fn!(i8, u8);
69+
signed_fn!(i16, u16);
70+
signed_fn!(i32, u32);
71+
signed_fn!(i64, u64);
72+
signed_fn!(i128, u128);
73+
74+
/// Generates a `u*` function that returns the [integer square root](
75+
/// https://en.wikipedia.org/wiki/Integer_square_root) of any input of
76+
/// a specific unsigned integer type.
77+
macro_rules! unsigned_fn {
78+
($UnsignedT:ident, $HalfBitsT:ident, $stages:ident) => {
79+
/// Returns the [integer square root](
80+
/// https://en.wikipedia.org/wiki/Integer_square_root) of any
81+
#[doc = concat!("[`", stringify!($UnsignedT), "`](prim@", stringify!($UnsignedT), ")")]
82+
/// input.
83+
#[must_use = "this returns the result of the operation, \
84+
without modifying the original"]
85+
#[inline]
86+
pub const fn $UnsignedT(mut n: $UnsignedT) -> $UnsignedT {
87+
if n <= <$HalfBitsT>::MAX as $UnsignedT {
88+
$HalfBitsT(n as $HalfBitsT) as $UnsignedT
89+
} else {
90+
// The normalization shift satisfies the Karatsuba square root
91+
// algorithm precondition "a₃ ≥ b/4" where a₃ is the most
92+
// significant quarter of `n`'s bits and b is the number of
93+
// values that can be represented by that quarter of the bits.
94+
//
95+
// b/4 would then be all 0s except the second most significant
96+
// bit (010...0) in binary. Since a₃ must be at least b/4, a₃'s
97+
// most significant bit or its neighbor must be a 1. Since a₃'s
98+
// most significant bits are `n`'s most significant bits, the
99+
// same applies to `n`.
100+
//
101+
// The reason to shift by an even number of bits is because an
102+
// even number of bits produces the square root shifted to the
103+
// left by half of the normalization shift:
104+
//
105+
// sqrt(n << (2 * p))
106+
// sqrt(2.pow(2 * p) * n)
107+
// sqrt(2.pow(2 * p)) * sqrt(n)
108+
// 2.pow(p) * sqrt(n)
109+
// sqrt(n) << p
110+
//
111+
// Shifting by an odd number of bits leaves an ugly sqrt(2)
112+
// multiplied in:
113+
//
114+
// sqrt(n << (2 * p + 1))
115+
// sqrt(2.pow(2 * p + 1) * n)
116+
// sqrt(2 * 2.pow(2 * p) * n)
117+
// sqrt(2) * sqrt(2.pow(2 * p)) * sqrt(n)
118+
// sqrt(2) * 2.pow(p) * sqrt(n)
119+
// sqrt(2) * (sqrt(n) << p)
120+
const EVEN_MAKING_BITMASK: u32 = !1;
121+
let normalization_shift = n.leading_zeros() & EVEN_MAKING_BITMASK;
122+
n <<= normalization_shift;
123+
124+
let s = $stages(n);
125+
126+
let denormalization_shift = normalization_shift >> 1;
127+
s >> denormalization_shift
128+
}
129+
}
130+
};
131+
}
132+
133+
/// Generates the first stage of the computation after normalization.
134+
///
135+
/// # Safety
136+
///
137+
/// `$n` must be nonzero.
138+
macro_rules! first_stage {
139+
($original_bits:literal, $n:ident) => {{
140+
debug_assert!($n != 0, "`$n` is zero in `first_stage!`.");
141+
142+
const N_SHIFT: u32 = $original_bits - 8;
143+
let n = $n >> N_SHIFT;
144+
145+
let (s, r) = U8_ISQRT_WITH_REMAINDER[n as usize];
146+
147+
// Inform the optimizer that `s` is nonzero. This will allow it to
148+
// avoid generating code to handle division-by-zero panics in the next
149+
// stage.
150+
//
151+
// SAFETY: If the original `$n` is zero, the top of the `unsigned_fn`
152+
// macro recurses instead of continuing to this point, so the original
153+
// `$n` wasn't a 0 if we've reached here.
154+
//
155+
// Then the `unsigned_fn` macro normalizes `$n` so that at least one of
156+
// its two most-significant bits is a 1.
157+
//
158+
// Then this stage puts the eight most-significant bits of `$n` into
159+
// `n`. This means that `n` here has at least one 1 bit in its two
160+
// most-significant bits, making `n` nonzero.
161+
//
162+
// `U8_ISQRT_WITH_REMAINDER[n as usize]` will give a nonzero `s` when
163+
// given a nonzero `n`.
164+
unsafe { crate::hint::assert_unchecked(s != 0) };
165+
(s, r)
166+
}};
167+
}
168+
169+
/// Generates a middle stage of the computation.
170+
///
171+
/// # Safety
172+
///
173+
/// `$s` must be nonzero.
174+
macro_rules! middle_stage {
175+
($original_bits:literal, $ty:ty, $n:ident, $s:ident, $r:ident) => {{
176+
debug_assert!($s != 0, "`$s` is zero in `middle_stage!`.");
177+
178+
const N_SHIFT: u32 = $original_bits - <$ty>::BITS;
179+
let n = ($n >> N_SHIFT) as $ty;
180+
181+
const HALF_BITS: u32 = <$ty>::BITS >> 1;
182+
const QUARTER_BITS: u32 = <$ty>::BITS >> 2;
183+
const LOWER_HALF_1_BITS: $ty = (1 << HALF_BITS) - 1;
184+
const LOWEST_QUARTER_1_BITS: $ty = (1 << QUARTER_BITS) - 1;
185+
186+
let lo = n & LOWER_HALF_1_BITS;
187+
let numerator = (($r as $ty) << QUARTER_BITS) | (lo >> QUARTER_BITS);
188+
let denominator = ($s as $ty) << 1;
189+
let q = numerator / denominator;
190+
let u = numerator % denominator;
191+
192+
let mut s = ($s << QUARTER_BITS) as $ty + q;
193+
let (mut r, overflow) =
194+
((u << QUARTER_BITS) | (lo & LOWEST_QUARTER_1_BITS)).overflowing_sub(q * q);
195+
if overflow {
196+
r = r.wrapping_add(2 * s - 1);
197+
s -= 1;
198+
}
199+
200+
// Inform the optimizer that `s` is nonzero. This will allow it to
201+
// avoid generating code to handle division-by-zero panics in the next
202+
// stage.
203+
//
204+
// SAFETY: If the original `$n` is zero, the top of the `unsigned_fn`
205+
// macro recurses instead of continuing to this point, so the original
206+
// `$n` wasn't a 0 if we've reached here.
207+
//
208+
// Then the `unsigned_fn` macro normalizes `$n` so that at least one of
209+
// its two most-significant bits is a 1.
210+
//
211+
// Then these stages take as many of the most-significant bits of `$n`
212+
// as will fit in this stage's type. For example, the stage that
213+
// handles `u32` deals with the 32 most-significant bits of `$n`. This
214+
// means that each stage has at least one 1 bit in `n`'s two
215+
// most-significant bits, making `n` nonzero.
216+
//
217+
// Then this stage will produce the correct integer square root for
218+
// that `n` value. Since `n` is nonzero, `s` will also be nonzero.
219+
unsafe { crate::hint::assert_unchecked(s != 0) };
220+
(s, r)
221+
}};
222+
}
223+
224+
/// Generates the last stage of the computation before denormalization.
225+
///
226+
/// # Safety
227+
///
228+
/// `$s` must be nonzero.
229+
macro_rules! last_stage {
230+
($ty:ty, $n:ident, $s:ident, $r:ident) => {{
231+
debug_assert!($s != 0, "`$s` is zero in `last_stage!`.");
232+
233+
const HALF_BITS: u32 = <$ty>::BITS >> 1;
234+
const QUARTER_BITS: u32 = <$ty>::BITS >> 2;
235+
const LOWER_HALF_1_BITS: $ty = (1 << HALF_BITS) - 1;
236+
237+
let lo = $n & LOWER_HALF_1_BITS;
238+
let numerator = (($r as $ty) << QUARTER_BITS) | (lo >> QUARTER_BITS);
239+
let denominator = ($s as $ty) << 1;
240+
241+
let q = numerator / denominator;
242+
let mut s = ($s << QUARTER_BITS) as $ty + q;
243+
let (s_squared, overflow) = s.overflowing_mul(s);
244+
if overflow || s_squared > $n {
245+
s -= 1;
246+
}
247+
s
248+
}};
249+
}
250+
251+
/// Takes the normalized [`u16`](prim@u16) input and gets its normalized
252+
/// [integer square root](https://en.wikipedia.org/wiki/Integer_square_root).
253+
///
254+
/// # Safety
255+
///
256+
/// `n` must be nonzero.
257+
#[inline]
258+
const fn u16_stages(n: u16) -> u16 {
259+
let (s, r) = first_stage!(16, n);
260+
last_stage!(u16, n, s, r)
261+
}
262+
263+
/// Takes the normalized [`u32`](prim@u32) input and gets its normalized
264+
/// [integer square root](https://en.wikipedia.org/wiki/Integer_square_root).
265+
///
266+
/// # Safety
267+
///
268+
/// `n` must be nonzero.
269+
#[inline]
270+
const fn u32_stages(n: u32) -> u32 {
271+
let (s, r) = first_stage!(32, n);
272+
let (s, r) = middle_stage!(32, u16, n, s, r);
273+
last_stage!(u32, n, s, r)
274+
}
275+
276+
/// Takes the normalized [`u64`](prim@u64) input and gets its normalized
277+
/// [integer square root](https://en.wikipedia.org/wiki/Integer_square_root).
278+
///
279+
/// # Safety
280+
///
281+
/// `n` must be nonzero.
282+
#[inline]
283+
const fn u64_stages(n: u64) -> u64 {
284+
let (s, r) = first_stage!(64, n);
285+
let (s, r) = middle_stage!(64, u16, n, s, r);
286+
let (s, r) = middle_stage!(64, u32, n, s, r);
287+
last_stage!(u64, n, s, r)
288+
}
289+
290+
/// Takes the normalized [`u128`](prim@u128) input and gets its normalized
291+
/// [integer square root](https://en.wikipedia.org/wiki/Integer_square_root).
292+
///
293+
/// # Safety
294+
///
295+
/// `n` must be nonzero.
296+
#[inline]
297+
const fn u128_stages(n: u128) -> u128 {
298+
let (s, r) = first_stage!(128, n);
299+
let (s, r) = middle_stage!(128, u16, n, s, r);
300+
let (s, r) = middle_stage!(128, u32, n, s, r);
301+
let (s, r) = middle_stage!(128, u64, n, s, r);
302+
last_stage!(u128, n, s, r)
303+
}
304+
305+
unsigned_fn!(u16, u8, u16_stages);
306+
unsigned_fn!(u32, u16, u32_stages);
307+
unsigned_fn!(u64, u32, u64_stages);
308+
unsigned_fn!(u128, u64, u128_stages);
309+
310+
/// Instantiate this panic logic once, rather than for all the isqrt methods
311+
/// on every single primitive type.
312+
#[cold]
313+
#[track_caller]
314+
pub const fn panic_for_negative_argument() -> ! {
315+
panic!("argument of integer square root cannot be negative")
316+
}

Diff for: core/src/num/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ mod uint_macros; // import uint_impl!
4141

4242
mod error;
4343
mod int_log10;
44+
mod int_sqrt;
4445
mod nonzero;
4546
mod overflow_panic;
4647
mod saturating;

0 commit comments

Comments
 (0)