Skip to content

Commit 096196d

Browse files
committed
Refactor UninhabitedEnumBranching to mark targets unreachable.
1 parent 0b13e63 commit 096196d

6 files changed

+84
-63
lines changed

compiler/rustc_mir_transform/src/uninhabited_enum_branching.rs

+47-57
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,7 @@
33
use crate::MirPass;
44
use rustc_data_structures::fx::FxHashSet;
55
use rustc_middle::mir::{
6-
BasicBlockData, Body, Local, Operand, Rvalue, StatementKind, SwitchTargets, Terminator,
7-
TerminatorKind,
6+
BasicBlockData, Body, Local, Operand, Rvalue, StatementKind, Terminator, TerminatorKind,
87
};
98
use rustc_middle::ty::layout::TyAndLayout;
109
use rustc_middle::ty::{Ty, TyCtxt};
@@ -30,17 +29,20 @@ fn get_switched_on_type<'tcx>(
3029
let terminator = block_data.terminator();
3130

3231
// Only bother checking blocks which terminate by switching on a local.
33-
if let Some(local) = get_discriminant_local(&terminator.kind)
34-
&& let [.., stmt_before_term] = &block_data.statements[..]
35-
&& let StatementKind::Assign(box (l, Rvalue::Discriminant(place))) = stmt_before_term.kind
32+
let local = get_discriminant_local(&terminator.kind)?;
33+
34+
let stmt_before_term = block_data.statements.last()?;
35+
36+
if let StatementKind::Assign(box (l, Rvalue::Discriminant(place))) = stmt_before_term.kind
3637
&& l.as_local() == Some(local)
37-
&& let ty = place.ty(body, tcx).ty
38-
&& ty.is_enum()
3938
{
40-
Some(ty)
41-
} else {
42-
None
39+
let ty = place.ty(body, tcx).ty;
40+
if ty.is_enum() {
41+
return Some(ty);
42+
}
4343
}
44+
45+
None
4446
}
4547

4648
fn variant_discriminants<'tcx>(
@@ -67,28 +69,6 @@ fn variant_discriminants<'tcx>(
6769
}
6870
}
6971

70-
/// Ensures that the `otherwise` branch leads to an unreachable bb, returning `None` if so and a new
71-
/// bb to use as the new target if not.
72-
fn ensure_otherwise_unreachable<'tcx>(
73-
body: &Body<'tcx>,
74-
targets: &SwitchTargets,
75-
) -> Option<BasicBlockData<'tcx>> {
76-
let otherwise = targets.otherwise();
77-
let bb = &body.basic_blocks[otherwise];
78-
if bb.terminator().kind == TerminatorKind::Unreachable
79-
&& bb.statements.iter().all(|s| matches!(&s.kind, StatementKind::StorageDead(_)))
80-
{
81-
return None;
82-
}
83-
84-
let mut new_block = BasicBlockData::new(Some(Terminator {
85-
source_info: bb.terminator().source_info,
86-
kind: TerminatorKind::Unreachable,
87-
}));
88-
new_block.is_cleanup = bb.is_cleanup;
89-
Some(new_block)
90-
}
91-
9272
impl<'tcx> MirPass<'tcx> for UninhabitedEnumBranching {
9373
fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
9474
sess.mir_opt_level() > 0
@@ -97,13 +77,16 @@ impl<'tcx> MirPass<'tcx> for UninhabitedEnumBranching {
9777
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
9878
trace!("UninhabitedEnumBranching starting for {:?}", body.source);
9979

100-
for bb in body.basic_blocks.indices() {
80+
let mut removable_switchs = Vec::new();
81+
82+
for (bb, bb_data) in body.basic_blocks.iter_enumerated() {
10183
trace!("processing block {:?}", bb);
10284

103-
let Some(discriminant_ty) = get_switched_on_type(&body.basic_blocks[bb], tcx, body)
104-
else {
85+
if bb_data.is_cleanup {
10586
continue;
106-
};
87+
}
88+
89+
let Some(discriminant_ty) = get_switched_on_type(&bb_data, tcx, body) else { continue };
10790

10891
let layout = tcx.layout_of(
10992
tcx.param_env_reveal_all_normalized(body.source.def_id()).and(discriminant_ty),
@@ -117,31 +100,38 @@ impl<'tcx> MirPass<'tcx> for UninhabitedEnumBranching {
117100

118101
trace!("allowed_variants = {:?}", allowed_variants);
119102

120-
if let TerminatorKind::SwitchInt { targets, .. } =
121-
&mut body.basic_blocks_mut()[bb].terminator_mut().kind
122-
{
123-
let mut new_targets = SwitchTargets::new(
124-
targets.iter().filter(|(val, _)| allowed_variants.contains(val)),
125-
targets.otherwise(),
126-
);
127-
128-
if new_targets.iter().count() == allowed_variants.len() {
129-
if let Some(updated) = ensure_otherwise_unreachable(body, &new_targets) {
130-
let new_otherwise = body.basic_blocks_mut().push(updated);
131-
*new_targets.all_targets_mut().last_mut().unwrap() = new_otherwise;
132-
}
133-
}
103+
let terminator = bb_data.terminator();
104+
let TerminatorKind::SwitchInt { targets, .. } = &terminator.kind else { bug!() };
134105

135-
if let TerminatorKind::SwitchInt { targets, .. } =
136-
&mut body.basic_blocks_mut()[bb].terminator_mut().kind
137-
{
138-
*targets = new_targets;
106+
let mut reachable_count = 0;
107+
for (index, (val, _)) in targets.iter().enumerate() {
108+
if allowed_variants.contains(&val) {
109+
reachable_count += 1;
139110
} else {
140-
unreachable!()
111+
removable_switchs.push((bb, index));
141112
}
142-
} else {
143-
unreachable!()
144113
}
114+
115+
if reachable_count == allowed_variants.len() {
116+
removable_switchs.push((bb, targets.iter().count()));
117+
}
118+
}
119+
120+
if removable_switchs.is_empty() {
121+
return;
122+
}
123+
124+
let new_block = BasicBlockData::new(Some(Terminator {
125+
source_info: body.basic_blocks[removable_switchs[0].0].terminator().source_info,
126+
kind: TerminatorKind::Unreachable,
127+
}));
128+
let unreachable_block = body.basic_blocks.as_mut().push(new_block);
129+
130+
for (bb, index) in removable_switchs {
131+
let bb = &mut body.basic_blocks.as_mut()[bb];
132+
let terminator = bb.terminator_mut();
133+
let TerminatorKind::SwitchInt { targets, .. } = &mut terminator.kind else { bug!() };
134+
targets.all_targets_mut()[index] = unreachable_block;
145135
}
146136
}
147137
}

tests/mir-opt/uninhabited_enum_branching.main.SimplifyCfg-after-uninhabited-enum-branching.after.mir

+7-1
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,20 @@ fn main() -> () {
1212
let mut _8: isize;
1313
let _9: &str;
1414
let mut _10: bool;
15+
let mut _11: bool;
16+
let mut _12: bool;
1517

1618
bb0: {
1719
StorageLive(_1);
1820
StorageLive(_2);
1921
_2 = Test1::C;
2022
_3 = discriminant(_2);
21-
_10 = Eq(_3, const 2_isize);
23+
_10 = Ne(_3, const 0_isize);
2224
assume(move _10);
25+
_11 = Ne(_3, const 1_isize);
26+
assume(move _11);
27+
_12 = Eq(_3, const 2_isize);
28+
assume(move _12);
2329
StorageLive(_5);
2430
_5 = const "C";
2531
_1 = &(*_5);

tests/mir-opt/uninhabited_enum_branching.main.UninhabitedEnumBranching.diff

+7-2
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
_2 = Test1::C;
2020
_3 = discriminant(_2);
2121
- switchInt(move _3) -> [0: bb3, 1: bb4, 2: bb1, otherwise: bb2];
22-
+ switchInt(move _3) -> [2: bb1, otherwise: bb2];
22+
+ switchInt(move _3) -> [0: bb9, 1: bb9, 2: bb1, otherwise: bb9];
2323
}
2424

2525
bb1: {
@@ -54,7 +54,8 @@
5454
StorageLive(_7);
5555
_7 = Test2::D;
5656
_8 = discriminant(_7);
57-
switchInt(move _8) -> [4: bb7, 5: bb6, otherwise: bb2];
57+
- switchInt(move _8) -> [4: bb7, 5: bb6, otherwise: bb2];
58+
+ switchInt(move _8) -> [4: bb7, 5: bb6, otherwise: bb9];
5859
}
5960

6061
bb6: {
@@ -75,6 +76,10 @@
7576
StorageDead(_6);
7677
_0 = const ();
7778
return;
79+
+ }
80+
+
81+
+ bb9: {
82+
+ unreachable;
7883
}
7984
}
8085

tests/mir-opt/uninhabited_enum_branching2.main.SimplifyCfg-after-uninhabited-enum-branching.after.mir

+12
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ fn main() -> () {
1515
let _11: &str;
1616
let _12: &str;
1717
let _13: &str;
18+
let mut _14: bool;
19+
let mut _15: bool;
20+
let mut _16: bool;
21+
let mut _17: bool;
1822
scope 1 {
1923
debug plop => _1;
2024
}
@@ -29,6 +33,10 @@ fn main() -> () {
2933
StorageLive(_4);
3034
_4 = &(_1.1: Test1);
3135
_5 = discriminant((*_4));
36+
_16 = Ne(_5, const 0_isize);
37+
assume(move _16);
38+
_17 = Ne(_5, const 1_isize);
39+
assume(move _17);
3240
switchInt(move _5) -> [2: bb3, 3: bb1, otherwise: bb2];
3341
}
3442

@@ -57,6 +65,10 @@ fn main() -> () {
5765
StorageDead(_3);
5866
StorageLive(_9);
5967
_10 = discriminant((_1.1: Test1));
68+
_14 = Ne(_10, const 0_isize);
69+
assume(move _14);
70+
_15 = Ne(_10, const 1_isize);
71+
assume(move _15);
6072
switchInt(move _10) -> [2: bb6, 3: bb5, otherwise: bb2];
6173
}
6274

tests/mir-opt/uninhabited_enum_branching2.main.UninhabitedEnumBranching.diff

+6-2
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
_4 = &(_1.1: Test1);
3232
_5 = discriminant((*_4));
3333
- switchInt(move _5) -> [0: bb3, 1: bb4, 2: bb5, 3: bb1, otherwise: bb2];
34-
+ switchInt(move _5) -> [2: bb5, 3: bb1, otherwise: bb2];
34+
+ switchInt(move _5) -> [0: bb12, 1: bb12, 2: bb5, 3: bb1, otherwise: bb12];
3535
}
3636

3737
bb1: {
@@ -73,7 +73,7 @@
7373
StorageLive(_9);
7474
_10 = discriminant((_1.1: Test1));
7575
- switchInt(move _10) -> [0: bb8, 1: bb9, 2: bb10, 3: bb7, otherwise: bb2];
76-
+ switchInt(move _10) -> [2: bb10, 3: bb7, otherwise: bb2];
76+
+ switchInt(move _10) -> [0: bb12, 1: bb12, 2: bb10, 3: bb7, otherwise: bb12];
7777
}
7878

7979
bb7: {
@@ -110,6 +110,10 @@
110110
_0 = const ();
111111
StorageDead(_1);
112112
return;
113+
+ }
114+
+
115+
+ bb12: {
116+
+ unreachable;
113117
}
114118
}
115119

tests/mir-opt/uninhabited_fallthrough_elimination.keep_fallthrough.UninhabitedEnumBranching.diff

+5-1
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
bb0: {
1010
_2 = discriminant(_1);
1111
- switchInt(move _2) -> [0: bb2, 1: bb3, otherwise: bb1];
12-
+ switchInt(move _2) -> [1: bb3, otherwise: bb1];
12+
+ switchInt(move _2) -> [0: bb5, 1: bb3, otherwise: bb1];
1313
}
1414

1515
bb1: {
@@ -29,6 +29,10 @@
2929

3030
bb4: {
3131
return;
32+
+ }
33+
+
34+
+ bb5: {
35+
+ unreachable;
3236
}
3337
}
3438

0 commit comments

Comments
 (0)