Skip to content

Commit c7ed89d

Browse files
committed
Merge pull request #224 from dotty-staging/refactor/treeTransformInits
Refactor/tree transform inits
2 parents 7978a5f + e02db34 commit c7ed89d

File tree

7 files changed

+420
-392
lines changed

7 files changed

+420
-392
lines changed

src/dotty/tools/dotc/transform/CapturedVars.scala

Lines changed: 73 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -17,86 +17,94 @@ import SymUtils._
1717
import collection.{ mutable, immutable }
1818
import collection.mutable.{ LinkedHashMap, LinkedHashSet, TreeSet }
1919

20-
class CapturedVars extends MiniPhaseTransform with SymTransformer { thisTransform =>
20+
class CapturedVars extends MiniPhase with IdentityDenotTransformer { thisTransform =>
2121
import ast.tpd._
2222

2323
/** the following two members override abstract members in Transform */
2424
val phaseName: String = "capturedVars"
25+
val treeTransform = new Transform(Set())
2526

26-
override def treeTransformPhase = thisTransform.next
27+
class Transform(captured: collection.Set[Symbol]) extends TreeTransform {
28+
def phase = thisTransform
29+
override def treeTransformPhase = thisTransform.next
2730

28-
private var captured: mutable.HashSet[Symbol] = _
29-
30-
private class CollectCaptured(implicit ctx: Context) extends EnclosingMethodTraverser {
31-
def traverse(enclMeth: Symbol, tree: Tree) = tree match {
32-
case id: Ident =>
33-
val sym = id.symbol
34-
if (sym.is(Mutable, butNot = Method) && sym.owner.isTerm && sym.enclosingMethod != enclMeth) {
35-
ctx.log(i"capturing $sym in ${sym.enclosingMethod}, referenced from $enclMeth")
36-
captured += sym
37-
}
38-
case _ =>
39-
foldOver(enclMeth, tree)
40-
}
41-
def runOver(tree: Tree) = {
42-
captured = mutable.HashSet()
43-
apply(NoSymbol, tree)
31+
private class CollectCaptured(implicit ctx: Context) extends EnclosingMethodTraverser {
32+
private val captured = mutable.HashSet[Symbol]()
33+
def traverse(enclMeth: Symbol, tree: Tree) = tree match {
34+
case id: Ident =>
35+
val sym = id.symbol
36+
if (sym.is(Mutable, butNot = Method) && sym.owner.isTerm && sym.enclosingMethod != enclMeth) {
37+
ctx.log(i"capturing $sym in ${sym.enclosingMethod}, referenced from $enclMeth")
38+
captured += sym
39+
}
40+
case _ =>
41+
foldOver(enclMeth, tree)
42+
}
43+
def runOver(tree: Tree): collection.Set[Symbol] = {
44+
apply(NoSymbol, tree)
45+
captured
46+
}
4447
}
45-
}
4648

47-
override def init(implicit ctx: Context, info: TransformerInfo): Unit =
48-
(new CollectCaptured)(ctx.withPhase(thisTransform)).runOver(ctx.compilationUnit.tpdTree)
49+
override def prepareForUnit(tree: Tree)(implicit ctx: Context) = {
50+
val captured = (new CollectCaptured)(ctx.withPhase(thisTransform))
51+
.runOver(ctx.compilationUnit.tpdTree)
52+
new Transform(captured)
53+
}
4954

50-
override def transformSym(sd: SymDenotation)(implicit ctx: Context): SymDenotation =
51-
if (captured(sd.symbol)) {
52-
val newd = sd.copySymDenotation(
53-
info = refCls(sd.info.classSymbol, sd.hasAnnotation(defn.VolatileAnnot)).typeRef)
54-
newd.removeAnnotation(defn.VolatileAnnot)
55-
newd
56-
} else sd
55+
/** The {Volatile|}{Int|Double|...|Object}Ref class corresponding to the class `cls`,
56+
* depending on whether the reference should be @volatile
57+
*/
58+
def refCls(cls: Symbol, isVolatile: Boolean)(implicit ctx: Context): Symbol = {
59+
val refMap = if (isVolatile) defn.volatileRefClass else defn.refClass
60+
refMap.getOrElse(cls, refMap(defn.ObjectClass))
61+
}
5762

58-
/** The {Volatile|}{Int|Double|...|Object}Ref class corresponding to the class `cls`,
59-
* depending on whether the reference should be @volatile
60-
*/
61-
def refCls(cls: Symbol, isVolatile: Boolean)(implicit ctx: Context): Symbol = {
62-
val refMap = if (isVolatile) defn.volatileRefClass else defn.refClass
63-
refMap.getOrElse(cls, refMap(defn.ObjectClass))
64-
}
63+
def capturedType(vble: Symbol)(implicit ctx: Context): Type = {
64+
val oldInfo = vble.denot(ctx.withPhase(thisTransform)).info
65+
refCls(oldInfo.classSymbol, vble.isVolatile).typeRef
66+
}
6567

66-
def capturedType(vble: Symbol)(implicit ctx: Context): Type = {
67-
val oldInfo = vble.denot(ctx.withPhase(thisTransform)).info
68-
refCls(oldInfo.classSymbol, vble.isVolatile).typeRef
69-
}
68+
override def prepareForValDef(vdef: ValDef)(implicit ctx: Context) = {
69+
val sym = vdef.symbol
70+
if (captured contains sym) {
71+
val newd = sym.denot(ctx.withPhase(thisTransform)).copySymDenotation(
72+
info = refCls(sym.info.classSymbol, sym.hasAnnotation(defn.VolatileAnnot)).typeRef)
73+
newd.removeAnnotation(defn.VolatileAnnot)
74+
newd.installAfter(thisTransform)
75+
}
76+
this
77+
}
7078

71-
override def transformValDef(vdef: ValDef)(implicit ctx: Context, info: TransformerInfo): Tree = {
72-
val vble = vdef.symbol
73-
if (captured contains vble) {
74-
def boxMethod(name: TermName): Tree =
75-
ref(vble.info.classSymbol.companionModule.info.member(name).symbol)
76-
cpy.ValDef(vdef)(
77-
rhs = vdef.rhs match {
78-
case EmptyTree => boxMethod(nme.zero).appliedToNone.withPos(vdef.pos)
79-
case arg => boxMethod(nme.create).appliedTo(arg)
80-
},
81-
tpt = TypeTree(vble.info).withPos(vdef.tpt.pos))
79+
override def transformValDef(vdef: ValDef)(implicit ctx: Context, info: TransformerInfo): Tree = {
80+
val vble = vdef.symbol
81+
if (captured contains vble) {
82+
def boxMethod(name: TermName): Tree =
83+
ref(vble.info.classSymbol.companionModule.info.member(name).symbol)
84+
cpy.ValDef(vdef)(
85+
rhs = vdef.rhs match {
86+
case EmptyTree => boxMethod(nme.zero).appliedToNone.withPos(vdef.pos)
87+
case arg => boxMethod(nme.create).appliedTo(arg)
88+
},
89+
tpt = TypeTree(vble.info).withPos(vdef.tpt.pos))
90+
} else vdef
8291
}
83-
else vdef
84-
}
8592

86-
override def transformIdent(id: Ident)(implicit ctx: Context, info: TransformerInfo): Tree = {
87-
val vble = id.symbol
88-
if (captured(vble))
89-
(id select nme.elem).ensureConforms(vble.denot(ctx.withPhase(thisTransform)).info)
90-
else id
91-
}
93+
override def transformIdent(id: Ident)(implicit ctx: Context, info: TransformerInfo): Tree = {
94+
val vble = id.symbol
95+
if (captured(vble))
96+
(id select nme.elem).ensureConforms(vble.denot(ctx.withPhase(thisTransform)).info)
97+
else id
98+
}
9299

93-
override def transformAssign(tree: Assign)(implicit ctx: Context, info: TransformerInfo): Tree = {
94-
val lhs1 = tree.lhs match {
95-
case TypeApply(Select(qual @ Select(qual2, nme.elem), nme.asInstanceOf_), _) =>
96-
assert(captured(qual2.symbol))
97-
qual
98-
case _ => tree.lhs
100+
override def transformAssign(tree: Assign)(implicit ctx: Context, info: TransformerInfo): Tree = {
101+
val lhs1 = tree.lhs match {
102+
case TypeApply(Select(qual @ Select(qual2, nme.elem), nme.asInstanceOf_), _) =>
103+
assert(captured(qual2.symbol))
104+
qual
105+
case _ => tree.lhs
106+
}
107+
cpy.Assign(tree)(lhs1, tree.rhs)
99108
}
100-
cpy.Assign(tree)(lhs1, tree.rhs)
101109
}
102110
}

src/dotty/tools/dotc/transform/CollectEntryPoints.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,10 @@ import dotty.tools.dotc.config.JavaPlatform
2626
class CollectEntryPoints extends MiniPhaseTransform {
2727

2828
/** perform context-dependant initialization */
29-
override def init(implicit ctx: Context, info: TransformerInfo): Unit = {
29+
override def prepareForUnit(tree: tpd.Tree)(implicit ctx: Context) = {
3030
entryPoints = collection.immutable.TreeSet.empty[Symbol](new SymbolOrdering())
3131
assert(ctx.platform.isInstanceOf[JavaPlatform], "Java platform specific phase")
32+
this
3233
}
3334

3435
private var entryPoints: Set[Symbol] = _

src/dotty/tools/dotc/transform/InterceptedMethods.scala

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,11 +53,12 @@ class InterceptedMethods extends MiniPhaseTransform {
5353
private var primitiveGetClassMethods: Set[Symbol] = _
5454

5555
/** perform context-dependant initialization */
56-
override def init(implicit ctx: Context, info: TransformerInfo): Unit = {
56+
override def prepareForUnit(tree: Tree)(implicit ctx: Context) = {
5757
poundPoundMethods = Set(defn.Any_##)
5858
Any_comparisons = Set(defn.Any_==, defn.Any_!=)
5959
interceptedMethods = poundPoundMethods ++ Any_comparisons
6060
primitiveGetClassMethods = Set[Symbol]() ++ defn.ScalaValueClasses.map(x => x.requiredMethod(nme.getClass_))
61+
this
6162
}
6263

6364
// this should be removed if we have guarantee that ## will get Apply node

0 commit comments

Comments
 (0)