@@ -88,56 +88,42 @@ class LabelDefs extends MiniPhaseTransform {
88
88
override def transformDefDef (tree : tpd.DefDef )(implicit ctx : Context , info : TransformerInfo ): tpd.Tree = {
89
89
if (tree.symbol is Label ) tree
90
90
else {
91
- collectLabelDefs.clear()
92
- val newRhs = collectLabelDefs.transform(tree.rhs)
93
- var labelDefs = collectLabelDefs.labelDefs
91
+ val labelDefs = collectLabelDefs(tree.rhs)
94
92
95
93
def putLabelDefsNearCallees = new TreeMap () {
96
-
97
94
override def transform (tree : tpd.Tree )(implicit ctx : Context ): tpd.Tree = {
98
95
tree match {
96
+ case t : Template => t
99
97
case t : Apply if labelDefs.contains(t.symbol) =>
100
98
val labelDef = labelDefs(t.symbol)
101
99
labelDefs -= t.symbol
102
-
103
- val labelDef2 = transform(labelDef)
100
+ val labelDef2 = cpy.DefDef (labelDef)(rhs = transform(labelDef.rhs))
104
101
Block (labelDef2:: Nil , t)
105
-
102
+ case t : DefDef =>
103
+ assert(t.symbol is Label )
104
+ EmptyTree
106
105
case _ => if (labelDefs.nonEmpty) super .transform(tree) else tree
107
106
}
108
107
}
109
108
}
110
109
111
- val res = cpy.DefDef (tree)(rhs = putLabelDefsNearCallees.transform(newRhs))
112
-
113
- res
110
+ cpy.DefDef (tree)(rhs = putLabelDefsNearCallees.transform(tree.rhs))
114
111
}
115
112
}
116
113
117
- private object collectLabelDefs extends TreeMap () {
118
-
114
+ private def collectLabelDefs (tree : Tree )(implicit ctx : Context ): mutable.HashMap [Symbol , DefDef ] = {
119
115
// labelSymbol -> Defining tree
120
- val labelDefs = new mutable.HashMap [Symbol , Tree ]()
121
-
122
- def clear (): Unit = {
123
- labelDefs.clear()
124
- }
125
-
126
- override def transform (tree : tpd.Tree )(implicit ctx : Context ): tpd.Tree = tree match {
127
- case t : Template => t
128
- case t : Block =>
129
- val r = super .transform(t)
130
- r match {
131
- case t : Block if t.stats.isEmpty => t.expr
132
- case _ => r
133
- }
134
- case t : DefDef =>
135
- assert(t.symbol is Label )
136
- val r = super .transform(tree)
137
- labelDefs(r.symbol) = r
138
- EmptyTree
139
- case _ =>
140
- super .transform(tree)
141
- }
116
+ val labelDefs = new mutable.HashMap [Symbol , DefDef ]()
117
+ new TreeTraverser {
118
+ override def traverse (tree : tpd.Tree )(implicit ctx : Context ): Unit = tree match {
119
+ case _ : Template =>
120
+ case t : DefDef =>
121
+ assert(t.symbol is Label )
122
+ labelDefs(t.symbol) = t
123
+ traverseChildren(t)
124
+ case _ => traverseChildren(tree)
125
+ }
126
+ }.traverse(tree)
127
+ labelDefs
142
128
}
143
129
}
0 commit comments