@@ -87,56 +87,42 @@ class LabelDefs extends MiniPhase {
87
87
override def transformDefDef (tree : tpd.DefDef )(implicit ctx : Context ): tpd.Tree = {
88
88
if (tree.symbol is Label ) tree
89
89
else {
90
- collectLabelDefs.clear()
91
- val newRhs = collectLabelDefs.transform(tree.rhs)
92
- var labelDefs = collectLabelDefs.labelDefs
90
+ val labelDefs = collectLabelDefs(tree.rhs)
93
91
94
92
def putLabelDefsNearCallees = new TreeMap () {
95
-
96
93
override def transform (tree : tpd.Tree )(implicit ctx : Context ): tpd.Tree = {
97
94
tree match {
95
+ case t : Template => t
98
96
case t : Apply if labelDefs.contains(t.symbol) =>
99
97
val labelDef = labelDefs(t.symbol)
100
98
labelDefs -= t.symbol
101
-
102
- val labelDef2 = transform(labelDef)
99
+ val labelDef2 = cpy.DefDef (labelDef)(rhs = transform(labelDef.rhs))
103
100
Block (labelDef2:: Nil , t)
104
-
101
+ case t : DefDef =>
102
+ assert(t.symbol is Label )
103
+ EmptyTree
105
104
case _ => if (labelDefs.nonEmpty) super .transform(tree) else tree
106
105
}
107
106
}
108
107
}
109
108
110
- val res = cpy.DefDef (tree)(rhs = putLabelDefsNearCallees.transform(newRhs))
111
-
112
- res
109
+ cpy.DefDef (tree)(rhs = putLabelDefsNearCallees.transform(tree.rhs))
113
110
}
114
111
}
115
112
116
- private object collectLabelDefs extends TreeMap () {
117
-
113
+ private def collectLabelDefs (tree : Tree )(implicit ctx : Context ): mutable.HashMap [Symbol , DefDef ] = {
118
114
// labelSymbol -> Defining tree
119
- val labelDefs = new mutable.HashMap [Symbol , Tree ]()
120
-
121
- def clear (): Unit = {
122
- labelDefs.clear()
123
- }
124
-
125
- override def transform (tree : tpd.Tree )(implicit ctx : Context ): tpd.Tree = tree match {
126
- case t : Template => t
127
- case t : Block =>
128
- val r = super .transform(t)
129
- r match {
130
- case t : Block if t.stats.isEmpty => t.expr
131
- case _ => r
132
- }
133
- case t : DefDef =>
134
- assert(t.symbol is Label )
135
- val r = super .transform(tree)
136
- labelDefs(r.symbol) = r
137
- EmptyTree
138
- case _ =>
139
- super .transform(tree)
140
- }
115
+ val labelDefs = new mutable.HashMap [Symbol , DefDef ]()
116
+ new TreeTraverser {
117
+ override def traverse (tree : tpd.Tree )(implicit ctx : Context ): Unit = tree match {
118
+ case _ : Template =>
119
+ case t : DefDef =>
120
+ assert(t.symbol is Label )
121
+ labelDefs(t.symbol) = t
122
+ traverseChildren(t)
123
+ case _ => traverseChildren(tree)
124
+ }
125
+ }.traverse(tree)
126
+ labelDefs
141
127
}
142
128
}
0 commit comments