Skip to content

Commit 6b690bf

Browse files
committed
Optimize unsafe nulls subtype test
1 parent 4f04826 commit 6b690bf

File tree

12 files changed

+63
-67
lines changed

12 files changed

+63
-67
lines changed

compiler/src/dotty/tools/dotc/config/Feature.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,9 @@ object Feature:
6161
def dynamicsEnabled(using Context): Boolean =
6262
enabled(nme.dynamics)
6363

64+
def unsafeNullsEnabled(using Context) =
65+
ctx.explicitNulls && enabled(nme.unsafeNulls)
66+
6467
def dependentEnabled(using Context) =
6568
enabled(dependent, defn.LanguageExperimentalModule.moduleClass)
6669

compiler/src/dotty/tools/dotc/core/Mode.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,4 +119,7 @@ object Mode {
119119

120120
/** Should we try to convert values ignoring Null type? */
121121
val UnsafeNullConversion: Mode = newMode(27, "UnsafeNullConversion")
122+
123+
/** Unsafe Nulls SubType */
124+
val UnsafeNullsSubType: Mode = newMode(28, "UnsafeNullsSubType")
122125
}

compiler/src/dotty/tools/dotc/core/NullOpsDecorator.scala

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ object NullOpsDecorator {
5959
}
6060

6161
/** Can the type has null value after erasure?
62+
* TODO
6263
*/
6364
def isNullableAfterErasure(using Context): Boolean = self match {
6465
case tp: ClassInfo => tp.cls.isNullableClassAfterErasure
@@ -69,20 +70,20 @@ object NullOpsDecorator {
6970
self.isNullType || self <:< defn.ObjectType
7071
}
7172

72-
def isUnsafelyNulltoAnyRef(pt: Type)(using Context): Boolean =
73-
self.isNullType && pt.isNullableAfterErasure
73+
// def isUnsafelyNulltoAnyRef(pt: Type)(using Context): Boolean =
74+
// self.isNullType && pt.isNullableAfterErasure
7475

75-
def isUnsafeSubtype(pt: Type, relaxedSubtype: Boolean = false)(using Context): Boolean =
76-
val selfs = self.stripAllNulls
77-
val pts = pt.stripAllNulls
78-
if relaxedSubtype then
79-
selfs relaxed_<:< pts
80-
else
81-
selfs <:< pts
76+
// def isUnsafeSubtype(pt: Type, relaxedSubtype: Boolean = false)(using Context): Boolean =
77+
// val selfs = self.stripAllNulls
78+
// val pts = pt.stripAllNulls
79+
// if relaxedSubtype then
80+
// selfs relaxed_<:< pts
81+
// else
82+
// selfs <:< pts
8283

83-
/** Can we convert a tree with type `self` to type `pt` unsafely.
84-
*/
85-
def isUnsafelyConvertible(pt: Type, relaxedSubtype: Boolean = false)(using Context): Boolean =
86-
self.isUnsafelyNulltoAnyRef(pt) || self.isUnsafeSubtype(pt, relaxedSubtype)
84+
// /** Can we convert a tree with type `self` to type `pt` unsafely.
85+
// */
86+
// def unsafeNullsSubType(pt: Type, relaxed: Boolean = false)(using Context): Boolean =
87+
// self.isUnsafelyNulltoAnyRef(pt) || self.isUnsafeSubtype(pt, relaxed)
8788
}
8889
}

compiler/src/dotty/tools/dotc/core/TypeComparer.scala

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -757,7 +757,11 @@ class TypeComparer(@constructorOnly initctx: Context) extends ConstraintHandling
757757
isSubType(hi1, tp2, approx.addLow) || compareGADT || tryLiftedToThis1
758758
case _ =>
759759
def isNullable(tp: Type): Boolean = tp.widenDealias match {
760-
case tp: TypeRef => tp.symbol.isNullableClass
760+
case tp: TypeRef =>
761+
if ctx.mode.is(Mode.UnsafeNullsSubType) then
762+
tp.symbol.isNullableClassAfterErasure
763+
else
764+
tp.symbol.isNullableClass
761765
case tp: RefinedOrRecType => isNullable(tp.parent)
762766
case tp: AppliedType => isNullable(tp.tycon)
763767
case AndType(tp1, tp2) => isNullable(tp1) && isNullable(tp2)

compiler/src/dotty/tools/dotc/core/tasty/TreeUnpickler.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -366,7 +366,7 @@ class TreeUnpickler(reader: TastyReader,
366366
else
367367
val hi0 = readVariances(readType())
368368
val hi =
369-
if ctx.explicitNulls && lo.isNullType && hi0.isNullableAfterErasure
369+
if ctx.explicitNulls && lo.isBottomTypeAfterErasure && hi0.isNullableAfterErasure
370370
then OrNull(hi0) else hi0
371371
TypeBounds(lo, hi)
372372
case ANNOTATEDtype =>
@@ -1247,7 +1247,7 @@ class TreeUnpickler(reader: TastyReader,
12471247
val lo = readTpt()
12481248
val hi0 = if currentAddr == end then lo else readTpt()
12491249
val hi =
1250-
if ctx.explicitNulls && lo.tpe.isNullType && hi0.tpe.isNullableAfterErasure
1250+
if ctx.explicitNulls && lo.tpe.isBottomTypeAfterErasure && hi0.tpe.isNullableAfterErasure
12511251
then TypeTree(OrNull(hi0.tpe)) else hi0
12521252
val alias = if currentAddr == end then EmptyTree else readTpt()
12531253
TypeBoundsTree(lo, hi, alias)

compiler/src/dotty/tools/dotc/core/unpickleScala2/Scala2Unpickler.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -773,7 +773,7 @@ class Scala2Unpickler(bytes: Array[Byte], classRoot: ClassDenotation, moduleClas
773773
val lo = readTypeRef()
774774
val hi0 = readTypeRef()
775775
val hi =
776-
if ctx.explicitNulls && lo.isNullType && hi0.isNullableAfterErasure
776+
if ctx.explicitNulls && lo.isBottomTypeAfterErasure && hi0.isNullableAfterErasure
777777
then OrNull(hi0) else hi0
778778
TypeBounds(lo, hi)
779779
case REFINEDtpe =>
@@ -1255,7 +1255,7 @@ class Scala2Unpickler(bytes: Array[Byte], classRoot: ClassDenotation, moduleClas
12551255
val lo = readTreeRef()
12561256
val hi0 = readTreeRef()
12571257
val hi =
1258-
if ctx.explicitNulls && lo.tpe.isNullType && hi0.tpe.isNullableAfterErasure
1258+
if ctx.explicitNulls && lo.tpe.isBottomTypeAfterErasure && hi0.tpe.isNullableAfterErasure
12591259
then TypeTree(OrNull(hi0.tpe)) else hi0
12601260
TypeBoundsTree(lo, hi)
12611261

compiler/src/dotty/tools/dotc/typer/Implicits.scala

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,8 @@ object Implicits:
133133
else if (mt.paramInfos.lengthCompare(1) == 0 && {
134134
var formal = widenSingleton(mt.paramInfos.head)
135135
if (approx) formal = wildApprox(formal)
136-
explore((argType relaxed_<:< formal.widenExpr) ||
137-
Nullables.convertUnsafeNulls &&
138-
argType.isUnsafeSubtype(formal.widenExpr, true))
136+
Nullables.useUnsafeNullsSubTypeIf(ctx.mode.is(Mode.UnsafeNullConversion))(
137+
explore(argType relaxed_<:< formal.widenExpr))
139138
})
140139
Candidate.Conversion
141140
else
@@ -1314,7 +1313,8 @@ trait Implicits:
13141313

13151314
/** All available implicits, without ranking */
13161315
def allImplicits: Set[TermRef] = {
1317-
val contextuals = ctx.implicits.eligible(wildProto, ctx.mode.is(Mode.UnsafeNullConversion)).map(tryImplicit(_, contextual = true))
1316+
val contextuals = ctx.implicits.eligible(wildProto, ctx.mode.is(Mode.UnsafeNullConversion))
1317+
.map(tryImplicit(_, contextual = true))
13181318
val inscope = implicitScope(wildProto).eligible.map(tryImplicit(_, contextual = false))
13191319
(contextuals.toSet ++ inscope).collect {
13201320
case success: SearchSuccess => success.ref

compiler/src/dotty/tools/dotc/typer/Nullables.scala

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,9 @@ import ast.Trees.mods
2020
object Nullables:
2121
import ast.tpd._
2222

23-
/** Should we try to convert values ignoring Null type at this moment? */
24-
def convertUnsafeNulls(using Context): Boolean =
25-
ctx.explicitNulls && (
26-
config.Feature.enabled(nme.unsafeNulls) ||
27-
ctx.mode.is(Mode.UnsafeNullConversion))
23+
inline def useUnsafeNullsSubTypeIf[T](cond: Boolean)(inline op: Context ?=> T)(using Context): T =
24+
val c = if cond then ctx.addMode(Mode.UnsafeNullsSubType) else ctx
25+
op(using c)
2826

2927
/** A set of val or var references that are known to be not null, plus a set of
3028
* variable references that are not known (anymore) to be not null

compiler/src/dotty/tools/dotc/typer/ProtoTypes.scala

Lines changed: 10 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@ import Trees._
1010
import Constants._
1111
import util.{Stats, SimpleIdentityMap}
1212
import Decorators._
13+
import Nullables.useUnsafeNullsSubTypeIf
1314
import NullOpsDecorator._
1415
import Uniques._
16+
import config.Feature
1517
import config.Printers.typr
1618
import util.SourceFile
1719
import util.Property
@@ -39,10 +41,10 @@ object ProtoTypes {
3941
def isCompatible(tp: Type, pt: Type)(using Context): Boolean =
4042
val tpw = tp.widenExpr
4143
val ptw = pt.widenExpr
42-
(tpw relaxed_<:< ptw)
43-
// If unsafeNulls is enabled, we relax the condition by
44-
// striping all nulls from the types before subtype check.
45-
|| Nullables.convertUnsafeNulls && tpw.isUnsafelyConvertible(ptw, true)
44+
useUnsafeNullsSubTypeIf(
45+
Feature.unsafeNullsEnabled
46+
|| ctx.mode.is(Mode.UnsafeNullConversion))(
47+
tpw relaxed_<:< ptw)
4648
|| viewExists(tp, pt)
4749

4850
/** Like isCompatibe, but using a subtype comparison with necessary eithers
@@ -51,16 +53,11 @@ object ProtoTypes {
5153
def necessarilyCompatible(tp: Type, pt: Type)(using Context): Boolean =
5254
val tpw = tp.widenExpr
5355
val ptw = pt.widenExpr
54-
necessarySubType(tpw, ptw)
56+
useUnsafeNullsSubTypeIf(
57+
Feature.unsafeNullsEnabled
58+
|| ctx.mode.is(Mode.UnsafeNullConversion))(
59+
necessarySubType(tpw, ptw))
5560
|| tpw.isValueSubType(ptw)
56-
|| Nullables.convertUnsafeNulls && {
57-
// See comments in `isCompatible`
58-
val tpwsn = tpw.stripAllNulls
59-
val ptwsn = ptw.stripAllNulls
60-
necessarySubType(tpwsn, ptwsn)
61-
|| tpwsn.isValueSubType(ptwsn)
62-
|| tpwsn.isUnsafelyNulltoAnyRef(ptwsn)
63-
}
6461
|| viewExists(tp, pt)
6562

6663
/** Test compatibility after normalization.

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -547,7 +547,7 @@ class Typer extends Namer
547547

548548
def typeSelectOnTerm(using Context): Tree =
549549
val qual = typedExpr(tree.qualifier, selectionProto(tree.name, pt, this))
550-
val qual1 = if Nullables.convertUnsafeNulls then
550+
val qual1 = if unsafeNullsEnabled then
551551
qual.tpe match {
552552
case OrNull(tpe1) =>
553553
qual.cast(AndType(qual.tpe, tpe1))
@@ -1264,9 +1264,8 @@ class Typer extends Namer
12641264
val pt1 = if ctx.explicitNulls then pt.stripNull else pt
12651265
pt1 match {
12661266
case SAMType(sam)
1267-
if !defn.isFunctionType(pt1) && (
1268-
mt <:< sam ||
1269-
Nullables.convertUnsafeNulls && mt.stripAllNulls <:< sam.stripAllNulls) =>
1267+
if !defn.isFunctionType(pt1)
1268+
&& useUnsafeNullsSubTypeIf(unsafeNullsEnabled)(mt <:< sam) =>
12701269
// SAMs of the form C[?] where C is a class cannot be conversion targets.
12711270
// The resulting class `class $anon extends C[?] {...}` would be illegal,
12721271
// since type arguments to `C`'s super constructor cannot be constructed.
@@ -3474,9 +3473,7 @@ class Typer extends Namer
34743473
val pt1 = if ctx.explicitNulls then pt.stripNull else pt
34753474
pt1 match {
34763475
case SAMType(sam)
3477-
if wtp <:< sam.toFunctionType()
3478-
|| (Nullables.convertUnsafeNulls
3479-
&& wtp.stripAllNulls <:< sam.toFunctionType().stripAllNulls) =>
3476+
if useUnsafeNullsSubTypeIf(unsafeNullsEnabled)(wtp <:< sam.toFunctionType()) =>
34803477
// was ... && isFullyDefined(pt, ForceDegree.flipBottom)
34813478
// but this prevents case blocks from implementing polymorphic partial functions,
34823479
// since we do not know the result parameter a priori. Have to wait until the
@@ -3548,10 +3545,9 @@ class Typer extends Namer
35483545
val treeTpe = tree.tpe
35493546

35503547
def tryUnsafeNullConver(fail: => Tree)(using Context): Tree =
3551-
// If explicitNulls and unsafeNulls are enabled, and
3552-
if ctx.mode.is(Mode.UnsafeNullConversion)
3553-
&& pt.isValueType
3554-
&& treeTpe.isUnsafelyConvertible(pt)
3548+
if pt.isValueType
3549+
&& useUnsafeNullsSubTypeIf(ctx.mode.is(Mode.UnsafeNullConversion))(
3550+
treeTpe <:< pt)
35553551
then tree.cast(pt)
35563552
else fail
35573553

@@ -3566,16 +3562,16 @@ class Typer extends Namer
35663562
else recover(failure.reason)
35673563

35683564
val searchCtx =
3569-
if ctx.explicitNulls && config.Feature.enabled(nme.unsafeNulls) then
3565+
if unsafeNullsEnabled then
35703566
ctx.addMode(Mode.UnsafeNullConversion)
35713567
else ctx
35723568

35733569
inContext(searchCtx) {
35743570
if ctx.mode.is(Mode.ImplicitsEnabled) && tree.typeOpt.isValueType then
35753571
if pt.isRef(defn.AnyValClass) || pt.isRef(defn.ObjectClass) then
35763572
// We want to allow `null` to `AnyRef` if UnsafeNullConversion is enabled
3577-
if !(ctx.mode.is(Mode.UnsafeNullConversion)
3578-
&& treeTpe.isUnsafelyConvertible(pt)) then
3573+
if !(useUnsafeNullsSubTypeIf(ctx.mode.is(Mode.UnsafeNullConversion))(
3574+
treeTpe <:< pt)) then
35793575
report.error(em"the result of an implicit conversion must be more specific than $pt", tree.srcPos)
35803576
tree.cast(pt)
35813577
else

tests/explicit-nulls/unsafe-common/unsafe-cast.scala

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,10 +65,4 @@ class S {
6565
n1(Array("a", null))
6666
n2(Array("a", null))
6767
}
68-
69-
def test[T <: AnyRef](x: T | Null): T = {
70-
val y: T = x // error
71-
val z: T = null // error
72-
x // error
73-
}
7468
}

tests/explicit-nulls/unsafe-common/unsafe-implicit3.scala

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -96,20 +96,20 @@ class S {
9696
val z8: Array[String | Null] | Null = y2 // error
9797
}
9898

99-
def test5[T <: AnyRef] = {
99+
def test5[T >: Null <: AnyRef | Null] = {
100100
given Conversion[T, Array[T]] = _ => ???
101101

102102
val y1: T = ???
103103
val y2: T | Null = ???
104104

105105
val z1: Array[T] = y1
106-
val z2: Array[T | Null] = y1 // error
106+
val z2: Array[T | Null] = y1
107107
val z3: Array[T] | Null = y1
108-
val z4: Array[T | Null] | Null = y1 // error
108+
val z4: Array[T | Null] | Null = y1
109109

110-
val z5: Array[T] = y2 // error
111-
val z6: Array[T | Null] = y2 // error
112-
val z7: Array[T] | Null = y2 // error
113-
val z8: Array[T | Null] | Null = y2 // error
110+
val z5: Array[T] = y2
111+
val z6: Array[T | Null] = y2
112+
val z7: Array[T] | Null = y2
113+
val z8: Array[T | Null] | Null = y2
114114
}
115115
}

0 commit comments

Comments
 (0)