1
1
package dotty .tools .dotc
2
2
package transform
3
3
4
+ import ast .tpd
5
+ import ast .Trees .*
6
+ import ast .TreeMapWithImplicits
4
7
import core .*
5
8
import Flags .*
9
+ import Decorators .*
6
10
import Contexts .*
7
11
import Symbols .*
12
+ import Decorators .*
13
+ import config .Printers .inlining
14
+ import DenotTransformers .IdentityDenotTransformer
15
+ import inlines .Inlines
16
+ import quoted .*
17
+ import staging .StagingLevel
18
+ import util .Property
8
19
9
- import dotty .tools .dotc .ast .tpd
10
- import dotty .tools .dotc .ast .Trees .*
11
- import dotty .tools .dotc .quoted .*
12
- import dotty .tools .dotc .inlines .Inlines
13
- import dotty .tools .dotc .ast .TreeMapWithImplicits
14
- import dotty .tools .dotc .core .DenotTransformers .IdentityDenotTransformer
15
- import dotty .tools .dotc .staging .StagingLevel
16
-
17
- import scala .collection .mutable .ListBuffer
20
+ import scala .collection .mutable
18
21
19
22
/** Inlines all calls to inline methods that are not in an inline method or a quote */
20
23
class Inlining extends MacroTransform , IdentityDenotTransformer {
@@ -56,38 +59,94 @@ class Inlining extends MacroTransform, IdentityDenotTransformer {
56
59
57
60
def newTransformer (using Context ): Transformer = new Transformer {
58
61
override def transform (tree : tpd.Tree )(using Context ): tpd.Tree =
59
- new InliningTreeMap ().transform(tree)
62
+ InliningTreeMap ().transform(tree)
63
+ }
64
+
65
+ private class MemoizeStatsTreeMap extends TreeMapWithImplicits {
66
+
67
+ // It is safe to assume that the companion of a tree is in the same scope
68
+ // Therefore, when expanding MacroAnnotations, we will only keep track of
69
+ // the trees in the same scope as the current transformed tree
70
+
71
+ private val TrackedTrees = new Property .Key [mutable.Map [Symbol , tpd.MemberDef ]]
72
+
73
+ private def trackedTreesCtx (trees : mutable.Map [Symbol , tpd.MemberDef ])(using Context ): Context =
74
+ ctx.fresh.setProperty(TrackedTrees , trees)
75
+
76
+ private def trackedTrees (using Context ): mutable.Map [Symbol , MemberDef ] =
77
+ ctx.property(TrackedTrees ).getOrElse(mutable.Map .empty)
78
+
79
+ protected final def getTracked (sym : Symbol )(using Context ): Option [MemberDef ] =
80
+ for trees <- ctx.property(TrackedTrees ); tree <- trees.get(sym) yield tree
81
+
82
+ protected final def updateTracked (sym : Symbol , tree : MemberDef )(using Context ): Unit =
83
+ for trees <- ctx.property(TrackedTrees ) do trees.update(sym, tree)
84
+
85
+ private final def withUpdatedTree (stats : List [Tree ], trackedTrees : mutable.Map [Symbol , MemberDef ])(using Context ) =
86
+ stats.mapConserve:
87
+ case tree : MemberDef if trackedTrees.contains(tree.symbol) =>
88
+ trackedTrees(tree.symbol)
89
+ case stat => stat
90
+
91
+ override def transform (tree : Tree )(using Context ): Tree =
92
+ tree match
93
+ case PackageDef (_, stats) =>
94
+ // Step I: Collect and memoize all the stats
95
+ val treesToTrack = stats.collect { case m : MemberDef => (m.symbol, m) }
96
+ val withTrackedTreeCtx = trackedTreesCtx(mutable.Map (treesToTrack* ))
97
+ // Step II & III: Transform the tree
98
+ val result = super .transform(tree)(using withTrackedTreeCtx)
99
+ // Step III: Reconcile between the symbols in syms and the tree
100
+ (result : @ unchecked) match
101
+ case pkg@ PackageDef (pid, stats) =>
102
+ cpy.PackageDef (pkg)(
103
+ pid = pid,
104
+ stats = withUpdatedTree(stats, trackedTrees(using withTrackedTreeCtx))
105
+ )
106
+ case tree => tree
107
+ case block : Block =>
108
+ // Step I: Fetch all the member definitions in the block
109
+ val trees = block.stats.collect { case m : MemberDef => (m.symbol, m) }
110
+ val withTrackedTreeCtx = trackedTreesCtx(mutable.Map (trees* ))
111
+ // Step II: Transform the tree
112
+ val result = super .transform(tree)(using withTrackedTreeCtx)
113
+ // Step III: Reconcile between the symbols in syms and the tree
114
+ (result : @ unchecked) match
115
+ case b@ Block (stats, expr) =>
116
+ val trackedTree = trackedTrees(using withTrackedTreeCtx)
117
+ cpy.Block (b)(
118
+ expr = expr,
119
+ stats = withUpdatedTree(stats, trackedTrees(using withTrackedTreeCtx))
120
+ )
121
+ case tree => tree
122
+ case TypeDef (_, impl : Template ) =>
123
+ // Step I: Collect and memoize all the stats
124
+ val trees = impl.body.collect { case m : MemberDef => (m.symbol, m) }
125
+ val withTrackedTreeCtx = trackedTreesCtx(mutable.Map (trees* ))
126
+ // Step II: Transform the tree
127
+ val result = super .transform(tree)(using withTrackedTreeCtx)
128
+ // Step III: Reconcile between the symbols in syms and the tree
129
+ (result : @ unchecked) match
130
+ case tree@ TypeDef (name, impl : Template ) =>
131
+ val trackedTree = trackedTrees(using withTrackedTreeCtx)
132
+ cpy.TypeDef (tree)(
133
+ name = name,
134
+ rhs = cpy.Template (impl)(
135
+ body = withUpdatedTree(impl.body, trackedTrees(using withTrackedTreeCtx))
136
+ ))
137
+ case _ => super .transform(tree)
60
138
}
61
139
62
- private class InliningTreeMap extends TreeMapWithImplicits {
140
+ private class InliningTreeMap extends MemoizeStatsTreeMap {
63
141
64
142
/** List of top level classes added by macro annotation in a package object.
65
143
* These are added to the PackageDef that owns this particular package object.
66
144
*/
67
- private val newTopClasses = MutableSymbolMap [ListBuffer [Tree ]]()
145
+ private val newTopClasses = MutableSymbolMap [mutable. ListBuffer [Tree ]]()
68
146
69
147
override def transform (tree : Tree )(using Context ): Tree = {
70
148
tree match
71
- case tree : MemberDef =>
72
- if tree.symbol.is(Inline ) then tree
73
- else if tree.symbol.is(Param ) then super .transform(tree)
74
- else if
75
- ! tree.symbol.isPrimaryConstructor
76
- && StagingLevel .level == 0
77
- && MacroAnnotations .hasMacroAnnotation(tree.symbol)
78
- then
79
- val trees = (new MacroAnnotations (self)).expandAnnotations(tree)
80
- val trees1 = trees.map(super .transform)
81
-
82
- // Find classes added to the top level from a package object
83
- val (topClasses, trees2) =
84
- if ctx.owner.isPackageObject then trees1.partition(_.symbol.owner == ctx.owner.owner)
85
- else (Nil , trees1)
86
- if topClasses.nonEmpty then
87
- newTopClasses.getOrElseUpdate(ctx.owner.owner, new ListBuffer ) ++= topClasses
88
-
89
- flatTree(trees2)
90
- else super .transform(tree)
149
+ case tree : MemberDef => transformMemberDef(tree)
91
150
case _ : Typed | _ : Block =>
92
151
super .transform(tree)
93
152
case _ : PackageDef =>
@@ -113,7 +172,58 @@ class Inlining extends MacroTransform, IdentityDenotTransformer {
113
172
else Inlines .inlineCall(tree1)
114
173
else super .transform(tree)
115
174
}
175
+
176
+ private def transformMemberDef (tree : MemberDef )(using Context ) : Tree =
177
+ if tree.symbol.is(Inline ) then tree
178
+ else if tree.symbol.is(Param ) then
179
+ super .transform(tree)
180
+ else if
181
+ ! tree.symbol.isPrimaryConstructor
182
+ && StagingLevel .level == 0
183
+ && MacroAnnotations .hasMacroAnnotation(tree.symbol)
184
+ then
185
+ // Fetch the companion's tree
186
+ val companionTree =
187
+ getTracked(
188
+ sym = if tree.symbol.is(ModuleClass ) then tree.symbol.companionClass
189
+ else if tree.symbol.is(ModuleVal ) then NoSymbol
190
+ else tree.symbol.companionModule.moduleClass)
191
+
192
+ // Fetch the latest tracked tree (It might have already been processed by its companion)
193
+ val latestTree = getTracked(tree.symbol).getOrElse(tree)
194
+
195
+ // Expand and process MacroAnnotations
196
+ val (trees, companion) = MacroAnnotations .expandAnnotations(latestTree, companionTree)
197
+ // Enter and register the new symbols in their owner
198
+ for tree <- trees do MacroAnnotations .enterMissingSymbols(tree, self)
199
+ // Update the tracked trees
200
+ for case tree : MemberDef <- trees do updateTracked(tree.symbol, tree)
201
+ for tree <- companion do
202
+ updateTracked(tree.symbol, tree)
203
+ MacroAnnotations .enterMissingSymbols(tree, self)
204
+
205
+ // Perform inlining on the expansion of the annotations
206
+ val trees1 = trees.map(super .transform)
207
+
208
+ for case tree : MemberDef <- trees1 do updateTracked(tree.symbol, tree)
209
+
210
+ // Find classes added to the top level from a package object
211
+ val (topClasses, trees2) =
212
+ if ctx.owner.isPackageObject then trees1.partition(_.symbol.owner == ctx.owner.owner)
213
+ else (Nil , trees1)
214
+ if topClasses.nonEmpty then
215
+ newTopClasses.getOrElseUpdate(ctx.owner.owner, new mutable.ListBuffer ) ++= topClasses
216
+ flatTree(trees2)
217
+ else
218
+ super .transform(tree) match
219
+ case tree : MemberDef =>
220
+ updateTracked(tree.symbol, tree)
221
+ tree
222
+ case tree => tree
223
+ end transformMemberDef
224
+
116
225
}
226
+
117
227
}
118
228
119
229
object Inlining :
0 commit comments