Skip to content

Commit b7bfe15

Browse files
committed
Fix #1692: Null out fields after use in lazy initialization
Private fields that are only used during lzyy val initialization can be assigned null once the lazy val is initialized. This is not just an optimization, but is needed for correctness to prevent memory leaks.
1 parent 07fa870 commit b7bfe15

File tree

4 files changed

+142
-13
lines changed

4 files changed

+142
-13
lines changed

compiler/src/dotty/tools/dotc/Compiler.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,6 +73,7 @@ class Compiler {
7373
new LiftTry, // Put try expressions that might execute on non-empty stacks into their own methods
7474
new HoistSuperArgs, // Hoist complex arguments of supercalls to enclosing scope
7575
new ClassOf, // Expand `Predef.classOf` calls.
76+
new CollectNullableFields, // Collect fields that can be null out after use in lazy initialization
7677
new RefChecks) :: // Various checks mostly related to abstract members and overriding
7778
List(new TryCatchPatterns, // Compile cases in try/catch
7879
new PatternMatcher, // Compile pattern matches
@@ -97,7 +98,7 @@ class Compiler {
9798
List(new Erasure) :: // Rewrite types to JVM model, erasing all type parameters, abstract types and refinements.
9899
List(new ElimErasedValueType, // Expand erased value types to their underlying implmementation types
99100
new VCElideAllocations, // Peep-hole optimization to eliminate unnecessary value class allocations
100-
new Mixin, // Expand trait fields and trait initializers
101+
new Mixin, // Expand trait fields and trait initializers
101102
new LazyVals, // Expand lazy vals
102103
new Memoize, // Add private fields to getters and setters
103104
new NonLocalReturns, // Expand non-local returns

compiler/src/dotty/tools/dotc/core/Phases.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ object Phases {
210210

211211
private[this] var myTyperPhase: Phase = _
212212
private[this] var myPicklerPhase: Phase = _
213+
private[this] var myCollectNullableFieldsPhase: Phase = _
213214
private[this] var myRefChecksPhase: Phase = _
214215
private[this] var myPatmatPhase: Phase = _
215216
private[this] var myElimRepeatedPhase: Phase = _
@@ -224,6 +225,7 @@ object Phases {
224225

225226
final def typerPhase = myTyperPhase
226227
final def picklerPhase = myPicklerPhase
228+
final def collectNullableFieldsPhase = myCollectNullableFieldsPhase
227229
final def refchecksPhase = myRefChecksPhase
228230
final def patmatPhase = myPatmatPhase
229231
final def elimRepeatedPhase = myElimRepeatedPhase
@@ -241,6 +243,7 @@ object Phases {
241243

242244
myTyperPhase = phaseOfClass(classOf[FrontEnd])
243245
myPicklerPhase = phaseOfClass(classOf[Pickler])
246+
myCollectNullableFieldsPhase = phaseOfClass(classOf[CollectNullableFields])
244247
myRefChecksPhase = phaseOfClass(classOf[RefChecks])
245248
myElimRepeatedPhase = phaseOfClass(classOf[ElimRepeated])
246249
myExtensionMethodsPhase = phaseOfClass(classOf[ExtensionMethods])
Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
package dotty.tools.dotc.transform
2+
3+
import dotty.tools.dotc.ast.tpd
4+
import dotty.tools.dotc.core.Contexts.Context
5+
import dotty.tools.dotc.core.Flags._
6+
import dotty.tools.dotc.core.Symbols.Symbol
7+
import dotty.tools.dotc.core.Types.ExprType
8+
import dotty.tools.dotc.transform.MegaPhase.MiniPhase
9+
import dotty.tools.dotc.transform.SymUtils._
10+
11+
import scala.collection.JavaConverters._
12+
import scala.collection.mutable
13+
14+
import java.util.IdentityHashMap
15+
16+
object CollectNullableFields {
17+
val name = "collectNullableFields"
18+
}
19+
20+
/** Collect fields that can be null out after use in lazy initialization.
21+
*
22+
* This information is used during lazy val transformation to assign null to private
23+
* fields that are only used within a lazy val initializer.
24+
* This is necessary to prevent memory leaks. E.g.
25+
*
26+
* {{{
27+
* class TestByNameLazy(byNameMsg: => String) {
28+
* lazy val byLazyValMsg = byNameMsg
29+
* }
30+
* }}}
31+
*
32+
* Here `byNameMsg` should be null out once `byLazyValMsg` is
33+
* initialised.
34+
*/
35+
class CollectNullableFields extends MiniPhase {
36+
import tpd._
37+
38+
override def phaseName = CollectNullableFields.name
39+
40+
private sealed trait FieldInfo
41+
private case object NotNullable extends FieldInfo
42+
private case class Nullable(by: Symbol) extends FieldInfo
43+
44+
private[this] var fieldInfo: IdentityHashMap[Symbol, FieldInfo] = _
45+
46+
override def prepareForUnit(tree: Tree)(implicit ctx: Context) = {
47+
fieldInfo = new IdentityHashMap
48+
ctx
49+
}
50+
51+
private def recordUse(tree: Tree)(implicit ctx: Context): Tree = {
52+
val sym = tree.symbol
53+
54+
def isNullable =
55+
sym.info.isInstanceOf[ExprType] ||
56+
sym.info.widenDealias.typeSymbol.isNullableClass
57+
val isNullablePrivateField = sym.isField && sym.is(Private) && isNullable
58+
59+
if (isNullablePrivateField)
60+
fieldInfo.get(sym) match {
61+
case Nullable(from) if from != ctx.owner => // used in multiple lazy val initializers
62+
fieldInfo.put(sym, NotNullable)
63+
case null => // not in the map
64+
val from = ctx.owner
65+
val inLazyValInitializer = from.is(Lazy, butNot = Module)
66+
val info = if (inLazyValInitializer) Nullable(from) else NotNullable
67+
fieldInfo.put(sym, info)
68+
case _ =>
69+
}
70+
71+
tree
72+
}
73+
74+
override def transformIdent(tree: Ident)(implicit ctx: Context) =
75+
recordUse(tree)
76+
77+
override def transformSelect(tree: Select)(implicit ctx: Context) =
78+
recordUse(tree)
79+
80+
/** Map lazy values to the fields they should null after initialization. */
81+
def lazyValNullables(implicit ctx: Context): Map[Symbol, List[Symbol]] = {
82+
val result = new mutable.HashMap[Symbol, mutable.ListBuffer[Symbol]]
83+
84+
fieldInfo.forEach {
85+
case (sym, Nullable(from)) =>
86+
val bldr = result.getOrElseUpdate(from, new mutable.ListBuffer)
87+
bldr += sym
88+
case _ =>
89+
}
90+
91+
result.mapValues(_.toList).toMap
92+
}
93+
}

compiler/src/dotty/tools/dotc/transform/LazyVals.scala

Lines changed: 44 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ class LazyVals extends MiniPhase with IdentityDenotTransformer {
3939

4040
/** List of names of phases that should have finished processing of tree
4141
* before this phase starts processing same tree */
42-
override def runsAfter = Set(Mixin.name)
42+
override def runsAfter = Set(Mixin.name, CollectNullableFields.name)
4343

4444
override def changesMembers = true // the phase adds lazy val accessors
4545

@@ -50,6 +50,15 @@ class LazyVals extends MiniPhase with IdentityDenotTransformer {
5050

5151
val containerFlagsMask = Flags.Method | Flags.Lazy | Flags.Accessor | Flags.Module
5252

53+
/** A map of lazy values to the fields they should null after initialization. */
54+
private[this] var lazyValNullables: Map[Symbol, List[Symbol]] = _
55+
private def nullableFor(sym: Symbol) = lazyValNullables.getOrElse(sym, Nil)
56+
57+
override def prepareForUnit(tree: Tree)(implicit ctx: Context) = {
58+
lazyValNullables = ctx.collectNullableFieldsPhase.asInstanceOf[CollectNullableFields].lazyValNullables
59+
ctx
60+
}
61+
5362
override def transformDefDef(tree: tpd.DefDef)(implicit ctx: Context): tpd.Tree =
5463
transformLazyVal(tree)
5564

@@ -150,7 +159,7 @@ class LazyVals extends MiniPhase with IdentityDenotTransformer {
150159
val initBody =
151160
adaptToType(
152161
ref(holderSymbol).select(defn.Object_synchronized).appliedTo(
153-
adaptToType(mkNonThreadSafeDef(result, flag, initer), defn.ObjectType)),
162+
adaptToType(mkNonThreadSafeDef(result, flag, initer, nullableFor(x.symbol)), defn.ObjectType)),
154163
tpe)
155164
val initTree = DefDef(initSymbol, initBody)
156165
val holderTree = ValDef(holderSymbol, New(holderImpl.typeRef, List()))
@@ -176,37 +185,46 @@ class LazyVals extends MiniPhase with IdentityDenotTransformer {
176185
holders:::stats
177186
}
178187

188+
private def nullOut(nullables: List[Symbol])(implicit ctx: Context): List[Tree] = {
189+
val nullConst = Literal(Constants.Constant(null))
190+
nullables.map(sym => ref(sym).becomes(nullConst))
191+
}
192+
179193
/** Create non-threadsafe lazy accessor equivalent to such code
180194
* def methodSymbol() = {
181195
* if (flag) target
182196
* else {
183197
* target = rhs
184198
* flag = true
199+
* nullable = null
185200
* target
186201
* }
187202
* }
188203
*/
189204

190-
def mkNonThreadSafeDef(target: Tree, flag: Tree, rhs: Tree)(implicit ctx: Context) = {
205+
def mkNonThreadSafeDef(target: Tree, flag: Tree, rhs: Tree, nullables: List[Symbol])(implicit ctx: Context) = {
191206
val setFlag = flag.becomes(Literal(Constants.Constant(true)))
192-
val setTargets = if (isWildcardArg(rhs)) Nil else target.becomes(rhs) :: Nil
193-
val init = Block(setFlag :: setTargets, target.ensureApplied)
207+
val setNullables = nullOut(nullables)
208+
val setTargetAndNullable = if (isWildcardArg(rhs)) setNullables else target.becomes(rhs) :: setNullables
209+
val init = Block(setFlag :: setTargetAndNullable, target.ensureApplied)
194210
If(flag.ensureApplied, target.ensureApplied, init)
195211
}
196212

197213
/** Create non-threadsafe lazy accessor for not-nullable types equivalent to such code
198214
* def methodSymbol() = {
199215
* if (target eq null) {
200216
* target = rhs
217+
* nullable = null
201218
* target
202219
* } else target
203220
* }
204221
*/
205-
def mkDefNonThreadSafeNonNullable(target: Symbol, rhs: Tree)(implicit ctx: Context) = {
222+
def mkDefNonThreadSafeNonNullable(target: Symbol, rhs: Tree, nullables: List[Symbol])(implicit ctx: Context) = {
206223
val cond = ref(target).select(nme.eq).appliedTo(Literal(Constant(null)))
207224
val exp = ref(target)
208225
val setTarget = exp.becomes(rhs)
209-
val init = Block(List(setTarget), exp)
226+
val setNullables = nullOut(nullables)
227+
val init = Block(setTarget :: setNullables, exp)
210228
If(cond, init, exp)
211229
}
212230

@@ -222,14 +240,14 @@ class LazyVals extends MiniPhase with IdentityDenotTransformer {
222240

223241
val containerTree = ValDef(containerSymbol, defaultValue(tpe))
224242
if (x.tpe.isNotNull && tpe <:< defn.ObjectType) { // can use 'null' value instead of flag
225-
val slowPath = DefDef(x.symbol.asTerm, mkDefNonThreadSafeNonNullable(containerSymbol, x.rhs))
243+
val slowPath = DefDef(x.symbol.asTerm, mkDefNonThreadSafeNonNullable(containerSymbol, x.rhs, nullableFor(x.symbol)))
226244
Thicket(containerTree, slowPath)
227245
}
228246
else {
229247
val flagName = LazyBitMapName.fresh(x.name.asTermName)
230248
val flagSymbol = ctx.newSymbol(x.symbol.owner, flagName, containerFlags | Flags.Private, defn.BooleanType).enteredAfter(this)
231249
val flag = ValDef(flagSymbol, Literal(Constants.Constant(false)))
232-
val slowPath = DefDef(x.symbol.asTerm, mkNonThreadSafeDef(ref(containerSymbol), ref(flagSymbol), x.rhs))
250+
val slowPath = DefDef(x.symbol.asTerm, mkNonThreadSafeDef(ref(containerSymbol), ref(flagSymbol), x.rhs, nullableFor(x.symbol)))
233251
Thicket(containerTree, flag, slowPath)
234252
}
235253
}
@@ -263,10 +281,23 @@ class LazyVals extends MiniPhase with IdentityDenotTransformer {
263281
* result = $target
264282
* }
265283
* }
284+
* nullable = null
266285
* result
267286
* }
268287
*/
269-
def mkThreadSafeDef(methodSymbol: TermSymbol, claz: ClassSymbol, ord: Int, target: Symbol, rhs: Tree, tp: Types.Type, offset: Tree, getFlag: Tree, stateMask: Tree, casFlag: Tree, setFlagState: Tree, waitOnLock: Tree)(implicit ctx: Context) = {
288+
def mkThreadSafeDef(methodSymbol: TermSymbol,
289+
claz: ClassSymbol,
290+
ord: Int,
291+
target: Symbol,
292+
rhs: Tree,
293+
tp: Types.Type,
294+
offset: Tree,
295+
getFlag: Tree,
296+
stateMask: Tree,
297+
casFlag: Tree,
298+
setFlagState: Tree,
299+
waitOnLock: Tree,
300+
nullables: List[Symbol])(implicit ctx: Context) = {
270301
val initState = Literal(Constants.Constant(0))
271302
val computeState = Literal(Constants.Constant(1))
272303
val notifyState = Literal(Constants.Constant(2))
@@ -330,7 +361,8 @@ class LazyVals extends MiniPhase with IdentityDenotTransformer {
330361

331362
val whileBody = List(ref(flagSymbol).becomes(getFlag.appliedTo(thiz, offset)), cases)
332363
val cycle = WhileDo(methodSymbol, whileCond, whileBody)
333-
DefDef(methodSymbol, Block(resultDef :: retryDef :: flagDef :: cycle :: Nil, ref(resultSymbol)))
364+
val setNullables = nullOut(nullables)
365+
DefDef(methodSymbol, Block(resultDef :: retryDef :: flagDef :: cycle :: setNullables, ref(resultSymbol)))
334366
}
335367

336368
def transformMemberDefVolatile(x: ValOrDefDef)(implicit ctx: Context) = {
@@ -391,7 +423,7 @@ class LazyVals extends MiniPhase with IdentityDenotTransformer {
391423
val state = Select(ref(helperModule), lazyNme.RLazyVals.state)
392424
val cas = Select(ref(helperModule), lazyNme.RLazyVals.cas)
393425

394-
val accessor = mkThreadSafeDef(x.symbol.asTerm, claz, ord, containerSymbol, x.rhs, tpe, offset, getFlag, state, cas, setFlag, wait)
426+
val accessor = mkThreadSafeDef(x.symbol.asTerm, claz, ord, containerSymbol, x.rhs, tpe, offset, getFlag, state, cas, setFlag, wait, nullableFor(x.symbol))
395427
if (flag eq EmptyTree)
396428
Thicket(containerTree, accessor)
397429
else Thicket(containerTree, flag, accessor)

0 commit comments

Comments
 (0)