Skip to content

Commit 709716b

Browse files
authored
cranelift: Implement scalar FMA on x86 (#4460)
x86 does not have dedicated instructions for scalar FMA, lower to a libcall which seems to be what llvm does.
1 parent ff6082c commit 709716b

File tree

13 files changed

+167
-50
lines changed

13 files changed

+167
-50
lines changed

cranelift/codegen/src/ir/libcall.rs

+56-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
//! Naming well-known routines in the runtime library.
22
3-
use crate::ir::{types, ExternalName, FuncRef, Function, Opcode, Type};
3+
use crate::ir::{types, AbiParam, ExternalName, FuncRef, Function, Opcode, Signature, Type};
4+
use crate::isa::CallConv;
45
use core::fmt;
56
use core::str::FromStr;
67
#[cfg(feature = "enable-serde")]
@@ -50,6 +51,10 @@ pub enum LibCall {
5051
NearestF32,
5152
/// nearest.f64
5253
NearestF64,
54+
/// fma.f32
55+
FmaF32,
56+
/// fma.f64
57+
FmaF64,
5358
/// libc.memcpy
5459
Memcpy,
5560
/// libc.memset
@@ -91,6 +96,8 @@ impl FromStr for LibCall {
9196
"TruncF64" => Ok(Self::TruncF64),
9297
"NearestF32" => Ok(Self::NearestF32),
9398
"NearestF64" => Ok(Self::NearestF64),
99+
"FmaF32" => Ok(Self::FmaF32),
100+
"FmaF64" => Ok(Self::FmaF64),
94101
"Memcpy" => Ok(Self::Memcpy),
95102
"Memset" => Ok(Self::Memset),
96103
"Memmove" => Ok(Self::Memmove),
@@ -124,13 +131,15 @@ impl LibCall {
124131
Opcode::Floor => Self::FloorF32,
125132
Opcode::Trunc => Self::TruncF32,
126133
Opcode::Nearest => Self::NearestF32,
134+
Opcode::Fma => Self::FmaF32,
127135
_ => return None,
128136
},
129137
types::F64 => match opcode {
130138
Opcode::Ceil => Self::CeilF64,
131139
Opcode::Floor => Self::FloorF64,
132140
Opcode::Trunc => Self::TruncF64,
133141
Opcode::Nearest => Self::NearestF64,
142+
Opcode::Fma => Self::FmaF64,
134143
_ => return None,
135144
},
136145
_ => return None,
@@ -157,13 +166,59 @@ impl LibCall {
157166
TruncF64,
158167
NearestF32,
159168
NearestF64,
169+
FmaF32,
170+
FmaF64,
160171
Memcpy,
161172
Memset,
162173
Memmove,
163174
Memcmp,
164175
ElfTlsGetAddr,
165176
]
166177
}
178+
179+
/// Get a [Signature] for the function targeted by this [LibCall].
180+
pub fn signature(&self, call_conv: CallConv) -> Signature {
181+
use types::*;
182+
let mut sig = Signature::new(call_conv);
183+
184+
match self {
185+
LibCall::UdivI64
186+
| LibCall::SdivI64
187+
| LibCall::UremI64
188+
| LibCall::SremI64
189+
| LibCall::IshlI64
190+
| LibCall::UshrI64
191+
| LibCall::SshrI64 => {
192+
sig.params.push(AbiParam::new(I64));
193+
sig.params.push(AbiParam::new(I64));
194+
sig.returns.push(AbiParam::new(I64));
195+
}
196+
LibCall::CeilF32 | LibCall::FloorF32 | LibCall::TruncF32 | LibCall::NearestF32 => {
197+
sig.params.push(AbiParam::new(F32));
198+
sig.returns.push(AbiParam::new(F32));
199+
}
200+
LibCall::TruncF64 | LibCall::FloorF64 | LibCall::CeilF64 | LibCall::NearestF64 => {
201+
sig.params.push(AbiParam::new(F64));
202+
sig.returns.push(AbiParam::new(F64));
203+
}
204+
LibCall::FmaF32 | LibCall::FmaF64 => {
205+
let ty = if *self == LibCall::FmaF32 { F32 } else { F64 };
206+
207+
sig.params.push(AbiParam::new(ty));
208+
sig.params.push(AbiParam::new(ty));
209+
sig.params.push(AbiParam::new(ty));
210+
sig.returns.push(AbiParam::new(ty));
211+
}
212+
LibCall::Probestack
213+
| LibCall::Memcpy
214+
| LibCall::Memset
215+
| LibCall::Memmove
216+
| LibCall::Memcmp
217+
| LibCall::ElfTlsGetAddr => unimplemented!(),
218+
}
219+
220+
sig
221+
}
167222
}
168223

169224
/// Get a function reference for the probestack function in `func`.

cranelift/codegen/src/isa/aarch64/lower.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1551,7 +1551,7 @@ impl LowerBackend for AArch64Backend {
15511551
type MInst = Inst;
15521552

15531553
fn lower<C: LowerCtx<I = Inst>>(&self, ctx: &mut C, ir_inst: IRInst) -> CodegenResult<()> {
1554-
lower_inst::lower_insn_to_regs(ctx, ir_inst, &self.flags, &self.isa_flags)
1554+
lower_inst::lower_insn_to_regs(ctx, ir_inst, &self.triple, &self.flags, &self.isa_flags)
15551555
}
15561556

15571557
fn lower_branch_group<C: LowerCtx<I = Inst>>(

cranelift/codegen/src/isa/aarch64/lower/isle.rs

+11-3
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,7 @@ use regalloc2::PReg;
3030
use std::boxed::Box;
3131
use std::convert::TryFrom;
3232
use std::vec::Vec;
33+
use target_lexicon::Triple;
3334

3435
type BoxCallInfo = Box<CallInfo>;
3536
type BoxCallIndInfo = Box<CallIndInfo>;
@@ -40,6 +41,7 @@ type BoxExternalName = Box<ExternalName>;
4041
/// The main entry point for lowering with ISLE.
4142
pub(crate) fn lower<C>(
4243
lower_ctx: &mut C,
44+
triple: &Triple,
4345
flags: &Flags,
4446
isa_flags: &IsaFlags,
4547
outputs: &[InsnOutput],
@@ -48,9 +50,15 @@ pub(crate) fn lower<C>(
4850
where
4951
C: LowerCtx<I = MInst>,
5052
{
51-
lower_common(lower_ctx, flags, isa_flags, outputs, inst, |cx, insn| {
52-
generated_code::constructor_lower(cx, insn)
53-
})
53+
lower_common(
54+
lower_ctx,
55+
triple,
56+
flags,
57+
isa_flags,
58+
outputs,
59+
inst,
60+
|cx, insn| generated_code::constructor_lower(cx, insn),
61+
)
5462
}
5563

5664
pub struct ExtendedValue {

cranelift/codegen/src/isa/aarch64/lower_inst.rs

+3-1
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,13 @@ use crate::{CodegenError, CodegenResult};
1616
use alloc::boxed::Box;
1717
use alloc::vec::Vec;
1818
use core::convert::TryFrom;
19+
use target_lexicon::Triple;
1920

2021
/// Actually codegen an instruction's results into registers.
2122
pub(crate) fn lower_insn_to_regs<C: LowerCtx<I = Inst>>(
2223
ctx: &mut C,
2324
insn: IRInst,
25+
triple: &Triple,
2426
flags: &Flags,
2527
isa_flags: &aarch64_settings::Flags,
2628
) -> CodegenResult<()> {
@@ -33,7 +35,7 @@ pub(crate) fn lower_insn_to_regs<C: LowerCtx<I = Inst>>(
3335
None
3436
};
3537

36-
if let Ok(()) = super::lower::isle::lower(ctx, flags, isa_flags, &outputs, insn) {
38+
if let Ok(()) = super::lower::isle::lower(ctx, triple, flags, isa_flags, &outputs, insn) {
3739
return Ok(());
3840
}
3941

cranelift/codegen/src/isa/s390x/lower.rs

+9-3
Original file line numberDiff line numberDiff line change
@@ -30,9 +30,14 @@ impl LowerBackend for S390xBackend {
3030
None
3131
};
3232

33-
if let Ok(()) =
34-
super::lower::isle::lower(ctx, &self.flags, &self.isa_flags, &outputs, ir_inst)
35-
{
33+
if let Ok(()) = super::lower::isle::lower(
34+
ctx,
35+
&self.triple,
36+
&self.flags,
37+
&self.isa_flags,
38+
&outputs,
39+
ir_inst,
40+
) {
3641
return Ok(());
3742
}
3843

@@ -295,6 +300,7 @@ impl LowerBackend for S390xBackend {
295300
// the second branch (if any) by emitting a two-way conditional branch.
296301
if let Ok(()) = super::lower::isle::lower_branch(
297302
ctx,
303+
&self.triple,
298304
&self.flags,
299305
&self.isa_flags,
300306
branches[0],

cranelift/codegen/src/isa/s390x/lower/isle.rs

+21-6
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ use std::boxed::Box;
2626
use std::cell::Cell;
2727
use std::convert::TryFrom;
2828
use std::vec::Vec;
29+
use target_lexicon::Triple;
2930

3031
type BoxCallInfo = Box<CallInfo>;
3132
type BoxCallIndInfo = Box<CallIndInfo>;
@@ -37,6 +38,7 @@ type VecMInstBuilder = Cell<Vec<MInst>>;
3738
/// The main entry point for lowering with ISLE.
3839
pub(crate) fn lower<C>(
3940
lower_ctx: &mut C,
41+
triple: &Triple,
4042
flags: &Flags,
4143
isa_flags: &IsaFlags,
4244
outputs: &[InsnOutput],
@@ -45,14 +47,21 @@ pub(crate) fn lower<C>(
4547
where
4648
C: LowerCtx<I = MInst>,
4749
{
48-
lower_common(lower_ctx, flags, isa_flags, outputs, inst, |cx, insn| {
49-
generated_code::constructor_lower(cx, insn)
50-
})
50+
lower_common(
51+
lower_ctx,
52+
triple,
53+
flags,
54+
isa_flags,
55+
outputs,
56+
inst,
57+
|cx, insn| generated_code::constructor_lower(cx, insn),
58+
)
5159
}
5260

5361
/// The main entry point for branch lowering with ISLE.
5462
pub(crate) fn lower_branch<C>(
5563
lower_ctx: &mut C,
64+
triple: &Triple,
5665
flags: &Flags,
5766
isa_flags: &IsaFlags,
5867
branch: Inst,
@@ -61,9 +70,15 @@ pub(crate) fn lower_branch<C>(
6170
where
6271
C: LowerCtx<I = MInst>,
6372
{
64-
lower_common(lower_ctx, flags, isa_flags, &[], branch, |cx, insn| {
65-
generated_code::constructor_lower_branch(cx, insn, &targets.to_vec())
66-
})
73+
lower_common(
74+
lower_ctx,
75+
triple,
76+
flags,
77+
isa_flags,
78+
&[],
79+
branch,
80+
|cx, insn| generated_code::constructor_lower_branch(cx, insn, &targets.to_vec()),
81+
)
6782
}
6883

6984
impl<C> generated_code::Context for IsleContext<'_, C, Flags, IsaFlags, 6>

cranelift/codegen/src/isa/x64/inst.isle

+10
Original file line numberDiff line numberDiff line change
@@ -3354,3 +3354,13 @@
33543354
(decl x64_rsp () Reg)
33553355
(rule (x64_rsp)
33563356
(mov_preg (preg_rsp)))
3357+
3358+
;;;; Helpers for Emitting LibCalls ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
3359+
3360+
(type LibCall extern
3361+
(enum
3362+
FmaF32
3363+
FmaF64))
3364+
3365+
(decl libcall_3 (LibCall Reg Reg Reg) Reg)
3366+
(extern constructor libcall_3 libcall_3)

cranelift/codegen/src/isa/x64/lower.isle

+4
Original file line numberDiff line numberDiff line change
@@ -2491,6 +2491,10 @@
24912491

24922492
;; Rules for `fma` ;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;;
24932493

2494+
(rule (lower (has_type $F32 (fma x y z)))
2495+
(libcall_3 (LibCall.FmaF32) x y z))
2496+
(rule (lower (has_type $F64 (fma x y z)))
2497+
(libcall_3 (LibCall.FmaF64) x y z))
24942498
(rule (lower (has_type $F32X4 (fma x y z)))
24952499
(x64_vfmadd213ps x y z))
24962500
(rule (lower (has_type $F64X2 (fma x y z)))

cranelift/codegen/src/isa/x64/lower.rs

+13-29
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,7 @@ pub(super) mod isle;
66
use crate::data_value::DataValue;
77
use crate::ir::{
88
condcodes::{CondCode, FloatCC, IntCC},
9-
types, AbiParam, ExternalName, Inst as IRInst, InstructionData, LibCall, Opcode, Signature,
10-
Type,
9+
types, ExternalName, Inst as IRInst, InstructionData, LibCall, Opcode, Type,
1110
};
1211
use crate::isa::x64::abi::*;
1312
use crate::isa::x64::inst::args::*;
@@ -573,29 +572,13 @@ fn emit_fcmp<C: LowerCtx<I = Inst>>(
573572
cond_result
574573
}
575574

576-
fn make_libcall_sig<C: LowerCtx<I = Inst>>(
577-
ctx: &mut C,
578-
insn: IRInst,
579-
call_conv: CallConv,
580-
) -> Signature {
581-
let mut sig = Signature::new(call_conv);
582-
for i in 0..ctx.num_inputs(insn) {
583-
sig.params.push(AbiParam::new(ctx.input_ty(insn, i)));
584-
}
585-
for i in 0..ctx.num_outputs(insn) {
586-
sig.returns.push(AbiParam::new(ctx.output_ty(insn, i)));
587-
}
588-
sig
589-
}
590-
591575
fn emit_vm_call<C: LowerCtx<I = Inst>>(
592576
ctx: &mut C,
593577
flags: &Flags,
594578
triple: &Triple,
595579
libcall: LibCall,
596-
insn: IRInst,
597-
inputs: SmallVec<[InsnInput; 4]>,
598-
outputs: SmallVec<[InsnOutput; 2]>,
580+
inputs: &[Reg],
581+
outputs: &[Writable<Reg>],
599582
) -> CodegenResult<()> {
600583
let extname = ExternalName::LibCall(libcall);
601584

@@ -607,7 +590,7 @@ fn emit_vm_call<C: LowerCtx<I = Inst>>(
607590

608591
// TODO avoid recreating signatures for every single Libcall function.
609592
let call_conv = CallConv::for_libcall(flags, CallConv::triple_default(triple));
610-
let sig = make_libcall_sig(ctx, insn, call_conv);
593+
let sig = libcall.signature(call_conv);
611594
let caller_conv = ctx.abi().call_conv();
612595

613596
let mut abi = X64ABICaller::from_func(&sig, &extname, dist, caller_conv, flags)?;
@@ -617,14 +600,12 @@ fn emit_vm_call<C: LowerCtx<I = Inst>>(
617600
assert_eq!(inputs.len(), abi.num_args());
618601

619602
for (i, input) in inputs.iter().enumerate() {
620-
let arg_reg = put_input_in_reg(ctx, *input);
621-
abi.emit_copy_regs_to_arg(ctx, i, ValueRegs::one(arg_reg));
603+
abi.emit_copy_regs_to_arg(ctx, i, ValueRegs::one(*input));
622604
}
623605

624606
abi.emit_call(ctx);
625607
for (i, output) in outputs.iter().enumerate() {
626-
let retval_reg = get_output_reg(ctx, *output).only_reg().unwrap();
627-
abi.emit_copy_retval_to_regs(ctx, i, ValueRegs::one(retval_reg));
608+
abi.emit_copy_retval_to_regs(ctx, i, ValueRegs::one(*output));
628609
}
629610
abi.emit_stack_post_adjust(ctx);
630611

@@ -810,7 +791,7 @@ fn lower_insn_to_regs<C: LowerCtx<I = Inst>>(
810791
None
811792
};
812793

813-
if let Ok(()) = isle::lower(ctx, flags, isa_flags, &outputs, insn) {
794+
if let Ok(()) = isle::lower(ctx, triple, flags, isa_flags, &outputs, insn) {
814795
return Ok(());
815796
}
816797

@@ -884,6 +865,7 @@ fn lower_insn_to_regs<C: LowerCtx<I = Inst>>(
884865
| Opcode::FvpromoteLow
885866
| Opcode::Fdemote
886867
| Opcode::Fvdemote
868+
| Opcode::Fma
887869
| Opcode::Icmp
888870
| Opcode::Fcmp
889871
| Opcode::Load
@@ -1974,7 +1956,11 @@ fn lower_insn_to_regs<C: LowerCtx<I = Inst>>(
19741956
ty, op
19751957
),
19761958
};
1977-
emit_vm_call(ctx, flags, triple, libcall, insn, inputs, outputs)?;
1959+
1960+
let input = put_input_in_reg(ctx, inputs[0]);
1961+
let dst = get_output_reg(ctx, outputs[0]).only_reg().unwrap();
1962+
1963+
emit_vm_call(ctx, flags, triple, libcall, &[input], &[dst])?;
19781964
}
19791965
}
19801966

@@ -2726,8 +2712,6 @@ fn lower_insn_to_regs<C: LowerCtx<I = Inst>>(
27262712

27272713
Opcode::Cls => unimplemented!("Cls not supported"),
27282714

2729-
Opcode::Fma => implemented_in_isle(ctx),
2730-
27312715
Opcode::BorNot | Opcode::BxorNot => {
27322716
unimplemented!("or-not / xor-not opcodes not implemented");
27332717
}

0 commit comments

Comments
 (0)