Skip to content

Commit 2139651

Browse files
committed
Improve isqrt tests and add benchmarks
* Choose test inputs more thoroughly and systematically. * Check that `isqrt` and `checked_isqrt` have equivalent results for signed types, either equivalent numerically or equivalent as a panic and a `None`. * Check that `isqrt` has numerically-equivalent results for unsigned types and their `NonZero` counterparts. * Reuse `ilog10` benchmarks, plus benchmarks that use a uniform distribution.
1 parent 932cbd4 commit 2139651

File tree

6 files changed

+313
-32
lines changed

6 files changed

+313
-32
lines changed

Diff for: core/benches/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#![feature(iter_array_chunks)]
99
#![feature(iter_next_chunk)]
1010
#![feature(iter_advance_by)]
11+
#![feature(isqrt)]
1112

1213
extern crate test;
1314

Diff for: core/benches/num/int_sqrt/mod.rs

+62
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
use rand::Rng;
2+
use test::{black_box, Bencher};
3+
4+
macro_rules! int_sqrt_bench {
5+
($t:ty, $predictable:ident, $random:ident, $random_small:ident, $random_uniform:ident) => {
6+
#[bench]
7+
fn $predictable(bench: &mut Bencher) {
8+
bench.iter(|| {
9+
for n in 0..(<$t>::BITS / 8) {
10+
for i in 1..=(100 as $t) {
11+
let x = black_box(i << (n * 8));
12+
black_box(x.isqrt());
13+
}
14+
}
15+
});
16+
}
17+
18+
#[bench]
19+
fn $random(bench: &mut Bencher) {
20+
let mut rng = crate::bench_rng();
21+
/* Exponentially distributed random numbers from the whole range of the type. */
22+
let numbers: Vec<$t> =
23+
(0..256).map(|_| rng.gen::<$t>() >> rng.gen_range(0..<$t>::BITS)).collect();
24+
bench.iter(|| {
25+
for x in &numbers {
26+
black_box(black_box(x).isqrt());
27+
}
28+
});
29+
}
30+
31+
#[bench]
32+
fn $random_small(bench: &mut Bencher) {
33+
let mut rng = crate::bench_rng();
34+
/* Exponentially distributed random numbers from the range 0..256. */
35+
let numbers: Vec<$t> =
36+
(0..256).map(|_| (rng.gen::<u8>() >> rng.gen_range(0..u8::BITS)) as $t).collect();
37+
bench.iter(|| {
38+
for x in &numbers {
39+
black_box(black_box(x).isqrt());
40+
}
41+
});
42+
}
43+
44+
#[bench]
45+
fn $random_uniform(bench: &mut Bencher) {
46+
let mut rng = crate::bench_rng();
47+
/* Exponentially distributed random numbers from the whole range of the type. */
48+
let numbers: Vec<$t> = (0..256).map(|_| rng.gen::<$t>()).collect();
49+
bench.iter(|| {
50+
for x in &numbers {
51+
black_box(black_box(x).isqrt());
52+
}
53+
});
54+
}
55+
};
56+
}
57+
58+
int_sqrt_bench! {u8, u8_sqrt_predictable, u8_sqrt_random, u8_sqrt_random_small, u8_sqrt_uniform}
59+
int_sqrt_bench! {u16, u16_sqrt_predictable, u16_sqrt_random, u16_sqrt_random_small, u16_sqrt_uniform}
60+
int_sqrt_bench! {u32, u32_sqrt_predictable, u32_sqrt_random, u32_sqrt_random_small, u32_sqrt_uniform}
61+
int_sqrt_bench! {u64, u64_sqrt_predictable, u64_sqrt_random, u64_sqrt_random_small, u64_sqrt_uniform}
62+
int_sqrt_bench! {u128, u128_sqrt_predictable, u128_sqrt_random, u128_sqrt_random_small, u128_sqrt_uniform}

Diff for: core/benches/num/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ mod dec2flt;
22
mod flt2dec;
33
mod int_log;
44
mod int_pow;
5+
mod int_sqrt;
56

67
use std::str::FromStr;
78

Diff for: core/tests/num/int_macros.rs

-32
Original file line numberDiff line numberDiff line change
@@ -288,38 +288,6 @@ macro_rules! int_module {
288288
assert_eq!(r.saturating_pow(0), 1 as $T);
289289
}
290290

291-
#[test]
292-
fn test_isqrt() {
293-
assert_eq!($T::MIN.checked_isqrt(), None);
294-
assert_eq!((-1 as $T).checked_isqrt(), None);
295-
assert_eq!((0 as $T).isqrt(), 0 as $T);
296-
assert_eq!((1 as $T).isqrt(), 1 as $T);
297-
assert_eq!((2 as $T).isqrt(), 1 as $T);
298-
assert_eq!((99 as $T).isqrt(), 9 as $T);
299-
assert_eq!((100 as $T).isqrt(), 10 as $T);
300-
}
301-
302-
#[cfg(not(miri))] // Miri is too slow
303-
#[test]
304-
fn test_lots_of_isqrt() {
305-
let n_max: $T = (1024 * 1024).min($T::MAX as u128) as $T;
306-
for n in 0..=n_max {
307-
let isqrt: $T = n.isqrt();
308-
309-
assert!(isqrt.pow(2) <= n);
310-
let (square, overflow) = (isqrt + 1).overflowing_pow(2);
311-
assert!(overflow || square > n);
312-
}
313-
314-
for n in ($T::MAX - 127)..=$T::MAX {
315-
let isqrt: $T = n.isqrt();
316-
317-
assert!(isqrt.pow(2) <= n);
318-
let (square, overflow) = (isqrt + 1).overflowing_pow(2);
319-
assert!(overflow || square > n);
320-
}
321-
}
322-
323291
#[test]
324292
fn test_div_floor() {
325293
let a: $T = 8;

Diff for: core/tests/num/int_sqrt.rs

+248
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,248 @@
1+
macro_rules! tests {
2+
($isqrt_consistency_check_fn_macro:ident : $($T:ident)+) => {
3+
$(
4+
mod $T {
5+
$isqrt_consistency_check_fn_macro!($T);
6+
7+
// Check that the following produce the correct values from
8+
// `isqrt`:
9+
//
10+
// * the first and last 128 nonnegative values
11+
// * powers of two, minus one
12+
// * powers of two
13+
//
14+
// For signed types, check that `checked_isqrt` and `isqrt`
15+
// either produce the same numeric value or respectively
16+
// produce `None` and a panic. Make sure to do a consistency
17+
// check for `<$T>::MIN` as well, as no nonnegative values
18+
// negate to it.
19+
//
20+
// For unsigned types check that `isqrt` produces the same
21+
// numeric value for `$T` and `NonZero<$T>`.
22+
#[test]
23+
fn isqrt() {
24+
isqrt_consistency_check(<$T>::MIN);
25+
26+
for n in (0..=127)
27+
.chain(<$T>::MAX - 127..=<$T>::MAX)
28+
.chain((0..<$T>::MAX.count_ones()).map(|exponent| (1 << exponent) - 1))
29+
.chain((0..<$T>::MAX.count_ones()).map(|exponent| 1 << exponent))
30+
{
31+
isqrt_consistency_check(n);
32+
33+
let isqrt_n = n.isqrt();
34+
assert!(
35+
isqrt_n
36+
.checked_mul(isqrt_n)
37+
.map(|isqrt_n_squared| isqrt_n_squared <= n)
38+
.unwrap_or(false),
39+
"`{n}.isqrt()` should be lower than {isqrt_n}."
40+
);
41+
assert!(
42+
(isqrt_n + 1)
43+
.checked_mul(isqrt_n + 1)
44+
.map(|isqrt_n_plus_1_squared| n < isqrt_n_plus_1_squared)
45+
.unwrap_or(true),
46+
"`{n}.isqrt()` should be higher than {isqrt_n})."
47+
);
48+
}
49+
}
50+
51+
// Check the square roots of:
52+
//
53+
// * the first 1,024 perfect squares
54+
// * halfway between each of the first 1,024 perfect squares
55+
// and the next perfect square
56+
// * the next perfect square after the each of the first 1,024
57+
// perfect squares, minus one
58+
// * the last 1,024 perfect squares
59+
// * the last 1,024 perfect squares, minus one
60+
// * halfway between each of the last 1,024 perfect squares
61+
// and the previous perfect square
62+
#[test]
63+
// Skip this test on Miri, as it takes too long to run.
64+
#[cfg(not(miri))]
65+
fn isqrt_extended() {
66+
// The correct value is worked out by using the fact that
67+
// the nth nonzero perfect square is the sum of the first n
68+
// odd numbers:
69+
//
70+
// 1 = 1
71+
// 4 = 1 + 3
72+
// 9 = 1 + 3 + 5
73+
// 16 = 1 + 3 + 5 + 7
74+
//
75+
// Note also that the last odd number added in is two times
76+
// the square root of the previous perfect square, plus
77+
// one:
78+
//
79+
// 1 = 2*0 + 1
80+
// 3 = 2*1 + 1
81+
// 5 = 2*2 + 1
82+
// 7 = 2*3 + 1
83+
//
84+
// That means we can add the square root of this perfect
85+
// square once to get about halfway to the next perfect
86+
// square, then we can add the square root of this perfect
87+
// square again to get to the next perfect square, minus
88+
// one, then we can add one to get to the next perfect
89+
// square.
90+
//
91+
// This allows us to, for each of the first 1,024 perfect
92+
// squares, test that the square roots of the following are
93+
// all correct and equal to each other:
94+
//
95+
// * the current perfect square
96+
// * about halfway to the next perfect square
97+
// * the next perfect square, minus one
98+
let mut n: $T = 0;
99+
for sqrt_n in 0..1_024.min((1_u128 << (<$T>::MAX.count_ones()/2)) - 1) as $T {
100+
isqrt_consistency_check(n);
101+
assert_eq!(
102+
n.isqrt(),
103+
sqrt_n,
104+
"`{sqrt_n}.pow(2).isqrt()` should be {sqrt_n}."
105+
);
106+
107+
n += sqrt_n;
108+
isqrt_consistency_check(n);
109+
assert_eq!(
110+
n.isqrt(),
111+
sqrt_n,
112+
"{n} is about halfway between `{sqrt_n}.pow(2)` and `{}.pow(2)`, so `{n}.isqrt()` should be {sqrt_n}.",
113+
sqrt_n + 1
114+
);
115+
116+
n += sqrt_n;
117+
isqrt_consistency_check(n);
118+
assert_eq!(
119+
n.isqrt(),
120+
sqrt_n,
121+
"`({}.pow(2) - 1).isqrt()` should be {sqrt_n}.",
122+
sqrt_n + 1
123+
);
124+
125+
n += 1;
126+
}
127+
128+
// Similarly, for each of the last 1,024 perfect squares,
129+
// check:
130+
//
131+
// * the current perfect square
132+
// * the current perfect square, minus one
133+
// * about halfway to the previous perfect square
134+
//
135+
// `MAX`'s `isqrt` return value is verified in the `isqrt`
136+
// test function above.
137+
let maximum_sqrt = <$T>::MAX.isqrt();
138+
let mut n = maximum_sqrt * maximum_sqrt;
139+
140+
for sqrt_n in (maximum_sqrt - 1_024.min((1_u128 << (<$T>::MAX.count_ones()/2)) - 1) as $T..maximum_sqrt).rev() {
141+
isqrt_consistency_check(n);
142+
assert_eq!(
143+
n.isqrt(),
144+
sqrt_n + 1,
145+
"`{0}.pow(2).isqrt()` should be {0}.",
146+
sqrt_n + 1
147+
);
148+
149+
n -= 1;
150+
isqrt_consistency_check(n);
151+
assert_eq!(
152+
n.isqrt(),
153+
sqrt_n,
154+
"`({}.pow(2) - 1).isqrt()` should be {sqrt_n}.",
155+
sqrt_n + 1
156+
);
157+
158+
n -= sqrt_n;
159+
isqrt_consistency_check(n);
160+
assert_eq!(
161+
n.isqrt(),
162+
sqrt_n,
163+
"{n} is about halfway between `{sqrt_n}.pow(2)` and `{}.pow(2)`, so `{n}.isqrt()` should be {sqrt_n}.",
164+
sqrt_n + 1
165+
);
166+
167+
n -= sqrt_n;
168+
}
169+
}
170+
}
171+
)*
172+
};
173+
}
174+
175+
macro_rules! signed_check {
176+
($T:ident) => {
177+
/// This takes an input and, if it's nonnegative or
178+
#[doc = concat!("`", stringify!($T), "::MIN`,")]
179+
/// checks that `isqrt` and `checked_isqrt` produce equivalent results
180+
/// for that input and for the negative of that input.
181+
///
182+
/// # Note
183+
///
184+
/// This cannot check that negative inputs to `isqrt` cause panics if
185+
/// panics abort instead of unwind.
186+
fn isqrt_consistency_check(n: $T) {
187+
// `<$T>::MIN` will be negative, so ignore it in this nonnegative
188+
// section.
189+
if n >= 0 {
190+
assert_eq!(
191+
Some(n.isqrt()),
192+
n.checked_isqrt(),
193+
"`{n}.checked_isqrt()` should match `Some({n}.isqrt())`.",
194+
);
195+
}
196+
197+
// `wrapping_neg` so that `<$T>::MIN` will negate to itself rather
198+
// than panicking.
199+
let negative_n = n.wrapping_neg();
200+
201+
// Zero negated will still be nonnegative, so ignore it in this
202+
// negative section.
203+
if negative_n < 0 {
204+
assert_eq!(
205+
negative_n.checked_isqrt(),
206+
None,
207+
"`({negative_n}).checked_isqrt()` should be `None`, as {negative_n} is negative.",
208+
);
209+
210+
// `catch_unwind` only works when panics unwind rather than abort.
211+
#[cfg(panic = "unwind")]
212+
{
213+
std::panic::catch_unwind(core::panic::AssertUnwindSafe(|| (-n).isqrt())).expect_err(
214+
&format!("`({negative_n}).isqrt()` should have panicked, as {negative_n} is negative.")
215+
);
216+
}
217+
}
218+
}
219+
};
220+
}
221+
222+
macro_rules! unsigned_check {
223+
($T:ident) => {
224+
/// This takes an input and, if it's nonzero, checks that `isqrt`
225+
/// produces the same numeric value for both
226+
#[doc = concat!("`", stringify!($T), "` and ")]
227+
#[doc = concat!("`NonZero<", stringify!($T), ">`.")]
228+
fn isqrt_consistency_check(n: $T) {
229+
// Zero cannot be turned into a `NonZero` value, so ignore it in
230+
// this nonzero section.
231+
if n > 0 {
232+
assert_eq!(
233+
n.isqrt(),
234+
core::num::NonZero::<$T>::new(n)
235+
.expect(
236+
"Was not able to create a new `NonZero` value from a nonzero number."
237+
)
238+
.isqrt()
239+
.get(),
240+
"`{n}.isqrt` should match `NonZero`'s `{n}.isqrt().get()`.",
241+
);
242+
}
243+
}
244+
};
245+
}
246+
247+
tests!(signed_check: i8 i16 i32 i64 i128);
248+
tests!(unsigned_check: u8 u16 u32 u64 u128);

Diff for: core/tests/num/mod.rs

+1
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ mod const_from;
2727
mod dec2flt;
2828
mod flt2dec;
2929
mod int_log;
30+
mod int_sqrt;
3031
mod ops;
3132
mod wrapping;
3233

0 commit comments

Comments
 (0)