@@ -3,8 +3,10 @@ use std::iter;
3
3
use rustc_index:: IndexSlice ;
4
4
use rustc_middle:: mir:: patch:: MirPatch ;
5
5
use rustc_middle:: mir:: * ;
6
+ use rustc_middle:: ty:: layout:: { IntegerExt , TyAndLayout } ;
6
7
use rustc_middle:: ty:: { ParamEnv , ScalarInt , Ty , TyCtxt } ;
7
- use rustc_target:: abi:: Size ;
8
+ use rustc_target:: abi:: Integer ;
9
+ use rustc_type_ir:: TyKind :: * ;
8
10
9
11
use super :: simplify:: simplify_cfg;
10
12
@@ -264,33 +266,56 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToIf {
264
266
}
265
267
}
266
268
269
+ /// Check if the cast constant using `IntToInt` is equal to the target constant.
270
+ fn can_cast (
271
+ tcx : TyCtxt < ' _ > ,
272
+ src_val : impl Into < u128 > ,
273
+ src_layout : TyAndLayout < ' _ > ,
274
+ cast_ty : Ty < ' _ > ,
275
+ target_scalar : ScalarInt ,
276
+ ) -> bool {
277
+ let from_scalar = ScalarInt :: try_from_uint ( src_val. into ( ) , src_layout. size ) . unwrap ( ) ;
278
+ let v = match src_layout. ty . kind ( ) {
279
+ Uint ( _) => from_scalar. to_uint ( src_layout. size ) ,
280
+ Int ( _) => from_scalar. to_int ( src_layout. size ) as u128 ,
281
+ _ => unreachable ! ( "invalid int" ) ,
282
+ } ;
283
+ let size = match * cast_ty. kind ( ) {
284
+ Int ( t) => Integer :: from_int_ty ( & tcx, t) . size ( ) ,
285
+ Uint ( t) => Integer :: from_uint_ty ( & tcx, t) . size ( ) ,
286
+ _ => unreachable ! ( "invalid int" ) ,
287
+ } ;
288
+ let v = size. truncate ( v) ;
289
+ let cast_scalar = ScalarInt :: try_from_uint ( v, size) . unwrap ( ) ;
290
+ cast_scalar == target_scalar
291
+ }
292
+
267
293
#[ derive( Default ) ]
268
294
struct SimplifyToExp {
269
- transfrom_types : Vec < TransfromType > ,
295
+ transfrom_kinds : Vec < TransfromKind > ,
270
296
}
271
297
272
298
#[ derive( Clone , Copy ) ]
273
- enum CompareType < ' tcx , ' a > {
299
+ enum ExpectedTransformKind < ' tcx , ' a > {
274
300
/// Identical statements.
275
301
Same ( & ' a StatementKind < ' tcx > ) ,
276
302
/// Assignment statements have the same value.
277
- Eq ( & ' a Place < ' tcx > , Ty < ' tcx > , ScalarInt ) ,
303
+ SameByEq { place : & ' a Place < ' tcx > , ty : Ty < ' tcx > , scalar : ScalarInt } ,
278
304
/// Enum variant comparison type.
279
- Discr { place : & ' a Place < ' tcx > , ty : Ty < ' tcx > , is_signed : bool } ,
305
+ Cast { place : & ' a Place < ' tcx > , ty : Ty < ' tcx > } ,
280
306
}
281
307
282
- enum TransfromType {
308
+ enum TransfromKind {
283
309
Same ,
284
- Eq ,
285
- Discr ,
310
+ Cast ,
286
311
}
287
312
288
- impl From < CompareType < ' _ , ' _ > > for TransfromType {
289
- fn from ( compare_type : CompareType < ' _ , ' _ > ) -> Self {
313
+ impl From < ExpectedTransformKind < ' _ , ' _ > > for TransfromKind {
314
+ fn from ( compare_type : ExpectedTransformKind < ' _ , ' _ > ) -> Self {
290
315
match compare_type {
291
- CompareType :: Same ( _) => TransfromType :: Same ,
292
- CompareType :: Eq ( _ , _ , _ ) => TransfromType :: Eq ,
293
- CompareType :: Discr { .. } => TransfromType :: Discr ,
316
+ ExpectedTransformKind :: Same ( _) => TransfromKind :: Same ,
317
+ ExpectedTransformKind :: SameByEq { .. } => TransfromKind :: Same ,
318
+ ExpectedTransformKind :: Cast { .. } => TransfromKind :: Cast ,
294
319
}
295
320
}
296
321
}
@@ -354,7 +379,7 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
354
379
return None ;
355
380
}
356
381
let mut target_iter = targets. iter ( ) ;
357
- let ( first_val , first_target) = target_iter. next ( ) . unwrap ( ) ;
382
+ let ( first_case_val , first_target) = target_iter. next ( ) . unwrap ( ) ;
358
383
let first_terminator_kind = & bbs[ first_target] . terminator ( ) . kind ;
359
384
// Check that destinations are identical, and if not, then don't optimize this block
360
385
if !targets
@@ -364,24 +389,20 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
364
389
return None ;
365
390
}
366
391
367
- let discr_size = tcx. layout_of ( param_env. and ( discr_ty) ) . unwrap ( ) . size ;
392
+ let discr_layout = tcx. layout_of ( param_env. and ( discr_ty) ) . unwrap ( ) ;
368
393
let first_stmts = & bbs[ first_target] . statements ;
369
- let ( second_val , second_target) = target_iter. next ( ) . unwrap ( ) ;
394
+ let ( second_case_val , second_target) = target_iter. next ( ) . unwrap ( ) ;
370
395
let second_stmts = & bbs[ second_target] . statements ;
371
396
if first_stmts. len ( ) != second_stmts. len ( ) {
372
397
return None ;
373
398
}
374
399
375
- fn int_equal ( l : ScalarInt , r : impl Into < u128 > , size : Size ) -> bool {
376
- l. to_bits_unchecked ( ) == ScalarInt :: try_from_uint ( r, size) . unwrap ( ) . to_bits_unchecked ( )
377
- }
378
-
379
400
// We first compare the two branches, and then the other branches need to fulfill the same conditions.
380
- let mut compare_types = Vec :: new ( ) ;
401
+ let mut expected_transform_kinds = Vec :: new ( ) ;
381
402
for ( f, s) in iter:: zip ( first_stmts, second_stmts) {
382
403
let compare_type = match ( & f. kind , & s. kind ) {
383
404
// If two statements are exactly the same, we can optimize.
384
- ( f_s, s_s) if f_s == s_s => CompareType :: Same ( f_s) ,
405
+ ( f_s, s_s) if f_s == s_s => ExpectedTransformKind :: Same ( f_s) ,
385
406
386
407
// If two statements are assignments with the match values to the same place, we can optimize.
387
408
(
@@ -395,22 +416,29 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
395
416
f_c. const_ . try_eval_scalar_int ( tcx, param_env) ,
396
417
s_c. const_ . try_eval_scalar_int ( tcx, param_env) ,
397
418
) {
398
- ( Some ( f) , Some ( s) ) if f == s => CompareType :: Eq ( lhs_f, f_c. const_ . ty ( ) , f) ,
399
- // Enum variants can also be simplified to an assignment statement if their values are equal.
400
- // We need to consider both unsigned and signed scenarios here.
419
+ ( Some ( f) , Some ( s) ) if f == s => ExpectedTransformKind :: SameByEq {
420
+ place : lhs_f,
421
+ ty : f_c. const_ . ty ( ) ,
422
+ scalar : f,
423
+ } ,
424
+ // Enum variants can also be simplified to an assignment statement,
425
+ // if we can use `IntToInt` cast to get an equal value.
401
426
( Some ( f) , Some ( s) )
402
- if ( ( f_c. const_ . ty ( ) . is_signed ( ) || discr_ty. is_signed ( ) )
403
- && int_equal ( f, first_val, discr_size)
404
- && int_equal ( s, second_val, discr_size) )
405
- || ( Some ( f) == ScalarInt :: try_from_uint ( first_val, f. size ( ) )
406
- && Some ( s)
407
- == ScalarInt :: try_from_uint ( second_val, s. size ( ) ) ) =>
427
+ if ( can_cast (
428
+ tcx,
429
+ first_case_val,
430
+ discr_layout,
431
+ f_c. const_ . ty ( ) ,
432
+ f,
433
+ ) && can_cast (
434
+ tcx,
435
+ second_case_val,
436
+ discr_layout,
437
+ f_c. const_ . ty ( ) ,
438
+ s,
439
+ ) ) =>
408
440
{
409
- CompareType :: Discr {
410
- place : lhs_f,
411
- ty : f_c. const_ . ty ( ) ,
412
- is_signed : f_c. const_ . ty ( ) . is_signed ( ) || discr_ty. is_signed ( ) ,
413
- }
441
+ ExpectedTransformKind :: Cast { place : lhs_f, ty : f_c. const_ . ty ( ) }
414
442
}
415
443
_ => {
416
444
return None ;
@@ -421,47 +449,36 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
421
449
// Otherwise we cannot optimize. Try another block.
422
450
_ => return None ,
423
451
} ;
424
- compare_types . push ( compare_type) ;
452
+ expected_transform_kinds . push ( compare_type) ;
425
453
}
426
454
427
455
// All remaining BBs need to fulfill the same pattern as the two BBs from the previous step.
428
456
for ( other_val, other_target) in target_iter {
429
457
let other_stmts = & bbs[ other_target] . statements ;
430
- if compare_types . len ( ) != other_stmts. len ( ) {
458
+ if expected_transform_kinds . len ( ) != other_stmts. len ( ) {
431
459
return None ;
432
460
}
433
- for ( f, s) in iter:: zip ( & compare_types , other_stmts) {
461
+ for ( f, s) in iter:: zip ( & expected_transform_kinds , other_stmts) {
434
462
match ( * f, & s. kind ) {
435
- ( CompareType :: Same ( f_s) , s_s) if f_s == s_s => { }
463
+ ( ExpectedTransformKind :: Same ( f_s) , s_s) if f_s == s_s => { }
436
464
(
437
- CompareType :: Eq ( lhs_f, f_ty, val ) ,
465
+ ExpectedTransformKind :: SameByEq { place : lhs_f, ty : f_ty, scalar } ,
438
466
StatementKind :: Assign ( box ( lhs_s, Rvalue :: Use ( Operand :: Constant ( s_c) ) ) ) ,
439
467
) if lhs_f == lhs_s
440
468
&& s_c. const_ . ty ( ) == f_ty
441
- && s_c. const_ . try_eval_scalar_int ( tcx, param_env) == Some ( val ) => { }
469
+ && s_c. const_ . try_eval_scalar_int ( tcx, param_env) == Some ( scalar ) => { }
442
470
(
443
- CompareType :: Discr { place : lhs_f, ty : f_ty, is_signed } ,
471
+ ExpectedTransformKind :: Cast { place : lhs_f, ty : f_ty } ,
444
472
StatementKind :: Assign ( box ( lhs_s, Rvalue :: Use ( Operand :: Constant ( s_c) ) ) ) ,
445
- ) if lhs_f == lhs_s && s_c. const_ . ty ( ) == f_ty => {
446
- let Some ( f) = s_c. const_ . try_eval_scalar_int ( tcx, param_env) else {
447
- return None ;
448
- } ;
449
- if is_signed
450
- && s_c. const_ . ty ( ) . is_signed ( )
451
- && int_equal ( f, other_val, discr_size)
452
- {
453
- continue ;
454
- }
455
- if Some ( f) == ScalarInt :: try_from_uint ( other_val, f. size ( ) ) {
456
- continue ;
457
- }
458
- return None ;
459
- }
473
+ ) if let Some ( f) = s_c. const_ . try_eval_scalar_int ( tcx, param_env)
474
+ && lhs_f == lhs_s
475
+ && s_c. const_ . ty ( ) == f_ty
476
+ && can_cast ( tcx, other_val, discr_layout, f_ty, f) => { }
460
477
_ => return None ,
461
478
}
462
479
}
463
480
}
464
- self . transfrom_types = compare_types . into_iter ( ) . map ( |c| c. into ( ) ) . collect ( ) ;
481
+ self . transfrom_kinds = expected_transform_kinds . into_iter ( ) . map ( |c| c. into ( ) ) . collect ( ) ;
465
482
Some ( ( ) )
466
483
}
467
484
@@ -479,13 +496,13 @@ impl<'tcx> SimplifyMatch<'tcx> for SimplifyToExp {
479
496
let ( _, first) = targets. iter ( ) . next ( ) . unwrap ( ) ;
480
497
let first = & bbs[ first] ;
481
498
482
- for ( t, s) in iter:: zip ( & self . transfrom_types , & first. statements ) {
499
+ for ( t, s) in iter:: zip ( & self . transfrom_kinds , & first. statements ) {
483
500
match ( t, & s. kind ) {
484
- ( TransfromType :: Same , _ ) | ( TransfromType :: Eq , _) => {
501
+ ( TransfromKind :: Same , _) => {
485
502
patch. add_statement ( parent_end, s. kind . clone ( ) ) ;
486
503
}
487
504
(
488
- TransfromType :: Discr ,
505
+ TransfromKind :: Cast ,
489
506
StatementKind :: Assign ( box ( lhs, Rvalue :: Use ( Operand :: Constant ( f_c) ) ) ) ,
490
507
) => {
491
508
let operand = Operand :: Copy ( Place :: from ( discr_local) ) ;
0 commit comments