Skip to content

Commit a234e70

Browse files
committed
JumpThreading: Bail out on interp errors
1 parent 3ccfe76 commit a234e70

File tree

1 file changed

+89
-61
lines changed

1 file changed

+89
-61
lines changed

Diff for: compiler/rustc_mir_transform/src/jump_threading.rs

+89-61
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,11 @@ impl<'tcx> crate::MirPass<'tcx> for JumpThreading {
9292
};
9393

9494
for bb in body.basic_blocks.indices() {
95-
finder.start_from_switch(bb);
95+
let old_len = finder.opportunities.len();
96+
// If we have any const-eval errors discard any opportunities found
97+
if finder.start_from_switch(bb).is_none() {
98+
finder.opportunities.truncate(old_len);
99+
}
96100
}
97101

98102
let opportunities = finder.opportunities;
@@ -170,8 +174,21 @@ impl<'a> ConditionSet<'a> {
170174
self.iter().filter(move |c| c.matches(value))
171175
}
172176

173-
fn map(self, arena: &'a DroplessArena, f: impl Fn(Condition) -> Condition) -> ConditionSet<'a> {
174-
ConditionSet(arena.alloc_from_iter(self.iter().map(f)))
177+
fn map(
178+
self,
179+
arena: &'a DroplessArena,
180+
f: impl Fn(Condition) -> Option<Condition>,
181+
) -> Option<ConditionSet<'a>> {
182+
let mut all_ok = true;
183+
let set = arena.alloc_from_iter(self.iter().map_while(|c| {
184+
if let Some(c) = f(c) {
185+
Some(c)
186+
} else {
187+
all_ok = false;
188+
None
189+
}
190+
}));
191+
all_ok.then_some(ConditionSet(set))
175192
}
176193
}
177194

@@ -182,28 +199,28 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
182199

183200
/// Recursion entry point to find threading opportunities.
184201
#[instrument(level = "trace", skip(self))]
185-
fn start_from_switch(&mut self, bb: BasicBlock) {
202+
fn start_from_switch(&mut self, bb: BasicBlock) -> Option<()> {
186203
let bbdata = &self.body[bb];
187204
if bbdata.is_cleanup || self.loop_headers.contains(bb) {
188-
return;
205+
return Some(());
189206
}
190-
let Some((discr, targets)) = bbdata.terminator().kind.as_switch() else { return };
191-
let Some(discr) = discr.place() else { return };
207+
let Some((discr, targets)) = bbdata.terminator().kind.as_switch() else { return Some(()) };
208+
let Some(discr) = discr.place() else { return Some(()) };
192209
debug!(?discr, ?bb);
193210

194211
let discr_ty = discr.ty(self.body, self.tcx).ty;
195212
let Ok(discr_layout) = self.ecx.layout_of(discr_ty) else {
196-
return;
213+
return Some(());
197214
};
198215

199-
let Some(discr) = self.map.find(discr.as_ref()) else { return };
216+
let Some(discr) = self.map.find(discr.as_ref()) else { return Some(()) };
200217
debug!(?discr);
201218

202219
let cost = CostChecker::new(self.tcx, self.param_env, None, self.body);
203220
let mut state = State::new_reachable();
204221

205222
let conds = if let Some((value, then, else_)) = targets.as_static_if() {
206-
let Some(value) = ScalarInt::try_from_uint(value, discr_layout.size) else { return };
223+
let value = ScalarInt::try_from_uint(value, discr_layout.size)?;
207224
self.arena.alloc_from_iter([
208225
Condition { value, polarity: Polarity::Eq, target: then },
209226
Condition { value, polarity: Polarity::Ne, target: else_ },
@@ -217,7 +234,7 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
217234
let conds = ConditionSet(conds);
218235
state.insert_value_idx(discr, conds, &self.map);
219236

220-
self.find_opportunity(bb, state, cost, 0);
237+
self.find_opportunity(bb, state, cost, 0)
221238
}
222239

223240
/// Recursively walk statements backwards from this bb's terminator to find threading
@@ -229,27 +246,27 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
229246
mut state: State<ConditionSet<'a>>,
230247
mut cost: CostChecker<'_, 'tcx>,
231248
depth: usize,
232-
) {
249+
) -> Option<()> {
233250
// Do not thread through loop headers.
234251
if self.loop_headers.contains(bb) {
235-
return;
252+
return Some(());
236253
}
237254

238255
debug!(cost = ?cost.cost());
239256
for (statement_index, stmt) in
240257
self.body.basic_blocks[bb].statements.iter().enumerate().rev()
241258
{
242259
if self.is_empty(&state) {
243-
return;
260+
return Some(());
244261
}
245262

246263
cost.visit_statement(stmt, Location { block: bb, statement_index });
247264
if cost.cost() > MAX_COST {
248-
return;
265+
return Some(());
249266
}
250267

251268
// Attempt to turn the `current_condition` on `lhs` into a condition on another place.
252-
self.process_statement(bb, stmt, &mut state);
269+
self.process_statement(bb, stmt, &mut state)?;
253270

254271
// When a statement mutates a place, assignments to that place that happen
255272
// above the mutation cannot fulfill a condition.
@@ -261,7 +278,7 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
261278
}
262279

263280
if self.is_empty(&state) || depth >= MAX_BACKTRACK {
264-
return;
281+
return Some(());
265282
}
266283

267284
let last_non_rec = self.opportunities.len();
@@ -274,9 +291,9 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
274291
match term.kind {
275292
TerminatorKind::SwitchInt { ref discr, ref targets } => {
276293
self.process_switch_int(discr, targets, bb, &mut state);
277-
self.find_opportunity(pred, state, cost, depth + 1);
294+
self.find_opportunity(pred, state, cost, depth + 1)?;
278295
}
279-
_ => self.recurse_through_terminator(pred, || state, &cost, depth),
296+
_ => self.recurse_through_terminator(pred, || state, &cost, depth)?,
280297
}
281298
} else if let &[ref predecessors @ .., last_pred] = &predecessors[..] {
282299
for &pred in predecessors {
@@ -301,12 +318,13 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
301318
let first = &mut new_tos[0];
302319
*first = ThreadingOpportunity { chain: vec![bb], target: first.target };
303320
self.opportunities.truncate(last_non_rec + 1);
304-
return;
321+
return Some(());
305322
}
306323

307324
for op in self.opportunities[last_non_rec..].iter_mut() {
308325
op.chain.push(bb);
309326
}
327+
Some(())
310328
}
311329

312330
/// Extract the mutated place from a statement.
@@ -419,23 +437,24 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
419437
lhs: PlaceIndex,
420438
rhs: &Operand<'tcx>,
421439
state: &mut State<ConditionSet<'a>>,
422-
) {
440+
) -> Option<()> {
423441
match rhs {
424442
// If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
425443
Operand::Constant(constant) => {
426444
let Some(constant) =
427445
self.ecx.eval_mir_constant(&constant.const_, constant.span, None).discard_err()
428446
else {
429-
return;
447+
return Some(());
430448
};
431449
self.process_constant(bb, lhs, constant, state);
432450
}
433451
// Transfer the conditions on the copied rhs.
434452
Operand::Move(rhs) | Operand::Copy(rhs) => {
435-
let Some(rhs) = self.map.find(rhs.as_ref()) else { return };
453+
let Some(rhs) = self.map.find(rhs.as_ref()) else { return Some(()) };
436454
state.insert_place_idx(rhs, lhs, &self.map);
437455
}
438456
}
457+
Some(())
439458
}
440459

441460
#[instrument(level = "trace", skip(self))]
@@ -445,22 +464,26 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
445464
lhs_place: &Place<'tcx>,
446465
rhs: &Rvalue<'tcx>,
447466
state: &mut State<ConditionSet<'a>>,
448-
) {
449-
let Some(lhs) = self.map.find(lhs_place.as_ref()) else { return };
467+
) -> Option<()> {
468+
let Some(lhs) = self.map.find(lhs_place.as_ref()) else {
469+
return Some(());
470+
};
450471
match rhs {
451-
Rvalue::Use(operand) => self.process_operand(bb, lhs, operand, state),
472+
Rvalue::Use(operand) => self.process_operand(bb, lhs, operand, state)?,
452473
// Transfer the conditions on the copy rhs.
453-
Rvalue::CopyForDeref(rhs) => self.process_operand(bb, lhs, &Operand::Copy(*rhs), state),
474+
Rvalue::CopyForDeref(rhs) => {
475+
self.process_operand(bb, lhs, &Operand::Copy(*rhs), state)?
476+
}
454477
Rvalue::Discriminant(rhs) => {
455-
let Some(rhs) = self.map.find_discr(rhs.as_ref()) else { return };
478+
let Some(rhs) = self.map.find_discr(rhs.as_ref()) else { return Some(()) };
456479
state.insert_place_idx(rhs, lhs, &self.map);
457480
}
458481
// If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
459482
Rvalue::Aggregate(box ref kind, ref operands) => {
460483
let agg_ty = lhs_place.ty(self.body, self.tcx).ty;
461484
let lhs = match kind {
462485
// Do not support unions.
463-
AggregateKind::Adt(.., Some(_)) => return,
486+
AggregateKind::Adt(.., Some(_)) => return Some(()),
464487
AggregateKind::Adt(_, variant_index, ..) if agg_ty.is_enum() => {
465488
if let Some(discr_target) = self.map.apply(lhs, TrackElem::Discriminant)
466489
&& let Some(discr_value) = self
@@ -473,31 +496,31 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
473496
if let Some(idx) = self.map.apply(lhs, TrackElem::Variant(*variant_index)) {
474497
idx
475498
} else {
476-
return;
499+
return Some(());
477500
}
478501
}
479502
_ => lhs,
480503
};
481504
for (field_index, operand) in operands.iter_enumerated() {
482505
if let Some(field) = self.map.apply(lhs, TrackElem::Field(field_index)) {
483-
self.process_operand(bb, field, operand, state);
506+
self.process_operand(bb, field, operand, state)?;
484507
}
485508
}
486509
}
487510
// Transfer the conditions on the copy rhs, after inverting the value of the condition.
488511
Rvalue::UnaryOp(UnOp::Not, Operand::Move(place) | Operand::Copy(place)) => {
489512
let layout = self.ecx.layout_of(place.ty(self.body, self.tcx).ty).unwrap();
490-
let Some(conditions) = state.try_get_idx(lhs, &self.map) else { return };
491-
let Some(place) = self.map.find(place.as_ref()) else { return };
513+
let Some(conditions) = state.try_get_idx(lhs, &self.map) else { return Some(()) };
514+
let Some(place) = self.map.find(place.as_ref()) else { return Some(()) };
492515
let conds = conditions.map(self.arena, |mut cond| {
493516
cond.value = self
494517
.ecx
495518
.unary_op(UnOp::Not, &ImmTy::from_scalar_int(cond.value, layout))
496-
.unwrap()
519+
.discard_err()?
497520
.to_scalar_int()
498-
.unwrap();
499-
cond
500-
});
521+
.discard_err()?;
522+
Some(cond)
523+
})?;
501524
state.insert_value_idx(place, conds, &self.map);
502525
}
503526
// We expect `lhs ?= A`. We found `lhs = Eq(rhs, B)`.
@@ -507,33 +530,36 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
507530
box (Operand::Move(place) | Operand::Copy(place), Operand::Constant(value))
508531
| box (Operand::Constant(value), Operand::Move(place) | Operand::Copy(place)),
509532
) => {
510-
let Some(conditions) = state.try_get_idx(lhs, &self.map) else { return };
511-
let Some(place) = self.map.find(place.as_ref()) else { return };
533+
let Some(conditions) = state.try_get_idx(lhs, &self.map) else { return Some(()) };
534+
let Some(place) = self.map.find(place.as_ref()) else { return Some(()) };
512535
let equals = match op {
513536
BinOp::Eq => ScalarInt::TRUE,
514537
BinOp::Ne => ScalarInt::FALSE,
515-
_ => return,
538+
_ => return Some(()),
516539
};
517540
if value.const_.ty().is_floating_point() {
518541
// Floating point equality does not follow bit-patterns.
519542
// -0.0 and NaN both have special rules for equality,
520543
// and therefore we cannot use integer comparisons for them.
521544
// Avoid handling them, though this could be extended in the future.
522-
return;
545+
return Some(());
523546
}
524547
let Some(value) = value.const_.try_eval_scalar_int(self.tcx, self.param_env) else {
525-
return;
548+
return Some(());
526549
};
527-
let conds = conditions.map(self.arena, |c| Condition {
528-
value,
529-
polarity: if c.matches(equals) { Polarity::Eq } else { Polarity::Ne },
530-
..c
531-
});
550+
let conds = conditions.map(self.arena, |c| {
551+
Some(Condition {
552+
value,
553+
polarity: if c.matches(equals) { Polarity::Eq } else { Polarity::Ne },
554+
..c
555+
})
556+
})?;
532557
state.insert_value_idx(place, conds, &self.map);
533558
}
534559

535560
_ => {}
536561
}
562+
Some(())
537563
}
538564

539565
#[instrument(level = "trace", skip(self))]
@@ -542,7 +568,7 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
542568
bb: BasicBlock,
543569
stmt: &Statement<'tcx>,
544570
state: &mut State<ConditionSet<'a>>,
545-
) {
571+
) -> Option<()> {
546572
let register_opportunity = |c: Condition| {
547573
debug!(?bb, ?c.target, "register");
548574
self.opportunities.push(ThreadingOpportunity { chain: vec![bb], target: c.target })
@@ -555,13 +581,15 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
555581
// If we expect `discriminant(place) ?= A`,
556582
// we have an opportunity if `variant_index ?= A`.
557583
StatementKind::SetDiscriminant { box place, variant_index } => {
558-
let Some(discr_target) = self.map.find_discr(place.as_ref()) else { return };
584+
let Some(discr_target) = self.map.find_discr(place.as_ref()) else {
585+
return Some(());
586+
};
559587
let enum_ty = place.ty(self.body, self.tcx).ty;
560588
// `SetDiscriminant` may be a no-op if the assigned variant is the untagged variant
561589
// of a niche encoding. If we cannot ensure that we write to the discriminant, do
562590
// nothing.
563591
let Ok(enum_layout) = self.ecx.layout_of(enum_ty) else {
564-
return;
592+
return Some(());
565593
};
566594
let writes_discriminant = match enum_layout.variants {
567595
Variants::Single { index } => {
@@ -575,26 +603,26 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
575603
} => *variant_index != untagged_variant,
576604
};
577605
if writes_discriminant {
578-
let Some(discr) =
579-
self.ecx.discriminant_for_variant(enum_ty, *variant_index).discard_err()
580-
else {
581-
return;
582-
};
606+
let discr =
607+
self.ecx.discriminant_for_variant(enum_ty, *variant_index).discard_err()?;
583608
self.process_immediate(bb, discr_target, discr, state);
584609
}
585610
}
586611
// If we expect `lhs ?= true`, we have an opportunity if we assume `lhs == true`.
587612
StatementKind::Intrinsic(box NonDivergingIntrinsic::Assume(
588613
Operand::Copy(place) | Operand::Move(place),
589614
)) => {
590-
let Some(conditions) = state.try_get(place.as_ref(), &self.map) else { return };
615+
let Some(conditions) = state.try_get(place.as_ref(), &self.map) else {
616+
return Some(());
617+
};
591618
conditions.iter_matches(ScalarInt::TRUE).for_each(register_opportunity);
592619
}
593620
StatementKind::Assign(box (lhs_place, rhs)) => {
594-
self.process_assign(bb, lhs_place, rhs, state);
621+
self.process_assign(bb, lhs_place, rhs, state)?;
595622
}
596623
_ => {}
597624
}
625+
Some(())
598626
}
599627

600628
#[instrument(level = "trace", skip(self, state, cost))]
@@ -605,7 +633,7 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
605633
state: impl FnOnce() -> State<ConditionSet<'a>>,
606634
cost: &CostChecker<'_, 'tcx>,
607635
depth: usize,
608-
) {
636+
) -> Option<()> {
609637
let term = self.body.basic_blocks[bb].terminator();
610638
let place_to_flood = match term.kind {
611639
// We come from a target, so those are not possible.
@@ -620,9 +648,9 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
620648
| TerminatorKind::FalseUnwind { .. }
621649
| TerminatorKind::Yield { .. } => bug!("{term:?} invalid"),
622650
// Cannot reason about inline asm.
623-
TerminatorKind::InlineAsm { .. } => return,
651+
TerminatorKind::InlineAsm { .. } => return Some(()),
624652
// `SwitchInt` is handled specially.
625-
TerminatorKind::SwitchInt { .. } => return,
653+
TerminatorKind::SwitchInt { .. } => return Some(()),
626654
// We can recurse, no thing particular to do.
627655
TerminatorKind::Goto { .. } => None,
628656
// Flood the overwritten place, and progress through.
@@ -637,7 +665,7 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
637665
if let Some(place_to_flood) = place_to_flood {
638666
state.flood_with(place_to_flood.as_ref(), &self.map, ConditionSet::BOTTOM);
639667
}
640-
self.find_opportunity(bb, state, cost.clone(), depth + 1);
668+
self.find_opportunity(bb, state, cost.clone(), depth + 1)
641669
}
642670

643671
#[instrument(level = "trace", skip(self))]

0 commit comments

Comments
 (0)