|
| 1 | +//! These functions use the [Karatsuba square root algorithm][1] to compute the |
| 2 | +//! [integer square root](https://en.wikipedia.org/wiki/Integer_square_root) |
| 3 | +//! for the primitive integer types. |
| 4 | +//! |
| 5 | +//! The signed integer functions can only handle **nonnegative** inputs, so |
| 6 | +//! that must be checked before calling those. |
| 7 | +//! |
| 8 | +//! [1]: <https://web.archive.org/web/20230511212802/https://inria.hal.science/inria-00072854v1/file/RR-3805.pdf> |
| 9 | +//! "Paul Zimmermann. Karatsuba Square Root. \[Research Report\] RR-3805, |
| 10 | +//! INRIA. 1999, pp.8. (inria-00072854)" |
| 11 | +
|
| 12 | +/// This array stores the [integer square roots]( |
| 13 | +/// https://en.wikipedia.org/wiki/Integer_square_root) and remainders of each |
| 14 | +/// [`u8`](prim@u8) value. For example, `U8_ISQRT_WITH_REMAINDER[17]` will be |
| 15 | +/// `(4, 1)` because the integer square root of 17 is 4 and because 17 is 1 |
| 16 | +/// higher than 4 squared. |
| 17 | +const U8_ISQRT_WITH_REMAINDER: [(u8, u8); 256] = { |
| 18 | + let mut result = [(0, 0); 256]; |
| 19 | + |
| 20 | + let mut n: usize = 0; |
| 21 | + let mut isqrt_n: usize = 0; |
| 22 | + while n < result.len() { |
| 23 | + result[n] = (isqrt_n as u8, (n - isqrt_n.pow(2)) as u8); |
| 24 | + |
| 25 | + n += 1; |
| 26 | + if n == (isqrt_n + 1).pow(2) { |
| 27 | + isqrt_n += 1; |
| 28 | + } |
| 29 | + } |
| 30 | + |
| 31 | + result |
| 32 | +}; |
| 33 | + |
| 34 | +/// Returns the [integer square root]( |
| 35 | +/// https://en.wikipedia.org/wiki/Integer_square_root) of any [`u8`](prim@u8) |
| 36 | +/// input. |
| 37 | +#[must_use = "this returns the result of the operation, \ |
| 38 | + without modifying the original"] |
| 39 | +#[inline] |
| 40 | +pub const fn u8(n: u8) -> u8 { |
| 41 | + U8_ISQRT_WITH_REMAINDER[n as usize].0 |
| 42 | +} |
| 43 | + |
| 44 | +/// Generates an `i*` function that returns the [integer square root]( |
| 45 | +/// https://en.wikipedia.org/wiki/Integer_square_root) of any **nonnegative** |
| 46 | +/// input of a specific signed integer type. |
| 47 | +macro_rules! signed_fn { |
| 48 | + ($SignedT:ident, $UnsignedT:ident) => { |
| 49 | + /// Returns the [integer square root]( |
| 50 | + /// https://en.wikipedia.org/wiki/Integer_square_root) of any |
| 51 | + /// **nonnegative** |
| 52 | + #[doc = concat!("[`", stringify!($SignedT), "`](prim@", stringify!($SignedT), ")")] |
| 53 | + /// input. |
| 54 | + /// |
| 55 | + /// # Safety |
| 56 | + /// |
| 57 | + /// This results in undefined behavior when the input is negative. |
| 58 | + #[must_use = "this returns the result of the operation, \ |
| 59 | + without modifying the original"] |
| 60 | + #[inline] |
| 61 | + pub const unsafe fn $SignedT(n: $SignedT) -> $SignedT { |
| 62 | + debug_assert!(n >= 0, "Negative input inside `isqrt`."); |
| 63 | + $UnsignedT(n as $UnsignedT) as $SignedT |
| 64 | + } |
| 65 | + }; |
| 66 | +} |
| 67 | + |
| 68 | +signed_fn!(i8, u8); |
| 69 | +signed_fn!(i16, u16); |
| 70 | +signed_fn!(i32, u32); |
| 71 | +signed_fn!(i64, u64); |
| 72 | +signed_fn!(i128, u128); |
| 73 | + |
| 74 | +/// Generates a `u*` function that returns the [integer square root]( |
| 75 | +/// https://en.wikipedia.org/wiki/Integer_square_root) of any input of |
| 76 | +/// a specific unsigned integer type. |
| 77 | +macro_rules! unsigned_fn { |
| 78 | + ($UnsignedT:ident, $HalfBitsT:ident, $stages:ident) => { |
| 79 | + /// Returns the [integer square root]( |
| 80 | + /// https://en.wikipedia.org/wiki/Integer_square_root) of any |
| 81 | + #[doc = concat!("[`", stringify!($UnsignedT), "`](prim@", stringify!($UnsignedT), ")")] |
| 82 | + /// input. |
| 83 | + #[must_use = "this returns the result of the operation, \ |
| 84 | + without modifying the original"] |
| 85 | + #[inline] |
| 86 | + pub const fn $UnsignedT(mut n: $UnsignedT) -> $UnsignedT { |
| 87 | + if n <= <$HalfBitsT>::MAX as $UnsignedT { |
| 88 | + $HalfBitsT(n as $HalfBitsT) as $UnsignedT |
| 89 | + } else { |
| 90 | + // The normalization shift satisfies the Karatsuba square root |
| 91 | + // algorithm precondition "a₃ ≥ b/4" where a₃ is the most |
| 92 | + // significant quarter of `n`'s bits and b is the number of |
| 93 | + // values that can be represented by that quarter of the bits. |
| 94 | + // |
| 95 | + // b/4 would then be all 0s except the second most significant |
| 96 | + // bit (010...0) in binary. Since a₃ must be at least b/4, a₃'s |
| 97 | + // most significant bit or its neighbor must be a 1. Since a₃'s |
| 98 | + // most significant bits are `n`'s most significant bits, the |
| 99 | + // same applies to `n`. |
| 100 | + // |
| 101 | + // The reason to shift by an even number of bits is because an |
| 102 | + // even number of bits produces the square root shifted to the |
| 103 | + // left by half of the normalization shift: |
| 104 | + // |
| 105 | + // sqrt(n << (2 * p)) |
| 106 | + // sqrt(2.pow(2 * p) * n) |
| 107 | + // sqrt(2.pow(2 * p)) * sqrt(n) |
| 108 | + // 2.pow(p) * sqrt(n) |
| 109 | + // sqrt(n) << p |
| 110 | + // |
| 111 | + // Shifting by an odd number of bits leaves an ugly sqrt(2) |
| 112 | + // multiplied in: |
| 113 | + // |
| 114 | + // sqrt(n << (2 * p + 1)) |
| 115 | + // sqrt(2.pow(2 * p + 1) * n) |
| 116 | + // sqrt(2 * 2.pow(2 * p) * n) |
| 117 | + // sqrt(2) * sqrt(2.pow(2 * p)) * sqrt(n) |
| 118 | + // sqrt(2) * 2.pow(p) * sqrt(n) |
| 119 | + // sqrt(2) * (sqrt(n) << p) |
| 120 | + const EVEN_MAKING_BITMASK: u32 = !1; |
| 121 | + let normalization_shift = n.leading_zeros() & EVEN_MAKING_BITMASK; |
| 122 | + n <<= normalization_shift; |
| 123 | + |
| 124 | + let s = $stages(n); |
| 125 | + |
| 126 | + let denormalization_shift = normalization_shift >> 1; |
| 127 | + s >> denormalization_shift |
| 128 | + } |
| 129 | + } |
| 130 | + }; |
| 131 | +} |
| 132 | + |
| 133 | +/// Generates the first stage of the computation after normalization. |
| 134 | +/// |
| 135 | +/// # Safety |
| 136 | +/// |
| 137 | +/// `$n` must be nonzero. |
| 138 | +macro_rules! first_stage { |
| 139 | + ($original_bits:literal, $n:ident) => {{ |
| 140 | + debug_assert!($n != 0, "`$n` is zero in `first_stage!`."); |
| 141 | + |
| 142 | + const N_SHIFT: u32 = $original_bits - 8; |
| 143 | + let n = $n >> N_SHIFT; |
| 144 | + |
| 145 | + let (s, r) = U8_ISQRT_WITH_REMAINDER[n as usize]; |
| 146 | + |
| 147 | + // Inform the optimizer that `s` is nonzero. This will allow it to |
| 148 | + // avoid generating code to handle division-by-zero panics in the next |
| 149 | + // stage. |
| 150 | + // |
| 151 | + // SAFETY: If the original `$n` is zero, the top of the `unsigned_fn` |
| 152 | + // macro recurses instead of continuing to this point, so the original |
| 153 | + // `$n` wasn't a 0 if we've reached here. |
| 154 | + // |
| 155 | + // Then the `unsigned_fn` macro normalizes `$n` so that at least one of |
| 156 | + // its two most-significant bits is a 1. |
| 157 | + // |
| 158 | + // Then this stage puts the eight most-significant bits of `$n` into |
| 159 | + // `n`. This means that `n` here has at least one 1 bit in its two |
| 160 | + // most-significant bits, making `n` nonzero. |
| 161 | + // |
| 162 | + // `U8_ISQRT_WITH_REMAINDER[n as usize]` will give a nonzero `s` when |
| 163 | + // given a nonzero `n`. |
| 164 | + unsafe { crate::hint::assert_unchecked(s != 0) }; |
| 165 | + (s, r) |
| 166 | + }}; |
| 167 | +} |
| 168 | + |
| 169 | +/// Generates a middle stage of the computation. |
| 170 | +/// |
| 171 | +/// # Safety |
| 172 | +/// |
| 173 | +/// `$s` must be nonzero. |
| 174 | +macro_rules! middle_stage { |
| 175 | + ($original_bits:literal, $ty:ty, $n:ident, $s:ident, $r:ident) => {{ |
| 176 | + debug_assert!($s != 0, "`$s` is zero in `middle_stage!`."); |
| 177 | + |
| 178 | + const N_SHIFT: u32 = $original_bits - <$ty>::BITS; |
| 179 | + let n = ($n >> N_SHIFT) as $ty; |
| 180 | + |
| 181 | + const HALF_BITS: u32 = <$ty>::BITS >> 1; |
| 182 | + const QUARTER_BITS: u32 = <$ty>::BITS >> 2; |
| 183 | + const LOWER_HALF_1_BITS: $ty = (1 << HALF_BITS) - 1; |
| 184 | + const LOWEST_QUARTER_1_BITS: $ty = (1 << QUARTER_BITS) - 1; |
| 185 | + |
| 186 | + let lo = n & LOWER_HALF_1_BITS; |
| 187 | + let numerator = (($r as $ty) << QUARTER_BITS) | (lo >> QUARTER_BITS); |
| 188 | + let denominator = ($s as $ty) << 1; |
| 189 | + let q = numerator / denominator; |
| 190 | + let u = numerator % denominator; |
| 191 | + |
| 192 | + let mut s = ($s << QUARTER_BITS) as $ty + q; |
| 193 | + let (mut r, overflow) = |
| 194 | + ((u << QUARTER_BITS) | (lo & LOWEST_QUARTER_1_BITS)).overflowing_sub(q * q); |
| 195 | + if overflow { |
| 196 | + r = r.wrapping_add(2 * s - 1); |
| 197 | + s -= 1; |
| 198 | + } |
| 199 | + |
| 200 | + // Inform the optimizer that `s` is nonzero. This will allow it to |
| 201 | + // avoid generating code to handle division-by-zero panics in the next |
| 202 | + // stage. |
| 203 | + // |
| 204 | + // SAFETY: If the original `$n` is zero, the top of the `unsigned_fn` |
| 205 | + // macro recurses instead of continuing to this point, so the original |
| 206 | + // `$n` wasn't a 0 if we've reached here. |
| 207 | + // |
| 208 | + // Then the `unsigned_fn` macro normalizes `$n` so that at least one of |
| 209 | + // its two most-significant bits is a 1. |
| 210 | + // |
| 211 | + // Then these stages take as many of the most-significant bits of `$n` |
| 212 | + // as will fit in this stage's type. For example, the stage that |
| 213 | + // handles `u32` deals with the 32 most-significant bits of `$n`. This |
| 214 | + // means that each stage has at least one 1 bit in `n`'s two |
| 215 | + // most-significant bits, making `n` nonzero. |
| 216 | + // |
| 217 | + // Then this stage will produce the correct integer square root for |
| 218 | + // that `n` value. Since `n` is nonzero, `s` will also be nonzero. |
| 219 | + unsafe { crate::hint::assert_unchecked(s != 0) }; |
| 220 | + (s, r) |
| 221 | + }}; |
| 222 | +} |
| 223 | + |
| 224 | +/// Generates the last stage of the computation before denormalization. |
| 225 | +/// |
| 226 | +/// # Safety |
| 227 | +/// |
| 228 | +/// `$s` must be nonzero. |
| 229 | +macro_rules! last_stage { |
| 230 | + ($ty:ty, $n:ident, $s:ident, $r:ident) => {{ |
| 231 | + debug_assert!($s != 0, "`$s` is zero in `last_stage!`."); |
| 232 | + |
| 233 | + const HALF_BITS: u32 = <$ty>::BITS >> 1; |
| 234 | + const QUARTER_BITS: u32 = <$ty>::BITS >> 2; |
| 235 | + const LOWER_HALF_1_BITS: $ty = (1 << HALF_BITS) - 1; |
| 236 | + |
| 237 | + let lo = $n & LOWER_HALF_1_BITS; |
| 238 | + let numerator = (($r as $ty) << QUARTER_BITS) | (lo >> QUARTER_BITS); |
| 239 | + let denominator = ($s as $ty) << 1; |
| 240 | + |
| 241 | + let q = numerator / denominator; |
| 242 | + let mut s = ($s << QUARTER_BITS) as $ty + q; |
| 243 | + let (s_squared, overflow) = s.overflowing_mul(s); |
| 244 | + if overflow || s_squared > $n { |
| 245 | + s -= 1; |
| 246 | + } |
| 247 | + s |
| 248 | + }}; |
| 249 | +} |
| 250 | + |
| 251 | +/// Takes the normalized [`u16`](prim@u16) input and gets its normalized |
| 252 | +/// [integer square root](https://en.wikipedia.org/wiki/Integer_square_root). |
| 253 | +/// |
| 254 | +/// # Safety |
| 255 | +/// |
| 256 | +/// `n` must be nonzero. |
| 257 | +#[inline] |
| 258 | +const fn u16_stages(n: u16) -> u16 { |
| 259 | + let (s, r) = first_stage!(16, n); |
| 260 | + last_stage!(u16, n, s, r) |
| 261 | +} |
| 262 | + |
| 263 | +/// Takes the normalized [`u32`](prim@u32) input and gets its normalized |
| 264 | +/// [integer square root](https://en.wikipedia.org/wiki/Integer_square_root). |
| 265 | +/// |
| 266 | +/// # Safety |
| 267 | +/// |
| 268 | +/// `n` must be nonzero. |
| 269 | +#[inline] |
| 270 | +const fn u32_stages(n: u32) -> u32 { |
| 271 | + let (s, r) = first_stage!(32, n); |
| 272 | + let (s, r) = middle_stage!(32, u16, n, s, r); |
| 273 | + last_stage!(u32, n, s, r) |
| 274 | +} |
| 275 | + |
| 276 | +/// Takes the normalized [`u64`](prim@u64) input and gets its normalized |
| 277 | +/// [integer square root](https://en.wikipedia.org/wiki/Integer_square_root). |
| 278 | +/// |
| 279 | +/// # Safety |
| 280 | +/// |
| 281 | +/// `n` must be nonzero. |
| 282 | +#[inline] |
| 283 | +const fn u64_stages(n: u64) -> u64 { |
| 284 | + let (s, r) = first_stage!(64, n); |
| 285 | + let (s, r) = middle_stage!(64, u16, n, s, r); |
| 286 | + let (s, r) = middle_stage!(64, u32, n, s, r); |
| 287 | + last_stage!(u64, n, s, r) |
| 288 | +} |
| 289 | + |
| 290 | +/// Takes the normalized [`u128`](prim@u128) input and gets its normalized |
| 291 | +/// [integer square root](https://en.wikipedia.org/wiki/Integer_square_root). |
| 292 | +/// |
| 293 | +/// # Safety |
| 294 | +/// |
| 295 | +/// `n` must be nonzero. |
| 296 | +#[inline] |
| 297 | +const fn u128_stages(n: u128) -> u128 { |
| 298 | + let (s, r) = first_stage!(128, n); |
| 299 | + let (s, r) = middle_stage!(128, u16, n, s, r); |
| 300 | + let (s, r) = middle_stage!(128, u32, n, s, r); |
| 301 | + let (s, r) = middle_stage!(128, u64, n, s, r); |
| 302 | + last_stage!(u128, n, s, r) |
| 303 | +} |
| 304 | + |
| 305 | +unsigned_fn!(u16, u8, u16_stages); |
| 306 | +unsigned_fn!(u32, u16, u32_stages); |
| 307 | +unsigned_fn!(u64, u32, u64_stages); |
| 308 | +unsigned_fn!(u128, u64, u128_stages); |
| 309 | + |
| 310 | +/// Instantiate this panic logic once, rather than for all the isqrt methods |
| 311 | +/// on every single primitive type. |
| 312 | +#[cold] |
| 313 | +#[track_caller] |
| 314 | +pub const fn panic_for_negative_argument() -> ! { |
| 315 | + panic!("argument of integer square root cannot be negative") |
| 316 | +} |
0 commit comments