Skip to content

Commit 18bfe5d

Browse files
committed
Auto merge of rust-lang#92048 - Urgau:num-midpoint, r=scottmcm
Add midpoint function for all integers and floating numbers This pull-request adds the `midpoint` function to `{u,i}{8,16,32,64,128,size}`, `NonZeroU{8,16,32,64,size}` and `f{32,64}`. This new function is analog to the [C++ midpoint](https://en.cppreference.com/w/cpp/numeric/midpoint) function, and basically compute `(a + b) / 2` with a rounding towards ~~`a`~~ negative infinity in the case of integers. Or simply said: `midpoint(a, b)` is `(a + b) >> 1` as if it were performed in a sufficiently-large signed integral type. Note that unlike the C++ function this pull-request does not implement this function on pointers (`*const T` or `*mut T`). This could be implemented in a future pull-request if desire. ### Implementation For `f32` and `f64` the implementation in based on the `libcxx` [one](https://github.com/llvm/llvm-project/blob/18ab892ff7e9032914ff7fdb07685d5945c84fef/libcxx/include/__numeric/midpoint.h#L65-L77). I originally tried many different approach but all of them failed or lead me with a poor version of the `libcxx`. Note that `libstdc++` has a very similar one; Microsoft STL implementation is also basically the same as `libcxx`. It unfortunately doesn't seems like a better way exist. For unsigned integers I created the macro `midpoint_impl!`, this macro has two branches: - The first one take `$SelfT` and is used when there is no unsigned integer with at least the double of bits. The code simply use this formula `a + (b - a) / 2` with the arguments in the correct order and signs to have the good rounding. - The second branch is used when a `$WideT` (at least double of bits as `$SelfT`) is provided, using a wider number means that no overflow can occur, this greatly improve the codegen (no branch and less instructions). For signed integers the code basically forwards the signed numbers to the unsigned version of midpoint by mapping the signed numbers to their unsigned numbers (`ex: i8 [-128; 127] to [0; 255]`) and vice versa. I originally created a version that worked directly on the signed numbers but the code was "ugly" and not understandable. Despite this mapping "overhead" the codegen is better than my most optimized version on signed integers. ~~Note that in the case of unsigned numbers I tried to be smart and used `#[cfg(target_pointer_width = "64")]` to determine if using the wide version was better or not by looking at the assembly on godbolt. This was applied to `u32`, `u64` and `usize` and doesn't change the behavior only the assembly code generated.~~
2 parents eda41ad + 95a383b commit 18bfe5d

File tree

10 files changed

+313
-3
lines changed

10 files changed

+313
-3
lines changed

library/core/src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,7 @@
133133
#![feature(const_maybe_uninit_assume_init)]
134134
#![feature(const_maybe_uninit_uninit_array)]
135135
#![feature(const_nonnull_new)]
136+
#![feature(const_num_midpoint)]
136137
#![feature(const_option)]
137138
#![feature(const_option_ext)]
138139
#![feature(const_pin)]

library/core/src/num/f32.rs

+36
Original file line numberDiff line numberDiff line change
@@ -940,6 +940,42 @@ impl f32 {
940940
}
941941
}
942942

943+
/// Calculates the middle point of `self` and `rhs`.
944+
///
945+
/// This returns NaN when *either* argument is NaN or if a combination of
946+
/// +inf and -inf is provided as arguments.
947+
///
948+
/// # Examples
949+
///
950+
/// ```
951+
/// #![feature(num_midpoint)]
952+
/// assert_eq!(1f32.midpoint(4.0), 2.5);
953+
/// assert_eq!((-5.5f32).midpoint(8.0), 1.25);
954+
/// ```
955+
#[unstable(feature = "num_midpoint", issue = "110840")]
956+
pub fn midpoint(self, other: f32) -> f32 {
957+
const LO: f32 = f32::MIN_POSITIVE * 2.;
958+
const HI: f32 = f32::MAX / 2.;
959+
960+
let (a, b) = (self, other);
961+
let abs_a = a.abs_private();
962+
let abs_b = b.abs_private();
963+
964+
if abs_a <= HI && abs_b <= HI {
965+
// Overflow is impossible
966+
(a + b) / 2.
967+
} else if abs_a < LO {
968+
// Not safe to halve a
969+
a + (b / 2.)
970+
} else if abs_b < LO {
971+
// Not safe to halve b
972+
(a / 2.) + b
973+
} else {
974+
// Not safe to halve a and b
975+
(a / 2.) + (b / 2.)
976+
}
977+
}
978+
943979
/// Rounds toward zero and converts to any primitive integer type,
944980
/// assuming that the value is finite and fits in that type.
945981
///

library/core/src/num/f64.rs

+36
Original file line numberDiff line numberDiff line change
@@ -951,6 +951,42 @@ impl f64 {
951951
}
952952
}
953953

954+
/// Calculates the middle point of `self` and `rhs`.
955+
///
956+
/// This returns NaN when *either* argument is NaN or if a combination of
957+
/// +inf and -inf is provided as arguments.
958+
///
959+
/// # Examples
960+
///
961+
/// ```
962+
/// #![feature(num_midpoint)]
963+
/// assert_eq!(1f64.midpoint(4.0), 2.5);
964+
/// assert_eq!((-5.5f64).midpoint(8.0), 1.25);
965+
/// ```
966+
#[unstable(feature = "num_midpoint", issue = "110840")]
967+
pub fn midpoint(self, other: f64) -> f64 {
968+
const LO: f64 = f64::MIN_POSITIVE * 2.;
969+
const HI: f64 = f64::MAX / 2.;
970+
971+
let (a, b) = (self, other);
972+
let abs_a = a.abs_private();
973+
let abs_b = b.abs_private();
974+
975+
if abs_a <= HI && abs_b <= HI {
976+
// Overflow is impossible
977+
(a + b) / 2.
978+
} else if abs_a < LO {
979+
// Not safe to halve a
980+
a + (b / 2.)
981+
} else if abs_b < LO {
982+
// Not safe to halve b
983+
(a / 2.) + b
984+
} else {
985+
// Not safe to halve a and b
986+
(a / 2.) + (b / 2.)
987+
}
988+
}
989+
954990
/// Rounds toward zero and converts to any primitive integer type,
955991
/// assuming that the value is finite and fits in that type.
956992
///

library/core/src/num/int_macros.rs

+38
Original file line numberDiff line numberDiff line change
@@ -2332,6 +2332,44 @@ macro_rules! int_impl {
23322332
}
23332333
}
23342334

2335+
/// Calculates the middle point of `self` and `rhs`.
2336+
///
2337+
/// `midpoint(a, b)` is `(a + b) >> 1` as if it were performed in a
2338+
/// sufficiently-large signed integral type. This implies that the result is
2339+
/// always rounded towards negative infinity and that no overflow will ever occur.
2340+
///
2341+
/// # Examples
2342+
///
2343+
/// ```
2344+
/// #![feature(num_midpoint)]
2345+
#[doc = concat!("assert_eq!(0", stringify!($SelfT), ".midpoint(4), 2);")]
2346+
#[doc = concat!("assert_eq!(0", stringify!($SelfT), ".midpoint(-1), -1);")]
2347+
#[doc = concat!("assert_eq!((-1", stringify!($SelfT), ").midpoint(0), -1);")]
2348+
/// ```
2349+
#[unstable(feature = "num_midpoint", issue = "110840")]
2350+
#[rustc_const_unstable(feature = "const_num_midpoint", issue = "110840")]
2351+
#[rustc_allow_const_fn_unstable(const_num_midpoint)]
2352+
#[must_use = "this returns the result of the operation, \
2353+
without modifying the original"]
2354+
#[inline]
2355+
pub const fn midpoint(self, rhs: Self) -> Self {
2356+
const U: $UnsignedT = <$SelfT>::MIN.unsigned_abs();
2357+
2358+
// Map an $SelfT to an $UnsignedT
2359+
// ex: i8 [-128; 127] to [0; 255]
2360+
const fn map(a: $SelfT) -> $UnsignedT {
2361+
(a as $UnsignedT) ^ U
2362+
}
2363+
2364+
// Map an $UnsignedT to an $SelfT
2365+
// ex: u8 [0; 255] to [-128; 127]
2366+
const fn demap(a: $UnsignedT) -> $SelfT {
2367+
(a ^ U) as $SelfT
2368+
}
2369+
2370+
demap(<$UnsignedT>::midpoint(map(self), map(rhs)))
2371+
}
2372+
23352373
/// Returns the logarithm of the number with respect to an arbitrary base,
23362374
/// rounded down.
23372375
///

library/core/src/num/mod.rs

+59
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,57 @@ depending on the target pointer size.
9595
};
9696
}
9797

98+
macro_rules! midpoint_impl {
99+
($SelfT:ty, unsigned) => {
100+
/// Calculates the middle point of `self` and `rhs`.
101+
///
102+
/// `midpoint(a, b)` is `(a + b) >> 1` as if it were performed in a
103+
/// sufficiently-large signed integral type. This implies that the result is
104+
/// always rounded towards negative infinity and that no overflow will ever occur.
105+
///
106+
/// # Examples
107+
///
108+
/// ```
109+
/// #![feature(num_midpoint)]
110+
#[doc = concat!("assert_eq!(0", stringify!($SelfT), ".midpoint(4), 2);")]
111+
#[doc = concat!("assert_eq!(1", stringify!($SelfT), ".midpoint(4), 2);")]
112+
/// ```
113+
#[unstable(feature = "num_midpoint", issue = "110840")]
114+
#[rustc_const_unstable(feature = "const_num_midpoint", issue = "110840")]
115+
#[must_use = "this returns the result of the operation, \
116+
without modifying the original"]
117+
#[inline]
118+
pub const fn midpoint(self, rhs: $SelfT) -> $SelfT {
119+
// Use the well known branchless algorthim from Hacker's Delight to compute
120+
// `(a + b) / 2` without overflowing: `((a ^ b) >> 1) + (a & b)`.
121+
((self ^ rhs) >> 1) + (self & rhs)
122+
}
123+
};
124+
($SelfT:ty, $WideT:ty, unsigned) => {
125+
/// Calculates the middle point of `self` and `rhs`.
126+
///
127+
/// `midpoint(a, b)` is `(a + b) >> 1` as if it were performed in a
128+
/// sufficiently-large signed integral type. This implies that the result is
129+
/// always rounded towards negative infinity and that no overflow will ever occur.
130+
///
131+
/// # Examples
132+
///
133+
/// ```
134+
/// #![feature(num_midpoint)]
135+
#[doc = concat!("assert_eq!(0", stringify!($SelfT), ".midpoint(4), 2);")]
136+
#[doc = concat!("assert_eq!(1", stringify!($SelfT), ".midpoint(4), 2);")]
137+
/// ```
138+
#[unstable(feature = "num_midpoint", issue = "110840")]
139+
#[rustc_const_unstable(feature = "const_num_midpoint", issue = "110840")]
140+
#[must_use = "this returns the result of the operation, \
141+
without modifying the original"]
142+
#[inline]
143+
pub const fn midpoint(self, rhs: $SelfT) -> $SelfT {
144+
((self as $WideT + rhs as $WideT) / 2) as $SelfT
145+
}
146+
};
147+
}
148+
98149
macro_rules! widening_impl {
99150
($SelfT:ty, $WideT:ty, $BITS:literal, unsigned) => {
100151
/// Calculates the complete product `self * rhs` without the possibility to overflow.
@@ -455,6 +506,7 @@ impl u8 {
455506
bound_condition = "",
456507
}
457508
widening_impl! { u8, u16, 8, unsigned }
509+
midpoint_impl! { u8, u16, unsigned }
458510

459511
/// Checks if the value is within the ASCII range.
460512
///
@@ -1066,6 +1118,7 @@ impl u16 {
10661118
bound_condition = "",
10671119
}
10681120
widening_impl! { u16, u32, 16, unsigned }
1121+
midpoint_impl! { u16, u32, unsigned }
10691122

10701123
/// Checks if the value is a Unicode surrogate code point, which are disallowed values for [`char`].
10711124
///
@@ -1114,6 +1167,7 @@ impl u32 {
11141167
bound_condition = "",
11151168
}
11161169
widening_impl! { u32, u64, 32, unsigned }
1170+
midpoint_impl! { u32, u64, unsigned }
11171171
}
11181172

11191173
impl u64 {
@@ -1137,6 +1191,7 @@ impl u64 {
11371191
bound_condition = "",
11381192
}
11391193
widening_impl! { u64, u128, 64, unsigned }
1194+
midpoint_impl! { u64, u128, unsigned }
11401195
}
11411196

11421197
impl u128 {
@@ -1161,6 +1216,7 @@ impl u128 {
11611216
from_xe_bytes_doc = "",
11621217
bound_condition = "",
11631218
}
1219+
midpoint_impl! { u128, unsigned }
11641220
}
11651221

11661222
#[cfg(target_pointer_width = "16")]
@@ -1185,6 +1241,7 @@ impl usize {
11851241
bound_condition = " on 16-bit targets",
11861242
}
11871243
widening_impl! { usize, u32, 16, unsigned }
1244+
midpoint_impl! { usize, u32, unsigned }
11881245
}
11891246

11901247
#[cfg(target_pointer_width = "32")]
@@ -1209,6 +1266,7 @@ impl usize {
12091266
bound_condition = " on 32-bit targets",
12101267
}
12111268
widening_impl! { usize, u64, 32, unsigned }
1269+
midpoint_impl! { usize, u64, unsigned }
12121270
}
12131271

12141272
#[cfg(target_pointer_width = "64")]
@@ -1233,6 +1291,7 @@ impl usize {
12331291
bound_condition = " on 64-bit targets",
12341292
}
12351293
widening_impl! { usize, u128, 64, unsigned }
1294+
midpoint_impl! { usize, u128, unsigned }
12361295
}
12371296

12381297
impl usize {

library/core/src/num/nonzero.rs

+37
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,43 @@ macro_rules! nonzero_unsigned_operations {
493493
pub const fn ilog10(self) -> u32 {
494494
super::int_log10::$Int(self.0)
495495
}
496+
497+
/// Calculates the middle point of `self` and `rhs`.
498+
///
499+
/// `midpoint(a, b)` is `(a + b) >> 1` as if it were performed in a
500+
/// sufficiently-large signed integral type. This implies that the result is
501+
/// always rounded towards negative infinity and that no overflow will ever occur.
502+
///
503+
/// # Examples
504+
///
505+
/// ```
506+
/// #![feature(num_midpoint)]
507+
#[doc = concat!("# use std::num::", stringify!($Ty), ";")]
508+
///
509+
/// # fn main() { test().unwrap(); }
510+
/// # fn test() -> Option<()> {
511+
#[doc = concat!("let one = ", stringify!($Ty), "::new(1)?;")]
512+
#[doc = concat!("let two = ", stringify!($Ty), "::new(2)?;")]
513+
#[doc = concat!("let four = ", stringify!($Ty), "::new(4)?;")]
514+
///
515+
/// assert_eq!(one.midpoint(four), two);
516+
/// assert_eq!(four.midpoint(one), two);
517+
/// # Some(())
518+
/// # }
519+
/// ```
520+
#[unstable(feature = "num_midpoint", issue = "110840")]
521+
#[rustc_const_unstable(feature = "const_num_midpoint", issue = "110840")]
522+
#[rustc_allow_const_fn_unstable(const_num_midpoint)]
523+
#[must_use = "this returns the result of the operation, \
524+
without modifying the original"]
525+
#[inline]
526+
pub const fn midpoint(self, rhs: Self) -> Self {
527+
// SAFETY: The only way to get `0` with midpoint is to have two opposite or
528+
// near opposite numbers: (-5, 5), (0, 1), (0, 0) which is impossible because
529+
// of the unsignedness of this number and also because $Ty is guaranteed to
530+
// never being 0.
531+
unsafe { $Ty::new_unchecked(self.get().midpoint(rhs.get())) }
532+
}
496533
}
497534
)+
498535
}

library/core/tests/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
#![feature(maybe_uninit_uninit_array_transpose)]
5454
#![feature(min_specialization)]
5555
#![feature(numfmt)]
56+
#![feature(num_midpoint)]
5657
#![feature(step_trait)]
5758
#![feature(str_internals)]
5859
#![feature(std_internals)]

library/core/tests/num/int_macros.rs

+26
Original file line numberDiff line numberDiff line change
@@ -364,6 +364,32 @@ macro_rules! int_module {
364364
assert_eq!((0 as $T).borrowing_sub($T::MIN, false), ($T::MIN, true));
365365
assert_eq!((0 as $T).borrowing_sub($T::MIN, true), ($T::MAX, false));
366366
}
367+
368+
#[test]
369+
fn test_midpoint() {
370+
assert_eq!(<$T>::midpoint(1, 3), 2);
371+
assert_eq!(<$T>::midpoint(3, 1), 2);
372+
373+
assert_eq!(<$T>::midpoint(0, 0), 0);
374+
assert_eq!(<$T>::midpoint(0, 2), 1);
375+
assert_eq!(<$T>::midpoint(2, 0), 1);
376+
assert_eq!(<$T>::midpoint(2, 2), 2);
377+
378+
assert_eq!(<$T>::midpoint(1, 4), 2);
379+
assert_eq!(<$T>::midpoint(4, 1), 2);
380+
assert_eq!(<$T>::midpoint(3, 4), 3);
381+
assert_eq!(<$T>::midpoint(4, 3), 3);
382+
383+
assert_eq!(<$T>::midpoint(<$T>::MIN, <$T>::MAX), -1);
384+
assert_eq!(<$T>::midpoint(<$T>::MAX, <$T>::MIN), -1);
385+
assert_eq!(<$T>::midpoint(<$T>::MIN, <$T>::MIN), <$T>::MIN);
386+
assert_eq!(<$T>::midpoint(<$T>::MAX, <$T>::MAX), <$T>::MAX);
387+
388+
assert_eq!(<$T>::midpoint(<$T>::MIN, 6), <$T>::MIN / 2 + 3);
389+
assert_eq!(<$T>::midpoint(6, <$T>::MIN), <$T>::MIN / 2 + 3);
390+
assert_eq!(<$T>::midpoint(<$T>::MAX, 6), <$T>::MAX / 2 + 3);
391+
assert_eq!(<$T>::midpoint(6, <$T>::MAX), <$T>::MAX / 2 + 3);
392+
}
367393
}
368394
};
369395
}

0 commit comments

Comments
 (0)