1
- use rustc_index:: IndexVec ;
1
+ use rustc_index:: IndexSlice ;
2
+ use rustc_middle:: mir:: patch:: MirPatch ;
2
3
use rustc_middle:: mir:: * ;
3
4
use rustc_middle:: ty:: { ParamEnv , ScalarInt , Ty , TyCtxt } ;
4
5
use rustc_target:: abi:: Size ;
@@ -17,9 +18,10 @@ impl<'tcx> MirPass<'tcx> for MatchBranchSimplification {
17
18
let def_id = body. source . def_id ( ) ;
18
19
let param_env = tcx. param_env_reveal_all_normalized ( def_id) ;
19
20
20
- let bbs = body. basic_blocks . as_mut ( ) ;
21
21
let mut should_cleanup = false ;
22
- for bb_idx in bbs. indices ( ) {
22
+ for i in 0 ..body. basic_blocks . len ( ) {
23
+ let bbs = & * body. basic_blocks ;
24
+ let bb_idx = BasicBlock :: from_usize ( i) ;
23
25
if !tcx. consider_optimizing ( || format ! ( "MatchBranchSimplification {def_id:?} " ) ) {
24
26
continue ;
25
27
}
@@ -35,12 +37,11 @@ impl<'tcx> MirPass<'tcx> for MatchBranchSimplification {
35
37
_ => continue ,
36
38
} ;
37
39
38
- if SimplifyToIf . simplify ( tcx, & mut body. local_decls , bbs , bb_idx, param_env) {
40
+ if SimplifyToIf . simplify ( tcx, body, bb_idx, param_env) {
39
41
should_cleanup = true ;
40
42
continue ;
41
43
}
42
- if SimplifyToExp :: default ( ) . simplify ( tcx, & mut body. local_decls , bbs, bb_idx, param_env)
43
- {
44
+ if SimplifyToExp :: default ( ) . simplify ( tcx, body, bb_idx, param_env) {
44
45
should_cleanup = true ;
45
46
continue ;
46
47
}
@@ -58,41 +59,39 @@ trait SimplifyMatch<'tcx> {
58
59
fn simplify (
59
60
& mut self ,
60
61
tcx : TyCtxt < ' tcx > ,
61
- local_decls : & mut IndexVec < Local , LocalDecl < ' tcx > > ,
62
- bbs : & mut IndexVec < BasicBlock , BasicBlockData < ' tcx > > ,
62
+ body : & mut Body < ' tcx > ,
63
63
switch_bb_idx : BasicBlock ,
64
64
param_env : ParamEnv < ' tcx > ,
65
65
) -> bool {
66
+ let bbs = & body. basic_blocks ;
66
67
let ( discr, targets) = match bbs[ switch_bb_idx] . terminator ( ) . kind {
67
68
TerminatorKind :: SwitchInt { ref discr, ref targets, .. } => ( discr, targets) ,
68
69
_ => unreachable ! ( ) ,
69
70
} ;
70
71
71
- let discr_ty = discr. ty ( local_decls, tcx) ;
72
+ let discr_ty = discr. ty ( body . local_decls ( ) , tcx) ;
72
73
if !self . can_simplify ( tcx, targets, param_env, bbs, discr_ty) {
73
74
return false ;
74
75
}
75
76
77
+ let mut patch = MirPatch :: new ( body) ;
78
+
76
79
// Take ownership of items now that we know we can optimize.
77
80
let discr = discr. clone ( ) ;
78
81
79
82
// Introduce a temporary for the discriminant value.
80
83
let source_info = bbs[ switch_bb_idx] . terminator ( ) . source_info ;
81
- let discr_local = local_decls . push ( LocalDecl :: new ( discr_ty, source_info. span ) ) ;
84
+ let discr_local = patch . new_temp ( discr_ty, source_info. span ) ;
82
85
83
- let new_stmts = self . new_stmts ( tcx, targets, param_env, bbs, discr_local, discr_ty) ;
84
86
let ( _, first) = targets. iter ( ) . next ( ) . unwrap ( ) ;
85
- let ( from, first) = bbs. pick2_mut ( switch_bb_idx, first) ;
86
- from. statements
87
- . push ( Statement { source_info, kind : StatementKind :: StorageLive ( discr_local) } ) ;
88
- from. statements . push ( Statement {
89
- source_info,
90
- kind : StatementKind :: Assign ( Box :: new ( ( Place :: from ( discr_local) , Rvalue :: Use ( discr) ) ) ) ,
91
- } ) ;
92
- from. statements . extend ( new_stmts) ;
93
- from. statements
94
- . push ( Statement { source_info, kind : StatementKind :: StorageDead ( discr_local) } ) ;
95
- from. terminator_mut ( ) . kind = first. terminator ( ) . kind . clone ( ) ;
87
+ let statement_index = bbs[ switch_bb_idx] . statements . len ( ) ;
88
+ let parent_end = Location { block : switch_bb_idx, statement_index } ;
89
+ patch. add_statement ( parent_end, StatementKind :: StorageLive ( discr_local) ) ;
90
+ patch. add_assign ( parent_end, Place :: from ( discr_local) , Rvalue :: Use ( discr) ) ;
91
+ self . new_stmts ( tcx, targets, param_env, & mut patch, parent_end, bbs, discr_local, discr_ty) ;
92
+ patch. add_statement ( parent_end, StatementKind :: StorageDead ( discr_local) ) ;
93
+ patch. patch_terminator ( switch_bb_idx, bbs[ first] . terminator ( ) . kind . clone ( ) ) ;
94
+ patch. apply ( body) ;
96
95
true
97
96
}
98
97
@@ -104,7 +103,7 @@ trait SimplifyMatch<'tcx> {
104
103
tcx : TyCtxt < ' tcx > ,
105
104
targets : & SwitchTargets ,
106
105
param_env : ParamEnv < ' tcx > ,
107
- bbs : & IndexVec < BasicBlock , BasicBlockData < ' tcx > > ,
106
+ bbs : & IndexSlice < BasicBlock , BasicBlockData < ' tcx > > ,
108
107
discr_ty : Ty < ' tcx > ,
109
108
) -> bool ;
110
109
@@ -113,10 +112,12 @@ trait SimplifyMatch<'tcx> {
113
112
tcx : TyCtxt < ' tcx > ,
114
113
targets : & SwitchTargets ,
115
114
param_env : ParamEnv < ' tcx > ,
116
- bbs : & IndexVec < BasicBlock , BasicBlockData < ' tcx > > ,
115
+ patch : & mut MirPatch < ' tcx > ,
116
+ parent_end : Location ,
117
+ bbs : & IndexSlice < BasicBlock , BasicBlockData < ' tcx > > ,
117
118
discr_local : Local ,
118
119
discr_ty : Ty < ' tcx > ,
119
- ) -> Vec < Statement < ' tcx > > ;
120
+ ) ;
120
121
}
121
122
122
123
struct SimplifyToIf ;
@@ -158,7 +159,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf {
158
159
tcx : TyCtxt < ' tcx > ,
159
160
targets : & SwitchTargets ,
160
161
param_env : ParamEnv < ' tcx > ,
161
- bbs : & IndexVec < BasicBlock , BasicBlockData < ' tcx > > ,
162
+ bbs : & IndexSlice < BasicBlock , BasicBlockData < ' tcx > > ,
162
163
_discr_ty : Ty < ' tcx > ,
163
164
) -> bool {
164
165
if targets. iter ( ) . len ( ) != 1 {
@@ -209,20 +210,23 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf {
209
210
tcx : TyCtxt < ' tcx > ,
210
211
targets : & SwitchTargets ,
211
212
param_env : ParamEnv < ' tcx > ,
212
- bbs : & IndexVec < BasicBlock , BasicBlockData < ' tcx > > ,
213
+ patch : & mut MirPatch < ' tcx > ,
214
+ parent_end : Location ,
215
+ bbs : & IndexSlice < BasicBlock , BasicBlockData < ' tcx > > ,
213
216
discr_local : Local ,
214
217
discr_ty : Ty < ' tcx > ,
215
- ) -> Vec < Statement < ' tcx > > {
218
+ ) {
216
219
let ( val, first) = targets. iter ( ) . next ( ) . unwrap ( ) ;
217
220
let second = targets. otherwise ( ) ;
218
221
// We already checked that first and second are different blocks,
219
222
// and bb_idx has a different terminator from both of them.
220
223
let first = & bbs[ first] ;
221
224
let second = & bbs[ second] ;
222
-
223
- let new_stmts = iter:: zip ( & first. statements , & second. statements ) . map ( |( f, s) | {
225
+ for ( f, s) in iter:: zip ( & first. statements , & second. statements ) {
224
226
match ( & f. kind , & s. kind ) {
225
- ( f_s, s_s) if f_s == s_s => ( * f) . clone ( ) ,
227
+ ( f_s, s_s) if f_s == s_s => {
228
+ patch. add_statement ( parent_end, f. kind . clone ( ) ) ;
229
+ }
226
230
227
231
(
228
232
StatementKind :: Assign ( box ( lhs, Rvalue :: Use ( Operand :: Constant ( f_c) ) ) ) ,
@@ -233,7 +237,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf {
233
237
let s_b = s_c. const_ . try_eval_bool ( tcx, param_env) . unwrap ( ) ;
234
238
if f_b == s_b {
235
239
// Same value in both blocks. Use statement as is.
236
- ( * f ) . clone ( )
240
+ patch . add_statement ( parent_end , f . kind . clone ( ) ) ;
237
241
} else {
238
242
// Different value between blocks. Make value conditional on switch condition.
239
243
let size = tcx. layout_of ( param_env. and ( discr_ty) ) . unwrap ( ) . size ;
@@ -248,17 +252,13 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf {
248
252
op,
249
253
Box :: new ( ( Operand :: Copy ( Place :: from ( discr_local) ) , const_cmp) ) ,
250
254
) ;
251
- Statement {
252
- source_info : f. source_info ,
253
- kind : StatementKind :: Assign ( Box :: new ( ( * lhs, rhs) ) ) ,
254
- }
255
+ patch. add_assign ( parent_end, * lhs, rhs) ;
255
256
}
256
257
}
257
258
258
259
_ => unreachable ! ( ) ,
259
260
}
260
- } ) ;
261
- new_stmts. collect ( )
261
+ }
262
262
}
263
263
}
264
264
@@ -335,7 +335,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
335
335
tcx : TyCtxt < ' tcx > ,
336
336
targets : & SwitchTargets ,
337
337
param_env : ParamEnv < ' tcx > ,
338
- bbs : & IndexVec < BasicBlock , BasicBlockData < ' tcx > > ,
338
+ bbs : & IndexSlice < BasicBlock , BasicBlockData < ' tcx > > ,
339
339
discr_ty : Ty < ' tcx > ,
340
340
) -> bool {
341
341
if targets. iter ( ) . len ( ) < 2 || targets. iter ( ) . len ( ) > 64 {
@@ -372,6 +372,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
372
372
== ScalarInt :: try_from_uint ( r, size) . unwrap ( ) . try_to_int ( size) . unwrap ( )
373
373
}
374
374
375
+ // We first compare the two branches, and then the other branches need to fulfill the same conditions.
375
376
let mut compare_types = Vec :: new ( ) ;
376
377
for ( f, s) in iter:: zip ( first_stmts, second_stmts) {
377
378
let compare_type = match ( & f. kind , & s. kind ) {
@@ -391,6 +392,8 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
391
392
s_c. const_ . try_eval_scalar_int ( tcx, param_env) ,
392
393
) {
393
394
( Some ( f) , Some ( s) ) if f == s => CompareType :: Eq ( lhs_f, f_c. const_ . ty ( ) , f) ,
395
+ // Enum variants can also be simplified to an assignment statement if their values are equal.
396
+ // We need to consider both unsigned and signed scenarios here.
394
397
( Some ( f) , Some ( s) )
395
398
if ( ( f_c. const_ . ty ( ) . is_signed ( ) || discr_ty. is_signed ( ) )
396
399
&& int_equal ( f, first_val, discr_size)
@@ -463,16 +466,20 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
463
466
_tcx : TyCtxt < ' tcx > ,
464
467
targets : & SwitchTargets ,
465
468
_param_env : ParamEnv < ' tcx > ,
466
- bbs : & IndexVec < BasicBlock , BasicBlockData < ' tcx > > ,
469
+ patch : & mut MirPatch < ' tcx > ,
470
+ parent_end : Location ,
471
+ bbs : & IndexSlice < BasicBlock , BasicBlockData < ' tcx > > ,
467
472
discr_local : Local ,
468
473
discr_ty : Ty < ' tcx > ,
469
- ) -> Vec < Statement < ' tcx > > {
474
+ ) {
470
475
let ( _, first) = targets. iter ( ) . next ( ) . unwrap ( ) ;
471
476
let first = & bbs[ first] ;
472
477
473
- let new_stmts =
474
- iter:: zip ( & self . transfrom_types , & first. statements ) . map ( |( t, s) | match ( t, & s. kind ) {
475
- ( TransfromType :: Same , _) | ( TransfromType :: Eq , _) => ( * s) . clone ( ) ,
478
+ for ( t, s) in iter:: zip ( & self . transfrom_types , & first. statements ) {
479
+ match ( t, & s. kind ) {
480
+ ( TransfromType :: Same , _) | ( TransfromType :: Eq , _) => {
481
+ patch. add_statement ( parent_end, s. kind . clone ( ) ) ;
482
+ }
476
483
(
477
484
TransfromType :: Discr ,
478
485
StatementKind :: Assign ( box ( lhs, Rvalue :: Use ( Operand :: Constant ( f_c) ) ) ) ,
@@ -483,13 +490,10 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
483
490
} else {
484
491
Rvalue :: Cast ( CastKind :: IntToInt , operand, f_c. const_ . ty ( ) )
485
492
} ;
486
- Statement {
487
- source_info : s. source_info ,
488
- kind : StatementKind :: Assign ( Box :: new ( ( * lhs, r_val) ) ) ,
489
- }
493
+ patch. add_assign ( parent_end, * lhs, r_val) ;
490
494
}
491
495
_ => unreachable ! ( ) ,
492
- } ) ;
493
- new_stmts . collect ( )
496
+ }
497
+ }
494
498
}
495
499
}
0 commit comments