@@ -133,18 +133,29 @@ impl<'tcx> crate::MirPass<'tcx> for EarlyOtherwiseBranch {
133
133
134
134
let mut patch = MirPatch :: new ( body) ;
135
135
136
- // create temp to store second discriminant in, `_s` in example above
137
- let second_discriminant_temp =
138
- patch. new_temp ( opt_data. child_ty , opt_data. child_source . span ) ;
136
+ let ( second_discriminant_temp, second_operand) = if opt_data. need_hoist_discriminant {
137
+ // create temp to store second discriminant in, `_s` in example above
138
+ let second_discriminant_temp =
139
+ patch. new_temp ( opt_data. child_ty , opt_data. child_source . span ) ;
139
140
140
- patch. add_statement ( parent_end, StatementKind :: StorageLive ( second_discriminant_temp) ) ;
141
+ patch. add_statement (
142
+ parent_end,
143
+ StatementKind :: StorageLive ( second_discriminant_temp) ,
144
+ ) ;
141
145
142
- // create assignment of discriminant
143
- patch. add_assign (
144
- parent_end,
145
- Place :: from ( second_discriminant_temp) ,
146
- Rvalue :: Discriminant ( opt_data. child_place ) ,
147
- ) ;
146
+ // create assignment of discriminant
147
+ patch. add_assign (
148
+ parent_end,
149
+ Place :: from ( second_discriminant_temp) ,
150
+ Rvalue :: Discriminant ( opt_data. child_place ) ,
151
+ ) ;
152
+ (
153
+ Some ( second_discriminant_temp) ,
154
+ Operand :: Move ( Place :: from ( second_discriminant_temp) ) ,
155
+ )
156
+ } else {
157
+ ( None , Operand :: Copy ( opt_data. child_place ) )
158
+ } ;
148
159
149
160
// create temp to store inequality comparison between the two discriminants, `_t` in
150
161
// example above
@@ -153,11 +164,9 @@ impl<'tcx> crate::MirPass<'tcx> for EarlyOtherwiseBranch {
153
164
let comp_temp = patch. new_temp ( comp_res_type, opt_data. child_source . span ) ;
154
165
patch. add_statement ( parent_end, StatementKind :: StorageLive ( comp_temp) ) ;
155
166
156
- // create inequality comparison between the two discriminants
157
- let comp_rvalue = Rvalue :: BinaryOp (
158
- nequal,
159
- Box :: new ( ( parent_op. clone ( ) , Operand :: Move ( Place :: from ( second_discriminant_temp) ) ) ) ,
160
- ) ;
167
+ // create inequality comparison
168
+ let comp_rvalue =
169
+ Rvalue :: BinaryOp ( nequal, Box :: new ( ( parent_op. clone ( ) , second_operand) ) ) ;
161
170
patch. add_statement (
162
171
parent_end,
163
172
StatementKind :: Assign ( Box :: new ( ( Place :: from ( comp_temp) , comp_rvalue) ) ) ,
@@ -193,8 +202,13 @@ impl<'tcx> crate::MirPass<'tcx> for EarlyOtherwiseBranch {
193
202
TerminatorKind :: if_ ( Operand :: Move ( Place :: from ( comp_temp) ) , true_case, false_case) ,
194
203
) ;
195
204
196
- // generate StorageDead for the second_discriminant_temp not in use anymore
197
- patch. add_statement ( parent_end, StatementKind :: StorageDead ( second_discriminant_temp) ) ;
205
+ if let Some ( second_discriminant_temp) = second_discriminant_temp {
206
+ // generate StorageDead for the second_discriminant_temp not in use anymore
207
+ patch. add_statement (
208
+ parent_end,
209
+ StatementKind :: StorageDead ( second_discriminant_temp) ,
210
+ ) ;
211
+ }
198
212
199
213
// Generate a StorageDead for comp_temp in each of the targets, since we moved it into
200
214
// the switch
@@ -222,6 +236,7 @@ struct OptimizationData<'tcx> {
222
236
child_place : Place < ' tcx > ,
223
237
child_ty : Ty < ' tcx > ,
224
238
child_source : SourceInfo ,
239
+ need_hoist_discriminant : bool ,
225
240
}
226
241
227
242
fn evaluate_candidate < ' tcx > (
@@ -235,70 +250,128 @@ fn evaluate_candidate<'tcx>(
235
250
return None ;
236
251
} ;
237
252
let parent_ty = parent_discr. ty ( body. local_decls ( ) , tcx) ;
238
- if !bbs[ targets. otherwise ( ) ] . is_empty_unreachable ( ) {
239
- // Someone could write code like this:
240
- // ```rust
241
- // let Q = val;
242
- // if discriminant(P) == otherwise {
243
- // let ptr = &mut Q as *mut _ as *mut u8;
244
- // // It may be difficult for us to effectively determine whether values are valid.
245
- // // Invalid values can come from all sorts of corners.
246
- // unsafe { *ptr = 10; }
247
- // }
248
- //
249
- // match P {
250
- // A => match Q {
251
- // A => {
252
- // // code
253
- // }
254
- // _ => {
255
- // // don't use Q
256
- // }
257
- // }
258
- // _ => {
259
- // // don't use Q
260
- // }
261
- // };
262
- // ```
263
- //
264
- // Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant
265
- // of an invalid value, which is UB.
266
- // In order to fix this, **we would either need to show that the discriminant computation of
267
- // `place` is computed in all branches**.
268
- // FIXME(#95162) For the moment, we adopt a conservative approach and
269
- // consider only the `otherwise` branch has no statements and an unreachable terminator.
270
- return None ;
271
- }
272
253
let ( _, child) = targets. iter ( ) . next ( ) ?;
273
- let child_terminator = & bbs[ child] . terminator ( ) ;
274
- let TerminatorKind :: SwitchInt { targets : child_targets, discr : child_discr } =
275
- & child_terminator. kind
254
+
255
+ let Terminator {
256
+ kind : TerminatorKind :: SwitchInt { targets : child_targets, discr : child_discr } ,
257
+ source_info,
258
+ } = bbs[ child] . terminator ( )
276
259
else {
277
260
return None ;
278
261
} ;
279
262
let child_ty = child_discr. ty ( body. local_decls ( ) , tcx) ;
280
263
if child_ty != parent_ty {
281
264
return None ;
282
265
}
283
- let Some ( StatementKind :: Assign ( boxed) ) = & bbs[ child] . statements . first ( ) . map ( |x| & x. kind ) else {
266
+
267
+ // We only handle:
268
+ // ```
269
+ // bb4: {
270
+ // _8 = discriminant((_3.1: Enum1));
271
+ // switchInt(move _8) -> [2: bb7, otherwise: bb1];
272
+ // }
273
+ // ```
274
+ // and
275
+ // ```
276
+ // bb2: {
277
+ // switchInt((_3.1: u64)) -> [1: bb5, otherwise: bb1];
278
+ // }
279
+ // ```
280
+ if bbs[ child] . statements . len ( ) > 1 {
284
281
return None ;
282
+ }
283
+
284
+ // When thie BB has exactly one statement, this statement should be discriminant.
285
+ let need_hoist_discriminant = bbs[ child] . statements . len ( ) == 1 ;
286
+ let child_place = if need_hoist_discriminant {
287
+ if !bbs[ targets. otherwise ( ) ] . is_empty_unreachable ( ) {
288
+ // Someone could write code like this:
289
+ // ```rust
290
+ // let Q = val;
291
+ // if discriminant(P) == otherwise {
292
+ // let ptr = &mut Q as *mut _ as *mut u8;
293
+ // // It may be difficult for us to effectively determine whether values are valid.
294
+ // // Invalid values can come from all sorts of corners.
295
+ // unsafe { *ptr = 10; }
296
+ // }
297
+ //
298
+ // match P {
299
+ // A => match Q {
300
+ // A => {
301
+ // // code
302
+ // }
303
+ // _ => {
304
+ // // don't use Q
305
+ // }
306
+ // }
307
+ // _ => {
308
+ // // don't use Q
309
+ // }
310
+ // };
311
+ // ```
312
+ //
313
+ // Hoisting the `discriminant(Q)` out of the `A` arm causes us to compute the discriminant of an
314
+ // invalid value, which is UB.
315
+ // In order to fix this, **we would either need to show that the discriminant computation of
316
+ // `place` is computed in all branches**.
317
+ // FIXME(#95162) For the moment, we adopt a conservative approach and
318
+ // consider only the `otherwise` branch has no statements and an unreachable terminator.
319
+ return None ;
320
+ }
321
+ // Handle:
322
+ // ```
323
+ // bb4: {
324
+ // _8 = discriminant((_3.1: Enum1));
325
+ // switchInt(move _8) -> [2: bb7, otherwise: bb1];
326
+ // }
327
+ // ```
328
+ let [
329
+ Statement {
330
+ kind : StatementKind :: Assign ( box ( _, Rvalue :: Discriminant ( child_place) ) ) ,
331
+ ..
332
+ } ,
333
+ ] = bbs[ child] . statements . as_slice ( )
334
+ else {
335
+ return None ;
336
+ } ;
337
+ * child_place
338
+ } else {
339
+ // Handle:
340
+ // ```
341
+ // bb2: {
342
+ // switchInt((_3.1: u64)) -> [1: bb5, otherwise: bb1];
343
+ // }
344
+ // ```
345
+ let Operand :: Copy ( child_place) = child_discr else {
346
+ return None ;
347
+ } ;
348
+ * child_place
285
349
} ;
286
- let ( _, Rvalue :: Discriminant ( child_place) ) = & * * boxed else {
287
- return None ;
350
+ let destination = if need_hoist_discriminant || bbs[ targets. otherwise ( ) ] . is_empty_unreachable ( )
351
+ {
352
+ child_targets. otherwise ( )
353
+ } else {
354
+ targets. otherwise ( )
288
355
} ;
289
- let destination = child_targets. otherwise ( ) ;
290
356
291
357
// Verify that the optimization is legal for each branch
292
358
for ( value, child) in targets. iter ( ) {
293
- if !verify_candidate_branch ( & bbs[ child] , value, * child_place, destination) {
359
+ if !verify_candidate_branch (
360
+ & bbs[ child] ,
361
+ value,
362
+ child_place,
363
+ destination,
364
+ need_hoist_discriminant,
365
+ ) {
294
366
return None ;
295
367
}
296
368
}
297
369
Some ( OptimizationData {
298
370
destination,
299
- child_place : * child_place ,
371
+ child_place,
300
372
child_ty,
301
- child_source : child_terminator. source_info ,
373
+ child_source : * source_info,
374
+ need_hoist_discriminant,
302
375
} )
303
376
}
304
377
@@ -307,31 +380,48 @@ fn verify_candidate_branch<'tcx>(
307
380
value : u128 ,
308
381
place : Place < ' tcx > ,
309
382
destination : BasicBlock ,
383
+ need_hoist_discriminant : bool ,
310
384
) -> bool {
311
- // In order for the optimization to be correct, the branch must...
312
- // ...have exactly one statement
313
- if let [ statement] = branch. statements . as_slice ( )
314
- // ...assign the discriminant of `place` in that statement
315
- && let StatementKind :: Assign ( boxed) = & statement. kind
316
- && let ( discr_place, Rvalue :: Discriminant ( from_place) ) = & * * boxed
317
- && * from_place == place
318
- // ...make that assignment to a local
319
- && discr_place. projection . is_empty ( )
320
- // ...terminate on a `SwitchInt` that invalidates that local
321
- && let TerminatorKind :: SwitchInt { discr : switch_op, targets, .. } =
322
- & branch. terminator ( ) . kind
323
- && * switch_op == Operand :: Move ( * discr_place)
324
- // ...fall through to `destination` if the switch misses
325
- && destination == targets. otherwise ( )
326
- // ...have a branch for value `value`
327
- && let mut iter = targets. iter ( )
328
- && let Some ( ( target_value, _) ) = iter. next ( )
329
- && target_value == value
330
- // ...and have no more branches
331
- && iter. next ( ) . is_none ( )
332
- {
333
- true
385
+ // In order for the optimization to be correct, the terminator must be a `SwitchInt`.
386
+ let TerminatorKind :: SwitchInt { discr : switch_op, targets } = & branch. terminator ( ) . kind else {
387
+ return false ;
388
+ } ;
389
+ if need_hoist_discriminant {
390
+ // If we need hoist discriminant, the branch must have exactly one statement.
391
+ let [ statement] = branch. statements . as_slice ( ) else {
392
+ return false ;
393
+ } ;
394
+ // The statement must assign the discriminant of `place`.
395
+ let StatementKind :: Assign ( box ( discr_place, Rvalue :: Discriminant ( from_place) ) ) =
396
+ statement. kind
397
+ else {
398
+ return false ;
399
+ } ;
400
+ if from_place != place {
401
+ return false ;
402
+ }
403
+ // The assignment must invalidate a local that terminate on a `SwitchInt`.
404
+ if !discr_place. projection . is_empty ( ) || * switch_op != Operand :: Move ( discr_place) {
405
+ return false ;
406
+ }
334
407
} else {
335
- false
408
+ // If we don't need hoist discriminant, the branch must not have any statements.
409
+ if !branch. statements . is_empty ( ) {
410
+ return false ;
411
+ }
412
+ // The place on `SwitchInt` must be the same.
413
+ if * switch_op != Operand :: Copy ( place) {
414
+ return false ;
415
+ }
336
416
}
417
+ // It must fall through to `destination` if the switch misses.
418
+ if destination != targets. otherwise ( ) {
419
+ return false ;
420
+ }
421
+ // It must have exactly one branch for value `value` and have no more branches.
422
+ let mut iter = targets. iter ( ) ;
423
+ let ( Some ( ( target_value, _) ) , None ) = ( iter. next ( ) , iter. next ( ) ) else {
424
+ return false ;
425
+ } ;
426
+ target_value == value
337
427
}
0 commit comments