diff --git a/compiler/src/dotty/tools/dotc/core/Types.scala b/compiler/src/dotty/tools/dotc/core/Types.scala index 6650771963f9..ae7ec91e7108 100644 --- a/compiler/src/dotty/tools/dotc/core/Types.scala +++ b/compiler/src/dotty/tools/dotc/core/Types.scala @@ -853,18 +853,23 @@ object Types { def goAnd(l: Type, r: Type) = go(l).meet(go(r), pre, safeIntersection = ctx.base.pendingMemberSearches.contains(name)) - def goOr(tp: OrType) = tp match { - case OrNull(tp1) if Nullables.unsafeNullsEnabled => - // Selecting `name` from a type `T | Null` is like selecting `name` from `T`, if - // unsafeNulls is enabled. This can throw at runtime, but we trade soundness for usability. - tp1.findMember(name, pre.stripNull, required, excluded) - case _ => + def goOr(tp: OrType) = + inline def searchAfterJoin = // we need to keep the invariant that `pre <: tp`. Branch `union-types-narrow-prefix` // achieved that by narrowing `pre` to each alternative, but it led to merge errors in // lots of places. The present strategy is instead of widen `tp` using `join` to be a // supertype of `pre`. go(tp.join) - } + + if Nullables.unsafeNullsEnabled then tp match + case OrNull(tp1) if tp1 <:< defn.ObjectType => + // Selecting `name` from a type `T | Null` is like selecting `name` from `T`, if + // unsafeNulls is enabled and T is a subtype of AnyRef. + // This can throw at runtime, but we trade soundness for usability. + tp1.findMember(name, pre.stripNull, required, excluded) + case _ => + searchAfterJoin + else searchAfterJoin val recCount = ctx.base.findMemberCount if (recCount >= Config.LogPendingFindMemberThreshold) diff --git a/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala b/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala index f7e33bd4a5f7..3dcec413540f 100644 --- a/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala +++ b/compiler/src/dotty/tools/dotc/typer/TypeAssigner.scala @@ -162,7 +162,7 @@ trait TypeAssigner { val qualType = qual.tpe.widenIfUnstable def kind = if tree.isType then "type" else "value" val foundWithoutNull = qualType match - case OrNull(qualType1) => + case OrNull(qualType1) if qualType1 <:< defn.ObjectType => val name = tree.name val pre = maybeSkolemizePrefix(qualType1, name) reallyExists(qualType1.findMember(name, pre)) diff --git a/compiler/src/dotty/tools/dotc/typer/Typer.scala b/compiler/src/dotty/tools/dotc/typer/Typer.scala index 433075f73fd9..ff8641f4b9c3 100644 --- a/compiler/src/dotty/tools/dotc/typer/Typer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Typer.scala @@ -626,7 +626,7 @@ class Typer extends Namer val qual = typedExpr(tree.qualifier, shallowSelectionProto(tree.name, pt, this)) val qual1 = if Nullables.unsafeNullsEnabled then qual.tpe match { - case OrNull(tpe1) => + case OrNull(tpe1) if tpe1 <:< defn.ObjectType => qual.cast(AndType(qual.tpe, tpe1)) case tp => if tp.isNullType diff --git a/tests/explicit-nulls/neg/AnyValOrNullSelect.scala b/tests/explicit-nulls/neg/AnyValOrNullSelect.scala new file mode 100644 index 000000000000..44e8b4e7edfb --- /dev/null +++ b/tests/explicit-nulls/neg/AnyValOrNullSelect.scala @@ -0,0 +1,13 @@ +case class MyVal(i: Int) extends AnyVal: + def printVal: Unit = + println(i) + +class Test: + val v: MyVal | Null = MyVal(1) + + def f1 = + v.printVal // error: value printVal is not a member of MyVal | Null + + def f1 = + import scala.language.unsafeNulls + v.printVal // error: value printVal is not a member of MyVal | Null diff --git a/tests/explicit-nulls/pos/AnyValOrNull.scala b/tests/explicit-nulls/pos/AnyValOrNull.scala new file mode 100644 index 000000000000..098d3eba973d --- /dev/null +++ b/tests/explicit-nulls/pos/AnyValOrNull.scala @@ -0,0 +1,36 @@ +case class MyVal(i: Boolean) extends AnyVal + +class Test1: + + def test1 = + val v: AnyVal | Null = null + if v == null then + println("null") + + def test2 = + val v: Int | Null = 1 + if v != null then + println(v) + + def test3 = + val v: MyVal | Null = MyVal(false) + if v != null then + println(v) + +class Test2: + import scala.language.unsafeNulls + + def test1 = + val v: AnyVal | Null = null + if v == null then + println("null") + + def test2 = + val v: Int | Null = 1 + if v != null then + println(v) + + def test3 = + val v: MyVal | Null = MyVal(false) + if v != null then + println(v)