Skip to content

Remove early exits from JumpThreading. #140024

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
135 changes: 64 additions & 71 deletions compiler/rustc_mir_transform/src/jump_threading.rs
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,7 @@ impl<'tcx> crate::MirPass<'tcx> for JumpThreading {
};

for bb in body.basic_blocks.indices() {
let old_len = finder.opportunities.len();
// If we have any const-eval errors discard any opportunities found
if finder.start_from_switch(bb).is_none() {
finder.opportunities.truncate(old_len);
}
finder.start_from_switch(bb);
}

let opportunities = finder.opportunities;
Expand Down Expand Up @@ -201,28 +197,26 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {

/// Recursion entry point to find threading opportunities.
#[instrument(level = "trace", skip(self))]
fn start_from_switch(&mut self, bb: BasicBlock) -> Option<()> {
fn start_from_switch(&mut self, bb: BasicBlock) {
let bbdata = &self.body[bb];
if bbdata.is_cleanup || self.loop_headers.contains(bb) {
return Some(());
return;
}
let Some((discr, targets)) = bbdata.terminator().kind.as_switch() else { return Some(()) };
let Some(discr) = discr.place() else { return Some(()) };
let Some((discr, targets)) = bbdata.terminator().kind.as_switch() else { return };
let Some(discr) = discr.place() else { return };
debug!(?discr, ?bb);

let discr_ty = discr.ty(self.body, self.tcx).ty;
let Ok(discr_layout) = self.ecx.layout_of(discr_ty) else {
return Some(());
};
let Ok(discr_layout) = self.ecx.layout_of(discr_ty) else { return };

let Some(discr) = self.map.find(discr.as_ref()) else { return Some(()) };
let Some(discr) = self.map.find(discr.as_ref()) else { return };
debug!(?discr);

let cost = CostChecker::new(self.tcx, self.typing_env, None, self.body);
let mut state = State::new_reachable();

let conds = if let Some((value, then, else_)) = targets.as_static_if() {
let value = ScalarInt::try_from_uint(value, discr_layout.size)?;
let Some(value) = ScalarInt::try_from_uint(value, discr_layout.size) else { return };
self.arena.alloc_from_iter([
Condition { value, polarity: Polarity::Eq, target: then },
Condition { value, polarity: Polarity::Ne, target: else_ },
Expand All @@ -248,27 +242,27 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
mut state: State<ConditionSet<'a>>,
mut cost: CostChecker<'_, 'tcx>,
depth: usize,
) -> Option<()> {
) {
// Do not thread through loop headers.
if self.loop_headers.contains(bb) {
return Some(());
return;
}

debug!(cost = ?cost.cost());
for (statement_index, stmt) in
self.body.basic_blocks[bb].statements.iter().enumerate().rev()
{
if self.is_empty(&state) {
return Some(());
return;
}

cost.visit_statement(stmt, Location { block: bb, statement_index });
if cost.cost() > MAX_COST {
return Some(());
return;
}

// Attempt to turn the `current_condition` on `lhs` into a condition on another place.
self.process_statement(bb, stmt, &mut state)?;
self.process_statement(bb, stmt, &mut state);

// When a statement mutates a place, assignments to that place that happen
// above the mutation cannot fulfill a condition.
Expand All @@ -280,7 +274,7 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
}

if self.is_empty(&state) || depth >= MAX_BACKTRACK {
return Some(());
return;
}

let last_non_rec = self.opportunities.len();
Expand All @@ -293,9 +287,9 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
match term.kind {
TerminatorKind::SwitchInt { ref discr, ref targets } => {
self.process_switch_int(discr, targets, bb, &mut state);
self.find_opportunity(pred, state, cost, depth + 1)?;
self.find_opportunity(pred, state, cost, depth + 1);
}
_ => self.recurse_through_terminator(pred, || state, &cost, depth)?,
_ => self.recurse_through_terminator(pred, || state, &cost, depth),
}
} else if let &[ref predecessors @ .., last_pred] = &predecessors[..] {
for &pred in predecessors {
Expand All @@ -320,13 +314,12 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
let first = &mut new_tos[0];
*first = ThreadingOpportunity { chain: vec![bb], target: first.target };
self.opportunities.truncate(last_non_rec + 1);
return Some(());
return;
}

for op in self.opportunities[last_non_rec..].iter_mut() {
op.chain.push(bb);
}
Some(())
}

/// Extract the mutated place from a statement.
Expand Down Expand Up @@ -440,23 +433,23 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
lhs: PlaceIndex,
rhs: &Operand<'tcx>,
state: &mut State<ConditionSet<'a>>,
) -> Option<()> {
) {
match rhs {
// If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
Operand::Constant(constant) => {
let constant = self
.ecx
.eval_mir_constant(&constant.const_, constant.span, None)
.discard_err()?;
let Some(constant) =
self.ecx.eval_mir_constant(&constant.const_, constant.span, None).discard_err()
else {
return;
};
self.process_constant(bb, lhs, constant, state);
}
// Transfer the conditions on the copied rhs.
Operand::Move(rhs) | Operand::Copy(rhs) => {
let Some(rhs) = self.map.find(rhs.as_ref()) else { return Some(()) };
let Some(rhs) = self.map.find(rhs.as_ref()) else { return };
state.insert_place_idx(rhs, lhs, &self.map);
}
}
Some(())
}

#[instrument(level = "trace", skip(self))]
Expand All @@ -466,26 +459,22 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
lhs_place: &Place<'tcx>,
rhs: &Rvalue<'tcx>,
state: &mut State<ConditionSet<'a>>,
) -> Option<()> {
let Some(lhs) = self.map.find(lhs_place.as_ref()) else {
return Some(());
};
) {
let Some(lhs) = self.map.find(lhs_place.as_ref()) else { return };
match rhs {
Rvalue::Use(operand) => self.process_operand(bb, lhs, operand, state)?,
Rvalue::Use(operand) => self.process_operand(bb, lhs, operand, state),
// Transfer the conditions on the copy rhs.
Rvalue::CopyForDeref(rhs) => {
self.process_operand(bb, lhs, &Operand::Copy(*rhs), state)?
}
Rvalue::CopyForDeref(rhs) => self.process_operand(bb, lhs, &Operand::Copy(*rhs), state),
Rvalue::Discriminant(rhs) => {
let Some(rhs) = self.map.find_discr(rhs.as_ref()) else { return Some(()) };
let Some(rhs) = self.map.find_discr(rhs.as_ref()) else { return };
state.insert_place_idx(rhs, lhs, &self.map);
}
// If we expect `lhs ?= A`, we have an opportunity if we assume `constant == A`.
Rvalue::Aggregate(box kind, operands) => {
let agg_ty = lhs_place.ty(self.body, self.tcx).ty;
let lhs = match kind {
// Do not support unions.
AggregateKind::Adt(.., Some(_)) => return Some(()),
AggregateKind::Adt(.., Some(_)) => return,
AggregateKind::Adt(_, variant_index, ..) if agg_ty.is_enum() => {
if let Some(discr_target) = self.map.apply(lhs, TrackElem::Discriminant)
&& let Some(discr_value) = self
Expand All @@ -498,31 +487,33 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
if let Some(idx) = self.map.apply(lhs, TrackElem::Variant(*variant_index)) {
idx
} else {
return Some(());
return;
}
}
_ => lhs,
};
for (field_index, operand) in operands.iter_enumerated() {
if let Some(field) = self.map.apply(lhs, TrackElem::Field(field_index)) {
self.process_operand(bb, field, operand, state)?;
self.process_operand(bb, field, operand, state);
}
}
}
// Transfer the conditions on the copy rhs, after inverting the value of the condition.
Rvalue::UnaryOp(UnOp::Not, Operand::Move(place) | Operand::Copy(place)) => {
let layout = self.ecx.layout_of(place.ty(self.body, self.tcx).ty).unwrap();
let Some(conditions) = state.try_get_idx(lhs, &self.map) else { return Some(()) };
let Some(place) = self.map.find(place.as_ref()) else { return Some(()) };
let conds = conditions.map(self.arena, |mut cond| {
let Some(conditions) = state.try_get_idx(lhs, &self.map) else { return };
let Some(place) = self.map.find(place.as_ref()) else { return };
let Some(conds) = conditions.map(self.arena, |mut cond| {
cond.value = self
.ecx
.unary_op(UnOp::Not, &ImmTy::from_scalar_int(cond.value, layout))
.discard_err()?
.to_scalar_int()
.discard_err()?;
Some(cond)
})?;
}) else {
return;
};
state.insert_value_idx(place, conds, &self.map);
}
// We expect `lhs ?= A`. We found `lhs = Eq(rhs, B)`.
Expand All @@ -532,34 +523,38 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
box (Operand::Move(place) | Operand::Copy(place), Operand::Constant(value))
| box (Operand::Constant(value), Operand::Move(place) | Operand::Copy(place)),
) => {
let Some(conditions) = state.try_get_idx(lhs, &self.map) else { return Some(()) };
let Some(place) = self.map.find(place.as_ref()) else { return Some(()) };
let Some(conditions) = state.try_get_idx(lhs, &self.map) else { return };
let Some(place) = self.map.find(place.as_ref()) else { return };
let equals = match op {
BinOp::Eq => ScalarInt::TRUE,
BinOp::Ne => ScalarInt::FALSE,
_ => return Some(()),
_ => return,
};
if value.const_.ty().is_floating_point() {
// Floating point equality does not follow bit-patterns.
// -0.0 and NaN both have special rules for equality,
// and therefore we cannot use integer comparisons for them.
// Avoid handling them, though this could be extended in the future.
return Some(());
return;
}
let value = value.const_.try_eval_scalar_int(self.tcx, self.typing_env)?;
let conds = conditions.map(self.arena, |c| {
let Some(value) = value.const_.try_eval_scalar_int(self.tcx, self.typing_env)
else {
return;
};
let Some(conds) = conditions.map(self.arena, |c| {
Some(Condition {
value,
polarity: if c.matches(equals) { Polarity::Eq } else { Polarity::Ne },
..c
})
})?;
}) else {
return;
};
state.insert_value_idx(place, conds, &self.map);
}

_ => {}
}
Some(())
}

#[instrument(level = "trace", skip(self))]
Expand All @@ -568,7 +563,7 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
bb: BasicBlock,
stmt: &Statement<'tcx>,
state: &mut State<ConditionSet<'a>>,
) -> Option<()> {
) {
let register_opportunity = |c: Condition| {
debug!(?bb, ?c.target, "register");
self.opportunities.push(ThreadingOpportunity { chain: vec![bb], target: c.target })
Expand All @@ -581,32 +576,30 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
// If we expect `discriminant(place) ?= A`,
// we have an opportunity if `variant_index ?= A`.
StatementKind::SetDiscriminant { box place, variant_index } => {
let Some(discr_target) = self.map.find_discr(place.as_ref()) else {
return Some(());
};
let Some(discr_target) = self.map.find_discr(place.as_ref()) else { return };
let enum_ty = place.ty(self.body, self.tcx).ty;
// `SetDiscriminant` guarantees that the discriminant is now `variant_index`.
// Even if the discriminant write does nothing due to niches, it is UB to set the
// discriminant when the data does not encode the desired discriminant.
let discr =
self.ecx.discriminant_for_variant(enum_ty, *variant_index).discard_err()?;
self.process_immediate(bb, discr_target, discr, state);
let Some(discr) =
self.ecx.discriminant_for_variant(enum_ty, *variant_index).discard_err()
else {
return;
};
self.process_immediate(bb, discr_target, discr, state)
}
// If we expect `lhs ?= true`, we have an opportunity if we assume `lhs == true`.
StatementKind::Intrinsic(box NonDivergingIntrinsic::Assume(
Operand::Copy(place) | Operand::Move(place),
)) => {
let Some(conditions) = state.try_get(place.as_ref(), &self.map) else {
return Some(());
};
conditions.iter_matches(ScalarInt::TRUE).for_each(register_opportunity);
let Some(conditions) = state.try_get(place.as_ref(), &self.map) else { return };
conditions.iter_matches(ScalarInt::TRUE).for_each(register_opportunity)
}
StatementKind::Assign(box (lhs_place, rhs)) => {
self.process_assign(bb, lhs_place, rhs, state)?;
self.process_assign(bb, lhs_place, rhs, state)
}
_ => {}
}
Some(())
}

#[instrument(level = "trace", skip(self, state, cost))]
Expand All @@ -617,7 +610,7 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
state: impl FnOnce() -> State<ConditionSet<'a>>,
cost: &CostChecker<'_, 'tcx>,
depth: usize,
) -> Option<()> {
) {
let term = self.body.basic_blocks[bb].terminator();
let place_to_flood = match term.kind {
// We come from a target, so those are not possible.
Expand All @@ -632,9 +625,9 @@ impl<'a, 'tcx> TOFinder<'a, 'tcx> {
| TerminatorKind::FalseUnwind { .. }
| TerminatorKind::Yield { .. } => bug!("{term:?} invalid"),
// Cannot reason about inline asm.
TerminatorKind::InlineAsm { .. } => return Some(()),
TerminatorKind::InlineAsm { .. } => return,
// `SwitchInt` is handled specially.
TerminatorKind::SwitchInt { .. } => return Some(()),
TerminatorKind::SwitchInt { .. } => return,
// We can recurse, no thing particular to do.
TerminatorKind::Goto { .. } => None,
// Flood the overwritten place, and progress through.
Expand Down
Loading