Skip to content

Commit 41a5d8e

Browse files
committed
JumpThreading: Bail out on interp errors
1 parent 9f35fe4 commit 41a5d8e

File tree

1 file changed

+91
-67
lines changed

1 file changed

+91
-67
lines changed

compiler/rustc_mir_transform/src/jump_threading.rs

+91-67
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,11 @@ impl<'tcx> crate::MirPass<'tcx> for JumpThreading {
9090
};
9191

9292
for bb in body.basic_blocks.indices() {
93-
finder.start_from_switch(bb);
93+
let old_len = finder.opportunities.len();
94+
// If we have any const-eval errors discard any opportunities found
95+
if finder.start_from_switch(bb).is_none() {
96+
finder.opportunities.truncate(old_len);
97+
}
9498
}
9599

96100
let opportunities = finder.opportunities;
@@ -172,8 +176,21 @@ impl<'a> ConditionSet<'a> {
172176
self.iter().filter(move |c| c.matches(value))
173177
}
174178

175-
fn map(self, arena: &'a DroplessArena, f: impl Fn(Condition) -> Condition) -> ConditionSet<'a> {
176-
ConditionSet(arena.alloc_from_iter(self.iter().map(f)))
179+
fn map(
180+
self,
181+
arena: &'a DroplessArena,
182+
f: impl Fn(Condition) -> Option<Condition>,
183+
) -> Option<ConditionSet<'a>> {
184+
let mut all_ok = true;
185+
let set = arena.alloc_from_iter(self.iter().map_while(|c| {
186+
if let Some(c) = f(c) {
187+
Some(c)
188+
} else {
189+
all_ok = false;
190+
None
191+
}
192+
}));
193+
all_ok.then_some(ConditionSet(set))
177194
}
178195
}
179196

@@ -184,28 +201,28 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
184201

185202
/// Recursion entry point to find threading opportunities.
186203
#[instrument(level = "trace", skip(self))]
187-
fn start_from_switch(&mut self, bb: BasicBlock) {
204+
fn start_from_switch(&mut self, bb: BasicBlock) -> Option<()> {
188205
let bbdata = &self.body[bb];
189206
if bbdata.is_cleanup || self.loop_headers.contains(bb) {
190-
return;
207+
return Some(());
191208
}
192-
let Some((discr, targets)) = bbdata.terminator().kind.as_switch() else { return };
193-
let Some(discr) = discr.place() else { return };
209+
let Some((discr, targets)) = bbdata.terminator().kind.as_switch() else { return Some(()) };
210+
let Some(discr) = discr.place() else { return Some(()) };
194211
debug!(?discr, ?bb);
195212

196213
let discr_ty = discr.ty(self.body, self.tcx).ty;
197214
let Ok(discr_layout) = self.ecx.layout_of(discr_ty) else {
198-
return;
215+
return Some(());
199216
};
200217

201-
let Some(discr) = self.map.find(discr.as_ref()) else { return };
218+
let Some(discr) = self.map.find(discr.as_ref()) else { return Some(()) };
202219
debug!(?discr);
203220

204221
let cost = CostChecker::new(self.tcx, self.typing_env, None, self.body);
205222
let mut state = State::new_reachable();
206223

207224
let conds = if let Some((value, then, else_)) = targets.as_static_if() {
208-
let Some(value) = ScalarInt::try_from_uint(value, discr_layout.size) else { return };
225+
let value = ScalarInt::try_from_uint(value, discr_layout.size)?;
209226
self.arena.alloc_from_iter([
210227
Condition { value, polarity: Polarity::Eq, target: then },
211228
Condition { value, polarity: Polarity::Ne, target: else_ },
@@ -219,7 +236,7 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
219236
let conds = ConditionSet(conds);
220237
state.insert_value_idx(discr, conds, &self.map);
221238

222-
self.find_opportunity(bb, state, cost, 0);
239+
self.find_opportunity(bb, state, cost, 0)
223240
}
224241

225242
/// Recursively walk statements backwards from this bb's terminator to find threading
@@ -231,27 +248,27 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
231248
mut state: State<ConditionSet<'a>>,
232249
mut cost: CostChecker<'_, 'tcx>,
233250
depth: usize,
234-
) {
251+
) -> Option<()> {
235252
// Do not thread through loop headers.
236253
if self.loop_headers.contains(bb) {
237-
return;
254+
return Some(());
238255
}
239256

240257
debug!(cost = ?cost.cost());
241258
for (statement_index, stmt) in
242259
self.body.basic_blocks[bb].statements.iter().enumerate().rev()
243260
{
244261
if self.is_empty(&state) {
245-
return;
262+
return Some(());
246263
}
247264

248265
cost.visit_statement(stmt, Location { block: bb, statement_index });
249266
if cost.cost() > MAX_COST {
250-
return;
267+
return Some(());
251268
}
252269

253270
// Attempt to turn the `current_condition` on `lhs` into a condition on another place.
254-
self.process_statement(bb, stmt, &mut state);
271+
self.process_statement(bb, stmt, &mut state)?;
255272

256273
// When a statement mutates a place, assignments to that place that happen
257274
// above the mutation cannot fulfill a condition.
@@ -263,7 +280,7 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
263280
}
264281

265282
if self.is_empty(&state) || depth >= MAX_BACKTRACK {
266-
return;
283+
return Some(());
267284
}
268285

269286
let last_non_rec = self.opportunities.len();
@@ -276,9 +293,9 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
276293
match term.kind {
277294
TerminatorKind::SwitchInt { ref discr, ref targets } => {
278295
self.process_switch_int(discr, targets, bb, &mut state);
279-
self.find_opportunity(pred, state, cost, depth + 1);
296+
self.find_opportunity(pred, state, cost, depth + 1)?;
280297
}
281-
_ => self.recurse_through_terminator(pred, || state, &cost, depth),
298+
_ => self.recurse_through_terminator(pred, || state, &cost, depth)?,
282299
}
283300
} else if let &[ref predecessors @ .., last_pred] = &predecessors[..] {
284301
for &pred in predecessors {
@@ -303,12 +320,13 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
303320
let first = &mut new_tos[0];
304321
*first = ThreadingOpportunity { chain: vec![bb], target: first.target };
305322
self.opportunities.truncate(last_non_rec + 1);
306-
return;
323+
return Some(());
307324
}
308325

309326
for op in self.opportunities[last_non_rec..].iter_mut() {
310327
op.chain.push(bb);
311328
}
329+
Some(())
312330
}
313331

314332
/// Extract the mutated place from a statement.
@@ -422,23 +440,23 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
422440
lhs: PlaceIndex,
423441
rhs: &Operand<'tcx>,
424442
state: &mut State<ConditionSet<'a>>,
425-
) {
443+
) -> Option<()> {
426444
match rhs {
427445
// If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
428446
Operand::Constant(constant) => {
429-
let Some(constant) =
430-
self.ecx.eval_mir_constant(&constant.const_, constant.span, None).discard_err()
431-
else {
432-
return;
433-
};
447+
let constant = self
448+
.ecx
449+
.eval_mir_constant(&constant.const_, constant.span, None)
450+
.discard_err()?;
434451
self.process_constant(bb, lhs, constant, state);
435452
}
436453
// Transfer the conditions on the copied rhs.
437454
Operand::Move(rhs) | Operand::Copy(rhs) => {
438-
let Some(rhs) = self.map.find(rhs.as_ref()) else { return };
455+
let Some(rhs) = self.map.find(rhs.as_ref()) else { return Some(()) };
439456
state.insert_place_idx(rhs, lhs, &self.map);
440457
}
441458
}
459+
Some(())
442460
}
443461

444462
#[instrument(level = "trace", skip(self))]
@@ -448,22 +466,26 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
448466
lhs_place: &Place<'tcx>,
449467
rhs: &Rvalue<'tcx>,
450468
state: &mut State<ConditionSet<'a>>,
451-
) {
452-
let Some(lhs) = self.map.find(lhs_place.as_ref()) else { return };
469+
) -> Option<()> {
470+
let Some(lhs) = self.map.find(lhs_place.as_ref()) else {
471+
return Some(());
472+
};
453473
match rhs {
454-
Rvalue::Use(operand) => self.process_operand(bb, lhs, operand, state),
474+
Rvalue::Use(operand) => self.process_operand(bb, lhs, operand, state)?,
455475
// Transfer the conditions on the copy rhs.
456-
Rvalue::CopyForDeref(rhs) => self.process_operand(bb, lhs, &Operand::Copy(*rhs), state),
476+
Rvalue::CopyForDeref(rhs) => {
477+
self.process_operand(bb, lhs, &Operand::Copy(*rhs), state)?
478+
}
457479
Rvalue::Discriminant(rhs) => {
458-
let Some(rhs) = self.map.find_discr(rhs.as_ref()) else { return };
480+
let Some(rhs) = self.map.find_discr(rhs.as_ref()) else { return Some(()) };
459481
state.insert_place_idx(rhs, lhs, &self.map);
460482
}
461483
// If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
462484
Rvalue::Aggregate(box kind, operands) => {
463485
let agg_ty = lhs_place.ty(self.body, self.tcx).ty;
464486
let lhs = match kind {
465487
// Do not support unions.
466-
AggregateKind::Adt(.., Some(_)) => return,
488+
AggregateKind::Adt(.., Some(_)) => return Some(()),
467489
AggregateKind::Adt(_, variant_index, ..) if agg_ty.is_enum() => {
468490
if let Some(discr_target) = self.map.apply(lhs, TrackElem::Discriminant)
469491
&& let Some(discr_value) = self
@@ -476,31 +498,31 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
476498
if let Some(idx) = self.map.apply(lhs, TrackElem::Variant(*variant_index)) {
477499
idx
478500
} else {
479-
return;
501+
return Some(());
480502
}
481503
}
482504
_ => lhs,
483505
};
484506
for (field_index, operand) in operands.iter_enumerated() {
485507
if let Some(field) = self.map.apply(lhs, TrackElem::Field(field_index)) {
486-
self.process_operand(bb, field, operand, state);
508+
self.process_operand(bb, field, operand, state)?;
487509
}
488510
}
489511
}
490512
// Transfer the conditions on the copy rhs, after inverting the value of the condition.
491513
Rvalue::UnaryOp(UnOp::Not, Operand::Move(place) | Operand::Copy(place)) => {
492514
let layout = self.ecx.layout_of(place.ty(self.body, self.tcx).ty).unwrap();
493-
let Some(conditions) = state.try_get_idx(lhs, &self.map) else { return };
494-
let Some(place) = self.map.find(place.as_ref()) else { return };
515+
let Some(conditions) = state.try_get_idx(lhs, &self.map) else { return Some(()) };
516+
let Some(place) = self.map.find(place.as_ref()) else { return Some(()) };
495517
let conds = conditions.map(self.arena, |mut cond| {
496518
cond.value = self
497519
.ecx
498520
.unary_op(UnOp::Not, &ImmTy::from_scalar_int(cond.value, layout))
499-
.unwrap()
521+
.discard_err()?
500522
.to_scalar_int()
501-
.unwrap();
502-
cond
503-
});
523+
.discard_err()?;
524+
Some(cond)
525+
})?;
504526
state.insert_value_idx(place, conds, &self.map);
505527
}
506528
// We expect `lhs ?= A`. We found `lhs = Eq(rhs, B)`.
@@ -510,34 +532,34 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
510532
box (Operand::Move(place) | Operand::Copy(place), Operand::Constant(value))
511533
| box (Operand::Constant(value), Operand::Move(place) | Operand::Copy(place)),
512534
) => {
513-
let Some(conditions) = state.try_get_idx(lhs, &self.map) else { return };
514-
let Some(place) = self.map.find(place.as_ref()) else { return };
535+
let Some(conditions) = state.try_get_idx(lhs, &self.map) else { return Some(()) };
536+
let Some(place) = self.map.find(place.as_ref()) else { return Some(()) };
515537
let equals = match op {
516538
BinOp::Eq => ScalarInt::TRUE,
517539
BinOp::Ne => ScalarInt::FALSE,
518-
_ => return,
540+
_ => return Some(()),
519541
};
520542
if value.const_.ty().is_floating_point() {
521543
// Floating point equality does not follow bit-patterns.
522544
// -0.0 and NaN both have special rules for equality,
523545
// and therefore we cannot use integer comparisons for them.
524546
// Avoid handling them, though this could be extended in the future.
525-
return;
547+
return Some(());
526548
}
527-
let Some(value) = value.const_.try_eval_scalar_int(self.tcx, self.typing_env)
528-
else {
529-
return;
530-
};
531-
let conds = conditions.map(self.arena, |c| Condition {
532-
value,
533-
polarity: if c.matches(equals) { Polarity::Eq } else { Polarity::Ne },
534-
..c
535-
});
549+
let value = value.const_.try_eval_scalar_int(self.tcx, self.typing_env)?;
550+
let conds = conditions.map(self.arena, |c| {
551+
Some(Condition {
552+
value,
553+
polarity: if c.matches(equals) { Polarity::Eq } else { Polarity::Ne },
554+
..c
555+
})
556+
})?;
536557
state.insert_value_idx(place, conds, &self.map);
537558
}
538559

539560
_ => {}
540561
}
562+
Some(())
541563
}
542564

543565
#[instrument(level = "trace", skip(self))]
@@ -546,7 +568,7 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
546568
bb: BasicBlock,
547569
stmt: &Statement<'tcx>,
548570
state: &mut State<ConditionSet<'a>>,
549-
) {
571+
) -> Option<()> {
550572
let register_opportunity = |c: Condition| {
551573
debug!(?bb, ?c.target, "register");
552574
self.opportunities.push(ThreadingOpportunity { chain: vec![bb], target: c.target })
@@ -559,30 +581,32 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
559581
// If we expect `discriminant(place) ?= A`,
560582
// we have an opportunity if `variant_index ?= A`.
561583
StatementKind::SetDiscriminant { box place, variant_index } => {
562-
let Some(discr_target) = self.map.find_discr(place.as_ref()) else { return };
584+
let Some(discr_target) = self.map.find_discr(place.as_ref()) else {
585+
return Some(());
586+
};
563587
let enum_ty = place.ty(self.body, self.tcx).ty;
564588
// `SetDiscriminant` guarantees that the discriminant is now `variant_index`.
565589
// Even if the discriminant write does nothing due to niches, it is UB to set the
566590
// discriminant when the data does not encode the desired discriminant.
567-
let Some(discr) =
568-
self.ecx.discriminant_for_variant(enum_ty, *variant_index).discard_err()
569-
else {
570-
return;
571-
};
591+
let discr =
592+
self.ecx.discriminant_for_variant(enum_ty, *variant_index).discard_err()?;
572593
self.process_immediate(bb, discr_target, discr, state);
573594
}
574595
// If we expect `lhs ?= true`, we have an opportunity if we assume `lhs == true`.
575596
StatementKind::Intrinsic(box NonDivergingIntrinsic::Assume(
576597
Operand::Copy(place) | Operand::Move(place),
577598
)) => {
578-
let Some(conditions) = state.try_get(place.as_ref(), &self.map) else { return };
599+
let Some(conditions) = state.try_get(place.as_ref(), &self.map) else {
600+
return Some(());
601+
};
579602
conditions.iter_matches(ScalarInt::TRUE).for_each(register_opportunity);
580603
}
581604
StatementKind::Assign(box (lhs_place, rhs)) => {
582-
self.process_assign(bb, lhs_place, rhs, state);
605+
self.process_assign(bb, lhs_place, rhs, state)?;
583606
}
584607
_ => {}
585608
}
609+
Some(())
586610
}
587611

588612
#[instrument(level = "trace", skip(self, state, cost))]
@@ -593,7 +617,7 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
593617
state: impl FnOnce() -> State<ConditionSet<'a>>,
594618
cost: &CostChecker<'_, 'tcx>,
595619
depth: usize,
596-
) {
620+
) -> Option<()> {
597621
let term = self.body.basic_blocks[bb].terminator();
598622
let place_to_flood = match term.kind {
599623
// We come from a target, so those are not possible.
@@ -608,9 +632,9 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
608632
| TerminatorKind::FalseUnwind { .. }
609633
| TerminatorKind::Yield { .. } => bug!("{term:?} invalid"),
610634
// Cannot reason about inline asm.
611-
TerminatorKind::InlineAsm { .. } => return,
635+
TerminatorKind::InlineAsm { .. } => return Some(()),
612636
// `SwitchInt` is handled specially.
613-
TerminatorKind::SwitchInt { .. } => return,
637+
TerminatorKind::SwitchInt { .. } => return Some(()),
614638
// We can recurse, no thing particular to do.
615639
TerminatorKind::Goto { .. } => None,
616640
// Flood the overwritten place, and progress through.
@@ -625,7 +649,7 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
625649
if let Some(place_to_flood) = place_to_flood {
626650
state.flood_with(place_to_flood.as_ref(), &self.map, ConditionSet::BOTTOM);
627651
}
628-
self.find_opportunity(bb, state, cost.clone(), depth + 1);
652+
self.find_opportunity(bb, state, cost.clone(), depth + 1)
629653
}
630654

631655
#[instrument(level = "trace", skip(self))]

0 commit comments

Comments
 (0)