Skip to content

Commit 98f5ebb

Browse files
committed
Auto merge of rust-lang#113970 - cjgillot:assume-all-the-things, r=nikic
Replace switch to unreachable by assume statements `UnreachablePropagation` currently keeps some switch terminators alive in order to ensure codegen can infer the inequalities on the discriminants. This PR proposes to encode those inequalities as `Assume` statements. This allows to simplify MIR further by removing some useless terminators.
2 parents 09ac6e4 + ae2e211 commit 98f5ebb

26 files changed

+573
-470
lines changed

Diff for: compiler/rustc_mir_transform/src/lib.rs

+3-2
Original file line numberDiff line numberDiff line change
@@ -568,10 +568,11 @@ fn run_optimization_passes<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
568568
&[
569569
&check_alignment::CheckAlignment,
570570
&lower_slice_len::LowerSliceLenCalls, // has to be done before inlining, otherwise actual call will be almost always inlined. Also simple, so can just do first
571-
&unreachable_prop::UnreachablePropagation,
571+
&inline::Inline,
572+
// Substitutions during inlining may introduce switch on enums with uninhabited branches.
572573
&uninhabited_enum_branching::UninhabitedEnumBranching,
574+
&unreachable_prop::UnreachablePropagation,
573575
&o1(simplify::SimplifyCfg::AfterUninhabitedEnumBranching),
574-
&inline::Inline,
575576
&remove_storage_markers::RemoveStorageMarkers,
576577
&remove_zsts::RemoveZsts,
577578
&normalize_array_len::NormalizeArrayLen, // has to run after `slice::len` lowering

Diff for: compiler/rustc_mir_transform/src/simplify_branches.rs

+18-1
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,25 @@ impl<'tcx> MirPass<'tcx> for SimplifyConstCondition {
1616
}
1717

1818
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
19+
trace!("Running SimplifyConstCondition on {:?}", body.source);
1920
let param_env = tcx.param_env_reveal_all_normalized(body.source.def_id());
20-
for block in body.basic_blocks_mut() {
21+
'blocks: for block in body.basic_blocks_mut() {
22+
for stmt in block.statements.iter_mut() {
23+
if let StatementKind::Intrinsic(box ref intrinsic) = stmt.kind
24+
&& let NonDivergingIntrinsic::Assume(discr) = intrinsic
25+
&& let Operand::Constant(ref c) = discr
26+
&& let Some(constant) = c.const_.try_eval_bool(tcx, param_env)
27+
{
28+
if constant {
29+
stmt.make_nop();
30+
} else {
31+
block.statements.clear();
32+
block.terminator_mut().kind = TerminatorKind::Unreachable;
33+
continue 'blocks;
34+
}
35+
}
36+
}
37+
2138
let terminator = block.terminator_mut();
2239
terminator.kind = match terminator.kind {
2340
TerminatorKind::SwitchInt {

Diff for: compiler/rustc_mir_transform/src/uninhabited_enum_branching.rs

+47-57
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
use crate::MirPass;
44
use rustc_data_structures::fx::FxHashSet;
55
use rustc_middle::mir::{
6-
BasicBlockData, Body, Local, Operand, Rvalue, StatementKind, SwitchTargets, Terminator,
7-
TerminatorKind,
6+
BasicBlockData, Body, Local, Operand, Rvalue, StatementKind, Terminator, TerminatorKind,
87
};
98
use rustc_middle::ty::layout::TyAndLayout;
109
use rustc_middle::ty::{Ty, TyCtxt};
@@ -30,17 +29,20 @@ fn get_switched_on_type<'tcx>(
3029
let terminator = block_data.terminator();
3130

3231
// Only bother checking blocks which terminate by switching on a local.
33-
if let Some(local) = get_discriminant_local(&terminator.kind)
34-
&& let [.., stmt_before_term] = &block_data.statements[..]
35-
&& let StatementKind::Assign(box (l, Rvalue::Discriminant(place))) = stmt_before_term.kind
32+
let local = get_discriminant_local(&terminator.kind)?;
33+
34+
let stmt_before_term = block_data.statements.last()?;
35+
36+
if let StatementKind::Assign(box (l, Rvalue::Discriminant(place))) = stmt_before_term.kind
3637
&& l.as_local() == Some(local)
37-
&& let ty = place.ty(body, tcx).ty
38-
&& ty.is_enum()
3938
{
40-
Some(ty)
41-
} else {
42-
None
39+
let ty = place.ty(body, tcx).ty;
40+
if ty.is_enum() {
41+
return Some(ty);
42+
}
4343
}
44+
45+
None
4446
}
4547

4648
fn variant_discriminants<'tcx>(
@@ -67,28 +69,6 @@ fn variant_discriminants<'tcx>(
6769
}
6870
}
6971

70-
/// Ensures that the `otherwise` branch leads to an unreachable bb, returning `None` if so and a new
71-
/// bb to use as the new target if not.
72-
fn ensure_otherwise_unreachable<'tcx>(
73-
body: &Body<'tcx>,
74-
targets: &SwitchTargets,
75-
) -> Option<BasicBlockData<'tcx>> {
76-
let otherwise = targets.otherwise();
77-
let bb = &body.basic_blocks[otherwise];
78-
if bb.terminator().kind == TerminatorKind::Unreachable
79-
&& bb.statements.iter().all(|s| matches!(&s.kind, StatementKind::StorageDead(_)))
80-
{
81-
return None;
82-
}
83-
84-
let mut new_block = BasicBlockData::new(Some(Terminator {
85-
source_info: bb.terminator().source_info,
86-
kind: TerminatorKind::Unreachable,
87-
}));
88-
new_block.is_cleanup = bb.is_cleanup;
89-
Some(new_block)
90-
}
91-
9272
impl<'tcx> MirPass<'tcx> for UninhabitedEnumBranching {
9373
fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
9474
sess.mir_opt_level() > 0
@@ -97,13 +77,16 @@ impl<'tcx> MirPass<'tcx> for UninhabitedEnumBranching {
9777
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
9878
trace!("UninhabitedEnumBranching starting for {:?}", body.source);
9979

100-
for bb in body.basic_blocks.indices() {
80+
let mut removable_switchs = Vec::new();
81+
82+
for (bb, bb_data) in body.basic_blocks.iter_enumerated() {
10183
trace!("processing block {:?}", bb);
10284

103-
let Some(discriminant_ty) = get_switched_on_type(&body.basic_blocks[bb], tcx, body)
104-
else {
85+
if bb_data.is_cleanup {
10586
continue;
106-
};
87+
}
88+
89+
let Some(discriminant_ty) = get_switched_on_type(&bb_data, tcx, body) else { continue };
10790

10891
let layout = tcx.layout_of(
10992
tcx.param_env_reveal_all_normalized(body.source.def_id()).and(discriminant_ty),
@@ -117,31 +100,38 @@ impl<'tcx> MirPass<'tcx> for UninhabitedEnumBranching {
117100

118101
trace!("allowed_variants = {:?}", allowed_variants);
119102

120-
if let TerminatorKind::SwitchInt { targets, .. } =
121-
&mut body.basic_blocks_mut()[bb].terminator_mut().kind
122-
{
123-
let mut new_targets = SwitchTargets::new(
124-
targets.iter().filter(|(val, _)| allowed_variants.contains(val)),
125-
targets.otherwise(),
126-
);
127-
128-
if new_targets.iter().count() == allowed_variants.len() {
129-
if let Some(updated) = ensure_otherwise_unreachable(body, &new_targets) {
130-
let new_otherwise = body.basic_blocks_mut().push(updated);
131-
*new_targets.all_targets_mut().last_mut().unwrap() = new_otherwise;
132-
}
133-
}
103+
let terminator = bb_data.terminator();
104+
let TerminatorKind::SwitchInt { targets, .. } = &terminator.kind else { bug!() };
134105

135-
if let TerminatorKind::SwitchInt { targets, .. } =
136-
&mut body.basic_blocks_mut()[bb].terminator_mut().kind
137-
{
138-
*targets = new_targets;
106+
let mut reachable_count = 0;
107+
for (index, (val, _)) in targets.iter().enumerate() {
108+
if allowed_variants.contains(&val) {
109+
reachable_count += 1;
139110
} else {
140-
unreachable!()
111+
removable_switchs.push((bb, index));
141112
}
142-
} else {
143-
unreachable!()
144113
}
114+
115+
if reachable_count == allowed_variants.len() {
116+
removable_switchs.push((bb, targets.iter().count()));
117+
}
118+
}
119+
120+
if removable_switchs.is_empty() {
121+
return;
122+
}
123+
124+
let new_block = BasicBlockData::new(Some(Terminator {
125+
source_info: body.basic_blocks[removable_switchs[0].0].terminator().source_info,
126+
kind: TerminatorKind::Unreachable,
127+
}));
128+
let unreachable_block = body.basic_blocks.as_mut().push(new_block);
129+
130+
for (bb, index) in removable_switchs {
131+
let bb = &mut body.basic_blocks.as_mut()[bb];
132+
let terminator = bb.terminator_mut();
133+
let TerminatorKind::SwitchInt { targets, .. } = &mut terminator.kind else { bug!() };
134+
targets.all_targets_mut()[index] = unreachable_block;
145135
}
146136
}
147137
}

0 commit comments

Comments
 (0)