diff --git a/src/dotty/tools/backend/jvm/BottomTypes.scala b/src/dotty/tools/backend/jvm/BottomTypes.scala new file mode 100644 index 000000000000..07e1315e515a --- /dev/null +++ b/src/dotty/tools/backend/jvm/BottomTypes.scala @@ -0,0 +1,81 @@ +package dotty.tools.backend.jvm + +import dotty.tools.dotc.ast.Trees.Thicket +import dotty.tools.dotc.ast.{Trees, tpd} +import dotty.tools.dotc.core.Contexts.Context +import dotty.tools.dotc.core.Types +import dotty.tools.dotc.transform.TreeTransforms.{TransformerInfo, TreeTransform, MiniPhase, MiniPhaseTransform} +import dotty.tools.dotc +import dotty.tools.dotc.backend.jvm.DottyPrimitives +import dotty.tools.dotc.core.Flags.FlagSet +import dotty.tools.dotc.transform.Erasure +import dotty.tools.dotc.transform.SymUtils._ +import java.io.{File => JFile} + +import scala.collection.generic.Clearable +import scala.collection.mutable +import scala.collection.mutable.{ListBuffer, ArrayBuffer} +import scala.reflect.ClassTag +import scala.reflect.internal.util.WeakHashSet +import scala.reflect.io.{Directory, PlainDirectory, AbstractFile} +import scala.tools.asm.{ClassVisitor, FieldVisitor, MethodVisitor} +import scala.tools.nsc.backend.jvm.{BCodeHelpers, BackendInterface} +import dotty.tools.dotc.core._ +import Periods._ +import SymDenotations._ +import Contexts._ +import Types._ +import Symbols._ +import Denotations._ +import Phases._ +import java.lang.AssertionError +import dotty.tools.dotc.util.Positions.Position +import Decorators._ +import tpd._ +import Flags._ +import StdNames.nme + +/** + * Ensures that tree does not contain type subsumptions where subsumed type is bottom type + * of our typesystem, but not the bottom type of JVM typesystem. + */ +class BottomTypes extends MiniPhaseTransform { + def phaseName: String = "bottomTypes" + + + def adaptBottom(treeOfBottomType: tpd.Tree, expectedType: Type)(implicit ctx: Context) = { + if (defn.isBottomType(treeOfBottomType.tpe) && (treeOfBottomType.tpe ne expectedType)) + Erasure.Boxing.adaptToType(treeOfBottomType, expectedType) + else treeOfBottomType + } + + override def transformDefDef(tree: tpd.DefDef)(implicit ctx: Context, info: TransformerInfo): tpd.Tree = { + val returnTp = tree.symbol.info.dealias.finalResultType + cpy.DefDef(tree)(rhs = adaptBottom(tree.rhs, returnTp)) + } + + + override def transformAssign(tree: tpd.Assign)(implicit ctx: Context, info: TransformerInfo): tpd.Tree = { + val returnTp = tree.lhs.symbol.info.dealias + cpy.Assign(tree)(tree.lhs, adaptBottom(tree.rhs, returnTp)) + } + + + override def transformTyped(tree: tpd.Typed)(implicit ctx: Context, info: TransformerInfo): tpd.Tree = { + cpy.Typed(tree)(adaptBottom(tree.expr, tree.tpt.tpe), tree.tpt) + } + + override def transformApply(tree: tpd.Apply)(implicit ctx: Context, info: TransformerInfo): tpd.Tree = { + val fun = tree.fun + val newArgs: List[tpd.Tree] = tree.args.zip(fun.tpe.dealias.firstParamTypes).map(x => adaptBottom(x._1, x._2)) + val changeNeeded = tree.args != newArgs // cpy.Apply does not check if elements are the same, + // it only does `eq` on lists as whole + if (changeNeeded) cpy.Apply(tree)(fun = fun, args = newArgs) + else tree + } + + override def transformValDef(tree: tpd.ValDef)(implicit ctx: Context, info: TransformerInfo): tpd.Tree = { + val returnTp = tree.symbol.info.dealias + cpy.ValDef(tree)(rhs = adaptBottom(tree.rhs, returnTp)) + } +} diff --git a/src/dotty/tools/dotc/Compiler.scala b/src/dotty/tools/dotc/Compiler.scala index 42d223fe9e83..4b06d3bdf83b 100644 --- a/src/dotty/tools/dotc/Compiler.scala +++ b/src/dotty/tools/dotc/Compiler.scala @@ -15,7 +15,7 @@ import transform.TreeTransforms.{TreeTransform, TreeTransformer} import core.DenotTransformers.DenotTransformer import core.Denotations.SingleDenotation -import dotty.tools.backend.jvm.{LabelDefs, GenBCode} +import dotty.tools.backend.jvm.{BottomTypes, LabelDefs, GenBCode} class Compiler { @@ -82,6 +82,7 @@ class Compiler { List(/*new PrivateToStatic,*/ new ExpandPrivate, new CollectEntryPoints, + new BottomTypes, new LabelDefs), List(new GenBCode) ) diff --git a/src/dotty/tools/dotc/transform/Erasure.scala b/src/dotty/tools/dotc/transform/Erasure.scala index 3445b4c444ec..355b588b9ca4 100644 --- a/src/dotty/tools/dotc/transform/Erasure.scala +++ b/src/dotty/tools/dotc/transform/Erasure.scala @@ -246,7 +246,7 @@ object Erasure extends TypeTestsCasts{ * e -> unbox(e, PT) otherwise, if `PT` is an erased value type * e -> box(e) if `e` is of primitive type and `PT` is not a primitive type * e -> unbox(e, PT) if `PT` is a primitive type and `e` is not of primitive type - * e -> cast(e, PT) otherwise + * e -> cast(e, PT) otherwise, including if `PT` is a bottom type. */ def adaptToType(tree: Tree, pt: Type)(implicit ctx: Context): Tree = if (pt.isInstanceOf[FunProto]) tree @@ -254,7 +254,7 @@ object Erasure extends TypeTestsCasts{ case MethodType(Nil, _) if tree.isTerm => adaptToType(tree.appliedToNone, pt) case tpw => - if (pt.isInstanceOf[ProtoType] || tree.tpe <:< pt) + if (pt.isInstanceOf[ProtoType] || ((!pt.isValueType || !defn.isBottomType(tree.tpe)) && (tree.tpe <:< pt))) tree else if (tpw.isErasedValueType) adaptToType(box(tree), pt) @@ -267,6 +267,7 @@ object Erasure extends TypeTestsCasts{ else cast(tree, pt) } + } class Typer extends typer.ReTyper with NoChecking { diff --git a/tests/pos/BottomOr.scala b/tests/pos/BottomOr.scala new file mode 100644 index 000000000000..23626cd38bd7 --- /dev/null +++ b/tests/pos/BottomOr.scala @@ -0,0 +1,6 @@ +object Test{ + val a: Nothing = ??? + val b: Null = ??? + val c: Null | Nothing = if(a == b) a else b + val d: Nothing | Nothing = if(a == b) a else b +} diff --git a/tests/pos/i828.scala b/tests/pos/i828.scala new file mode 100644 index 000000000000..9bf90552e18d --- /dev/null +++ b/tests/pos/i828.scala @@ -0,0 +1,18 @@ +object X { + val x: Int = null.asInstanceOf[Nothing] + def d: Int = null.asInstanceOf[Nothing] + var s: Int = 0 + s = null.asInstanceOf[Nothing] + def takeInt(i: Int): Unit + takeInt(null.asInstanceOf[Nothing]) +} + +object Y { + val n: Null = null + val x: Object = n + def d: Object = n + var s: Object = 0 + s = n + def takeInt(i: Object): Unit + takeInt(n) +}