Skip to content

Commit 9f35fe4

Browse files
committed
JumpThreading: Re-enable and fix Not ops on non-booleans
1 parent 092a284 commit 9f35fe4

6 files changed

+120
-27
lines changed

compiler/rustc_mir_transform/src/jump_threading.rs

+11-18
Original file line numberDiff line numberDiff line change
@@ -150,14 +150,6 @@ impl Condition {
150150
fn matches(&self, value: ScalarInt) -> bool {
151151
(self.value == value) == (self.polarity == Polarity::Eq)
152152
}
153-
154-
fn inv(mut self) -> Self {
155-
self.polarity = match self.polarity {
156-
Polarity::Eq => Polarity::Ne,
157-
Polarity::Ne => Polarity::Eq,
158-
};
159-
self
160-
}
161153
}
162154

163155
#[derive(Copy, Clone, Debug)]
@@ -495,19 +487,20 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
495487
}
496488
}
497489
}
498-
// Transfer the conditions on the copy rhs, after inversing polarity.
490+
// Transfer the conditions on the copy rhs, after inverting the value of the condition.
499491
Rvalue::UnaryOp(UnOp::Not, Operand::Move(place) | Operand::Copy(place)) => {
500-
if !place.ty(self.body, self.tcx).ty.is_bool() {
501-
// Constructing the conditions by inverting the polarity
502-
// of equality is only correct for bools. That is to say,
503-
// `!a == b` is not `a != b` for integers greater than 1 bit.
504-
return;
505-
}
492+
let layout = self.ecx.layout_of(place.ty(self.body, self.tcx).ty).unwrap();
506493
let Some(conditions) = state.try_get_idx(lhs, &self.map) else { return };
507494
let Some(place) = self.map.find(place.as_ref()) else { return };
508-
// FIXME: I think This could be generalized to not bool if we
509-
// actually perform a logical not on the condition's value.
510-
let conds = conditions.map(self.arena, Condition::inv);
495+
let conds = conditions.map(self.arena, |mut cond| {
496+
cond.value = self
497+
.ecx
498+
.unary_op(UnOp::Not, &ImmTy::from_scalar_int(cond.value, layout))
499+
.unwrap()
500+
.to_scalar_int()
501+
.unwrap();
502+
cond
503+
});
511504
state.insert_value_idx(place, conds, &self.map);
512505
}
513506
// We expect `lhs ?= A`. We found `lhs = Eq(rhs, B)`.

tests/mir-opt/jump_threading.bitwise_not.JumpThreading.panic-abort.diff

+3-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
fn bitwise_not() -> i32 {
55
let mut _0: i32;
6-
let mut _1: i32;
6+
let _1: i32;
77
let mut _2: bool;
88
let mut _3: i32;
99
let mut _4: i32;
@@ -13,7 +13,6 @@
1313

1414
bb0: {
1515
StorageLive(_1);
16-
_1 = const 0_i32;
1716
_1 = const 1_i32;
1817
StorageLive(_2);
1918
StorageLive(_3);
@@ -22,7 +21,8 @@
2221
_3 = Not(move _4);
2322
StorageDead(_4);
2423
_2 = Eq(move _3, const 0_i32);
25-
switchInt(move _2) -> [0: bb2, otherwise: bb1];
24+
- switchInt(move _2) -> [0: bb2, otherwise: bb1];
25+
+ goto -> bb2;
2626
}
2727

2828
bb1: {

tests/mir-opt/jump_threading.bitwise_not.JumpThreading.panic-unwind.diff

+3-3
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
fn bitwise_not() -> i32 {
55
let mut _0: i32;
6-
let mut _1: i32;
6+
let _1: i32;
77
let mut _2: bool;
88
let mut _3: i32;
99
let mut _4: i32;
@@ -13,7 +13,6 @@
1313

1414
bb0: {
1515
StorageLive(_1);
16-
_1 = const 0_i32;
1716
_1 = const 1_i32;
1817
StorageLive(_2);
1918
StorageLive(_3);
@@ -22,7 +21,8 @@
2221
_3 = Not(move _4);
2322
StorageDead(_4);
2423
_2 = Eq(move _3, const 0_i32);
25-
switchInt(move _2) -> [0: bb2, otherwise: bb1];
24+
- switchInt(move _2) -> [0: bb2, otherwise: bb1];
25+
+ goto -> bb2;
2626
}
2727

2828
bb1: {
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
- // MIR for `logical_not` before JumpThreading
2+
+ // MIR for `logical_not` after JumpThreading
3+
4+
fn logical_not() -> i32 {
5+
let mut _0: i32;
6+
let _1: bool;
7+
let mut _2: bool;
8+
let mut _3: bool;
9+
let mut _4: bool;
10+
scope 1 {
11+
debug a => _1;
12+
}
13+
14+
bb0: {
15+
StorageLive(_1);
16+
_1 = const false;
17+
StorageLive(_2);
18+
StorageLive(_3);
19+
StorageLive(_4);
20+
_4 = copy _1;
21+
_3 = Not(move _4);
22+
StorageDead(_4);
23+
_2 = Eq(move _3, const true);
24+
- switchInt(move _2) -> [0: bb2, otherwise: bb1];
25+
+ goto -> bb1;
26+
}
27+
28+
bb1: {
29+
StorageDead(_3);
30+
_0 = const 1_i32;
31+
goto -> bb3;
32+
}
33+
34+
bb2: {
35+
StorageDead(_3);
36+
_0 = const 0_i32;
37+
goto -> bb3;
38+
}
39+
40+
bb3: {
41+
StorageDead(_2);
42+
StorageDead(_1);
43+
return;
44+
}
45+
}
46+
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
- // MIR for `logical_not` before JumpThreading
2+
+ // MIR for `logical_not` after JumpThreading
3+
4+
fn logical_not() -> i32 {
5+
let mut _0: i32;
6+
let _1: bool;
7+
let mut _2: bool;
8+
let mut _3: bool;
9+
let mut _4: bool;
10+
scope 1 {
11+
debug a => _1;
12+
}
13+
14+
bb0: {
15+
StorageLive(_1);
16+
_1 = const false;
17+
StorageLive(_2);
18+
StorageLive(_3);
19+
StorageLive(_4);
20+
_4 = copy _1;
21+
_3 = Not(move _4);
22+
StorageDead(_4);
23+
_2 = Eq(move _3, const true);
24+
- switchInt(move _2) -> [0: bb2, otherwise: bb1];
25+
+ goto -> bb1;
26+
}
27+
28+
bb1: {
29+
StorageDead(_3);
30+
_0 = const 1_i32;
31+
goto -> bb3;
32+
}
33+
34+
bb2: {
35+
StorageDead(_3);
36+
_0 = const 0_i32;
37+
goto -> bb3;
38+
}
39+
40+
bb3: {
41+
StorageDead(_2);
42+
StorageDead(_1);
43+
return;
44+
}
45+
}
46+

tests/mir-opt/jump_threading.rs

+11-3
Original file line numberDiff line numberDiff line change
@@ -532,14 +532,19 @@ fn floats() -> u32 {
532532

533533
pub fn bitwise_not() -> i32 {
534534
// CHECK-LABEL: fn bitwise_not(
535-
// CHECK: switchInt(
536535

537536
// Test for #131195, which was optimizing `!a == b` into `a != b`.
538-
let mut a: i32 = 0;
539-
a = 1;
537+
let a = 1;
540538
if !a == 0 { 1 } else { 0 }
541539
}
542540

541+
pub fn logical_not() -> i32 {
542+
// CHECK-LABEL: fn logical_not(
543+
544+
let a = false;
545+
if !a == true { 1 } else { 0 }
546+
}
547+
543548
fn main() {
544549
// CHECK-LABEL: fn main(
545550
too_complex(Ok(0));
@@ -555,6 +560,8 @@ fn main() {
555560
aggregate(7);
556561
assume(7, false);
557562
floats();
563+
bitwise_not();
564+
logical_not();
558565
}
559566

560567
// EMIT_MIR jump_threading.too_complex.JumpThreading.diff
@@ -572,3 +579,4 @@ fn main() {
572579
// EMIT_MIR jump_threading.aggregate_copy.JumpThreading.diff
573580
// EMIT_MIR jump_threading.floats.JumpThreading.diff
574581
// EMIT_MIR jump_threading.bitwise_not.JumpThreading.diff
582+
// EMIT_MIR jump_threading.logical_not.JumpThreading.diff

0 commit comments

Comments
 (0)