From 708a2e02527f04909b7dc1c68bc4a89d79a50bb5 Mon Sep 17 00:00:00 2001 From: Allan Renucci Date: Tue, 10 Apr 2018 10:55:25 +0200 Subject: [PATCH 1/5] Fix #1692: Null out fields after use in lazy initialization Private fields that are only used during lazy val initialization can be assigned null once the lazy val is initialized. This is not just an optimization, but is needed for correctness to prevent memory leaks. --- compiler/src/dotty/tools/dotc/Compiler.scala | 3 +- .../src/dotty/tools/dotc/core/Phases.scala | 3 + .../transform/CollectNullableFields.scala | 107 ++++++++++++++++ .../dotty/tools/dotc/transform/LazyVals.scala | 65 ++++++++-- tests/run/i1692.scala | 121 ++++++++++++++++++ 5 files changed, 286 insertions(+), 13 deletions(-) create mode 100644 compiler/src/dotty/tools/dotc/transform/CollectNullableFields.scala create mode 100644 tests/run/i1692.scala diff --git a/compiler/src/dotty/tools/dotc/Compiler.scala b/compiler/src/dotty/tools/dotc/Compiler.scala index 94779bca680c..e15a941b4e47 100644 --- a/compiler/src/dotty/tools/dotc/Compiler.scala +++ b/compiler/src/dotty/tools/dotc/Compiler.scala @@ -72,6 +72,7 @@ class Compiler { new LiftTry, // Put try expressions that might execute on non-empty stacks into their own methods new HoistSuperArgs, // Hoist complex arguments of supercalls to enclosing scope new ClassOf, // Expand `Predef.classOf` calls. + new CollectNullableFields, // Collect fields that can be null out after use in lazy initialization new RefChecks) :: // Various checks mostly related to abstract members and overriding List(new TryCatchPatterns, // Compile cases in try/catch new PatternMatcher, // Compile pattern matches @@ -97,7 +98,7 @@ class Compiler { List(new Erasure) :: // Rewrite types to JVM model, erasing all type parameters, abstract types and refinements. List(new ElimErasedValueType, // Expand erased value types to their underlying implmementation types new VCElideAllocations, // Peep-hole optimization to eliminate unnecessary value class allocations - new Mixin, // Expand trait fields and trait initializers + new Mixin, // Expand trait fields and trait initializers new LazyVals, // Expand lazy vals new Memoize, // Add private fields to getters and setters new NonLocalReturns, // Expand non-local returns diff --git a/compiler/src/dotty/tools/dotc/core/Phases.scala b/compiler/src/dotty/tools/dotc/core/Phases.scala index 239e8a4fe9d1..3894ef445c47 100644 --- a/compiler/src/dotty/tools/dotc/core/Phases.scala +++ b/compiler/src/dotty/tools/dotc/core/Phases.scala @@ -211,6 +211,7 @@ object Phases { private[this] var myTyperPhase: Phase = _ private[this] var mySbtExtractDependenciesPhase: Phase = _ private[this] var myPicklerPhase: Phase = _ + private[this] var myCollectNullableFieldsPhase: Phase = _ private[this] var myRefChecksPhase: Phase = _ private[this] var myPatmatPhase: Phase = _ private[this] var myElimRepeatedPhase: Phase = _ @@ -226,6 +227,7 @@ object Phases { final def typerPhase = myTyperPhase final def sbtExtractDependenciesPhase = mySbtExtractDependenciesPhase final def picklerPhase = myPicklerPhase + final def collectNullableFieldsPhase = myCollectNullableFieldsPhase final def refchecksPhase = myRefChecksPhase final def patmatPhase = myPatmatPhase final def elimRepeatedPhase = myElimRepeatedPhase @@ -244,6 +246,7 @@ object Phases { myTyperPhase = phaseOfClass(classOf[FrontEnd]) mySbtExtractDependenciesPhase = phaseOfClass(classOf[sbt.ExtractDependencies]) myPicklerPhase = phaseOfClass(classOf[Pickler]) + myCollectNullableFieldsPhase = phaseOfClass(classOf[CollectNullableFields]) myRefChecksPhase = phaseOfClass(classOf[RefChecks]) myElimRepeatedPhase = phaseOfClass(classOf[ElimRepeated]) myExtensionMethodsPhase = phaseOfClass(classOf[ExtensionMethods]) diff --git a/compiler/src/dotty/tools/dotc/transform/CollectNullableFields.scala b/compiler/src/dotty/tools/dotc/transform/CollectNullableFields.scala new file mode 100644 index 000000000000..dc910d33a987 --- /dev/null +++ b/compiler/src/dotty/tools/dotc/transform/CollectNullableFields.scala @@ -0,0 +1,107 @@ +package dotty.tools.dotc.transform + +import dotty.tools.dotc.ast.tpd +import dotty.tools.dotc.core.Contexts.Context +import dotty.tools.dotc.core.Flags._ +import dotty.tools.dotc.core.Symbols.Symbol +import dotty.tools.dotc.core.Types.{Type, ExprType} +import dotty.tools.dotc.transform.MegaPhase.MiniPhase +import dotty.tools.dotc.transform.SymUtils._ + +import scala.collection.JavaConverters._ +import scala.collection.mutable + +import java.util.IdentityHashMap + +object CollectNullableFields { + val name = "collectNullableFields" +} + +/** Collect fields that can be null out after use in lazy initialization. + * + * This information is used during lazy val transformation to assign null to private + * fields that are only used within a lazy val initializer. This is not just an optimization, + * but is needed for correctness to prevent memory leaks. E.g. + * + * {{{ + * class TestByNameLazy(byNameMsg: => String) { + * lazy val byLazyValMsg = byNameMsg + * } + * }}} + * + * Here `byNameMsg` should be null out once `byLazyValMsg` is + * initialised. + * + * A field is nullable if all the conditions below hold: + * - is private + * - is not lazy + * - its type is nullable, or is an expression type (e.g. => Int) + * - is on used in a lazy val initializer + * - defined in the same class as the lazy val + * - TODO from Scalac? from a non-trait class + */ +class CollectNullableFields extends MiniPhase { + import tpd._ + + override def phaseName = CollectNullableFields.name + + private[this] sealed trait FieldInfo + private[this] case object NotNullable extends FieldInfo + private[this] case class Nullable(by: Symbol) extends FieldInfo + + /** Whether or not a field is nullable */ + private[this] var nullability: IdentityHashMap[Symbol, FieldInfo] = _ + + override def prepareForUnit(tree: Tree)(implicit ctx: Context) = { + nullability = new IdentityHashMap + ctx + } + + private def recordUse(tree: Tree)(implicit ctx: Context): Tree = { + val sym = tree.symbol + + def isNullableType(tpe: Type) = + tpe.isInstanceOf[ExprType] || + tpe.widenDealias.typeSymbol.isNullableClass + val isNullablePrivateField = sym.isField && sym.is(Private, butNot = Lazy) && isNullableType(sym.info) + + if (isNullablePrivateField) + nullability.get(sym) match { + case Nullable(from) if from != ctx.owner => // used in multiple lazy val initializers + nullability.put(sym, NotNullable) + case null => // not in the map + val from = ctx.owner + val isNullable = + from.is(Lazy) && from.isField && // used in lazy field initializer + from.owner.eq(sym.owner) // lazy val and field in the same class + val info = if (isNullable) Nullable(from) else NotNullable + nullability.put(sym, info) + case _ => + // Do nothing for: + // - NotNullable + // - Nullable(ctx.owner) + } + + tree + } + + override def transformIdent(tree: Ident)(implicit ctx: Context) = + recordUse(tree) + + override def transformSelect(tree: Select)(implicit ctx: Context) = + recordUse(tree) + + /** Map lazy values to the fields they should null after initialization. */ + def lazyValNullables(implicit ctx: Context): Map[Symbol, List[Symbol]] = { + val result = new mutable.HashMap[Symbol, mutable.ListBuffer[Symbol]] + + nullability.forEach { + case (sym, Nullable(from)) => + val bldr = result.getOrElseUpdate(from, new mutable.ListBuffer) + bldr += sym + case _ => + } + + result.mapValues(_.toList).toMap + } +} diff --git a/compiler/src/dotty/tools/dotc/transform/LazyVals.scala b/compiler/src/dotty/tools/dotc/transform/LazyVals.scala index 620bd53bc948..eede157c6f9a 100644 --- a/compiler/src/dotty/tools/dotc/transform/LazyVals.scala +++ b/compiler/src/dotty/tools/dotc/transform/LazyVals.scala @@ -39,7 +39,7 @@ class LazyVals extends MiniPhase with IdentityDenotTransformer { /** List of names of phases that should have finished processing of tree * before this phase starts processing same tree */ - override def runsAfter = Set(Mixin.name) + override def runsAfter = Set(Mixin.name, CollectNullableFields.name) override def changesMembers = true // the phase adds lazy val accessors @@ -50,6 +50,18 @@ class LazyVals extends MiniPhase with IdentityDenotTransformer { val containerFlagsMask = Flags.Method | Flags.Lazy | Flags.Accessor | Flags.Module + /** A map of lazy values to the fields they should null after initialization. */ + private[this] var lazyValNullables: Map[Symbol, List[Symbol]] = _ + private def nullableFor(sym: Symbol)(implicit ctx: Context) = + if (sym.is(Flags.Module)) Nil + else lazyValNullables.getOrElse(sym, Nil) + + + override def prepareForUnit(tree: Tree)(implicit ctx: Context) = { + lazyValNullables = ctx.collectNullableFieldsPhase.asInstanceOf[CollectNullableFields].lazyValNullables + ctx + } + override def transformDefDef(tree: tpd.DefDef)(implicit ctx: Context): tpd.Tree = transformLazyVal(tree) @@ -150,7 +162,7 @@ class LazyVals extends MiniPhase with IdentityDenotTransformer { val initBody = adaptToType( ref(holderSymbol).select(defn.Object_synchronized).appliedTo( - adaptToType(mkNonThreadSafeDef(result, flag, initer), defn.ObjectType)), + adaptToType(mkNonThreadSafeDef(result, flag, initer, nullables = Nil), defn.ObjectType)), tpe) val initTree = DefDef(initSymbol, initBody) val holderTree = ValDef(holderSymbol, New(holderImpl.typeRef, List())) @@ -176,21 +188,33 @@ class LazyVals extends MiniPhase with IdentityDenotTransformer { holders:::stats } + private def nullOut(nullables: List[Symbol])(implicit ctx: Context): List[Tree] = { + val nullConst = Literal(Constants.Constant(null)) + nullables.map { sym => + val field = if (sym.isGetter) sym.field else sym + assert(field.isField) + field.setFlag(Flags.Mutable) + ref(field).becomes(nullConst) + } + } + /** Create non-threadsafe lazy accessor equivalent to such code * def methodSymbol() = { * if (flag) target * else { * target = rhs * flag = true + * nullable = null * target * } * } */ - def mkNonThreadSafeDef(target: Tree, flag: Tree, rhs: Tree)(implicit ctx: Context) = { + def mkNonThreadSafeDef(target: Tree, flag: Tree, rhs: Tree, nullables: List[Symbol])(implicit ctx: Context) = { val setFlag = flag.becomes(Literal(Constants.Constant(true))) - val setTargets = if (isWildcardArg(rhs)) Nil else target.becomes(rhs) :: Nil - val init = Block(setFlag :: setTargets, target.ensureApplied) + val setNullables = nullOut(nullables) + val setTargetAndNullable = if (isWildcardArg(rhs)) setNullables else target.becomes(rhs) :: setNullables + val init = Block(setFlag :: setTargetAndNullable, target.ensureApplied) If(flag.ensureApplied, target.ensureApplied, init) } @@ -198,15 +222,17 @@ class LazyVals extends MiniPhase with IdentityDenotTransformer { * def methodSymbol() = { * if (target eq null) { * target = rhs + * nullable = null * target * } else target * } */ - def mkDefNonThreadSafeNonNullable(target: Symbol, rhs: Tree)(implicit ctx: Context) = { + def mkDefNonThreadSafeNonNullable(target: Symbol, rhs: Tree, nullables: List[Symbol])(implicit ctx: Context) = { val cond = ref(target).select(nme.eq).appliedTo(Literal(Constant(null))) val exp = ref(target) val setTarget = exp.becomes(rhs) - val init = Block(List(setTarget), exp) + val setNullables = nullOut(nullables) + val init = Block(setTarget :: setNullables, exp) If(cond, init, exp) } @@ -222,14 +248,14 @@ class LazyVals extends MiniPhase with IdentityDenotTransformer { val containerTree = ValDef(containerSymbol, defaultValue(tpe)) if (x.tpe.isNotNull && tpe <:< defn.ObjectType) { // can use 'null' value instead of flag - val slowPath = DefDef(x.symbol.asTerm, mkDefNonThreadSafeNonNullable(containerSymbol, x.rhs)) + val slowPath = DefDef(x.symbol.asTerm, mkDefNonThreadSafeNonNullable(containerSymbol, x.rhs, nullableFor(x.symbol))) Thicket(containerTree, slowPath) } else { val flagName = LazyBitMapName.fresh(x.name.asTermName) val flagSymbol = ctx.newSymbol(x.symbol.owner, flagName, containerFlags | Flags.Private, defn.BooleanType).enteredAfter(this) val flag = ValDef(flagSymbol, Literal(Constants.Constant(false))) - val slowPath = DefDef(x.symbol.asTerm, mkNonThreadSafeDef(ref(containerSymbol), ref(flagSymbol), x.rhs)) + val slowPath = DefDef(x.symbol.asTerm, mkNonThreadSafeDef(ref(containerSymbol), ref(flagSymbol), x.rhs, nullableFor(x.symbol))) Thicket(containerTree, flag, slowPath) } } @@ -263,10 +289,23 @@ class LazyVals extends MiniPhase with IdentityDenotTransformer { * result = $target * } * } + * nullable = null * result * } */ - def mkThreadSafeDef(methodSymbol: TermSymbol, claz: ClassSymbol, ord: Int, target: Symbol, rhs: Tree, tp: Types.Type, offset: Tree, getFlag: Tree, stateMask: Tree, casFlag: Tree, setFlagState: Tree, waitOnLock: Tree)(implicit ctx: Context) = { + def mkThreadSafeDef(methodSymbol: TermSymbol, + claz: ClassSymbol, + ord: Int, + target: Symbol, + rhs: Tree, + tp: Types.Type, + offset: Tree, + getFlag: Tree, + stateMask: Tree, + casFlag: Tree, + setFlagState: Tree, + waitOnLock: Tree, + nullables: List[Symbol])(implicit ctx: Context) = { val initState = Literal(Constants.Constant(0)) val computeState = Literal(Constants.Constant(1)) val notifyState = Literal(Constants.Constant(2)) @@ -330,7 +369,8 @@ class LazyVals extends MiniPhase with IdentityDenotTransformer { val whileBody = List(ref(flagSymbol).becomes(getFlag.appliedTo(thiz, offset)), cases) val cycle = WhileDo(methodSymbol, whileCond, whileBody) - DefDef(methodSymbol, Block(resultDef :: retryDef :: flagDef :: cycle :: Nil, ref(resultSymbol))) + val setNullables = nullOut(nullables) + DefDef(methodSymbol, Block(resultDef :: retryDef :: flagDef :: cycle :: setNullables, ref(resultSymbol))) } def transformMemberDefVolatile(x: ValOrDefDef)(implicit ctx: Context) = { @@ -390,8 +430,9 @@ class LazyVals extends MiniPhase with IdentityDenotTransformer { val wait = Select(ref(helperModule), lazyNme.RLazyVals.wait4Notification) val state = Select(ref(helperModule), lazyNme.RLazyVals.state) val cas = Select(ref(helperModule), lazyNme.RLazyVals.cas) + val nullables = nullableFor(x.symbol) - val accessor = mkThreadSafeDef(x.symbol.asTerm, claz, ord, containerSymbol, x.rhs, tpe, offset, getFlag, state, cas, setFlag, wait) + val accessor = mkThreadSafeDef(x.symbol.asTerm, claz, ord, containerSymbol, x.rhs, tpe, offset, getFlag, state, cas, setFlag, wait, nullables) if (flag eq EmptyTree) Thicket(containerTree, accessor) else Thicket(containerTree, flag, accessor) diff --git a/tests/run/i1692.scala b/tests/run/i1692.scala new file mode 100644 index 000000000000..fb7d03fbee94 --- /dev/null +++ b/tests/run/i1692.scala @@ -0,0 +1,121 @@ +class VCInt(val x: Int) extends AnyVal +class VCString(val x: String) extends AnyVal + +class LazyNullable(a: => Int) { + lazy val l0 = a // null out a + + private val b = "B" + lazy val l1 = b // null out b + + private val c = "C" + @volatile lazy val l2 = c // null out c + + private val d = "D" + lazy val l3 = d + d // null out d (Scalac require single use?) +} + +object LazyNullable2 { + private val a = "A" + lazy val l0 = a // null out a +} + +class LazyNotNullable { + private val a = 'A'.toInt // not nullable type + lazy val l0 = a + + private val b = new VCInt('B'.toInt) // not nullable type + lazy val l1 = b + + private val c = new VCString("C") // should be nullable but is not?? + lazy val l2 = c + + private lazy val d = "D" // not nullable because lazy + lazy val l3 = d + + val e = "E" // not nullable because not private + lazy val l4 = e + + private val f = "F" // not nullable because used in mutiple lazy vals + lazy val l5 = f + lazy val l6 = f + + private val g = "G" // not nullable because used outside a lazy val initializer + def foo = g + lazy val l7 = g + + private val h = "H" // not nullable because field and lazy val not defined in the same class + class Inner { + lazy val l8 = h + } +} + +object Test { + def main(args: Array[String]): Unit = { + nullableTests() + notNullableTests() + } + + def nullableTests() = { + val lz = new LazyNullable('A'.toInt) + + def assertNull(fieldName: String) = { + val value = readField(fieldName, lz) + assert(value == null, s"$fieldName was $value, null expected") + } + + assert(lz.l0 == 'A'.toInt) + assertNull("a") + + assert(lz.l1 == "B") + assertNull("b") + + assert(lz.l2 == "C") + assertNull("c") + + assert(lz.l3 == "DD") + assertNull("d") + + assert(LazyNullable2.l0 == "A") + assert(readField("a", LazyNullable2) == null) + } + + def notNullableTests() = { + val lz = new LazyNotNullable + + def assertNotNull(fieldName: String) = { + val value = readField(fieldName, lz) + assert(value != null, s"$fieldName was null") + } + + assert(lz.l0 == 'A'.toInt) + assertNotNull("a") + + assert(lz.l1 == new VCInt('B'.toInt)) + assertNotNull("b") + + assert(lz.l2 == new VCString("C")) + assertNotNull("c") + + assert(lz.l3 == "D") + + assert(lz.l4 == "E") + assertNotNull("e") + + assert(lz.l5 == "F") + assert(lz.l6 == "F") + assertNotNull("f") + + assert(lz.l7 == "G") + assertNotNull("g") + + val inner = new lz.Inner + assert(inner.l8 == "H") + assertNotNull("h") + } + + def readField(fieldName: String, target: Any): Any = { + val field = target.getClass.getDeclaredField(fieldName) + field.setAccessible(true) + field.get(target) + } +} From f060565b036fce721670f3c46c866d98bf72922f Mon Sep 17 00:00:00 2001 From: Allan Renucci Date: Wed, 11 Apr 2018 18:03:35 +0200 Subject: [PATCH 2/5] Don't null out private field in trait --- .../tools/dotc/transform/CollectNullableFields.scala | 8 ++++++-- tests/run/i1692.scala | 9 +++++++++ 2 files changed, 15 insertions(+), 2 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/transform/CollectNullableFields.scala b/compiler/src/dotty/tools/dotc/transform/CollectNullableFields.scala index dc910d33a987..477ff69d292b 100644 --- a/compiler/src/dotty/tools/dotc/transform/CollectNullableFields.scala +++ b/compiler/src/dotty/tools/dotc/transform/CollectNullableFields.scala @@ -33,12 +33,12 @@ object CollectNullableFields { * initialised. * * A field is nullable if all the conditions below hold: + * - belongs to a non trait-class * - is private * - is not lazy * - its type is nullable, or is an expression type (e.g. => Int) * - is on used in a lazy val initializer * - defined in the same class as the lazy val - * - TODO from Scalac? from a non-trait class */ class CollectNullableFields extends MiniPhase { import tpd._ @@ -63,7 +63,11 @@ class CollectNullableFields extends MiniPhase { def isNullableType(tpe: Type) = tpe.isInstanceOf[ExprType] || tpe.widenDealias.typeSymbol.isNullableClass - val isNullablePrivateField = sym.isField && sym.is(Private, butNot = Lazy) && isNullableType(sym.info) + val isNullablePrivateField = + sym.isField && + sym.is(Private, butNot = Lazy) && + !sym.owner.is(Trait) && + isNullableType(sym.info) if (isNullablePrivateField) nullability.get(sym) match { diff --git a/tests/run/i1692.scala b/tests/run/i1692.scala index fb7d03fbee94..81cafd0e2c2f 100644 --- a/tests/run/i1692.scala +++ b/tests/run/i1692.scala @@ -49,6 +49,11 @@ class LazyNotNullable { } } +trait LazyTrait { + private val a = "A" + lazy val l0 = a +} + object Test { def main(args: Array[String]): Unit = { nullableTests() @@ -111,6 +116,10 @@ object Test { val inner = new lz.Inner assert(inner.l8 == "H") assertNotNull("h") + + val fromTrait = new LazyTrait {} + assert(fromTrait.l0 == "A") + assert(readField("LazyTrait$$a", fromTrait) != null) // fragile: test will break if compiler generated name change } def readField(fieldName: String, target: Any): Any = { From fefbe4c3a4662a9a7a2cf8ec35b3dab2c1d05c50 Mon Sep 17 00:00:00 2001 From: Allan Renucci Date: Mon, 30 Apr 2018 11:41:38 +0200 Subject: [PATCH 3/5] Fix typos and indentation --- compiler/src/dotty/tools/dotc/Compiler.scala | 2 +- .../transform/CollectNullableFields.scala | 6 +- .../dotty/tools/dotc/transform/LazyVals.scala | 537 +++++++++--------- 3 files changed, 275 insertions(+), 270 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/Compiler.scala b/compiler/src/dotty/tools/dotc/Compiler.scala index e15a941b4e47..b4372a00b235 100644 --- a/compiler/src/dotty/tools/dotc/Compiler.scala +++ b/compiler/src/dotty/tools/dotc/Compiler.scala @@ -72,7 +72,7 @@ class Compiler { new LiftTry, // Put try expressions that might execute on non-empty stacks into their own methods new HoistSuperArgs, // Hoist complex arguments of supercalls to enclosing scope new ClassOf, // Expand `Predef.classOf` calls. - new CollectNullableFields, // Collect fields that can be null out after use in lazy initialization + new CollectNullableFields, // Collect fields that can be nulled out after use in lazy initialization new RefChecks) :: // Various checks mostly related to abstract members and overriding List(new TryCatchPatterns, // Compile cases in try/catch new PatternMatcher, // Compile pattern matches diff --git a/compiler/src/dotty/tools/dotc/transform/CollectNullableFields.scala b/compiler/src/dotty/tools/dotc/transform/CollectNullableFields.scala index 477ff69d292b..622f6988115d 100644 --- a/compiler/src/dotty/tools/dotc/transform/CollectNullableFields.scala +++ b/compiler/src/dotty/tools/dotc/transform/CollectNullableFields.scala @@ -17,17 +17,17 @@ object CollectNullableFields { val name = "collectNullableFields" } -/** Collect fields that can be null out after use in lazy initialization. +/** Collect fields that can be nulled out after use in lazy initialization. * * This information is used during lazy val transformation to assign null to private * fields that are only used within a lazy val initializer. This is not just an optimization, * but is needed for correctness to prevent memory leaks. E.g. * - * {{{ + * ```scala * class TestByNameLazy(byNameMsg: => String) { * lazy val byLazyValMsg = byNameMsg * } - * }}} + * ``` * * Here `byNameMsg` should be null out once `byLazyValMsg` is * initialised. diff --git a/compiler/src/dotty/tools/dotc/transform/LazyVals.scala b/compiler/src/dotty/tools/dotc/transform/LazyVals.scala index eede157c6f9a..3ac4c49f12d4 100644 --- a/compiler/src/dotty/tools/dotc/transform/LazyVals.scala +++ b/compiler/src/dotty/tools/dotc/transform/LazyVals.scala @@ -129,51 +129,51 @@ class LazyVals extends MiniPhase with IdentityDenotTransformer { Thicket(field, getter) } - /** Replace a local lazy val inside a method, - * with a LazyHolder from - * dotty.runtime(eg dotty.runtime.LazyInt) - */ - def transformLocalDef(x: ValOrDefDef)(implicit ctx: Context) = { - val valueInitter = x.rhs - val xname = x.name.asTermName - val holderName = LazyLocalName.fresh(xname) - val initName = LazyLocalInitName.fresh(xname) - val tpe = x.tpe.widen.resultType.widen - - val holderType = - if (tpe isRef defn.IntClass) "LazyInt" - else if (tpe isRef defn.LongClass) "LazyLong" - else if (tpe isRef defn.BooleanClass) "LazyBoolean" - else if (tpe isRef defn.FloatClass) "LazyFloat" - else if (tpe isRef defn.DoubleClass) "LazyDouble" - else if (tpe isRef defn.ByteClass) "LazyByte" - else if (tpe isRef defn.CharClass) "LazyChar" - else if (tpe isRef defn.ShortClass) "LazyShort" - else "LazyRef" - - - val holderImpl = ctx.requiredClass("dotty.runtime." + holderType) - - val holderSymbol = ctx.newSymbol(x.symbol.owner, holderName, containerFlags, holderImpl.typeRef, coord = x.pos) - val initSymbol = ctx.newSymbol(x.symbol.owner, initName, initFlags, MethodType(Nil, tpe), coord = x.pos) - val result = ref(holderSymbol).select(lazyNme.value).withPos(x.pos) - val flag = ref(holderSymbol).select(lazyNme.initialized) - val initer = valueInitter.changeOwnerAfter(x.symbol, initSymbol, this) - val initBody = - adaptToType( - ref(holderSymbol).select(defn.Object_synchronized).appliedTo( - adaptToType(mkNonThreadSafeDef(result, flag, initer, nullables = Nil), defn.ObjectType)), - tpe) - val initTree = DefDef(initSymbol, initBody) - val holderTree = ValDef(holderSymbol, New(holderImpl.typeRef, List())) - val methodBody = tpd.If(flag.ensureApplied, - result.ensureApplied, - ref(initSymbol).ensureApplied).ensureConforms(tpe) - - val methodTree = DefDef(x.symbol.asTerm, methodBody) - ctx.debuglog(s"found a lazy val ${x.show},\nrewrote with ${holderTree.show}") - Thicket(holderTree, initTree, methodTree) - } + /** Replace a local lazy val inside a method, + * with a LazyHolder from + * dotty.runtime(eg dotty.runtime.LazyInt) + */ + def transformLocalDef(x: ValOrDefDef)(implicit ctx: Context) = { + val valueInitter = x.rhs + val xname = x.name.asTermName + val holderName = LazyLocalName.fresh(xname) + val initName = LazyLocalInitName.fresh(xname) + val tpe = x.tpe.widen.resultType.widen + + val holderType = + if (tpe isRef defn.IntClass) "LazyInt" + else if (tpe isRef defn.LongClass) "LazyLong" + else if (tpe isRef defn.BooleanClass) "LazyBoolean" + else if (tpe isRef defn.FloatClass) "LazyFloat" + else if (tpe isRef defn.DoubleClass) "LazyDouble" + else if (tpe isRef defn.ByteClass) "LazyByte" + else if (tpe isRef defn.CharClass) "LazyChar" + else if (tpe isRef defn.ShortClass) "LazyShort" + else "LazyRef" + + + val holderImpl = ctx.requiredClass("dotty.runtime." + holderType) + + val holderSymbol = ctx.newSymbol(x.symbol.owner, holderName, containerFlags, holderImpl.typeRef, coord = x.pos) + val initSymbol = ctx.newSymbol(x.symbol.owner, initName, initFlags, MethodType(Nil, tpe), coord = x.pos) + val result = ref(holderSymbol).select(lazyNme.value).withPos(x.pos) + val flag = ref(holderSymbol).select(lazyNme.initialized) + val initer = valueInitter.changeOwnerAfter(x.symbol, initSymbol, this) + val initBody = + adaptToType( + ref(holderSymbol).select(defn.Object_synchronized).appliedTo( + adaptToType(mkNonThreadSafeDef(result, flag, initer, nullables = Nil), defn.ObjectType)), + tpe) + val initTree = DefDef(initSymbol, initBody) + val holderTree = ValDef(holderSymbol, New(holderImpl.typeRef, List())) + val methodBody = tpd.If(flag.ensureApplied, + result.ensureApplied, + ref(initSymbol).ensureApplied).ensureConforms(tpe) + + val methodTree = DefDef(x.symbol.asTerm, methodBody) + ctx.debuglog(s"found a lazy val ${x.show},\nrewrote with ${holderTree.show}") + Thicket(holderTree, initTree, methodTree) + } override def transformStats(trees: List[tpd.Tree])(implicit ctx: Context): List[tpd.Tree] = { @@ -199,244 +199,249 @@ class LazyVals extends MiniPhase with IdentityDenotTransformer { } /** Create non-threadsafe lazy accessor equivalent to such code - * def methodSymbol() = { - * if (flag) target - * else { - * target = rhs - * flag = true - * nullable = null - * target - * } - * } - */ - - def mkNonThreadSafeDef(target: Tree, flag: Tree, rhs: Tree, nullables: List[Symbol])(implicit ctx: Context) = { - val setFlag = flag.becomes(Literal(Constants.Constant(true))) - val setNullables = nullOut(nullables) - val setTargetAndNullable = if (isWildcardArg(rhs)) setNullables else target.becomes(rhs) :: setNullables - val init = Block(setFlag :: setTargetAndNullable, target.ensureApplied) - If(flag.ensureApplied, target.ensureApplied, init) + * ``` + * def methodSymbol() = { + * if (flag) target + * else { + * target = rhs + * flag = true + * nullable = null + * target + * } + * } + * } + * ``` + */ + def mkNonThreadSafeDef(target: Tree, flag: Tree, rhs: Tree, nullables: List[Symbol])(implicit ctx: Context) = { + val setFlag = flag.becomes(Literal(Constants.Constant(true))) + val setNullables = nullOut(nullables) + val setTargetAndNullable = if (isWildcardArg(rhs)) setNullables else target.becomes(rhs) :: setNullables + val init = Block(setFlag :: setTargetAndNullable, target.ensureApplied) + If(flag.ensureApplied, target.ensureApplied, init) + } + + /** Create non-threadsafe lazy accessor for not-nullable types equivalent to such code + * ``` + * def methodSymbol() = { + * if (target eq null) { + * target = rhs + * nullable = null + * target + * } else target + * } + * ``` + */ + def mkDefNonThreadSafeNonNullable(target: Symbol, rhs: Tree, nullables: List[Symbol])(implicit ctx: Context) = { + val cond = ref(target).select(nme.eq).appliedTo(Literal(Constant(null))) + val exp = ref(target) + val setTarget = exp.becomes(rhs) + val setNullables = nullOut(nullables) + val init = Block(setTarget :: setNullables, exp) + If(cond, init, exp) + } + + def transformMemberDefNonVolatile(x: ValOrDefDef)(implicit ctx: Context) = { + val claz = x.symbol.owner.asClass + val tpe = x.tpe.widen.resultType.widen + assert(!(x.symbol is Flags.Mutable)) + val containerName = LazyLocalName.fresh(x.name.asTermName) + val containerSymbol = ctx.newSymbol(claz, containerName, + x.symbol.flags &~ containerFlagsMask | containerFlags | Flags.Private, + tpe, coord = x.symbol.coord + ).enteredAfter(this) + + val containerTree = ValDef(containerSymbol, defaultValue(tpe)) + if (x.tpe.isNotNull && tpe <:< defn.ObjectType) { // can use 'null' value instead of flag + val slowPath = DefDef(x.symbol.asTerm, mkDefNonThreadSafeNonNullable(containerSymbol, x.rhs, nullableFor(x.symbol))) + Thicket(containerTree, slowPath) + } + else { + val flagName = LazyBitMapName.fresh(x.name.asTermName) + val flagSymbol = ctx.newSymbol(x.symbol.owner, flagName, containerFlags | Flags.Private, defn.BooleanType).enteredAfter(this) + val flag = ValDef(flagSymbol, Literal(Constants.Constant(false))) + val slowPath = DefDef(x.symbol.asTerm, mkNonThreadSafeDef(ref(containerSymbol), ref(flagSymbol), x.rhs, nullableFor(x.symbol))) + Thicket(containerTree, flag, slowPath) } + } - /** Create non-threadsafe lazy accessor for not-nullable types equivalent to such code - * def methodSymbol() = { - * if (target eq null) { - * target = rhs - * nullable = null - * target - * } else target - * } - */ - def mkDefNonThreadSafeNonNullable(target: Symbol, rhs: Tree, nullables: List[Symbol])(implicit ctx: Context) = { - val cond = ref(target).select(nme.eq).appliedTo(Literal(Constant(null))) - val exp = ref(target) - val setTarget = exp.becomes(rhs) - val setNullables = nullOut(nullables) - val init = Block(setTarget :: setNullables, exp) - If(cond, init, exp) + /** Create a threadsafe lazy accessor equivalent to such code + * ``` + * def methodSymbol(): Int = { + * val result: Int = 0 + * val retry: Boolean = true + * var flag: Long = 0L + * while retry do { + * flag = dotty.runtime.LazyVals.get(this, $claz.$OFFSET) + * dotty.runtime.LazyVals.STATE(flag, 0) match { + * case 0 => + * if dotty.runtime.LazyVals.CAS(this, $claz.$OFFSET, flag, 1, $ord) { + * try {result = rhs} catch { + * case x: Throwable => + * dotty.runtime.LazyVals.setFlag(this, $claz.$OFFSET, 0, $ord) + * throw x + * } + * $target = result + * dotty.runtime.LazyVals.setFlag(this, $claz.$OFFSET, 3, $ord) + * retry = false + * } + * case 1 => + * dotty.runtime.LazyVals.wait4Notification(this, $claz.$OFFSET, flag, $ord) + * case 2 => + * dotty.runtime.LazyVals.wait4Notification(this, $claz.$OFFSET, flag, $ord) + * case 3 => + * retry = false + * result = $target + * } + * } + * nullable = null + * result + * } + * ``` + */ + def mkThreadSafeDef(methodSymbol: TermSymbol, + claz: ClassSymbol, + ord: Int, + target: Symbol, + rhs: Tree, + tp: Types.Type, + offset: Tree, + getFlag: Tree, + stateMask: Tree, + casFlag: Tree, + setFlagState: Tree, + waitOnLock: Tree, + nullables: List[Symbol])(implicit ctx: Context) = { + val initState = Literal(Constants.Constant(0)) + val computeState = Literal(Constants.Constant(1)) + val notifyState = Literal(Constants.Constant(2)) + val computedState = Literal(Constants.Constant(3)) + val flagSymbol = ctx.newSymbol(methodSymbol, lazyNme.flag, containerFlags, defn.LongType) + val flagDef = ValDef(flagSymbol, Literal(Constant(0L))) + + val thiz = This(claz)(ctx.fresh.setOwner(claz)) + + val resultSymbol = ctx.newSymbol(methodSymbol, lazyNme.result, containerFlags, tp) + val resultDef = ValDef(resultSymbol, defaultValue(tp)) + + val retrySymbol = ctx.newSymbol(methodSymbol, lazyNme.retry, containerFlags, defn.BooleanType) + val retryDef = ValDef(retrySymbol, Literal(Constants.Constant(true))) + + val whileCond = ref(retrySymbol) + + val compute = { + val handlerSymbol = ctx.newSymbol(methodSymbol, nme.ANON_FUN, Flags.Synthetic, + MethodType(List(nme.x_1), List(defn.ThrowableType), defn.IntType)) + val caseSymbol = ctx.newSymbol(methodSymbol, nme.DEFAULT_EXCEPTION_NAME, Flags.Synthetic, defn.ThrowableType) + val triggerRetry = setFlagState.appliedTo(thiz, offset, initState, Literal(Constant(ord))) + val complete = setFlagState.appliedTo(thiz, offset, computedState, Literal(Constant(ord))) + + val handler = CaseDef(Bind(caseSymbol, ref(caseSymbol)), EmptyTree, + Block(List(triggerRetry), Throw(ref(caseSymbol)) + )) + + val compute = ref(resultSymbol).becomes(rhs) + val tr = Try(compute, List(handler), EmptyTree) + val assign = ref(target).becomes(ref(resultSymbol)) + val noRetry = ref(retrySymbol).becomes(Literal(Constants.Constant(false))) + val body = If(casFlag.appliedTo(thiz, offset, ref(flagSymbol), computeState, Literal(Constant(ord))), + Block(tr :: assign :: complete :: noRetry :: Nil, Literal(Constant(()))), + Literal(Constant(()))) + + CaseDef(initState, EmptyTree, body) } - def transformMemberDefNonVolatile(x: ValOrDefDef)(implicit ctx: Context) = { - val claz = x.symbol.owner.asClass - val tpe = x.tpe.widen.resultType.widen - assert(!(x.symbol is Flags.Mutable)) - val containerName = LazyLocalName.fresh(x.name.asTermName) - val containerSymbol = ctx.newSymbol(claz, containerName, - x.symbol.flags &~ containerFlagsMask | containerFlags | Flags.Private, - tpe, coord = x.symbol.coord - ).enteredAfter(this) - - val containerTree = ValDef(containerSymbol, defaultValue(tpe)) - if (x.tpe.isNotNull && tpe <:< defn.ObjectType) { // can use 'null' value instead of flag - val slowPath = DefDef(x.symbol.asTerm, mkDefNonThreadSafeNonNullable(containerSymbol, x.rhs, nullableFor(x.symbol))) - Thicket(containerTree, slowPath) - } - else { - val flagName = LazyBitMapName.fresh(x.name.asTermName) - val flagSymbol = ctx.newSymbol(x.symbol.owner, flagName, containerFlags | Flags.Private, defn.BooleanType).enteredAfter(this) - val flag = ValDef(flagSymbol, Literal(Constants.Constant(false))) - val slowPath = DefDef(x.symbol.asTerm, mkNonThreadSafeDef(ref(containerSymbol), ref(flagSymbol), x.rhs, nullableFor(x.symbol))) - Thicket(containerTree, flag, slowPath) - } + val waitFirst = { + val wait = waitOnLock.appliedTo(thiz, offset, ref(flagSymbol), Literal(Constant(ord))) + CaseDef(computeState, EmptyTree, wait) } - /** Create a threadsafe lazy accessor equivalent to such code - * - * def methodSymbol(): Int = { - * val result: Int = 0 - * val retry: Boolean = true - * var flag: Long = 0L - * while retry do { - * flag = dotty.runtime.LazyVals.get(this, $claz.$OFFSET) - * dotty.runtime.LazyVals.STATE(flag, 0) match { - * case 0 => - * if dotty.runtime.LazyVals.CAS(this, $claz.$OFFSET, flag, 1, $ord) { - * try {result = rhs} catch { - * case x: Throwable => - * dotty.runtime.LazyVals.setFlag(this, $claz.$OFFSET, 0, $ord) - * throw x - * } - * $target = result - * dotty.runtime.LazyVals.setFlag(this, $claz.$OFFSET, 3, $ord) - * retry = false - * } - * case 1 => - * dotty.runtime.LazyVals.wait4Notification(this, $claz.$OFFSET, flag, $ord) - * case 2 => - * dotty.runtime.LazyVals.wait4Notification(this, $claz.$OFFSET, flag, $ord) - * case 3 => - * retry = false - * result = $target - * } - * } - * nullable = null - * result - * } - */ - def mkThreadSafeDef(methodSymbol: TermSymbol, - claz: ClassSymbol, - ord: Int, - target: Symbol, - rhs: Tree, - tp: Types.Type, - offset: Tree, - getFlag: Tree, - stateMask: Tree, - casFlag: Tree, - setFlagState: Tree, - waitOnLock: Tree, - nullables: List[Symbol])(implicit ctx: Context) = { - val initState = Literal(Constants.Constant(0)) - val computeState = Literal(Constants.Constant(1)) - val notifyState = Literal(Constants.Constant(2)) - val computedState = Literal(Constants.Constant(3)) - val flagSymbol = ctx.newSymbol(methodSymbol, lazyNme.flag, containerFlags, defn.LongType) - val flagDef = ValDef(flagSymbol, Literal(Constant(0L))) - - val thiz = This(claz)(ctx.fresh.setOwner(claz)) - - val resultSymbol = ctx.newSymbol(methodSymbol, lazyNme.result, containerFlags, tp) - val resultDef = ValDef(resultSymbol, defaultValue(tp)) - - val retrySymbol = ctx.newSymbol(methodSymbol, lazyNme.retry, containerFlags, defn.BooleanType) - val retryDef = ValDef(retrySymbol, Literal(Constants.Constant(true))) - - val whileCond = ref(retrySymbol) - - val compute = { - val handlerSymbol = ctx.newSymbol(methodSymbol, nme.ANON_FUN, Flags.Synthetic, - MethodType(List(nme.x_1), List(defn.ThrowableType), defn.IntType)) - val caseSymbol = ctx.newSymbol(methodSymbol, nme.DEFAULT_EXCEPTION_NAME, Flags.Synthetic, defn.ThrowableType) - val triggerRetry = setFlagState.appliedTo(thiz, offset, initState, Literal(Constant(ord))) - val complete = setFlagState.appliedTo(thiz, offset, computedState, Literal(Constant(ord))) - - val handler = CaseDef(Bind(caseSymbol, ref(caseSymbol)), EmptyTree, - Block(List(triggerRetry), Throw(ref(caseSymbol)) - )) - - val compute = ref(resultSymbol).becomes(rhs) - val tr = Try(compute, List(handler), EmptyTree) - val assign = ref(target).becomes(ref(resultSymbol)) - val noRetry = ref(retrySymbol).becomes(Literal(Constants.Constant(false))) - val body = If(casFlag.appliedTo(thiz, offset, ref(flagSymbol), computeState, Literal(Constant(ord))), - Block(tr :: assign :: complete :: noRetry :: Nil, Literal(Constant(()))), - Literal(Constant(()))) - - CaseDef(initState, EmptyTree, body) - } + val waitSecond = { + val wait = waitOnLock.appliedTo(thiz, offset, ref(flagSymbol), Literal(Constant(ord))) + CaseDef(notifyState, EmptyTree, wait) + } - val waitFirst = { - val wait = waitOnLock.appliedTo(thiz, offset, ref(flagSymbol), Literal(Constant(ord))) - CaseDef(computeState, EmptyTree, wait) - } + val computed = { + val noRetry = ref(retrySymbol).becomes(Literal(Constants.Constant(false))) + val result = ref(resultSymbol).becomes(ref(target)) + val body = Block(noRetry :: result :: Nil, Literal(Constant(()))) + CaseDef(computedState, EmptyTree, body) + } - val waitSecond = { - val wait = waitOnLock.appliedTo(thiz, offset, ref(flagSymbol), Literal(Constant(ord))) - CaseDef(notifyState, EmptyTree, wait) - } + val default = CaseDef(Underscore(defn.LongType), EmptyTree, Literal(Constant(()))) - val computed = { - val noRetry = ref(retrySymbol).becomes(Literal(Constants.Constant(false))) - val result = ref(resultSymbol).becomes(ref(target)) - val body = Block(noRetry :: result :: Nil, Literal(Constant(()))) - CaseDef(computedState, EmptyTree, body) - } + val cases = Match(stateMask.appliedTo(ref(flagSymbol), Literal(Constant(ord))), + List(compute, waitFirst, waitSecond, computed, default)) //todo: annotate with @switch - val default = CaseDef(Underscore(defn.LongType), EmptyTree, Literal(Constant(()))) + val whileBody = List(ref(flagSymbol).becomes(getFlag.appliedTo(thiz, offset)), cases) + val cycle = WhileDo(methodSymbol, whileCond, whileBody) + val setNullables = nullOut(nullables) + DefDef(methodSymbol, Block(resultDef :: retryDef :: flagDef :: cycle :: setNullables, ref(resultSymbol))) + } - val cases = Match(stateMask.appliedTo(ref(flagSymbol), Literal(Constant(ord))), - List(compute, waitFirst, waitSecond, computed, default)) //todo: annotate with @switch + def transformMemberDefVolatile(x: ValOrDefDef)(implicit ctx: Context) = { + assert(!(x.symbol is Flags.Mutable)) + + val tpe = x.tpe.widen.resultType.widen + val claz = x.symbol.owner.asClass + val thizClass = Literal(Constant(claz.info)) + val helperModule = ctx.requiredModule("dotty.runtime.LazyVals") + val getOffset = Select(ref(helperModule), lazyNme.RLazyVals.getOffset) + var offsetSymbol: TermSymbol = null + var flag: Tree = EmptyTree + var ord = 0 + + def offsetName(id: Int) = (StdNames.nme.LAZY_FIELD_OFFSET + (if(x.symbol.owner.is(Flags.Module)) "_m_" else "") + id.toString).toTermName + + // compute or create appropriate offsetSymol, bitmap and bits used by current ValDef + appendOffsetDefs.get(claz) match { + case Some(info) => + val flagsPerLong = (64 / dotty.runtime.LazyVals.BITS_PER_LAZY_VAL).toInt + info.ord += 1 + ord = info.ord % flagsPerLong + val id = info.ord / flagsPerLong + val offsetById = offsetName(id) + if (ord != 0) { // there are unused bits in already existing flag + offsetSymbol = claz.info.decl(offsetById) + .suchThat(sym => (sym is Flags.Synthetic) && sym.isTerm) + .symbol.asTerm + } else { // need to create a new flag + offsetSymbol = ctx.newSymbol(claz, offsetById, Flags.Synthetic, defn.LongType).enteredAfter(this) + offsetSymbol.addAnnotation(Annotation(defn.ScalaStaticAnnot)) + val flagName = (StdNames.nme.BITMAP_PREFIX + id.toString).toTermName + val flagSymbol = ctx.newSymbol(claz, flagName, containerFlags, defn.LongType).enteredAfter(this) + flag = ValDef(flagSymbol, Literal(Constants.Constant(0L))) + val offsetTree = ValDef(offsetSymbol, getOffset.appliedTo(thizClass, Literal(Constant(flagName.toString)))) + info.defs = offsetTree :: info.defs + } - val whileBody = List(ref(flagSymbol).becomes(getFlag.appliedTo(thiz, offset)), cases) - val cycle = WhileDo(methodSymbol, whileCond, whileBody) - val setNullables = nullOut(nullables) - DefDef(methodSymbol, Block(resultDef :: retryDef :: flagDef :: cycle :: setNullables, ref(resultSymbol))) + case None => + offsetSymbol = ctx.newSymbol(claz, offsetName(0), Flags.Synthetic, defn.LongType).enteredAfter(this) + offsetSymbol.addAnnotation(Annotation(defn.ScalaStaticAnnot)) + val flagName = (StdNames.nme.BITMAP_PREFIX + "0").toTermName + val flagSymbol = ctx.newSymbol(claz, flagName, containerFlags, defn.LongType).enteredAfter(this) + flag = ValDef(flagSymbol, Literal(Constants.Constant(0L))) + val offsetTree = ValDef(offsetSymbol, getOffset.appliedTo(thizClass, Literal(Constant(flagName.toString)))) + appendOffsetDefs += (claz -> new OffsetInfo(List(offsetTree), ord)) } - def transformMemberDefVolatile(x: ValOrDefDef)(implicit ctx: Context) = { - assert(!(x.symbol is Flags.Mutable)) - - val tpe = x.tpe.widen.resultType.widen - val claz = x.symbol.owner.asClass - val thizClass = Literal(Constant(claz.info)) - val helperModule = ctx.requiredModule("dotty.runtime.LazyVals") - val getOffset = Select(ref(helperModule), lazyNme.RLazyVals.getOffset) - var offsetSymbol: TermSymbol = null - var flag: Tree = EmptyTree - var ord = 0 - - def offsetName(id: Int) = (StdNames.nme.LAZY_FIELD_OFFSET + (if(x.symbol.owner.is(Flags.Module)) "_m_" else "") + id.toString).toTermName - - // compute or create appropriate offsetSymol, bitmap and bits used by current ValDef - appendOffsetDefs.get(claz) match { - case Some(info) => - val flagsPerLong = (64 / dotty.runtime.LazyVals.BITS_PER_LAZY_VAL).toInt - info.ord += 1 - ord = info.ord % flagsPerLong - val id = info.ord / flagsPerLong - val offsetById = offsetName(id) - if (ord != 0) { // there are unused bits in already existing flag - offsetSymbol = claz.info.decl(offsetById) - .suchThat(sym => (sym is Flags.Synthetic) && sym.isTerm) - .symbol.asTerm - } else { // need to create a new flag - offsetSymbol = ctx.newSymbol(claz, offsetById, Flags.Synthetic, defn.LongType).enteredAfter(this) - offsetSymbol.addAnnotation(Annotation(defn.ScalaStaticAnnot)) - val flagName = (StdNames.nme.BITMAP_PREFIX + id.toString).toTermName - val flagSymbol = ctx.newSymbol(claz, flagName, containerFlags, defn.LongType).enteredAfter(this) - flag = ValDef(flagSymbol, Literal(Constants.Constant(0L))) - val offsetTree = ValDef(offsetSymbol, getOffset.appliedTo(thizClass, Literal(Constant(flagName.toString)))) - info.defs = offsetTree :: info.defs - } - - case None => - offsetSymbol = ctx.newSymbol(claz, offsetName(0), Flags.Synthetic, defn.LongType).enteredAfter(this) - offsetSymbol.addAnnotation(Annotation(defn.ScalaStaticAnnot)) - val flagName = (StdNames.nme.BITMAP_PREFIX + "0").toTermName - val flagSymbol = ctx.newSymbol(claz, flagName, containerFlags, defn.LongType).enteredAfter(this) - flag = ValDef(flagSymbol, Literal(Constants.Constant(0L))) - val offsetTree = ValDef(offsetSymbol, getOffset.appliedTo(thizClass, Literal(Constant(flagName.toString)))) - appendOffsetDefs += (claz -> new OffsetInfo(List(offsetTree), ord)) - } - - val containerName = LazyLocalName.fresh(x.name.asTermName) - val containerSymbol = ctx.newSymbol(claz, containerName, x.symbol.flags &~ containerFlagsMask | containerFlags, tpe, coord = x.symbol.coord).enteredAfter(this) + val containerName = LazyLocalName.fresh(x.name.asTermName) + val containerSymbol = ctx.newSymbol(claz, containerName, x.symbol.flags &~ containerFlagsMask | containerFlags, tpe, coord = x.symbol.coord).enteredAfter(this) - val containerTree = ValDef(containerSymbol, defaultValue(tpe)) + val containerTree = ValDef(containerSymbol, defaultValue(tpe)) - val offset = ref(offsetSymbol) - val getFlag = Select(ref(helperModule), lazyNme.RLazyVals.get) - val setFlag = Select(ref(helperModule), lazyNme.RLazyVals.setFlag) - val wait = Select(ref(helperModule), lazyNme.RLazyVals.wait4Notification) - val state = Select(ref(helperModule), lazyNme.RLazyVals.state) - val cas = Select(ref(helperModule), lazyNme.RLazyVals.cas) - val nullables = nullableFor(x.symbol) + val offset = ref(offsetSymbol) + val getFlag = Select(ref(helperModule), lazyNme.RLazyVals.get) + val setFlag = Select(ref(helperModule), lazyNme.RLazyVals.setFlag) + val wait = Select(ref(helperModule), lazyNme.RLazyVals.wait4Notification) + val state = Select(ref(helperModule), lazyNme.RLazyVals.state) + val cas = Select(ref(helperModule), lazyNme.RLazyVals.cas) + val nullables = nullableFor(x.symbol) - val accessor = mkThreadSafeDef(x.symbol.asTerm, claz, ord, containerSymbol, x.rhs, tpe, offset, getFlag, state, cas, setFlag, wait, nullables) - if (flag eq EmptyTree) - Thicket(containerTree, accessor) - else Thicket(containerTree, flag, accessor) - } + val accessor = mkThreadSafeDef(x.symbol.asTerm, claz, ord, containerSymbol, x.rhs, tpe, offset, getFlag, state, cas, setFlag, wait, nullables) + if (flag eq EmptyTree) + Thicket(containerTree, accessor) + else Thicket(containerTree, flag, accessor) + } } object LazyVals { From 78f83c9ddc4478c8b9f0a957645adfe14c536238 Mon Sep 17 00:00:00 2001 From: Allan Renucci Date: Mon, 30 Apr 2018 14:22:43 +0200 Subject: [PATCH 4/5] Address review comments --- compiler/src/dotty/tools/dotc/Compiler.scala | 2 +- .../transform/CollectNullableFields.scala | 32 +++++++++++-------- .../dotty/tools/dotc/transform/Getters.scala | 6 +++- .../dotty/tools/dotc/transform/LazyVals.scala | 15 ++++++--- tests/run/i1692.scala | 2 +- 5 files changed, 36 insertions(+), 21 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/Compiler.scala b/compiler/src/dotty/tools/dotc/Compiler.scala index b4372a00b235..596b1d6bf45f 100644 --- a/compiler/src/dotty/tools/dotc/Compiler.scala +++ b/compiler/src/dotty/tools/dotc/Compiler.scala @@ -72,7 +72,6 @@ class Compiler { new LiftTry, // Put try expressions that might execute on non-empty stacks into their own methods new HoistSuperArgs, // Hoist complex arguments of supercalls to enclosing scope new ClassOf, // Expand `Predef.classOf` calls. - new CollectNullableFields, // Collect fields that can be nulled out after use in lazy initialization new RefChecks) :: // Various checks mostly related to abstract members and overriding List(new TryCatchPatterns, // Compile cases in try/catch new PatternMatcher, // Compile pattern matches @@ -88,6 +87,7 @@ class Compiler { new InterceptedMethods, // Special handling of `==`, `|=`, `getClass` methods new Getters, // Replace non-private vals and vars with getter defs (fields are added later) new ElimByName, // Expand by-name parameter references + new CollectNullableFields, // Collect fields that can be nulled out after use in lazy initialization new ElimOuterSelect, // Expand outer selections new AugmentScala2Traits, // Expand traits defined in Scala 2.x to simulate old-style rewritings new ResolveSuper, // Implement super accessors and add forwarders to trait methods diff --git a/compiler/src/dotty/tools/dotc/transform/CollectNullableFields.scala b/compiler/src/dotty/tools/dotc/transform/CollectNullableFields.scala index 622f6988115d..939896866a9f 100644 --- a/compiler/src/dotty/tools/dotc/transform/CollectNullableFields.scala +++ b/compiler/src/dotty/tools/dotc/transform/CollectNullableFields.scala @@ -45,6 +45,13 @@ class CollectNullableFields extends MiniPhase { override def phaseName = CollectNullableFields.name + /** Running after `ElimByName` to see by names as nullable types. + * + * We don't necessary need to run after `Getters`, but the implementation + * could be simplified if we were to run before. + */ + override def runsAfter = Set(Getters.name, ElimByName.name) + private[this] sealed trait FieldInfo private[this] case object NotNullable extends FieldInfo private[this] case class Nullable(by: Symbol) extends FieldInfo @@ -53,21 +60,20 @@ class CollectNullableFields extends MiniPhase { private[this] var nullability: IdentityHashMap[Symbol, FieldInfo] = _ override def prepareForUnit(tree: Tree)(implicit ctx: Context) = { - nullability = new IdentityHashMap + if (nullability == null) nullability = new IdentityHashMap ctx } private def recordUse(tree: Tree)(implicit ctx: Context): Tree = { - val sym = tree.symbol + def isField(sym: Symbol) = + sym.isField || sym.isGetter // running after phase Getters - def isNullableType(tpe: Type) = - tpe.isInstanceOf[ExprType] || - tpe.widenDealias.typeSymbol.isNullableClass + val sym = tree.symbol val isNullablePrivateField = - sym.isField && + isField(sym) && sym.is(Private, butNot = Lazy) && !sym.owner.is(Trait) && - isNullableType(sym.info) + sym.info.widenDealias.typeSymbol.isNullableClass if (isNullablePrivateField) nullability.get(sym) match { @@ -76,8 +82,8 @@ class CollectNullableFields extends MiniPhase { case null => // not in the map val from = ctx.owner val isNullable = - from.is(Lazy) && from.isField && // used in lazy field initializer - from.owner.eq(sym.owner) // lazy val and field in the same class + from.is(Lazy) && isField(from) && // used in lazy field initializer + from.owner.eq(sym.owner) // lazy val and field defined in the same class val info = if (isNullable) Nullable(from) else NotNullable nullability.put(sym, info) case _ => @@ -96,16 +102,16 @@ class CollectNullableFields extends MiniPhase { recordUse(tree) /** Map lazy values to the fields they should null after initialization. */ - def lazyValNullables(implicit ctx: Context): Map[Symbol, List[Symbol]] = { - val result = new mutable.HashMap[Symbol, mutable.ListBuffer[Symbol]] + def lazyValNullables(implicit ctx: Context): IdentityHashMap[Symbol, mutable.ListBuffer[Symbol]] = { + val result = new IdentityHashMap[Symbol, mutable.ListBuffer[Symbol]] nullability.forEach { case (sym, Nullable(from)) => - val bldr = result.getOrElseUpdate(from, new mutable.ListBuffer) + val bldr = result.computeIfAbsent(from, _ => new mutable.ListBuffer) bldr += sym case _ => } - result.mapValues(_.toList).toMap + result } } diff --git a/compiler/src/dotty/tools/dotc/transform/Getters.scala b/compiler/src/dotty/tools/dotc/transform/Getters.scala index 55e68d69ff11..3c554e4230af 100644 --- a/compiler/src/dotty/tools/dotc/transform/Getters.scala +++ b/compiler/src/dotty/tools/dotc/transform/Getters.scala @@ -49,7 +49,7 @@ import ValueClasses._ class Getters extends MiniPhase with SymTransformer { import ast.tpd._ - override def phaseName = "getters" + override def phaseName = Getters.name override def transformSym(d: SymDenotation)(implicit ctx: Context): SymDenotation = { def noGetterNeeded = @@ -74,3 +74,7 @@ class Getters extends MiniPhase with SymTransformer { override def transformAssign(tree: Assign)(implicit ctx: Context): Tree = if (tree.lhs.symbol is Method) tree.lhs.becomes(tree.rhs).withPos(tree.pos) else tree } + +object Getters { + val name = "getters" +} diff --git a/compiler/src/dotty/tools/dotc/transform/LazyVals.scala b/compiler/src/dotty/tools/dotc/transform/LazyVals.scala index 3ac4c49f12d4..bd871294661c 100644 --- a/compiler/src/dotty/tools/dotc/transform/LazyVals.scala +++ b/compiler/src/dotty/tools/dotc/transform/LazyVals.scala @@ -26,6 +26,8 @@ import dotty.tools.dotc.core.SymDenotations.SymDenotation import dotty.tools.dotc.core.DenotTransformers.{SymTransformer, IdentityDenotTransformer, DenotTransformer} import Erasure.Boxing.adaptToType +import java.util.IdentityHashMap + class LazyVals extends MiniPhase with IdentityDenotTransformer { import LazyVals._ import tpd._ @@ -51,14 +53,17 @@ class LazyVals extends MiniPhase with IdentityDenotTransformer { val containerFlagsMask = Flags.Method | Flags.Lazy | Flags.Accessor | Flags.Module /** A map of lazy values to the fields they should null after initialization. */ - private[this] var lazyValNullables: Map[Symbol, List[Symbol]] = _ - private def nullableFor(sym: Symbol)(implicit ctx: Context) = - if (sym.is(Flags.Module)) Nil - else lazyValNullables.getOrElse(sym, Nil) + private[this] var lazyValNullables: IdentityHashMap[Symbol, mutable.ListBuffer[Symbol]] = _ + private def nullableFor(sym: Symbol)(implicit ctx: Context) = { + val nullables = lazyValNullables.remove(sym) + if (nullables == null || sym.is(Flags.Module)) Nil + else nullables.toList + } override def prepareForUnit(tree: Tree)(implicit ctx: Context) = { - lazyValNullables = ctx.collectNullableFieldsPhase.asInstanceOf[CollectNullableFields].lazyValNullables + if (lazyValNullables == null) + lazyValNullables = ctx.collectNullableFieldsPhase.asInstanceOf[CollectNullableFields].lazyValNullables ctx } diff --git a/tests/run/i1692.scala b/tests/run/i1692.scala index 81cafd0e2c2f..12ab7835385d 100644 --- a/tests/run/i1692.scala +++ b/tests/run/i1692.scala @@ -119,7 +119,7 @@ object Test { val fromTrait = new LazyTrait {} assert(fromTrait.l0 == "A") - assert(readField("LazyTrait$$a", fromTrait) != null) // fragile: test will break if compiler generated name change + assert(readField("LazyTrait$$a", fromTrait) != null) // fragile: test will break if compiler generated names change } def readField(fieldName: String, target: Any): Any = { From d800e310a210440452ed56e3d702d86155530e10 Mon Sep 17 00:00:00 2001 From: Allan Renucci Date: Wed, 9 May 2018 20:25:22 +0200 Subject: [PATCH 5/5] Apply optimisation on private[this] fields only --- .../transform/CollectNullableFields.scala | 27 ++++++-------- .../dotty/tools/dotc/transform/Getters.scala | 6 +-- .../dotty/tools/dotc/transform/LazyVals.scala | 6 +-- tests/run/i1692.scala | 37 ++++++++++++------- 4 files changed, 39 insertions(+), 37 deletions(-) diff --git a/compiler/src/dotty/tools/dotc/transform/CollectNullableFields.scala b/compiler/src/dotty/tools/dotc/transform/CollectNullableFields.scala index 939896866a9f..201ae4e21def 100644 --- a/compiler/src/dotty/tools/dotc/transform/CollectNullableFields.scala +++ b/compiler/src/dotty/tools/dotc/transform/CollectNullableFields.scala @@ -34,10 +34,10 @@ object CollectNullableFields { * * A field is nullable if all the conditions below hold: * - belongs to a non trait-class - * - is private + * - is private[this] * - is not lazy - * - its type is nullable, or is an expression type (e.g. => Int) - * - is on used in a lazy val initializer + * - its type is nullable + * - is only used in a lazy val initializer * - defined in the same class as the lazy val */ class CollectNullableFields extends MiniPhase { @@ -45,12 +45,8 @@ class CollectNullableFields extends MiniPhase { override def phaseName = CollectNullableFields.name - /** Running after `ElimByName` to see by names as nullable types. - * - * We don't necessary need to run after `Getters`, but the implementation - * could be simplified if we were to run before. - */ - override def runsAfter = Set(Getters.name, ElimByName.name) + /** Running after `ElimByName` to see by names as nullable types. */ + override def runsAfter = Set(ElimByName.name) private[this] sealed trait FieldInfo private[this] case object NotNullable extends FieldInfo @@ -65,14 +61,12 @@ class CollectNullableFields extends MiniPhase { } private def recordUse(tree: Tree)(implicit ctx: Context): Tree = { - def isField(sym: Symbol) = - sym.isField || sym.isGetter // running after phase Getters - val sym = tree.symbol val isNullablePrivateField = - isField(sym) && - sym.is(Private, butNot = Lazy) && + sym.isField && + !sym.is(Lazy) && !sym.owner.is(Trait) && + sym.initial.is(PrivateLocal) && sym.info.widenDealias.typeSymbol.isNullableClass if (isNullablePrivateField) @@ -82,8 +76,9 @@ class CollectNullableFields extends MiniPhase { case null => // not in the map val from = ctx.owner val isNullable = - from.is(Lazy) && isField(from) && // used in lazy field initializer - from.owner.eq(sym.owner) // lazy val and field defined in the same class + from.is(Lazy, butNot = Module) && // is lazy val + from.owner.isClass && // is field + from.owner.eq(sym.owner) // is lazy val and field defined in the same class val info = if (isNullable) Nullable(from) else NotNullable nullability.put(sym, info) case _ => diff --git a/compiler/src/dotty/tools/dotc/transform/Getters.scala b/compiler/src/dotty/tools/dotc/transform/Getters.scala index 3c554e4230af..55e68d69ff11 100644 --- a/compiler/src/dotty/tools/dotc/transform/Getters.scala +++ b/compiler/src/dotty/tools/dotc/transform/Getters.scala @@ -49,7 +49,7 @@ import ValueClasses._ class Getters extends MiniPhase with SymTransformer { import ast.tpd._ - override def phaseName = Getters.name + override def phaseName = "getters" override def transformSym(d: SymDenotation)(implicit ctx: Context): SymDenotation = { def noGetterNeeded = @@ -74,7 +74,3 @@ class Getters extends MiniPhase with SymTransformer { override def transformAssign(tree: Assign)(implicit ctx: Context): Tree = if (tree.lhs.symbol is Method) tree.lhs.becomes(tree.rhs).withPos(tree.pos) else tree } - -object Getters { - val name = "getters" -} diff --git a/compiler/src/dotty/tools/dotc/transform/LazyVals.scala b/compiler/src/dotty/tools/dotc/transform/LazyVals.scala index bd871294661c..7e5cc6a277d9 100644 --- a/compiler/src/dotty/tools/dotc/transform/LazyVals.scala +++ b/compiler/src/dotty/tools/dotc/transform/LazyVals.scala @@ -55,8 +55,9 @@ class LazyVals extends MiniPhase with IdentityDenotTransformer { /** A map of lazy values to the fields they should null after initialization. */ private[this] var lazyValNullables: IdentityHashMap[Symbol, mutable.ListBuffer[Symbol]] = _ private def nullableFor(sym: Symbol)(implicit ctx: Context) = { + // optimisation: value only used once, we can remove the value from the map val nullables = lazyValNullables.remove(sym) - if (nullables == null || sym.is(Flags.Module)) Nil + if (nullables == null) Nil else nullables.toList } @@ -195,8 +196,7 @@ class LazyVals extends MiniPhase with IdentityDenotTransformer { private def nullOut(nullables: List[Symbol])(implicit ctx: Context): List[Tree] = { val nullConst = Literal(Constants.Constant(null)) - nullables.map { sym => - val field = if (sym.isGetter) sym.field else sym + nullables.map { field => assert(field.isField) field.setFlag(Flags.Mutable) ref(field).becomes(nullConst) diff --git a/tests/run/i1692.scala b/tests/run/i1692.scala index 12ab7835385d..f70cd1b2ed16 100644 --- a/tests/run/i1692.scala +++ b/tests/run/i1692.scala @@ -4,46 +4,46 @@ class VCString(val x: String) extends AnyVal class LazyNullable(a: => Int) { lazy val l0 = a // null out a - private val b = "B" + private[this] val b = "B" lazy val l1 = b // null out b - private val c = "C" + private[this] val c = "C" @volatile lazy val l2 = c // null out c - private val d = "D" + private[this] val d = "D" lazy val l3 = d + d // null out d (Scalac require single use?) } object LazyNullable2 { - private val a = "A" + private[this] val a = "A" lazy val l0 = a // null out a } class LazyNotNullable { - private val a = 'A'.toInt // not nullable type + private[this] val a = 'A'.toInt // not nullable type lazy val l0 = a - private val b = new VCInt('B'.toInt) // not nullable type + private[this] val b = new VCInt('B'.toInt) // not nullable type lazy val l1 = b - private val c = new VCString("C") // should be nullable but is not?? + private[this] val c = new VCString("C") // should be nullable but is not?? lazy val l2 = c - private lazy val d = "D" // not nullable because lazy + private[this] lazy val d = "D" // not nullable because lazy lazy val l3 = d - val e = "E" // not nullable because not private + private val e = "E" // not nullable because not private[this] lazy val l4 = e - private val f = "F" // not nullable because used in mutiple lazy vals + private[this] val f = "F" // not nullable because used in mutiple lazy vals lazy val l5 = f lazy val l6 = f - private val g = "G" // not nullable because used outside a lazy val initializer + private[this] val g = "G" // not nullable because used outside a lazy val initializer def foo = g lazy val l7 = g - private val h = "H" // not nullable because field and lazy val not defined in the same class + private[this] val h = "H" // not nullable because field and lazy val not defined in the same class class Inner { lazy val l8 = h } @@ -54,6 +54,13 @@ trait LazyTrait { lazy val l0 = a } +class Foo(val x: String) + +class LazyNotNullable2(x: String) extends Foo(x) { + lazy val y = x // not nullable. Here x is super.x +} + + object Test { def main(args: Array[String]): Unit = { nullableTests() @@ -115,11 +122,15 @@ object Test { val inner = new lz.Inner assert(inner.l8 == "H") - assertNotNull("h") + assertNotNull("LazyNotNullable$$h") // fragile: test will break if compiler generated names change val fromTrait = new LazyTrait {} assert(fromTrait.l0 == "A") assert(readField("LazyTrait$$a", fromTrait) != null) // fragile: test will break if compiler generated names change + + val lz2 = new LazyNotNullable2("Hello") + assert(lz2.y == "Hello") + assert(lz2.x == "Hello") } def readField(fieldName: String, target: Any): Any = {