Skip to content

Commit b5bd98d

Browse files
committed
Update MIR with MirPatch in UninhabitedEnumBranching
1 parent 3d7f8b4 commit b5bd98d

15 files changed

+155
-165
lines changed

Diff for: compiler/rustc_middle/src/mir/patch.rs

+25-2
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,8 @@ pub struct MirPatch<'tcx> {
1111
resume_block: Option<BasicBlock>,
1212
// Only for unreachable in cleanup path.
1313
unreachable_cleanup_block: Option<BasicBlock>,
14+
// Only for unreachable not in cleanup path.
15+
unreachable_no_cleanup_block: Option<BasicBlock>,
1416
// Cached block for UnwindTerminate (with reason)
1517
terminate_block: Option<(BasicBlock, UnwindTerminateReason)>,
1618
body_span: Span,
@@ -27,6 +29,7 @@ impl<'tcx> MirPatch<'tcx> {
2729
next_local: body.local_decls.len(),
2830
resume_block: None,
2931
unreachable_cleanup_block: None,
32+
unreachable_no_cleanup_block: None,
3033
terminate_block: None,
3134
body_span: body.span,
3235
};
@@ -43,9 +46,12 @@ impl<'tcx> MirPatch<'tcx> {
4346
// Check if we already have an unreachable block
4447
if matches!(block.terminator().kind, TerminatorKind::Unreachable)
4548
&& block.statements.is_empty()
46-
&& block.is_cleanup
4749
{
48-
result.unreachable_cleanup_block = Some(bb);
50+
if block.is_cleanup {
51+
result.unreachable_cleanup_block = Some(bb);
52+
} else {
53+
result.unreachable_no_cleanup_block = Some(bb);
54+
}
4955
continue;
5056
}
5157

@@ -95,6 +101,23 @@ impl<'tcx> MirPatch<'tcx> {
95101
bb
96102
}
97103

104+
pub fn unreachable_no_cleanup_block(&mut self) -> BasicBlock {
105+
if let Some(bb) = self.unreachable_no_cleanup_block {
106+
return bb;
107+
}
108+
109+
let bb = self.new_block(BasicBlockData {
110+
statements: vec![],
111+
terminator: Some(Terminator {
112+
source_info: SourceInfo::outermost(self.body_span),
113+
kind: TerminatorKind::Unreachable,
114+
}),
115+
is_cleanup: false,
116+
});
117+
self.unreachable_no_cleanup_block = Some(bb);
118+
bb
119+
}
120+
98121
pub fn terminate_block(&mut self, reason: UnwindTerminateReason) -> BasicBlock {
99122
if let Some((cached_bb, cached_reason)) = self.terminate_block
100123
&& reason == cached_reason

Diff for: compiler/rustc_mir_transform/src/uninhabited_enum_branching.rs

+30-37
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,9 @@
22
33
use crate::MirPass;
44
use rustc_data_structures::fx::FxHashSet;
5+
use rustc_middle::mir::patch::MirPatch;
56
use rustc_middle::mir::{
6-
BasicBlockData, Body, Local, Operand, Rvalue, StatementKind, Terminator, TerminatorKind,
7+
BasicBlockData, Body, Local, Operand, Rvalue, StatementKind, TerminatorKind,
78
};
89
use rustc_middle::ty::layout::TyAndLayout;
910
use rustc_middle::ty::{Ty, TyCtxt};
@@ -77,8 +78,8 @@ impl<'tcx> MirPass<'tcx> for UninhabitedEnumBranching {
7778
fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
7879
trace!("UninhabitedEnumBranching starting for {:?}", body.source);
7980

80-
let mut removable_switchs = Vec::new();
81-
let mut otherwise_is_last_variant_switchs = Vec::new();
81+
let mut unreachable_targets = Vec::new();
82+
let mut patch = MirPatch::new(body);
8283

8384
for (bb, bb_data) in body.basic_blocks.iter_enumerated() {
8485
trace!("processing block {:?}", bb);
@@ -107,49 +108,41 @@ impl<'tcx> MirPass<'tcx> for UninhabitedEnumBranching {
107108

108109
trace!("allowed_variants = {:?}", allowed_variants);
109110

110-
let terminator = bb_data.terminator();
111-
let TerminatorKind::SwitchInt { targets, .. } = &terminator.kind else { bug!() };
111+
unreachable_targets.clear();
112+
let TerminatorKind::SwitchInt { targets, discr } = &bb_data.terminator().kind else {
113+
bug!()
114+
};
112115

113116
for (index, (val, _)) in targets.iter().enumerate() {
114117
if !allowed_variants.remove(&val) {
115-
removable_switchs.push((bb, index));
118+
unreachable_targets.push(index);
116119
}
117120
}
118121

119-
if allowed_variants.is_empty() {
120-
removable_switchs.push((bb, targets.iter().count()));
121-
} else if allowed_variants.len() == 1
122-
&& !body.basic_blocks[targets.otherwise()].is_empty_unreachable()
123-
{
124-
#[allow(rustc::potential_query_instability)]
125-
let last_variant = *allowed_variants.iter().next().unwrap();
126-
otherwise_is_last_variant_switchs.push((bb, last_variant));
127-
}
128-
}
122+
let replace_otherwise_to_unreachable = allowed_variants.len() <= 1
123+
&& !body.basic_blocks[targets.otherwise()].is_empty_unreachable();
129124

130-
for (bb, last_variant) in otherwise_is_last_variant_switchs {
131-
let bb_data = &mut body.basic_blocks.as_mut()[bb];
132-
let terminator = bb_data.terminator_mut();
133-
let TerminatorKind::SwitchInt { targets, .. } = &mut terminator.kind else { bug!() };
134-
targets.add_target(last_variant, targets.otherwise());
135-
removable_switchs.push((bb, targets.iter().count()));
136-
}
125+
if unreachable_targets.is_empty() && !replace_otherwise_to_unreachable {
126+
continue;
127+
}
137128

138-
if removable_switchs.is_empty() {
139-
return;
129+
let unreachable_block = patch.unreachable_no_cleanup_block();
130+
let mut targets = targets.clone();
131+
if replace_otherwise_to_unreachable {
132+
let otherwise_is_last_variant = !allowed_variants.is_empty();
133+
if otherwise_is_last_variant {
134+
#[allow(rustc::potential_query_instability)]
135+
let last_variant = *allowed_variants.iter().next().unwrap();
136+
targets.add_target(last_variant, targets.otherwise());
137+
}
138+
unreachable_targets.push(targets.iter().count());
139+
}
140+
for index in unreachable_targets.iter() {
141+
targets.all_targets_mut()[*index] = unreachable_block;
142+
}
143+
patch.patch_terminator(bb, TerminatorKind::SwitchInt { targets, discr: discr.clone() });
140144
}
141145

142-
let new_block = BasicBlockData::new(Some(Terminator {
143-
source_info: body.basic_blocks[removable_switchs[0].0].terminator().source_info,
144-
kind: TerminatorKind::Unreachable,
145-
}));
146-
let unreachable_block = body.basic_blocks.as_mut().push(new_block);
147-
148-
for (bb, index) in removable_switchs {
149-
let bb = &mut body.basic_blocks.as_mut()[bb];
150-
let terminator = bb.terminator_mut();
151-
let TerminatorKind::SwitchInt { targets, .. } = &mut terminator.kind else { bug!() };
152-
targets.all_targets_mut()[index] = unreachable_block;
153-
}
146+
patch.apply(body);
154147
}
155148
}

Diff for: tests/mir-opt/pre-codegen/issue_117368_print_invalid_constant.main.GVN.32bit.panic-abort.diff

+13-13
Original file line numberDiff line numberDiff line change
@@ -58,16 +58,20 @@
5858
+ _2 = const Option::<Layout>::None;
5959
StorageLive(_10);
6060
- _10 = discriminant(_2);
61-
- switchInt(move _10) -> [0: bb1, 1: bb2, otherwise: bb6];
61+
- switchInt(move _10) -> [0: bb2, 1: bb3, otherwise: bb1];
6262
+ _10 = const 0_isize;
63-
+ switchInt(const 0_isize) -> [0: bb1, 1: bb2, otherwise: bb6];
63+
+ switchInt(const 0_isize) -> [0: bb2, 1: bb3, otherwise: bb1];
6464
}
6565

6666
bb1: {
67-
_11 = option::unwrap_failed() -> unwind unreachable;
67+
unreachable;
6868
}
6969

7070
bb2: {
71+
_11 = option::unwrap_failed() -> unwind unreachable;
72+
}
73+
74+
bb3: {
7175
- _1 = move ((_2 as Some).0: std::alloc::Layout);
7276
+ _1 = const Layout {{ size: Indirect { alloc_id: ALLOC0, offset: Size(4 bytes) }: usize, align: std::ptr::Alignment(Scalar(0x00000000): std::ptr::alignment::AlignmentEnum) }};
7377
StorageDead(_10);
@@ -82,21 +86,21 @@
8286
+ _7 = const {ALLOC1<imm>: &std::alloc::Global};
8387
StorageLive(_8);
8488
- _8 = _1;
85-
- _6 = std::alloc::Global::alloc_impl(move _7, move _8, const false) -> [return: bb3, unwind unreachable];
89+
- _6 = std::alloc::Global::alloc_impl(move _7, move _8, const false) -> [return: bb4, unwind unreachable];
8690
+ _8 = const Layout {{ size: Indirect { alloc_id: ALLOC0, offset: Size(4 bytes) }: usize, align: std::ptr::Alignment(Scalar(0x00000000): std::ptr::alignment::AlignmentEnum) }};
87-
+ _6 = std::alloc::Global::alloc_impl(const {ALLOC1<imm>: &std::alloc::Global}, const Layout {{ size: Indirect { alloc_id: ALLOC0, offset: Size(4 bytes) }: usize, align: std::ptr::Alignment(Scalar(0x00000000): std::ptr::alignment::AlignmentEnum) }}, const false) -> [return: bb3, unwind unreachable];
91+
+ _6 = std::alloc::Global::alloc_impl(const {ALLOC1<imm>: &std::alloc::Global}, const Layout {{ size: Indirect { alloc_id: ALLOC0, offset: Size(4 bytes) }: usize, align: std::ptr::Alignment(Scalar(0x00000000): std::ptr::alignment::AlignmentEnum) }}, const false) -> [return: bb4, unwind unreachable];
8892
}
8993

90-
bb3: {
94+
bb4: {
9195
StorageDead(_8);
9296
StorageDead(_7);
9397
StorageLive(_12);
9498
StorageLive(_15);
9599
_12 = discriminant(_6);
96-
switchInt(move _12) -> [0: bb5, 1: bb4, otherwise: bb6];
100+
switchInt(move _12) -> [0: bb6, 1: bb5, otherwise: bb1];
97101
}
98102

99-
bb4: {
103+
bb5: {
100104
_15 = const "called `Result::unwrap()` on an `Err` value";
101105
StorageLive(_16);
102106
StorageLive(_17);
@@ -106,7 +110,7 @@
106110
_14 = result::unwrap_failed(move _15, move _16) -> unwind unreachable;
107111
}
108112

109-
bb5: {
113+
bb6: {
110114
_5 = move ((_6 as Ok).0: std::ptr::NonNull<[u8]>);
111115
StorageDead(_15);
112116
StorageDead(_12);
@@ -127,10 +131,6 @@
127131
+ nop;
128132
return;
129133
}
130-
131-
bb6: {
132-
unreachable;
133-
}
134134
}
135135
+
136136
+ ALLOC0 (size: 8, align: 4) {

Diff for: tests/mir-opt/pre-codegen/issue_117368_print_invalid_constant.main.GVN.32bit.panic-unwind.diff

+10-10
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@
4343
+ _2 = const Option::<Layout>::None;
4444
StorageLive(_10);
4545
- _10 = discriminant(_2);
46-
- switchInt(move _10) -> [0: bb2, 1: bb3, otherwise: bb5];
46+
- switchInt(move _10) -> [0: bb3, 1: bb4, otherwise: bb2];
4747
+ _10 = const 0_isize;
48-
+ switchInt(const 0_isize) -> [0: bb2, 1: bb3, otherwise: bb5];
48+
+ switchInt(const 0_isize) -> [0: bb3, 1: bb4, otherwise: bb2];
4949
}
5050

5151
bb1: {
@@ -68,10 +68,14 @@
6868
}
6969

7070
bb2: {
71-
_11 = option::unwrap_failed() -> unwind continue;
71+
unreachable;
7272
}
7373

7474
bb3: {
75+
_11 = option::unwrap_failed() -> unwind continue;
76+
}
77+
78+
bb4: {
7579
- _1 = move ((_2 as Some).0: std::alloc::Layout);
7680
+ _1 = const Layout {{ size: Indirect { alloc_id: ALLOC0, offset: Size(4 bytes) }: usize, align: std::ptr::Alignment(Scalar(0x00000000): std::ptr::alignment::AlignmentEnum) }};
7781
StorageDead(_10);
@@ -86,20 +90,16 @@
8690
+ _7 = const {ALLOC1<imm>: &std::alloc::Global};
8791
StorageLive(_8);
8892
- _8 = _1;
89-
- _6 = std::alloc::Global::alloc_impl(move _7, move _8, const false) -> [return: bb4, unwind continue];
93+
- _6 = std::alloc::Global::alloc_impl(move _7, move _8, const false) -> [return: bb5, unwind continue];
9094
+ _8 = const Layout {{ size: Indirect { alloc_id: ALLOC0, offset: Size(4 bytes) }: usize, align: std::ptr::Alignment(Scalar(0x00000000): std::ptr::alignment::AlignmentEnum) }};
91-
+ _6 = std::alloc::Global::alloc_impl(const {ALLOC1<imm>: &std::alloc::Global}, const Layout {{ size: Indirect { alloc_id: ALLOC0, offset: Size(4 bytes) }: usize, align: std::ptr::Alignment(Scalar(0x00000000): std::ptr::alignment::AlignmentEnum) }}, const false) -> [return: bb4, unwind continue];
95+
+ _6 = std::alloc::Global::alloc_impl(const {ALLOC1<imm>: &std::alloc::Global}, const Layout {{ size: Indirect { alloc_id: ALLOC0, offset: Size(4 bytes) }: usize, align: std::ptr::Alignment(Scalar(0x00000000): std::ptr::alignment::AlignmentEnum) }}, const false) -> [return: bb5, unwind continue];
9296
}
9397

94-
bb4: {
98+
bb5: {
9599
StorageDead(_8);
96100
StorageDead(_7);
97101
_5 = Result::<NonNull<[u8]>, std::alloc::AllocError>::unwrap(move _6) -> [return: bb1, unwind continue];
98102
}
99-
100-
bb5: {
101-
unreachable;
102-
}
103103
}
104104
+
105105
+ ALLOC0 (size: 8, align: 4) {

Diff for: tests/mir-opt/pre-codegen/issue_117368_print_invalid_constant.main.GVN.64bit.panic-abort.diff

+13-13
Original file line numberDiff line numberDiff line change
@@ -58,16 +58,20 @@
5858
+ _2 = const Option::<Layout>::None;
5959
StorageLive(_10);
6060
- _10 = discriminant(_2);
61-
- switchInt(move _10) -> [0: bb1, 1: bb2, otherwise: bb6];
61+
- switchInt(move _10) -> [0: bb2, 1: bb3, otherwise: bb1];
6262
+ _10 = const 0_isize;
63-
+ switchInt(const 0_isize) -> [0: bb1, 1: bb2, otherwise: bb6];
63+
+ switchInt(const 0_isize) -> [0: bb2, 1: bb3, otherwise: bb1];
6464
}
6565

6666
bb1: {
67-
_11 = option::unwrap_failed() -> unwind unreachable;
67+
unreachable;
6868
}
6969

7070
bb2: {
71+
_11 = option::unwrap_failed() -> unwind unreachable;
72+
}
73+
74+
bb3: {
7175
- _1 = move ((_2 as Some).0: std::alloc::Layout);
7276
+ _1 = const Layout {{ size: Indirect { alloc_id: ALLOC0, offset: Size(8 bytes) }: usize, align: std::ptr::Alignment(Scalar(0x0000000000000000): std::ptr::alignment::AlignmentEnum) }};
7377
StorageDead(_10);
@@ -82,21 +86,21 @@
8286
+ _7 = const {ALLOC1<imm>: &std::alloc::Global};
8387
StorageLive(_8);
8488
- _8 = _1;
85-
- _6 = std::alloc::Global::alloc_impl(move _7, move _8, const false) -> [return: bb3, unwind unreachable];
89+
- _6 = std::alloc::Global::alloc_impl(move _7, move _8, const false) -> [return: bb4, unwind unreachable];
8690
+ _8 = const Layout {{ size: Indirect { alloc_id: ALLOC0, offset: Size(8 bytes) }: usize, align: std::ptr::Alignment(Scalar(0x0000000000000000): std::ptr::alignment::AlignmentEnum) }};
87-
+ _6 = std::alloc::Global::alloc_impl(const {ALLOC1<imm>: &std::alloc::Global}, const Layout {{ size: Indirect { alloc_id: ALLOC0, offset: Size(8 bytes) }: usize, align: std::ptr::Alignment(Scalar(0x0000000000000000): std::ptr::alignment::AlignmentEnum) }}, const false) -> [return: bb3, unwind unreachable];
91+
+ _6 = std::alloc::Global::alloc_impl(const {ALLOC1<imm>: &std::alloc::Global}, const Layout {{ size: Indirect { alloc_id: ALLOC0, offset: Size(8 bytes) }: usize, align: std::ptr::Alignment(Scalar(0x0000000000000000): std::ptr::alignment::AlignmentEnum) }}, const false) -> [return: bb4, unwind unreachable];
8892
}
8993

90-
bb3: {
94+
bb4: {
9195
StorageDead(_8);
9296
StorageDead(_7);
9397
StorageLive(_12);
9498
StorageLive(_15);
9599
_12 = discriminant(_6);
96-
switchInt(move _12) -> [0: bb5, 1: bb4, otherwise: bb6];
100+
switchInt(move _12) -> [0: bb6, 1: bb5, otherwise: bb1];
97101
}
98102

99-
bb4: {
103+
bb5: {
100104
_15 = const "called `Result::unwrap()` on an `Err` value";
101105
StorageLive(_16);
102106
StorageLive(_17);
@@ -106,7 +110,7 @@
106110
_14 = result::unwrap_failed(move _15, move _16) -> unwind unreachable;
107111
}
108112

109-
bb5: {
113+
bb6: {
110114
_5 = move ((_6 as Ok).0: std::ptr::NonNull<[u8]>);
111115
StorageDead(_15);
112116
StorageDead(_12);
@@ -127,10 +131,6 @@
127131
+ nop;
128132
return;
129133
}
130-
131-
bb6: {
132-
unreachable;
133-
}
134134
}
135135
+
136136
+ ALLOC0 (size: 16, align: 8) {

Diff for: tests/mir-opt/pre-codegen/issue_117368_print_invalid_constant.main.GVN.64bit.panic-unwind.diff

+10-10
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,9 @@
4343
+ _2 = const Option::<Layout>::None;
4444
StorageLive(_10);
4545
- _10 = discriminant(_2);
46-
- switchInt(move _10) -> [0: bb2, 1: bb3, otherwise: bb5];
46+
- switchInt(move _10) -> [0: bb3, 1: bb4, otherwise: bb2];
4747
+ _10 = const 0_isize;
48-
+ switchInt(const 0_isize) -> [0: bb2, 1: bb3, otherwise: bb5];
48+
+ switchInt(const 0_isize) -> [0: bb3, 1: bb4, otherwise: bb2];
4949
}
5050

5151
bb1: {
@@ -68,10 +68,14 @@
6868
}
6969

7070
bb2: {
71-
_11 = option::unwrap_failed() -> unwind continue;
71+
unreachable;
7272
}
7373

7474
bb3: {
75+
_11 = option::unwrap_failed() -> unwind continue;
76+
}
77+
78+
bb4: {
7579
- _1 = move ((_2 as Some).0: std::alloc::Layout);
7680
+ _1 = const Layout {{ size: Indirect { alloc_id: ALLOC0, offset: Size(8 bytes) }: usize, align: std::ptr::Alignment(Scalar(0x0000000000000000): std::ptr::alignment::AlignmentEnum) }};
7781
StorageDead(_10);
@@ -86,20 +90,16 @@
8690
+ _7 = const {ALLOC1<imm>: &std::alloc::Global};
8791
StorageLive(_8);
8892
- _8 = _1;
89-
- _6 = std::alloc::Global::alloc_impl(move _7, move _8, const false) -> [return: bb4, unwind continue];
93+
- _6 = std::alloc::Global::alloc_impl(move _7, move _8, const false) -> [return: bb5, unwind continue];
9094
+ _8 = const Layout {{ size: Indirect { alloc_id: ALLOC0, offset: Size(8 bytes) }: usize, align: std::ptr::Alignment(Scalar(0x0000000000000000): std::ptr::alignment::AlignmentEnum) }};
91-
+ _6 = std::alloc::Global::alloc_impl(const {ALLOC1<imm>: &std::alloc::Global}, const Layout {{ size: Indirect { alloc_id: ALLOC0, offset: Size(8 bytes) }: usize, align: std::ptr::Alignment(Scalar(0x0000000000000000): std::ptr::alignment::AlignmentEnum) }}, const false) -> [return: bb4, unwind continue];
95+
+ _6 = std::alloc::Global::alloc_impl(const {ALLOC1<imm>: &std::alloc::Global}, const Layout {{ size: Indirect { alloc_id: ALLOC0, offset: Size(8 bytes) }: usize, align: std::ptr::Alignment(Scalar(0x0000000000000000): std::ptr::alignment::AlignmentEnum) }}, const false) -> [return: bb5, unwind continue];
9296
}
9397

94-
bb4: {
98+
bb5: {
9599
StorageDead(_8);
96100
StorageDead(_7);
97101
_5 = Result::<NonNull<[u8]>, std::alloc::AllocError>::unwrap(move _6) -> [return: bb1, unwind continue];
98102
}
99-
100-
bb5: {
101-
unreachable;
102-
}
103103
}
104104
+
105105
+ ALLOC0 (size: 16, align: 8) {

0 commit comments

Comments
 (0)