Skip to content

Commit 5c84d77

Browse files
authored
Rollup merge of #60745 - wesleywiser:const_prop_into_terminators, r=oli-obk
Perform constant propagation into terminators Perform constant propagation into MIR `Assert` and `SwitchInt` `Terminator`s which in some cases allows them to be removed by the branch simplification pass. r? @oli-obk
2 parents f9d65c0 + ec853ba commit 5c84d77

File tree

5 files changed

+152
-73
lines changed

5 files changed

+152
-73
lines changed

src/librustc_mir/transform/const_prop.rs

+108-67
Original file line numberDiff line numberDiff line change
@@ -546,6 +546,10 @@ impl<'a, 'mir, 'tcx> ConstPropagator<'a, 'mir, 'tcx> {
546546
}
547547
}
548548
}
549+
550+
fn should_const_prop(&self) -> bool {
551+
self.tcx.sess.opts.debugging_opts.mir_opt_level >= 2
552+
}
549553
}
550554

551555
fn type_size_of<'a, 'tcx>(tcx: TyCtxt<'a, 'tcx, 'tcx>,
@@ -639,7 +643,7 @@ impl<'b, 'a, 'tcx> MutVisitor<'tcx> for ConstPropagator<'b, 'a, 'tcx> {
639643
assert!(self.places[local].is_none());
640644
self.places[local] = Some(value);
641645

642-
if self.tcx.sess.opts.debugging_opts.mir_opt_level >= 2 {
646+
if self.should_const_prop() {
643647
self.replace_with_const(rval, value, statement.source_info.span);
644648
}
645649
}
@@ -656,75 +660,112 @@ impl<'b, 'a, 'tcx> MutVisitor<'tcx> for ConstPropagator<'b, 'a, 'tcx> {
656660
location: Location,
657661
) {
658662
self.super_terminator(terminator, location);
659-
let source_info = terminator.source_info;;
660-
if let TerminatorKind::Assert { expected, msg, cond, .. } = &terminator.kind {
661-
if let Some(value) = self.eval_operand(&cond, source_info) {
662-
trace!("assertion on {:?} should be {:?}", value, expected);
663-
let expected = ScalarMaybeUndef::from(Scalar::from_bool(*expected));
664-
if expected != self.ecx.read_scalar(value).unwrap() {
665-
// poison all places this operand references so that further code
666-
// doesn't use the invalid value
667-
match cond {
668-
Operand::Move(ref place) | Operand::Copy(ref place) => {
669-
let mut place = place;
670-
while let Place::Projection(ref proj) = *place {
671-
place = &proj.base;
672-
}
673-
if let Place::Base(PlaceBase::Local(local)) = *place {
674-
self.places[local] = None;
663+
let source_info = terminator.source_info;
664+
match &mut terminator.kind {
665+
TerminatorKind::Assert { expected, msg, ref mut cond, .. } => {
666+
if let Some(value) = self.eval_operand(&cond, source_info) {
667+
trace!("assertion on {:?} should be {:?}", value, expected);
668+
let expected = ScalarMaybeUndef::from(Scalar::from_bool(*expected));
669+
let value_const = self.ecx.read_scalar(value).unwrap();
670+
if expected != value_const {
671+
// poison all places this operand references so that further code
672+
// doesn't use the invalid value
673+
match cond {
674+
Operand::Move(ref place) | Operand::Copy(ref place) => {
675+
let mut place = place;
676+
while let Place::Projection(ref proj) = *place {
677+
place = &proj.base;
678+
}
679+
if let Place::Base(PlaceBase::Local(local)) = *place {
680+
self.places[local] = None;
681+
}
682+
},
683+
Operand::Constant(_) => {}
684+
}
685+
let span = terminator.source_info.span;
686+
let hir_id = self
687+
.tcx
688+
.hir()
689+
.as_local_hir_id(self.source.def_id())
690+
.expect("some part of a failing const eval must be local");
691+
use rustc::mir::interpret::InterpError::*;
692+
let msg = match msg {
693+
Overflow(_) |
694+
OverflowNeg |
695+
DivisionByZero |
696+
RemainderByZero => msg.description().to_owned(),
697+
BoundsCheck { ref len, ref index } => {
698+
let len = self
699+
.eval_operand(len, source_info)
700+
.expect("len must be const");
701+
let len = match self.ecx.read_scalar(len) {
702+
Ok(ScalarMaybeUndef::Scalar(Scalar::Bits {
703+
bits, ..
704+
})) => bits,
705+
other => bug!("const len not primitive: {:?}", other),
706+
};
707+
let index = self
708+
.eval_operand(index, source_info)
709+
.expect("index must be const");
710+
let index = match self.ecx.read_scalar(index) {
711+
Ok(ScalarMaybeUndef::Scalar(Scalar::Bits {
712+
bits, ..
713+
})) => bits,
714+
other => bug!("const index not primitive: {:?}", other),
715+
};
716+
format!(
717+
"index out of bounds: \
718+
the len is {} but the index is {}",
719+
len,
720+
index,
721+
)
722+
},
723+
// Need proper const propagator for these
724+
_ => return,
725+
};
726+
self.tcx.lint_hir(
727+
::rustc::lint::builtin::CONST_ERR,
728+
hir_id,
729+
span,
730+
&msg,
731+
);
732+
} else {
733+
if self.should_const_prop() {
734+
if let ScalarMaybeUndef::Scalar(scalar) = value_const {
735+
*cond = self.operand_from_scalar(
736+
scalar,
737+
self.tcx.types.bool,
738+
source_info.span,
739+
);
675740
}
676-
},
677-
Operand::Constant(_) => {}
741+
}
678742
}
679-
let span = terminator.source_info.span;
680-
let hir_id = self
681-
.tcx
682-
.hir()
683-
.as_local_hir_id(self.source.def_id())
684-
.expect("some part of a failing const eval must be local");
685-
use rustc::mir::interpret::InterpError::*;
686-
let msg = match msg {
687-
Overflow(_) |
688-
OverflowNeg |
689-
DivisionByZero |
690-
RemainderByZero => msg.description().to_owned(),
691-
BoundsCheck { ref len, ref index } => {
692-
let len = self
693-
.eval_operand(len, source_info)
694-
.expect("len must be const");
695-
let len = match self.ecx.read_scalar(len) {
696-
Ok(ScalarMaybeUndef::Scalar(Scalar::Bits {
697-
bits, ..
698-
})) => bits,
699-
other => bug!("const len not primitive: {:?}", other),
700-
};
701-
let index = self
702-
.eval_operand(index, source_info)
703-
.expect("index must be const");
704-
let index = match self.ecx.read_scalar(index) {
705-
Ok(ScalarMaybeUndef::Scalar(Scalar::Bits {
706-
bits, ..
707-
})) => bits,
708-
other => bug!("const index not primitive: {:?}", other),
709-
};
710-
format!(
711-
"index out of bounds: \
712-
the len is {} but the index is {}",
713-
len,
714-
index,
715-
)
716-
},
717-
// Need proper const propagator for these
718-
_ => return,
719-
};
720-
self.tcx.lint_hir(
721-
::rustc::lint::builtin::CONST_ERR,
722-
hir_id,
723-
span,
724-
&msg,
725-
);
726743
}
727-
}
744+
},
745+
TerminatorKind::SwitchInt { ref mut discr, switch_ty, .. } => {
746+
if self.should_const_prop() {
747+
if let Some(value) = self.eval_operand(&discr, source_info) {
748+
if let ScalarMaybeUndef::Scalar(scalar) =
749+
self.ecx.read_scalar(value).unwrap() {
750+
*discr = self.operand_from_scalar(scalar, switch_ty, source_info.span);
751+
}
752+
}
753+
}
754+
},
755+
//none of these have Operands to const-propagate
756+
TerminatorKind::Goto { .. } |
757+
TerminatorKind::Resume |
758+
TerminatorKind::Abort |
759+
TerminatorKind::Return |
760+
TerminatorKind::Unreachable |
761+
TerminatorKind::Drop { .. } |
762+
TerminatorKind::DropAndReplace { .. } |
763+
TerminatorKind::Yield { .. } |
764+
TerminatorKind::GeneratorDrop |
765+
TerminatorKind::FalseEdges { .. } |
766+
TerminatorKind::FalseUnwind { .. } => { }
767+
//FIXME(wesleywiser) Call does have Operands that could be const-propagated
768+
TerminatorKind::Call { .. } => { }
728769
}
729770
}
730771
}

src/test/mir-opt/const_prop/array_index.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ fn main() {
2323
// bb0: {
2424
// ...
2525
// _5 = const true;
26-
// assert(move _5, "index out of bounds: the len is move _4 but the index is _3") -> bb1;
26+
// assert(const true, "index out of bounds: the len is move _4 but the index is _3") -> bb1;
2727
// }
2828
// bb1: {
2929
// _1 = _2[_3];

src/test/mir-opt/const_prop/checked_add.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,6 @@ fn main() {
1616
// bb0: {
1717
// ...
1818
// _2 = (const 2u32, const false);
19-
// assert(!move (_2.1: bool), "attempt to add with overflow") -> bb1;
19+
// assert(!const false, "attempt to add with overflow") -> bb1;
2020
// }
2121
// END rustc.main.ConstProp.after.mir
+38
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
#[inline(never)]
2+
fn foo(_: i32) { }
3+
4+
fn main() {
5+
match 1 {
6+
1 => foo(0),
7+
_ => foo(-1),
8+
}
9+
}
10+
11+
// END RUST SOURCE
12+
// START rustc.main.ConstProp.before.mir
13+
// bb0: {
14+
// ...
15+
// _1 = const 1i32;
16+
// switchInt(_1) -> [1i32: bb1, otherwise: bb2];
17+
// }
18+
// END rustc.main.ConstProp.before.mir
19+
// START rustc.main.ConstProp.after.mir
20+
// bb0: {
21+
// ...
22+
// switchInt(const 1i32) -> [1i32: bb1, otherwise: bb2];
23+
// }
24+
// END rustc.main.ConstProp.after.mir
25+
// START rustc.main.SimplifyBranches-after-const-prop.before.mir
26+
// bb0: {
27+
// ...
28+
// _1 = const 1i32;
29+
// switchInt(const 1i32) -> [1i32: bb1, otherwise: bb2];
30+
// }
31+
// END rustc.main.SimplifyBranches-after-const-prop.before.mir
32+
// START rustc.main.SimplifyBranches-after-const-prop.after.mir
33+
// bb0: {
34+
// ...
35+
// _1 = const 1i32;
36+
// goto -> bb1;
37+
// }
38+
// END rustc.main.SimplifyBranches-after-const-prop.after.mir

src/test/mir-opt/simplify_if.rs

+4-4
Original file line numberDiff line numberDiff line change
@@ -5,15 +5,15 @@ fn main() {
55
}
66

77
// END RUST SOURCE
8-
// START rustc.main.SimplifyBranches-after-copy-prop.before.mir
8+
// START rustc.main.SimplifyBranches-after-const-prop.before.mir
99
// bb0: {
1010
// ...
1111
// switchInt(const false) -> [false: bb3, otherwise: bb1];
1212
// }
13-
// END rustc.main.SimplifyBranches-after-copy-prop.before.mir
14-
// START rustc.main.SimplifyBranches-after-copy-prop.after.mir
13+
// END rustc.main.SimplifyBranches-after-const-prop.before.mir
14+
// START rustc.main.SimplifyBranches-after-const-prop.after.mir
1515
// bb0: {
1616
// ...
1717
// goto -> bb3;
1818
// }
19-
// END rustc.main.SimplifyBranches-after-copy-prop.after.mir
19+
// END rustc.main.SimplifyBranches-after-const-prop.after.mir

0 commit comments

Comments
 (0)