Skip to content

Commit 7a08f84

Browse files
committed
Auto merge of rust-lang#126578 - scottmcm:inlining-bonuses-too, r=davidtwco
Account for things that optimize out in inlining costs This updates the MIR inlining `CostChecker` to have both bonuses and penalties, rather than just penalties. That lets us add bonuses for some things where we want to encourage inlining without risking wrapping into a gigantic cost. For example, `switchInt(const …)` we give an inlining bonus because codegen will actually eliminate the branch (and associated dead blocks) once it's monomorphized, so measuring both sides of the branch gives an unrealistically-high cost to it. Similarly, an `unreachable` terminator gets a small bonus, because whatever branch leads there doesn't actually exist post-codegen.
2 parents a9c8887 + eac6b29 commit 7a08f84

8 files changed

+371
-77
lines changed

Diff for: compiler/rustc_mir_transform/src/cost_checker.rs

+75-31
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
use rustc_middle::bug;
12
use rustc_middle::mir::visit::*;
23
use rustc_middle::mir::*;
34
use rustc_middle::ty::{self, ParamEnv, Ty, TyCtxt};
@@ -6,13 +7,16 @@ const INSTR_COST: usize = 5;
67
const CALL_PENALTY: usize = 25;
78
const LANDINGPAD_PENALTY: usize = 50;
89
const RESUME_PENALTY: usize = 45;
10+
const LARGE_SWITCH_PENALTY: usize = 20;
11+
const CONST_SWITCH_BONUS: usize = 10;
912

1013
/// Verify that the callee body is compatible with the caller.
1114
#[derive(Clone)]
1215
pub(crate) struct CostChecker<'b, 'tcx> {
1316
tcx: TyCtxt<'tcx>,
1417
param_env: ParamEnv<'tcx>,
15-
cost: usize,
18+
penalty: usize,
19+
bonus: usize,
1620
callee_body: &'b Body<'tcx>,
1721
instance: Option<ty::Instance<'tcx>>,
1822
}
@@ -24,11 +28,11 @@ impl<'b, 'tcx> CostChecker<'b, 'tcx> {
2428
instance: Option<ty::Instance<'tcx>>,
2529
callee_body: &'b Body<'tcx>,
2630
) -> CostChecker<'b, 'tcx> {
27-
CostChecker { tcx, param_env, callee_body, instance, cost: 0 }
31+
CostChecker { tcx, param_env, callee_body, instance, penalty: 0, bonus: 0 }
2832
}
2933

3034
pub fn cost(&self) -> usize {
31-
self.cost
35+
usize::saturating_sub(self.penalty, self.bonus)
3236
}
3337

3438
fn instantiate_ty(&self, v: Ty<'tcx>) -> Ty<'tcx> {
@@ -41,60 +45,100 @@ impl<'b, 'tcx> CostChecker<'b, 'tcx> {
4145
}
4246

4347
impl<'tcx> Visitor<'tcx> for CostChecker<'_, 'tcx> {
44-
fn visit_statement(&mut self, statement: &Statement<'tcx>, _: Location) {
45-
// Don't count StorageLive/StorageDead in the inlining cost.
48+
fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) {
49+
// Most costs are in rvalues and terminators, not in statements.
4650
match statement.kind {
47-
StatementKind::StorageLive(_)
48-
| StatementKind::StorageDead(_)
49-
| StatementKind::Deinit(_)
50-
| StatementKind::Nop => {}
51-
_ => self.cost += INSTR_COST,
51+
StatementKind::Intrinsic(ref ndi) => {
52+
self.penalty += match **ndi {
53+
NonDivergingIntrinsic::Assume(..) => INSTR_COST,
54+
NonDivergingIntrinsic::CopyNonOverlapping(..) => CALL_PENALTY,
55+
};
56+
}
57+
_ => self.super_statement(statement, location),
58+
}
59+
}
60+
61+
fn visit_rvalue(&mut self, rvalue: &Rvalue<'tcx>, _location: Location) {
62+
match rvalue {
63+
Rvalue::NullaryOp(NullOp::UbChecks, ..) if !self.tcx.sess.ub_checks() => {
64+
// If this is in optimized MIR it's because it's used later,
65+
// so if we don't need UB checks this session, give a bonus
66+
// here to offset the cost of the call later.
67+
self.bonus += CALL_PENALTY;
68+
}
69+
// These are essentially constants that didn't end up in an Operand,
70+
// so treat them as also being free.
71+
Rvalue::NullaryOp(..) => {}
72+
_ => self.penalty += INSTR_COST,
5273
}
5374
}
5475

5576
fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, _: Location) {
56-
let tcx = self.tcx;
57-
match terminator.kind {
58-
TerminatorKind::Drop { ref place, unwind, .. } => {
77+
match &terminator.kind {
78+
TerminatorKind::Drop { place, unwind, .. } => {
5979
// If the place doesn't actually need dropping, treat it like a regular goto.
60-
let ty = self.instantiate_ty(place.ty(self.callee_body, tcx).ty);
61-
if ty.needs_drop(tcx, self.param_env) {
62-
self.cost += CALL_PENALTY;
80+
let ty = self.instantiate_ty(place.ty(self.callee_body, self.tcx).ty);
81+
if ty.needs_drop(self.tcx, self.param_env) {
82+
self.penalty += CALL_PENALTY;
6383
if let UnwindAction::Cleanup(_) = unwind {
64-
self.cost += LANDINGPAD_PENALTY;
84+
self.penalty += LANDINGPAD_PENALTY;
6585
}
66-
} else {
67-
self.cost += INSTR_COST;
6886
}
6987
}
70-
TerminatorKind::Call { func: Operand::Constant(ref f), unwind, .. } => {
71-
let fn_ty = self.instantiate_ty(f.const_.ty());
72-
self.cost += if let ty::FnDef(def_id, _) = *fn_ty.kind()
73-
&& tcx.intrinsic(def_id).is_some()
88+
TerminatorKind::Call { func, unwind, .. } => {
89+
self.penalty += if let Some((def_id, ..)) = func.const_fn_def()
90+
&& self.tcx.intrinsic(def_id).is_some()
7491
{
7592
// Don't give intrinsics the extra penalty for calls
7693
INSTR_COST
7794
} else {
7895
CALL_PENALTY
7996
};
8097
if let UnwindAction::Cleanup(_) = unwind {
81-
self.cost += LANDINGPAD_PENALTY;
98+
self.penalty += LANDINGPAD_PENALTY;
99+
}
100+
}
101+
TerminatorKind::SwitchInt { discr, targets } => {
102+
if discr.constant().is_some() {
103+
// Not only will this become a `Goto`, but likely other
104+
// things will be removable as unreachable.
105+
self.bonus += CONST_SWITCH_BONUS;
106+
} else if targets.all_targets().len() > 3 {
107+
// More than false/true/unreachable gets extra cost.
108+
self.penalty += LARGE_SWITCH_PENALTY;
109+
} else {
110+
self.penalty += INSTR_COST;
82111
}
83112
}
84-
TerminatorKind::Assert { unwind, .. } => {
85-
self.cost += CALL_PENALTY;
113+
TerminatorKind::Assert { unwind, msg, .. } => {
114+
self.penalty +=
115+
if msg.is_optional_overflow_check() && !self.tcx.sess.overflow_checks() {
116+
INSTR_COST
117+
} else {
118+
CALL_PENALTY
119+
};
86120
if let UnwindAction::Cleanup(_) = unwind {
87-
self.cost += LANDINGPAD_PENALTY;
121+
self.penalty += LANDINGPAD_PENALTY;
88122
}
89123
}
90-
TerminatorKind::UnwindResume => self.cost += RESUME_PENALTY,
124+
TerminatorKind::UnwindResume => self.penalty += RESUME_PENALTY,
91125
TerminatorKind::InlineAsm { unwind, .. } => {
92-
self.cost += INSTR_COST;
126+
self.penalty += INSTR_COST;
93127
if let UnwindAction::Cleanup(_) = unwind {
94-
self.cost += LANDINGPAD_PENALTY;
128+
self.penalty += LANDINGPAD_PENALTY;
95129
}
96130
}
97-
_ => self.cost += INSTR_COST,
131+
TerminatorKind::Unreachable => {
132+
self.bonus += INSTR_COST;
133+
}
134+
TerminatorKind::Goto { .. } | TerminatorKind::Return => {}
135+
TerminatorKind::UnwindTerminate(..) => {}
136+
kind @ (TerminatorKind::FalseUnwind { .. }
137+
| TerminatorKind::FalseEdge { .. }
138+
| TerminatorKind::Yield { .. }
139+
| TerminatorKind::CoroutineDrop) => {
140+
bug!("{kind:?} should not be in runtime MIR");
141+
}
98142
}
99143
}
100144
}

Diff for: library/core/src/slice/iter/macros.rs

+4-2
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,8 @@ macro_rules! iterator {
103103
// so this new pointer is inside `self` and thus guaranteed to be non-null.
104104
unsafe {
105105
if_zst!(mut self,
106-
len => *len = len.unchecked_sub(offset),
106+
// Using the intrinsic directly avoids emitting a UbCheck
107+
len => *len = crate::intrinsics::unchecked_sub(*len, offset),
107108
_end => self.ptr = self.ptr.add(offset),
108109
);
109110
}
@@ -119,7 +120,8 @@ macro_rules! iterator {
119120
// SAFETY: By our precondition, `offset` can be at most the
120121
// current length, so the subtraction can never overflow.
121122
len => unsafe {
122-
*len = len.unchecked_sub(offset);
123+
// Using the intrinsic directly avoids emitting a UbCheck
124+
*len = crate::intrinsics::unchecked_sub(*len, offset);
123125
self.ptr
124126
},
125127
// SAFETY: the caller guarantees that `offset` doesn't exceed `self.len()`,

Diff for: tests/codegen/issues/issue-112509-slice-get-andthen-get.rs

+5-1
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,12 @@
33

44
// CHECK-LABEL: @write_u8_variant_a
55
// CHECK-NEXT: {{.*}}:
6-
// CHECK-NEXT: getelementptr
76
// CHECK-NEXT: icmp ugt
7+
// CHECK-NEXT: getelementptr
8+
// CHECK-NEXT: select i1 {{.+}} null
9+
// CHECK-NEXT: insertvalue
10+
// CHECK-NEXT: insertvalue
11+
// CHECK-NEXT: ret
812
#[no_mangle]
913
pub fn write_u8_variant_a(bytes: &mut [u8], buf: u8, offset: usize) -> Option<&mut [u8]> {
1014
let buf = buf.to_le_bytes();

Diff for: tests/crashes/123893.rs

+4
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,14 @@ fn generic_impl<T>() {
1111
impl<T> MagicTrait for T {
1212
const IS_BIG: bool = true;
1313
}
14+
more_cost();
1415
if T::IS_BIG {
1516
big_impl::<i32>();
1617
}
1718
}
1819

1920
#[inline(never)]
2021
fn big_impl<T>() {}
22+
23+
#[inline(never)]
24+
fn more_cost() {}

Diff for: tests/mir-opt/pre-codegen/checked_ops.step_forward.PreCodegen.after.panic-abort.mir

+76-1
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,87 @@ fn step_forward(_1: u16, _2: usize) -> u16 {
44
debug x => _1;
55
debug n => _2;
66
let mut _0: u16;
7+
scope 1 (inlined <u16 as Step>::forward) {
8+
let mut _8: u16;
9+
scope 2 {
10+
}
11+
scope 3 (inlined <u16 as Step>::forward_checked) {
12+
scope 4 {
13+
scope 6 (inlined core::num::<impl u16>::checked_add) {
14+
let mut _7: bool;
15+
scope 7 {
16+
}
17+
scope 8 (inlined core::num::<impl u16>::overflowing_add) {
18+
let mut _5: (u16, bool);
19+
let _6: bool;
20+
scope 9 {
21+
}
22+
}
23+
}
24+
}
25+
scope 5 (inlined convert::num::ptr_try_from_impls::<impl TryFrom<usize> for u16>::try_from) {
26+
let mut _3: bool;
27+
let mut _4: u16;
28+
}
29+
}
30+
scope 10 (inlined Option::<u16>::is_none) {
31+
scope 11 (inlined Option::<u16>::is_some) {
32+
}
33+
}
34+
scope 12 (inlined core::num::<impl u16>::wrapping_add) {
35+
}
36+
}
737

838
bb0: {
9-
_0 = <u16 as Step>::forward(move _1, move _2) -> [return: bb1, unwind unreachable];
39+
StorageLive(_4);
40+
StorageLive(_3);
41+
_3 = Gt(_2, const 65535_usize);
42+
switchInt(move _3) -> [0: bb1, otherwise: bb5];
1043
}
1144

1245
bb1: {
46+
_4 = _2 as u16 (IntToInt);
47+
StorageDead(_3);
48+
StorageLive(_6);
49+
StorageLive(_5);
50+
_5 = AddWithOverflow(_1, _4);
51+
_6 = (_5.1: bool);
52+
StorageDead(_5);
53+
StorageLive(_7);
54+
_7 = unlikely(move _6) -> [return: bb2, unwind unreachable];
55+
}
56+
57+
bb2: {
58+
switchInt(move _7) -> [0: bb3, otherwise: bb4];
59+
}
60+
61+
bb3: {
62+
StorageDead(_7);
63+
StorageDead(_6);
64+
goto -> bb7;
65+
}
66+
67+
bb4: {
68+
StorageDead(_7);
69+
StorageDead(_6);
70+
goto -> bb6;
71+
}
72+
73+
bb5: {
74+
StorageDead(_3);
75+
goto -> bb6;
76+
}
77+
78+
bb6: {
79+
assert(!const true, "attempt to compute `{} + {}`, which would overflow", const core::num::<impl u16>::MAX, const 1_u16) -> [success: bb7, unwind unreachable];
80+
}
81+
82+
bb7: {
83+
StorageLive(_8);
84+
_8 = _2 as u16 (IntToInt);
85+
_0 = Add(_1, _8);
86+
StorageDead(_8);
87+
StorageDead(_4);
1388
return;
1489
}
1590
}

Diff for: tests/mir-opt/pre-codegen/checked_ops.step_forward.PreCodegen.after.panic-unwind.mir

+76-1
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,87 @@ fn step_forward(_1: u16, _2: usize) -> u16 {
44
debug x => _1;
55
debug n => _2;
66
let mut _0: u16;
7+
scope 1 (inlined <u16 as Step>::forward) {
8+
let mut _8: u16;
9+
scope 2 {
10+
}
11+
scope 3 (inlined <u16 as Step>::forward_checked) {
12+
scope 4 {
13+
scope 6 (inlined core::num::<impl u16>::checked_add) {
14+
let mut _7: bool;
15+
scope 7 {
16+
}
17+
scope 8 (inlined core::num::<impl u16>::overflowing_add) {
18+
let mut _5: (u16, bool);
19+
let _6: bool;
20+
scope 9 {
21+
}
22+
}
23+
}
24+
}
25+
scope 5 (inlined convert::num::ptr_try_from_impls::<impl TryFrom<usize> for u16>::try_from) {
26+
let mut _3: bool;
27+
let mut _4: u16;
28+
}
29+
}
30+
scope 10 (inlined Option::<u16>::is_none) {
31+
scope 11 (inlined Option::<u16>::is_some) {
32+
}
33+
}
34+
scope 12 (inlined core::num::<impl u16>::wrapping_add) {
35+
}
36+
}
737

838
bb0: {
9-
_0 = <u16 as Step>::forward(move _1, move _2) -> [return: bb1, unwind continue];
39+
StorageLive(_4);
40+
StorageLive(_3);
41+
_3 = Gt(_2, const 65535_usize);
42+
switchInt(move _3) -> [0: bb1, otherwise: bb5];
1043
}
1144

1245
bb1: {
46+
_4 = _2 as u16 (IntToInt);
47+
StorageDead(_3);
48+
StorageLive(_6);
49+
StorageLive(_5);
50+
_5 = AddWithOverflow(_1, _4);
51+
_6 = (_5.1: bool);
52+
StorageDead(_5);
53+
StorageLive(_7);
54+
_7 = unlikely(move _6) -> [return: bb2, unwind unreachable];
55+
}
56+
57+
bb2: {
58+
switchInt(move _7) -> [0: bb3, otherwise: bb4];
59+
}
60+
61+
bb3: {
62+
StorageDead(_7);
63+
StorageDead(_6);
64+
goto -> bb7;
65+
}
66+
67+
bb4: {
68+
StorageDead(_7);
69+
StorageDead(_6);
70+
goto -> bb6;
71+
}
72+
73+
bb5: {
74+
StorageDead(_3);
75+
goto -> bb6;
76+
}
77+
78+
bb6: {
79+
assert(!const true, "attempt to compute `{} + {}`, which would overflow", const core::num::<impl u16>::MAX, const 1_u16) -> [success: bb7, unwind continue];
80+
}
81+
82+
bb7: {
83+
StorageLive(_8);
84+
_8 = _2 as u16 (IntToInt);
85+
_0 = Add(_1, _8);
86+
StorageDead(_8);
87+
StorageDead(_4);
1388
return;
1489
}
1590
}

0 commit comments

Comments
 (0)