Skip to content

Commit 6c521ae

Browse files
authored
Fix approximateOr of (A & Double) | Null (#16241)
2 parents a210b7f + 36a9a9f commit 6c521ae

File tree

3 files changed

+31
-10
lines changed

3 files changed

+31
-10
lines changed

compiler/src/dotty/tools/dotc/core/TypeOps.scala

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -225,16 +225,18 @@ object TypeOps:
225225
*/
226226
def orDominator(tp: Type)(using Context): Type = {
227227

228-
/** a faster version of cs1 intersect cs2 that treats bottom types correctly */
228+
/** a faster version of cs1 intersect cs2 */
229229
def intersect(cs1: List[ClassSymbol], cs2: List[ClassSymbol]): List[ClassSymbol] =
230-
if cs1.head == defn.NothingClass then cs2
231-
else if cs2.head == defn.NothingClass then cs1
232-
else if cs1.head == defn.NullClass && !ctx.explicitNulls && cs2.head.derivesFrom(defn.ObjectClass) then cs2
233-
else if cs2.head == defn.NullClass && !ctx.explicitNulls && cs1.head.derivesFrom(defn.ObjectClass) then cs1
234-
else
235-
val cs2AsSet = new util.HashSet[ClassSymbol](128)
236-
cs2.foreach(cs2AsSet += _)
237-
cs1.filter(cs2AsSet.contains)
230+
val cs2AsSet = BaseClassSet(cs2)
231+
cs1.filter(cs2AsSet.contains)
232+
233+
/** a version of Type#baseClasses that treats bottom types correctly */
234+
def orBaseClasses(tp: Type): List[ClassSymbol] = tp.stripTypeVar match
235+
case OrType(tp1, tp2) =>
236+
if tp1.isBottomType && (tp1 frozen_<:< tp2) then orBaseClasses(tp2)
237+
else if tp2.isBottomType && (tp2 frozen_<:< tp1) then orBaseClasses(tp1)
238+
else intersect(orBaseClasses(tp1), orBaseClasses(tp2))
239+
case _ => tp.baseClasses
238240

239241
/** The minimal set of classes in `cs` which derive all other classes in `cs` */
240242
def dominators(cs: List[ClassSymbol], accu: List[ClassSymbol]): List[ClassSymbol] = (cs: @unchecked) match {
@@ -369,7 +371,7 @@ object TypeOps:
369371
}
370372

371373
// Step 3: Intersect base classes of both sides
372-
val commonBaseClasses = tp.mapReduceOr(_.baseClasses)(intersect)
374+
val commonBaseClasses = orBaseClasses(tp)
373375
val doms = dominators(commonBaseClasses, Nil)
374376
def baseTp(cls: ClassSymbol): Type =
375377
tp.baseType(cls).mapReduceOr(identity)(mergeRefinedOrApplied)

tests/explicit-nulls/pos/i16236.scala

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
// Copy of tests/pos/i16236.scala
2+
trait A
3+
4+
def consume[T](t: T): Unit = ()
5+
6+
def fails(p: (Double & A) | Null): Unit = consume(p) // was: assertion failed: <notype> & A
7+
8+
def switchedOrder(p: (A & Double) | Null): Unit = consume(p) // ok
9+
def nonPrimitive(p: (String & A) | Null): Unit = consume(p) // ok
10+
def notNull(p: (Double & A)): Unit = consume(p) // ok

tests/pos/i16236.scala

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
trait A
2+
3+
def consume[T](t: T): Unit = ()
4+
5+
def fails(p: (Double & A) | Null): Unit = consume(p) // was: assertion failed: <notype> & A
6+
7+
def switchedOrder(p: (A & Double) | Null): Unit = consume(p) // ok
8+
def nonPrimitive(p: (String & A) | Null): Unit = consume(p) // ok
9+
def notNull(p: (Double & A)): Unit = consume(p) // ok

0 commit comments

Comments
 (0)