Skip to content

Refactor init checker: Extract reusable code #16705

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jan 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
185 changes: 185 additions & 0 deletions compiler/src/dotty/tools/dotc/transform/init/Cache.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
package dotty.tools.dotc
package transform
package init

import core.*
import Contexts.*

import ast.tpd
import tpd.Tree

/** The co-inductive cache used for analysis
*
* The cache contains two maps from `(Config, Tree)` to `Res`:
*
* - input cache (`this.last`)
* - output cache (`this.current`)
*
* The two caches are required because we want to make sure in a new iteration,
* an expression is evaluated exactly once. The monotonicity of the analysis
* ensures that the cache state goes up the lattice of the abstract domain,
* consequently the algorithm terminates.
*
* The general skeleton for usage of the cache is as follows
*
* def analysis(entryExp: Expr) = {
* def iterate(entryExp: Expr)(using Cache) =
* eval(entryExp, initConfig)
* if cache.hasChanged && noErrors then
* cache.last = cache.current
* cache.current = Empty
* cache.changed = false
* iterate(entryExp)
* else
* reportErrors
*
*
* def eval(expr: Expr, config: Config)(using Cache) =
* cache.cachedEval(config, expr) {
* // Actual recursive evaluation of expression.
* //
* // Only executed if the entry `(exp, config)` is not in the output cache.
* }
*
* iterate(entryExp)(using new Cache)
* }
*
* See the documentation for the method `Cache.cachedEval` for more information.
*
* What goes to the configuration (`Config`) and what goes to the result (`Res`)
* need to be decided by the specific analysis and justified by reasoning about
* soundness.
*
* @param Config The analysis state that matters for evaluating an expression.
* @param Res The result from the evaluation the given expression.
*/
class Cache[Config, Res]:
import Cache.*

/** The cache for expression values from last iteration */
protected var last: ExprValueCache[Config, Res] = Map.empty

/** The output cache for expression values
*
* The output cache is computed based on the cache values `last` from the
* last iteration.
*
* Both `last` and `current` are required to make sure an encountered
* expression is evaluated once in each iteration.
*/
protected var current: ExprValueCache[Config, Res] = Map.empty

/** Whether the current heap is different from the last heap?
*
* `changed == false` implies that the fixed point has been reached.
*/
protected var changed: Boolean = false
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is changed part of the cache? Is the heap now intended to be independent of the cache? Should changed and hasChanged be renamed to indicate that they are about the heap?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, changed is part of the cache, it's not related to the heap.


/** Used to avoid allocation, its state does not matter */
protected given MutableTreeWrapper = new MutableTreeWrapper

def get(config: Config, expr: Tree): Option[Res] =
current.get(config, expr)

/** Evaluate an expression with cache
*
* The algorithmic skeleton is as follows:
*
* if this.current.contains(config, expr) then
* return cached value
* else
* val assumed = this.last(config, expr) or bottom value if absent
* this.current(config, expr) = assumed
* val actual = eval(exp)
*
* if assumed != actual then
* this.changed = true
* this.current(config, expr) = actual
*
*/
def cachedEval(config: Config, expr: Tree, cacheResult: Boolean, default: Res)(eval: Tree => Res): Res =
this.get(config, expr) match
case Some(value) => value
case None =>
val assumeValue: Res =
this.last.get(config, expr) match
case Some(value) => value
case None =>
this.last = this.last.updatedNested(config, expr, default)
default

this.current = this.current.updatedNested(config, expr, assumeValue)

val actual = eval(expr)
if actual != assumeValue then
// println("Changed! from = " + assumeValue + ", to = " + actual)
this.changed = true
// TODO: respect cacheResult to reduce cache size
this.current = this.current.updatedNested(config, expr, actual)
// this.current = this.current.removed(config, expr)
end if

actual
end cachedEval

def hasChanged = changed

/** Prepare cache for the next iteration
*
* 1. Reset changed flag.
*
* 2. Use current cache as last cache and set current cache to be empty.
*/
def prepareForNextIteration()(using Context) =
this.changed = false
this.last = this.current
this.current = Map.empty
end Cache

object Cache:
type ExprValueCache[Config, Res] = Map[Config, Map[TreeWrapper, Res]]

/** A wrapper for trees for storage in maps based on referential equality of trees. */
abstract class TreeWrapper:
def tree: Tree

override final def equals(other: Any): Boolean =
other match
case that: TreeWrapper => this.tree eq that.tree
case _ => false

override final def hashCode = tree.hashCode

/** The immutable wrapper is intended to be stored as key in the heap. */
class ImmutableTreeWrapper(val tree: Tree) extends TreeWrapper

/** For queries on the heap, reuse the same wrapper to avoid unnecessary allocation.
*
* A `MutableTreeWrapper` is only ever used temporarily for querying a map,
* and is never inserted to the map.
*/
class MutableTreeWrapper extends TreeWrapper:
var queryTree: Tree | Null = null
def tree: Tree = queryTree match
case tree: Tree => tree
case null => ???

extension [Config, Res](cache: ExprValueCache[Config, Res])
def get(config: Config, expr: Tree)(using queryWrapper: MutableTreeWrapper): Option[Res] =
queryWrapper.queryTree = expr
cache.get(config).flatMap(_.get(queryWrapper))

def removed(config: Config, expr: Tree)(using queryWrapper: MutableTreeWrapper) =
queryWrapper.queryTree = expr
val innerMap2 = cache(config).removed(queryWrapper)
cache.updated(config, innerMap2)

def updatedNested(config: Config, expr: Tree, result: Res): ExprValueCache[Config, Res] =
val wrapper = new ImmutableTreeWrapper(expr)
updatedNestedWrapper(config, wrapper, result)

def updatedNestedWrapper(config: Config, wrapper: ImmutableTreeWrapper, result: Res): ExprValueCache[Config, Res] =
val innerMap = cache.getOrElse(config, Map.empty[TreeWrapper, Res])
val innerMap2 = innerMap.updated(wrapper, result)
cache.updated(config, innerMap2)
end extension
71 changes: 13 additions & 58 deletions compiler/src/dotty/tools/dotc/transform/init/Errors.scala
Original file line number Diff line number Diff line change
Expand Up @@ -5,106 +5,61 @@ package init

import ast.tpd._
import core._
import util.SourcePosition
import util.Property
import Decorators._, printing.SyntaxHighlighting
import util.SourcePosition
import Types._, Symbols._, Contexts._

import scala.collection.mutable
import Trace.Trace

object Errors:
private val IsFromPromotion = new Property.Key[Boolean]

sealed trait Error:
def trace: Seq[Tree]
def trace: Trace
def show(using Context): String

def pos(using Context): SourcePosition = trace.last.sourcePos
def pos(using Context): SourcePosition = Trace.position(using trace).sourcePos

def stacktrace(using Context): String =
val preamble: String =
if ctx.property(IsFromPromotion).nonEmpty
then " Promotion trace:\n"
else " Calling trace:\n"
buildStacktrace(trace, preamble)
Trace.buildStacktrace(trace, preamble)

def issue(using Context): Unit =
report.warning(show, this.pos)
end Error

def buildStacktrace(trace: Seq[Tree], preamble: String)(using Context): String = if trace.isEmpty then "" else preamble + {
var lastLineNum = -1
var lines: mutable.ArrayBuffer[String] = new mutable.ArrayBuffer
trace.foreach { tree =>
val pos = tree.sourcePos
val prefix = "-> "
val line =
if pos.source.exists then
val loc = "[ " + pos.source.file.name + ":" + (pos.line + 1) + " ]"
val code = SyntaxHighlighting.highlight(pos.lineContent.trim.nn)
i"$code\t$loc"
else
tree.show
val positionMarkerLine =
if pos.exists && pos.source.exists then
positionMarker(pos)
else ""

// always use the more precise trace location
if lastLineNum == pos.line then
lines.dropRightInPlace(1)

lines += (prefix + line + "\n" + positionMarkerLine)

lastLineNum = pos.line
}
val sb = new StringBuilder
for line <- lines do sb.append(line)
sb.toString
}

/** Used to underline source positions in the stack trace
* pos.source must exist
*/
private def positionMarker(pos: SourcePosition): String =
val trimmed = pos.lineContent.takeWhile(c => c.isWhitespace).length
val padding = pos.startColumnPadding.substring(trimmed).nn + " "
val carets =
if (pos.startLine == pos.endLine)
"^" * math.max(1, pos.endColumn - pos.startColumn)
else "^"

s"$padding$carets\n"

override def toString() = this.getClass.getName.nn

/** Access non-initialized field */
case class AccessNonInit(field: Symbol)(val trace: Seq[Tree]) extends Error:
def source: Tree = trace.last
case class AccessNonInit(field: Symbol)(val trace: Trace) extends Error:
def source: Tree = Trace.position(using trace)
def show(using Context): String =
"Access non-initialized " + field.show + "." + stacktrace

override def pos(using Context): SourcePosition = field.sourcePos

/** Promote a value under initialization to fully-initialized */
case class PromoteError(msg: String)(val trace: Seq[Tree]) extends Error:
case class PromoteError(msg: String)(val trace: Trace) extends Error:
def show(using Context): String = msg + stacktrace

case class AccessCold(field: Symbol)(val trace: Seq[Tree]) extends Error:
case class AccessCold(field: Symbol)(val trace: Trace) extends Error:
def show(using Context): String =
"Access field " + field.show + " on a cold object." + stacktrace

case class CallCold(meth: Symbol)(val trace: Seq[Tree]) extends Error:
case class CallCold(meth: Symbol)(val trace: Trace) extends Error:
def show(using Context): String =
"Call method " + meth.show + " on a cold object." + stacktrace

case class CallUnknown(meth: Symbol)(val trace: Seq[Tree]) extends Error:
case class CallUnknown(meth: Symbol)(val trace: Trace) extends Error:
def show(using Context): String =
val prefix = if meth.is(Flags.Method) then "Calling the external method " else "Accessing the external field"
prefix + meth.show + " may cause initialization errors." + stacktrace

/** Promote a value under initialization to fully-initialized */
case class UnsafePromotion(msg: String, error: Error)(val trace: Seq[Tree]) extends Error:
case class UnsafePromotion(msg: String, error: Error)(val trace: Trace) extends Error:
def show(using Context): String =
msg + stacktrace + "\n" +
"Promoting the value to hot (transitively initialized) failed due to the following problem:\n" + {
Expand All @@ -116,7 +71,7 @@ object Errors:
*
* Invariant: argsIndices.nonEmpty
*/
case class UnsafeLeaking(error: Error, nonHotOuterClass: Symbol, argsIndices: List[Int])(val trace: Seq[Tree]) extends Error:
case class UnsafeLeaking(error: Error, nonHotOuterClass: Symbol, argsIndices: List[Int])(val trace: Trace) extends Error:
def show(using Context): String =
"Problematic object instantiation: " + argumentInfo() + stacktrace + "\n" +
"It leads to the following error during object initialization:\n" +
Expand Down
Loading