@@ -3,16 +3,17 @@ use std::fmt;
3
3
use arrayvec:: ArrayVec ;
4
4
use either:: Either ;
5
5
use rustc_abi as abi;
6
- use rustc_abi:: { Align , BackendRepr , Size } ;
6
+ use rustc_abi:: { Align , BackendRepr , FIRST_VARIANT , Primitive , Size , TagEncoding , Variants } ;
7
7
use rustc_middle:: mir:: interpret:: { Pointer , Scalar , alloc_range} ;
8
8
use rustc_middle:: mir:: { self , ConstValue } ;
9
9
use rustc_middle:: ty:: Ty ;
10
10
use rustc_middle:: ty:: layout:: { LayoutOf , TyAndLayout } ;
11
11
use rustc_middle:: { bug, span_bug} ;
12
- use tracing:: debug;
12
+ use tracing:: { debug, instrument } ;
13
13
14
14
use super :: place:: { PlaceRef , PlaceValue } ;
15
15
use super :: { FunctionCx , LocalRef } ;
16
+ use crate :: common:: IntPredicate ;
16
17
use crate :: traits:: * ;
17
18
use crate :: { MemFlags , size_of_val} ;
18
19
@@ -415,6 +416,140 @@ impl<'a, 'tcx, V: CodegenObject> OperandRef<'tcx, V> {
415
416
416
417
OperandRef { val, layout : field }
417
418
}
419
+
420
+ /// Obtain the actual discriminant of a value.
421
+ #[ instrument( level = "trace" , skip( fx, bx) ) ]
422
+ pub fn codegen_get_discr < Bx : BuilderMethods < ' a , ' tcx , Value = V > > (
423
+ self ,
424
+ fx : & mut FunctionCx < ' a , ' tcx , Bx > ,
425
+ bx : & mut Bx ,
426
+ cast_to : Ty < ' tcx > ,
427
+ ) -> V {
428
+ let dl = & bx. tcx ( ) . data_layout ;
429
+ let cast_to_layout = bx. cx ( ) . layout_of ( cast_to) ;
430
+ let cast_to = bx. cx ( ) . immediate_backend_type ( cast_to_layout) ;
431
+ if self . layout . is_uninhabited ( ) {
432
+ return bx. cx ( ) . const_poison ( cast_to) ;
433
+ }
434
+ let ( tag_scalar, tag_encoding, tag_field) = match self . layout . variants {
435
+ Variants :: Empty => unreachable ! ( "we already handled uninhabited types" ) ,
436
+ Variants :: Single { index } => {
437
+ let discr_val =
438
+ if let Some ( discr) = self . layout . ty . discriminant_for_variant ( bx. tcx ( ) , index) {
439
+ discr. val
440
+ } else {
441
+ assert_eq ! ( index, FIRST_VARIANT ) ;
442
+ 0
443
+ } ;
444
+ return bx. cx ( ) . const_uint_big ( cast_to, discr_val) ;
445
+ }
446
+ Variants :: Multiple { tag, ref tag_encoding, tag_field, .. } => {
447
+ ( tag, tag_encoding, tag_field)
448
+ }
449
+ } ;
450
+
451
+ // Read the tag/niche-encoded discriminant from memory.
452
+ let tag_op = match self . val {
453
+ OperandValue :: ZeroSized => bug ! ( ) ,
454
+ OperandValue :: Immediate ( _) | OperandValue :: Pair ( _, _) => {
455
+ self . extract_field ( fx, bx, tag_field)
456
+ }
457
+ OperandValue :: Ref ( place) => {
458
+ let tag = place. with_type ( self . layout ) . project_field ( bx, tag_field) ;
459
+ bx. load_operand ( tag)
460
+ }
461
+ } ;
462
+ let tag_imm = tag_op. immediate ( ) ;
463
+
464
+ // Decode the discriminant (specifically if it's niche-encoded).
465
+ match * tag_encoding {
466
+ TagEncoding :: Direct => {
467
+ let signed = match tag_scalar. primitive ( ) {
468
+ // We use `i1` for bytes that are always `0` or `1`,
469
+ // e.g., `#[repr(i8)] enum E { A, B }`, but we can't
470
+ // let LLVM interpret the `i1` as signed, because
471
+ // then `i1 1` (i.e., `E::B`) is effectively `i8 -1`.
472
+ Primitive :: Int ( _, signed) => !tag_scalar. is_bool ( ) && signed,
473
+ _ => false ,
474
+ } ;
475
+ bx. intcast ( tag_imm, cast_to, signed)
476
+ }
477
+ TagEncoding :: Niche { untagged_variant, ref niche_variants, niche_start } => {
478
+ // Cast to an integer so we don't have to treat a pointer as a
479
+ // special case.
480
+ let ( tag, tag_llty) = match tag_scalar. primitive ( ) {
481
+ // FIXME(erikdesjardins): handle non-default addrspace ptr sizes
482
+ Primitive :: Pointer ( _) => {
483
+ let t = bx. type_from_integer ( dl. ptr_sized_integer ( ) ) ;
484
+ let tag = bx. ptrtoint ( tag_imm, t) ;
485
+ ( tag, t)
486
+ }
487
+ _ => ( tag_imm, bx. cx ( ) . immediate_backend_type ( tag_op. layout ) ) ,
488
+ } ;
489
+
490
+ let relative_max = niche_variants. end ( ) . as_u32 ( ) - niche_variants. start ( ) . as_u32 ( ) ;
491
+
492
+ // We have a subrange `niche_start..=niche_end` inside `range`.
493
+ // If the value of the tag is inside this subrange, it's a
494
+ // "niche value", an increment of the discriminant. Otherwise it
495
+ // indicates the untagged variant.
496
+ // A general algorithm to extract the discriminant from the tag
497
+ // is:
498
+ // relative_tag = tag - niche_start
499
+ // is_niche = relative_tag <= (ule) relative_max
500
+ // discr = if is_niche {
501
+ // cast(relative_tag) + niche_variants.start()
502
+ // } else {
503
+ // untagged_variant
504
+ // }
505
+ // However, we will likely be able to emit simpler code.
506
+ let ( is_niche, tagged_discr, delta) = if relative_max == 0 {
507
+ // Best case scenario: only one tagged variant. This will
508
+ // likely become just a comparison and a jump.
509
+ // The algorithm is:
510
+ // is_niche = tag == niche_start
511
+ // discr = if is_niche {
512
+ // niche_start
513
+ // } else {
514
+ // untagged_variant
515
+ // }
516
+ let niche_start = bx. cx ( ) . const_uint_big ( tag_llty, niche_start) ;
517
+ let is_niche = bx. icmp ( IntPredicate :: IntEQ , tag, niche_start) ;
518
+ let tagged_discr =
519
+ bx. cx ( ) . const_uint ( cast_to, niche_variants. start ( ) . as_u32 ( ) as u64 ) ;
520
+ ( is_niche, tagged_discr, 0 )
521
+ } else {
522
+ // The special cases don't apply, so we'll have to go with
523
+ // the general algorithm.
524
+ let relative_discr = bx. sub ( tag, bx. cx ( ) . const_uint_big ( tag_llty, niche_start) ) ;
525
+ let cast_tag = bx. intcast ( relative_discr, cast_to, false ) ;
526
+ let is_niche = bx. icmp (
527
+ IntPredicate :: IntULE ,
528
+ relative_discr,
529
+ bx. cx ( ) . const_uint ( tag_llty, relative_max as u64 ) ,
530
+ ) ;
531
+ ( is_niche, cast_tag, niche_variants. start ( ) . as_u32 ( ) as u128 )
532
+ } ;
533
+
534
+ let tagged_discr = if delta == 0 {
535
+ tagged_discr
536
+ } else {
537
+ bx. add ( tagged_discr, bx. cx ( ) . const_uint_big ( cast_to, delta) )
538
+ } ;
539
+
540
+ let discr = bx. select (
541
+ is_niche,
542
+ tagged_discr,
543
+ bx. cx ( ) . const_uint ( cast_to, untagged_variant. as_u32 ( ) as u64 ) ,
544
+ ) ;
545
+
546
+ // In principle we could insert assumes on the possible range of `discr`, but
547
+ // currently in LLVM this seems to be a pessimization.
548
+
549
+ discr
550
+ }
551
+ }
552
+ }
418
553
}
419
554
420
555
impl < ' a , ' tcx , V : CodegenObject > OperandValue < V > {
0 commit comments