Skip to content

Commit 2795848

Browse files
committed
Convert fmaf to a generic implementation
Introduce a version of generic `fma` that works when there is a larger hardware-backed float type available to compute the result with more precision. This is currently used only for `f32`, but with some minor adjustments it should work for `f16` as well.
1 parent 57a21a1 commit 2795848

File tree

6 files changed

+129
-99
lines changed

6 files changed

+129
-99
lines changed

src/math/fmaf.rs

Lines changed: 2 additions & 94 deletions
Original file line numberDiff line numberDiff line change
@@ -1,103 +1,11 @@
1-
/* origin: FreeBSD /usr/src/lib/msun/src/s_fmaf.c */
2-
/*-
3-
* Copyright (c) 2005-2011 David Schultz <[email protected]>
4-
* All rights reserved.
5-
*
6-
* Redistribution and use in source and binary forms, with or without
7-
* modification, are permitted provided that the following conditions
8-
* are met:
9-
* 1. Redistributions of source code must retain the above copyright
10-
* notice, this list of conditions and the following disclaimer.
11-
* 2. Redistributions in binary form must reproduce the above copyright
12-
* notice, this list of conditions and the following disclaimer in the
13-
* documentation and/or other materials provided with the distribution.
14-
*
15-
* THIS SOFTWARE IS PROVIDED BY THE AUTHOR AND CONTRIBUTORS ``AS IS'' AND
16-
* ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
17-
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
18-
* ARE DISCLAIMED. IN NO EVENT SHALL THE AUTHOR OR CONTRIBUTORS BE LIABLE
19-
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
20-
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS
21-
* OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
22-
* HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
23-
* LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY
24-
* OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF
25-
* SUCH DAMAGE.
26-
*/
27-
28-
use core::f32;
29-
use core::ptr::read_volatile;
30-
31-
use super::fenv::{
32-
FE_INEXACT, FE_TONEAREST, FE_UNDERFLOW, feclearexcept, fegetround, feraiseexcept, fetestexcept,
33-
};
34-
35-
/*
36-
* Fused multiply-add: Compute x * y + z with a single rounding error.
37-
*
38-
* A double has more than twice as much precision than a float, so
39-
* direct double-precision arithmetic suffices, except where double
40-
* rounding occurs.
41-
*/
42-
431
/// Floating multiply add (f32)
442
///
453
/// Computes `(x*y)+z`, rounded as one ternary operation:
464
/// Computes the value (as if) to infinite precision and rounds once to the result format,
475
/// according to the rounding mode characterized by the value of FLT_ROUNDS.
486
#[cfg_attr(all(test, assert_no_panic), no_panic::no_panic)]
49-
pub fn fmaf(x: f32, y: f32, mut z: f32) -> f32 {
50-
let xy: f64;
51-
let mut result: f64;
52-
let mut ui: u64;
53-
let e: i32;
54-
55-
xy = x as f64 * y as f64;
56-
result = xy + z as f64;
57-
ui = result.to_bits();
58-
e = (ui >> 52) as i32 & 0x7ff;
59-
/* Common case: The double precision result is fine. */
60-
if (
61-
/* not a halfway case */
62-
ui & 0x1fffffff) != 0x10000000 ||
63-
/* NaN */
64-
e == 0x7ff ||
65-
/* exact */
66-
(result - xy == z as f64 && result - z as f64 == xy) ||
67-
/* not round-to-nearest */
68-
fegetround() != FE_TONEAREST
69-
{
70-
/*
71-
underflow may not be raised correctly, example:
72-
fmaf(0x1p-120f, 0x1p-120f, 0x1p-149f)
73-
*/
74-
if ((0x3ff - 149)..(0x3ff - 126)).contains(&e) && fetestexcept(FE_INEXACT) != 0 {
75-
feclearexcept(FE_INEXACT);
76-
// prevent `xy + vz` from being CSE'd with `xy + z` above
77-
let vz: f32 = unsafe { read_volatile(&z) };
78-
result = xy + vz as f64;
79-
if fetestexcept(FE_INEXACT) != 0 {
80-
feraiseexcept(FE_UNDERFLOW);
81-
} else {
82-
feraiseexcept(FE_INEXACT);
83-
}
84-
}
85-
z = result as f32;
86-
return z;
87-
}
88-
89-
/*
90-
* If result is inexact, and exactly halfway between two float values,
91-
* we need to adjust the low-order bit in the direction of the error.
92-
*/
93-
let neg = ui >> 63 != 0;
94-
let err = if neg == (z as f64 > xy) { xy - result + z as f64 } else { z as f64 - result + xy };
95-
if neg == (err < 0.0) {
96-
ui += 1;
97-
} else {
98-
ui -= 1;
99-
}
100-
f64::from_bits(ui) as f32
7+
pub fn fmaf(x: f32, y: f32, z: f32) -> f32 {
8+
super::generic::fma_wide(x, y, z)
1019
}
10210

10311
#[cfg(test)]

src/math/generic/fma.rs

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,13 @@
11
/* SPDX-License-Identifier: MIT */
2-
/* origin: musl src/math/fma.c. Ported to generic Rust algorithm in 2025, TG. */
2+
/* origin: musl src/math/{fma,fmaf}.c. Ported to generic Rust algorithm in 2025, TG. */
33

44
use core::{f32, f64};
55

6+
use super::super::fenv::{
7+
FE_INEXACT, FE_TONEAREST, FE_UNDERFLOW, feclearexcept, fegetround, feraiseexcept, fetestexcept,
8+
};
69
use super::super::support::{DInt, HInt, IntTy};
7-
use super::super::{CastFrom, CastInto, Float, Int, MinInt};
10+
use super::super::{CastFrom, CastInto, DFloat, Float, HFloat, Int, MinInt};
811

912
/// Fused multiply-add that works when there is not a larger float size available. Currently this
1013
/// is still specialized only for `f64`. Computes `(x * y) + z`.
@@ -212,6 +215,66 @@ where
212215
super::scalbn(r, e)
213216
}
214217

218+
/// Fma implementation when a hardware-backed larger float type is available. For `f32` and `f64`,
219+
/// `f64` has enough precision to represent the `f32` in its entirety, except for double rounding.
220+
pub fn fma_wide<F, B>(x: F, y: F, z: F) -> F
221+
where
222+
F: Float + HFloat<D = B>,
223+
B: Float + DFloat<H = F>,
224+
B::Int: CastInto<i32>,
225+
i32: CastFrom<i32>,
226+
{
227+
let one = IntTy::<B>::ONE;
228+
229+
let xy: B = x.widen() * y.widen();
230+
let mut result: B = xy + z.widen();
231+
let mut ui: B::Int = result.to_bits();
232+
let re = result.exp();
233+
let zb: B = z.widen();
234+
235+
let prec_diff = B::SIG_BITS - F::SIG_BITS;
236+
let excess_prec = ui & ((one << prec_diff) - one);
237+
let halfway = one << (prec_diff - 1);
238+
239+
// Common case: the larger precision is fine if...
240+
// This is not a halfway case
241+
if excess_prec != halfway
242+
// Or the result is NaN
243+
|| re == B::EXP_SAT
244+
// Or the result is exact
245+
|| (result - xy == zb && result - zb == xy)
246+
// Or the mode is something other than round to nearest
247+
|| fegetround() != FE_TONEAREST
248+
{
249+
let min_inexact_exp = (B::EXP_BIAS as i32 + F::EXP_MIN_SUBNORM) as u32;
250+
let max_inexact_exp = (B::EXP_BIAS as i32 + F::EXP_MIN) as u32;
251+
252+
if (min_inexact_exp..max_inexact_exp).contains(&re) && fetestexcept(FE_INEXACT) != 0 {
253+
feclearexcept(FE_INEXACT);
254+
// prevent `xy + vz` from being CSE'd with `xy + z` above
255+
let vz: F = force_eval!(z);
256+
result = xy + vz.widen();
257+
if fetestexcept(FE_INEXACT) != 0 {
258+
feraiseexcept(FE_UNDERFLOW);
259+
} else {
260+
feraiseexcept(FE_INEXACT);
261+
}
262+
}
263+
264+
return result.narrow();
265+
}
266+
267+
let neg = ui >> (B::BITS - 1) != IntTy::<B>::ZERO;
268+
let err = if neg == (zb > xy) { xy - result + zb } else { zb - result + xy };
269+
if neg == (err < B::ZERO) {
270+
ui += one;
271+
} else {
272+
ui -= one;
273+
}
274+
275+
B::from_bits(ui).narrow()
276+
}
277+
215278
/// Representation of `F` that has handled subnormals.
216279
#[derive(Clone, Copy, Debug)]
217280
struct Norm<F: Float> {

src/math/generic/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ pub use copysign::copysign;
1818
pub use fabs::fabs;
1919
pub use fdim::fdim;
2020
pub use floor::floor;
21-
pub use fma::fma;
21+
pub use fma::{fma, fma_wide};
2222
pub use fmax::fmax;
2323
pub use fmin::fmin;
2424
pub use fmod::fmod;

src/math/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ use self::rem_pio2::rem_pio2;
121121
use self::rem_pio2_large::rem_pio2_large;
122122
use self::rem_pio2f::rem_pio2f;
123123
#[allow(unused_imports)]
124-
use self::support::{CastFrom, CastInto, DInt, Float, HInt, Int, IntTy, MinInt};
124+
use self::support::{CastFrom, CastInto, DFloat, DInt, Float, HFloat, HInt, Int, IntTy, MinInt};
125125

126126
// Public modules
127127
mod acos;

src/math/support/float_traits.rs

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,64 @@ pub const fn f64_from_bits(bits: u64) -> f64 {
276276
unsafe { mem::transmute::<u64, f64>(bits) }
277277
}
278278

279+
/// Trait for floats twice the bit width of another integer.
280+
pub trait DFloat: Float {
281+
/// Float that is half the bit width of the floatthis trait is implemented for.
282+
type H: HFloat<D = Self>;
283+
284+
/// Narrow the float type.
285+
fn narrow(self) -> Self::H;
286+
}
287+
288+
/// Trait for floats half the bit width of another float.
289+
pub trait HFloat: Float {
290+
/// Float that is double the bit width of the float this trait is implemented for.
291+
type D: DFloat<H = Self>;
292+
293+
/// Widen the float type.
294+
fn widen(self) -> Self::D;
295+
}
296+
297+
macro_rules! impl_d_float {
298+
($($X:ident $D:ident),*) => {
299+
$(
300+
impl DFloat for $D {
301+
type H = $X;
302+
303+
fn narrow(self) -> Self::H {
304+
self as $X
305+
}
306+
}
307+
)*
308+
};
309+
}
310+
311+
macro_rules! impl_h_float {
312+
($($H:ident $X:ident),*) => {
313+
$(
314+
impl HFloat for $H {
315+
type D = $X;
316+
317+
fn widen(self) -> Self::D {
318+
self as $X
319+
}
320+
}
321+
)*
322+
};
323+
}
324+
325+
impl_d_float!(f32 f64);
326+
#[cfg(f16_enabled)]
327+
impl_d_float!(f16 f32);
328+
#[cfg(f128_enabled)]
329+
impl_d_float!(f64 f128);
330+
331+
impl_h_float!(f32 f64);
332+
#[cfg(f16_enabled)]
333+
impl_h_float!(f16 f32);
334+
#[cfg(f128_enabled)]
335+
impl_h_float!(f64 f128);
336+
279337
#[cfg(test)]
280338
mod tests {
281339
use super::*;

src/math/support/mod.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ mod float_traits;
55
pub mod hex_float;
66
mod int_traits;
77

8-
pub use float_traits::{Float, IntTy};
8+
#[allow(unused_imports)]
9+
pub use float_traits::{DFloat, Float, HFloat, IntTy};
910
pub(crate) use float_traits::{f32_from_bits, f64_from_bits};
1011
#[cfg(f16_enabled)]
1112
#[allow(unused_imports)]

0 commit comments

Comments
 (0)