Skip to content

Adding fix point evaluation in initialization checker #13379

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 18 additions & 4 deletions compiler/src/dotty/tools/dotc/transform/init/Checker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import Phases._


import scala.collection.mutable
import util.EqHashMap


class Checker extends Phase {
Expand Down Expand Up @@ -78,11 +79,24 @@ class Checker extends Phase {

val paramValues = tpl.constr.termParamss.flatten.map(param => param.symbol -> Hot).toMap

given Promoted = Promoted.empty
given Trace = Trace.empty
given Env = Env(paramValues)
// A wrapper for eval that uses two-cache method to compute fix-point after evaluating a class body
def fixPointEval(expr: Tree, thisV: Addr, klass: ClassSymbol, inputCache: evalCache, outputCache: evalCache): Result = {
import semantic.Heap._
given Promoted = Promoted.empty
given Trace = Trace.empty
given Env = Env(paramValues)
given (evalCache, evalCache) = (inputCache, outputCache)
val res = eval(expr, thisV, klass)
if !res.errors.isEmpty then res
else if inputCache.equal(outputCache) then
inputCache.commitEvalCache(thisV)
res
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

During the meeting, @olhotak mentioned it's preferable to use a flag to remember whether something has changed, which should be faster than testing equality of caches.

else
thisV.emptyField()
fixPointEval(expr, thisV, klass, outputCache, new evalCache)
}
Copy link
Contributor

@liufengyun liufengyun Aug 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I might have missed something, we mentioned it's preferable to compute fixed points for field initializers & when accessing a missing key of warm objects. Otherwise, it will not integrate well with the heap-level immutable caches.


val res = eval(tpl, thisRef, cls)
val res = fixPointEval(tpl, thisRef, cls, new evalCache, new evalCache)
res.errors.foreach(_.issue)
}

Expand Down
56 changes: 36 additions & 20 deletions compiler/src/dotty/tools/dotc/transform/init/Semantic.scala
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ class Semantic {
*/
def updateField(field: Symbol, value: Value): Contextual[Unit] =
val fields = heap(ref).fields
assert(!fields.contains(field), field.show + " already init, new = " + value + ", ref =" + ref)
// assert(!fields.contains(field), field.show + " already init, new = " + value + ", ref =" + ref)
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The fix-point evaluation needs to update fields that are already initialized, but this line doesn't allow to do that. Could we remove this assertion, or shall we find other ways to empty the heap before re-evaluation?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is an issue that we will have to consider in the overall design, in the theory. The assumption that the heap does not change once set was dependent on a single pass of evaluation. Now that we are doing a fixed-point computation, I don't think this assumption holds anymore.

So we might need to include the heap in the fixed-point evaluation, i.e. in the cache.

More generally, I think this points to a need to discuss the overall design more first before doing the implementation.

Still, the implementation was valuable in that it uncovered this theoretical issue.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

More thoughts on the overall design:

One design that we know is correct (but probably inefficient) is to do the fixed-point computation at the very outermost level. That is, we do one pass of the whole analysis, for all classes. We then clear the cache, clear the heap, clear everything, and do the whole analysis again, except this time, instead of using Hot as a starting point for all unseen expressions, we use the result from the first phase. Repeat until a fixed point.

Maybe we should just implement this correct-but-inefficient strategy first and then look at ways to improve efficiency.

If we do not track any dependences, it is quite easy to see that at least two passes over every expression are necessary: the first eval of any expression could depend on something that turns out to be non-hot (we can't know for sure since we don't track dependences), so to be safe we need to analyze the expression at least once more.

If we decide that two passes over everything are too expensive, then we must look at ways to track dependences. Then, I think our target should be to find cheap, i.e. coarse-grained ways to track dependences. For example, if eval of some expression depends on some field of the heap, it may be cheaper to just record that it depends on the heap in general and redo the eval of that expression than to precisely record the exact set of fields of the heap that it depends on.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can compute fixed points for field initializers and then store the result in the heap --- which follows the paper. Actually, that's one of the main motivations of computing fixed points for all values stored in the heap. We can view the heap as a global cache.

fields(field) = value

/** Update the immediate outer of the given `klass` of the abstract object
Expand All @@ -133,6 +133,11 @@ class Semantic {
*/
def updateOuter(klass: ClassSymbol, value: Value): Contextual[Unit] =
heap(ref).outers(klass) = value

/** Eliminate the field is necessary when computing fix-point
*
*/
def emptyField(): Contextual[Unit] = heap(ref).fields.clear()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure we need this. Breaking invariants in code is not good.

end extension
}
type Heap = Heap.Heap
Expand Down Expand Up @@ -268,7 +273,7 @@ class Semantic {
}

/** The state that threads through the interpreter */
type Contextual[T] = (Env, Context, Trace, Promoted) ?=> T
type Contextual[T] = (Env, (evalCache, evalCache), Context, Trace, Promoted) ?=> T
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In Scala, types need to be capitalized. Meanwhile, opaque types can help here to make the interface nicer.


// ----- Error Handling -----------------------------------

Expand Down Expand Up @@ -337,7 +342,7 @@ class Semantic {
if target.is(Flags.Lazy) then
given Trace = trace1
val rhs = target.defTree.asInstanceOf[ValDef].rhs
eval(rhs, addr, target.owner.asClass, cacheResult = true)
eval(rhs, addr, target.owner.asClass)
else
val obj = heap(addr)
if obj.fields.contains(target) then
Expand All @@ -351,7 +356,7 @@ class Semantic {
Result(Hot, Nil)
else if target.hasSource then
val rhs = target.defTree.asInstanceOf[ValOrDefDef].rhs
eval(rhs, addr, target.owner.asClass, cacheResult = true)
eval(rhs, addr, target.owner.asClass)
else
val error = CallUnknown(field, source, trace.toVector)
Result(Hot, error :: Nil)
Expand Down Expand Up @@ -404,15 +409,15 @@ class Semantic {
if target.isPrimaryConstructor then
given Env = env2
val tpl = cls.defTree.asInstanceOf[TypeDef].rhs.asInstanceOf[Template]
val res = withTrace(trace.add(cls.defTree)) { eval(tpl, addr, cls, cacheResult = true) }
val res = withTrace(trace.add(cls.defTree)) { eval(tpl, addr, cls) }
Result(addr, res.errors)
else if target.isConstructor then
given Env = env2
eval(ddef.rhs, addr, cls, cacheResult = true)
eval(ddef.rhs, addr, cls)
else
// normal method call
withEnv(if isLocal then env else Env.empty) {
eval(ddef.rhs, addr, cls, cacheResult = true) ++ checkArgs
eval(ddef.rhs, addr, cls) ++ checkArgs
}
else if addr.canIgnoreMethodCall(target) then
Result(Hot, Nil)
Expand All @@ -433,7 +438,7 @@ class Semantic {
if meth.name.toString == "tupled" then Result(value, Nil) // a call like `fun.tupled`
else
withEnv(env) {
eval(body, thisV, klass, cacheResult = true) ++ checkArgs
eval(body, thisV, klass) ++ checkArgs
}

case RefSet(refs) =>
Expand Down Expand Up @@ -702,19 +707,30 @@ class Semantic {
*
* This method only handles cache logic and delegates the work to `cases`.
*/
def eval(expr: Tree, thisV: Addr, klass: ClassSymbol, cacheResult: Boolean = false): Contextual[Result] = log("evaluating " + expr.show + ", this = " + thisV.show, printer, res => res.asInstanceOf[Result].show) {

type evalCache = EqHashMap[Tree, Result]
def inputCache(using caches: (evalCache, evalCache)) = caches._1
def outputCache(using caches: (evalCache, evalCache)) = caches._2
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type definition needs to be capitalized. I think opaque types will make the API nicer.

BTW, let's move domain definitions to the domain section.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The cache key is incorrect: from the paper, it says Expr -> Value -> Value instead of Expr -> Result.

extension (evalcache: evalCache)
def equal(evalcache2: evalCache): Boolean =
evalcache.toSeq.forall((key, result) => evalcache2.contains(key) && evalcache2(key) == result)

def commitEvalCache(thisV: Addr): Unit =
val innerMap = cache.getOrElseUpdate(thisV, new EqHashMap[Tree, Value])
evalcache.toSeq.foreach((key, result) => innerMap(key) = result.value)
end extension

def eval(expr: Tree, thisV: Addr, klass: ClassSymbol): Contextual[Result] = log("evaluating " + expr.show + ", this = " + thisV.show, printer, res => res.asInstanceOf[Result].show) {
val innerMap = cache.getOrElseUpdate(thisV, new EqHashMap[Tree, Value])
if (innerMap.contains(expr)) Result(innerMap(expr), Errors.empty)
else {
// no need to compute fix-point, because
// 1. the result is decided by `cfg` for a legal program
// (heap change is irrelevant thanks to monotonicity)
// 2. errors will have been reported for an illegal program
innerMap(expr) = Hot
if innerMap.contains(expr) then Result(innerMap(expr), Errors.empty)
else if outputCache.contains(expr) then outputCache(expr)
else
// need to compute fix-point for soundness
if !inputCache.contains(expr) then inputCache(expr) = Result(Hot, Errors.empty)
outputCache(expr) = inputCache(expr)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's simpler to write the code as follows:

outputCache(expr) = inputCache.getOrElse(expr, Hot)

val res = cases(expr, thisV, klass)
if cacheResult then innerMap(expr) = res.value else innerMap.remove(expr)
outputCache(expr) = res
res
}
}

/** Evaluate a list of expressions */
Expand Down Expand Up @@ -889,7 +905,7 @@ class Semantic {
case vdef : ValDef =>
// local val definition
// TODO: support explicit @cold annotation for local definitions
eval(vdef.rhs, thisV, klass, cacheResult = true)
eval(vdef.rhs, thisV, klass)

case ddef : DefDef =>
// local method
Expand Down Expand Up @@ -1109,7 +1125,7 @@ class Semantic {
tpl.body.foreach {
case vdef : ValDef if !vdef.symbol.is(Flags.Lazy) && !vdef.rhs.isEmpty =>
given Env = Env.empty
val res = eval(vdef.rhs, thisV, klass, cacheResult = true)
val res = eval(vdef.rhs, thisV, klass)
errorBuffer ++= res.errors
thisV.updateField(vdef.symbol, res.value)
fieldsChanged = true
Expand Down
4 changes: 2 additions & 2 deletions tests/init/neg/enum-desugared.check
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,6 @@
| Cannot prove that the value is fully-initialized. May only use initialized value as method arguments.
|
| The unsafe promotion may cause the following problem:
| Calling the external method method ordinal may cause initialization errors. Calling trace:
| Calling the external method method name may cause initialization errors. Calling trace:
| -> Array(this.LazyErrorId, this.NoExplanationID) // error // error [ enum-desugared.scala:17 ]
| -> def errorNumber: Int = this.ordinal() - 2 [ enum-desugared.scala:8 ]
| -> override def productPrefix: String = this.name() [ enum-desugared.scala:29 ]
2 changes: 1 addition & 1 deletion tests/init/neg/inner-loop.scala
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
class Outer { outer =>
class Inner extends Outer {
val x = 5 + outer.n // error
val x = 5 + outer.n
}
val inner = new Inner
val n = 6 // error
Expand Down
5 changes: 2 additions & 3 deletions tests/init/neg/local-warm4.check
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,5 @@
| -> class A(x: Int) extends Foo(x) { [ local-warm4.scala:6 ]
| -> val b = new B(y) [ local-warm4.scala:10 ]
| -> class B(x: Int) extends A(x) { [ local-warm4.scala:13 ]
| -> class A(x: Int) extends Foo(x) { [ local-warm4.scala:6 ]
| -> increment() [ local-warm4.scala:9 ]
| -> updateA() [ local-warm4.scala:21 ]
| -> if y < 10 then increment() [ local-warm4.scala:23 ]
| -> updateA() [ local-warm4.scala:21 ]
4 changes: 4 additions & 0 deletions tests/init/neg/unsound1.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
-- Error: tests/init/neg/unsound1.scala:2:35 ---------------------------------------------------------------------------
2 | if (m > 0) println(foo(m - 1).a2.n) // error
| ^^^^^^^^^^^^^^^
| Access field A.this.foo(A.this.m.-(1)).a2.n on a value with an unknown initialization status.
11 changes: 11 additions & 0 deletions tests/init/neg/unsound1.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
class A(m: Int) {
if (m > 0) println(foo(m - 1).a2.n) // error
def foo(n: Int): B =
if (n % 2 == 0)
new B(new A(n - 1), foo(n - 1).a1)
else
new B(this, new A(n - 1))
var n: Int = 10
}

class B(val a1: A, val a2: A)
6 changes: 6 additions & 0 deletions tests/init/neg/unsound2.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
-- Error: tests/init/neg/unsound2.scala:5:26 ---------------------------------------------------------------------------
5 | def getN: Int = a.n // error
| ^^^
| Access field B.this.a.n on a value with an unknown initialization status. Calling trace:
| -> println(foo(x).getB) [ unsound2.scala:8 ]
| -> def foo(y: Int): B = if (y > 10) then B(bar(y - 1), foo(y - 1).getN) else B(bar(y), 10) [ unsound2.scala:2 ]
10 changes: 10 additions & 0 deletions tests/init/neg/unsound2.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
case class A(x: Int) {
def foo(y: Int): B = if (y > 10) then B(bar(y - 1), foo(y - 1).getN) else B(bar(y), 10)
def bar(y: Int): A = if (y > 10) then A(y - 1) else this
class B(a: A, b: Int) {
def getN: Int = a.n // error
def getB: Int = b
}
println(foo(x).getB)
val n: Int = 10
}
5 changes: 5 additions & 0 deletions tests/init/neg/unsound3.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
-- Error: tests/init/neg/unsound3.scala:10:38 --------------------------------------------------------------------------
10 | if (x < 12) then foo().getC().b else newB // error
| ^^^^^^^^^^^^^^
| Access field C.this.foo().getC().b on a value with an unknown initialization status. Calling trace:
| -> val b = foo() [ unsound3.scala:12 ]
13 changes: 13 additions & 0 deletions tests/init/neg/unsound3.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
class B(c: C) {
def getC() = c
}

class C {
var x = 10
def foo(): B = {
x += 1
val newB = new B(this)
if (x < 12) then foo().getC().b else newB // error
}
val b = foo()
}
6 changes: 6 additions & 0 deletions tests/init/neg/unsound4.check
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
-- Error: tests/init/neg/unsound4.scala:3:8 ----------------------------------------------------------------------------
3 | val aAgain = foo(5) // error
| ^
| Access non-initialized value aAgain. Calling trace:
| -> val aAgain = foo(5) // error [ unsound4.scala:3 ]
| -> def foo(x: Int): A = if (x < 5) then this else foo(x - 1).aAgain [ unsound4.scala:2 ]
4 changes: 4 additions & 0 deletions tests/init/neg/unsound4.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
class A {
def foo(x: Int): A = if (x < 5) then this else foo(x - 1).aAgain
val aAgain = foo(5) // error
}