@@ -32,6 +32,7 @@ import java.lang.AssertionError
32
32
import dotty .tools .dotc .util .Positions .Position
33
33
import Decorators ._
34
34
import tpd ._
35
+ import Flags ._
35
36
import StdNames .nme
36
37
37
38
/**
@@ -80,54 +81,68 @@ class LabelDefs extends MiniPhaseTransform {
80
81
81
82
val queue = new ArrayBuffer [Tree ]()
82
83
83
-
84
-
85
- override def transformBlock (tree : tpd.Block )(implicit ctx : Context , info : TransformerInfo ): tpd.Tree = {
86
- collectLabelDefs.clear
87
- val newStats = collectLabelDefs.transformStats(tree.stats)
88
- val newExpr = collectLabelDefs.transform(tree.expr)
89
- val labelCalls = collectLabelDefs.labelCalls
90
- val entryPoints = collectLabelDefs.parentLabelCalls
91
- val labelDefs = collectLabelDefs.labelDefs
92
-
93
- // make sure that for every label there's a single location it should return and single entry point
94
- // if theres already a location that it returns to that's a failure
95
- val disallowed = new mutable.HashMap [Symbol , Tree ]()
96
- queue.sizeHint(labelCalls.size + entryPoints.size)
97
- def moveLabels (entryPoint : Tree ): List [Tree ] = {
98
- if ((entryPoint.symbol is Flags .Label ) && labelDefs.contains(entryPoint.symbol)) {
99
- val visitedNow = new mutable.HashMap [Symbol , Tree ]()
100
- val treesToAppend = new ArrayBuffer [Tree ]() // order matters. parents should go first
101
- queue.clear()
102
-
103
- var visited = 0
104
- queue += entryPoint
105
- while (visited < queue.size) {
106
- val owningLabelDefSym = queue(visited).symbol
107
- val owningLabelDef = labelDefs(owningLabelDefSym)
108
- for (call <- labelCalls(owningLabelDefSym))
109
- if (disallowed.contains(call.symbol)) {
110
- val oldCall = disallowed(call.symbol)
111
- ctx.error(s " Multiple return locations for Label $oldCall and $call" , call.symbol.pos)
112
- } else {
113
- if ((! visitedNow.contains(call.symbol)) && labelDefs.contains(call.symbol)) {
114
- val df = labelDefs(call.symbol)
115
- visitedNow.put(call.symbol, labelDefs(call.symbol))
116
- queue += call
84
+ override def transformDefDef (tree : tpd.DefDef )(implicit ctx : Context , info : TransformerInfo ): tpd.Tree = {
85
+ if (tree.symbol is Flags .Label ) tree
86
+ else {
87
+ collectLabelDefs.clear
88
+ val newRhs = collectLabelDefs.transform(tree.rhs)
89
+ val labelCalls = collectLabelDefs.labelCalls
90
+ var entryPoints = collectLabelDefs.parentLabelCalls
91
+ var labelDefs = collectLabelDefs.labelDefs
92
+
93
+ // make sure that for every label there's a single location it should return and single entry point
94
+ // if theres already a location that it returns to that's a failure
95
+ val disallowed = new mutable.HashMap [Symbol , Tree ]()
96
+ queue.sizeHint(labelCalls.size + entryPoints.size)
97
+ def moveLabels (entryPoint : Tree ): List [Tree ] = {
98
+ if ((entryPoint.symbol is Flags .Label ) && labelDefs.contains(entryPoint.symbol)) {
99
+ val visitedNow = new mutable.HashMap [Symbol , Tree ]()
100
+ val treesToAppend = new ArrayBuffer [Tree ]() // order matters. parents should go first
101
+ queue.clear()
102
+
103
+ var visited = 0
104
+ queue += entryPoint
105
+ while (visited < queue.size) {
106
+ val owningLabelDefSym = queue(visited).symbol
107
+ val owningLabelDef = labelDefs(owningLabelDefSym)
108
+ for (call <- labelCalls(owningLabelDefSym))
109
+ if (disallowed.contains(call.symbol)) {
110
+ val oldCall = disallowed(call.symbol)
111
+ ctx.error(s " Multiple return locations for Label $oldCall and $call" , call.symbol.pos)
112
+ } else {
113
+ if ((! visitedNow.contains(call.symbol)) && labelDefs.contains(call.symbol)) {
114
+ visitedNow.put(call.symbol, labelDefs(call.symbol))
115
+ queue += call
116
+ }
117
117
}
118
+ if (! treesToAppend.contains(owningLabelDef)) {
119
+ treesToAppend += owningLabelDef
118
120
}
119
- if (! treesToAppend.contains(owningLabelDef))
120
- treesToAppend += owningLabelDef
121
- visited += 1
121
+ visited += 1
122
+ }
123
+ disallowed ++= visitedNow
124
+
125
+ treesToAppend.toList
126
+ } else Nil
127
+ }
128
+
129
+ val putLabelDefsNearCallees = new TreeMap () {
130
+
131
+ override def transform (tree : tpd.Tree )(implicit ctx : Context ): tpd.Tree = {
132
+ tree match {
133
+ case t : Apply if (entryPoints.contains(t)) =>
134
+ entryPoints = entryPoints - t
135
+ Block (moveLabels(t), t)
136
+ case _ => if (entryPoints.nonEmpty && labelDefs.nonEmpty) super .transform(tree) else tree
137
+ }
122
138
}
123
- disallowed ++= visitedNow
139
+ }
124
140
125
- treesToAppend.toList
126
- } else Nil
127
- }
128
141
129
- cpy.Block (tree)(entryPoints.flatMap(moveLabels).toList ++ newStats, newExpr )
142
+ val res = cpy.DefDef (tree)(rhs = putLabelDefsNearCallees.transform(newRhs) )
130
143
144
+ res
145
+ }
131
146
}
132
147
133
148
val collectLabelDefs = new TreeMap () {
@@ -137,13 +152,12 @@ class LabelDefs extends MiniPhaseTransform {
137
152
var isInsideLabel = false
138
153
var isInsideBlock = false
139
154
140
- def shouldMoveLabel = ! isInsideBlock
155
+ def shouldMoveLabel = true
141
156
142
157
// labelSymbol -> Defining tree
143
158
val labelDefs = new mutable.HashMap [Symbol , Tree ]()
144
159
// owner -> all calls by this owner
145
160
val labelCalls = new mutable.HashMap [Symbol , mutable.Set [Tree ]]()
146
- val labelCallCounts = new mutable.HashMap [Symbol , Int ]()
147
161
148
162
def clear = {
149
163
parentLabelCalls.clear()
@@ -175,7 +189,6 @@ class LabelDefs extends MiniPhaseTransform {
175
189
} else r
176
190
case t : Apply if t.symbol is Flags .Label =>
177
191
parentLabelCalls = parentLabelCalls + t
178
- labelCallCounts.get(t.symbol)
179
192
super .transform(tree)
180
193
case _ =>
181
194
super .transform(tree)
0 commit comments