diff --git a/compiler/src/dotty/tools/dotc/transform/init/Cache.scala b/compiler/src/dotty/tools/dotc/transform/init/Cache.scala new file mode 100644 index 000000000000..14a52d995131 --- /dev/null +++ b/compiler/src/dotty/tools/dotc/transform/init/Cache.scala @@ -0,0 +1,185 @@ +package dotty.tools.dotc +package transform +package init + +import core.* +import Contexts.* + +import ast.tpd +import tpd.Tree + +/** The co-inductive cache used for analysis + * + * The cache contains two maps from `(Config, Tree)` to `Res`: + * + * - input cache (`this.last`) + * - output cache (`this.current`) + * + * The two caches are required because we want to make sure in a new iteration, + * an expression is evaluated exactly once. The monotonicity of the analysis + * ensures that the cache state goes up the lattice of the abstract domain, + * consequently the algorithm terminates. + * + * The general skeleton for usage of the cache is as follows + * + * def analysis(entryExp: Expr) = { + * def iterate(entryExp: Expr)(using Cache) = + * eval(entryExp, initConfig) + * if cache.hasChanged && noErrors then + * cache.last = cache.current + * cache.current = Empty + * cache.changed = false + * iterate(entryExp) + * else + * reportErrors + * + * + * def eval(expr: Expr, config: Config)(using Cache) = + * cache.cachedEval(config, expr) { + * // Actual recursive evaluation of expression. + * // + * // Only executed if the entry `(exp, config)` is not in the output cache. + * } + * + * iterate(entryExp)(using new Cache) + * } + * + * See the documentation for the method `Cache.cachedEval` for more information. + * + * What goes to the configuration (`Config`) and what goes to the result (`Res`) + * need to be decided by the specific analysis and justified by reasoning about + * soundness. + * + * @param Config The analysis state that matters for evaluating an expression. + * @param Res The result from the evaluation the given expression. + */ +class Cache[Config, Res]: + import Cache.* + + /** The cache for expression values from last iteration */ + protected var last: ExprValueCache[Config, Res] = Map.empty + + /** The output cache for expression values + * + * The output cache is computed based on the cache values `last` from the + * last iteration. + * + * Both `last` and `current` are required to make sure an encountered + * expression is evaluated once in each iteration. + */ + protected var current: ExprValueCache[Config, Res] = Map.empty + + /** Whether the current heap is different from the last heap? + * + * `changed == false` implies that the fixed point has been reached. + */ + protected var changed: Boolean = false + + /** Used to avoid allocation, its state does not matter */ + protected given MutableTreeWrapper = new MutableTreeWrapper + + def get(config: Config, expr: Tree): Option[Res] = + current.get(config, expr) + + /** Evaluate an expression with cache + * + * The algorithmic skeleton is as follows: + * + * if this.current.contains(config, expr) then + * return cached value + * else + * val assumed = this.last(config, expr) or bottom value if absent + * this.current(config, expr) = assumed + * val actual = eval(exp) + * + * if assumed != actual then + * this.changed = true + * this.current(config, expr) = actual + * + */ + def cachedEval(config: Config, expr: Tree, cacheResult: Boolean, default: Res)(eval: Tree => Res): Res = + this.get(config, expr) match + case Some(value) => value + case None => + val assumeValue: Res = + this.last.get(config, expr) match + case Some(value) => value + case None => + this.last = this.last.updatedNested(config, expr, default) + default + + this.current = this.current.updatedNested(config, expr, assumeValue) + + val actual = eval(expr) + if actual != assumeValue then + // println("Changed! from = " + assumeValue + ", to = " + actual) + this.changed = true + // TODO: respect cacheResult to reduce cache size + this.current = this.current.updatedNested(config, expr, actual) + // this.current = this.current.removed(config, expr) + end if + + actual + end cachedEval + + def hasChanged = changed + + /** Prepare cache for the next iteration + * + * 1. Reset changed flag. + * + * 2. Use current cache as last cache and set current cache to be empty. + */ + def prepareForNextIteration()(using Context) = + this.changed = false + this.last = this.current + this.current = Map.empty +end Cache + +object Cache: + type ExprValueCache[Config, Res] = Map[Config, Map[TreeWrapper, Res]] + + /** A wrapper for trees for storage in maps based on referential equality of trees. */ + abstract class TreeWrapper: + def tree: Tree + + override final def equals(other: Any): Boolean = + other match + case that: TreeWrapper => this.tree eq that.tree + case _ => false + + override final def hashCode = tree.hashCode + + /** The immutable wrapper is intended to be stored as key in the heap. */ + class ImmutableTreeWrapper(val tree: Tree) extends TreeWrapper + + /** For queries on the heap, reuse the same wrapper to avoid unnecessary allocation. + * + * A `MutableTreeWrapper` is only ever used temporarily for querying a map, + * and is never inserted to the map. + */ + class MutableTreeWrapper extends TreeWrapper: + var queryTree: Tree | Null = null + def tree: Tree = queryTree match + case tree: Tree => tree + case null => ??? + + extension [Config, Res](cache: ExprValueCache[Config, Res]) + def get(config: Config, expr: Tree)(using queryWrapper: MutableTreeWrapper): Option[Res] = + queryWrapper.queryTree = expr + cache.get(config).flatMap(_.get(queryWrapper)) + + def removed(config: Config, expr: Tree)(using queryWrapper: MutableTreeWrapper) = + queryWrapper.queryTree = expr + val innerMap2 = cache(config).removed(queryWrapper) + cache.updated(config, innerMap2) + + def updatedNested(config: Config, expr: Tree, result: Res): ExprValueCache[Config, Res] = + val wrapper = new ImmutableTreeWrapper(expr) + updatedNestedWrapper(config, wrapper, result) + + def updatedNestedWrapper(config: Config, wrapper: ImmutableTreeWrapper, result: Res): ExprValueCache[Config, Res] = + val innerMap = cache.getOrElse(config, Map.empty[TreeWrapper, Res]) + val innerMap2 = innerMap.updated(wrapper, result) + cache.updated(config, innerMap2) + end extension diff --git a/compiler/src/dotty/tools/dotc/transform/init/Errors.scala b/compiler/src/dotty/tools/dotc/transform/init/Errors.scala index 7d92d2b2a921..762e029ba36f 100644 --- a/compiler/src/dotty/tools/dotc/transform/init/Errors.scala +++ b/compiler/src/dotty/tools/dotc/transform/init/Errors.scala @@ -5,106 +5,61 @@ package init import ast.tpd._ import core._ -import util.SourcePosition import util.Property -import Decorators._, printing.SyntaxHighlighting +import util.SourcePosition import Types._, Symbols._, Contexts._ -import scala.collection.mutable +import Trace.Trace object Errors: private val IsFromPromotion = new Property.Key[Boolean] sealed trait Error: - def trace: Seq[Tree] + def trace: Trace def show(using Context): String - def pos(using Context): SourcePosition = trace.last.sourcePos + def pos(using Context): SourcePosition = Trace.position(using trace).sourcePos def stacktrace(using Context): String = val preamble: String = if ctx.property(IsFromPromotion).nonEmpty then " Promotion trace:\n" else " Calling trace:\n" - buildStacktrace(trace, preamble) + Trace.buildStacktrace(trace, preamble) def issue(using Context): Unit = report.warning(show, this.pos) end Error - def buildStacktrace(trace: Seq[Tree], preamble: String)(using Context): String = if trace.isEmpty then "" else preamble + { - var lastLineNum = -1 - var lines: mutable.ArrayBuffer[String] = new mutable.ArrayBuffer - trace.foreach { tree => - val pos = tree.sourcePos - val prefix = "-> " - val line = - if pos.source.exists then - val loc = "[ " + pos.source.file.name + ":" + (pos.line + 1) + " ]" - val code = SyntaxHighlighting.highlight(pos.lineContent.trim.nn) - i"$code\t$loc" - else - tree.show - val positionMarkerLine = - if pos.exists && pos.source.exists then - positionMarker(pos) - else "" - - // always use the more precise trace location - if lastLineNum == pos.line then - lines.dropRightInPlace(1) - - lines += (prefix + line + "\n" + positionMarkerLine) - - lastLineNum = pos.line - } - val sb = new StringBuilder - for line <- lines do sb.append(line) - sb.toString - } - - /** Used to underline source positions in the stack trace - * pos.source must exist - */ - private def positionMarker(pos: SourcePosition): String = - val trimmed = pos.lineContent.takeWhile(c => c.isWhitespace).length - val padding = pos.startColumnPadding.substring(trimmed).nn + " " - val carets = - if (pos.startLine == pos.endLine) - "^" * math.max(1, pos.endColumn - pos.startColumn) - else "^" - - s"$padding$carets\n" - override def toString() = this.getClass.getName.nn /** Access non-initialized field */ - case class AccessNonInit(field: Symbol)(val trace: Seq[Tree]) extends Error: - def source: Tree = trace.last + case class AccessNonInit(field: Symbol)(val trace: Trace) extends Error: + def source: Tree = Trace.position(using trace) def show(using Context): String = "Access non-initialized " + field.show + "." + stacktrace override def pos(using Context): SourcePosition = field.sourcePos /** Promote a value under initialization to fully-initialized */ - case class PromoteError(msg: String)(val trace: Seq[Tree]) extends Error: + case class PromoteError(msg: String)(val trace: Trace) extends Error: def show(using Context): String = msg + stacktrace - case class AccessCold(field: Symbol)(val trace: Seq[Tree]) extends Error: + case class AccessCold(field: Symbol)(val trace: Trace) extends Error: def show(using Context): String = "Access field " + field.show + " on a cold object." + stacktrace - case class CallCold(meth: Symbol)(val trace: Seq[Tree]) extends Error: + case class CallCold(meth: Symbol)(val trace: Trace) extends Error: def show(using Context): String = "Call method " + meth.show + " on a cold object." + stacktrace - case class CallUnknown(meth: Symbol)(val trace: Seq[Tree]) extends Error: + case class CallUnknown(meth: Symbol)(val trace: Trace) extends Error: def show(using Context): String = val prefix = if meth.is(Flags.Method) then "Calling the external method " else "Accessing the external field" prefix + meth.show + " may cause initialization errors." + stacktrace /** Promote a value under initialization to fully-initialized */ - case class UnsafePromotion(msg: String, error: Error)(val trace: Seq[Tree]) extends Error: + case class UnsafePromotion(msg: String, error: Error)(val trace: Trace) extends Error: def show(using Context): String = msg + stacktrace + "\n" + "Promoting the value to hot (transitively initialized) failed due to the following problem:\n" + { @@ -116,7 +71,7 @@ object Errors: * * Invariant: argsIndices.nonEmpty */ - case class UnsafeLeaking(error: Error, nonHotOuterClass: Symbol, argsIndices: List[Int])(val trace: Seq[Tree]) extends Error: + case class UnsafeLeaking(error: Error, nonHotOuterClass: Symbol, argsIndices: List[Int])(val trace: Trace) extends Error: def show(using Context): String = "Problematic object instantiation: " + argumentInfo() + stacktrace + "\n" + "It leads to the following error during object initialization:\n" + diff --git a/compiler/src/dotty/tools/dotc/transform/init/Semantic.scala b/compiler/src/dotty/tools/dotc/transform/init/Semantic.scala index eb1692e00a12..286e3a124d12 100644 --- a/compiler/src/dotty/tools/dotc/transform/init/Semantic.scala +++ b/compiler/src/dotty/tools/dotc/transform/init/Semantic.scala @@ -15,10 +15,19 @@ import config.Printers.init as printer import reporting.trace as log import Errors.* +import Trace.* +import Util.* +import Cache.* import scala.collection.mutable import scala.annotation.tailrec +/** + * Checks safe initialization of objects + * + * This algorithm cannot handle safe access of global object names. That part + * is handled by the check in `Objects` (@see Objects). + */ object Semantic: // ----- Domain definitions -------------------------------- @@ -117,7 +126,7 @@ object Semantic: assert(!populatingParams, "the object is already populating parameters") populatingParams = true val tpl = klass.defTree.asInstanceOf[TypeDef].rhs.asInstanceOf[Template] - extendTrace(klass.defTree) { this.callConstructor(ctor, args.map(arg => ArgInfo(arg, trace))) } + extendTrace(klass.defTree) { this.callConstructor(ctor, args.map(arg => new ArgInfo(arg, trace))) } populatingParams = false this } @@ -207,7 +216,7 @@ object Semantic: object Cache: /** Cache for expressions * - * Ref -> Tree -> Value + * Value -> Tree -> Value * * The first key is the value of `this` for the expression. * @@ -233,66 +242,27 @@ object Semantic: * that could be reused to check other classes. We employ this trick to * improve performance of the analysis. */ - private type ExprValueCache = Map[Value, Map[TreeWrapper, Value]] /** The heap for abstract objects * - * The heap objects are immutable. - */ - private type Heap = Map[Ref, Objekt] - - /** A wrapper for trees for storage in maps based on referential equality of trees. */ - private abstract class TreeWrapper: - def tree: Tree - - override final def equals(other: Any): Boolean = - other match - case that: TreeWrapper => this.tree eq that.tree - case _ => false - - override final def hashCode = tree.hashCode - - /** The immutable wrapper is intended to be stored as key in the heap. */ - private class ImmutableTreeWrapper(val tree: Tree) extends TreeWrapper - - /** For queries on the heap, reuse the same wrapper to avoid unnecessary allocation. + * The heap objects are immutable and its values are essentially derived + * from the cache, thus they are not part of the configuration. * - * A `MutableTreeWrapper` is only ever used temporarily for querying a map, - * and is never inserted to the map. + * The only exception is the object correspond to `ThisRef`, where the + * object remembers the set of initialized fields. That information is reset + * in each iteration thus is harmless. */ - private class MutableTreeWrapper extends TreeWrapper: - var queryTree: Tree | Null = null - def tree: Tree = queryTree match - case tree: Tree => tree - case null => ??? - - class Cache: - /** The cache for expression values from last iteration */ - private var last: ExprValueCache = Map.empty + private type Heap = Map[Ref, Objekt] - /** The output cache for expression values - * - * The output cache is computed based on the cache values `last` from the - * last iteration. - * - * Both `last` and `current` are required to make sure an encountered - * expression is evaluated once in each iteration. - */ - private var current: ExprValueCache = Map.empty + class Data extends Cache[Value, Value]: /** Global cached values for expressions * * The values are only added when a fixed point is reached. * * It is intended to improve performance for computation related to warm values. */ - private var stable: ExprValueCache = Map.empty - - /** Whether the current heap is different from the last heap? - * - * `changed == false` implies that the fixed point has been reached. - */ - private var changed: Boolean = false + private var stable: ExprValueCache[Value, Value] = Map.empty /** Abstract heap stores abstract objects * @@ -320,77 +290,38 @@ object Semantic: /** Used to revert heap to last stable heap. */ private var heapStable: Heap = Map.empty - /** Used to avoid allocation, its state does not matter */ - private given MutableTreeWrapper = new MutableTreeWrapper - - def hasChanged = changed - - def get(value: Value, expr: Tree): Option[Value] = - current.get(value, expr) match - case None => stable.get(value, expr) + override def get(value: Value, expr: Tree): Option[Value] = + stable.get(value, expr) match + case None => super.get(value, expr) case res => res /** Backup the state of the cache * * All the shared data structures must be immutable. */ - def backup(): Cache = - val cache = new Cache - cache.last = this.last - cache.current = this.current + def backup(): Data = + val cache = new Data cache.stable = this.stable cache.heap = this.heap cache.heapStable = this.heapStable cache.changed = this.changed + cache.last = this.last + cache.current = this.current cache /** Restore state from a backup */ - def restore(cache: Cache) = + def restore(cache: Data) = + this.changed = cache.changed this.last = cache.last this.current = cache.current this.stable = cache.stable this.heap = cache.heap this.heapStable = cache.heapStable - this.changed = cache.changed - - /** Copy the value of `(value, expr)` from the last cache to the current cache - * - * It assumes the value is `Hot` if it doesn't exist in the last cache. - * - * It updates the current caches if the values change. - * - * The two caches are required because we want to make sure in a new iteration, an expression is evaluated once. - */ - def assume(value: Value, expr: Tree, cacheResult: Boolean)(fun: => Value): Contextual[Value] = - val assumeValue: Value = - last.get(value, expr) match - case Some(value) => value - case None => - this.last = last.updatedNested(value, expr, Hot) - Hot - - this.current = current.updatedNested(value, expr, assumeValue) - - val actual = fun - if actual != assumeValue then - this.changed = true - this.current = this.current.updatedNested(value, expr, actual) - else - // It's tempting to cache the value in stable, but it's unsound. - // The reason is that the current value may depend on other values - // which might change. - // - // stable.put(value, expr, actual) - () - end if - - actual - end assume /** Commit current cache to stable cache. */ private def commitToStableCache() = for - (v, m) <- current + (v, m) <- this.current if v.isWarm // It's useless to cache value for ThisRef. (wrapper, res) <- m do @@ -404,10 +335,8 @@ object Semantic: * * 3. Revert heap to stable. */ - def prepareForNextIteration()(using Context) = - this.changed = false - this.last = this.current - this.current = Map.empty + override def prepareForNextIteration()(using Context) = + super.prepareForNextIteration() this.heap = this.heapStable /** Prepare for checking next class @@ -421,15 +350,15 @@ object Semantic: * 4. Reset last cache. */ def prepareForNextClass()(using Context) = - if this.changed then - this.changed = false + if this.hasChanged then this.heap = this.heapStable else this.commitToStableCache() this.heapStable = this.heap - this.last = Map.empty - this.current = Map.empty + // reset changed and cache + super.prepareForNextIteration() + def updateObject(ref: Ref, obj: Objekt) = assert(!this.heapStable.contains(ref)) @@ -438,59 +367,19 @@ object Semantic: def containsObject(ref: Ref) = heap.contains(ref) def getObject(ref: Ref) = heap(ref) - end Cache - - extension (cache: ExprValueCache) - private def get(value: Value, expr: Tree)(using queryWrapper: MutableTreeWrapper): Option[Value] = - queryWrapper.queryTree = expr - cache.get(value).flatMap(_.get(queryWrapper)) - - private def removed(value: Value, expr: Tree)(using queryWrapper: MutableTreeWrapper) = - queryWrapper.queryTree = expr - val innerMap2 = cache(value).removed(queryWrapper) - cache.updated(value, innerMap2) - - private def updatedNested(value: Value, expr: Tree, result: Value): ExprValueCache = - val wrapper = new ImmutableTreeWrapper(expr) - updatedNestedWrapper(value, wrapper, result) - - private def updatedNestedWrapper(value: Value, wrapper: ImmutableTreeWrapper, result: Value): ExprValueCache = - val innerMap = cache.getOrElse(value, Map.empty[TreeWrapper, Value]) - val innerMap2 = innerMap.updated(wrapper, result) - cache.updated(value, innerMap2) - end extension - end Cache + end Data - import Cache.* + end Cache - inline def cache(using c: Cache): Cache = c + inline def cache(using c: Cache.Data): Cache.Data = c // ----- Checker State ----------------------------------- /** The state that threads through the interpreter */ - type Contextual[T] = (Context, Trace, Promoted, Cache, Reporter) ?=> T + type Contextual[T] = (Context, Trace, Promoted, Cache.Data, Reporter) ?=> T // ----- Error Handling ----------------------------------- - object Trace: - opaque type Trace = Vector[Tree] - - val empty: Trace = Vector.empty - - extension (trace: Trace) - def add(node: Tree): Trace = trace :+ node - def toVector: Vector[Tree] = trace - - def show(using trace: Trace, ctx: Context): String = buildStacktrace(trace, "\n") - - def position(using trace: Trace): Tree = trace.last - type Trace = Trace.Trace - - import Trace.* - def trace(using t: Trace): Trace = t - inline def withTrace[T](t: Trace)(op: Trace ?=> T): T = op(using t) - inline def extendTrace[T](node: Tree)(using t: Trace)(op: Trace ?=> T): T = op(using t.add(node)) - /** Error reporting */ trait Reporter: def report(err: Error): Unit @@ -508,7 +397,7 @@ object Semantic: /** * Revert the cache to previous state. */ - def abort()(using Cache): Unit + def abort()(using Cache.Data): Unit def errors: List[Error] object Reporter: @@ -517,8 +406,8 @@ object Semantic: def errors = buf.toList def report(err: Error) = buf += err - class TryBufferedReporter(backup: Cache) extends BufferedReporter with TryReporter: - def abort()(using Cache): Unit = cache.restore(backup) + class TryBufferedReporter(backup: Cache.Data) extends BufferedReporter with TryReporter: + def abort()(using Cache.Data): Unit = cache.restore(backup) class ErrorFound(val error: Error) extends Exception class StopEarlyReporter extends Reporter: @@ -529,7 +418,7 @@ object Semantic: * The TryReporter cannot be thrown away: either `abort` must be called or * the errors must be reported. */ - def errorsIn(fn: Reporter ?=> Unit)(using Cache): TryReporter = + def errorsIn(fn: Reporter ?=> Unit)(using Cache.Data): TryReporter = val reporter = new TryBufferedReporter(cache.backup()) fn(using reporter) reporter @@ -544,7 +433,7 @@ object Semantic: catch case ex: ErrorFound => ex.error :: Nil - def hasErrors(fn: Reporter ?=> Unit)(using Cache): Boolean = + def hasErrors(fn: Reporter ?=> Unit)(using Cache.Data): Boolean = val backup = cache.backup() val errors = stopEarly(fn) cache.restore(backup) @@ -606,14 +495,14 @@ object Semantic: case _ => cache.getObject(ref) - def ensureObjectExists()(using Cache): ref.type = + def ensureObjectExists()(using Cache.Data): ref.type = if cache.containsObject(ref) then printer.println("object " + ref + " already exists") ref else ensureFresh() - def ensureFresh()(using Cache): ref.type = + def ensureFresh()(using Cache.Data): ref.type = val obj = Objekt(ref.klass, fields = Map.empty, outers = Map(ref.klass -> ref.outer)) printer.println("reset object " + ref) cache.updateObject(ref, obj) @@ -664,7 +553,7 @@ object Semantic: Hot case Cold => - val error = AccessCold(field)(trace.toVector) + val error = AccessCold(field)(trace) reporter.report(error) Hot @@ -689,11 +578,11 @@ object Semantic: val rhs = target.defTree.asInstanceOf[ValOrDefDef].rhs eval(rhs, ref, target.owner.asClass, cacheResult = true) else - val error = CallUnknown(field)(trace.toVector) + val error = CallUnknown(field)(trace) reporter.report(error) Hot else - val error = AccessNonInit(target)(trace.toVector) + val error = AccessNonInit(target)(trace) reporter.report(error) Hot else @@ -779,7 +668,7 @@ object Semantic: case Cold => promoteArgs() - val error = CallCold(meth)(trace.toVector) + val error = CallCold(meth)(trace) reporter.report(error) Hot @@ -820,7 +709,7 @@ object Semantic: // try promoting the receiver as last resort val hasErrors = Reporter.hasErrors { ref.promote("try promote value to hot") } if hasErrors then - val error = CallUnknown(target)(trace.toVector) + val error = CallUnknown(target)(trace) reporter.report(error) Hot else if target.exists then @@ -899,7 +788,7 @@ object Semantic: Hot else // no source code available - val error = CallUnknown(ctor)(trace.toVector) + val error = CallUnknown(ctor)(trace) reporter.report(error) Hot } @@ -922,7 +811,7 @@ object Semantic: yield i + 1 - val error = UnsafeLeaking(errors.head, nonHotOuterClass, indices)(trace.toVector) + val error = UnsafeLeaking(errors.head, nonHotOuterClass, indices)(trace) reporter.report(error) Hot else @@ -947,7 +836,7 @@ object Semantic: tryLeak(warm, NoSymbol, args2) case Cold => - val error = CallCold(ctor)(trace.toVector) + val error = CallCold(ctor)(trace) reporter.report(error) Hot @@ -1078,7 +967,7 @@ object Semantic: case Hot => case Cold => - reporter.report(PromoteError(msg)(trace.toVector)) + reporter.report(PromoteError(msg)(trace)) case thisRef: ThisRef => val emptyFields = thisRef.nonInitFields() @@ -1086,7 +975,7 @@ object Semantic: promoted.promoteCurrent(thisRef) else val fields = "Non initialized field(s): " + emptyFields.map(_.show).mkString(", ") + "." - reporter.report(PromoteError(msg + "\n" + fields)(trace.toVector)) + reporter.report(PromoteError(msg + "\n" + fields)(trace)) case warm: Warm => if !promoted.contains(warm) then @@ -1106,7 +995,7 @@ object Semantic: res.promote("The function return value is not hot. Found = " + res.show + ".") } if errors.nonEmpty then - reporter.report(UnsafePromotion(msg, errors.head)(trace.toVector)) + reporter.report(UnsafePromotion(msg, errors.head)(trace)) else promoted.add(fun) @@ -1156,12 +1045,12 @@ object Semantic: if !isHotSegment then for member <- klass.info.decls do if member.isClass then - val error = PromoteError("Promotion cancelled as the value contains inner " + member.show + ".")(Vector.empty) + val error = PromoteError("Promotion cancelled as the value contains inner " + member.show + ".")(Trace.empty) reporter.report(error) else if !member.isType && !member.isConstructor && !member.is(Flags.Deferred) then given Trace = Trace.empty if member.is(Flags.Method, butNot = Flags.Accessor) then - val args = member.info.paramInfoss.flatten.map(_ => ArgInfo(Hot, Trace.empty)) + val args = member.info.paramInfoss.flatten.map(_ => new ArgInfo(Hot: Value, Trace.empty)) val res = warm.call(member, args, receiver = warm.klass.typeRef, superType = NoType) withTrace(trace.add(member.defTree)) { res.promote("Cannot prove that the return value of " + member.show + " is hot. Found = " + res.show + ".") @@ -1189,7 +1078,7 @@ object Semantic: } if errors.isEmpty then Nil - else UnsafePromotion(msg, errors.head)(trace.toVector) :: Nil + else UnsafePromotion(msg, errors.head)(trace) :: Nil } end extension @@ -1212,7 +1101,7 @@ object Semantic: * * The class to be checked must be an instantiable concrete class. */ - private def checkClass(classSym: ClassSymbol)(using Cache, Context): Unit = + private def checkClass(classSym: ClassSymbol)(using Cache.Data, Context): Unit = val thisRef = ThisRef(classSym) val tpl = classSym.defTree.asInstanceOf[TypeDef].rhs.asInstanceOf[Template] @@ -1246,16 +1135,16 @@ object Semantic: * Check the specified concrete classes */ def checkClasses(classes: List[ClassSymbol])(using Context): Unit = - given Cache() + given Cache.Data() for classSym <- classes if isConcreteClass(classSym) do checkClass(classSym) // ----- Semantic definition -------------------------------- + type ArgInfo = TraceValue[Value] - /** Utility definition used for better error-reporting of argument errors */ - case class ArgInfo(value: Value, trace: Trace): - def promote: Contextual[Unit] = withTrace(trace) { - value.promote("Cannot prove the method argument is hot. Only hot values are safe to leak.\nFound = " + value.show + ".") + extension (arg: ArgInfo) + def promote: Contextual[Unit] = withTrace(arg.trace) { + arg.value.promote("Cannot prove the method argument is hot. Only hot values are safe to leak.\nFound = " + arg.value.show + ".") } /** Evaluate an expression with the given value for `this` in a given class `klass` @@ -1279,10 +1168,7 @@ object Semantic: * @param cacheResult It is used to reduce the size of the cache. */ def eval(expr: Tree, thisV: Ref, klass: ClassSymbol, cacheResult: Boolean = false): Contextual[Value] = log("evaluating " + expr.show + ", this = " + thisV.show + " in " + klass.show, printer, (_: Value).show) { - cache.get(thisV, expr) match - case Some(value) => value - case None => - cache.assume(thisV, expr, cacheResult) { cases(expr, thisV, klass) } + cache.cachedEval(thisV, expr, cacheResult, default = Hot) { expr => cases(expr, thisV, klass) } } /** Evaluate a list of expressions */ @@ -1299,7 +1185,7 @@ object Semantic: else eval(arg.tree, thisV, klass) - argInfos += ArgInfo(res, trace.add(arg.tree)) + argInfos += new ArgInfo(res, trace.add(arg.tree)) } argInfos.toList @@ -1667,7 +1553,7 @@ object Semantic: // The parameter check of traits comes late in the mixin phase. // To avoid crash we supply hot values for erroneous parent calls. // See tests/neg/i16438.scala. - val args: List[ArgInfo] = ctor.info.paramInfoss.flatten.map(_ => ArgInfo(Hot, Trace.empty)) + val args: List[ArgInfo] = ctor.info.paramInfoss.flatten.map(_ => new ArgInfo(Hot, Trace.empty)) extendTrace(superParent) { superCall(tref, ctor, args, tasks) } @@ -1726,85 +1612,3 @@ object Semantic: traverseChildren(tp) traverser.traverse(tpt.tpe) - -// ----- Utility methods and extractors -------------------------------- - - def typeRefOf(tp: Type)(using Context): TypeRef = tp.dealias.typeConstructor match - case tref: TypeRef => tref - case hklambda: HKTypeLambda => typeRefOf(hklambda.resType) - - - opaque type Arg = Tree | ByNameArg - case class ByNameArg(tree: Tree) - - extension (arg: Arg) - def isByName = arg.isInstanceOf[ByNameArg] - def tree: Tree = arg match - case t: Tree => t - case ByNameArg(t) => t - - object Call: - - def unapply(tree: Tree)(using Context): Option[(Tree, List[List[Arg]])] = - tree match - case Apply(fn, args) => - val argTps = fn.tpe.widen match - case mt: MethodType => mt.paramInfos - val normArgs: List[Arg] = args.zip(argTps).map { - case (arg, _: ExprType) => ByNameArg(arg) - case (arg, _) => arg - } - unapply(fn) match - case Some((ref, args0)) => Some((ref, args0 :+ normArgs)) - case None => None - - case TypeApply(fn, targs) => - unapply(fn) - - case ref: RefTree if ref.tpe.widenSingleton.isInstanceOf[MethodicType] => - Some((ref, Nil)) - - case _ => None - - object NewExpr: - def unapply(tree: Tree)(using Context): Option[(TypeRef, New, Symbol, List[List[Arg]])] = - tree match - case Call(fn @ Select(newTree: New, init), argss) if init == nme.CONSTRUCTOR => - val tref = typeRefOf(newTree.tpe) - Some((tref, newTree, fn.symbol, argss)) - case _ => None - - object PolyFun: - def unapply(tree: Tree)(using Context): Option[Tree] = - tree match - case Block((cdef: TypeDef) :: Nil, Typed(NewExpr(tref, _, _, _), _)) - if tref.symbol.isAnonymousClass && tref <:< defn.PolyFunctionType - => - val body = cdef.rhs.asInstanceOf[Template].body - val apply = body.head.asInstanceOf[DefDef] - Some(apply.rhs) - case _ => - None - - extension (symbol: Symbol) def hasSource(using Context): Boolean = - !symbol.defTree.isEmpty - - def resolve(cls: ClassSymbol, sym: Symbol)(using Context): Symbol = log("resove " + cls + ", " + sym, printer, (_: Symbol).show) { - if (sym.isEffectivelyFinal || sym.isConstructor) sym - else sym.matchingMember(cls.appliedRef) - } - - private def isConcreteClass(cls: ClassSymbol)(using Context) = { - val instantiable: Boolean = - cls.is(Flags.Module) || - !cls.isOneOf(Flags.AbstractOrTrait) && { - // see `Checking.checkInstantiable` in typer - val tp = cls.appliedRef - val stp = SkolemType(tp) - val selfType = cls.givenSelfType.asSeenFrom(stp, cls) - !selfType.exists || stp <:< selfType - } - - // A concrete class may not be instantiated if the self type is not satisfied - instantiable && cls.enclosingPackageClass != defn.StdLibPatchesPackage.moduleClass - } diff --git a/compiler/src/dotty/tools/dotc/transform/init/Trace.scala b/compiler/src/dotty/tools/dotc/transform/init/Trace.scala new file mode 100644 index 000000000000..7dfbc0b6cfa5 --- /dev/null +++ b/compiler/src/dotty/tools/dotc/transform/init/Trace.scala @@ -0,0 +1,82 @@ +package dotty.tools.dotc +package transform +package init + +import core.* +import Contexts.* +import ast.tpd.* +import util.SourcePosition + +import Decorators._, printing.SyntaxHighlighting + +import scala.collection.mutable + +/** Logic related to evaluation trace for showing friendly error messages + * + * A trace is a sequence of program positions which tells the evaluation order + * that leads to an error. It is usually more informative than the stack trace + * by tracking the exact sub-expression in the trace instead of only methods. + */ +object Trace: + opaque type Trace = Vector[Tree] + + val empty: Trace = Vector.empty + + extension (trace: Trace) + def add(node: Tree): Trace = trace :+ node + def toVector: Vector[Tree] = trace + def ++(trace2: Trace): Trace = trace ++ trace2 + + def show(using trace: Trace, ctx: Context): String = buildStacktrace(trace, "\n") + + def position(using trace: Trace): Tree = trace.last + + def trace(using t: Trace): Trace = t + + inline def withTrace[T](t: Trace)(op: Trace ?=> T): T = op(using t) + + inline def extendTrace[T](node: Tree)(using t: Trace)(op: Trace ?=> T): T = op(using t.add(node)) + + def buildStacktrace(trace: Trace, preamble: String)(using Context): String = if trace.isEmpty then "" else preamble + { + var lastLineNum = -1 + var lines: mutable.ArrayBuffer[String] = new mutable.ArrayBuffer + trace.foreach { tree => + val pos = tree.sourcePos + val prefix = "-> " + val line = + if pos.source.exists then + val loc = "[ " + pos.source.file.name + ":" + (pos.line + 1) + " ]" + val code = SyntaxHighlighting.highlight(pos.lineContent.trim.nn) + i"$code\t$loc" + else + tree.show + val positionMarkerLine = + if pos.exists && pos.source.exists then + positionMarker(pos) + else "" + + // always use the more precise trace location + if lastLineNum == pos.line then + lines.dropRightInPlace(1) + + lines += (prefix + line + "\n" + positionMarkerLine) + + lastLineNum = pos.line + } + val sb = new StringBuilder + for line <- lines do sb.append(line) + sb.toString + } + + /** Used to underline source positions in the stack trace + * pos.source must exist + */ + private def positionMarker(pos: SourcePosition): String = + val trimmed = pos.lineContent.takeWhile(c => c.isWhitespace).length + val padding = pos.startColumnPadding.substring(trimmed).nn + " " + val carets = + if (pos.startLine == pos.endLine) + "^" * math.max(1, pos.endColumn - pos.startColumn) + else "^" + + s"$padding$carets\n" diff --git a/compiler/src/dotty/tools/dotc/transform/init/Util.scala b/compiler/src/dotty/tools/dotc/transform/init/Util.scala new file mode 100644 index 000000000000..4e60c1325b09 --- /dev/null +++ b/compiler/src/dotty/tools/dotc/transform/init/Util.scala @@ -0,0 +1,102 @@ +package dotty.tools.dotc +package transform +package init + +import core.* +import Contexts.* +import Types.* +import Symbols.* +import StdNames.* +import ast.tpd.* + +import reporting.trace as log +import config.Printers.init as printer + +import Trace.* + +object Util: + /** Utility definition used for better error-reporting of argument errors */ + case class TraceValue[T](value: T, trace: Trace) + + def typeRefOf(tp: Type)(using Context): TypeRef = tp.dealias.typeConstructor match + case tref: TypeRef => tref + case hklambda: HKTypeLambda => typeRefOf(hklambda.resType) + + + opaque type Arg = Tree | ByNameArg + case class ByNameArg(tree: Tree) + + extension (arg: Arg) + def isByName = arg.isInstanceOf[ByNameArg] + def tree: Tree = arg match + case t: Tree => t + case ByNameArg(t) => t + + object Call: + + def unapply(tree: Tree)(using Context): Option[(Tree, List[List[Arg]])] = + tree match + case Apply(fn, args) => + val argTps = fn.tpe.widen match + case mt: MethodType => mt.paramInfos + val normArgs: List[Arg] = args.zip(argTps).map { + case (arg, _: ExprType) => ByNameArg(arg) + case (arg, _) => arg + } + unapply(fn) match + case Some((ref, args0)) => Some((ref, args0 :+ normArgs)) + case None => None + + case TypeApply(fn, targs) => + unapply(fn) + + case ref: RefTree if ref.tpe.widenSingleton.isInstanceOf[MethodicType] => + Some((ref, Nil)) + + case _ => None + + object NewExpr: + def unapply(tree: Tree)(using Context): Option[(TypeRef, New, Symbol, List[List[Arg]])] = + tree match + case Call(fn @ Select(newTree: New, init), argss) if init == nme.CONSTRUCTOR => + val tref = typeRefOf(newTree.tpe) + Some((tref, newTree, fn.symbol, argss)) + case _ => None + + object PolyFun: + def unapply(tree: Tree)(using Context): Option[Tree] = + tree match + case Block((cdef: TypeDef) :: Nil, Typed(NewExpr(tref, _, _, _), _)) + if tref.symbol.isAnonymousClass && tref <:< defn.PolyFunctionType + => + val body = cdef.rhs.asInstanceOf[Template].body + val apply = body.head.asInstanceOf[DefDef] + Some(apply.rhs) + case _ => + None + + def resolve(cls: ClassSymbol, sym: Symbol)(using Context): Symbol = log("resove " + cls + ", " + sym, printer, (_: Symbol).show) { + if (sym.isEffectivelyFinal || sym.isConstructor) sym + else sym.matchingMember(cls.appliedRef) + } + + + extension (sym: Symbol) + def hasSource(using Context): Boolean = !sym.defTree.isEmpty + + def isStaticObject(using Context) = + sym.is(Flags.Module, butNot = Flags.Package) && sym.isStatic + + def isConcreteClass(cls: ClassSymbol)(using Context) = + val instantiable: Boolean = + cls.is(Flags.Module) || + !cls.isOneOf(Flags.AbstractOrTrait) && { + // see `Checking.checkInstantiable` in typer + val tp = cls.appliedRef + val stp = SkolemType(tp) + val selfType = cls.givenSelfType.asSeenFrom(stp, cls) + !selfType.exists || stp <:< selfType + } + + // A concrete class may not be instantiated if the self type is not satisfied + instantiable && cls.enclosingPackageClass != defn.StdLibPatchesPackage.moduleClass