Skip to content
This repository was archived by the owner on May 28, 2025. It is now read-only.

Commit f633537

Browse files
committed
Auto merge of rust-lang#2469 - RalfJung:math, r=RalfJung
implement some missing float functions With this we support the entire float API surface of the standard library. :) Also fixes rust-lang/miri#2468 by using host floats to implement FMA.
2 parents d6b750e + b1316ec commit f633537

File tree

5 files changed

+127
-51
lines changed

5 files changed

+127
-51
lines changed

src/shims/foreign_items.rs

Lines changed: 28 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -575,88 +575,106 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
575575
this.write_scalar(Scalar::from_machine_usize(u64::try_from(n).unwrap(), this), dest)?;
576576
}
577577

578-
// math functions
578+
// math functions (note that there are also intrinsics for some other functions)
579579
#[rustfmt::skip]
580580
| "cbrtf"
581581
| "coshf"
582582
| "sinhf"
583583
| "tanf"
584+
| "tanhf"
584585
| "acosf"
585586
| "asinf"
586587
| "atanf"
588+
| "log1pf"
589+
| "expm1f"
587590
=> {
588591
let [f] = this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
589592
// FIXME: Using host floats.
590593
let f = f32::from_bits(this.read_scalar(f)?.to_u32()?);
591-
let f = match link_name.as_str() {
594+
let res = match link_name.as_str() {
592595
"cbrtf" => f.cbrt(),
593596
"coshf" => f.cosh(),
594597
"sinhf" => f.sinh(),
595598
"tanf" => f.tan(),
599+
"tanhf" => f.tanh(),
596600
"acosf" => f.acos(),
597601
"asinf" => f.asin(),
598602
"atanf" => f.atan(),
603+
"log1pf" => f.ln_1p(),
604+
"expm1f" => f.exp_m1(),
599605
_ => bug!(),
600606
};
601-
this.write_scalar(Scalar::from_u32(f.to_bits()), dest)?;
607+
this.write_scalar(Scalar::from_u32(res.to_bits()), dest)?;
602608
}
603609
#[rustfmt::skip]
604610
| "_hypotf"
605611
| "hypotf"
606612
| "atan2f"
613+
| "fdimf"
607614
=> {
608615
let [f1, f2] = this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
609616
// underscore case for windows, here and below
610617
// (see https://docs.microsoft.com/en-us/cpp/c-runtime-library/reference/floating-point-primitives?view=vs-2019)
611618
// FIXME: Using host floats.
612619
let f1 = f32::from_bits(this.read_scalar(f1)?.to_u32()?);
613620
let f2 = f32::from_bits(this.read_scalar(f2)?.to_u32()?);
614-
let n = match link_name.as_str() {
621+
let res = match link_name.as_str() {
615622
"_hypotf" | "hypotf" => f1.hypot(f2),
616623
"atan2f" => f1.atan2(f2),
624+
#[allow(deprecated)]
625+
"fdimf" => f1.abs_sub(f2),
617626
_ => bug!(),
618627
};
619-
this.write_scalar(Scalar::from_u32(n.to_bits()), dest)?;
628+
this.write_scalar(Scalar::from_u32(res.to_bits()), dest)?;
620629
}
621630
#[rustfmt::skip]
622631
| "cbrt"
623632
| "cosh"
624633
| "sinh"
625634
| "tan"
635+
| "tanh"
626636
| "acos"
627637
| "asin"
628638
| "atan"
639+
| "log1p"
640+
| "expm1"
629641
=> {
630642
let [f] = this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
631643
// FIXME: Using host floats.
632644
let f = f64::from_bits(this.read_scalar(f)?.to_u64()?);
633-
let f = match link_name.as_str() {
645+
let res = match link_name.as_str() {
634646
"cbrt" => f.cbrt(),
635647
"cosh" => f.cosh(),
636648
"sinh" => f.sinh(),
637649
"tan" => f.tan(),
650+
"tanh" => f.tanh(),
638651
"acos" => f.acos(),
639652
"asin" => f.asin(),
640653
"atan" => f.atan(),
654+
"log1p" => f.ln_1p(),
655+
"expm1" => f.exp_m1(),
641656
_ => bug!(),
642657
};
643-
this.write_scalar(Scalar::from_u64(f.to_bits()), dest)?;
658+
this.write_scalar(Scalar::from_u64(res.to_bits()), dest)?;
644659
}
645660
#[rustfmt::skip]
646661
| "_hypot"
647662
| "hypot"
648663
| "atan2"
664+
| "fdim"
649665
=> {
650666
let [f1, f2] = this.check_shim(abi, Abi::C { unwind: false }, link_name, args)?;
651667
// FIXME: Using host floats.
652668
let f1 = f64::from_bits(this.read_scalar(f1)?.to_u64()?);
653669
let f2 = f64::from_bits(this.read_scalar(f2)?.to_u64()?);
654-
let n = match link_name.as_str() {
670+
let res = match link_name.as_str() {
655671
"_hypot" | "hypot" => f1.hypot(f2),
656672
"atan2" => f1.atan2(f2),
673+
#[allow(deprecated)]
674+
"fdim" => f1.abs_sub(f2),
657675
_ => bug!(),
658676
};
659-
this.write_scalar(Scalar::from_u64(n.to_bits()), dest)?;
677+
this.write_scalar(Scalar::from_u64(res.to_bits()), dest)?;
660678
}
661679
#[rustfmt::skip]
662680
| "_ldexp"
@@ -668,7 +686,7 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
668686
let x = this.read_scalar(x)?.to_f64()?;
669687
let exp = this.read_scalar(exp)?.to_i32()?;
670688

671-
// Saturating cast to i16. Even those are outside the valid exponent range to
689+
// Saturating cast to i16. Even those are outside the valid exponent range so
672690
// `scalbn` below will do its over/underflow handling.
673691
let exp = if exp > i32::from(i16::MAX) {
674692
i16::MAX

src/shims/intrinsics/mod.rs

Lines changed: 20 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -285,49 +285,55 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
285285
// FIXME: Using host floats.
286286
let f = f32::from_bits(this.read_scalar(f)?.to_u32()?);
287287
let f2 = f32::from_bits(this.read_scalar(f2)?.to_u32()?);
288-
this.write_scalar(Scalar::from_u32(f.powf(f2).to_bits()), dest)?;
288+
let res = f.powf(f2);
289+
this.write_scalar(Scalar::from_u32(res.to_bits()), dest)?;
289290
}
290291

291292
"powf64" => {
292293
let [f, f2] = check_arg_count(args)?;
293294
// FIXME: Using host floats.
294295
let f = f64::from_bits(this.read_scalar(f)?.to_u64()?);
295296
let f2 = f64::from_bits(this.read_scalar(f2)?.to_u64()?);
296-
this.write_scalar(Scalar::from_u64(f.powf(f2).to_bits()), dest)?;
297+
let res = f.powf(f2);
298+
this.write_scalar(Scalar::from_u64(res.to_bits()), dest)?;
297299
}
298300

299301
"fmaf32" => {
300302
let [a, b, c] = check_arg_count(args)?;
301-
let a = this.read_scalar(a)?.to_f32()?;
302-
let b = this.read_scalar(b)?.to_f32()?;
303-
let c = this.read_scalar(c)?.to_f32()?;
304-
let res = a.mul_add(b, c).value;
305-
this.write_scalar(Scalar::from_f32(res), dest)?;
303+
// FIXME: Using host floats, to work around https://github.com/rust-lang/miri/issues/2468.
304+
let a = f32::from_bits(this.read_scalar(a)?.to_u32()?);
305+
let b = f32::from_bits(this.read_scalar(b)?.to_u32()?);
306+
let c = f32::from_bits(this.read_scalar(c)?.to_u32()?);
307+
let res = a.mul_add(b, c);
308+
this.write_scalar(Scalar::from_u32(res.to_bits()), dest)?;
306309
}
307310

308311
"fmaf64" => {
309312
let [a, b, c] = check_arg_count(args)?;
310-
let a = this.read_scalar(a)?.to_f64()?;
311-
let b = this.read_scalar(b)?.to_f64()?;
312-
let c = this.read_scalar(c)?.to_f64()?;
313-
let res = a.mul_add(b, c).value;
314-
this.write_scalar(Scalar::from_f64(res), dest)?;
313+
// FIXME: Using host floats, to work around https://github.com/rust-lang/miri/issues/2468.
314+
let a = f64::from_bits(this.read_scalar(a)?.to_u64()?);
315+
let b = f64::from_bits(this.read_scalar(b)?.to_u64()?);
316+
let c = f64::from_bits(this.read_scalar(c)?.to_u64()?);
317+
let res = a.mul_add(b, c);
318+
this.write_scalar(Scalar::from_u64(res.to_bits()), dest)?;
315319
}
316320

317321
"powif32" => {
318322
let [f, i] = check_arg_count(args)?;
319323
// FIXME: Using host floats.
320324
let f = f32::from_bits(this.read_scalar(f)?.to_u32()?);
321325
let i = this.read_scalar(i)?.to_i32()?;
322-
this.write_scalar(Scalar::from_u32(f.powi(i).to_bits()), dest)?;
326+
let res = f.powi(i);
327+
this.write_scalar(Scalar::from_u32(res.to_bits()), dest)?;
323328
}
324329

325330
"powif64" => {
326331
let [f, i] = check_arg_count(args)?;
327332
// FIXME: Using host floats.
328333
let f = f64::from_bits(this.read_scalar(f)?.to_u64()?);
329334
let i = this.read_scalar(i)?.to_i32()?;
330-
this.write_scalar(Scalar::from_u64(f.powi(i).to_bits()), dest)?;
335+
let res = f.powi(i);
336+
this.write_scalar(Scalar::from_u64(res.to_bits()), dest)?;
331337
}
332338

333339
"float_to_int_unchecked" => {

src/shims/intrinsics/simd.rs

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -238,14 +238,25 @@ pub trait EvalContextExt<'mir, 'tcx: 'mir>: crate::MiriEvalContextExt<'mir, 'tcx
238238
let dest = this.mplace_index(&dest, i)?;
239239

240240
// Works for f32 and f64.
241+
// FIXME: using host floats to work around https://github.com/rust-lang/miri/issues/2468.
241242
let ty::Float(float_ty) = dest.layout.ty.kind() else {
242243
span_bug!(this.cur_span(), "{} operand is not a float", intrinsic_name)
243244
};
244245
let val = match float_ty {
245-
FloatTy::F32 =>
246-
Scalar::from_f32(a.to_f32()?.mul_add(b.to_f32()?, c.to_f32()?).value),
247-
FloatTy::F64 =>
248-
Scalar::from_f64(a.to_f64()?.mul_add(b.to_f64()?, c.to_f64()?).value),
246+
FloatTy::F32 => {
247+
let a = f32::from_bits(a.to_u32()?);
248+
let b = f32::from_bits(b.to_u32()?);
249+
let c = f32::from_bits(c.to_u32()?);
250+
let res = a.mul_add(b, c);
251+
Scalar::from_u32(res.to_bits())
252+
}
253+
FloatTy::F64 => {
254+
let a = f64::from_bits(a.to_u64()?);
255+
let b = f64::from_bits(b.to_u64()?);
256+
let c = f64::from_bits(c.to_u64()?);
257+
let res = a.mul_add(b, c);
258+
Scalar::from_u64(res.to_bits())
259+
}
249260
};
250261
this.write_scalar(val, &dest.into())?;
251262
}

tests/pass/intrinsics-math.rs

Lines changed: 54 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -32,24 +32,24 @@ pub fn main() {
3232
assert_approx_eq!(25f32.powi(-2), 0.0016f32);
3333
assert_approx_eq!(23.2f64.powi(2), 538.24f64);
3434

35-
assert_approx_eq!(0f32.sin(), 0f32);
36-
assert_approx_eq!((f64::consts::PI / 2f64).sin(), 1f64);
37-
38-
assert_approx_eq!(0f32.cos(), 1f32);
39-
assert_approx_eq!((f64::consts::PI * 2f64).cos(), 1f64);
40-
4135
assert_approx_eq!(25f32.powf(-2f32), 0.0016f32);
4236
assert_approx_eq!(400f64.powf(0.5f64), 20f64);
4337

44-
assert_approx_eq!((1f32.exp() - f32::consts::E).abs(), 0f32);
38+
assert_approx_eq!(1f32.exp(), f32::consts::E);
4539
assert_approx_eq!(1f64.exp(), f64::consts::E);
4640

41+
assert_approx_eq!(1f32.exp_m1(), f32::consts::E - 1.0);
42+
assert_approx_eq!(1f64.exp_m1(), f64::consts::E - 1.0);
43+
4744
assert_approx_eq!(10f32.exp2(), 1024f32);
4845
assert_approx_eq!(50f64.exp2(), 1125899906842624f64);
4946

50-
assert_approx_eq!((f32::consts::E.ln() - 1f32).abs(), 0f32);
47+
assert_approx_eq!(f32::consts::E.ln(), 1f32);
5148
assert_approx_eq!(1f64.ln(), 0f64);
5249

50+
assert_approx_eq!(0f32.ln_1p(), 0f32);
51+
assert_approx_eq!(0f64.ln_1p(), 0f64);
52+
5353
assert_approx_eq!(10f32.log10(), 1f32);
5454
assert_approx_eq!(f64::consts::E.log10(), f64::consts::LOG10_E);
5555

@@ -60,10 +60,18 @@ pub fn main() {
6060
assert_eq!(0.0f32.mul_add(-2.0, f32::consts::E), f32::consts::E);
6161
assert_approx_eq!(3.0f64.mul_add(2.0, 5.0), 11.0);
6262
assert_eq!(0.0f64.mul_add(-2.0f64, f64::consts::E), f64::consts::E);
63+
assert_eq!((-3.2f32).mul_add(2.4, f32::NEG_INFINITY), f32::NEG_INFINITY);
64+
assert_eq!((-3.2f64).mul_add(2.4, f64::NEG_INFINITY), f64::NEG_INFINITY);
6365

6466
assert_approx_eq!((-1.0f32).abs(), 1.0f32);
6567
assert_approx_eq!(34.2f64.abs(), 34.2f64);
6668

69+
#[allow(deprecated)]
70+
{
71+
assert_approx_eq!(5.0f32.abs_sub(3.0), 2.0);
72+
assert_approx_eq!(3.0f64.abs_sub(5.0), 0.0);
73+
}
74+
6775
assert_approx_eq!(3.8f32.floor(), 3.0f32);
6876
assert_approx_eq!((-1.1f64).floor(), -2.0f64);
6977

@@ -79,31 +87,54 @@ pub fn main() {
7987
assert_approx_eq!(3.0f32.hypot(4.0f32), 5.0f32);
8088
assert_approx_eq!(3.0f64.hypot(4.0f64), 5.0f64);
8189

82-
assert_approx_eq!(1.0f32.atan2(2.0f32), 0.46364761f32);
83-
assert_approx_eq!(1.0f32.atan2(2.0f32), 0.46364761f32);
90+
assert_eq!(3.3_f32.round(), 3.0);
91+
assert_eq!(3.3_f64.round(), 3.0);
8492

85-
assert_approx_eq!(1.0f32.cosh(), 1.54308f32);
86-
assert_approx_eq!(1.0f64.cosh(), 1.54308f64);
93+
assert_eq!(ldexp(0.65f64, 3i32), 5.2f64);
94+
assert_eq!(ldexp(1.42, 0xFFFF), f64::INFINITY);
95+
assert_eq!(ldexp(1.42, -0xFFFF), 0f64);
96+
97+
// Trigonometric functions.
98+
99+
assert_approx_eq!(0f32.sin(), 0f32);
100+
assert_approx_eq!((f64::consts::PI / 2f64).sin(), 1f64);
101+
assert_approx_eq!(f32::consts::FRAC_PI_6.sin(), 0.5);
102+
assert_approx_eq!(f64::consts::FRAC_PI_6.sin(), 0.5);
103+
assert_approx_eq!(f32::consts::FRAC_PI_4.sin().asin(), f32::consts::FRAC_PI_4);
104+
assert_approx_eq!(f64::consts::FRAC_PI_4.sin().asin(), f64::consts::FRAC_PI_4);
87105

88106
assert_approx_eq!(1.0f32.sinh(), 1.1752012f32);
89107
assert_approx_eq!(1.0f64.sinh(), 1.1752012f64);
108+
assert_approx_eq!(2.0f32.asinh(), 1.443635475178810342493276740273105f32);
109+
assert_approx_eq!((-2.0f64).asinh(), -1.443635475178810342493276740273105f64);
90110

91-
assert_approx_eq!(1.0f32.tan(), 1.557408f32);
92-
assert_approx_eq!(1.0f64.tan(), 1.557408f64);
93-
111+
assert_approx_eq!(0f32.cos(), 1f32);
112+
assert_approx_eq!((f64::consts::PI * 2f64).cos(), 1f64);
113+
assert_approx_eq!(f32::consts::FRAC_PI_3.cos(), 0.5);
114+
assert_approx_eq!(f64::consts::FRAC_PI_3.cos(), 0.5);
94115
assert_approx_eq!(f32::consts::FRAC_PI_4.cos().acos(), f32::consts::FRAC_PI_4);
95116
assert_approx_eq!(f64::consts::FRAC_PI_4.cos().acos(), f64::consts::FRAC_PI_4);
96117

97-
assert_approx_eq!(f32::consts::FRAC_PI_4.sin().asin(), f32::consts::FRAC_PI_4);
98-
assert_approx_eq!(f64::consts::FRAC_PI_4.sin().asin(), f64::consts::FRAC_PI_4);
118+
assert_approx_eq!(1.0f32.cosh(), 1.54308f32);
119+
assert_approx_eq!(1.0f64.cosh(), 1.54308f64);
120+
assert_approx_eq!(2.0f32.acosh(), 1.31695789692481670862504634730796844f32);
121+
assert_approx_eq!(3.0f64.acosh(), 1.76274717403908605046521864995958461f64);
99122

123+
assert_approx_eq!(1.0f32.tan(), 1.557408f32);
124+
assert_approx_eq!(1.0f64.tan(), 1.557408f64);
100125
assert_approx_eq!(1.0_f32, 1.0_f32.tan().atan());
101126
assert_approx_eq!(1.0_f64, 1.0_f64.tan().atan());
127+
assert_approx_eq!(1.0f32.atan2(2.0f32), 0.46364761f32);
128+
assert_approx_eq!(1.0f32.atan2(2.0f32), 0.46364761f32);
102129

103-
assert_eq!(3.3_f32.round(), 3.0);
104-
assert_eq!(3.3_f64.round(), 3.0);
105-
106-
assert_eq!(ldexp(0.65f64, 3i32), 5.2f64);
107-
assert_eq!(ldexp(1.42, 0xFFFF), f64::INFINITY);
108-
assert_eq!(ldexp(1.42, -0xFFFF), 0f64);
130+
assert_approx_eq!(
131+
1.0f32.tanh(),
132+
(1.0 - f32::consts::E.powi(-2)) / (1.0 + f32::consts::E.powi(-2))
133+
);
134+
assert_approx_eq!(
135+
1.0f64.tanh(),
136+
(1.0 - f64::consts::E.powi(-2)) / (1.0 + f64::consts::E.powi(-2))
137+
);
138+
assert_approx_eq!(0.5f32.atanh(), 0.54930614433405484569762261846126285f32);
139+
assert_approx_eq!(0.5f64.atanh(), 0.54930614433405484569762261846126285f64);
109140
}

tests/pass/portable-simd.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@ fn simd_ops_f32() {
1818

1919
assert_eq!(a.mul_add(b, a), (a * b) + a);
2020
assert_eq!(b.mul_add(b, a), (b * b) + a);
21+
assert_eq!(a.mul_add(b, b), (a * b) + b);
22+
assert_eq!(
23+
f32x4::splat(-3.2).mul_add(b, f32x4::splat(f32::NEG_INFINITY)),
24+
f32x4::splat(f32::NEG_INFINITY)
25+
);
2126
assert_eq!((a * a).sqrt(), a);
2227
assert_eq!((b * b).sqrt(), b.abs());
2328

@@ -67,6 +72,11 @@ fn simd_ops_f64() {
6772

6873
assert_eq!(a.mul_add(b, a), (a * b) + a);
6974
assert_eq!(b.mul_add(b, a), (b * b) + a);
75+
assert_eq!(a.mul_add(b, b), (a * b) + b);
76+
assert_eq!(
77+
f64x4::splat(-3.2).mul_add(b, f64x4::splat(f64::NEG_INFINITY)),
78+
f64x4::splat(f64::NEG_INFINITY)
79+
);
7080
assert_eq!((a * a).sqrt(), a);
7181
assert_eq!((b * b).sqrt(), b.abs());
7282

0 commit comments

Comments
 (0)