Skip to content

WIP - Fix #9176: fast check of cyclic object initialization #11913

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 170 additions & 0 deletions compiler/src/dotty/tools/dotc/transform/init/CheckGlobal.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
package dotty.tools.dotc
package transform
package init

import core._
import Flags._
import Contexts._
import Types._
import Symbols._
import Decorators._
import printing.SyntaxHighlighting
import reporting.trace
import config.Printers.init

import ast.tpd._

import scala.collection.mutable


/** Check that static objects can be initialized without cycles
*
* For the check to be fast, the algorithm uses coarse approximation.
* We construct a dependency graph as follows:
*
* - if a static object `O` is used in another class/static-object `B`,
* then O -> B
* - if a class `C` is instantiated in a another class/static-object `B`,
* then C -> B
* - if a static-object/class `A` extends another class `B`,
* then A -> B
*
* Given the graph above, we check if there exists cycles.
*
* This check does not need to care about objects in libraries, as separate
* compilation ensures that there cannot be cyles between two separately
* compiled projects.
*/
class CheckGlobal {
case class Dependency(sym: Symbol, source: Tree)

/** Checking state */
case class State(visited: mutable.Set[Symbol], path: Vector[Tree], obj: Symbol) {
def cyclicPath(using Context): String = if (path.isEmpty) "" else " Cyclic path:\n" + {
var indentCount = 0
var last: String = ""
val sb = new StringBuilder
path.foreach { tree =>
indentCount += 1
val pos = tree.sourcePos
val prefix = s"${ " " * indentCount }-> "
val line =
if pos.source.exists then
val loc = "[ " + pos.source.file.name + ":" + (pos.line + 1) + " ]"
val code = SyntaxHighlighting.highlight(pos.lineContent.trim)
i"$code\t$loc"
else
tree.show

if (last != line) sb.append(prefix + line + "\n")

last = line
}
sb.toString
}
}

case class Error(state: State) {
def issue(using Context): Unit =
report.warning("Cylic object dependencies detected." + state.cyclicPath, state.obj.defTree.srcPos)
}

/** Summary of dependencies */
private val summaryCache = mutable.Map.empty[Symbol, List[Dependency]]

def check(obj: Symbol)(using Context): Unit = trace("checking " + obj.show, init) {
checkDependencies(obj, State(visited = mutable.Set.empty, path = Vector.empty, obj)) match
case Some(err) => err.issue
case _ =>
}

private def check(sym: Symbol, state: State)(using Context): Option[Error] = trace("checking " + sym.show, init) {
if sym == state.obj then
Some(Error(state))
else if state.visited.contains(sym) then
None
else
state.visited += sym
checkDependencies(sym, state)
}

private def checkDependencies(sym: Symbol, state: State)(using Context): Option[Error] = trace("checking dependencies of " + sym.show, init) {
val cls = if sym.is(Module) then sym.moduleClass.asClass else sym.asClass
val deps = analyze(cls)
Util.traceIndented("dependencies of " + sym.show + " = " + deps.map(_.sym.show).mkString(","), init)
var res: Option[Error] = None
// TODO: stop early
deps.foreach { dep =>
if res.isEmpty then
val state2: State = state.copy(path = state.path :+ dep.source)
res = check(dep.sym, state2)
}
res
}

private def analyze(cls: ClassSymbol)(using Context): List[Dependency] =
def isStaticObjectRef(sym: Symbol) =
sym.isTerm && !sym.is(Package) && sym.is(Module)
&& sym.isStatic && sym.moduleClass != cls

if (cls.defTree.isEmpty) Nil
else if (summaryCache.contains(cls)) summaryCache(cls)
else {
val cdef = cls.defTree.asInstanceOf[TypeDef]
val tpl = cdef.rhs.asInstanceOf[Template]

// ignore separately compiled classes
if !tpl.unforced.isInstanceOf[List[_]] then return Nil

var dependencies: List[Dependency] = Nil
val traverser = new TreeTraverser {
override def traverse(tree: Tree)(using Context): Unit =
tree match {
case tree: RefTree if isStaticObjectRef(tree.symbol) =>
dependencies = Dependency(tree.symbol, tree) :: dependencies

case tdef: TypeDef =>
// don't go into nested classes

case tree: New =>
dependencies = Dependency(tree.tpe.classSymbol, tree) :: dependencies

case _ =>
traverseChildren(tree)
}
}

def typeRefOf(tp: Type): TypeRef = tp.dealias.typeConstructor match {
case tref: TypeRef => tref
case hklambda: HKTypeLambda => typeRefOf(hklambda.resType)
}

def addStaticOuterDep(tp: Type, source: Tree): Unit =
tp match
case NoPrefix =>
case tmref: TermRef =>
if isStaticObjectRef(tmref.symbol) then
dependencies = Dependency(tmref.symbol, source) :: dependencies
case ThisType(tref) =>
val obj = tref.symbol.sourceModule
if isStaticObjectRef(obj) then
dependencies = Dependency(obj, source) :: dependencies
case _ =>
throw new Exception("unexpected type: " + tp)

// TODO: the traverser might create duplicate entries for parents
tpl.parents.foreach { tree =>
val tp = tree.tpe
val tref = typeRefOf(tp)
dependencies = Dependency(tp.classSymbol, tree) :: dependencies
addStaticOuterDep(tref.prefix, tree)
}

traverser.traverse(tpl)
summaryCache(cls) = dependencies
dependencies
}

def debugCache(using Context) =
summaryCache.map(_.show + " -> " + _.map(_.sym.show).mkString(",")).mkString("\n")
}
6 changes: 6 additions & 0 deletions compiler/src/dotty/tools/dotc/transform/init/Checker.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ class Checker extends MiniPhase {
// cache of class summary
private val baseEnv = Env(null)

val globalChecker = new CheckGlobal

override val runsAfter = Set(Pickler.name)

override def isEnabled(using Context): Boolean =
Expand Down Expand Up @@ -58,6 +60,10 @@ class Checker extends MiniPhase {
)

Checking.checkClassBody(tree)

// check cycles of object dependencies
if cls.is(Flags.Module) && cls.isStatic then
globalChecker.check(cls.sourceModule)
}

tree
Expand Down
34 changes: 23 additions & 11 deletions compiler/src/dotty/tools/dotc/transform/init/Checking.scala
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,18 @@ object Checking {
safePromoted: mutable.Set[Potential], // Potentials that can be safely promoted
env: Env
) {
def withOwner(sym: Symbol): State = copy(env = env.withOwner(sym))
def withOwner[T](sym: Symbol)(op: State ?=> T): T =
val state = this.copy(env = env.withOwner(sym))
val res = op(using state)
this.visited = state.visited
res


def withStep[T](step: Tree)(op: State ?=> T): T =
val state: State = this.copy(path = path :+ step)
val res = op(using state)
this.visited = state.visited
res

def test(op: State ?=> Errors): Errors = {
val savedVisited = visited
Expand All @@ -60,11 +71,12 @@ object Checking {
}
else {
state.visited = state.visited + eff
val state2: State = state.copy(path = state.path :+ eff.source)
eff match {
case eff: Promote => Checking.checkPromote(eff)(using state2)
case eff: FieldAccess => Checking.checkFieldAccess(eff)(using state2)
case eff: MethodCall => Checking.checkMethodCall(eff)(using state2)
state.withStep(eff.source) {
eff match {
case eff: Promote => Checking.checkPromote(eff)
case eff: FieldAccess => Checking.checkFieldAccess(eff)
case eff: MethodCall => Checking.checkMethodCall(eff)
}
}
}
}
Expand Down Expand Up @@ -118,11 +130,11 @@ object Checking {
def checkConstructor(ctor: Symbol, tp: Type, source: Tree)(using state: State): Unit = traceOp("checking " + ctor.show, init) {
val cls = ctor.owner
val classDef = cls.defTree
if (!classDef.isEmpty) {
given State = state.withOwner(cls)
if (ctor.isPrimaryConstructor) checkClassBody(classDef.asInstanceOf[TypeDef])
else checkSecondaryConstructor(ctor)
}
if (!classDef.isEmpty)
state.withOwner(cls) {
if (ctor.isPrimaryConstructor) checkClassBody(classDef.asInstanceOf[TypeDef])
else checkSecondaryConstructor(ctor)
}
}

def checkSecondaryConstructor(ctor: Symbol)(using state: State): Unit = traceOp("checking " + ctor.show, init) {
Expand Down
9 changes: 9 additions & 0 deletions tests/init/neg/i9176.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
class Foo(val opposite: Foo)
case object A extends Foo(B) // error
case object B extends Foo(A) // error
object Test {
def main(args: Array[String]): Unit = {
println(A.opposite)
println(B.opposite)
}
}
15 changes: 15 additions & 0 deletions tests/init/neg/t5366.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
class IdAndMsg(val id: Int, val msg: String = "")

case object ObjA extends IdAndMsg(1) // error
case object ObjB extends IdAndMsg(2) // error

object IdAndMsg { // error
val values = List(ObjA , ObjB)
}

object Test {
def main(args: Array[String]): Unit = {
ObjA
println(IdAndMsg.values)
}
}
8 changes: 8 additions & 0 deletions tests/init/neg/t9115.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
object D { // error
def aaa = 1 //that’s the reason
class Z (depends: Any)
case object D1 extends Z(aaa) // 'null' when calling D.D1 first time // error
case object D2 extends Z(aaa) // 'null' when calling D.D2 first time // error
println(D1)
println(D2)
}
3 changes: 3 additions & 0 deletions tests/init/neg/t9261.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
sealed abstract class OrderType(val reverse: OrderType)
case object Buy extends OrderType(Sell) // error
case object Sell extends OrderType(Buy) // error
23 changes: 23 additions & 0 deletions tests/init/neg/t9312.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
object DeadLockTest {
def main(args: Array[String]): Unit = {
def run(block: => Unit): Unit =
new Thread(new Runnable {def run(): Unit = block}).start()

run {println(Parent.Child1)}
run {println(Parent.Child2)}

}

object Parent { // error
trait Child {
Thread.sleep(2000) // ensure concurrent behavior
val parent = Parent
def siblings = parent.children - this
}

object Child1 extends Child // error
object Child2 extends Child // error

final val children = Set(Child1, Child2) // error
}
}