@@ -548,6 +548,9 @@ fn simplify_to_copy<'tcx>(
548
548
switch_bb_idx : BasicBlock ,
549
549
param_env : ParamEnv < ' tcx > ,
550
550
) -> Option < ( ) > {
551
+ if switch_bb_idx != START_BLOCK {
552
+ return None ;
553
+ }
551
554
let bbs = & body. basic_blocks ;
552
555
// Check if the copy source matches the following pattern.
553
556
// _2 = discriminant(*_1); // "*_1" is the expected the copy source.
@@ -563,6 +566,11 @@ fn simplify_to_copy<'tcx>(
563
566
if !expected_src_ty. ty . is_enum ( ) || expected_src_ty. variant_index . is_some ( ) {
564
567
return None ;
565
568
}
569
+ let expected_dest_place = Place :: return_place ( ) ;
570
+ let expected_dest_ty = expected_dest_place. ty ( body. local_decls ( ) , tcx) ;
571
+ if expected_dest_ty. ty != expected_src_ty. ty || expected_dest_ty. variant_index . is_some ( ) {
572
+ return None ;
573
+ }
566
574
let targets = match bbs[ switch_bb_idx] . terminator ( ) . kind {
567
575
TerminatorKind :: SwitchInt { ref discr, ref targets, .. }
568
576
if discr. place ( ) == Some ( discr_place) =>
@@ -589,7 +597,7 @@ fn simplify_to_copy<'tcx>(
589
597
590
598
let borrowed_locals = borrowed_locals ( body) ;
591
599
let mut live = None ;
592
- let mut expected_dest_place = None ;
600
+
593
601
for ( index, target_bb) in targets. iter ( ) {
594
602
let stmts = & bbs[ target_bb] . statements ;
595
603
if stmts. is_empty ( ) {
@@ -605,7 +613,7 @@ fn simplify_to_copy<'tcx>(
605
613
let ty:: Adt ( def, _) = dest_ty. ty . kind ( ) else {
606
614
return None ;
607
615
} ;
608
- if * expected_dest_place. get_or_insert ( * place ) != * place {
616
+ if expected_dest_place != * place {
609
617
return None ;
610
618
}
611
619
match rvalue {
@@ -637,7 +645,7 @@ fn simplify_to_copy<'tcx>(
637
645
// If the BB contains more than one statement, we have to check if these statements can be ignored.
638
646
let mut lived_stmts: BitSet < usize > =
639
647
BitSet :: new_filled ( bbs[ target_bb] . statements . len ( ) ) ;
640
- let mut dest_place = None ;
648
+ let mut expected_copy_stmt = None ;
641
649
for ( statement_index, statement) in bbs[ target_bb] . statements . iter ( ) . enumerate ( ) . rev ( ) {
642
650
let loc = Location { block : target_bb, statement_index } ;
643
651
if let StatementKind :: Assign ( assign) = & statement. kind {
@@ -661,13 +669,16 @@ fn simplify_to_copy<'tcx>(
661
669
live. seek_before_primary_effect ( loc) ;
662
670
if !live. get ( ) . contains ( place. local ) {
663
671
lived_stmts. remove ( statement_index) ;
664
- } else if matches ! (
665
- & statement. kind,
666
- StatementKind :: Assign ( box ( _, Rvalue :: Use ( Operand :: Copy ( _) ) ) )
667
- ) && dest_place. is_none ( )
672
+ } else if let StatementKind :: Assign ( box (
673
+ _,
674
+ Rvalue :: Use ( Operand :: Copy ( src_place) ) ,
675
+ ) ) = statement. kind
676
+ && expected_copy_stmt. is_none ( )
677
+ && expected_src_place == src_place
678
+ && expected_dest_place == * place
668
679
{
669
680
// There is only one statement that cannot be ignored that can be used as an expected copy statement.
670
- dest_place = Some ( * place ) ;
681
+ expected_copy_stmt = Some ( statement_index ) ;
671
682
} else {
672
683
return None ;
673
684
}
@@ -687,21 +698,19 @@ fn simplify_to_copy<'tcx>(
687
698
}
688
699
}
689
700
}
690
- let dest_place = dest_place?;
691
- if * expected_dest_place. get_or_insert ( dest_place) != dest_place {
692
- return None ;
693
- }
701
+ let expected_copy_stmt = expected_copy_stmt?;
694
702
// We can ignore the paired StorageLive and StorageDead.
695
703
let mut storage_live_locals: BitSet < Local > = BitSet :: new_empty ( body. local_decls . len ( ) ) ;
696
704
for stmt_index in lived_stmts. iter ( ) {
697
705
let statement = & bbs[ target_bb] . statements [ stmt_index] ;
698
706
match & statement. kind {
699
- StatementKind :: Assign ( box ( place, Rvalue :: Use ( Operand :: Copy ( src_place) ) ) )
700
- if * place == dest_place && * src_place == expected_src_place => { }
707
+ StatementKind :: Assign ( _) if expected_copy_stmt == stmt_index => { }
701
708
StatementKind :: StorageLive ( local)
702
- if * local != dest_place. local && storage_live_locals. insert ( * local) => { }
709
+ if * local != expected_dest_place. local
710
+ && storage_live_locals. insert ( * local) => { }
703
711
StatementKind :: StorageDead ( local)
704
- if * local != dest_place. local && storage_live_locals. remove ( * local) => { }
712
+ if * local != expected_dest_place. local
713
+ && storage_live_locals. remove ( * local) => { }
705
714
StatementKind :: Nop => { }
706
715
_ => return None ,
707
716
}
@@ -711,7 +720,6 @@ fn simplify_to_copy<'tcx>(
711
720
}
712
721
}
713
722
}
714
- let expected_dest_place = expected_dest_place?;
715
723
let statement_index = bbs[ switch_bb_idx] . statements . len ( ) ;
716
724
let parent_end = Location { block : switch_bb_idx, statement_index } ;
717
725
let mut patch = MirPatch :: new ( body) ;
0 commit comments