Skip to content

Commit e81ef44

Browse files
committed
Add support for companion in MacroAnnotations
1 parent 54d67e0 commit e81ef44

File tree

70 files changed

+596
-325
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

70 files changed

+596
-325
lines changed

compiler/src/dotty/tools/dotc/transform/Inlining.scala

Lines changed: 142 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,23 @@
11
package dotty.tools.dotc
22
package transform
33

4+
import ast.tpd
5+
import ast.Trees.*
6+
import ast.TreeMapWithImplicits
47
import core.*
58
import Flags.*
9+
import Decorators.*
610
import Contexts.*
711
import Symbols.*
12+
import Decorators.*
13+
import config.Printers.inlining
14+
import DenotTransformers.IdentityDenotTransformer
15+
import inlines.Inlines
16+
import quoted.*
17+
import staging.StagingLevel
18+
import util.Property
819

9-
import dotty.tools.dotc.ast.tpd
10-
import dotty.tools.dotc.ast.Trees.*
11-
import dotty.tools.dotc.quoted.*
12-
import dotty.tools.dotc.inlines.Inlines
13-
import dotty.tools.dotc.ast.TreeMapWithImplicits
14-
import dotty.tools.dotc.core.DenotTransformers.IdentityDenotTransformer
15-
import dotty.tools.dotc.staging.StagingLevel
16-
17-
import scala.collection.mutable.ListBuffer
20+
import scala.collection.mutable
1821

1922
/** Inlines all calls to inline methods that are not in an inline method or a quote */
2023
class Inlining extends MacroTransform, IdentityDenotTransformer {
@@ -56,38 +59,94 @@ class Inlining extends MacroTransform, IdentityDenotTransformer {
5659

5760
def newTransformer(using Context): Transformer = new Transformer {
5861
override def transform(tree: tpd.Tree)(using Context): tpd.Tree =
59-
new InliningTreeMap().transform(tree)
62+
InliningTreeMap().transform(tree)
63+
}
64+
65+
private class MemoizeStatsTreeMap extends TreeMapWithImplicits {
66+
67+
// It is safe to assume that the companion of a tree is in the same scope
68+
// Therefore, when expanding MacroAnnotations, we will only keep track of
69+
// the trees in the same scope as the current transformed tree
70+
71+
private val TrackedTrees = new Property.Key[mutable.Map[Symbol, tpd.MemberDef]]
72+
73+
private def trackedTreesCtx(trees: mutable.Map[Symbol, tpd.MemberDef])(using Context): Context =
74+
ctx.fresh.setProperty(TrackedTrees, trees)
75+
76+
private def trackedTrees(using Context): mutable.Map[Symbol, MemberDef] =
77+
ctx.property(TrackedTrees).getOrElse(mutable.Map.empty)
78+
79+
protected final def getTracked(sym: Symbol)(using Context): Option[MemberDef] =
80+
for trees <- ctx.property(TrackedTrees); tree <- trees.get(sym) yield tree
81+
82+
protected final def updateTracked(sym: Symbol, tree: MemberDef)(using Context): Unit =
83+
for trees <- ctx.property(TrackedTrees) do trees.update(sym, tree)
84+
85+
private final def withUpdatedTree(stats: List[Tree], trackedTrees: mutable.Map[Symbol, MemberDef])(using Context) =
86+
stats.mapConserve:
87+
case tree: MemberDef if trackedTrees.contains(tree.symbol) =>
88+
trackedTrees(tree.symbol)
89+
case stat => stat
90+
91+
override def transform(tree: Tree)(using Context): Tree =
92+
tree match
93+
case PackageDef(_, stats) =>
94+
// Step I: Collect and memoize all the stats
95+
val treesToTrack = stats.collect { case m: MemberDef => (m.symbol, m) }
96+
val withTrackedTreeCtx = trackedTreesCtx(mutable.Map(treesToTrack*))
97+
// Step II & III: Transform the tree
98+
val result = super.transform(tree)(using withTrackedTreeCtx)
99+
// Step III: Reconcile between the symbols in syms and the tree
100+
(result: @unchecked) match
101+
case pkg@PackageDef(pid, stats) =>
102+
cpy.PackageDef(pkg)(
103+
pid = pid,
104+
stats = withUpdatedTree(stats, trackedTrees(using withTrackedTreeCtx))
105+
)
106+
case tree => tree
107+
case block: Block =>
108+
// Step I: Fetch all the member definitions in the block
109+
val trees = block.stats.collect { case m: MemberDef => (m.symbol, m) }
110+
val withTrackedTreeCtx = trackedTreesCtx(mutable.Map(trees*))
111+
// Step II: Transform the tree
112+
val result = super.transform(tree)(using withTrackedTreeCtx)
113+
// Step III: Reconcile between the symbols in syms and the tree
114+
(result: @unchecked) match
115+
case b@Block(stats, expr) =>
116+
val trackedTree = trackedTrees(using withTrackedTreeCtx)
117+
cpy.Block(b)(
118+
expr = expr,
119+
stats = withUpdatedTree(stats, trackedTrees(using withTrackedTreeCtx))
120+
)
121+
case tree => tree
122+
case TypeDef(_, impl: Template) =>
123+
// Step I: Collect and memoize all the stats
124+
val trees = impl.body.collect { case m: MemberDef => (m.symbol, m) }
125+
val withTrackedTreeCtx = trackedTreesCtx(mutable.Map(trees*))
126+
// Step II: Transform the tree
127+
val result = super.transform(tree)(using withTrackedTreeCtx)
128+
// Step III: Reconcile between the symbols in syms and the tree
129+
(result : @unchecked) match
130+
case tree@TypeDef(name, impl: Template) =>
131+
val trackedTree = trackedTrees(using withTrackedTreeCtx)
132+
cpy.TypeDef(tree)(
133+
name = name,
134+
rhs = cpy.Template(impl)(
135+
body = withUpdatedTree(impl.body, trackedTrees(using withTrackedTreeCtx))
136+
))
137+
case _ => super.transform(tree)
60138
}
61139

62-
private class InliningTreeMap extends TreeMapWithImplicits {
140+
private class InliningTreeMap extends MemoizeStatsTreeMap {
63141

64142
/** List of top level classes added by macro annotation in a package object.
65143
* These are added to the PackageDef that owns this particular package object.
66144
*/
67-
private val newTopClasses = MutableSymbolMap[ListBuffer[Tree]]()
145+
private val newTopClasses = MutableSymbolMap[mutable.ListBuffer[Tree]]()
68146

69147
override def transform(tree: Tree)(using Context): Tree = {
70148
tree match
71-
case tree: MemberDef =>
72-
if tree.symbol.is(Inline) then tree
73-
else if tree.symbol.is(Param) then super.transform(tree)
74-
else if
75-
!tree.symbol.isPrimaryConstructor
76-
&& StagingLevel.level == 0
77-
&& MacroAnnotations.hasMacroAnnotation(tree.symbol)
78-
then
79-
val trees = (new MacroAnnotations(self)).expandAnnotations(tree)
80-
val trees1 = trees.map(super.transform)
81-
82-
// Find classes added to the top level from a package object
83-
val (topClasses, trees2) =
84-
if ctx.owner.isPackageObject then trees1.partition(_.symbol.owner == ctx.owner.owner)
85-
else (Nil, trees1)
86-
if topClasses.nonEmpty then
87-
newTopClasses.getOrElseUpdate(ctx.owner.owner, new ListBuffer) ++= topClasses
88-
89-
flatTree(trees2)
90-
else super.transform(tree)
149+
case tree: MemberDef => transformMemberDef(tree)
91150
case _: Typed | _: Block =>
92151
super.transform(tree)
93152
case _: PackageDef =>
@@ -113,7 +172,58 @@ class Inlining extends MacroTransform, IdentityDenotTransformer {
113172
else Inlines.inlineCall(tree1)
114173
else super.transform(tree)
115174
}
175+
176+
private def transformMemberDef(tree: MemberDef)(using Context) : Tree =
177+
if tree.symbol.is(Inline) then tree
178+
else if tree.symbol.is(Param) then
179+
super.transform(tree)
180+
else if
181+
!tree.symbol.isPrimaryConstructor
182+
&& StagingLevel.level == 0
183+
&& MacroAnnotations.hasMacroAnnotation(tree.symbol)
184+
then
185+
// Fetch the companion's tree
186+
val companionTree =
187+
getTracked(
188+
sym = if tree.symbol.is(ModuleClass) then tree.symbol.companionClass
189+
else if tree.symbol.is(ModuleVal) then NoSymbol
190+
else tree.symbol.companionModule.moduleClass)
191+
192+
// Fetch the latest tracked tree (It might have already been processed by its companion)
193+
val latestTree = getTracked(tree.symbol).getOrElse(tree)
194+
195+
// Expand and process MacroAnnotations
196+
val (trees, companion) = MacroAnnotations.expandAnnotations(latestTree, companionTree)
197+
// Enter and register the new symbols in their owner
198+
for tree <- trees do MacroAnnotations.enterMissingSymbols(tree, self)
199+
// Update the tracked trees
200+
for case tree : MemberDef <- trees do updateTracked(tree.symbol, tree)
201+
for tree <- companion do
202+
updateTracked(tree.symbol, tree)
203+
MacroAnnotations.enterMissingSymbols(tree, self)
204+
205+
// Perform inlining on the expansion of the annotations
206+
val trees1 = trees.map(super.transform)
207+
208+
for case tree: MemberDef <- trees1 do updateTracked(tree.symbol, tree)
209+
210+
// Find classes added to the top level from a package object
211+
val (topClasses, trees2) =
212+
if ctx.owner.isPackageObject then trees1.partition(_.symbol.owner == ctx.owner.owner)
213+
else (Nil, trees1)
214+
if topClasses.nonEmpty then
215+
newTopClasses.getOrElseUpdate(ctx.owner.owner, new mutable.ListBuffer) ++= topClasses
216+
flatTree(trees2)
217+
else
218+
super.transform(tree) match
219+
case tree: MemberDef =>
220+
updateTracked(tree.symbol, tree)
221+
tree
222+
case tree => tree
223+
end transformMemberDef
224+
116225
}
226+
117227
}
118228

119229
object Inlining:

0 commit comments

Comments
 (0)