Skip to content

Commit a772336

Browse files
committed
Auto merge of rust-lang#129047 - DianQK:early_otherwise_branch_scalar, r=cjgillot
Apply `EarlyOtherwiseBranch` to scalar value In the future, I'm thinking of hoisting discriminant via GVN so that we only need to write very little code here. r? `@cjgillot`
2 parents 702987f + e3a9eaf commit a772336

5 files changed

+422
-85
lines changed

compiler/rustc_mir_transform/src/early_otherwise_branch.rs

+175-85
Original file line numberDiff line numberDiff line change
@@ -133,18 +133,29 @@ impl<'tcx> crate::MirPass<'tcx> for EarlyOtherwiseBranch {
133133

134134
let mut patch = MirPatch::new(body);
135135

136-
// create temp to store second discriminant in, `_s` in example above
137-
let second_discriminant_temp =
138-
patch.new_temp(opt_data.child_ty, opt_data.child_source.span);
136+
let (second_discriminant_temp, second_operand) = if opt_data.need_hoist_discriminant {
137+
// create temp to store second discriminant in, `_s` in example above
138+
let second_discriminant_temp =
139+
patch.new_temp(opt_data.child_ty, opt_data.child_source.span);
139140

140-
patch.add_statement(parent_end, StatementKind::StorageLive(second_discriminant_temp));
141+
patch.add_statement(
142+
parent_end,
143+
StatementKind::StorageLive(second_discriminant_temp),
144+
);
141145

142-
// create assignment of discriminant
143-
patch.add_assign(
144-
parent_end,
145-
Place::from(second_discriminant_temp),
146-
Rvalue::Discriminant(opt_data.child_place),
147-
);
146+
// create assignment of discriminant
147+
patch.add_assign(
148+
parent_end,
149+
Place::from(second_discriminant_temp),
150+
Rvalue::Discriminant(opt_data.child_place),
151+
);
152+
(
153+
Some(second_discriminant_temp),
154+
Operand::Move(Place::from(second_discriminant_temp)),
155+
)
156+
} else {
157+
(None, Operand::Copy(opt_data.child_place))
158+
};
148159

149160
// create temp to store inequality comparison between the two discriminants, `_t` in
150161
// example above
@@ -153,11 +164,9 @@ impl<'tcx> crate::MirPass<'tcx> for EarlyOtherwiseBranch {
153164
let comp_temp = patch.new_temp(comp_res_type, opt_data.child_source.span);
154165
patch.add_statement(parent_end, StatementKind::StorageLive(comp_temp));
155166

156-
// create inequality comparison between the two discriminants
157-
let comp_rvalue = Rvalue::BinaryOp(
158-
nequal,
159-
Box::new((parent_op.clone(), Operand::Move(Place::from(second_discriminant_temp)))),
160-
);
167+
// create inequality comparison
168+
let comp_rvalue =
169+
Rvalue::BinaryOp(nequal, Box::new((parent_op.clone(), second_operand)));
161170
patch.add_statement(
162171
parent_end,
163172
StatementKind::Assign(Box::new((Place::from(comp_temp), comp_rvalue))),
@@ -193,8 +202,13 @@ impl<'tcx> crate::MirPass<'tcx> for EarlyOtherwiseBranch {
193202
TerminatorKind::if_(Operand::Move(Place::from(comp_temp)), true_case, false_case),
194203
);
195204

196-
// generate StorageDead for the second_discriminant_temp not in use anymore
197-
patch.add_statement(parent_end, StatementKind::StorageDead(second_discriminant_temp));
205+
if let Some(second_discriminant_temp) = second_discriminant_temp {
206+
// generate StorageDead for the second_discriminant_temp not in use anymore
207+
patch.add_statement(
208+
parent_end,
209+
StatementKind::StorageDead(second_discriminant_temp),
210+
);
211+
}
198212

199213
// Generate a StorageDead for comp_temp in each of the targets, since we moved it into
200214
// the switch
@@ -222,6 +236,7 @@ struct OptimizationData<'tcx> {
222236
child_place: Place<'tcx>,
223237
child_ty: Ty<'tcx>,
224238
child_source: SourceInfo,
239+
need_hoist_discriminant: bool,
225240
}
226241

227242
fn evaluate_candidate<'tcx>(
@@ -235,70 +250,128 @@ fn evaluate_candidate<'tcx>(
235250
return None;
236251
};
237252
let parent_ty = parent_discr.ty(body.local_decls(), tcx);
238-
if !bbs[targets.otherwise()].is_empty_unreachable() {
239-
// Someone could write code like this:
240-
// ```rust
241-
// let Q = val;
242-
// if discriminant(P) == otherwise {
243-
// let ptr = &mut Q as *mut _ as *mut u8;
244-
// // It may be difficult for us to effectively determine whether values are valid.
245-
// // Invalid values can come from all sorts of corners.
246-
// unsafe { *ptr = 10; }
247-
// }
248-
//
249-
// match P {
250-
// A => match Q {
251-
// A => {
252-
// // code
253-
// }
254-
// _ => {
255-
// // don't use Q
256-
// }
257-
// }
258-
// _ => {
259-
// // don't use Q
260-
// }
261-
// };
262-
// ```
263-
//
264-
// Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant
265-
// of an invalid value, which is UB.
266-
// In order to fix this, **we would either need to show that the discriminant computation of
267-
// `place` is computed in all branches**.
268-
// FIXME(#95162) For the moment, we adopt a conservative approach and
269-
// consider only the `otherwise` branch has no statements and an unreachable terminator.
270-
return None;
271-
}
272253
let (_, child) = targets.iter().next()?;
273-
let child_terminator = &bbs[child].terminator();
274-
let TerminatorKind::SwitchInt { targets: child_targets, discr: child_discr } =
275-
&child_terminator.kind
254+
255+
let Terminator {
256+
kind: TerminatorKind::SwitchInt { targets: child_targets, discr: child_discr },
257+
source_info,
258+
} = bbs[child].terminator()
276259
else {
277260
return None;
278261
};
279262
let child_ty = child_discr.ty(body.local_decls(), tcx);
280263
if child_ty != parent_ty {
281264
return None;
282265
}
283-
let Some(StatementKind::Assign(boxed)) = &bbs[child].statements.first().map(|x| &x.kind) else {
266+
267+
// We only handle:
268+
// ```
269+
// bb4: {
270+
// _8 = discriminant((_3.1: Enum1));
271+
// switchInt(move _8) -> [2: bb7, otherwise: bb1];
272+
// }
273+
// ```
274+
// and
275+
// ```
276+
// bb2: {
277+
// switchInt((_3.1: u64)) -> [1: bb5, otherwise: bb1];
278+
// }
279+
// ```
280+
if bbs[child].statements.len() > 1 {
284281
return None;
282+
}
283+
284+
// When thie BB has exactly one statement, this statement should be discriminant.
285+
let need_hoist_discriminant = bbs[child].statements.len() == 1;
286+
let child_place = if need_hoist_discriminant {
287+
if !bbs[targets.otherwise()].is_empty_unreachable() {
288+
// Someone could write code like this:
289+
// ```rust
290+
// let Q = val;
291+
// if discriminant(P) == otherwise {
292+
// let ptr = &mut Q as *mut _ as *mut u8;
293+
// // It may be difficult for us to effectively determine whether values are valid.
294+
// // Invalid values can come from all sorts of corners.
295+
// unsafe { *ptr = 10; }
296+
// }
297+
//
298+
// match P {
299+
// A => match Q {
300+
// A => {
301+
// // code
302+
// }
303+
// _ => {
304+
// // don't use Q
305+
// }
306+
// }
307+
// _ => {
308+
// // don't use Q
309+
// }
310+
// };
311+
// ```
312+
//
313+
// Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant of an
314+
// invalid value, which is UB.
315+
// In order to fix this, **we would either need to show that the discriminant computation of
316+
// `place` is computed in all branches**.
317+
// FIXME(#95162) For the moment, we adopt a conservative approach and
318+
// consider only the `otherwise` branch has no statements and an unreachable terminator.
319+
return None;
320+
}
321+
// Handle:
322+
// ```
323+
// bb4: {
324+
// _8 = discriminant((_3.1: Enum1));
325+
// switchInt(move _8) -> [2: bb7, otherwise: bb1];
326+
// }
327+
// ```
328+
let [
329+
Statement {
330+
kind: StatementKind::Assign(box (_, Rvalue::Discriminant(child_place))),
331+
..
332+
},
333+
] = bbs[child].statements.as_slice()
334+
else {
335+
return None;
336+
};
337+
*child_place
338+
} else {
339+
// Handle:
340+
// ```
341+
// bb2: {
342+
// switchInt((_3.1: u64)) -> [1: bb5, otherwise: bb1];
343+
// }
344+
// ```
345+
let Operand::Copy(child_place) = child_discr else {
346+
return None;
347+
};
348+
*child_place
285349
};
286-
let (_, Rvalue::Discriminant(child_place)) = &**boxed else {
287-
return None;
350+
let destination = if need_hoist_discriminant || bbs[targets.otherwise()].is_empty_unreachable()
351+
{
352+
child_targets.otherwise()
353+
} else {
354+
targets.otherwise()
288355
};
289-
let destination = child_targets.otherwise();
290356

291357
// Verify that the optimization is legal for each branch
292358
for (value, child) in targets.iter() {
293-
if !verify_candidate_branch(&bbs[child], value, *child_place, destination) {
359+
if !verify_candidate_branch(
360+
&bbs[child],
361+
value,
362+
child_place,
363+
destination,
364+
need_hoist_discriminant,
365+
) {
294366
return None;
295367
}
296368
}
297369
Some(OptimizationData {
298370
destination,
299-
child_place: *child_place,
371+
child_place,
300372
child_ty,
301-
child_source: child_terminator.source_info,
373+
child_source: *source_info,
374+
need_hoist_discriminant,
302375
})
303376
}
304377

@@ -307,31 +380,48 @@ fn verify_candidate_branch<'tcx>(
307380
value: u128,
308381
place: Place<'tcx>,
309382
destination: BasicBlock,
383+
need_hoist_discriminant: bool,
310384
) -> bool {
311-
// In order for the optimization to be correct, the branch must...
312-
// ...have exactly one statement
313-
if let [statement] = branch.statements.as_slice()
314-
// ...assign the discriminant of `place` in that statement
315-
&& let StatementKind::Assign(boxed) = &statement.kind
316-
&& let (discr_place, Rvalue::Discriminant(from_place)) = &**boxed
317-
&& *from_place == place
318-
// ...make that assignment to a local
319-
&& discr_place.projection.is_empty()
320-
// ...terminate on a `SwitchInt` that invalidates that local
321-
&& let TerminatorKind::SwitchInt { discr: switch_op, targets, .. } =
322-
&branch.terminator().kind
323-
&& *switch_op == Operand::Move(*discr_place)
324-
// ...fall through to `destination` if the switch misses
325-
&& destination == targets.otherwise()
326-
// ...have a branch for value `value`
327-
&& let mut iter = targets.iter()
328-
&& let Some((target_value, _)) = iter.next()
329-
&& target_value == value
330-
// ...and have no more branches
331-
&& iter.next().is_none()
332-
{
333-
true
385+
// In order for the optimization to be correct, the terminator must be a `SwitchInt`.
386+
let TerminatorKind::SwitchInt { discr: switch_op, targets } = &branch.terminator().kind else {
387+
return false;
388+
};
389+
if need_hoist_discriminant {
390+
// If we need hoist discriminant, the branch must have exactly one statement.
391+
let [statement] = branch.statements.as_slice() else {
392+
return false;
393+
};
394+
// The statement must assign the discriminant of `place`.
395+
let StatementKind::Assign(box (discr_place, Rvalue::Discriminant(from_place))) =
396+
statement.kind
397+
else {
398+
return false;
399+
};
400+
if from_place != place {
401+
return false;
402+
}
403+
// The assignment must invalidate a local that terminate on a `SwitchInt`.
404+
if !discr_place.projection.is_empty() || *switch_op != Operand::Move(discr_place) {
405+
return false;
406+
}
334407
} else {
335-
false
408+
// If we don't need hoist discriminant, the branch must not have any statements.
409+
if !branch.statements.is_empty() {
410+
return false;
411+
}
412+
// The place on `SwitchInt` must be the same.
413+
if *switch_op != Operand::Copy(place) {
414+
return false;
415+
}
336416
}
417+
// It must fall through to `destination` if the switch misses.
418+
if destination != targets.otherwise() {
419+
return false;
420+
}
421+
// It must have exactly one branch for value `value` and have no more branches.
422+
let mut iter = targets.iter();
423+
let (Some((target_value, _)), None) = (iter.next(), iter.next()) else {
424+
return false;
425+
};
426+
target_value == value
337427
}

0 commit comments

Comments
 (0)