3
3
use std:: fmt;
4
4
use tracing:: debug;
5
5
6
- use hir_def:: { DefWithBodyId , EnumVariantId , HasModule , LocalFieldId , ModuleId , VariantId } ;
6
+ use hir_def:: { DefWithBodyId , EnumId , EnumVariantId , HasModule , LocalFieldId , ModuleId , VariantId } ;
7
7
use rustc_hash:: FxHashMap ;
8
8
use rustc_pattern_analysis:: {
9
9
constructor:: { Constructor , ConstructorSet , VariantVisibility } ,
@@ -36,6 +36,24 @@ pub(crate) type WitnessPat<'p> = rustc_pattern_analysis::pat::WitnessPat<MatchCh
36
36
#[ derive( Copy , Clone , Debug , PartialEq , Eq ) ]
37
37
pub ( crate ) enum Void { }
38
38
39
+ /// An index type for enum variants. This ranges from 0 to `variants.len()`, whereas `EnumVariantId`
40
+ /// can take arbitrary large values (and hence mustn't be used with `IndexVec`/`BitSet`).
41
+ #[ derive( Copy , Clone , Debug , PartialEq , Eq , Hash ) ]
42
+ pub ( crate ) struct EnumVariantContiguousIndex ( usize ) ;
43
+
44
+ impl EnumVariantContiguousIndex {
45
+ fn from_enum_variant_id ( db : & dyn HirDatabase , target_evid : EnumVariantId ) -> Self {
46
+ // Find the index of this variant in the list of variants.
47
+ use hir_def:: Lookup ;
48
+ let i = target_evid. lookup ( db. upcast ( ) ) . index as usize ;
49
+ EnumVariantContiguousIndex ( i)
50
+ }
51
+
52
+ fn to_enum_variant_id ( self , db : & dyn HirDatabase , eid : EnumId ) -> EnumVariantId {
53
+ db. enum_data ( eid) . variants [ self . 0 ] . 0
54
+ }
55
+ }
56
+
39
57
#[ derive( Clone ) ]
40
58
pub ( crate ) struct MatchCheckCtx < ' p > {
41
59
module : ModuleId ,
@@ -89,9 +107,18 @@ impl<'p> MatchCheckCtx<'p> {
89
107
}
90
108
}
91
109
92
- fn variant_id_for_adt ( ctor : & Constructor < Self > , adt : hir_def:: AdtId ) -> Option < VariantId > {
110
+ fn variant_id_for_adt (
111
+ db : & ' p dyn HirDatabase ,
112
+ ctor : & Constructor < Self > ,
113
+ adt : hir_def:: AdtId ,
114
+ ) -> Option < VariantId > {
93
115
match ctor {
94
- & Variant ( id) => Some ( id. into ( ) ) ,
116
+ Variant ( id) => {
117
+ let hir_def:: AdtId :: EnumId ( eid) = adt else {
118
+ panic ! ( "bad constructor {ctor:?} for adt {adt:?}" )
119
+ } ;
120
+ Some ( id. to_enum_variant_id ( db, eid) . into ( ) )
121
+ }
95
122
Struct | UnionField => match adt {
96
123
hir_def:: AdtId :: EnumId ( _) => None ,
97
124
hir_def:: AdtId :: StructId ( id) => Some ( id. into ( ) ) ,
@@ -175,19 +202,24 @@ impl<'p> MatchCheckCtx<'p> {
175
202
ctor = Struct ;
176
203
arity = 1 ;
177
204
}
178
- & TyKind :: Adt ( adt, _) => {
205
+ & TyKind :: Adt ( AdtId ( adt) , _) => {
179
206
ctor = match pat. kind . as_ref ( ) {
180
- PatKind :: Leaf { .. } if matches ! ( adt. 0 , hir_def:: AdtId :: UnionId ( _) ) => {
207
+ PatKind :: Leaf { .. } if matches ! ( adt, hir_def:: AdtId :: UnionId ( _) ) => {
181
208
UnionField
182
209
}
183
210
PatKind :: Leaf { .. } => Struct ,
184
- PatKind :: Variant { enum_variant, .. } => Variant ( * enum_variant) ,
211
+ PatKind :: Variant { enum_variant, .. } => {
212
+ Variant ( EnumVariantContiguousIndex :: from_enum_variant_id (
213
+ self . db ,
214
+ * enum_variant,
215
+ ) )
216
+ }
185
217
_ => {
186
218
never ! ( ) ;
187
219
Wildcard
188
220
}
189
221
} ;
190
- let variant = Self :: variant_id_for_adt ( & ctor, adt. 0 ) . unwrap ( ) ;
222
+ let variant = Self :: variant_id_for_adt ( self . db , & ctor, adt) . unwrap ( ) ;
191
223
arity = variant. variant_data ( self . db . upcast ( ) ) . fields ( ) . len ( ) ;
192
224
}
193
225
_ => {
@@ -239,7 +271,7 @@ impl<'p> MatchCheckCtx<'p> {
239
271
PatKind :: Deref { subpattern : subpatterns. next ( ) . unwrap ( ) }
240
272
}
241
273
TyKind :: Adt ( adt, substs) => {
242
- let variant = Self :: variant_id_for_adt ( pat. ctor ( ) , adt. 0 ) . unwrap ( ) ;
274
+ let variant = Self :: variant_id_for_adt ( self . db , pat. ctor ( ) , adt. 0 ) . unwrap ( ) ;
243
275
let subpatterns = self
244
276
. list_variant_fields ( pat. ty ( ) , variant)
245
277
. zip ( subpatterns)
@@ -277,7 +309,7 @@ impl<'p> MatchCheckCtx<'p> {
277
309
impl < ' p > PatCx for MatchCheckCtx < ' p > {
278
310
type Error = ( ) ;
279
311
type Ty = Ty ;
280
- type VariantIdx = EnumVariantId ;
312
+ type VariantIdx = EnumVariantContiguousIndex ;
281
313
type StrLit = Void ;
282
314
type ArmData = ( ) ;
283
315
type PatData = PatData < ' p > ;
@@ -303,7 +335,7 @@ impl<'p> PatCx for MatchCheckCtx<'p> {
303
335
// patterns. If we're here we can assume this is a box pattern.
304
336
1
305
337
} else {
306
- let variant = Self :: variant_id_for_adt ( ctor, adt) . unwrap ( ) ;
338
+ let variant = Self :: variant_id_for_adt ( self . db , ctor, adt) . unwrap ( ) ;
307
339
variant. variant_data ( self . db . upcast ( ) ) . fields ( ) . len ( )
308
340
}
309
341
}
@@ -343,7 +375,7 @@ impl<'p> PatCx for MatchCheckCtx<'p> {
343
375
let subst_ty = substs. at ( Interner , 0 ) . assert_ty_ref ( Interner ) . clone ( ) ;
344
376
single ( subst_ty)
345
377
} else {
346
- let variant = Self :: variant_id_for_adt ( ctor, adt) . unwrap ( ) ;
378
+ let variant = Self :: variant_id_for_adt ( self . db , ctor, adt) . unwrap ( ) ;
347
379
let ( adt, _) = ty. as_adt ( ) . unwrap ( ) ;
348
380
349
381
let adt_is_local =
@@ -421,15 +453,15 @@ impl<'p> PatCx for MatchCheckCtx<'p> {
421
453
ConstructorSet :: NoConstructors
422
454
} else {
423
455
let mut variants = FxHashMap :: default ( ) ;
424
- for & ( variant, _) in enum_data. variants . iter ( ) {
456
+ for ( i , & ( variant, _) ) in enum_data. variants . iter ( ) . enumerate ( ) {
425
457
let is_uninhabited =
426
458
is_enum_variant_uninhabited_from ( variant, subst, cx. module , cx. db ) ;
427
459
let visibility = if is_uninhabited {
428
460
VariantVisibility :: Empty
429
461
} else {
430
462
VariantVisibility :: Visible
431
463
} ;
432
- variants. insert ( variant , visibility) ;
464
+ variants. insert ( EnumVariantContiguousIndex ( i ) , visibility) ;
433
465
}
434
466
435
467
ConstructorSet :: Variants {
@@ -453,10 +485,10 @@ impl<'p> PatCx for MatchCheckCtx<'p> {
453
485
f : & mut fmt:: Formatter < ' _ > ,
454
486
pat : & rustc_pattern_analysis:: pat:: DeconstructedPat < Self > ,
455
487
) -> fmt:: Result {
488
+ let db = pat. data ( ) . db ;
456
489
let variant =
457
- pat. ty ( ) . as_adt ( ) . and_then ( |( adt, _) | Self :: variant_id_for_adt ( pat. ctor ( ) , adt) ) ;
490
+ pat. ty ( ) . as_adt ( ) . and_then ( |( adt, _) | Self :: variant_id_for_adt ( db , pat. ctor ( ) , adt) ) ;
458
491
459
- let db = pat. data ( ) . db ;
460
492
if let Some ( variant) = variant {
461
493
match variant {
462
494
VariantId :: EnumVariantId ( v) => {
0 commit comments