Skip to content

Commit 0198b89

Browse files
committed
Greatly sped up checked_isqrt and isqrt methods
* Uses a lookup table for 8-bit integers and then the Karatsuba square root algorithm for larger integers. * Includes optimization hints that give the compiler the exact numeric range of results.
1 parent 186d224 commit 0198b89

File tree

5 files changed

+237
-35
lines changed

5 files changed

+237
-35
lines changed

library/core/src/num/int_macros.rs

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1581,7 +1581,18 @@ macro_rules! int_impl {
15811581
if self < 0 {
15821582
None
15831583
} else {
1584-
Some((self as $UnsignedT).isqrt() as Self)
1584+
let result = crate::num::int_sqrt::$ActualT(self as $ActualT) as $SelfT;
1585+
1586+
// SAFETY: Inform the optimizer that square roots of
1587+
// nonnegative integers are nonnegative and what the maximum
1588+
// result is.
1589+
unsafe {
1590+
crate::hint::assert_unchecked(result >= 0);
1591+
const MAX_RESULT: $SelfT = crate::num::int_sqrt::$ActualT($ActualT::MAX) as $SelfT;
1592+
crate::hint::assert_unchecked(result <= MAX_RESULT);
1593+
}
1594+
1595+
Some(result)
15851596
}
15861597
}
15871598

@@ -2769,14 +2780,21 @@ macro_rules! int_impl {
27692780
without modifying the original"]
27702781
#[inline]
27712782
pub const fn isqrt(self) -> Self {
2772-
// I would like to implement it as
2773-
// ```
2774-
// self.checked_isqrt().expect("argument of integer square root must be non-negative")
2775-
// ```
2776-
// but `expect` is not yet stable as a `const fn`.
2777-
match self.checked_isqrt() {
2778-
Some(sqrt) => sqrt,
2779-
None => panic!("argument of integer square root must be non-negative"),
2783+
if self < 0 {
2784+
crate::num::int_sqrt::panic_for_negative_argument();
2785+
} else {
2786+
let result = crate::num::int_sqrt::$ActualT(self as $ActualT) as $SelfT;
2787+
2788+
// SAFETY: Inform the optimizer that square roots of
2789+
// nonnegative integers are nonnegative and what the maximum
2790+
// result is.
2791+
unsafe {
2792+
crate::hint::assert_unchecked(result >= 0);
2793+
const MAX_RESULT: $SelfT = crate::num::int_sqrt::$ActualT($ActualT::MAX) as $SelfT;
2794+
crate::hint::assert_unchecked(result <= MAX_RESULT);
2795+
}
2796+
2797+
result
27802798
}
27812799
}
27822800

library/core/src/num/int_sqrt.rs

Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
/// These functions compute the integer square root of their type, assuming
2+
/// that someone has already checked that the value is nonnegative.
3+
4+
const ISQRT_AND_REMAINDER_8_BIT: [(u8, u8); 256] = {
5+
let mut result = [(0, 0); 256];
6+
7+
let mut sqrt = 0;
8+
let mut i = 0;
9+
'outer: loop {
10+
let mut remaining = 2 * sqrt + 1;
11+
while remaining > 0 {
12+
result[i as usize] = (sqrt, 2 * sqrt + 1 - remaining);
13+
i += 1;
14+
if i >= result.len() {
15+
break 'outer;
16+
}
17+
remaining -= 1;
18+
}
19+
sqrt += 1;
20+
}
21+
22+
result
23+
};
24+
25+
// `#[inline(always)]` because the programmer-accessible functions will use
26+
// this internally and the contents of this should be inlined there.
27+
#[inline(always)]
28+
pub const fn u8(n: u8) -> u8 {
29+
ISQRT_AND_REMAINDER_8_BIT[n as usize].0
30+
}
31+
32+
#[inline(always)]
33+
const fn intermediate_u8(n: u8) -> (u8, u8) {
34+
ISQRT_AND_REMAINDER_8_BIT[n as usize]
35+
}
36+
37+
macro_rules! karatsuba_isqrt {
38+
($FullBitsT:ty, $fn:ident, $intermediate_fn:ident, $HalfBitsT:ty, $half_fn:ident, $intermediate_half_fn:ident) => {
39+
// `#[inline(always)]` because the programmer-accessible functions will
40+
// use this internally and the contents of this should be inlined
41+
// there.
42+
#[inline(always)]
43+
pub const fn $fn(mut n: $FullBitsT) -> $FullBitsT {
44+
// Performs a Karatsuba square root.
45+
// https://web.archive.org/web/20230511212802/https://inria.hal.science/inria-00072854v1/file/RR-3805.pdf
46+
47+
const HALF_BITS: u32 = <$FullBitsT>::BITS >> 1;
48+
const QUARTER_BITS: u32 = <$FullBitsT>::BITS >> 2;
49+
50+
let leading_zeros = n.leading_zeros();
51+
let result = if leading_zeros >= HALF_BITS {
52+
$half_fn(n as $HalfBitsT) as $FullBitsT
53+
} else {
54+
// Either the most-significant bit or its neighbor must be a one, so we shift left to make that happen.
55+
let precondition_shift = leading_zeros & (HALF_BITS - 2);
56+
n <<= precondition_shift;
57+
58+
let hi = (n >> HALF_BITS) as $HalfBitsT;
59+
let lo = n & (<$HalfBitsT>::MAX as $FullBitsT);
60+
61+
let (s_prime, r_prime) = $intermediate_half_fn(hi);
62+
63+
let numerator = ((r_prime as $FullBitsT) << QUARTER_BITS) | (lo >> QUARTER_BITS);
64+
let denominator = (s_prime as $FullBitsT) << 1;
65+
66+
let q = numerator / denominator;
67+
let u = numerator % denominator;
68+
69+
let mut s = (s_prime << QUARTER_BITS) as $FullBitsT + q;
70+
if ((u << QUARTER_BITS) | (lo & ((1 << QUARTER_BITS) - 1))) < q * q {
71+
s -= 1;
72+
}
73+
s >> (precondition_shift >> 1)
74+
};
75+
76+
result
77+
}
78+
79+
const fn $intermediate_fn(mut n: $FullBitsT) -> ($FullBitsT, $FullBitsT) {
80+
// Performs a Karatsuba square root.
81+
// https://web.archive.org/web/20230511212802/https://inria.hal.science/inria-00072854v1/file/RR-3805.pdf
82+
83+
const HALF_BITS: u32 = <$FullBitsT>::BITS >> 1;
84+
const QUARTER_BITS: u32 = <$FullBitsT>::BITS >> 2;
85+
86+
let leading_zeros = n.leading_zeros();
87+
let result = if leading_zeros >= HALF_BITS {
88+
let (s, r) = $intermediate_half_fn(n as $HalfBitsT);
89+
(s as $FullBitsT, r as $FullBitsT)
90+
} else {
91+
// Either the most-significant bit or its neighbor must be a one, so we shift left to make that happen.
92+
let precondition_shift = leading_zeros & (HALF_BITS - 2);
93+
n <<= precondition_shift;
94+
95+
let hi = (n >> HALF_BITS) as $HalfBitsT;
96+
let lo = n & (<$HalfBitsT>::MAX as $FullBitsT);
97+
98+
let (s_prime, r_prime) = $intermediate_half_fn(hi);
99+
100+
let numerator = ((r_prime as $FullBitsT) << QUARTER_BITS) | (lo >> QUARTER_BITS);
101+
let denominator = (s_prime as $FullBitsT) << 1;
102+
103+
let q = numerator / denominator;
104+
let u = numerator % denominator;
105+
106+
let mut s = (s_prime << QUARTER_BITS) as $FullBitsT + q;
107+
let (mut r, overflow) =
108+
((u << QUARTER_BITS) | (lo & ((1 << QUARTER_BITS) - 1))).overflowing_sub(q * q);
109+
if overflow {
110+
r = r.wrapping_add((s << 1) - 1);
111+
s -= 1;
112+
}
113+
(s >> (precondition_shift >> 1), r >> (precondition_shift >> 1))
114+
};
115+
116+
result
117+
}
118+
};
119+
}
120+
121+
karatsuba_isqrt!(u16, u16, intermediate_u16, u8, u8, intermediate_u8);
122+
karatsuba_isqrt!(u32, u32, intermediate_u32, u16, u16, intermediate_u16);
123+
karatsuba_isqrt!(u64, u64, intermediate_u64, u32, u32, intermediate_u32);
124+
karatsuba_isqrt!(u128, u128, _intermediate_u128, u64, u64, intermediate_u64);
125+
126+
#[cfg(target_pointer_width = "16")]
127+
#[inline(always)]
128+
pub const fn usize(n: usize) -> usize {
129+
u16(n as u16) as usize
130+
}
131+
132+
#[cfg(target_pointer_width = "32")]
133+
#[inline(always)]
134+
pub const fn usize(n: usize) -> usize {
135+
u32(n as u32) as usize
136+
}
137+
138+
#[cfg(target_pointer_width = "64")]
139+
#[inline(always)]
140+
pub const fn usize(n: usize) -> usize {
141+
u64(n as u64) as usize
142+
}
143+
144+
// 0 <= val <= i8::MAX
145+
#[inline(always)]
146+
pub const fn i8(n: i8) -> i8 {
147+
u8(n as u8) as i8
148+
}
149+
150+
// 0 <= val <= i16::MAX
151+
#[inline(always)]
152+
pub const fn i16(n: i16) -> i16 {
153+
u16(n as u16) as i16
154+
}
155+
156+
// 0 <= val <= i32::MAX
157+
#[inline(always)]
158+
pub const fn i32(n: i32) -> i32 {
159+
u32(n as u32) as i32
160+
}
161+
162+
// 0 <= val <= i64::MAX
163+
#[inline(always)]
164+
pub const fn i64(n: i64) -> i64 {
165+
u64(n as u64) as i64
166+
}
167+
168+
// 0 <= val <= i128::MAX
169+
#[inline(always)]
170+
pub const fn i128(n: i128) -> i128 {
171+
u128(n as u128) as i128
172+
}
173+
174+
/*
175+
This function is not used.
176+
177+
// 0 <= val <= isize::MAX
178+
#[inline(always)]
179+
pub const fn isize(n: isize) -> isize {
180+
usize(n as usize) as isize
181+
}
182+
*/
183+
184+
/// Instantiate this panic logic once, rather than for all the ilog methods
185+
/// on every single primitive type.
186+
#[cold]
187+
#[track_caller]
188+
pub const fn panic_for_negative_argument() -> ! {
189+
panic!("argument of integer square root cannot be negative")
190+
}

library/core/src/num/mod.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ mod uint_macros; // import uint_impl!
4141

4242
mod error;
4343
mod int_log10;
44+
mod int_sqrt;
4445
mod nonzero;
4546
mod overflow_panic;
4647
mod saturating;

library/core/src/num/nonzero.rs

Lines changed: 11 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1545,31 +1545,19 @@ macro_rules! nonzero_integer_signedness_dependent_methods {
15451545
without modifying the original"]
15461546
#[inline]
15471547
pub const fn isqrt(self) -> Self {
1548-
// The algorithm is based on the one presented in
1549-
// <https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Binary_numeral_system_(base_2)>
1550-
// which cites as source the following C code:
1551-
// <https://web.archive.org/web/20120306040058/http://medialab.freaknet.org/martin/src/sqrt/sqrt.c>.
1552-
1553-
let mut op = self.get();
1554-
let mut res = 0;
1555-
let mut one = 1 << (self.ilog2() & !1);
1556-
1557-
while one != 0 {
1558-
if op >= res + one {
1559-
op -= res + one;
1560-
res = (res >> 1) + one;
1561-
} else {
1562-
res >>= 1;
1563-
}
1564-
one >>= 2;
1548+
let result = super::int_sqrt::$Int(self.get());
1549+
1550+
// SAFETY: Inform the optimizer that square roots of positive
1551+
// integers are positive and what the maximum result is.
1552+
unsafe {
1553+
hint::assert_unchecked(result > 0);
1554+
const MAX_RESULT: $Int = super::int_sqrt::$Int($Int::MAX);
1555+
hint::assert_unchecked(result <= MAX_RESULT);
15651556
}
15661557

1567-
// SAFETY: The result fits in an integer with half as many bits.
1568-
// Inform the optimizer about it.
1569-
unsafe { hint::assert_unchecked(res < 1 << (Self::BITS / 2)) };
1570-
1571-
// SAFETY: The square root of an integer >= 1 is always >= 1.
1572-
unsafe { Self::new_unchecked(res) }
1558+
// SAFETY: The square root of a positive integer is always
1559+
// positive.
1560+
unsafe { Self::new_unchecked(result) }
15731561
}
15741562
};
15751563

library/core/src/num/uint_macros.rs

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2677,10 +2677,15 @@ macro_rules! uint_impl {
26772677
without modifying the original"]
26782678
#[inline]
26792679
pub const fn isqrt(self) -> Self {
2680-
match NonZero::new(self) {
2681-
Some(x) => x.isqrt().get(),
2682-
None => 0,
2680+
let result = crate::num::int_sqrt::$ActualT(self as $ActualT) as $SelfT;
2681+
2682+
// SAFETY: Inform the optimizer of what the maximum result is.
2683+
unsafe {
2684+
const MAX_RESULT: $SelfT = crate::num::int_sqrt::$ActualT($ActualT::MAX) as $SelfT;
2685+
crate::hint::assert_unchecked(result <= MAX_RESULT);
26832686
}
2687+
2688+
result
26842689
}
26852690

26862691
/// Performs Euclidean division.

0 commit comments

Comments
 (0)