@@ -3,37 +3,34 @@ package dotty.tools.dotc.transform
3
3
import dotty .tools .dotc .CompilationUnit
4
4
import dotty .tools .dotc .ast .tpd
5
5
import dotty .tools .dotc .core .Contexts .Context
6
- import dotty .tools .dotc .core .Phases .Phase
7
6
8
7
/** Set the `rootTreeOrProvider` property of class symbols. */
9
- class SetRootTree extends Phase {
8
+ class SetRootTree extends MacroTransform {
10
9
11
10
override val phaseName : String = SetRootTree .name
12
11
override def isRunnable (implicit ctx : Context ) =
13
12
super .isRunnable && ctx.settings.YretainTrees .value
14
13
15
- override def run (implicit ctx : Context ): Unit = {
16
- val tree = ctx.compilationUnit.tpdTree
17
- traverser.traverse(tree)
18
- }
19
-
20
- private def traverser = new tpd.TreeTraverser {
21
- override def traverse (tree : tpd.Tree )(implicit ctx : Context ): Unit = {
22
- tree match {
23
- case pkg : tpd.PackageDef =>
24
- traverseChildren(pkg)
25
- case td : tpd.TypeDef =>
26
- if (td.symbol.isClass) {
27
- val sym = td.symbol.asClass
28
- tpd.sliceTopLevel(ctx.compilationUnit.tpdTree, sym) match {
29
- case (pkg : tpd.PackageDef ) :: Nil =>
30
- sym.rootTreeOrProvider = pkg
31
- case _ =>
32
- sym.rootTreeOrProvider = td
14
+ override protected def newTransformer (implicit ctx : Context ): Transformer = {
15
+ new Transformer {
16
+ override def transform (tree : tpd.Tree )(implicit ctx : Context ): tpd.Tree = {
17
+ tree match {
18
+ case pkg : tpd.PackageDef =>
19
+ super .transform(pkg)
20
+ case td : tpd.TypeDef =>
21
+ if (td.symbol.isClass) {
22
+ val sym = td.symbol.asClass
23
+ tpd.sliceTopLevel(ctx.compilationUnit.tpdTree, sym) match {
24
+ case (pkg : tpd.PackageDef ) :: Nil =>
25
+ sym.rootTreeOrProvider = pkg
26
+ case _ =>
27
+ sym.rootTreeOrProvider = td
28
+ }
33
29
}
34
- }
35
- case _ =>
36
- ()
30
+ td
31
+ case other =>
32
+ other
33
+ }
37
34
}
38
35
}
39
36
}
0 commit comments