Skip to content

Commit b8536c1

Browse files
committed
Auto merge of rust-lang#116176 - FedericoStra:isqrt, r=dtolnay
Add "integer square root" method to integer primitive types For every suffix `N` among `8`, `16`, `32`, `64`, `128` and `size`, this PR adds the methods ```rust const fn uN::isqrt() -> uN; const fn iN::isqrt() -> iN; const fn iN::checked_isqrt() -> Option<iN>; ``` to compute the [integer square root](https://en.wikipedia.org/wiki/Integer_square_root), addressing issue rust-lang#89273. The implementation is based on the [base 2 digit-by-digit algorithm](https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Binary_numeral_system_(base_2)) on Wikipedia, which after some benchmarking has proved to be faster than both binary search and Heron's/Newton's method. I haven't had the time to understand and port [this code](http://atoms.alife.co.uk/sqrt/SquareRoot.java) based on lookup tables instead, but I'm not sure whether it's worth complicating such a function this much for relatively little benefit.
2 parents c1f86f0 + 25648de commit b8536c1

File tree

6 files changed

+165
-0
lines changed

6 files changed

+165
-0
lines changed

Diff for: library/core/src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@
178178
#![feature(ip)]
179179
#![feature(ip_bits)]
180180
#![feature(is_ascii_octdigit)]
181+
#![feature(isqrt)]
181182
#![feature(maybe_uninit_uninit_array)]
182183
#![feature(ptr_alignment_type)]
183184
#![feature(ptr_metadata)]

Diff for: library/core/src/num/int_macros.rs

+54
Original file line numberDiff line numberDiff line change
@@ -898,6 +898,30 @@ macro_rules! int_impl {
898898
acc.checked_mul(base)
899899
}
900900

901+
/// Returns the square root of the number, rounded down.
902+
///
903+
/// Returns `None` if `self` is negative.
904+
///
905+
/// # Examples
906+
///
907+
/// Basic usage:
908+
/// ```
909+
/// #![feature(isqrt)]
910+
#[doc = concat!("assert_eq!(10", stringify!($SelfT), ".checked_isqrt(), Some(3));")]
911+
/// ```
912+
#[unstable(feature = "isqrt", issue = "116226")]
913+
#[rustc_const_unstable(feature = "isqrt", issue = "116226")]
914+
#[must_use = "this returns the result of the operation, \
915+
without modifying the original"]
916+
#[inline]
917+
pub const fn checked_isqrt(self) -> Option<Self> {
918+
if self < 0 {
919+
None
920+
} else {
921+
Some((self as $UnsignedT).isqrt() as Self)
922+
}
923+
}
924+
901925
/// Saturating integer addition. Computes `self + rhs`, saturating at the numeric
902926
/// bounds instead of overflowing.
903927
///
@@ -2061,6 +2085,36 @@ macro_rules! int_impl {
20612085
acc * base
20622086
}
20632087

2088+
/// Returns the square root of the number, rounded down.
2089+
///
2090+
/// # Panics
2091+
///
2092+
/// This function will panic if `self` is negative.
2093+
///
2094+
/// # Examples
2095+
///
2096+
/// Basic usage:
2097+
/// ```
2098+
/// #![feature(isqrt)]
2099+
#[doc = concat!("assert_eq!(10", stringify!($SelfT), ".isqrt(), 3);")]
2100+
/// ```
2101+
#[unstable(feature = "isqrt", issue = "116226")]
2102+
#[rustc_const_unstable(feature = "isqrt", issue = "116226")]
2103+
#[must_use = "this returns the result of the operation, \
2104+
without modifying the original"]
2105+
#[inline]
2106+
pub const fn isqrt(self) -> Self {
2107+
// I would like to implement it as
2108+
// ```
2109+
// self.checked_isqrt().expect("argument of integer square root must be non-negative")
2110+
// ```
2111+
// but `expect` is not yet stable as a `const fn`.
2112+
match self.checked_isqrt() {
2113+
Some(sqrt) => sqrt,
2114+
None => panic!("argument of integer square root must be non-negative"),
2115+
}
2116+
}
2117+
20642118
/// Calculates the quotient of Euclidean division of `self` by `rhs`.
20652119
///
20662120
/// This computes the integer `q` such that `self = q * rhs + r`, with

Diff for: library/core/src/num/uint_macros.rs

+48
Original file line numberDiff line numberDiff line change
@@ -1995,6 +1995,54 @@ macro_rules! uint_impl {
19951995
acc * base
19961996
}
19971997

1998+
/// Returns the square root of the number, rounded down.
1999+
///
2000+
/// # Examples
2001+
///
2002+
/// Basic usage:
2003+
/// ```
2004+
/// #![feature(isqrt)]
2005+
#[doc = concat!("assert_eq!(10", stringify!($SelfT), ".isqrt(), 3);")]
2006+
/// ```
2007+
#[unstable(feature = "isqrt", issue = "116226")]
2008+
#[rustc_const_unstable(feature = "isqrt", issue = "116226")]
2009+
#[must_use = "this returns the result of the operation, \
2010+
without modifying the original"]
2011+
#[inline]
2012+
pub const fn isqrt(self) -> Self {
2013+
if self < 2 {
2014+
return self;
2015+
}
2016+
2017+
// The algorithm is based on the one presented in
2018+
// <https://en.wikipedia.org/wiki/Methods_of_computing_square_roots#Binary_numeral_system_(base_2)>
2019+
// which cites as source the following C code:
2020+
// <https://web.archive.org/web/20120306040058/http://medialab.freaknet.org/martin/src/sqrt/sqrt.c>.
2021+
2022+
let mut op = self;
2023+
let mut res = 0;
2024+
let mut one = 1 << (self.ilog2() & !1);
2025+
2026+
while one != 0 {
2027+
if op >= res + one {
2028+
op -= res + one;
2029+
res = (res >> 1) + one;
2030+
} else {
2031+
res >>= 1;
2032+
}
2033+
one >>= 2;
2034+
}
2035+
2036+
// SAFETY: the result is positive and fits in an integer with half as many bits.
2037+
// Inform the optimizer about it.
2038+
unsafe {
2039+
intrinsics::assume(0 < res);
2040+
intrinsics::assume(res < 1 << (Self::BITS / 2));
2041+
}
2042+
2043+
res
2044+
}
2045+
19982046
/// Performs Euclidean division.
19992047
///
20002048
/// Since, for the positive integers, all common

Diff for: library/core/tests/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@
5656
#![feature(min_specialization)]
5757
#![feature(numfmt)]
5858
#![feature(num_midpoint)]
59+
#![feature(isqrt)]
5960
#![feature(step_trait)]
6061
#![feature(str_internals)]
6162
#![feature(std_internals)]

Diff for: library/core/tests/num/int_macros.rs

+32
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,38 @@ macro_rules! int_module {
290290
assert_eq!(r.saturating_pow(0), 1 as $T);
291291
}
292292

293+
#[test]
294+
fn test_isqrt() {
295+
assert_eq!($T::MIN.checked_isqrt(), None);
296+
assert_eq!((-1 as $T).checked_isqrt(), None);
297+
assert_eq!((0 as $T).isqrt(), 0 as $T);
298+
assert_eq!((1 as $T).isqrt(), 1 as $T);
299+
assert_eq!((2 as $T).isqrt(), 1 as $T);
300+
assert_eq!((99 as $T).isqrt(), 9 as $T);
301+
assert_eq!((100 as $T).isqrt(), 10 as $T);
302+
}
303+
304+
#[cfg(not(miri))] // Miri is too slow
305+
#[test]
306+
fn test_lots_of_isqrt() {
307+
let n_max: $T = (1024 * 1024).min($T::MAX as u128) as $T;
308+
for n in 0..=n_max {
309+
let isqrt: $T = n.isqrt();
310+
311+
assert!(isqrt.pow(2) <= n);
312+
let (square, overflow) = (isqrt + 1).overflowing_pow(2);
313+
assert!(overflow || square > n);
314+
}
315+
316+
for n in ($T::MAX - 127)..=$T::MAX {
317+
let isqrt: $T = n.isqrt();
318+
319+
assert!(isqrt.pow(2) <= n);
320+
let (square, overflow) = (isqrt + 1).overflowing_pow(2);
321+
assert!(overflow || square > n);
322+
}
323+
}
324+
293325
#[test]
294326
fn test_div_floor() {
295327
let a: $T = 8;

Diff for: library/core/tests/num/uint_macros.rs

+29
Original file line numberDiff line numberDiff line change
@@ -206,6 +206,35 @@ macro_rules! uint_module {
206206
assert_eq!(r.saturating_pow(2), MAX);
207207
}
208208

209+
#[test]
210+
fn test_isqrt() {
211+
assert_eq!((0 as $T).isqrt(), 0 as $T);
212+
assert_eq!((1 as $T).isqrt(), 1 as $T);
213+
assert_eq!((2 as $T).isqrt(), 1 as $T);
214+
assert_eq!((99 as $T).isqrt(), 9 as $T);
215+
assert_eq!((100 as $T).isqrt(), 10 as $T);
216+
assert_eq!($T::MAX.isqrt(), (1 << ($T::BITS / 2)) - 1);
217+
}
218+
219+
#[cfg(not(miri))] // Miri is too slow
220+
#[test]
221+
fn test_lots_of_isqrt() {
222+
let n_max: $T = (1024 * 1024).min($T::MAX as u128) as $T;
223+
for n in 0..=n_max {
224+
let isqrt: $T = n.isqrt();
225+
226+
assert!(isqrt.pow(2) <= n);
227+
assert!(isqrt + 1 == (1 as $T) << ($T::BITS / 2) || (isqrt + 1).pow(2) > n);
228+
}
229+
230+
for n in ($T::MAX - 255)..=$T::MAX {
231+
let isqrt: $T = n.isqrt();
232+
233+
assert!(isqrt.pow(2) <= n);
234+
assert!(isqrt + 1 == (1 as $T) << ($T::BITS / 2) || (isqrt + 1).pow(2) > n);
235+
}
236+
}
237+
209238
#[test]
210239
fn test_div_floor() {
211240
assert_eq!((8 as $T).div_floor(3), 2);

0 commit comments

Comments
 (0)