@@ -259,7 +259,11 @@ object Types {
259
259
/** True if this type is an instance of the given `cls` or an instance of
260
260
* a non-bottom subclass of `cls`.
261
261
*/
262
- final def derivesFrom (cls : Symbol , afterErasure : Boolean = false )(using Context ): Boolean = {
262
+ final def derivesFrom (cls : Symbol , isErased : Boolean = false )(using Context ): Boolean = {
263
+ def isLowerBottomType (tp : Type ) =
264
+ (if isErased then tp.isBottomTypeAfterErasure else tp.isBottomType)
265
+ && (tp.hasClassSymbol(defn.NothingClass )
266
+ || cls != defn.NothingClass && ! cls.isValueClass)
263
267
def loop (tp : Type ): Boolean = tp match {
264
268
case tp : TypeRef =>
265
269
val sym = tp.symbol
@@ -276,10 +280,6 @@ object Types {
276
280
// If the type is `T | Null` or `T | Nothing`, the class is != Nothing,
277
281
// and `T` derivesFrom the class, then the OrType derivesFrom the class.
278
282
// Otherwise, we need to check both sides derivesFrom the class.
279
- def isLowerBottomType (tp : Type ) =
280
- (if afterErasure then t.isBottomTypeAfterErasure else t.isBottomType)
281
- && (tp.hasClassSymbol(defn.NothingClass )
282
- || cls != defn.NothingClass && ! cls.isValueClass)
283
283
if isLowerBottomType(tp.tp1) then
284
284
loop(tp.tp2)
285
285
else if isLowerBottomType(tp.tp2) then
@@ -463,28 +463,45 @@ object Types {
463
463
* instance, or NoSymbol if none exists (either because this type is not a
464
464
* value type, or because superclasses are ambiguous).
465
465
*/
466
- final def classSymbol (using Context ): Symbol = this match {
467
- case tp : TypeRef =>
468
- val sym = tp.symbol
469
- if (sym.isClass) sym else tp.superType.classSymbol
470
- case tp : TypeProxy =>
471
- tp.underlying.classSymbol
472
- case tp : ClassInfo =>
473
- tp.cls
474
- case AndType (l, r) =>
475
- val lsym = l.classSymbol
476
- val rsym = r.classSymbol
477
- if (lsym isSubClass rsym) lsym
478
- else if (rsym isSubClass lsym) rsym
479
- else NoSymbol
480
- case tp : OrType =>
481
- tp.join.classSymbol
482
- case _ : JavaArrayType =>
483
- defn.ArrayClass
484
- case _ =>
485
- NoSymbol
486
- }
466
+ final def classSymbol (using Context ): Symbol = classSymbolWith(false )
467
+ final def classSymbolAfterErasure (using Context ): Symbol = classSymbolWith(true )
468
+
469
+ final private def classSymbolWith (isErased : Boolean )(using Context ): Symbol = {
470
+ def loop (tp: Type ): Symbol = tp match {
471
+ case tp : TypeRef =>
472
+ val sym = tp.symbol
473
+ if (sym.isClass) sym else loop(tp.superType)
474
+ case tp : TypeProxy =>
475
+ loop(tp.underlying)
476
+ case tp : ClassInfo =>
477
+ tp.cls
478
+ case AndType (l, r) =>
479
+ val lsym = loop(l)
480
+ val rsym = loop(r)
481
+ if (lsym isSubClass rsym) lsym
482
+ else if (rsym isSubClass lsym) rsym
483
+ else NoSymbol
484
+ case tp : OrType =>
485
+ if tp.tp1.hasClassSymbol(defn.NothingClass ) then
486
+ loop(tp.tp2)
487
+ else if tp.tp2.hasClassSymbol(defn.NothingClass ) then
488
+ loop(tp.tp1)
489
+ else
490
+ val tp1Null = tp.tp1.hasClassSymbol(defn.NullClass )
491
+ val tp2Null = tp.tp2.hasClassSymbol(defn.NullClass )
492
+ if isErased && (tp1Null || tp2Null) then
493
+ val otherSide = if tp1Null then loop(tp.tp2) else loop(tp.tp1)
494
+ if otherSide.isValueClass then defn.AnyClass else otherSide
495
+ else
496
+ loop(tp.join)
497
+ case _ : JavaArrayType =>
498
+ defn.ArrayClass
499
+ case _ =>
500
+ NoSymbol
501
+ }
487
502
503
+ loop(this )
504
+ }
488
505
/** The least (wrt <:<) set of symbols satisfying the `include` predicate of which this type is a subtype
489
506
*/
490
507
final def parentSymbols (include : Symbol => Boolean )(using Context ): List [Symbol ] = this match {
0 commit comments