Skip to content

Commit 48f7b97

Browse files
committed
Refactor: move heap from ThisRef for sharing more information
1 parent f4aa49b commit 48f7b97

File tree

2 files changed

+67
-49
lines changed

2 files changed

+67
-49
lines changed

compiler/src/dotty/tools/dotc/transform/init/Checker.scala

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,8 +63,9 @@ class Checker extends MiniPhase {
6363

6464
import semantic._
6565
val tpl = tree.rhs.asInstanceOf[Template]
66-
val thisRef = ThisRef(cls)(fields = mutable.Map.empty)
67-
val res = eval(tpl, thisRef, cls)(using ctx, Vector.empty)
66+
val thisRef = ThisRef(cls)
67+
val heap = Objekt(fields = mutable.Map.empty)
68+
val res = eval(tpl, thisRef, cls)(using heap, ctx, Vector.empty)
6869
res.errors.foreach(_.issue)
6970
}
7071

compiler/src/dotty/tools/dotc/transform/init/Semantic.scala

Lines changed: 64 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,8 @@ class Semantic {
3838
case object Cold extends Value
3939

4040
/** Object referred by `this` which stores abstract values for all fields
41-
*
42-
* Note: the mutable `fields` plays the role of heap. Thanks to monotonicity
43-
* of the heap, we may handle it in a simple way.
4441
*/
45-
case class ThisRef(klass: ClassSymbol)(val fields: mutable.Map[Symbol, Value]) extends Value {
46-
var allFieldsInitialized: Boolean = false
47-
}
42+
case class ThisRef(klass: ClassSymbol) extends Value
4843

4944
/** An object with all fields initialized but reaches objects under initialization
5045
*
@@ -61,6 +56,21 @@ class Semantic {
6156
*/
6257
case class RefSet(refs: List[Warm | Fun | ThisRef]) extends Value
6358

59+
/** The current object under initialization
60+
*/
61+
case class Objekt(val fields: mutable.Map[Symbol, Value]) {
62+
var allFieldsInitialized: Boolean = false
63+
}
64+
65+
/** Abstract heap stores abstract objects
66+
*
67+
* As in the OOPSLA paper, the abstract heap is monotonistic.
68+
*
69+
* This is only one object we need to care about, hence it's just `Objekt`.
70+
*/
71+
type Heap = Objekt
72+
def heap(using h: Heap): Heap = h
73+
6474
/** Interpreter configuration
6575
*
6676
* The (abstract) interpreter can be seen as a push-down automaton
@@ -103,28 +113,30 @@ class Semantic {
103113

104114
def +(error: Error): Result = this.copy(errors = this.errors :+ error)
105115

106-
def ensureHot(msg: String, source: Tree)(using Context, Trace): Result =
116+
def ensureHot(msg: String, source: Tree): Contextual[Result] =
107117
this ++ value.promote(msg, source)
108118

109-
def select(f: Symbol, source: Tree)(using Context, Trace): Result =
119+
def select(f: Symbol, source: Tree): Contextual[Result] =
110120
value.select(f, source) ++ errors
111121

112-
def call(meth: Symbol, superType: Type, source: Tree)(using Context, Trace): Result =
122+
def call(meth: Symbol, superType: Type, source: Tree): Contextual[Result] =
113123
value.call(meth, superType, source) ++ errors
114124

115-
def instantiate(klass: ClassSymbol, ctor: Symbol, source: Tree)(using Context, Trace): Result =
125+
def instantiate(klass: ClassSymbol, ctor: Symbol, source: Tree): Contextual[Result] =
116126
value.instantiate(klass, ctor, source) ++ errors
117127
}
118128

129+
/** The state that threads through the interpreter */
130+
type Contextual[T] = (Heap, Context, Trace) ?=> T
131+
119132
// ----- Error Handling -----------------------------------
120133
type Trace = Vector[Tree]
121-
122-
val noErrors = Nil
134+
def trace(using t: Trace): Trace = t
123135

124136
extension (trace: Trace)
125137
def add(node: Tree): Trace = trace :+ node
126138

127-
def trace(using t: Trace): Trace = t
139+
val noErrors = Nil
128140

129141
// ----- Operations on domains -----------------------------
130142
extension (a: Value)
@@ -149,7 +161,7 @@ class Semantic {
149161
else values.reduce { (v1, v2) => v1.join(v2) }
150162

151163
extension (value: Value)
152-
def select(field: Symbol, source: Tree)(using Context, Trace): Result =
164+
def select(field: Symbol, source: Tree): Contextual[Result] =
153165
value match {
154166
case Hot =>
155167
Result(Hot, noErrors)
@@ -161,8 +173,8 @@ class Semantic {
161173
case thisRef: ThisRef =>
162174
val target = resolve(thisRef.klass, field)
163175
if target.is(Flags.Lazy) then value.call(target, superType = NoType, source)
164-
else if thisRef.fields.contains(target) then
165-
Result(thisRef.fields(target), Nil)
176+
else if heap.fields.contains(target) then
177+
Result(heap.fields(target), Nil)
166178
else
167179
val error = AccessNonInit(target, trace.add(source))
168180
Result(Hot, error :: Nil)
@@ -186,7 +198,7 @@ class Semantic {
186198
Result(value2, errors)
187199
}
188200

189-
def call(meth: Symbol, superType: Type, source: Tree)(using Context, Trace): Result =
201+
def call(meth: Symbol, superType: Type, source: Tree): Contextual[Result] =
190202
value match {
191203
case Hot =>
192204
Result(Hot, noErrors)
@@ -216,8 +228,8 @@ class Semantic {
216228
else
217229
val error = CallUnknown(target, source, trace)
218230
Result(Hot, error :: Nil)
219-
else if thisRef.fields.contains(target) then
220-
Result(thisRef.fields(target), Nil)
231+
else if heap.fields.contains(target) then
232+
Result(heap.fields(target), Nil)
221233
else
222234
val error = AccessNonInit(target, trace.add(source))
223235
Result(Hot, error :: Nil)
@@ -257,7 +269,7 @@ class Semantic {
257269
Result(value2, errors)
258270
}
259271

260-
def instantiate(klass: ClassSymbol, ctor: Symbol, source: Tree)(using Context, Trace): Result =
272+
def instantiate(klass: ClassSymbol, ctor: Symbol, source: Tree): Contextual[Result] =
261273
value match {
262274
case Hot =>
263275
Result(Hot, noErrors)
@@ -289,8 +301,17 @@ class Semantic {
289301
}
290302
end extension
291303

304+
extension (ref: ThisRef | Warm)
305+
def updateField(field: Symbol, value: Value): Contextual[Unit] =
306+
ref match
307+
case thisRef: ThisRef => heap.fields(field) = value
308+
case warm: Warm => // ignore
309+
end extension
310+
311+
// ----- Promotion ----------------------------------------------------
312+
292313
extension (value: Value)
293-
def canDirectlyPromote(using Context): Boolean =
314+
def canDirectlyPromote(using Heap, Context): Boolean =
294315
value match
295316
case Hot => true
296317
case Cold => false
@@ -299,13 +320,13 @@ class Semantic {
299320
warm.outer.canDirectlyPromote
300321

301322
case thisRef: ThisRef =>
302-
thisRef.allFieldsInitialized || {
323+
heap.allFieldsInitialized || {
303324
// If we have all fields initialized, then we can promote This to hot.
304-
thisRef.allFieldsInitialized = thisRef.klass.appliedRef.fields.forall { denot =>
325+
heap.allFieldsInitialized = thisRef.klass.appliedRef.fields.forall { denot =>
305326
val sym = denot.symbol
306-
sym.isOneOf(Flags.Lazy | Flags.Deferred) || thisRef.fields.contains(sym)
327+
sym.isOneOf(Flags.Lazy | Flags.Deferred) || heap.fields.contains(sym)
307328
}
308-
thisRef.allFieldsInitialized
329+
heap.allFieldsInitialized
309330
}
310331

311332
case fun: Fun => false
@@ -316,7 +337,7 @@ class Semantic {
316337
end canDirectlyPromote
317338

318339
/** Promotion of values to hot */
319-
def promote(msg: String, source: Tree)(using Context, Trace): List[Error] =
340+
def promote(msg: String, source: Tree): Contextual[List[Error]] =
320341
value match
321342
case Hot => Nil
322343

@@ -354,7 +375,7 @@ class Semantic {
354375
* promote the field value
355376
*
356377
*/
357-
def tryPromote(msg: String, source: Tree)(using Context, Trace): List[Error] =
378+
def tryPromote(msg: String, source: Tree): Contextual[List[Error]] = log("promote " + warm.show, printer) {
358379
val classRef = warm.klass.appliedRef
359380
if classRef.memberClasses.nonEmpty then
360381
return PromoteWarm(source, trace) :: Nil
@@ -381,16 +402,10 @@ class Semantic {
381402

382403
if buffer.isEmpty then Nil
383404
else UnsafePromotion(source, trace, buffer.toList) :: Nil
405+
}
384406

385407
end extension
386408

387-
extension (ref: ThisRef | Warm)
388-
def updateField(field: Symbol, value: Value): Unit =
389-
ref match
390-
case thisRef: ThisRef => thisRef.fields(field) = value
391-
case warm: Warm => // ignore
392-
end extension
393-
394409
// ----- Policies ------------------------------------------------------
395410
extension (value: Warm | ThisRef)
396411
/** Can the method call on `value` be ignored?
@@ -409,7 +424,7 @@ class Semantic {
409424
*
410425
* This method only handles cache logic and delegates the work to `cases`.
411426
*/
412-
def eval(expr: Tree, thisV: Value, klass: ClassSymbol, cacheResult: Boolean = false)(using Context, Trace): Result = log("evaluating " + expr.show + ", this = " + thisV.show, printer, res => res.asInstanceOf[Result].show) {
427+
def eval(expr: Tree, thisV: Value, klass: ClassSymbol, cacheResult: Boolean = false): Contextual[Result] = log("evaluating " + expr.show + ", this = " + thisV.show, printer, res => res.asInstanceOf[Result].show) {
413428
val innerMap = cache.getOrElseUpdate(thisV, new EqHashMap[Tree, Value])
414429
if (innerMap.contains(expr)) Result(innerMap(expr), noErrors)
415430
else {
@@ -425,11 +440,11 @@ class Semantic {
425440
}
426441

427442
/** Evaluate a list of expressions */
428-
def eval(exprs: List[Tree], thisV: Value, klass: ClassSymbol)(using Context, Trace): List[Result] =
443+
def eval(exprs: List[Tree], thisV: Value, klass: ClassSymbol): Contextual[List[Result]] =
429444
exprs.map { expr => eval(expr, thisV, klass) }
430445

431446
/** Evaluate arguments of methods */
432-
def evalArgs(args: List[Arg], thisV: Value, klass: ClassSymbol)(using Context, Trace): List[Error] =
447+
def evalArgs(args: List[Arg], thisV: Value, klass: ClassSymbol): Contextual[List[Error]] =
433448
val ress = args.map { arg =>
434449
val res =
435450
if arg.isByName then
@@ -449,7 +464,7 @@ class Semantic {
449464
*
450465
* Note: Recursive call should go to `eval` instead of `cases`.
451466
*/
452-
def cases(expr: Tree, thisV: Value, klass: ClassSymbol)(using Context, Trace): Result =
467+
def cases(expr: Tree, thisV: Value, klass: ClassSymbol): Contextual[Result] =
453468
expr match {
454469
case Ident(nme.WILDCARD) =>
455470
// TODO: disallow `var x: T = _`
@@ -471,15 +486,17 @@ class Semantic {
471486
// check args
472487
val errors = evalArgs(argss.flatten, thisV, klass)
473488

489+
val trace2: Trace = trace.add(expr)
490+
474491
ref match
475492
case Select(supert: Super, _) =>
476493
val SuperType(thisTp, superTp) = supert.tpe
477494
val thisValue2 = resolveThis(thisTp.classSymbol.asClass, thisV, klass)
478-
Result(thisValue2, errors).call(ref.symbol, superTp, expr)(using ctx, trace.add(expr))
495+
Result(thisValue2, errors).call(ref.symbol, superTp, expr)(using heap, ctx, trace2)
479496

480497
case Select(qual, _) =>
481498
val res = eval(qual, thisV, klass) ++ errors
482-
res.call(ref.symbol, superType = NoType, source = expr)(using ctx, trace.add(expr))
499+
res.call(ref.symbol, superType = NoType, source = expr)(using heap, ctx, trace2)
483500

484501
case id: Ident =>
485502
id.tpe match
@@ -491,10 +508,10 @@ class Semantic {
491508
case Hot => Result(Hot, errors)
492509
case _ =>
493510
val rhs = id.symbol.defTree.asInstanceOf[DefDef].rhs
494-
eval(rhs, thisValue2, enclosingClass, cacheResult = true)(using ctx, trace.add(expr))
511+
eval(rhs, thisValue2, enclosingClass, cacheResult = true)(using heap, ctx, trace2)
495512
case TermRef(prefix, _) =>
496513
val res = cases(prefix, thisV, klass, id) ++ errors
497-
res.call(id.symbol, superType = NoType, source = expr)(using ctx, trace.add(expr))
514+
res.call(id.symbol, superType = NoType, source = expr)(using heap, ctx, trace2)
498515

499516
case Select(qualifier, name) =>
500517
eval(qualifier, thisV, klass).select(expr.symbol, expr)
@@ -611,7 +628,7 @@ class Semantic {
611628
}
612629

613630
/** Handle semantics of leaf nodes */
614-
def cases(tp: Type, thisV: Value, klass: ClassSymbol, source: Tree)(using Context, Trace): Result = log("evaluating " + tp.show, printer, res => res.asInstanceOf[Result].show) {
631+
def cases(tp: Type, thisV: Value, klass: ClassSymbol, source: Tree): Contextual[Result] = log("evaluating " + tp.show, printer, res => res.asInstanceOf[Result].show) {
615632
tp match {
616633
case _: ConstantType =>
617634
Result(Hot, noErrors)
@@ -638,7 +655,7 @@ class Semantic {
638655
}
639656

640657
/** Resolve C.this that appear in `klass` */
641-
def resolveThis(target: ClassSymbol, thisV: Value, klass: ClassSymbol)(using Context, Trace): Value = log("resolving " + target.show + ", this = " + thisV.show + " in " + klass.show, printer, res => res.asInstanceOf[Value].show) {
658+
def resolveThis(target: ClassSymbol, thisV: Value, klass: ClassSymbol): Contextual[Value] = log("resolving " + target.show + ", this = " + thisV.show + " in " + klass.show, printer, res => res.asInstanceOf[Value].show) {
642659
if target == klass then thisV
643660
else
644661
thisV match
@@ -660,7 +677,7 @@ class Semantic {
660677
}
661678

662679
/** Compute the outer value that correspond to `tref.prefix` */
663-
def outerValue(tref: TypeRef, thisV: Value, klass: ClassSymbol, source: Tree)(using Context, Trace): Result =
680+
def outerValue(tref: TypeRef, thisV: Value, klass: ClassSymbol, source: Tree): Contextual[Result] =
664681
val cls = tref.classSymbol.asClass
665682
if tref.prefix == NoPrefix then
666683
val enclosing = cls.owner.lexicallyEnclosingClass.asClass
@@ -670,7 +687,7 @@ class Semantic {
670687
cases(tref.prefix, thisV, klass, source)
671688

672689
/** Initialize part of an abstract object in `klass` of the inheritance chain */
673-
def init(tpl: Template, thisV: ThisRef | Warm, klass: ClassSymbol)(using Context, Trace): Result = log("init " + klass.show, printer, res => res.asInstanceOf[Result].show) {
690+
def init(tpl: Template, thisV: ThisRef | Warm, klass: ClassSymbol): Contextual[Result] = log("init " + klass.show, printer, res => res.asInstanceOf[Result].show) {
674691
val errorBuffer = new mutable.ArrayBuffer[Error]
675692

676693
// init param fields
@@ -748,7 +765,7 @@ class Semantic {
748765
*
749766
* This is intended to avoid type soundness issues in Dotty.
750767
*/
751-
def checkTermUsage(tpt: Tree, thisV: Value, klass: ClassSymbol)(using Context, Trace): List[Error] =
768+
def checkTermUsage(tpt: Tree, thisV: Value, klass: ClassSymbol): Contextual[List[Error]] =
752769
val buf = new mutable.ArrayBuffer[Error]
753770
val traverser = new TypeTraverser {
754771
def traverse(tp: Type): Unit = tp match {

0 commit comments

Comments
 (0)