Skip to content

Commit bb8b11e

Browse files
committed
Auto merge of #120718 - saethlin:reasonable-fast-math, r=nnethercote
Add "algebraic" fast-math intrinsics, based on fast-math ops that cannot return poison Setting all of LLVM's fast-math flags makes our fast-math intrinsics very dangerous, because some inputs are UB. This set of flags permits common algebraic transformations, but according to the [LangRef](https://llvm.org/docs/LangRef.html#fastmath), only the flags `nnan` (no nans) and `ninf` (no infs) can produce poison. And this uses the algebraic float ops to fix #120720 cc `@orlp`
2 parents 7168c13 + cc73b71 commit bb8b11e

File tree

12 files changed

+226
-14
lines changed

12 files changed

+226
-14
lines changed

Diff for: compiler/rustc_codegen_cranelift/src/intrinsics/mod.rs

+15-6
Original file line numberDiff line numberDiff line change
@@ -1152,17 +1152,26 @@ fn codegen_regular_intrinsic_call<'tcx>(
11521152
ret.write_cvalue(fx, ret_val);
11531153
}
11541154

1155-
sym::fadd_fast | sym::fsub_fast | sym::fmul_fast | sym::fdiv_fast | sym::frem_fast => {
1155+
sym::fadd_fast
1156+
| sym::fsub_fast
1157+
| sym::fmul_fast
1158+
| sym::fdiv_fast
1159+
| sym::frem_fast
1160+
| sym::fadd_algebraic
1161+
| sym::fsub_algebraic
1162+
| sym::fmul_algebraic
1163+
| sym::fdiv_algebraic
1164+
| sym::frem_algebraic => {
11561165
intrinsic_args!(fx, args => (x, y); intrinsic);
11571166

11581167
let res = crate::num::codegen_float_binop(
11591168
fx,
11601169
match intrinsic {
1161-
sym::fadd_fast => BinOp::Add,
1162-
sym::fsub_fast => BinOp::Sub,
1163-
sym::fmul_fast => BinOp::Mul,
1164-
sym::fdiv_fast => BinOp::Div,
1165-
sym::frem_fast => BinOp::Rem,
1170+
sym::fadd_fast | sym::fadd_algebraic => BinOp::Add,
1171+
sym::fsub_fast | sym::fsub_algebraic => BinOp::Sub,
1172+
sym::fmul_fast | sym::fmul_algebraic => BinOp::Mul,
1173+
sym::fdiv_fast | sym::fdiv_algebraic => BinOp::Div,
1174+
sym::frem_fast | sym::frem_algebraic => BinOp::Rem,
11661175
_ => unreachable!(),
11671176
},
11681177
x,

Diff for: compiler/rustc_codegen_gcc/src/builder.rs

+25
Original file line numberDiff line numberDiff line change
@@ -705,6 +705,31 @@ impl<'a, 'gcc, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'gcc, 'tcx> {
705705
self.frem(lhs, rhs)
706706
}
707707

708+
fn fadd_algebraic(&mut self, lhs: RValue<'gcc>, rhs: RValue<'gcc>) -> RValue<'gcc> {
709+
// NOTE: it seems like we cannot enable fast-mode for a single operation in GCC.
710+
lhs + rhs
711+
}
712+
713+
fn fsub_algebraic(&mut self, lhs: RValue<'gcc>, rhs: RValue<'gcc>) -> RValue<'gcc> {
714+
// NOTE: it seems like we cannot enable fast-mode for a single operation in GCC.
715+
lhs - rhs
716+
}
717+
718+
fn fmul_algebraic(&mut self, lhs: RValue<'gcc>, rhs: RValue<'gcc>) -> RValue<'gcc> {
719+
// NOTE: it seems like we cannot enable fast-mode for a single operation in GCC.
720+
lhs * rhs
721+
}
722+
723+
fn fdiv_algebraic(&mut self, lhs: RValue<'gcc>, rhs: RValue<'gcc>) -> RValue<'gcc> {
724+
// NOTE: it seems like we cannot enable fast-mode for a single operation in GCC.
725+
lhs / rhs
726+
}
727+
728+
fn frem_algebraic(&mut self, lhs: RValue<'gcc>, rhs: RValue<'gcc>) -> RValue<'gcc> {
729+
// NOTE: it seems like we cannot enable fast-mode for a single operation in GCC.
730+
self.frem(lhs, rhs)
731+
}
732+
708733
fn checked_binop(&mut self, oop: OverflowOp, typ: Ty<'_>, lhs: Self::Value, rhs: Self::Value) -> (Self::Value, Self::Value) {
709734
self.gcc_checked_binop(oop, typ, lhs, rhs)
710735
}

Diff for: compiler/rustc_codegen_llvm/src/builder.rs

+44-4
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,46 @@ impl<'a, 'll, 'tcx> BuilderMethods<'a, 'tcx> for Builder<'a, 'll, 'tcx> {
340340
}
341341
}
342342

343+
fn fadd_algebraic(&mut self, lhs: &'ll Value, rhs: &'ll Value) -> &'ll Value {
344+
unsafe {
345+
let instr = llvm::LLVMBuildFAdd(self.llbuilder, lhs, rhs, UNNAMED);
346+
llvm::LLVMRustSetAlgebraicMath(instr);
347+
instr
348+
}
349+
}
350+
351+
fn fsub_algebraic(&mut self, lhs: &'ll Value, rhs: &'ll Value) -> &'ll Value {
352+
unsafe {
353+
let instr = llvm::LLVMBuildFSub(self.llbuilder, lhs, rhs, UNNAMED);
354+
llvm::LLVMRustSetAlgebraicMath(instr);
355+
instr
356+
}
357+
}
358+
359+
fn fmul_algebraic(&mut self, lhs: &'ll Value, rhs: &'ll Value) -> &'ll Value {
360+
unsafe {
361+
let instr = llvm::LLVMBuildFMul(self.llbuilder, lhs, rhs, UNNAMED);
362+
llvm::LLVMRustSetAlgebraicMath(instr);
363+
instr
364+
}
365+
}
366+
367+
fn fdiv_algebraic(&mut self, lhs: &'ll Value, rhs: &'ll Value) -> &'ll Value {
368+
unsafe {
369+
let instr = llvm::LLVMBuildFDiv(self.llbuilder, lhs, rhs, UNNAMED);
370+
llvm::LLVMRustSetAlgebraicMath(instr);
371+
instr
372+
}
373+
}
374+
375+
fn frem_algebraic(&mut self, lhs: &'ll Value, rhs: &'ll Value) -> &'ll Value {
376+
unsafe {
377+
let instr = llvm::LLVMBuildFRem(self.llbuilder, lhs, rhs, UNNAMED);
378+
llvm::LLVMRustSetAlgebraicMath(instr);
379+
instr
380+
}
381+
}
382+
343383
fn checked_binop(
344384
&mut self,
345385
oop: OverflowOp,
@@ -1327,17 +1367,17 @@ impl<'a, 'll, 'tcx> Builder<'a, 'll, 'tcx> {
13271367
pub fn vector_reduce_fmul(&mut self, acc: &'ll Value, src: &'ll Value) -> &'ll Value {
13281368
unsafe { llvm::LLVMRustBuildVectorReduceFMul(self.llbuilder, acc, src) }
13291369
}
1330-
pub fn vector_reduce_fadd_fast(&mut self, acc: &'ll Value, src: &'ll Value) -> &'ll Value {
1370+
pub fn vector_reduce_fadd_algebraic(&mut self, acc: &'ll Value, src: &'ll Value) -> &'ll Value {
13311371
unsafe {
13321372
let instr = llvm::LLVMRustBuildVectorReduceFAdd(self.llbuilder, acc, src);
1333-
llvm::LLVMRustSetFastMath(instr);
1373+
llvm::LLVMRustSetAlgebraicMath(instr);
13341374
instr
13351375
}
13361376
}
1337-
pub fn vector_reduce_fmul_fast(&mut self, acc: &'ll Value, src: &'ll Value) -> &'ll Value {
1377+
pub fn vector_reduce_fmul_algebraic(&mut self, acc: &'ll Value, src: &'ll Value) -> &'ll Value {
13381378
unsafe {
13391379
let instr = llvm::LLVMRustBuildVectorReduceFMul(self.llbuilder, acc, src);
1340-
llvm::LLVMRustSetFastMath(instr);
1380+
llvm::LLVMRustSetAlgebraicMath(instr);
13411381
instr
13421382
}
13431383
}

Diff for: compiler/rustc_codegen_llvm/src/intrinsic.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -1880,14 +1880,14 @@ fn generic_simd_intrinsic<'ll, 'tcx>(
18801880
arith_red!(simd_reduce_mul_ordered: vector_reduce_mul, vector_reduce_fmul, true, mul, 1.0);
18811881
arith_red!(
18821882
simd_reduce_add_unordered: vector_reduce_add,
1883-
vector_reduce_fadd_fast,
1883+
vector_reduce_fadd_algebraic,
18841884
false,
18851885
add,
18861886
0.0
18871887
);
18881888
arith_red!(
18891889
simd_reduce_mul_unordered: vector_reduce_mul,
1890-
vector_reduce_fmul_fast,
1890+
vector_reduce_fmul_algebraic,
18911891
false,
18921892
mul,
18931893
1.0

Diff for: compiler/rustc_codegen_llvm/src/llvm/ffi.rs

+1
Original file line numberDiff line numberDiff line change
@@ -1618,6 +1618,7 @@ extern "C" {
16181618
) -> &'a Value;
16191619

16201620
pub fn LLVMRustSetFastMath(Instr: &Value);
1621+
pub fn LLVMRustSetAlgebraicMath(Instr: &Value);
16211622

16221623
// Miscellaneous instructions
16231624
pub fn LLVMRustGetInstrProfIncrementIntrinsic(M: &Module) -> &Value;

Diff for: compiler/rustc_codegen_ssa/src/mir/intrinsic.rs

+32
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,38 @@ impl<'a, 'tcx, Bx: BuilderMethods<'a, 'tcx>> FunctionCx<'a, 'tcx, Bx> {
250250
}
251251
}
252252
}
253+
sym::fadd_algebraic
254+
| sym::fsub_algebraic
255+
| sym::fmul_algebraic
256+
| sym::fdiv_algebraic
257+
| sym::frem_algebraic => match float_type_width(arg_tys[0]) {
258+
Some(_width) => match name {
259+
sym::fadd_algebraic => {
260+
bx.fadd_algebraic(args[0].immediate(), args[1].immediate())
261+
}
262+
sym::fsub_algebraic => {
263+
bx.fsub_algebraic(args[0].immediate(), args[1].immediate())
264+
}
265+
sym::fmul_algebraic => {
266+
bx.fmul_algebraic(args[0].immediate(), args[1].immediate())
267+
}
268+
sym::fdiv_algebraic => {
269+
bx.fdiv_algebraic(args[0].immediate(), args[1].immediate())
270+
}
271+
sym::frem_algebraic => {
272+
bx.frem_algebraic(args[0].immediate(), args[1].immediate())
273+
}
274+
_ => bug!(),
275+
},
276+
None => {
277+
bx.tcx().dcx().emit_err(InvalidMonomorphization::BasicFloatType {
278+
span,
279+
name,
280+
ty: arg_tys[0],
281+
});
282+
return Ok(());
283+
}
284+
},
253285

254286
sym::float_to_int_unchecked => {
255287
if float_type_width(arg_tys[0]).is_none() {

Diff for: compiler/rustc_codegen_ssa/src/traits/builder.rs

+5
Original file line numberDiff line numberDiff line change
@@ -86,22 +86,27 @@ pub trait BuilderMethods<'a, 'tcx>:
8686
fn add(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
8787
fn fadd(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
8888
fn fadd_fast(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
89+
fn fadd_algebraic(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
8990
fn sub(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
9091
fn fsub(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
9192
fn fsub_fast(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
93+
fn fsub_algebraic(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
9294
fn mul(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
9395
fn fmul(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
9496
fn fmul_fast(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
97+
fn fmul_algebraic(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
9598
fn udiv(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
9699
fn exactudiv(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
97100
fn sdiv(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
98101
fn exactsdiv(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
99102
fn fdiv(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
100103
fn fdiv_fast(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
104+
fn fdiv_algebraic(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
101105
fn urem(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
102106
fn srem(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
103107
fn frem(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
104108
fn frem_fast(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
109+
fn frem_algebraic(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
105110
fn shl(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
106111
fn lshr(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;
107112
fn ashr(&mut self, lhs: Self::Value, rhs: Self::Value) -> Self::Value;

Diff for: compiler/rustc_hir_analysis/src/check/intrinsic.rs

+11-1
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,12 @@ pub fn intrinsic_operation_unsafety(tcx: TyCtxt<'_>, intrinsic_id: LocalDefId) -
123123
| sym::variant_count
124124
| sym::is_val_statically_known
125125
| sym::ptr_mask
126-
| sym::debug_assertions => hir::Unsafety::Normal,
126+
| sym::debug_assertions
127+
| sym::fadd_algebraic
128+
| sym::fsub_algebraic
129+
| sym::fmul_algebraic
130+
| sym::fdiv_algebraic
131+
| sym::frem_algebraic => hir::Unsafety::Normal,
127132
_ => hir::Unsafety::Unsafe,
128133
};
129134

@@ -405,6 +410,11 @@ pub fn check_intrinsic_type(
405410
sym::fadd_fast | sym::fsub_fast | sym::fmul_fast | sym::fdiv_fast | sym::frem_fast => {
406411
(1, 0, vec![param(0), param(0)], param(0))
407412
}
413+
sym::fadd_algebraic
414+
| sym::fsub_algebraic
415+
| sym::fmul_algebraic
416+
| sym::fdiv_algebraic
417+
| sym::frem_algebraic => (1, 0, vec![param(0), param(0)], param(0)),
408418
sym::float_to_int_unchecked => (2, 0, vec![param(0)], param(1)),
409419

410420
sym::assume => (0, 1, vec![tcx.types.bool], Ty::new_unit(tcx)),

Diff for: compiler/rustc_llvm/llvm-wrapper/RustWrapper.cpp

+24-1
Original file line numberDiff line numberDiff line change
@@ -418,7 +418,11 @@ extern "C" LLVMAttributeRef LLVMRustCreateMemoryEffectsAttr(LLVMContextRef C,
418418
}
419419
}
420420

421-
// Enable a fast-math flag
421+
// Enable all fast-math flags, including those which will cause floating-point operations
422+
// to return poison for some well-defined inputs. This function can only be used to build
423+
// unsafe Rust intrinsics. That unsafety does permit additional optimizations, but at the
424+
// time of writing, their value is not well-understood relative to those enabled by
425+
// LLVMRustSetAlgebraicMath.
422426
//
423427
// https://llvm.org/docs/LangRef.html#fast-math-flags
424428
extern "C" void LLVMRustSetFastMath(LLVMValueRef V) {
@@ -427,6 +431,25 @@ extern "C" void LLVMRustSetFastMath(LLVMValueRef V) {
427431
}
428432
}
429433

434+
// Enable fast-math flags which permit algebraic transformations that are not allowed by
435+
// IEEE floating point. For example:
436+
// a + (b + c) = (a + b) + c
437+
// and
438+
// a / b = a * (1 / b)
439+
// Note that this does NOT enable any flags which can cause a floating-point operation on
440+
// well-defined inputs to return poison, and therefore this function can be used to build
441+
// safe Rust intrinsics (such as fadd_algebraic).
442+
//
443+
// https://llvm.org/docs/LangRef.html#fast-math-flags
444+
extern "C" void LLVMRustSetAlgebraicMath(LLVMValueRef V) {
445+
if (auto I = dyn_cast<Instruction>(unwrap<Value>(V))) {
446+
I->setHasAllowReassoc(true);
447+
I->setHasAllowContract(true);
448+
I->setHasAllowReciprocal(true);
449+
I->setHasNoSignedZeros(true);
450+
}
451+
}
452+
430453
extern "C" LLVMValueRef
431454
LLVMRustBuildAtomicLoad(LLVMBuilderRef B, LLVMTypeRef Ty, LLVMValueRef Source,
432455
const char *Name, LLVMAtomicOrdering Order) {

Diff for: compiler/rustc_span/src/symbol.rs

+5
Original file line numberDiff line numberDiff line change
@@ -764,8 +764,10 @@ symbols! {
764764
f64_nan,
765765
fabsf32,
766766
fabsf64,
767+
fadd_algebraic,
767768
fadd_fast,
768769
fake_variadic,
770+
fdiv_algebraic,
769771
fdiv_fast,
770772
feature,
771773
fence,
@@ -785,6 +787,7 @@ symbols! {
785787
fmaf32,
786788
fmaf64,
787789
fmt,
790+
fmul_algebraic,
788791
fmul_fast,
789792
fn_align,
790793
fn_delegation,
@@ -810,6 +813,7 @@ symbols! {
810813
format_unsafe_arg,
811814
freeze,
812815
freg,
816+
frem_algebraic,
813817
frem_fast,
814818
from,
815819
from_desugaring,
@@ -823,6 +827,7 @@ symbols! {
823827
from_usize,
824828
from_yeet,
825829
fs_create_dir,
830+
fsub_algebraic,
826831
fsub_fast,
827832
fundamental,
828833
future,

Diff for: library/core/src/intrinsics.rs

+40
Original file line numberDiff line numberDiff line change
@@ -1898,6 +1898,46 @@ extern "rust-intrinsic" {
18981898
#[rustc_nounwind]
18991899
pub fn frem_fast<T: Copy>(a: T, b: T) -> T;
19001900

1901+
/// Float addition that allows optimizations based on algebraic rules.
1902+
///
1903+
/// This intrinsic does not have a stable counterpart.
1904+
#[rustc_nounwind]
1905+
#[rustc_safe_intrinsic]
1906+
#[cfg(not(bootstrap))]
1907+
pub fn fadd_algebraic<T: Copy>(a: T, b: T) -> T;
1908+
1909+
/// Float subtraction that allows optimizations based on algebraic rules.
1910+
///
1911+
/// This intrinsic does not have a stable counterpart.
1912+
#[rustc_nounwind]
1913+
#[rustc_safe_intrinsic]
1914+
#[cfg(not(bootstrap))]
1915+
pub fn fsub_algebraic<T: Copy>(a: T, b: T) -> T;
1916+
1917+
/// Float multiplication that allows optimizations based on algebraic rules.
1918+
///
1919+
/// This intrinsic does not have a stable counterpart.
1920+
#[rustc_nounwind]
1921+
#[rustc_safe_intrinsic]
1922+
#[cfg(not(bootstrap))]
1923+
pub fn fmul_algebraic<T: Copy>(a: T, b: T) -> T;
1924+
1925+
/// Float division that allows optimizations based on algebraic rules.
1926+
///
1927+
/// This intrinsic does not have a stable counterpart.
1928+
#[rustc_nounwind]
1929+
#[rustc_safe_intrinsic]
1930+
#[cfg(not(bootstrap))]
1931+
pub fn fdiv_algebraic<T: Copy>(a: T, b: T) -> T;
1932+
1933+
/// Float remainder that allows optimizations based on algebraic rules.
1934+
///
1935+
/// This intrinsic does not have a stable counterpart.
1936+
#[rustc_nounwind]
1937+
#[rustc_safe_intrinsic]
1938+
#[cfg(not(bootstrap))]
1939+
pub fn frem_algebraic<T: Copy>(a: T, b: T) -> T;
1940+
19011941
/// Convert with LLVM’s fptoui/fptosi, which may return undef for values out of range
19021942
/// (<https://github.com/rust-lang/rust/issues/10184>)
19031943
///

Diff for: tests/codegen/simd/issue-120720-reduce-nan.rs

+22
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
// compile-flags: -C opt-level=3 -C target-cpu=cannonlake
2+
// only-x86_64
3+
4+
// In a previous implementation, _mm512_reduce_add_pd did the reduction with all fast-math flags
5+
// enabled, making it UB to reduce a vector containing a NaN.
6+
7+
#![crate_type = "lib"]
8+
#![feature(stdarch_x86_avx512, avx512_target_feature)]
9+
use std::arch::x86_64::*;
10+
11+
// CHECK-label: @demo(
12+
#[no_mangle]
13+
#[target_feature(enable = "avx512f")] // Function-level target feature mismatches inhibit inlining
14+
pub unsafe fn demo() -> bool {
15+
// CHECK: %0 = tail call reassoc nsz arcp contract double @llvm.vector.reduce.fadd.v8f64(
16+
// CHECK: %_0.i = fcmp uno double %0, 0.000000e+00
17+
// CHECK: ret i1 %_0.i
18+
let res = unsafe {
19+
_mm512_reduce_add_pd(_mm512_set_pd(0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, f64::NAN))
20+
};
21+
res.is_nan()
22+
}

0 commit comments

Comments
 (0)